diff --git a/modelopt/torch/export/unified_export_megatron.py b/modelopt/torch/export/unified_export_megatron.py index 70a80aeec..137a6a8d0 100644 --- a/modelopt/torch/export/unified_export_megatron.py +++ b/modelopt/torch/export/unified_export_megatron.py @@ -38,6 +38,7 @@ from .model_config import ( KV_CACHE_FP8, + KV_CACHE_NVFP4, QUANTIZATION_FP8, QUANTIZATION_FP8_PB_REAL, QUANTIZATION_FP8_PB_WO, @@ -326,7 +327,6 @@ def save_pretrained( state_dict = self.extra_state_dict if self.export_extra_modules else self.state_dict quantization_format = get_quantization_format(self.model) quantization = None - kv_cache_quantization = None if quantization_format in ( QUANTIZATION_FP8_PB_REAL, @@ -338,6 +338,11 @@ def save_pretrained( elif quantization_format == QUANTIZATION_NVFP4: quantization = "NVFP4" + kv_cache_quantization = None + kv_cache_dtype = get_kv_cache_dtype(self.model) + if kv_cache_dtype in (KV_CACHE_FP8, KV_CACHE_NVFP4): + # Only FP8 KV Cache is supported in VLLM for now + kv_cache_quantization = kv_cache_dtype # We use the last PP rank and the 1st EP rank to write the config because # medusa_heads and eagle_module only exist in the last stage. if is_last_stage_main_rank: