sky-2002 commited on
Commit
fc96ac7
·
verified ·
1 Parent(s): 43088b5

Upload deepseek_tinystories/modeling_deepseek.py

Browse files
deepseek_tinystories/modeling_deepseek.py ADDED
@@ -0,0 +1,691 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from torch import nn
3
+ import torch
4
+ from typing import Optional
5
+ import torch.nn.functional as F
6
+
7
+
8
+ @dataclass
9
+ class DeepSeekModelConfig:
10
+ num_attention_heads: int = 8
11
+ input_dim: int = 1024
12
+ embed_dim: int = 1024
13
+ bias: bool = False
14
+ dropout: float = 0.1
15
+
16
+ kv_heads: int = 4 # number of key-value heads for grouped query attention
17
+
18
+ # configs needed for MLA
19
+ mla_kv_heads: int = (
20
+ 4 # number of groups of attention heads that share the same K and V matrices
21
+ )
22
+ use_mla: bool = False
23
+ num_gpus: int = 1 # number of gpus
24
+ # n_local_heads
25
+ # this is maybe for cases where computation is distributed across gpus, will have to read more
26
+
27
+ q_latent_dim: int = 4 # dimension of latent used to build queries
28
+ kv_latent_dim: int = 4 # dimension of latent used to build keys and values
29
+
30
+ # in official implementation, there are configs for
31
+ # rope and no-rope attention head dimensions, I am keeping it same as head dim
32
+ # since we concatenate the no-rope and rope queries and keys, they add these dimnensions
33
+ # to be later used to scaling attention scores
34
+
35
+ max_batch_size: int = 8
36
+ max_token_len: int = 1024
37
+
38
+ num_shared_experts: int = 8
39
+ num_routed_experts: int = 16
40
+ moe_top_k: int = 2
41
+ expert_intermediate_dim: int = 8192
42
+ eta: float = 0.05
43
+
44
+ num_dense_ffn: int = 2
45
+ num_moe_ffn: int = 4
46
+
47
+ mtp_depth: int = 3
48
+ vocab_size: int = 50257
49
+
50
+
51
+ class Expert(nn.Module):
52
+
53
+ def __init__(self, input_dim: int, intermediate_dim: int, dropout: float):
54
+ super().__init__()
55
+ self.w1 = nn.Linear(input_dim, intermediate_dim)
56
+ self.w11 = nn.Linear(input_dim, intermediate_dim)
57
+ self.w2 = nn.Linear(intermediate_dim, input_dim)
58
+ self.dropout = nn.Dropout(dropout)
59
+
60
+ def forward(self, x):
61
+ return self.dropout(self.w2(F.silu(self.w1(x)) * self.w11(x)))
62
+
63
+
64
+ class MoE(nn.Module):
65
+ def __init__(self, config: DeepSeekModelConfig):
66
+ super().__init__()
67
+ self.num_shared_experts = config.num_shared_experts
68
+ self.num_routed_experts = config.num_routed_experts
69
+ self.num_local_experts = config.num_routed_experts // config.num_gpus
70
+ self.top_k = config.moe_top_k
71
+
72
+ self.expert_selector = nn.Linear(
73
+ config.input_dim, self.num_routed_experts, bias=False
74
+ )
75
+ self.routed_experts = nn.ModuleList(
76
+ [
77
+ Expert(config.input_dim, config.expert_intermediate_dim, config.dropout)
78
+ for _ in range(self.num_routed_experts)
79
+ ]
80
+ )
81
+ self.shared_experts = Expert(
82
+ config.input_dim,
83
+ config.expert_intermediate_dim * self.num_shared_experts,
84
+ config.dropout,
85
+ )
86
+ self.eta = config.eta
87
+ self.register_buffer("expert_bias", torch.zeros(self.num_routed_experts))
88
+
89
+ def forward(self, x):
90
+ batch_size, num_tokens, input_dim = x.shape
91
+ gate_output, topk_indices = self.topk_routing(x, self.expert_bias)
92
+ x = x.view(
93
+ batch_size * num_tokens, input_dim
94
+ ) # so now it is like a list of tokens
95
+ gate_output = gate_output.view(batch_size * num_tokens, -1)
96
+
97
+ topk_indices = topk_indices.view(batch_size * num_tokens, -1)
98
+
99
+ # --- cache routing info for interpretability ---
100
+ self.last_topk_indices = (
101
+ topk_indices.view(batch_size, num_tokens, -1).detach().cpu()
102
+ )
103
+ self.last_gate_output = (
104
+ gate_output.view(batch_size, num_tokens, -1).detach().cpu()
105
+ )
106
+
107
+ expert_counts = torch.bincount(
108
+ topk_indices.flatten(), minlength=self.num_routed_experts
109
+ )
110
+
111
+ with torch.no_grad():
112
+ avg = expert_counts.float().mean()
113
+ err = expert_counts.float() - avg
114
+ self.expert_bias += -self.eta * err.sign()
115
+
116
+ # Save for logging
117
+ if hasattr(self, "expert_usage"):
118
+ self.expert_usage.append(expert_counts.detach().cpu())
119
+ else:
120
+ self.expert_usage = [expert_counts.detach().cpu()]
121
+
122
+ y = torch.zeros_like(x)
123
+ # counts = torch.bincount(
124
+ # topk_indices.flatten(), minlength=self.num_routed_experts
125
+ # ).tolist()
126
+ counts = expert_counts.tolist()
127
+ for i in range(self.num_routed_experts):
128
+ if counts[i] == 0:
129
+ continue
130
+ expert = self.routed_experts[i]
131
+
132
+ idx, expert_rank = torch.where(topk_indices == i)
133
+ y[idx] += expert(x[idx]) * gate_output[idx, expert_rank, None]
134
+
135
+ z = self.shared_experts(x)
136
+ return (y + z).view(batch_size, num_tokens, input_dim)
137
+
138
+ def topk_routing(self, x, bias=None):
139
+ batch_size, num_tokens, input_dim = x.shape
140
+
141
+ expert_logits = self.expert_selector(x) # B, T, num_experts
142
+ if bias is not None:
143
+ expert_logits = expert_logits + bias
144
+ topk_logits, topk_indices = torch.topk(expert_logits, k=self.top_k, dim=-1)
145
+ zeros = torch.full_like(expert_logits, float("-inf"))
146
+ sparse_logits = zeros.scatter(dim=-1, index=topk_indices, src=topk_logits)
147
+ gate_output = sparse_logits.softmax(dim=-1)
148
+ return gate_output, topk_indices
149
+
150
+
151
+ class RoPE(nn.Module):
152
+
153
+ def __init__(self, dim: int, max_seq_len: int = 2048, base: float = 10000.0):
154
+ super().__init__()
155
+ self.dim = dim
156
+ self.max_seq_len = max_seq_len
157
+ self.base = base
158
+
159
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
160
+ self.register_buffer("inv_freq", inv_freq)
161
+
162
+ self._cached_cos = None
163
+ self._cached_sin = None
164
+ self._cached_seq_len = 0
165
+
166
+ def _compute_cos_sin(self, seq_len: int, device: torch.device):
167
+ if seq_len > self._cached_seq_len or self._cached_cos is None:
168
+
169
+ t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype)
170
+
171
+ freqs = torch.outer(t, self.inv_freq)
172
+
173
+ cos_vals = torch.cos(freqs)
174
+ sin_vals = torch.sin(freqs)
175
+
176
+ self._cached_cos = cos_vals
177
+ self._cached_sin = sin_vals
178
+ self._cached_seq_len = seq_len
179
+
180
+ return self._cached_cos[:seq_len], self._cached_sin[:seq_len]
181
+
182
+ def apply_rope(self, x: torch.Tensor, position_ids: Optional[torch.Tensor] = None):
183
+ """Apply RoPE to input tensor"""
184
+ batch_size, num_tokens, n_heads, head_dim = x.shape
185
+
186
+ cos, sin = self._compute_cos_sin(num_tokens, x.device)
187
+
188
+ if position_ids is not None:
189
+ cos = cos[position_ids]
190
+ sin = sin[position_ids]
191
+
192
+ cos = cos.unsqueeze(0).unsqueeze(2) # [1, seq_len, 1, head_dim//2]
193
+ sin = sin.unsqueeze(0).unsqueeze(2)
194
+
195
+ x1 = x[..., ::2] # Even indices
196
+ x2 = x[..., 1::2] # Odd indices
197
+
198
+ rotated_x1 = x1 * cos - x2 * sin
199
+ rotated_x2 = x1 * sin + x2 * cos
200
+
201
+ rotated_x = torch.stack([rotated_x1, rotated_x2], dim=-1).flatten(-2)
202
+
203
+ return rotated_x
204
+
205
+
206
+ class MultiHeadAttention(nn.Module):
207
+ def __init__(self, config: DeepSeekModelConfig):
208
+ super().__init__()
209
+ self.num_heads = config.num_attention_heads
210
+ self.input_dim = config.input_dim
211
+ self.embed_dim = config.embed_dim
212
+ self.head_dim = self.embed_dim // self.num_heads
213
+
214
+ self.Wq = nn.Linear(self.input_dim, self.embed_dim, bias=False)
215
+ self.Wk = nn.Linear(self.input_dim, self.embed_dim, bias=False)
216
+ self.Wv = nn.Linear(self.input_dim, self.embed_dim, bias=False)
217
+ self.out_proj = nn.Linear(self.embed_dim, self.input_dim, bias=config.bias)
218
+
219
+ def forward(self, x):
220
+ # x is B, T, input_dim
221
+ batch_size, num_tokens, input_dim = x.shape
222
+ Q = (
223
+ self.Wq(x)
224
+ .view(batch_size, num_tokens, self.num_heads, self.head_dim)
225
+ .transpose(1, 2)
226
+ ) # becomes B, num_heads, T, head_dim
227
+ K = (
228
+ self.Wk(x)
229
+ .view(batch_size, num_tokens, self.num_heads, self.head_dim)
230
+ .transpose(1, 2)
231
+ ) # becomes B, num_heads, T, head_dim
232
+ V = (
233
+ self.Wv(x)
234
+ .view(batch_size, num_tokens, self.num_heads, self.head_dim)
235
+ .transpose(1, 2)
236
+ ) # becomes B, num_heads, T, head_dim
237
+
238
+ attention_scores = Q @ K.transpose(2, 3)
239
+ attention_scores = attention_scores / (self.head_dim**0.5)
240
+
241
+ causal_mask = torch.triu(torch.ones(num_tokens, num_tokens), diagonal=1)
242
+
243
+ attention_scores = attention_scores.masked_fill(
244
+ causal_mask.bool(), float("-inf")
245
+ )
246
+ attention_weights = torch.softmax(
247
+ attention_scores, dim=-1
248
+ ) # B, num_heads, T, T
249
+
250
+ context = attention_weights @ V # B, num_heads, T, head_dim
251
+ context = attention_weights.transpose(1, 2) # B, T, num_heads, head_dim
252
+ context = attention_weights.view(batch_size, num_tokens, self.embed_dim)
253
+ out = self.out_proj(context) # B, T, input_dim
254
+ return out
255
+
256
+
257
+ class MultiQueryAttention(nn.Module):
258
+ def __init__(self, config: DeepSeekModelConfig):
259
+ super().__init__()
260
+ self.num_heads = config.num_attention_heads
261
+ self.input_dim = config.input_dim
262
+ self.embed_dim = config.embed_dim
263
+ self.head_dim = self.embed_dim // self.num_heads
264
+
265
+ self.Wq = nn.Linear(self.input_dim, self.embed_dim, bias=False)
266
+ self.Wk = nn.Linear(self.input_dim, self.head_dim, bias=False)
267
+ self.Wv = nn.Linear(self.input_dim, self.head_dim, bias=False)
268
+ self.out_proj = nn.Linear(self.embed_dim, self.input_dim, bias=config.bias)
269
+
270
+ def forward(self, x):
271
+ # x is B, T, input_dim
272
+ batch_size, num_tokens, input_dim = x.shape
273
+ Q = (
274
+ self.Wq(x)
275
+ .view(batch_size, num_tokens, self.num_heads, self.head_dim)
276
+ .transpose(1, 2)
277
+ ) # becomes B, num_heads, T, head_dim
278
+ K = self.Wk(x) # B, T, head_dim
279
+ V = self.Wv(x) # B, T, head_dim
280
+
281
+ # create copies for all heads
282
+ K = K.expand(-1, self.num_heads, -1, -1)
283
+ V = V.expand(-1, self.num_heads, -1, -1)
284
+
285
+ attention_scores = Q @ K.transpose(2, 3)
286
+ attention_scores = attention_scores / (self.head_dim**0.5)
287
+
288
+ causal_mask = torch.triu(torch.ones(num_tokens, num_tokens), diagonal=1)
289
+
290
+ attention_scores = attention_scores.masked_fill(
291
+ causal_mask.bool(), float("-inf")
292
+ )
293
+ attention_weights = torch.softmax(
294
+ attention_scores, dim=-1
295
+ ) # B, num_heads, T, T
296
+
297
+ context = attention_weights @ V # B, num_heads, T, head_dim
298
+ context = attention_weights.transpose(1, 2) # B, T, num_heads, head_dim
299
+ context = attention_weights.view(batch_size, num_tokens, self.embed_dim)
300
+ out = self.out_proj(context) # B, T, input_dim
301
+ return out
302
+
303
+
304
+ class GroupedQueryAttention(nn.Module):
305
+ def __init__(self, config):
306
+ super().__init__()
307
+ self.num_heads = config.num_attention_heads
308
+ self.input_dim = config.input_dim
309
+ self.embed_dim = config.embed_dim
310
+ self.head_dim = self.embed_dim // self.num_heads
311
+ self.kv_heads = config.kv_heads
312
+
313
+ self.Wq = nn.Linear(self.input_dim, self.embed_dim, bias=False)
314
+ self.Wk = nn.Linear(self.input_dim, self.head_dim * config.kv_heads, bias=False)
315
+ self.Wv = nn.Linear(self.input_dim, self.head_dim * config.kv_heads, bias=False)
316
+ self.out_proj = nn.Linear(self.embed_dim, self.input_dim, bias=config.bias)
317
+
318
+ def forward(self, x):
319
+ batch_size, num_tokens, input_dim = x.shape
320
+ Q = (
321
+ self.Wq(x)
322
+ .view(batch_size, num_tokens, self.num_heads, self.head_dim)
323
+ .transpose(1, 2)
324
+ ) # becomes B, num_heads, T, head_dim
325
+
326
+ K = self.Wk(x) # B, T, head_dim*kv_heads
327
+ V = self.Wv(x) # B, T, head_dim*kv_heads
328
+
329
+ K = K.view(batch_size, num_tokens, self.kv_heads, self.head_dim)
330
+ V = V.view(batch_size, num_tokens, self.kv_heads, self.head_dim)
331
+
332
+ # now i need this
333
+ # if kv_heads is 3 and num_heads is 6
334
+ # I want k = [k1, k1, k2, k2, k3, k3] and same for v
335
+ K = K.repeat_interleave(
336
+ self.num_heads // self.kv_heads, dim=2
337
+ ) # B, T, num_heads, head_dim
338
+ V = V.repeat_interleave(
339
+ self.num_heads // self.kv_heads, dim=2
340
+ ) # B, T, num_heads, head_dim
341
+
342
+ attention_scores = Q @ K.transpose(2, 3)
343
+ attention_scores = attention_scores / (self.head_dim**0.5)
344
+
345
+ causal_mask = torch.triu(torch.ones(num_tokens, num_tokens), diagonal=1)
346
+
347
+ attention_scores = attention_scores.masked_fill(
348
+ causal_mask.bool(), float("-inf")
349
+ )
350
+ attention_weights = torch.softmax(
351
+ attention_scores, dim=-1
352
+ ) # B, num_heads, T, T
353
+
354
+ context = attention_weights @ V # B, num_heads, T, head_dim
355
+ context = attention_weights.transpose(1, 2) # B, T, num_heads, head_dim
356
+ context = attention_weights.view(batch_size, num_tokens, self.embed_dim)
357
+ out = self.out_proj(context) # B, T, input_dim
358
+ return out
359
+
360
+
361
+ # I have copied RMSNorm directly from Deepseek-V3 repo
362
+ class RMSNorm(nn.Module):
363
+ def __init__(self, dim: int, eps: float = 1e-6):
364
+ super().__init__()
365
+ self.dim = dim
366
+ self.eps = eps
367
+ self.weight = nn.Parameter(torch.ones(dim))
368
+
369
+ def forward(self, x: torch.Tensor):
370
+ return F.rms_norm(x, (self.dim,), self.weight, self.eps)
371
+
372
+
373
+ # TODO:
374
+ # 1. Try out grouped query attention styled MLA, where each kv head has its own latent cache
375
+ # 2.Try out sliding window attention, I read about this in gemma paper
376
+ class MultiHeadLatentAttention(nn.Module):
377
+
378
+ def __init__(self, config: DeepSeekModelConfig):
379
+ super().__init__()
380
+ self.num_heads = config.num_attention_heads
381
+ self.input_dim = config.input_dim
382
+ self.embed_dim = config.embed_dim
383
+ self.n_local_heads = config.num_attention_heads // config.num_gpus
384
+ self.head_dim = self.embed_dim // self.num_heads
385
+ self.mla_kv_heads = config.mla_kv_heads
386
+ self.kv_latent_dim = config.kv_latent_dim
387
+ self.q_latent_dim = config.q_latent_dim
388
+ self.dropout = nn.Dropout(config.dropout)
389
+
390
+ self.rope = RoPE(dim=self.head_dim)
391
+ self.out_proj = nn.Linear(
392
+ self.num_heads * self.head_dim, self.input_dim, bias=False
393
+ )
394
+
395
+ if self.q_latent_dim == 0:
396
+ self.Wq = nn.Linear(
397
+ self.input_dim, self.num_heads * self.head_dim, bias=False
398
+ )
399
+ else:
400
+ # -------------------(decoupled from RoPE)-----------------------------
401
+ # Query path - This feels to me like LoRa on Q
402
+ # because instead of Wq (input_dim, input_dim) we now have
403
+ # Wdq(input_dim, q_latent_dim) and Wuq(q_latent_dim, input_dim)
404
+ self.Wdq = nn.Linear(self.input_dim, self.q_latent_dim, bias=False)
405
+ self.q_norm = RMSNorm(self.q_latent_dim)
406
+ self.Wuq = nn.Linear(
407
+ self.q_latent_dim, self.num_heads * self.head_dim, bias=False
408
+ )
409
+
410
+ # this will build KV latent and also construct K and V from it
411
+ self.Wdkv = nn.Linear(self.input_dim, self.kv_latent_dim, bias=False)
412
+ self.kv_norm = RMSNorm(self.kv_latent_dim)
413
+ self.Wuk = nn.Linear(
414
+ self.kv_latent_dim, self.head_dim, bias=False
415
+ ) # here I am not using num_heads because we will use kv heads (grouped query attention)
416
+ self.Wuv = nn.Linear(
417
+ self.kv_latent_dim, self.mla_kv_heads * self.head_dim, bias=False
418
+ )
419
+
420
+ # cache the kv latent and the roped keys
421
+ self.register_buffer(
422
+ "kv_latent_cache",
423
+ torch.zeros(
424
+ config.max_batch_size, config.max_token_len, self.kv_latent_dim
425
+ ),
426
+ persistent=False, # I won't store on disk
427
+ )
428
+ self.register_buffer(
429
+ "keys_roped",
430
+ torch.zeros(
431
+ config.max_batch_size,
432
+ config.max_token_len,
433
+ self.mla_kv_heads,
434
+ # I could have not used these heads, then we have same keys for each head,4
435
+ # here it is same for a group of attention heads which come under one kv head
436
+ self.head_dim,
437
+ ),
438
+ persistent=False,
439
+ )
440
+ # --------------------------------------------------------------------
441
+
442
+ # -------------RoPE path----------------------------------------------
443
+ self.Wkr = nn.Linear(
444
+ self.input_dim, self.mla_kv_heads * self.head_dim, bias=False
445
+ )
446
+ self.Wqr = nn.Linear(self.q_latent_dim, self.embed_dim, bias=False)
447
+
448
+ def forward(self, x, start_pos=0):
449
+ batch_size, num_tokens, input_dim = x.shape
450
+ end_pos = start_pos + num_tokens
451
+ S = end_pos # total cached sequence length
452
+
453
+ # ----- Queries -----
454
+ if self.q_latent_dim == 0:
455
+ Q = (
456
+ self.Wq(x)
457
+ .view(batch_size, num_tokens, self.num_heads, self.head_dim)
458
+ .transpose(1, 2)
459
+ ) # [B, num_heads, T, head_dim]
460
+ else:
461
+ query_latent = self.Wdq(x)
462
+ query_latent = self.q_norm(query_latent)
463
+ Q = (
464
+ self.Wuq(query_latent)
465
+ .view(batch_size, num_tokens, self.num_heads, self.head_dim)
466
+ .transpose(1, 2) # [B, num_heads, T, head_dim]
467
+ )
468
+ # ----- RoPE path -----
469
+ if self.q_latent_dim == 0:
470
+ Qr = self.rope.apply_rope(
471
+ Q.view(batch_size, num_tokens, self.num_heads, self.head_dim)
472
+ ).transpose(1, 2)
473
+ else:
474
+ Qr = self.rope.apply_rope(
475
+ self.Wqr(query_latent).view(
476
+ batch_size, num_tokens, self.num_heads, self.head_dim
477
+ )
478
+ ).transpose(1, 2)
479
+ # ---------------------
480
+
481
+ # ----- KV latent -----
482
+ kv_latent = self.Wdkv(x) # [B, T, kv_latent_dim]
483
+ # update cache
484
+ self.kv_latent_cache[:batch_size, start_pos:end_pos] = self.kv_norm(
485
+ kv_latent
486
+ ).detach()
487
+
488
+ kv_latent_all = self.kv_latent_cache[
489
+ :batch_size, :end_pos
490
+ ] # [B, T, kv_latent_dim]
491
+
492
+ # [B, num_heads, T, head_dim] x [head_dim, kv_latent_dim]
493
+ Q_absorbed = Q @ self.Wuk.weight # B, num_heads, T, kv_latent_dim
494
+
495
+ V = self.Wuv(kv_latent_all).view(
496
+ batch_size, S, self.mla_kv_heads, self.head_dim
497
+ ) # [B, S, mla_kv_heads, head_dim]
498
+ # expand V to match n_heads
499
+ V = V.repeat_interleave(
500
+ self.num_heads // self.mla_kv_heads, dim=2
501
+ ) # [B, T, num_heads, head_dim]
502
+
503
+ V = V.transpose(1, 2) # [B, H, S, D]
504
+
505
+ # ----- RoPE path -----
506
+ K_pos_encoding = self.rope.apply_rope(
507
+ self.Wkr(x)
508
+ .view(batch_size, num_tokens, self.mla_kv_heads, self.head_dim)
509
+ .transpose(1, 2)
510
+ ).transpose(
511
+ 1, 2
512
+ ) # B, T, mla_kv_heads head_dim
513
+ self.keys_roped[:batch_size, start_pos:end_pos] = K_pos_encoding.detach()
514
+ keys_roped_all = self.keys_roped[:batch_size, :end_pos]
515
+ Kr = (
516
+ keys_roped_all.repeat_interleave(self.num_heads // self.mla_kv_heads, dim=2)
517
+ .view(batch_size, S, self.num_heads, self.head_dim)
518
+ .transpose(1, 2) # [B, S, T, head_dim]
519
+ )
520
+
521
+ # ----- Attention scores -----
522
+ # doing unsqueeze to account for heads, since kv cache is only one, not per head
523
+ attention_scores_1 = Q_absorbed @ kv_latent_all.unsqueeze(1).transpose(2, 3)
524
+
525
+ attention_scores_2 = Qr @ Kr.transpose(-2, -1) # [B, num_heads, T, T]
526
+ attention_scores = (attention_scores_1 + attention_scores_2) / (
527
+ 2 * self.head_dim
528
+ ) ** 0.5
529
+
530
+ # causal mask
531
+ causal_mask = torch.triu(
532
+ torch.ones(end_pos, end_pos, device=x.device), diagonal=1
533
+ )
534
+ attention_scores = attention_scores.masked_fill(
535
+ causal_mask.bool()[:, -num_tokens:], float("-inf")
536
+ )
537
+
538
+ attention_weights = torch.softmax(attention_scores, dim=-1)
539
+ self.last_attention = attention_weights.detach()
540
+ attention_weights = self.dropout(attention_weights)
541
+
542
+ # ----- Context -----
543
+ context = attention_weights @ V # [B, H, T, D]
544
+ context = (
545
+ context.transpose(1, 2)
546
+ .contiguous()
547
+ .view(batch_size, num_tokens, self.embed_dim)
548
+ )
549
+ out = self.out_proj(context)
550
+ return out
551
+
552
+
553
+ # Note: I might not use this in training, will do normal single token prediction only
554
+ class BasicMultiTokenPrediction(nn.Module):
555
+
556
+ def __init__(self, config: DeepSeekModelConfig):
557
+ super().__init__()
558
+
559
+ # If k is mtp_depth, and current token position is i
560
+ # this module predicts next k tokens, so from
561
+ # (i+1) to (i+k)
562
+ self.k = config.mtp_depth
563
+ self.vocab_size = config.vocab_size
564
+ self.rms_norm = RMSNorm(config.input_dim)
565
+ self.embed = nn.Embedding(self.vocab_size, config.input_dim)
566
+ self.unembed = nn.Linear(config.input_dim, self.vocab_size, bias=False)
567
+ self.unembed.weight = self.embed.weight
568
+
569
+ self.projections = nn.ModuleList(
570
+ [nn.Linear(2 * config.input_dim, config.input_dim) for _ in range(self.k)]
571
+ )
572
+
573
+ self.transformers = nn.ModuleList(
574
+ [
575
+ nn.TransformerEncoderLayer(config.input_dim, config.num_attention_heads)
576
+ for _ in range(self.k)
577
+ ]
578
+ )
579
+
580
+ def forward(self, x):
581
+ # x is the final hidden states for all tokens that we get after all transformer blocks,
582
+ # so it is just before the final un-ebedding layer
583
+ batch_size, num_tokens, input_size = x.shape
584
+ # if num_tokens is 6
585
+ # i = 0, 1, 2, 3, 4, 5
586
+ # k=3
587
+ # i can predict till 2+3 = 5
588
+ # so i have to iterate i from 0 to 2 only
589
+ # 2 = 6(num_tokens)-3(k)-1
590
+ # so I have to go till x[:,num_tokens-k, :]
591
+
592
+ logits = []
593
+
594
+ for ith_token_pos in range(0, num_tokens - self.k):
595
+ hidden_state_ith_token = x[:, ith_token_pos, :]
596
+
597
+ logits_k = []
598
+ for k in range(self.k):
599
+
600
+ future_position = ith_token_pos + k + 1
601
+ token_embedding = x[
602
+ :, future_position, :
603
+ ] # considering x as the final hidden state after all blocks
604
+
605
+ _h = self.rms_norm(hidden_state_ith_token)
606
+ _e = self.rms_norm(token_embedding)
607
+ merged = torch.cat([_h, _e], dim=1)
608
+
609
+ proj = self.projections[k](merged).unsqueeze(0)
610
+ out = self.transformers[k](proj)
611
+ hidden_state_current = out.squeeze(0)
612
+ _logits = self.unembed(hidden_state_current)
613
+ logits_k.append(_logits)
614
+
615
+ hidden_state_ith_token = hidden_state_current
616
+
617
+ logits_k = torch.stack(logits_k, dim=1)
618
+ logits.append(logits_k)
619
+
620
+ logits = torch.stack(logits, dim=0)
621
+ logits = logits.permute(1, 0, 2, 3).contiguous()
622
+ return logits
623
+
624
+
625
+ class TransformerBlock(nn.Module):
626
+
627
+ def __init__(self, config: DeepSeekModelConfig, moe: bool = True):
628
+ super().__init__()
629
+ self.rms_norm_1 = RMSNorm(config.input_dim)
630
+ self.mhla = MultiHeadLatentAttention(config)
631
+ self.rms_norm_2 = RMSNorm(config.input_dim)
632
+
633
+ if moe:
634
+ self.ffn = MoE(config)
635
+ else:
636
+ self.ffn = Expert(
637
+ config.input_dim, config.expert_intermediate_dim, config.dropout
638
+ )
639
+
640
+ def forward(self, x):
641
+ x = x + self.mhla(self.rms_norm_1(x))
642
+ x = x + self.ffn(self.rms_norm_2(x))
643
+ return x
644
+
645
+
646
+ class DeepseekInspiredModel(nn.Module):
647
+ def __init__(self, config: DeepSeekModelConfig):
648
+ super().__init__()
649
+ self.config = config
650
+ self.token_embedding = nn.Embedding(config.vocab_size, config.input_dim)
651
+ self.position_embedding = nn.Embedding(config.max_token_len, config.input_dim)
652
+
653
+ _blocks = [
654
+ TransformerBlock(config, moe=False) for _ in range(config.num_dense_ffn)
655
+ ]
656
+ _blocks.extend(
657
+ [TransformerBlock(config, moe=True) for _ in range(config.num_moe_ffn)]
658
+ )
659
+ self.transformer_blocks = nn.ModuleList(_blocks)
660
+
661
+ self.ln_f = RMSNorm(config.input_dim)
662
+ self.head = nn.Linear(config.input_dim, config.vocab_size, bias=False)
663
+ self.head.weight = self.token_embedding.weight
664
+
665
+ def forward(self, x):
666
+ batch_size, num_tokens = x.shape
667
+
668
+ token_embeddings = self.token_embedding(x)
669
+ position_ids = torch.arange(0, num_tokens, device=x.device).unsqueeze(0)
670
+ position_embeddings = self.position_embedding(position_ids)
671
+ h = token_embeddings + position_embeddings
672
+
673
+ for block in self.transformer_blocks:
674
+ h = block(h)
675
+ h = self.ln_f(h)
676
+ logits = self.head(h)
677
+ return logits
678
+
679
+
680
+ if __name__ == "__main__":
681
+ config = DeepSeekModelConfig()
682
+ x = torch.rand(1, 10)
683
+
684
+ dim = DeepseekInspiredModel(config)
685
+
686
+ print(
687
+ f"Number of parameters (in millions): {sum(p.numel() for p in dim.parameters()) / 1_000_000}"
688
+ )
689
+ print(
690
+ f"Number of parameters (in GB): {sum(p.numel() for p in dim.parameters())*4/1024**3:.2f} GB"
691
+ )