File size: 23,201 Bytes
d56eb1d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 |
"""# ββββββββββββ
# `mla.py`
Based on: https://huggingface.co/deepseek-ai/DeepSeek-R1/blob/main/modeling_deepseek.py
## RotaryEmbedding
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional
from .shared_space_config import SharedSpaceDecoderConfig
def create_norm_layer(hidden_size: int, config: SharedSpaceDecoderConfig) -> nn.Module:
"""
Create a normalization layer based on the config norm_type.
If `hidden_size` is `None`, this returns an identity layer.
Args:
hidden_size: The dimension to normalize over
config: Configuration containing norm_type and epsilon values
Returns:
Either a LayerNorm or RMSNorm layer
"""
if hidden_size is None:
return nn.Identity()
elif config.norm_type == "layernorm":
return nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)
elif config.norm_type == "rmsnorm":
return DeepseekV3RMSNorm(hidden_size, eps=config.rms_norm_eps)
else:
# This should be caught by config validation, but being defensive
raise ValueError(f"Unknown norm_type: {config.norm_type}")
# TODO - Find a shared place to put this.
class DeepseekV3RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
DeepseekV3RMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
# Helper function needed because it's called twice during RoPE,
# but I dumped it in the comments there.
# TODO - Nah, screw it, just write it twice! At least then you get
# to use the word 'query' instead of 'x'.
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
class RotaryEmbedding(nn.Module):
"""Precompute RoPE embeddings and store them as buffers."""
def __init__(self, config: SharedSpaceDecoderConfig) -> None:
super().__init__()
dim = config.rope_dims
seq_len = config.max_position_embeddings
# ------------------------------
# Compute inverse frequencies
# ------------------------------
# Shape: [dim // 2]
# inv_freq[i] = 1 / (theta^(i / dim))
inv_freq = 1.0 / (
config.rope_theta
** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)
)
# ------------------------------
# Apply RoPE scaling if configured
# ------------------------------
if config.rope_scaling is not None:
scaling_type = config.rope_scaling.get("type", "linear")
scaling_factor = config.rope_scaling.get("factor", 1.0)
if scaling_type == "linear":
# Linear scaling: divide frequencies by scaling factor
inv_freq = inv_freq / scaling_factor
elif scaling_type == "dynamic":
# Dynamic scaling: adjust based on sequence length
# This is a simplified implementation
inv_freq = inv_freq / scaling_factor
else:
print(f"Warning: Unknown RoPE scaling type '{scaling_type}', using linear scaling")
inv_freq = inv_freq / scaling_factor
# ------------------------------
# Compute position indices
# ------------------------------
# Shape: [seq_len]
t = torch.arange(seq_len, dtype=torch.float32)
# ------------------------------
# Outer product: [seq_len, dim // 2]
# Each row i contains: t[i] * inv_freq
# ------------------------------
freqs = torch.outer(t, inv_freq)
# ------------------------------
# Duplicate for interleaved sin/cos: [seq_len, dim]
# This matches the common format: [sin_0, cos_0, sin_1, cos_1, ...]
# ------------------------------
emb = torch.cat((freqs, freqs), dim=-1)
# ------------------------------
# Register cos/sin as buffers
# - Stored in float32
# - Will be moved to correct device/dtype via model.to(...)
# - Not saved with state_dict (persistent=False)
# ------------------------------
self.register_buffer("cos", emb.cos(), persistent=False)
self.register_buffer("sin", emb.sin(), persistent=False)
def forward(self, position_ids: torch.LongTensor) -> tuple[torch.Tensor, torch.Tensor]:
""" """
return None # This function is not necessary.
"""## MLA"""
class MultiheadLatentAttention(nn.Module):
"""
A variant of MLA with:
- Simplified RoPE handling:
- A portion of the head dimensions are used for position information.
- Same number of queries as keys. (no MQA)
- Optional output subspace
"""
def __init__(self, config: SharedSpaceDecoderConfig, layer_idx: int):
super().__init__()
self.config = config
# Used to determine if this layer is dense or uses latents.
self.layer_idx = layer_idx
self.attention_dropout_prob = config.attention_dropout_prob
self.num_heads = config.num_attention_heads
self.rope_theta = config.rope_theta
self.rope_dims = config.rope_dims
self.nope_dims = config.nope_dims
self.q_shared_dim = config.q_shared_dim
self.kv_shared_dim = config.kv_shared_dim
self.o_shared_dim = config.o_shared_dim
self.qk_private_dim = config.qk_private_dim
self.vo_private_dim = config.vo_private_dim
self.hidden_size = config.hidden_size
# =========================
# Input Projections
# =========================
# If this is one of the dense layers,
if self.layer_idx < config.num_dense_layers:
# =========================
# Dense Attention
# =========================
# No latent projections.
self.latent_spaces = False
# Define the standard QKV projection
self.qkv_proj = nn.Linear(
config.hidden_size,
self.num_heads * (self.qk_private_dim * 2 + self.vo_private_dim),
bias=config.attention_bias,
)
# Dense output projection
self.o_proj = nn.Linear(
self.num_heads * self.vo_private_dim,
config.hidden_size,
bias=config.attention_bias,
)
# If we're past the dense layers,
else:
# =========================
# Latent Attention
# =========================
# Use latent projections.
self.latent_spaces = True
# Input latent projections
print("config.q_shared_dim", config.q_shared_dim)
# If we're using a shared query subspace,
if config.q_shared_dim is not None:
# Set a flag that we'll check in `forward`.
self.query_shared = True
self.q_shared_proj = nn.Linear(
config.hidden_size,
self.q_shared_dim,
bias=config.attention_bias,
)
self.q_shared_norm = create_norm_layer(self.q_shared_dim, config)
else:
print("Using identity for shared projection.")
# Set a flag that we'll check in `forward`.
self.query_shared = False
self.q_shared_dim = config.hidden_size
#print("Updated self.q_shared_dim to", self.q_shared_dim)
# Use identity.
self.q_shared_proj = nn.Identity()
self.q_shared_norm = nn.Identity()
# If we're using a shared key/value subspace,
if config.kv_shared_dim is not None:
# Set a flag that we'll check in `forward`.
self.keyvalue_shared = True
self.kv_shared_proj = nn.Linear(
config.hidden_size,
self.kv_shared_dim,
bias=config.attention_bias,
)
self.kv_shared_norm = create_norm_layer(self.kv_shared_dim, config)
else:
# Set a flag that we'll check in `forward`.
self.keyvalue_shared = False
self.kv_shared_dim = config.hidden_size
# Use identity.
self.kv_shared_proj = nn.Identity()
self.kv_shared_norm = nn.Identity()
#print("config.q_shared_dim", config.q_shared_dim)
#print("self.qk_private_dim", self.qk_private_dim)
# Query heads
self.q_private_proj = nn.Linear(
self.q_shared_dim,
self.num_heads * self.qk_private_dim,
bias=False # TODO
)
# Key and Value heads, concatenated
self.kv_private_proj = nn.Linear(
self.kv_shared_dim,
self.num_heads * (self.qk_private_dim + self.vo_private_dim),
bias=False,
)
# Use output subspace if o_shared_dim is specified
self.output_subspace = config.o_shared_dim is not None
# If we're using an output subspace,
if self.output_subspace:
# ==========================
# Output Subspace
# ==========================
self.o_shared_dim = config.o_shared_dim
# Per-head output projections
# (Similar to original W^O, but projects the scored value vectors
# into a latent space instead of back to the model)
self.o_private_proj = nn.Linear(
self.num_heads * self.vo_private_dim,
self.o_shared_dim,
bias=False
)
# Norm layer between o_private_proj and o_shared_proj
# Note: In previous ViT experiments, this norm step hurt performance, but was beneficial
# in the DeepSeekV3 experiments.
# However, we're making it configurable so it can be tested in different contexts.
self.o_private_norm = create_norm_layer(self.o_shared_dim, config)
# Shared output projection
# The head outputs from `o_private_proj` are first summed together (across
# heads) in the latent space.
# Then we project their combined outputs (a single vector per token)
# back to model space via `o_shared_proj`.
self.o_shared_proj = nn.Linear(
self.o_shared_dim,
self.hidden_size,
bias=config.attention_bias
)
else:
# Dense output projection
self.o_proj = nn.Linear(
self.num_heads * self.vo_private_dim,
config.hidden_size,
bias=config.attention_bias,
)
# Softmax scaling factor.
self.softmax_scale = self.qk_private_dim ** (-0.5)
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor],
#past_key_value: Optional[Cache] = None, # TODO - Can I remove this?
#cache_position: Optional[torch.LongTensor] = None, # TODO - Can I remove this?
**kwargs,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
# === Tensor Dimension Symbols ===
# B: batch_size β number of samples in the batch
# T: seq_len β number of tokens per sample
# H: n_heads β number of attention heads
# D: hidden_dim β model embedding size
# Dv: vo_private_dim - per-head value/output projection dimension
# Dr: rope_dims - The first Dr dimensions receive rope.
# Cq: q_shared_dim - query shared subspace size
# Ckv: kv_shared_dim - key-value shared subspace size
# Co: o_shared_dim - output shared subspace size
# Input token embeddings
# hidden_states: [B, T, D]
B, T = hidden_states.shape[:2]
H = self.num_heads
Dq = self.qk_private_dim # per-head dim for Q and K
Dv = self.vo_private_dim # per-head dim for V/O
Dc_q, Dc_kv = self.q_shared_dim, self.kv_shared_dim
# ==============================
# QKV Head Projections
# ==============================
# Project tokens into per-head query, key, and value vectors
# If this layer uses latent projections,
if self.latent_spaces:
# ================================
# Shared Space Projections
# ================================
# Project token embeddings into shared latents
# Input:
# hidden_states [B, T, D]
# q_shared_proj [D, Cq]
# kv_shared_proj [D, Ckv]
# Output:
# q_shared [B, T, Cq]
# kv_shared [B, T, Ckv]
# If we're using a shared query subspace,
if self.q_shared_dim is not None:
q_shared = self.q_shared_proj(hidden_states)
# Normalize latent vectors, shapes unchanged.
q_shared = self.q_shared_norm(q_shared)
# Otherwise,
else:
# Use the hidden states
q_shared = hidden_states
# If we're using a shared key/value subspace,
if self.kv_shared_dim is not None:
# Project token embeddings into shared subspace.
kv_shared = self.kv_shared_proj(hidden_states)
# Normalize latent vectors, shapes unchanged.
kv_shared = self.kv_shared_norm(kv_shared)
# Otherwise,
else:
# Use the hidden states
kv_shared = hidden_states
# ======================================
# Per-Head (Private) Projections
# ======================================
# Project query latents onto query heads.
# Input:
# q_shared [B, T, Cq]
# q_private_proj [Cq, H*Dh]
# Output:
# queries [B, T, H*Dh]
queries = self.q_private_proj(q_shared)
# Project key/value latents onto key and value heads.
# The key and value heads are all concatenated, each head occupies
# Dh columns of the kv_private_proj. This yields the key and value
# vectors concatenated in the same way.
#
# Input:
# kv_shared [B, T, Ckv]
# kv_private_proj [Ckv, 2*H*Dh]
# Output:
# keysvalues [B, T, 2*H*Dh]
keysvalues = self.kv_private_proj(kv_shared)
# Split into key and value tensors
# Each: [B, T, H * Dh]
keys, values = keysvalues.chunk(2, dim=-1)
# If this is a dense attention layer (no latent projections),
else:
# ====================
# Standard MHA
# ====================
# Standard QKV projection
# Input:
# hidden_states [B, T, D]
# qkv_proj [D, 3*H*Dh]
# Output:
# querieskeysvalues [B, T, 3*H*Dh]
querieskeysvalues = self.qkv_proj(hidden_states)
# Separate query, key, and value vectors
# Each: [B, T, H * Dh]
queries, keys, values = querieskeysvalues.chunk(3, dim=-1)
# Split up queries so that there's just one per row.
# Same for keys and values.
#
# Inputs:
# Each [B, T, H*Dh]
# Output:
# Each [B, H, T, Dh]
queries = queries.view(B, T, H, Dq).transpose(1, 2)
keys = keys.view(B, T, H, Dq).transpose(1, 2)
values = values.view(B, T, H, Dv).transpose(1, 2)
# ==================
# RoPE
# ==================
# Apply rotary position embeddings to the first `self.rope_dims` of
# each head.
# The slice operations are free, but the concatenation is
# not, because the outputs of the rotation operation are new data
# occupying different memory. Still considered the best option,
# though.
# 1. Unpack the precomputed cosine and sine embeddings
# Position embeddings is a tuple of
# (cos [seq_len, rope_dims],
# sin [seq_len, rope_dims])
cos, sin = position_embeddings
# 2. Split the query and key heads into the part to rotate and the part
# to pass through (early columns get position info, later ones don't)
#
# (Using queries as example)
# Inputs:
# queries [B, H, T, Dh] Dh = rope_dims + not_rope_dims
# Outputs:
# q_rope [B, H, T, Dr]
# q_pass [B, H, T, Dh-Dr]
q_rope, q_pass = queries[..., :self.rope_dims], queries[..., self.rope_dims:]
k_rope, k_pass = keys[..., :self.rope_dims], keys[..., self.rope_dims:]
# 3. Apply the rotary embedding to the designated slice
#
# To broadcast cos and sin across the batch and head dimensions, we unsqueeze them.
# Shape change: [T, Dr] -> [1, 1, T, Dr]
cos = cos.unsqueeze(0).unsqueeze(0)
sin = sin.unsqueeze(0).unsqueeze(0)
#print("q_rope.shape[-1] // 2:", (q_rope.shape[-1] // 2))
#print("x1 = x[..., :x.shape[-1] // 2 ].shape:", q_rope[..., :q_rope.shape[-1] // 2 ].shape)
#print("sin/cos.shape:", cos.shape)
#print("q_rope.shape:", q_rope.shape)
#print("(q_rope * cos).shape:", (q_rope * cos).shape)
#print("rotate_half(q_rope).shape:", rotate_half(q_rope).shape)
#print("(rotate_half(q_rope) * sin).shape:", (rotate_half(q_rope) * sin).shape)
"""
In this example batch_size = 2, hum_heads = 8, seq_len = 65, rope_dims = 16
q_rope.shape[-1] // 2: 8
x1 = x[..., :x.shape[-1] // 2 ].shape: torch.Size([2, 8, 65, 8])
sin/cos.shape: torch.Size([1, 1, 65, 16]) # After double unsqueeze.
vq_rope.shape: torch.Size([2, 8, 65, 16])
(q_rope * cos).shape: torch.Size([2, 8, 65, 16])
rotate_half(q_rope).shape: torch.Size([2, 8, 65, 16])
(rotate_half(q_rope) * sin).shape: torch.Size([2, 8, 65, 16])
"""
# Let's walk through the queries as the example.
# What does rotate half do?
# dim -1 is the row vectors, the queries
#
# Step 1: Split the vector in half.
# "q_rope.shape[-1] // 2" <- How much to select. Half the length of the q_rope vector
# x1 = x[..., :x.shape[-1] // 2 ] # Select the first half of the vector.
# x2 = x[..., x.shape[-1] // 2:] # Select the second half.
#
# Step 2:
# - Apply negative to the values in the second half.
# - Reverse the order of the halves.
# return torch.cat((-x2, x1), dim=-1)
#
# ---- (q_rope * cos) ----
# Element-wise multiply the values in each `cos` vector with the
# corresponding (i.e., same sequence position) `q_rope` vector.
#
# Inputs:
# q_rope [B, H, T, Dr]
# cos [1, 1, T, Dr]
#
# Outputs:
# x [B, H, T, Dr]
#
# ---- (rotate_half(q_rope)) ----
# TODO
#
# Inputs:
# q_rope [B, T, Dr]
#
# Outputs:
# rot_q_rope [B, T, Dr]
#
# ---- rotated * sin ----
# TODO
q_rotated = (q_rope * cos) + (rotate_half(q_rope) * sin)
k_rotated = (k_rope * cos) + (rotate_half(k_rope) * sin)
# 4. Concatenate the rotated and pass-through parts back together
# Input (each): [B, H, T, Dr] and [B, H, T, Dq-Dr]
# Output (each): [B, H, T, Dq]
queries = torch.cat((q_rotated, q_pass), dim=-1)
keys = torch.cat((k_rotated, k_pass), dim=-1)
# ===================
# Attention
# ===================
# The tensors (queries, keys, values) now have shape [B, H, T, Dq]
# and are ready for the attention score calculation.
# Only apply dropout during training.
# self.training is a pytorch flag.
if self.training:
dropout_p = self.attention_dropout_prob
else:
dropout_p = 0.0
# Call SDPA / Flash Attention
# https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
attn_output = F.scaled_dot_product_attention(
queries,
keys,
values,
attn_mask=None, # attention_mask,
dropout_p=dropout_p,
scale=self.softmax_scale,
is_causal=True, # This is a decoder - apply causal masking
)
# Reshape output back to [B, T, H * Dv] from [B, H, T, Dv]
attn_output = attn_output.transpose(1, 2).contiguous().view(B, T, H * Dv)
# =========================
# Output Projection
# =========================
# If we are using an output latent projection,
if self.latent_spaces and self.output_subspace:
# Project the attention output into the output latent space.
# This is analogous to the W^O matrix in standard attention but
# projects to an intermediate latent dimension.
attn_output = self.o_private_proj(attn_output)
# Apply normalization to the output latents
attn_output = self.o_private_norm(attn_output)
# Re-project the output latent representation back to model space.
attn_output = self.o_shared_proj(attn_output)
# If this is a dense layer,
else:
# Project the values back into model space.
attn_output = self.o_proj(attn_output)
# -----------------------------------------
return attn_output
|