Add 2nd patch to fix DecoderLM head not having generate()
Browse files- task_heads.py +17 -1
task_heads.py
CHANGED
|
@@ -5,6 +5,7 @@ import torch.nn.functional as F
|
|
| 5 |
from typing import Optional, Union
|
| 6 |
|
| 7 |
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
|
|
| 8 |
|
| 9 |
from .shared_space_config import SharedSpaceDecoderConfig
|
| 10 |
from .shared_space_decoder import (
|
|
@@ -34,7 +35,7 @@ def create_norm_layer(hidden_size: int, config: SharedSpaceDecoderConfig) -> nn.
|
|
| 34 |
raise ValueError(f"Unknown norm_type: {config.norm_type}")
|
| 35 |
|
| 36 |
|
| 37 |
-
class SharedSpaceDecoderForCausalLM(SharedSpaceDecoderPreTrainedModel):
|
| 38 |
"""
|
| 39 |
Subspace Decoder model with a causal language modeling head.
|
| 40 |
|
|
@@ -207,4 +208,19 @@ class SharedSpaceDecoderForCausalLM(SharedSpaceDecoderPreTrainedModel):
|
|
| 207 |
hidden_states=hidden_states if kwargs.get("output_hidden_states", False) else None,
|
| 208 |
attentions=None,
|
| 209 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 210 |
|
|
|
|
| 5 |
from typing import Optional, Union
|
| 6 |
|
| 7 |
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 8 |
+
from transformers.generation.utils import GenerationMixin
|
| 9 |
|
| 10 |
from .shared_space_config import SharedSpaceDecoderConfig
|
| 11 |
from .shared_space_decoder import (
|
|
|
|
| 35 |
raise ValueError(f"Unknown norm_type: {config.norm_type}")
|
| 36 |
|
| 37 |
|
| 38 |
+
class SharedSpaceDecoderForCausalLM(GenerationMixin, SharedSpaceDecoderPreTrainedModel):
|
| 39 |
"""
|
| 40 |
Subspace Decoder model with a causal language modeling head.
|
| 41 |
|
|
|
|
| 208 |
hidden_states=hidden_states if kwargs.get("output_hidden_states", False) else None,
|
| 209 |
attentions=None,
|
| 210 |
)
|
| 211 |
+
|
| 212 |
+
# ---- Add this minimal bridge for generation: PATCH 2----
|
| 213 |
+
def prepare_inputs_for_generation(
|
| 214 |
+
self,
|
| 215 |
+
input_ids,
|
| 216 |
+
past_key_values=None,
|
| 217 |
+
attention_mask=None,
|
| 218 |
+
**kwargs,
|
| 219 |
+
):
|
| 220 |
+
# If you add KV cache later: if past_key_values is not None, slice to input_ids[:, -1:]
|
| 221 |
+
return {"input_ids": input_ids, "attention_mask": attention_mask}
|
| 222 |
+
|
| 223 |
+
# Optional; harmless no-op when you don’t have cache yet
|
| 224 |
+
def _reorder_cache(self, past_key_values, beam_idx):
|
| 225 |
+
return past_key_values
|
| 226 |
|