kokolamba's picture
Update model files
d56eb1d
"""# β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚
# `gla.py`
Based on: https://huggingface.co/deepseek-ai/DeepSeek-R1/blob/main/modeling_deepseek.py
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional
import math
from .shared_space_config import SharedSpaceDecoderConfig
def create_norm_layer(hidden_size: int, config: SharedSpaceDecoderConfig) -> nn.Module:
"""
Create a normalization layer based on the config norm_type.
If `hidden_size` is `None`, this returns an identity layer.
Args:
hidden_size: The dimension to normalize over
config: Configuration containing norm_type and epsilon values
Returns:
Either a LayerNorm or RMSNorm layer
"""
if hidden_size is None:
return nn.Identity()
elif config.norm_type == "layernorm":
return nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)
elif config.norm_type == "rmsnorm":
return DeepseekV3RMSNorm(hidden_size, eps=config.rms_norm_eps)
else:
# This should be caught by config validation, but being defensive
raise ValueError(f"Unknown norm_type: {config.norm_type}")
# TODO - Find a shared place to put this.
class DeepseekV3RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
DeepseekV3RMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
# Helper function needed because it's called twice during RoPE,
# but I dumped it in the comments there.
# TODO - Nah, screw it, just write it twice! At least then you get
# to use the word 'query' instead of 'x'.
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
class RotaryEmbedding(nn.Module):
"""Precompute RoPE embeddings and store them as buffers."""
def __init__(self, config: SharedSpaceDecoderConfig) -> None:
super().__init__()
dim = config.rope_dims
seq_len = config.max_position_embeddings
# ------------------------------
# Compute inverse frequencies
# ------------------------------
# Shape: [dim // 2]
# inv_freq[i] = 1 / (theta^(i / dim))
inv_freq = 1.0 / (
config.rope_theta
** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)
)
# ------------------------------
# Apply RoPE scaling if configured
# ------------------------------
if config.rope_scaling is not None:
scaling_type = config.rope_scaling.get("type", "linear")
scaling_factor = config.rope_scaling.get("factor", 1.0)
if scaling_type == "linear":
# Linear scaling: divide frequencies by scaling factor
inv_freq = inv_freq / scaling_factor
elif scaling_type == "dynamic":
# Dynamic scaling: adjust based on sequence length
# This is a simplified implementation
inv_freq = inv_freq / scaling_factor
else:
print(f"Warning: Unknown RoPE scaling type '{scaling_type}', using linear scaling")
inv_freq = inv_freq / scaling_factor
# ------------------------------
# Compute position indices
# ------------------------------
# Shape: [seq_len]
t = torch.arange(seq_len, dtype=torch.float32)
# ------------------------------
# Outer product: [seq_len, dim // 2]
# Each row i contains: t[i] * inv_freq
# ------------------------------
freqs = torch.outer(t, inv_freq)
# ------------------------------
# Duplicate for interleaved sin/cos: [seq_len, dim]
# This matches the common format: [sin_0, cos_0, sin_1, cos_1, ...]
# ------------------------------
emb = torch.cat((freqs, freqs), dim=-1)
# ------------------------------
# Register cos/sin as buffers
# - Stored in float32
# - Will be moved to correct device/dtype via model.to(...)
# - Not saved with state_dict (persistent=False)
# ------------------------------
self.register_buffer("cos", emb.cos(), persistent=False)
self.register_buffer("sin", emb.sin(), persistent=False)
def forward(self, position_ids: torch.LongTensor) -> tuple[torch.Tensor, torch.Tensor]:
""" """
return None # This function is not necessary.
"""## GLA"""
class GroupedLatentAttention(nn.Module):
"""
This version of Multihead Latent Attention applies the re-ordering trick from DeepSeekV3.
Instead of comparing the queries and keys in the query-key space, we compare them in the
kv-shared space.
For clarity, I've re-interpreted the naming of the heads, and am framing it as MQA.
What were previously labeled the query and key heads are now treated as a low-rank decomposition
of the query heads.
What we considered the "shared key/value space" is now a single key head that is also used as the
value head.
Finally, what we previously labeled the value and output heads are now treated as a low-rank
decomposition of the output heads.
This interpretation / implementation is designed to leverage the performance benefits of GQA.
The trade-off is that the query-key matching space is now larger--it will require a greater
number of calculations to match the queries to the keys. The hope is that the memory bandwidth
savings will outweigh the increased computational cost.
The same applies to the value-output space.
Note that, although the query-key and value-output spaces are now large, the low-rank
decomposition of the query heads and output heads ensures that the heads are still effectively
low rank / not over-parameterized.
Finally, note that this implementation also supports the optional use of shared spaces on
the query and output sides.
I've named the class "GroupedLatentAttention" because I may expand it to support multiple
key/value heads (i.e., multiple groups of query heads) in the future.
==== Adding RoPE to VO ====
### **Attempt**
We're extending Rotary Position Embeddings (RoPE) beyond the query-key interaction to the **value-output path** in Multihead Latent Attention (MLA).
* In DeepSeek-V3's MLA framing, the same **full-rank key/value head** provides both the keys (for patterns) and the values (for messages).
* Queries and output heads are low-rank bottlenecks, effectively serving as vocabularies of **pattern directions** (Q) and **message directions** (O).
* Standard RoPE only modulates the Q–K dot product. Our attempt is to also apply RoPE phases consistently in the V–O pathway, so that **positional dependence is preserved in both the matching (QK) and messaging (VO) sides**.
--
### **Hypothesis**
If we rotate value vectors by their **source position phase** and then apply the **inverse rotation at the destination** before output projection, the model gains a clean **relative-position equivariance** in the message path, mirroring the property RoPE provides for queries and keys.
This should:
1. Make the 1-to-1 correspondence between "pattern templates" (Q) and "message templates" (O) more consistent.
2. Reduce the burden on output heads to learn ad-hoc positional compensation.
3. Improve long-context generalization, since both attention matching *and* message passing would share the same relative-position geometry.
"""
def __init__(self, config: SharedSpaceDecoderConfig, layer_idx: int):
super().__init__()
self.config = config
# Used to determine if this layer is dense or uses latents.
self.layer_idx = layer_idx
self.attention_dropout_prob = config.attention_dropout_prob
self.num_heads = config.num_attention_heads
self.rope_theta = config.rope_theta
self.rope_dims = config.rope_dims
self.nope_dims = config.nope_dims
self.q_shared_dim = config.q_shared_dim
# What was previously considered the key/value shared dimension is now the
# size of the MQA style single key/value head.
self.kv_head_dim = config.kv_shared_dim
self.o_shared_dim = config.o_shared_dim
# What was previously the query/key head size is now the size of
# the query head decomposition.
self.q_inner_dim = config.qk_private_dim
# What was previously the value/output head size is now the size of
# the output head decomposition.
self.o_inner_dim = config.vo_private_dim
self.hidden_size = config.hidden_size
# =========================
# Input Projections
# =========================
# If this is one of the dense layers,
if self.layer_idx < config.num_dense_layers:
# =========================
# Dense Attention
# =========================
# No latent projections.
self.latent_spaces = False
# Define the standard QKV projection
self.qkv_proj = nn.Linear(
config.hidden_size,
self.num_heads * (self.qk_private_dim * 2 + self.vo_private_dim),
bias=config.attention_bias,
)
# Dense output projection
self.o_proj = nn.Linear(
self.num_heads * self.vo_private_dim,
config.hidden_size,
bias=config.attention_bias,
)
# If we're past the dense layers,
else:
# =========================
# Latent Attention
# =========================
# Use latent projections.
self.latent_spaces = True
# Input latent projections
print("config.q_shared_dim", config.q_shared_dim)
# ==========================
# Shared Query Space
# ==========================
# If we're using a shared query subspace,
if config.q_shared_dim is not None:
# Set a flag that we'll check in `forward`.
self.query_shared = True
self.q_shared_proj = nn.Linear(
config.hidden_size,
self.q_shared_dim,
bias=config.attention_bias,
)
self.q_shared_norm = create_norm_layer(self.q_shared_dim, config)
else:
print("Using identity for shared projection.")
# Set a flag that we'll check in `forward`.
self.query_shared = False
self.q_shared_dim = config.hidden_size
#print("Updated self.q_shared_dim to", self.q_shared_dim)
# Use identity.
self.q_shared_proj = nn.Identity()
self.q_shared_norm = nn.Identity()
# ==========================
# Shared Output Space
# ==========================
# If we're using a shared output space,
if config.o_shared_dim is not None:
# Set a flag that we'll check in `forward`.
self.output_shared = True
# Shared output projection
# The head outputs from `o_private_proj` are first summed together (across
# heads) in the latent space.
# Then we project their combined outputs (a single vector per token)
# back to model space via `o_shared_proj`.
self.o_shared_proj = nn.Linear(
self.o_shared_dim,
self.hidden_size,
bias=config.attention_bias
)
self.o_shared_norm = create_norm_layer(self.o_shared_dim, config)
else:
# Set a flag that we'll check in `forward`.
self.output_shared = False
self.o_shared_dim = config.hidden_size
# Use identity.
self.o_shared_proj = nn.Identity()
self.o_shared_norm = nn.Identity()
# ================================
# Decomposed Query Heads
# ================================
# Query down projections.
# The query head inner dimension makes the head low rank, as usual.
self.q_priv_a_proj = nn.Linear(
self.q_shared_dim,
self.num_heads * self.q_inner_dim,
bias=False
)
# Query up projections.
# We project back to the larger key/value space.
# Rather than create a linear and break it apart, we can create our
# desired shapes.
# per-head Dq_c -> Dkv (store as [H, Dq_c, Dkv])
self.q_priv_b_weight = nn.Parameter(
torch.empty(self.num_heads, self.q_inner_dim, self.kv_head_dim)
)
nn.init.kaiming_uniform_(self.q_priv_b_weight, a=math.sqrt(5))
# ====================================
# Single Joint Key/Value Head
# ====================================
# The single joint key/value head.
self.kv_priv_proj = nn.Linear(
self.hidden_size,
self.kv_head_dim,
bias=False,
)
self.kv_priv_norm = create_norm_layer(self.kv_head_dim, config)
# ================================
# Decomposed Output Heads
# ================================
# Down: values [B,H,T,Dkv] -> per-head Do_c using weights [H, Dkv, Do_c]
self.o_priv_a_weight = nn.Parameter(
torch.empty(self.num_heads, self.kv_head_dim, self.o_inner_dim)
)
nn.init.kaiming_uniform_(self.o_priv_a_weight, a=math.sqrt(5))
# Output up projections.
# We project back to the larger output subspace (or the model space,
# if no subspace is used).
self.o_priv_b_proj = nn.Linear(
self.num_heads * self.o_inner_dim,
self.o_shared_dim,
bias=False
)
# Let SDPA choose 1/sqrt(E). If you want explicit: self.kv_head_dim ** -0.5
self.softmax_scale = None
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor],
#past_key_value: Optional[Cache] = None, # TODO - Can I remove this?
#cache_position: Optional[torch.LongTensor] = None, # TODO - Can I remove this?
**kwargs,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
# === Tensor Dimension Symbols ===
# B: batch_size β€” number of samples in the batch
# T: seq_len β€” number of tokens per sample
# H: n_heads β€” number of attention heads
# D: hidden_dim β€” model embedding size
# Dq_c: q_inner_dim - per-head decomposition dim for Q
Dq_c = self.q_inner_dim # per-head inner dim for Q
# Do_c: o_inner_dim - per-head decomposition dim for O
Do_c = self.o_inner_dim # per-head inner dim for O
# Dkv: kv_head_dim - Head size of the joint key/value head
Dkv = self.kv_head_dim # Head size of the joint key/value head
# Dr: rope_dims - The first Dr dimensions receive rope.
# Dq_s: q_shared_dim - query shared subspace size
Dq_s = self.q_shared_dim
# Do_s: o_shared_dim - output shared subspace size
Do_s = self.o_shared_dim
# Input token embeddings
# hidden_states: [B, T, D]
B, T = hidden_states.shape[:2]
H = self.num_heads
# =============================
# Shared Query Space
# =============================
# These are set to identity if no shared query space is used.
# Project token embeddings into shared latents
# Input:
# hidden_states [B, T, D]
# q_shared_proj [D, Dq_s]
# kv_shared_proj [D, Dkv]
# Output:
# q_shared [B, T, Dq_s]
# kv_shared [B, T, Dkv]
q_shared = self.q_shared_proj(hidden_states)
# Normalize latent vectors, shapes unchanged.
q_shared = self.q_shared_norm(q_shared)
# ================================
# Decomposed Query Heads
# ================================
# Project query latents onto decomposed query heads.
#
# Down projection ('a')
# Input:
# q_shared [B, T, Dq_s]
# q_priv_a_proj [Dq_s, H*Dq_c]
# Output:
# queries_c [B, T, H*Dq_c]
queries_c = self.q_priv_a_proj(q_shared)
# Split the vectors by head
# Input:
# queries_c [B, T, H*Dq_c]
# Output:
# queries_c [B, T, H, Dq_c]
queries_c = queries_c.view(B, T, H, Dq_c)
# Up projection ('b')
# Input:
# queries_c [B, T, H, Dq_c]
# q_priv_b_weight [H, Dq_c, Dkv]
# Output:
# queries [B, H, T, Dkv]
queries = torch.einsum("bthd,hdc->bhtc", queries_c, self.q_priv_b_weight)
# ===================================
# Single Joint Key/Value Head
# ===================================
# Project token embeddings into single joint key/value head.
# Input:
# hidden_states [B, T, D]
# kv_priv_proj [D, Dkv]
# Output:
# keyvalue [B, T, Dkv]
keyvalue = self.kv_priv_proj(hidden_states)
# Apply QK normalization.
keyvalue = self.kv_priv_norm(keyvalue)
# Prepare the queries and keyvalue vectors for RoPE and flash attention.
# We have multiple query heads, and the queries are in `queries`.
# We have a single key head, and the keyvector is in `keyvalue`.
# Move the head dimension to the front, so for each head, we have
# a series of vectors for each token in the sequence.
#
# Inputs:
# keyvalue [B, T, Dkv]
# Output:
# keyvalue [B, 1, T, Dkv]
keyvalue = keyvalue.unsqueeze(1)
# ==================
# RoPE
# ==================
# Apply rotary position embeddings to the first `self.rope_dims` of
# each head.
# The slice operations are free, but the concatenation is
# not, because the outputs of the rotation operation are new data
# occupying different memory. Still considered the best option,
# though.
# 1. Unpack the precomputed cosine and sine embeddings
# Position embeddings is a tuple of
# (cos [seq_len, rope_dims],
# sin [seq_len, rope_dims])
cos, sin = position_embeddings
# 2. Split the query and key heads into the part to rotate and the part
# to pass through (early columns get position info, later ones don't)
#
# (Using queries as example)
# Inputs:
# queries [B, H, T, Dkv] Dkv = rope_dims + not_rope_dims
# Outputs:
# q_rope [B, H, T, Dr]
# q_pass [B, H, T, Dkv-Dr]
q_rope, q_pass = queries[..., :self.rope_dims], queries[..., self.rope_dims:]
k_rope, k_pass = keyvalue[..., :self.rope_dims], keyvalue[..., self.rope_dims:]
# 3. Apply the rotary embedding to the designated slice
#
# To broadcast cos and sin across the batch and head dimensions, we unsqueeze them.
# Shape change: [T, Dr] -> [1, 1, T, Dr]
cos = cos.unsqueeze(0).unsqueeze(0)
sin = sin.unsqueeze(0).unsqueeze(0)
#print("q_rope.shape[-1] // 2:", (q_rope.shape[-1] // 2))
#print("x1 = x[..., :x.shape[-1] // 2 ].shape:", q_rope[..., :q_rope.shape[-1] // 2 ].shape)
#print("sin/cos.shape:", cos.shape)
#print("q_rope.shape:", q_rope.shape)
#print("(q_rope * cos).shape:", (q_rope * cos).shape)
#print("rotate_half(q_rope).shape:", rotate_half(q_rope).shape)
#print("(rotate_half(q_rope) * sin).shape:", (rotate_half(q_rope) * sin).shape)
"""
In this example batch_size = 2, hum_heads = 8, seq_len = 65, rope_dims = 16
q_rope.shape[-1] // 2: 8
x1 = x[..., :x.shape[-1] // 2 ].shape: torch.Size([2, 8, 65, 8])
sin/cos.shape: torch.Size([1, 1, 65, 16]) # After double unsqueeze.
vq_rope.shape: torch.Size([2, 8, 65, 16])
(q_rope * cos).shape: torch.Size([2, 8, 65, 16])
rotate_half(q_rope).shape: torch.Size([2, 8, 65, 16])
(rotate_half(q_rope) * sin).shape: torch.Size([2, 8, 65, 16])
"""
# Let's walk through the queries as the example.
# What does rotate half do?
# dim -1 is the row vectors, the queries
#
# Step 1: Split the vector in half.
# "q_rope.shape[-1] // 2" <- How much to select. Half the length of the q_rope vector
# x1 = x[..., :x.shape[-1] // 2 ] # Select the first half of the vector.
# x2 = x[..., x.shape[-1] // 2:] # Select the second half.
#
# Step 2:
# - Apply negative to the values in the second half.
# - Reverse the order of the halves.
# return torch.cat((-x2, x1), dim=-1)
#
# ---- (q_rope * cos) ----
# Element-wise multiply the values in each `cos` vector with the
# corresponding (i.e., same sequence position) `q_rope` vector.
#
# Inputs:
# q_rope [B, H, T, Dr]
# cos [1, 1, T, Dr]
#
# Outputs:
# x [B, H, T, Dr]
#
# ---- (rotate_half(q_rope)) ----
# TODO
#
# Inputs:
# q_rope [B, T, Dr]
#
# Outputs:
# rot_q_rope [B, T, Dr]
#
# ---- rotated * sin ----
# TODO
q_rotated = (q_rope * cos) + (rotate_half(q_rope) * sin)
k_rotated = (k_rope * cos) + (rotate_half(k_rope) * sin)
# 4. Concatenate the rotated and pass-through parts back together
# Input (each): [B, H, T, Dr] and [B, H, T, Dkv-Dr]
# Output (each): [B, H, T, Dkv]
# (Where h = 1 for the key head and h = num_heads for the query heads)
queries = torch.cat((q_rotated, q_pass), dim=-1)
keyvalue = torch.cat((k_rotated, k_pass), dim=-1)
# ====================
# GQA / MQA
# ====================
# GPT says that flash attention will infer the broadcasting, so `expand` is not needed.
#
# We need to use the `expand` operation to broadcast the keyvalue vector
# across the query heads.
# Input:
# keyvalue [B, 1, T, Dkv]
# Output:
# keyvalue [B, H, T, Dkv]
#keyvalue = keyvalue.expand(-1, H, -1, -1)
# ===================
# Attention
# ===================
# We're ready for the attention score calculation.
# Only apply dropout during training.
# self.training is a pytorch flag.
if self.training:
dropout_p = self.attention_dropout_prob
else:
dropout_p = 0.0
# Call SDPA / Flash Attention
# https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
# Apply MQA / GQA. In this case, we have a single key head, and multiple query heads.
values = F.scaled_dot_product_attention(
queries,
keyvalue, # Single key vector (joint with value) for GQA / MQA.
keyvalue, # Single value vector (joint with key) for GQA / MQA.
attn_mask=None, # attention_mask,
dropout_p=dropout_p,
scale=self.softmax_scale,
is_causal=True, # This is a decoder - apply causal masking
)
# Attention outputs:
# values [B, H, T, Dkv]
# The final Dr dims of the value vectors carry RoPE information.
# We can either (1) add position dependence to the value-output process,
# or (2) we can strip off the RoPE information and only use the non-RoPE parts.
# Let's try option 1!
# Split the values into the RoPE and non-RoPE parts.
# Input:
# values [B, H, T, Dkv]
# Output:
# values_rope [B, H, T, Dr]
# values_pass [B, H, T, Dkv-Dr]
values_rope, values_pass = values[..., :self.rope_dims], values[..., self.rope_dims:]
# Fold the query RoPE information into the value vectors.
# Inverse rotation: R_{-ΞΈ} x = (x * cos) - (rotate_half(x) * sin)
# Input:
# values_rope [B, H, T, Dr]
# cos [1, 1, T, Dr]
# sin [1, 1, T, Dr]
# Output:
# values_unrot [B, H, T, Dr]
values_unrot = (values_rope * cos) - (rotate_half(values_rope) * sin)
# Now the values have the offset information in their rope dimensions,
# and the output heads can learn to use it.
values = torch.cat((values_unrot, values_pass), dim=-1) # [B,H,T,Dkv]
# =========================
# Output Projection
# =========================
# Project the values onto the decomposed output heads.
# Output down projection heads.
# Input:
# values [B, H, T, Dkv]
# o_priv_a_weight [H, Dkv, Do_c]
# Output:
# outputs_c [B, H, T, Do_c]
outputs_c = torch.einsum("bhtd,hdc->bhtc", values, self.o_priv_a_weight)
# For the up projection, we can concatenate the 'outputs_c' vectors by head,
# (in the same way we would usually concatenate the value vectors)
# Input:
# outputs_c [B, H, T, Do_c]
# Output:
# outputs_c [B, T, H*Do_c]
outputs_c = outputs_c.permute(0, 2, 1, 3).contiguous().view(B, T, H * Do_c)
# Project up to the shared output space and sum across the output heads.
# Input:
# outputs_c [B, T, H*Do_c]
# o_priv_b_proj [H*Do_c, Do_s]
# Output:
# output_s [B, T, Do_s]
output_s = self.o_priv_b_proj(outputs_c)
# Apply normalization to the output latents
output_s = self.o_shared_norm(output_s)
# Re-project the output latent representation back to model space.
# Input:
# output_s [B, T, Do_s]
# o_shared_proj [Do_s, D]
# Output:
# attn_output [B, T, D]
attn_output = self.o_shared_proj(output_s)
# TODO - Not currently supported.
# If this is a dense layer,
# Project the values back into model space.
# attn_output = self.o_proj(attn_output)
# -----------------------------------------
return attn_output