kokolamba commited on
Commit
c5d2ebc
·
1 Parent(s): e8e025f

Add 2nd patch to fix DecoderLM head not having generate()

Browse files
Files changed (1) hide show
  1. task_heads.py +17 -1
task_heads.py CHANGED
@@ -5,6 +5,7 @@ import torch.nn.functional as F
5
  from typing import Optional, Union
6
 
7
  from transformers.modeling_outputs import CausalLMOutputWithPast
 
8
 
9
  from .shared_space_config import SharedSpaceDecoderConfig
10
  from .shared_space_decoder import (
@@ -34,7 +35,7 @@ def create_norm_layer(hidden_size: int, config: SharedSpaceDecoderConfig) -> nn.
34
  raise ValueError(f"Unknown norm_type: {config.norm_type}")
35
 
36
 
37
- class SharedSpaceDecoderForCausalLM(SharedSpaceDecoderPreTrainedModel):
38
  """
39
  Subspace Decoder model with a causal language modeling head.
40
 
@@ -207,4 +208,19 @@ class SharedSpaceDecoderForCausalLM(SharedSpaceDecoderPreTrainedModel):
207
  hidden_states=hidden_states if kwargs.get("output_hidden_states", False) else None,
208
  attentions=None,
209
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
 
 
5
  from typing import Optional, Union
6
 
7
  from transformers.modeling_outputs import CausalLMOutputWithPast
8
+ from transformers.generation.utils import GenerationMixin
9
 
10
  from .shared_space_config import SharedSpaceDecoderConfig
11
  from .shared_space_decoder import (
 
35
  raise ValueError(f"Unknown norm_type: {config.norm_type}")
36
 
37
 
38
+ class SharedSpaceDecoderForCausalLM(GenerationMixin, SharedSpaceDecoderPreTrainedModel):
39
  """
40
  Subspace Decoder model with a causal language modeling head.
41
 
 
208
  hidden_states=hidden_states if kwargs.get("output_hidden_states", False) else None,
209
  attentions=None,
210
  )
211
+
212
+ # ---- Add this minimal bridge for generation: PATCH 2----
213
+ def prepare_inputs_for_generation(
214
+ self,
215
+ input_ids,
216
+ past_key_values=None,
217
+ attention_mask=None,
218
+ **kwargs,
219
+ ):
220
+ # If you add KV cache later: if past_key_values is not None, slice to input_ids[:, -1:]
221
+ return {"input_ids": input_ids, "attention_mask": attention_mask}
222
+
223
+ # Optional; harmless no-op when you don’t have cache yet
224
+ def _reorder_cache(self, past_key_values, beam_idx):
225
+ return past_key_values
226