Upload FP8Qwen2ForCausalLM (#6)
Browse files- Upload FP8Qwen2ForCausalLM (3e721beb3dfa87636eda6249a856fec8c4dfaae8)
Co-authored-by: Neil <[email protected]>
- config.json +2 -2
- generation_config.json +1 -1
- modeling_fp8_qwen2.py +84 -27
config.json
CHANGED
|
@@ -12,7 +12,6 @@
|
|
| 12 |
"AutoModelForTokenClassification": "modeling_fp8_qwen2.FP8Qwen2ForTokenClassification"
|
| 13 |
},
|
| 14 |
"bos_token_id": 151643,
|
| 15 |
-
"dtype": "bfloat16",
|
| 16 |
"eos_token_id": 151645,
|
| 17 |
"fp8_config": {
|
| 18 |
"act_block_size": 16,
|
|
@@ -68,7 +67,8 @@
|
|
| 68 |
"rope_theta": 1000000.0,
|
| 69 |
"sliding_window": null,
|
| 70 |
"tie_word_embeddings": false,
|
| 71 |
-
"
|
|
|
|
| 72 |
"use_cache": true,
|
| 73 |
"use_sliding_window": false,
|
| 74 |
"vocab_size": 152064
|
|
|
|
| 12 |
"AutoModelForTokenClassification": "modeling_fp8_qwen2.FP8Qwen2ForTokenClassification"
|
| 13 |
},
|
| 14 |
"bos_token_id": 151643,
|
|
|
|
| 15 |
"eos_token_id": 151645,
|
| 16 |
"fp8_config": {
|
| 17 |
"act_block_size": 16,
|
|
|
|
| 67 |
"rope_theta": 1000000.0,
|
| 68 |
"sliding_window": null,
|
| 69 |
"tie_word_embeddings": false,
|
| 70 |
+
"torch_dtype": "bfloat16",
|
| 71 |
+
"transformers_version": "4.54.1",
|
| 72 |
"use_cache": true,
|
| 73 |
"use_sliding_window": false,
|
| 74 |
"vocab_size": 152064
|
generation_config.json
CHANGED
|
@@ -2,5 +2,5 @@
|
|
| 2 |
"_from_model_config": true,
|
| 3 |
"bos_token_id": 151643,
|
| 4 |
"eos_token_id": 151645,
|
| 5 |
-
"transformers_version": "4.
|
| 6 |
}
|
|
|
|
| 2 |
"_from_model_config": true,
|
| 3 |
"bos_token_id": 151643,
|
| 4 |
"eos_token_id": 151645,
|
| 5 |
+
"transformers_version": "4.54.1"
|
| 6 |
}
|
modeling_fp8_qwen2.py
CHANGED
|
@@ -47,6 +47,7 @@ from .configuration_fp8_qwen2 import FP8Qwen2Config
|
|
| 47 |
from torchao.float8.float8_training_tensor import Float8TrainingTensor
|
| 48 |
|
| 49 |
from quasar.module import (
|
|
|
|
| 50 |
FP8RMSNorm,
|
| 51 |
FP8DSLinearWithCoat,
|
| 52 |
FP8DSLinearWithCoatWeightBlock,
|
|
@@ -65,9 +66,24 @@ class FP8Qwen2MLP(Qwen2MLP):
|
|
| 65 |
def __init__(self, config: FP8Qwen2Config):
|
| 66 |
super().__init__(config)
|
| 67 |
linear_module = FP8DSLinearWithCoat if config.fp8_config.training_mode else FP8DSLinearWithCoatWeightBlock
|
| 68 |
-
self.gate_proj = linear_module(
|
| 69 |
-
|
| 70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
|
| 72 |
if config.hidden_act == "silu":
|
| 73 |
mul_config = FP8MulConfig(
|
|
@@ -93,22 +109,46 @@ class FP8Qwen2Attention(Qwen2Attention):
|
|
| 93 |
def __init__(self, config: FP8Qwen2Config, layer_idx: int):
|
| 94 |
super().__init__(config, layer_idx)
|
| 95 |
linear_module = FP8DSLinearWithCoat if config.fp8_config.training_mode else FP8DSLinearWithCoatWeightBlock
|
| 96 |
-
self.q_proj = linear_module(
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
)
|
| 110 |
-
|
| 111 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
@deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
|
| 113 |
def forward(
|
| 114 |
self,
|
|
@@ -157,6 +197,9 @@ class FP8Qwen2Attention(Qwen2Attention):
|
|
| 157 |
)
|
| 158 |
|
| 159 |
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
|
|
|
|
|
|
|
|
|
| 160 |
attn_output = self.o_proj(attn_output)
|
| 161 |
return attn_output, attn_weights
|
| 162 |
|
|
@@ -169,8 +212,20 @@ class FP8Qwen2DecoderLayer(GradientCheckpointingLayer):
|
|
| 169 |
self.self_attn = FP8Qwen2Attention(config=config, layer_idx=layer_idx)
|
| 170 |
|
| 171 |
self.mlp = FP8Qwen2MLP(config)
|
| 172 |
-
self.input_layernorm = FP8RMSNorm(
|
| 173 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
self.attention_type = config.layer_types[layer_idx]
|
| 175 |
|
| 176 |
@deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
|
|
@@ -365,13 +420,13 @@ def make_state_dict_compatible_with_hf(
|
|
| 365 |
"""
|
| 366 |
# Assert linear keys and undesired linear keys are non-overlapping
|
| 367 |
assert set(linear_keys).isdisjoint(set(undesired_linear_keys))
|
| 368 |
-
|
| 369 |
compatible_state_dict = {}
|
| 370 |
|
| 371 |
for key in state_dict.keys():
|
| 372 |
if any(k in key for k in linear_keys):
|
| 373 |
weight = state_dict[key]
|
| 374 |
-
|
| 375 |
if already_fp8:
|
| 376 |
# The name (either weight or weight_scale_inv) is the same as the original key.
|
| 377 |
compatible_state_dict[key] = weight
|
|
@@ -381,17 +436,17 @@ def make_state_dict_compatible_with_hf(
|
|
| 381 |
float8_dtype=config.fp8_config.float8_dtype,
|
| 382 |
quant_type=config.fp8_config.quant_type,
|
| 383 |
fwd_block_size=config.fp8_config.mm_block_size,
|
| 384 |
-
scale_dtype=torch.float32,
|
| 385 |
)
|
| 386 |
quant_weight, scale_weight = fp8_quantize_hp2pb(
|
| 387 |
weight, tmp_quant_cfg, block_size=config.fp8_config.mm_block_size
|
| 388 |
)
|
| 389 |
-
|
| 390 |
name_quant = key.replace("weight", "weight")
|
| 391 |
name_scale = key.replace("weight", "weight_scale_inv")
|
| 392 |
compatible_state_dict[name_quant] = quant_weight
|
| 393 |
compatible_state_dict[name_scale] = scale_weight
|
| 394 |
-
|
| 395 |
elif any(k in key for k in undesired_linear_keys):
|
| 396 |
# Dequantize the weight
|
| 397 |
if already_fp8:
|
|
@@ -400,12 +455,14 @@ def make_state_dict_compatible_with_hf(
|
|
| 400 |
name_quant = key.replace("weight_scale_inv", "weight")
|
| 401 |
quant_weight = state_dict[name_quant]
|
| 402 |
scale_weight = state_dict[key]
|
| 403 |
-
weight = fp8_dequantize_pb2hp(
|
|
|
|
|
|
|
| 404 |
compatible_state_dict[name_quant] = weight
|
| 405 |
else:
|
| 406 |
# Do not quantize the weight.
|
| 407 |
compatible_state_dict[key] = state_dict[key]
|
| 408 |
-
|
| 409 |
else:
|
| 410 |
compatible_state_dict[key] = state_dict[key]
|
| 411 |
return compatible_state_dict
|
|
|
|
| 47 |
from torchao.float8.float8_training_tensor import Float8TrainingTensor
|
| 48 |
|
| 49 |
from quasar.module import (
|
| 50 |
+
FP8Quant,
|
| 51 |
FP8RMSNorm,
|
| 52 |
FP8DSLinearWithCoat,
|
| 53 |
FP8DSLinearWithCoatWeightBlock,
|
|
|
|
| 66 |
def __init__(self, config: FP8Qwen2Config):
|
| 67 |
super().__init__(config)
|
| 68 |
linear_module = FP8DSLinearWithCoat if config.fp8_config.training_mode else FP8DSLinearWithCoatWeightBlock
|
| 69 |
+
self.gate_proj = linear_module(
|
| 70 |
+
self.hidden_size,
|
| 71 |
+
self.intermediate_size,
|
| 72 |
+
bias=False,
|
| 73 |
+
dsgemm_config=FP8DSLinearWithCoatConfig(layer_name=f"gate_proj", scale_dtype=torch.float32),
|
| 74 |
+
)
|
| 75 |
+
self.up_proj = linear_module(
|
| 76 |
+
self.hidden_size,
|
| 77 |
+
self.intermediate_size,
|
| 78 |
+
bias=False,
|
| 79 |
+
dsgemm_config=FP8DSLinearWithCoatConfig(layer_name=f"up_proj", scale_dtype=torch.float32),
|
| 80 |
+
)
|
| 81 |
+
self.down_proj = linear_module(
|
| 82 |
+
self.intermediate_size,
|
| 83 |
+
self.hidden_size,
|
| 84 |
+
bias=False,
|
| 85 |
+
dsgemm_config=FP8DSLinearWithCoatConfig(layer_name=f"down_proj", scale_dtype=torch.float32),
|
| 86 |
+
)
|
| 87 |
|
| 88 |
if config.hidden_act == "silu":
|
| 89 |
mul_config = FP8MulConfig(
|
|
|
|
| 109 |
def __init__(self, config: FP8Qwen2Config, layer_idx: int):
|
| 110 |
super().__init__(config, layer_idx)
|
| 111 |
linear_module = FP8DSLinearWithCoat if config.fp8_config.training_mode else FP8DSLinearWithCoatWeightBlock
|
| 112 |
+
self.q_proj = linear_module(
|
| 113 |
+
config.hidden_size,
|
| 114 |
+
config.num_attention_heads * self.head_dim,
|
| 115 |
+
bias=True,
|
| 116 |
+
dsgemm_config=FP8DSLinearWithCoatConfig(layer_name=f"q_proj", scale_dtype=torch.float32),
|
| 117 |
+
)
|
| 118 |
+
self.k_proj = linear_module(
|
| 119 |
+
config.hidden_size,
|
| 120 |
+
config.num_key_value_heads * self.head_dim,
|
| 121 |
+
bias=True,
|
| 122 |
+
dsgemm_config=FP8DSLinearWithCoatConfig(layer_name=f"k_proj", scale_dtype=torch.float32),
|
| 123 |
+
)
|
| 124 |
+
self.v_proj = linear_module(
|
| 125 |
+
config.hidden_size,
|
| 126 |
+
config.num_key_value_heads * self.head_dim,
|
| 127 |
+
bias=True,
|
| 128 |
+
dsgemm_config=FP8DSLinearWithCoatConfig(layer_name=f"v_proj", scale_dtype=torch.float32),
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
# In both training and inference, we quantize the output of the attention layer.
|
| 132 |
+
self.o_proj_quant = FP8Quant(
|
| 133 |
+
quant_config=FP8QuantConfig(
|
| 134 |
+
float8_dtype=config.fp8_config.float8_dtype,
|
| 135 |
+
quant_type=QuantType.DIV,
|
| 136 |
+
fwd_block_size=config.fp8_config.mm_block_size,
|
| 137 |
+
layer_name=f"o_proj_quant",
|
| 138 |
+
scale_dtype=torch.float32,
|
| 139 |
)
|
| 140 |
+
)
|
| 141 |
+
self.o_proj = linear_module(
|
| 142 |
+
config.num_attention_heads * self.head_dim,
|
| 143 |
+
config.hidden_size,
|
| 144 |
+
bias=False,
|
| 145 |
+
dsgemm_config=FP8DSLinearWithCoatConfig(
|
| 146 |
+
fwd_input_quant_type=QuantType.DIV,
|
| 147 |
+
layer_name=f"o_proj",
|
| 148 |
+
scale_dtype=torch.float32,
|
| 149 |
+
),
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
@deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
|
| 153 |
def forward(
|
| 154 |
self,
|
|
|
|
| 197 |
)
|
| 198 |
|
| 199 |
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
| 200 |
+
|
| 201 |
+
# Quantize the output of the attention layer.
|
| 202 |
+
attn_output = self.o_proj_quant(attn_output)
|
| 203 |
attn_output = self.o_proj(attn_output)
|
| 204 |
return attn_output, attn_weights
|
| 205 |
|
|
|
|
| 212 |
self.self_attn = FP8Qwen2Attention(config=config, layer_idx=layer_idx)
|
| 213 |
|
| 214 |
self.mlp = FP8Qwen2MLP(config)
|
| 215 |
+
self.input_layernorm = FP8RMSNorm(
|
| 216 |
+
config.hidden_size,
|
| 217 |
+
eps=config.rms_norm_eps,
|
| 218 |
+
norm_config=FP8RMSNormConfig(
|
| 219 |
+
mm_block_size=config.fp8_config.mm_block_size, quant_type=QuantType.MUL, save_fp8_input=True
|
| 220 |
+
),
|
| 221 |
+
)
|
| 222 |
+
self.post_attention_layernorm = FP8RMSNorm(
|
| 223 |
+
config.hidden_size,
|
| 224 |
+
eps=config.rms_norm_eps,
|
| 225 |
+
norm_config=FP8RMSNormConfig(
|
| 226 |
+
mm_block_size=config.fp8_config.mm_block_size, quant_type=QuantType.MUL, save_fp8_input=True
|
| 227 |
+
),
|
| 228 |
+
)
|
| 229 |
self.attention_type = config.layer_types[layer_idx]
|
| 230 |
|
| 231 |
@deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
|
|
|
|
| 420 |
"""
|
| 421 |
# Assert linear keys and undesired linear keys are non-overlapping
|
| 422 |
assert set(linear_keys).isdisjoint(set(undesired_linear_keys))
|
| 423 |
+
|
| 424 |
compatible_state_dict = {}
|
| 425 |
|
| 426 |
for key in state_dict.keys():
|
| 427 |
if any(k in key for k in linear_keys):
|
| 428 |
weight = state_dict[key]
|
| 429 |
+
|
| 430 |
if already_fp8:
|
| 431 |
# The name (either weight or weight_scale_inv) is the same as the original key.
|
| 432 |
compatible_state_dict[key] = weight
|
|
|
|
| 436 |
float8_dtype=config.fp8_config.float8_dtype,
|
| 437 |
quant_type=config.fp8_config.quant_type,
|
| 438 |
fwd_block_size=config.fp8_config.mm_block_size,
|
| 439 |
+
scale_dtype=torch.float32,
|
| 440 |
)
|
| 441 |
quant_weight, scale_weight = fp8_quantize_hp2pb(
|
| 442 |
weight, tmp_quant_cfg, block_size=config.fp8_config.mm_block_size
|
| 443 |
)
|
| 444 |
+
|
| 445 |
name_quant = key.replace("weight", "weight")
|
| 446 |
name_scale = key.replace("weight", "weight_scale_inv")
|
| 447 |
compatible_state_dict[name_quant] = quant_weight
|
| 448 |
compatible_state_dict[name_scale] = scale_weight
|
| 449 |
+
|
| 450 |
elif any(k in key for k in undesired_linear_keys):
|
| 451 |
# Dequantize the weight
|
| 452 |
if already_fp8:
|
|
|
|
| 455 |
name_quant = key.replace("weight_scale_inv", "weight")
|
| 456 |
quant_weight = state_dict[name_quant]
|
| 457 |
scale_weight = state_dict[key]
|
| 458 |
+
weight = fp8_dequantize_pb2hp(
|
| 459 |
+
quant_weight, scale_weight, config.fp8_config, block_size=config.fp8_config.mm_block_size
|
| 460 |
+
)
|
| 461 |
compatible_state_dict[name_quant] = weight
|
| 462 |
else:
|
| 463 |
# Do not quantize the weight.
|
| 464 |
compatible_state_dict[key] = state_dict[key]
|
| 465 |
+
|
| 466 |
else:
|
| 467 |
compatible_state_dict[key] = state_dict[key]
|
| 468 |
return compatible_state_dict
|