Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions skyrl-train/skyrl_train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1067,26 +1067,28 @@ def train_critic_and_policy(self, data: TrainingInputBatch):
"""
Run the training step for the policy and critic models.

For Megatron strategy: uses ppo_train (training loop inside worker)
For FSDP strategy: uses forward_backward + optim_step (training loop in trainer)
For Megatron: Uses ppo_train via dispatch.
For FSDP/FSDP2: Uses forward_backward + optim_step via dispatch.

Dispatch handles offload/backload automatically when colocate_all=True.
"""
data.metadata["global_step"] = self.global_step
critic_status = None

if self.cfg.trainer.strategy == "megatron":
# Megatron: training loop inside worker via ppo_train
# Megatron: use ppo_train via dispatch
if self.has_critic:
with Timer("critic_train", self.all_timings):
critic_status = self.dispatch.ppo_train("critic", data)
with Timer("policy_train", self.all_timings):
policy_status = self.dispatch.ppo_train("policy", data)
else:
# FSDP: training loop in trainer via forward_backward + optim_step
# FSDP/FSDP2: use forward_backward + optim_step via dispatch
if self.has_critic:
with Timer("critic_train", self.all_timings):
critic_status = self._execute_training_step("critic", data)
critic_status = self._execute_training_step("critic", data, "critic")
with Timer("policy_train", self.all_timings):
policy_status = self._execute_training_step("policy", data)
policy_status = self._execute_training_step("policy", data, "policy")
Comment on lines 1088 to 1091
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There appears to be a critical issue with the arguments passed to _execute_training_step for both the critic and policy models. The function signature for _execute_training_step is (self, model: str, data: TrainingInputBatch), but it's being called with three arguments here (e.g., self._execute_training_step("critic", data, "critic")). This will result in a TypeError at runtime.

While the intent seems to be to pass a loss_fn, the implementation appears incomplete. Specifically:

  1. The signature of _execute_training_step hasn't been updated to accept a third argument.
  2. Even if it were updated, the critic training path would likely fail. The loss_fn would be "critic", which is not handled by PolicyWorkerBase._get_loss_fn, and CriticWorkerBase doesn't have a comparable method to handle a parameterized loss function.

To fix this, you'll need to update the signature of _execute_training_step and ensure that both policy and critic workers can correctly handle the new loss_fn parameter. For the critic, you might want to pass None as the loss_fn if it's not meant to be parameterized, and handle that case in _execute_training_step.


# Update metrics
if critic_status is not None:
Expand All @@ -1096,6 +1098,7 @@ def train_critic_and_policy(self, data: TrainingInputBatch):
for k, v in policy_status.items():
self.all_metrics.update({f"policy/{k}": v})

# Empty cache after training
self.dispatch.empty_cache()

