# coding=utf-8 # Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Qwen2 model configuration""" import torch from dataclasses import dataclass, asdict from enum import Enum from transformers.configuration_utils import PretrainedConfig, layer_type_validation from transformers.modeling_rope_utils import rope_config_validation from transformers.utils import logging from transformers.models.qwen2.configuration_qwen2 import Qwen2Config from quasar.kernel.configs import QuantType logger = logging.get_logger(__name__) @dataclass class FP8Config: """ Configuration for FP8 quantization. """ float8_dtype: torch.dtype = torch.float8_e4m3fn quant_type: QuantType = QuantType.DIV layer_name: str = "" act_block_size: int = 16 mm_block_size: int = 128 training_mode: bool = True """ If True, the linear layer will use high-precision weight. If False, the linear layer will use per-block quantized weight. """ class FP8Qwen2Config(Qwen2Config): model_type = "fp8_qwen2" fp8_config: FP8Config = FP8Config() model_name_orig: str = "" """Pass the name of the BF16 model""" def __init__( self, vocab_size=151936, hidden_size=4096, intermediate_size=22016, num_hidden_layers=32, num_attention_heads=32, num_key_value_heads=32, hidden_act="silu", max_position_embeddings=32768, initializer_range=0.02, rms_norm_eps=1e-6, use_cache=True, tie_word_embeddings=False, rope_theta=10000.0, rope_scaling=None, use_sliding_window=False, sliding_window=4096, max_window_layers=28, layer_types=None, attention_dropout=0.0, # Customized configs begins here fp8_config=None, model_name_orig="", **kwargs, ): super().__init__( vocab_size=vocab_size, hidden_size=hidden_size, intermediate_size=intermediate_size, num_hidden_layers=num_hidden_layers, num_attention_heads=num_attention_heads, num_key_value_heads=num_key_value_heads, hidden_act=hidden_act, max_position_embeddings=max_position_embeddings, initializer_range=initializer_range, rms_norm_eps=rms_norm_eps, use_cache=use_cache, tie_word_embeddings=tie_word_embeddings, rope_theta=rope_theta, rope_scaling=rope_scaling, use_sliding_window=use_sliding_window, sliding_window=sliding_window, max_window_layers=max_window_layers, layer_types=layer_types, attention_dropout=attention_dropout, **kwargs, ) # Convert it from dict to FP8Config (dataclass) if fp8_config is not None: self.fp8_config = fp8_config if isinstance(fp8_config, FP8Config) else FP8Config(**fp8_config) else: self.fp8_config = FP8Config() self.model_name_orig = model_name_orig def to_dict(self): output = super().to_dict() if hasattr(self.fp8_config, "__dataclass_fields__"): cfg_dict = asdict(self.fp8_config) for k, v in cfg_dict.items(): if isinstance(v, torch.dtype): # float8_dtype cfg_dict[k] = str(v) # save as 'torch.float8_e4m3fn' elif isinstance(v, Enum): # quant_type cfg_dict[k] = v.name # save as 'DIV' output["fp8_config"] = cfg_dict else: output["fp8_config"] = self.fp8_config return output @classmethod def from_dict(cls, config_dict, **kwargs): config = super().from_dict(config_dict, **kwargs) fp8_config = config_dict.get("fp8_config", {}) for k, v in fp8_config.items(): if k == "float8_dtype": assert v.startswith("torch."), f"Invalid float8_dtype: {v}" fp8_config[k] = getattr(torch, v[len("torch."):]) # elif k == "quant_type": fp8_config[k] = getattr(QuantType, v) config.fp8_config = FP8Config(**fp8_config) return config __all__ = ["FP8Qwen2Config"] FP8Qwen2Config.register_for_auto_class()