--- license: mit tags: - transformer - autoregressive - language-modeling - tinystories - jax - flax datasets: - roneneldan/TinyStories language: - en pipeline_tag: text-generation --- # Autoregressive Transformer trained on TinyStories This is an autoregressive decoder-only transformer model trained on the TinyStories dataset using JAX and Flax NNX. ## Model Details - **Model Type**: Autoregressive Decoder-only Transformer - **Framework**: JAX + Flax NNX - **Dataset**: TinyStories - **Parameters**: ~85.0M - **Precision**: Mixed (FP32 parameters, BF16 computation) ## Architecture ``` - Hidden Size: 512 - Number of Layers: 8 - Attention Heads: 8 - Intermediate Size: 2048 - Max Position Embeddings: 256 - Vocab Size: 50257 - Rotary Position Embeddings: True ``` ## Training Details - **Training Steps**: 3,120 - **Batch Size**: 32 - **Gradient Accumulation**: 4 - **Learning Rate**: 0.0003 - **Training Duration**: 0.43 hours - **Final Eval Loss**: 1.7965960502624512 - **Final Eval Perplexity**: 6.170201301574707 ## Usage ```python # This model was trained with JAX/Flax and requires the custom transformer implementation # to load and use. See the repository for implementation details. from transformers import AutoTokenizer import jax.numpy as jnp # Load tokenizer tokenizer = AutoTokenizer.from_pretrained("gpt2") if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # Example text generation (requires custom model loading) prompt = "Once upon a time, there was a little" # ... (model loading and generation code) ``` ## Training Configuration ```yaml model: hidden_size: 512 num_layers: 8 num_attention_heads: 8 intermediate_size: 2048 max_position_embeddings: 256 training: learning_rate: 0.0003 batch_size: 32 epochs: 10 warmup_ratio: 0.1 ``` ## Files - `config.json`: Model configuration - `train_history.json`: Training metrics and duration - `tokenizer/`: GPT-2 tokenizer files - `model_checkpoint/`: Best model checkpoint - `tensorboard_logs/`: Training logs for TensorBoard ## License MIT License - see LICENSE file for details.