from typing import Callable, Optional, Tuple import torch from torch import nn from transformers.models.qwen3.modeling_qwen3 import ( ALL_ATTENTION_FUNCTIONS, Cache, FlashAttentionKwargs, Qwen3Attention, Qwen3Config, Qwen3DecoderLayer, Qwen3ForCausalLM, Qwen3Model, eager_attention_forward, rotate_half, ) from transformers.processing_utils import Unpack from transformers.utils import logging logger = logging.get_logger(__name__) def custom_apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1, q_start_idx=0): """Applies Rotary Position Embedding to the query and key tensors.""" cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) q_embed = (q * cos[..., q_start_idx:, :]) + ( rotate_half(q) * sin[..., q_start_idx:, :] ) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed class CustomQwen3Attention(Qwen3Attention): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__(self, config: Qwen3Config, layer_idx: int): super().__init__(config, layer_idx=layer_idx) def forward( self, hidden_states: torch.Tensor, position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, q_start_idx: int = 0, # > 0: decoder pass w/encoder inputs in hidden_states **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) sa_hidden_sates = hidden_states[:, q_start_idx:, :] query_input_shape = sa_hidden_sates.shape[:-1] query_hidden_shape = (*query_input_shape, -1, self.head_dim) query_states = self.q_norm( self.q_proj(sa_hidden_sates).reshape(query_hidden_shape) ).transpose(1, 2) key_states = self.k_norm( self.k_proj(hidden_states).view(hidden_shape) ).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings query_states, key_states = custom_apply_rotary_pos_emb( query_states, key_states, cos, sin, q_start_idx=q_start_idx ) if past_key_value is not None: # sin and cos are specific to RoPE models # cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update( key_states, value_states, self.layer_idx, cache_kwargs ) # NOTE: downcast for flex-attention compatibility query_states, key_states = ( query_states.to(value_states.dtype), key_states.to(value_states.dtype), ) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[ self.config._attn_implementation ] attn_output, attn_weights = attention_interface( self, query_states, key_states, value_states, attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, sliding_window=self.sliding_window, # diff with Llama **kwargs, ) attn_output = attn_output.reshape(*query_input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) return attn_output, attn_weights class CustomQwen3DecoderLayer(Qwen3DecoderLayer): def __init__(self, config: Qwen3Config, layer_idx: int): super().__init__(config, layer_idx=layer_idx) self.self_attn = CustomQwen3Attention(config=config, layer_idx=layer_idx) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, q_start_idx: int = 0, **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[ torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] ]: residual = hidden_states[:, q_start_idx:, ...] hidden_states = self.input_layernorm(hidden_states) # Self Attention hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, q_start_idx=q_start_idx, **kwargs, ) hidden_states = residual + hidden_states # return hidden_states # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states outputs = (hidden_states,) if output_attentions: outputs += (self_attn_weights,) return outputs class CustomQwen3Model(Qwen3Model): def __init__(self, config: Qwen3Config): super().__init__(config) self.layers = nn.ModuleList( [ CustomQwen3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers) ] ) # Initialize weights and apply final processing self.post_init() class CustomQwen3ForCausalLM(Qwen3ForCausalLM): def __init__(self, config: Qwen3Config): super().__init__(config) # Initialize a new model with custom layers self.model = CustomQwen3Model(config) # Initialize weights and apply final processing self.post_init()