Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
b7b494c
normalize the advantages instead
justinvyu Jan 23, 2026
da24f6b
sum reduction for the loss
justinvyu Jan 23, 2026
cbed7ff
add a few options
justinvyu Jan 23, 2026
fc7b775
always sum on the loss calculation side
justinvyu Jan 23, 2026
c411586
fix some bugs
justinvyu Jan 23, 2026
1bf83d8
revert the critic worker changes
justinvyu Jan 23, 2026
dc838e1
more revert
justinvyu Jan 23, 2026
84e34e2
Merge branch 'main' of https://github.com/erictang000/SkyRL into fix_…
erictang000 Jan 28, 2026
be22e03
Merge branch 'main' of https://github.com/erictang000/SkyRL into fix_…
erictang000 Jan 28, 2026
6259174
fix conflict with main
erictang000 Jan 28, 2026
4b8e556
x
erictang000 Jan 28, 2026
4547ae1
fix normalization
erictang000 Jan 29, 2026
635a9c1
change reduce_loss to just sums everywhere
erictang000 Jan 30, 2026
cd7506b
fix tests
erictang000 Jan 31, 2026
0133819
Merge branch 'main' of https://github.com/erictang000/SkyRL into fix_…
erictang000 Jan 31, 2026
ddeaae4
scale by num_microbatches in megatron loss
erictang000 Feb 2, 2026
ef24d32
x
erictang000 Feb 2, 2026
89145e5
add dp size scaling to megatron and move fsdp loss scaling to loss ra…
erictang000 Feb 2, 2026
133d9a9
Merge branch 'main' of https://github.com/erictang000/SkyRL into fix_…
erictang000 Feb 2, 2026
49babe4
debugging megatron fsdp loss diff
erictang000 Feb 3, 2026
a3fa670
Merge branch 'main' of https://github.com/erictang000/SkyRL into fix_…
erictang000 Feb 3, 2026
c0dd9ff
Merge branch 'fix_loss_reduction2' of https://github.com/justinvyu/Sk…
erictang000 Feb 3, 2026
eb241de
start trying to make metrics invariant to dp size for megatron
erictang000 Feb 4, 2026
772141a
x
erictang000 Feb 4, 2026
f66736a
metrics now (mostly) matching for megatron vs fsdp
erictang000 Feb 4, 2026
0df4467
Merge branch 'fix_loss_reduction2' of https://github.com/justinvyu/Sk…
erictang000 Feb 4, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion skyrl-train/docs/configuration/config.rst
Original file line number Diff line number Diff line change
Expand Up @@ -575,7 +575,7 @@ It can be helpful to understand the final loss formulation to see how the differ
pg_losses3 = -advantages * config.clip_ratio_c
clip_pg_losses2 = torch.min(pg_losses3, clip_pg_losses1)
loss = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1)
loss = reduce_loss(loss, loss_mask, config.loss_reduction)
loss = reduce_loss(loss, loss_mask)
return loss, {"clip_ratio": clip_ratio}


Expand Down
4 changes: 0 additions & 4 deletions skyrl-train/examples/async/async_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from skyrl_train.trainer import RayPPOTrainer
from tqdm import tqdm
from skyrl_train.utils import Timer
from skyrl_train.utils.ppo_utils import normalize_advantages_dict
from skyrl_train.training_batch import TrainingInputBatch
from skyrl_train.generators.base import GeneratorOutput
from skyrl_train.utils.trainer_utils import ResumeMode
Expand Down Expand Up @@ -145,9 +144,6 @@ async def _run_training(self, generation_buffer):
training_input.pop(key)
training_input.metadata.pop("uids")

if self.cfg.trainer.algorithm.advantage_batch_normalize:
training_input = normalize_advantages_dict(training_input)

