|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from typing import Optional, Union |
|
|
|
|
|
from transformers.modeling_outputs import CausalLMOutputWithPast |
|
|
from transformers.generation.utils import GenerationMixin |
|
|
|
|
|
from .shared_space_config import SharedSpaceDecoderConfig |
|
|
from .shared_space_decoder import ( |
|
|
SharedSpaceDecoderPreTrainedModel, |
|
|
SharedSpaceDecoderModel, |
|
|
DeepseekV3RMSNorm |
|
|
) |
|
|
|
|
|
def create_norm_layer(hidden_size: int, config: SharedSpaceDecoderConfig) -> nn.Module: |
|
|
""" |
|
|
Create a normalization layer based on the config norm_type. |
|
|
|
|
|
Args: |
|
|
hidden_size: The dimension to normalize over |
|
|
config: Configuration containing norm_type and epsilon values |
|
|
|
|
|
Returns: |
|
|
Either a LayerNorm or RMSNorm layer |
|
|
""" |
|
|
if config.norm_type == "layernorm": |
|
|
return nn.LayerNorm(hidden_size, eps=config.layer_norm_eps) |
|
|
elif config.norm_type == "rmsnorm": |
|
|
from .shared_space_decoder import DeepseekV3RMSNorm |
|
|
return DeepseekV3RMSNorm(hidden_size, eps=config.rms_norm_eps) |
|
|
else: |
|
|
|
|
|
raise ValueError(f"Unknown norm_type: {config.norm_type}") |
|
|
|
|
|
|
|
|
class SharedSpaceDecoderForCausalLM(GenerationMixin, SharedSpaceDecoderPreTrainedModel): |
|
|
""" |
|
|
Subspace Decoder model with a causal language modeling head. |
|
|
|
|
|
This model extends the SharedSpaceDecoderModel with: |
|
|
- A language modeling head that projects hidden states to vocabulary logits |
|
|
- Support for computing cross-entropy loss for language modeling |
|
|
- Proper HuggingFace compatibility for causal language modeling tasks |
|
|
- Decoder-specific initialization strategies |
|
|
|
|
|
The model can be used for: |
|
|
- Text generation |
|
|
- Language modeling pretraining |
|
|
- Fine-tuning on downstream tasks |
|
|
""" |
|
|
|
|
|
def __init__(self, config: SharedSpaceDecoderConfig) -> None: |
|
|
super().__init__(config) |
|
|
|
|
|
|
|
|
self.model = SharedSpaceDecoderModel(config) |
|
|
|
|
|
|
|
|
self.norm = create_norm_layer(config.hidden_size, config) |
|
|
|
|
|
|
|
|
|
|
|
self.lm_head = nn.Linear( |
|
|
config.hidden_size, |
|
|
config.vocab_size, |
|
|
bias=False |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
self.post_init() |
|
|
|
|
|
def _init_weights(self, module: nn.Module) -> None: |
|
|
""" |
|
|
Decoder-specific weight initialization with special handling for language modeling head. |
|
|
|
|
|
Key differences from encoder initialization: |
|
|
- Language modeling head gets specialized initialization for stability |
|
|
- Configurable normalization layers (LayerNorm or RMSNorm) are properly handled |
|
|
- Weight tying considerations for embedding/lm_head relationship |
|
|
""" |
|
|
|
|
|
|
|
|
super()._init_weights(module) |
|
|
|
|
|
|
|
|
if module is self.lm_head: |
|
|
|
|
|
|
|
|
|
|
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
|
|
|
|
|
|
|
|
if self.model.vocab_proj is not None: |
|
|
|
|
|
|
|
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range * 0.5) |
|
|
|
|
|
def get_input_embeddings(self): |
|
|
"""Return the input embedding layer for compatibility with HuggingFace.""" |
|
|
return self.model.vocab_embed |
|
|
|
|
|
def set_input_embeddings(self, value): |
|
|
"""Set the input embedding layer for compatibility with HuggingFace.""" |
|
|
self.model.vocab_embed = value |
|
|
|
|
|
def get_output_embeddings(self): |
|
|
"""Return the output embedding layer (lm_head) for compatibility.""" |
|
|
return self.lm_head |
|
|
|
|
|
def set_output_embeddings(self, new_embeddings): |
|
|
"""Set the output embedding layer for compatibility.""" |
|
|
self.lm_head = new_embeddings |
|
|
|
|
|
def tie_weights(self): |
|
|
""" |
|
|
Tie the input and output embedding weights. |
|
|
|
|
|
This method sets the language modeling head's weight to be the same as |
|
|
the input embedding weight. This reduces the number of parameters and |
|
|
is a common practice in modern language models. |
|
|
|
|
|
Note: For vocab subspace models, we need to handle the case where |
|
|
input embeddings go through a projection layer. |
|
|
""" |
|
|
|
|
|
if getattr(self.model, "vocab_proj", None) is None: |
|
|
|
|
|
self._tie_or_clone_weights(self.lm_head, self.model.vocab_embed) |
|
|
|
|
|
|
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: torch.LongTensor, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
labels: Optional[torch.LongTensor] = None, |
|
|
**kwargs, |
|
|
) -> Union[CausalLMOutputWithPast, tuple]: |
|
|
""" |
|
|
Forward pass for causal language modeling. |
|
|
|
|
|
Args: |
|
|
input_ids: Token ids of shape [batch_size, seq_len] |
|
|
attention_mask: Attention mask of shape [batch_size, seq_len] |
|
|
(1 for real tokens, 0 for padding) |
|
|
labels: Ground truth token ids for computing loss. Same shape as input_ids. |
|
|
If provided, loss will be computed. Typically input_ids shifted by 1. |
|
|
|
|
|
Returns: |
|
|
CausalLMOutputWithPast containing: |
|
|
- logits: Prediction logits of shape [batch_size, seq_len, vocab_size] |
|
|
- loss: Cross-entropy loss if labels provided, else None |
|
|
- hidden_states: Final layer hidden states [batch_size, seq_len, hidden_size] |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if attention_mask is None and input_ids is not None: |
|
|
|
|
|
|
|
|
attention_mask = torch.ones( |
|
|
(input_ids.size(0), input_ids.size(1)), |
|
|
dtype=torch.long, |
|
|
device=input_ids.device, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hidden_states = self.model( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
**kwargs |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
hidden_states = self.norm(hidden_states) |
|
|
|
|
|
|
|
|
|
|
|
logits = self.lm_head(hidden_states) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
loss = None |
|
|
if labels is not None: |
|
|
|
|
|
loss = self.loss_function( |
|
|
logits, |
|
|
labels, |
|
|
vocab_size=self.config.vocab_size, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
|
|
|
return CausalLMOutputWithPast( |
|
|
loss=loss, |
|
|
logits=logits, |
|
|
past_key_values=None, |
|
|
|
|
|
hidden_states=hidden_states if kwargs.get("output_hidden_states", False) else None, |
|
|
attentions=None, |
|
|
) |
|
|
|
|
|
|
|
|
def prepare_inputs_for_generation( |
|
|
self, |
|
|
input_ids, |
|
|
past_key_values=None, |
|
|
attention_mask=None, |
|
|
**kwargs, |
|
|
): |
|
|
|
|
|
return {"input_ids": input_ids, "attention_mask": attention_mask} |
|
|
|
|
|
|
|
|
def _reorder_cache(self, past_key_values, beam_idx): |
|
|
return past_key_values |
|
|
|
|
|
|