Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 37 additions & 4 deletions tools/convert_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# Basic conversion (no vocab pruning):
python tools/convert_to_hf.py --input-dir <checkpoint_dir>

# Auto-generate draft config from target model:
# With target model (loads correct embed_tokens + auto-generates config):
python tools/convert_to_hf.py --input-dir <checkpoint_dir> \
--target-model-path moonshotai/Kimi-K2.5 --trust-remote-code

Expand Down Expand Up @@ -182,7 +182,7 @@ def _extract_model_weights(
for k, v in state_dict.items():
if not isinstance(v, torch.Tensor):
continue
if "draft_model." not in k or "embed" in k.lower():
if "draft_model." not in k:
skipped_keys.append(k)
continue
new_key = k.split("draft_model.")[-1]
Expand All @@ -191,7 +191,7 @@ def _extract_model_weights(
model_state[new_key] = v

logger.info(
"Extracted %d model weight keys (skipped %d non-draft/embedding keys)",
"Extracted %d model weight keys (skipped %d non-draft keys)",
len(model_state),
len(skipped_keys),
)
Expand Down Expand Up @@ -366,6 +366,8 @@ def _convert_fsdp_to_hf(
config_path: str,
input_dir: str,
output_dir: str,
target_model_path: Optional[str] = None,
embedding_key: str = "model.embed_tokens",
prune_vocab: bool = False,
dataset_path: Optional[str] = None,
draft_vocab_size: Optional[int] = None,
Expand All @@ -390,12 +392,35 @@ def _convert_fsdp_to_hf(
config = AutoDraftModelConfig.from_file(config_path)
hf_model = AutoEagle3DraftModel.from_config(config)

# Infer dtype from checkpoint weights so the HF model matches (avoids silent precision changes)
ckpt_dtype = Counter(v.dtype for v in model_state.values()).most_common(1)[0][0]
logger.info("Checkpoint dtype: %s, casting HF model to match", ckpt_dtype)
hf_model = hf_model.to(ckpt_dtype)

missing, unexpected = hf_model.load_state_dict(model_state, strict=False)
if missing:
logger.warning("Missing keys: %s", missing)
if unexpected:
logger.warning("Unexpected keys: %s", unexpected)

# Optionally override embed_tokens from target model.
# Useful for older checkpoints where embed_tokens may not have been saved correctly.
if target_model_path:
logger.info("Loading embed_tokens from target model: %s", target_model_path)
embed_key = embedding_key + ".weight"
try:
hf_model.load_embedding(target_model_path, embedding_key=embed_key)
except KeyError as e:
raise ValueError(
f"Embedding key '{embed_key}' not found in target model. "
f"Use --embedding-key to specify the correct key prefix."
) from e
logger.info(
"Loaded embed_tokens: shape=%s, dtype=%s",
list(hf_model.embed_tokens.weight.shape),
hf_model.embed_tokens.weight.dtype,
)

os.makedirs(output_dir, exist_ok=True)

with open(config_path) as f:
Expand Down Expand Up @@ -462,7 +487,13 @@ def _parse_args() -> argparse.Namespace:
type=str,
default=None,
help="Target model (HF hub id or local path). "
"Used to auto-generate config.json when not found in checkpoint dir",
"Used to load embed_tokens and auto-generate config.json if missing",
)
parser.add_argument(
"--embedding-key",
type=str,
default="model.embed_tokens",
help="Key prefix for embedding weights in target model (default: model.embed_tokens)",
)
parser.add_argument(
"--trust-remote-code",
Expand Down Expand Up @@ -603,6 +634,8 @@ def _validate_args(args: argparse.Namespace) -> None:
config_path=config_path,
input_dir=model_dir,
output_dir=output_dir,
target_model_path=args.target_model_path,
embedding_key=args.embedding_key,
prune_vocab=args.prune_vocab,
dataset_path=args.dataset_path,
draft_vocab_size=args.draft_vocab_size,
Expand Down