diff --git a/skyrl-train/skyrl_train/utils/ppo_utils.py b/skyrl-train/skyrl_train/utils/ppo_utils.py index d814aedd7..2e1de2858 100644 --- a/skyrl-train/skyrl_train/utils/ppo_utils.py +++ b/skyrl-train/skyrl_train/utils/ppo_utils.py @@ -19,7 +19,7 @@ from collections import defaultdict from enum import StrEnum from functools import wraps -from typing import Callable, List, Literal, Optional, Tuple, Union +from typing import Any, Callable, List, Literal, Optional, Tuple, Union import numpy as np import ray @@ -430,6 +430,7 @@ class AdvantageEstimator(StrEnum): GRPO = "grpo" RLOO = "rloo" REINFORCE_PP = "reinforce++" + GDPO = "gdpo" class AdvantageEstimatorRegistry(BaseFunctionRegistry): @@ -456,6 +457,7 @@ def repopulate_registry(cls): "gae": [AdvantageEstimator.GAE, compute_gae_advantage_return], "rloo": [AdvantageEstimator.RLOO, compute_rloo_outcome_advantage], "reinforce++": [AdvantageEstimator.REINFORCE_PP, compute_reinforce_plus_plus_outcome_advantage], + "gdpo": [AdvantageEstimator.GDPO, compute_gdpo_outcome_advantage], } for ae_name, (ae_type, ae_func) in ae_types.items(): @@ -1079,6 +1081,76 @@ def compute_grpo_outcome_advantage( return scores, scores +@register_advantage_estimator(AdvantageEstimator.GDPO) +def compute_gdpo_outcome_advantage( + token_level_rewards: torch.Tensor, + response_mask: torch.Tensor, + index: np.ndarray, + epsilon: float = 1e-6, + grpo_norm_by_std: bool = True, + **kwargs, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Compute advantage for GDPO, (see https://arxiv.org/pdf/2601.05242 for details) + + Expects: + - token_level_rewards: Float[torch.Tensor, "batch_size seqlen num_objectives"] + - response_mask: Float[torch.Tensor, "batch_size seqlen"] + - index: np.ndarray (batch_size) + - epsilon: float + - grpo_norm_by_std: bool (used for GDPO as well) + + Returns: + - advantages: Float[torch.Tensor, "batch_size seqlen"] + - returns: Float[torch.Tensor, "batch_size seqlen"] + """ + + gdpo_norm_by_std = grpo_norm_by_std + + # this assumes reward-level rewards assigned as well as single scalar outcome reward for each response + scores = token_level_rewards.sum(dim=-2) + id2score = defaultdict[Any, list](list) + id2pos = defaultdict(list) + + with torch.no_grad(): + bsz = scores.shape[0] + advantage_propogated = torch.empty((bsz,), device=scores.device, dtype=scores.dtype) + for i in range(bsz): + id2score[index[i]].append(scores[i]) # prompt index -> rollouts x objectives + id2pos[index[i]].append(i) # prompt index -> rollout indices for batch norm + + for idx in id2score: # per batch + reward_fn_scores = torch.stack(id2score[idx]) # all objective scores for batch -> rollouts x objectives + if reward_fn_scores.shape[0] == 1: + rwd_fn_means = torch.zeros(reward_fn_scores.shape[1]) + rwd_fn_stds = torch.ones(reward_fn_scores.shape[1]) + elif reward_fn_scores.shape[0] > 1: + rwd_fn_means = torch.mean(reward_fn_scores, dim=0) + rwd_fn_stds = torch.std(reward_fn_scores, dim=0, unbiased=False) + if gdpo_norm_by_std: + reward_fn_scores = (reward_fn_scores - rwd_fn_means) / (rwd_fn_stds + epsilon) + else: + reward_fn_scores = reward_fn_scores - rwd_fn_means + + id2score[idx] = reward_fn_scores.sum( + dim=-1 + ) # sum over objectives to get unnormalized advantage per rollout + + positions = torch.tensor(id2pos[idx], device=scores.device, dtype=torch.long) + advantage_propogated.index_copy_(0, positions, id2score[idx]) + + # batch norm the advantages + batch_mean = advantage_propogated.mean() + if gdpo_norm_by_std: + batch_std = advantage_propogated.std(unbiased=False) + a_hat = (advantage_propogated - batch_mean) / (batch_std + epsilon) + else: + a_hat = advantage_propogated - batch_mean + advantages = a_hat.unsqueeze(-1) * response_mask + + return advantages, advantages + + def repopulate_all_registries(): PolicyLossRegistry.repopulate_registry() AdvantageEstimatorRegistry.repopulate_registry() diff --git a/skyrl-train/tests/cpu/utils/test_ppo_utils.py b/skyrl-train/tests/cpu/utils/test_ppo_utils.py index fb69ce15e..5b6070efc 100644 --- a/skyrl-train/tests/cpu/utils/test_ppo_utils.py +++ b/skyrl-train/tests/cpu/utils/test_ppo_utils.py @@ -11,10 +11,12 @@ compute_approx_kl, compute_gae_advantage_return, compute_grpo_outcome_advantage, + compute_gdpo_outcome_advantage, compute_advantages_and_returns, AdaptiveKLController, FixedKLController, AdvantageEstimatorRegistry, + AdvantageEstimator, register_advantage_estimator, PolicyLossRegistry, register_policy_loss, @@ -172,6 +174,65 @@ def test_compute_grpo_outcome_advantage_norm_std_false(): assert torch.allclose(adv, expected, atol=1e-5), f"Expected {expected}, got {adv}" +def test_compute_gdpo_outcome_advantage_shape_and_alias_norm_flag(): + """ + GDPO expects multi-objective rewards (B, T, K) and should be configurable via the existing + `grpo_norm_by_std` flag (aliased internally to GDPO to avoid redundant config). + + """ + token_level_rewards = torch.zeros((4, 3, 2), dtype=torch.float) + token_level_rewards[0, -1] = torch.tensor([1.0, 2.0]) + token_level_rewards[1, -1] = torch.tensor([2.0, 0.0]) + token_level_rewards[2, -1] = torch.tensor([0.0, 1.0]) + token_level_rewards[3, -1] = torch.tensor([1.0, 1.0]) + + response_mask = torch.ones((4, 3), dtype=torch.float) + index = np.array([0, 0, 1, 1]) + + adv_direct, ret_direct = compute_gdpo_outcome_advantage( + token_level_rewards=token_level_rewards, + response_mask=response_mask, + index=index, + grpo_norm_by_std=False, + ) + assert adv_direct.shape == response_mask.shape + assert torch.allclose(adv_direct, ret_direct), "Advantages and returns should be equal with GDPO" + + expected_adv = torch.tensor( + [ + [0.5, 0.5, 0.5], + [-0.5, -0.5, -0.5], + [-0.5, -0.5, -0.5], + [0.5, 0.5, 0.5], + ] + ) + assert torch.allclose(adv_direct, expected_adv, atol=1e-5), f"Expected {expected_adv}, got {adv_direct}" + + from omegaconf import OmegaConf + + cfg = OmegaConf.create({}) + adv_false, _ = compute_advantages_and_returns( + token_level_rewards=token_level_rewards, + response_mask=response_mask, + index=index, + adv_estimator=AdvantageEstimator.GDPO, + config=cfg, + grpo_norm_by_std=False, + ) + adv_true, _ = compute_advantages_and_returns( + token_level_rewards=token_level_rewards, + response_mask=response_mask, + index=index, + adv_estimator=AdvantageEstimator.GDPO, + config=cfg, + grpo_norm_by_std=True, + ) + assert adv_false.shape == response_mask.shape + assert adv_true.shape == response_mask.shape + assert torch.allclose(adv_false, expected_adv, atol=1e-5), "Direct call should match compute_advantages_and_returns" + assert not torch.allclose(adv_false, adv_true), "Toggling grpo_norm_by_std should affect GDPO advantages" + + def test_compute_gae_advantage_return(advantage_test_data): rewards, values, response_mask, index = advantage_test_data