# coding=utf-8 # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. # # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX # and OPT implementations in this library. It has been modified from its # original forms to accommodate minor architectural differences compared # to GPT-NeoX and OPT used by the Meta AI team that trained the model. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """LLaMA model configuration""" import torch from dataclasses import dataclass, asdict from enum import Enum from transformers.configuration_utils import PretrainedConfig from transformers.utils import logging from transformers.models.llama.configuration_llama import LlamaConfig from quasar.kernel.configs import QuantType logger = logging.get_logger(__name__) @dataclass class FP8Config: """ Configuration for FP8 quantization. """ float8_dtype: torch.dtype = torch.float8_e4m3fn quant_type: QuantType = QuantType.DIV layer_name: str = "" act_block_size: int = 16 mm_block_size: int = 128 training_mode: bool = True """ If True, the linear layer will use high-precision weight. If False, the linear layer will use per-block quantized weight. """ class FP8LlamaConfig(LlamaConfig): model_type = "fp8_llama" fp8_config: FP8Config = FP8Config() model_name_orig: str = "" """Pass the name of the BF16 model""" def __init__( self, vocab_size=32000, hidden_size=4096, intermediate_size=11008, num_hidden_layers=32, num_attention_heads=32, num_key_value_heads=None, hidden_act="silu", max_position_embeddings=2048, initializer_range=0.02, rms_norm_eps=1e-6, use_cache=True, pad_token_id=None, bos_token_id=1, eos_token_id=2, pretraining_tp=1, tie_word_embeddings=False, rope_theta=10000.0, rope_scaling=None, attention_bias=False, attention_dropout=0.0, mlp_bias=False, head_dim=None, # Customized configs begins here fp8_config=None, model_name_orig="", **kwargs, ): super().__init__( vocab_size=vocab_size, hidden_size=hidden_size, intermediate_size=intermediate_size, num_hidden_layers=num_hidden_layers, num_attention_heads=num_attention_heads, num_key_value_heads=num_key_value_heads, hidden_act=hidden_act, max_position_embeddings=max_position_embeddings, initializer_range=initializer_range, rms_norm_eps=rms_norm_eps, use_cache=use_cache, pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, pretraining_tp=pretraining_tp, tie_word_embeddings=tie_word_embeddings, rope_theta=rope_theta, rope_scaling=rope_scaling, attention_bias=attention_bias, attention_dropout=attention_dropout, mlp_bias=mlp_bias, head_dim=head_dim, **kwargs, ) # Convert it from dict to FP8Config (dataclass) if fp8_config is not None: self.fp8_config = fp8_config if isinstance(fp8_config, FP8Config) else FP8Config(**fp8_config) else: self.fp8_config = FP8Config() self.model_name_orig = model_name_orig def to_dict(self): output = super().to_dict() if hasattr(self.fp8_config, "__dataclass_fields__"): cfg_dict = asdict(self.fp8_config) for k, v in cfg_dict.items(): if isinstance(v, torch.dtype): # float8_dtype cfg_dict[k] = str(v) # save as 'torch.float8_e4m3fn' elif isinstance(v, Enum): # quant_type cfg_dict[k] = v.name # save as 'DIV' output["fp8_config"] = cfg_dict else: output["fp8_config"] = self.fp8_config return output @classmethod def from_dict(cls, config_dict, **kwargs): config = super().from_dict(config_dict, **kwargs) fp8_config = config_dict.get("fp8_config", {}) for k, v in fp8_config.items(): if k == "float8_dtype": assert v.startswith("torch."), f"Invalid float8_dtype: {v}" fp8_config[k] = getattr(torch, v[len("torch."):]) # elif k == "quant_type": fp8_config[k] = getattr(QuantType, v) config.fp8_config = FP8Config(**fp8_config) return config __all__ = ["FP8LlamaConfig"] FP8LlamaConfig.register_for_auto_class()