Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions skyrl-train/skyrl_train/utils/ppo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,14 +566,17 @@ def ppo_policy_loss(
loss_reduction = config.loss_reduction
assert loss_reduction in [
"token_mean",
"token_mean_v2",
"token_sum",
"sequence_mean",
"seq_mean_token_sum_norm",
], "loss_reduction must be either 'token_mean', 'sequence_mean', or 'seq_mean_token_sum_norm'"
], "loss_reduction must be one of 'token_mean', 'token_mean_v2', 'token_sum', 'sequence_mean', or 'seq_mean_token_sum_norm'"

ratio = _safe_exp_delta(log_probs - old_log_probs, clip=20.0, out_dtype=log_probs.dtype)
surr1 = ratio * advantages
surr2 = ratio.clamp(1 - config.eps_clip_low, 1 + config.eps_clip_high) * advantages
loss = -torch.min(surr1, surr2)

clip_ratio = masked_mean((-surr2 > -surr1).float(), loss_mask).mean().detach().item()
clip_pg_losses1 = loss
if config.policy_loss_type == "dual_clip":
Expand Down Expand Up @@ -881,12 +884,19 @@ def compute_policy_loss_kl_cov(
def reduce_loss(
loss: torch.Tensor,
loss_mask: Optional[torch.Tensor],
loss_reduction: Literal["token_mean", "sequence_mean", "seq_mean_token_sum_norm"],
loss_reduction: Literal["token_mean", "token_mean_v2", "token_sum", "sequence_mean", "seq_mean_token_sum_norm"],
max_seq_len: Optional[int] = None,
) -> torch.Tensor:
if loss_reduction == "token_mean":
# sum over *all* valid tokens, divide by total valid-token count
loss = masked_mean(loss, loss_mask)
elif loss_reduction == "token_sum" or loss_reduction == "token_mean_v2":
# sum over *all* valid tokens without averaging
# token_mean_v2 will divide the gradient by total number of tokens before optim_step
if loss_mask is not None:
loss = (loss * loss_mask).sum()
else:
loss = loss.sum()
elif loss_reduction == "sequence_mean":
# per-sequence token-mean (dim=-1), then batch-mean
loss = masked_mean(loss, loss_mask, dim=-1).mean()
Expand Down
32 changes: 30 additions & 2 deletions skyrl-train/skyrl_train/workers/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,6 +657,7 @@ def _normalize_mini_batch_size(self):

# Track micro batches for gradient scaling at optim_step
self._micro_batches_accumulated = 0
self._total_tokens_accumulated = 0

def forward_backward(self, data: TrainingInputBatch) -> Dict[str, float]:
"""
Expand All @@ -677,6 +678,10 @@ def forward_backward(self, data: TrainingInputBatch) -> Dict[str, float]:
for micro_batch in BatchIterator(data, micro_batch_size, drop_last=False):
metrics = self._forward_backward_micro(micro_batch)
self._micro_batches_accumulated += 1

# Track total tokens accumulated to compute average loss per token during optim_step
self._total_tokens_accumulated += micro_batch.attention_mask.sum().item()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should actually be loss_mask


for k, v in metrics.items():
all_metrics[k].append(v)

Expand Down Expand Up @@ -783,9 +788,31 @@ def optim_step(self) -> float:
Returns:
The gradient norm (before scaling, after clipping)
"""
# Scale accumulated gradients by 1/N to get correct average
if self._micro_batches_accumulated > 0:
scale = 1.0 / self._micro_batches_accumulated
if self.cfg.trainer.algorithm.loss_reduction in ["token_sum", "token_mean_v2"]:
# Scale by the total number of tokens accumulated across all workers.
total_tokens_accumulated_tensor = torch.tensor(
self._total_tokens_accumulated,
dtype=torch.long,
device=torch.cuda.current_device(),
)
global_tokens_accumulated = self.strategy.all_reduce(total_tokens_accumulated_tensor, op="sum").item()

# When loss_reduction="token_sum", each worker computes the
# sum of token losses across microbatches.
# FSDP all-reduce averages this local sum across workers,
# which divides the global sum by the number of workers.
# To counteract this, we multiply by the number of workers
# to recover the global token loss sum, and divide by the number of tokens
# to get the average token loss.
scale = self.strategy.world_size

if self.cfg.trainer.algorithm.loss_reduction == "token_mean_v2":
scale /= global_tokens_accumulated
else:
# Scale accumulated gradients by 1/N to get correct average
scale = 1.0 / self._micro_batches_accumulated

for param in self.model.parameters():
if param.grad is not None:
param.grad.mul_(scale)
Expand All @@ -795,6 +822,7 @@ def optim_step(self) -> float:

# Reset counter for next accumulation cycle
self._micro_batches_accumulated = 0
self._total_tokens_accumulated = 0

if grad_norm is not None:
grad_norm = grad_norm.detach().cpu().item()
Expand Down
Loading