Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion modelopt/torch/export/unified_export_megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@

from .model_config import (
KV_CACHE_FP8,
KV_CACHE_NVFP4,
QUANTIZATION_FP8,
QUANTIZATION_FP8_PB_REAL,
QUANTIZATION_FP8_PB_WO,
Expand Down Expand Up @@ -326,7 +327,6 @@ def save_pretrained(
state_dict = self.extra_state_dict if self.export_extra_modules else self.state_dict
quantization_format = get_quantization_format(self.model)
quantization = None
kv_cache_quantization = None

if quantization_format in (
QUANTIZATION_FP8_PB_REAL,
Expand All @@ -338,6 +338,11 @@ def save_pretrained(
elif quantization_format == QUANTIZATION_NVFP4:
quantization = "NVFP4"

kv_cache_quantization = None
kv_cache_dtype = get_kv_cache_dtype(self.model)
if kv_cache_dtype in (KV_CACHE_FP8, KV_CACHE_NVFP4):
# Only FP8 KV Cache is supported in VLLM for now
kv_cache_quantization = kv_cache_dtype
# We use the last PP rank and the 1st EP rank to write the config because
# medusa_heads and eagle_module only exist in the last stage.
if is_last_stage_main_rank:
Expand Down
Loading