|
|
from typing import Literal |
|
|
|
|
|
import torch |
|
|
from torch import nn |
|
|
from transformers import ( |
|
|
AutoConfig, |
|
|
AutoModel, |
|
|
AutoModelForCausalLM, |
|
|
AutoModelForMaskedLM, |
|
|
DynamicCache, |
|
|
) |
|
|
from transformers.modeling_outputs import ( |
|
|
BaseModelOutputWithPast, |
|
|
CausalLMOutputWithPast, |
|
|
) |
|
|
|
|
|
from .backbone_custom_modeling_qwen3 import CustomQwen3ForCausalLM |
|
|
|
|
|
try: |
|
|
from torch.nn.attention.flex_attention import BlockMask |
|
|
except ImportError: |
|
|
BlockMask = None |
|
|
|
|
|
AUTO_MODEL_CLS = { |
|
|
"AutoModel": AutoModel, |
|
|
"AutoModelForCausalLM": AutoModelForCausalLM, |
|
|
"AutoModelForMaskedLM": AutoModelForMaskedLM, |
|
|
} |
|
|
|
|
|
|
|
|
class AutoModelFromPreTrained(nn.Module): |
|
|
"""Simple wrapper class that enables using AutoModel from pre-trained.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
automodel_cls: Literal[ |
|
|
"AutoModel", |
|
|
"AutoModelForCausalLM", |
|
|
"AutoModelForMaskedLM", |
|
|
], |
|
|
pretrained_model_name_or_path: str, |
|
|
trust_remote_code: bool = True, |
|
|
num_layers: int = -1, |
|
|
keep_top_layers: bool = False, |
|
|
reinit_model: bool = False, |
|
|
use_causal_mask: bool = False, |
|
|
**automodel_init_kwargs, |
|
|
): |
|
|
super().__init__() |
|
|
self.use_causal_mask = use_causal_mask |
|
|
if reinit_model: |
|
|
auto_config = AutoConfig.from_pretrained( |
|
|
pretrained_model_name_or_path, |
|
|
num_hidden_layers=num_layers, |
|
|
trust_remote_code=trust_remote_code, |
|
|
**automodel_init_kwargs, |
|
|
) |
|
|
self.model = CustomQwen3ForCausalLM(auto_config) |
|
|
|
|
|
else: |
|
|
self.model = AUTO_MODEL_CLS[automodel_cls].from_pretrained( |
|
|
pretrained_model_name_or_path, |
|
|
trust_remote_code=trust_remote_code, |
|
|
**automodel_init_kwargs, |
|
|
) |
|
|
num_layers = ( |
|
|
len(self.model.model.layers) if num_layers == -1 else num_layers |
|
|
) |
|
|
if keep_top_layers: |
|
|
self.model.model.layers = self.model.model.layers[-num_layers:] |
|
|
else: |
|
|
self.model.model.layers = self.model.model.layers[:num_layers] |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: torch.LongTensor, |
|
|
attention_mask: torch.FloatTensor | BlockMask | None = None, |
|
|
position_ids: torch.LongTensor | None = None, |
|
|
cache_position: torch.LongTensor | None = None, |
|
|
past_key_values: DynamicCache | None = None, |
|
|
fix_cache_length: bool = False, |
|
|
return_updated_cache=False, |
|
|
**kwargs, |
|
|
) -> CausalLMOutputWithPast | BaseModelOutputWithPast: |
|
|
prev_cache_len = None |
|
|
if past_key_values is not None and fix_cache_length: |
|
|
prev_cache_len = [ |
|
|
past_key_values[i][0].shape[-2] |
|
|
for i in range(len(past_key_values)) |
|
|
] |
|
|
if self.use_causal_mask: |
|
|
attention_mask = None |
|
|
model_output = self.model( |
|
|
input_ids, |
|
|
attention_mask=attention_mask, |
|
|
position_ids=position_ids, |
|
|
cache_position=cache_position, |
|
|
past_key_values=past_key_values, |
|
|
**kwargs, |
|
|
) |
|
|
if return_updated_cache: |
|
|
return BaseModelOutputWithPast(past_key_values=model_output.past_key_values) |
|
|
if ( |
|
|
prev_cache_len is not None |
|
|
and model_output.get("past_key_values", None) is not None |
|
|
): |
|
|
|
|
|
|
|
|
for i, cache_len in enumerate(prev_cache_len): |
|
|
model_output.past_key_values.key_cache[i] = ( |
|
|
model_output.past_key_values.key_cache[i][..., :cache_len, :] |
|
|
) |
|
|
model_output.past_key_values.value_cache[i] = ( |
|
|
model_output.past_key_values.value_cache[i][..., :cache_len, :] |
|
|
) |
|
|
return model_output |
|
|
|