|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Qwen2 model configuration""" |
|
|
|
|
|
import torch |
|
|
from dataclasses import dataclass, asdict |
|
|
from enum import Enum |
|
|
|
|
|
from transformers.configuration_utils import PretrainedConfig, layer_type_validation |
|
|
from transformers.modeling_rope_utils import rope_config_validation |
|
|
from transformers.utils import logging |
|
|
|
|
|
from transformers.models.qwen2.configuration_qwen2 import Qwen2Config |
|
|
|
|
|
from quasar.kernel.configs import QuantType |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class FP8Config: |
|
|
""" |
|
|
Configuration for FP8 quantization. |
|
|
""" |
|
|
|
|
|
float8_dtype: torch.dtype = torch.float8_e4m3fn |
|
|
quant_type: QuantType = QuantType.DIV |
|
|
layer_name: str = "" |
|
|
|
|
|
act_block_size: int = 16 |
|
|
mm_block_size: int = 128 |
|
|
|
|
|
training_mode: bool = True |
|
|
""" |
|
|
If True, the linear layer will use high-precision weight. |
|
|
If False, the linear layer will use per-block quantized weight. |
|
|
""" |
|
|
|
|
|
|
|
|
class FP8Qwen2Config(Qwen2Config): |
|
|
model_type = "fp8_qwen2" |
|
|
fp8_config: FP8Config = FP8Config() |
|
|
model_name_orig: str = "" |
|
|
"""Pass the name of the BF16 model""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
vocab_size=151936, |
|
|
hidden_size=4096, |
|
|
intermediate_size=22016, |
|
|
num_hidden_layers=32, |
|
|
num_attention_heads=32, |
|
|
num_key_value_heads=32, |
|
|
hidden_act="silu", |
|
|
max_position_embeddings=32768, |
|
|
initializer_range=0.02, |
|
|
rms_norm_eps=1e-6, |
|
|
use_cache=True, |
|
|
tie_word_embeddings=False, |
|
|
rope_theta=10000.0, |
|
|
rope_scaling=None, |
|
|
use_sliding_window=False, |
|
|
sliding_window=4096, |
|
|
max_window_layers=28, |
|
|
layer_types=None, |
|
|
attention_dropout=0.0, |
|
|
|
|
|
fp8_config=None, |
|
|
model_name_orig="", |
|
|
**kwargs, |
|
|
): |
|
|
super().__init__( |
|
|
vocab_size=vocab_size, |
|
|
hidden_size=hidden_size, |
|
|
intermediate_size=intermediate_size, |
|
|
num_hidden_layers=num_hidden_layers, |
|
|
num_attention_heads=num_attention_heads, |
|
|
num_key_value_heads=num_key_value_heads, |
|
|
hidden_act=hidden_act, |
|
|
max_position_embeddings=max_position_embeddings, |
|
|
initializer_range=initializer_range, |
|
|
rms_norm_eps=rms_norm_eps, |
|
|
use_cache=use_cache, |
|
|
tie_word_embeddings=tie_word_embeddings, |
|
|
rope_theta=rope_theta, |
|
|
rope_scaling=rope_scaling, |
|
|
use_sliding_window=use_sliding_window, |
|
|
sliding_window=sliding_window, |
|
|
max_window_layers=max_window_layers, |
|
|
layer_types=layer_types, |
|
|
attention_dropout=attention_dropout, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
if fp8_config is not None: |
|
|
self.fp8_config = fp8_config if isinstance(fp8_config, FP8Config) else FP8Config(**fp8_config) |
|
|
else: |
|
|
self.fp8_config = FP8Config() |
|
|
|
|
|
self.model_name_orig = model_name_orig |
|
|
|
|
|
|
|
|
def to_dict(self): |
|
|
output = super().to_dict() |
|
|
if hasattr(self.fp8_config, "__dataclass_fields__"): |
|
|
cfg_dict = asdict(self.fp8_config) |
|
|
for k, v in cfg_dict.items(): |
|
|
if isinstance(v, torch.dtype): |
|
|
cfg_dict[k] = str(v) |
|
|
elif isinstance(v, Enum): |
|
|
cfg_dict[k] = v.name |
|
|
output["fp8_config"] = cfg_dict |
|
|
else: |
|
|
output["fp8_config"] = self.fp8_config |
|
|
return output |
|
|
|
|
|
|
|
|
@classmethod |
|
|
def from_dict(cls, config_dict, **kwargs): |
|
|
config = super().from_dict(config_dict, **kwargs) |
|
|
|
|
|
fp8_config = config_dict.get("fp8_config", {}) |
|
|
for k, v in fp8_config.items(): |
|
|
if k == "float8_dtype": |
|
|
assert v.startswith("torch."), f"Invalid float8_dtype: {v}" |
|
|
fp8_config[k] = getattr(torch, v[len("torch."):]) |
|
|
elif k == "quant_type": |
|
|
fp8_config[k] = getattr(QuantType, v) |
|
|
config.fp8_config = FP8Config(**fp8_config) |
|
|
return config |
|
|
|
|
|
|
|
|
__all__ = ["FP8Qwen2Config"] |
|
|
|
|
|
FP8Qwen2Config.register_for_auto_class() |