-
Notifications
You must be signed in to change notification settings - Fork 73
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Hello,
🐛 Describe the bug
I found very high KL divergence scores on the main branch with
python -m apps.grpo.main --config apps/grpo/qwen3_1_7b.yamland
group_size: 8
local_batch_size: 2 # per-device batch size
max_req_tokens: 1024
max_res_tokens: 16384
(complete config file: qwen3_1_7b_kl_checks.yaml )
Here are the logged numbers with KL divergence scores
grpo_loss/kl_divergence_mean: 1,568,987.96
grpo_loss/kl_divergence_max: 6,821,680,128.0
grpo_loss/policy_gradient_loss: -0.198
grpo_loss/total_loss: 1.77
buffer/sample/avg_sampled_policy_age: 1.0
beta=1e-6 is suppressing these high values but something seems really off
Versions
Tested latest on e940fd89ad9b262e690988f2f1a5ee6cb0d25574
Config: apps/grpo/qwen3_1_7b.yaml with 16k response length and batch size 2
Step: 9
torch: 2.9.0+cu128
torchtitan: 0.2.0
vllm: 0.10.1.dev0+g6d8d0a24c.d20251219
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working