imlixinyang commited on
Commit
854d14d
·
1 Parent(s): ef4e385
Files changed (1) hide show
  1. quant.py +1 -1
quant.py CHANGED
@@ -178,7 +178,7 @@ def FluxFp8GeMMProcessor(model: torch.nn.Module):
178
  )
179
  named_modules = list(model.named_modules())
180
  for name, linear in tqdm.tqdm(named_modules, desc="Quantizing weights to fp8"):
181
- if isinstance(linear, torch.nn.Linear):
182
  quant_weight, weight_scale = per_tensor_quantize(linear.weight)
183
  bias = copy.deepcopy(linear.bias) if linear.bias is not None else None
184
  quant_linear = FP8DynamicLinear(
 
178
  )
179
  named_modules = list(model.named_modules())
180
  for name, linear in tqdm.tqdm(named_modules, desc="Quantizing weights to fp8"):
181
+ if isinstance(linear, torch.nn.Linear) and "blocks" in name:
182
  quant_weight, weight_scale = per_tensor_quantize(linear.weight)
183
  bias = copy.deepcopy(linear.bias) if linear.bias is not None else None
184
  quant_linear = FP8DynamicLinear(