Skip to content

[bug] KL Divergence explodes to billions in GRPO training #664

@gitlost-murali

Description

@gitlost-murali

Hello,

🐛 Describe the bug

I found very high KL divergence scores on the main branch with

python -m apps.grpo.main --config apps/grpo/qwen3_1_7b.yaml

and

group_size: 8
local_batch_size: 2 # per-device batch size
max_req_tokens: 1024
max_res_tokens: 16384

(complete config file: qwen3_1_7b_kl_checks.yaml )

Here are the logged numbers with KL divergence scores

grpo_loss/kl_divergence_mean: 1,568,987.96
grpo_loss/kl_divergence_max: 6,821,680,128.0
grpo_loss/policy_gradient_loss: -0.198
grpo_loss/total_loss: 1.77
buffer/sample/avg_sampled_policy_age: 1.0 
Image

beta=1e-6 is suppressing these high values but something seems really off

Versions

Tested latest on e940fd89ad9b262e690988f2f1a5ee6cb0d25574

Config: apps/grpo/qwen3_1_7b.yaml with 16k response length and batch size 2

Step: 9
torch: 2.9.0+cu128
torchtitan: 0.2.0
vllm: 0.10.1.dev0+g6d8d0a24c.d20251219

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions