| """# ββββββββββββ | |
| # `mla.py` | |
| Based on: https://huggingface.co/deepseek-ai/DeepSeek-R1/blob/main/modeling_deepseek.py | |
| ## RotaryEmbedding | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from typing import Optional | |
| from .shared_space_config import SharedSpaceDecoderConfig | |
| def create_norm_layer(hidden_size: int, config: SharedSpaceDecoderConfig) -> nn.Module: | |
| """ | |
| Create a normalization layer based on the config norm_type. | |
| If `hidden_size` is `None`, this returns an identity layer. | |
| Args: | |
| hidden_size: The dimension to normalize over | |
| config: Configuration containing norm_type and epsilon values | |
| Returns: | |
| Either a LayerNorm or RMSNorm layer | |
| """ | |
| if hidden_size is None: | |
| return nn.Identity() | |
| elif config.norm_type == "layernorm": | |
| return nn.LayerNorm(hidden_size, eps=config.layer_norm_eps) | |
| elif config.norm_type == "rmsnorm": | |
| return DeepseekV3RMSNorm(hidden_size, eps=config.rms_norm_eps) | |
| else: | |
| # This should be caught by config validation, but being defensive | |
| raise ValueError(f"Unknown norm_type: {config.norm_type}") | |
| # TODO - Find a shared place to put this. | |
| class DeepseekV3RMSNorm(nn.Module): | |
| def __init__(self, hidden_size, eps=1e-6): | |
| """ | |
| DeepseekV3RMSNorm is equivalent to T5LayerNorm | |
| """ | |
| super().__init__() | |
| self.weight = nn.Parameter(torch.ones(hidden_size)) | |
| self.variance_epsilon = eps | |
| def forward(self, hidden_states): | |
| input_dtype = hidden_states.dtype | |
| hidden_states = hidden_states.to(torch.float32) | |
| variance = hidden_states.pow(2).mean(-1, keepdim=True) | |
| hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) | |
| return self.weight * hidden_states.to(input_dtype) | |
| # Helper function needed because it's called twice during RoPE, | |
| # but I dumped it in the comments there. | |
| # TODO - Nah, screw it, just write it twice! At least then you get | |
| # to use the word 'query' instead of 'x'. | |
| def rotate_half(x): | |
| """Rotates half the hidden dims of the input.""" | |
| x1 = x[..., : x.shape[-1] // 2] | |
| x2 = x[..., x.shape[-1] // 2 :] | |
| return torch.cat((-x2, x1), dim=-1) | |
| class RotaryEmbedding(nn.Module): | |
| """Precompute RoPE embeddings and store them as buffers.""" | |
| def __init__(self, config: SharedSpaceDecoderConfig) -> None: | |
| super().__init__() | |
| dim = config.rope_dims | |
| seq_len = config.max_position_embeddings | |
| # ------------------------------ | |
| # Compute inverse frequencies | |
| # ------------------------------ | |
| # Shape: [dim // 2] | |
| # inv_freq[i] = 1 / (theta^(i / dim)) | |
| inv_freq = 1.0 / ( | |
| config.rope_theta | |
| ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim) | |
| ) | |
| # ------------------------------ | |
| # Apply RoPE scaling if configured | |
| # ------------------------------ | |
| if config.rope_scaling is not None: | |
| scaling_type = config.rope_scaling.get("type", "linear") | |
| scaling_factor = config.rope_scaling.get("factor", 1.0) | |
| if scaling_type == "linear": | |
| # Linear scaling: divide frequencies by scaling factor | |
| inv_freq = inv_freq / scaling_factor | |
| elif scaling_type == "dynamic": | |
| # Dynamic scaling: adjust based on sequence length | |
| # This is a simplified implementation | |
| inv_freq = inv_freq / scaling_factor | |
| else: | |
| print(f"Warning: Unknown RoPE scaling type '{scaling_type}', using linear scaling") | |
| inv_freq = inv_freq / scaling_factor | |
| # ------------------------------ | |
| # Compute position indices | |
| # ------------------------------ | |
| # Shape: [seq_len] | |
| t = torch.arange(seq_len, dtype=torch.float32) | |
| # ------------------------------ | |
| # Outer product: [seq_len, dim // 2] | |
| # Each row i contains: t[i] * inv_freq | |
| # ------------------------------ | |
| freqs = torch.outer(t, inv_freq) | |
| # ------------------------------ | |
| # Duplicate for interleaved sin/cos: [seq_len, dim] | |
| # This matches the common format: [sin_0, cos_0, sin_1, cos_1, ...] | |
| # ------------------------------ | |
| emb = torch.cat((freqs, freqs), dim=-1) | |
| # ------------------------------ | |
| # Register cos/sin as buffers | |
| # - Stored in float32 | |
| # - Will be moved to correct device/dtype via model.to(...) | |
| # - Not saved with state_dict (persistent=False) | |
| # ------------------------------ | |
| self.register_buffer("cos", emb.cos(), persistent=False) | |
| self.register_buffer("sin", emb.sin(), persistent=False) | |
| def forward(self, position_ids: torch.LongTensor) -> tuple[torch.Tensor, torch.Tensor]: | |
| """ """ | |
| return None # This function is not necessary. | |
| """## MLA""" | |
| class MultiheadLatentAttention(nn.Module): | |
| """ | |
| A variant of MLA with: | |
| - Simplified RoPE handling: | |
| - A portion of the head dimensions are used for position information. | |
| - Same number of queries as keys. (no MQA) | |
| - Optional output subspace | |
| """ | |
| def __init__(self, config: SharedSpaceDecoderConfig, layer_idx: int): | |
| super().__init__() | |
| self.config = config | |
| # Used to determine if this layer is dense or uses latents. | |
| self.layer_idx = layer_idx | |
| self.attention_dropout_prob = config.attention_dropout_prob | |
| self.num_heads = config.num_attention_heads | |
| self.rope_theta = config.rope_theta | |
| self.rope_dims = config.rope_dims | |
| self.nope_dims = config.nope_dims | |
| self.q_shared_dim = config.q_shared_dim | |
| self.kv_shared_dim = config.kv_shared_dim | |
| self.o_shared_dim = config.o_shared_dim | |
| self.qk_private_dim = config.qk_private_dim | |
| self.vo_private_dim = config.vo_private_dim | |
| self.hidden_size = config.hidden_size | |
| # ========================= | |
| # Input Projections | |
| # ========================= | |
| # If this is one of the dense layers, | |
| if self.layer_idx < config.num_dense_layers: | |
| # ========================= | |
| # Dense Attention | |
| # ========================= | |
| # No latent projections. | |
| self.latent_spaces = False | |
| # Define the standard QKV projection | |
| self.qkv_proj = nn.Linear( | |
| config.hidden_size, | |
| self.num_heads * (self.qk_private_dim * 2 + self.vo_private_dim), | |
| bias=config.attention_bias, | |
| ) | |
| # Dense output projection | |
| self.o_proj = nn.Linear( | |
| self.num_heads * self.vo_private_dim, | |
| config.hidden_size, | |
| bias=config.attention_bias, | |
| ) | |
| # If we're past the dense layers, | |
| else: | |
| # ========================= | |
| # Latent Attention | |
| # ========================= | |
| # Use latent projections. | |
| self.latent_spaces = True | |
| # Input latent projections | |
| print("config.q_shared_dim", config.q_shared_dim) | |
| # If we're using a shared query subspace, | |
| if config.q_shared_dim is not None: | |
| # Set a flag that we'll check in `forward`. | |
| self.query_shared = True | |
| self.q_shared_proj = nn.Linear( | |
| config.hidden_size, | |
| self.q_shared_dim, | |
| bias=config.attention_bias, | |
| ) | |
| self.q_shared_norm = create_norm_layer(self.q_shared_dim, config) | |
| else: | |
| print("Using identity for shared projection.") | |
| # Set a flag that we'll check in `forward`. | |
| self.query_shared = False | |
| self.q_shared_dim = config.hidden_size | |
| #print("Updated self.q_shared_dim to", self.q_shared_dim) | |
| # Use identity. | |
| self.q_shared_proj = nn.Identity() | |
| self.q_shared_norm = nn.Identity() | |
| # If we're using a shared key/value subspace, | |
| if config.kv_shared_dim is not None: | |
| # Set a flag that we'll check in `forward`. | |
| self.keyvalue_shared = True | |
| self.kv_shared_proj = nn.Linear( | |
| config.hidden_size, | |
| self.kv_shared_dim, | |
| bias=config.attention_bias, | |
| ) | |
| self.kv_shared_norm = create_norm_layer(self.kv_shared_dim, config) | |
| else: | |
| # Set a flag that we'll check in `forward`. | |
| self.keyvalue_shared = False | |
| self.kv_shared_dim = config.hidden_size | |
| # Use identity. | |
| self.kv_shared_proj = nn.Identity() | |
| self.kv_shared_norm = nn.Identity() | |
| #print("config.q_shared_dim", config.q_shared_dim) | |
| #print("self.qk_private_dim", self.qk_private_dim) | |
| # Query heads | |
| self.q_private_proj = nn.Linear( | |
| self.q_shared_dim, | |
| self.num_heads * self.qk_private_dim, | |
| bias=False # TODO | |
| ) | |
| # Key and Value heads, concatenated | |
| self.kv_private_proj = nn.Linear( | |
| self.kv_shared_dim, | |
| self.num_heads * (self.qk_private_dim + self.vo_private_dim), | |
| bias=False, | |
| ) | |
| # Use output subspace if o_shared_dim is specified | |
| self.output_subspace = config.o_shared_dim is not None | |
| # If we're using an output subspace, | |
| if self.output_subspace: | |
| # ========================== | |
| # Output Subspace | |
| # ========================== | |
| self.o_shared_dim = config.o_shared_dim | |
| # Per-head output projections | |
| # (Similar to original W^O, but projects the scored value vectors | |
| # into a latent space instead of back to the model) | |
| self.o_private_proj = nn.Linear( | |
| self.num_heads * self.vo_private_dim, | |
| self.o_shared_dim, | |
| bias=False | |
| ) | |
| # Norm layer between o_private_proj and o_shared_proj | |
| # Note: In previous ViT experiments, this norm step hurt performance, but was beneficial | |
| # in the DeepSeekV3 experiments. | |
| # However, we're making it configurable so it can be tested in different contexts. | |
| self.o_private_norm = create_norm_layer(self.o_shared_dim, config) | |
| # Shared output projection | |
| # The head outputs from `o_private_proj` are first summed together (across | |
| # heads) in the latent space. | |
| # Then we project their combined outputs (a single vector per token) | |
| # back to model space via `o_shared_proj`. | |
| self.o_shared_proj = nn.Linear( | |
| self.o_shared_dim, | |
| self.hidden_size, | |
| bias=config.attention_bias | |
| ) | |
| else: | |
| # Dense output projection | |
| self.o_proj = nn.Linear( | |
| self.num_heads * self.vo_private_dim, | |
| config.hidden_size, | |
| bias=config.attention_bias, | |
| ) | |
| # Softmax scaling factor. | |
| self.softmax_scale = self.qk_private_dim ** (-0.5) | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| position_embeddings: tuple[torch.Tensor, torch.Tensor], | |
| attention_mask: Optional[torch.Tensor], | |
| #past_key_value: Optional[Cache] = None, # TODO - Can I remove this? | |
| #cache_position: Optional[torch.LongTensor] = None, # TODO - Can I remove this? | |
| **kwargs, | |
| ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: | |
| # === Tensor Dimension Symbols === | |
| # B: batch_size β number of samples in the batch | |
| # T: seq_len β number of tokens per sample | |
| # H: n_heads β number of attention heads | |
| # D: hidden_dim β model embedding size | |
| # Dv: vo_private_dim - per-head value/output projection dimension | |
| # Dr: rope_dims - The first Dr dimensions receive rope. | |
| # Cq: q_shared_dim - query shared subspace size | |
| # Ckv: kv_shared_dim - key-value shared subspace size | |
| # Co: o_shared_dim - output shared subspace size | |
| # Input token embeddings | |
| # hidden_states: [B, T, D] | |
| B, T = hidden_states.shape[:2] | |
| H = self.num_heads | |
| Dq = self.qk_private_dim # per-head dim for Q and K | |
| Dv = self.vo_private_dim # per-head dim for V/O | |
| Dc_q, Dc_kv = self.q_shared_dim, self.kv_shared_dim | |
| # ============================== | |
| # QKV Head Projections | |
| # ============================== | |
| # Project tokens into per-head query, key, and value vectors | |
| # If this layer uses latent projections, | |
| if self.latent_spaces: | |
| # ================================ | |
| # Shared Space Projections | |
| # ================================ | |
| # Project token embeddings into shared latents | |
| # Input: | |
| # hidden_states [B, T, D] | |
| # q_shared_proj [D, Cq] | |
| # kv_shared_proj [D, Ckv] | |
| # Output: | |
| # q_shared [B, T, Cq] | |
| # kv_shared [B, T, Ckv] | |
| # If we're using a shared query subspace, | |
| if self.q_shared_dim is not None: | |
| q_shared = self.q_shared_proj(hidden_states) | |
| # Normalize latent vectors, shapes unchanged. | |
| q_shared = self.q_shared_norm(q_shared) | |
| # Otherwise, | |
| else: | |
| # Use the hidden states | |
| q_shared = hidden_states | |
| # If we're using a shared key/value subspace, | |
| if self.kv_shared_dim is not None: | |
| # Project token embeddings into shared subspace. | |
| kv_shared = self.kv_shared_proj(hidden_states) | |
| # Normalize latent vectors, shapes unchanged. | |
| kv_shared = self.kv_shared_norm(kv_shared) | |
| # Otherwise, | |
| else: | |
| # Use the hidden states | |
| kv_shared = hidden_states | |
| # ====================================== | |
| # Per-Head (Private) Projections | |
| # ====================================== | |
| # Project query latents onto query heads. | |
| # Input: | |
| # q_shared [B, T, Cq] | |
| # q_private_proj [Cq, H*Dh] | |
| # Output: | |
| # queries [B, T, H*Dh] | |
| queries = self.q_private_proj(q_shared) | |
| # Project key/value latents onto key and value heads. | |
| # The key and value heads are all concatenated, each head occupies | |
| # Dh columns of the kv_private_proj. This yields the key and value | |
| # vectors concatenated in the same way. | |
| # | |
| # Input: | |
| # kv_shared [B, T, Ckv] | |
| # kv_private_proj [Ckv, 2*H*Dh] | |
| # Output: | |
| # keysvalues [B, T, 2*H*Dh] | |
| keysvalues = self.kv_private_proj(kv_shared) | |
| # Split into key and value tensors | |
| # Each: [B, T, H * Dh] | |
| keys, values = keysvalues.chunk(2, dim=-1) | |
| # If this is a dense attention layer (no latent projections), | |
| else: | |
| # ==================== | |
| # Standard MHA | |
| # ==================== | |
| # Standard QKV projection | |
| # Input: | |
| # hidden_states [B, T, D] | |
| # qkv_proj [D, 3*H*Dh] | |
| # Output: | |
| # querieskeysvalues [B, T, 3*H*Dh] | |
| querieskeysvalues = self.qkv_proj(hidden_states) | |
| # Separate query, key, and value vectors | |
| # Each: [B, T, H * Dh] | |
| queries, keys, values = querieskeysvalues.chunk(3, dim=-1) | |
| # Split up queries so that there's just one per row. | |
| # Same for keys and values. | |
| # | |
| # Inputs: | |
| # Each [B, T, H*Dh] | |
| # Output: | |
| # Each [B, H, T, Dh] | |
| queries = queries.view(B, T, H, Dq).transpose(1, 2) | |
| keys = keys.view(B, T, H, Dq).transpose(1, 2) | |
| values = values.view(B, T, H, Dv).transpose(1, 2) | |
| # ================== | |
| # RoPE | |
| # ================== | |
| # Apply rotary position embeddings to the first `self.rope_dims` of | |
| # each head. | |
| # The slice operations are free, but the concatenation is | |
| # not, because the outputs of the rotation operation are new data | |
| # occupying different memory. Still considered the best option, | |
| # though. | |
| # 1. Unpack the precomputed cosine and sine embeddings | |
| # Position embeddings is a tuple of | |
| # (cos [seq_len, rope_dims], | |
| # sin [seq_len, rope_dims]) | |
| cos, sin = position_embeddings | |
| # 2. Split the query and key heads into the part to rotate and the part | |
| # to pass through (early columns get position info, later ones don't) | |
| # | |
| # (Using queries as example) | |
| # Inputs: | |
| # queries [B, H, T, Dh] Dh = rope_dims + not_rope_dims | |
| # Outputs: | |
| # q_rope [B, H, T, Dr] | |
| # q_pass [B, H, T, Dh-Dr] | |
| q_rope, q_pass = queries[..., :self.rope_dims], queries[..., self.rope_dims:] | |
| k_rope, k_pass = keys[..., :self.rope_dims], keys[..., self.rope_dims:] | |
| # 3. Apply the rotary embedding to the designated slice | |
| # | |
| # To broadcast cos and sin across the batch and head dimensions, we unsqueeze them. | |
| # Shape change: [T, Dr] -> [1, 1, T, Dr] | |
| cos = cos.unsqueeze(0).unsqueeze(0) | |
| sin = sin.unsqueeze(0).unsqueeze(0) | |
| #print("q_rope.shape[-1] // 2:", (q_rope.shape[-1] // 2)) | |
| #print("x1 = x[..., :x.shape[-1] // 2 ].shape:", q_rope[..., :q_rope.shape[-1] // 2 ].shape) | |
| #print("sin/cos.shape:", cos.shape) | |
| #print("q_rope.shape:", q_rope.shape) | |
| #print("(q_rope * cos).shape:", (q_rope * cos).shape) | |
| #print("rotate_half(q_rope).shape:", rotate_half(q_rope).shape) | |
| #print("(rotate_half(q_rope) * sin).shape:", (rotate_half(q_rope) * sin).shape) | |
| """ | |
| In this example batch_size = 2, hum_heads = 8, seq_len = 65, rope_dims = 16 | |
| q_rope.shape[-1] // 2: 8 | |
| x1 = x[..., :x.shape[-1] // 2 ].shape: torch.Size([2, 8, 65, 8]) | |
| sin/cos.shape: torch.Size([1, 1, 65, 16]) # After double unsqueeze. | |
| vq_rope.shape: torch.Size([2, 8, 65, 16]) | |
| (q_rope * cos).shape: torch.Size([2, 8, 65, 16]) | |
| rotate_half(q_rope).shape: torch.Size([2, 8, 65, 16]) | |
| (rotate_half(q_rope) * sin).shape: torch.Size([2, 8, 65, 16]) | |
| """ | |
| # Let's walk through the queries as the example. | |
| # What does rotate half do? | |
| # dim -1 is the row vectors, the queries | |
| # | |
| # Step 1: Split the vector in half. | |
| # "q_rope.shape[-1] // 2" <- How much to select. Half the length of the q_rope vector | |
| # x1 = x[..., :x.shape[-1] // 2 ] # Select the first half of the vector. | |
| # x2 = x[..., x.shape[-1] // 2:] # Select the second half. | |
| # | |
| # Step 2: | |
| # - Apply negative to the values in the second half. | |
| # - Reverse the order of the halves. | |
| # return torch.cat((-x2, x1), dim=-1) | |
| # | |
| # ---- (q_rope * cos) ---- | |
| # Element-wise multiply the values in each `cos` vector with the | |
| # corresponding (i.e., same sequence position) `q_rope` vector. | |
| # | |
| # Inputs: | |
| # q_rope [B, H, T, Dr] | |
| # cos [1, 1, T, Dr] | |
| # | |
| # Outputs: | |
| # x [B, H, T, Dr] | |
| # | |
| # ---- (rotate_half(q_rope)) ---- | |
| # TODO | |
| # | |
| # Inputs: | |
| # q_rope [B, T, Dr] | |
| # | |
| # Outputs: | |
| # rot_q_rope [B, T, Dr] | |
| # | |
| # ---- rotated * sin ---- | |
| # TODO | |
| q_rotated = (q_rope * cos) + (rotate_half(q_rope) * sin) | |
| k_rotated = (k_rope * cos) + (rotate_half(k_rope) * sin) | |
| # 4. Concatenate the rotated and pass-through parts back together | |
| # Input (each): [B, H, T, Dr] and [B, H, T, Dq-Dr] | |
| # Output (each): [B, H, T, Dq] | |
| queries = torch.cat((q_rotated, q_pass), dim=-1) | |
| keys = torch.cat((k_rotated, k_pass), dim=-1) | |
| # =================== | |
| # Attention | |
| # =================== | |
| # The tensors (queries, keys, values) now have shape [B, H, T, Dq] | |
| # and are ready for the attention score calculation. | |
| # Only apply dropout during training. | |
| # self.training is a pytorch flag. | |
| if self.training: | |
| dropout_p = self.attention_dropout_prob | |
| else: | |
| dropout_p = 0.0 | |
| # Call SDPA / Flash Attention | |
| # https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html | |
| attn_output = F.scaled_dot_product_attention( | |
| queries, | |
| keys, | |
| values, | |
| attn_mask=None, # attention_mask, | |
| dropout_p=dropout_p, | |
| scale=self.softmax_scale, | |
| is_causal=True, # This is a decoder - apply causal masking | |
| ) | |
| # Reshape output back to [B, T, H * Dv] from [B, H, T, Dv] | |
| attn_output = attn_output.transpose(1, 2).contiguous().view(B, T, H * Dv) | |
| # ========================= | |
| # Output Projection | |
| # ========================= | |
| # If we are using an output latent projection, | |
| if self.latent_spaces and self.output_subspace: | |
| # Project the attention output into the output latent space. | |
| # This is analogous to the W^O matrix in standard attention but | |
| # projects to an intermediate latent dimension. | |
| attn_output = self.o_private_proj(attn_output) | |
| # Apply normalization to the output latents | |
| attn_output = self.o_private_norm(attn_output) | |
| # Re-project the output latent representation back to model space. | |
| attn_output = self.o_shared_proj(attn_output) | |
| # If this is a dense layer, | |
| else: | |
| # Project the values back into model space. | |
| attn_output = self.o_proj(attn_output) | |
| # ----------------------------------------- | |
| return attn_output | |