-
Notifications
You must be signed in to change notification settings - Fork 247
[skyrl-train] Fix loss reduction by moving normalization to the advantage computation #925
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
| for param in self.model.parameters(): | ||
| if param.grad is not None: | ||
| param.grad.mul_(self.strategy.world_size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we could do this at the advantage computation level, but i thought it was a bit weird to have ddp all-reduce implementation details there so i separated it to be here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah i agree that this is the right separation
…loss_reduction2
…loss_reduction2
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request effectively addresses the 'mean of means' bias in PPO policy loss reduction by moving the normalization logic from the loss function to the advantage computation. However, a potential division-by-zero vulnerability was identified in the new normalize_minibatch_advantages function in trainer.py. This could lead to numerical instability (NaNs) and training failure if a mini-batch contains only masked-out sequences; a fix using .clamp(min=1.0) is recommended. Additionally, I have one suggestion to improve the robustness of the configuration validation.
| # assert cfg.trainer.algorithm.loss_reduction in ( | ||
| # "token_mean", | ||
| # "sequence_mean", | ||
| # "seq_mean_token_sum_norm", | ||
| # ), ( | ||
| # f"invalid loss_reduction: {cfg.trainer.algorithm.loss_reduction}. " | ||
| # f"Must be one of `['token_mean', 'sequence_mean', 'seq_mean_token_sum_norm']`" | ||
| # ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This assertion for loss_reduction has been commented out. While the normalization logic has moved to trainer.py, this validation is still crucial. If an invalid loss_reduction value is provided in the configuration, normalize_minibatch_advantages will silently fail to normalize the advantages, as it lacks an else block for unknown values. This would result in an un-normalized sum for the loss, which could be very large and lead to training instability. It's safer to fail fast with an explicit error.
I recommend re-enabling this assertion to ensure only valid loss_reduction options are accepted.
| # assert cfg.trainer.algorithm.loss_reduction in ( | |
| # "token_mean", | |
| # "sequence_mean", | |
| # "seq_mean_token_sum_norm", | |
| # ), ( | |
| # f"invalid loss_reduction: {cfg.trainer.algorithm.loss_reduction}. " | |
| # f"Must be one of `['token_mean', 'sequence_mean', 'seq_mean_token_sum_norm']`" | |
| # ) | |
| assert cfg.trainer.algorithm.loss_reduction in ( | |
| "token_mean", | |
| "sequence_mean", | |
| "seq_mean_token_sum_norm", | |
| ), ( | |
| f"invalid loss_reduction: {cfg.trainer.algorithm.loss_reduction}. " | |
| f"Must be one of `['token_mean', 'sequence_mean', 'seq_mean_token_sum_norm']`" | |
| ) |
| # Option 1: token mean | ||
| if self.cfg.trainer.algorithm.loss_reduction == "token_mean": | ||
| data["advantages"] = advantages / loss_mask.sum() | ||
|
|
||
| # Option 2: sequence mean | ||
| elif self.cfg.trainer.algorithm.loss_reduction == "sequence_mean": | ||
| batch_size = len(data) | ||
| data["advantages"] = advantages / (batch_size * loss_mask.sum(dim=-1, keepdim=True)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The normalize_minibatch_advantages function performs division by loss_mask.sum() (line 1036) and loss_mask.sum(dim=-1, keepdim=True) (line 1041) without verifying if the divisor is zero. In Reinforcement Learning training, if a mini-batch consists entirely of sequences that are masked out (e.g., due to filtering or empty responses), the sum of the loss_mask will be zero. Dividing by zero will result in inf or nan values in the advantages tensor, which will propagate to the gradients and corrupt the model weights during the optimizer step. This effectively causes a Denial of Service (DoS) on the training process.
Recommendation: Use .clamp(min=1.0) on the divisor to ensure it is never zero, consistent with the implementation of masked_mean in skyrl_train/utils/ppo_utils.py.
| # Option 1: token mean | |
| if self.cfg.trainer.algorithm.loss_reduction == "token_mean": | |
| data["advantages"] = advantages / loss_mask.sum() | |
| # Option 2: sequence mean | |
| elif self.cfg.trainer.algorithm.loss_reduction == "sequence_mean": | |
| batch_size = len(data) | |
| data["advantages"] = advantages / (batch_size * loss_mask.sum(dim=-1, keepdim=True)) | |
| # Option 1: token mean | |
| if self.cfg.trainer.algorithm.loss_reduction == "token_mean": | |
| data["advantages"] = advantages / loss_mask.sum().clamp(min=1.0) | |
| # Option 2: sequence mean | |
| elif self.cfg.trainer.algorithm.loss_reduction == "sequence_mean": | |
| batch_size = len(data) | |
| data["advantages"] = advantages / (batch_size * loss_mask.sum(dim=-1, keepdim=True).clamp(min=1.0)) |
…loss_reduction2
…loss_reduction2
…loss_reduction2
Summary
The previous implementation for ppo policy loss reduction had a "mean of means" bias — when computing token-mean loss across micro-batches and workers with varying token counts, the naive averaging gave incorrect results where:
Micro-batch 1: 100 tokens, average loss = 0.5, micro-batch 2: 900 tokens, average loss = 0.3Naive mean: (0.5 + 0.3) / 2 = 0.4, Correct token-mean: (100×0.5 + 900×0.3) / 1000 = 0.32After this PR,
ppo_policy_lossused withinforward_backwardnow just sums the per-token loss for all sequences and relies on the advantages passed in by the user to handle the loss normalization.This aligns with Tinker semantics:
Example for
loss_reduction="token_mean":1/num_minibatch_tokensnormalization into the advantage:loss = sum( -advantage_i * ratio_i for i in range(num_minibatch_tokens) ) / num_minibatch_tokenssum( -(advantage_i / num_minibatch_tokens) * ratio_i for i in range(num_minibatch_tokens) )DDP all-reduce
DDP/FSDP defaults to a mean all-reduce for gradients across workers. This PR counteracts this by multiplying by the DP world size.
Additional details
This was the first attempt: #909
This method was to track total tokens and then do one big normalization at the
optim_stepin order to get an average per-token loss. But, we decided to align with Tinker's way of just summing up the loss at the end, and pushing any loss normalization to the user's advantage calculation.The benefit is that users have full control of customizing their loss reduction strategy, rather than having it happen in our opaque
forward_backward,optim_stepimplementation which would require some configuration argument that diverges from tinker's API. For example, we would need to add a config somewhere to determine how to average/sum the loss:Follow-up work
The
ppo_critic_losshas the same problem but is not as important as the policy loss.