|
|
from typing import Callable, Optional, Union |
|
|
|
|
|
import torch |
|
|
from torch import nn |
|
|
|
|
|
from transformers.activations import ACT2FN |
|
|
from transformers.cache_utils import Cache, DynamicCache |
|
|
from transformers.generation import GenerationMixin |
|
|
from transformers.integrations import use_kernel_forward_from_hub |
|
|
from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask |
|
|
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs |
|
|
from transformers.modeling_layers import ( |
|
|
GenericForQuestionAnswering, |
|
|
GenericForSequenceClassification, |
|
|
GenericForTokenClassification, |
|
|
GradientCheckpointingLayer, |
|
|
) |
|
|
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast |
|
|
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update |
|
|
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel |
|
|
from transformers.processing_utils import Unpack |
|
|
from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple, logging |
|
|
from transformers.utils.deprecation import deprecate_kwarg |
|
|
from transformers.utils.generic import check_model_inputs |
|
|
from transformers.models.qwen2.configuration_qwen2 import Qwen2Config |
|
|
|
|
|
from transformers.models.qwen2.modeling_qwen2 import ( |
|
|
Qwen2MLP, |
|
|
Qwen2Attention, |
|
|
apply_rotary_pos_emb, |
|
|
eager_attention_forward, |
|
|
Qwen2RMSNorm, |
|
|
Qwen2RotaryEmbedding, |
|
|
Qwen2Model, |
|
|
Qwen2ForCausalLM, |
|
|
) |
|
|
|
|
|
from transformers.modeling_layers import ( |
|
|
GenericForQuestionAnswering, |
|
|
GenericForSequenceClassification, |
|
|
GenericForTokenClassification, |
|
|
GradientCheckpointingLayer, |
|
|
) |
|
|
|
|
|
from .configuration_fp8_qwen2 import FP8Qwen2Config |
|
|
|
|
|
from torchao.float8.float8_training_tensor import Float8TrainingTensor |
|
|
|
|
|
from quasar.module import ( |
|
|
FP8Quant, |
|
|
FP8RMSNorm, |
|
|
FP8DSLinearWithCoat, |
|
|
FP8DSLinearWithCoatWeightBlock, |
|
|
FP8FusedSiLUMul, |
|
|
FP8Identity, |
|
|
) |
|
|
|
|
|
from quasar.kernel.configs import FP8RMSNormConfig, QuantType, FP8MulConfig, FP8DSLinearWithCoatConfig, FP8QuantConfig |
|
|
from quasar.kernel.quant.quantize_hp2pb import fp8_quantize_hp2pb |
|
|
from quasar.kernel.quant.dequantize_pb2hp import fp8_dequantize_pb2hp |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
|
|
|
class FP8Qwen2MLP(Qwen2MLP): |
|
|
def __init__(self, config: FP8Qwen2Config): |
|
|
super().__init__(config) |
|
|
linear_module = FP8DSLinearWithCoat if config.fp8_config.training_mode else FP8DSLinearWithCoatWeightBlock |
|
|
self.gate_proj = linear_module( |
|
|
self.hidden_size, |
|
|
self.intermediate_size, |
|
|
bias=False, |
|
|
dsgemm_config=FP8DSLinearWithCoatConfig(layer_name=f"gate_proj", scale_dtype=torch.float32), |
|
|
) |
|
|
self.up_proj = linear_module( |
|
|
self.hidden_size, |
|
|
self.intermediate_size, |
|
|
bias=False, |
|
|
dsgemm_config=FP8DSLinearWithCoatConfig(layer_name=f"up_proj", scale_dtype=torch.float32), |
|
|
) |
|
|
self.down_proj = linear_module( |
|
|
self.intermediate_size, |
|
|
self.hidden_size, |
|
|
bias=False, |
|
|
dsgemm_config=FP8DSLinearWithCoatConfig(layer_name=f"down_proj", scale_dtype=torch.float32), |
|
|
) |
|
|
|
|
|
if config.hidden_act == "silu": |
|
|
mul_config = FP8MulConfig( |
|
|
quant_type=QuantType.MUL, |
|
|
) |
|
|
self.act_fn = FP8FusedSiLUMul(mul_config) |
|
|
else: |
|
|
raise ValueError(f"Unsupported activation function: {config.hidden_act}") |
|
|
|
|
|
def forward(self, x): |
|
|
gate_x = self.gate_proj(x) |
|
|
up_x = self.up_proj(x) |
|
|
|
|
|
mul_x = self.act_fn(gate_x, up_x) |
|
|
down_proj = self.down_proj(mul_x) |
|
|
|
|
|
return down_proj |
|
|
|
|
|
|
|
|
class FP8Qwen2Attention(Qwen2Attention): |
|
|
"""Multi-headed attention from 'Attention Is All You Need' paper""" |
|
|
|
|
|
def __init__(self, config: FP8Qwen2Config, layer_idx: int): |
|
|
super().__init__(config, layer_idx) |
|
|
linear_module = FP8DSLinearWithCoat if config.fp8_config.training_mode else FP8DSLinearWithCoatWeightBlock |
|
|
self.q_proj = linear_module( |
|
|
config.hidden_size, |
|
|
config.num_attention_heads * self.head_dim, |
|
|
bias=True, |
|
|
dsgemm_config=FP8DSLinearWithCoatConfig(layer_name=f"q_proj", scale_dtype=torch.float32), |
|
|
) |
|
|
self.k_proj = linear_module( |
|
|
config.hidden_size, |
|
|
config.num_key_value_heads * self.head_dim, |
|
|
bias=True, |
|
|
dsgemm_config=FP8DSLinearWithCoatConfig(layer_name=f"k_proj", scale_dtype=torch.float32), |
|
|
) |
|
|
self.v_proj = linear_module( |
|
|
config.hidden_size, |
|
|
config.num_key_value_heads * self.head_dim, |
|
|
bias=True, |
|
|
dsgemm_config=FP8DSLinearWithCoatConfig(layer_name=f"v_proj", scale_dtype=torch.float32), |
|
|
) |
|
|
|
|
|
|
|
|
self.o_proj_quant = FP8Quant( |
|
|
quant_config=FP8QuantConfig( |
|
|
float8_dtype=config.fp8_config.float8_dtype, |
|
|
quant_type=QuantType.DIV, |
|
|
fwd_block_size=config.fp8_config.mm_block_size, |
|
|
layer_name=f"o_proj_quant", |
|
|
scale_dtype=torch.float32, |
|
|
) |
|
|
) |
|
|
self.o_proj = linear_module( |
|
|
config.num_attention_heads * self.head_dim, |
|
|
config.hidden_size, |
|
|
bias=False, |
|
|
dsgemm_config=FP8DSLinearWithCoatConfig( |
|
|
fwd_input_quant_type=QuantType.DIV, |
|
|
layer_name=f"o_proj", |
|
|
scale_dtype=torch.float32, |
|
|
), |
|
|
) |
|
|
|
|
|
@deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") |
|
|
def forward( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
position_embeddings: tuple[torch.Tensor, torch.Tensor], |
|
|
attention_mask: Optional[torch.Tensor], |
|
|
past_key_values: Optional[Cache] = None, |
|
|
cache_position: Optional[torch.LongTensor] = None, |
|
|
**kwargs: Unpack[FlashAttentionKwargs], |
|
|
) -> tuple[torch.Tensor, Optional[torch.Tensor]]: |
|
|
if isinstance(hidden_states, Float8TrainingTensor): |
|
|
|
|
|
input_shape = hidden_states.shape[:-2] |
|
|
else: |
|
|
input_shape = hidden_states.shape[:-1] |
|
|
hidden_shape = (*input_shape, -1, self.head_dim) |
|
|
|
|
|
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) |
|
|
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) |
|
|
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) |
|
|
|
|
|
cos, sin = position_embeddings |
|
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) |
|
|
|
|
|
|
|
|
|
|
|
if past_key_values is not None: |
|
|
|
|
|
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} |
|
|
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) |
|
|
|
|
|
attention_interface: Callable = eager_attention_forward |
|
|
if self.config._attn_implementation != "eager": |
|
|
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] |
|
|
|
|
|
attn_output, attn_weights = attention_interface( |
|
|
self, |
|
|
query_states, |
|
|
key_states, |
|
|
value_states, |
|
|
attention_mask, |
|
|
dropout=0.0 if not self.training else self.attention_dropout, |
|
|
scaling=self.scaling, |
|
|
sliding_window=self.sliding_window, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
attn_output = attn_output.reshape(*input_shape, -1).contiguous() |
|
|
|
|
|
|
|
|
attn_output = self.o_proj_quant(attn_output) |
|
|
attn_output = self.o_proj(attn_output) |
|
|
return attn_output, attn_weights |
|
|
|
|
|
|
|
|
class FP8Qwen2DecoderLayer(GradientCheckpointingLayer): |
|
|
def __init__(self, config: FP8Qwen2Config, layer_idx: int): |
|
|
super().__init__() |
|
|
self.hidden_size = config.hidden_size |
|
|
|
|
|
self.self_attn = FP8Qwen2Attention(config=config, layer_idx=layer_idx) |
|
|
|
|
|
self.mlp = FP8Qwen2MLP(config) |
|
|
self.input_layernorm = FP8RMSNorm( |
|
|
config.hidden_size, |
|
|
eps=config.rms_norm_eps, |
|
|
norm_config=FP8RMSNormConfig( |
|
|
mm_block_size=config.fp8_config.mm_block_size, quant_type=QuantType.MUL, save_fp8_input=True |
|
|
), |
|
|
) |
|
|
self.post_attention_layernorm = FP8RMSNorm( |
|
|
config.hidden_size, |
|
|
eps=config.rms_norm_eps, |
|
|
norm_config=FP8RMSNormConfig( |
|
|
mm_block_size=config.fp8_config.mm_block_size, quant_type=QuantType.MUL, save_fp8_input=True |
|
|
), |
|
|
) |
|
|
self.attention_type = config.layer_types[layer_idx] |
|
|
|
|
|
@deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") |
|
|
def forward( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.LongTensor] = None, |
|
|
past_key_values: Optional[Cache] = None, |
|
|
use_cache: Optional[bool] = False, |
|
|
cache_position: Optional[torch.LongTensor] = None, |
|
|
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, |
|
|
**kwargs: Unpack[TransformersKwargs], |
|
|
) -> torch.Tensor: |
|
|
residual = hidden_states |
|
|
hidden_states = self.input_layernorm(hidden_states) |
|
|
|
|
|
hidden_states, _ = self.self_attn( |
|
|
hidden_states=hidden_states, |
|
|
attention_mask=attention_mask, |
|
|
position_ids=position_ids, |
|
|
past_key_values=past_key_values, |
|
|
use_cache=use_cache, |
|
|
cache_position=cache_position, |
|
|
position_embeddings=position_embeddings, |
|
|
**kwargs, |
|
|
) |
|
|
hidden_states = residual + hidden_states |
|
|
|
|
|
|
|
|
residual = hidden_states |
|
|
hidden_states = self.post_attention_layernorm(hidden_states) |
|
|
hidden_states = self.mlp(hidden_states) |
|
|
hidden_states = residual + hidden_states |
|
|
return hidden_states |
|
|
|
|
|
|
|
|
@auto_docstring |
|
|
class FP8Qwen2PreTrainedModel(PreTrainedModel): |
|
|
config_class = FP8Qwen2Config |
|
|
config: FP8Qwen2Config |
|
|
base_model_prefix = "model" |
|
|
supports_gradient_checkpointing = True |
|
|
_no_split_modules = ["FP8Qwen2DecoderLayer"] |
|
|
_skip_keys_device_placement = ["past_key_values"] |
|
|
_supports_flash_attn = True |
|
|
_supports_sdpa = True |
|
|
_supports_flex_attn = True |
|
|
|
|
|
_can_compile_fullgraph = True |
|
|
_supports_attention_backend = True |
|
|
_can_record_outputs = { |
|
|
"hidden_states": FP8Qwen2DecoderLayer, |
|
|
"attentions": FP8Qwen2Attention, |
|
|
} |
|
|
|
|
|
|
|
|
@auto_docstring |
|
|
class FP8Qwen2Model(FP8Qwen2PreTrainedModel): |
|
|
def __init__(self, config: FP8Qwen2Config): |
|
|
super().__init__(config) |
|
|
self.padding_idx = config.pad_token_id |
|
|
self.vocab_size = config.vocab_size |
|
|
|
|
|
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) |
|
|
self.layers = nn.ModuleList( |
|
|
[FP8Qwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] |
|
|
) |
|
|
self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
|
self.rotary_emb = Qwen2RotaryEmbedding(config=config) |
|
|
self.gradient_checkpointing = False |
|
|
self.has_sliding_layers = "sliding_attention" in self.config.layer_types |
|
|
|
|
|
|
|
|
self.post_init() |
|
|
|
|
|
forward = Qwen2Model.forward |
|
|
|
|
|
|
|
|
@auto_docstring |
|
|
class FP8Qwen2ForCausalLM(FP8Qwen2PreTrainedModel, GenerationMixin): |
|
|
config_class = FP8Qwen2Config |
|
|
_tied_weights_keys = ["lm_head.weight"] |
|
|
_tp_plan = {"lm_head": "colwise_rep"} |
|
|
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])} |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
self.model = FP8Qwen2Model(config) |
|
|
self.vocab_size = config.vocab_size |
|
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
|
|
|
|
|
|
self.post_init() |
|
|
|
|
|
@can_return_tuple |
|
|
@auto_docstring |
|
|
def forward( |
|
|
self, |
|
|
input_ids: Optional[torch.LongTensor] = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.LongTensor] = None, |
|
|
past_key_values: Optional[Cache] = None, |
|
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
|
labels: Optional[torch.LongTensor] = None, |
|
|
use_cache: Optional[bool] = None, |
|
|
cache_position: Optional[torch.LongTensor] = None, |
|
|
logits_to_keep: Union[int, torch.Tensor] = 0, |
|
|
**kwargs: Unpack[TransformersKwargs], |
|
|
) -> CausalLMOutputWithPast: |
|
|
r""" |
|
|
Example: |
|
|
|
|
|
```python |
|
|
>>> from transformers import AutoTokenizer, Qwen2ForCausalLM |
|
|
|
|
|
>>> model = Qwen2ForCausalLM.from_pretrained("meta-qwen2/Qwen2-2-7b-hf") |
|
|
>>> tokenizer = AutoTokenizer.from_pretrained("meta-qwen2/Qwen2-2-7b-hf") |
|
|
|
|
|
>>> prompt = "Hey, are you conscious? Can you talk to me?" |
|
|
>>> inputs = tokenizer(prompt, return_tensors="pt") |
|
|
|
|
|
>>> # Generate |
|
|
>>> generate_ids = model.generate(inputs.input_ids, max_length=30) |
|
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] |
|
|
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." |
|
|
```""" |
|
|
outputs: BaseModelOutputWithPast = self.model( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
position_ids=position_ids, |
|
|
past_key_values=past_key_values, |
|
|
inputs_embeds=inputs_embeds, |
|
|
use_cache=use_cache, |
|
|
cache_position=cache_position, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
hidden_states = outputs.last_hidden_state |
|
|
|
|
|
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep |
|
|
logits = self.lm_head(hidden_states[:, slice_indices, :]) |
|
|
|
|
|
loss = None |
|
|
if labels is not None: |
|
|
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) |
|
|
|
|
|
return CausalLMOutputWithPast( |
|
|
loss=loss, |
|
|
logits=logits, |
|
|
past_key_values=outputs.past_key_values, |
|
|
hidden_states=outputs.hidden_states, |
|
|
attentions=outputs.attentions, |
|
|
) |
|
|
|
|
|
|
|
|
class FP8Qwen2ForSequenceClassification(GenericForSequenceClassification, FP8Qwen2PreTrainedModel): |
|
|
pass |
|
|
|
|
|
|
|
|
class FP8Qwen2ForTokenClassification(GenericForTokenClassification, FP8Qwen2PreTrainedModel): |
|
|
pass |
|
|
|
|
|
|
|
|
class FP8Qwen2ForQuestionAnswering(GenericForQuestionAnswering, FP8Qwen2PreTrainedModel): |
|
|
base_model_prefix = "transformer" |
|
|
|
|
|
|
|
|
__all__ = [ |
|
|
"FP8Qwen2PreTrainedModel", |
|
|
"FP8Qwen2Model", |
|
|
"FP8Qwen2ForCausalLM", |
|
|
"FP8Qwen2ForSequenceClassification", |
|
|
"FP8Qwen2ForTokenClassification", |
|
|
"FP8Qwen2ForQuestionAnswering", |
|
|
] |
|
|
|
|
|
|
|
|
FP8Qwen2Model.register_for_auto_class("AutoModel") |
|
|
FP8Qwen2ForCausalLM.register_for_auto_class("AutoModelForCausalLM") |
|
|
|
|
|
|
|
|
def make_state_dict_compatible_with_hf( |
|
|
state_dict: dict[str, torch.Tensor], |
|
|
linear_keys: list[str], |
|
|
undesired_linear_keys: list[str], |
|
|
config: FP8Qwen2Config = FP8Qwen2Config(), |
|
|
already_fp8: bool = False, |
|
|
) -> dict[str, torch.Tensor]: |
|
|
""" |
|
|
Make the state dict compatible with HuggingFace. |
|
|
""" |
|
|
|
|
|
assert set(linear_keys).isdisjoint(set(undesired_linear_keys)) |
|
|
|
|
|
compatible_state_dict = {} |
|
|
|
|
|
for key in state_dict.keys(): |
|
|
if any(k in key for k in linear_keys): |
|
|
weight = state_dict[key] |
|
|
|
|
|
if already_fp8: |
|
|
|
|
|
compatible_state_dict[key] = weight |
|
|
else: |
|
|
|
|
|
tmp_quant_cfg = FP8QuantConfig( |
|
|
float8_dtype=config.fp8_config.float8_dtype, |
|
|
quant_type=config.fp8_config.quant_type, |
|
|
fwd_block_size=config.fp8_config.mm_block_size, |
|
|
scale_dtype=torch.float32, |
|
|
) |
|
|
quant_weight, scale_weight = fp8_quantize_hp2pb( |
|
|
weight, tmp_quant_cfg, block_size=config.fp8_config.mm_block_size |
|
|
) |
|
|
|
|
|
name_quant = key.replace("weight", "weight") |
|
|
name_scale = key.replace("weight", "weight_scale_inv") |
|
|
compatible_state_dict[name_quant] = quant_weight |
|
|
compatible_state_dict[name_scale] = scale_weight |
|
|
|
|
|
elif any(k in key for k in undesired_linear_keys): |
|
|
|
|
|
if already_fp8: |
|
|
|
|
|
if "weight_scale_inv" in key: |
|
|
name_quant = key.replace("weight_scale_inv", "weight") |
|
|
quant_weight = state_dict[name_quant] |
|
|
scale_weight = state_dict[key] |
|
|
weight = fp8_dequantize_pb2hp( |
|
|
quant_weight, scale_weight, config.fp8_config, block_size=config.fp8_config.mm_block_size |
|
|
) |
|
|
compatible_state_dict[name_quant] = weight |
|
|
else: |
|
|
|
|
|
compatible_state_dict[key] = state_dict[key] |
|
|
|
|
|
else: |
|
|
compatible_state_dict[key] = state_dict[key] |
|
|
return compatible_state_dict |
|
|
|
|
|
|
|
|
def set_named_weight_to_fp8( |
|
|
model: FP8Qwen2ForCausalLM, |
|
|
linear_keys: list[str] = ["q_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"], |
|
|
): |
|
|
""" |
|
|
Set the dtype of the weight of the linear layers to FP8. |
|
|
Also set layer name for debugging. |
|
|
""" |
|
|
for name, module in model.named_modules(): |
|
|
if any(k in name for k in linear_keys): |
|
|
module.weight.data = module.weight.data.to(torch.float8_e4m3fn) |
|
|
module.layer_name = name |
|
|
|
|
|
return model |
|
|
|