# gemma3_biencoder.py from __future__ import annotations import copy import torch import torch.nn as nn from typing import Optional, Tuple, Union from transformers.modeling_outputs import MaskedLMOutput, SequenceClassifierOutput, TokenClassifierOutput from transformers.models.gemma3.configuration_gemma3 import Gemma3TextConfig from transformers.models.gemma3.modeling_gemma3 import ( Gemma3PreTrainedModel, Gemma3TextModel, ) class Gemma3EncoderModel(Gemma3PreTrainedModel): config_class = Gemma3TextConfig base_model_prefix = "encoder" def __init__(self, config): cfg = copy.deepcopy(config) if hasattr(cfg, "use_bidirectional_attention"): cfg.use_bidirectional_attention = True cfg.use_cache = False super().__init__(cfg) self.encoder = Gemma3TextModel(cfg) self.post_init() def forward(self, input_ids=None, attention_mask=None, position_ids=None, inputs_embeds=None, output_attentions=None, output_hidden_states=None, return_dict=True, **kwargs): return self.encoder( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, inputs_embeds=inputs_embeds, use_cache=False, is_causal=False, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, **kwargs, ) class Gemma3EncoderForMaskedLM(Gemma3PreTrainedModel): config_class = Gemma3TextConfig base_model_prefix = "encoder" _tied_weights_keys = ["lm_head.weight"] _keys_to_ignore_on_load_missing = [r"lm_head\.weight"] def __init__(self, config: Gemma3TextConfig): cfg = copy.deepcopy(config) if hasattr(cfg, "use_bidirectional_attention"): cfg.use_bidirectional_attention = True cfg.use_cache = False super().__init__(cfg) self.encoder = Gemma3TextModel(cfg) self.vocab_size = cfg.vocab_size self.lm_head = nn.Linear(cfg.hidden_size, cfg.vocab_size, bias=False) self.post_init() # calls tie_weights() # Embeddings / head def get_input_embeddings(self): return self.encoder.embed_tokens def set_input_embeddings(self, new_embeddings): self.encoder.embed_tokens = new_embeddings def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_head: nn.Module): self.lm_head = new_head # Keep vocab_size in sync; ensure pointer-tying def tie_weights(self): if hasattr(self.config, "vocab_size"): self.config.vocab_size = self.get_input_embeddings().num_embeddings self.vocab_size = self.config.vocab_size if getattr(self.config, "tie_word_embeddings", True): self._tie_or_clone_weights(self.lm_head, self.get_input_embeddings()) # Ensure 'lm_head.weight' exists when saving (avoids resume warnings) def state_dict(self, *args, **kwargs): sd = super().state_dict(*args, **kwargs) if "lm_head.weight" not in sd and getattr(self.config, "tie_word_embeddings", True): emb_key = f"{self.base_model_prefix}.embed_tokens.weight" if emb_key in sd: sd["lm_head.weight"] = sd[emb_key] return sd def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = True, **kwargs, ) -> Union[MaskedLMOutput, Tuple[torch.Tensor, ...]]: outputs = self.encoder( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, inputs_embeds=inputs_embeds, use_cache=False, is_causal=False, output_attentions=output_attentions, output_hidden_states=output_hidden_states, **kwargs, ) hidden_states = outputs.last_hidden_state logits = self.lm_head(hidden_states) loss = None if labels is not None: loss_fct = nn.CrossEntropyLoss(ignore_index=-100) loss = loss_fct(logits.view(-1, self.vocab_size), labels.view(-1)) if not return_dict: out = (logits, hidden_states) if output_hidden_states: out += (outputs.hidden_states,) if output_attentions: out += (outputs.attentions,) if loss is not None: out = (loss,) + out return out return MaskedLMOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) class Gemma3EncoderForSequenceClassification(Gemma3PreTrainedModel): """Gemma3 Encoder with a sequence classification head (mean pooling + linear).""" config_class = Gemma3TextConfig base_model_prefix = "encoder" def __init__(self, config: Gemma3TextConfig): cfg = copy.deepcopy(config) if hasattr(cfg, "use_bidirectional_attention"): cfg.use_bidirectional_attention = True cfg.use_cache = False super().__init__(cfg) self.num_labels = getattr(cfg, "num_labels", 2) self.encoder = Gemma3TextModel(cfg) classifier_dropout = getattr(cfg, "classifier_dropout", 0.0) self.dropout = nn.Dropout(classifier_dropout) self.classifier = nn.Linear(cfg.hidden_size, self.num_labels) self.post_init() def get_input_embeddings(self): return self.encoder.embed_tokens def set_input_embeddings(self, new_embeddings): self.encoder.embed_tokens = new_embeddings def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = True, **kwargs, ) -> Union[SequenceClassifierOutput, Tuple[torch.Tensor, ...]]: outputs = self.encoder( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, inputs_embeds=inputs_embeds, use_cache=False, is_causal=False, output_attentions=output_attentions, output_hidden_states=output_hidden_states, **kwargs, ) hidden_states = outputs.last_hidden_state # (batch, seq_len, hidden) # Mean pooling over non-padded tokens if attention_mask is not None: mask = attention_mask.unsqueeze(-1).float() # (batch, seq_len, 1) pooled = (hidden_states * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1e-9) else: pooled = hidden_states.mean(dim=1) pooled = self.dropout(pooled) logits = self.classifier(pooled) loss = None if labels is not None: if self.config.problem_type is None: if self.num_labels == 1: self.config.problem_type = "regression" elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): self.config.problem_type = "single_label_classification" else: self.config.problem_type = "multi_label_classification" if self.config.problem_type == "regression": loss_fct = nn.MSELoss() if self.num_labels == 1: loss = loss_fct(logits.squeeze(), labels.squeeze()) else: loss = loss_fct(logits, labels) elif self.config.problem_type == "single_label_classification": loss_fct = nn.CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) elif self.config.problem_type == "multi_label_classification": loss_fct = nn.BCEWithLogitsLoss() loss = loss_fct(logits, labels) if not return_dict: output = (logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output return SequenceClassifierOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) class Gemma3EncoderForTokenClassification(Gemma3PreTrainedModel): """Gemma3 Encoder with a token classification head for NER/POS tagging.""" config_class = Gemma3TextConfig base_model_prefix = "encoder" def __init__(self, config: Gemma3TextConfig): cfg = copy.deepcopy(config) if hasattr(cfg, "use_bidirectional_attention"): cfg.use_bidirectional_attention = True cfg.use_cache = False super().__init__(cfg) self.num_labels = getattr(cfg, "num_labels", 2) self.encoder = Gemma3TextModel(cfg) classifier_dropout = getattr(cfg, "classifier_dropout", 0.0) self.dropout = nn.Dropout(classifier_dropout) self.classifier = nn.Linear(cfg.hidden_size, self.num_labels) self.post_init() def get_input_embeddings(self): return self.encoder.embed_tokens def set_input_embeddings(self, new_embeddings): self.encoder.embed_tokens = new_embeddings def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = True, **kwargs, ) -> Union[TokenClassifierOutput, Tuple[torch.Tensor, ...]]: outputs = self.encoder( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, inputs_embeds=inputs_embeds, use_cache=False, is_causal=False, output_attentions=output_attentions, output_hidden_states=output_hidden_states, **kwargs, ) hidden_states = outputs.last_hidden_state hidden_states = self.dropout(hidden_states) logits = self.classifier(hidden_states) loss = None if labels is not None: loss_fct = nn.CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) if not return_dict: output = (logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output return TokenClassifierOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, )