From b7b494c922d32174fb9a67e0101dcbd3e73b981e Mon Sep 17 00:00:00 2001 From: Justin Yu Date: Thu, 22 Jan 2026 17:49:51 -0800 Subject: [PATCH 01/20] normalize the advantages instead Signed-off-by: Justin Yu --- skyrl-train/skyrl_train/trainer.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/skyrl-train/skyrl_train/trainer.py b/skyrl-train/skyrl_train/trainer.py index 8292c5d3a..1acc970a7 100644 --- a/skyrl-train/skyrl_train/trainer.py +++ b/skyrl-train/skyrl_train/trainer.py @@ -1026,6 +1026,22 @@ def sync_policy_weights_to_inference_engines(self) -> List[ObjectRef]: "pass_through", "broadcast_to_inference_engines", self.inference_engine_client ) + def _normalize_minibatch_advantages(self, data: TrainingInputBatch) -> TrainingInputBatch: + """Normalize the advantages in the mini-batch.""" + advantages = data["advantages"] + loss_mask = data["loss_mask"] + + # NOTE: Do not modify the tensor in place! + # Otherwise subsequent epochs will keep dividing the same tensor. + + # Option 1: token mean + data["advantages"] = advantages / loss_mask.sum() + + # Option 2: sequence mean + # data["advantages"] = advantages / (data.batch_size * loss_mask.sum(dim=-1, keepdim=True)) + + return data + def _execute_training_step(self, model: str, data: TrainingInputBatch) -> Dict[str, float]: """ Execute training step for FSDP strategy using forward_backward + optim_step. @@ -1057,6 +1073,8 @@ def _execute_training_step(self, model: str, data: TrainingInputBatch) -> Dict[s end_idx = (local_step + 1) * mini_batch_size mini_batch = data[start_idx:end_idx] + mini_batch = self._normalize_minibatch_advantages(mini_batch) + status = self.dispatch.forward_backward(model, mini_batch) for k, v in status.items(): all_metrics[k].append(v) From da24f6b3cd8716d383c94306e9b7378cd66709e7 Mon Sep 17 00:00:00 2001 From: Justin Yu Date: Thu, 22 Jan 2026 17:52:41 -0800 Subject: [PATCH 02/20] sum reduction for the loss Signed-off-by: Justin Yu --- skyrl-train/skyrl_train/utils/ppo_utils.py | 10 ++++++++-- skyrl-train/skyrl_train/utils/utils.py | 16 ++++++++-------- 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/skyrl-train/skyrl_train/utils/ppo_utils.py b/skyrl-train/skyrl_train/utils/ppo_utils.py index d814aedd7..ec5ff15d4 100644 --- a/skyrl-train/skyrl_train/utils/ppo_utils.py +++ b/skyrl-train/skyrl_train/utils/ppo_utils.py @@ -565,10 +565,11 @@ def ppo_policy_loss( assert config.policy_loss_type in ["regular", "dual_clip"], "loss_type must be either 'regular' or 'dual_clip'" loss_reduction = config.loss_reduction assert loss_reduction in [ + "sum", "token_mean", "sequence_mean", "seq_mean_token_sum_norm", - ], "loss_reduction must be either 'token_mean', 'sequence_mean', or 'seq_mean_token_sum_norm'" + ], "loss_reduction must be either 'sum', 'token_mean', 'sequence_mean', or 'seq_mean_token_sum_norm'" ratio = _safe_exp_delta(log_probs - old_log_probs, clip=20.0, out_dtype=log_probs.dtype) surr1 = ratio * advantages @@ -881,7 +882,7 @@ def compute_policy_loss_kl_cov( def reduce_loss( loss: torch.Tensor, loss_mask: Optional[torch.Tensor], - loss_reduction: Literal["token_mean", "sequence_mean", "seq_mean_token_sum_norm"], + loss_reduction: Literal["sum", "token_mean", "sequence_mean", "seq_mean_token_sum_norm"], max_seq_len: Optional[int] = None, ) -> torch.Tensor: if loss_reduction == "token_mean": @@ -901,6 +902,11 @@ def reduce_loss( # If no mask, assume all tokens are valid seq_losses = torch.sum(loss, dim=-1) / max_seq_len loss = torch.mean(seq_losses) + elif loss_reduction == "sum": + if loss_mask is not None: + loss = torch.sum(loss * loss_mask) + else: + loss = torch.sum(loss) else: raise ValueError(f"Invalid loss reduction type: {loss_reduction}") return loss diff --git a/skyrl-train/skyrl_train/utils/utils.py b/skyrl-train/skyrl_train/utils/utils.py index 2c4dc8f7e..072c61686 100644 --- a/skyrl-train/skyrl_train/utils/utils.py +++ b/skyrl-train/skyrl_train/utils/utils.py @@ -254,14 +254,14 @@ def validate_cfg(cfg: DictConfig): f"Must be one of {available_advantage_estimators}" ) - assert cfg.trainer.algorithm.loss_reduction in ( - "token_mean", - "sequence_mean", - "seq_mean_token_sum_norm", - ), ( - f"invalid loss_reduction: {cfg.trainer.algorithm.loss_reduction}. " - f"Must be one of `['token_mean', 'sequence_mean', 'seq_mean_token_sum_norm']`" - ) + # assert cfg.trainer.algorithm.loss_reduction in ( + # "token_mean", + # "sequence_mean", + # "seq_mean_token_sum_norm", + # ), ( + # f"invalid loss_reduction: {cfg.trainer.algorithm.loss_reduction}. " + # f"Must be one of `['token_mean', 'sequence_mean', 'seq_mean_token_sum_norm']`" + # ) # add field to algorithm config needed for loss functions # create a new config to make it modifiable From cbed7ff69555473aba0e91f6c47ce48f3f7a2a74 Mon Sep 17 00:00:00 2001 From: Justin Yu Date: Thu, 22 Jan 2026 18:09:34 -0800 Subject: [PATCH 03/20] add a few options Signed-off-by: Justin Yu --- skyrl-train/skyrl_train/trainer.py | 13 ++++++++++--- skyrl-train/skyrl_train/workers/worker.py | 13 ++++++------- 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/skyrl-train/skyrl_train/trainer.py b/skyrl-train/skyrl_train/trainer.py index 1acc970a7..69358c0d2 100644 --- a/skyrl-train/skyrl_train/trainer.py +++ b/skyrl-train/skyrl_train/trainer.py @@ -1027,7 +1027,11 @@ def sync_policy_weights_to_inference_engines(self) -> List[ObjectRef]: ) def _normalize_minibatch_advantages(self, data: TrainingInputBatch) -> TrainingInputBatch: - """Normalize the advantages in the mini-batch.""" + """Normalize the advantages in the mini-batch. + + This normalization results in calculating the correct minibatch loss for the + given loss reduction type when reducing the loss with a sum. + """ advantages = data["advantages"] loss_mask = data["loss_mask"] @@ -1035,10 +1039,13 @@ def _normalize_minibatch_advantages(self, data: TrainingInputBatch) -> TrainingI # Otherwise subsequent epochs will keep dividing the same tensor. # Option 1: token mean - data["advantages"] = advantages / loss_mask.sum() + if self.cfg.trainer.algorithm.loss_reduction == "token_mean": + data["advantages"] = advantages / loss_mask.sum() # Option 2: sequence mean - # data["advantages"] = advantages / (data.batch_size * loss_mask.sum(dim=-1, keepdim=True)) + elif self.cfg.trainer.algorithm.loss_reduction == "sequence_mean": + batch_size = len(data) + data["advantages"] = advantages / (batch_size * loss_mask.sum(dim=-1, keepdim=True)) return data diff --git a/skyrl-train/skyrl_train/workers/worker.py b/skyrl-train/skyrl_train/workers/worker.py index 7fc253d6c..427ff772d 100644 --- a/skyrl-train/skyrl_train/workers/worker.py +++ b/skyrl-train/skyrl_train/workers/worker.py @@ -778,17 +778,16 @@ def _forward_backward_micro(self, experience: Experience) -> Dict[str, float]: def optim_step(self) -> float: """ - Scale gradients by 1/micro_batches_accumulated, perform optimizer step, and reset counter. + Perform optimizer step. Returns: The gradient norm (before scaling, after clipping) """ - # Scale accumulated gradients by 1/N to get correct average - if self._micro_batches_accumulated > 0: - scale = 1.0 / self._micro_batches_accumulated - for param in self.model.parameters(): - if param.grad is not None: - param.grad.mul_(scale) + # Scale gradients by data parallelism size to undo the DDP all-reduce mean. + scale = 1.0 / self.strategy.world_size + for param in self.model.parameters(): + if param.grad is not None: + param.grad.mul_(scale) # Perform optimizer step (includes gradient clipping) grad_norm = self.strategy.optimizer_step(self.optimizer, self.model, self.scheduler, name="actor") From fc7b7757a3ec724bb05878a35aaeaa5a7f1dc3b0 Mon Sep 17 00:00:00 2001 From: Justin Yu Date: Thu, 22 Jan 2026 18:13:42 -0800 Subject: [PATCH 04/20] always sum on the loss calculation side Signed-off-by: Justin Yu --- skyrl-train/skyrl_train/utils/ppo_utils.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/skyrl-train/skyrl_train/utils/ppo_utils.py b/skyrl-train/skyrl_train/utils/ppo_utils.py index ec5ff15d4..9c30c2815 100644 --- a/skyrl-train/skyrl_train/utils/ppo_utils.py +++ b/skyrl-train/skyrl_train/utils/ppo_utils.py @@ -591,7 +591,13 @@ def ppo_policy_loss( tis_imp_ratio = torch.clamp(tis_imp_ratio, max=config.tis_imp_ratio_cap) loss = loss * tis_imp_ratio - loss = reduce_loss(loss, loss_mask, loss_reduction, config.max_seq_len) + # NOTE: We scaled the advantages to handle the loss normalization in the trainer. + # So we just need to sum the token-level losses here. + if loss_mask is not None: + loss = loss * loss_mask + loss = loss.sum() + # loss = reduce_loss(loss, loss_mask, loss_reduction, config.max_seq_len) + return loss, clip_ratio @@ -882,7 +888,7 @@ def compute_policy_loss_kl_cov( def reduce_loss( loss: torch.Tensor, loss_mask: Optional[torch.Tensor], - loss_reduction: Literal["sum", "token_mean", "sequence_mean", "seq_mean_token_sum_norm"], + loss_reduction: Literal["token_mean", "sequence_mean", "seq_mean_token_sum_norm"], max_seq_len: Optional[int] = None, ) -> torch.Tensor: if loss_reduction == "token_mean": @@ -902,11 +908,6 @@ def reduce_loss( # If no mask, assume all tokens are valid seq_losses = torch.sum(loss, dim=-1) / max_seq_len loss = torch.mean(seq_losses) - elif loss_reduction == "sum": - if loss_mask is not None: - loss = torch.sum(loss * loss_mask) - else: - loss = torch.sum(loss) else: raise ValueError(f"Invalid loss reduction type: {loss_reduction}") return loss From c4115860de6e7595bc079492e6636f14aa81c70a Mon Sep 17 00:00:00 2001 From: Justin Yu Date: Thu, 22 Jan 2026 18:17:07 -0800 Subject: [PATCH 05/20] fix some bugs Signed-off-by: Justin Yu --- skyrl-train/skyrl_train/utils/ppo_utils.py | 3 +-- skyrl-train/skyrl_train/workers/worker.py | 14 +++++--------- 2 files changed, 6 insertions(+), 11 deletions(-) diff --git a/skyrl-train/skyrl_train/utils/ppo_utils.py b/skyrl-train/skyrl_train/utils/ppo_utils.py index 9c30c2815..9252af17e 100644 --- a/skyrl-train/skyrl_train/utils/ppo_utils.py +++ b/skyrl-train/skyrl_train/utils/ppo_utils.py @@ -565,11 +565,10 @@ def ppo_policy_loss( assert config.policy_loss_type in ["regular", "dual_clip"], "loss_type must be either 'regular' or 'dual_clip'" loss_reduction = config.loss_reduction assert loss_reduction in [ - "sum", "token_mean", "sequence_mean", "seq_mean_token_sum_norm", - ], "loss_reduction must be either 'sum', 'token_mean', 'sequence_mean', or 'seq_mean_token_sum_norm'" + ], "loss_reduction must be either 'token_mean', 'sequence_mean', or 'seq_mean_token_sum_norm'" ratio = _safe_exp_delta(log_probs - old_log_probs, clip=20.0, out_dtype=log_probs.dtype) surr1 = ratio * advantages diff --git a/skyrl-train/skyrl_train/workers/worker.py b/skyrl-train/skyrl_train/workers/worker.py index 427ff772d..708fa6895 100644 --- a/skyrl-train/skyrl_train/workers/worker.py +++ b/skyrl-train/skyrl_train/workers/worker.py @@ -784,10 +784,9 @@ def optim_step(self) -> float: The gradient norm (before scaling, after clipping) """ # Scale gradients by data parallelism size to undo the DDP all-reduce mean. - scale = 1.0 / self.strategy.world_size for param in self.model.parameters(): if param.grad is not None: - param.grad.mul_(scale) + param.grad.mul_(self.strategy.world_size) # Perform optimizer step (includes gradient clipping) grad_norm = self.strategy.optimizer_step(self.optimizer, self.model, self.scheduler, name="actor") @@ -984,17 +983,14 @@ def _forward_backward_micro(self, experience: Experience) -> Dict[str, float]: def optim_step(self) -> float: """ - Scale gradients by 1/micro_batches_accumulated, perform optimizer step, and reset counter. + Perform optimizer step. Returns: The gradient norm (before scaling, after clipping) """ - # Scale accumulated gradients by 1/N to get correct average - if self._micro_batches_accumulated > 0: - scale = 1.0 / self._micro_batches_accumulated - for param in self.model.parameters(): - if param.grad is not None: - param.grad.mul_(scale) + for param in self.model.parameters(): + if param.grad is not None: + param.grad.mul_(self.strategy.world_size) # Perform optimizer step (includes gradient clipping) grad_norm = self.strategy.optimizer_step(self.optimizer, self.model, self.scheduler, name="critic") From 1bf83d8bff3634de21fff6d6fed40c17381c885c Mon Sep 17 00:00:00 2001 From: Justin Yu Date: Thu, 22 Jan 2026 18:24:00 -0800 Subject: [PATCH 06/20] revert the critic worker changes Signed-off-by: Justin Yu --- skyrl-train/skyrl_train/utils/ppo_utils.py | 1 + skyrl-train/skyrl_train/workers/worker.py | 9 ++++++--- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/skyrl-train/skyrl_train/utils/ppo_utils.py b/skyrl-train/skyrl_train/utils/ppo_utils.py index 9252af17e..9df1b37dc 100644 --- a/skyrl-train/skyrl_train/utils/ppo_utils.py +++ b/skyrl-train/skyrl_train/utils/ppo_utils.py @@ -192,6 +192,7 @@ def ppo_critic_loss( clipfrac = None loss = (values - returns) ** 2 + # TODO: We separately run into the "mean of means" problem here. loss = masked_mean(loss, loss_mask, dim=-1).mean() return 0.5 * loss, clipfrac diff --git a/skyrl-train/skyrl_train/workers/worker.py b/skyrl-train/skyrl_train/workers/worker.py index 708fa6895..d73add566 100644 --- a/skyrl-train/skyrl_train/workers/worker.py +++ b/skyrl-train/skyrl_train/workers/worker.py @@ -988,9 +988,12 @@ def optim_step(self) -> float: Returns: The gradient norm (before scaling, after clipping) """ - for param in self.model.parameters(): - if param.grad is not None: - param.grad.mul_(self.strategy.world_size) + # Scale accumulated gradients by 1/N to get correct average + if self._micro_batches_accumulated > 0: + scale = 1.0 / self._micro_batches_accumulated + for param in self.model.parameters(): + if param.grad is not None: + param.grad.mul_(scale) # Perform optimizer step (includes gradient clipping) grad_norm = self.strategy.optimizer_step(self.optimizer, self.model, self.scheduler, name="critic") From dc838e10b7094bf03b7354804e1dd3029a207bf2 Mon Sep 17 00:00:00 2001 From: Justin Yu Date: Thu, 22 Jan 2026 18:27:02 -0800 Subject: [PATCH 07/20] more revert Signed-off-by: Justin Yu --- skyrl-train/skyrl_train/workers/worker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skyrl-train/skyrl_train/workers/worker.py b/skyrl-train/skyrl_train/workers/worker.py index d73add566..0e42f417b 100644 --- a/skyrl-train/skyrl_train/workers/worker.py +++ b/skyrl-train/skyrl_train/workers/worker.py @@ -983,7 +983,7 @@ def _forward_backward_micro(self, experience: Experience) -> Dict[str, float]: def optim_step(self) -> float: """ - Perform optimizer step. + Scale gradients by 1/micro_batches_accumulated, perform optimizer step, and reset counter. Returns: The gradient norm (before scaling, after clipping) From 62591745eb00be1815632476434111fe21d9c365 Mon Sep 17 00:00:00 2001 From: Eric Tang Date: Wed, 28 Jan 2026 22:31:38 +0000 Subject: [PATCH 08/20] fix conflict with main --- skyrl-train/skyrl_train/trainer.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/skyrl-train/skyrl_train/trainer.py b/skyrl-train/skyrl_train/trainer.py index 618b87f33..6e802689f 100644 --- a/skyrl-train/skyrl_train/trainer.py +++ b/skyrl-train/skyrl_train/trainer.py @@ -1072,19 +1072,25 @@ def _execute_training_step(self, model: str, data: TrainingInputBatch) -> Dict[s mini_batch_size = self.cfg.trainer.critic_mini_batch_size * n_samples all_metrics: Dict[str, List[float]] = defaultdict(list) + num_mini_batches = len(data) // mini_batch_size + + # iterate over mini-batches to do mini batch level normalization + for local_step in range(num_mini_batches): + start_idx = local_step * mini_batch_size + end_idx = (local_step + 1) * mini_batch_size + mini_batch = data[start_idx:end_idx] + mini_batch = self._normalize_minibatch_advantages(mini_batch) + data[start_idx:end_idx] = mini_batch # Stage full batch in object store ONCE to avoid repeated serialization data_ref = self.dispatch.stage_data(data) # Training loop over epochs and mini-batches for _epoch in range(self.cfg.trainer.update_epochs_per_batch): - num_mini_batches = len(data) // mini_batch_size for local_step in range(num_mini_batches): start_idx = local_step * mini_batch_size end_idx = (local_step + 1) * mini_batch_size - mini_batch = self._normalize_minibatch_advantages(mini_batch) - # Workers fetch from object store and slice locally status = self.dispatch.forward_backward_from_staged(model, data_ref, start_idx, end_idx) for k, v in status.items(): From 4b8e556d54b0f7c3778ba27ad19b4368e20e3e56 Mon Sep 17 00:00:00 2001 From: Eric Tang Date: Wed, 28 Jan 2026 22:35:38 +0000 Subject: [PATCH 09/20] x --- skyrl-train/skyrl_train/trainer.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/skyrl-train/skyrl_train/trainer.py b/skyrl-train/skyrl_train/trainer.py index 6e802689f..dfb78a5a0 100644 --- a/skyrl-train/skyrl_train/trainer.py +++ b/skyrl-train/skyrl_train/trainer.py @@ -1019,11 +1019,6 @@ def apply_reward_kl_penalty( return data - def sync_policy_weights_to_inference_engines(self) -> List[ObjectRef]: - return self.policy_model.async_run_ray_method( - "pass_through", "broadcast_to_inference_engines", self.inference_engine_client - ) - def _normalize_minibatch_advantages(self, data: TrainingInputBatch) -> TrainingInputBatch: """Normalize the advantages in the mini-batch. From 4547ae1da30f9524227652f0c41308ca483d30c0 Mon Sep 17 00:00:00 2001 From: Eric Tang Date: Thu, 29 Jan 2026 22:12:30 +0000 Subject: [PATCH 10/20] fix normalization --- skyrl-train/skyrl_train/trainer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/skyrl-train/skyrl_train/trainer.py b/skyrl-train/skyrl_train/trainer.py index dfb78a5a0..7f97e7e3e 100644 --- a/skyrl-train/skyrl_train/trainer.py +++ b/skyrl-train/skyrl_train/trainer.py @@ -1075,7 +1075,8 @@ def _execute_training_step(self, model: str, data: TrainingInputBatch) -> Dict[s end_idx = (local_step + 1) * mini_batch_size mini_batch = data[start_idx:end_idx] mini_batch = self._normalize_minibatch_advantages(mini_batch) - data[start_idx:end_idx] = mini_batch + # Copy normalized advantages back to original batch + data["advantages"][start_idx:end_idx] = mini_batch["advantages"] # Stage full batch in object store ONCE to avoid repeated serialization data_ref = self.dispatch.stage_data(data) From 635a9c1e853e2939422166c8b462f72936361a84 Mon Sep 17 00:00:00 2001 From: Eric Tang Date: Fri, 30 Jan 2026 22:50:53 +0000 Subject: [PATCH 11/20] change reduce_loss to just sums everywhere --- skyrl-train/docs/configuration/config.rst | 2 +- .../main_on_policy_distill.py | 2 +- .../run_on_policy_distill_math_qwen3_1.7b.sh | 1 + .../run_on_policy_distill_math_qwen3_4b.sh | 1 + skyrl-train/skyrl_train/trainer.py | 10 +++- skyrl-train/skyrl_train/utils/ppo_utils.py | 55 +++---------------- skyrl-train/tests/cpu/utils/test_ppo_utils.py | 32 ++++------- 7 files changed, 30 insertions(+), 73 deletions(-) diff --git a/skyrl-train/docs/configuration/config.rst b/skyrl-train/docs/configuration/config.rst index 1871613cb..0094149c6 100644 --- a/skyrl-train/docs/configuration/config.rst +++ b/skyrl-train/docs/configuration/config.rst @@ -510,7 +510,7 @@ It can be helpful to understand the final loss formulation to see how the differ pg_losses3 = -advantages * config.clip_ratio_c clip_pg_losses2 = torch.min(pg_losses3, clip_pg_losses1) loss = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1) - loss = reduce_loss(loss, loss_mask, config.loss_reduction) + loss = reduce_loss(loss, loss_mask) return loss, clip_ratio diff --git a/skyrl-train/examples/on_policy_distillation/main_on_policy_distill.py b/skyrl-train/examples/on_policy_distillation/main_on_policy_distill.py index c7030cbcc..eea7da5b4 100644 --- a/skyrl-train/examples/on_policy_distillation/main_on_policy_distill.py +++ b/skyrl-train/examples/on_policy_distillation/main_on_policy_distill.py @@ -51,7 +51,7 @@ def compute_importance_sampling_policy_loss( # as defined here: https://tinker-docs.thinkingmachines.ai/losses#policy-gradient-importance_sampling loss = -torch.exp(log_probs - old_log_probs) * advantages - loss = reduce_loss(loss, loss_mask, "seq_mean_token_sum_norm", config.max_seq_len) + loss = reduce_loss(loss, loss_mask) # return loss and a dummy clip ratio value as we aren't clipping here return loss, 0.0 diff --git a/skyrl-train/examples/on_policy_distillation/run_on_policy_distill_math_qwen3_1.7b.sh b/skyrl-train/examples/on_policy_distillation/run_on_policy_distill_math_qwen3_1.7b.sh index 250e8252a..6be4ab3ee 100644 --- a/skyrl-train/examples/on_policy_distillation/run_on_policy_distill_math_qwen3_1.7b.sh +++ b/skyrl-train/examples/on_policy_distillation/run_on_policy_distill_math_qwen3_1.7b.sh @@ -74,6 +74,7 @@ uv run --isolated --extra vllm -m examples.on_policy_distillation.main_on_policy trainer.policy.optimizer_config.weight_decay=0.1 \ trainer.algorithm.use_kl_loss=$USE_KL_LOSS \ trainer.algorithm.use_kl_in_reward=$USE_KL_IN_REWARD \ + trainer.algorithm.loss_reduction="seq_mean_token_sum_norm" \ generator.backend=vllm \ generator.run_engines_locally=true \ generator.async_engine=false \ diff --git a/skyrl-train/examples/on_policy_distillation/run_on_policy_distill_math_qwen3_4b.sh b/skyrl-train/examples/on_policy_distillation/run_on_policy_distill_math_qwen3_4b.sh index 670a35152..fb60594ae 100644 --- a/skyrl-train/examples/on_policy_distillation/run_on_policy_distill_math_qwen3_4b.sh +++ b/skyrl-train/examples/on_policy_distillation/run_on_policy_distill_math_qwen3_4b.sh @@ -74,6 +74,7 @@ uv run --isolated --extra vllm -m examples.on_policy_distillation.main_on_policy trainer.policy.optimizer_config.weight_decay=0.1 \ trainer.algorithm.use_kl_loss=$USE_KL_LOSS \ trainer.algorithm.use_kl_in_reward=$USE_KL_IN_REWARD \ + trainer.algorithm.loss_reduction="seq_mean_token_sum_norm" \ generator.backend=vllm \ generator.run_engines_locally=true \ generator.async_engine=false \ diff --git a/skyrl-train/skyrl_train/trainer.py b/skyrl-train/skyrl_train/trainer.py index 7f97e7e3e..73666c6c3 100644 --- a/skyrl-train/skyrl_train/trainer.py +++ b/skyrl-train/skyrl_train/trainer.py @@ -1019,7 +1019,7 @@ def apply_reward_kl_penalty( return data - def _normalize_minibatch_advantages(self, data: TrainingInputBatch) -> TrainingInputBatch: + def normalize_minibatch_advantages(self, data: TrainingInputBatch) -> TrainingInputBatch: """Normalize the advantages in the mini-batch. This normalization results in calculating the correct minibatch loss for the @@ -1040,6 +1040,12 @@ def _normalize_minibatch_advantages(self, data: TrainingInputBatch) -> TrainingI batch_size = len(data) data["advantages"] = advantages / (batch_size * loss_mask.sum(dim=-1, keepdim=True)) + # option 3: Dr. GRPO style loss reduction to avoid length bias by normalizing by a constant + elif self.cfg.trainer.algorithm.loss_reduction == "seq_mean_token_sum_norm": + batch_size = len(data) + max_seq_len = self.cfg.trainer.algorithm.max_seq_len + data["advantages"] = advantages / (batch_size * max_seq_len) + return data def _execute_training_step(self, model: str, data: TrainingInputBatch) -> Dict[str, float]: @@ -1074,7 +1080,7 @@ def _execute_training_step(self, model: str, data: TrainingInputBatch) -> Dict[s start_idx = local_step * mini_batch_size end_idx = (local_step + 1) * mini_batch_size mini_batch = data[start_idx:end_idx] - mini_batch = self._normalize_minibatch_advantages(mini_batch) + mini_batch = self.normalize_minibatch_advantages(mini_batch) # Copy normalized advantages back to original batch data["advantages"][start_idx:end_idx] = mini_batch["advantages"] diff --git a/skyrl-train/skyrl_train/utils/ppo_utils.py b/skyrl-train/skyrl_train/utils/ppo_utils.py index 8edf72d0d..aa4f04964 100644 --- a/skyrl-train/skyrl_train/utils/ppo_utils.py +++ b/skyrl-train/skyrl_train/utils/ppo_utils.py @@ -19,7 +19,7 @@ from collections import defaultdict from enum import StrEnum from functools import wraps -from typing import Callable, List, Literal, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import numpy as np import ray @@ -593,12 +593,7 @@ def ppo_policy_loss( tis_imp_ratio = torch.clamp(tis_imp_ratio, max=config.tis_imp_ratio_cap) loss = loss * tis_imp_ratio - # NOTE: We scaled the advantages to handle the loss normalization in the trainer. - # So we just need to sum the token-level losses here. - if loss_mask is not None: - loss = loss * loss_mask loss = loss.sum() - # loss = reduce_loss(loss, loss_mask, loss_reduction, config.max_seq_len) return loss, clip_ratio @@ -663,7 +658,7 @@ def gate_function(x, tau): loss = -gates * advantages # for SAPO, we need to aggregate the loss at the sequence level (seq-mean-token-mean) - loss = reduce_loss(loss, loss_mask, loss_reduction, config.max_seq_len) + loss = reduce_loss(loss, loss_mask) # SAPO does not use clipping, so we set clip_ratio to 0.0 for compatibility clip_ratio = 0.0 @@ -726,7 +721,7 @@ def gspo_policy_loss( # Compute clipping ratio for monitoring clip_ratio = masked_mean((-surr2 > -surr1).float(), loss_mask).mean().detach().item() - loss = reduce_loss(loss, loss_mask, loss_reduction, config.max_seq_len) + loss = reduce_loss(loss, loss_mask) return loss, clip_ratio @@ -756,7 +751,7 @@ def compute_policy_loss_cispo( is_clipped = (ratio < 1 - config.cispo.cispo_eps_clip_low) | (ratio > 1 + config.cispo.cispo_eps_clip_high) clip_ratio = masked_mean(is_clipped.float(), loss_mask).mean().detach().item() - loss = reduce_loss(loss, loss_mask, config.loss_reduction, config.max_seq_len) + loss = reduce_loss(loss, loss_mask) return loss, clip_ratio @@ -817,12 +812,7 @@ def compute_policy_loss_clip_cov( # Apply correction mask to losses pg_losses = torch.maximum(pg_losses1, pg_losses2) * corr - pg_loss = reduce_loss( - loss=pg_losses, - loss_mask=loss_mask, - loss_reduction=config.loss_reduction, - max_seq_len=config.max_seq_len, - ) + pg_loss = reduce_loss(loss=pg_losses, loss_mask=loss_mask) return pg_loss, clip_frac.item() @@ -876,12 +866,7 @@ def compute_policy_loss_kl_cov( large_cov_idxs % advantages.shape[1], ] - pg_loss = reduce_loss( - loss=pg_losses, - loss_mask=loss_mask, - loss_reduction=config.loss_reduction, - max_seq_len=config.max_seq_len, - ) + pg_loss = reduce_loss(loss=pg_losses, loss_mask=loss_mask) # NOTE (sumanthrh): Since the pg clip ratio is not applicable for KL-COV so we just use 0.0 return pg_loss, 0.0 @@ -920,10 +905,7 @@ def cross_entropy_loss( elementwise_loss = -log_probs # Apply loss mask and sum (matching Tinker's SUM reduction semantics) - if loss_mask is not None: - loss = (elementwise_loss * loss_mask).sum() - else: - loss = elementwise_loss.sum() + loss = reduce_loss(elementwise_loss, loss_mask) # No clipping in cross-entropy loss return loss, 0.0 @@ -932,29 +914,8 @@ def cross_entropy_loss( def reduce_loss( loss: torch.Tensor, loss_mask: Optional[torch.Tensor], - loss_reduction: Literal["token_mean", "sequence_mean", "seq_mean_token_sum_norm"], - max_seq_len: Optional[int] = None, ) -> torch.Tensor: - if loss_reduction == "token_mean": - # sum over *all* valid tokens, divide by total valid-token count - loss = masked_mean(loss, loss_mask) - elif loss_reduction == "sequence_mean": - # per-sequence token-mean (dim=-1), then batch-mean - loss = masked_mean(loss, loss_mask, dim=-1).mean() - elif loss_reduction == "seq_mean_token_sum_norm": - # per-sequence token-sum, normalized by the max sequence length, then batch mean - # this is the Dr. GRPO loss reduction to avoid length bias by normalizing by a constant - assert max_seq_len is not None, "max_seq_len must be provided for seq_mean_token_sum_norm loss reduction" - # NOTE: max_seq_len is computed as cfg.generator.max_input_length + cfg.generator.sampling_params.max_generate_length by default - if loss_mask is not None: - seq_losses = torch.sum(loss * loss_mask, dim=-1) / max_seq_len - else: - # If no mask, assume all tokens are valid - seq_losses = torch.sum(loss, dim=-1) / max_seq_len - loss = torch.mean(seq_losses) - else: - raise ValueError(f"Invalid loss reduction type: {loss_reduction}") - return loss + return (loss * loss_mask).sum() if loss_mask is not None else loss.sum() # NOTE (erictang000): below ported from verl diff --git a/skyrl-train/tests/cpu/utils/test_ppo_utils.py b/skyrl-train/tests/cpu/utils/test_ppo_utils.py index fb69ce15e..ef36fa3eb 100644 --- a/skyrl-train/tests/cpu/utils/test_ppo_utils.py +++ b/skyrl-train/tests/cpu/utils/test_ppo_utils.py @@ -243,29 +243,17 @@ def test_compute_gae_advantage_return_lam(advantage_test_data): def test_reduce_loss(): - """Test the reduce_loss function with different reduction types.""" - # Test data: 2x3 loss tensor with different valid token counts per sequence + """Test that reduce_loss computes the masked sum correctly.""" loss = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) - loss_mask = torch.tensor([[1.0, 1.0, 1.0], [1.0, 0.0, 0.0]]) # seq0 has 3 tokens, seq1 has 1 token - - # Test token_mean: sum all valid losses / count valid tokens - # Valid losses: [1.0, 2.0, 3.0, 4.0], mean = 10.0/4 = 2.5 - result_token = reduce_loss(loss, loss_mask, "token_mean") - expected_token = torch.tensor(2.5) - assert torch.allclose(result_token, expected_token), f"Expected {expected_token}, got {result_token}" - - # Test sequence_mean: mean of per-sequence means - # Seq 0: (1.0 + 2.0 + 3.0) / 3 = 2.0, Seq 1: 4.0 / 1 = 4.0, batch mean = (2.0 + 4.0) / 2 = 3.0 - result_seq = reduce_loss(loss, loss_mask, "sequence_mean") - expected_seq = torch.tensor(3.0) - assert torch.allclose(result_seq, expected_seq), f"Expected {expected_seq}, got {result_seq}" - - # Test seq_mean_token_sum_norm: sum per sequence / max_len, then batch mean - # Seq 0: (1.0 + 2.0 + 3.0) / 4 = 1.5, Seq 1: 4.0 / 4 = 1.0, batch mean = (1.5 + 1.0) / 2 = 1.25 - max_seq_len = 4 - result_max = reduce_loss(loss, loss_mask, "seq_mean_token_sum_norm", max_seq_len) - expected_max = torch.tensor(1.25) - assert torch.allclose(result_max, expected_max), f"Expected {expected_max}, got {result_max}" + loss_mask = torch.tensor([[1.0, 1.0, 1.0], [1.0, 0.0, 0.0]]) + + # With mask: sum of valid losses = 1.0 + 2.0 + 3.0 + 4.0 = 10.0 + result = reduce_loss(loss, loss_mask) + assert torch.allclose(result, torch.tensor(10.0)) + + # Without mask: sum of all losses = 1+2+3+4+5+6 = 21.0 + result_no_mask = reduce_loss(loss, None) + assert torch.allclose(result_no_mask, torch.tensor(21.0)) def test_adaptive_kl_controller_update(): From cd7506bb5c6578ffc5616a6b0ec7fcdea09805a3 Mon Sep 17 00:00:00 2001 From: Eric Tang Date: Sat, 31 Jan 2026 00:42:14 +0000 Subject: [PATCH 12/20] fix tests --- skyrl-train/skyrl_train/utils/ppo_utils.py | 28 +-- .../tests/cpu/algorithms/test_losses.py | 189 +----------------- 2 files changed, 11 insertions(+), 206 deletions(-) diff --git a/skyrl-train/skyrl_train/utils/ppo_utils.py b/skyrl-train/skyrl_train/utils/ppo_utils.py index aa4f04964..5fe06ff93 100644 --- a/skyrl-train/skyrl_train/utils/ppo_utils.py +++ b/skyrl-train/skyrl_train/utils/ppo_utils.py @@ -566,12 +566,6 @@ def ppo_policy_loss( rollout_logprobs: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, float]: assert config.policy_loss_type in ["regular", "dual_clip"], "loss_type must be either 'regular' or 'dual_clip'" - loss_reduction = config.loss_reduction - assert loss_reduction in [ - "token_mean", - "sequence_mean", - "seq_mean_token_sum_norm", - ], "loss_reduction must be either 'token_mean', 'sequence_mean', or 'seq_mean_token_sum_norm'" ratio = _safe_exp_delta(log_probs - old_log_probs, clip=20.0, out_dtype=log_probs.dtype) surr1 = ratio * advantages @@ -615,16 +609,7 @@ def sapo_policy_loss( See https://arxiv.org/pdf/2511.20347 for more details. """ - # SAPO must use sequence_mean reduction - loss_reduction = config.loss_reduction - if loss_reduction != "sequence_mean": - # The SAPO paper uses sequence_mean reduction; there's no reason - # why a user couldn't use token_mean reduction, but - # it's not clear whether it would be stable or not. - from loguru import logger as logger_ # have to do lazy import to avoid pickling error - - logger_.warning(f"With SAPO it's recommended to use 'sequence_mean' loss reduction; got {loss_reduction}") - + # SAPO should use sequence_mean reduction to avoid length bias # temperature for positive and negative token updates tau_pos = torch.as_tensor(config.sapo.tau_pos, dtype=advantages.dtype, device=advantages.device) tau_neg = torch.as_tensor(config.sapo.tau_neg, dtype=advantages.dtype, device=advantages.device) @@ -687,16 +672,7 @@ def gspo_policy_loss( The variant of GSPO used here is GSPO-token, a generalization which allows for token-level advantages [equations 14 and 15 in the paper]. """ - # GSPO must use sequence_mean reduction - loss_reduction = config.loss_reduction - if loss_reduction != "sequence_mean": - # The GSPO paper uses sequence_mean reduction; there's no reason - # why a user couldn't use token_mean reduction, but - # it's not clear whether it would be stable or not. - from loguru import logger as logger_ # have to do lazy import to avoid pickling error - - logger_.warning(f"With GSPO it's recommended to use 'sequence_mean' loss reduction; got {loss_reduction}") - + # GSPO should use sequence_mean reduction to avoid length bias # Compute log ratios log_ratio = log_probs - old_log_probs diff --git a/skyrl-train/tests/cpu/algorithms/test_losses.py b/skyrl-train/tests/cpu/algorithms/test_losses.py index f5904b595..2a95b07fa 100644 --- a/skyrl-train/tests/cpu/algorithms/test_losses.py +++ b/skyrl-train/tests/cpu/algorithms/test_losses.py @@ -31,7 +31,6 @@ def test_policy_loss_dual_clip(): "eps_clip_high": 0.2, "clip_ratio_c": 3.0, "policy_loss_type": "dual_clip", - "loss_reduction": "token_mean", "max_seq_len": 4, "use_tis": False, } @@ -56,7 +55,7 @@ def test_policy_loss_dual_clip(): # For negative advantages, use dual clipped loss final_loss = torch.where(advantages < 0, min_loss, max_loss) # [-0.5, 1.0, 12.0] assert torch.allclose(final_loss, torch.tensor([[-0.5, 1.0, 12.0]], device=device), rtol=1e-3) - expected_loss = final_loss.mean() # -(-12.5/3) = 4.1667 + expected_loss = final_loss.sum() # 12.5 # Calculate actual loss actual_loss, _ = loss_fn(log_probs=log_probs, old_log_probs=old_log_probs, advantages=advantages, config=config) @@ -64,7 +63,7 @@ def test_policy_loss_dual_clip(): # Verify results torch.testing.assert_close(actual_loss, expected_loss, rtol=1e-3, atol=1e-8) # close to hand calculated value - assert actual_loss.item() == pytest.approx(4.1667, abs=1e-4) + assert actual_loss.item() == pytest.approx(12.5, abs=1e-4) def test_policy_loss_cispo(): @@ -84,7 +83,6 @@ def test_policy_loss_cispo(): { "cispo": {"cispo_eps_clip_low": 0.2, "cispo_eps_clip_high": 0.2}, "policy_loss_type": "cispo", - "loss_reduction": "token_mean", "max_seq_len": 4, "use_tis": False, } @@ -106,9 +104,9 @@ def test_policy_loss_cispo(): # loss_per_token[0] = -(1.0 * 0.8 * -1.69315) = 1.35452 # loss_per_token[1] = -(-1.0 * 1.0 * -1.0) = -1.0 # loss_per_token[2] = -(-4.0 * 1.2 * -0.69741) = -3.347568 - # mean(loss) = (1.35452 - 1.0 - 3.347568) / 3 = -0.99768266666 + # sum(loss) = 1.35452 - 1.0 - 3.347568 = -2.99 loss = -ratio.clamp(1 - 0.2, 1 + 0.2) * advantages * log_probs - expected_loss = loss.mean() + expected_loss = loss.sum() # Calculate actual loss actual_loss, _ = loss_fn( @@ -121,167 +119,7 @@ def test_policy_loss_cispo(): # Verify results torch.testing.assert_close(actual_loss, expected_loss, rtol=1e-3, atol=1e-8) # close to hand calculated value - assert actual_loss.item() == pytest.approx(-0.99768266666, abs=1e-4) - - -def test_policy_loss_reduction_modes(): - """Tests different loss_reduction modes in PolicyLoss function. - - Note: token_mean and sequence_mean give the same result when all sequences - have the same length and no mask is applied, but differ when masking creates - different effective sequence lengths. - """ - - device = "cpu" - - clip_eps_low = 0.2 - clip_eps_high = 0.2 - - advantages = torch.tensor( - [ - [2.0, 2.0, 2.0], # sequence 1: consistently higher advantages - [1.0, 1.0, 1.0], # sequence 2: consistently lower advantages - ], - device=device, - ) - - old_log_probs = torch.tensor([[-1.0, -1.0, -1.0], [-1.0, -1.0, -1.0]], device=device) - - log_probs = torch.tensor( - [[-1.5, -0.5, -1.2], [-0.8, -1.3, -0.9]], # ratios ≈ [[0.61, 1.65, 0.83],[1.22, 0.74, 1.11]] - device=device, - ) - - # Create masks to test sequences with different numbers of valid tokens - loss_mask = torch.tensor([[1.0, 1.0, 1.0], [1.0, 0.0, 0.0]], device=device) - - # Create configs for different reduction modes - config_token = DictConfig( - { - "eps_clip_low": clip_eps_low, - "eps_clip_high": clip_eps_high, - "clip_ratio_c": 3.0, - "policy_loss_type": "regular", - "loss_reduction": "token_mean", - "max_seq_len": 4, - "use_tis": False, - } - ) - - config_seq = DictConfig( - { - "eps_clip_low": clip_eps_low, - "eps_clip_high": clip_eps_high, - "clip_ratio_c": 3.0, - "policy_loss_type": "regular", - "loss_reduction": "sequence_mean", - "max_seq_len": 4, - "use_tis": False, - } - ) - - # Get loss function - loss_fn = PolicyLossRegistry.get("regular") - - # Test token_mean without mask - loss_token_no_mask, _ = loss_fn(log_probs, old_log_probs, advantages, config_token) - - # Test token_mean with mask - loss_token_with_mask, _ = loss_fn(log_probs, old_log_probs, advantages, config_token, loss_mask) - - # Test sequence_mean without mask - loss_seq_no_mask, _ = loss_fn(log_probs, old_log_probs, advantages, config_seq) - - # Test sequence_mean with mask - loss_seq_with_mask, _ = loss_fn(log_probs, old_log_probs, advantages, config_seq, loss_mask) - - # Manual calculations to verify (using default PolicyLoss parameters) - ratio = torch.exp(log_probs - old_log_probs) - surr1 = ratio * advantages - surr2 = ratio.clamp(1 - clip_eps_low, 1 + clip_eps_high) * advantages # clip_eps_low=0.2, clip_eps_high=0.2 - loss_per_token = -torch.min(surr1, surr2) - - # Expected token_mean without mask: mean of all tokens - expected_token_no_mask = loss_per_token.mean() - - # Expected token_mean with mask: masked mean of all tokens - expected_token_with_mask = (loss_per_token * loss_mask).sum() / (loss_mask.sum() + 1e-8) - - # Expected sequence_mean without mask: mean of sequence means - expected_seq_no_mask = loss_per_token.mean(dim=1).mean() - - # Expected sequence_mean with mask: mean of masked sequence means - seq_means_masked = (loss_per_token * loss_mask).sum(dim=1) / (loss_mask.sum(dim=1) + 1e-8) - expected_seq_with_mask = seq_means_masked.mean() - - # Verify results - torch.testing.assert_close(loss_token_no_mask, expected_token_no_mask, rtol=1e-5, atol=1e-8) - torch.testing.assert_close(loss_token_with_mask, expected_token_with_mask, rtol=1e-5, atol=1e-8) - torch.testing.assert_close(loss_seq_no_mask, expected_seq_no_mask, rtol=1e-5, atol=1e-8) - torch.testing.assert_close(loss_seq_with_mask, expected_seq_with_mask, rtol=1e-5, atol=1e-8) - - # Verify that the two reduction modes give the same results when sequences have equal length and no mask - assert torch.allclose( - loss_token_no_mask, loss_seq_no_mask, rtol=1e-5 - ), "token_mean and sequence_mean should give same results when sequences have equal length and no mask" - # But they should give different results when mask creates different effective sequence lengths - assert not torch.allclose( - loss_token_with_mask, loss_seq_with_mask, rtol=1e-3 - ), "token_mean and sequence_mean with mask should give different results" - - -def test_policy_loss_reduction_edge_cases(): - """Tests edge cases for loss_reduction modes.""" - - device = "cpu" - - # Test with single sequence (should give same result for both modes) - advantages = torch.tensor([[1.0, -1.0, 2.0]], device=device) - old_log_probs = torch.tensor([[-1.0, -1.0, -1.0]], device=device) - log_probs = torch.tensor([[-1.5, -0.5, -1.2]], device=device) - - # Create configs for different reduction modes - config_token = DictConfig( - { - "eps_clip_low": 0.2, - "eps_clip_high": 0.2, - "clip_ratio_c": 3.0, - "policy_loss_type": "regular", - "loss_reduction": "token_mean", - "max_seq_len": 4, - "use_tis": False, - } - ) - - config_seq = DictConfig( - { - "eps_clip_low": 0.2, - "eps_clip_high": 0.2, - "clip_ratio_c": 3.0, - "policy_loss_type": "regular", - "loss_reduction": "sequence_mean", - "max_seq_len": 4, - "use_tis": False, - } - ) - - # Get loss function - loss_fn = PolicyLossRegistry.get("regular") - - loss_token, _ = loss_fn(log_probs, old_log_probs, advantages, config_token) - loss_seq, _ = loss_fn(log_probs, old_log_probs, advantages, config_seq) - - # With single sequence, both modes should give same result - torch.testing.assert_close(loss_token, loss_seq, rtol=1e-6, atol=1e-8) - - # Test with completely masked sequence - loss_mask = torch.tensor([[0.0, 0.0, 0.0]], device=device) - loss_token_masked, _ = loss_fn(log_probs, old_log_probs, advantages, config_token, loss_mask) - loss_seq_masked, _ = loss_fn(log_probs, old_log_probs, advantages, config_seq, loss_mask) - - # Should handle zero mask gracefully (due to +1e-8 in denominator) - assert torch.isfinite(loss_token_masked) - assert torch.isfinite(loss_seq_masked) + assert actual_loss.item() == pytest.approx(-2.9930, abs=1e-4) def test_gspo_importance_sampling_levels(): @@ -345,7 +183,6 @@ def test_gspo_importance_sampling_levels(): "eps_clip_high": clip_eps_high, "clip_ratio_c": 3.0, "policy_loss_type": "regular", - "loss_reduction": "token_mean", "max_seq_len": 4, "use_tis": False, } @@ -360,7 +197,6 @@ def test_gspo_importance_sampling_levels(): "eps_clip_high": clip_eps_high, "clip_ratio_c": 3.0, "policy_loss_type": "gspo", - "loss_reduction": "sequence_mean", # GSPO recommended reduction "max_seq_len": 4, "use_tis": False, } @@ -374,7 +210,7 @@ def test_gspo_importance_sampling_levels(): surr1_token = ratio_token * advantages surr2_token = ratio_token.clamp(1 - clip_eps_low, 1 + clip_eps_high) * advantages loss_per_token_token = -torch.min(surr1_token, surr2_token) - expected_token = (loss_per_token_token * loss_mask).sum() / (loss_mask.sum() + 1e-8) + expected_token = (loss_per_token_token * loss_mask).sum() # Calculate token-level clipping ratio is_clipped_token = (-surr2_token > -surr1_token) & (loss_mask.bool()) @@ -390,8 +226,7 @@ def test_gspo_importance_sampling_levels(): surr1_sequence = ratio_sequence * advantages surr2_sequence = ratio_sequence.clamp(1 - clip_eps_low, 1 + clip_eps_high) * advantages loss_per_token_sequence = -torch.min(surr1_sequence, surr2_sequence) - # GSPO uses sequence_mean reduction - expected_sequence = masked_mean(loss_per_token_sequence, loss_mask, dim=-1).mean() + expected_sequence = loss_per_token_sequence.sum() # Calculate sequence-level clipping ratio is_clipped_sequence = (-surr2_sequence > -surr1_sequence) & (loss_mask.bool()) @@ -466,7 +301,6 @@ def test_clip_cov_policy_loss(): "eps_clip_low": 0.2, "eps_clip_high": 0.2, "policy_loss_type": "clip_cov", - "loss_reduction": "token_mean", "max_seq_len": 4, "clip_cov": {"clip_ratio": 0.5, "clip_cov_lb": -5.0, "clip_cov_ub": 5.0}, # Large ratio for testing } @@ -488,7 +322,6 @@ def test_clip_cov_policy_loss(): "eps_clip_low": 0.2, "eps_clip_high": 0.2, "policy_loss_type": "regular", - "loss_reduction": "token_mean", "max_seq_len": 4, "use_tis": False, } @@ -528,7 +361,6 @@ def test_kl_cov_policy_loss(): config = DictConfig( { "policy_loss_type": "kl_cov", - "loss_reduction": "token_mean", "max_seq_len": 4, "kl_cov": {"kl_cov_frac": 0.5, "ppo_kl_coef": 1.0}, # Apply KL to 50% of tokens } @@ -550,7 +382,6 @@ def test_kl_cov_policy_loss(): "eps_clip_low": 0.2, "eps_clip_high": 0.2, "policy_loss_type": "regular", - "loss_reduction": "token_mean", "max_seq_len": 4, "use_tis": False, } @@ -578,11 +409,10 @@ def test_sapo_policy_loss_basic(): # Ratios ≈ [exp(-0.5), exp(0.2), exp(-0.1)] ≈ [0.6065, 1.2214, 0.9048] log_probs = torch.tensor([[-1.5, -0.8, -1.1]], device=device) - # SAPO config: uses sequence_mean reduction and distinct tau_pos / tau_neg + # SAPO config: distinct tau_pos / tau_neg config = DictConfig( { "policy_loss_type": "sapo", - "loss_reduction": "sequence_mean", "max_seq_len": 4, "sapo": {"tau_pos": 1.0, "tau_neg": 2.0}, } @@ -614,8 +444,7 @@ def gate_function(x, tau): gates = gate_function(ratio, taus) loss_per_token = -gates * advantages - # sequence_mean reduction: per-sequence token mean, then batch mean - expected_loss = loss_per_token.mean(dim=-1).mean() + expected_loss = loss_per_token.sum() torch.testing.assert_close(actual_loss, expected_loss, rtol=1e-5, atol=1e-8) From ddeaae48154d467373dd894515071f1d7fa708a0 Mon Sep 17 00:00:00 2001 From: Eric Tang Date: Mon, 2 Feb 2026 02:10:27 +0000 Subject: [PATCH 13/20] scale by num_microbatches in megatron loss --- .../megatron/run_megatron_dapo_qwen3_1.7b.sh | 120 ++++++++++++++++++ .../megatron/megatron_model_wrapper.py | 8 ++ .../workers/megatron/megatron_worker.py | 3 + 3 files changed, 131 insertions(+) create mode 100644 skyrl-train/examples/megatron/run_megatron_dapo_qwen3_1.7b.sh diff --git a/skyrl-train/examples/megatron/run_megatron_dapo_qwen3_1.7b.sh b/skyrl-train/examples/megatron/run_megatron_dapo_qwen3_1.7b.sh new file mode 100644 index 000000000..2cd754cbb --- /dev/null +++ b/skyrl-train/examples/megatron/run_megatron_dapo_qwen3_1.7b.sh @@ -0,0 +1,120 @@ +set -x + +# Colocated DAPO training+generation for Qwen3-1.7B-Base on DAPO training data with Megatron. +# bash examples/algorithms/dapo/prepare_dapo_data.sh +# bash examples/megatron/run_megatron_dapo_qwen3_1.7b.sh + +MODEL_NAME="Qwen/Qwen3-1.7B-Base" +DATA_DIR="$HOME/data/dapo" +TRAIN_FILE="$DATA_DIR/dapo-math-17k-cleaned.parquet" +TEST_FILE="$DATA_DIR/aime-2024-cleaned.parquet" +NUM_NODES=1 +NUM_GPUS_PER_NODE=8 +NUM_INFERENCE_ENGINES=8 +INFERENCE_ENGINE_TENSOR_PARALLEL_SIZE=1 +LOGGER="wandb" # change to "console" to print to stdout + +CLIP_RATIO_LOW=0.2 +CLIP_RATIO_HIGH=0.28 +# use token mean loss reduction +LOSS_REDUCTION="token_mean" +# applies overlong filtering (but not soft overlong punishment) +APPLY_OVERLONG_FILTERING=true +# apply soft overlong punishment with custom trainer impl in main_dapo.py +OVERLONG_BUFFER_LEN=$((1024 * 4)) +OVERLONG_BUFFER_PENALTY_FACTOR=1.0 + +# other DAPO parameters +USE_KL_LOSS=false +TEMPERATURE=1.0 +TOP_P=1.0 +EVAL_TOP_P=0.7 +CLIP_RATIO_C=10.0 +MAX_PROMPT_LENGTH=$((1024 * 2)) +MAX_RESPONSE_LENGTH=$((1024 * 8)) + +# repro run parameters +TRAIN_BATCH_SIZE=512 +MINI_BATCH_SIZE=32 +N_SAMPLES_PER_PROMPT=16 +EVAL_N_SAMPLES_PER_PROMPT=32 +ENFORCE_EAGER=true # cuda graphs can cause some instability +LR=1e-6 + +# megatron config +MEGATRON_TP=1 +MEGATRON_PP=1 +MEGATRON_CP=1 +MEGATRON_EP=1 +MEGATRON_ETP=null + +# TIS parameters +TIS_IMP_RATIO_CAP=2.0 +USE_TIS=true + +uv run --isolated --extra mcore -m examples.algorithms.dapo.main_dapo \ + data.train_data="['$TRAIN_FILE']" \ + data.val_data="['$TEST_FILE']" \ + trainer.algorithm.advantage_estimator="grpo" \ + trainer.algorithm.policy_loss_type="dual_clip" \ + +trainer.algorithm.overlong_buffer.len=$OVERLONG_BUFFER_LEN \ + +trainer.algorithm.overlong_buffer.penalty_factor=$OVERLONG_BUFFER_PENALTY_FACTOR \ + trainer.algorithm.loss_reduction=$LOSS_REDUCTION \ + generator.enforce_eager=$ENFORCE_EAGER \ + generator.apply_overlong_filtering=$APPLY_OVERLONG_FILTERING \ + generator.sampling_params.temperature=$TEMPERATURE \ + generator.sampling_params.top_p=$TOP_P \ + generator.eval_sampling_params.top_p=$EVAL_TOP_P \ + generator.eval_sampling_params.temperature=$TEMPERATURE \ + trainer.algorithm.use_kl_loss=$USE_KL_LOSS \ + trainer.algorithm.clip_ratio_c=$CLIP_RATIO_C \ + trainer.policy.model.path="$MODEL_NAME" \ + trainer.placement.colocate_all=true \ + trainer.strategy=megatron \ + trainer.placement.policy_num_nodes=$NUM_NODES \ + trainer.placement.policy_num_gpus_per_node=$NUM_GPUS_PER_NODE \ + generator.num_inference_engines=$NUM_INFERENCE_ENGINES \ + generator.inference_engine_tensor_parallel_size=$INFERENCE_ENGINE_TENSOR_PARALLEL_SIZE \ + trainer.policy.megatron_config.tensor_model_parallel_size=$MEGATRON_TP \ + trainer.policy.megatron_config.pipeline_model_parallel_size=$MEGATRON_PP \ + trainer.policy.megatron_config.context_parallel_size=$MEGATRON_CP \ + trainer.policy.megatron_config.expert_model_parallel_size=$MEGATRON_EP \ + trainer.policy.megatron_config.expert_tensor_parallel_size=$MEGATRON_ETP \ + trainer.algorithm.use_tis=$USE_TIS \ + trainer.algorithm.tis_imp_ratio_cap=$TIS_IMP_RATIO_CAP \ + trainer.epochs=20 \ + trainer.algorithm.eps_clip_low=$CLIP_RATIO_LOW \ + trainer.algorithm.eps_clip_high=$CLIP_RATIO_HIGH \ + trainer.eval_batch_size=1024 \ + trainer.eval_before_train=true \ + trainer.eval_interval=5 \ + trainer.update_epochs_per_batch=1 \ + trainer.train_batch_size=$TRAIN_BATCH_SIZE \ + trainer.policy_mini_batch_size=$MINI_BATCH_SIZE \ + trainer.micro_forward_batch_size_per_gpu=1 \ + trainer.micro_train_batch_size_per_gpu=1 \ + trainer.ckpt_interval=10 \ + trainer.max_prompt_length=$MAX_PROMPT_LENGTH \ + generator.sampling_params.max_generate_length=$MAX_RESPONSE_LENGTH \ + trainer.policy.optimizer_config.lr=$LR \ + trainer.policy.optimizer_config.num_warmup_steps=160 \ + trainer.policy.optimizer_config.weight_decay=0.1 \ + trainer.policy.optimizer_config.max_grad_norm=1.0 \ + generator.backend=vllm \ + generator.run_engines_locally=true \ + generator.weight_sync_backend=nccl \ + generator.async_engine=false \ + generator.batched=true \ + environment.env_class=aime \ + generator.n_samples_per_prompt=$N_SAMPLES_PER_PROMPT \ + generator.eval_n_samples_per_prompt=$EVAL_N_SAMPLES_PER_PROMPT \ + generator.gpu_memory_utilization=0.8 \ + trainer.logger="$LOGGER" \ + trainer.project_name="dapo_aime" \ + trainer.run_name="dapo_qwen3_1.7b_base_megatron_tp${MEGATRON_TP}_pp${MEGATRON_PP}_cp${MEGATRON_CP}_loss_sum_average_in_collective" \ + trainer.export_path="$HOME/exports/dapo_qwen3_1.7b_base_megatron_tp${MEGATRON_TP}_pp${MEGATRON_PP}_cp${MEGATRON_CP}_loss_sum_average_in_collective" \ + trainer.hf_save_interval=300 \ + trainer.resume_mode=latest \ + trainer.max_ckpts_to_keep=3 \ + trainer.ckpt_path="$HOME/ckpts/dapo_qwen3_1.7b_base_megatron_tp${MEGATRON_TP}_pp${MEGATRON_PP}_cp${MEGATRON_CP}_loss_sum_average_in_collective" \ + $@ \ No newline at end of file diff --git a/skyrl-train/skyrl_train/workers/megatron/megatron_model_wrapper.py b/skyrl-train/skyrl_train/workers/megatron/megatron_model_wrapper.py index 55067dab4..b8efec79a 100644 --- a/skyrl-train/skyrl_train/workers/megatron/megatron_model_wrapper.py +++ b/skyrl-train/skyrl_train/workers/megatron/megatron_model_wrapper.py @@ -216,10 +216,16 @@ def loss_func(logits, data): loss_mask = data["loss_mask"] rollout_action_logprobs = data["rollout_action_logprobs"] action_mask = data.get("action_mask") + num_microbatches = data.get("num_microbatches") tp_grp = mpu.get_tensor_model_parallel_group() tp_rank = mpu.get_tensor_model_parallel_rank() + # Megatron's pipeline parallel forward_backward_func internally divides loss by num_microbatches + # https://github.com/NVIDIA/Megatron-LM/blob/core_v0.15.2/megatron/core/pipeline_parallel/schedules.py#L248 + # we want to maintain a sum of losses across all micro batches, so we reverse this division. + loss_scale = num_microbatches + # temperature normalization if temperature != 1.0: logits.div_(temperature) @@ -246,6 +252,8 @@ def loss_func(logits, data): rollout_logprobs=rollout_action_logprobs, ) + policy_loss = policy_loss * loss_scale + # SFT path: cross_entropy loss (negative log likelihood) if resolved_loss_name == "cross_entropy": loss = policy_loss diff --git a/skyrl-train/skyrl_train/workers/megatron/megatron_worker.py b/skyrl-train/skyrl_train/workers/megatron/megatron_worker.py index 60c591d34..1073fd395 100644 --- a/skyrl-train/skyrl_train/workers/megatron/megatron_worker.py +++ b/skyrl-train/skyrl_train/workers/megatron/megatron_worker.py @@ -557,6 +557,9 @@ def forward_backward( } ) + for m_batch in micro_buffer: + m_batch["num_microbatches"] = len(micro_buffer) + if not micro_buffer: return {} From ef24d325d168383ca8b6d808781e20a8e28fbc7e Mon Sep 17 00:00:00 2001 From: Eric Tang Date: Mon, 2 Feb 2026 17:54:26 +0000 Subject: [PATCH 14/20] x --- .../examples/megatron/run_megatron_dapo_qwen3_1.7b.sh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/skyrl-train/examples/megatron/run_megatron_dapo_qwen3_1.7b.sh b/skyrl-train/examples/megatron/run_megatron_dapo_qwen3_1.7b.sh index 2cd754cbb..8d08ae008 100644 --- a/skyrl-train/examples/megatron/run_megatron_dapo_qwen3_1.7b.sh +++ b/skyrl-train/examples/megatron/run_megatron_dapo_qwen3_1.7b.sh @@ -111,10 +111,10 @@ uv run --isolated --extra mcore -m examples.algorithms.dapo.main_dapo \ generator.gpu_memory_utilization=0.8 \ trainer.logger="$LOGGER" \ trainer.project_name="dapo_aime" \ - trainer.run_name="dapo_qwen3_1.7b_base_megatron_tp${MEGATRON_TP}_pp${MEGATRON_PP}_cp${MEGATRON_CP}_loss_sum_average_in_collective" \ - trainer.export_path="$HOME/exports/dapo_qwen3_1.7b_base_megatron_tp${MEGATRON_TP}_pp${MEGATRON_PP}_cp${MEGATRON_CP}_loss_sum_average_in_collective" \ + trainer.run_name="dapo_qwen3_1.7b_base_megatron_tp${MEGATRON_TP}_pp${MEGATRON_PP}_cp${MEGATRON_CP}_loss_sum_num_micro_batches_average_in_collective" \ + trainer.export_path="$HOME/exports/dapo_qwen3_1.7b_base_megatron_tp${MEGATRON_TP}_pp${MEGATRON_PP}_cp${MEGATRON_CP}_loss_sum_num_micro_batches_average_in_collective" \ trainer.hf_save_interval=300 \ trainer.resume_mode=latest \ trainer.max_ckpts_to_keep=3 \ - trainer.ckpt_path="$HOME/ckpts/dapo_qwen3_1.7b_base_megatron_tp${MEGATRON_TP}_pp${MEGATRON_PP}_cp${MEGATRON_CP}_loss_sum_average_in_collective" \ + trainer.ckpt_path="$HOME/ckpts/dapo_qwen3_1.7b_base_megatron_tp${MEGATRON_TP}_pp${MEGATRON_PP}_cp${MEGATRON_CP}_loss_sum_num_micro_batches_average_in_collective" \ $@ \ No newline at end of file From 89145e5be8d592202d4357bd870943ce5f2baa3f Mon Sep 17 00:00:00 2001 From: Eric Tang Date: Mon, 2 Feb 2026 19:58:00 +0000 Subject: [PATCH 15/20] add dp size scaling to megatron and move fsdp loss scaling to loss rather than grad --- .../workers/megatron/megatron_model_wrapper.py | 6 +++++- skyrl-train/skyrl_train/workers/worker.py | 9 ++++----- skyrl-train/tests/gpu/gpu_ci/test_megatron_worker.py | 5 +++-- 3 files changed, 12 insertions(+), 8 deletions(-) diff --git a/skyrl-train/skyrl_train/workers/megatron/megatron_model_wrapper.py b/skyrl-train/skyrl_train/workers/megatron/megatron_model_wrapper.py index b8efec79a..2a28113ff 100644 --- a/skyrl-train/skyrl_train/workers/megatron/megatron_model_wrapper.py +++ b/skyrl-train/skyrl_train/workers/megatron/megatron_model_wrapper.py @@ -218,13 +218,16 @@ def loss_func(logits, data): action_mask = data.get("action_mask") num_microbatches = data.get("num_microbatches") + dp_size = mpu.get_data_parallel_world_size() tp_grp = mpu.get_tensor_model_parallel_group() tp_rank = mpu.get_tensor_model_parallel_rank() # Megatron's pipeline parallel forward_backward_func internally divides loss by num_microbatches # https://github.com/NVIDIA/Megatron-LM/blob/core_v0.15.2/megatron/core/pipeline_parallel/schedules.py#L248 # we want to maintain a sum of losses across all micro batches, so we reverse this division. - loss_scale = num_microbatches + # we additionally multiply by the data parallelism size to undo the DDP all-reduce mean + # https://github.com/NVIDIA/Megatron-LM/blob/core_v0.15.2/megatron/core/distributed/distributed_data_parallel.py#L285 + loss_scale = num_microbatches * dp_size # temperature normalization if temperature != 1.0: @@ -263,6 +266,7 @@ def loss_func(logits, data): elementwise_loss = -action_log_probs if loss_mask is not None: elementwise_loss = elementwise_loss * loss_mask + elementwise_loss = elementwise_loss * loss_scale # Build per-sequence loss_fn_outputs batch_size = action_log_probs.shape[0] diff --git a/skyrl-train/skyrl_train/workers/worker.py b/skyrl-train/skyrl_train/workers/worker.py index f37ebd7d6..aef8aaa04 100644 --- a/skyrl-train/skyrl_train/workers/worker.py +++ b/skyrl-train/skyrl_train/workers/worker.py @@ -788,6 +788,9 @@ def _forward_backward_micro( rollout_logprobs=rollout_action_logprobs, ) + loss_scale = self.mesh_rank.dp_size + policy_loss = policy_loss * loss_scale + # SFT path: skip KL/entropy terms, return per-token outputs for Tinker API if resolved_loss_name == "cross_entropy": loss = policy_loss @@ -798,6 +801,7 @@ def _forward_backward_micro( elementwise_loss = -action_log_probs if loss_mask is not None: elementwise_loss = elementwise_loss * loss_mask + elementwise_loss = elementwise_loss * loss_scale # Build per-sequence loss_fn_outputs (matches Tinker's ForwardBackwardOutput structure) # Trim to actual response length per sample (Tinker expects variable-length arrays @@ -887,11 +891,6 @@ def optim_step(self) -> float: Returns: The gradient norm (before scaling, after clipping) """ - # Scale gradients by data parallelism size to undo the DDP all-reduce mean. - for param in self.model.parameters(): - if param.grad is not None: - param.grad.mul_(self.strategy.world_size) - # Perform optimizer step (includes gradient clipping) grad_norm = self.strategy.optimizer_step(self.optimizer, self.model, self.scheduler, name="actor") diff --git a/skyrl-train/tests/gpu/gpu_ci/test_megatron_worker.py b/skyrl-train/tests/gpu/gpu_ci/test_megatron_worker.py index 674c63abd..a9aa33891 100644 --- a/skyrl-train/tests/gpu/gpu_ci/test_megatron_worker.py +++ b/skyrl-train/tests/gpu/gpu_ci/test_megatron_worker.py @@ -449,7 +449,8 @@ async def test_megatron_train( Full test: initialize actor group, send dummy experience to training_step, validate output. """ cfg = get_test_actor_config(model_name=MODEL_NAME if ep == 1 else MOE_MODEL_NAME) - batch = get_test_training_batch(batch_size=gpus_per_node) + batch_size = gpus_per_node * 2 + batch = get_test_training_batch(batch_size=batch_size) cfg.trainer.strategy = "megatron" cfg.trainer.placement.policy_num_gpus_per_node = gpus_per_node @@ -474,7 +475,7 @@ async def test_megatron_train( cfg.trainer.policy.megatron_config.transformer_config_kwargs = transformer_config_kwargs # set batch sizes correctly - cfg.trainer.train_batch_size = gpus_per_node + cfg.trainer.train_batch_size = batch_size cfg.trainer.policy_mini_batch_size = gpus_per_node cfg.generator.n_samples_per_prompt = 1 cfg.trainer.micro_train_batch_size_per_gpu = 1 From 49babe4bd6309b227bd1dea08c1c6148620e3ce4 Mon Sep 17 00:00:00 2001 From: Eric Tang Date: Tue, 3 Feb 2026 20:13:19 +0000 Subject: [PATCH 16/20] debugging megatron fsdp loss diff --- skyrl-train/examples/async/async_trainer.py | 3 -- .../skyrl_train/fully_async_trainer.py | 4 --- skyrl-train/skyrl_train/trainer.py | 28 ++++++++++++++----- skyrl-train/skyrl_train/utils/ppo_utils.py | 22 --------------- skyrl-train/skyrl_train/workers/worker.py | 10 ++++++- .../skyrl_train/workers/worker_utils.py | 5 +++- .../tests/gpu/gpu_ci/test_megatron_worker.py | 9 ++++-- skyrl-train/tests/gpu/test_grpo_sp_sanity.py | 4 --- 8 files changed, 40 insertions(+), 45 deletions(-) diff --git a/skyrl-train/examples/async/async_trainer.py b/skyrl-train/examples/async/async_trainer.py index 43268b5f2..36cf3234d 100644 --- a/skyrl-train/examples/async/async_trainer.py +++ b/skyrl-train/examples/async/async_trainer.py @@ -5,7 +5,6 @@ from skyrl_train.trainer import RayPPOTrainer from tqdm import tqdm from skyrl_train.utils import Timer -from skyrl_train.utils.ppo_utils import normalize_advantages_dict from skyrl_train.training_batch import TrainingInputBatch from skyrl_train.generators.base import GeneratorOutput from skyrl_train.utils.trainer_utils import ResumeMode @@ -145,8 +144,6 @@ async def _run_training(self, generation_buffer): training_input.pop(key) training_input.metadata.pop("uids") - if self.cfg.trainer.algorithm.advantage_batch_normalize: - training_input = normalize_advantages_dict(training_input) if self.cfg.trainer.dump_data_batch: # dump data to file diff --git a/skyrl-train/skyrl_train/fully_async_trainer.py b/skyrl-train/skyrl_train/fully_async_trainer.py index 1894e1e38..f5582911a 100644 --- a/skyrl-train/skyrl_train/fully_async_trainer.py +++ b/skyrl-train/skyrl_train/fully_async_trainer.py @@ -20,7 +20,6 @@ from skyrl_train.trainer import RayPPOTrainer from tqdm import tqdm from skyrl_train.utils import Timer -from skyrl_train.utils.ppo_utils import normalize_advantages_dict from skyrl_train.training_batch import TrainingInputBatch from skyrl_train.generators.base import GeneratorOutput from skyrl_train.utils.trainer_utils import ResumeMode, build_dataloader @@ -511,9 +510,6 @@ async def _run_training(self, training_input: TrainingInputBatch): training_input.pop(key) training_input.metadata.pop("uids") - if self.cfg.trainer.algorithm.advantage_batch_normalize: - training_input = normalize_advantages_dict(training_input) - if self.cfg.trainer.dump_data_batch: # dump data to file with Timer("dump_data_batch"): diff --git a/skyrl-train/skyrl_train/trainer.py b/skyrl-train/skyrl_train/trainer.py index 2f0360838..2b83a7b90 100644 --- a/skyrl-train/skyrl_train/trainer.py +++ b/skyrl-train/skyrl_train/trainer.py @@ -54,7 +54,6 @@ compute_approx_kl, get_kl_controller, masked_mean, - normalize_advantages_dict, ) from skyrl_train.utils.tracking import Tracking from skyrl_train.utils.trainer_utils import ( @@ -269,9 +268,6 @@ async def train(self): training_input.pop(key) training_input.metadata.pop("uids") - if self.cfg.trainer.algorithm.advantage_batch_normalize: - training_input = normalize_advantages_dict(training_input) - if self.cfg.trainer.dump_data_batch: # dump data to file with Timer("dump_data_batch"): @@ -1023,15 +1019,28 @@ def apply_reward_kl_penalty( def normalize_minibatch_advantages(self, data: TrainingInputBatch) -> TrainingInputBatch: """Normalize the advantages in the mini-batch. - This normalization results in calculating the correct minibatch loss for the - given loss reduction type when reducing the loss with a sum. + This function handles two types of normalization: + 1. Batch normalization (z-score): if advantage_batch_normalize is True, + normalizes advantages to have zero mean and unit variance. + 2. Loss reduction normalization: scales advantages based on the loss_reduction + type to calculate the correct minibatch loss when reducing with a sum. """ advantages = data["advantages"] loss_mask = data["loss_mask"] + response_mask = data["response_mask"] # NOTE: Do not modify the tensor in place! # Otherwise subsequent epochs will keep dividing the same tensor. + # Step 1: Z-score normalization (if enabled) + if self.cfg.trainer.algorithm.advantage_batch_normalize: + num_actions = response_mask.sum() + mean = advantages.mean() + std = ((advantages - mean).pow(2) * response_mask).sum() + rstd = (std / num_actions).clamp(min=1e-8).rsqrt() + advantages = (advantages - mean) * rstd + + # Step 2: Loss reduction normalization # Option 1: token mean if self.cfg.trainer.algorithm.loss_reduction == "token_mean": data["advantages"] = advantages / loss_mask.sum() @@ -1041,12 +1050,17 @@ def normalize_minibatch_advantages(self, data: TrainingInputBatch) -> TrainingIn batch_size = len(data) data["advantages"] = advantages / (batch_size * loss_mask.sum(dim=-1, keepdim=True)) - # option 3: Dr. GRPO style loss reduction to avoid length bias by normalizing by a constant + # Option 3: Dr. GRPO style loss reduction to avoid length bias by normalizing by a constant elif self.cfg.trainer.algorithm.loss_reduction == "seq_mean_token_sum_norm": batch_size = len(data) max_seq_len = self.cfg.trainer.algorithm.max_seq_len data["advantages"] = advantages / (batch_size * max_seq_len) + else: + # No loss reduction normalization, but still apply batch normalization if it was done + if self.cfg.trainer.algorithm.advantage_batch_normalize: + data["advantages"] = advantages + return data def _execute_training_step(self, model: str, data: TrainingInputBatch) -> Dict[str, float]: diff --git a/skyrl-train/skyrl_train/utils/ppo_utils.py b/skyrl-train/skyrl_train/utils/ppo_utils.py index a4f97a948..92b0d16a7 100644 --- a/skyrl-train/skyrl_train/utils/ppo_utils.py +++ b/skyrl-train/skyrl_train/utils/ppo_utils.py @@ -127,28 +127,6 @@ def compute_approx_kl( kld = kld * loss_mask return kld - -@torch.no_grad() -def normalize_advantages_dict(data: TrainingInputBatch) -> TrainingInputBatch: - """Normalizes the advantages in the data batch. - - Expects: - - `["advantages"]`: Float[torch.Tensor, "batch_size seqlen"] - - `["response_mask"]`: Float[torch.Tensor, "batch_size seqlen"] - """ - advantages: Float[torch.Tensor, "batch_size seqlen"] = data["advantages"] - response_masks: Float[torch.Tensor, "batch_size seqlen"] = data["response_mask"] - num_actions: float = response_masks.sum() - # mean - mean: float = advantages.mean() - # std - std: float = ((advantages - mean).pow(2) * response_masks).sum() - rstd: float = (std / num_actions).clamp(min=1e-8).rsqrt() - - data["advantages"] = (advantages - mean) * rstd - return data - - def masked_var(values, mask, unbiased=True): """Compute variance of tensor with masked values.""" mean = masked_mean(values, mask) diff --git a/skyrl-train/skyrl_train/workers/worker.py b/skyrl-train/skyrl_train/workers/worker.py index 88b116a11..e8d38327b 100644 --- a/skyrl-train/skyrl_train/workers/worker.py +++ b/skyrl-train/skyrl_train/workers/worker.py @@ -685,7 +685,7 @@ def forward_backward( for k, v in metrics.items(): all_metrics[k].append(v) - + result = reduce_metrics(dict(all_metrics)) # Add back loss_fn_outputs (concatenated across micro-batches) @@ -890,7 +890,15 @@ def _forward_backward_micro( loss_fn_outputs = status.pop("loss_fn_outputs", None) # All-reduce metrics across DP workers + # hacky work aroudn to all reduce sum for loss while keeping mean for other metrics for now + loss_status = { + "final_loss": status["final_loss"], + "policy_loss": status["policy_loss"], + } + loss_status = self.strategy.all_reduce(loss_status, op="sum") status = self.strategy.all_reduce(status) + status["final_loss"] = loss_status["final_loss"] + status["policy_loss"] = loss_status["policy_loss"] # Add back loss_fn_outputs after all_reduce if loss_fn_outputs is not None: diff --git a/skyrl-train/skyrl_train/workers/worker_utils.py b/skyrl-train/skyrl_train/workers/worker_utils.py index 3fe22be4c..9235b714a 100644 --- a/skyrl-train/skyrl_train/workers/worker_utils.py +++ b/skyrl-train/skyrl_train/workers/worker_utils.py @@ -12,7 +12,10 @@ def reduce_metrics(metrics: Dict[str, List[float]]) -> Dict[str, float]: for k, v in metrics.items(): assert len(v) > 0, f"No metrics for key {k}" assert all(isinstance(x, (int, float)) for x in v), f"Metrics for key {k} are not all numbers" - reduced_metrics[k] = sum(v) / len(v) + if k.endswith("_loss"): + reduced_metrics[k] = sum(v) + else: + reduced_metrics[k] = sum(v) / len(v) return reduced_metrics diff --git a/skyrl-train/tests/gpu/gpu_ci/test_megatron_worker.py b/skyrl-train/tests/gpu/gpu_ci/test_megatron_worker.py index d4931b094..575b0f9fe 100644 --- a/skyrl-train/tests/gpu/gpu_ci/test_megatron_worker.py +++ b/skyrl-train/tests/gpu/gpu_ci/test_megatron_worker.py @@ -422,7 +422,7 @@ async def test_megatron_lora_forward(ray_init_fixture, tp, pp, cp, ep, etp, gpus ("policy", 4, 1, 1, 4, 1, 4, True, False, True), ], ids=[ - "tp2_pp2_policy_seq_packing", + "x", "tp2_pp2_policy_seq_packing_with_entropy_loss", "tp2_pp2_policy_lora", "tp2_pp2_policy_unpacked", @@ -523,15 +523,18 @@ async def test_megatron_train( # Both FSDP and Megatron use forward_backward + optim_step (unified interface) batch.metadata["global_step"] = 0 - results_fsdp = ray.get(actor_group.async_run_ray_method("pass_through", "forward_backward", batch)) + results_fsdp = ray.get(actor_group.async_run_ray_method("mesh", "forward_backward", batch)) ray.get(actor_group.async_run_ray_method("pass_through", "optim_step")) # Get learning rate from worker lr_results = ray.get(actor_group.async_run_ray_method("pass_through", "get_lr")) for i, result in enumerate(results_fsdp): result["policy_lr"] = lr_results[i] - + print("megatron results: ", results_megatron) + print("\n\n") print("megatron results: ", results_megatron[0]) print("\n\n") + print("fsdp results: ", results_fsdp) + print("\n\n") print("fsdp results: ", results_fsdp[0]) keys_to_compare = ["policy_loss", "policy_lr", "ppo_clip_ratio", "policy_entropy", "policy_kl", "final_loss"] diff --git a/skyrl-train/tests/gpu/test_grpo_sp_sanity.py b/skyrl-train/tests/gpu/test_grpo_sp_sanity.py index eff1ce650..47d16dbea 100644 --- a/skyrl-train/tests/gpu/test_grpo_sp_sanity.py +++ b/skyrl-train/tests/gpu/test_grpo_sp_sanity.py @@ -15,7 +15,6 @@ from skyrl_train.config import SkyRLConfig from skyrl_train.utils import Timer -from skyrl_train.utils.ppo_utils import normalize_advantages_dict import asyncio @@ -122,9 +121,6 @@ def train(self): # remove some unwanted keys data.pop(batch_keys=["rewards"]) - if self.cfg.trainer.algorithm.advantage_batch_normalize: - data = normalize_advantages_dict(data) - # 4. train policy/critic model with Timer("train_critic_and_policy", self.all_timings): status = self.train_critic_and_policy(data) From eb241dedb7666457140dfd2844dc2df5297d087b Mon Sep 17 00:00:00 2001 From: Eric Tang Date: Wed, 4 Feb 2026 02:04:01 +0000 Subject: [PATCH 17/20] start trying to make metrics invariant to dp size for megatron --- skyrl-train/examples/async/async_trainer.py | 1 - .../examples/megatron/run_fsdp_baseline.sh | 8 ++++---- skyrl-train/examples/megatron/run_megatron.sh | 13 +++++++++---- skyrl-train/skyrl_train/distributed/strategy.py | 17 ++++++++++------- skyrl-train/skyrl_train/trainer.py | 1 - skyrl-train/skyrl_train/utils/ppo_utils.py | 2 +- .../workers/megatron/megatron_worker.py | 12 ++++-------- skyrl-train/skyrl_train/workers/worker.py | 2 +- skyrl-train/skyrl_train/workers/worker_utils.py | 11 +++++++---- .../tests/gpu/gpu_ci/test_megatron_worker.py | 6 +++++- 10 files changed, 41 insertions(+), 32 deletions(-) diff --git a/skyrl-train/examples/async/async_trainer.py b/skyrl-train/examples/async/async_trainer.py index 36cf3234d..18e7685f0 100644 --- a/skyrl-train/examples/async/async_trainer.py +++ b/skyrl-train/examples/async/async_trainer.py @@ -144,7 +144,6 @@ async def _run_training(self, generation_buffer): training_input.pop(key) training_input.metadata.pop("uids") - if self.cfg.trainer.dump_data_batch: # dump data to file with Timer("dump_data_batch"): diff --git a/skyrl-train/examples/megatron/run_fsdp_baseline.sh b/skyrl-train/examples/megatron/run_fsdp_baseline.sh index 5291b5d96..e1121ad3d 100644 --- a/skyrl-train/examples/megatron/run_fsdp_baseline.sh +++ b/skyrl-train/examples/megatron/run_fsdp_baseline.sh @@ -25,7 +25,7 @@ uv run --isolated --extra $INFERENCE_BACKEND -m skyrl_train.entrypoints.main_bas generator.inference_engine_tensor_parallel_size=1 \ trainer.epochs=20 \ trainer.eval_batch_size=1024 \ - trainer.eval_before_train=true \ + trainer.eval_before_train=false \ trainer.eval_interval=5 \ trainer.update_epochs_per_batch=1 \ trainer.train_batch_size=128 \ @@ -35,8 +35,8 @@ uv run --isolated --extra $INFERENCE_BACKEND -m skyrl_train.entrypoints.main_bas trainer.ckpt_interval=10 \ trainer.max_prompt_length=512 \ generator.sampling_params.max_generate_length=1024 \ - trainer.policy.optimizer_config.lr=4.0e-6 \ - trainer.algorithm.use_kl_loss=true \ + trainer.policy.optimizer_config.lr=1.0e-6 \ + trainer.algorithm.use_kl_loss=false \ generator.backend=$INFERENCE_BACKEND \ generator.run_engines_locally=true \ generator.weight_sync_backend=nccl \ @@ -47,7 +47,7 @@ uv run --isolated --extra $INFERENCE_BACKEND -m skyrl_train.entrypoints.main_bas generator.gpu_memory_utilization=0.8 \ trainer.logger="$LOGGER" \ trainer.project_name="gsm8k_megatron" \ - trainer.run_name="gsm8k_fsdp1_4gpus" \ + trainer.run_name="gsm8k_fsdp1_4gpus_loss_sum" \ trainer.resume_mode=null \ trainer.ckpt_path="$HOME/ckpts/gsm8k_fsdp_ckpt" \ $@ \ No newline at end of file diff --git a/skyrl-train/examples/megatron/run_megatron.sh b/skyrl-train/examples/megatron/run_megatron.sh index cf0d9f9ed..bf341222a 100644 --- a/skyrl-train/examples/megatron/run_megatron.sh +++ b/skyrl-train/examples/megatron/run_megatron.sh @@ -13,8 +13,8 @@ MODEL_NAME="Qwen/Qwen3-0.6B" INFERENCE_BACKEND="vllm" # currently only vllm is supported for megatron -MEGATRON_TP=2 -MEGATRON_PP=2 +MEGATRON_TP=1 +MEGATRON_PP=1 MEGATRON_CP=1 # torch profiler config @@ -22,6 +22,9 @@ ENABLE_TORCH_PROFILER=false RANKS_TO_PROFILE="[0]" SAVE_PATH="$HOME/megatron_prof/tp${MEGATRON_TP}_pp${MEGATRON_PP}_cp${MEGATRON_CP}_${MODEL_NAME}" +TIS_TYPE="token" +TIS_RATIO_CLIP_HIGH=2.0 + uv run --isolated --extra mcore -m skyrl_train.entrypoints.main_base \ data.train_data="['$DATA_DIR/train.parquet']" \ data.val_data="['$DATA_DIR/validation.parquet']" \ @@ -42,6 +45,8 @@ uv run --isolated --extra mcore -m skyrl_train.entrypoints.main_base \ trainer.ref.megatron_config.tensor_model_parallel_size=$MEGATRON_TP \ trainer.ref.megatron_config.context_parallel_size=$MEGATRON_CP \ trainer.ref.megatron_config.pipeline_model_parallel_size=$MEGATRON_PP \ + trainer.algorithm.off_policy_correction.tis_ratio_type=$TIS_TYPE \ + trainer.algorithm.off_policy_correction.token_tis_ratio_clip_high=$TIS_RATIO_CLIP_HIGH \ trainer.use_sample_packing=true \ trainer.epochs=20 \ trainer.eval_batch_size=1024 \ @@ -56,7 +61,7 @@ uv run --isolated --extra mcore -m skyrl_train.entrypoints.main_base \ trainer.max_prompt_length=512 \ generator.sampling_params.max_generate_length=1024 \ trainer.policy.optimizer_config.lr=1.0e-6 \ - trainer.algorithm.use_kl_loss=true \ + trainer.algorithm.use_kl_loss=false \ generator.backend=$INFERENCE_BACKEND \ generator.run_engines_locally=true \ generator.weight_sync_backend=nccl \ @@ -67,7 +72,7 @@ uv run --isolated --extra mcore -m skyrl_train.entrypoints.main_base \ generator.gpu_memory_utilization=0.7 \ trainer.logger="$LOGGER" \ trainer.project_name="gsm8k_megatron" \ - trainer.run_name="gsm8k_megatron_tp${MEGATRON_TP}_pp${MEGATRON_PP}_cp${MEGATRON_CP}_${MODEL_NAME}" \ + trainer.run_name="gsm8k_megatron_tp${MEGATRON_TP}_pp${MEGATRON_PP}_cp${MEGATRON_CP}_${MODEL_NAME}_loss_sum_with_tis" \ trainer.resume_mode=null \ trainer.ckpt_path="$HOME/ckpts/gsm8k_megatron_ckpt" \ $@ \ No newline at end of file diff --git a/skyrl-train/skyrl_train/distributed/strategy.py b/skyrl-train/skyrl_train/distributed/strategy.py index fb8269e97..c638a4c3c 100644 --- a/skyrl-train/skyrl_train/distributed/strategy.py +++ b/skyrl-train/skyrl_train/distributed/strategy.py @@ -67,11 +67,11 @@ def get_rank(self) -> int: """Get current process rank""" return dist.get_rank() - def all_reduce(self, data: DataT, op="mean") -> DataT: + def all_reduce(self, data: DataT, op="mean", group=None) -> DataT: """Perform all_reduce across all processes""" assert op in ("mean", "max", "sum", "min") if isinstance(data, dict): - return {k: self.all_reduce(v, op) for k, v in data.items()} + return {k: self.all_reduce(v, op, group) for k, v in data.items()} else: is_tensor = True if not isinstance(data, torch.Tensor): @@ -82,14 +82,17 @@ def all_reduce(self, data: DataT, op="mean") -> DataT: if is_cpu_tensor: data = data.to(torch.cuda.current_device()) if op == "mean": - data /= self.world_size - dist.all_reduce(data, op=dist.ReduceOp.SUM) + if group is None: + data /= self.world_size + else: + data /= group.size() + dist.all_reduce(data, op=dist.ReduceOp.SUM, group=group) elif op == "max": - dist.all_reduce(data, op=dist.ReduceOp.MAX) + dist.all_reduce(data, op=dist.ReduceOp.MAX, group=group) elif op == "min": - dist.all_reduce(data, op=dist.ReduceOp.MIN) + dist.all_reduce(data, op=dist.ReduceOp.MIN, group=group) elif op == "sum": - dist.all_reduce(data, op=dist.ReduceOp.SUM) + dist.all_reduce(data, op=dist.ReduceOp.SUM, group=group) if is_cpu_tensor: data = data.cpu() return data.item() if not is_tensor else data diff --git a/skyrl-train/skyrl_train/trainer.py b/skyrl-train/skyrl_train/trainer.py index a8e29d021..4642c2b31 100644 --- a/skyrl-train/skyrl_train/trainer.py +++ b/skyrl-train/skyrl_train/trainer.py @@ -55,7 +55,6 @@ get_kl_controller, masked_mean, ) -from skyrl_train.utils.torch_utils import masked_mean from skyrl_train.utils.tracking import Tracking from skyrl_train.utils.trainer_utils import ( GLOBAL_STEP_PREFIX, diff --git a/skyrl-train/skyrl_train/utils/ppo_utils.py b/skyrl-train/skyrl_train/utils/ppo_utils.py index 66bbed775..c75455ae1 100644 --- a/skyrl-train/skyrl_train/utils/ppo_utils.py +++ b/skyrl-train/skyrl_train/utils/ppo_utils.py @@ -29,7 +29,6 @@ from omegaconf import DictConfig from skyrl_train.config import AlgorithmConfig -from skyrl_train.training_batch import TrainingInputBatch from skyrl_train.utils.off_policy_correction_utils import apply_off_policy_correction from skyrl_train.utils.torch_utils import masked_mean, safe_exp_delta @@ -123,6 +122,7 @@ def compute_approx_kl( kld = kld * loss_mask return kld + def masked_var(values, mask, unbiased=True): """Compute variance of tensor with masked values.""" mean = masked_mean(values, mask) diff --git a/skyrl-train/skyrl_train/workers/megatron/megatron_worker.py b/skyrl-train/skyrl_train/workers/megatron/megatron_worker.py index 003285248..21ba1134a 100644 --- a/skyrl-train/skyrl_train/workers/megatron/megatron_worker.py +++ b/skyrl-train/skyrl_train/workers/megatron/megatron_worker.py @@ -583,9 +583,6 @@ def forward_backward( if self.empty_cuda_cache: torch.cuda.empty_cache() - # Track number of micro-batches for metrics - self._micro_batches_accumulated += len(micro_buffer) - # Aggregate metrics across micro-batches all_loss_fn_outputs = [] # Handle separately from scalar metrics for metrics in metrics_list: @@ -595,10 +592,12 @@ def forward_backward( for k, v in metrics.items(): all_metrics[k].append(v) - # Reduce and all-reduce metrics + # Reduce and all-reduce metrics across DP ranks only + # (metrics should be identical within DP groups, i.e., across TP/PP/SP ranks) status = reduce_metrics(dict(all_metrics)) status["policy_lr"] = self.optimizer.param_groups[0]["lr"] - status = all_reduce_metrics(status, self.strategy) + group = mpu.get_data_parallel_group(with_context_parallel=True) + status = all_reduce_metrics(status, self.strategy, group=group) # Add loss_fn_outputs back (not reduced, kept as list) if all_loss_fn_outputs: @@ -618,9 +617,6 @@ def optim_step(self) -> Optional[float]: """ grad_norm = self.strategy.optimizer_step(self.optimizer, self.model, self.scheduler, name="actor") - # Reset counter for next accumulation cycle - self._micro_batches_accumulated = 0 - if grad_norm is not None: grad_norm = grad_norm.detach().cpu().item() if hasattr(grad_norm, "item") else grad_norm return grad_norm diff --git a/skyrl-train/skyrl_train/workers/worker.py b/skyrl-train/skyrl_train/workers/worker.py index 6fb9e5f73..21fe3285f 100644 --- a/skyrl-train/skyrl_train/workers/worker.py +++ b/skyrl-train/skyrl_train/workers/worker.py @@ -685,7 +685,7 @@ def forward_backward( for k, v in metrics.items(): all_metrics[k].append(v) - + result = reduce_metrics(dict(all_metrics)) # Add back loss_fn_outputs (concatenated across micro-batches) diff --git a/skyrl-train/skyrl_train/workers/worker_utils.py b/skyrl-train/skyrl_train/workers/worker_utils.py index 2e13acddc..93b133f7b 100644 --- a/skyrl-train/skyrl_train/workers/worker_utils.py +++ b/skyrl-train/skyrl_train/workers/worker_utils.py @@ -24,16 +24,19 @@ def reduce_metrics(metrics: Dict[str, List[float]]) -> Dict[str, float]: return reduced_metrics -def all_reduce_metrics(metrics: Dict[str, List[float]], strategy: DistributedStrategy) -> Dict[str, float]: +def all_reduce_metrics(metrics: Dict[str, List[float]], strategy: DistributedStrategy, group=None) -> Dict[str, float]: """All reduce metrics across all processes.""" min_metrics = {k: v for k, v in metrics.items() if k.endswith("_min")} max_metrics = {k: v for k, v in metrics.items() if k.endswith("_max")} mean_metrics = {k: v for k, v in metrics.items() if k not in min_metrics and k not in max_metrics} - status_mean = strategy.all_reduce(mean_metrics, op="mean") - status_min = strategy.all_reduce(min_metrics, op="min") - status_max = strategy.all_reduce(max_metrics, op="max") + sum_metrics = {k: v for k, v in metrics.items() if k.endswith("_loss")} + status_mean = strategy.all_reduce(mean_metrics, op="mean", group=group) + status_min = strategy.all_reduce(min_metrics, op="min", group=group) + status_max = strategy.all_reduce(max_metrics, op="max", group=group) + status_sum = strategy.all_reduce(sum_metrics, op="sum", group=group) status_mean.update(status_min) status_mean.update(status_max) + status_mean.update(status_sum) return status_mean diff --git a/skyrl-train/tests/gpu/gpu_ci/test_megatron_worker.py b/skyrl-train/tests/gpu/gpu_ci/test_megatron_worker.py index 3a86a60a5..4916c5645 100644 --- a/skyrl-train/tests/gpu/gpu_ci/test_megatron_worker.py +++ b/skyrl-train/tests/gpu/gpu_ci/test_megatron_worker.py @@ -413,7 +413,7 @@ async def test_megatron_lora_forward(ray_init_fixture, tp, pp, cp, ep, etp, gpus @pytest.mark.parametrize( ("worker_type", "tp", "pp", "cp", "ep", "etp", "gpus_per_node", "use_sample_packing", "use_entropy_loss", "lora"), [ - ("policy", 2, 2, 1, 1, 1, 4, True, False, False), + ("policy", 2, 1, 1, 1, 1, 4, True, False, False), ("policy", 2, 2, 1, 1, 1, 4, True, True, False), ("policy", 2, 2, 1, 1, 1, 4, True, False, True), ("policy", 2, 2, 1, 1, 1, 4, False, False, False), @@ -504,6 +504,10 @@ async def test_megatron_train( for k, v in result.items(): assert isinstance(v, (int, float)), f"{k} should be an int or float" + print("megatron results: ", results_megatron) + print("\n\n") + print("megatron results: ", results_megatron[0]) + ray.shutdown() ray_init_for_tests() From 772141a03772bba6964f455188085ae59407f12f Mon Sep 17 00:00:00 2001 From: Eric Tang Date: Wed, 4 Feb 2026 02:07:41 +0000 Subject: [PATCH 18/20] x --- .../megatron/run_megatron_dapo_qwen3_1.7b.sh | 21 ++++++++++--------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/skyrl-train/examples/megatron/run_megatron_dapo_qwen3_1.7b.sh b/skyrl-train/examples/megatron/run_megatron_dapo_qwen3_1.7b.sh index 8d08ae008..b558113d3 100644 --- a/skyrl-train/examples/megatron/run_megatron_dapo_qwen3_1.7b.sh +++ b/skyrl-train/examples/megatron/run_megatron_dapo_qwen3_1.7b.sh @@ -42,15 +42,16 @@ ENFORCE_EAGER=true # cuda graphs can cause some instability LR=1e-6 # megatron config -MEGATRON_TP=1 -MEGATRON_PP=1 +MEGATRON_TP=4 +MEGATRON_PP=2 MEGATRON_CP=1 MEGATRON_EP=1 MEGATRON_ETP=null + # TIS parameters TIS_IMP_RATIO_CAP=2.0 -USE_TIS=true +TIS_TYPE=token uv run --isolated --extra mcore -m examples.algorithms.dapo.main_dapo \ data.train_data="['$TRAIN_FILE']" \ @@ -80,8 +81,8 @@ uv run --isolated --extra mcore -m examples.algorithms.dapo.main_dapo \ trainer.policy.megatron_config.context_parallel_size=$MEGATRON_CP \ trainer.policy.megatron_config.expert_model_parallel_size=$MEGATRON_EP \ trainer.policy.megatron_config.expert_tensor_parallel_size=$MEGATRON_ETP \ - trainer.algorithm.use_tis=$USE_TIS \ - trainer.algorithm.tis_imp_ratio_cap=$TIS_IMP_RATIO_CAP \ + trainer.algorithm.off_policy_correction.tis_ratio_type=$TIS_TYPE \ + trainer.algorithm.off_policy_correction.token_tis_ratio_clip_high=$TIS_IMP_RATIO_CAP \ trainer.epochs=20 \ trainer.algorithm.eps_clip_low=$CLIP_RATIO_LOW \ trainer.algorithm.eps_clip_high=$CLIP_RATIO_HIGH \ @@ -91,8 +92,8 @@ uv run --isolated --extra mcore -m examples.algorithms.dapo.main_dapo \ trainer.update_epochs_per_batch=1 \ trainer.train_batch_size=$TRAIN_BATCH_SIZE \ trainer.policy_mini_batch_size=$MINI_BATCH_SIZE \ - trainer.micro_forward_batch_size_per_gpu=1 \ - trainer.micro_train_batch_size_per_gpu=1 \ + trainer.micro_forward_batch_size_per_gpu=8 \ + trainer.micro_train_batch_size_per_gpu=8 \ trainer.ckpt_interval=10 \ trainer.max_prompt_length=$MAX_PROMPT_LENGTH \ generator.sampling_params.max_generate_length=$MAX_RESPONSE_LENGTH \ @@ -111,10 +112,10 @@ uv run --isolated --extra mcore -m examples.algorithms.dapo.main_dapo \ generator.gpu_memory_utilization=0.8 \ trainer.logger="$LOGGER" \ trainer.project_name="dapo_aime" \ - trainer.run_name="dapo_qwen3_1.7b_base_megatron_tp${MEGATRON_TP}_pp${MEGATRON_PP}_cp${MEGATRON_CP}_loss_sum_num_micro_batches_average_in_collective" \ - trainer.export_path="$HOME/exports/dapo_qwen3_1.7b_base_megatron_tp${MEGATRON_TP}_pp${MEGATRON_PP}_cp${MEGATRON_CP}_loss_sum_num_micro_batches_average_in_collective" \ + trainer.run_name="dapo_qwen3_1.7b_base_megatron_tp${MEGATRON_TP}_pp${MEGATRON_PP}_cp${MEGATRON_CP}_loss_sum_dp1" \ + trainer.export_path="$HOME/exports/dapo_qwen3_1.7b_base_megatron_tp${MEGATRON_TP}_pp${MEGATRON_PP}_cp${MEGATRON_CP}_loss_sum_dp1" \ trainer.hf_save_interval=300 \ trainer.resume_mode=latest \ trainer.max_ckpts_to_keep=3 \ - trainer.ckpt_path="$HOME/ckpts/dapo_qwen3_1.7b_base_megatron_tp${MEGATRON_TP}_pp${MEGATRON_PP}_cp${MEGATRON_CP}_loss_sum_num_micro_batches_average_in_collective" \ + trainer.ckpt_path="$HOME/ckpts/dapo_qwen3_1.7b_base_megatron_tp${MEGATRON_TP}_pp${MEGATRON_PP}_cp${MEGATRON_CP}_loss_sum_dp1" \ $@ \ No newline at end of file From f66736aa0a8ef1392ab49875fef12c8b393abb21 Mon Sep 17 00:00:00 2001 From: Eric Tang Date: Wed, 4 Feb 2026 22:51:39 +0000 Subject: [PATCH 19/20] metrics now (mostly) matching for megatron vs fsdp --- skyrl-train/examples/megatron/run_megatron.sh | 18 +++++++-------- .../megatron/megatron_model_wrapper.py | 14 ++++++------ skyrl-train/skyrl_train/workers/worker.py | 22 +++++-------------- .../skyrl_train/workers/worker_utils.py | 4 +++- .../tests/gpu/gpu_ci/test_megatron_worker.py | 8 +++---- 5 files changed, 28 insertions(+), 38 deletions(-) diff --git a/skyrl-train/examples/megatron/run_megatron.sh b/skyrl-train/examples/megatron/run_megatron.sh index bf341222a..c31657c86 100644 --- a/skyrl-train/examples/megatron/run_megatron.sh +++ b/skyrl-train/examples/megatron/run_megatron.sh @@ -17,13 +17,15 @@ MEGATRON_TP=1 MEGATRON_PP=1 MEGATRON_CP=1 -# torch profiler config -ENABLE_TORCH_PROFILER=false -RANKS_TO_PROFILE="[0]" -SAVE_PATH="$HOME/megatron_prof/tp${MEGATRON_TP}_pp${MEGATRON_PP}_cp${MEGATRON_CP}_${MODEL_NAME}" +# # torch profiler config +# ENABLE_TORCH_PROFILER=false +# RANKS_TO_PROFILE="[0]" +# SAVE_PATH="$HOME/megatron_prof/tp${MEGATRON_TP}_pp${MEGATRON_PP}_cp${MEGATRON_CP}_${MODEL_NAME}" -TIS_TYPE="token" -TIS_RATIO_CLIP_HIGH=2.0 +# TIS_TYPE="token" +# TIS_RATIO_CLIP_HIGH=2.0 + # trainer.algorithm.off_policy_correction.tis_ratio_type=$TIS_TYPE \ + # trainer.algorithm.off_policy_correction.token_tis_ratio_clip_high=$TIS_RATIO_CLIP_HIGH \ uv run --isolated --extra mcore -m skyrl_train.entrypoints.main_base \ data.train_data="['$DATA_DIR/train.parquet']" \ @@ -45,8 +47,6 @@ uv run --isolated --extra mcore -m skyrl_train.entrypoints.main_base \ trainer.ref.megatron_config.tensor_model_parallel_size=$MEGATRON_TP \ trainer.ref.megatron_config.context_parallel_size=$MEGATRON_CP \ trainer.ref.megatron_config.pipeline_model_parallel_size=$MEGATRON_PP \ - trainer.algorithm.off_policy_correction.tis_ratio_type=$TIS_TYPE \ - trainer.algorithm.off_policy_correction.token_tis_ratio_clip_high=$TIS_RATIO_CLIP_HIGH \ trainer.use_sample_packing=true \ trainer.epochs=20 \ trainer.eval_batch_size=1024 \ @@ -72,7 +72,7 @@ uv run --isolated --extra mcore -m skyrl_train.entrypoints.main_base \ generator.gpu_memory_utilization=0.7 \ trainer.logger="$LOGGER" \ trainer.project_name="gsm8k_megatron" \ - trainer.run_name="gsm8k_megatron_tp${MEGATRON_TP}_pp${MEGATRON_PP}_cp${MEGATRON_CP}_${MODEL_NAME}_loss_sum_with_tis" \ + trainer.run_name="gsm8k_megatron_tp${MEGATRON_TP}_pp${MEGATRON_PP}_cp${MEGATRON_CP}_${MODEL_NAME}_loss_sum_no_scaling_num_microbatches" \ trainer.resume_mode=null \ trainer.ckpt_path="$HOME/ckpts/gsm8k_megatron_ckpt" \ $@ \ No newline at end of file diff --git a/skyrl-train/skyrl_train/workers/megatron/megatron_model_wrapper.py b/skyrl-train/skyrl_train/workers/megatron/megatron_model_wrapper.py index a98c4a0cf..a02bd5009 100644 --- a/skyrl-train/skyrl_train/workers/megatron/megatron_model_wrapper.py +++ b/skyrl-train/skyrl_train/workers/megatron/megatron_model_wrapper.py @@ -257,11 +257,10 @@ def loss_func(logits, data): rollout_logprobs=rollout_action_logprobs, ) - policy_loss = policy_loss * loss_scale - # SFT path: cross_entropy loss (negative log likelihood) if resolved_loss_name == "cross_entropy": - loss = policy_loss + unscaled_loss = policy_loss + loss = unscaled_loss * loss_scale # Compute elementwise loss for Tinker API (per-token NLL) with torch.no_grad(): @@ -290,7 +289,7 @@ def loss_func(logits, data): ) metrics = { - "loss": loss.detach().item(), + "loss": unscaled_loss.detach().item(), "response_length": num_actions, "loss_fn_outputs": loss_fn_outputs, } @@ -320,11 +319,12 @@ def loss_func(logits, data): kl_loss = torch.tensor(0.0) kl_loss_term = kl_loss * loss_config.kl_loss_coef - loss = policy_loss + kl_loss_term - entropy_loss_term + unscaled_loss = policy_loss + kl_loss_term - entropy_loss_term + loss = unscaled_loss * loss_scale metrics = { - "final_loss": loss.detach().item(), - "policy_loss": policy_loss.detach().item(), + "final_loss": unscaled_loss.detach().item() * dp_size, + "policy_loss": policy_loss.detach().item() * dp_size, "policy_entropy": entropy.detach().item(), "policy_kl": kl_loss.detach().item(), } diff --git a/skyrl-train/skyrl_train/workers/worker.py b/skyrl-train/skyrl_train/workers/worker.py index 21fe3285f..bb6a7e833 100644 --- a/skyrl-train/skyrl_train/workers/worker.py +++ b/skyrl-train/skyrl_train/workers/worker.py @@ -686,8 +686,13 @@ def forward_backward( for k, v in metrics.items(): all_metrics[k].append(v) + # reduce metrics across micro batches (sum, mean, min, max) result = reduce_metrics(dict(all_metrics)) + # all reduce metrics across DP workers + dp_group = self.device_mesh.get_group("dp") + result = all_reduce_metrics(result, self.strategy, group=dp_group) + # Add back loss_fn_outputs (concatenated across micro-batches) if all_loss_fn_outputs: result["loss_fn_outputs"] = all_loss_fn_outputs @@ -887,23 +892,6 @@ def _forward_backward_micro( if self.cfg.trainer.algorithm.use_kl_loss: status["policy_kl"] = kl_loss.item() - loss_fn_outputs = status.pop("loss_fn_outputs", None) - - # All-reduce metrics across DP workers - # hacky work aroudn to all reduce sum for loss while keeping mean for other metrics for now - loss_status = { - "final_loss": status["final_loss"], - "policy_loss": status["policy_loss"], - } - loss_status = self.strategy.all_reduce(loss_status, op="sum") - status = all_reduce_metrics(status, self.strategy) - status["final_loss"] = loss_status["final_loss"] - status["policy_loss"] = loss_status["policy_loss"] - - # Add back loss_fn_outputs after all_reduce - if loss_fn_outputs is not None: - status["loss_fn_outputs"] = loss_fn_outputs - return status def optim_step(self) -> float: diff --git a/skyrl-train/skyrl_train/workers/worker_utils.py b/skyrl-train/skyrl_train/workers/worker_utils.py index 93b133f7b..ff5276a84 100644 --- a/skyrl-train/skyrl_train/workers/worker_utils.py +++ b/skyrl-train/skyrl_train/workers/worker_utils.py @@ -28,8 +28,10 @@ def all_reduce_metrics(metrics: Dict[str, List[float]], strategy: DistributedStr """All reduce metrics across all processes.""" min_metrics = {k: v for k, v in metrics.items() if k.endswith("_min")} max_metrics = {k: v for k, v in metrics.items() if k.endswith("_max")} - mean_metrics = {k: v for k, v in metrics.items() if k not in min_metrics and k not in max_metrics} sum_metrics = {k: v for k, v in metrics.items() if k.endswith("_loss")} + mean_metrics = { + k: v for k, v in metrics.items() if k not in min_metrics and k not in max_metrics and k not in sum_metrics + } status_mean = strategy.all_reduce(mean_metrics, op="mean", group=group) status_min = strategy.all_reduce(min_metrics, op="min", group=group) status_max = strategy.all_reduce(max_metrics, op="max", group=group) diff --git a/skyrl-train/tests/gpu/gpu_ci/test_megatron_worker.py b/skyrl-train/tests/gpu/gpu_ci/test_megatron_worker.py index 4916c5645..06fc1af93 100644 --- a/skyrl-train/tests/gpu/gpu_ci/test_megatron_worker.py +++ b/skyrl-train/tests/gpu/gpu_ci/test_megatron_worker.py @@ -413,7 +413,7 @@ async def test_megatron_lora_forward(ray_init_fixture, tp, pp, cp, ep, etp, gpus @pytest.mark.parametrize( ("worker_type", "tp", "pp", "cp", "ep", "etp", "gpus_per_node", "use_sample_packing", "use_entropy_loss", "lora"), [ - ("policy", 2, 1, 1, 1, 1, 4, True, False, False), + ("policy", 1, 1, 1, 1, 1, 4, True, False, False), ("policy", 2, 2, 1, 1, 1, 4, True, True, False), ("policy", 2, 2, 1, 1, 1, 4, True, False, True), ("policy", 2, 2, 1, 1, 1, 4, False, False, False), @@ -439,7 +439,7 @@ async def test_megatron_train( Full test: initialize actor group, send dummy experience to training_step, validate output. """ cfg = get_test_actor_config(model_name=MODEL_NAME if ep == 1 else MOE_MODEL_NAME) - batch_size = gpus_per_node * 2 + batch_size = gpus_per_node * 8 batch = get_test_training_batch(batch_size=batch_size) cfg.trainer.strategy = "megatron" @@ -450,6 +450,7 @@ async def test_megatron_train( cfg.trainer.policy.megatron_config.expert_model_parallel_size = ep cfg.trainer.policy.megatron_config.expert_tensor_parallel_size = etp cfg.trainer.use_sample_packing = use_sample_packing + cfg.trainer.algorithm.use_kl_loss = False if use_entropy_loss: cfg.trainer.algorithm.use_entropy_loss = True cfg.trainer.algorithm.entropy_loss_coef = 0.01 @@ -507,7 +508,7 @@ async def test_megatron_train( print("megatron results: ", results_megatron) print("\n\n") print("megatron results: ", results_megatron[0]) - + print("\n\n") ray.shutdown() ray_init_for_tests() @@ -551,7 +552,6 @@ async def test_megatron_train( "policy_lr", "loss_metrics/clip_ratio", "policy_entropy", - "policy_kl", "final_loss", ] if ep > 1: From 05054d175352aae6b50db1f177da3c13c05d4d4a Mon Sep 17 00:00:00 2001 From: Eric Tang Date: Thu, 5 Feb 2026 08:30:57 +0000 Subject: [PATCH 20/20] loss scale in fsdp worker correctly --- skyrl-train/skyrl_train/workers/worker.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/skyrl-train/skyrl_train/workers/worker.py b/skyrl-train/skyrl_train/workers/worker.py index bb6a7e833..702da84c7 100644 --- a/skyrl-train/skyrl_train/workers/worker.py +++ b/skyrl-train/skyrl_train/workers/worker.py @@ -808,11 +808,10 @@ def _forward_backward_micro( ) loss_scale = self.mesh_rank.dp_size - policy_loss = policy_loss * loss_scale # SFT path: skip KL/entropy terms, return per-token outputs for Tinker API if resolved_loss_name == "cross_entropy": - loss = policy_loss + loss = policy_loss * loss_scale self.strategy.backward(loss, self.model, self.optimizer) # Compute elementwise loss for Tinker API (per-token NLL) @@ -878,11 +877,12 @@ def _forward_backward_micro( kl_loss_term = kl_loss * self.cfg.trainer.algorithm.kl_loss_coef loss = policy_loss + kl_loss_term - entropy_loss_term + loss = loss * loss_scale self.strategy.backward(loss, self.model, self.optimizer) status = { "final_loss": loss.item(), - "policy_loss": policy_loss.item(), + "policy_loss": policy_loss.item() * loss_scale, "policy_entropy": entropy.item(), "response_length": num_actions, "policy_lr": self.scheduler.get_last_lr()[0],