Qwen2.5-7B-Instruct-train-Quasar-1002 / configuration_fp8_qwen2.py
xihc-ucb's picture
Upload FP8Qwen2ForCausalLM
94c4ed7 verified
# coding=utf-8
# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Qwen2 model configuration"""
import torch
from dataclasses import dataclass, asdict
from enum import Enum
from transformers.configuration_utils import PretrainedConfig, layer_type_validation
from transformers.modeling_rope_utils import rope_config_validation
from transformers.utils import logging
from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
from quasar.kernel.configs import QuantType
logger = logging.get_logger(__name__)
@dataclass
class FP8Config:
"""
Configuration for FP8 quantization.
"""
float8_dtype: torch.dtype = torch.float8_e4m3fn
quant_type: QuantType = QuantType.DIV
layer_name: str = ""
act_block_size: int = 16
mm_block_size: int = 128
training_mode: bool = True
"""
If True, the linear layer will use high-precision weight.
If False, the linear layer will use per-block quantized weight.
"""
class FP8Qwen2Config(Qwen2Config):
model_type = "fp8_qwen2"
fp8_config: FP8Config = FP8Config()
model_name_orig: str = ""
"""Pass the name of the BF16 model"""
def __init__(
self,
vocab_size=151936,
hidden_size=4096,
intermediate_size=22016,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=32,
hidden_act="silu",
max_position_embeddings=32768,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
tie_word_embeddings=False,
rope_theta=10000.0,
rope_scaling=None,
use_sliding_window=False,
sliding_window=4096,
max_window_layers=28,
layer_types=None,
attention_dropout=0.0,
# Customized configs begins here
fp8_config=None,
model_name_orig="",
**kwargs,
):
super().__init__(
vocab_size=vocab_size,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_hidden_layers=num_hidden_layers,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
hidden_act=hidden_act,
max_position_embeddings=max_position_embeddings,
initializer_range=initializer_range,
rms_norm_eps=rms_norm_eps,
use_cache=use_cache,
tie_word_embeddings=tie_word_embeddings,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
use_sliding_window=use_sliding_window,
sliding_window=sliding_window,
max_window_layers=max_window_layers,
layer_types=layer_types,
attention_dropout=attention_dropout,
**kwargs,
)
# Convert it from dict to FP8Config (dataclass)
if fp8_config is not None:
self.fp8_config = fp8_config if isinstance(fp8_config, FP8Config) else FP8Config(**fp8_config)
else:
self.fp8_config = FP8Config()
self.model_name_orig = model_name_orig
def to_dict(self):
output = super().to_dict()
if hasattr(self.fp8_config, "__dataclass_fields__"):
cfg_dict = asdict(self.fp8_config)
for k, v in cfg_dict.items():
if isinstance(v, torch.dtype): # float8_dtype
cfg_dict[k] = str(v) # save as 'torch.float8_e4m3fn'
elif isinstance(v, Enum): # quant_type
cfg_dict[k] = v.name # save as 'DIV'
output["fp8_config"] = cfg_dict
else:
output["fp8_config"] = self.fp8_config
return output
@classmethod
def from_dict(cls, config_dict, **kwargs):
config = super().from_dict(config_dict, **kwargs)
fp8_config = config_dict.get("fp8_config", {})
for k, v in fp8_config.items():
if k == "float8_dtype":
assert v.startswith("torch."), f"Invalid float8_dtype: {v}"
fp8_config[k] = getattr(torch, v[len("torch."):]) #
elif k == "quant_type":
fp8_config[k] = getattr(QuantType, v)
config.fp8_config = FP8Config(**fp8_config)
return config
__all__ = ["FP8Qwen2Config"]
FP8Qwen2Config.register_for_auto_class()