diff --git a/skyrl-train/skyrl_train/utils/ppo_utils.py b/skyrl-train/skyrl_train/utils/ppo_utils.py index d814aedd7..a1d01a585 100644 --- a/skyrl-train/skyrl_train/utils/ppo_utils.py +++ b/skyrl-train/skyrl_train/utils/ppo_utils.py @@ -566,14 +566,17 @@ def ppo_policy_loss( loss_reduction = config.loss_reduction assert loss_reduction in [ "token_mean", + "token_mean_v2", + "token_sum", "sequence_mean", "seq_mean_token_sum_norm", - ], "loss_reduction must be either 'token_mean', 'sequence_mean', or 'seq_mean_token_sum_norm'" + ], "loss_reduction must be one of 'token_mean', 'token_mean_v2', 'token_sum', 'sequence_mean', or 'seq_mean_token_sum_norm'" ratio = _safe_exp_delta(log_probs - old_log_probs, clip=20.0, out_dtype=log_probs.dtype) surr1 = ratio * advantages surr2 = ratio.clamp(1 - config.eps_clip_low, 1 + config.eps_clip_high) * advantages loss = -torch.min(surr1, surr2) + clip_ratio = masked_mean((-surr2 > -surr1).float(), loss_mask).mean().detach().item() clip_pg_losses1 = loss if config.policy_loss_type == "dual_clip": @@ -881,12 +884,19 @@ def compute_policy_loss_kl_cov( def reduce_loss( loss: torch.Tensor, loss_mask: Optional[torch.Tensor], - loss_reduction: Literal["token_mean", "sequence_mean", "seq_mean_token_sum_norm"], + loss_reduction: Literal["token_mean", "token_mean_v2", "token_sum", "sequence_mean", "seq_mean_token_sum_norm"], max_seq_len: Optional[int] = None, ) -> torch.Tensor: if loss_reduction == "token_mean": # sum over *all* valid tokens, divide by total valid-token count loss = masked_mean(loss, loss_mask) + elif loss_reduction == "token_sum" or loss_reduction == "token_mean_v2": + # sum over *all* valid tokens without averaging + # token_mean_v2 will divide the gradient by total number of tokens before optim_step + if loss_mask is not None: + loss = (loss * loss_mask).sum() + else: + loss = loss.sum() elif loss_reduction == "sequence_mean": # per-sequence token-mean (dim=-1), then batch-mean loss = masked_mean(loss, loss_mask, dim=-1).mean() diff --git a/skyrl-train/skyrl_train/workers/worker.py b/skyrl-train/skyrl_train/workers/worker.py index 7fc253d6c..8d968aaf7 100644 --- a/skyrl-train/skyrl_train/workers/worker.py +++ b/skyrl-train/skyrl_train/workers/worker.py @@ -657,6 +657,7 @@ def _normalize_mini_batch_size(self): # Track micro batches for gradient scaling at optim_step self._micro_batches_accumulated = 0 + self._total_tokens_accumulated = 0 def forward_backward(self, data: TrainingInputBatch) -> Dict[str, float]: """ @@ -677,6 +678,10 @@ def forward_backward(self, data: TrainingInputBatch) -> Dict[str, float]: for micro_batch in BatchIterator(data, micro_batch_size, drop_last=False): metrics = self._forward_backward_micro(micro_batch) self._micro_batches_accumulated += 1 + + # Track total tokens accumulated to compute average loss per token during optim_step + self._total_tokens_accumulated += micro_batch.attention_mask.sum().item() + for k, v in metrics.items(): all_metrics[k].append(v) @@ -783,9 +788,31 @@ def optim_step(self) -> float: Returns: The gradient norm (before scaling, after clipping) """ - # Scale accumulated gradients by 1/N to get correct average if self._micro_batches_accumulated > 0: - scale = 1.0 / self._micro_batches_accumulated + if self.cfg.trainer.algorithm.loss_reduction in ["token_sum", "token_mean_v2"]: + # Scale by the total number of tokens accumulated across all workers. + total_tokens_accumulated_tensor = torch.tensor( + self._total_tokens_accumulated, + dtype=torch.long, + device=torch.cuda.current_device(), + ) + global_tokens_accumulated = self.strategy.all_reduce(total_tokens_accumulated_tensor, op="sum").item() + + # When loss_reduction="token_sum", each worker computes the + # sum of token losses across microbatches. + # FSDP all-reduce averages this local sum across workers, + # which divides the global sum by the number of workers. + # To counteract this, we multiply by the number of workers + # to recover the global token loss sum, and divide by the number of tokens + # to get the average token loss. + scale = self.strategy.world_size + + if self.cfg.trainer.algorithm.loss_reduction == "token_mean_v2": + scale /= global_tokens_accumulated + else: + # Scale accumulated gradients by 1/N to get correct average + scale = 1.0 / self._micro_batches_accumulated + for param in self.model.parameters(): if param.grad is not None: param.grad.mul_(scale) @@ -795,6 +822,7 @@ def optim_step(self) -> float: # Reset counter for next accumulation cycle self._micro_batches_accumulated = 0 + self._total_tokens_accumulated = 0 if grad_norm is not None: grad_norm = grad_norm.detach().cpu().item()