Feature Extraction
Transformers
PyTorch
e2d2
custom_code
e2d2-wmt / backbone_custom_modeling_qwen3.py
yairschiff's picture
Add model and code
b5a4dcb verified
from typing import Callable, Optional, Tuple
import torch
from torch import nn
from transformers.models.qwen3.modeling_qwen3 import (
ALL_ATTENTION_FUNCTIONS,
Cache,
FlashAttentionKwargs,
Qwen3Attention,
Qwen3Config,
Qwen3DecoderLayer,
Qwen3ForCausalLM,
Qwen3Model,
eager_attention_forward,
rotate_half,
)
from transformers.processing_utils import Unpack
from transformers.utils import logging
logger = logging.get_logger(__name__)
def custom_apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1, q_start_idx=0):
"""Applies Rotary Position Embedding to the query and key tensors."""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos[..., q_start_idx:, :]) + (
rotate_half(q) * sin[..., q_start_idx:, :]
)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class CustomQwen3Attention(Qwen3Attention):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: Qwen3Config, layer_idx: int):
super().__init__(config, layer_idx=layer_idx)
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor],
past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
q_start_idx: int = 0, # > 0: decoder pass w/encoder inputs in hidden_states
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
sa_hidden_sates = hidden_states[:, q_start_idx:, :]
query_input_shape = sa_hidden_sates.shape[:-1]
query_hidden_shape = (*query_input_shape, -1, self.head_dim)
query_states = self.q_norm(
self.q_proj(sa_hidden_sates).reshape(query_hidden_shape)
).transpose(1, 2)
key_states = self.k_norm(
self.k_proj(hidden_states).view(hidden_shape)
).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
cos, sin = position_embeddings
query_states, key_states = custom_apply_rotary_pos_emb(
query_states, key_states, cos, sin, q_start_idx=q_start_idx
)
if past_key_value is not None:
# sin and cos are specific to RoPE models
# cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, cache_kwargs
)
# NOTE: downcast for flex-attention compatibility
query_states, key_states = (
query_states.to(value_states.dtype),
key_states.to(value_states.dtype),
)
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[
self.config._attn_implementation
]
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
sliding_window=self.sliding_window, # diff with Llama
**kwargs,
)
attn_output = attn_output.reshape(*query_input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
class CustomQwen3DecoderLayer(Qwen3DecoderLayer):
def __init__(self, config: Qwen3Config, layer_idx: int):
super().__init__(config, layer_idx=layer_idx)
self.self_attn = CustomQwen3Attention(config=config, layer_idx=layer_idx)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
q_start_idx: int = 0,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
]:
residual = hidden_states[:, q_start_idx:, ...]
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
q_start_idx=q_start_idx,
**kwargs,
)
hidden_states = residual + hidden_states
# return hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
return outputs
class CustomQwen3Model(Qwen3Model):
def __init__(self, config: Qwen3Config):
super().__init__(config)
self.layers = nn.ModuleList(
[
CustomQwen3DecoderLayer(config, layer_idx)
for layer_idx in range(config.num_hidden_layers)
]
)
# Initialize weights and apply final processing
self.post_init()
class CustomQwen3ForCausalLM(Qwen3ForCausalLM):
def __init__(self, config: Qwen3Config):
super().__init__(config)
# Initialize a new model with custom layers
self.model = CustomQwen3Model(config)
# Initialize weights and apply final processing
self.post_init()