|
|
|
|
|
|
|
|
""" |
|
|
modeling_shared_subspace_decoder.py |
|
|
|
|
|
SharedSpaceDecoder model implementation for HuggingFace Transformers. |
|
|
""" |
|
|
|
|
|
from typing import Optional |
|
|
|
|
|
import torch |
|
|
from torch import nn |
|
|
|
|
|
from transformers.configuration_utils import PretrainedConfig |
|
|
from transformers.modeling_utils import PreTrainedModel |
|
|
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa |
|
|
|
|
|
from layers.mla import MultiheadLatentAttention, RotaryEmbedding |
|
|
from layers.feedforward import SubspaceFeedForward |
|
|
from .configuration_shared_subspace_decoder import SharedSpaceDecoderConfig |
|
|
|
|
|
"""`RMSNorm` |
|
|
|
|
|
From: |
|
|
https://huggingface.co/deepseek-ai/DeepSeek-R1/blob/main/modeling_deepseek.py |
|
|
|
|
|
TODO - May not need? |
|
|
""" |
|
|
|
|
|
class DeepseekV3RMSNorm(nn.Module): |
|
|
def __init__(self, hidden_size, eps=1e-6): |
|
|
""" |
|
|
DeepseekV3RMSNorm is equivalent to T5LayerNorm |
|
|
""" |
|
|
super().__init__() |
|
|
self.weight = nn.Parameter(torch.ones(hidden_size)) |
|
|
self.variance_epsilon = eps |
|
|
|
|
|
def forward(self, hidden_states): |
|
|
input_dtype = hidden_states.dtype |
|
|
hidden_states = hidden_states.to(torch.float32) |
|
|
variance = hidden_states.pow(2).mean(-1, keepdim=True) |
|
|
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) |
|
|
return self.weight * hidden_states.to(input_dtype) |
|
|
|
|
|
def create_norm_layer(hidden_size: int, config: SharedSpaceDecoderConfig) -> nn.Module: |
|
|
""" |
|
|
Create a normalization layer based on the config norm_type. |
|
|
|
|
|
Args: |
|
|
hidden_size: The dimension to normalize over |
|
|
config: Configuration containing norm_type and epsilon values |
|
|
|
|
|
Returns: |
|
|
Either a LayerNorm or RMSNorm layer |
|
|
""" |
|
|
if config.norm_type == "layernorm": |
|
|
return nn.LayerNorm(hidden_size, eps=config.layer_norm_eps) |
|
|
elif config.norm_type == "rmsnorm": |
|
|
return DeepseekV3RMSNorm(hidden_size, eps=config.rms_norm_eps) |
|
|
else: |
|
|
|
|
|
raise ValueError(f"Unknown norm_type: {config.norm_type}") |
|
|
|
|
|
"""#### *PreTrainedModel""" |
|
|
|
|
|
class SharedSpaceDecoderPreTrainedModel(PreTrainedModel): |
|
|
""" |
|
|
The **PreTrainedModel object: |
|
|
- Is instantiated when TODO |
|
|
- Initializes: |
|
|
- TODO |
|
|
- Provides access to TODO |
|
|
- Executes TODO |
|
|
""" |
|
|
|
|
|
config_class = SharedSpaceDecoderConfig |
|
|
base_model_prefix = "model" |
|
|
|
|
|
def _init_weights(self, module: nn.Module) -> None: |
|
|
"""Weight initialization hook used by :class:`PreTrainedModel`. |
|
|
|
|
|
``PreTrainedModel.post_init`` will recursively apply this function to |
|
|
every submodule right after construction. HuggingFace models override |
|
|
it so that creating a model from scratch yields the same initialization |
|
|
as ``from_pretrained`` when no checkpoint is supplied. |
|
|
|
|
|
This decoder-specific initialization strategy includes: |
|
|
- Proper handling of configurable normalization layers (LayerNorm or RMSNorm) |
|
|
- Special initialization for language modeling heads |
|
|
- Considerations for causal attention and autoregressive modeling |
|
|
- Support for both dense and decomposed vocabulary embeddings |
|
|
""" |
|
|
|
|
|
if isinstance(module, nn.Linear): |
|
|
|
|
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
|
|
if module.bias is not None: |
|
|
module.bias.data.zero_() |
|
|
|
|
|
elif isinstance(module, nn.Embedding): |
|
|
|
|
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
|
|
if module.padding_idx is not None: |
|
|
module.weight.data[module.padding_idx].zero_() |
|
|
|
|
|
elif isinstance(module, DeepseekV3RMSNorm): |
|
|
|
|
|
module.weight.data.fill_(1.0) |
|
|
|
|
|
elif isinstance(module, nn.LayerNorm): |
|
|
|
|
|
module.bias.data.zero_() |
|
|
module.weight.data.fill_(1.0) |
|
|
|
|
|
"""# ββββββββββββ |
|
|
|
|
|
# Classes |
|
|
""" |
|
|
|
|
|
"""#### `*Layer`""" |
|
|
|
|
|
class SharedSpaceDecoderLayer(nn.Module): |
|
|
""" |
|
|
The **Layer object: |
|
|
- Is instantiated by :class:`SharedSpaceDecoderModel` for each |
|
|
Transformer block in the decoder. |
|
|
- Initializes: |
|
|
- ``self_attn`` β multi-head latent attention implementing either |
|
|
dense or latent projections depending on the configuration. |
|
|
- ``ffn`` β a :class:`SubspaceFeedForward` block. |
|
|
- RMSNorm layers for pre-attention and pre-FFN normalization. |
|
|
- Provides access to the attention and feed-forward submodules via the |
|
|
attributes ``self_attn`` and ``ffn``. |
|
|
- Executes a single decoder block in :meth:`forward`. |
|
|
""" |
|
|
|
|
|
def __init__(self, config: SharedSpaceDecoderConfig, layer_idx: int) -> None: |
|
|
|
|
|
super().__init__() |
|
|
|
|
|
|
|
|
self.attn_input_norm = create_norm_layer(config.hidden_size, config) |
|
|
|
|
|
|
|
|
self.self_attn = MultiheadLatentAttention(config, layer_idx) |
|
|
|
|
|
|
|
|
self.ffn_input_norm = create_norm_layer(config.hidden_size, config) |
|
|
|
|
|
|
|
|
self.ffn = SubspaceFeedForward(config, layer_idx) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
position_embeddings: tuple[torch.Tensor, torch.Tensor], |
|
|
attention_mask: Optional[torch.Tensor], |
|
|
) -> torch.Tensor: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
residual_strm = hidden_states |
|
|
|
|
|
|
|
|
attn_input = self.attn_input_norm(hidden_states) |
|
|
|
|
|
|
|
|
attn_output = self.self_attn( |
|
|
attn_input, |
|
|
position_embeddings, |
|
|
attention_mask, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
hidden_states = residual_strm + attn_output |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
residual_strm = hidden_states |
|
|
|
|
|
|
|
|
ffn_input = self.ffn_input_norm(hidden_states) |
|
|
|
|
|
|
|
|
ffn_output = self.ffn(ffn_input) |
|
|
|
|
|
|
|
|
hidden_states = residual_strm + ffn_output |
|
|
|
|
|
return hidden_states |
|
|
|
|
|
"""#### *Model""" |
|
|
|
|
|
class SharedSpaceDecoderModel(SharedSpaceDecoderPreTrainedModel): |
|
|
""" |
|
|
The **Model object: |
|
|
- Initializes: |
|
|
- The vocabulary embeddings (and optional decomposition) |
|
|
- Position embeddings (calculated in RotaryEmbedding) |
|
|
- All of the **Layer objects. |
|
|
- Provides interface to vocab embeddings. |
|
|
- Executes the whole decoder model in `forward` with causal attention. |
|
|
|
|
|
This is the base decoder without the language modeling head. |
|
|
Use SubspaceDecoderForCausalLM for language modeling tasks. |
|
|
""" |
|
|
|
|
|
def __init__(self, config: SharedSpaceDecoderConfig) -> None: |
|
|
super().__init__(config) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if config.vocab_subspace: |
|
|
|
|
|
|
|
|
|
|
|
self.vocab_embed = nn.Embedding( |
|
|
config.vocab_size, |
|
|
config.vocab_rank |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.vocab_proj = nn.Linear( |
|
|
config.vocab_rank, |
|
|
config.hidden_size, |
|
|
bias=False |
|
|
) |
|
|
|
|
|
|
|
|
else: |
|
|
|
|
|
self.vocab_embed = nn.Embedding( |
|
|
config.vocab_size, |
|
|
config.hidden_size |
|
|
) |
|
|
|
|
|
self.vocab_proj = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.rope = RotaryEmbedding(config) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
layers = [] |
|
|
|
|
|
|
|
|
for i in range(config.num_hidden_layers): |
|
|
|
|
|
layers.append( |
|
|
SharedSpaceDecoderLayer( |
|
|
config, |
|
|
layer_idx = i |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
self.layers = nn.ModuleList(layers) |
|
|
|
|
|
|
|
|
self.post_init() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def embed(self, input_ids: torch.LongTensor) -> torch.Tensor: |
|
|
""" |
|
|
Return token embeddings for input ids. |
|
|
This will perform the up projection to model space if the vocabulary is |
|
|
decomposed. |
|
|
|
|
|
input_ids have shape [batch_size, seq_len] |
|
|
""" |
|
|
|
|
|
|
|
|
if self.vocab_proj is not None: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
x = self.vocab_embed(input_ids) |
|
|
|
|
|
|
|
|
return(self.vocab_proj(x)) |
|
|
|
|
|
|
|
|
else: |
|
|
|
|
|
return self.vocab_embed(input_ids) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: torch.LongTensor, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
**kwargs, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Run the full decoder stack with causal attention. |
|
|
|
|
|
Inputs: |
|
|
input_ids [batch_size, seq_len] |
|
|
attention_mask [batch_size, seq_len] - 1 for real tokens, 0 for padding |
|
|
|
|
|
Returns: |
|
|
Final decoder layer output [batch_size, seq_len, model_size] |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
hidden_states = self.embed(input_ids) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
seq_len = hidden_states.size(1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
R_cos = self.rope.cos[:seq_len] |
|
|
R_sin = self.rope.sin[:seq_len] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
use_sdpa_attention_masks = ( |
|
|
self.attn_implementation == "sdpa" |
|
|
and self.position_embedding_type == "absolute" |
|
|
and head_mask is None |
|
|
and not output_attentions |
|
|
) |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
if True: |
|
|
|
|
|
|
|
|
extended_attention_mask = _prepare_4d_attention_mask_for_sdpa( |
|
|
attention_mask, |
|
|
hidden_states.dtype, |
|
|
tgt_len = seq_len |
|
|
) |
|
|
attention_mask = extended_attention_mask |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for layer_i, layer in enumerate(self.layers): |
|
|
|
|
|
|
|
|
hidden_states = layer( |
|
|
hidden_states, |
|
|
(R_cos, R_sin), |
|
|
attention_mask, |
|
|
) |
|
|
|
|
|
|
|
|
return hidden_states |
|
|
|
|
|
|