diff --git a/agentlightning/verl/trainer.py b/agentlightning/verl/trainer.py index cec2e9101..3a75a66db 100644 --- a/agentlightning/verl/trainer.py +++ b/agentlightning/verl/trainer.py @@ -200,6 +200,10 @@ def _validate(self): return test_metrics def _train_step(self, batch_dict: dict) -> dict: + # Check if RAFT mode is enabled + if self.config.algorithm.adv_estimator == "raft": + return self._train_step_raft(batch_dict) + # Isolate in a separate method to automatically recycle the variables before validation. batch: DataProto = DataProto.from_single_dict(batch_dict) metrics = {} @@ -388,6 +392,223 @@ def _train_step(self, batch_dict: dict) -> dict: return metrics + def _train_step_raft(self, batch_dict: dict) -> dict: + """ + RAFT training step: Simplified training loop that only trains on r=1 samples. + + RAFT (Rejection sampling Adaptive Fine-Tuning) differs from GRPO/PPO by: + 1. Rejection sampling: Only keeping samples with reward r=1 + 2. Simple loss: Using standard cross-entropy (NLL) loss instead of advantage-weighted loss + 3. No critic: No value function estimation needed + 4. No advantage: No advantage function or GAE computation needed + """ + batch: DataProto = DataProto.from_single_dict(batch_dict) + metrics = {} + timing_raw = {} + + with _timer("step", timing_raw): + # When agent mode is enabled, we read the batch as it is. + gen_batch = batch + + # Generate rollouts and collect data + with _timer("gen", timing_raw): + self.async_rollout_manager.wake_up() + self.agent_mode_daemon.set_up_data_and_server( + gen_batch.non_tensor_batch, self.async_rollout_manager.server_addresses + ) + self.agent_mode_daemon.run_until_all_finished() + batch, agent_metrics = self.agent_mode_daemon.get_train_data_batch( + max_prompt_length=self.config.data.max_prompt_length, + max_response_length=self.config.data.max_response_length, + device=gen_batch.batch["fake_ids"].device, + ) + metrics.update(agent_metrics) + self.agent_mode_daemon.clear_data_and_server() + self.async_rollout_manager.sleep() + + # RAFT Step 1: Rejection Sampling - Filter to keep only r=1 samples + with _timer("rejection_sampling", timing_raw): + # Extract rewards from token_level_scores (sum to get sequence-level reward) + # The reward is stored at the last token position in token_level_scores + sequence_rewards = batch.batch["token_level_scores"].sum(dim=-1) # (batch_size,) + + # Binary reward: 1.0 for success, 0.0 for failure + # In RAFT, we only keep samples with reward == 1.0 + is_positive_reward = (sequence_rewards == 1.0) + positive_indices = is_positive_reward.nonzero(as_tuple=True)[0] + + # Log rejection sampling statistics + n_total = len(batch) + n_positive = len(positive_indices) + n_rejected = n_total - n_positive + metrics["raft/n_total_samples"] = n_total + metrics["raft/n_positive_samples"] = n_positive + metrics["raft/n_rejected_samples"] = n_rejected + metrics["raft/rejection_rate"] = n_rejected / n_total if n_total > 0 else 0.0 + metrics["raft/positive_rate"] = n_positive / n_total if n_total > 0 else 0.0 + + # If no positive samples, skip this training step + if n_positive == 0: + metrics["raft/loss"] = 0.0 + metrics["raft/skipped_no_positive_samples"] = 1 + return metrics + + # Filter batch to keep only positive samples + positive_batch = batch[positive_indices.cpu().tolist()] + + # RAFT Step 2: Compute response mask for the filtered batch + positive_batch.batch["response_mask"] = compute_response_mask(positive_batch) + + # Set uid (required by update_actor, similar to GRPO) + # uid is used for algorithm like GRPO, should be aligned to data id + if "data_id_list" in positive_batch.non_tensor_batch: + positive_batch.non_tensor_batch["uid"] = positive_batch.non_tensor_batch["data_id_list"] + + # Drop samples with prompts that are too long + keep_indices = (~positive_batch.batch["is_drop_mask"]).nonzero(as_tuple=True)[0] + metrics["raft/n_triplets_prompt_too_long"] = ( + positive_batch.batch["is_drop_mask"].shape[0] - keep_indices.shape[0] + ) + if len(keep_indices) == 0: + metrics["raft/loss"] = 0.0 + metrics["raft/skipped_all_dropped"] = 1 + return metrics + positive_batch = positive_batch[keep_indices] + + # Round to mini batch size for efficient training + mini_batch_size = self.config.actor_rollout_ref.actor.ppo_mini_batch_size + n_transition = len(positive_batch) + random_indices = list(range(n_transition)) + random.shuffle(random_indices) + positive_batch.reorder(torch.tensor(random_indices).type(torch.int32)) + n_remained_transition = n_transition // mini_batch_size * mini_batch_size + positive_batch = positive_batch[list(range(n_remained_transition))] + metrics["raft/n_triplets_dropped_remainder"] = n_transition - n_remained_transition + + # Balance batch if enabled + if self.config.trainer.balance_batch: + self._balance_batch(positive_batch, metrics=metrics) + + # RAFT Step 3: Prepare batch for RAFT loss computation + # Remove advantage-related fields since RAFT doesn't use them + raft_batch = positive_batch + max_response_length = raft_batch.batch["responses"].shape[-1] + + # RAFT Step 4: Prepare batch for actor update + # Need to compute old_log_probs and set required meta_info fields + with _timer("prepare_raft_batch", timing_raw): + # Ensure uid is set (may have been lost during filtering) + if "data_id_list" in raft_batch.non_tensor_batch: + raft_batch.non_tensor_batch["uid"] = raft_batch.non_tensor_batch["data_id_list"] + + # Compute global_token_num (required by update_actor) + raft_batch.meta_info["global_token_num"] = torch.sum(raft_batch.batch["attention_mask"], dim=-1).tolist() + + # Pad batch for distributed training before computing log_probs + raft_batch, pad_size_prep = pad_dataproto_to_divisor(raft_batch, self.actor_rollout_wg.world_size) + + # Compute old_log_probs (required by update_actor, similar to GRPO) + # This is needed even for RAFT because update_actor expects this field + old_log_prob = self.actor_rollout_wg.compute_log_prob(raft_batch) + entropys = old_log_prob.batch["entropys"] + response_masks = raft_batch.batch["response_mask"] + loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode + entropy_loss = agg_loss(loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode) + old_log_prob_metrics = {"actor/entropy_loss": entropy_loss.detach().item()} + metrics.update(old_log_prob_metrics) + old_log_prob.batch.pop("entropys") + raft_batch = raft_batch.union(old_log_prob) + + # Set required meta_info fields (similar to GRPO) + raft_batch.meta_info["multi_turn"] = self.config.actor_rollout_ref.rollout.multi_turn.enable + # Temperature is required by update_actor (from config or default 0.7) + raft_batch.meta_info["temperature"] = self.config.actor_rollout_ref.rollout.get("temperature", 0.7) + + # Unpad before setting advantages + raft_batch = unpad_dataproto(raft_batch, pad_size=pad_size_prep) + + # RAFT Step 5: Pure SFT update + # Use standard cross-entropy loss via PPO with advantages=1.0 and disabled clipping + # Note: PPO loss with advantages=1.0 and no clipping becomes equivalent to SFT + with _timer("update_actor_sft", timing_raw): + # Set advantages to 1.0 (no advantage weighting, pure SFT) + # This makes the PPO loss equivalent to standard cross-entropy when clipping is disabled + raft_batch.batch["advantages"] = torch.ones( + (len(raft_batch), max_response_length), + device=raft_batch.batch["input_ids"].device, + dtype=torch.float32 + ) + raft_batch.batch["returns"] = raft_batch.batch["advantages"].clone() + + # Remove any existing values field (no critic in RAFT) + if "values" in raft_batch.batch: + raft_batch.batch.pop("values") + + # Pad again for distributed training before update_actor + raft_batch, pad_size_actor = pad_dataproto_to_divisor(raft_batch, self.actor_rollout_wg.world_size) + + # Temporarily disable PPO clipping for pure SFT + original_clip_low = self.config.actor_rollout_ref.actor.get("clip_ratio_low", 0.2) + original_clip_high = self.config.actor_rollout_ref.actor.get("clip_ratio_high", 0.3) + + # Disable clipping: set both ratios to a very large value (effectively no clipping) + # Using 1000.0 ensures clip(ratio, 1-1, 1+1000) = clip(ratio, 0, 1001) + # which doesn't restrict ratio values in [0, +∞) range + self.config.actor_rollout_ref.actor["clip_ratio_low"] = 1 + self.config.actor_rollout_ref.actor["clip_ratio_high"] = 1000 + + try: + # Update actor with pure SFT loss + # With advantages=1.0 and clipping disabled, this becomes standard cross-entropy + # This mimics SFTTrainer.compute_loss() behavior + actor_output = self.actor_rollout_wg.update_actor(raft_batch) + actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"]) + metrics.update(actor_output_metrics) + + # Extract and log the SFT loss + # Use actor loss from update_actor output (same as GRPO) + # Note: update_actor returns "actor/pg_loss" not "actor/loss" + if "actor/pg_loss" in actor_output_metrics: + metrics["raft/loss"] = actor_output_metrics["actor/pg_loss"] + elif "actor/loss" in actor_output_metrics: + metrics["raft/loss"] = actor_output_metrics["actor/loss"] + else: + # Fallback: use a default value if loss not found + metrics["raft/loss"] = 0.0 + finally: + # Restore original clipping ratios + self.config.actor_rollout_ref.actor["clip_ratio_low"] = original_clip_low + self.config.actor_rollout_ref.actor["clip_ratio_high"] = original_clip_high + + # Log that we're using pure SFT update (like SFTTrainer) + metrics["raft/pure_sft_update"] = 1.0 + + # Log rollout generations if enabled + rollout_data_dir = self.config.trainer.get("rollout_data_dir", None) + if rollout_data_dir: + with _timer("dump_rollout_generations", timing_raw): + # Unpad for logging + log_batch = unpad_dataproto(raft_batch, pad_size_actor) + inputs = self.tokenizer.batch_decode(log_batch.batch["prompts"], skip_special_tokens=True) + outputs = self.tokenizer.batch_decode(log_batch.batch["responses"], skip_special_tokens=True) + # Get scores from the filtered batch + log_scores = log_batch.batch["token_level_scores"].sum(dim=-1).cpu().tolist() + self._dump_generations( + inputs=inputs, + outputs=outputs, + scores=log_scores, + reward_extra_infos_dict={}, + dump_path=rollout_data_dir, + ) + + # Compute training metrics + # Note: We skip critic metrics for RAFT since there's no critic + metrics.update(compute_timing_metrics(batch=raft_batch, timing_raw=timing_raw)) + n_gpus = self.resource_pool_manager.get_n_gpus() + metrics.update(compute_throughout_metrics(batch=raft_batch, timing_raw=timing_raw, n_gpus=n_gpus)) + + return metrics + def fit(self): logger = Tracking( project_name=self.config.trainer.project_name, @@ -496,4 +717,4 @@ def fit(self): return progress_bar.update(1) - self.global_steps += 1 + self.global_steps += 1 \ No newline at end of file