File size: 9,117 Bytes
d56eb1d
 
 
 
 
 
 
c5d2ebc
d56eb1d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c5d2ebc
d56eb1d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e8e025f
 
 
 
 
 
 
 
 
 
 
 
 
 
d56eb1d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c5d2ebc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d56eb1d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Union

from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.generation.utils import GenerationMixin

from .shared_space_config import SharedSpaceDecoderConfig
from .shared_space_decoder import (
    SharedSpaceDecoderPreTrainedModel,
    SharedSpaceDecoderModel,
    DeepseekV3RMSNorm
)

def create_norm_layer(hidden_size: int, config: SharedSpaceDecoderConfig) -> nn.Module:
    """
    Create a normalization layer based on the config norm_type.
    
    Args:
        hidden_size: The dimension to normalize over
        config: Configuration containing norm_type and epsilon values
    
    Returns:
        Either a LayerNorm or RMSNorm layer
    """
    if config.norm_type == "layernorm":
        return nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)
    elif config.norm_type == "rmsnorm":
        from .shared_space_decoder import DeepseekV3RMSNorm
        return DeepseekV3RMSNorm(hidden_size, eps=config.rms_norm_eps)
    else:
        # This should be caught by config validation, but being defensive
        raise ValueError(f"Unknown norm_type: {config.norm_type}")


class SharedSpaceDecoderForCausalLM(GenerationMixin, SharedSpaceDecoderPreTrainedModel):
    """
    Subspace Decoder model with a causal language modeling head.
    
    This model extends the SharedSpaceDecoderModel with:
    - A language modeling head that projects hidden states to vocabulary logits
    - Support for computing cross-entropy loss for language modeling
    - Proper HuggingFace compatibility for causal language modeling tasks
    - Decoder-specific initialization strategies
    
    The model can be used for:
    - Text generation
    - Language modeling pretraining  
    - Fine-tuning on downstream tasks
    """

    def __init__(self, config: SharedSpaceDecoderConfig) -> None:
        super().__init__(config)
        
        # Initialize the base decoder model
        self.model = SharedSpaceDecoderModel(config)
        
        # Final layer norm before the language modeling head
        self.norm = create_norm_layer(config.hidden_size, config)
        
        # Language modeling head
        # Projects from hidden_size to vocab_size to get logits for each token
        self.lm_head = nn.Linear(
            config.hidden_size,
            config.vocab_size,
            bias=False  # Following common practice in modern LMs
        )
        
        # Initialize weights with decoder-specific strategy
        # Note: tie_weights() will be called automatically by post_init() if config.tie_word_embeddings=True
        self.post_init()

    def _init_weights(self, module: nn.Module) -> None:
        """
        Decoder-specific weight initialization with special handling for language modeling head.
        
        Key differences from encoder initialization:
        - Language modeling head gets specialized initialization for stability
        - Configurable normalization layers (LayerNorm or RMSNorm) are properly handled  
        - Weight tying considerations for embedding/lm_head relationship
        """
        
        # Use the base class initialization for most modules
        super()._init_weights(module)
        
        # Special handling for language modeling head
        if module is self.lm_head:
            # Use smaller initialization for the language modeling head
            # This helps with training stability in autoregressive generation
            # Common practice is to use std=initializer_range or smaller
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            
            # If weight tying is not used, we might want even smaller init
            if self.model.vocab_proj is not None:
                # For vocab subspace models where weights aren't tied,
                # use a smaller scale to prevent initial logits from being too large
                module.weight.data.normal_(mean=0.0, std=self.config.initializer_range * 0.5)

    def get_input_embeddings(self):
        """Return the input embedding layer for compatibility with HuggingFace."""
        return self.model.vocab_embed

    def set_input_embeddings(self, value):
        """Set the input embedding layer for compatibility with HuggingFace."""
        self.model.vocab_embed = value

    def get_output_embeddings(self):
        """Return the output embedding layer (lm_head) for compatibility."""
        return self.lm_head

    def set_output_embeddings(self, new_embeddings):
        """Set the output embedding layer for compatibility."""
        self.lm_head = new_embeddings

    def tie_weights(self):
        """
        Tie the input and output embedding weights.
        
        This method sets the language modeling head's weight to be the same as 
        the input embedding weight. This reduces the number of parameters and
        is a common practice in modern language models.
        
        Note: For vocab subspace models, we need to handle the case where
        input embeddings go through a projection layer.
        """
        # Only tie when embeddings live in model space (no vocab_proj)
        if getattr(self.model, "vocab_proj", None) is None:
            # Use HF utility for correct tying/cloning semantics
            self._tie_or_clone_weights(self.lm_head, self.model.vocab_embed)
        # else: leave untied for subspace case


    def forward(
        self,
        input_ids: torch.LongTensor,
        attention_mask: Optional[torch.Tensor] = None,
        labels: Optional[torch.LongTensor] = None,
        **kwargs,
    ) -> Union[CausalLMOutputWithPast, tuple]:
        """
        Forward pass for causal language modeling.
        
        Args:
            input_ids: Token ids of shape [batch_size, seq_len]
            attention_mask: Attention mask of shape [batch_size, seq_len] 
                           (1 for real tokens, 0 for padding)
            labels: Ground truth token ids for computing loss. Same shape as input_ids.
                   If provided, loss will be computed. Typically input_ids shifted by 1.
            
        Returns:
            CausalLMOutputWithPast containing:
            - logits: Prediction logits of shape [batch_size, seq_len, vocab_size]  
            - loss: Cross-entropy loss if labels provided, else None
            - hidden_states: Final layer hidden states [batch_size, seq_len, hidden_size]
        """

        # Adding a patch for when attention_mask is None
        # ---------------------------
        # >>> PATCH: ensure a mask if none is provided
        # ---------------------------
        if attention_mask is None and input_ids is not None:
            # Create an all-ones mask (no padding) so SDPA mask prep won’t crash
            # dtype long/bool are both accepted by HF mask utils; long is common.
            attention_mask = torch.ones(
                (input_ids.size(0), input_ids.size(1)),
                dtype=torch.long,
                device=input_ids.device,
            )
        # ---------------------------
        
        # Run the base decoder model
        # This applies all the transformer layers with causal attention
        hidden_states = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            **kwargs
        )
        
        # Apply final layer normalization
        # This normalizes the final hidden states before the language modeling head
        hidden_states = self.norm(hidden_states)
        
        # Project to vocabulary logits
        # Shape: [batch_size, seq_len, vocab_size]
        logits = self.lm_head(hidden_states)
        
        # Compute loss if labels are provided
        # Previously, we had custom loss computation here, but now we use the 
        # standard HuggingFace loss function.
        loss = None
        if labels is not None:
            # Flatten the tokens
            loss = self.loss_function(
                logits,
                labels,
                vocab_size=self.config.vocab_size,
                **kwargs,
            )
        
        # Return in HuggingFace format
        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=None,  # Not implementing KV cache yet
            #hidden_states=hidden_states,
            hidden_states=hidden_states if kwargs.get("output_hidden_states", False) else None,
            attentions=None,
        )
    
    # ---- Add this minimal bridge for generation: PATCH 2----
    def prepare_inputs_for_generation(
        self,
        input_ids,
        past_key_values=None,
        attention_mask=None,
        **kwargs,
    ):
        # If you add KV cache later: if past_key_values is not None, slice to input_ids[:, -1:]
        return {"input_ids": input_ids, "attention_mask": attention_mask}

    # Optional; harmless no-op when you don’t have cache yet
    def _reorder_cache(self, past_key_values, beam_idx):
        return past_key_values