Autoregressive Transformer trained on TinyStories
This is an autoregressive decoder-only transformer model trained on the TinyStories dataset using JAX and Flax NNX.
Model Details
- Model Type: Autoregressive Decoder-only Transformer
 - Framework: JAX + Flax NNX
 - Dataset: TinyStories
 - Parameters: ~85.0M
 - Precision: Mixed (FP32 parameters, BF16 computation)
 
Architecture
- Hidden Size: 512
- Number of Layers: 8
- Attention Heads: 8
- Intermediate Size: 2048
- Max Position Embeddings: 256
- Vocab Size: 50257
- Rotary Position Embeddings: True
Training Details
- Training Steps: 3,120
 - Batch Size: 32
 - Gradient Accumulation: 4
 - Learning Rate: 0.0003
 - Training Duration: 0.43 hours
 - Final Eval Loss: 1.7965960502624512
 - Final Eval Perplexity: 6.170201301574707
 
Usage
# This model was trained with JAX/Flax and requires the custom transformer implementation
# to load and use. See the repository for implementation details.
from transformers import AutoTokenizer
import jax.numpy as jnp
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("gpt2")
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
# Example text generation (requires custom model loading)
prompt = "Once upon a time, there was a little"
# ... (model loading and generation code)
Training Configuration
model:
  hidden_size: 512
  num_layers: 8  
  num_attention_heads: 8
  intermediate_size: 2048
  max_position_embeddings: 256
training:
  learning_rate: 0.0003
  batch_size: 32
  epochs: 10
  warmup_ratio: 0.1
Files
config.json: Model configurationtrain_history.json: Training metrics and durationtokenizer/: GPT-2 tokenizer filesmodel_checkpoint/: Best model checkpointtensorboard_logs/: Training logs for TensorBoard
License
MIT License - see LICENSE file for details.
- Downloads last month
 - 29