thiomajid commited on
Commit
209716f
·
verified ·
1 Parent(s): 20825b2

Upload transformer model after 3120 steps (perplexity: 6.170201301574707)

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ model_checkpoint/default/ocdbt.process_0/d/7cda073904d823e88c6221b1b1753d92 filter=lfs diff=lfs merge=lfs -text
37
+ model_checkpoint/default/ocdbt.process_0/d/8eb7ace34ec179a3f15dc99655709dee filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ tags:
4
+ - transformer
5
+ - autoregressive
6
+ - language-modeling
7
+ - tinystories
8
+ - jax
9
+ - flax
10
+ datasets:
11
+ - roneneldan/TinyStories
12
+ language:
13
+ - en
14
+ pipeline_tag: text-generation
15
+ ---
16
+
17
+ # Autoregressive Transformer trained on TinyStories
18
+
19
+ This is an autoregressive decoder-only transformer model trained on the TinyStories dataset using JAX and Flax NNX.
20
+
21
+ ## Model Details
22
+
23
+ - **Model Type**: Autoregressive Decoder-only Transformer
24
+ - **Framework**: JAX + Flax NNX
25
+ - **Dataset**: TinyStories
26
+ - **Parameters**: ~85.0M
27
+ - **Precision**: Mixed (FP32 parameters, BF16 computation)
28
+
29
+ ## Architecture
30
+
31
+ ```
32
+ - Hidden Size: 512
33
+ - Number of Layers: 8
34
+ - Attention Heads: 8
35
+ - Intermediate Size: 2048
36
+ - Max Position Embeddings: 256
37
+ - Vocab Size: 50257
38
+ - Rotary Position Embeddings: True
39
+ ```
40
+
41
+ ## Training Details
42
+
43
+ - **Training Steps**: 3,120
44
+ - **Batch Size**: 32
45
+ - **Gradient Accumulation**: 4
46
+ - **Learning Rate**: 0.0003
47
+ - **Training Duration**: 0.43 hours
48
+ - **Final Eval Loss**: 1.7965960502624512
49
+ - **Final Eval Perplexity**: 6.170201301574707
50
+
51
+ ## Usage
52
+
53
+ ```python
54
+ # This model was trained with JAX/Flax and requires the custom transformer implementation
55
+ # to load and use. See the repository for implementation details.
56
+
57
+ from transformers import AutoTokenizer
58
+ import jax.numpy as jnp
59
+
60
+ # Load tokenizer
61
+ tokenizer = AutoTokenizer.from_pretrained("gpt2")
62
+ if tokenizer.pad_token is None:
63
+ tokenizer.pad_token = tokenizer.eos_token
64
+
65
+ # Example text generation (requires custom model loading)
66
+ prompt = "Once upon a time, there was a little"
67
+ # ... (model loading and generation code)
68
+ ```
69
+
70
+ ## Training Configuration
71
+
72
+ ```yaml
73
+ model:
74
+ hidden_size: 512
75
+ num_layers: 8
76
+ num_attention_heads: 8
77
+ intermediate_size: 2048
78
+ max_position_embeddings: 256
79
+
80
+ training:
81
+ learning_rate: 0.0003
82
+ batch_size: 32
83
+ epochs: 10
84
+ warmup_ratio: 0.1
85
+ ```
86
+
87
+ ## Files
88
+
89
+ - `config.json`: Model configuration
90
+ - `train_history.json`: Training metrics and duration
91
+ - `tokenizer/`: GPT-2 tokenizer files
92
+ - `model_checkpoint/`: Best model checkpoint
93
+ - `tensorboard_logs/`: Training logs for TensorBoard
94
+
95
+ ## License
96
+
97
+ MIT License - see LICENSE file for details.
config.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "vocab_size": 50257,
3
+ "hidden_size": 512,
4
+ "num_layers": 8,
5
+ "num_attention_heads": 8,
6
+ "intermediate_size": 2048,
7
+ "max_position_embeddings": 256,
8
+ "attention_dropout": 0.1,
9
+ "use_rotary_pos_emb": true,
10
+ "rope_theta": 10000.0,
11
+ "hidden_dropout": 0.1,
12
+ "layer_norm_eps": 1e-05,
13
+ "use_bias": false,
14
+ "initializer_range": 0.02,
15
+ "pad_token_id": 50256,
16
+ "bos_token_id": 50256,
17
+ "eos_token_id": 50256,
18
+ "use_cache": true
19
+ }
model_checkpoint/_CHECKPOINT_METADATA ADDED
@@ -0,0 +1 @@
 
 
1
+ {"item_handlers": {"default": "orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler", "metrics": "orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler"}, "metrics": {}, "performance_metrics": {}, "init_timestamp_nsecs": 1754355577302929106, "commit_timestamp_nsecs": 1754355580447202487, "custom_metadata": {}}
model_checkpoint/default/_METADATA ADDED
@@ -0,0 +1 @@
 
 
1
+ {"tree_metadata": {"('embed_tokens', 'embedding', 'value')": {"key_metadata": [{"key": "embed_tokens", "key_type": 2}, {"key": "embedding", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [50257, 64]}}, "('layers', 'input_layernorm', 'weight', 'value')": {"key_metadata": [{"key": "layers", "key_type": 2}, {"key": "input_layernorm", "key_type": 2}, {"key": "weight", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1, 512]}}, "('layers', 'mlp', 'down_proj', 'kernel', 'value')": {"key_metadata": [{"key": "layers", "key_type": 2}, {"key": "mlp", "key_type": 2}, {"key": "down_proj", "key_type": 2}, {"key": "kernel", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1, 2048, 512]}}, "('layers', 'mlp', 'gate_proj', 'kernel', 'value')": {"key_metadata": [{"key": "layers", "key_type": 2}, {"key": "mlp", "key_type": 2}, {"key": "gate_proj", "key_type": 2}, {"key": "kernel", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 64, 2048]}}, "('layers', 'mlp', 'up_proj', 'kernel', 'value')": {"key_metadata": [{"key": "layers", "key_type": 2}, {"key": "mlp", "key_type": 2}, {"key": "up_proj", "key_type": 2}, {"key": "kernel", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 64, 2048]}}, "('layers', 'post_attention_layernorm', 'weight', 'value')": {"key_metadata": [{"key": "layers", "key_type": 2}, {"key": "post_attention_layernorm", "key_type": 2}, {"key": "weight", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1, 512]}}, "('layers', 'self_attn', 'k_proj', 'kernel', 'value')": {"key_metadata": [{"key": "layers", "key_type": 2}, {"key": "self_attn", "key_type": 2}, {"key": "k_proj", "key_type": 2}, {"key": "kernel", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 64, 512]}}, "('layers', 'self_attn', 'o_proj', 'kernel', 'value')": {"key_metadata": [{"key": "layers", "key_type": 2}, {"key": "self_attn", "key_type": 2}, {"key": "o_proj", "key_type": 2}, {"key": "kernel", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1, 512, 512]}}, "('layers', 'self_attn', 'q_proj', 'kernel', 'value')": {"key_metadata": [{"key": "layers", "key_type": 2}, {"key": "self_attn", "key_type": 2}, {"key": "q_proj", "key_type": 2}, {"key": "kernel", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 64, 512]}}, "('layers', 'self_attn', 'v_proj', 'kernel', 'value')": {"key_metadata": [{"key": "layers", "key_type": 2}, {"key": "self_attn", "key_type": 2}, {"key": "v_proj", "key_type": 2}, {"key": "kernel", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 64, 512]}}, "('lm_head', 'kernel', 'value')": {"key_metadata": [{"key": "lm_head", "key_type": 2}, {"key": "kernel", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [64, 50257]}}, "('norm', 'weight', 'value')": {"key_metadata": [{"key": "norm", "key_type": 2}, {"key": "weight", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [64]}}}, "use_zarr3": false, "store_array_data_equal_to_fill_value": true, "custom_metadata": null}
model_checkpoint/default/_sharding ADDED
@@ -0,0 +1 @@
 
 
1
+ {"ZW1iZWRfdG9rZW5zLmVtYmVkZGluZy52YWx1ZQ==":"{\"sharding_type\": \"NamedSharding\", \"shape\": [1, 8], \"axis_names\": [\"dp\", \"tp\"], \"partition_spec\": [null, \"tp\"], \"device_mesh\": {\"mesh\": [[{\"id\": 0}, {\"id\": 1}, {\"id\": 2}, {\"id\": 3}, {\"id\": 6}, {\"id\": 7}, {\"id\": 4}, {\"id\": 5}]]}}","bG1faGVhZC5rZXJuZWwudmFsdWU=":"{\"sharding_type\": \"NamedSharding\", \"shape\": [1, 8], \"axis_names\": [\"dp\", \"tp\"], \"partition_spec\": [\"tp\"], \"device_mesh\": {\"mesh\": [[{\"id\": 0}, {\"id\": 1}, {\"id\": 2}, {\"id\": 3}, {\"id\": 6}, {\"id\": 7}, {\"id\": 4}, {\"id\": 5}]]}}","bGF5ZXJzLm1scC51cF9wcm9qLmtlcm5lbC52YWx1ZQ==":"{\"sharding_type\": \"NamedSharding\", \"shape\": [1, 8], \"axis_names\": [\"dp\", \"tp\"], \"partition_spec\": [null, \"tp\"], \"device_mesh\": {\"mesh\": [[{\"id\": 0}, {\"id\": 1}, {\"id\": 2}, {\"id\": 3}, {\"id\": 6}, {\"id\": 7}, {\"id\": 4}, {\"id\": 5}]]}}","bGF5ZXJzLm1scC5kb3duX3Byb2oua2VybmVsLnZhbHVl":"{\"sharding_type\": \"NamedSharding\", \"shape\": [1, 8], \"axis_names\": [\"dp\", \"tp\"], \"partition_spec\": [\"tp\"], \"device_mesh\": {\"mesh\": [[{\"id\": 0}, {\"id\": 1}, {\"id\": 2}, {\"id\": 3}, {\"id\": 6}, {\"id\": 7}, {\"id\": 4}, {\"id\": 5}]]}}","bGF5ZXJzLm1scC5nYXRlX3Byb2oua2VybmVsLnZhbHVl":"{\"sharding_type\": \"NamedSharding\", \"shape\": [1, 8], \"axis_names\": [\"dp\", \"tp\"], \"partition_spec\": [null, \"tp\"], \"device_mesh\": {\"mesh\": [[{\"id\": 0}, {\"id\": 1}, {\"id\": 2}, {\"id\": 3}, {\"id\": 6}, {\"id\": 7}, {\"id\": 4}, {\"id\": 5}]]}}","bGF5ZXJzLmlucHV0X2xheWVybm9ybS53ZWlnaHQudmFsdWU=":"{\"sharding_type\": \"NamedSharding\", \"shape\": [1, 8], \"axis_names\": [\"dp\", \"tp\"], \"partition_spec\": [\"tp\"], \"device_mesh\": {\"mesh\": [[{\"id\": 0}, {\"id\": 1}, {\"id\": 2}, {\"id\": 3}, {\"id\": 6}, {\"id\": 7}, {\"id\": 4}, {\"id\": 5}]]}}","bGF5ZXJzLnBvc3RfYXR0ZW50aW9uX2xheWVybm9ybS53ZWlnaHQudmFsdWU=":"{\"sharding_type\": \"NamedSharding\", \"shape\": [1, 8], \"axis_names\": [\"dp\", \"tp\"], \"partition_spec\": [\"tp\"], \"device_mesh\": {\"mesh\": [[{\"id\": 0}, {\"id\": 1}, {\"id\": 2}, {\"id\": 3}, {\"id\": 6}, {\"id\": 7}, {\"id\": 4}, {\"id\": 5}]]}}","bGF5ZXJzLnNlbGZfYXR0bi52X3Byb2oua2VybmVsLnZhbHVl":"{\"sharding_type\": \"NamedSharding\", \"shape\": [1, 8], \"axis_names\": [\"dp\", \"tp\"], \"partition_spec\": [null, \"tp\"], \"device_mesh\": {\"mesh\": [[{\"id\": 0}, {\"id\": 1}, {\"id\": 2}, {\"id\": 3}, {\"id\": 6}, {\"id\": 7}, {\"id\": 4}, {\"id\": 5}]]}}","bGF5ZXJzLnNlbGZfYXR0bi5rX3Byb2oua2VybmVsLnZhbHVl":"{\"sharding_type\": \"NamedSharding\", \"shape\": [1, 8], \"axis_names\": [\"dp\", \"tp\"], \"partition_spec\": [null, \"tp\"], \"device_mesh\": {\"mesh\": [[{\"id\": 0}, {\"id\": 1}, {\"id\": 2}, {\"id\": 3}, {\"id\": 6}, {\"id\": 7}, {\"id\": 4}, {\"id\": 5}]]}}","bGF5ZXJzLnNlbGZfYXR0bi5vX3Byb2oua2VybmVsLnZhbHVl":"{\"sharding_type\": \"NamedSharding\", \"shape\": [1, 8], \"axis_names\": [\"dp\", \"tp\"], \"partition_spec\": [\"tp\"], \"device_mesh\": {\"mesh\": [[{\"id\": 0}, {\"id\": 1}, {\"id\": 2}, {\"id\": 3}, {\"id\": 6}, {\"id\": 7}, {\"id\": 4}, {\"id\": 5}]]}}","bGF5ZXJzLnNlbGZfYXR0bi5xX3Byb2oua2VybmVsLnZhbHVl":"{\"sharding_type\": \"NamedSharding\", \"shape\": [1, 8], \"axis_names\": [\"dp\", \"tp\"], \"partition_spec\": [null, \"tp\"], \"device_mesh\": {\"mesh\": [[{\"id\": 0}, {\"id\": 1}, {\"id\": 2}, {\"id\": 3}, {\"id\": 6}, {\"id\": 7}, {\"id\": 4}, {\"id\": 5}]]}}","bm9ybS53ZWlnaHQudmFsdWU=":"{\"sharding_type\": \"NamedSharding\", \"shape\": [1, 8], \"axis_names\": [\"dp\", \"tp\"], \"partition_spec\": [\"tp\"], \"device_mesh\": {\"mesh\": [[{\"id\": 0}, {\"id\": 1}, {\"id\": 2}, {\"id\": 3}, {\"id\": 6}, {\"id\": 7}, {\"id\": 4}, {\"id\": 5}]]}}"}
model_checkpoint/default/array_metadatas/process_0 ADDED
@@ -0,0 +1 @@
 
 
1
+ {"array_metadatas": [{"array_metadata": {"param_name": "embed_tokens.embedding.value", "write_shape": [50257, 64], "chunk_shape": [50257, 64], "ext_metadata": null}}, {"array_metadata": {"param_name": "layers.input_layernorm.weight.value", "write_shape": [1, 512], "chunk_shape": [1, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "layers.mlp.down_proj.kernel.value", "write_shape": [1, 2048, 512], "chunk_shape": [1, 2048, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "layers.mlp.gate_proj.kernel.value", "write_shape": [8, 64, 2048], "chunk_shape": [8, 64, 2048], "ext_metadata": null}}, {"array_metadata": {"param_name": "layers.mlp.up_proj.kernel.value", "write_shape": [8, 64, 2048], "chunk_shape": [8, 64, 2048], "ext_metadata": null}}, {"array_metadata": {"param_name": "layers.post_attention_layernorm.weight.value", "write_shape": [1, 512], "chunk_shape": [1, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "layers.self_attn.k_proj.kernel.value", "write_shape": [8, 64, 512], "chunk_shape": [8, 64, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "layers.self_attn.o_proj.kernel.value", "write_shape": [1, 512, 512], "chunk_shape": [1, 512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "layers.self_attn.q_proj.kernel.value", "write_shape": [8, 64, 512], "chunk_shape": [8, 64, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "layers.self_attn.v_proj.kernel.value", "write_shape": [8, 64, 512], "chunk_shape": [8, 64, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "lm_head.kernel.value", "write_shape": [64, 50257], "chunk_shape": [64, 50257], "ext_metadata": null}}, {"array_metadata": {"param_name": "norm.weight.value", "write_shape": [64], "chunk_shape": [64], "ext_metadata": null}}]}
model_checkpoint/default/d/bbe9d23c311448de661dd075a5e5c8fe ADDED
Binary file (3.3 kB). View file
 
model_checkpoint/default/manifest.ocdbt ADDED
Binary file (117 Bytes). View file
 
model_checkpoint/default/ocdbt.process_0/d/7bed338326f3a30d7a31ea15d2bdf6b8 ADDED
Binary file (203 Bytes). View file
 
model_checkpoint/default/ocdbt.process_0/d/7cda073904d823e88c6221b1b1753d92 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a48e357d8aacd9475a5b559293c52968db70b5d6194a3710c5102e41667e6077
3
+ size 193635403
model_checkpoint/default/ocdbt.process_0/d/8eb7ace34ec179a3f15dc99655709dee ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d1499ddea7c9d643d6dd48045ee5e7b234b847f4baf38ee8649028e05155b3c0
3
+ size 118000029
model_checkpoint/default/ocdbt.process_0/d/bf31c95ddb90334c53f019bb5658c1f5 ADDED
Binary file (444 Bytes). View file
 
model_checkpoint/default/ocdbt.process_0/manifest.ocdbt ADDED
Binary file (258 Bytes). View file
 
model_checkpoint/metrics/metrics ADDED
@@ -0,0 +1 @@
 
 
1
+ {"eval_loss": 1.7965960502624512}
tensorboard_logs/events.out.tfevents.1754354033.047df4d1a242.2109.0.v2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dd3e08939b8b6d15aa8bbddf2dfba5e0ebdcdebe169ea2cf7921d2d75979bedd
3
+ size 14383
tokenizer/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer/special_tokens_map.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "<|endoftext|>",
3
+ "eos_token": "<|endoftext|>",
4
+ "pad_token": "<|endoftext|>",
5
+ "unk_token": "<|endoftext|>"
6
+ }
tokenizer/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer/tokenizer_config.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "added_tokens_decoder": {
4
+ "50256": {
5
+ "content": "<|endoftext|>",
6
+ "lstrip": false,
7
+ "normalized": true,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ }
12
+ },
13
+ "bos_token": "<|endoftext|>",
14
+ "clean_up_tokenization_spaces": false,
15
+ "eos_token": "<|endoftext|>",
16
+ "extra_special_tokens": {},
17
+ "model_max_length": 1024,
18
+ "pad_token": "<|endoftext|>",
19
+ "tokenizer_class": "GPT2Tokenizer",
20
+ "unk_token": "<|endoftext|>"
21
+ }
tokenizer/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
train_history.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "total_training_duration": 1547.211970743,
3
+ "avg_epoch_duration": 154.72109574149997,
4
+ "num_epochs_completed": 10,
5
+ "global_steps": 3120,
6
+ "global_optimizer_steps": 780,
7
+ "total_parameters": 85026304
8
+ }