# -*- coding: utf-8 -*- """ modeling_shared_subspace_decoder.py SharedSpaceDecoder model implementation for HuggingFace Transformers. """ from typing import Optional import torch from torch import nn from transformers.configuration_utils import PretrainedConfig from transformers.modeling_utils import PreTrainedModel from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa from layers.mla import MultiheadLatentAttention, RotaryEmbedding from layers.feedforward import SubspaceFeedForward from .configuration_shared_subspace_decoder import SharedSpaceDecoderConfig """`RMSNorm` From: https://huggingface.co/deepseek-ai/DeepSeek-R1/blob/main/modeling_deepseek.py TODO - May not need? """ class DeepseekV3RMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ DeepseekV3RMSNorm is equivalent to T5LayerNorm """ super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) def create_norm_layer(hidden_size: int, config: SharedSpaceDecoderConfig) -> nn.Module: """ Create a normalization layer based on the config norm_type. Args: hidden_size: The dimension to normalize over config: Configuration containing norm_type and epsilon values Returns: Either a LayerNorm or RMSNorm layer """ if config.norm_type == "layernorm": return nn.LayerNorm(hidden_size, eps=config.layer_norm_eps) elif config.norm_type == "rmsnorm": return DeepseekV3RMSNorm(hidden_size, eps=config.rms_norm_eps) else: # This should be caught by config validation, but being defensive raise ValueError(f"Unknown norm_type: {config.norm_type}") """#### *PreTrainedModel""" class SharedSpaceDecoderPreTrainedModel(PreTrainedModel): """ The **PreTrainedModel object: - Is instantiated when TODO - Initializes: - TODO - Provides access to TODO - Executes TODO """ config_class = SharedSpaceDecoderConfig base_model_prefix = "model" def _init_weights(self, module: nn.Module) -> None: """Weight initialization hook used by :class:`PreTrainedModel`. ``PreTrainedModel.post_init`` will recursively apply this function to every submodule right after construction. HuggingFace models override it so that creating a model from scratch yields the same initialization as ``from_pretrained`` when no checkpoint is supplied. This decoder-specific initialization strategy includes: - Proper handling of configurable normalization layers (LayerNorm or RMSNorm) - Special initialization for language modeling heads - Considerations for causal attention and autoregressive modeling - Support for both dense and decomposed vocabulary embeddings """ if isinstance(module, nn.Linear): # Standard linear layer initialization module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): # Initialize embeddings with normal distribution module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() elif isinstance(module, DeepseekV3RMSNorm): # RMSNorm initialization: weight to 1.0, no bias term module.weight.data.fill_(1.0) elif isinstance(module, nn.LayerNorm): # LayerNorm initialization: bias to 0, weight to 1.0 module.bias.data.zero_() module.weight.data.fill_(1.0) """# ▂▂▂▂▂▂▂▂▂▂▂▂ # Classes """ """#### `*Layer`""" class SharedSpaceDecoderLayer(nn.Module): """ The **Layer object: - Is instantiated by :class:`SharedSpaceDecoderModel` for each Transformer block in the decoder. - Initializes: - ``self_attn`` – multi-head latent attention implementing either dense or latent projections depending on the configuration. - ``ffn`` – a :class:`SubspaceFeedForward` block. - RMSNorm layers for pre-attention and pre-FFN normalization. - Provides access to the attention and feed-forward submodules via the attributes ``self_attn`` and ``ffn``. - Executes a single decoder block in :meth:`forward`. """ def __init__(self, config: SharedSpaceDecoderConfig, layer_idx: int) -> None: super().__init__() # Norm applied prior to attention. self.attn_input_norm = create_norm_layer(config.hidden_size, config) # Attention block self.self_attn = MultiheadLatentAttention(config, layer_idx) # Norm applied prior to FFN self.ffn_input_norm = create_norm_layer(config.hidden_size, config) # Feed-forward network used after attention self.ffn = SubspaceFeedForward(config, layer_idx) def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], # RoPE embeddings attention_mask: Optional[torch.Tensor], ) -> torch.Tensor: # ======================== # Self Attention # ======================== residual_strm = hidden_states # Normalize the hidden states to create the input to attention. attn_input = self.attn_input_norm(hidden_states) # Evaluate attn_output = self.self_attn( attn_input, position_embeddings, attention_mask, ) # Add the attention output (the residual) back to the non-normalized # hidden_states. hidden_states = residual_strm + attn_output # =========================== # Feed-Forward Network # =========================== residual_strm = hidden_states # Normalize the updated hidden states prior to the FFN ffn_input = self.ffn_input_norm(hidden_states) # Evaluate ffn_output = self.ffn(ffn_input) # Add the output the un-normalized hidden states. hidden_states = residual_strm + ffn_output return hidden_states """#### *Model""" class SharedSpaceDecoderModel(SharedSpaceDecoderPreTrainedModel): """ The **Model object: - Initializes: - The vocabulary embeddings (and optional decomposition) - Position embeddings (calculated in RotaryEmbedding) - All of the **Layer objects. - Provides interface to vocab embeddings. - Executes the whole decoder model in `forward` with causal attention. This is the base decoder without the language modeling head. Use SubspaceDecoderForCausalLM for language modeling tasks. """ def __init__(self, config: SharedSpaceDecoderConfig) -> None: super().__init__(config) # ============================ # Vocabulary Embeddings # ============================ # Decomposing the vocabulary (if enabled) defines a shared projection # which constrains the model to store semantic information (and # whatever other static token knowledge) into a limited set of # feature directions. # If we're decomposing the token embeddings, # TODO - Rename to vocab_subspace. if config.vocab_subspace: # Create the embedding table. Vocabulary embeddings are learned # in a lower dimensional latent space. self.vocab_embed = nn.Embedding( config.vocab_size, # Number of tokens config.vocab_rank # Subspace dimension ) # Create a # Selected token latents will be projected up to model size. # vocab_proj has shape [vocab_rank x model_size] self.vocab_proj = nn.Linear( config.vocab_rank, # Size of latents config.hidden_size, # Model size bias=False ) # Otherwise, for a dense vocabulary, else: # Create the dense embedding table in model space. self.vocab_embed = nn.Embedding( config.vocab_size, # Number of tokens config.hidden_size # Model size ) self.vocab_proj = None # ===================== # RoPE Embeddings # ===================== # Pre-computes the table of RoPE embeddings, leaving them in # GPU memory. self.rope = RotaryEmbedding(config) # =================== # Create Layers # =================== layers = [] # For each layer, for i in range(config.num_hidden_layers): # Create a **Layer, providing the config and indicating its number. layers.append( SharedSpaceDecoderLayer( config, layer_idx = i ) ) # Wrap in torch ModuleList self.layers = nn.ModuleList(layers) # Whatever huggingface does behind the scenes... self.post_init() # Agents: Do not define boilerplate helpers, e.g., get/set_input_embeddings def embed(self, input_ids: torch.LongTensor) -> torch.Tensor: """ Return token embeddings for input ids. This will perform the up projection to model space if the vocabulary is decomposed. input_ids have shape [batch_size, seq_len] """ # If the vocabulary is decomposed, if self.vocab_proj is not None: # Retrieve the latents # input_ids: [batch_size, seq_len] # x: [batch_size, seq_len, latent_dim] x = self.vocab_embed(input_ids) # Project the latents back to model space and return. return(self.vocab_proj(x)) # If the vocabulary is dense, else: # Just return the embeddings. return self.vocab_embed(input_ids) def forward( self, input_ids: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: """ Run the full decoder stack with causal attention. Inputs: input_ids [batch_size, seq_len] attention_mask [batch_size, seq_len] - 1 for real tokens, 0 for padding Returns: Final decoder layer output [batch_size, seq_len, model_size] """ # Retrieve the token embeddings for this sequence. # These are model_size, regardless of whether the vocab is decompd. hidden_states = self.embed(input_ids) # Retrieve the rotary position embeddings for all of the positions in # our current input sequence. seq_len = hidden_states.size(1) # Retrieves just the ones necessary for the sequence length of the # input. These are vectors, two per token. Their length is the # number of head dimensions we're applying RoPE to. # Input # cos: [max_seq_len, rope_dims] # sin: [max_seq_len, rope_dims] # Outputs: # R_cos [seq_len, rope_dims] # R_sin [seq_len, rope_dims] R_cos = self.rope.cos[:seq_len] R_sin = self.rope.sin[:seq_len] # =============================== # Attention Mask Conversion # =============================== """ use_sdpa_attention_masks = ( self.attn_implementation == "sdpa" and self.position_embedding_type == "absolute" and head_mask is None and not output_attentions ) """ # Expand the attention mask #if use_sdpa_attention_masks and attention_mask.dim() == 2: if True: # Expand the attention mask for SDPA. # [bsz, seq_len] -> [bsz, 1, seq_len, seq_len] extended_attention_mask = _prepare_4d_attention_mask_for_sdpa( attention_mask, hidden_states.dtype, tgt_len = seq_len ) attention_mask = extended_attention_mask # Run the model! # For each decoder layer, for layer_i, layer in enumerate(self.layers): # Evaluate the layer hidden_states = layer( hidden_states, # Token embeddings (R_cos, R_sin), # Rope embeddings, passed as a tuple. attention_mask, # Attn mask ) # Return the final output of the decoder stack. return hidden_states