From c5997a38eed422059e81ec121eaf13f4c62955ac Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 16 Oct 2025 21:40:42 +0100 Subject: [PATCH] Update [ghstack-poisoned] --- torchrl/objectives/llm/grpo.py | 159 ++++++++++++--------------------- 1 file changed, 55 insertions(+), 104 deletions(-) diff --git a/torchrl/objectives/llm/grpo.py b/torchrl/objectives/llm/grpo.py index 08d93ae4c58..5fd95dcb3b9 100644 --- a/torchrl/objectives/llm/grpo.py +++ b/torchrl/objectives/llm/grpo.py @@ -24,7 +24,6 @@ ProbabilisticTensorDictSequential, set_composite_lp_aggregate, ) -from tensordict.utils import expand_as_right from torch import distributions as d from torchrl._utils import logger as torchrl_logger, VERBOSE from torchrl.envs.transforms.transforms import Transform @@ -78,8 +77,10 @@ class GRPOLoss(LossModule): The masking strategy must match the strategy used for advantage computation to avoid shape mismatches. Keyword Args: - clip_epsilon (scalar, optional): weight clipping threshold in the clipped PPO loss equation. - default: 0.2 + clip_epsilon (float | tuple[float, float], optional): clipping threshold(s) for the clipped surrogate. + - float x: symmetric clipping [1 - x, 1 + x] (default: 0.2) + - tuple (eps_low, eps_high): asymmetric clipping [1 - eps_low, 1 + eps_high] as in DAPO Clip-Higher + recommended defaults from DAPO: (0.20, 0.28); see Eq. (10) in the paper. entropy_bonus (bool, optional): if ``True``, an entropy bonus will be added to the loss to favour exploratory policies. samples_mc_entropy (int, optional): if the distribution retrieved from the policy @@ -113,6 +114,12 @@ class GRPOLoss(LossModule): - "generic": Use attention masking (all valid tokens) Defaults to "sft" since we can't guarantee assistant masks are available. + .. note:: DAPO defaults (for reference): + - Clip-Higher asymmetric thresholds: (eps_low, eps_high) = (0.20, 0.28) + - Token-level policy gradient loss is recommended for long CoT + - Dynamic sampling filters trivial advantage groups + See DAPO [arXiv](https://arxiv.org/html/2503.14476). + .. note:: Parameters and buffers from the policy / critic will not be cast to that device to ensure that the storages match the ones that are passed to other components, such as data collectors. """ @@ -136,7 +143,7 @@ def __init__( self, actor_network: LLMWrapperBase | None = None, *, - clip_epsilon: float = 0.2, + clip_epsilon: float | tuple[float, float] = 0.2, entropy_bonus: bool = True, samples_mc_entropy: int = 1, entropy_coeff: float = 0.01, @@ -165,7 +172,28 @@ def __init__( device = getattr( torch, "get_default_device", lambda: torch.device("cpu") )() - self.register_buffer("clip_epsilon", torch.tensor(clip_epsilon, device=device)) + # Accept symmetric or asymmetric thresholds + if isinstance(clip_epsilon, (tuple, list)): + if len(clip_epsilon) != 2: + raise ValueError( + f"clip_epsilon tuple must have length 2, got {clip_epsilon}." + ) + eps_low, eps_high = clip_epsilon + else: + eps_low = float(clip_epsilon) + eps_high = float(clip_epsilon) + # Basic validation + if eps_low < 0 or eps_high < 0: + raise ValueError( + f"clip_epsilon values must be non-negative, got ({eps_low}, {eps_high})." + ) + if eps_low >= 1.0: + raise ValueError( + f"clip_epsilon low must be < 1 (to keep 1 - eps_low > 0), got {eps_low}." + ) + # Register buffers + self.register_buffer("clip_epsilon_low", torch.tensor(eps_low, device=device)) + self.register_buffer("clip_epsilon_high", torch.tensor(eps_high, device=device)) self.masking_strategy = masking_strategy # Defaults for keys @@ -178,7 +206,11 @@ def __init__( @property def _clip_bounds(self): - return ((-self.clip_epsilon).log1p(), self.clip_epsilon.log1p()) + # Returns (log(1 - eps_low), log(1 + eps_high)) for clamping log-weight + return ( + (-self.clip_epsilon_low).log1p(), + self.clip_epsilon_high.log1p(), + ) def _set_in_keys(self): keys = [] @@ -321,6 +353,7 @@ def forward(self, tensordict: TensorDictBase) -> GRPOLossOutput: ratio = log_weight_clip.exp() gain2 = ratio * advantage + # Token-level objective: compute min over clipped/unclipped at the token level gain = torch.stack([gain1, gain2], -1).min(dim=-1).values td_out = TensorDict({"loss_objective": -gain}) td_out.set("clip_fraction", clip_fraction) @@ -407,109 +440,27 @@ def _get_entropy( entropy.batch_size = adv_shape return entropy.unsqueeze(-1) - def _kl_to_ref( - self, - tensordict: TensorDictBase, - key: NestedKey = ("next", "ref_log_prob"), - ref_log_prob: torch.Tensor | None = None, - coeff: float | None = None, - mask: torch.Tensor | None = None, - dist: d.Distribution | None = None, - ): - if coeff is None: - coeff = self.kl_to_ref_coeff - # TODO: customize this - if ref_log_prob is None: - ref_log_prob = tensordict.get( - key, - as_padded_tensor=True, - padding_side="left", - padding_value=0.0, - ) - if ref_log_prob is None: - raise KeyError( - f"Couldn't find the ref log-prob {key} in the input data ({tensordict.keys(True)=})." - ) - ref_log_prob = ref_log_prob.squeeze(-1) - cur_log_prob = tensordict.get("_cur_log_prob") - # TODO: remove this - if cur_log_prob.shape != ref_log_prob.shape: - raise ValueError( - f"cur_log_prob and ref_log_prob must have the same shape, got {cur_log_prob.shape=} and {ref_log_prob.shape=}" - ) - if mask is not None: - ref_log_prob = torch.where( - expand_as_right(mask, ref_log_prob), ref_log_prob, 0.0 - ) - cur_log_prob = torch.where( - expand_as_right(mask, cur_log_prob), cur_log_prob, 0.0 - ) - diff = ref_log_prob - cur_log_prob - kl_penalty = (diff.expm1() - diff).mean() - return coeff * kl_penalty, kl_penalty - - def _log_weight( - self, tensordict: TensorDictBase, adv_shape: torch.Size - ) -> tuple[torch.Tensor, d.Distribution, torch.Tensor]: - cur_log_prob, dist, is_composite = self._get_cur_log_prob(tensordict) +class DAPO(GRPOLoss): + """DAPO (Clip-Higher over GRPO). - prev_log_prob = tensordict.get( - self.tensor_keys.sample_log_prob, - as_padded_tensor=True, - padding_side="left", - padding_value=0.0, - ) - - if prev_log_prob is None: - raise KeyError( - f"Couldn't find the log-prob {self.tensor_keys.sample_log_prob} in the input data." - ) - if prev_log_prob.requires_grad: - raise RuntimeError( - f"tensordict stored {self.tensor_keys.sample_log_prob} requires grad." - ) - - # Check for shape mismatches and provide helpful error messages - if cur_log_prob.shape != prev_log_prob.shape: - # Try to provide helpful debugging information - error_msg = ( - f"Shape mismatch detected in GRPOLoss: current log-prob shape {cur_log_prob.shape} " - f"!= previous log-prob shape {prev_log_prob.shape}. " - f"This usually indicates a mismatch between the masking strategy used for " - f"advantage computation and the masking strategy used for loss computation.\n" - f"Current masking strategy: '{self.masking_strategy}'\n" - f"Possible solutions:\n" - f"1. If using RLHF (multi-turn conversations), set masking_strategy='rlhf'\n" - f"2. If using SFT (single-turn conversations), set masking_strategy='sft'\n" - f"3. If using generic scenarios, set masking_strategy='generic'\n" - f"4. Ensure the advantage was computed with the same masking strategy as the loss" - ) - raise ValueError(error_msg) + Validates asymmetric clip thresholds; recommended (0.20, 0.28), see Eq. (10) in DAPO + [arXiv](https://arxiv.org/html/2503.14476). + """ - attention_mask = dist.mask - cur_log_prob = torch.where( - expand_as_right(attention_mask, cur_log_prob), cur_log_prob, 0.0 - ) - prev_log_prob = torch.where( - expand_as_right(attention_mask, prev_log_prob), prev_log_prob, 0.0 + def __init__( + self, + actor_network: LLMWrapperBase | None = None, + *, + clip_epsilon: tuple[float, float] = (0.20, 0.28), + **kwargs, + ): + if not (isinstance(clip_epsilon, (tuple, list)) and len(clip_epsilon) == 2): + raise ValueError("DAPO requires clip_epsilon=(eps_low, eps_high).") + super().__init__( + actor_network=actor_network, clip_epsilon=clip_epsilon, **kwargs ) - if is_composite: - raise NotImplementedError - log_weight = (cur_log_prob - prev_log_prob).unsqueeze(-1) - if is_tensor_collection(log_weight): - log_weight = _sum_td_features(log_weight) - log_weight = log_weight.view(adv_shape).unsqueeze(-1) - - kl_approx = (prev_log_prob - cur_log_prob).unsqueeze(-1) - if is_tensor_collection(kl_approx): - kl_approx = _sum_td_features(kl_approx) - - tensordict.set("_cur_log_prob", cur_log_prob) - - return log_weight, dist, kl_approx - class MCAdvantage(Transform): """Monte-Carlo advantage computation engine.