if self.cfg.trainer.dump_data_batch:
# dump data to file
with Timer("dump_data_batch"):
Expand Down
8 changes: 4 additions & 4 deletions skyrl-train/examples/megatron/run_fsdp_baseline.sh
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ uv run --isolated --extra $INFERENCE_BACKEND -m skyrl_train.entrypoints.main_bas
generator.inference_engine_tensor_parallel_size=1 \
trainer.epochs=20 \
trainer.eval_batch_size=1024 \
trainer.eval_before_train=true \
trainer.eval_before_train=false \
trainer.eval_interval=5 \
trainer.update_epochs_per_batch=1 \
trainer.train_batch_size=128 \
Expand All @@ -35,8 +35,8 @@ uv run --isolated --extra $INFERENCE_BACKEND -m skyrl_train.entrypoints.main_bas
trainer.ckpt_interval=10 \
trainer.max_prompt_length=512 \
generator.sampling_params.max_generate_length=1024 \
trainer.policy.optimizer_config.lr=4.0e-6 \
trainer.algorithm.use_kl_loss=true \
trainer.policy.optimizer_config.lr=1.0e-6 \
trainer.algorithm.use_kl_loss=false \
generator.backend=$INFERENCE_BACKEND \
generator.run_engines_locally=true \
generator.weight_sync_backend=nccl \
Expand All @@ -47,7 +47,7 @@ uv run --isolated --extra $INFERENCE_BACKEND -m skyrl_train.entrypoints.main_bas
generator.gpu_memory_utilization=0.8 \
trainer.logger="$LOGGER" \
trainer.project_name="gsm8k_megatron" \
trainer.run_name="gsm8k_fsdp1_4gpus" \
trainer.run_name="gsm8k_fsdp1_4gpus_loss_sum" \
trainer.resume_mode=null \
trainer.ckpt_path="$HOME/ckpts/gsm8k_fsdp_ckpt" \
$@
21 changes: 13 additions & 8 deletions skyrl-train/examples/megatron/run_megatron.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,19 @@ MODEL_NAME="Qwen/Qwen3-0.6B"

INFERENCE_BACKEND="vllm" # currently only vllm is supported for megatron

MEGATRON_TP=2
MEGATRON_PP=2
MEGATRON_TP=1
MEGATRON_PP=1
MEGATRON_CP=1

# torch profiler config
ENABLE_TORCH_PROFILER=false
RANKS_TO_PROFILE="[0]"
SAVE_PATH="$HOME/megatron_prof/tp${MEGATRON_TP}_pp${MEGATRON_PP}_cp${MEGATRON_CP}_${MODEL_NAME}"
# # torch profiler config
# ENABLE_TORCH_PROFILER=false
# RANKS_TO_PROFILE="[0]"
# SAVE_PATH="$HOME/megatron_prof/tp${MEGATRON_TP}_pp${MEGATRON_PP}_cp${MEGATRON_CP}_${MODEL_NAME}"

# TIS_TYPE="token"
# TIS_RATIO_CLIP_HIGH=2.0
# trainer.algorithm.off_policy_correction.tis_ratio_type=$TIS_TYPE \
# trainer.algorithm.off_policy_correction.token_tis_ratio_clip_high=$TIS_RATIO_CLIP_HIGH \

uv run --isolated --extra mcore -m skyrl_train.entrypoints.main_base \
data.train_data="['$DATA_DIR/train.parquet']" \
Expand Down Expand Up @@ -56,7 +61,7 @@ uv run --isolated --extra mcore -m skyrl_train.entrypoints.main_base \
trainer.max_prompt_length=512 \
generator.sampling_params.max_generate_length=1024 \
trainer.policy.optimizer_config.lr=1.0e-6 \
trainer.algorithm.use_kl_loss=true \
trainer.algorithm.use_kl_loss=false \
generator.backend=$INFERENCE_BACKEND \
generator.run_engines_locally=true \
generator.weight_sync_backend=nccl \
Expand All @@ -67,7 +72,7 @@ uv run --isolated --extra mcore -m skyrl_train.entrypoints.main_base \
generator.gpu_memory_utilization=0.7 \
trainer.logger="$LOGGER" \
trainer.project_name="gsm8k_megatron" \
trainer.run_name="gsm8k_megatron_tp${MEGATRON_TP}_pp${MEGATRON_PP}_cp${MEGATRON_CP}_${MODEL_NAME}" \
trainer.run_name="gsm8k_megatron_tp${MEGATRON_TP}_pp${MEGATRON_PP}_cp${MEGATRON_CP}_${MODEL_NAME}_loss_sum_no_scaling_num_microbatches" \
trainer.resume_mode=null \
trainer.ckpt_path="$HOME/ckpts/gsm8k_megatron_ckpt" \
$@
121 changes: 121 additions & 0 deletions skyrl-train/examples/megatron/run_megatron_dapo_qwen3_1.7b.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
set -x

