Skip to content

[Bug Fix] Fix exported checkpoint with random embed_tokens and wrong dtype#39

Merged
cicirori merged 1 commit intomainfrom
fix/export-embed-tokens-dtype
Mar 11, 2026
Merged

[Bug Fix] Fix exported checkpoint with random embed_tokens and wrong dtype#39
cicirori merged 1 commit intomainfrom
fix/export-embed-tokens-dtype

Conversation

@cicirori
Copy link
Collaborator

@cicirori cicirori commented Mar 11, 2026

Summary

Root cause: _extract_model_weights() skipped all keys containing "embed", so the HF model's randomly initialized embed_tokens was saved by save_pretrained. The FSDP checkpoint itself already contains the correct target embeddings (verified by torch.equal with the target model).

Changes in convert_to_hf.py:

  • Include embed_tokens from FSDP checkpoint (no longer skip "embed" keys)
  • Optionally override embed_tokens from target model via --target-model-path (for verification or old checkpoints)
  • Add --embedding-key option for models with non-standard embedding key (e.g. language_model.model.embed_tokens for VLMs)
  • Infer dtype from checkpoint weights and cast HF model before load_state_dict to avoid bf16→fp16 precision loss
  • Set torch_dtype in config to match actual weight dtype

Note: vocab pruning path is unaffected — it only trims lm_head, not embed_tokens (embed needs full vocab for target token id lookup).

Verified on kimi_eval_test checkpoint

# FSDP checkpoint embed_tokens vs target model:
torch.equal: True   ← FSDP already has the correct embedding

# Exported checkpoint (without --target-model-path):
Draft embed:  shape=[163840, 7168], dtype=torch.bfloat16
Target embed: shape=[163840, 7168], dtype=torch.bfloat16
torch.equal:  True
config dtype:  bfloat16
All 14 weights saved in bfloat16

Test plan

  • Verify FSDP checkpoint embed_tokens matches target model
  • Export checkpoint → torch.equal returns True
  • All weights saved in bfloat16, config dtype correct
  • Load in vLLM and verify acceptance length improvement

Fixes #38

@cicirori cicirori force-pushed the fix/export-embed-tokens-dtype branch 2 times, most recently from f978c3d to d36d631 Compare March 11, 2026 06:52
yubofredwang
yubofredwang previously approved these changes Mar 11, 2026
@cicirori cicirori force-pushed the fix/export-embed-tokens-dtype branch 2 times, most recently from f41e0d6 to 2d54fac Compare March 11, 2026 07:00
@cicirori
Copy link
Collaborator Author

Verification

Tested on an existing FSDP checkpoint (Kimi-K2.5 Eagle3, 16-rank training).

1. FSDP checkpoint already contains correct embed_tokens:

FSDP embed vs Target embed: torch.equal = True

2. Export without --target-model-path (embed from FSDP checkpoint):

Extracted 14 model weight keys (skipped 0 non-draft keys)
Checkpoint dtype: torch.bfloat16, casting HF model to match
Draft embed:  shape=[163840, 7168], dtype=torch.bfloat16
Target embed: shape=[163840, 7168], dtype=torch.bfloat16
torch.equal:  True
config dtype:  bfloat16
All 14 weights saved in bfloat16

3. Export with --target-model-path (embed overridden from target model):

Loaded embed_tokens: shape=[163840, 7168], dtype=torch.bfloat16
torch.equal:  True

Both paths produce identical, correct results.

@cicirori cicirori force-pushed the fix/export-embed-tokens-dtype branch 2 times, most recently from a18454c to ed56787 Compare March 11, 2026 07:11
…wrong dtype

- Include embed_tokens from FSDP checkpoint instead of skipping it
  (the FSDP checkpoint already contains the correct target embeddings)
- Optionally override embed_tokens from target model via --target-model-path
  (for compatibility or verification)
- Add --embedding-key option for models with non-standard embedding key
- Infer dtype from checkpoint weights and cast HF model before loading to
  avoid bf16→fp16 precision loss
- Set torch_dtype in config to match actual weight dtype

Fixes #38
@cicirori cicirori force-pushed the fix/export-embed-tokens-dtype branch from ed56787 to 91c82a0 Compare March 11, 2026 07:15
@cicirori cicirori merged commit 7fe42ac into main Mar 11, 2026
1 check passed
@cicirori cicirori deleted the fix/export-embed-tokens-dtype branch March 11, 2026 07:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Exported Eagle3 checkpoints contain random embed_tokens weights, causing poor acceptance length in vLLM

2 participants