|
|
from dataclasses import dataclass |
|
|
from functools import partial |
|
|
from typing import Optional, Tuple, Union |
|
|
|
|
|
import torch |
|
|
from torch import nn |
|
|
from transformers import AutoConfig, AutoModelForCausalLM |
|
|
from transformers.cache_utils import DynamicCache |
|
|
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs |
|
|
from transformers.modeling_outputs import ( |
|
|
BaseModelOutputWithPast, |
|
|
CausalLMOutputWithPast, |
|
|
ModelOutput, |
|
|
) |
|
|
from transformers.processing_utils import Unpack |
|
|
from transformers.utils import logging |
|
|
|
|
|
from .backbone_custom_modeling_qwen3 import CustomQwen3ForCausalLM |
|
|
|
|
|
try: |
|
|
from torch.nn.attention.flex_attention import BlockMask |
|
|
except ImportError: |
|
|
BlockMask = None |
|
|
|
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class EncoderBaseModelOutputWithPast(ModelOutput): |
|
|
"""Custom (encoder) model output. |
|
|
Stores previous decoder and updated encoder cache and encoder last hidden state. |
|
|
""" |
|
|
|
|
|
past_key_values: Optional[Union[Tuple[Tuple[torch.FloatTensor]], DynamicCache]] = ( |
|
|
None |
|
|
) |
|
|
encoder_last_hidden_state: Optional[torch.FloatTensor] = None |
|
|
encoder_past_key_values: Optional[ |
|
|
Union[Tuple[Tuple[torch.FloatTensor]], DynamicCache] |
|
|
] = None |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class DecoderCausalLMOutputWithPast(ModelOutput): |
|
|
"""Custom (decoder) model output. |
|
|
Stores previous encoder and updated decoder cache and decoder logits. |
|
|
""" |
|
|
|
|
|
logits: Optional[torch.FloatTensor] = None |
|
|
past_key_values: Optional[Union[Tuple[Tuple[torch.FloatTensor]], DynamicCache]] = ( |
|
|
None |
|
|
) |
|
|
encoder_past_key_values: Optional[ |
|
|
Union[Tuple[Tuple[torch.FloatTensor]], DynamicCache] |
|
|
] = None |
|
|
|
|
|
|
|
|
class LLMasEncoderDecoder(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
pretrained_model_name_or_path: str, |
|
|
max_length: int, |
|
|
attn_backend: str = "sdpa", |
|
|
freeze_encoder: bool = False, |
|
|
reinit_encoder: bool = False, |
|
|
reinit_decoder: bool = False, |
|
|
tie_encoder_decoder_weights: bool = False, |
|
|
use_encoder_causal_mask: bool = False, |
|
|
num_encoder_layers: int = -1, |
|
|
num_decoder_layers: int = -1, |
|
|
keep_top_encoder_layers: bool = False, |
|
|
keep_top_decoder_layers: bool = False, |
|
|
use_gradient_checkpointing: bool = False, |
|
|
**llm_init_kwargs, |
|
|
): |
|
|
assert not (tie_encoder_decoder_weights and reinit_decoder), ( |
|
|
"Cannot tie encoder-decoder weights and reinitialize decoder." |
|
|
) |
|
|
assert not (tie_encoder_decoder_weights and freeze_encoder), ( |
|
|
"Cannot freeze encoder weights when tying encoder-decoder weights." |
|
|
) |
|
|
super().__init__() |
|
|
self.use_encoder_causal_mask = use_encoder_causal_mask |
|
|
self.tie_encoder_decoder_weights = tie_encoder_decoder_weights |
|
|
|
|
|
if reinit_encoder: |
|
|
assert num_encoder_layers > 0 |
|
|
encoder_config = AutoConfig.from_pretrained( |
|
|
pretrained_model_name_or_path, |
|
|
trust_remote_code=True, |
|
|
num_hidden_layers=num_encoder_layers, |
|
|
attn_implementation=attn_backend, |
|
|
**llm_init_kwargs, |
|
|
) |
|
|
self.encoder = CustomQwen3ForCausalLM(encoder_config) |
|
|
else: |
|
|
self.encoder = CustomQwen3ForCausalLM.from_pretrained( |
|
|
pretrained_model_name_or_path, |
|
|
trust_remote_code=True, |
|
|
attn_implementation=attn_backend, |
|
|
**llm_init_kwargs, |
|
|
) |
|
|
assert num_encoder_layers <= len(self.encoder.model.layers), ( |
|
|
f"Cannot keep {num_encoder_layers} layers. " |
|
|
f"Pre-trained model only has {len(self.encoder.model.layers)} layers." |
|
|
) |
|
|
num_encoder_layers = ( |
|
|
len(self.encoder.model.layers) |
|
|
if num_encoder_layers == -1 |
|
|
else num_encoder_layers |
|
|
) |
|
|
if keep_top_encoder_layers: |
|
|
self.encoder.model.layers = self.encoder.model.layers[ |
|
|
-num_encoder_layers: |
|
|
] |
|
|
else: |
|
|
self.encoder.model.layers = self.encoder.model.layers[ |
|
|
:num_encoder_layers |
|
|
] |
|
|
|
|
|
if freeze_encoder: |
|
|
for name, param in self.encoder.named_parameters(): |
|
|
if "embed_tokens" not in name: |
|
|
param.requires_grad = False |
|
|
if use_gradient_checkpointing: |
|
|
self.encoder.gradient_checkpointing_enable() |
|
|
|
|
|
if tie_encoder_decoder_weights: |
|
|
self.decoder = self.encoder |
|
|
num_decoder_layers = ( |
|
|
len(self.decoder.model.layers) |
|
|
if num_decoder_layers == -1 |
|
|
else num_decoder_layers |
|
|
) |
|
|
assert num_decoder_layers <= len(self.decoder.model.layers), ( |
|
|
f"Cannot keep {num_decoder_layers} layers. " |
|
|
f"Pre-trained model only has {len(self.decoder.model.layers)} layers." |
|
|
) |
|
|
|
|
|
self.decoder_layer_idxs = list(range(len(self.encoder.model.layers)))[ |
|
|
-num_decoder_layers: |
|
|
] |
|
|
|
|
|
else: |
|
|
if reinit_decoder: |
|
|
assert num_decoder_layers > 0 |
|
|
decoder_config = AutoConfig.from_pretrained( |
|
|
pretrained_model_name_or_path, |
|
|
trust_remote_code=True, |
|
|
num_hidden_layers=num_decoder_layers, |
|
|
attn_implementation=attn_backend, |
|
|
**llm_init_kwargs, |
|
|
) |
|
|
self.decoder = CustomQwen3ForCausalLM(decoder_config) |
|
|
else: |
|
|
self.decoder = CustomQwen3ForCausalLM.from_pretrained( |
|
|
pretrained_model_name_or_path, |
|
|
trust_remote_code=True, |
|
|
attn_implementation=attn_backend, |
|
|
**llm_init_kwargs, |
|
|
) |
|
|
assert num_decoder_layers <= len(self.decoder.model.layers), ( |
|
|
f"Cannot keep {num_decoder_layers} layers. " |
|
|
f"Pre-trained model only has {len(self.decoder.layers)} layers." |
|
|
) |
|
|
if keep_top_decoder_layers: |
|
|
self.decoder.model.layers = self.decoder.model.layers[ |
|
|
-num_decoder_layers: |
|
|
] |
|
|
else: |
|
|
self.decoder.model.layers = self.decoder.model.layers[ |
|
|
:num_decoder_layers |
|
|
] |
|
|
del self.decoder.model.embed_tokens |
|
|
|
|
|
|
|
|
if ( |
|
|
self.encoder.lm_head.weight.data_ptr() |
|
|
== self.encoder.model.embed_tokens.weight.data_ptr() |
|
|
): |
|
|
self.decoder.lm_head = self.encoder.lm_head |
|
|
else: |
|
|
del self.encoder.lm_head |
|
|
if use_gradient_checkpointing: |
|
|
self.decoder.gradient_checkpointing_enable() |
|
|
self.max_length = max_length |
|
|
|
|
|
def freeze_encoder(self): |
|
|
for p in self.encoder.model.parameters(): |
|
|
p.requires_grad = False |
|
|
|
|
|
def unfreeze_encoder(self): |
|
|
for p in self.encoder.model.parameters(): |
|
|
p.requires_grad = True |
|
|
|
|
|
|
|
|
def forward( |
|
|
self, |
|
|
|
|
|
input_ids: torch.LongTensor, |
|
|
attention_mask: Optional[Union[torch.FloatTensor, BlockMask]] = None, |
|
|
position_ids: Optional[torch.LongTensor] = None, |
|
|
cache_position: Optional[torch.LongTensor] = None, |
|
|
past_key_values: Optional[DynamicCache] = None, |
|
|
encoder_last_hidden_state: Optional[torch.FloatTensor] = None, |
|
|
|
|
|
encoder_input_ids: Optional[torch.LongTensor] = None, |
|
|
encoder_attention_mask: Optional[Union[torch.FloatTensor, BlockMask]] = None, |
|
|
encoder_position_ids: Optional[torch.LongTensor] = None, |
|
|
encoder_cache_position: Optional[torch.LongTensor] = None, |
|
|
encoder_past_key_values: Optional[DynamicCache] = None, |
|
|
|
|
|
fix_cache_length: bool = True, |
|
|
return_updated_cache: bool = False, |
|
|
**flash_attn_kwargs: Unpack[FlashAttentionKwargs], |
|
|
) -> Union[DecoderCausalLMOutputWithPast, EncoderBaseModelOutputWithPast]: |
|
|
|
|
|
|
|
|
new_seen_tokens = ( |
|
|
0 |
|
|
if encoder_last_hidden_state is None |
|
|
else encoder_last_hidden_state.shape[1] |
|
|
) |
|
|
|
|
|
if encoder_input_ids is not None: |
|
|
if self.use_encoder_causal_mask: |
|
|
encoder_attention_mask = None |
|
|
if encoder_cache_position is None and encoder_position_ids is not None: |
|
|
encoder_cache_position = encoder_position_ids[0] |
|
|
encoder_output = self.encoder.model( |
|
|
input_ids=encoder_input_ids, |
|
|
attention_mask=encoder_attention_mask, |
|
|
position_ids=encoder_position_ids, |
|
|
use_cache=True, |
|
|
past_key_values=encoder_past_key_values, |
|
|
cache_position=encoder_cache_position, |
|
|
) |
|
|
if return_updated_cache: |
|
|
|
|
|
return EncoderBaseModelOutputWithPast( |
|
|
encoder_last_hidden_state=encoder_output.last_hidden_state, |
|
|
encoder_past_key_values=encoder_output.past_key_values, |
|
|
past_key_values=past_key_values, |
|
|
) |
|
|
encoder_last_hidden_state = encoder_output.last_hidden_state |
|
|
|
|
|
|
|
|
if encoder_last_hidden_state is None: |
|
|
q_start_idx = 0 |
|
|
decoder_hidden_states = self.encoder.model.embed_tokens(input_ids) |
|
|
if cache_position is None: |
|
|
if position_ids is not None: |
|
|
cache_position = position_ids[0] |
|
|
else: |
|
|
past_seen_tokens = ( |
|
|
past_key_values.get_seq_length() |
|
|
if past_key_values is not None |
|
|
else 0 |
|
|
) |
|
|
cache_position = torch.arange( |
|
|
past_seen_tokens, |
|
|
past_seen_tokens + decoder_hidden_states.shape[1], |
|
|
device=decoder_hidden_states.device, |
|
|
) |
|
|
if position_ids is None: |
|
|
position_ids = cache_position.unsqueeze(0) |
|
|
decoder_position_embeddings = self.decoder.model.rotary_emb( |
|
|
decoder_hidden_states, position_ids |
|
|
) |
|
|
else: |
|
|
q_start_idx = encoder_last_hidden_state.shape[1] |
|
|
decoder_hidden_states = self.encoder.model.embed_tokens(input_ids) |
|
|
decoder_hidden_states = torch.cat( |
|
|
[ |
|
|
encoder_last_hidden_state, |
|
|
decoder_hidden_states, |
|
|
], |
|
|
dim=1, |
|
|
) |
|
|
if cache_position is None: |
|
|
if position_ids is not None: |
|
|
cache_position = position_ids[0] |
|
|
else: |
|
|
past_seen_tokens = ( |
|
|
past_key_values.get_seq_length() |
|
|
if past_key_values is not None |
|
|
else 0 |
|
|
) |
|
|
cache_position = torch.cat( |
|
|
[ |
|
|
torch.arange( |
|
|
past_seen_tokens, |
|
|
past_seen_tokens + encoder_last_hidden_state.shape[1], |
|
|
device=decoder_hidden_states.device, |
|
|
), |
|
|
torch.arange( |
|
|
past_seen_tokens + new_seen_tokens, |
|
|
past_seen_tokens + new_seen_tokens + input_ids.shape[1], |
|
|
device=decoder_hidden_states.device, |
|
|
), |
|
|
], |
|
|
dim=-1, |
|
|
) |
|
|
if position_ids is None: |
|
|
position_ids = cache_position.unsqueeze(0) |
|
|
decoder_position_embeddings = self.decoder.model.rotary_emb( |
|
|
decoder_hidden_states, position_ids |
|
|
) |
|
|
|
|
|
if hasattr(self.decoder.model, "_update_causal_mask"): |
|
|
|
|
|
attention_mask = self.decoder.model._update_causal_mask( |
|
|
attention_mask=attention_mask, |
|
|
input_tensor=decoder_hidden_states, |
|
|
cache_position=cache_position, |
|
|
past_key_values=past_key_values, |
|
|
output_attentions=False, |
|
|
) |
|
|
for decoder_layer in self.decoder.model.layers: |
|
|
layer_idx = decoder_layer.self_attn.layer_idx |
|
|
if ( |
|
|
self.tie_encoder_decoder_weights |
|
|
and layer_idx not in self.decoder_layer_idxs |
|
|
): |
|
|
continue |
|
|
|
|
|
|
|
|
if past_key_values is not None and len(past_key_values) > layer_idx: |
|
|
prev_cache_len = past_key_values[layer_idx][0].shape[-2] |
|
|
else: |
|
|
prev_cache_len = 0 |
|
|
cache_len = prev_cache_len + new_seen_tokens |
|
|
|
|
|
if self.decoder.model.gradient_checkpointing and self.training: |
|
|
|
|
|
decoder_hidden_states = self.decoder._gradient_checkpointing_func( |
|
|
partial(decoder_layer.__call__, **flash_attn_kwargs), |
|
|
decoder_hidden_states, |
|
|
attention_mask, |
|
|
position_ids, |
|
|
past_key_values, |
|
|
False, |
|
|
True, |
|
|
cache_position, |
|
|
decoder_position_embeddings, |
|
|
q_start_idx, |
|
|
)[0] |
|
|
else: |
|
|
decoder_hidden_states = decoder_layer( |
|
|
hidden_states=decoder_hidden_states, |
|
|
attention_mask=attention_mask, |
|
|
position_ids=position_ids, |
|
|
past_key_value=past_key_values, |
|
|
output_attentions=False, |
|
|
use_cache=True, |
|
|
cache_position=cache_position, |
|
|
position_embeddings=decoder_position_embeddings, |
|
|
q_start_idx=q_start_idx, |
|
|
**flash_attn_kwargs, |
|
|
)[0] |
|
|
|
|
|
if q_start_idx > 0: |
|
|
decoder_hidden_states = torch.cat( |
|
|
[ |
|
|
encoder_last_hidden_state, |
|
|
decoder_hidden_states, |
|
|
], |
|
|
dim=1, |
|
|
) |
|
|
|
|
|
if past_key_values is not None: |
|
|
|
|
|
|
|
|
past_key_values.key_cache[layer_idx] = past_key_values.key_cache[ |
|
|
layer_idx |
|
|
][..., :cache_len, :] |
|
|
past_key_values.value_cache[layer_idx] = past_key_values.value_cache[ |
|
|
layer_idx |
|
|
][..., :cache_len, :] |
|
|
decoder_hidden_states = self.decoder.model.norm( |
|
|
decoder_hidden_states[:, q_start_idx:, :] |
|
|
) |
|
|
logits = self.decoder.lm_head(decoder_hidden_states) |
|
|
return DecoderCausalLMOutputWithPast( |
|
|
logits=logits, |
|
|
past_key_values=past_key_values, |
|
|
encoder_past_key_values=encoder_past_key_values, |
|
|
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
class LLMasEncoderDecoderShareKV(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
pretrained_model_name_or_path: str, |
|
|
max_length: int, |
|
|
attn_backend: str = "sdpa", |
|
|
freeze_encoder: bool = False, |
|
|
reinit_encoder: bool = False, |
|
|
reinit_decoder: bool = False, |
|
|
tie_encoder_decoder_weights: bool = False, |
|
|
use_encoder_causal_mask: bool = False, |
|
|
num_encoder_layers: int = -1, |
|
|
num_decoder_layers: int = -1, |
|
|
keep_top_encoder_layers: bool = False, |
|
|
keep_top_decoder_layers: bool = False, |
|
|
use_gradient_checkpointing: bool = False, |
|
|
**llm_init_kwargs, |
|
|
): |
|
|
assert not (tie_encoder_decoder_weights and reinit_decoder), ( |
|
|
"Cannot tie encoder-decoder weights and reinitialize decoder." |
|
|
) |
|
|
assert not (tie_encoder_decoder_weights and freeze_encoder), ( |
|
|
"Cannot freeze encoder weights when tying encoder-decoder weights." |
|
|
) |
|
|
super().__init__() |
|
|
self.use_encoder_causal_mask = use_encoder_causal_mask |
|
|
self.tie_encoder_decoder_weights = tie_encoder_decoder_weights |
|
|
|
|
|
if reinit_encoder: |
|
|
assert num_encoder_layers > 0 |
|
|
encoder_config = AutoConfig.from_pretrained( |
|
|
pretrained_model_name_or_path, |
|
|
trust_remote_code=True, |
|
|
num_hidden_layers=num_encoder_layers, |
|
|
attn_implementation=attn_backend, |
|
|
**llm_init_kwargs, |
|
|
) |
|
|
self.encoder = AutoModelForCausalLM.from_config(encoder_config) |
|
|
else: |
|
|
self.encoder = AutoModelForCausalLM.from_pretrained( |
|
|
pretrained_model_name_or_path, |
|
|
trust_remote_code=True, |
|
|
attn_implementation=attn_backend, |
|
|
**llm_init_kwargs, |
|
|
) |
|
|
assert num_encoder_layers <= len(self.encoder.model.layers), ( |
|
|
f"Cannot keep {num_encoder_layers} layers. " |
|
|
f"Pre-trained model only has {len(self.encoder.model.layers)} layers." |
|
|
) |
|
|
num_encoder_layers = ( |
|
|
len(self.encoder.model.layers) |
|
|
if num_encoder_layers == -1 |
|
|
else num_encoder_layers |
|
|
) |
|
|
if keep_top_encoder_layers: |
|
|
self.encoder.model.layers = self.encoder.model.layers[ |
|
|
-num_encoder_layers: |
|
|
] |
|
|
else: |
|
|
self.encoder.model.layers = self.encoder.model.layers[ |
|
|
:num_encoder_layers |
|
|
] |
|
|
|
|
|
if freeze_encoder: |
|
|
for name, param in self.encoder.named_parameters(): |
|
|
if "embed_tokens" not in name: |
|
|
param.requires_grad = False |
|
|
if use_gradient_checkpointing: |
|
|
self.encoder.gradient_checkpointing_enable() |
|
|
|
|
|
if tie_encoder_decoder_weights: |
|
|
self.decoder = self.encoder |
|
|
num_decoder_layers = ( |
|
|
len(self.decoder.model.layers) |
|
|
if num_decoder_layers == -1 |
|
|
else num_decoder_layers |
|
|
) |
|
|
assert num_decoder_layers <= len(self.decoder.model.layers), ( |
|
|
f"Cannot keep {num_decoder_layers} layers. " |
|
|
f"Pre-trained model only has {len(self.decoder.model.layers)} layers." |
|
|
) |
|
|
|
|
|
self.decoder_layer_idxs = list(range(len(self.encoder.model.layers)))[ |
|
|
-num_decoder_layers: |
|
|
] |
|
|
|
|
|
else: |
|
|
if reinit_decoder: |
|
|
assert num_decoder_layers > 0 |
|
|
decoder_config = AutoConfig.from_pretrained( |
|
|
pretrained_model_name_or_path, |
|
|
trust_remote_code=True, |
|
|
num_hidden_layers=num_decoder_layers, |
|
|
attn_implementation=attn_backend, |
|
|
**llm_init_kwargs, |
|
|
) |
|
|
self.decoder = AutoModelForCausalLM(decoder_config) |
|
|
else: |
|
|
self.decoder = AutoModelForCausalLM.from_pretrained( |
|
|
pretrained_model_name_or_path, |
|
|
trust_remote_code=True, |
|
|
attn_implementation=attn_backend, |
|
|
**llm_init_kwargs, |
|
|
) |
|
|
assert num_decoder_layers <= len(self.decoder.model.layers), ( |
|
|
f"Cannot keep {num_decoder_layers} layers. " |
|
|
f"Pre-trained model only has {len(self.decoder.layers)} layers." |
|
|
) |
|
|
if keep_top_decoder_layers: |
|
|
self.decoder.model.layers = self.decoder.model.layers[ |
|
|
-num_decoder_layers: |
|
|
] |
|
|
else: |
|
|
self.decoder.model.layers = self.decoder.model.layers[ |
|
|
:num_decoder_layers |
|
|
] |
|
|
del self.decoder.model.embed_tokens |
|
|
|
|
|
self.encoder.model.embed_tokens.requires_grad_(True) |
|
|
unused_self_attn_params = ["o_proj", "q_norm", "q_proj"] |
|
|
unused_layernorm_params = ["input_layernorm", "post_attention_layernorm"] |
|
|
for unused_param in unused_self_attn_params: |
|
|
if hasattr(self.encoder.model.layers[-1].self_attn, unused_param): |
|
|
getattr( |
|
|
self.encoder.model.layers[-1].self_attn, unused_param |
|
|
).requires_grad_(False) |
|
|
self.encoder.model.layers[-1].mlp.requires_grad_(False) |
|
|
self.encoder.model.norm.requires_grad_(False) |
|
|
for unused_param in unused_layernorm_params: |
|
|
if hasattr(self.encoder.model.layers[-1], unused_param): |
|
|
getattr(self.encoder.model.layers[-1], unused_param).requires_grad_( |
|
|
False |
|
|
) |
|
|
|
|
|
|
|
|
if ( |
|
|
self.encoder.lm_head.weight.data_ptr() |
|
|
== self.encoder.model.embed_tokens.weight.data_ptr() |
|
|
): |
|
|
self.decoder.lm_head = self.encoder.lm_head |
|
|
else: |
|
|
del self.encoder.lm_head |
|
|
if use_gradient_checkpointing: |
|
|
self.decoder.gradient_checkpointing_enable() |
|
|
self.max_length = max_length |
|
|
|
|
|
def freeze_encoder(self): |
|
|
for p in self.encoder.model.parameters(): |
|
|
p.requires_grad = False |
|
|
|
|
|
def unfreeze_encoder(self): |
|
|
for p in self.encoder.model.parameters(): |
|
|
p.requires_grad = True |
|
|
|
|
|
|
|
|
def forward( |
|
|
self, |
|
|
|
|
|
input_ids: torch.LongTensor, |
|
|
attention_mask: Optional[Union[torch.FloatTensor, BlockMask]] = None, |
|
|
position_ids: Optional[torch.LongTensor] = None, |
|
|
cache_position: Optional[torch.LongTensor] = None, |
|
|
past_key_values: Optional[DynamicCache] = None, |
|
|
encoder_last_hidden_state: Optional[torch.FloatTensor] = None, |
|
|
|
|
|
encoder_input_ids: Optional[torch.LongTensor] = None, |
|
|
encoder_attention_mask: Optional[Union[torch.FloatTensor, BlockMask]] = None, |
|
|
encoder_position_ids: Optional[torch.LongTensor] = None, |
|
|
encoder_cache_position: Optional[torch.LongTensor] = None, |
|
|
encoder_past_key_values: Optional[DynamicCache] = None, |
|
|
|
|
|
fix_cache_length: bool = True, |
|
|
return_updated_cache: bool = False, |
|
|
**flash_attn_kwargs: Unpack[FlashAttentionKwargs], |
|
|
) -> Union[CausalLMOutputWithPast, BaseModelOutputWithPast]: |
|
|
|
|
|
if encoder_input_ids is not None: |
|
|
if self.use_encoder_causal_mask: |
|
|
encoder_attention_mask = None |
|
|
if encoder_cache_position is None and encoder_position_ids is not None: |
|
|
encoder_cache_position = encoder_position_ids[0] |
|
|
past_key_values = self.encoder.model( |
|
|
input_ids=encoder_input_ids, |
|
|
attention_mask=encoder_attention_mask, |
|
|
position_ids=encoder_position_ids, |
|
|
use_cache=True, |
|
|
past_key_values=past_key_values, |
|
|
cache_position=encoder_cache_position, |
|
|
).past_key_values |
|
|
if return_updated_cache: |
|
|
|
|
|
return BaseModelOutputWithPast( |
|
|
past_key_values=past_key_values, |
|
|
) |
|
|
|
|
|
|
|
|
decoder_hidden_states = self.encoder.model.embed_tokens(input_ids) |
|
|
if cache_position is None: |
|
|
if position_ids is not None: |
|
|
cache_position = position_ids[0] |
|
|
else: |
|
|
cache_position = torch.arange( |
|
|
decoder_hidden_states.shape[1], |
|
|
device=decoder_hidden_states.device, |
|
|
) |
|
|
if position_ids is None: |
|
|
position_ids = cache_position.unsqueeze(0) |
|
|
decoder_position_embeddings = self.decoder.model.rotary_emb( |
|
|
decoder_hidden_states, position_ids |
|
|
) |
|
|
|
|
|
if hasattr(self.decoder.model, "_update_causal_mask"): |
|
|
|
|
|
attention_mask = self.decoder.model._update_causal_mask( |
|
|
attention_mask=attention_mask, |
|
|
input_tensor=decoder_hidden_states, |
|
|
cache_position=cache_position, |
|
|
past_key_values=past_key_values, |
|
|
output_attentions=False, |
|
|
) |
|
|
for decoder_layer in self.decoder.model.layers: |
|
|
layer_idx = decoder_layer.self_attn.layer_idx |
|
|
if ( |
|
|
self.tie_encoder_decoder_weights |
|
|
and layer_idx not in self.decoder_layer_idxs |
|
|
): |
|
|
continue |
|
|
|
|
|
|
|
|
if past_key_values is not None and len(past_key_values) > layer_idx: |
|
|
prev_cache_len = past_key_values[layer_idx][0].shape[-2] |
|
|
else: |
|
|
prev_cache_len = 0 |
|
|
|
|
|
decoder_hidden_states = decoder_layer( |
|
|
hidden_states=decoder_hidden_states, |
|
|
attention_mask=attention_mask, |
|
|
position_ids=position_ids, |
|
|
past_key_value=past_key_values, |
|
|
output_attentions=False, |
|
|
use_cache=True, |
|
|
cache_position=position_ids[0], |
|
|
position_embeddings=decoder_position_embeddings, |
|
|
**flash_attn_kwargs, |
|
|
)[0] |
|
|
|
|
|
if past_key_values is not None: |
|
|
|
|
|
|
|
|
past_key_values.key_cache[layer_idx] = past_key_values.key_cache[ |
|
|
layer_idx |
|
|
][..., :prev_cache_len, :] |
|
|
past_key_values.value_cache[layer_idx] = past_key_values.value_cache[ |
|
|
layer_idx |
|
|
][..., :prev_cache_len, :] |
|
|
decoder_hidden_states = self.decoder.model.norm(decoder_hidden_states) |
|
|
logits = self.decoder.lm_head(decoder_hidden_states) |
|
|
return CausalLMOutputWithPast( |
|
|
logits=logits, |
|
|
past_key_values=past_key_values, |
|
|
) |
|
|
|