# Colocated DAPO training+generation for Qwen3-1.7B-Base on DAPO training data with Megatron.
# bash examples/algorithms/dapo/prepare_dapo_data.sh
# bash examples/megatron/run_megatron_dapo_qwen3_1.7b.sh

MODEL_NAME="Qwen/Qwen3-1.7B-Base"
DATA_DIR="$HOME/data/dapo"
TRAIN_FILE="$DATA_DIR/dapo-math-17k-cleaned.parquet"
TEST_FILE="$DATA_DIR/aime-2024-cleaned.parquet"
NUM_NODES=1
NUM_GPUS_PER_NODE=8
NUM_INFERENCE_ENGINES=8
INFERENCE_ENGINE_TENSOR_PARALLEL_SIZE=1
LOGGER="wandb" # change to "console" to print to stdout

CLIP_RATIO_LOW=0.2
CLIP_RATIO_HIGH=0.28
# use token mean loss reduction
LOSS_REDUCTION="token_mean"
# applies overlong filtering (but not soft overlong punishment)
APPLY_OVERLONG_FILTERING=true
# apply soft overlong punishment with custom trainer impl in main_dapo.py
OVERLONG_BUFFER_LEN=$((1024 * 4))
OVERLONG_BUFFER_PENALTY_FACTOR=1.0

# other DAPO parameters
USE_KL_LOSS=false
TEMPERATURE=1.0
TOP_P=1.0
EVAL_TOP_P=0.7
CLIP_RATIO_C=10.0
MAX_PROMPT_LENGTH=$((1024 * 2))
MAX_RESPONSE_LENGTH=$((1024 * 8))

# repro run parameters
TRAIN_BATCH_SIZE=512
MINI_BATCH_SIZE=32
N_SAMPLES_PER_PROMPT=16
EVAL_N_SAMPLES_PER_PROMPT=32
ENFORCE_EAGER=true # cuda graphs can cause some instability
LR=1e-6

# megatron config
MEGATRON_TP=4
MEGATRON_PP=2
MEGATRON_CP=1
MEGATRON_EP=1
MEGATRON_ETP=null


# TIS parameters
TIS_IMP_RATIO_CAP=2.0
TIS_TYPE=token