return policy_status
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -552,7 +552,6 @@ def ppo_train(self, train_data) -> "TrainingOutputBatch":
# TODO: Convert this into 2 loops for minibatches and microbatches.
micro_buffer = []
for local_step, experience in enumerate(pbar):
# BatchIterator now yields Experience objects directly
experience.to_device(torch.cuda.current_device())
sequences = experience.sequences
attention_mask = experience.attention_mask
Expand Down
47 changes: 39 additions & 8 deletions skyrl-train/skyrl_train/workers/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,6 +630,27 @@ async def async_run_method(


class PolicyWorkerBase(Worker):
# TODO(tgriggs): Remove once loss function naming is unified.
# Tinker loss_fn names -> SkyRL PolicyLossRegistry names
TINKER_LOSS_FN_MAP = {"ppo": "regular"}

@staticmethod
def convert_tinker_loss_config(loss_fn_config: Dict[str, Any]) -> Dict[str, Any]:
"""Convert Tinker loss_fn_config to SkyRL algorithm config format.

Tinker uses absolute ratio bounds (e.g., 0.9, 1.1).
SkyRL uses offsets from 1.0 (e.g., 0.1, 0.1).
"""
skyrl_config = {}
for k, v in loss_fn_config.items():
if k == "clip_low_threshold":
skyrl_config["eps_clip_low"] = 1.0 - v # 0.9 -> 0.1
elif k == "clip_high_threshold":
skyrl_config["eps_clip_high"] = v - 1.0 # 1.1 -> 0.1
else:
skyrl_config[k] = v
return skyrl_config

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.model: nn.Module = None
Expand All @@ -647,17 +668,30 @@ def _normalize_mini_batch_size(self):
The worker no longer needs to know mini batch size - it processes whatever
batch it receives, breaking it into micro batches. Gradient scaling happens
at optim_step time based on how many micro batches were accumulated.

TODO: Rename to _init_gradient_accumulation_state once Megatron no longer
requires mini-batch normalization in its override. The name is kept for
backwards compatibility with Megatron which still does actual normalization.
"""
if not hasattr(self, "mesh_rank") or self.mesh_rank is None:
raise RuntimeError("mesh_rank must be initialized before calling _normalize_mini_batch_size()")

# Track micro batches for gradient scaling at optim_step
self._micro_batches_accumulated = 0

dp_size = self.mesh_rank.dp_size
self.policy_mini_batch_size_per_gpu = (
self.cfg.trainer.policy_mini_batch_size * self.cfg.generator.n_samples_per_prompt // dp_size
)

def _get_loss_fn(self, loss_fn: Optional[str] = None) -> Callable:
"""Get loss function from Tinker name or fall back to config."""
if loss_fn is None:
name = self.cfg.trainer.algorithm.policy_loss_type
elif loss_fn in self.TINKER_LOSS_FN_MAP:
name = self.TINKER_LOSS_FN_MAP[loss_fn]
else:
raise ValueError(
f"loss_fn '{loss_fn}' not yet supported. Supported: {list(self.TINKER_LOSS_FN_MAP.keys())}"
)
return PolicyLossRegistry.get(name)

def forward_backward(self, data: TrainingInputBatch) -> Dict[str, float]:
"""
Perform forward and backward passes for a batch, handling micro-batching internally.
Expand Down Expand Up @@ -758,6 +792,7 @@ def _forward_backward_micro(self, experience: Experience) -> Dict[str, float]:
kl_loss_term = kl_loss * self.cfg.trainer.algorithm.kl_loss_coef

loss = policy_loss + kl_loss_term - entropy_loss_term
# NO loss scaling here - gradient scaling happens at optim_step
self.strategy.backward(loss, self.model, self.optimizer)

status = {
Expand Down Expand Up @@ -894,10 +929,6 @@ def _normalize_mini_batch_size(self):
The worker no longer needs to know mini batch size - it processes whatever
batch it receives, breaking it into micro batches. Gradient scaling happens
at optim_step time based on how many micro batches were accumulated.

TODO: Rename to _init_gradient_accumulation_state once Megatron no longer
requires mini-batch normalization in its override. The name is kept for
backwards compatibility with Megatron which still does actual normalization.
"""
if not hasattr(self, "mesh_rank") or self.mesh_rank is None:
raise RuntimeError("mesh_rank must be initialized before calling _normalize_mini_batch_size()")
Expand Down
4 changes: 3 additions & 1 deletion skyrl-train/skyrl_train/workers/worker_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
"""

from dataclasses import dataclass
from typing import Dict, List, Optional
from typing import Any, Dict, List, Optional

import ray
from omegaconf import DictConfig
Expand Down Expand Up @@ -153,6 +153,8 @@ def forward(self, model: str, data: TrainingInputBatch) -> TrainingOutputBatch:
output = concatenate_outputs_after_mesh_dispatch(self._actor_groups[model].actor_infos, results)
return output

# === Training ===

def forward_backward(self, model: str, data: TrainingInputBatch) -> Dict[str, float]:
"""Run forward/backward pass. Needs model + optimizer."""
self._ensure_on_gpu(model, need_optimizer=True, need_model=True)
Expand Down
Loading
Loading