2424 ProbabilisticTensorDictSequential ,
2525 set_composite_lp_aggregate ,
2626)
27- from tensordict .utils import expand_as_right
2827from torch import distributions as d
2928from torchrl ._utils import logger as torchrl_logger , VERBOSE
3029from torchrl .envs .transforms .transforms import Transform
@@ -78,8 +77,10 @@ class GRPOLoss(LossModule):
7877 The masking strategy must match the strategy used for advantage computation to avoid shape mismatches.
7978
8079 Keyword Args:
81- clip_epsilon (scalar, optional): weight clipping threshold in the clipped PPO loss equation.
82- default: 0.2
80+ clip_epsilon (float | tuple[float, float], optional): clipping threshold(s) for the clipped surrogate.
81+ - float x: symmetric clipping [1 - x, 1 + x] (default: 0.2)
82+ - tuple (eps_low, eps_high): asymmetric clipping [1 - eps_low, 1 + eps_high] as in DAPO Clip-Higher
83+ recommended defaults from DAPO: (0.20, 0.28); see Eq. (10) in the paper.
8384 entropy_bonus (bool, optional): if ``True``, an entropy bonus will be added to the
8485 loss to favour exploratory policies.
8586 samples_mc_entropy (int, optional): if the distribution retrieved from the policy
@@ -113,6 +114,12 @@ class GRPOLoss(LossModule):
113114 - "generic": Use attention masking (all valid tokens)
114115 Defaults to "sft" since we can't guarantee assistant masks are available.
115116
117+ .. note:: DAPO defaults (for reference):
118+ - Clip-Higher asymmetric thresholds: (eps_low, eps_high) = (0.20, 0.28)
119+ - Token-level policy gradient loss is recommended for long CoT
120+ - Dynamic sampling filters trivial advantage groups
121+ See DAPO [arXiv](https://arxiv.org/html/2503.14476).
122+
116123 .. note:: Parameters and buffers from the policy / critic will not be cast to that device to ensure that
117124 the storages match the ones that are passed to other components, such as data collectors.
118125 """
@@ -136,7 +143,7 @@ def __init__(
136143 self ,
137144 actor_network : LLMWrapperBase | None = None ,
138145 * ,
139- clip_epsilon : float = 0.2 ,
146+ clip_epsilon : float | tuple [ float , float ] = 0.2 ,
140147 entropy_bonus : bool = True ,
141148 samples_mc_entropy : int = 1 ,
142149 entropy_coeff : float = 0.01 ,
@@ -165,7 +172,28 @@ def __init__(
165172 device = getattr (
166173 torch , "get_default_device" , lambda : torch .device ("cpu" )
167174 )()
168- self .register_buffer ("clip_epsilon" , torch .tensor (clip_epsilon , device = device ))
175+ # Accept symmetric or asymmetric thresholds
176+ if isinstance (clip_epsilon , (tuple , list )):
177+ if len (clip_epsilon ) != 2 :
178+ raise ValueError (
179+ f"clip_epsilon tuple must have length 2, got { clip_epsilon } ."
180+ )
181+ eps_low , eps_high = clip_epsilon
182+ else :
183+ eps_low = float (clip_epsilon )
184+ eps_high = float (clip_epsilon )
185+ # Basic validation
186+ if eps_low < 0 or eps_high < 0 :
187+ raise ValueError (
188+ f"clip_epsilon values must be non-negative, got ({ eps_low } , { eps_high } )."
189+ )
190+ if eps_low >= 1.0 :
191+ raise ValueError (
192+ f"clip_epsilon low must be < 1 (to keep 1 - eps_low > 0), got { eps_low } ."
193+ )
194+ # Register buffers
195+ self .register_buffer ("clip_epsilon_low" , torch .tensor (eps_low , device = device ))
196+ self .register_buffer ("clip_epsilon_high" , torch .tensor (eps_high , device = device ))
169197
170198 self .masking_strategy = masking_strategy
171199 # Defaults for keys
@@ -178,7 +206,11 @@ def __init__(
178206
179207 @property
180208 def _clip_bounds (self ):
181- return ((- self .clip_epsilon ).log1p (), self .clip_epsilon .log1p ())
209+ # Returns (log(1 - eps_low), log(1 + eps_high)) for clamping log-weight
210+ return (
211+ (- self .clip_epsilon_low ).log1p (),
212+ self .clip_epsilon_high .log1p (),
213+ )
182214
183215 def _set_in_keys (self ):
184216 keys = []
@@ -321,6 +353,7 @@ def forward(self, tensordict: TensorDictBase) -> GRPOLossOutput:
321353 ratio = log_weight_clip .exp ()
322354 gain2 = ratio * advantage
323355
356+ # Token-level objective: compute min over clipped/unclipped at the token level
324357 gain = torch .stack ([gain1 , gain2 ], - 1 ).min (dim = - 1 ).values
325358 td_out = TensorDict ({"loss_objective" : - gain })
326359 td_out .set ("clip_fraction" , clip_fraction )
@@ -407,109 +440,27 @@ def _get_entropy(
407440 entropy .batch_size = adv_shape
408441 return entropy .unsqueeze (- 1 )
409442
410- def _kl_to_ref (
411- self ,
412- tensordict : TensorDictBase ,
413- key : NestedKey = ("next" , "ref_log_prob" ),
414- ref_log_prob : torch .Tensor | None = None ,
415- coeff : float | None = None ,
416- mask : torch .Tensor | None = None ,
417- dist : d .Distribution | None = None ,
418- ):
419- if coeff is None :
420- coeff = self .kl_to_ref_coeff
421- # TODO: customize this
422- if ref_log_prob is None :
423- ref_log_prob = tensordict .get (
424- key ,
425- as_padded_tensor = True ,
426- padding_side = "left" ,
427- padding_value = 0.0 ,
428- )
429- if ref_log_prob is None :
430- raise KeyError (
431- f"Couldn't find the ref log-prob { key } in the input data ({ tensordict .keys (True )= } )."
432- )
433- ref_log_prob = ref_log_prob .squeeze (- 1 )
434- cur_log_prob = tensordict .get ("_cur_log_prob" )
435- # TODO: remove this
436- if cur_log_prob .shape != ref_log_prob .shape :
437- raise ValueError (
438- f"cur_log_prob and ref_log_prob must have the same shape, got { cur_log_prob .shape = } and { ref_log_prob .shape = } "
439- )
440- if mask is not None :
441- ref_log_prob = torch .where (
442- expand_as_right (mask , ref_log_prob ), ref_log_prob , 0.0
443- )
444- cur_log_prob = torch .where (
445- expand_as_right (mask , cur_log_prob ), cur_log_prob , 0.0
446- )
447- diff = ref_log_prob - cur_log_prob
448- kl_penalty = (diff .expm1 () - diff ).mean ()
449- return coeff * kl_penalty , kl_penalty
450-
451- def _log_weight (
452- self , tensordict : TensorDictBase , adv_shape : torch .Size
453- ) -> tuple [torch .Tensor , d .Distribution , torch .Tensor ]:
454443
455- cur_log_prob , dist , is_composite = self ._get_cur_log_prob (tensordict )
444+ class DAPO (GRPOLoss ):
445+ """DAPO (Clip-Higher over GRPO).
456446
457- prev_log_prob = tensordict .get (
458- self .tensor_keys .sample_log_prob ,
459- as_padded_tensor = True ,
460- padding_side = "left" ,
461- padding_value = 0.0 ,
462- )
463-
464- if prev_log_prob is None :
465- raise KeyError (
466- f"Couldn't find the log-prob { self .tensor_keys .sample_log_prob } in the input data."
467- )
468- if prev_log_prob .requires_grad :
469- raise RuntimeError (
470- f"tensordict stored { self .tensor_keys .sample_log_prob } requires grad."
471- )
472-
473- # Check for shape mismatches and provide helpful error messages
474- if cur_log_prob .shape != prev_log_prob .shape :
475- # Try to provide helpful debugging information
476- error_msg = (
477- f"Shape mismatch detected in GRPOLoss: current log-prob shape { cur_log_prob .shape } "
478- f"!= previous log-prob shape { prev_log_prob .shape } . "
479- f"This usually indicates a mismatch between the masking strategy used for "
480- f"advantage computation and the masking strategy used for loss computation.\n "
481- f"Current masking strategy: '{ self .masking_strategy } '\n "
482- f"Possible solutions:\n "
483- f"1. If using RLHF (multi-turn conversations), set masking_strategy='rlhf'\n "
484- f"2. If using SFT (single-turn conversations), set masking_strategy='sft'\n "
485- f"3. If using generic scenarios, set masking_strategy='generic'\n "
486- f"4. Ensure the advantage was computed with the same masking strategy as the loss"
487- )
488- raise ValueError (error_msg )
447+ Validates asymmetric clip thresholds; recommended (0.20, 0.28), see Eq. (10) in DAPO
448+ [arXiv](https://arxiv.org/html/2503.14476).
449+ """
489450
490- attention_mask = dist .mask
491- cur_log_prob = torch .where (
492- expand_as_right (attention_mask , cur_log_prob ), cur_log_prob , 0.0
493- )
494- prev_log_prob = torch .where (
495- expand_as_right (attention_mask , prev_log_prob ), prev_log_prob , 0.0
451+ def __init__ (
452+ self ,
453+ actor_network : LLMWrapperBase | None = None ,
454+ * ,
455+ clip_epsilon : tuple [float , float ] = (0.20 , 0.28 ),
456+ ** kwargs ,
457+ ):
458+ if not (isinstance (clip_epsilon , (tuple , list )) and len (clip_epsilon ) == 2 ):
459+ raise ValueError ("DAPO requires clip_epsilon=(eps_low, eps_high)." )
460+ super ().__init__ (
461+ actor_network = actor_network , clip_epsilon = clip_epsilon , ** kwargs
496462 )
497463
498- if is_composite :
499- raise NotImplementedError
500- log_weight = (cur_log_prob - prev_log_prob ).unsqueeze (- 1 )
501- if is_tensor_collection (log_weight ):
502- log_weight = _sum_td_features (log_weight )
503- log_weight = log_weight .view (adv_shape ).unsqueeze (- 1 )
504-
505- kl_approx = (prev_log_prob - cur_log_prob ).unsqueeze (- 1 )
506- if is_tensor_collection (kl_approx ):
507- kl_approx = _sum_td_features (kl_approx )
508-
509- tensordict .set ("_cur_log_prob" , cur_log_prob )
510-
511- return log_weight , dist , kl_approx
512-
513464
514465class MCAdvantage (Transform ):
515466 """Monte-Carlo advantage computation engine.
0 commit comments