diff --git a/docs/content/docs/configuration/config.mdx b/docs/content/docs/configuration/config.mdx index 9f450425e7..957364539c 100644 --- a/docs/content/docs/configuration/config.mdx +++ b/docs/content/docs/configuration/config.mdx @@ -557,7 +557,7 @@ def ppo_policy_loss( pg_losses3 = -advantages * config.clip_ratio_c clip_pg_losses2 = torch.min(pg_losses3, clip_pg_losses1) loss = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1) - loss = reduce_loss(loss, loss_mask, config.loss_reduction) + loss = reduce_loss(loss, loss_mask) return loss, {"clip_ratio": clip_ratio} ``` diff --git a/skyrl-train/docs/configuration/config.rst b/skyrl-train/docs/configuration/config.rst index 63e0bea5d3..7573a874a4 100644 --- a/skyrl-train/docs/configuration/config.rst +++ b/skyrl-train/docs/configuration/config.rst @@ -575,7 +575,7 @@ It can be helpful to understand the final loss formulation to see how the differ pg_losses3 = -advantages * config.clip_ratio_c clip_pg_losses2 = torch.min(pg_losses3, clip_pg_losses1) loss = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1) - loss = reduce_loss(loss, loss_mask, config.loss_reduction) + loss = reduce_loss(loss, loss_mask) return loss, {"clip_ratio": clip_ratio} diff --git a/skyrl-train/examples/async/async_trainer.py b/skyrl-train/examples/async/async_trainer.py index 43268b5f26..18e7685f09 100644 --- a/skyrl-train/examples/async/async_trainer.py +++ b/skyrl-train/examples/async/async_trainer.py @@ -5,7 +5,6 @@ from skyrl_train.trainer import RayPPOTrainer from tqdm import tqdm from skyrl_train.utils import Timer -from skyrl_train.utils.ppo_utils import normalize_advantages_dict from skyrl_train.training_batch import TrainingInputBatch from skyrl_train.generators.base import GeneratorOutput from skyrl_train.utils.trainer_utils import ResumeMode @@ -145,9 +144,6 @@ async def _run_training(self, generation_buffer): training_input.pop(key) training_input.metadata.pop("uids") - if self.cfg.trainer.algorithm.advantage_batch_normalize: - training_input = normalize_advantages_dict(training_input) - if self.cfg.trainer.dump_data_batch: # dump data to file with Timer("dump_data_batch"): diff --git a/skyrl-train/examples/megatron/run_fsdp_baseline.sh b/skyrl-train/examples/megatron/run_fsdp_baseline.sh index 5291b5d960..e1121ad3df 100644 --- a/skyrl-train/examples/megatron/run_fsdp_baseline.sh +++ b/skyrl-train/examples/megatron/run_fsdp_baseline.sh @@ -25,7 +25,7 @@ uv run --isolated --extra $INFERENCE_BACKEND -m skyrl_train.entrypoints.main_bas generator.inference_engine_tensor_parallel_size=1 \ trainer.epochs=20 \ trainer.eval_batch_size=1024 \ - trainer.eval_before_train=true \ + trainer.eval_before_train=false \ trainer.eval_interval=5 \ trainer.update_epochs_per_batch=1 \ trainer.train_batch_size=128 \ @@ -35,8 +35,8 @@ uv run --isolated --extra $INFERENCE_BACKEND -m skyrl_train.entrypoints.main_bas trainer.ckpt_interval=10 \ trainer.max_prompt_length=512 \ generator.sampling_params.max_generate_length=1024 \ - trainer.policy.optimizer_config.lr=4.0e-6 \ - trainer.algorithm.use_kl_loss=true \ + trainer.policy.optimizer_config.lr=1.0e-6 \ + trainer.algorithm.use_kl_loss=false \ generator.backend=$INFERENCE_BACKEND \ generator.run_engines_locally=true \ generator.weight_sync_backend=nccl \ @@ -47,7 +47,7 @@ uv run --isolated --extra $INFERENCE_BACKEND -m skyrl_train.entrypoints.main_bas generator.gpu_memory_utilization=0.8 \ trainer.logger="$LOGGER" \ trainer.project_name="gsm8k_megatron" \ - trainer.run_name="gsm8k_fsdp1_4gpus" \ + trainer.run_name="gsm8k_fsdp1_4gpus_loss_sum" \ trainer.resume_mode=null \ trainer.ckpt_path="$HOME/ckpts/gsm8k_fsdp_ckpt" \ $@ \ No newline at end of file diff --git a/skyrl-train/examples/megatron/run_megatron.sh b/skyrl-train/examples/megatron/run_megatron.sh index cf0d9f9eda..0dd6665f3f 100644 --- a/skyrl-train/examples/megatron/run_megatron.sh +++ b/skyrl-train/examples/megatron/run_megatron.sh @@ -13,14 +13,19 @@ MODEL_NAME="Qwen/Qwen3-0.6B" INFERENCE_BACKEND="vllm" # currently only vllm is supported for megatron -MEGATRON_TP=2 -MEGATRON_PP=2 +MEGATRON_TP=1 +MEGATRON_PP=1 MEGATRON_CP=1 -# torch profiler config -ENABLE_TORCH_PROFILER=false -RANKS_TO_PROFILE="[0]" -SAVE_PATH="$HOME/megatron_prof/tp${MEGATRON_TP}_pp${MEGATRON_PP}_cp${MEGATRON_CP}_${MODEL_NAME}" +# # torch profiler config +# ENABLE_TORCH_PROFILER=false +# RANKS_TO_PROFILE="[0]" +# SAVE_PATH="$HOME/megatron_prof/tp${MEGATRON_TP}_pp${MEGATRON_PP}_cp${MEGATRON_CP}_${MODEL_NAME}" + +# TIS_TYPE="token" +# TIS_RATIO_CLIP_HIGH=2.0 + # trainer.algorithm.off_policy_correction.tis_ratio_type=$TIS_TYPE \ + # trainer.algorithm.off_policy_correction.token_tis_ratio_clip_high=$TIS_RATIO_CLIP_HIGH \ uv run --isolated --extra mcore -m skyrl_train.entrypoints.main_base \ data.train_data="['$DATA_DIR/train.parquet']" \ @@ -52,11 +57,12 @@ uv run --isolated --extra mcore -m skyrl_train.entrypoints.main_base \ trainer.policy_mini_batch_size=64 \ trainer.micro_forward_batch_size_per_gpu=4 \ trainer.micro_train_batch_size_per_gpu=4 \ + trainer.algorithm.loss_reduction="seq_mean_token_sum_norm" \ trainer.ckpt_interval=10 \ trainer.max_prompt_length=512 \ generator.sampling_params.max_generate_length=1024 \ trainer.policy.optimizer_config.lr=1.0e-6 \ - trainer.algorithm.use_kl_loss=true \ + trainer.algorithm.use_kl_loss=false \ generator.backend=$INFERENCE_BACKEND \ generator.run_engines_locally=true \ generator.weight_sync_backend=nccl \ @@ -67,7 +73,7 @@ uv run --isolated --extra mcore -m skyrl_train.entrypoints.main_base \ generator.gpu_memory_utilization=0.7 \ trainer.logger="$LOGGER" \ trainer.project_name="gsm8k_megatron" \ - trainer.run_name="gsm8k_megatron_tp${MEGATRON_TP}_pp${MEGATRON_PP}_cp${MEGATRON_CP}_${MODEL_NAME}" \ + trainer.run_name="gsm8k_megatron_tp${MEGATRON_TP}_pp${MEGATRON_PP}_cp${MEGATRON_CP}_${MODEL_NAME}_seq_mean_token_sum_norm" \ trainer.resume_mode=null \ trainer.ckpt_path="$HOME/ckpts/gsm8k_megatron_ckpt" \ $@ \ No newline at end of file diff --git a/skyrl-train/examples/megatron/run_megatron_dapo_qwen3_1.7b.sh b/skyrl-train/examples/megatron/run_megatron_dapo_qwen3_1.7b.sh new file mode 100644 index 0000000000..b558113d37 --- /dev/null +++ b/skyrl-train/examples/megatron/run_megatron_dapo_qwen3_1.7b.sh @@ -0,0 +1,121 @@ +set -x + +# Colocated DAPO training+generation for Qwen3-1.7B-Base on DAPO training data with Megatron. +# bash examples/algorithms/dapo/prepare_dapo_data.sh +# bash examples/megatron/run_megatron_dapo_qwen3_1.7b.sh + +MODEL_NAME="Qwen/Qwen3-1.7B-Base" +DATA_DIR="$HOME/data/dapo" +TRAIN_FILE="$DATA_DIR/dapo-math-17k-cleaned.parquet" +TEST_FILE="$DATA_DIR/aime-2024-cleaned.parquet" +NUM_NODES=1 +NUM_GPUS_PER_NODE=8 +NUM_INFERENCE_ENGINES=8 +INFERENCE_ENGINE_TENSOR_PARALLEL_SIZE=1 +LOGGER="wandb" # change to "console" to print to stdout + +CLIP_RATIO_LOW=0.2 +CLIP_RATIO_HIGH=0.28 +# use token mean loss reduction +LOSS_REDUCTION="token_mean" +# applies overlong filtering (but not soft overlong punishment) +APPLY_OVERLONG_FILTERING=true +# apply soft overlong punishment with custom trainer impl in main_dapo.py +OVERLONG_BUFFER_LEN=$((1024 * 4)) +OVERLONG_BUFFER_PENALTY_FACTOR=1.0 + +# other DAPO parameters +USE_KL_LOSS=false +TEMPERATURE=1.0 +TOP_P=1.0 +EVAL_TOP_P=0.7 +CLIP_RATIO_C=10.0 +MAX_PROMPT_LENGTH=$((1024 * 2)) +MAX_RESPONSE_LENGTH=$((1024 * 8)) + +# repro run parameters +TRAIN_BATCH_SIZE=512 +MINI_BATCH_SIZE=32 +N_SAMPLES_PER_PROMPT=16 +EVAL_N_SAMPLES_PER_PROMPT=32 +ENFORCE_EAGER=true # cuda graphs can cause some instability +LR=1e-6 + +# megatron config +MEGATRON_TP=4 +MEGATRON_PP=2 +MEGATRON_CP=1 +MEGATRON_EP=1 +MEGATRON_ETP=null + + +# TIS parameters +TIS_IMP_RATIO_CAP=2.0 +TIS_TYPE=token + +uv run --isolated --extra mcore -m examples.algorithms.dapo.main_dapo \ + data.train_data="['$TRAIN_FILE']" \ + data.val_data="['$TEST_FILE']" \ + trainer.algorithm.advantage_estimator="grpo" \ + trainer.algorithm.policy_loss_type="dual_clip" \ + +trainer.algorithm.overlong_buffer.len=$OVERLONG_BUFFER_LEN \ + +trainer.algorithm.overlong_buffer.penalty_factor=$OVERLONG_BUFFER_PENALTY_FACTOR \ + trainer.algorithm.loss_reduction=$LOSS_REDUCTION \ + generator.enforce_eager=$ENFORCE_EAGER \ + generator.apply_overlong_filtering=$APPLY_OVERLONG_FILTERING \ + generator.sampling_params.temperature=$TEMPERATURE \ + generator.sampling_params.top_p=$TOP_P \ + generator.eval_sampling_params.top_p=$EVAL_TOP_P \ + generator.eval_sampling_params.temperature=$TEMPERATURE \ + trainer.algorithm.use_kl_loss=$USE_KL_LOSS \ + trainer.algorithm.clip_ratio_c=$CLIP_RATIO_C \ + trainer.policy.model.path="$MODEL_NAME" \ + trainer.placement.colocate_all=true \ + trainer.strategy=megatron \ + trainer.placement.policy_num_nodes=$NUM_NODES \ + trainer.placement.policy_num_gpus_per_node=$NUM_GPUS_PER_NODE \ + generator.num_inference_engines=$NUM_INFERENCE_ENGINES \ + generator.inference_engine_tensor_parallel_size=$INFERENCE_ENGINE_TENSOR_PARALLEL_SIZE \ + trainer.policy.megatron_config.tensor_model_parallel_size=$MEGATRON_TP \ + trainer.policy.megatron_config.pipeline_model_parallel_size=$MEGATRON_PP \ + trainer.policy.megatron_config.context_parallel_size=$MEGATRON_CP \ + trainer.policy.megatron_config.expert_model_parallel_size=$MEGATRON_EP \ + trainer.policy.megatron_config.expert_tensor_parallel_size=$MEGATRON_ETP \ + trainer.algorithm.off_policy_correction.tis_ratio_type=$TIS_TYPE \ + trainer.algorithm.off_policy_correction.token_tis_ratio_clip_high=$TIS_IMP_RATIO_CAP \ + trainer.epochs=20 \ + trainer.algorithm.eps_clip_low=$CLIP_RATIO_LOW \ + trainer.algorithm.eps_clip_high=$CLIP_RATIO_HIGH \ + trainer.eval_batch_size=1024 \ + trainer.eval_before_train=true \ + trainer.eval_interval=5 \ + trainer.update_epochs_per_batch=1 \ + trainer.train_batch_size=$TRAIN_BATCH_SIZE \ + trainer.policy_mini_batch_size=$MINI_BATCH_SIZE \ + trainer.micro_forward_batch_size_per_gpu=8 \ + trainer.micro_train_batch_size_per_gpu=8 \ + trainer.ckpt_interval=10 \ + trainer.max_prompt_length=$MAX_PROMPT_LENGTH \ + generator.sampling_params.max_generate_length=$MAX_RESPONSE_LENGTH \ + trainer.policy.optimizer_config.lr=$LR \ + trainer.policy.optimizer_config.num_warmup_steps=160 \ + trainer.policy.optimizer_config.weight_decay=0.1 \ + trainer.policy.optimizer_config.max_grad_norm=1.0 \ + generator.backend=vllm \ + generator.run_engines_locally=true \ + generator.weight_sync_backend=nccl \ + generator.async_engine=false \ + generator.batched=true \ + environment.env_class=aime \ + generator.n_samples_per_prompt=$N_SAMPLES_PER_PROMPT \ + generator.eval_n_samples_per_prompt=$EVAL_N_SAMPLES_PER_PROMPT \ + generator.gpu_memory_utilization=0.8 \ + trainer.logger="$LOGGER" \ + trainer.project_name="dapo_aime" \ + trainer.run_name="dapo_qwen3_1.7b_base_megatron_tp${MEGATRON_TP}_pp${MEGATRON_PP}_cp${MEGATRON_CP}_loss_sum_dp1" \ + trainer.export_path="$HOME/exports/dapo_qwen3_1.7b_base_megatron_tp${MEGATRON_TP}_pp${MEGATRON_PP}_cp${MEGATRON_CP}_loss_sum_dp1" \ + trainer.hf_save_interval=300 \ + trainer.resume_mode=latest \ + trainer.max_ckpts_to_keep=3 \ + trainer.ckpt_path="$HOME/ckpts/dapo_qwen3_1.7b_base_megatron_tp${MEGATRON_TP}_pp${MEGATRON_PP}_cp${MEGATRON_CP}_loss_sum_dp1" \ + $@ \ No newline at end of file diff --git a/skyrl-train/examples/on_policy_distillation/main_on_policy_distill.py b/skyrl-train/examples/on_policy_distillation/main_on_policy_distill.py index e40797b647..94f91fb15b 100644 --- a/skyrl-train/examples/on_policy_distillation/main_on_policy_distill.py +++ b/skyrl-train/examples/on_policy_distillation/main_on_policy_distill.py @@ -51,7 +51,7 @@ def compute_importance_sampling_policy_loss( # as defined here: https://tinker-docs.thinkingmachines.ai/losses#policy-gradient-importance_sampling loss = -torch.exp(log_probs - old_log_probs) * advantages - loss = reduce_loss(loss, loss_mask, "seq_mean_token_sum_norm", config.max_seq_len) + loss = reduce_loss(loss, loss_mask) # return loss and a dummy clip ratio value as we aren't clipping here return loss, {"clip_ratio": 0.0} diff --git a/skyrl-train/examples/on_policy_distillation/run_on_policy_distill_math_qwen3_1.7b.sh b/skyrl-train/examples/on_policy_distillation/run_on_policy_distill_math_qwen3_1.7b.sh index 250e8252af..6be4ab3eee 100644 --- a/skyrl-train/examples/on_policy_distillation/run_on_policy_distill_math_qwen3_1.7b.sh +++ b/skyrl-train/examples/on_policy_distillation/run_on_policy_distill_math_qwen3_1.7b.sh @@ -74,6 +74,7 @@ uv run --isolated --extra vllm -m examples.on_policy_distillation.main_on_policy trainer.policy.optimizer_config.weight_decay=0.1 \ trainer.algorithm.use_kl_loss=$USE_KL_LOSS \ trainer.algorithm.use_kl_in_reward=$USE_KL_IN_REWARD \ + trainer.algorithm.loss_reduction="seq_mean_token_sum_norm" \ generator.backend=vllm \ generator.run_engines_locally=true \ generator.async_engine=false \ diff --git a/skyrl-train/examples/on_policy_distillation/run_on_policy_distill_math_qwen3_4b.sh b/skyrl-train/examples/on_policy_distillation/run_on_policy_distill_math_qwen3_4b.sh index 670a351520..fb60594ae2 100644 --- a/skyrl-train/examples/on_policy_distillation/run_on_policy_distill_math_qwen3_4b.sh +++ b/skyrl-train/examples/on_policy_distillation/run_on_policy_distill_math_qwen3_4b.sh @@ -74,6 +74,7 @@ uv run --isolated --extra vllm -m examples.on_policy_distillation.main_on_policy trainer.policy.optimizer_config.weight_decay=0.1 \ trainer.algorithm.use_kl_loss=$USE_KL_LOSS \ trainer.algorithm.use_kl_in_reward=$USE_KL_IN_REWARD \ + trainer.algorithm.loss_reduction="seq_mean_token_sum_norm" \ generator.backend=vllm \ generator.run_engines_locally=true \ generator.async_engine=false \ diff --git a/skyrl-train/skyrl_train/distributed/strategy.py b/skyrl-train/skyrl_train/distributed/strategy.py index fb8269e977..c638a4c3c4 100644 --- a/skyrl-train/skyrl_train/distributed/strategy.py +++ b/skyrl-train/skyrl_train/distributed/strategy.py @@ -67,11 +67,11 @@ def get_rank(self) -> int: """Get current process rank""" return dist.get_rank() - def all_reduce(self, data: DataT, op="mean") -> DataT: + def all_reduce(self, data: DataT, op="mean", group=None) -> DataT: """Perform all_reduce across all processes""" assert op in ("mean", "max", "sum", "min") if isinstance(data, dict): - return {k: self.all_reduce(v, op) for k, v in data.items()} + return {k: self.all_reduce(v, op, group) for k, v in data.items()} else: is_tensor = True if not isinstance(data, torch.Tensor): @@ -82,14 +82,17 @@ def all_reduce(self, data: DataT, op="mean") -> DataT: if is_cpu_tensor: data = data.to(torch.cuda.current_device()) if op == "mean": - data /= self.world_size - dist.all_reduce(data, op=dist.ReduceOp.SUM) + if group is None: + data /= self.world_size + else: + data /= group.size() + dist.all_reduce(data, op=dist.ReduceOp.SUM, group=group) elif op == "max": - dist.all_reduce(data, op=dist.ReduceOp.MAX) + dist.all_reduce(data, op=dist.ReduceOp.MAX, group=group) elif op == "min": - dist.all_reduce(data, op=dist.ReduceOp.MIN) + dist.all_reduce(data, op=dist.ReduceOp.MIN, group=group) elif op == "sum": - dist.all_reduce(data, op=dist.ReduceOp.SUM) + dist.all_reduce(data, op=dist.ReduceOp.SUM, group=group) if is_cpu_tensor: data = data.cpu() return data.item() if not is_tensor else data diff --git a/skyrl-train/skyrl_train/fully_async_trainer.py b/skyrl-train/skyrl_train/fully_async_trainer.py index 1894e1e387..f5582911aa 100644 --- a/skyrl-train/skyrl_train/fully_async_trainer.py +++ b/skyrl-train/skyrl_train/fully_async_trainer.py @@ -20,7 +20,6 @@ from skyrl_train.trainer import RayPPOTrainer from tqdm import tqdm from skyrl_train.utils import Timer -from skyrl_train.utils.ppo_utils import normalize_advantages_dict from skyrl_train.training_batch import TrainingInputBatch from skyrl_train.generators.base import GeneratorOutput from skyrl_train.utils.trainer_utils import ResumeMode, build_dataloader @@ -511,9 +510,6 @@ async def _run_training(self, training_input: TrainingInputBatch): training_input.pop(key) training_input.metadata.pop("uids") - if self.cfg.trainer.algorithm.advantage_batch_normalize: - training_input = normalize_advantages_dict(training_input) - if self.cfg.trainer.dump_data_batch: # dump data to file with Timer("dump_data_batch"): diff --git a/skyrl-train/skyrl_train/trainer.py b/skyrl-train/skyrl_train/trainer.py index c046ea36fe..b5ea7449ca 100644 --- a/skyrl-train/skyrl_train/trainer.py +++ b/skyrl-train/skyrl_train/trainer.py @@ -53,9 +53,8 @@ FixedKLController, compute_approx_kl, get_kl_controller, - normalize_advantages_dict, + masked_mean, ) -from skyrl_train.utils.torch_utils import masked_mean from skyrl_train.utils.tracking import Tracking from skyrl_train.utils.trainer_utils import ( GLOBAL_STEP_PREFIX, @@ -269,9 +268,6 @@ async def train(self): training_input.pop(key) training_input.metadata.pop("uids") - if self.cfg.trainer.algorithm.advantage_batch_normalize: - training_input = normalize_advantages_dict(training_input) - if self.cfg.trainer.dump_data_batch: # dump data to file with Timer("dump_data_batch"): @@ -1025,6 +1021,53 @@ def apply_reward_kl_penalty( return data + def normalize_minibatch_advantages(self, data: TrainingInputBatch) -> TrainingInputBatch: + """Normalize the advantages in the mini-batch. + + This function handles two types of normalization: + 1. Batch normalization (z-score): if advantage_batch_normalize is True, + normalizes advantages to have zero mean and unit variance. + 2. Loss reduction normalization: scales advantages based on the loss_reduction + type to calculate the correct minibatch loss when reducing with a sum. + """ + advantages = data["advantages"] + loss_mask = data["loss_mask"] + response_mask = data["response_mask"] + + # NOTE: Do not modify the tensor in place! + # Otherwise subsequent epochs will keep dividing the same tensor. + + # Step 1: Z-score normalization (if enabled) + if self.cfg.trainer.algorithm.advantage_batch_normalize: + num_actions = response_mask.sum() + mean = advantages.mean() + std = ((advantages - mean).pow(2) * response_mask).sum() + rstd = (std / num_actions).clamp(min=1e-8).rsqrt() + advantages = (advantages - mean) * rstd + + # Step 2: Loss reduction normalization + # Option 1: token mean + if self.cfg.trainer.algorithm.loss_reduction == "token_mean": + data["advantages"] = advantages / loss_mask.sum().clamp(min=1) + + # Option 2: sequence mean + elif self.cfg.trainer.algorithm.loss_reduction == "sequence_mean": + batch_size = len(data) + data["advantages"] = advantages / (batch_size * loss_mask.sum(dim=-1, keepdim=True).clamp(min=1)) + + # Option 3: Dr. GRPO style loss reduction to avoid length bias by normalizing by a constant + elif self.cfg.trainer.algorithm.loss_reduction == "seq_mean_token_sum_norm": + batch_size = len(data) + max_seq_len = self.cfg.trainer.algorithm.max_seq_len + data["advantages"] = advantages / (batch_size * max_seq_len) + + else: + # No loss reduction normalization, but still apply batch normalization if it was done + if self.cfg.trainer.algorithm.advantage_batch_normalize: + data["advantages"] = advantages + + return data + def _execute_training_step(self, model: str, data: TrainingInputBatch) -> Dict[str, float]: """ Execute training step for FSDP strategy using forward_backward + optim_step. @@ -1050,13 +1093,22 @@ def _execute_training_step(self, model: str, data: TrainingInputBatch) -> Dict[s mini_batch_size = self.cfg.trainer.critic_mini_batch_size * n_samples all_metrics: Dict[str, List[float]] = defaultdict(list) + num_mini_batches = len(data) // mini_batch_size + + # iterate over mini-batches to do mini batch level normalization + for local_step in range(num_mini_batches): + start_idx = local_step * mini_batch_size + end_idx = (local_step + 1) * mini_batch_size + mini_batch = data[start_idx:end_idx] + mini_batch = self.normalize_minibatch_advantages(mini_batch) + # Copy normalized advantages back to original batch + data["advantages"][start_idx:end_idx] = mini_batch["advantages"] # Stage full batch in object store ONCE to avoid repeated serialization data_ref = self.dispatch.stage_data(data) # Training loop over epochs and mini-batches for _epoch in range(self.cfg.trainer.update_epochs_per_batch): - num_mini_batches = len(data) // mini_batch_size for local_step in range(num_mini_batches): start_idx = local_step * mini_batch_size end_idx = (local_step + 1) * mini_batch_size diff --git a/skyrl-train/skyrl_train/utils/ppo_utils.py b/skyrl-train/skyrl_train/utils/ppo_utils.py index 1242f95a28..b1ff83fe72 100644 --- a/skyrl-train/skyrl_train/utils/ppo_utils.py +++ b/skyrl-train/skyrl_train/utils/ppo_utils.py @@ -19,7 +19,7 @@ from collections import defaultdict from enum import StrEnum from functools import wraps -from typing import Callable, List, Literal, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import numpy as np import ray @@ -29,7 +29,6 @@ from omegaconf import DictConfig from skyrl_train.config import AlgorithmConfig -from skyrl_train.training_batch import TrainingInputBatch from skyrl_train.utils.off_policy_correction_utils import apply_off_policy_correction from skyrl_train.utils.torch_utils import masked_mean, safe_exp_delta @@ -124,27 +123,6 @@ def compute_approx_kl( return kld -@torch.no_grad() -def normalize_advantages_dict(data: TrainingInputBatch) -> TrainingInputBatch: - """Normalizes the advantages in the data batch. - - Expects: - - `["advantages"]`: Float[torch.Tensor, "batch_size seqlen"] - - `["response_mask"]`: Float[torch.Tensor, "batch_size seqlen"] - """ - advantages: Float[torch.Tensor, "batch_size seqlen"] = data["advantages"] - response_masks: Float[torch.Tensor, "batch_size seqlen"] = data["response_mask"] - num_actions: float = response_masks.sum() - # mean - mean: float = advantages.mean() - # std - std: float = ((advantages - mean).pow(2) * response_masks).sum() - rstd: float = (std / num_actions).clamp(min=1e-8).rsqrt() - - data["advantages"] = (advantages - mean) * rstd - return data - - def masked_var(values, mask, unbiased=True): """Compute variance of tensor with masked values.""" mean = masked_mean(values, mask) @@ -189,6 +167,7 @@ def ppo_critic_loss( clipfrac = None loss = (values - returns) ** 2 + # TODO: We separately run into the "mean of means" problem here. loss = masked_mean(loss, loss_mask, dim=-1).mean() return 0.5 * loss, clipfrac @@ -554,12 +533,6 @@ def ppo_policy_loss( rollout_logprobs: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, dict[str, float]]: assert config.policy_loss_type in ["regular", "dual_clip"], "loss_type must be either 'regular' or 'dual_clip'" - loss_reduction = config.loss_reduction - assert loss_reduction in [ - "token_mean", - "sequence_mean", - "seq_mean_token_sum_norm", - ], "loss_reduction must be either 'token_mean', 'sequence_mean', or 'seq_mean_token_sum_norm'" ratio = safe_exp_delta(log_probs - old_log_probs, clip=20.0, out_dtype=log_probs.dtype) surr1 = ratio * advantages @@ -580,7 +553,8 @@ def ppo_policy_loss( ) loss_metrics.update(off_policy_metrics) - loss = reduce_loss(loss, loss_mask, loss_reduction, config.max_seq_len) + loss = reduce_loss(loss, loss_mask) + return loss, loss_metrics @@ -601,16 +575,7 @@ def sapo_policy_loss( See https://arxiv.org/pdf/2511.20347 for more details. """ - # SAPO must use sequence_mean reduction - loss_reduction = config.loss_reduction - if loss_reduction != "sequence_mean": - # The SAPO paper uses sequence_mean reduction; there's no reason - # why a user couldn't use token_mean reduction, but - # it's not clear whether it would be stable or not. - from loguru import logger as logger_ # have to do lazy import to avoid pickling error - - logger_.warning(f"With SAPO it's recommended to use 'sequence_mean' loss reduction; got {loss_reduction}") - + # SAPO should use sequence_mean reduction to avoid length bias # temperature for positive and negative token updates tau_pos = torch.as_tensor(config.sapo.tau_pos, dtype=advantages.dtype, device=advantages.device) tau_neg = torch.as_tensor(config.sapo.tau_neg, dtype=advantages.dtype, device=advantages.device) @@ -651,7 +616,7 @@ def gate_function(x, tau): loss_metrics.update(off_policy_metrics) # for SAPO, we need to aggregate the loss at the sequence level (seq-mean-token-mean) - loss = reduce_loss(loss, loss_mask, loss_reduction, config.max_seq_len) + loss = reduce_loss(loss, loss_mask) return loss, loss_metrics @@ -677,16 +642,7 @@ def gspo_policy_loss( The variant of GSPO used here is GSPO-token, a generalization which allows for token-level advantages [equations 14 and 15 in the paper]. """ - # GSPO must use sequence_mean reduction - loss_reduction = config.loss_reduction - if loss_reduction != "sequence_mean": - # The GSPO paper uses sequence_mean reduction; there's no reason - # why a user couldn't use token_mean reduction, but - # it's not clear whether it would be stable or not. - from loguru import logger as logger_ # have to do lazy import to avoid pickling error - - logger_.warning(f"With GSPO it's recommended to use 'sequence_mean' loss reduction; got {loss_reduction}") - + # GSPO should use sequence_mean reduction to avoid length bias # Compute log ratios log_ratio = log_probs - old_log_probs @@ -718,7 +674,7 @@ def gspo_policy_loss( ) loss_metrics.update(off_policy_metrics) - loss = reduce_loss(loss, loss_mask, loss_reduction, config.max_seq_len) + loss = reduce_loss(loss, loss_mask) return loss, loss_metrics @@ -755,7 +711,7 @@ def compute_policy_loss_cispo( ) loss_metrics.update(off_policy_metrics) - loss = reduce_loss(loss, loss_mask, config.loss_reduction, config.max_seq_len) + loss = reduce_loss(loss, loss_mask) return loss, loss_metrics @@ -820,8 +776,6 @@ def compute_policy_loss_clip_cov( pg_loss = reduce_loss( loss=pg_losses, loss_mask=loss_mask, - loss_reduction=config.loss_reduction, - max_seq_len=config.max_seq_len, ) return pg_loss, {"clip_ratio": clip_frac.item()} @@ -876,12 +830,7 @@ def compute_policy_loss_kl_cov( large_cov_idxs % advantages.shape[1], ] - pg_loss = reduce_loss( - loss=pg_losses, - loss_mask=loss_mask, - loss_reduction=config.loss_reduction, - max_seq_len=config.max_seq_len, - ) + pg_loss = reduce_loss(loss=pg_losses, loss_mask=loss_mask) # NOTE (sumanthrh): Since the pg clip ratio is not applicable for KL-COV so we just use 0.0 return pg_loss, {"clip_ratio": 0.0} @@ -920,10 +869,7 @@ def cross_entropy_loss( elementwise_loss = -log_probs # Apply loss mask and sum (matching Tinker's SUM reduction semantics) - if loss_mask is not None: - loss = (elementwise_loss * loss_mask).sum() - else: - loss = elementwise_loss.sum() + loss = reduce_loss(elementwise_loss, loss_mask) # No clipping in cross-entropy loss return loss, {"clip_ratio": 0.0} @@ -932,29 +878,8 @@ def cross_entropy_loss( def reduce_loss( loss: torch.Tensor, loss_mask: Optional[torch.Tensor], - loss_reduction: Literal["token_mean", "sequence_mean", "seq_mean_token_sum_norm"], - max_seq_len: Optional[int] = None, ) -> torch.Tensor: - if loss_reduction == "token_mean": - # sum over *all* valid tokens, divide by total valid-token count - loss = masked_mean(loss, loss_mask) - elif loss_reduction == "sequence_mean": - # per-sequence token-mean (dim=-1), then batch-mean - loss = masked_mean(loss, loss_mask, dim=-1).mean() - elif loss_reduction == "seq_mean_token_sum_norm": - # per-sequence token-sum, normalized by the max sequence length, then batch mean - # this is the Dr. GRPO loss reduction to avoid length bias by normalizing by a constant - assert max_seq_len is not None, "max_seq_len must be provided for seq_mean_token_sum_norm loss reduction" - # NOTE: max_seq_len is computed as cfg.generator.max_input_length + cfg.generator.sampling_params.max_generate_length by default - if loss_mask is not None: - seq_losses = torch.sum(loss * loss_mask, dim=-1) / max_seq_len - else: - # If no mask, assume all tokens are valid - seq_losses = torch.sum(loss, dim=-1) / max_seq_len - loss = torch.mean(seq_losses) - else: - raise ValueError(f"Invalid loss reduction type: {loss_reduction}") - return loss + return (loss * loss_mask).sum() if loss_mask is not None else loss.sum() # NOTE (erictang000): below ported from verl diff --git a/skyrl-train/skyrl_train/workers/megatron/megatron_model_wrapper.py b/skyrl-train/skyrl_train/workers/megatron/megatron_model_wrapper.py index 4def3870d3..17a711aa83 100644 --- a/skyrl-train/skyrl_train/workers/megatron/megatron_model_wrapper.py +++ b/skyrl-train/skyrl_train/workers/megatron/megatron_model_wrapper.py @@ -217,10 +217,19 @@ def loss_func(logits, data): loss_mask = data["loss_mask"] rollout_action_logprobs = data["rollout_action_logprobs"] action_mask = data.get("action_mask") + num_microbatches = data.get("num_microbatches") + dp_size = mpu.get_data_parallel_world_size() tp_grp = mpu.get_tensor_model_parallel_group() tp_rank = mpu.get_tensor_model_parallel_rank() + # Megatron's pipeline parallel forward_backward_func internally divides loss by num_microbatches + # https://github.com/NVIDIA/Megatron-LM/blob/core_v0.15.2/megatron/core/pipeline_parallel/schedules.py#L248 + # we want to maintain a sum of losses across all micro batches, so we reverse this division. + # we additionally multiply by the data parallelism size to undo the DDP all-reduce mean + # https://github.com/NVIDIA/Megatron-LM/blob/core_v0.15.2/megatron/core/distributed/distributed_data_parallel.py#L285 + loss_scale = num_microbatches * dp_size + # temperature normalization if temperature != 1.0: logits.div_(temperature) @@ -250,13 +259,15 @@ def loss_func(logits, data): # SFT path: cross_entropy loss (negative log likelihood) if resolved_loss_name == "cross_entropy": - loss = policy_loss + unscaled_loss = policy_loss + loss = unscaled_loss * loss_scale # Compute elementwise loss for Tinker API (per-token NLL) with torch.no_grad(): elementwise_loss = -action_log_probs if loss_mask is not None: elementwise_loss = elementwise_loss * loss_mask + elementwise_loss = elementwise_loss * loss_scale # Build per-sequence loss_fn_outputs batch_size = action_log_probs.shape[0] @@ -278,7 +289,7 @@ def loss_func(logits, data): ) metrics = { - "loss": loss.detach().item(), + "loss": unscaled_loss.detach().item(), "response_length": num_actions, "loss_fn_outputs": loss_fn_outputs, } @@ -308,10 +319,11 @@ def loss_func(logits, data): kl_loss = torch.tensor(0.0) kl_loss_term = kl_loss * loss_config.kl_loss_coef - loss = policy_loss + kl_loss_term - entropy_loss_term + unscaled_loss = policy_loss + kl_loss_term - entropy_loss_term + loss = unscaled_loss * loss_scale metrics = { - "final_loss": loss.detach().item(), + "final_loss": unscaled_loss.detach().item(), "policy_loss": policy_loss.detach().item(), "policy_entropy": entropy.detach().item(), "policy_kl": kl_loss.detach().item(), diff --git a/skyrl-train/skyrl_train/workers/megatron/megatron_worker.py b/skyrl-train/skyrl_train/workers/megatron/megatron_worker.py index 16ac509766..21ba1134a9 100644 --- a/skyrl-train/skyrl_train/workers/megatron/megatron_worker.py +++ b/skyrl-train/skyrl_train/workers/megatron/megatron_worker.py @@ -562,6 +562,9 @@ def forward_backward( } ) + for m_batch in micro_buffer: + m_batch["num_microbatches"] = len(micro_buffer) + if not micro_buffer: return {} @@ -580,9 +583,6 @@ def forward_backward( if self.empty_cuda_cache: torch.cuda.empty_cache() - # Track number of micro-batches for metrics - self._micro_batches_accumulated += len(micro_buffer) - # Aggregate metrics across micro-batches all_loss_fn_outputs = [] # Handle separately from scalar metrics for metrics in metrics_list: @@ -592,10 +592,12 @@ def forward_backward( for k, v in metrics.items(): all_metrics[k].append(v) - # Reduce and all-reduce metrics + # Reduce and all-reduce metrics across DP ranks only + # (metrics should be identical within DP groups, i.e., across TP/PP/SP ranks) status = reduce_metrics(dict(all_metrics)) status["policy_lr"] = self.optimizer.param_groups[0]["lr"] - status = all_reduce_metrics(status, self.strategy) + group = mpu.get_data_parallel_group(with_context_parallel=True) + status = all_reduce_metrics(status, self.strategy, group=group) # Add loss_fn_outputs back (not reduced, kept as list) if all_loss_fn_outputs: @@ -615,9 +617,6 @@ def optim_step(self) -> Optional[float]: """ grad_norm = self.strategy.optimizer_step(self.optimizer, self.model, self.scheduler, name="actor") - # Reset counter for next accumulation cycle - self._micro_batches_accumulated = 0 - if grad_norm is not None: grad_norm = grad_norm.detach().cpu().item() if hasattr(grad_norm, "item") else grad_norm return grad_norm diff --git a/skyrl-train/skyrl_train/workers/worker.py b/skyrl-train/skyrl_train/workers/worker.py index 1e8aa0e9df..d1e246060c 100644 --- a/skyrl-train/skyrl_train/workers/worker.py +++ b/skyrl-train/skyrl_train/workers/worker.py @@ -686,8 +686,13 @@ def forward_backward( for k, v in metrics.items(): all_metrics[k].append(v) + # reduce metrics across micro batches (sum, mean, min, max) result = reduce_metrics(dict(all_metrics)) + # all reduce metrics across DP workers + dp_group = self.device_mesh.get_group("dp") + result = all_reduce_metrics(result, self.strategy, group=dp_group) + # Add back loss_fn_outputs (concatenated across micro-batches) if all_loss_fn_outputs: result["loss_fn_outputs"] = all_loss_fn_outputs @@ -802,9 +807,12 @@ def _forward_backward_micro( rollout_logprobs=rollout_action_logprobs, ) + loss_scale = self.mesh_rank.dp_size + # SFT path: skip KL/entropy terms, return per-token outputs for Tinker API if resolved_loss_name == "cross_entropy": - loss = policy_loss + unscaled_loss = policy_loss + loss = unscaled_loss * loss_scale self.strategy.backward(loss, self.model, self.optimizer) # Compute elementwise loss for Tinker API (per-token NLL) @@ -812,6 +820,7 @@ def _forward_backward_micro( elementwise_loss = -action_log_probs if loss_mask is not None: elementwise_loss = elementwise_loss * loss_mask + elementwise_loss = elementwise_loss * loss_scale # Build per-sequence loss_fn_outputs (matches Tinker's ForwardBackwardOutput structure) # Trim to actual response length per sample (Tinker expects variable-length arrays @@ -836,7 +845,7 @@ def _forward_backward_micro( ) status = { - "loss": loss.item(), + "loss": unscaled_loss.item(), "response_length": num_actions, "lr": self.scheduler.get_last_lr()[0], "loss_fn_outputs": loss_fn_outputs, @@ -868,11 +877,12 @@ def _forward_backward_micro( kl_loss = torch.tensor(0.0) kl_loss_term = kl_loss * self.cfg.trainer.algorithm.kl_loss_coef - loss = policy_loss + kl_loss_term - entropy_loss_term + unscaled_loss = policy_loss + kl_loss_term - entropy_loss_term + loss = unscaled_loss * loss_scale self.strategy.backward(loss, self.model, self.optimizer) status = { - "final_loss": loss.item(), + "final_loss": unscaled_loss.item(), "policy_loss": policy_loss.item(), "policy_entropy": entropy.item(), "response_length": num_actions, @@ -883,31 +893,15 @@ def _forward_backward_micro( if self.cfg.trainer.algorithm.use_kl_loss: status["policy_kl"] = kl_loss.item() - loss_fn_outputs = status.pop("loss_fn_outputs", None) - - # All-reduce metrics across DP workers - status = all_reduce_metrics(status, self.strategy) - - # Add back loss_fn_outputs after all_reduce - if loss_fn_outputs is not None: - status["loss_fn_outputs"] = loss_fn_outputs - return status def optim_step(self) -> float: """ - Scale gradients by 1/micro_batches_accumulated, perform optimizer step, and reset counter. + Perform optimizer step. Returns: The gradient norm (before scaling, after clipping) """ - # Scale accumulated gradients by 1/N to get correct average - if self._micro_batches_accumulated > 0: - scale = 1.0 / self._micro_batches_accumulated - for param in self.model.parameters(): - if param.grad is not None: - param.grad.mul_(scale) - # Perform optimizer step (includes gradient clipping) grad_norm = self.strategy.optimizer_step(self.optimizer, self.model, self.scheduler, name="actor") diff --git a/skyrl-train/skyrl_train/workers/worker_utils.py b/skyrl-train/skyrl_train/workers/worker_utils.py index 4cdbc02b33..ff5276a842 100644 --- a/skyrl-train/skyrl_train/workers/worker_utils.py +++ b/skyrl-train/skyrl_train/workers/worker_utils.py @@ -13,7 +13,9 @@ def reduce_metrics(metrics: Dict[str, List[float]]) -> Dict[str, float]: for k, v in metrics.items(): assert len(v) > 0, f"No metrics for key {k}" assert all(isinstance(x, (int, float)) for x in v), f"Metrics for key {k} are not all numbers" - if k.endswith("_max"): + if k.endswith("_loss"): + reduced_metrics[k] = sum(v) + elif k.endswith("_max"): reduced_metrics[k] = max(v) elif k.endswith("_min"): reduced_metrics[k] = min(v) @@ -22,16 +24,21 @@ def reduce_metrics(metrics: Dict[str, List[float]]) -> Dict[str, float]: return reduced_metrics -def all_reduce_metrics(metrics: Dict[str, List[float]], strategy: DistributedStrategy) -> Dict[str, float]: +def all_reduce_metrics(metrics: Dict[str, List[float]], strategy: DistributedStrategy, group=None) -> Dict[str, float]: """All reduce metrics across all processes.""" min_metrics = {k: v for k, v in metrics.items() if k.endswith("_min")} max_metrics = {k: v for k, v in metrics.items() if k.endswith("_max")} - mean_metrics = {k: v for k, v in metrics.items() if k not in min_metrics and k not in max_metrics} - status_mean = strategy.all_reduce(mean_metrics, op="mean") - status_min = strategy.all_reduce(min_metrics, op="min") - status_max = strategy.all_reduce(max_metrics, op="max") + sum_metrics = {k: v for k, v in metrics.items() if k.endswith("_loss")} + mean_metrics = { + k: v for k, v in metrics.items() if k not in min_metrics and k not in max_metrics and k not in sum_metrics + } + status_mean = strategy.all_reduce(mean_metrics, op="mean", group=group) + status_min = strategy.all_reduce(min_metrics, op="min", group=group) + status_max = strategy.all_reduce(max_metrics, op="max", group=group) + status_sum = strategy.all_reduce(sum_metrics, op="sum", group=group) status_mean.update(status_min) status_mean.update(status_max) + status_mean.update(status_sum) return status_mean diff --git a/skyrl-train/tests/cpu/algorithms/test_losses.py b/skyrl-train/tests/cpu/algorithms/test_losses.py index fbc78b2b2b..818bf5103f 100644 --- a/skyrl-train/tests/cpu/algorithms/test_losses.py +++ b/skyrl-train/tests/cpu/algorithms/test_losses.py @@ -47,7 +47,6 @@ def test_policy_loss_dual_clip(): eps_clip_high=0.2, clip_ratio_c=3.0, policy_loss_type="dual_clip", - loss_reduction="token_mean", max_seq_len=4, off_policy_correction=NULL_OFF_POLICY_CORR, ) @@ -71,7 +70,7 @@ def test_policy_loss_dual_clip(): # For negative advantages, use dual clipped loss final_loss = torch.where(advantages < 0, min_loss, max_loss) # [-0.5, 1.0, 12.0] assert torch.allclose(final_loss, torch.tensor([[-0.5, 1.0, 12.0]], device=device), rtol=1e-3) - expected_loss = final_loss.mean() # -(-12.5/3) = 4.1667 + expected_loss = final_loss.sum() # 12.5 # Calculate actual loss actual_loss, _ = loss_fn(log_probs=log_probs, old_log_probs=old_log_probs, advantages=advantages, config=config) @@ -79,7 +78,7 @@ def test_policy_loss_dual_clip(): # Verify results torch.testing.assert_close(actual_loss, expected_loss, rtol=1e-3, atol=1e-8) # close to hand calculated value - assert actual_loss.item() == pytest.approx(4.1667, abs=1e-4) + assert actual_loss.item() == pytest.approx(12.5, abs=1e-4) def test_policy_loss_cispo(): @@ -98,7 +97,6 @@ def test_policy_loss_cispo(): config = AlgorithmConfig( cispo=CISPOConfig(cispo_eps_clip_low=0.2, cispo_eps_clip_high=0.2), policy_loss_type="cispo", - loss_reduction="token_mean", max_seq_len=4, off_policy_correction=NULL_OFF_POLICY_CORR, ) @@ -119,9 +117,9 @@ def test_policy_loss_cispo(): # loss_per_token[0] = -(1.0 * 0.8 * -1.69315) = 1.35452 # loss_per_token[1] = -(-1.0 * 1.0 * -1.0) = -1.0 # loss_per_token[2] = -(-4.0 * 1.2 * -0.69741) = -3.347568 - # mean(loss) = (1.35452 - 1.0 - 3.347568) / 3 = -0.99768266666 + # sum(loss) = 1.35452 - 1.0 - 3.347568 = -2.99 loss = -ratio.clamp(1 - 0.2, 1 + 0.2) * advantages * log_probs - expected_loss = loss.mean() + expected_loss = loss.sum() # Calculate actual loss actual_loss, _ = loss_fn( @@ -134,158 +132,7 @@ def test_policy_loss_cispo(): # Verify results torch.testing.assert_close(actual_loss, expected_loss, rtol=1e-3, atol=1e-8) # close to hand calculated value - assert actual_loss.item() == pytest.approx(-0.99768266666, abs=1e-4) - - -def test_policy_loss_reduction_modes(): - """Tests different loss_reduction modes in PolicyLoss function. - - Note: token_mean and sequence_mean give the same result when all sequences - have the same length and no mask is applied, but differ when masking creates - different effective sequence lengths. - """ - - device = "cpu" - - clip_eps_low = 0.2 - clip_eps_high = 0.2 - - advantages = torch.tensor( - [ - [2.0, 2.0, 2.0], # sequence 1: consistently higher advantages - [1.0, 1.0, 1.0], # sequence 2: consistently lower advantages - ], - device=device, - ) - - old_log_probs = torch.tensor([[-1.0, -1.0, -1.0], [-1.0, -1.0, -1.0]], device=device) - - log_probs = torch.tensor( - [[-1.5, -0.5, -1.2], [-0.8, -1.3, -0.9]], # ratios ≈ [[0.61, 1.65, 0.83],[1.22, 0.74, 1.11]] - device=device, - ) - - # Create masks to test sequences with different numbers of valid tokens - loss_mask = torch.tensor([[1.0, 1.0, 1.0], [1.0, 0.0, 0.0]], device=device) - - # Create configs for different reduction modes - config_token = AlgorithmConfig( - eps_clip_low=clip_eps_low, - eps_clip_high=clip_eps_high, - clip_ratio_c=3.0, - policy_loss_type="regular", - loss_reduction="token_mean", - max_seq_len=4, - off_policy_correction=NULL_OFF_POLICY_CORR, - ) - - config_seq = AlgorithmConfig( - eps_clip_low=clip_eps_low, - eps_clip_high=clip_eps_high, - clip_ratio_c=3.0, - policy_loss_type="regular", - loss_reduction="sequence_mean", - max_seq_len=4, - off_policy_correction=NULL_OFF_POLICY_CORR, - ) - - # Get loss function - loss_fn = PolicyLossRegistry.get("regular") - - # Test token_mean without mask - loss_token_no_mask, _ = loss_fn(log_probs, old_log_probs, advantages, config_token) - - # Test token_mean with mask - loss_token_with_mask, _ = loss_fn(log_probs, old_log_probs, advantages, config_token, loss_mask) - - # Test sequence_mean without mask - loss_seq_no_mask, _ = loss_fn(log_probs, old_log_probs, advantages, config_seq) - - # Test sequence_mean with mask - loss_seq_with_mask, _ = loss_fn(log_probs, old_log_probs, advantages, config_seq, loss_mask) - - # Manual calculations to verify (using default PolicyLoss parameters) - ratio = torch.exp(log_probs - old_log_probs) - surr1 = ratio * advantages - surr2 = ratio.clamp(1 - clip_eps_low, 1 + clip_eps_high) * advantages # clip_eps_low=0.2, clip_eps_high=0.2 - loss_per_token = -torch.min(surr1, surr2) - - # Expected token_mean without mask: mean of all tokens - expected_token_no_mask = loss_per_token.mean() - - # Expected token_mean with mask: masked mean of all tokens - expected_token_with_mask = (loss_per_token * loss_mask).sum() / (loss_mask.sum() + 1e-8) - - # Expected sequence_mean without mask: mean of sequence means - expected_seq_no_mask = loss_per_token.mean(dim=1).mean() - - # Expected sequence_mean with mask: mean of masked sequence means - seq_means_masked = (loss_per_token * loss_mask).sum(dim=1) / (loss_mask.sum(dim=1) + 1e-8) - expected_seq_with_mask = seq_means_masked.mean() - - # Verify results - torch.testing.assert_close(loss_token_no_mask, expected_token_no_mask, rtol=1e-5, atol=1e-8) - torch.testing.assert_close(loss_token_with_mask, expected_token_with_mask, rtol=1e-5, atol=1e-8) - torch.testing.assert_close(loss_seq_no_mask, expected_seq_no_mask, rtol=1e-5, atol=1e-8) - torch.testing.assert_close(loss_seq_with_mask, expected_seq_with_mask, rtol=1e-5, atol=1e-8) - - # Verify that the two reduction modes give the same results when sequences have equal length and no mask - assert torch.allclose( - loss_token_no_mask, loss_seq_no_mask, rtol=1e-5 - ), "token_mean and sequence_mean should give same results when sequences have equal length and no mask" - # But they should give different results when mask creates different effective sequence lengths - assert not torch.allclose( - loss_token_with_mask, loss_seq_with_mask, rtol=1e-3 - ), "token_mean and sequence_mean with mask should give different results" - - -def test_policy_loss_reduction_edge_cases(): - """Tests edge cases for loss_reduction modes.""" - - device = "cpu" - - # Test with single sequence (should give same result for both modes) - advantages = torch.tensor([[1.0, -1.0, 2.0]], device=device) - old_log_probs = torch.tensor([[-1.0, -1.0, -1.0]], device=device) - log_probs = torch.tensor([[-1.5, -0.5, -1.2]], device=device) - - # Create configs for different reduction modes - config_token = AlgorithmConfig( - eps_clip_low=0.2, - eps_clip_high=0.2, - clip_ratio_c=3.0, - policy_loss_type="regular", - loss_reduction="token_mean", - max_seq_len=4, - off_policy_correction=NULL_OFF_POLICY_CORR, - ) - - config_seq = AlgorithmConfig( - eps_clip_low=0.2, - eps_clip_high=0.2, - clip_ratio_c=3.0, - policy_loss_type="regular", - loss_reduction="sequence_mean", - max_seq_len=4, - off_policy_correction=NULL_OFF_POLICY_CORR, - ) - # Get loss function - loss_fn = PolicyLossRegistry.get("regular") - - loss_token, _ = loss_fn(log_probs, old_log_probs, advantages, config_token) - loss_seq, _ = loss_fn(log_probs, old_log_probs, advantages, config_seq) - - # With single sequence, both modes should give same result - torch.testing.assert_close(loss_token, loss_seq, rtol=1e-6, atol=1e-8) - - # Test with completely masked sequence - loss_mask = torch.tensor([[0.0, 0.0, 0.0]], device=device) - loss_token_masked, _ = loss_fn(log_probs, old_log_probs, advantages, config_token, loss_mask) - loss_seq_masked, _ = loss_fn(log_probs, old_log_probs, advantages, config_seq, loss_mask) - - # Should handle zero mask gracefully (due to +1e-8 in denominator) - assert torch.isfinite(loss_token_masked) - assert torch.isfinite(loss_seq_masked) + assert actual_loss.item() == pytest.approx(-2.9930, abs=1e-4) def test_gspo_importance_sampling_levels(): @@ -348,7 +195,6 @@ def test_gspo_importance_sampling_levels(): eps_clip_high=clip_eps_high, clip_ratio_c=3.0, policy_loss_type="regular", - loss_reduction="token_mean", max_seq_len=4, off_policy_correction=NULL_OFF_POLICY_CORR, ) @@ -361,7 +207,6 @@ def test_gspo_importance_sampling_levels(): eps_clip_high=clip_eps_high, clip_ratio_c=3.0, policy_loss_type="gspo", - loss_reduction="sequence_mean", # GSPO recommended reduction max_seq_len=4, off_policy_correction=NULL_OFF_POLICY_CORR, ) @@ -374,7 +219,7 @@ def test_gspo_importance_sampling_levels(): surr1_token = ratio_token * advantages surr2_token = ratio_token.clamp(1 - clip_eps_low, 1 + clip_eps_high) * advantages loss_per_token_token = -torch.min(surr1_token, surr2_token) - expected_token = (loss_per_token_token * loss_mask).sum() / (loss_mask.sum() + 1e-8) + expected_token = (loss_per_token_token * loss_mask).sum() # Calculate token-level clipping ratio is_clipped_token = (-surr2_token > -surr1_token) & (loss_mask.bool()) @@ -390,8 +235,7 @@ def test_gspo_importance_sampling_levels(): surr1_sequence = ratio_sequence * advantages surr2_sequence = ratio_sequence.clamp(1 - clip_eps_low, 1 + clip_eps_high) * advantages loss_per_token_sequence = -torch.min(surr1_sequence, surr2_sequence) - # GSPO uses sequence_mean reduction - expected_sequence = masked_mean(loss_per_token_sequence, loss_mask, dim=-1).mean() + expected_sequence = loss_per_token_sequence.sum() # Calculate sequence-level clipping ratio is_clipped_sequence = (-surr2_sequence > -surr1_sequence) & (loss_mask.bool()) @@ -465,7 +309,6 @@ def test_clip_cov_policy_loss(): eps_clip_low=0.2, eps_clip_high=0.2, policy_loss_type="clip_cov", - loss_reduction="token_mean", max_seq_len=4, clip_cov=ClipCovConfig(clip_ratio=0.5, clip_cov_lb=-5.0, clip_cov_ub=5.0), # Large ratio for testing off_policy_correction=NULL_OFF_POLICY_CORR, @@ -487,7 +330,6 @@ def test_clip_cov_policy_loss(): eps_clip_low=0.2, eps_clip_high=0.2, policy_loss_type="regular", - loss_reduction="token_mean", max_seq_len=4, off_policy_correction=NULL_OFF_POLICY_CORR, ) @@ -525,7 +367,6 @@ def test_kl_cov_policy_loss(): # Create KL-Cov config config = AlgorithmConfig( policy_loss_type="kl_cov", - loss_reduction="token_mean", max_seq_len=4, kl_cov=KLCovConfig(kl_cov_frac=0.5, ppo_kl_coef=1.0), # Apply KL to 50% of tokens off_policy_correction=NULL_OFF_POLICY_CORR, @@ -546,7 +387,6 @@ def test_kl_cov_policy_loss(): eps_clip_low=0.2, eps_clip_high=0.2, policy_loss_type="regular", - loss_reduction="token_mean", max_seq_len=4, use_tis=False, off_policy_correction=NULL_OFF_POLICY_CORR, @@ -574,10 +414,9 @@ def test_sapo_policy_loss_basic(): # Ratios ≈ [exp(-0.5), exp(0.2), exp(-0.1)] ≈ [0.6065, 1.2214, 0.9048] log_probs = torch.tensor([[-1.5, -0.8, -1.1]], device=device) - # SAPO config: uses sequence_mean reduction and distinct tau_pos / tau_neg + # SAPO config: distinct tau_pos / tau_neg config = AlgorithmConfig( policy_loss_type="sapo", - loss_reduction="sequence_mean", max_seq_len=4, sapo=SAPOConfig(tau_pos=1.0, tau_neg=2.0), off_policy_correction=NULL_OFF_POLICY_CORR, @@ -609,8 +448,7 @@ def gate_function(x, tau): gates = gate_function(ratio, taus) loss_per_token = -gates * advantages - # sequence_mean reduction: per-sequence token mean, then batch mean - expected_loss = loss_per_token.mean(dim=-1).mean() + expected_loss = loss_per_token.sum() torch.testing.assert_close(actual_loss, expected_loss, rtol=1e-5, atol=1e-8) diff --git a/skyrl-train/tests/cpu/test_trainer.py b/skyrl-train/tests/cpu/test_trainer.py index 33511e247c..f79f5549b0 100644 --- a/skyrl-train/tests/cpu/test_trainer.py +++ b/skyrl-train/tests/cpu/test_trainer.py @@ -448,7 +448,11 @@ def create_test_worker(worker_class): # Mock dependencies worker.strategy = MagicMock() worker.strategy.is_rank_0.return_value = False # Disable progress bars - worker.strategy.all_reduce.return_value = {"loss": 0.5, "lr": 1e-4} + worker.strategy.all_reduce.side_effect = lambda d, op, group=None: d # Return input dict unchanged + + # Mock device_mesh for DP group access + worker.device_mesh = MagicMock() + worker.device_mesh.get_group.return_value = None # No actual process group in tests # Always set model for all worker types worker.model = MagicMock() diff --git a/skyrl-train/tests/cpu/utils/test_ppo_utils.py b/skyrl-train/tests/cpu/utils/test_ppo_utils.py index f28b8c1f37..cb6572e8f6 100644 --- a/skyrl-train/tests/cpu/utils/test_ppo_utils.py +++ b/skyrl-train/tests/cpu/utils/test_ppo_utils.py @@ -243,29 +243,17 @@ def test_compute_gae_advantage_return_lam(advantage_test_data): def test_reduce_loss(): - """Test the reduce_loss function with different reduction types.""" - # Test data: 2x3 loss tensor with different valid token counts per sequence + """Test that reduce_loss computes the masked sum correctly.""" loss = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) - loss_mask = torch.tensor([[1.0, 1.0, 1.0], [1.0, 0.0, 0.0]]) # seq0 has 3 tokens, seq1 has 1 token - - # Test token_mean: sum all valid losses / count valid tokens - # Valid losses: [1.0, 2.0, 3.0, 4.0], mean = 10.0/4 = 2.5 - result_token = reduce_loss(loss, loss_mask, "token_mean") - expected_token = torch.tensor(2.5) - assert torch.allclose(result_token, expected_token), f"Expected {expected_token}, got {result_token}" - - # Test sequence_mean: mean of per-sequence means - # Seq 0: (1.0 + 2.0 + 3.0) / 3 = 2.0, Seq 1: 4.0 / 1 = 4.0, batch mean = (2.0 + 4.0) / 2 = 3.0 - result_seq = reduce_loss(loss, loss_mask, "sequence_mean") - expected_seq = torch.tensor(3.0) - assert torch.allclose(result_seq, expected_seq), f"Expected {expected_seq}, got {result_seq}" - - # Test seq_mean_token_sum_norm: sum per sequence / max_len, then batch mean - # Seq 0: (1.0 + 2.0 + 3.0) / 4 = 1.5, Seq 1: 4.0 / 4 = 1.0, batch mean = (1.5 + 1.0) / 2 = 1.25 - max_seq_len = 4 - result_max = reduce_loss(loss, loss_mask, "seq_mean_token_sum_norm", max_seq_len) - expected_max = torch.tensor(1.25) - assert torch.allclose(result_max, expected_max), f"Expected {expected_max}, got {result_max}" + loss_mask = torch.tensor([[1.0, 1.0, 1.0], [1.0, 0.0, 0.0]]) + + # With mask: sum of valid losses = 1.0 + 2.0 + 3.0 + 4.0 = 10.0 + result = reduce_loss(loss, loss_mask) + assert torch.allclose(result, torch.tensor(10.0)) + + # Without mask: sum of all losses = 1+2+3+4+5+6 = 21.0 + result_no_mask = reduce_loss(loss, None) + assert torch.allclose(result_no_mask, torch.tensor(21.0)) def test_adaptive_kl_controller_update(): diff --git a/skyrl-train/tests/cpu/workers/test_worker_utils.py b/skyrl-train/tests/cpu/workers/test_worker_utils.py index c75097c244..ae513ecb33 100644 --- a/skyrl-train/tests/cpu/workers/test_worker_utils.py +++ b/skyrl-train/tests/cpu/workers/test_worker_utils.py @@ -23,10 +23,16 @@ def test_reduce_metrics_min_suffix(self): assert result["is_ratio_min"] == 1.0 def test_reduce_metrics_mean_default(self): - """Keys without _max/_min suffix should use mean reduction.""" + """Keys without _max/_min/_loss suffix should use mean reduction.""" + metrics = {"entropy": [1.0, 2.0, 3.0]} + result = reduce_metrics(metrics) + assert result["entropy"] == 2.0 # mean of [1, 2, 3] + + def test_reduce_metrics_loss_sum(self): + """Keys ending in _loss should use sum reduction.""" metrics = {"policy_loss": [1.0, 2.0, 3.0]} result = reduce_metrics(metrics) - assert result["policy_loss"] == 2.0 # mean of [1, 2, 3] + assert result["policy_loss"] == 6.0 # sum of [1, 2, 3] def test_reduce_metrics_mixed(self): """Test mixed metric types are reduced correctly.""" @@ -34,11 +40,13 @@ def test_reduce_metrics_mixed(self): "is_ratio_max": [1.0, 10.0], "is_ratio_min": [0.5, 2.0], "policy_loss": [1.0, 3.0], + "entropy": [1.0, 3.0], } result = reduce_metrics(metrics) assert result["is_ratio_max"] == 10.0 assert result["is_ratio_min"] == 0.5 - assert result["policy_loss"] == 2.0 + assert result["policy_loss"] == 4.0 # sum + assert result["entropy"] == 2.0 # mean def test_reduce_metrics_single_value(self): """Test reduction with single value lists.""" @@ -65,7 +73,7 @@ def test_all_reduce_metrics_separates_by_suffix(self): strategy = MagicMock() # Mock all_reduce to return the input dict unchanged but track calls - def mock_all_reduce(d, op): + def mock_all_reduce(d, op, group=None): return {k: v for k, v in d.items()} strategy.all_reduce.side_effect = mock_all_reduce @@ -79,8 +87,8 @@ def mock_all_reduce(d, op): _ = all_reduce_metrics(metrics, strategy) - # Verify all_reduce was called 3 times - assert strategy.all_reduce.call_count == 3 + # Verify all_reduce was called 4 times (mean, min, max, sum) + assert strategy.all_reduce.call_count == 4 # Check that the correct ops were used calls = strategy.all_reduce.call_args_list @@ -93,9 +101,13 @@ def mock_all_reduce(d, op): op = kwargs.get("op") if kwargs else args[1] ops_and_keys.append((op, set(data_dict.keys()))) - # Verify mean metrics (policy_loss, entropy) + # Verify mean metrics (entropy only - no suffix) mean_call = [c for c in ops_and_keys if c[0] == "mean"][0] - assert mean_call[1] == {"policy_loss", "entropy"} + assert mean_call[1] == {"entropy"} + + # Verify sum metrics (_loss suffix) + sum_call = [c for c in ops_and_keys if c[0] == "sum"][0] + assert sum_call[1] == {"policy_loss"} # Verify min metrics min_call = [c for c in ops_and_keys if c[0] == "min"][0] @@ -110,13 +122,15 @@ def test_all_reduce_metrics_returns_merged_results(self): strategy = MagicMock() # Mock all_reduce to modify values based on op - def mock_all_reduce(d, op): + def mock_all_reduce(d, op, group=None): if op == "mean": return {k: v * 2 for k, v in d.items()} # Double for mean elif op == "min": return {k: v / 2 for k, v in d.items()} # Halve for min elif op == "max": return {k: v * 3 for k, v in d.items()} # Triple for max + elif op == "sum": + return {k: v * 4 for k, v in d.items()} # Quadruple for sum return d strategy.all_reduce.side_effect = mock_all_reduce @@ -125,6 +139,7 @@ def mock_all_reduce(d, op): "is_ratio_max": 10.0, "is_ratio_min": 0.1, "policy_loss": 1.5, + "entropy": 0.5, } result = all_reduce_metrics(metrics, strategy) @@ -133,16 +148,18 @@ def mock_all_reduce(d, op): assert "is_ratio_max" in result assert "is_ratio_min" in result assert "policy_loss" in result + assert "entropy" in result # Check values were transformed correctly assert result["is_ratio_max"] == 30.0 # 10.0 * 3 (max op) assert result["is_ratio_min"] == 0.05 # 0.1 / 2 (min op) - assert result["policy_loss"] == 3.0 # 1.5 * 2 (mean op) + assert result["policy_loss"] == 6.0 # 1.5 * 4 (sum op) + assert result["entropy"] == 1.0 # 0.5 * 2 (mean op) def test_all_reduce_metrics_only_max(self): """Test with only _max metrics.""" strategy = MagicMock() - strategy.all_reduce.side_effect = lambda d, op: d + strategy.all_reduce.side_effect = lambda d, op, group=None: d metrics = {"loss_max": 5.0, "ratio_max": 10.0} @@ -153,7 +170,7 @@ def test_all_reduce_metrics_only_max(self): def test_all_reduce_metrics_only_min(self): """Test with only _min metrics.""" strategy = MagicMock() - strategy.all_reduce.side_effect = lambda d, op: d + strategy.all_reduce.side_effect = lambda d, op, group=None: d metrics = {"loss_min": 0.1, "ratio_min": 0.01} @@ -162,12 +179,23 @@ def test_all_reduce_metrics_only_min(self): assert result == {"loss_min": 0.1, "ratio_min": 0.01} def test_all_reduce_metrics_only_mean(self): - """Test with only mean metrics (no _max/_min suffix).""" + """Test with only mean metrics (no _max/_min/_loss suffix).""" + strategy = MagicMock() + strategy.all_reduce.side_effect = lambda d, op, group=None: d + + metrics = {"entropy": 0.5, "kl_div": 1.5} + + result = all_reduce_metrics(metrics, strategy) + + assert result == {"entropy": 0.5, "kl_div": 1.5} + + def test_all_reduce_metrics_only_sum(self): + """Test with only _loss metrics (sum reduction).""" strategy = MagicMock() - strategy.all_reduce.side_effect = lambda d, op: d + strategy.all_reduce.side_effect = lambda d, op, group=None: d - metrics = {"policy_loss": 1.5, "entropy": 0.5} + metrics = {"policy_loss": 1.5, "value_loss": 0.5} result = all_reduce_metrics(metrics, strategy) - assert result == {"policy_loss": 1.5, "entropy": 0.5} + assert result == {"policy_loss": 1.5, "value_loss": 0.5} diff --git a/skyrl-train/tests/gpu/gpu_ci/test_megatron_worker.py b/skyrl-train/tests/gpu/gpu_ci/test_megatron_worker.py index d9730cf98c..75d52e37b0 100644 --- a/skyrl-train/tests/gpu/gpu_ci/test_megatron_worker.py +++ b/skyrl-train/tests/gpu/gpu_ci/test_megatron_worker.py @@ -422,7 +422,7 @@ async def test_megatron_lora_forward(ray_init_fixture, tp, pp, cp, ep, etp, gpus ("policy", 4, 1, 1, 4, 1, 4, True, False, True), ], ids=[ - "tp2_pp2_policy_seq_packing", + "x", "tp2_pp2_policy_seq_packing_with_entropy_loss", "tp2_pp2_policy_lora", "tp2_pp2_policy_unpacked", @@ -439,7 +439,8 @@ async def test_megatron_train( Full test: initialize actor group, send dummy experience to training_step, validate output. """ cfg = get_test_actor_config(model_name=MODEL_NAME if ep == 1 else MOE_MODEL_NAME) - batch = get_test_training_batch(batch_size=gpus_per_node) + batch_size = gpus_per_node * 2 + batch = get_test_training_batch(batch_size=batch_size) cfg.trainer.strategy = "megatron" cfg.trainer.placement.policy_num_gpus_per_node = gpus_per_node @@ -449,6 +450,7 @@ async def test_megatron_train( cfg.trainer.policy.megatron_config.expert_model_parallel_size = ep cfg.trainer.policy.megatron_config.expert_tensor_parallel_size = etp cfg.trainer.use_sample_packing = use_sample_packing + cfg.trainer.algorithm.use_kl_loss = False if use_entropy_loss: cfg.trainer.algorithm.use_entropy_loss = True cfg.trainer.algorithm.entropy_loss_coef = 0.01 @@ -466,7 +468,7 @@ async def test_megatron_train( cfg.trainer.algorithm.off_policy_correction.geo_mask_low = 0.98 # set batch sizes correctly - cfg.trainer.train_batch_size = gpus_per_node + cfg.trainer.train_batch_size = batch_size cfg.trainer.policy_mini_batch_size = gpus_per_node cfg.generator.n_samples_per_prompt = 1 cfg.trainer.micro_train_batch_size_per_gpu = 1 @@ -527,13 +529,12 @@ async def test_megatron_train( # Both FSDP and Megatron use forward_backward + optim_step (unified interface) batch.metadata["global_step"] = 0 - results_fsdp = ray.get(actor_group.async_run_ray_method("pass_through", "forward_backward", batch)) + results_fsdp = ray.get(actor_group.async_run_ray_method("mesh", "forward_backward", batch)) ray.get(actor_group.async_run_ray_method("pass_through", "optim_step")) # Get learning rate from worker lr_results = ray.get(actor_group.async_run_ray_method("pass_through", "get_lr")) for i, result in enumerate(results_fsdp): result["policy_lr"] = lr_results[i] - print("megatron results: ", results_megatron[0]) print("\n\n") print("fsdp results: ", results_fsdp[0]) @@ -543,7 +544,6 @@ async def test_megatron_train( "policy_lr", "loss_metrics/clip_ratio", "policy_entropy", - "policy_kl", "final_loss", ] if ep > 1: diff --git a/skyrl-train/tests/gpu/test_grpo_sp_sanity.py b/skyrl-train/tests/gpu/test_grpo_sp_sanity.py index eff1ce6506..47d16dbea3 100644 --- a/skyrl-train/tests/gpu/test_grpo_sp_sanity.py +++ b/skyrl-train/tests/gpu/test_grpo_sp_sanity.py @@ -15,7 +15,6 @@ from skyrl_train.config import SkyRLConfig from skyrl_train.utils import Timer -from skyrl_train.utils.ppo_utils import normalize_advantages_dict import asyncio @@ -122,9 +121,6 @@ def train(self): # remove some unwanted keys data.pop(batch_keys=["rewards"]) - if self.cfg.trainer.algorithm.advantage_batch_normalize: - data = normalize_advantages_dict(data) - # 4. train policy/critic model with Timer("train_critic_and_policy", self.all_timings): status = self.train_critic_and_policy(data)