kokolamba's picture
Update model files
d56eb1d
"""# β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚
# `mla.py`
Based on: https://huggingface.co/deepseek-ai/DeepSeek-R1/blob/main/modeling_deepseek.py
## RotaryEmbedding
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional
from .shared_space_config import SharedSpaceDecoderConfig
def create_norm_layer(hidden_size: int, config: SharedSpaceDecoderConfig) -> nn.Module:
"""
Create a normalization layer based on the config norm_type.
If `hidden_size` is `None`, this returns an identity layer.
Args:
hidden_size: The dimension to normalize over
config: Configuration containing norm_type and epsilon values
Returns:
Either a LayerNorm or RMSNorm layer
"""
if hidden_size is None:
return nn.Identity()
elif config.norm_type == "layernorm":
return nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)
elif config.norm_type == "rmsnorm":
return DeepseekV3RMSNorm(hidden_size, eps=config.rms_norm_eps)
else:
# This should be caught by config validation, but being defensive
raise ValueError(f"Unknown norm_type: {config.norm_type}")
# TODO - Find a shared place to put this.
class DeepseekV3RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
DeepseekV3RMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
# Helper function needed because it's called twice during RoPE,
# but I dumped it in the comments there.
# TODO - Nah, screw it, just write it twice! At least then you get
# to use the word 'query' instead of 'x'.
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
class RotaryEmbedding(nn.Module):
"""Precompute RoPE embeddings and store them as buffers."""
def __init__(self, config: SharedSpaceDecoderConfig) -> None:
super().__init__()
dim = config.rope_dims
seq_len = config.max_position_embeddings
# ------------------------------
# Compute inverse frequencies
# ------------------------------
# Shape: [dim // 2]
# inv_freq[i] = 1 / (theta^(i / dim))
inv_freq = 1.0 / (
config.rope_theta
** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)
)
# ------------------------------
# Apply RoPE scaling if configured
# ------------------------------
if config.rope_scaling is not None:
scaling_type = config.rope_scaling.get("type", "linear")
scaling_factor = config.rope_scaling.get("factor", 1.0)
if scaling_type == "linear":
# Linear scaling: divide frequencies by scaling factor
inv_freq = inv_freq / scaling_factor
elif scaling_type == "dynamic":
# Dynamic scaling: adjust based on sequence length
# This is a simplified implementation
inv_freq = inv_freq / scaling_factor
else:
print(f"Warning: Unknown RoPE scaling type '{scaling_type}', using linear scaling")
inv_freq = inv_freq / scaling_factor
# ------------------------------
# Compute position indices
# ------------------------------
# Shape: [seq_len]
t = torch.arange(seq_len, dtype=torch.float32)
# ------------------------------
# Outer product: [seq_len, dim // 2]
# Each row i contains: t[i] * inv_freq
# ------------------------------
freqs = torch.outer(t, inv_freq)
# ------------------------------
# Duplicate for interleaved sin/cos: [seq_len, dim]
# This matches the common format: [sin_0, cos_0, sin_1, cos_1, ...]
# ------------------------------
emb = torch.cat((freqs, freqs), dim=-1)
# ------------------------------
# Register cos/sin as buffers
# - Stored in float32
# - Will be moved to correct device/dtype via model.to(...)
# - Not saved with state_dict (persistent=False)
# ------------------------------
self.register_buffer("cos", emb.cos(), persistent=False)
self.register_buffer("sin", emb.sin(), persistent=False)
def forward(self, position_ids: torch.LongTensor) -> tuple[torch.Tensor, torch.Tensor]:
""" """
return None # This function is not necessary.
"""## MLA"""
class MultiheadLatentAttention(nn.Module):
"""
A variant of MLA with:
- Simplified RoPE handling:
- A portion of the head dimensions are used for position information.
- Same number of queries as keys. (no MQA)
- Optional output subspace
"""
def __init__(self, config: SharedSpaceDecoderConfig, layer_idx: int):
super().__init__()
self.config = config
# Used to determine if this layer is dense or uses latents.
self.layer_idx = layer_idx
self.attention_dropout_prob = config.attention_dropout_prob
self.num_heads = config.num_attention_heads
self.rope_theta = config.rope_theta
self.rope_dims = config.rope_dims
self.nope_dims = config.nope_dims
self.q_shared_dim = config.q_shared_dim
self.kv_shared_dim = config.kv_shared_dim
self.o_shared_dim = config.o_shared_dim
self.qk_private_dim = config.qk_private_dim
self.vo_private_dim = config.vo_private_dim
self.hidden_size = config.hidden_size
# =========================
# Input Projections
# =========================
# If this is one of the dense layers,
if self.layer_idx < config.num_dense_layers:
# =========================
# Dense Attention
# =========================
# No latent projections.
self.latent_spaces = False
# Define the standard QKV projection
self.qkv_proj = nn.Linear(
config.hidden_size,
self.num_heads * (self.qk_private_dim * 2 + self.vo_private_dim),
bias=config.attention_bias,
)
# Dense output projection
self.o_proj = nn.Linear(
self.num_heads * self.vo_private_dim,
config.hidden_size,
bias=config.attention_bias,
)
# If we're past the dense layers,
else:
# =========================
# Latent Attention
# =========================
# Use latent projections.
self.latent_spaces = True
# Input latent projections
print("config.q_shared_dim", config.q_shared_dim)
# If we're using a shared query subspace,
if config.q_shared_dim is not None:
# Set a flag that we'll check in `forward`.
self.query_shared = True
self.q_shared_proj = nn.Linear(
config.hidden_size,
self.q_shared_dim,
bias=config.attention_bias,
)
self.q_shared_norm = create_norm_layer(self.q_shared_dim, config)
else:
print("Using identity for shared projection.")
# Set a flag that we'll check in `forward`.
self.query_shared = False
self.q_shared_dim = config.hidden_size
#print("Updated self.q_shared_dim to", self.q_shared_dim)
# Use identity.
self.q_shared_proj = nn.Identity()
self.q_shared_norm = nn.Identity()
# If we're using a shared key/value subspace,
if config.kv_shared_dim is not None:
# Set a flag that we'll check in `forward`.
self.keyvalue_shared = True
self.kv_shared_proj = nn.Linear(
config.hidden_size,
self.kv_shared_dim,
bias=config.attention_bias,
)
self.kv_shared_norm = create_norm_layer(self.kv_shared_dim, config)
else:
# Set a flag that we'll check in `forward`.
self.keyvalue_shared = False
self.kv_shared_dim = config.hidden_size
# Use identity.
self.kv_shared_proj = nn.Identity()
self.kv_shared_norm = nn.Identity()
#print("config.q_shared_dim", config.q_shared_dim)
#print("self.qk_private_dim", self.qk_private_dim)
# Query heads
self.q_private_proj = nn.Linear(
self.q_shared_dim,
self.num_heads * self.qk_private_dim,
bias=False # TODO
)
# Key and Value heads, concatenated
self.kv_private_proj = nn.Linear(
self.kv_shared_dim,
self.num_heads * (self.qk_private_dim + self.vo_private_dim),
bias=False,
)
# Use output subspace if o_shared_dim is specified
self.output_subspace = config.o_shared_dim is not None
# If we're using an output subspace,
if self.output_subspace:
# ==========================
# Output Subspace
# ==========================
self.o_shared_dim = config.o_shared_dim
# Per-head output projections
# (Similar to original W^O, but projects the scored value vectors
# into a latent space instead of back to the model)
self.o_private_proj = nn.Linear(
self.num_heads * self.vo_private_dim,
self.o_shared_dim,
bias=False
)
# Norm layer between o_private_proj and o_shared_proj
# Note: In previous ViT experiments, this norm step hurt performance, but was beneficial
# in the DeepSeekV3 experiments.
# However, we're making it configurable so it can be tested in different contexts.
self.o_private_norm = create_norm_layer(self.o_shared_dim, config)
# Shared output projection
# The head outputs from `o_private_proj` are first summed together (across
# heads) in the latent space.
# Then we project their combined outputs (a single vector per token)
# back to model space via `o_shared_proj`.
self.o_shared_proj = nn.Linear(
self.o_shared_dim,
self.hidden_size,
bias=config.attention_bias
)
else:
# Dense output projection
self.o_proj = nn.Linear(
self.num_heads * self.vo_private_dim,
config.hidden_size,
bias=config.attention_bias,
)
# Softmax scaling factor.
self.softmax_scale = self.qk_private_dim ** (-0.5)
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor],
#past_key_value: Optional[Cache] = None, # TODO - Can I remove this?
#cache_position: Optional[torch.LongTensor] = None, # TODO - Can I remove this?
**kwargs,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
# === Tensor Dimension Symbols ===
# B: batch_size β€” number of samples in the batch
# T: seq_len β€” number of tokens per sample
# H: n_heads β€” number of attention heads
# D: hidden_dim β€” model embedding size
# Dv: vo_private_dim - per-head value/output projection dimension
# Dr: rope_dims - The first Dr dimensions receive rope.
# Cq: q_shared_dim - query shared subspace size
# Ckv: kv_shared_dim - key-value shared subspace size
# Co: o_shared_dim - output shared subspace size
# Input token embeddings
# hidden_states: [B, T, D]
B, T = hidden_states.shape[:2]
H = self.num_heads
Dq = self.qk_private_dim # per-head dim for Q and K
Dv = self.vo_private_dim # per-head dim for V/O
Dc_q, Dc_kv = self.q_shared_dim, self.kv_shared_dim
# ==============================
# QKV Head Projections
# ==============================
# Project tokens into per-head query, key, and value vectors
# If this layer uses latent projections,
if self.latent_spaces:
# ================================
# Shared Space Projections
# ================================
# Project token embeddings into shared latents
# Input:
# hidden_states [B, T, D]
# q_shared_proj [D, Cq]
# kv_shared_proj [D, Ckv]
# Output:
# q_shared [B, T, Cq]
# kv_shared [B, T, Ckv]
# If we're using a shared query subspace,
if self.q_shared_dim is not None:
q_shared = self.q_shared_proj(hidden_states)
# Normalize latent vectors, shapes unchanged.
q_shared = self.q_shared_norm(q_shared)
# Otherwise,
else:
# Use the hidden states
q_shared = hidden_states
# If we're using a shared key/value subspace,
if self.kv_shared_dim is not None:
# Project token embeddings into shared subspace.
kv_shared = self.kv_shared_proj(hidden_states)
# Normalize latent vectors, shapes unchanged.
kv_shared = self.kv_shared_norm(kv_shared)
# Otherwise,
else:
# Use the hidden states
kv_shared = hidden_states
# ======================================
# Per-Head (Private) Projections
# ======================================
# Project query latents onto query heads.
# Input:
# q_shared [B, T, Cq]
# q_private_proj [Cq, H*Dh]
# Output:
# queries [B, T, H*Dh]
queries = self.q_private_proj(q_shared)
# Project key/value latents onto key and value heads.
# The key and value heads are all concatenated, each head occupies
# Dh columns of the kv_private_proj. This yields the key and value
# vectors concatenated in the same way.
#
# Input:
# kv_shared [B, T, Ckv]
# kv_private_proj [Ckv, 2*H*Dh]
# Output:
# keysvalues [B, T, 2*H*Dh]
keysvalues = self.kv_private_proj(kv_shared)
# Split into key and value tensors
# Each: [B, T, H * Dh]
keys, values = keysvalues.chunk(2, dim=-1)
# If this is a dense attention layer (no latent projections),
else:
# ====================
# Standard MHA
# ====================
# Standard QKV projection
# Input:
# hidden_states [B, T, D]
# qkv_proj [D, 3*H*Dh]
# Output:
# querieskeysvalues [B, T, 3*H*Dh]
querieskeysvalues = self.qkv_proj(hidden_states)
# Separate query, key, and value vectors
# Each: [B, T, H * Dh]
queries, keys, values = querieskeysvalues.chunk(3, dim=-1)
# Split up queries so that there's just one per row.
# Same for keys and values.
#
# Inputs:
# Each [B, T, H*Dh]
# Output:
# Each [B, H, T, Dh]
queries = queries.view(B, T, H, Dq).transpose(1, 2)
keys = keys.view(B, T, H, Dq).transpose(1, 2)
values = values.view(B, T, H, Dv).transpose(1, 2)
# ==================
# RoPE
# ==================
# Apply rotary position embeddings to the first `self.rope_dims` of
# each head.
# The slice operations are free, but the concatenation is
# not, because the outputs of the rotation operation are new data
# occupying different memory. Still considered the best option,
# though.
# 1. Unpack the precomputed cosine and sine embeddings
# Position embeddings is a tuple of
# (cos [seq_len, rope_dims],
# sin [seq_len, rope_dims])
cos, sin = position_embeddings
# 2. Split the query and key heads into the part to rotate and the part
# to pass through (early columns get position info, later ones don't)
#
# (Using queries as example)
# Inputs:
# queries [B, H, T, Dh] Dh = rope_dims + not_rope_dims
# Outputs:
# q_rope [B, H, T, Dr]
# q_pass [B, H, T, Dh-Dr]
q_rope, q_pass = queries[..., :self.rope_dims], queries[..., self.rope_dims:]
k_rope, k_pass = keys[..., :self.rope_dims], keys[..., self.rope_dims:]
# 3. Apply the rotary embedding to the designated slice
#
# To broadcast cos and sin across the batch and head dimensions, we unsqueeze them.
# Shape change: [T, Dr] -> [1, 1, T, Dr]
cos = cos.unsqueeze(0).unsqueeze(0)
sin = sin.unsqueeze(0).unsqueeze(0)
#print("q_rope.shape[-1] // 2:", (q_rope.shape[-1] // 2))
#print("x1 = x[..., :x.shape[-1] // 2 ].shape:", q_rope[..., :q_rope.shape[-1] // 2 ].shape)
#print("sin/cos.shape:", cos.shape)
#print("q_rope.shape:", q_rope.shape)
#print("(q_rope * cos).shape:", (q_rope * cos).shape)
#print("rotate_half(q_rope).shape:", rotate_half(q_rope).shape)
#print("(rotate_half(q_rope) * sin).shape:", (rotate_half(q_rope) * sin).shape)
"""
In this example batch_size = 2, hum_heads = 8, seq_len = 65, rope_dims = 16
q_rope.shape[-1] // 2: 8
x1 = x[..., :x.shape[-1] // 2 ].shape: torch.Size([2, 8, 65, 8])
sin/cos.shape: torch.Size([1, 1, 65, 16]) # After double unsqueeze.
vq_rope.shape: torch.Size([2, 8, 65, 16])
(q_rope * cos).shape: torch.Size([2, 8, 65, 16])
rotate_half(q_rope).shape: torch.Size([2, 8, 65, 16])
(rotate_half(q_rope) * sin).shape: torch.Size([2, 8, 65, 16])
"""
# Let's walk through the queries as the example.
# What does rotate half do?
# dim -1 is the row vectors, the queries
#
# Step 1: Split the vector in half.
# "q_rope.shape[-1] // 2" <- How much to select. Half the length of the q_rope vector
# x1 = x[..., :x.shape[-1] // 2 ] # Select the first half of the vector.
# x2 = x[..., x.shape[-1] // 2:] # Select the second half.
#
# Step 2:
# - Apply negative to the values in the second half.
# - Reverse the order of the halves.
# return torch.cat((-x2, x1), dim=-1)
#
# ---- (q_rope * cos) ----
# Element-wise multiply the values in each `cos` vector with the
# corresponding (i.e., same sequence position) `q_rope` vector.
#
# Inputs:
# q_rope [B, H, T, Dr]
# cos [1, 1, T, Dr]
#
# Outputs:
# x [B, H, T, Dr]
#
# ---- (rotate_half(q_rope)) ----
# TODO
#
# Inputs:
# q_rope [B, T, Dr]
#
# Outputs:
# rot_q_rope [B, T, Dr]
#
# ---- rotated * sin ----
# TODO
q_rotated = (q_rope * cos) + (rotate_half(q_rope) * sin)
k_rotated = (k_rope * cos) + (rotate_half(k_rope) * sin)
# 4. Concatenate the rotated and pass-through parts back together
# Input (each): [B, H, T, Dr] and [B, H, T, Dq-Dr]
# Output (each): [B, H, T, Dq]
queries = torch.cat((q_rotated, q_pass), dim=-1)
keys = torch.cat((k_rotated, k_pass), dim=-1)
# ===================
# Attention
# ===================
# The tensors (queries, keys, values) now have shape [B, H, T, Dq]
# and are ready for the attention score calculation.
# Only apply dropout during training.
# self.training is a pytorch flag.
if self.training:
dropout_p = self.attention_dropout_prob
else:
dropout_p = 0.0
# Call SDPA / Flash Attention
# https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
attn_output = F.scaled_dot_product_attention(
queries,
keys,
values,
attn_mask=None, # attention_mask,
dropout_p=dropout_p,
scale=self.softmax_scale,
is_causal=True, # This is a decoder - apply causal masking
)
# Reshape output back to [B, T, H * Dv] from [B, H, T, Dv]
attn_output = attn_output.transpose(1, 2).contiguous().view(B, T, H * Dv)
# =========================
# Output Projection
# =========================
# If we are using an output latent projection,
if self.latent_spaces and self.output_subspace:
# Project the attention output into the output latent space.
# This is analogous to the W^O matrix in standard attention but
# projects to an intermediate latent dimension.
attn_output = self.o_private_proj(attn_output)
# Apply normalization to the output latents
attn_output = self.o_private_norm(attn_output)
# Re-project the output latent representation back to model space.
attn_output = self.o_shared_proj(attn_output)
# If this is a dense layer,
else:
# Project the values back into model space.
attn_output = self.o_proj(attn_output)
# -----------------------------------------
return attn_output