Feature Extraction
Transformers
PyTorch
e2d2
custom_code
e2d2-wmt / backbone_encoder_decoder.py
yairschiff's picture
Add model and code
b5a4dcb verified
raw
history blame
28.9 kB
from dataclasses import dataclass
from functools import partial
from typing import Optional, Tuple, Union
import torch
from torch import nn
from transformers import AutoConfig, AutoModelForCausalLM
from transformers.cache_utils import DynamicCache
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
ModelOutput,
)
from transformers.processing_utils import Unpack
from transformers.utils import logging
from .backbone_custom_modeling_qwen3 import CustomQwen3ForCausalLM
try:
from torch.nn.attention.flex_attention import BlockMask
except ImportError:
BlockMask = None
logger = logging.get_logger(__name__)
@dataclass
class EncoderBaseModelOutputWithPast(ModelOutput):
"""Custom (encoder) model output.
Stores previous decoder and updated encoder cache and encoder last hidden state.
"""
past_key_values: Optional[Union[Tuple[Tuple[torch.FloatTensor]], DynamicCache]] = (
None
)
encoder_last_hidden_state: Optional[torch.FloatTensor] = None
encoder_past_key_values: Optional[
Union[Tuple[Tuple[torch.FloatTensor]], DynamicCache]
] = None
@dataclass
class DecoderCausalLMOutputWithPast(ModelOutput):
"""Custom (decoder) model output.
Stores previous encoder and updated decoder cache and decoder logits.
"""
logits: Optional[torch.FloatTensor] = None
past_key_values: Optional[Union[Tuple[Tuple[torch.FloatTensor]], DynamicCache]] = (
None
)
encoder_past_key_values: Optional[
Union[Tuple[Tuple[torch.FloatTensor]], DynamicCache]
] = None
class LLMasEncoderDecoder(nn.Module):
def __init__(
self,
pretrained_model_name_or_path: str,
max_length: int,
attn_backend: str = "sdpa",
freeze_encoder: bool = False,
reinit_encoder: bool = False,
reinit_decoder: bool = False,
tie_encoder_decoder_weights: bool = False,
use_encoder_causal_mask: bool = False,
num_encoder_layers: int = -1,
num_decoder_layers: int = -1,
keep_top_encoder_layers: bool = False,
keep_top_decoder_layers: bool = False,
use_gradient_checkpointing: bool = False,
**llm_init_kwargs,
):
assert not (tie_encoder_decoder_weights and reinit_decoder), (
"Cannot tie encoder-decoder weights and reinitialize decoder."
)
assert not (tie_encoder_decoder_weights and freeze_encoder), (
"Cannot freeze encoder weights when tying encoder-decoder weights."
)
super().__init__()
self.use_encoder_causal_mask = use_encoder_causal_mask
self.tie_encoder_decoder_weights = tie_encoder_decoder_weights
if reinit_encoder:
assert num_encoder_layers > 0
encoder_config = AutoConfig.from_pretrained(
pretrained_model_name_or_path,
trust_remote_code=True,
num_hidden_layers=num_encoder_layers,
attn_implementation=attn_backend,
**llm_init_kwargs,
)
self.encoder = CustomQwen3ForCausalLM(encoder_config)
else:
self.encoder = CustomQwen3ForCausalLM.from_pretrained(
pretrained_model_name_or_path,
trust_remote_code=True,
attn_implementation=attn_backend,
**llm_init_kwargs,
)
assert num_encoder_layers <= len(self.encoder.model.layers), (
f"Cannot keep {num_encoder_layers} layers. "
f"Pre-trained model only has {len(self.encoder.model.layers)} layers."
)
num_encoder_layers = (
len(self.encoder.model.layers)
if num_encoder_layers == -1
else num_encoder_layers
)
if keep_top_encoder_layers:
self.encoder.model.layers = self.encoder.model.layers[
-num_encoder_layers:
]
else:
self.encoder.model.layers = self.encoder.model.layers[
:num_encoder_layers
]
if freeze_encoder:
for name, param in self.encoder.named_parameters():
if "embed_tokens" not in name:
param.requires_grad = False
if use_gradient_checkpointing:
self.encoder.gradient_checkpointing_enable()
if tie_encoder_decoder_weights:
self.decoder = self.encoder
num_decoder_layers = (
len(self.decoder.model.layers)
if num_decoder_layers == -1
else num_decoder_layers
)
assert num_decoder_layers <= len(self.decoder.model.layers), (
f"Cannot keep {num_decoder_layers} layers. "
f"Pre-trained model only has {len(self.decoder.model.layers)} layers."
)
# Keep **top** layers when tying weights
self.decoder_layer_idxs = list(range(len(self.encoder.model.layers)))[
-num_decoder_layers:
]
else:
if reinit_decoder:
assert num_decoder_layers > 0
decoder_config = AutoConfig.from_pretrained(
pretrained_model_name_or_path,
trust_remote_code=True,
num_hidden_layers=num_decoder_layers,
attn_implementation=attn_backend,
**llm_init_kwargs,
)
self.decoder = CustomQwen3ForCausalLM(decoder_config)
else:
self.decoder = CustomQwen3ForCausalLM.from_pretrained(
pretrained_model_name_or_path,
trust_remote_code=True,
attn_implementation=attn_backend,
**llm_init_kwargs,
)
assert num_decoder_layers <= len(self.decoder.model.layers), (
f"Cannot keep {num_decoder_layers} layers. "
f"Pre-trained model only has {len(self.decoder.layers)} layers."
)
if keep_top_decoder_layers:
self.decoder.model.layers = self.decoder.model.layers[
-num_decoder_layers:
]
else:
self.decoder.model.layers = self.decoder.model.layers[
:num_decoder_layers
]
del self.decoder.model.embed_tokens
# if in the original LM, the lm_head is weight-tied to embedding,
# point decoder lm_head to encoder's (instead of initializing separately)
if (
self.encoder.lm_head.weight.data_ptr()
== self.encoder.model.embed_tokens.weight.data_ptr()
):
self.decoder.lm_head = self.encoder.lm_head
else:
del self.encoder.lm_head
if use_gradient_checkpointing:
self.decoder.gradient_checkpointing_enable()
self.max_length = max_length
def freeze_encoder(self):
for p in self.encoder.model.parameters():
p.requires_grad = False
def unfreeze_encoder(self):
for p in self.encoder.model.parameters():
p.requires_grad = True
# noinspection PyUnusedLocal
def forward(
self,
# Decoder inputs
input_ids: torch.LongTensor,
attention_mask: Optional[Union[torch.FloatTensor, BlockMask]] = None,
position_ids: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
past_key_values: Optional[DynamicCache] = None,
encoder_last_hidden_state: Optional[torch.FloatTensor] = None,
# Encoder inputs
encoder_input_ids: Optional[torch.LongTensor] = None,
encoder_attention_mask: Optional[Union[torch.FloatTensor, BlockMask]] = None,
encoder_position_ids: Optional[torch.LongTensor] = None,
encoder_cache_position: Optional[torch.LongTensor] = None,
encoder_past_key_values: Optional[DynamicCache] = None,
# Additional args
fix_cache_length: bool = True, # Not used; compatibility with other backbones
return_updated_cache: bool = False,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> Union[DecoderCausalLMOutputWithPast, EncoderBaseModelOutputWithPast]:
# During training/eval encoder_last_hidden_state is None.
# During generation encoder_last_hidden_state can be not None.
new_seen_tokens = (
0
if encoder_last_hidden_state is None
else encoder_last_hidden_state.shape[1]
)
# Encode clean tokens
if encoder_input_ids is not None:
if self.use_encoder_causal_mask:
encoder_attention_mask = None # None --> enforces use of causal mask
if encoder_cache_position is None and encoder_position_ids is not None:
encoder_cache_position = encoder_position_ids[0]
encoder_output = self.encoder.model(
input_ids=encoder_input_ids,
attention_mask=encoder_attention_mask,
position_ids=encoder_position_ids,
use_cache=True,
past_key_values=encoder_past_key_values,
cache_position=encoder_cache_position,
)
if return_updated_cache:
# encoder_output.past_key_values now contains latest encoder input
return EncoderBaseModelOutputWithPast(
encoder_last_hidden_state=encoder_output.last_hidden_state,
encoder_past_key_values=encoder_output.past_key_values,
past_key_values=past_key_values,
)
encoder_last_hidden_state = encoder_output.last_hidden_state
# Run decoder with xattn to clean token hidden states
if encoder_last_hidden_state is None: # No new encoder tokens
q_start_idx = 0
decoder_hidden_states = self.encoder.model.embed_tokens(input_ids)
if cache_position is None:
if position_ids is not None:
cache_position = position_ids[0]
else:
past_seen_tokens = (
past_key_values.get_seq_length()
if past_key_values is not None
else 0
)
cache_position = torch.arange(
past_seen_tokens,
past_seen_tokens + decoder_hidden_states.shape[1],
device=decoder_hidden_states.device,
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
decoder_position_embeddings = self.decoder.model.rotary_emb(
decoder_hidden_states, position_ids
)
else:
q_start_idx = encoder_last_hidden_state.shape[1]
decoder_hidden_states = self.encoder.model.embed_tokens(input_ids)
decoder_hidden_states = torch.cat(
[
encoder_last_hidden_state,
decoder_hidden_states,
],
dim=1,
)
if cache_position is None:
if position_ids is not None:
cache_position = position_ids[0]
else:
past_seen_tokens = (
past_key_values.get_seq_length()
if past_key_values is not None
else 0
)
cache_position = torch.cat(
[
torch.arange( # clean token position ids
past_seen_tokens,
past_seen_tokens + encoder_last_hidden_state.shape[1],
device=decoder_hidden_states.device,
),
torch.arange( # noisy position ids
past_seen_tokens + new_seen_tokens,
past_seen_tokens + new_seen_tokens + input_ids.shape[1],
device=decoder_hidden_states.device,
),
],
dim=-1,
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
decoder_position_embeddings = self.decoder.model.rotary_emb(
decoder_hidden_states, position_ids
)
if hasattr(self.decoder.model, "_update_causal_mask"): # bc on transformers
# noinspection PyProtectedMember
attention_mask = self.decoder.model._update_causal_mask(
attention_mask=attention_mask,
input_tensor=decoder_hidden_states,
cache_position=cache_position,
past_key_values=past_key_values,
output_attentions=False,
)
for decoder_layer in self.decoder.model.layers:
layer_idx = decoder_layer.self_attn.layer_idx
if (
self.tie_encoder_decoder_weights
and layer_idx not in self.decoder_layer_idxs
):
continue
# past_key_values gets updated in-place.
# Record previous length to re-truncate after each layer forward
if past_key_values is not None and len(past_key_values) > layer_idx:
prev_cache_len = past_key_values[layer_idx][0].shape[-2] # type: ignore
else:
prev_cache_len = 0
cache_len = prev_cache_len + new_seen_tokens
if self.decoder.model.gradient_checkpointing and self.training:
# noinspection PyProtectedMember
decoder_hidden_states = self.decoder._gradient_checkpointing_func(
partial(decoder_layer.__call__, **flash_attn_kwargs),
decoder_hidden_states, # hidden_states=,
attention_mask, # attention_mask=,
position_ids, # position_ids=,
past_key_values, # past_key_value=,
False, # output_attentions=,
True, # use_cache=,
cache_position, # cache_position=,
decoder_position_embeddings, # position_embeddings=,
q_start_idx, # q_start_idx=
)[0] # Shape: (input_ids.shape[0], input_ids.shape[1], hidden_dim)
else:
decoder_hidden_states = decoder_layer(
hidden_states=decoder_hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=False,
use_cache=True,
cache_position=cache_position,
position_embeddings=decoder_position_embeddings,
q_start_idx=q_start_idx, # Indicates where to slice output
**flash_attn_kwargs,
)[0] # Shape: (input_ids.shape[0], input_ids.shape[1], hidden_dim)
# Update decoder_hidden_states
if q_start_idx > 0:
decoder_hidden_states = torch.cat(
[
encoder_last_hidden_state,
decoder_hidden_states,
],
dim=1,
)
if past_key_values is not None:
# DynamicCache extends along sequence dimension by default;
# truncate back to original cache len + encoder output length
past_key_values.key_cache[layer_idx] = past_key_values.key_cache[
layer_idx
][..., :cache_len, :]
past_key_values.value_cache[layer_idx] = past_key_values.value_cache[
layer_idx
][..., :cache_len, :]
decoder_hidden_states = self.decoder.model.norm(
decoder_hidden_states[:, q_start_idx:, :]
)
logits = self.decoder.lm_head(decoder_hidden_states)
return DecoderCausalLMOutputWithPast(
logits=logits,
past_key_values=past_key_values,
encoder_past_key_values=encoder_past_key_values,
# Do not need to store encoder_last_hidden_state.
# If it was passed in, then it has become part of the past_key_values cache.
)
class LLMasEncoderDecoderShareKV(nn.Module):
def __init__(
self,
pretrained_model_name_or_path: str,
max_length: int,
attn_backend: str = "sdpa",
freeze_encoder: bool = False,
reinit_encoder: bool = False,
reinit_decoder: bool = False,
tie_encoder_decoder_weights: bool = False,
use_encoder_causal_mask: bool = False,
num_encoder_layers: int = -1,
num_decoder_layers: int = -1,
keep_top_encoder_layers: bool = False,
keep_top_decoder_layers: bool = False,
use_gradient_checkpointing: bool = False,
**llm_init_kwargs,
):
assert not (tie_encoder_decoder_weights and reinit_decoder), (
"Cannot tie encoder-decoder weights and reinitialize decoder."
)
assert not (tie_encoder_decoder_weights and freeze_encoder), (
"Cannot freeze encoder weights when tying encoder-decoder weights."
)
super().__init__()
self.use_encoder_causal_mask = use_encoder_causal_mask
self.tie_encoder_decoder_weights = tie_encoder_decoder_weights
if reinit_encoder:
assert num_encoder_layers > 0
encoder_config = AutoConfig.from_pretrained(
pretrained_model_name_or_path,
trust_remote_code=True,
num_hidden_layers=num_encoder_layers,
attn_implementation=attn_backend,
**llm_init_kwargs,
)
self.encoder = AutoModelForCausalLM.from_config(encoder_config)
else:
self.encoder = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path,
trust_remote_code=True,
attn_implementation=attn_backend,
**llm_init_kwargs,
)
assert num_encoder_layers <= len(self.encoder.model.layers), (
f"Cannot keep {num_encoder_layers} layers. "
f"Pre-trained model only has {len(self.encoder.model.layers)} layers."
)
num_encoder_layers = (
len(self.encoder.model.layers)
if num_encoder_layers == -1
else num_encoder_layers
)
if keep_top_encoder_layers:
self.encoder.model.layers = self.encoder.model.layers[
-num_encoder_layers:
]
else:
self.encoder.model.layers = self.encoder.model.layers[
:num_encoder_layers
]
if freeze_encoder:
for name, param in self.encoder.named_parameters():
if "embed_tokens" not in name:
param.requires_grad = False
if use_gradient_checkpointing:
self.encoder.gradient_checkpointing_enable()
if tie_encoder_decoder_weights:
self.decoder = self.encoder
num_decoder_layers = (
len(self.decoder.model.layers)
if num_decoder_layers == -1
else num_decoder_layers
)
assert num_decoder_layers <= len(self.decoder.model.layers), (
f"Cannot keep {num_decoder_layers} layers. "
f"Pre-trained model only has {len(self.decoder.model.layers)} layers."
)
# Keep **top** layers when tying weights
self.decoder_layer_idxs = list(range(len(self.encoder.model.layers)))[
-num_decoder_layers:
]
else:
if reinit_decoder:
assert num_decoder_layers > 0
decoder_config = AutoConfig.from_pretrained(
pretrained_model_name_or_path,
trust_remote_code=True,
num_hidden_layers=num_decoder_layers,
attn_implementation=attn_backend,
**llm_init_kwargs,
)
self.decoder = AutoModelForCausalLM(decoder_config)
else:
self.decoder = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path,
trust_remote_code=True,
attn_implementation=attn_backend,
**llm_init_kwargs,
)
assert num_decoder_layers <= len(self.decoder.model.layers), (
f"Cannot keep {num_decoder_layers} layers. "
f"Pre-trained model only has {len(self.decoder.layers)} layers."
)
if keep_top_decoder_layers:
self.decoder.model.layers = self.decoder.model.layers[
-num_decoder_layers:
]
else:
self.decoder.model.layers = self.decoder.model.layers[
:num_decoder_layers
]
del self.decoder.model.embed_tokens
# Even for frozen encoder, ensure embedding tokens are trainable
self.encoder.model.embed_tokens.requires_grad_(True)
unused_self_attn_params = ["o_proj", "q_norm", "q_proj"]
unused_layernorm_params = ["input_layernorm", "post_attention_layernorm"]
for unused_param in unused_self_attn_params:
if hasattr(self.encoder.model.layers[-1].self_attn, unused_param):
getattr(
self.encoder.model.layers[-1].self_attn, unused_param
).requires_grad_(False)
self.encoder.model.layers[-1].mlp.requires_grad_(False)
self.encoder.model.norm.requires_grad_(False)
for unused_param in unused_layernorm_params:
if hasattr(self.encoder.model.layers[-1], unused_param):
getattr(self.encoder.model.layers[-1], unused_param).requires_grad_(
False
)
# if in the original LM, the lm_head is weight-tied to embedding,
# point decoder lm_head to encoder's (instead of initializing separately)
if (
self.encoder.lm_head.weight.data_ptr()
== self.encoder.model.embed_tokens.weight.data_ptr()
):
self.decoder.lm_head = self.encoder.lm_head
else:
del self.encoder.lm_head
if use_gradient_checkpointing:
self.decoder.gradient_checkpointing_enable()
self.max_length = max_length
def freeze_encoder(self):
for p in self.encoder.model.parameters():
p.requires_grad = False
def unfreeze_encoder(self):
for p in self.encoder.model.parameters():
p.requires_grad = True
# noinspection PyUnusedLocal
def forward(
self,
# Decoder inputs
input_ids: torch.LongTensor,
attention_mask: Optional[Union[torch.FloatTensor, BlockMask]] = None,
position_ids: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
past_key_values: Optional[DynamicCache] = None,
encoder_last_hidden_state: Optional[torch.FloatTensor] = None, # Not used
# Encoder inputs
encoder_input_ids: Optional[torch.LongTensor] = None,
encoder_attention_mask: Optional[Union[torch.FloatTensor, BlockMask]] = None,
encoder_position_ids: Optional[torch.LongTensor] = None,
encoder_cache_position: Optional[torch.LongTensor] = None,
encoder_past_key_values: Optional[DynamicCache] = None, # Not used
# Additional args
fix_cache_length: bool = True, # Not used; compatibility with other backbones
return_updated_cache: bool = False,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> Union[CausalLMOutputWithPast, BaseModelOutputWithPast]:
# Encode clean tokens
if encoder_input_ids is not None:
if self.use_encoder_causal_mask:
encoder_attention_mask = None # None --> enforces use of causal mask
if encoder_cache_position is None and encoder_position_ids is not None:
encoder_cache_position = encoder_position_ids[0]
past_key_values = self.encoder.model(
input_ids=encoder_input_ids,
attention_mask=encoder_attention_mask,
position_ids=encoder_position_ids,
use_cache=True,
past_key_values=past_key_values,
cache_position=encoder_cache_position,
).past_key_values
if return_updated_cache:
# encoder_output.past_key_values now contains latest encoder input
return BaseModelOutputWithPast(
past_key_values=past_key_values,
)
# Run decoder with xattn to clean token hidden states
decoder_hidden_states = self.encoder.model.embed_tokens(input_ids)
if cache_position is None:
if position_ids is not None:
cache_position = position_ids[0]
else: # During training / validation position_ids are not provided
cache_position = torch.arange(
decoder_hidden_states.shape[1],
device=decoder_hidden_states.device,
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
decoder_position_embeddings = self.decoder.model.rotary_emb(
decoder_hidden_states, position_ids
)
if hasattr(self.decoder.model, "_update_causal_mask"): # bc on transformers
# noinspection PyProtectedMember
attention_mask = self.decoder.model._update_causal_mask(
attention_mask=attention_mask,
input_tensor=decoder_hidden_states,
cache_position=cache_position,
past_key_values=past_key_values,
output_attentions=False,
)
for decoder_layer in self.decoder.model.layers:
layer_idx = decoder_layer.self_attn.layer_idx
if (
self.tie_encoder_decoder_weights
and layer_idx not in self.decoder_layer_idxs
):
continue
# past_key_values gets updated in-place.
# Record previous length to truncate after each layer forward
if past_key_values is not None and len(past_key_values) > layer_idx:
prev_cache_len = past_key_values[layer_idx][0].shape[-2] # type: ignore
else:
prev_cache_len = 0
decoder_hidden_states = decoder_layer(
hidden_states=decoder_hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=False,
use_cache=True,
cache_position=position_ids[0],
position_embeddings=decoder_position_embeddings,
**flash_attn_kwargs,
)[0] # Shape: (input_ids.shape[0], input_ids.shape[1], hidden_dim)
if past_key_values is not None:
# DynamicCache extends along sequence dimension by default;
# truncate back to original cache len + encoder output length
past_key_values.key_cache[layer_idx] = past_key_values.key_cache[
layer_idx
][..., :prev_cache_len, :]
past_key_values.value_cache[layer_idx] = past_key_values.value_cache[
layer_idx
][..., :prev_cache_len, :]
decoder_hidden_states = self.decoder.model.norm(decoder_hidden_states)
logits = self.decoder.lm_head(decoder_hidden_states)
return CausalLMOutputWithPast(
logits=logits,
past_key_values=past_key_values,
)