Feature Extraction
Transformers
PyTorch
e2d2
custom_code
e2d2-wmt / backbone_automodel.py
yairschiff's picture
Add model and code
b5a4dcb verified
raw
history blame
4.16 kB
from typing import Literal
import torch
from torch import nn
from transformers import (
AutoConfig,
AutoModel,
AutoModelForCausalLM,
AutoModelForMaskedLM,
DynamicCache,
)
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
)
from .backbone_custom_modeling_qwen3 import CustomQwen3ForCausalLM
try:
from torch.nn.attention.flex_attention import BlockMask
except ImportError:
BlockMask = None
AUTO_MODEL_CLS = {
"AutoModel": AutoModel,
"AutoModelForCausalLM": AutoModelForCausalLM,
"AutoModelForMaskedLM": AutoModelForMaskedLM,
}
class AutoModelFromPreTrained(nn.Module):
"""Simple wrapper class that enables using AutoModel from pre-trained."""
def __init__(
self,
automodel_cls: Literal[
"AutoModel",
"AutoModelForCausalLM",
"AutoModelForMaskedLM",
],
pretrained_model_name_or_path: str,
trust_remote_code: bool = True,
num_layers: int = -1,
keep_top_layers: bool = False,
reinit_model: bool = False,
use_causal_mask: bool = False,
**automodel_init_kwargs,
):
super().__init__()
self.use_causal_mask = use_causal_mask
if reinit_model:
auto_config = AutoConfig.from_pretrained(
pretrained_model_name_or_path,
num_hidden_layers=num_layers,
trust_remote_code=trust_remote_code,
**automodel_init_kwargs,
)
self.model = CustomQwen3ForCausalLM(auto_config)
# self.model = AUTO_MODEL_CLS[automodel_cls].from_config(auto_config)
else:
self.model = AUTO_MODEL_CLS[automodel_cls].from_pretrained(
pretrained_model_name_or_path,
trust_remote_code=trust_remote_code,
**automodel_init_kwargs,
)
num_layers = (
len(self.model.model.layers) if num_layers == -1 else num_layers
)
if keep_top_layers:
self.model.model.layers = self.model.model.layers[-num_layers:]
else:
self.model.model.layers = self.model.model.layers[:num_layers]
def forward(
self,
input_ids: torch.LongTensor,
attention_mask: torch.FloatTensor | BlockMask | None = None,
position_ids: torch.LongTensor | None = None,
cache_position: torch.LongTensor | None = None,
past_key_values: DynamicCache | None = None,
fix_cache_length: bool = False, # False for AR, True for diffusion models
return_updated_cache=False,
**kwargs,
) -> CausalLMOutputWithPast | BaseModelOutputWithPast:
prev_cache_len = None
if past_key_values is not None and fix_cache_length:
prev_cache_len = [
past_key_values[i][0].shape[-2] # type: ignore
for i in range(len(past_key_values))
]
if self.use_causal_mask:
attention_mask = None # None --> enforces use of causal mask
model_output = self.model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
cache_position=cache_position,
past_key_values=past_key_values,
**kwargs,
)
if return_updated_cache:
return BaseModelOutputWithPast(past_key_values=model_output.past_key_values)
if (
prev_cache_len is not None
and model_output.get("past_key_values", None) is not None
):
# DynamicCache extends along sequence dimension by default;
# truncate back to original cache len
for i, cache_len in enumerate(prev_cache_len):
model_output.past_key_values.key_cache[i] = (
model_output.past_key_values.key_cache[i][..., :cache_len, :]
)
model_output.past_key_values.value_cache[i] = (
model_output.past_key_values.value_cache[i][..., :cache_len, :]
)
return model_output