Skip to content

Commit c5997a3

Browse files
committed
Update
[ghstack-poisoned]
1 parent 0d35a30 commit c5997a3

File tree

1 file changed

+55
-104
lines changed

1 file changed

+55
-104
lines changed

torchrl/objectives/llm/grpo.py

Lines changed: 55 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
ProbabilisticTensorDictSequential,
2525
set_composite_lp_aggregate,
2626
)
27-
from tensordict.utils import expand_as_right
2827
from torch import distributions as d
2928
from torchrl._utils import logger as torchrl_logger, VERBOSE
3029
from torchrl.envs.transforms.transforms import Transform
@@ -78,8 +77,10 @@ class GRPOLoss(LossModule):
7877
The masking strategy must match the strategy used for advantage computation to avoid shape mismatches.
7978
8079
Keyword Args:
81-
clip_epsilon (scalar, optional): weight clipping threshold in the clipped PPO loss equation.
82-
default: 0.2
80+
clip_epsilon (float | tuple[float, float], optional): clipping threshold(s) for the clipped surrogate.
81+
- float x: symmetric clipping [1 - x, 1 + x] (default: 0.2)
82+
- tuple (eps_low, eps_high): asymmetric clipping [1 - eps_low, 1 + eps_high] as in DAPO Clip-Higher
83+
recommended defaults from DAPO: (0.20, 0.28); see Eq. (10) in the paper.
8384
entropy_bonus (bool, optional): if ``True``, an entropy bonus will be added to the
8485
loss to favour exploratory policies.
8586
samples_mc_entropy (int, optional): if the distribution retrieved from the policy
@@ -113,6 +114,12 @@ class GRPOLoss(LossModule):
113114
- "generic": Use attention masking (all valid tokens)
114115
Defaults to "sft" since we can't guarantee assistant masks are available.
115116
117+
.. note:: DAPO defaults (for reference):
118+
- Clip-Higher asymmetric thresholds: (eps_low, eps_high) = (0.20, 0.28)
119+
- Token-level policy gradient loss is recommended for long CoT
120+
- Dynamic sampling filters trivial advantage groups
121+
See DAPO [arXiv](https://arxiv.org/html/2503.14476).
122+
116123
.. note:: Parameters and buffers from the policy / critic will not be cast to that device to ensure that
117124
the storages match the ones that are passed to other components, such as data collectors.
118125
"""
@@ -136,7 +143,7 @@ def __init__(
136143
self,
137144
actor_network: LLMWrapperBase | None = None,
138145
*,
139-
clip_epsilon: float = 0.2,
146+
clip_epsilon: float | tuple[float, float] = 0.2,
140147
entropy_bonus: bool = True,
141148
samples_mc_entropy: int = 1,
142149
entropy_coeff: float = 0.01,
@@ -165,7 +172,28 @@ def __init__(
165172
device = getattr(
166173
torch, "get_default_device", lambda: torch.device("cpu")
167174
)()
168-
self.register_buffer("clip_epsilon", torch.tensor(clip_epsilon, device=device))
175+
# Accept symmetric or asymmetric thresholds
176+
if isinstance(clip_epsilon, (tuple, list)):
177+
if len(clip_epsilon) != 2:
178+
raise ValueError(
179+
f"clip_epsilon tuple must have length 2, got {clip_epsilon}."
180+
)
181+
eps_low, eps_high = clip_epsilon
182+
else:
183+
eps_low = float(clip_epsilon)
184+
eps_high = float(clip_epsilon)
185+
# Basic validation
186+
if eps_low < 0 or eps_high < 0:
187+
raise ValueError(
188+
f"clip_epsilon values must be non-negative, got ({eps_low}, {eps_high})."
189+
)
190+
if eps_low >= 1.0:
191+
raise ValueError(
192+
f"clip_epsilon low must be < 1 (to keep 1 - eps_low > 0), got {eps_low}."
193+
)
194+
# Register buffers
195+
self.register_buffer("clip_epsilon_low", torch.tensor(eps_low, device=device))
196+
self.register_buffer("clip_epsilon_high", torch.tensor(eps_high, device=device))
169197

170198
self.masking_strategy = masking_strategy
171199
# Defaults for keys
@@ -178,7 +206,11 @@ def __init__(
178206

179207
@property
180208
def _clip_bounds(self):
181-
return ((-self.clip_epsilon).log1p(), self.clip_epsilon.log1p())
209+
# Returns (log(1 - eps_low), log(1 + eps_high)) for clamping log-weight
210+
return (
211+
(-self.clip_epsilon_low).log1p(),
212+
self.clip_epsilon_high.log1p(),
213+
)
182214

183215
def _set_in_keys(self):
184216
keys = []
@@ -321,6 +353,7 @@ def forward(self, tensordict: TensorDictBase) -> GRPOLossOutput:
321353
ratio = log_weight_clip.exp()
322354
gain2 = ratio * advantage
323355

356+
# Token-level objective: compute min over clipped/unclipped at the token level
324357
gain = torch.stack([gain1, gain2], -1).min(dim=-1).values
325358
td_out = TensorDict({"loss_objective": -gain})
326359
td_out.set("clip_fraction", clip_fraction)
@@ -407,109 +440,27 @@ def _get_entropy(
407440
entropy.batch_size = adv_shape
408441
return entropy.unsqueeze(-1)
409442

410-
def _kl_to_ref(
411-
self,
412-
tensordict: TensorDictBase,
413-
key: NestedKey = ("next", "ref_log_prob"),
414-
ref_log_prob: torch.Tensor | None = None,
415-
coeff: float | None = None,
416-
mask: torch.Tensor | None = None,
417-
dist: d.Distribution | None = None,
418-
):
419-
if coeff is None:
420-
coeff = self.kl_to_ref_coeff
421-
# TODO: customize this
422-
if ref_log_prob is None:
423-
ref_log_prob = tensordict.get(
424-
key,
425-
as_padded_tensor=True,
426-
padding_side="left",
427-
padding_value=0.0,
428-
)
429-
if ref_log_prob is None:
430-
raise KeyError(
431-
f"Couldn't find the ref log-prob {key} in the input data ({tensordict.keys(True)=})."
432-
)
433-
ref_log_prob = ref_log_prob.squeeze(-1)
434-
cur_log_prob = tensordict.get("_cur_log_prob")
435-
# TODO: remove this
436-
if cur_log_prob.shape != ref_log_prob.shape:
437-
raise ValueError(
438-
f"cur_log_prob and ref_log_prob must have the same shape, got {cur_log_prob.shape=} and {ref_log_prob.shape=}"
439-
)
440-
if mask is not None:
441-
ref_log_prob = torch.where(
442-
expand_as_right(mask, ref_log_prob), ref_log_prob, 0.0
443-
)
444-
cur_log_prob = torch.where(
445-
expand_as_right(mask, cur_log_prob), cur_log_prob, 0.0
446-
)
447-
diff = ref_log_prob - cur_log_prob
448-
kl_penalty = (diff.expm1() - diff).mean()
449-
return coeff * kl_penalty, kl_penalty
450-
451-
def _log_weight(
452-
self, tensordict: TensorDictBase, adv_shape: torch.Size
453-
) -> tuple[torch.Tensor, d.Distribution, torch.Tensor]:
454443

455-
cur_log_prob, dist, is_composite = self._get_cur_log_prob(tensordict)
444+
class DAPO(GRPOLoss):
445+
"""DAPO (Clip-Higher over GRPO).
456446
457-
prev_log_prob = tensordict.get(
458-
self.tensor_keys.sample_log_prob,
459-
as_padded_tensor=True,
460-
padding_side="left",
461-
padding_value=0.0,
462-
)
463-
464-
if prev_log_prob is None:
465-
raise KeyError(
466-
f"Couldn't find the log-prob {self.tensor_keys.sample_log_prob} in the input data."
467-
)
468-
if prev_log_prob.requires_grad:
469-
raise RuntimeError(
470-
f"tensordict stored {self.tensor_keys.sample_log_prob} requires grad."
471-
)
472-
473-
# Check for shape mismatches and provide helpful error messages
474-
if cur_log_prob.shape != prev_log_prob.shape:
475-
# Try to provide helpful debugging information
476-
error_msg = (
477-
f"Shape mismatch detected in GRPOLoss: current log-prob shape {cur_log_prob.shape} "
478-
f"!= previous log-prob shape {prev_log_prob.shape}. "
479-
f"This usually indicates a mismatch between the masking strategy used for "
480-
f"advantage computation and the masking strategy used for loss computation.\n"
481-
f"Current masking strategy: '{self.masking_strategy}'\n"
482-
f"Possible solutions:\n"
483-
f"1. If using RLHF (multi-turn conversations), set masking_strategy='rlhf'\n"
484-
f"2. If using SFT (single-turn conversations), set masking_strategy='sft'\n"
485-
f"3. If using generic scenarios, set masking_strategy='generic'\n"
486-
f"4. Ensure the advantage was computed with the same masking strategy as the loss"
487-
)
488-
raise ValueError(error_msg)
447+
Validates asymmetric clip thresholds; recommended (0.20, 0.28), see Eq. (10) in DAPO
448+
[arXiv](https://arxiv.org/html/2503.14476).
449+
"""
489450

490-
attention_mask = dist.mask
491-
cur_log_prob = torch.where(
492-
expand_as_right(attention_mask, cur_log_prob), cur_log_prob, 0.0
493-
)
494-
prev_log_prob = torch.where(
495-
expand_as_right(attention_mask, prev_log_prob), prev_log_prob, 0.0
451+
def __init__(
452+
self,
453+
actor_network: LLMWrapperBase | None = None,
454+
*,
455+
clip_epsilon: tuple[float, float] = (0.20, 0.28),
456+
**kwargs,
457+
):
458+
if not (isinstance(clip_epsilon, (tuple, list)) and len(clip_epsilon) == 2):
459+
raise ValueError("DAPO requires clip_epsilon=(eps_low, eps_high).")
460+
super().__init__(
461+
actor_network=actor_network, clip_epsilon=clip_epsilon, **kwargs
496462
)
497463

498-
if is_composite:
499-
raise NotImplementedError
500-
log_weight = (cur_log_prob - prev_log_prob).unsqueeze(-1)
501-
if is_tensor_collection(log_weight):
502-
log_weight = _sum_td_features(log_weight)
503-
log_weight = log_weight.view(adv_shape).unsqueeze(-1)
504-
505-
kl_approx = (prev_log_prob - cur_log_prob).unsqueeze(-1)
506-
if is_tensor_collection(kl_approx):
507-
kl_approx = _sum_td_features(kl_approx)
508-
509-
tensordict.set("_cur_log_prob", cur_log_prob)
510-
511-
return log_weight, dist, kl_approx
512-
513464

514465
class MCAdvantage(Transform):
515466
"""Monte-Carlo advantage computation engine.

0 commit comments

Comments
 (0)