From a55616e6bf90a0ebf19fea1e943860beee04b8c1 Mon Sep 17 00:00:00 2001 From: Justin Yu Date: Tue, 20 Jan 2026 16:20:22 -0800 Subject: [PATCH 1/3] add token_sum reduction Signed-off-by: Justin Yu --- skyrl-train/skyrl_train/utils/ppo_utils.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/skyrl-train/skyrl_train/utils/ppo_utils.py b/skyrl-train/skyrl_train/utils/ppo_utils.py index d814aedd7..e7d32a24a 100644 --- a/skyrl-train/skyrl_train/utils/ppo_utils.py +++ b/skyrl-train/skyrl_train/utils/ppo_utils.py @@ -566,9 +566,10 @@ def ppo_policy_loss( loss_reduction = config.loss_reduction assert loss_reduction in [ "token_mean", + "token_sum", "sequence_mean", "seq_mean_token_sum_norm", - ], "loss_reduction must be either 'token_mean', 'sequence_mean', or 'seq_mean_token_sum_norm'" + ], "loss_reduction must be one of 'token_mean', 'token_sum', 'sequence_mean', or 'seq_mean_token_sum_norm'" ratio = _safe_exp_delta(log_probs - old_log_probs, clip=20.0, out_dtype=log_probs.dtype) surr1 = ratio * advantages @@ -881,12 +882,18 @@ def compute_policy_loss_kl_cov( def reduce_loss( loss: torch.Tensor, loss_mask: Optional[torch.Tensor], - loss_reduction: Literal["token_mean", "sequence_mean", "seq_mean_token_sum_norm"], + loss_reduction: Literal["token_mean", "token_sum", "sequence_mean", "seq_mean_token_sum_norm"], max_seq_len: Optional[int] = None, ) -> torch.Tensor: if loss_reduction == "token_mean": # sum over *all* valid tokens, divide by total valid-token count loss = masked_mean(loss, loss_mask) + elif loss_reduction == "token_sum": + # sum over *all* valid tokens without averaging + if loss_mask is not None: + loss = (loss * loss_mask).sum() + else: + loss = loss.sum() elif loss_reduction == "sequence_mean": # per-sequence token-mean (dim=-1), then batch-mean loss = masked_mean(loss, loss_mask, dim=-1).mean() From 02c4125a22987f90f67fda940426d91c4c44dbec Mon Sep 17 00:00:00 2001 From: Justin Yu Date: Tue, 20 Jan 2026 17:16:11 -0800 Subject: [PATCH 2/3] add hacky token_mean impl Signed-off-by: Justin Yu --- skyrl-train/skyrl_train/workers/worker.py | 29 +++++++++++++++++++++-- 1 file changed, 27 insertions(+), 2 deletions(-) diff --git a/skyrl-train/skyrl_train/workers/worker.py b/skyrl-train/skyrl_train/workers/worker.py index 7fc253d6c..58a9bcc52 100644 --- a/skyrl-train/skyrl_train/workers/worker.py +++ b/skyrl-train/skyrl_train/workers/worker.py @@ -657,6 +657,7 @@ def _normalize_mini_batch_size(self): # Track micro batches for gradient scaling at optim_step self._micro_batches_accumulated = 0 + self._total_tokens_accumulated = 0 def forward_backward(self, data: TrainingInputBatch) -> Dict[str, float]: """ @@ -677,6 +678,10 @@ def forward_backward(self, data: TrainingInputBatch) -> Dict[str, float]: for micro_batch in BatchIterator(data, micro_batch_size, drop_last=False): metrics = self._forward_backward_micro(micro_batch) self._micro_batches_accumulated += 1 + + # Track total tokens accumulated to compute average loss per token during optim_step + self._total_tokens_accumulated += micro_batch.attention_mask.sum().item() + for k, v in metrics.items(): all_metrics[k].append(v) @@ -783,9 +788,28 @@ def optim_step(self) -> float: Returns: The gradient norm (before scaling, after clipping) """ - # Scale accumulated gradients by 1/N to get correct average if self._micro_batches_accumulated > 0: - scale = 1.0 / self._micro_batches_accumulated + if self.cfg.trainer.algorithm.loss_reduction == "token_sum": + # Scale by the total number of tokens accumulated across all workers. + total_tokens_accumulated_tensor = torch.tensor( + self._total_tokens_accumulated, + dtype=torch.long, + device=torch.cuda.current_device(), + ) + global_tokens_accumulated = self.strategy.all_reduce(total_tokens_accumulated_tensor, op="sum").item() + + # When loss_reduction="token_sum", each worker computes the + # sum of token losses across microbatches. + # FSDP all-reduce averages this local sum across workers, + # which divides the global sum by the number of workers. + # To counteract this, we multiply by the number of workers + # to recover the global token loss sum, and divide by the number of tokens + # to get the average token loss. + scale = self.strategy.world_size / global_tokens_accumulated + else: + # Scale accumulated gradients by 1/N to get correct average + scale = 1.0 / self._micro_batches_accumulated + for param in self.model.parameters(): if param.grad is not None: param.grad.mul_(scale) @@ -795,6 +819,7 @@ def optim_step(self) -> float: # Reset counter for next accumulation cycle self._micro_batches_accumulated = 0 + self._total_tokens_accumulated = 0 if grad_norm is not None: grad_norm = grad_norm.detach().cpu().item() From edc91f311b25807bf53137b6aeee4fbfe9a8d5ef Mon Sep 17 00:00:00 2001 From: Justin Yu Date: Tue, 20 Jan 2026 17:31:11 -0800 Subject: [PATCH 3/3] add token_sum vs token_mean_v2 Signed-off-by: Justin Yu --- skyrl-train/skyrl_train/utils/ppo_utils.py | 9 ++++++--- skyrl-train/skyrl_train/workers/worker.py | 7 +++++-- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/skyrl-train/skyrl_train/utils/ppo_utils.py b/skyrl-train/skyrl_train/utils/ppo_utils.py index e7d32a24a..a1d01a585 100644 --- a/skyrl-train/skyrl_train/utils/ppo_utils.py +++ b/skyrl-train/skyrl_train/utils/ppo_utils.py @@ -566,15 +566,17 @@ def ppo_policy_loss( loss_reduction = config.loss_reduction assert loss_reduction in [ "token_mean", + "token_mean_v2", "token_sum", "sequence_mean", "seq_mean_token_sum_norm", - ], "loss_reduction must be one of 'token_mean', 'token_sum', 'sequence_mean', or 'seq_mean_token_sum_norm'" + ], "loss_reduction must be one of 'token_mean', 'token_mean_v2', 'token_sum', 'sequence_mean', or 'seq_mean_token_sum_norm'" ratio = _safe_exp_delta(log_probs - old_log_probs, clip=20.0, out_dtype=log_probs.dtype) surr1 = ratio * advantages surr2 = ratio.clamp(1 - config.eps_clip_low, 1 + config.eps_clip_high) * advantages loss = -torch.min(surr1, surr2) + clip_ratio = masked_mean((-surr2 > -surr1).float(), loss_mask).mean().detach().item() clip_pg_losses1 = loss if config.policy_loss_type == "dual_clip": @@ -882,14 +884,15 @@ def compute_policy_loss_kl_cov( def reduce_loss( loss: torch.Tensor, loss_mask: Optional[torch.Tensor], - loss_reduction: Literal["token_mean", "token_sum", "sequence_mean", "seq_mean_token_sum_norm"], + loss_reduction: Literal["token_mean", "token_mean_v2", "token_sum", "sequence_mean", "seq_mean_token_sum_norm"], max_seq_len: Optional[int] = None, ) -> torch.Tensor: if loss_reduction == "token_mean": # sum over *all* valid tokens, divide by total valid-token count loss = masked_mean(loss, loss_mask) - elif loss_reduction == "token_sum": + elif loss_reduction == "token_sum" or loss_reduction == "token_mean_v2": # sum over *all* valid tokens without averaging + # token_mean_v2 will divide the gradient by total number of tokens before optim_step if loss_mask is not None: loss = (loss * loss_mask).sum() else: diff --git a/skyrl-train/skyrl_train/workers/worker.py b/skyrl-train/skyrl_train/workers/worker.py index 58a9bcc52..8d968aaf7 100644 --- a/skyrl-train/skyrl_train/workers/worker.py +++ b/skyrl-train/skyrl_train/workers/worker.py @@ -789,7 +789,7 @@ def optim_step(self) -> float: The gradient norm (before scaling, after clipping) """ if self._micro_batches_accumulated > 0: - if self.cfg.trainer.algorithm.loss_reduction == "token_sum": + if self.cfg.trainer.algorithm.loss_reduction in ["token_sum", "token_mean_v2"]: # Scale by the total number of tokens accumulated across all workers. total_tokens_accumulated_tensor = torch.tensor( self._total_tokens_accumulated, @@ -805,7 +805,10 @@ def optim_step(self) -> float: # To counteract this, we multiply by the number of workers # to recover the global token loss sum, and divide by the number of tokens # to get the average token loss. - scale = self.strategy.world_size / global_tokens_accumulated + scale = self.strategy.world_size + + if self.cfg.trainer.algorithm.loss_reduction == "token_mean_v2": + scale /= global_tokens_accumulated else: # Scale accumulated gradients by 1/N to get correct average scale = 1.0 / self._micro_batches_accumulated