xihc-ucb NneilNight commited on
Commit
542e4d6
·
verified ·
1 Parent(s): c30c97b

Upload FP8Qwen2ForCausalLM (#6)

Browse files

- Upload FP8Qwen2ForCausalLM (3e721beb3dfa87636eda6249a856fec8c4dfaae8)


Co-authored-by: Neil <[email protected]>

Files changed (3) hide show
  1. config.json +2 -2
  2. generation_config.json +1 -1
  3. modeling_fp8_qwen2.py +84 -27
config.json CHANGED
@@ -12,7 +12,6 @@
12
  "AutoModelForTokenClassification": "modeling_fp8_qwen2.FP8Qwen2ForTokenClassification"
13
  },
14
  "bos_token_id": 151643,
15
- "dtype": "bfloat16",
16
  "eos_token_id": 151645,
17
  "fp8_config": {
18
  "act_block_size": 16,
@@ -68,7 +67,8 @@
68
  "rope_theta": 1000000.0,
69
  "sliding_window": null,
70
  "tie_word_embeddings": false,
71
- "transformers_version": "4.57.0",
 
72
  "use_cache": true,
73
  "use_sliding_window": false,
74
  "vocab_size": 152064
 
12
  "AutoModelForTokenClassification": "modeling_fp8_qwen2.FP8Qwen2ForTokenClassification"
13
  },
14
  "bos_token_id": 151643,
 
15
  "eos_token_id": 151645,
16
  "fp8_config": {
17
  "act_block_size": 16,
 
67
  "rope_theta": 1000000.0,
68
  "sliding_window": null,
69
  "tie_word_embeddings": false,
70
+ "torch_dtype": "bfloat16",
71
+ "transformers_version": "4.54.1",
72
  "use_cache": true,
73
  "use_sliding_window": false,
74
  "vocab_size": 152064
generation_config.json CHANGED
@@ -2,5 +2,5 @@
2
  "_from_model_config": true,
3
  "bos_token_id": 151643,
4
  "eos_token_id": 151645,
5
- "transformers_version": "4.57.0"
6
  }
 
2
  "_from_model_config": true,
3
  "bos_token_id": 151643,
4
  "eos_token_id": 151645,
5
+ "transformers_version": "4.54.1"
6
  }
modeling_fp8_qwen2.py CHANGED
@@ -47,6 +47,7 @@ from .configuration_fp8_qwen2 import FP8Qwen2Config
47
  from torchao.float8.float8_training_tensor import Float8TrainingTensor
48
 
