from typing import Callable, Optional, Union import torch from torch import nn from transformers.activations import ACT2FN from transformers.cache_utils import Cache, DynamicCache from transformers.generation import GenerationMixin from transformers.integrations import use_kernel_forward_from_hub from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask from transformers.modeling_flash_attention_utils import FlashAttentionKwargs from transformers.modeling_layers import ( GenericForQuestionAnswering, GenericForSequenceClassification, GenericForTokenClassification, GradientCheckpointingLayer, ) from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from transformers.processing_utils import Unpack from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple, logging from transformers.utils.deprecation import deprecate_kwarg from transformers.utils.generic import check_model_inputs from transformers.models.qwen2.configuration_qwen2 import Qwen2Config from transformers.models.qwen2.modeling_qwen2 import ( Qwen2MLP, Qwen2Attention, apply_rotary_pos_emb, eager_attention_forward, Qwen2RMSNorm, Qwen2RotaryEmbedding, Qwen2Model, Qwen2ForCausalLM, ) from transformers.modeling_layers import ( GenericForQuestionAnswering, GenericForSequenceClassification, GenericForTokenClassification, GradientCheckpointingLayer, ) from .configuration_fp8_qwen2 import FP8Qwen2Config from torchao.float8.float8_training_tensor import Float8TrainingTensor from quasar.module import ( FP8Quant, FP8RMSNorm, FP8DSLinearWithCoat, FP8DSLinearWithCoatWeightBlock, FP8FusedSiLUMul, FP8Identity, ) from quasar.kernel.configs import FP8RMSNormConfig, QuantType, FP8MulConfig, FP8DSLinearWithCoatConfig, FP8QuantConfig from quasar.kernel.quant.quantize_hp2pb import fp8_quantize_hp2pb from quasar.kernel.quant.dequantize_pb2hp import fp8_dequantize_pb2hp logger = logging.get_logger(__name__) class FP8Qwen2MLP(Qwen2MLP): def __init__(self, config: FP8Qwen2Config): super().__init__(config) linear_module = FP8DSLinearWithCoat if config.fp8_config.training_mode else FP8DSLinearWithCoatWeightBlock self.gate_proj = linear_module( self.hidden_size, self.intermediate_size, bias=False, dsgemm_config=FP8DSLinearWithCoatConfig(layer_name=f"gate_proj", scale_dtype=torch.float32), ) self.up_proj = linear_module( self.hidden_size, self.intermediate_size, bias=False, dsgemm_config=FP8DSLinearWithCoatConfig(layer_name=f"up_proj", scale_dtype=torch.float32), ) self.down_proj = linear_module( self.intermediate_size, self.hidden_size, bias=False, dsgemm_config=FP8DSLinearWithCoatConfig(layer_name=f"down_proj", scale_dtype=torch.float32), ) if config.hidden_act == "silu": mul_config = FP8MulConfig( quant_type=QuantType.MUL, ) self.act_fn = FP8FusedSiLUMul(mul_config) else: raise ValueError(f"Unsupported activation function: {config.hidden_act}") def forward(self, x): gate_x = self.gate_proj(x) up_x = self.up_proj(x) mul_x = self.act_fn(gate_x, up_x) down_proj = self.down_proj(mul_x) return down_proj class FP8Qwen2Attention(Qwen2Attention): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__(self, config: FP8Qwen2Config, layer_idx: int): super().__init__(config, layer_idx) linear_module = FP8DSLinearWithCoat if config.fp8_config.training_mode else FP8DSLinearWithCoatWeightBlock self.q_proj = linear_module( config.hidden_size, config.num_attention_heads * self.head_dim, bias=True, dsgemm_config=FP8DSLinearWithCoatConfig(layer_name=f"q_proj", scale_dtype=torch.float32), ) self.k_proj = linear_module( config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True, dsgemm_config=FP8DSLinearWithCoatConfig(layer_name=f"k_proj", scale_dtype=torch.float32), ) self.v_proj = linear_module( config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True, dsgemm_config=FP8DSLinearWithCoatConfig(layer_name=f"v_proj", scale_dtype=torch.float32), ) # In both training and inference, we quantize the output of the attention layer. self.o_proj_quant = FP8Quant( quant_config=FP8QuantConfig( float8_dtype=config.fp8_config.float8_dtype, quant_type=QuantType.DIV, fwd_block_size=config.fp8_config.mm_block_size, layer_name=f"o_proj_quant", scale_dtype=torch.float32, ) ) self.o_proj = linear_module( config.num_attention_heads * self.head_dim, config.hidden_size, bias=False, dsgemm_config=FP8DSLinearWithCoatConfig( fwd_input_quant_type=QuantType.DIV, layer_name=f"o_proj", scale_dtype=torch.float32, ), ) @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: if isinstance(hidden_states, Float8TrainingTensor): # Float8Tensor's last dim is quantize group size, not hidden size. input_shape = hidden_states.shape[:-2] else: input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) # TODO: Add quantization if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, query_states, key_states, value_states, attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, sliding_window=self.sliding_window, # main diff with Qwen2 **kwargs, ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() # Quantize the output of the attention layer. attn_output = self.o_proj_quant(attn_output) attn_output = self.o_proj(attn_output) return attn_output, attn_weights class FP8Qwen2DecoderLayer(GradientCheckpointingLayer): def __init__(self, config: FP8Qwen2Config, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size self.self_attn = FP8Qwen2Attention(config=config, layer_idx=layer_idx) self.mlp = FP8Qwen2MLP(config) self.input_layernorm = FP8RMSNorm( config.hidden_size, eps=config.rms_norm_eps, norm_config=FP8RMSNormConfig( mm_block_size=config.fp8_config.mm_block_size, quant_type=QuantType.MUL, save_fp8_input=True ), ) self.post_attention_layernorm = FP8RMSNorm( config.hidden_size, eps=config.rms_norm_eps, norm_config=FP8RMSNormConfig( mm_block_size=config.fp8_config.mm_block_size, quant_type=QuantType.MUL, save_fp8_input=True ), ) self.attention_type = config.layer_types[layer_idx] @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs: Unpack[TransformersKwargs], ) -> torch.Tensor: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) # Self Attention hidden_states, _ = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, **kwargs, ) hidden_states = residual + hidden_states # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states return hidden_states @auto_docstring class FP8Qwen2PreTrainedModel(PreTrainedModel): config_class = FP8Qwen2Config config: FP8Qwen2Config base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["FP8Qwen2DecoderLayer"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True _can_compile_fullgraph = True _supports_attention_backend = True _can_record_outputs = { "hidden_states": FP8Qwen2DecoderLayer, "attentions": FP8Qwen2Attention, } @auto_docstring class FP8Qwen2Model(FP8Qwen2PreTrainedModel): def __init__(self, config: FP8Qwen2Config): super().__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.layers = nn.ModuleList( [FP8Qwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = Qwen2RotaryEmbedding(config=config) self.gradient_checkpointing = False self.has_sliding_layers = "sliding_attention" in self.config.layer_types # Initialize weights and apply final processing self.post_init() forward = Qwen2Model.forward @auto_docstring class FP8Qwen2ForCausalLM(FP8Qwen2PreTrainedModel, GenerationMixin): config_class = FP8Qwen2Config _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} def __init__(self, config): super().__init__(config) self.model = FP8Qwen2Model(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing self.post_init() @can_return_tuple @auto_docstring def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[TransformersKwargs], ) -> CausalLMOutputWithPast: r""" Example: ```python >>> from transformers import AutoTokenizer, Qwen2ForCausalLM >>> model = Qwen2ForCausalLM.from_pretrained("meta-qwen2/Qwen2-2-7b-hf") >>> tokenizer = AutoTokenizer.from_pretrained("meta-qwen2/Qwen2-2-7b-hf") >>> prompt = "Hey, are you conscious? Can you talk to me?" >>> inputs = tokenizer(prompt, return_tensors="pt") >>> # Generate >>> generate_ids = model.generate(inputs.input_ids, max_length=30) >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." ```""" outputs: BaseModelOutputWithPast = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, cache_position=cache_position, **kwargs, ) hidden_states = outputs.last_hidden_state # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) class FP8Qwen2ForSequenceClassification(GenericForSequenceClassification, FP8Qwen2PreTrainedModel): pass class FP8Qwen2ForTokenClassification(GenericForTokenClassification, FP8Qwen2PreTrainedModel): pass class FP8Qwen2ForQuestionAnswering(GenericForQuestionAnswering, FP8Qwen2PreTrainedModel): base_model_prefix = "transformer" # For BC, where `transformer` was used instead of `model` __all__ = [ "FP8Qwen2PreTrainedModel", "FP8Qwen2Model", "FP8Qwen2ForCausalLM", "FP8Qwen2ForSequenceClassification", "FP8Qwen2ForTokenClassification", "FP8Qwen2ForQuestionAnswering", ] FP8Qwen2Model.register_for_auto_class("AutoModel") FP8Qwen2ForCausalLM.register_for_auto_class("AutoModelForCausalLM") def make_state_dict_compatible_with_hf( state_dict: dict[str, torch.Tensor], linear_keys: list[str], undesired_linear_keys: list[str], config: FP8Qwen2Config = FP8Qwen2Config(), already_fp8: bool = False, ) -> dict[str, torch.Tensor]: """ Make the state dict compatible with HuggingFace. """ # Assert linear keys and undesired linear keys are non-overlapping assert set(linear_keys).isdisjoint(set(undesired_linear_keys)) compatible_state_dict = {} for key in state_dict.keys(): if any(k in key for k in linear_keys): weight = state_dict[key] if already_fp8: # The name (either weight or weight_scale_inv) is the same as the original key. compatible_state_dict[key] = weight else: # We need to use float32 for the scale, since we are using DeepGEMM. tmp_quant_cfg = FP8QuantConfig( float8_dtype=config.fp8_config.float8_dtype, quant_type=config.fp8_config.quant_type, fwd_block_size=config.fp8_config.mm_block_size, scale_dtype=torch.float32, ) quant_weight, scale_weight = fp8_quantize_hp2pb( weight, tmp_quant_cfg, block_size=config.fp8_config.mm_block_size ) name_quant = key.replace("weight", "weight") name_scale = key.replace("weight", "weight_scale_inv") compatible_state_dict[name_quant] = quant_weight compatible_state_dict[name_scale] = scale_weight elif any(k in key for k in undesired_linear_keys): # Dequantize the weight if already_fp8: # We only do the dequantization once. When encountering the weight, we skip it. if "weight_scale_inv" in key: name_quant = key.replace("weight_scale_inv", "weight") quant_weight = state_dict[name_quant] scale_weight = state_dict[key] weight = fp8_dequantize_pb2hp( quant_weight, scale_weight, config.fp8_config, block_size=config.fp8_config.mm_block_size ) compatible_state_dict[name_quant] = weight else: # Do not quantize the weight. compatible_state_dict[key] = state_dict[key] else: compatible_state_dict[key] = state_dict[key] return compatible_state_dict def set_named_weight_to_fp8( model: FP8Qwen2ForCausalLM, linear_keys: list[str] = ["q_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"], ): """ Set the dtype of the weight of the linear layers to FP8. Also set layer name for debugging. """ for name, module in model.named_modules(): if any(k in name for k in linear_keys): module.weight.data = module.weight.data.to(torch.float8_e4m3fn) module.layer_name = name return model