Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 73 additions & 1 deletion skyrl-train/skyrl_train/utils/ppo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from collections import defaultdict
from enum import StrEnum
from functools import wraps
from typing import Callable, List, Literal, Optional, Tuple, Union
from typing import Any, Callable, List, Literal, Optional, Tuple, Union

import numpy as np
import ray
Expand Down Expand Up @@ -430,6 +430,7 @@ class AdvantageEstimator(StrEnum):
GRPO = "grpo"
RLOO = "rloo"
REINFORCE_PP = "reinforce++"
GDPO = "gdpo"


class AdvantageEstimatorRegistry(BaseFunctionRegistry):
Expand All @@ -456,6 +457,7 @@ def repopulate_registry(cls):
"gae": [AdvantageEstimator.GAE, compute_gae_advantage_return],
"rloo": [AdvantageEstimator.RLOO, compute_rloo_outcome_advantage],
"reinforce++": [AdvantageEstimator.REINFORCE_PP, compute_reinforce_plus_plus_outcome_advantage],
"gdpo": [AdvantageEstimator.GDPO, compute_gdpo_outcome_advantage],
}

for ae_name, (ae_type, ae_func) in ae_types.items():
Expand Down Expand Up @@ -1079,6 +1081,76 @@ def compute_grpo_outcome_advantage(
return scores, scores


@register_advantage_estimator(AdvantageEstimator.GDPO)
def compute_gdpo_outcome_advantage(
token_level_rewards: torch.Tensor,
response_mask: torch.Tensor,
index: np.ndarray,
epsilon: float = 1e-6,
grpo_norm_by_std: bool = True,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Compute advantage for GDPO, (see https://arxiv.org/pdf/2601.05242 for details)

Expects:
- token_level_rewards: Float[torch.Tensor, "batch_size seqlen num_objectives"]
- response_mask: Float[torch.Tensor, "batch_size seqlen"]
- index: np.ndarray (batch_size)
- epsilon: float
- grpo_norm_by_std: bool (used for GDPO as well)

Returns:
- advantages: Float[torch.Tensor, "batch_size seqlen"]
- returns: Float[torch.Tensor, "batch_size seqlen"]
"""

gdpo_norm_by_std = grpo_norm_by_std

# this assumes reward-level rewards assigned as well as single scalar outcome reward for each response
scores = token_level_rewards.sum(dim=-2)
id2score = defaultdict[Any, list](list)
id2pos = defaultdict(list)

with torch.no_grad():
bsz = scores.shape[0]
advantage_propogated = torch.empty((bsz,), device=scores.device, dtype=scores.dtype)
for i in range(bsz):
id2score[index[i]].append(scores[i]) # prompt index -> rollouts x objectives
id2pos[index[i]].append(i) # prompt index -> rollout indices for batch norm

for idx in id2score: # per batch
reward_fn_scores = torch.stack(id2score[idx]) # all objective scores for batch -> rollouts x objectives
if reward_fn_scores.shape[0] == 1:
rwd_fn_means = torch.zeros(reward_fn_scores.shape[1])
rwd_fn_stds = torch.ones(reward_fn_scores.shape[1])
elif reward_fn_scores.shape[0] > 1:
rwd_fn_means = torch.mean(reward_fn_scores, dim=0)
rwd_fn_stds = torch.std(reward_fn_scores, dim=0, unbiased=False)
if gdpo_norm_by_std:
reward_fn_scores = (reward_fn_scores - rwd_fn_means) / (rwd_fn_stds + epsilon)
else:
reward_fn_scores = reward_fn_scores - rwd_fn_means

id2score[idx] = reward_fn_scores.sum(
dim=-1
) # sum over objectives to get unnormalized advantage per rollout

positions = torch.tensor(id2pos[idx], device=scores.device, dtype=torch.long)
advantage_propogated.index_copy_(0, positions, id2score[idx])

# batch norm the advantages
batch_mean = advantage_propogated.mean()
if gdpo_norm_by_std:
batch_std = advantage_propogated.std(unbiased=False)
a_hat = (advantage_propogated - batch_mean) / (batch_std + epsilon)
else:
a_hat = advantage_propogated - batch_mean
advantages = a_hat.unsqueeze(-1) * response_mask

return advantages, advantages


def repopulate_all_registries():
PolicyLossRegistry.repopulate_registry()
AdvantageEstimatorRegistry.repopulate_registry()
Expand Down
61 changes: 61 additions & 0 deletions skyrl-train/tests/cpu/utils/test_ppo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@
compute_approx_kl,
compute_gae_advantage_return,
compute_grpo_outcome_advantage,
compute_gdpo_outcome_advantage,
compute_advantages_and_returns,
AdaptiveKLController,
FixedKLController,
AdvantageEstimatorRegistry,
AdvantageEstimator,
register_advantage_estimator,
PolicyLossRegistry,
register_policy_loss,
Expand Down Expand Up @@ -172,6 +174,65 @@ def test_compute_grpo_outcome_advantage_norm_std_false():
assert torch.allclose(adv, expected, atol=1e-5), f"Expected {expected}, got {adv}"


def test_compute_gdpo_outcome_advantage_shape_and_alias_norm_flag():
"""
GDPO expects multi-objective rewards (B, T, K) and should be configurable via the existing
`grpo_norm_by_std` flag (aliased internally to GDPO to avoid redundant config).

"""
token_level_rewards = torch.zeros((4, 3, 2), dtype=torch.float)
token_level_rewards[0, -1] = torch.tensor([1.0, 2.0])
token_level_rewards[1, -1] = torch.tensor([2.0, 0.0])
token_level_rewards[2, -1] = torch.tensor([0.0, 1.0])
token_level_rewards[3, -1] = torch.tensor([1.0, 1.0])

response_mask = torch.ones((4, 3), dtype=torch.float)
index = np.array([0, 0, 1, 1])

adv_direct, ret_direct = compute_gdpo_outcome_advantage(
token_level_rewards=token_level_rewards,
response_mask=response_mask,
index=index,
grpo_norm_by_std=False,
)
assert adv_direct.shape == response_mask.shape
assert torch.allclose(adv_direct, ret_direct), "Advantages and returns should be equal with GDPO"

expected_adv = torch.tensor(
[
[0.5, 0.5, 0.5],
[-0.5, -0.5, -0.5],
[-0.5, -0.5, -0.5],
[0.5, 0.5, 0.5],
]
)
assert torch.allclose(adv_direct, expected_adv, atol=1e-5), f"Expected {expected_adv}, got {adv_direct}"

from omegaconf import OmegaConf

cfg = OmegaConf.create({})
adv_false, _ = compute_advantages_and_returns(
token_level_rewards=token_level_rewards,
response_mask=response_mask,
index=index,
adv_estimator=AdvantageEstimator.GDPO,
config=cfg,
grpo_norm_by_std=False,
)
adv_true, _ = compute_advantages_and_returns(
token_level_rewards=token_level_rewards,
response_mask=response_mask,
index=index,
adv_estimator=AdvantageEstimator.GDPO,
config=cfg,
grpo_norm_by_std=True,
)
assert adv_false.shape == response_mask.shape
assert adv_true.shape == response_mask.shape
assert torch.allclose(adv_false, expected_adv, atol=1e-5), "Direct call should match compute_advantages_and_returns"
assert not torch.allclose(adv_false, adv_true), "Toggling grpo_norm_by_std should affect GDPO advantages"


def test_compute_gae_advantage_return(advantage_test_data):
rewards, values, response_mask, index = advantage_test_data

Expand Down