File size: 6,478 Bytes
b5a4dcb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 |
from typing import Callable, Optional, Tuple
import torch
from torch import nn
from transformers.models.qwen3.modeling_qwen3 import (
ALL_ATTENTION_FUNCTIONS,
Cache,
FlashAttentionKwargs,
Qwen3Attention,
Qwen3Config,
Qwen3DecoderLayer,
Qwen3ForCausalLM,
Qwen3Model,
eager_attention_forward,
rotate_half,
)
from transformers.processing_utils import Unpack
from transformers.utils import logging
logger = logging.get_logger(__name__)
def custom_apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1, q_start_idx=0):
"""Applies Rotary Position Embedding to the query and key tensors."""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos[..., q_start_idx:, :]) + (
rotate_half(q) * sin[..., q_start_idx:, :]
)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class CustomQwen3Attention(Qwen3Attention):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: Qwen3Config, layer_idx: int):
super().__init__(config, layer_idx=layer_idx)
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor],
past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
q_start_idx: int = 0, # > 0: decoder pass w/encoder inputs in hidden_states
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
sa_hidden_sates = hidden_states[:, q_start_idx:, :]
query_input_shape = sa_hidden_sates.shape[:-1]
query_hidden_shape = (*query_input_shape, -1, self.head_dim)
query_states = self.q_norm(
self.q_proj(sa_hidden_sates).reshape(query_hidden_shape)
).transpose(1, 2)
key_states = self.k_norm(
self.k_proj(hidden_states).view(hidden_shape)
).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
cos, sin = position_embeddings
query_states, key_states = custom_apply_rotary_pos_emb(
query_states, key_states, cos, sin, q_start_idx=q_start_idx
)
if past_key_value is not None:
# sin and cos are specific to RoPE models
# cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, cache_kwargs
)
# NOTE: downcast for flex-attention compatibility
query_states, key_states = (
query_states.to(value_states.dtype),
key_states.to(value_states.dtype),
)
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[
self.config._attn_implementation
]
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
sliding_window=self.sliding_window, # diff with Llama
**kwargs,
)
attn_output = attn_output.reshape(*query_input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
class CustomQwen3DecoderLayer(Qwen3DecoderLayer):
def __init__(self, config: Qwen3Config, layer_idx: int):
super().__init__(config, layer_idx=layer_idx)
self.self_attn = CustomQwen3Attention(config=config, layer_idx=layer_idx)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
q_start_idx: int = 0,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
]:
residual = hidden_states[:, q_start_idx:, ...]
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
q_start_idx=q_start_idx,
**kwargs,
)
hidden_states = residual + hidden_states
# return hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
return outputs
class CustomQwen3Model(Qwen3Model):
def __init__(self, config: Qwen3Config):
super().__init__(config)
self.layers = nn.ModuleList(
[
CustomQwen3DecoderLayer(config, layer_idx)
for layer_idx in range(config.num_hidden_layers)
]
)
# Initialize weights and apply final processing
self.post_init()
class CustomQwen3ForCausalLM(Qwen3ForCausalLM):
def __init__(self, config: Qwen3Config):
super().__init__(config)
# Initialize a new model with custom layers
self.model = CustomQwen3Model(config)
# Initialize weights and apply final processing
self.post_init()
|