Qwen2.5-7B-Instruct-train-Quasar-1002 / modeling_fp8_qwen2.py
xihc-ucb's picture
Upload FP8Qwen2ForCausalLM (#6)
542e4d6 verified
from typing import Callable, Optional, Union
import torch
from torch import nn
from transformers.activations import ACT2FN
from transformers.cache_utils import Cache, DynamicCache
from transformers.generation import GenerationMixin
from transformers.integrations import use_kernel_forward_from_hub
from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.modeling_layers import (
GenericForQuestionAnswering,
GenericForSequenceClassification,
GenericForTokenClassification,
GradientCheckpointingLayer,
)
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from transformers.processing_utils import Unpack
from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
from transformers.utils.deprecation import deprecate_kwarg
from transformers.utils.generic import check_model_inputs
from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
from transformers.models.qwen2.modeling_qwen2 import (
Qwen2MLP,
Qwen2Attention,
apply_rotary_pos_emb,
eager_attention_forward,
Qwen2RMSNorm,
Qwen2RotaryEmbedding,
Qwen2Model,
Qwen2ForCausalLM,
)
from transformers.modeling_layers import (
GenericForQuestionAnswering,
GenericForSequenceClassification,
GenericForTokenClassification,
GradientCheckpointingLayer,
)
from .configuration_fp8_qwen2 import FP8Qwen2Config
from torchao.float8.float8_training_tensor import Float8TrainingTensor
from quasar.module import (
FP8Quant,
FP8RMSNorm,
FP8DSLinearWithCoat,
FP8DSLinearWithCoatWeightBlock,
FP8FusedSiLUMul,
FP8Identity,
)
from quasar.kernel.configs import FP8RMSNormConfig, QuantType, FP8MulConfig, FP8DSLinearWithCoatConfig, FP8QuantConfig
from quasar.kernel.quant.quantize_hp2pb import fp8_quantize_hp2pb
from quasar.kernel.quant.dequantize_pb2hp import fp8_dequantize_pb2hp
logger = logging.get_logger(__name__)
class FP8Qwen2MLP(Qwen2MLP):
def __init__(self, config: FP8Qwen2Config):
super().__init__(config)
linear_module = FP8DSLinearWithCoat if config.fp8_config.training_mode else FP8DSLinearWithCoatWeightBlock
self.gate_proj = linear_module(
self.hidden_size,
self.intermediate_size,
bias=False,
dsgemm_config=FP8DSLinearWithCoatConfig(layer_name=f"gate_proj", scale_dtype=torch.float32),
)
self.up_proj = linear_module(
self.hidden_size,
self.intermediate_size,
bias=False,
dsgemm_config=FP8DSLinearWithCoatConfig(layer_name=f"up_proj", scale_dtype=torch.float32),
)
self.down_proj = linear_module(
self.intermediate_size,
self.hidden_size,
bias=False,
dsgemm_config=FP8DSLinearWithCoatConfig(layer_name=f"down_proj", scale_dtype=torch.float32),
)
if config.hidden_act == "silu":
mul_config = FP8MulConfig(
quant_type=QuantType.MUL,
)
self.act_fn = FP8FusedSiLUMul(mul_config)
else:
raise ValueError(f"Unsupported activation function: {config.hidden_act}")
def forward(self, x):
gate_x = self.gate_proj(x)
up_x = self.up_proj(x)
mul_x = self.act_fn(gate_x, up_x)
down_proj = self.down_proj(mul_x)
return down_proj
class FP8Qwen2Attention(Qwen2Attention):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: FP8Qwen2Config, layer_idx: int):
super().__init__(config, layer_idx)
linear_module = FP8DSLinearWithCoat if config.fp8_config.training_mode else FP8DSLinearWithCoatWeightBlock
self.q_proj = linear_module(
config.hidden_size,
config.num_attention_heads * self.head_dim,
bias=True,
dsgemm_config=FP8DSLinearWithCoatConfig(layer_name=f"q_proj", scale_dtype=torch.float32),
)
self.k_proj = linear_module(
config.hidden_size,
config.num_key_value_heads * self.head_dim,
bias=True,
dsgemm_config=FP8DSLinearWithCoatConfig(layer_name=f"k_proj", scale_dtype=torch.float32),
)
self.v_proj = linear_module(
config.hidden_size,
config.num_key_value_heads * self.head_dim,
bias=True,
dsgemm_config=FP8DSLinearWithCoatConfig(layer_name=f"v_proj", scale_dtype=torch.float32),
)
# In both training and inference, we quantize the output of the attention layer.
self.o_proj_quant = FP8Quant(
quant_config=FP8QuantConfig(
float8_dtype=config.fp8_config.float8_dtype,
quant_type=QuantType.DIV,
fwd_block_size=config.fp8_config.mm_block_size,
layer_name=f"o_proj_quant",
scale_dtype=torch.float32,
)
)
self.o_proj = linear_module(
config.num_attention_heads * self.head_dim,
config.hidden_size,
bias=False,
dsgemm_config=FP8DSLinearWithCoatConfig(
fwd_input_quant_type=QuantType.DIV,
layer_name=f"o_proj",
scale_dtype=torch.float32,
),
)
@deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor],
past_key_values: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
if isinstance(hidden_states, Float8TrainingTensor):
# Float8Tensor's last dim is quantize group size, not hidden size.
input_shape = hidden_states.shape[:-2]
else:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
# TODO: Add quantization
if past_key_values is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
sliding_window=self.sliding_window, # main diff with Qwen2
**kwargs,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
# Quantize the output of the attention layer.
attn_output = self.o_proj_quant(attn_output)
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
class FP8Qwen2DecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: FP8Qwen2Config, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = FP8Qwen2Attention(config=config, layer_idx=layer_idx)
self.mlp = FP8Qwen2MLP(config)
self.input_layernorm = FP8RMSNorm(
config.hidden_size,
eps=config.rms_norm_eps,
norm_config=FP8RMSNormConfig(
mm_block_size=config.fp8_config.mm_block_size, quant_type=QuantType.MUL, save_fp8_input=True
),
)
self.post_attention_layernorm = FP8RMSNorm(
config.hidden_size,
eps=config.rms_norm_eps,
norm_config=FP8RMSNormConfig(
mm_block_size=config.fp8_config.mm_block_size, quant_type=QuantType.MUL, save_fp8_input=True
),
)
self.attention_type = config.layer_types[layer_idx]
@deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
**kwargs: Unpack[TransformersKwargs],
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, _ = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
@auto_docstring
class FP8Qwen2PreTrainedModel(PreTrainedModel):
config_class = FP8Qwen2Config
config: FP8Qwen2Config
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["FP8Qwen2DecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn = True
_supports_sdpa = True
_supports_flex_attn = True
_can_compile_fullgraph = True
_supports_attention_backend = True
_can_record_outputs = {
"hidden_states": FP8Qwen2DecoderLayer,
"attentions": FP8Qwen2Attention,
}
@auto_docstring
class FP8Qwen2Model(FP8Qwen2PreTrainedModel):
def __init__(self, config: FP8Qwen2Config):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList(
[FP8Qwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = Qwen2RotaryEmbedding(config=config)
self.gradient_checkpointing = False
self.has_sliding_layers = "sliding_attention" in self.config.layer_types
# Initialize weights and apply final processing
self.post_init()
forward = Qwen2Model.forward
@auto_docstring
class FP8Qwen2ForCausalLM(FP8Qwen2PreTrainedModel, GenerationMixin):
config_class = FP8Qwen2Config
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
def __init__(self, config):
super().__init__(config)
self.model = FP8Qwen2Model(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
@can_return_tuple
@auto_docstring
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs: Unpack[TransformersKwargs],
) -> CausalLMOutputWithPast:
r"""
Example:
```python
>>> from transformers import AutoTokenizer, Qwen2ForCausalLM
>>> model = Qwen2ForCausalLM.from_pretrained("meta-qwen2/Qwen2-2-7b-hf")
>>> tokenizer = AutoTokenizer.from_pretrained("meta-qwen2/Qwen2-2-7b-hf")
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
outputs: BaseModelOutputWithPast = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs.last_hidden_state
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
class FP8Qwen2ForSequenceClassification(GenericForSequenceClassification, FP8Qwen2PreTrainedModel):
pass
class FP8Qwen2ForTokenClassification(GenericForTokenClassification, FP8Qwen2PreTrainedModel):
pass
class FP8Qwen2ForQuestionAnswering(GenericForQuestionAnswering, FP8Qwen2PreTrainedModel):
base_model_prefix = "transformer" # For BC, where `transformer` was used instead of `model`
__all__ = [
"FP8Qwen2PreTrainedModel",
"FP8Qwen2Model",
"FP8Qwen2ForCausalLM",
"FP8Qwen2ForSequenceClassification",
"FP8Qwen2ForTokenClassification",
"FP8Qwen2ForQuestionAnswering",
]
FP8Qwen2Model.register_for_auto_class("AutoModel")
FP8Qwen2ForCausalLM.register_for_auto_class("AutoModelForCausalLM")
def make_state_dict_compatible_with_hf(
state_dict: dict[str, torch.Tensor],
linear_keys: list[str],
undesired_linear_keys: list[str],
config: FP8Qwen2Config = FP8Qwen2Config(),
already_fp8: bool = False,
) -> dict[str, torch.Tensor]:
"""
Make the state dict compatible with HuggingFace.
"""
# Assert linear keys and undesired linear keys are non-overlapping
assert set(linear_keys).isdisjoint(set(undesired_linear_keys))
compatible_state_dict = {}
for key in state_dict.keys():
if any(k in key for k in linear_keys):
weight = state_dict[key]
if already_fp8:
# The name (either weight or weight_scale_inv) is the same as the original key.
compatible_state_dict[key] = weight
else:
# We need to use float32 for the scale, since we are using DeepGEMM.
tmp_quant_cfg = FP8QuantConfig(
float8_dtype=config.fp8_config.float8_dtype,
quant_type=config.fp8_config.quant_type,
fwd_block_size=config.fp8_config.mm_block_size,
scale_dtype=torch.float32,
)
quant_weight, scale_weight = fp8_quantize_hp2pb(
weight, tmp_quant_cfg, block_size=config.fp8_config.mm_block_size
)
name_quant = key.replace("weight", "weight")
name_scale = key.replace("weight", "weight_scale_inv")
compatible_state_dict[name_quant] = quant_weight
compatible_state_dict[name_scale] = scale_weight
elif any(k in key for k in undesired_linear_keys):
# Dequantize the weight
if already_fp8:
# We only do the dequantization once. When encountering the weight, we skip it.
if "weight_scale_inv" in key:
name_quant = key.replace("weight_scale_inv", "weight")
quant_weight = state_dict[name_quant]
scale_weight = state_dict[key]
weight = fp8_dequantize_pb2hp(
quant_weight, scale_weight, config.fp8_config, block_size=config.fp8_config.mm_block_size
)
compatible_state_dict[name_quant] = weight
else:
# Do not quantize the weight.
compatible_state_dict[key] = state_dict[key]
else:
compatible_state_dict[key] = state_dict[key]
return compatible_state_dict
def set_named_weight_to_fp8(
model: FP8Qwen2ForCausalLM,
linear_keys: list[str] = ["q_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
):
"""
Set the dtype of the weight of the linear layers to FP8.
Also set layer name for debugging.
"""
for name, module in model.named_modules():
if any(k in name for k in linear_keys):
module.weight.data = module.weight.data.to(torch.float8_e4m3fn)
module.layer_name = name
return model