From 91c82a0fbf9fa7f43eaeedee55013ecb4d033557 Mon Sep 17 00:00:00 2001 From: cicirori Date: Wed, 11 Mar 2026 00:00:07 -0700 Subject: [PATCH] [Bug Fix] Fix exported checkpoint containing random embed_tokens and wrong dtype MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Include embed_tokens from FSDP checkpoint instead of skipping it (the FSDP checkpoint already contains the correct target embeddings) - Optionally override embed_tokens from target model via --target-model-path (for compatibility or verification) - Add --embedding-key option for models with non-standard embedding key - Infer dtype from checkpoint weights and cast HF model before loading to avoid bf16→fp16 precision loss - Set torch_dtype in config to match actual weight dtype Fixes #38 --- tools/convert_to_hf.py | 41 +++++++++++++++++++++++++++++++++++++---- 1 file changed, 37 insertions(+), 4 deletions(-) diff --git a/tools/convert_to_hf.py b/tools/convert_to_hf.py index 9a82eba..122603a 100644 --- a/tools/convert_to_hf.py +++ b/tools/convert_to_hf.py @@ -5,7 +5,7 @@ # Basic conversion (no vocab pruning): python tools/convert_to_hf.py --input-dir - # Auto-generate draft config from target model: + # With target model (loads correct embed_tokens + auto-generates config): python tools/convert_to_hf.py --input-dir \ --target-model-path moonshotai/Kimi-K2.5 --trust-remote-code @@ -182,7 +182,7 @@ def _extract_model_weights( for k, v in state_dict.items(): if not isinstance(v, torch.Tensor): continue - if "draft_model." not in k or "embed" in k.lower(): + if "draft_model." not in k: skipped_keys.append(k) continue new_key = k.split("draft_model.")[-1] @@ -191,7 +191,7 @@ def _extract_model_weights( model_state[new_key] = v logger.info( - "Extracted %d model weight keys (skipped %d non-draft/embedding keys)", + "Extracted %d model weight keys (skipped %d non-draft keys)", len(model_state), len(skipped_keys), ) @@ -366,6 +366,8 @@ def _convert_fsdp_to_hf( config_path: str, input_dir: str, output_dir: str, + target_model_path: Optional[str] = None, + embedding_key: str = "model.embed_tokens", prune_vocab: bool = False, dataset_path: Optional[str] = None, draft_vocab_size: Optional[int] = None, @@ -390,12 +392,35 @@ def _convert_fsdp_to_hf( config = AutoDraftModelConfig.from_file(config_path) hf_model = AutoEagle3DraftModel.from_config(config) + # Infer dtype from checkpoint weights so the HF model matches (avoids silent precision changes) + ckpt_dtype = Counter(v.dtype for v in model_state.values()).most_common(1)[0][0] + logger.info("Checkpoint dtype: %s, casting HF model to match", ckpt_dtype) + hf_model = hf_model.to(ckpt_dtype) + missing, unexpected = hf_model.load_state_dict(model_state, strict=False) if missing: logger.warning("Missing keys: %s", missing) if unexpected: logger.warning("Unexpected keys: %s", unexpected) + # Optionally override embed_tokens from target model. + # Useful for older checkpoints where embed_tokens may not have been saved correctly. + if target_model_path: + logger.info("Loading embed_tokens from target model: %s", target_model_path) + embed_key = embedding_key + ".weight" + try: + hf_model.load_embedding(target_model_path, embedding_key=embed_key) + except KeyError as e: + raise ValueError( + f"Embedding key '{embed_key}' not found in target model. " + f"Use --embedding-key to specify the correct key prefix." + ) from e + logger.info( + "Loaded embed_tokens: shape=%s, dtype=%s", + list(hf_model.embed_tokens.weight.shape), + hf_model.embed_tokens.weight.dtype, + ) + os.makedirs(output_dir, exist_ok=True) with open(config_path) as f: @@ -462,7 +487,13 @@ def _parse_args() -> argparse.Namespace: type=str, default=None, help="Target model (HF hub id or local path). " - "Used to auto-generate config.json when not found in checkpoint dir", + "Used to load embed_tokens and auto-generate config.json if missing", + ) + parser.add_argument( + "--embedding-key", + type=str, + default="model.embed_tokens", + help="Key prefix for embedding weights in target model (default: model.embed_tokens)", ) parser.add_argument( "--trust-remote-code", @@ -603,6 +634,8 @@ def _validate_args(args: argparse.Namespace) -> None: config_path=config_path, input_dir=model_dir, output_dir=output_dir, + target_model_path=args.target_model_path, + embedding_key=args.embedding_key, prune_vocab=args.prune_vocab, dataset_path=args.dataset_path, draft_vocab_size=args.draft_vocab_size,