uv run --isolated --extra mcore -m examples.algorithms.dapo.main_dapo \
data.train_data="['$TRAIN_FILE']" \
data.val_data="['$TEST_FILE']" \
trainer.algorithm.advantage_estimator="grpo" \
trainer.algorithm.policy_loss_type="dual_clip" \
+trainer.algorithm.overlong_buffer.len=$OVERLONG_BUFFER_LEN \
+trainer.algorithm.overlong_buffer.penalty_factor=$OVERLONG_BUFFER_PENALTY_FACTOR \
trainer.algorithm.loss_reduction=$LOSS_REDUCTION \
generator.enforce_eager=$ENFORCE_EAGER \
generator.apply_overlong_filtering=$APPLY_OVERLONG_FILTERING \
generator.sampling_params.temperature=$TEMPERATURE \
generator.sampling_params.top_p=$TOP_P \
generator.eval_sampling_params.top_p=$EVAL_TOP_P \
generator.eval_sampling_params.temperature=$TEMPERATURE \
trainer.algorithm.use_kl_loss=$USE_KL_LOSS \
trainer.algorithm.clip_ratio_c=$CLIP_RATIO_C \
trainer.policy.model.path="$MODEL_NAME" \
trainer.placement.colocate_all=true \
trainer.strategy=megatron \
trainer.placement.policy_num_nodes=$NUM_NODES \
trainer.placement.policy_num_gpus_per_node=$NUM_GPUS_PER_NODE \
generator.num_inference_engines=$NUM_INFERENCE_ENGINES \
generator.inference_engine_tensor_parallel_size=$INFERENCE_ENGINE_TENSOR_PARALLEL_SIZE \
trainer.policy.megatron_config.tensor_model_parallel_size=$MEGATRON_TP \
trainer.policy.megatron_config.pipeline_model_parallel_size=$MEGATRON_PP \
trainer.policy.megatron_config.context_parallel_size=$MEGATRON_CP \
trainer.policy.megatron_config.expert_model_parallel_size=$MEGATRON_EP \
trainer.policy.megatron_config.expert_tensor_parallel_size=$MEGATRON_ETP \
trainer.algorithm.off_policy_correction.tis_ratio_type=$TIS_TYPE \
trainer.algorithm.off_policy_correction.token_tis_ratio_clip_high=$TIS_IMP_RATIO_CAP \
trainer.epochs=20 \
trainer.algorithm.eps_clip_low=$CLIP_RATIO_LOW \
trainer.algorithm.eps_clip_high=$CLIP_RATIO_HIGH \
trainer.eval_batch_size=1024 \
trainer.eval_before_train=true \
trainer.eval_interval=5 \
trainer.update_epochs_per_batch=1 \
trainer.train_batch_size=$TRAIN_BATCH_SIZE \
trainer.policy_mini_batch_size=$MINI_BATCH_SIZE \
trainer.micro_forward_batch_size_per_gpu=8 \
trainer.micro_train_batch_size_per_gpu=8 \
trainer.ckpt_interval=10 \
trainer.max_prompt_length=$MAX_PROMPT_LENGTH \
generator.sampling_params.max_generate_length=$MAX_RESPONSE_LENGTH \
trainer.policy.optimizer_config.lr=$LR \
trainer.policy.optimizer_config.num_warmup_steps=160 \
trainer.policy.optimizer_config.weight_decay=0.1 \
trainer.policy.optimizer_config.max_grad_norm=1.0 \
generator.backend=vllm \
generator.run_engines_locally=true \
generator.weight_sync_backend=nccl \
generator.async_engine=false \
generator.batched=true \
environment.env_class=aime \
generator.n_samples_per_prompt=$N_SAMPLES_PER_PROMPT \
generator.eval_n_samples_per_prompt=$EVAL_N_SAMPLES_PER_PROMPT \
generator.gpu_memory_utilization=0.8 \
trainer.logger="$LOGGER" \
trainer.project_name="dapo_aime" \
trainer.run_name="dapo_qwen3_1.7b_base_megatron_tp${MEGATRON_TP}_pp${MEGATRON_PP}_cp${MEGATRON_CP}_loss_sum_dp1" \
trainer.export_path="$HOME/exports/dapo_qwen3_1.7b_base_megatron_tp${MEGATRON_TP}_pp${MEGATRON_PP}_cp${MEGATRON_CP}_loss_sum_dp1" \
trainer.hf_save_interval=300 \
trainer.resume_mode=latest \
trainer.max_ckpts_to_keep=3 \
trainer.ckpt_path="$HOME/ckpts/dapo_qwen3_1.7b_base_megatron_tp${MEGATRON_TP}_pp${MEGATRON_PP}_cp${MEGATRON_CP}_loss_sum_dp1" \
$@
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def compute_importance_sampling_policy_loss(
# as defined here: https://tinker-docs.thinkingmachines.ai/losses#policy-gradient-importance_sampling
loss = -torch.exp(log_probs - old_log_probs) * advantages

loss = reduce_loss(loss, loss_mask, "seq_mean_token_sum_norm", config.max_seq_len)
loss = reduce_loss(loss, loss_mask)
# return loss and a dummy clip ratio value as we aren't clipping here
return loss, {"clip_ratio": 0.0}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ uv run --isolated --extra vllm -m examples.on_policy_distillation.main_on_policy
trainer.policy.optimizer_config.weight_decay=0.1 \
trainer.algorithm.use_kl_loss=$USE_KL_LOSS \
trainer.algorithm.use_kl_in_reward=$USE_KL_IN_REWARD \
trainer.algorithm.loss_reduction="seq_mean_token_sum_norm" \
generator.backend=vllm \
generator.run_engines_locally=true \
generator.async_engine=false \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ uv run --isolated --extra vllm -m examples.on_policy_distillation.main_on_policy
trainer.policy.optimizer_config.weight_decay=0.1 \
trainer.algorithm.use_kl_loss=$USE_KL_LOSS \
trainer.algorithm.use_kl_in_reward=$USE_KL_IN_REWARD \
trainer.algorithm.loss_reduction="seq_mean_token_sum_norm" \
generator.backend=vllm \
generator.run_engines_locally=true \
generator.async_engine=false \
Expand Down
17 changes: 10 additions & 7 deletions skyrl-train/skyrl_train/distributed/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,11 @@ def get_rank(self) -> int:
"""Get current process rank"""
return dist.get_rank()

def all_reduce(self, data: DataT, op="mean") -> DataT:
def all_reduce(self, data: DataT, op="mean", group=None) -> DataT:
"""Perform all_reduce across all processes"""
assert op in ("mean", "max", "sum", "min")
if isinstance(data, dict):
return {k: self.all_reduce(v, op) for k, v in data.items()}
return {k: self.all_reduce(v, op, group) for k, v in data.items()}
else:
is_tensor = True
if not isinstance(data, torch.Tensor):
Expand All @@ -82,14 +82,17 @@ def all_reduce(self, data: DataT, op="mean") -> DataT:
if is_cpu_tensor:
data = data.to(torch.cuda.current_device())
if op == "mean":
data /= self.world_size
dist.all_reduce(data, op=dist.ReduceOp.SUM)
if group is None:
data /= self.world_size
else:
data /= group.size()
dist.all_reduce(data, op=dist.ReduceOp.SUM, group=group)
elif op == "max":
dist.all_reduce(data, op=dist.ReduceOp.MAX)
dist.all_reduce(data, op=dist.ReduceOp.MAX, group=group)
elif op == "min":
dist.all_reduce(data, op=dist.ReduceOp.MIN)
dist.all_reduce(data, op=dist.ReduceOp.MIN, group=group)
elif op == "sum":
dist.all_reduce(data, op=dist.ReduceOp.SUM)
dist.all_reduce(data, op=dist.ReduceOp.SUM, group=group)
if is_cpu_tensor:
data = data.cpu()
return data.item() if not is_tensor else data
Expand Down
4 changes: 0 additions & 4 deletions skyrl-train/skyrl_train/fully_async_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from skyrl_train.trainer import RayPPOTrainer
from tqdm import tqdm
from skyrl_train.utils import Timer
from skyrl_train.utils.ppo_utils import normalize_advantages_dict
from skyrl_train.training_batch import TrainingInputBatch
from skyrl_train.generators.base import GeneratorOutput
from skyrl_train.utils.trainer_utils import ResumeMode, build_dataloader
Expand Down Expand Up @@ -511,9 +510,6 @@ async def _run_training(self, training_input: TrainingInputBatch):
training_input.pop(key)
training_input.metadata.pop("uids")

if self.cfg.trainer.algorithm.advantage_batch_normalize:
training_input = normalize_advantages_dict(training_input)

if self.cfg.trainer.dump_data_batch:
# dump data to file
with Timer("dump_data_batch"):
Expand Down
Loading
Loading