Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 77 additions & 2 deletions examples/deepseek/ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,20 +300,89 @@ def calibrate_loop(model):
# disable head that corresponds to lm_head (for the huggingface checkpoint)
mtq_cfg["quant_cfg"]["*head*"] = {"enable": False}

allowed_mla_quant = [None, "per_tensor_fp8"]
allowed_mla_quant = [None, "per_tensor_fp8", "nvfp4_wq_a_wkv_a_wq_b_wo", "nvfp4_wq_a_wkv_a_wq_b_wo_fp8_wkv_b"]
assert mla_quant in allowed_mla_quant, f"mla_quant must be {allowed_mla_quant}"

if not mla_quant:
mtq_cfg["quant_cfg"]["*attn*"] = {"enable": False}
elif mla_quant == "per_tensor_fp8":
mtq_cfg["quant_cfg"]["*attn*weight_quantizer"] = {"num_bits": (4, 3), "axis": None}
mtq_cfg["quant_cfg"]["*attn*input_quantizer"] = {"num_bits": (4, 3), "axis": None}
elif mla_quant == "nvfp4_wq_a_wkv_a_wq_b_wo": # for DeepSeek-R1-0528-v3_1
# Only quantize linear layers(wq_a, wq_b, wkv_a, wo) in MLA, not BMM operations
mla_linear_layers = ["*wq_a*", "*wq_b*", "*wkv_a*", "*wkv_b*", "*wo*"] # "*wq*"
mla_nvfp4_linear_layers = ["*wq_a*", "*wkv_a*", "*wq_b*", "*wo*"]
for layer in mla_linear_layers:
if layer in mla_nvfp4_linear_layers:
mtq_cfg["quant_cfg"][layer+"_quantizer"] = {
"num_bits": (2, 1),
"block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)},
"axis": None,
"enable": True,
}
else:
mtq_cfg["quant_cfg"][layer+"_quantizer"] = {"enable": False}

# Disable BMM quantizers
mtq_cfg["quant_cfg"]["*attn.kv_bmm_quantizer*"] = {"enable": False}
mtq_cfg["quant_cfg"]["*attn.pe_bmm_quantizer*"] = {"enable": False}

elif mla_quant == "nvfp4_wq_a_wkv_a_wq_b_wo_fp8_wkv_b": # for DeepSeek-R1-0528-v3_2
# wq_a, wkv_a, wq_b, wo use NVFP4
# wkv_b uses FP8 per-tensor quantization (weight: normal scale, activation: scale=1)
mla_linear_layers = ["*wq_a*", "*wq_b*", "*wkv_a*", "*wkv_b*", "*wo*"]
mla_nvfp4_linear_layers = ["*wq_a*", "*wkv_a*", "*wq_b*", "*wo*"]

for layer in mla_linear_layers:
if layer in mla_nvfp4_linear_layers:
# NVFP4 quantization
mtq_cfg["quant_cfg"][layer+"_quantizer"] = {
"num_bits": (2, 1),
"block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)},
"axis": None,
"enable": True,
}
elif layer == "*wkv_b*":
# wkv_b uses FP8 per-tensor quantization
mtq_cfg["quant_cfg"][layer+"weight_quantizer"] = {
"num_bits": (4, 3), # FP8
"axis": None,
"enable": True,
}
mtq_cfg["quant_cfg"][layer+"input_quantizer"] = {
"num_bits": (4, 3), # FP8
"axis": None,
"enable": True,
}
else:
mtq_cfg["quant_cfg"][layer+"_quantizer"] = {"enable": False}

# Disable BMM quantizers
mtq_cfg["quant_cfg"]["*attn.kv_bmm_quantizer*"] = {"enable": False}
mtq_cfg["quant_cfg"]["*attn.pe_bmm_quantizer*"] = {"enable": False}

if not args.disable_wo_quant and "FP4" in quant_cfg:
mtq_cfg["quant_cfg"]["*wo*weight_quantizer"] = mtq_cfg["quant_cfg"]["*input_quantizer"]
mtq_cfg["quant_cfg"]["*wo*input_quantizer"] = mtq_cfg["quant_cfg"]["*weight_quantizer"]
## ptq
transformer = mtq.quantize(transformer, mtq_cfg, calibrate_loop)

# Force wkv_b activation scale=1 for nvfp4_wq_a_wkv_a_wq_b_wo_fp8_wkv_b
if mla_quant == "nvfp4_wq_a_wkv_a_wq_b_wo_fp8_wkv_b":
fp8_max_value = 448.0 # FP8 E4M3 max value

for name, module in transformer.named_modules():
# Match wkv_b layers
if "wkv_b" in name:
if hasattr(module, 'input_quantizer') and module.input_quantizer.is_enabled:
# Force activation amax = 448.0, so scale = amax/448.0 = 1.0
if int(os.environ.get("LOCAL_RANK", "0")) == 0:
old_amax = module.input_quantizer._amax.data.clone()
module.input_quantizer._amax.data.fill_(fp8_max_value)
print(f"[INFO] Forced {name}.input_quantizer amax from {old_amax.item()} to {fp8_max_value}")
else:
module.input_quantizer._amax.data.fill_(fp8_max_value)

if int(os.environ["LOCAL_RANK"]) == 0:
mtq.print_quant_summary(transformer)

Expand Down Expand Up @@ -396,11 +465,17 @@ def state_dict_filter(state_dict):
parser.add_argument("--disable_fp8_kvcache", action="store_true", help="disable fp8 kvcache.")
parser.add_argument("--disable_wo_quant", action="store_true", help="disable MLA wo quant.")
parser.add_argument("--trust_remote_code", action="store_true", help="trust remote code.")
parser.add_argument(
"--mla_quant",
type=str,
default=None,
help="MLA quantization type: None (disable), per_tensor_fp8, nvfp4_wq_a_wkv_a_wq_b_wo, or nvfp4_wq_a_wkv_a_wq_b_wo_fp8_wkv_b"
)

args = parser.parse_args()
model = load_deepseek_model(args.config, args.model_path, args.batch_size)
tokenizer = AutoTokenizer.from_pretrained(
args.model_path, trust_remote_code=args.trust_remote_code
)
model = ptq(model, tokenizer, args.quant_cfg, args.batch_size, args.calib_size)
model = ptq(model, tokenizer, args.quant_cfg, args.batch_size, args.calib_size, args.mla_quant)
save_amax_and_quant_config(model, args.output_path, not args.disable_fp8_kvcache)