Safetensors
regresslm
custom_code
RLM-GemmaS-Code-v0 / modeling_regresslm.py
akhauriyash's picture
Initial upload: RegressLM export for RLM-GemmaS-Code-v0
ee2b58e verified
"""Hugging Face wrapper around a T5 Gemma backbone with a numeric decoder."""
from __future__ import annotations
from dataclasses import dataclass
from typing import Optional, Sequence
import torch
import torch.nn.functional as F
from torch import nn
from transformers import T5GemmaConfig, T5GemmaForConditionalGeneration
from transformers.generation.logits_process import LogitsProcessor, LogitsProcessorList
from transformers.modeling_outputs import Seq2SeqLMOutput
from transformers import PreTrainedModel, GenerationMixin
from .configuration_regresslm import RegressLMConfig
from .tokenization_p10 import IEEEFloatTokenizer, P10Tokenizer
@dataclass
class RegressLMOutput(Seq2SeqLMOutput):
"""Extends the default seq2seq output with optional regression logits."""
regression_logits: Optional[torch.Tensor] = None
class _NumericConstraintHelper:
"""Utility that mirrors the `DecoderVocab` logic for numeric decoding."""
def __init__(self, tokenizer) -> None:
self.tokenizer = tokenizer
self.num_tokens_per_obj = tokenizer.num_tokens_per_obj
self.pad_token_id = tokenizer.pad_token_id
def allowed_token_ids(self, prev_token_ids: Sequence[int]) -> list[int]:
return self.tokenizer.possible_next_token_ids(prev_token_ids)
def decode(self, token_ids: Sequence[int]) -> list[float]:
return self.tokenizer.token_ids_to_floats(token_ids)
class _NumericLogitsProcessor(LogitsProcessor):
"""Constrains generation so only valid numeric tokens appear."""
def __init__(self, helper: _NumericConstraintHelper):
self.helper = helper
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: # type: ignore[override]
batch, _ = input_ids.shape
updated_scores = scores.clone()
for row in range(batch):
prev_ids = input_ids[row].tolist()
allowed = self.helper.allowed_token_ids(prev_ids)
vocab = updated_scores.shape[-1]
if any(i < 0 or i >= vocab for i in allowed):
raise ValueError(
f"Numeric constraint produced out-of-range id(s): "
f"max={max(allowed)}, vocab={vocab}. "
f"Check tokenizer <-> decoder vocab alignment."
)
mask = torch.full_like(updated_scores[row], float("-inf"))
mask[allowed] = 0.0
updated_scores[row] = updated_scores[row] + mask
return updated_scores
class RegressLMForConditionalGeneration(PreTrainedModel, GenerationMixin):
"""Drop-in Hugging Face model that mirrors ``PyTorchModel`` for inference."""
config_class = RegressLMConfig
base_model_prefix = "model"
def __init__(self, config: RegressLMConfig) -> None:
super().__init__(config)
backbone_cfg = T5GemmaConfig(**config.backbone_config)
self.model = T5GemmaForConditionalGeneration(backbone_cfg)
# Encoder vocabulary: optionally resize the shared embedding.
if config.encoder_vocab_size is not None:
cur = self.model.get_input_embeddings().num_embeddings
if cur != config.encoder_vocab_size:
self.model.resize_token_embeddings(config.encoder_vocab_size)
# Decoder vocabulary: always detach from the shared embedding so we can
# host the numeric tokens.
if config.decoder_vocab_size is not None:
self._resize_decoder_vocab(config.decoder_vocab_size)
hidden_size = getattr(self.model.config.encoder, "d_model", None)
if hidden_size is None:
hidden_size = getattr(self.model.config.encoder, "hidden_size")
if hidden_size is None:
raise ValueError("Unable to infer hidden size from backbone config.")
self.use_regression_head = config.use_regression_head
if self.use_regression_head:
self.regression_head = nn.Linear(hidden_size, 1)
else:
self.regression_head = None
decoder_spec = getattr(config, "decoder_tokenizer", "P10").upper()
if decoder_spec.startswith("IEEE"):
mantissa_digits = getattr(config, "ieee_mantissa_digits", None)
exponent_digits = getattr(config, "ieee_exponent_digits", None)
if mantissa_digits is None or exponent_digits is None:
raise ValueError(
"Config missing IEEE tokenizer parameters: `ieee_mantissa_digits` and `ieee_exponent_digits`."
)
tokenizer = IEEEFloatTokenizer(
base=getattr(config, "ieee_base", 10),
num_mantissa_digits=mantissa_digits,
num_exponent_digits=exponent_digits,
)
else:
tokenizer = P10Tokenizer(
num_digits=getattr(config, "num_digits", 6),
exponent_range=getattr(config, "exponent_range", 10),
)
self.constraint_helper = _NumericConstraintHelper(tokenizer)
# Sanity-check: decoder vocab size must match numeric tokenizer size
if (config.decoder_vocab_size is not None and
config.decoder_vocab_size != self.constraint_helper.tokenizer.decoder_vocab_size):
raise ValueError(
f"Decoder vocab mismatch: model={config.decoder_vocab_size} "
f"tokenizer={self.constraint_helper.tokenizer.decoder_vocab_size}. "
"Make sure the tokenizer does NOT add PAD and preserves training order."
)
self.num_tokens_per_obj = config.num_tokens_per_obj
self.max_num_objs = config.max_num_objs
self.decoder_start_token_id = config.pad_token_id
self.post_init()
# ------------------------------------------------------------------
# Helpers mirroring PyTorchModel utilities
# ------------------------------------------------------------------
def _resize_decoder_vocab(self, vocab_size: int) -> None:
decoder = self.model.model.decoder.embed_tokens
if decoder.num_embeddings != vocab_size:
self.model.model.decoder.embed_tokens = nn.Embedding(
vocab_size,
decoder.embedding_dim,
padding_idx=decoder.padding_idx,
)
hidden = getattr(self.model.config.encoder, "d_model", None)
if hidden is None:
hidden = getattr(self.model.config.encoder, "hidden_size")
if hasattr(self.model.lm_head, "out_proj"):
self.model.lm_head.out_proj = nn.Linear(hidden, vocab_size, bias=False)
else:
self.model.lm_head = nn.Linear(hidden, vocab_size, bias=False)
self.model.register_buffer("final_logits_bias", torch.zeros((1, vocab_size)))
if hasattr(self.model.config, "tie_word_embeddings"):
self.model.config.tie_word_embeddings = False
def _pool_encoder_outputs(self, memory: torch.Tensor, pad_mask: torch.Tensor) -> torch.Tensor:
mask = (~pad_mask).unsqueeze(-1).type_as(memory)
denom = mask.sum(dim=1).clamp(min=1.0)
return (memory * mask).sum(dim=1) / denom
def get_encoder(self): # type: ignore[override]
return self.model.get_encoder()
def get_decoder(self): # type: ignore[override]
return self.model.get_decoder()
def get_input_embeddings(self): # type: ignore[override]
return self.model.get_input_embeddings()
def set_input_embeddings(self, new_embeddings): # type: ignore[override]
self.model.set_input_embeddings(new_embeddings)
def tie_weights(self): # type: ignore[override]
# Word embeddings are intentionally untied once the decoder vocab is replaced.
pass
# ------------------------------------------------------------------
# Forward pass mirroring the training-time behaviour
# ------------------------------------------------------------------
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
labels: Optional[torch.LongTensor] = None,
regression_targets: Optional[torch.Tensor] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> RegressLMOutput:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if self.use_regression_head:
if input_ids is None:
raise ValueError("`input_ids` must be provided for regression inference.")
if attention_mask is None:
attention_mask = (input_ids != self.config.pad_token_id).long()
encoder_out = self.model.get_encoder()(
input_ids=input_ids,
attention_mask=attention_mask,
return_dict=True,
)
memory = encoder_out.last_hidden_state
pad_mask = attention_mask == 0
pooled = self._pool_encoder_outputs(memory, pad_mask)
preds = self.regression_head(pooled).squeeze(-1)
loss = None
target = regression_targets if regression_targets is not None else labels
if target is not None:
loss = F.mse_loss(preds, target.to(preds.dtype))
if not return_dict:
return (loss, preds) if loss is not None else (preds,)
return RegressLMOutput(
loss=loss,
logits=None,
regression_logits=preds,
encoder_last_hidden_state=memory,
)
proc_labels = None
if labels is not None:
proc_labels = labels.clone()
proc_labels[proc_labels == self.config.pad_token_id] = -100
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
labels=proc_labels,
return_dict=True,
**kwargs,
)
if not return_dict:
if outputs.loss is None:
return (outputs.logits,)
return (outputs.loss, outputs.logits)
return RegressLMOutput(
loss=outputs.loss,
logits=outputs.logits,
past_key_values=outputs.past_key_values,
decoder_hidden_states=outputs.decoder_hidden_states,
decoder_attentions=outputs.decoder_attentions,
cross_attentions=outputs.cross_attentions,
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
)
# ------------------------------------------------------------------
# Generation helpers
# ------------------------------------------------------------------
def prepare_inputs_for_generation(self, *args, **kwargs): # type: ignore[override]
return self.model.prepare_inputs_for_generation(*args, **kwargs)
def _get_logits_processor( # type: ignore[override]
self,
generation_config,
input_ids_seq_length=None,
encoder_input_ids=None,
prefix_allowed_tokens_fn=None,
logits_processor=None,
device=None,
model_kwargs=None,
negative_prompt_ids=None,
negative_prompt_attention_mask=None,
):
processors = super()._get_logits_processor(
generation_config=generation_config,
input_ids_seq_length=input_ids_seq_length,
encoder_input_ids=encoder_input_ids,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
logits_processor=logits_processor,
device=device,
model_kwargs=model_kwargs,
negative_prompt_ids=negative_prompt_ids,
negative_prompt_attention_mask=negative_prompt_attention_mask,
)
processors.append(_NumericLogitsProcessor(self.constraint_helper))
return processors
def generate(self, *args, **kwargs): # type: ignore[override]
if "decoder_start_token_id" not in kwargs:
kwargs["decoder_start_token_id"] = self.config.pad_token_id
if "max_new_tokens" not in kwargs and "max_length" not in kwargs:
kwargs["max_new_tokens"] = self.config.max_num_objs * self.num_tokens_per_obj
return super().generate(*args, **kwargs)
# ------------------------------------------------------------------
# Convenience helper used after generation
# ------------------------------------------------------------------
def decode_to_floats(self, sequences: torch.Tensor | Sequence[Sequence[int]]) -> list[list[float]]:
if isinstance(sequences, torch.Tensor):
iterable = sequences.cpu().tolist()
else:
iterable = sequences
return [self.constraint_helper.decode(seq) for seq in iterable]