49
  from quasar.module import (
 
50
  FP8RMSNorm,
51
  FP8DSLinearWithCoat,
52
  FP8DSLinearWithCoatWeightBlock,
@@ -65,9 +66,24 @@ class FP8Qwen2MLP(Qwen2MLP):
65
  def __init__(self, config: FP8Qwen2Config):
66
  super().__init__(config)
67
  linear_module = FP8DSLinearWithCoat if config.fp8_config.training_mode else FP8DSLinearWithCoatWeightBlock
68
- self.gate_proj = linear_module(self.hidden_size, self.intermediate_size, bias=False, dsgemm_config=FP8DSLinearWithCoatConfig(layer_name=f"gate_proj"))
69
- self.up_proj = linear_module(self.hidden_size, self.intermediate_size, bias=False, dsgemm_config=FP8DSLinearWithCoatConfig(layer_name=f"up_proj"))
70
- self.down_proj = linear_module(self.intermediate_size, self.hidden_size, bias=False, dsgemm_config=FP8DSLinearWithCoatConfig(layer_name=f"down_proj"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
  if config.hidden_act == "silu":
73
  mul_config = FP8MulConfig(
@@ -93,22 +109,46 @@ class FP8Qwen2Attention(Qwen2Attention):
93
  def __init__(self, config: FP8Qwen2Config, layer_idx: int):
94
  super().__init__(config, layer_idx)
95
  linear_module = FP8DSLinearWithCoat if config.fp8_config.training_mode else FP8DSLinearWithCoatWeightBlock
96
- self.q_proj = linear_module(config.hidden_size, config.num_attention_heads * self.head_dim, bias=True, dsgemm_config=FP8DSLinearWithCoatConfig(layer_name=f"q_proj"))
97
- self.k_proj = linear_module(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True, dsgemm_config=FP8DSLinearWithCoatConfig(layer_name=f"k_proj"))
98
- self.v_proj = linear_module(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True, dsgemm_config=FP8DSLinearWithCoatConfig(layer_name=f"v_proj"))
99
-
100
- if not config.fp8_config.training_mode:
101
- # Only when doing inference, we quantize the output of the attention layer.
102
- self.o_proj = linear_module(
103
- config.num_attention_heads * self.head_dim, config.hidden_size, bias=False,
104
- dsgemm_config=FP8DSLinearWithCoatConfig(
105
- fwd_input_quant_type=QuantType.DIV,
106
- layer_name=f"o_proj",
107
- scale_dtype=torch.float32,
108
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  )
110
-
111
-
 
 
 
 
 
 
 
 
 
 
112
  @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
113
  def forward(
114
  self,
@@ -157,6 +197,9 @@ class FP8Qwen2Attention(Qwen2Attention):
157
  )
158
 
159
  attn_output = attn_output.reshape(*input_shape, -1).contiguous()
 
 
 
160
  attn_output = self.o_proj(attn_output)
161
  return attn_output, attn_weights
162
 
@@ -169,8 +212,20 @@ class FP8Qwen2DecoderLayer(GradientCheckpointingLayer):
169
  self.self_attn = FP8Qwen2Attention(config=config, layer_idx=layer_idx)
170
 
171
  self.mlp = FP8Qwen2MLP(config)
172
- self.input_layernorm = FP8RMSNorm(config.hidden_size, eps=config.rms_norm_eps, norm_config=FP8RMSNormConfig(mm_block_size=config.fp8_config.mm_block_size, quant_type=QuantType.MUL, save_fp8_input=True))
173
- self.post_attention_layernorm = FP8RMSNorm(config.hidden_size, eps=config.rms_norm_eps, norm_config=FP8RMSNormConfig(mm_block_size=config.fp8_config.mm_block_size, quant_type=QuantType.MUL, save_fp8_input=True))
 
 
 
 
 
 
 
 
 
 
 
 
174
  self.attention_type = config.layer_types[layer_idx]
175
 
176
  @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
@@ -365,13 +420,13 @@ def make_state_dict_compatible_with_hf(
365
  """
366
  # Assert linear keys and undesired linear keys are non-overlapping
367
  assert set(linear_keys).isdisjoint(set(undesired_linear_keys))
368
-
369
  compatible_state_dict = {}
370
 
371
  for key in state_dict.keys():
372
  if any(k in key for k in linear_keys):
373
  weight = state_dict[key]
374
-
375
  if already_fp8:
376
  # The name (either weight or weight_scale_inv) is the same as the original key.
377
  compatible_state_dict[key] = weight
@@ -381,17 +436,17 @@ def make_state_dict_compatible_with_hf(
381
  float8_dtype=config.fp8_config.float8_dtype,
382
  quant_type=config.fp8_config.quant_type,
383
  fwd_block_size=config.fp8_config.mm_block_size,
384
- scale_dtype=torch.float32,
385
  )
386
  quant_weight, scale_weight = fp8_quantize_hp2pb(
387
  weight, tmp_quant_cfg, block_size=config.fp8_config.mm_block_size
388
  )
389
-
390
  name_quant = key.replace("weight", "weight")
391
  name_scale = key.replace("weight", "weight_scale_inv")
392
  compatible_state_dict[name_quant] = quant_weight
393
  compatible_state_dict[name_scale] = scale_weight
394
-
395
  elif any(k in key for k in undesired_linear_keys):
396
  # Dequantize the weight
397
  if already_fp8:
@@ -400,12 +455,14 @@ def make_state_dict_compatible_with_hf(
400
  name_quant = key.replace("weight_scale_inv", "weight")
401
  quant_weight = state_dict[name_quant]
402
  scale_weight = state_dict[key]
403
- weight = fp8_dequantize_pb2hp(quant_weight, scale_weight, config.fp8_config, block_size=config.fp8_config.mm_block_size)
 
 
404
  compatible_state_dict[name_quant] = weight
405
  else:
406
  # Do not quantize the weight.
407
  compatible_state_dict[key] = state_dict[key]
408
-
409
  else:
410
  compatible_state_dict[key] = state_dict[key]
411
  return compatible_state_dict
 
47
  from torchao.float8.float8_training_tensor import Float8TrainingTensor
48
 
49
  from quasar.module import (
50
+ FP8Quant,
51
  FP8RMSNorm,
52
  FP8DSLinearWithCoat,
53
  FP8DSLinearWithCoatWeightBlock,
 
66
  def __init__(self, config: FP8Qwen2Config):
67
  super().__init__(config)
68
  linear_module = FP8DSLinearWithCoat if config.fp8_config.training_mode else FP8DSLinearWithCoatWeightBlock
69
+ self.gate_proj = linear_module(
70
+ self.hidden_size,
71
+ self.intermediate_size,
72
+ bias=False,
73
+ dsgemm_config=FP8DSLinearWithCoatConfig(layer_name=f"gate_proj", scale_dtype=torch.float32),
74
+ )
75
+ self.up_proj = linear_module(
76
+ self.hidden_size,
77
+ self.intermediate_size,
78
+ bias=False,
79
+ dsgemm_config=FP8DSLinearWithCoatConfig(layer_name=f"up_proj", scale_dtype=torch.float32),
80
+ )
81
+ self.down_proj = linear_module(
82
+ self.intermediate_size,
83
+ self.hidden_size,
84
+ bias=False,
85
+ dsgemm_config=FP8DSLinearWithCoatConfig(layer_name=f"down_proj", scale_dtype=torch.float32),
86
+ )
87
 
88
  if config.hidden_act == "silu":
89
  mul_config = FP8MulConfig(
 
109
  def __init__(self, config: FP8Qwen2Config, layer_idx: int):
110
  super().__init__(config, layer_idx)
111
  linear_module = FP8DSLinearWithCoat if config.fp8_config.training_mode else FP8DSLinearWithCoatWeightBlock
112
+ self.q_proj = linear_module(
113
+ config.hidden_size,
114
+ config.num_attention_heads * self.head_dim,
115
+ bias=True,
116
+ dsgemm_config=FP8DSLinearWithCoatConfig(layer_name=f"q_proj", scale_dtype=torch.float32),
117
+ )
118
+ self.k_proj = linear_module(
119
+ config.hidden_size,
120
+ config.num_key_value_heads * self.head_dim,
121
+ bias=True,
122
+ dsgemm_config=FP8DSLinearWithCoatConfig(layer_name=f"k_proj", scale_dtype=torch.float32),
123
+ )
124
+ self.v_proj = linear_module(
125
+ config.hidden_size,
126
+ config.num_key_value_heads * self.head_dim,
127
+ bias=True,
128
+ dsgemm_config=FP8DSLinearWithCoatConfig(layer_name=f"v_proj", scale_dtype=torch.float32),
129
+ )
130
+
131
+ # In both training and inference, we quantize the output of the attention layer.
132
+ self.o_proj_quant = FP8Quant(
133
+ quant_config=FP8QuantConfig(
134
+ float8_dtype=config.fp8_config.float8_dtype,
135
+ quant_type=QuantType.DIV,
136
+ fwd_block_size=config.fp8_config.mm_block_size,
137
+ layer_name=f"o_proj_quant",
138
+ scale_dtype=torch.float32,
139
  )
140
+ )
141
+ self.o_proj = linear_module(
142
+ config.num_attention_heads * self.head_dim,
143
+ config.hidden_size,
144
+ bias=False,
145
+ dsgemm_config=FP8DSLinearWithCoatConfig(
146
+ fwd_input_quant_type=QuantType.DIV,
147
+ layer_name=f"o_proj",
148
+ scale_dtype=torch.float32,
149
+ ),
150
+ )
151
+
152
  @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
153
  def forward(
154
  self,
 
197
  )
198
 
199
  attn_output = attn_output.reshape(*input_shape, -1).contiguous()
200
+
201
+ # Quantize the output of the attention layer.
202
+ attn_output = self.o_proj_quant(attn_output)
203
  attn_output = self.o_proj(attn_output)
204
  return attn_output, attn_weights
205
 
 
212
  self.self_attn = FP8Qwen2Attention(config=config, layer_idx=layer_idx)
213
 
214
  self.mlp = FP8Qwen2MLP(config)
215
+ self.input_layernorm = FP8RMSNorm(
216
+ config.hidden_size,
217
+ eps=config.rms_norm_eps,
218
+ norm_config=FP8RMSNormConfig(
219
+ mm_block_size=config.fp8_config.mm_block_size, quant_type=QuantType.MUL, save_fp8_input=True
220
+ ),
221
+ )
222
+ self.post_attention_layernorm = FP8RMSNorm(
223
+ config.hidden_size,
224
+ eps=config.rms_norm_eps,
225
+ norm_config=FP8RMSNormConfig(
226
+ mm_block_size=config.fp8_config.mm_block_size, quant_type=QuantType.MUL, save_fp8_input=True
227
+ ),
228
+ )
229
  self.attention_type = config.layer_types[layer_idx]
230
 
231
  @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
 
420
  """
421
  # Assert linear keys and undesired linear keys are non-overlapping
422
  assert set(linear_keys).isdisjoint(set(undesired_linear_keys))
423
+
424
  compatible_state_dict = {}
425
 
426
  for key in state_dict.keys():
427
  if any(k in key for k in linear_keys):
428
  weight = state_dict[key]
429
+
430
  if already_fp8:
431
  # The name (either weight or weight_scale_inv) is the same as the original key.
432
  compatible_state_dict[key] = weight
 
436
  float8_dtype=config.fp8_config.float8_dtype,
437
  quant_type=config.fp8_config.quant_type,
438
  fwd_block_size=config.fp8_config.mm_block_size,
439
+ scale_dtype=torch.float32,
440
  )
441
  quant_weight, scale_weight = fp8_quantize_hp2pb(
442
  weight, tmp_quant_cfg, block_size=config.fp8_config.mm_block_size
443
  )
444
+
445
  name_quant = key.replace("weight", "weight")
446
  name_scale = key.replace("weight", "weight_scale_inv")
447
  compatible_state_dict[name_quant] = quant_weight
448
  compatible_state_dict[name_scale] = scale_weight
449
+
450
  elif any(k in key for k in undesired_linear_keys):
451
  # Dequantize the weight
452
  if already_fp8:
 
455
  name_quant = key.replace("weight_scale_inv", "weight")
456
  quant_weight = state_dict[name_quant]
457
  scale_weight = state_dict[key]
458
+ weight = fp8_dequantize_pb2hp(
459
+ quant_weight, scale_weight, config.fp8_config, block_size=config.fp8_config.mm_block_size
460
+ )
461
  compatible_state_dict[name_quant] = weight
462
  else:
463
  # Do not quantize the weight.
464
  compatible_state_dict[key] = state_dict[key]
465
+
466
  else:
467
  compatible_state_dict[key] = state_dict[key]
468
  return compatible_state_dict