From 012edfdef80335f61ccab4bb452141fe244b31c5 Mon Sep 17 00:00:00 2001 From: Lukasz Pierscieniewski Date: Thu, 19 Feb 2026 12:02:59 +0100 Subject: [PATCH 1/2] feat(ppo): Implementation scaffolding --- 3rdparty/Automodel-workspace/Automodel | 2 +- examples/configs/ppo_math_1B.yaml | 444 +++++ examples/run_ppo.py | 136 ++ nemo_rl/algorithms/advantage_estimator.py | 134 ++ nemo_rl/algorithms/loss/__init__.py | 7 +- nemo_rl/algorithms/loss/loss_functions.py | 54 +- nemo_rl/algorithms/ppo.py | 1700 +++++++++++++++++ nemo_rl/algorithms/utils.py | 25 +- .../ray_actor_environment_registry.py | 1 + nemo_rl/models/automodel/setup.py | 26 +- nemo_rl/models/value/__init__.py | 19 + nemo_rl/models/value/config.py | 80 + nemo_rl/models/value/interfaces.py | 96 + nemo_rl/models/value/lm_value.py | 459 +++++ nemo_rl/models/value/workers/__init__.py | 13 + .../value/workers/dtensor_value_worker_v2.py | 596 ++++++ nemo_rl/utils/automodel_checkpoint.py | 3 + 17 files changed, 3782 insertions(+), 13 deletions(-) create mode 100644 examples/configs/ppo_math_1B.yaml create mode 100644 examples/run_ppo.py create mode 100644 nemo_rl/algorithms/ppo.py create mode 100644 nemo_rl/models/value/__init__.py create mode 100644 nemo_rl/models/value/config.py create mode 100644 nemo_rl/models/value/interfaces.py create mode 100644 nemo_rl/models/value/lm_value.py create mode 100644 nemo_rl/models/value/workers/__init__.py create mode 100644 nemo_rl/models/value/workers/dtensor_value_worker_v2.py diff --git a/3rdparty/Automodel-workspace/Automodel b/3rdparty/Automodel-workspace/Automodel index 1d42deb981..fe0184feed 160000 --- a/3rdparty/Automodel-workspace/Automodel +++ b/3rdparty/Automodel-workspace/Automodel @@ -1 +1 @@ -Subproject commit 1d42deb98169fd94b54c714c0fe4bf308fe7115a +Subproject commit fe0184feedb5cae855723f54ff5580c7c3e778e5 diff --git a/examples/configs/ppo_math_1B.yaml b/examples/configs/ppo_math_1B.yaml new file mode 100644 index 0000000000..f451c05f3f --- /dev/null +++ b/examples/configs/ppo_math_1B.yaml @@ -0,0 +1,444 @@ +# PPO Algorithm Configuration +# This config supports multiple advantage estimators: +# - "grpo": GRPO with leave-one-out baseline +# - "reinforce_plus_plus": Reinforce++ with optional baseline +# - "gae": Generalized Advantage Estimation (requires value model) +ppo: + num_prompts_per_step: 1024 + num_generations_per_prompt: 1 + max_rollout_turns: 1 # for multi-turn rollouts. Math Environments just have 1 turn (answering the question) + max_num_epochs: 200 + steps_per_epoch: 1 + normalize_rewards: false + use_leave_one_out_baseline: true + val_period: 1 + val_at_start: true + val_at_end: false + overlong_filtering: false + max_val_samples: 4096 + val_batch_size: 256 + seed: 42 + use_dynamic_sampling: false + dynamic_sampling_max_gen_batches: 10 + batch_multiplier: 1 + reward_shaping: + enabled: false + overlong_buffer_length: 128 + overlong_buffer_penalty: 1 + max_response_length: ${policy.max_total_sequence_length} + stop_properly_penalty_coef: null + + # Advantage Estimator Configuration + # Options: "grpo", "reinforce_plus_plus", or "gae" + adv_estimator: + name: "gae" # Use GAE with value model for PPO + # GAE-specific parameters + gae_lambda: 0.95 # GAE λ parameter (decay factor, typically 0.95-0.98) + gae_gamma: 0.99 # Discount factor γ (typically 0.99) + normalize_advantages: false # Normalize advantages globally across batch + # Legacy GRPO/Reinforce++ parameters (not used with GAE) + normalize_rewards: false + use_leave_one_out_baseline: false + minus_baseline: true + reward_scaling: + enabled: false + source_min: 0.0 + source_max: 1.0 + target_min: 0.0 + target_max: 1.0 + +loss_fn: + reference_policy_kl_penalty: 0.1 + # Can be set to k1, k2, k3 + # For more details, see http://joschu.net/blog/kl-approx.html + reference_policy_kl_type: "k3" + kl_input_clamp_value: 20.0 + kl_output_clamp_value: 10.0 + ratio_clip_min: 0.2 + ratio_clip_max: 0.2 + ratio_clip_c: null + # (default off) loss formulation improvements (docs/guides/grpo.md#loss) + use_on_policy_kl_approximation: false + # Async GRPO requires importance sampling correction enabled + # Set to true when async_grpo.enabled is true + use_importance_sampling_correction: false + truncated_importance_sampling_ratio: null + truncated_importance_sampling_ratio_min: null # Lower bound for ICE-POP + truncated_importance_sampling_type: tis # "tis" (clamp to max) or "icepop" (filter outside [min, max]) + sequence_level_importance_ratios: false + token_level_loss: true + force_on_policy_ratio: false # Set to true to force ratio=1.0 (requires train_global_batch_size == num_prompts_per_step * num_generations_per_prompt) + use_kl_in_reward: true # Reinforce++: add KL penalty to reward instead of loss + +checkpointing: + enabled: true + checkpoint_dir: "results/ppo" + metric_name: "val:accuracy" # one of "val:" or "train:" followed by the metric name + higher_is_better: true + keep_top_k: 3 + save_period: 10 + checkpoint_must_save_by: null + model_save_format: "safetensors" + save_consolidated: false + +policy: + #model_name: "Qwen/Qwen3-4B-Instruct-2507" + #model_name: "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" + model_name: "Qwen/Qwen2.5-1.5B" + tokenizer: + name: ${policy.model_name} ## specify if you'd like to use a tokenizer different from the model's default + chat_template_kwargs: null # can be used to pass kwargs to the chat template, e.g., enable_thinking=true + hf_config_overrides: {} + train_global_batch_size: 256 + train_micro_batch_size: 1 + generation_batch_size: 32 # Only used when generating using HF backend + logprob_batch_size: ${policy.train_micro_batch_size} + max_total_sequence_length: 4096 + precision: "bfloat16" + logprob_chunk_size: null + offload_optimizer_for_logprob: false # Only useful for non-colocated generation since colocated generation will always offload optimizer to cuda before refit + + dtensor_cfg: + _v2: true + enabled: true + cpu_offload: true + sequence_parallel: false + activation_checkpointing: true + tensor_parallel_size: 2 + context_parallel_size: 1 + custom_parallel_plan: null + + # LoRA (Low-Rank Adaptation) Configuration + lora_cfg: + enabled: False # Set to True to enable LoRA fine-tuning + target_modules: [] # List of module names to apply LoRA (empty list with match_all_linear=true applies to all linear layers) + exclude_modules: [] # List of module names to exclude from LoRA + match_all_linear: true # If True, applies LoRA to all linear layers (overrides target_modules) + dim: 8 # LoRA rank (r): lower rank = fewer parameters but less capacity. Typical values: 4, 8, 16, 32, 64 + alpha: 32 # LoRA scaling factor: effective learning rate multiplier = alpha/dim. Typical values: 16, 32, 64 + dropout: 0.0 # Dropout probability applied to LoRA layers (0.0 = no dropout) + dropout_position: "post" # Where to apply dropout: "pre" (before LoRA) or "post" (after LoRA) + lora_A_init: "xavier" # Initialization method for LoRA A matrix: "xavier" or "uniform" + use_triton: true # Use Triton-optimized kernels for LoRA (faster but requires flash-attn). Disable when tensor_parallel_size > 1 + + megatron_cfg: + enabled: false + empty_unused_memory_level: 1 # 1 is the minimum recommendation for RL since we almost always need to offload before beginning generation. Setting to 0 is faster, but you are more likely to run out of GPU memory. + activation_checkpointing: false + converter_type: "Qwen2ForCausalLM" + tensor_model_parallel_size: 1 + expert_tensor_parallel_size: 1 + expert_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + num_layers_in_first_pipeline_stage: null + num_layers_in_last_pipeline_stage: null + context_parallel_size: 1 + pipeline_dtype: ${policy.precision} + sequence_parallel: false + freeze_moe_router: true + moe_router_dtype: "fp64" + moe_router_load_balancing_type: "none" # "seq_aux_loss" causes logprob error divergence for grpo + moe_router_bias_update_rate: 0.0 # by default, disable bias updates for grpo + moe_permute_fusion: false + #gives ~20% training perf speedup with sequence packing + apply_rope_fusion: True + # gives ~25% training perf speedup with sequence packing and apply_rope_fusion + bias_activation_fusion: True + defer_fp32_logits: False + moe_per_layer_logging: False + moe_enable_deepep: false + moe_token_dispatcher_type: "allgather" + moe_shared_expert_overlap: false + + optimizer: + optimizer: "adam" + lr: 5.0e-6 + min_lr: 5.0e-7 + weight_decay: 0.01 + bf16: true + fp16: false + params_dtype: "float32" + + #adam + adam_beta1: 0.9 + adam_beta2: 0.999 + adam_eps: 1e-8 + + #sgd + sgd_momentum: 0.9 + + #distributed optimizer + use_distributed_optimizer: true + use_precision_aware_optimizer: true + + clip_grad: ${policy.max_grad_norm} + + # optimizer cpu offload + optimizer_cpu_offload: false + optimizer_offload_fraction: 0.0 + + scheduler: + start_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay} + end_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay} + weight_decay_incr_style: "constant" + lr_decay_style: "constant" + lr_decay_iters: 1000 + lr_warmup_iters: 13 + lr_warmup_init: 5.0e-7 + + distributed_data_parallel_config: + grad_reduce_in_fp32: false + overlap_grad_reduce: true + overlap_param_gather: true + use_custom_fsdp: false + data_parallel_sharding_strategy: "optim_grads_params" + + fp8_cfg: null + + env_vars: null + + # See docs/design-docs/sequence-packing-and-dynamic-batching.md + # for more details on dynamic batching and sequence packing. + dynamic_batching: + enabled: False + train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}} + logprob_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.logprob_batch_size}} + sequence_length_round: 64 + + sequence_packing: + enabled: True + train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}} + logprob_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.logprob_batch_size}} + algorithm: "modified_first_fit_decreasing" + sequence_length_round: 64 + + # makes the training sequence length divisible by the tensor parallel size + # this is useful for sequence parallel training + make_sequence_length_divisible_by: ${policy.dtensor_cfg.tensor_parallel_size} + max_grad_norm: 1.0 + + optimizer: + name: "torch.optim.AdamW" + kwargs: + lr: 1.0e-6 + weight_decay: 0.01 + betas: [0.9, 0.999] + eps: 1e-8 + # when using Dtensor, we need to set foreach + # and fused to False + foreach: False + fused: False + + scheduler: + - name: "torch.optim.lr_scheduler.LinearLR" + kwargs: + start_factor: 0.1 + end_factor: 1.0 + total_iters: 10 + - name: "torch.optim.lr_scheduler.ConstantLR" + kwargs: + factor: 1.0 + total_iters: 10000000000 + - milestones: [10] + + generation: + backend: "vllm" + max_new_tokens: ${policy.max_total_sequence_length} + temperature: 1.0 + top_p: 1.0 + top_k: null + stop_token_ids: null + stop_strings: null + mcore_generation_config: + buffer_size_gb: 20 # Total GPU memory (in GB) allocated for KV cache buffers + buffer_guaranteed_fraction: 0.1 # Fraction of buffer reserved for guaranteed active requests + num_cuda_graphs: 16 # Number of CUDA graphs to pre-compile for different batch sizes + block_size_tokens: 256 # Size of each KV cache block in tokens (affects memory granularity) + use_cuda_graphs_for_non_decode_steps: true # Enable CUDA graphs for prefill/context processing + enable_chunked_prefill: true # Split long prefills into chunks for better memory management + unified_memory_level: 0 # Unified memory usage level (0=disabled, higher values enable more aggressive paging) + max_tokens: 16384 # Maximum number of tokens to use in a single step. Analogous to vllm's max_num_batched_tokens + vllm_cfg: + async_engine: false + precision: ${policy.precision} + kv_cache_dtype: "auto" + tensor_parallel_size: 1 + pipeline_parallel_size: 1 + expert_parallel_size: 1 # When EP > 1, EP must be a multiple of TP since vLLM's EP = DP * TP + gpu_memory_utilization: 0.6 + max_model_len: ${policy.max_total_sequence_length} + # when enforce_eager is False, it is optional to set ++policy.generation.vllm_kwargs.compilation_config.use_inductor=False for better accuracy, + # with the flag, vllm will use the custom CUDA kernels instead of the Triton kernels generated by torch.compile + # for more details, see convergence issue https://github.com/NVIDIA-NeMo/RL/issues/998 + enforce_eager: False + use_deep_gemm: False + num_last_layers_in_bf16: 0 + num_first_layers_in_bf16: 0 + enable_vllm_metrics_logger: true # Set to true to enable vLLM internal metrics logger, turn off for better performance + vllm_metrics_logger_interval: 0.5 # Interval in seconds to collect vLLM logger metrics + vllm_kwargs: {} + colocated: + # true: generation shares training GPUs + # false: uses dedicated generation resources + enabled: true + # only relevant when enabled is false + resources: + gpus_per_node: null # Decides num gpus to be dedicated to generation when there is one node in the cluster i.e cluster.num_nodes == 1 + num_nodes: null # Decides number of nodes to be dedicated to generation + +# Value Model Configuration (for GAE advantage estimation in PPO) +value: + model_name: ${policy.model_name} # Can use same model as policy or a smaller one + tokenizer: + name: ${value.model_name} + chat_template_kwargs: null + hf_config_overrides: {} + + # Training batch sizes + train_global_batch_size: ${policy.train_global_batch_size} + train_micro_batch_size: 1 + # Sequence length settings + max_total_sequence_length: ${policy.max_total_sequence_length} + make_sequence_length_divisible_by: ${value.dtensor_cfg.tensor_parallel_size} + + # Precision + precision: "bfloat16" + + # Reward model configuration (value models use regression head) + reward_model_cfg: + enabled: true + reward_model_type: "regression" + only_head_unfreeze: false + + # DTensor backend configuration (only DTensor V2 is supported for value models) + dtensor_cfg: + _v2: true + enabled: true + cpu_offload: true + sequence_parallel: false + activation_checkpointing: true + tensor_parallel_size: 2 + context_parallel_size: 1 + custom_parallel_plan: null + + # LoRA configuration for value model + lora_cfg: + enabled: false + target_modules: [] + exclude_modules: [] + match_all_linear: true + dim: 8 + alpha: 32 + dropout: 0.0 + dropout_position: "post" + lora_A_init: "xavier" + use_triton: true + + # Batching strategies (same as policy) + dynamic_batching: + enabled: false + train_mb_tokens: ${mul:${value.max_total_sequence_length}, ${value.train_micro_batch_size}} + sequence_length_round: 64 + + sequence_packing: + enabled: false + train_mb_tokens: ${mul:${value.max_total_sequence_length}, ${value.train_micro_batch_size}} + algorithm: "modified_first_fit_decreasing" + + # Gradient clipping + max_grad_norm: 1.0 + + # Optimizer (typically similar to policy but can use different learning rate) + optimizer: + name: "torch.optim.AdamW" + kwargs: + lr: 1.0e-7 # Can tune separately from policy + weight_decay: 0.01 + betas: [0.9, 0.999] + eps: 1e-8 + foreach: false + fused: false + + # Scheduler (typically similar to policy) + scheduler: + - name: "torch.optim.lr_scheduler.ConstantLR" + kwargs: + factor: 1.0 + total_iters: 10000000000 + - milestones: [] + +data: + max_input_seq_length: ${policy.max_total_sequence_length} # upper bound, real truncation occurs at vllm.max_model_len + shuffle: true + num_workers: 1 + + # dataset + train: + dataset_name: OpenMathInstruct-2 + split_validation_size: 0.05 # use 5% of the training data as validation data + seed: ${ppo.seed} # seed for train/validation split when split_validation_size > 0 + validation: null + # train: + # dataset_name: DeepScaler + # validation: + # dataset_name: AIME2024 + # repeat: 16 + # default settings for all datasets + default: + prompt_file: "examples/prompts/cot.txt" + system_prompt_file: null + processor: "math_hf_data_processor" + env_name: "math" + + # You can also use multiple datasets by using a list of datasets. + # See `examples/configs/grpo_multiple_datasets.yaml` for a full configuration example. + + # You can use custom response datasets for training and validation. For example: + # train: + # # this dataset will override input_key and use the default values for other vars + # data_path: /path/to/local/train_dataset.jsonl + # input_key: question + # validation: + # # this dataset will use the default values for other vars except data_path + # data_path: /path/to/local/val_dataset.jsonl + # default: + # # will use below vars as default values if dataset doesn't specify it + # dataset_name: ResponseDataset + # input_key: input + # output_key: output + # prompt_file: null + # system_prompt_file: null + # processor: "math_hf_data_processor" + # env_name: math + # See https://github.com/NVIDIA-NeMo/RL/blob/main/docs/guides/grpo.md#datasets for more details. + +env: + math: + num_workers: 8 + math_verify_impl: "hf_math_verify" + +logger: + log_dir: "logs" # Base directory for all logs + num_val_samples_to_print: 0 # Number of validation samples to pretty print on terminal + wandb_enabled: false + tensorboard_enabled: false + mlflow_enabled: false # Disable MLflow logging + swanlab_enabled: false # Disable SwanLab logging + monitor_gpus: true # If true, will monitor GPU usage and log to wandb and/or tensorboard + wandb: + project: "grpo-dev" + name: "grpo-dev-logger" + swanlab: + project: "grpo-dev" + name: "grpo-dev-logger" + tensorboard: {} + mlflow: + experiment_name: "grpo-dev" + run_name: "grpo-dev-logger" + gpu_monitoring: + collection_interval: 10 # How often to collect GPU usage metrics (in seconds) + flush_interval: 10 # How often to flush GPU usage metrics to the loggers (in seconds) + +cluster: + gpus_per_node: 8 + num_nodes: 4 diff --git a/examples/run_ppo.py b/examples/run_ppo.py new file mode 100644 index 0000000000..585998593b --- /dev/null +++ b/examples/run_ppo.py @@ -0,0 +1,136 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import pprint + +from omegaconf import OmegaConf + +from nemo_rl.algorithms.ppo import MasterConfig, ppo_train, setup +from nemo_rl.algorithms.utils import get_tokenizer +from nemo_rl.data.utils import setup_response_data +from nemo_rl.distributed.virtual_cluster import init_ray +from nemo_rl.models.generation import configure_generation_config +from nemo_rl.utils.config import ( + load_config, + parse_hydra_overrides, + register_omegaconf_resolvers, +) +from nemo_rl.utils.logger import get_next_experiment_dir + + +def parse_args() -> tuple[argparse.Namespace, list[str]]: + """Parse command line arguments.""" + parser = argparse.ArgumentParser(description="Run GRPO training with configuration") + parser.add_argument( + "--config", type=str, default=None, help="Path to YAML config file" + ) + + # Parse known args for the script + args, overrides = parser.parse_known_args() + + return args, overrides + + +def main() -> None: + """Main entry point.""" + # Parse arguments + register_omegaconf_resolvers() + args, overrides = parse_args() + + if not args.config: + args.config = os.path.join( + os.path.dirname(__file__), "configs", "ppo_math_1B.yaml" + ) + + config = load_config(args.config) + print(f"Loaded configuration from: {args.config}") + + if overrides: + print(f"Overrides: {overrides}") + config = parse_hydra_overrides(config, overrides) + + config: MasterConfig = OmegaConf.to_container(config, resolve=True) + print("Applied CLI overrides") + + # Print config + print("Final config:") + pprint.pprint(config) + + # Get the next experiment directory with incremented ID + config["logger"]["log_dir"] = get_next_experiment_dir(config["logger"]["log_dir"]) + print(f"📊 Using log directory: {config['logger']['log_dir']}") + if config["checkpointing"]["enabled"]: + print( + f"📊 Using checkpoint directory: {config['checkpointing']['checkpoint_dir']}" + ) + + init_ray() + + # setup tokenizer + tokenizer = get_tokenizer(config["policy"]["tokenizer"]) + assert config["policy"]["generation"] is not None, ( + "A generation config is required for GRPO" + ) + config["policy"]["generation"] = configure_generation_config( + config["policy"]["generation"], tokenizer + ) + + # setup data + ( + dataset, + val_dataset, + task_to_env, + val_task_to_env, + ) = setup_response_data(tokenizer, config["data"], config["env"]) + + ( + policy, + policy_generation, + value_model, + cluster, + dataloader, + val_dataloader, + loss_fn, + value_loss_fn, + logger, + checkpointer, + grpo_state, + master_config, + ) = setup(config, tokenizer, dataset, val_dataset) + + print("🚀 Running synchronous PPO training") + + # Run standard PPO training + ppo_train( + policy, + policy_generation, + value_model, + dataloader, + val_dataloader, + tokenizer, + loss_fn, + value_loss_fn, + task_to_env, + val_task_to_env, + logger, + checkpointer, + grpo_state, + master_config, + ) + + +if __name__ == "__main__": + main() diff --git a/nemo_rl/algorithms/advantage_estimator.py b/nemo_rl/algorithms/advantage_estimator.py index 6a7288ff52..cacc61a2cd 100644 --- a/nemo_rl/algorithms/advantage_estimator.py +++ b/nemo_rl/algorithms/advantage_estimator.py @@ -18,17 +18,23 @@ - GRPOAdvantageEstimator: Standard GRPO advantage with leave-one-out baseline - GDPOAdvantageEstimator: Multi-reward GDPO (per-component baselines, sum then normalize) - ReinforcePlusPlusAdvantageEstimator: Reinforce++ with optional baseline subtraction (minus_baseline) and KL penalty in reward +- GAEAdvantageEstimator: Generalized Advantage Estimation (GAE) with temporal bootstrapping Reference papers: - ProRLv2: https://developer.nvidia.com/blog/scaling-llm-reinforcement-learning-with-prolonged-training-using-prorl-v2/ - Reinforce++: https://arxiv.org/abs/2501.03262 +- GAE: https://arxiv.org/abs/1506.02438 (High-Dimensional Continuous Control Using Generalized Advantage Estimation) """ +from string import whitespace + import torch from nemo_rl.algorithms.utils import ( calculate_baseline_and_std_per_prompt, calculate_kl, get_gdpo_reward_component_keys, + masked_mean, + masked_var, ) @@ -220,3 +226,131 @@ def compute_advantage( adv = (adv - adv_mean) * adv_rstd return adv + + +class GeneralizedAdvantageEstimator: + """Generalized Advantage Estimation (GAE) with temporal bootstrapping. + + GAE computes advantages using temporal difference (TD) and exponentially-weighted averages: + δ_t = r_t + γ * V(s_{t+1}) * (1 - done_t) - V(s_t) + A_t = Σ_{l=0}^{∞} (γλ)^l * δ_{t+l} + + This is computed recursively backwards: + A_t = δ_t + γλ * (1 - done_t) * A_{t+1} + + Args: + gae_lambda: GAE λ parameter (decay factor for advantage estimation, typically 0.95-0.98) + gae_gamma: Discount factor γ (typically 0.99) + normalize_advantages: If True, normalize advantages globally across batch + """ + + def __init__(self, estimator_config: dict, loss_config: dict): + self.gae_lambda = estimator_config.get("gae_lambda", 0.95) + self.gae_gamma = estimator_config.get("gae_gamma", 0.99) + self.normalize_advantages = estimator_config.get("normalize_advantages", True) + + self.kl_coef = loss_config.get("reference_policy_kl_penalty", 0.1) + self.kl_enabled = loss_config.get("use_kl_in_reward", False) + self.kl_type = loss_config.get("reference_policy_kl_type", "k3") + + def _reward_whiten( + self, + rewards: torch.Tensor, + mask: torch.Tensor, + shift_mean: bool = True, + ) -> torch.Tensor: + mean = masked_mean(rewards, mask) + var = masked_var(rewards, mask, mean) + + whitened_rewards = (rewards - mean) * torch.rsqrt(var + 1e-8) + + if not shift_mean: + whitened_rewards = whitened_rewards + mean + return whitened_rewards + + def compute_advantage( + self, + prompt_ids, + rewards, + mask, + lengths, + values, + reference_logprobs, + logprobs, + **kwargs, + ): + """Compute GAE advantages with temporal bootstrapping. + + Args: + prompt_ids: Tensor of shape [batch_size] identifying which prompt each sample belongs to. + rewards: Tensor of shape [batch_size] containing reward for each sample. + In PPO, this is typically the final reward at the end of the trajectory. + mask: Response token mask of shape [batch_size, seq_len], 1 for valid response tokens, 0 for padding. + lengths: Input lengths of shape [batch_size]. + values: Value predictions of shape [batch_size, seq_len]. Required for GAE. + reference_logprobs: Reference policy log probabilities of shape [batch_size, seq_len]. Required for GAE. + **kwargs: Additional arguments (unused). + + Returns: + Advantages tensor of shape [batch_size, seq_len]. + """ + if self.kl_enabled: + kl = calculate_kl(logprobs, reference_logprobs, self.kl_type) * self.kl_coef + else: + kl = None + advantages, returns = self.compute_advantage_reference( + rewards, lengths, values, kl, mask=mask + ) + + advantages = torch.masked_fill( + self._reward_whiten(advantages, mask), + # advantages, + ~(mask.bool()), + 0, + ) + return advantages, returns + + def compute_advantage_reference( + self, + rewards, + lengths, + values, + kl, + **kwargs, + ): + """Reference GAE implementation for correctness validation. + + Fixes two issues in compute_advantage: + 1. Terminal state: uses V=0 at t=L (not a padding token's value). + 2. No cross-sequence contamination: accumulation resets per sequence. + + Args: + rewards: Tensor of shape [batch_size]. + lengths: Total sequence lengths of shape [batch_size]. + values: Value predictions of shape [batch_size, seq_len]. + + Returns: + advantages: Tensor of shape [batch_size, seq_len]. + returns: advantages + values, shape [batch_size, seq_len]. + """ + advantages = torch.zeros_like(values) + if kl is None: + kl = torch.zeros_like(values) + + for i in range(values.shape[0]): + L = int(lengths[i]) + + last_adv = 0.0 + v_next = 0.0 + r = rewards[i].item() + + for t in reversed(range(L)): + delta = r - kl[i, t] + self.gae_gamma * v_next - values[i, t] + last_adv = delta + self.gae_gamma * self.gae_lambda * last_adv + advantages[i, t] = last_adv + + v_next = values[i, t] + r = 0.0 + + returns = advantages + values + return advantages, returns diff --git a/nemo_rl/algorithms/loss/__init__.py b/nemo_rl/algorithms/loss/__init__.py index a2d404fdaa..174544fc10 100644 --- a/nemo_rl/algorithms/loss/__init__.py +++ b/nemo_rl/algorithms/loss/__init__.py @@ -22,14 +22,12 @@ DPOLossConfig, DPOLossDataDict, DPOLossFn, + MseValueLossFn, NLLLossFn, PreferenceLossDataDict, PreferenceLossFn, ) -from nemo_rl.algorithms.loss.utils import ( - prepare_loss_input, - prepare_packed_loss_input, -) +from nemo_rl.algorithms.loss.utils import prepare_loss_input, prepare_packed_loss_input from nemo_rl.algorithms.loss.wrapper import ( SequencePackingFusionLossWrapper, SequencePackingLossWrapper, @@ -53,5 +51,6 @@ "prepare_packed_loss_input", "SequencePackingFusionLossWrapper", "SequencePackingLossWrapper", + "MseValueLossFn", "wrap_loss_fn_with_input_preparation", ] diff --git a/nemo_rl/algorithms/loss/loss_functions.py b/nemo_rl/algorithms/loss/loss_functions.py index ab05586d7c..503a8d518d 100755 --- a/nemo_rl/algorithms/loss/loss_functions.py +++ b/nemo_rl/algorithms/loss/loss_functions.py @@ -118,7 +118,11 @@ def __init__(self, cfg: ClippedPGLossConfig): self.ratio_clip_min = cfg["ratio_clip_min"] self.ratio_clip_max = cfg["ratio_clip_max"] self.ratio_clip_c = cfg["ratio_clip_c"] # set to None to disable dual-clipping - self.reference_policy_kl_penalty = cfg["reference_policy_kl_penalty"] + self.reference_policy_kl_penalty = ( + cfg["reference_policy_kl_penalty"] + if not cfg.get("use_kl_in_reward", False) + else 0 + ) self.reference_policy_kl_type = cfg["reference_policy_kl_type"] self.kl_input_clamp_value = cfg["kl_input_clamp_value"] self.kl_output_clamp_value = cfg["kl_output_clamp_value"] @@ -978,3 +982,51 @@ def __call__( } return kl_loss, metrics + + +class MseValueLossFn(LossFunction): + """Mean Squared Error value loss function.""" + + def __init__(self, loss_cfg): + self.ratio_clip_min = loss_cfg["ratio_clip_min"] + self.ratio_clip_max = loss_cfg["ratio_clip_max"] + + def __call__( + self, + values: torch.Tensor, + data: BatchedDataDict, + global_valid_seqs: torch.Tensor, + global_valid_toks: torch.Tensor, + ) -> tuple[torch.Tensor, dict[str, Any]]: + """Compute Mean Squared Error value loss.""" + + if values.shape[-1] != 1: + values = values[..., 0] + + token_mask = data["token_mask"] + sample_mask = data["sample_mask"] + returns = data["returns"] + old_values = data["values"] + + mask = token_mask * sample_mask.unsqueeze(-1) + + values_clamped = values.clamp( + old_values - self.ratio_clip_min, + old_values + self.ratio_clip_max, + ) + + loss = torch.max( + torch.square(values - returns), + torch.square(values_clamped - returns), + ) + + loss = 0.5 * masked_mean( + loss, mask, global_normalization_factor=global_valid_toks + ) + + metrics = { + "loss": float(loss.item()), + "num_valid_samples": int(values.shape[0]), + } + + return loss, metrics diff --git a/nemo_rl/algorithms/ppo.py b/nemo_rl/algorithms/ppo.py new file mode 100644 index 0000000000..6f22b35c20 --- /dev/null +++ b/nemo_rl/algorithms/ppo.py @@ -0,0 +1,1700 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import gc +import os +import time +import warnings +from concurrent.futures import ThreadPoolExecutor +from contextlib import nullcontext +from pathlib import Path +from typing import Any, NotRequired, Optional, TypedDict, TypeVar, cast + +import numpy as np +import ray +import torch +from torchdata.stateful_dataloader import StatefulDataLoader +from transformers import AutoProcessor +from transformers.tokenization_utils_base import PreTrainedTokenizerBase + +from nemo_rl.algorithms.advantage_estimator import ( + GeneralizedAdvantageEstimator, + GRPOAdvantageEstimator, + ReinforcePlusPlusAdvantageEstimator, +) +from nemo_rl.algorithms.loss import ( + ClippedPGLossConfig, + ClippedPGLossDataDict, + ClippedPGLossFn, + MseValueLossFn, +) +from nemo_rl.algorithms.loss.interfaces import LossFunction +from nemo_rl.algorithms.reward_functions import ( + RewardShapingConfig, + apply_reward_shaping, +) +from nemo_rl.algorithms.utils import ( + calculate_baseline_and_std_per_prompt, + log_generation_metrics_to_wandb, + print_performance_metrics, + set_seed, +) +from nemo_rl.data import DataConfig +from nemo_rl.data.collate_fn import rl_collate_fn +from nemo_rl.data.datasets import AllTaskProcessedDataset +from nemo_rl.data.interfaces import DatumSpec +from nemo_rl.data.llm_message_utils import ( + batched_message_log_to_flat_message, + get_keys_from_message_log, +) +from nemo_rl.distributed.batched_data_dict import BatchedDataDict +from nemo_rl.distributed.collectives import T +from nemo_rl.distributed.ray_actor_environment_registry import get_actor_python_env +from nemo_rl.distributed.virtual_cluster import ClusterConfig, RayVirtualCluster +from nemo_rl.environments.interfaces import EnvironmentInterface +from nemo_rl.experience.rollouts import ( + run_async_multi_turn_rollout, + run_async_nemo_gym_rollout, + run_multi_turn_rollout, +) +from nemo_rl.models import value +from nemo_rl.models.automodel import train +from nemo_rl.models.generation.interfaces import GenerationInterface +from nemo_rl.models.generation.sglang import SGLangConfig, SGLangGeneration +from nemo_rl.models.generation.vllm import VllmConfig, VllmGeneration +from nemo_rl.models.policy import PolicyConfig +from nemo_rl.models.policy.interfaces import ColocatablePolicyInterface +from nemo_rl.models.policy.lm_policy import Policy +from nemo_rl.models.value import Value, ValueConfig +from nemo_rl.models.value.interfaces import ValueInterface +from nemo_rl.utils.checkpoint import CheckpointingConfig, CheckpointManager +from nemo_rl.utils.logger import Logger, LoggerConfig, print_message_log_samples +from nemo_rl.utils.memory_tracker import MemoryTracker +from nemo_rl.utils.nsys import maybe_gpu_profile_step +from nemo_rl.utils.timer import TimeoutChecker, Timer +from nemo_rl.utils.venvs import create_local_venv_on_each_node + +# =============================================================================== +# Configuration +# =============================================================================== +TokenizerType = TypeVar("TokenizerType", bound=PreTrainedTokenizerBase) + + +class RewardScalingConfig(TypedDict): + """Configure linear reward scaling with clamping. + + When `enabled` is True, each reward is clamped to the source interval + [source_min, source_max] and linearly mapped to the target interval + [target_min, target_max]. Refer to the scale_rewards function for the implementation. + + Defaults: + source_min=0.0, source_max=1.0, target_min=0.0, target_max=1.0 + """ + + enabled: bool + source_min: NotRequired[float] + source_max: NotRequired[float] + target_min: NotRequired[float] + target_max: NotRequired[float] + + +class AsyncGRPOConfig(TypedDict): + enabled: bool + # Maximum trajectory age in training steps for samples drawn from the + # async replay buffer. Trajectories older than this are excluded during + # sampling; buffer sizing also scales with this value. + max_trajectory_age_steps: int + # Does the weight synchronization as soon as the training is done + # without waiting for the pending generations to finish. + in_flight_weight_updates: NotRequired[bool] + # Recomputes the KV cache after the in-flight weight updates. + recompute_kv_cache_after_weight_updates: NotRequired[bool] + + +class AdvEstimatorConfig(TypedDict): + """Configuration for advantage estimator (GRPO or Reinforce++).""" + + name: str # "grpo" or "reinforce_plus_plus" + # GRPO specific + normalize_rewards: NotRequired[bool] + use_leave_one_out_baseline: NotRequired[bool] + # Reinforce++ specific + minus_baseline: NotRequired[bool] + + +class GRPOConfig(TypedDict): + num_prompts_per_step: int + num_generations_per_prompt: int + max_num_epochs: int + max_num_steps: int + max_rollout_turns: int + normalize_rewards: bool + use_leave_one_out_baseline: bool + val_period: int + val_batch_size: int + val_at_start: bool + # Whether to run validation on the last training step. Setting this to True ensures the + # final checkpoint has validation metrics, which is required for get_best_checkpoint_path(). + val_at_end: bool + max_val_samples: int + skip_reference_policy_logprobs_calculation: NotRequired[bool] + seed: int + async_grpo: NotRequired[AsyncGRPOConfig] + overlong_filtering: NotRequired[bool] + # whether to enable dynamic sampling, i.e. + # whether to discard prompts whose rewards have zero standard deviation + use_dynamic_sampling: bool + # When using dynamic sampling, the maximum number of batches to generate + # before throwing an error + dynamic_sampling_max_gen_batches: NotRequired[int] + # When using dynamic sampling, generation prompt batch size will equal + # num_prompts_per_step * batch_multiplier + batch_multiplier: NotRequired[float] + reward_shaping: RewardShapingConfig + reward_scaling: RewardScalingConfig + # By default advantages are calculated on CPU. Setting this flag to true leverages GPU for their computation. + calculate_advantages_on_gpu: NotRequired[bool] + # Advantage estimator configuration (grpo or reinforce_plus_plus) + adv_estimator: NotRequired[AdvEstimatorConfig] + + +class GRPOSaveState(TypedDict): + consumed_samples: int + current_step: int + current_epoch: int + total_steps: int + total_valid_tokens: int # Track total number of non-padding tokens during training + val_reward: NotRequired[ + float + ] # Optional field - may not be present during training + + +def _default_grpo_save_state() -> GRPOSaveState: + return { + "consumed_samples": 0, + "current_step": 0, + "current_epoch": 0, + "total_steps": 0, + "total_valid_tokens": 0, + "val_reward": -99999999.0, + } + + +class GRPOLoggerConfig(LoggerConfig): + num_val_samples_to_print: int # number of val samples to print to stdout + + +class MasterConfig(TypedDict): + policy: PolicyConfig + value: NotRequired[ValueConfig] # Value model configuration + loss_fn: ClippedPGLossConfig + env: dict[str, Any] + data: DataConfig + grpo: GRPOConfig + logger: GRPOLoggerConfig + cluster: ClusterConfig + checkpointing: CheckpointingConfig + + +# =============================================================================== +# Setup & Initialization +# =============================================================================== + + +def setup( + master_config: MasterConfig, + tokenizer: TokenizerType, + dataset: AllTaskProcessedDataset, + val_dataset: Optional[AllTaskProcessedDataset], + processor: Optional[AutoProcessor] = None, +) -> tuple[ + ColocatablePolicyInterface, + Optional[GenerationInterface], + Optional[ValueInterface], + tuple[RayVirtualCluster, RayVirtualCluster], + StatefulDataLoader, + Optional[StatefulDataLoader], + ClippedPGLossFn, + Logger, + CheckpointManager, + GRPOSaveState, + MasterConfig, +]: + """Main entry point for running GRPO algorithm. + + Returns: + tuple of policy, cluster, dataloader, tokenizer, loss_fn, math_env, logger, master_config, val_dataloader + """ + # Start timing the entire setup process + setup_start_time = time.perf_counter() + + # Extract individual configs for easier access + policy_config = master_config["policy"] + value_config = master_config.get("value", None) + generation_config = master_config["policy"]["generation"] + env_configs = master_config["env"] + loss_config = master_config["loss_fn"] + ppo_config = master_config["ppo"] + data_config = master_config["data"] + logger_config = master_config["logger"] + cluster_config = master_config["cluster"] + + assert generation_config is not None, ( + "A generation config in the PolicyConfig is required for GRPO" + ) + + # Set seed for all random number generators + set_seed(ppo_config["seed"]) + + # ========================== + # Logger + # ========================== + logger = Logger(logger_config) + logger.log_hyperparams(master_config) + + # ========================== + # Checkpointing + # ========================== + checkpointer = CheckpointManager(master_config["checkpointing"]) + last_checkpoint_path = checkpointer.get_latest_checkpoint_path() + grpo_save_state: Optional[GRPOSaveState] = cast( + Optional[GRPOSaveState], checkpointer.load_training_info(last_checkpoint_path) + ) + if grpo_save_state is None: + grpo_save_state = _default_grpo_save_state() + + # ========================== + # Data + # ========================== + # Validate batch_multiplier + batch_multiplier = ppo_config["batch_multiplier"] + dataloader_batch_size = ppo_config["num_prompts_per_step"] + if not ppo_config["use_dynamic_sampling"]: + assert batch_multiplier == 1, ( + "batch_multiplier>1 can only be used if use_dynamic_sampling=True" + ) + else: + dataloader_batch_size = int(dataloader_batch_size * batch_multiplier) + + dataloader = StatefulDataLoader( + dataset, + batch_size=dataloader_batch_size, + shuffle=data_config["shuffle"], + collate_fn=rl_collate_fn, + drop_last=True, + num_workers=data_config["num_workers"], + ) + if last_checkpoint_path is not None: + dataloader_state_dict = torch.load( + os.path.join(last_checkpoint_path, "train_dataloader.pt") + ) + dataloader.load_state_dict(dataloader_state_dict) + + print(f" ✓ Training dataloader loaded with {len(dataset)} samples", flush=True) + + # Load validation dataset if provided + val_dataloader: Optional[StatefulDataLoader] = None + # If validation is enabled, load the validation dataloader + if ( + ppo_config["val_period"] > 0 + or ppo_config["val_at_start"] + or ppo_config["val_at_end"] + ): + assert val_dataset is not None, ( + "Validation dataset is required if validation is enabled" + ) + val_dataloader = StatefulDataLoader( + val_dataset, + batch_size=ppo_config["val_batch_size"], + shuffle=False, + collate_fn=rl_collate_fn, + num_workers=data_config["num_workers"], + ) + print( + f" ✓ Validation dataloader loaded with {len(val_dataset)} samples", + flush=True, + ) + + # ========================== + # Loss Function + # ========================== + loss_fn = ClippedPGLossFn(loss_config) + value_loss_fn = MseValueLossFn(loss_config) + + # Validate force_on_policy_ratio + if loss_config.get("force_on_policy_ratio", False): + assert ( + ppo_config["num_prompts_per_step"] + * ppo_config["num_generations_per_prompt"] + == policy_config["train_global_batch_size"] + ), ( + "force_on_policy_ratio requires train_global_batch_size == num_prompts_per_step * num_generations_per_prompt" + ) + os.environ["NRL_IGNORE_TP_ACCURACY_CHECK"] = "1" + print(" ✓ force_on_policy_ratio enabled") + + # ========================== + # Cluster + # ========================== + print("\n▶ Setting up compute cluster...", flush=True) + colocated_inference = generation_config["colocated"]["enabled"] + reward_model_enabled = ( + "env_name" in data_config and data_config["env_name"] == "reward_model" + ) + + total_nodes = cluster_config["num_nodes"] + if reward_model_enabled: + rm_resource = env_configs["reward_model"]["resources"] + rm_nodes = rm_resource["num_nodes"] + rm_gpus_per_node = rm_resource["gpus_per_node"] + else: + rm_nodes = 0 + rm_gpus_per_node = 0 + + if total_nodes == 1: + policy_nodes = total_nodes + else: + policy_nodes = total_nodes - rm_nodes + assert policy_nodes > 0, ( + "policy_nodes must be > 0, but got " + f"policy_nodes:{policy_nodes} + rm_nodes:{rm_nodes} = total_nodes:{total_nodes}" + ) + + if colocated_inference: + if total_nodes == 1: + policy_gpus_per_node = cluster_config["gpus_per_node"] - rm_gpus_per_node + assert policy_gpus_per_node > 0, ( + "policy.generation.colocated.resources.gpus_per_node must be > 0 " + "when cluster.num_nodes = 1, " + f"but got {policy_gpus_per_node}." + ) + else: + policy_gpus_per_node = cluster_config["gpus_per_node"] + + cluster = RayVirtualCluster( + name="grpo_policy_cluster", + bundle_ct_per_node_list=[policy_gpus_per_node] * policy_nodes, + use_gpus=True, + num_gpus_per_node=policy_gpus_per_node, + max_colocated_worker_groups=1 + if generation_config["backend"] == "megatron" + else 3, + ) + train_cluster = cluster + inference_cluster = cluster + value_cluster = cluster + print( + f" ✓ Ray cluster for policy initialized with {policy_nodes} nodes", + flush=True, + ) + + else: + assert generation_config["backend"] != "megatron", ( + "Non-colocated inference is not supported for Megatron generation backends. " + "Please use vLLM backend for generation." + ) + + # train resources will be updated through overall and inference resources below + train_gpus_per_node = cluster_config["gpus_per_node"] + train_nodes = policy_nodes + + inference_resources = generation_config["colocated"]["resources"] + inference_gpus_per_node = inference_resources["gpus_per_node"] + inference_nodes = inference_resources["num_nodes"] + + # validate and configure resources + if policy_nodes == 1: + # When policy_nodes == 1, train and inference are on the same node + assert ( + inference_gpus_per_node is not None and inference_gpus_per_node > 0 + ), ( + "policy.generation.colocated.resources.gpus_per_node must be explicitly set to a value > 0 " + "when policy_nodes = 1 and inference is non-colocated, " + f"but got {inference_gpus_per_node}." + ) + assert inference_nodes is None or inference_nodes == 1, ( + "policy.generation.colocated.resources.num_nodes must be 1 or set to null " + "when policy_nodes = 1 and inference is non-colocated, " + f"but got {inference_nodes}." + ) + + inference_nodes = 1 + # If total_nodes == 1, reward model is also on the same node; otherwise it's on a different node + reward_gpus_to_subtract = ( + rm_gpus_per_node if total_nodes == 1 and reward_model_enabled else 0 + ) + train_gpus_per_node -= inference_gpus_per_node + reward_gpus_to_subtract + assert train_gpus_per_node > 0, ( + "No enough GPUs for training, " + f"train_gpus_per_node:{train_gpus_per_node} = cluster_config['gpus_per_node']:{cluster_config['gpus_per_node']} - inference_gpus_per_node:{inference_gpus_per_node}" + + ( + f" - rm_gpus_per_node:{rm_gpus_per_node}" + if total_nodes == 1 and reward_model_enabled + else "" + ) + ) + else: + # train, inference, and reward model are all on different nodes + assert inference_nodes > 0, ( + "policy.generation.colocated.resources.num_nodes must be > 0 " + "when cluster.num_nodes > 1 and inference is non-colocated, " + f"but got {inference_nodes}." + ) + assert ( + inference_gpus_per_node is not None + and inference_gpus_per_node == cluster_config["gpus_per_node"] + ), ( + "policy.generation.colocated.resources.gpus_per_node must be explicitly set and equal to cluster.gpus_per_node " + "when cluster.num_nodes > 1 and inference is non-colocated, " + f"but got inference_gpus_per_node={inference_gpus_per_node}, cluster.gpus_per_node={cluster_config['gpus_per_node']}." + ) + train_nodes -= inference_nodes + + # initialize train cluster + train_cluster = RayVirtualCluster( + name="grpo_train_cluster", + bundle_ct_per_node_list=[train_gpus_per_node] * train_nodes, + use_gpus=True, + num_gpus_per_node=train_gpus_per_node, + max_colocated_worker_groups=1, + ) + print( + f" ✓ Ray train cluster initialized with {train_nodes} nodes with {train_gpus_per_node} GPUs per node", + flush=True, + ) + + # initialize inference cluster + inference_cluster = RayVirtualCluster( + name="grpo_inference_cluster", + bundle_ct_per_node_list=[inference_gpus_per_node] * inference_nodes, + use_gpus=True, + num_gpus_per_node=inference_gpus_per_node, + max_colocated_worker_groups=1, + ) + print( + f" ✓ Ray inference cluster initialized with {inference_nodes} nodes with {inference_gpus_per_node} GPUs per node", + flush=True, + ) + + # ========================== + # Training and Inference + # ========================== + print("\n▶ Setting up model and training...", flush=True) + + # vllm model loading prefers clean environment, initialize policy_generation before policy in colocated mode + backend = generation_config["backend"] + generation_config["model_name"] = policy_config["model_name"] # Needed for vLLM + + # Dictionary to store worker initialization timing stats for logging + worker_init_timing_metrics = {} + + # Prepare checkpoint paths + if last_checkpoint_path: + weights_path = Path(last_checkpoint_path) / "policy" / "weights" + optimizer_path = Path(last_checkpoint_path) / "policy" / "optimizer" + else: + weights_path = None + optimizer_path = None + + if policy_config.get("megatron_cfg", {}).get("enabled", False): + ## NOTE: this is equal to the total number of scheduler steps + total_train_iters = min( + ppo_config["max_num_steps"], + ppo_config["max_num_epochs"] * len(dataloader), + ) + policy_config["megatron_cfg"]["train_iters"] = total_train_iters + + # Define initialization functions that will be used in all paths + def init_policy(): + """Initialize policy training workers.""" + t0 = time.perf_counter() + p = Policy( + cluster=train_cluster, + config=policy_config, + tokenizer=tokenizer, + processor=processor, + weights_path=weights_path, + optimizer_path=optimizer_path, + init_optimizer=True, + ) + return p, time.perf_counter() - t0 + + def init_value(): + """Initialize value model training workers.""" + t0 = time.perf_counter() + # Prepare checkpoint paths for value model + if last_checkpoint_path: + value_weights_path = Path(last_checkpoint_path) / "value" / "weights" + value_optimizer_path = Path(last_checkpoint_path) / "value" / "optimizer" + else: + value_weights_path = None + value_optimizer_path = None + + # TODO: Proper implementation of the value model + v = Value( + cluster=train_cluster, + config=value_config, + tokenizer=tokenizer, + name_prefix="lm_value", + weights_path=value_weights_path, + optimizer_path=value_optimizer_path, + init_optimizer=True, + ) + # v = None + return v, time.perf_counter() - t0 + + def init_vllm(): + """Initialize vLLM generation workers.""" + t0 = time.perf_counter() + pg = VllmGeneration(cluster=inference_cluster, config=generation_config) + pg.finish_generation() + return pg, time.perf_counter() - t0 + + def init_sglang(): + """Initialize SGLang generation workers.""" + t0 = time.perf_counter() + pg = SGLangGeneration(cluster=inference_cluster, config=generation_config) + pg.finish_generation() + return pg, time.perf_counter() - t0 + + def initialize_generation_with_policy( + init_generation_fn, + generation_name: str, + init_time_key: str, + colocated_inference: bool, + worker_init_timing_metrics: dict, + ): + """Generic function to initialize a generation engine (vLLM or SGLang) along with policy. + + Args: + init_generation_fn: Function that initializes the generation engine (init_vllm or init_sglang) + generation_name: Name of the generation engine ("vLLM" or "SGLang") + init_time_key: Key name for storing initialization time in metrics ("vllm_init_time_s" or "sglang_init_time_s") + colocated_inference: Whether inference is colocated with training + worker_init_timing_metrics: Dictionary to store timing metrics + + Returns: + Tuple of (policy_generation, policy) + """ + # Determine if parallel initialization is possible (non-colocated mode) + use_parallel_init = not colocated_inference + + if use_parallel_init: + # Parallel initialization: Generation engine and Policy can initialize simultaneously + print( + " ⚡ Using parallel worker initialization (non-colocated mode)", + flush=True, + ) + + # Execute both initializations in parallel + parallel_start_time = time.perf_counter() + with ThreadPoolExecutor(max_workers=2) as executor: + generation_future = executor.submit(init_generation_fn) + policy_future = executor.submit(init_policy) + policy_generation, generation_time = generation_future.result() + policy, policy_time = policy_future.result() + parallel_wall_time = time.perf_counter() - parallel_start_time + + # Store timing metrics + worker_init_timing_metrics[init_time_key] = generation_time + worker_init_timing_metrics["policy_init_time_s"] = policy_time + worker_init_timing_metrics["parallel_wall_time_s"] = parallel_wall_time + worker_init_timing_metrics["parallel_init_enabled"] = True + + # Value model not supported in non-colocated mode yet + value_model = None + + else: + # Sequential initialization: colocated mode (GPU memory requires generation engine first) + print( + " ⚙️ Using sequential worker initialization (colocated mode)", + flush=True, + ) + + # Initialize generation engine first (clean GPU memory), then policy + policy_generation, generation_time = init_generation_fn() + worker_init_timing_metrics[init_time_key] = generation_time + + policy, policy_time = init_policy() + worker_init_timing_metrics["policy_init_time_s"] = policy_time + worker_init_timing_metrics["parallel_init_enabled"] = 0.0 + + # Initialize value model if configured (for GAE in colocated mode) + if value_config is not None: + print(" ⚙️ Initializing value model for GAE...", flush=True) + value_model, value_time = init_value() + worker_init_timing_metrics["value_init_time_s"] = value_time + print(f" ✓ Value model initialized in {value_time:.2f}s", flush=True) + else: + value_model = None + + return policy_generation, policy, value_model + + # Handle generation-specific setup + if backend == "megatron": + # Megatron generation: policy_generation is None, only initialize policy + policy_generation = None + print( + f" ✓ Using {backend} backend for generation with {policy_config['model_name']}", + flush=True, + ) + + policy, policy_time = init_policy() + worker_init_timing_metrics["policy_init_time_s"] = policy_time + + # Value model not supported for megatron backend yet + value_model = None + + elif backend == "vllm": + # vLLM generation: setup config, then initialize with policy + generation_config = cast(VllmConfig, generation_config) + if generation_config["vllm_cfg"]["precision"] == "fp8": + assert loss_config["use_importance_sampling_correction"] is True, ( + "Importance sampling must be enabled for vLLM FP8 generation for good convergence!" + ) + if generation_config["vllm_cfg"]["kv_cache_dtype"].startswith("fp8"): + # FP8 KV cache requires FP8 model precision + assert generation_config["vllm_cfg"]["precision"] == "fp8", ( + f"kv_cache_dtype='{generation_config['vllm_cfg']['kv_cache_dtype']}' requires precision='fp8'. " + "FP8 KV cache can only be used together with FP8 model weights." + ) + # FP8 KV cache compatibility checks + assert policy_config["dtensor_cfg"]["enabled"] == False, ( + "DTensor backend is not supported with kv cache fp8 enabled." + ) + assert not _should_use_async_rollouts(master_config), ( + "Async rollouts is not supported with kv cache fp8 enabled." + ) + assert policy_config["megatron_cfg"]["pipeline_model_parallel_size"] == 1, ( + "Currently when using FP8 KV cache in generation, then in megatron we only support pipeline_model_parallel_size=1. We will add more support in future." + ) + + ## make vllm hf overrides match the training policy + generation_config["vllm_cfg"]["hf_overrides"] = policy_config.get( + "hf_config_overrides", {} + ) + + policy_generation, policy, value_model = initialize_generation_with_policy( + init_generation_fn=init_vllm, + generation_name="vLLM", + init_time_key="vllm_init_time_s", + colocated_inference=colocated_inference, + worker_init_timing_metrics=worker_init_timing_metrics, + ) + + print( + f" ✓ Using vLLM backend for generation with {policy_config['model_name']}", + flush=True, + ) + + elif backend == "sglang": + generation_config = cast(SGLangConfig, generation_config) + + # Set model_path if not already set + if "model_path" not in generation_config["sglang_cfg"]: + generation_config["sglang_cfg"]["model_path"] = policy_config["model_name"] + + policy_generation, policy, value_model = initialize_generation_with_policy( + init_generation_fn=init_sglang, + generation_name="SGLang", + init_time_key="sglang_init_time_s", + colocated_inference=colocated_inference, + worker_init_timing_metrics=worker_init_timing_metrics, + ) + + print( + f" ✓ Using SGLang backend for generation with {policy_config['model_name']}", + flush=True, + ) + + # Record when worker initialization completes (for calculating other setup time) + worker_init_complete_time = time.perf_counter() - setup_start_time + + # print the node IP and GPU ID of the policy workers for debugging + policy.print_node_ip_and_gpu_id() + + # if it is not colocated inference, initialize collective communication for update weights + if not colocated_inference: + t0 = time.perf_counter() + ip, port = train_cluster.get_master_address_and_port() + print(f"Using ip: {ip}, port: {port} for collective communication", flush=True) + # world includes all training workers and all inference workers + train_world_size = train_cluster.world_size() + inference_world_size = inference_nodes * inference_gpus_per_node + world_size = train_world_size + inference_world_size + # init collective + futures_train = policy.init_collective( + ip, port, world_size, train_world_size=train_world_size + ) + futures_inference = policy_generation.init_collective( + ip, port, world_size, train_world_size=train_world_size + ) # type: ignore + # wait for all futures to complete + ray.get(futures_train + futures_inference) + worker_init_timing_metrics["collective_init_time_s"] = time.perf_counter() - t0 + + # prepare refit info + state_dict_info = policy.prepare_refit_info() + if policy_generation is not None: + policy_generation.prepare_refit_info(state_dict_info) + + # Calculate total setup time + total_setup_time = time.perf_counter() - setup_start_time + worker_init_timing_metrics["total_setup_time_s"] = total_setup_time + + # Log worker initialization timing metrics to logger + if worker_init_timing_metrics: + print("\n▶ Worker Initialization Timing:") + + vllm_time = worker_init_timing_metrics.get("vllm_init_time_s", 0) + policy_time = worker_init_timing_metrics.get("policy_init_time_s", 0) + total_setup = worker_init_timing_metrics.get("total_setup_time_s", 0) + + if vllm_time: + print(f" vLLM init: {vllm_time:.1f}s") + + if policy_time: + print(f" Policy init: {policy_time:.1f}s") + + # Calculate "other" time (time after worker init completes) + other_time = total_setup - worker_init_complete_time + worker_init_timing_metrics["other_setup_time_s"] = other_time + print(f" Other setup: {other_time:.1f}s") + + print(f" Total setup: {total_setup:.1f}s") + + # Log all metrics to the logger for analysis + logger.log_metrics(worker_init_timing_metrics, step=0, prefix="timing/setup") + + print("\n" + "=" * 60) + print(" " * 18 + "SETUP COMPLETE") + print(f" Total setup time: {total_setup_time:.1f}s") + print("=" * 60 + "\n", flush=True) + + return ( + policy, + policy_generation, + value_model, + (train_cluster, inference_cluster), + dataloader, + val_dataloader, + loss_fn, + value_loss_fn, + logger, + checkpointer, + grpo_save_state, + master_config, + ) + + +def dynamic_sampling( + repeated_batch: BatchedDataDict[DatumSpec], + std: torch.Tensor, + baseline: torch.Tensor, + dynamic_sampling_num_gen_batches: int, + master_config: MasterConfig, + timer: Timer, + batch_cache: BatchedDataDict[DatumSpec] = None, +) -> BatchedDataDict[DatumSpec]: + """Implements the dynamic sampling algorithm to select prompts with non-zero standard deviation. + + This function filters the current batch to retain only those prompts that have a non-zero standard deviation. + If the current batch has fewer number of prompts with non-zero standard deviation than the required batch size, defined as num_prompts_per_step * num_generations_per_prompt, + we store it in the batch_cache to be used in later iterations. + If the current batch has more number of prompts with non-zero standard deviation than the required batch size, defined as num_prompts_per_step * num_generations_per_prompt, + the batch is sliced to ensure batch size is num_prompts_per_step * num_generations_per_prompt. + is_batch_complete is set to False to indicate that the current batch is not enough to meet the required batch size. This is used as a signal in the GRPO training loop + to continue sampling or proceed to training. + This approach is based on the dynamic sampling algorithm from the DAPO paper: + https://arxiv.org/pdf/2503.14476. + + Args: + repeated_batch (BatchedDataDict[DatumSpec]): The current batch of data containing prompts, responses, rewards, baselines, and std. + std (torch.Tensor): Tensor representing the standard deviation for each prompt group. + baseline (torch.Tensor): Baseline values for each prompt group. + dynamic_sampling_num_gen_batches (int): Number of generation batches processed at the current step. + master_config (MasterConfig): Configuration containing GRPO and policy settings. + batch_cache (BatchedDataDict[DatumSpec], optional): Cache storing previously selected prompts with non-zero std. + + Returns: + tuple: A tuple containing: + - repeated_batch (BatchedDataDict[DatumSpec]): Updated batch with selected prompts. + - is_batch_complete (bool): Indicates if the batch has enough samples with non-zero std for training. + - batch_cache (BatchedDataDict[DatumSpec]): Updated cache for future iterations. + """ + # is_batch_complete is used to indicate if the current batch was able to generate enough prompts with non-zero std. + is_batch_complete = True + + # Required batch size for training + train_prompts_size = ( + master_config["grpo"]["num_prompts_per_step"] + * master_config["grpo"]["num_generations_per_prompt"] + ) + # Store the baseline, std and total_reward for the current unfiltered batch. + repeated_batch["baseline"] = baseline + repeated_batch["std"] = std + total_rewards = repeated_batch["total_reward"] + dynamic_sampling_metrics = {} + + # Dynamic sampling algorithm (used in DAPO algorithm) + # This block implements dynamic sampling by selecting prompt groups with non-zero std. + # If sampled prompts (with non-zero std) are fewer than num_prompts_per_step * num_generations_per_prompt, continue sampling until dynamic_sampling_max_gen_batches is reached. + if master_config["grpo"]["use_dynamic_sampling"]: + with timer.time("dynamic_sampling"): + # Get the prompt indices with non-zero std + non_zero_std_mask = std != 0.0 + + keep_prompt_indices = torch.arange( + len(non_zero_std_mask), device=std.device + )[non_zero_std_mask].tolist() + + # Only select the inputs that have non-zero std + # total_reward is already a part of repeated_batch so we don't need to add it again + filtered_repeated_batch = repeated_batch.select_indices(keep_prompt_indices) + filtered_repeated_batch["std"] = std[keep_prompt_indices] + filtered_repeated_batch["baseline"] = baseline[keep_prompt_indices] + + # Store filtered and total rewards to track them separately + filtered_rewards = filtered_repeated_batch["total_reward"] + filtered_repeated_batch["total_reward"] = total_rewards + filtered_repeated_batch["filtered_reward"] = filtered_rewards + + # Store the total_reward for the current filtered batch. + # If none of the prompts in current batch have non-zero std, filtered_repeated_batch.size will be 0. + # In this case, the current batch will be ignored and the next batch will be processed and we generate responses for it. + if filtered_repeated_batch.size > 0: + # Concatenate the previous partially filled batch with the current batch. This serves as a cache to store and collect the prompts with non-zero std. + # This is used in the next iteration when the current batch is not enough to fill the buffer. + batch_cache = ( + filtered_repeated_batch + if batch_cache is None + else BatchedDataDict.from_batches( + [batch_cache, filtered_repeated_batch] + ) + ) + filtered_repeated_batch = batch_cache + + filtered_prompts_size = filtered_repeated_batch.size + print( + f"Detected {filtered_prompts_size} prompts with non-zero std; " + f"{train_prompts_size} are required and used for training." + ) + + # If the generation samples size is smaller than a fixed threshold (train_prompts_size), keep generating by processing the next batch + if filtered_prompts_size < train_prompts_size: + dynamic_sampling_max_gen_batches = master_config["grpo"][ + "dynamic_sampling_max_gen_batches" + ] + assert dynamic_sampling_max_gen_batches > 0, ( + "When using grpo.use_dynamic_sampling, grpo.dynamic_sampling_max_gen_batches must be > 0" + ) + if dynamic_sampling_num_gen_batches <= dynamic_sampling_max_gen_batches: + print( + f"Generation sample buffer size: {filtered_prompts_size} is smaller than train_prompts_size: {train_prompts_size}. Processed {dynamic_sampling_num_gen_batches} batches so far out of {dynamic_sampling_max_gen_batches}." + ) + is_batch_complete = False + else: + raise ValueError( + f"Dynamic sampling has reached the maximum allowed number of batches ({dynamic_sampling_max_gen_batches}). Consider evaluating the complexity of your data or adjusting the num_prompts_per_step or num_generations_per_prompt parameters to enhance the diversity of the samples." + ) + else: + num_discarded_valid_samples = filtered_prompts_size - train_prompts_size + dynamic_sampling_metrics[ + "dynamic_sampling_num_discarded_valid_samples" + ] = num_discarded_valid_samples + + # Slice the batch, rewards, baselines and std to ensure batch size is train_prompts_size + filtered_repeated_batch = filtered_repeated_batch.slice( + 0, train_prompts_size + ) + + batch_to_return = ( + filtered_repeated_batch + if master_config["grpo"]["use_dynamic_sampling"] + else repeated_batch + ) + return batch_to_return, is_batch_complete, batch_cache, dynamic_sampling_metrics + + +def scale_rewards( + repeated_batch: BatchedDataDict[DatumSpec], reward_scaling_cfg: RewardScalingConfig +) -> BatchedDataDict[DatumSpec]: + """Linearly scales rewards from a source range to a target range. + + If `reward_scaling.enabled` is True, each reward in `repeated_batch["total_reward"]` + is clamped to the configured source interval [source_min, source_max] and then + rescaled to the target interval [target_min, target_max]. + + Default configuration: + source_min = 0.0 + source_max = 1.0 + target_min = 0.0 + target_max = 1.0 + """ + if reward_scaling_cfg["enabled"]: + rewards = repeated_batch["total_reward"] + source_min = float(reward_scaling_cfg["source_min"]) + source_max = float(reward_scaling_cfg["source_max"]) + target_min = float(reward_scaling_cfg["target_min"]) + target_max = float(reward_scaling_cfg["target_max"]) + + # Detect out-of-range values + out_of_range_mask = (rewards < source_min) | (rewards > source_max) + if torch.any(out_of_range_mask): + print( + f"[reward_scaling] WARNING: {int(out_of_range_mask.sum())} rewards " + f"are outside the configured source range [{source_min}, {source_max}]. " + f"Values will be clipped before scaling." + ) + + # Clamp and scale + rewards = torch.clamp(rewards, min=source_min, max=source_max) + scaled_rewards = target_min + (rewards - source_min) / ( + source_max - source_min + ) * (target_max - target_min) + repeated_batch["total_reward"] = scaled_rewards + + return repeated_batch + + +def _create_advantage_estimator(master_config: MasterConfig): + """Create and return an advantage estimator based on configuration. + + Args: + master_config: The master configuration dictionary. + + Returns: + An advantage estimator instance (GRPOAdvantageEstimator, ReinforcePlusPlusAdvantageEstimator, or GAEAdvantageEstimator). + + Raises: + ValueError: If the advantage estimator name is not recognized. + """ + ppo_config = master_config["ppo"] + loss_config = master_config["loss_fn"] + + # Provide backward-compatible defaults when adv_estimator is not in config. + # Fall back to top-level grpo.normalize_rewards / grpo.use_leave_one_out_baseline + # which older configs still use. + adv_estimator_config = ppo_config.get( + "adv_estimator", + { + "name": "grpo", + "normalize_rewards": ppo_config.get("normalize_rewards", True), + "use_leave_one_out_baseline": ppo_config.get( + "use_leave_one_out_baseline", False + ), + "minus_baseline": True, + }, + ) + + adv_estimator_name = adv_estimator_config["name"] + if adv_estimator_name == "grpo": + adv_estimator = GRPOAdvantageEstimator(adv_estimator_config, loss_config) + print(" ✓ Using GRPO advantage estimator") + elif adv_estimator_name == "reinforce_plus_plus": + adv_estimator = ReinforcePlusPlusAdvantageEstimator( + adv_estimator_config, loss_config + ) + print(" ✓ Using Reinforce++ advantage estimator") + elif adv_estimator_name == "gae": + adv_estimator = GeneralizedAdvantageEstimator(adv_estimator_config, loss_config) + gae_lambda = adv_estimator_config.get("gae_lambda", 0.95) + gae_gamma = adv_estimator_config.get("gae_gamma", 0.99) + print(f" ✓ Using GAE advantage estimator (λ={gae_lambda}, γ={gae_gamma})") + else: + raise ValueError(f"Invalid adv_estimator name: {adv_estimator_name}") + + return adv_estimator + + +def _extract_prompt_only_messages(message_logs: list) -> list: + """Extract only prompt messages (user/system) from message logs. + + This is used to get prompt IDs for advantage estimation, excluding + any assistant responses. + + Args: + message_logs: List of message logs, where each log is a list of messages. + + Returns: + List of message logs containing only user and system messages. + """ + prompt_only_message_logs = [] + for message_log in message_logs: + prompt_only_log = [] + for message in message_log: + if message["role"] == "user" or message["role"] == "system": + prompt_only_log.append(message) + prompt_only_message_logs.append(prompt_only_log) + return prompt_only_message_logs + + +def refit_policy_generation( + policy: ColocatablePolicyInterface, + policy_generation: GenerationInterface, + colocated_inference: bool, + _refit_buffer_size_gb: Optional[int] = None, + timer: Optional[Timer] = None, + kv_scales: Optional[dict[str, float]] = None, +) -> None: + """Refit the policy generation interface with the latest policy weights. + + Args: + policy: The policy to provide weights to the inference engine. + policy_generation: The inference engine to refit. + _refit_buffer_size_gb: The size of the buffer to use for refitting. + If it is None, the buffer size will be computed by the remaining memory. + This parameter is primarily used for testing. + timer: Optional Timer used to time the prepare/transfer/update phase + kv_scales: Optional dictionary of KV cache scales for FP8 quantization. + """ + if colocated_inference: + policy.offload_before_refit() + policy_generation.prepare_for_generation(tags=["weights"]) + + # Create a context manager that does nothing when timer is None + timer_context = ( + timer.time("prepare_for_generation/transfer_and_update_weights") + if timer is not None + else nullcontext() + ) + with timer_context: + # update weights + update_success = False + if colocated_inference: + # get model param keys, which is grouped by size + if _refit_buffer_size_gb is not None: + buffer_size_bytes = _refit_buffer_size_gb * (1024**3) + else: + # Empirically sets ratio as 30% to maximize efficiency. + # The remaining 70% is a necessary buffer reserved for the parameter all-gathering across the expert-parallelism dimension. + memory_ratio = os.getenv("NRL_REFIT_BUFFER_MEMORY_RATIO", "0.3") + buffer_size_bytes = int( + policy.get_free_memory_bytes() * float(memory_ratio) + ) + + if isinstance(policy_generation, SGLangGeneration): + sglang_url_to_gpu_uuids = ( + policy_generation.get_sglang_url_to_gpu_uuids() + ) + # Stream weights via HTTP + flush_success = policy_generation.invalidate_kv_cache() + if not flush_success: + print("SGLang KV cache invalidation failed before weight update. ") + futures_train = policy.stream_weights_via_http( + sglang_url_to_gpu_uuids=sglang_url_to_gpu_uuids, + ) + # Wait for all workers to complete + ray.get(futures_train) + update_success = True + else: + # Original ZMQ IPC path for vLLM + futures_train = policy.stream_weights_via_ipc_zmq( + buffer_size_bytes=buffer_size_bytes + ) + futures_inference = policy_generation.update_weights_via_ipc_zmq() + # wait for all futures to complete + ray.get(futures_train) + results = ray.get(futures_inference) + update_success = all(result for result in results if result is not None) + else: + # update weights through nccl + # SGLang haven't implemented non-colocated inference mode. + if isinstance(policy_generation, SGLangGeneration): + raise NotImplementedError( + "SGLang haven't implemented non-colocated inference mode. " + ) + futures_train = policy.broadcast_weights_for_collective(kv_scales=kv_scales) + futures_inference = policy_generation.update_weights_from_collective() + # wait for all futures to complete + ray.get(futures_train) + results = ray.get(futures_inference) + update_success = all(result for result in results if result is not None) + + # check if update is successful + if not update_success: + error_tag = "cuda-ipc" if colocated_inference else "nccl" + error_message = ( + "❌ Error: Updating weights for the generation policy failed during refit.\n" + f"This often indicates an issue with {error_tag} or " + "a problem within the generation backend (e.g., vLLM worker).\n" + ) + raise RuntimeError(error_message) + + if colocated_inference: + policy.offload_after_refit() + policy_generation.prepare_for_generation(tags=["kv_cache"]) + + +def _log_mixed_rewards_and_advantages_information( + logger: Logger, + total_steps: int, + metrics: dict[str, Any], + baseline: torch.Tensor, + advantages: torch.Tensor, +) -> None: + # The histograms that are logged are logged with a prefix "train/" to the name, since that is what the remaining metrics will be logged with. + logger.log_histogram( + baseline.numpy(), total_steps + 1, "train/baseline_reward/histogram" + ) + metrics["baseline_reward/pct_0"] = 100 * (baseline == 0).float().mean().item() + metrics["baseline_reward/pct_1"] = 100 * (baseline == 1).float().mean().item() + metrics["baseline_reward/pct_mixed"] = ( + 100 - metrics["baseline_reward/pct_0"] - metrics["baseline_reward/pct_1"] + ) + + logger.log_histogram( + advantages.numpy(), total_steps + 1, "train/advantages/histogram" + ) + metrics["advantages/sum"] = advantages.float().sum().item() + metrics["advantages/mean"] = advantages.float().mean().item() + + +# =============================================================================== +# Training & Validation +# =============================================================================== + + +def ppo_train( + policy: ColocatablePolicyInterface, + policy_generation: Optional[GenerationInterface], + value_model: ValueInterface, + dataloader: StatefulDataLoader, + val_dataloader: Optional[StatefulDataLoader], + tokenizer: TokenizerType, + loss_fn: LossFunction, + value_loss_fn: LossFunction, + task_to_env: dict[str, EnvironmentInterface], + val_task_to_env: Optional[dict[str, EnvironmentInterface]], + logger: Logger, + checkpointer: CheckpointManager, + grpo_save_state: GRPOSaveState, + master_config: MasterConfig, +) -> None: + """Run PPO training algorithm.""" + timer = Timer() + timeout = TimeoutChecker( + timeout=master_config["checkpointing"]["checkpoint_must_save_by"], + fit_last_save_time=True, + ) + timeout.start_iterations() + memory_tracker = MemoryTracker() + + kv_scales_cache = None # Cache reused for computed kv scales + + NEED_REFIT = True + # If policy_generation is None, use the policy as the generation interface (megatron framework backend) + if policy_generation is None: + policy_generation = policy # type: ignore + NEED_REFIT = False + POLICY_GENERATION_STALE = True # tracks if generation needs a refit before running + assert policy_generation is not None # for mypy type check + + if master_config["ppo"].get("skip_reference_policy_logprobs_calculation"): + assert master_config["loss_fn"]["reference_policy_kl_penalty"] == 0 + print( + "Reference policy logprob calculation will be skipped since `ppo.skip_reference_policy_logprobs_calculation` is set to True and `loss_fn.reference_policy_kl_penalty` is 0." + ) + + # Check if we need to sync KV cache scales + # When fallback to policy as the policy_generation, we use getattr to check. + sync_kv_scales = getattr(policy_generation, "requires_kv_scale_sync", False) + + # common config/state times + current_step = grpo_save_state["current_step"] # current step within an epoch + total_steps = grpo_save_state["total_steps"] # total steps across all epochs + current_epoch = grpo_save_state["current_epoch"] # current epoch + max_num_epochs = master_config["ppo"][ + "max_num_epochs" + ] # max number of epochs to train for + steps_per_epoch = master_config["ppo"]["steps_per_epoch"] + + consumed_samples = grpo_save_state[ + "consumed_samples" + ] # total samples consumed across all epochs + total_valid_tokens = grpo_save_state.get( + "total_valid_tokens", 0 + ) # total valid tokens processed across all epochs; default to 0 for backward compatibility with older checkpoints + val_at_start = master_config["ppo"]["val_at_start"] + val_at_end = master_config["ppo"]["val_at_end"] + val_period = master_config["ppo"]["val_period"] + colocated_inference = master_config["policy"]["generation"]["colocated"]["enabled"] + + # Initialize advantage estimator + adv_estimator = _create_advantage_estimator(master_config) + + # Run validation at the start if configured + # TODO: Add validation with kv scales if needed + if val_at_start and current_step == 0: + print("\n🔍 Running initial validation...", flush=True) + memory_tracker.snapshot_start_of_stage("Initial validation", dir()) + + if NEED_REFIT and POLICY_GENERATION_STALE: + refit_policy_generation(policy, policy_generation, colocated_inference) + POLICY_GENERATION_STALE = False + else: + policy_generation.prepare_for_generation() + val_metrics, validation_timings = validate( + policy_generation, + val_dataloader, + tokenizer, + val_task_to_env, + step=0, + master_config=master_config, + logger=logger, + ) + policy_generation.finish_generation() + logger.log_metrics(val_metrics, current_step, prefix="validation") + logger.log_metrics(validation_timings, current_step, prefix="timing/validation") + + # Run PPO training loop + train_loader_iter = iter(dataloader) + val_metrics = None + for epoch in range(current_epoch, max_num_epochs): + metrics_logging_data = dict() + print(f"\n{'=' * 25} Epoch {epoch + 1}/{max_num_epochs} {'=' * 25}") + + with timer.time("total_epoch_time"): + print("▶ Preparing batch...", flush=True) + with timer.time("data_processing"): + batch = next(train_loader_iter) + repeated_batch: BatchedDataDict[DatumSpec] = batch.repeat_interleave( + master_config["ppo"]["num_generations_per_prompt"] + ) + # Convert LLMMessageLogType to FlatMessagesType for generation + batched_flat, input_lengths = batched_message_log_to_flat_message( + repeated_batch["message_log"], + pad_value_dict={"token_ids": tokenizer.pad_token_id}, + ) + input_ids = batched_flat["token_ids"] + print( + f"▶ Generating responses for batch of size {repeated_batch.size}...", + flush=True, + ) + with timer.time("prepare_for_generation/total"): + if NEED_REFIT and POLICY_GENERATION_STALE: + refit_policy_generation( + policy, policy_generation, colocated_inference + ) + POLICY_GENERATION_STALE = False + else: + if colocated_inference: + policy.offload_after_refit() + policy_generation.prepare_for_generation() + + with timer.time("generation"): + print("▶ Generating responses...", flush=True) + # Clear logger metrics for each generation step + if policy_generation is not None: + policy_generation.clear_logger_metrics() + # Use NeMo-Gym rollouts if enabled. We cascade NeMo-Gym first since NeMo-Gym requires async rollouts. + repeated_batch, rollout_metrics = run_multi_turn_rollout( + policy_generation=policy_generation, + input_batch=repeated_batch, + tokenizer=tokenizer, + task_to_env=task_to_env, + max_seq_len=master_config["policy"]["max_total_sequence_length"], + max_rollout_turns=master_config["ppo"]["max_rollout_turns"], + greedy=False, + ) + + policy_generation.finish_generation() + # Collect generation logger metrics for performance reporting after each generation step + # inflight batch sizes and num pending samples are collected from each worker + if policy_generation is not None: + generation_logger_metrics = policy_generation.get_logger_metrics() + metrics_logging_data["mean_gen_tokens_per_sample"] = rollout_metrics[ + "mean_gen_tokens_per_sample" + ] + logger.log_metrics(rollout_metrics, total_steps + 1, prefix="train") + + repeated_batch = scale_rewards( + repeated_batch, master_config["ppo"]["reward_scaling"] + ) + if master_config["ppo"]["reward_shaping"]["enabled"]: + repeated_batch = apply_reward_shaping( + repeated_batch, master_config["ppo"]["reward_shaping"] + ) + + with timer.time("reward_calculation"): + print("▶ Calculating rewards and values...", flush=True) + + for message_log in repeated_batch["message_log"]: + for _, message in enumerate(message_log): + if message["role"] == "assistant": + message["token_loss_mask"] = torch.ones_like( + message["token_ids"] + ) + else: + message["token_loss_mask"] = torch.zeros_like( + message["token_ids"] + ) + if "generation_logprobs" not in message: + message["generation_logprobs"] = torch.zeros_like( + message["token_ids"], dtype=torch.float32 + ) + + messages, input_lengths = batched_message_log_to_flat_message( + repeated_batch["message_log"], + pad_value_dict={"token_ids": tokenizer.pad_token_id}, + make_sequence_length_divisible_by=master_config["policy"][ + "make_sequence_length_divisible_by" + ], + ) + + train_data = BatchedDataDict[ClippedPGLossDataDict]( + { + "input_ids": messages["token_ids"], + "input_lengths": input_lengths, + "generation_logprobs": messages["generation_logprobs"], + "rewards": repeated_batch["total_reward"], + "sample_mask": repeated_batch["loss_multiplier"], + "token_mask": messages["token_loss_mask"], + } + ) + + values = value_model.get_values(train_data) + train_data["values"] = values["values"][..., 0] + + print( + f" • Average batch reward: {train_data['rewards'].mean().numpy():.4f}\n" + f" • Average batch response length: {input_lengths.sum() / input_lengths.shape[0]:.4f}\n" + f" • Max batch response length: {input_lengths.max():.4f}" + ) + + with timer.time("logprob_inference_prep"): + print("▶ Preparing for logprob inference...", flush=True) + policy.prepare_for_lp_inference() + + with timer.time("policy_and_reference_logprobs"): + print("▶ Computing policy and reference logprobs...", flush=True) + train_data["prev_logprobs"] = policy.get_logprobs( + train_data, timer=timer + )["logprobs"] + + train_data["reference_policy_logprobs"] = ( + policy.get_reference_policy_logprobs(train_data, timer=timer)[ + "reference_logprobs" + ] + ) + + with timer.time("compute_advantages"): + print("▶ Computing advantages...", flush=True) + advantages, returns = adv_estimator.compute_advantage( + prompt_ids=torch.arange(messages["token_ids"].shape[0]), + rewards=train_data["rewards"], + mask=train_data["token_mask"], + values=train_data["values"], + lengths=input_lengths, + reference_logprobs=train_data["reference_policy_logprobs"], + logprobs=train_data["prev_logprobs"], + ) + + train_data["advantages"] = advantages + train_data["returns"] = returns + + for step in range(steps_per_epoch): + print( + f"▶ Epoch {epoch + 1}/{max_num_epochs}, Step {step + 1}/{steps_per_epoch}...", + flush=True, + ) + + with timer.time("policy_training_prep"): + policy.prepare_for_training() + + with timer.time("policy_training"): + print(" • Training policy...", flush=True) + train_results = policy.train( + train_data, + loss_fn, + timer=timer, + ) + policy.finish_training() + POLICY_GENERATION_STALE = True + + with timer.time("value_training_prep"): + value_model.prepare_for_training() + + with timer.time("value_training"): + print(" • Training value...", flush=True) + value_results = value_model.train( + train_data, + value_loss_fn, + timer=timer, + ) + value_model.finish_training() + + print(" • Results:") + + print( + f" • Policy loss: {train_results['loss'].mean().item():.4f}" + ) + print( + f" • Value loss: {value_results['loss'].mean().item():.4f}" + ) + + is_last_epoch = epoch + 1 >= max_num_epochs + if (val_period > 0 and (epoch + 1) % val_period == 0) or ( + val_at_end and is_last_epoch + ): + with timer.time("validation"): + print("▶ Validating...", flush=True) + + if NEED_REFIT and POLICY_GENERATION_STALE: + refit_policy_generation( + policy, policy_generation, colocated_inference + ) + POLICY_GENERATION_STALE = False + else: + if colocated_inference: + policy.offload_after_refit() + policy_generation.prepare_for_generation() + val_metrics, validation_timings = validate( + policy_generation, + val_dataloader, + tokenizer, + val_task_to_env, + step=epoch + 1, + master_config=master_config, + logger=logger, + ) + policy_generation.finish_generation() + + logger.log_metrics( + validation_timings, + current_epoch + 1, + prefix="timing/validation", + ) + logger.log_metrics( + val_metrics, current_epoch + 1, prefix="validation" + ) + + train_metrics = { + "policy_loss": train_results["loss"].mean().item(), + "value_loss": value_results["loss"].mean().item(), + "reward": repeated_batch["total_reward"].mean().item(), + } + + # --- Checkpointing --- + consumed_samples += master_config["ppo"]["num_prompts_per_step"] + total_steps += 1 + timeout.mark_iteration() + + is_last_epoch = epoch + 1 >= max_num_epochs + should_save_by_step = ( + is_last_epoch + or total_steps % master_config["checkpointing"]["save_period"] == 0 + ) + should_save_by_timeout = timeout.check_save() + + if master_config["checkpointing"]["enabled"] and ( + should_save_by_step or should_save_by_timeout + ): + policy.prepare_for_training() + + grpo_save_state["current_step"] = 0 + grpo_save_state["total_steps"] = total_steps + grpo_save_state["current_epoch"] = epoch + 1 + grpo_save_state["total_valid_tokens"] = total_valid_tokens + grpo_save_state["consumed_samples"] = consumed_samples + if val_metrics is not None: + grpo_save_state["val_reward"] = val_metrics.get( + "accuracy", -float("inf") + ) + elif "val_reward" in grpo_save_state: + del grpo_save_state["val_reward"] + + full_metric_name = master_config["checkpointing"]["metric_name"] + if full_metric_name is not None: + assert full_metric_name.startswith( + "train:" + ) or full_metric_name.startswith("val:"), ( + f"metric_name={full_metric_name} must start with 'val:' or 'train:',\n" + f"followed by the corresponding name in the metrics dictionary." + ) + prefix, metric_name = full_metric_name.split(":", 1) + metrics_source = train_metrics if prefix == "train" else val_metrics + if not metrics_source: + warnings.warn( + f"You asked to save checkpoints based on {metric_name} but no " + f"{prefix} metrics were collected. This checkpoint will not be saved as top-k.", + stacklevel=2, + ) + if full_metric_name in grpo_save_state: + del grpo_save_state[full_metric_name] + elif metric_name not in metrics_source: + raise ValueError( + f"Metric {metric_name} not found in {prefix} metrics" + ) + else: + grpo_save_state[full_metric_name] = metrics_source[metric_name] + + with timer.time("checkpointing"): + print( + f"Saving checkpoint for epoch {epoch + 1} (step {total_steps})...", + flush=True, + ) + checkpoint_path = checkpointer.init_tmp_checkpoint( + total_steps, grpo_save_state, master_config + ) + policy.save_checkpoint( + weights_path=os.path.join(checkpoint_path, "policy", "weights"), + optimizer_path=os.path.join( + checkpoint_path, "policy", "optimizer" + ), + tokenizer_path=os.path.join( + checkpoint_path, "policy", "tokenizer" + ), + checkpointing_cfg=master_config["checkpointing"], + ) + if value_model is not None: + value_model.save_checkpoint( + weights_path=os.path.join( + checkpoint_path, "value", "weights" + ), + optimizer_path=os.path.join( + checkpoint_path, "value", "optimizer" + ), + checkpointing_cfg=master_config["checkpointing"], + ) + torch.save( + dataloader.state_dict(), + os.path.join(checkpoint_path, "train_dataloader.pt"), + ) + checkpointer.finalize_checkpoint(checkpoint_path) + + timer.reset() + + +def validate( + policy_generation: GenerationInterface, + val_dataloader: Optional[StatefulDataLoader], + tokenizer, + val_task_to_env: Optional[dict[str, EnvironmentInterface]], + step: int, + master_config: MasterConfig, + logger: Optional[Logger] = None, +) -> tuple[dict[str, Any], dict[str, Any]]: + """Run validation on the validation dataset.""" + if val_dataloader is None: + assert val_dataloader is not None or master_config["dpo"]["val_period"] == 0, ( + "val_dataloader is None, so dpo.val_period must be 0" + ) + print(" ⚠️ No validation dataloader provided, skipping validation", flush=True) + return {}, {} + + timer = Timer() + with timer.time("total_validation_time"): + print(f"▶ Starting validation at step {step}...", flush=True) + + total_rewards = [] + total_lengths = [] + all_message_logs = [] # Collect all message logs + + max_batches = ( + master_config["ppo"]["max_val_samples"] + // master_config["ppo"]["val_batch_size"] + ) + for batch_idx, val_batch in enumerate(val_dataloader): + if batch_idx >= max_batches: + break + + additional_metrics_to_report = dict() + + val_batch, gen_metrics = run_multi_turn_rollout( + policy_generation, + val_batch, + tokenizer, + val_task_to_env, + max_seq_len=master_config["policy"]["max_total_sequence_length"], + max_rollout_turns=master_config["ppo"]["max_rollout_turns"], + greedy=False, + ) + + total_rewards.extend(val_batch["total_reward"].tolist()) + total_lengths.append(gen_metrics["mean_gen_tokens_per_sample"]) + + # Collect message logs for later display + to_env = [ + get_keys_from_message_log( + val_batch["message_log"][i], ["role", "content"] + ) + for i in range(len(val_batch["message_log"])) + ] + + all_message_logs.extend(to_env) + + # Calculate validation metrics + num_samples = len(total_rewards) + if num_samples > 0: + rewards_t = torch.tensor(total_rewards, dtype=torch.float32) + accuracy = rewards_t.mean().item() + else: + accuracy = 0.0 + + avg_length = ( + sum(total_lengths) / len(total_lengths) if len(total_lengths) > 0 else 0.0 + ) + + val_metrics = { + "accuracy": accuracy, + "avg_length": avg_length, + **additional_metrics_to_report, + } + + # Print sample conversations only once at the end of validation + try: + print_message_log_samples( + all_message_logs, + total_rewards, + num_samples=min( + master_config["logger"]["num_val_samples_to_print"], + len(all_message_logs), + ), + step=step, + ) + except Exception as e: + print(f"\n ⚠️ Error displaying message samples: {str(e)}") + print(" ⚠️ Continuing validation without displaying samples...", flush=True) + + # Get timing metrics + timing_metrics = timer.get_timing_metrics(reduction_op="sum") + validation_time = timing_metrics.get("total_validation_time", 0) + + # Print summary of validation results + print("\n📊 Validation Results:") + print(f" • Accuracy: {accuracy:.4f}") + print(f" • Average response length: {avg_length:.1f} tokens") + print(f" • Samples processed: {len(total_rewards)}", flush=True) + + # Print timing information + print("\n ⏱️ Validation Timing:") + validation_time = timing_metrics.get("total_validation_time", 0) + print(f" • Total validation time: {validation_time:.2f}s", flush=True) + + # Log validation data to JSONL file + if logger is not None: + val_log_data = { + "content": all_message_logs, + "rewards": total_rewards, + } + logger.log_batched_dict_as_jsonl(val_log_data, f"val_data_step{step}.jsonl") + + # Make sure to reset the timer after validation + timer.reset() + + # Explicit GPU memory cleanup after validation + gc.collect() + torch.cuda.empty_cache() + + return val_metrics, timing_metrics diff --git a/nemo_rl/algorithms/utils.py b/nemo_rl/algorithms/utils.py index f2f3da7d0b..cb3f572395 100644 --- a/nemo_rl/algorithms/utils.py +++ b/nemo_rl/algorithms/utils.py @@ -17,15 +17,12 @@ import re import warnings from functools import partial, wraps +from locale import normalize from typing import Any, Optional import numpy as np import torch -from transformers import ( - AutoProcessor, - AutoTokenizer, - PreTrainedTokenizerBase, -) +from transformers import AutoProcessor, AutoTokenizer, PreTrainedTokenizerBase from nemo_rl.data.chat_templates import COMMON_CHAT_TEMPLATES from nemo_rl.models.policy import TokenizerConfig @@ -222,6 +219,24 @@ def mask_out_neg_inf_logprobs( return logprobs +def masked_var( + values: torch.Tensor, + mask: torch.Tensor, + mean: Optional[torch.Tensor | float] = None, + unbiased: bool = True, +) -> torch.Tensor: + if mean is None: + mean = masked_mean(values, mask) + centered_values = values - mean + variance = masked_mean(centered_values**2, mask) + + if unbiased: + normalization_factor = torch.sum(mask) + correction = (normalization_factor) / (normalization_factor - 1) + variance = variance * correction + return variance + + def set_seed(seed: int) -> None: """Sets the seed for python, numpy, and pytorch.""" random.seed(seed) diff --git a/nemo_rl/distributed/ray_actor_environment_registry.py b/nemo_rl/distributed/ray_actor_environment_registry.py index 95677873a4..d8eb88b31d 100644 --- a/nemo_rl/distributed/ray_actor_environment_registry.py +++ b/nemo_rl/distributed/ray_actor_environment_registry.py @@ -33,6 +33,7 @@ "nemo_rl.models.generation.sglang.sglang_worker.SGLangGenerationWorker": SGLANG_EXECUTABLE, "nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker": PY_EXECUTABLES.FSDP, "nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2": PY_EXECUTABLES.AUTOMODEL, + "nemo_rl.models.value.workers.dtensor_value_worker_v2.DTensorValueWorkerV2": PY_EXECUTABLES.AUTOMODEL, "nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker": MCORE_EXECUTABLE, "nemo_rl.environments.math_environment.MathEnvironment": PY_EXECUTABLES.SYSTEM, "nemo_rl.environments.math_environment.MathMultiRewardEnvironment": PY_EXECUTABLES.SYSTEM, diff --git a/nemo_rl/models/automodel/setup.py b/nemo_rl/models/automodel/setup.py index ad90c78c67..d3e20e2dde 100644 --- a/nemo_rl/models/automodel/setup.py +++ b/nemo_rl/models/automodel/setup.py @@ -20,7 +20,10 @@ import torch from accelerate import init_empty_weights from hydra.utils import get_class -from nemo_automodel import NeMoAutoModelForSequenceClassification +from nemo_automodel import ( + NeMoAutoModelForSequenceClassification, + NeMoAutoModelForTokenClassification, +) from nemo_automodel._transformers.registry import ModelRegistry from nemo_automodel.components._peft.lora import ( PeftConfig, @@ -168,6 +171,14 @@ def validate_and_prepare_config( "for the linear head of Bradley-Terry reward models." ) model_config.num_labels = 1 + elif rm_type == "regression": + model_class = NeMoAutoModelForTokenClassification + if model_config.num_labels != 1: + print( + "model_config.num_labels is not 1. Setting it to 1 since this value is used as the out_features " + "for the linear head of regression reward models." + ) + model_config.num_labels = 1 else: raise ValueError(f"Unknown reward model type: {rm_type}") else: @@ -314,6 +325,7 @@ def setup_model_and_optimizer( init_optimizer: bool = True, weights_path: Optional[str] = None, optimizer_path: Optional[str] = None, + optimizer_module_filter: Optional[list[str]] = None, ) -> ModelAndOptimizerState: """Set up model, parallelization, and optimizer. @@ -330,6 +342,7 @@ def setup_model_and_optimizer( init_optimizer: Whether to initialize optimizer weights_path: Optional path to checkpoint weights to load optimizer_path: Optional path to optimizer state to load + optimizer_module_filter: Optional list of module names to filter optimizer parameters Returns: ModelAndOptimizerState containing model, optimizer, scheduler, and metadata @@ -535,7 +548,16 @@ def setup_model_and_optimizer( optimizer = None if init_optimizer: optimizer_cls = get_class(config["optimizer"]["name"]) - optimizer = optimizer_cls(model.parameters(), **config["optimizer"]["kwargs"]) + + if optimizer_module_filter is not None: + parameters = [ + p + for n, p in model.named_parameters() + if any(x in n for x in optimizer_module_filter) + ] + else: + parameters = model.parameters() + optimizer = optimizer_cls(parameters, **config["optimizer"]["kwargs"]) # Initialize scheduler scheduler = None diff --git a/nemo_rl/models/value/__init__.py b/nemo_rl/models/value/__init__.py new file mode 100644 index 0000000000..b3e412ee1f --- /dev/null +++ b/nemo_rl/models/value/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemo_rl.models.value.config import ValueConfig +from nemo_rl.models.value.interfaces import ValueInterface, ValueOutputSpec +from nemo_rl.models.value.lm_value import Value + +__all__ = ["Value", "ValueConfig", "ValueInterface", "ValueOutputSpec"] diff --git a/nemo_rl/models/value/config.py b/nemo_rl/models/value/config.py new file mode 100644 index 0000000000..2b13511552 --- /dev/null +++ b/nemo_rl/models/value/config.py @@ -0,0 +1,80 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, NotRequired, TypedDict + +from nemo_rl.models.policy import ( + DTensorConfig, + DTensorConfigDisabled, + DynamicBatchingConfig, + DynamicBatchingConfigDisabled, + PytorchOptimizerConfig, + RewardModelConfig, + SchedulerMilestones, + SequencePackingConfig, + SequencePackingConfigDisabled, + SinglePytorchMilestonesConfig, + SinglePytorchSchedulerConfig, + TokenizerConfig, +) + + +class ValueConfig(TypedDict): + """Configuration for Value models in PPO. + + Value models use a subset of PolicyConfig fields, excluding generation-specific + and reference policy settings. + """ + + model_name: str + tokenizer: TokenizerConfig + + # Training batch sizes + train_global_batch_size: int + train_micro_batch_size: int + logprob_batch_size: NotRequired[int] # Used for value inference batch size + + # Precision + precision: str + + # Reward model config (value models use regression head) + reward_model_cfg: RewardModelConfig + + # Backend configuration - only DTensor is supported for value models + dtensor_cfg: DTensorConfig | DTensorConfigDisabled + + # HuggingFace config overrides + hf_config_overrides: NotRequired[dict[str, Any]] + + # Batching strategies + dynamic_batching: DynamicBatchingConfig | DynamicBatchingConfigDisabled + sequence_packing: NotRequired[SequencePackingConfig | SequencePackingConfigDisabled] + + # Sequence length settings + make_sequence_length_divisible_by: int + max_total_sequence_length: int + + # Gradient clipping + max_grad_norm: NotRequired[float | int | None] + + # Checkpoint loading + dequantize_base_checkpoint: NotRequired[bool] + + # Optimizer and scheduler + optimizer: NotRequired[PytorchOptimizerConfig | None] + scheduler: NotRequired[ + list[SinglePytorchSchedulerConfig | SinglePytorchMilestonesConfig] + | SchedulerMilestones + | None + ] diff --git a/nemo_rl/models/value/interfaces.py b/nemo_rl/models/value/interfaces.py new file mode 100644 index 0000000000..874e1f1eca --- /dev/null +++ b/nemo_rl/models/value/interfaces.py @@ -0,0 +1,96 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from abc import ABC, abstractmethod +from typing import Any, Optional, TypedDict + +import torch + +from nemo_rl.algorithms.loss.interfaces import LossFunction +from nemo_rl.distributed.batched_data_dict import BatchedDataDict +from nemo_rl.models.generation.interfaces import GenerationDatumSpec +from nemo_rl.utils.timer import Timer + + +class ValueOutputSpec(TypedDict): + """values: Tensor of value predictions [batch_size, sequence_length].""" + + values: torch.Tensor + + +class ValueInterface(ABC): + """Abstract base class defining the interface for value functions.""" + + @abstractmethod + def get_values( + self, + data: BatchedDataDict[GenerationDatumSpec], + timer: Optional[Timer] = None, + ) -> BatchedDataDict[ValueOutputSpec]: + """Get value predictions for observations. + + Args: + data: BatchedDataDict containing input sequences (tokens) + timer: Optional timer for profiling + + Returns: + BatchedDataDict containing: + - values: Tensor of value predictions [batch_size, sequence_length] + """ + pass + + @abstractmethod + def train( + self, + data: BatchedDataDict, + loss_fn: LossFunction, + eval_mode: bool = False, + *, + gbs: Optional[int] = None, + mbs: Optional[int] = None, + timer: Optional[Timer] = None, + ) -> dict[str, Any]: + """Train the value function on a global batch of data. + + Args: + data: BatchedDataDict containing training data + loss_fn: Loss function to use for training + eval_mode: Whether to run in evaluation mode (no gradient updates) + gbs: Global batch size override (if None, uses config default) + mbs: Micro batch size override (if None, uses config default) + timer: Optional timer for profiling + + Returns: + Dictionary containing training metrics (loss, grad_norm, etc.) + """ + pass + + @abstractmethod + def prepare_for_training(self, *args: Any, **kwargs: Any) -> None: + """Prepare the value model for training (e.g., load to GPU).""" + pass + + @abstractmethod + def finish_training(self, *args: Any, **kwargs: Any) -> None: + """Clean up after training.""" + pass + + @abstractmethod + def save_checkpoint(self, *args: Any, **kwargs: Any) -> None: + """Save model checkpoint.""" + pass + + @abstractmethod + def shutdown(self) -> bool: + """Shutdown workers and clean up resources.""" + pass diff --git a/nemo_rl/models/value/lm_value.py b/nemo_rl/models/value/lm_value.py new file mode 100644 index 0000000000..2213358099 --- /dev/null +++ b/nemo_rl/models/value/lm_value.py @@ -0,0 +1,459 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import warnings +from contextlib import nullcontext +from typing import Any, Optional, Union + +import numpy as np +import ray +from transformers.tokenization_utils_base import PreTrainedTokenizerBase + +from nemo_rl.algorithms.loss.interfaces import LossFunction +from nemo_rl.distributed.batched_data_dict import ( + BatchedDataDict, + DynamicBatchingArgs, + SequencePackingArgs, + SlicedDataDict, +) +from nemo_rl.distributed.named_sharding import NamedSharding +from nemo_rl.distributed.virtual_cluster import RayVirtualCluster +from nemo_rl.distributed.worker_groups import RayWorkerBuilder, RayWorkerGroup +from nemo_rl.models.generation.interfaces import GenerationDatumSpec +from nemo_rl.models.value.config import ValueConfig +from nemo_rl.models.value.interfaces import ValueInterface, ValueOutputSpec +from nemo_rl.utils.checkpoint import CheckpointingConfig +from nemo_rl.utils.timer import Timer + +PathLike = Union[str, "os.PathLike[Any]"] + + +class Value(ValueInterface): + """Value function model for PPO using distributed training with Ray workers.""" + + def __init__( + self, + cluster: RayVirtualCluster, + config: ValueConfig, + tokenizer: PreTrainedTokenizerBase, + name_prefix: str = "lm_value", + workers_per_node: Optional[Union[int, list[int]]] = None, + init_optimizer: bool = True, + weights_path: Optional[PathLike] = None, + optimizer_path: Optional[PathLike] = None, + ): + """Initialize the Value model. + + Args: + cluster: Ray virtual cluster for distributed training + config: Configuration for the value model + tokenizer: Tokenizer for the model + name_prefix: Prefix for worker names + workers_per_node: Number of workers per node + init_optimizer: Whether to initialize the optimizer + weights_path: Path to load model weights from + optimizer_path: Path to load optimizer state from + """ + if weights_path: + weights_path = os.path.abspath(weights_path) + if optimizer_path: + optimizer_path = os.path.abspath(optimizer_path) + + worker_builder_cls: str + tp_size = 1 + pp_size = 1 + cp_size = 1 + + # Value models use the same backend configuration as policy models + megatron_enable = bool(config.get("megatron_cfg", {}).get("enabled", False)) + dtensor_enable = bool(config.get("dtensor_cfg", {}).get("enabled", False)) + + if megatron_enable and dtensor_enable: + raise ValueError( + "Configure either Megatron (value.megatron_cfg.enabled=true) or " + "DTensor (value.dtensor_cfg.enabled=true), not both." + ) + + if megatron_enable: + raise NotImplementedError( + "Megatron backend is not yet implemented for Value models. " + "Please use DTensor backend (value.dtensor_cfg.enabled=true)." + ) + else: + if not dtensor_enable: + raise ValueError( + "Please set value.dtensor_cfg.enabled=true to use DTensor training backend." + ) + + # Check if _v2 is enabled (defaults to False for backward compatibility) + use_v2 = config.get("dtensor_cfg", {}).get("_v2", False) + if use_v2: + worker_builder_cls = "nemo_rl.models.value.workers.dtensor_value_worker_v2.DTensorValueWorkerV2" + + if "TORCH_CUDA_ARCH_LIST" not in os.environ: + warnings.warn( + "TORCH_CUDA_ARCH_LIST is not set. This is needed if using DeepEP in DTensorValueWorker V2. " + "This variable is set in our container, but if you are running a custom container or baremetal, " + "you may need to set this variable manually. Example: export TORCH_CUDA_ARCH_LIST='9.0 10.0'" + ) + else: + raise NotImplementedError( + "DTensor V1 backend is not implemented for Value models. " + "Please set value.dtensor_cfg._v2=true to use DTensor V2." + ) + + tp_size = config["dtensor_cfg"]["tensor_parallel_size"] + cp_size = config["dtensor_cfg"]["context_parallel_size"] + + env_vars = config["dtensor_cfg"].get("env_vars", {}) + + # Validate world_size compatibility with parallelism configuration + model_parallel_size = pp_size * cp_size * tp_size + actual_world_size = cluster.world_size() + + if actual_world_size < model_parallel_size: + raise ValueError( + f"World size ({actual_world_size}) is insufficient for the parallelism configuration. " + f"Required minimum world size: PP({pp_size}) * CP({cp_size}) * TP({tp_size}) = {model_parallel_size}. " + f"This would result in DP = {actual_world_size}/{model_parallel_size} = {actual_world_size / model_parallel_size:.3f}, but DP must be ≥ 1. " + f"Please either increase the number of GPUs/nodes or reduce the parallelism parameters." + ) + + if actual_world_size % model_parallel_size != 0: + dp_size_float = actual_world_size / model_parallel_size + raise ValueError( + f"World size ({actual_world_size}) must be divisible by PP * CP * TP ({model_parallel_size}). " + f"The data parallel size (DP = world_size / (PP * CP * TP)) must be a positive integer. " + f"Current DP would be {actual_world_size}/{model_parallel_size} = {dp_size_float:.6f}, which is not an integer. " + f"Please adjust your cluster size or parallelism parameters." + ) + + self.sharding_annotations = NamedSharding( + layout=np.arange(cluster.world_size()).reshape( + pp_size, # PP + -1, # DP + cp_size, # CP + tp_size, # TP + ), + names=[ + "pipeline_parallel", + "data_parallel", + "context_parallel", + "tensor_parallel", + ], + ) + + from ray.util.queue import Queue as RayQueue + + pre_init_queue = RayQueue() + worker_builder = RayWorkerBuilder( + worker_builder_cls, + config, + tokenizer=tokenizer, + init_optimizer=init_optimizer, + weights_path=weights_path, + optimizer_path=optimizer_path, + worker_sharding_annotations=self.sharding_annotations, + pre_init_communication_queue=pre_init_queue, + ) + + if cluster._sorted_bundle_indices is not None: + # The cluster has initialized a unified placement group across nodes + group_size = cluster.num_gpus_per_node + tied_groups = [ + (i // group_size, [bundle_idx]) + for i, bundle_idx in enumerate(cluster._sorted_bundle_indices) + ] + + self.worker_group = RayWorkerGroup( + cluster, + worker_builder, + name_prefix=name_prefix, + bundle_indices_list=tied_groups, + sharding_annotations=self.sharding_annotations, + env_vars=env_vars or {}, + ) + else: + self.worker_group = RayWorkerGroup( + cluster, + worker_builder, + name_prefix=name_prefix, + workers_per_node=workers_per_node, + sharding_annotations=self.sharding_annotations, + env_vars=env_vars or {}, + ) + + # Configure dynamic batching + if config["dynamic_batching"]["enabled"]: + assert pp_size == 1, ( + "Dynamic batching is only supported for single pipeline parallel stage" + ) + self.use_dynamic_batches = True + self.dynamic_batching_args: DynamicBatchingArgs = { + "input_key": "input_ids", + "input_lengths_key": "input_lengths", + "sequence_length_round": config["dynamic_batching"][ + "sequence_length_round" + ], + "max_tokens_per_microbatch": 0, # Override in each call + } + assert not config["sequence_packing"]["enabled"], ( + "Dynamic Batching is exclusive of Sequence Packing. Please disable Sequence Packing to use Dynamic Batching" + ) + else: + self.use_dynamic_batches = False + + # Configure sequence packing + if config["sequence_packing"]["enabled"]: + self.use_sequence_packing = True + sequence_length_pad_multiple = ( + cp_size * 2 * tp_size if cp_size > 1 else tp_size + ) + self.sequence_packing_args: SequencePackingArgs = { + "algorithm": config["sequence_packing"]["algorithm"], + "input_key": "input_ids", + "input_lengths_key": "input_lengths", + "sequence_length_pad_multiple": sequence_length_pad_multiple, + } + assert not config["dynamic_batching"]["enabled"], ( + "Sequence Packing is exclusive of Dynamic Batching. Please disable Dynamic Batching" + ) + else: + self.use_sequence_packing = False + + self.cfg = config + + def get_values( + self, + data: BatchedDataDict[GenerationDatumSpec], + timer: Optional[Timer] = None, + ) -> BatchedDataDict[ValueOutputSpec]: + """Get value predictions for a batch of data. + + Args: + data: BatchedDataDict containing input sequences + timer: Optional timer for profiling + + Returns: + BatchedDataDict containing value predictions [batch_size, sequence_length] + """ + dp_size = self.sharding_annotations.get_axis_size("data_parallel") + sharded_data: list[SlicedDataDict] + unsorted_data_indices: list[int] + + with timer.time("get_values/shard_data") if timer else nullcontext(): + if self.use_dynamic_batches: + self.dynamic_batching_args["max_tokens_per_microbatch"] = self.cfg[ + "dynamic_batching" + ]["logprob_mb_tokens"] + sharded_data, unsorted_data_indices = data.shard_by_batch_size( # type: ignore + dp_size, + batch_size=None, + dynamic_batching_args=self.dynamic_batching_args, + ) + elif self.use_sequence_packing: + self.sequence_packing_args["max_tokens_per_microbatch"] = self.cfg[ + "sequence_packing" + ]["logprob_mb_tokens"] + sharded_data, unsorted_data_indices = data.shard_by_batch_size( + dp_size, + batch_size=None, + sequence_packing_args=self.sequence_packing_args, + ) + else: + sharded_data = data.shard_by_batch_size( # type: ignore + dp_size, + batch_size=None, + ) + + with timer.time("get_values/submit_value_futures") if timer else nullcontext(): + futures = self.worker_group.run_all_workers_sharded_data( + "get_values", + data=sharded_data, + in_sharded_axes=["data_parallel"], + replicate_on_axes=[ + "context_parallel", + "tensor_parallel", + "pipeline_parallel", + ], + output_is_replicated=[ + "context_parallel", + "tensor_parallel", + "pipeline_parallel", + ], + ) + values: BatchedDataDict[ValueOutputSpec] = BatchedDataDict.from_batches( + self.worker_group.get_all_worker_results(futures) + ) + + # Reorder data if dynamic batching or sequence packing was used + if self.use_dynamic_batches or self.use_sequence_packing: + values.reorder_data(unsorted_data_indices) + + return values + + def train( + self, + data: BatchedDataDict, + loss_fn: LossFunction, + eval_mode: bool = False, + *, + gbs: Optional[int] = None, + mbs: Optional[int] = None, + timer: Optional[Timer] = None, + ) -> dict[str, Any]: + """Train the value function on a batch of data with a given loss function. + + Args: + data: BatchedDataDict containing training data + loss_fn: Loss function to use for training + eval_mode: Whether to run in evaluation mode (no gradient updates) + gbs: Global batch size override (if None, uses config default) + mbs: Micro batch size override (if None, uses config default) + timer: Optional timer for profiling + + Returns: + Dictionary containing training metrics (loss, grad_norm, etc.) + """ + batch_size = gbs or self.cfg["train_global_batch_size"] + micro_batch_size = mbs or self.cfg["train_micro_batch_size"] + + # Shard and replicate the batch + dp_size = self.sharding_annotations.get_axis_size("data_parallel") + with timer.time("value_training/sharding_data") if timer else nullcontext(): + if self.use_dynamic_batches: + self.dynamic_batching_args["max_tokens_per_microbatch"] = self.cfg[ + "dynamic_batching" + ]["train_mb_tokens"] + sharded_data, _ = data.shard_by_batch_size( + dp_size, + batch_size=batch_size, + dynamic_batching_args=self.dynamic_batching_args, + ) + elif self.use_sequence_packing: + self.sequence_packing_args["max_tokens_per_microbatch"] = self.cfg[ + "sequence_packing" + ]["train_mb_tokens"] + sharded_data, _ = data.shard_by_batch_size( + dp_size, + batch_size=batch_size, + sequence_packing_args=self.sequence_packing_args, + ) + else: + sharded_data = data.shard_by_batch_size( + dp_size, + batch_size=batch_size, + ) + + # Train each shard in parallel + with ( + timer.time("value_training/submit_training_futures") + if timer + else nullcontext() + ): + futures = self.worker_group.run_all_workers_sharded_data( + "train", + data=sharded_data, + in_sharded_axes=["data_parallel"], + replicate_on_axes=[ + "context_parallel", + "tensor_parallel", + "pipeline_parallel", + ], + output_is_replicated=[ + "context_parallel", + "tensor_parallel", + "pipeline_parallel", + ], + common_kwargs={ + "loss_fn": loss_fn, + "eval_mode": eval_mode, + "gbs": batch_size, + "mbs": micro_batch_size, + }, + ) + results = self.worker_group.get_all_worker_results(futures) + + # Aggregate the results + aggregated_results = { + "loss": results[0]["global_loss"], + "grad_norm": results[0]["grad_norm"], + } + + # Aggregate metrics across all workers + from collections import defaultdict + + all_mb_metrics = defaultdict(list) + for r in results: + for k, v in r["all_mb_metrics"].items(): + all_mb_metrics[k].extend(v) + aggregated_results["all_mb_metrics"] = dict(all_mb_metrics) + + return aggregated_results + + def prepare_for_training(self, *args: Any, **kwargs: Any) -> None: + """Prepare the value model for training (load to GPU).""" + futures = self.worker_group.run_all_workers_single_data("prepare_for_training") + ray.get(futures) + + def finish_training(self, *args: Any, **kwargs: Any) -> None: + """Clean up after training.""" + pass + + def save_checkpoint( + self, + weights_path: str, + optimizer_path: Optional[str] = None, + tokenizer_path: Optional[str] = None, + checkpointing_cfg: Optional[CheckpointingConfig] = None, + ) -> None: + """Save a checkpoint of the value model.""" + use_v2 = self.cfg.get("dtensor_cfg", {}).get("_v2", False) + + if use_v2: + futures = self.worker_group.run_all_workers_single_data( + "save_checkpoint", + weights_path=weights_path, + optimizer_path=optimizer_path, + tokenizer_path=tokenizer_path, + checkpointing_cfg=checkpointing_cfg, + ) + else: + if ( + checkpointing_cfg is not None + and checkpointing_cfg.get("model_save_format", None) is not None + ): + raise ValueError( + "model_save_format must be None or omitted if using DTensorValueWorker (_v2=False)." + ) + futures = self.worker_group.run_all_workers_single_data( + "save_checkpoint", + weights_path=weights_path, + optimizer_path=optimizer_path, + tokenizer_path=tokenizer_path, + ) + ray.get(futures) + + def shutdown(self) -> bool: + """Shut down all value workers and clean up resources.""" + try: + return self.worker_group.shutdown(cleanup_method="shutdown") + except Exception as e: + print(f"Error during value model shutdown: {e}") + return False + + def __del__(self) -> None: + """Shuts down the worker groups when the object is deleted or garbage collected.""" + if hasattr(self, "worker_group"): + self.worker_group.shutdown(cleanup_method="shutdown") diff --git a/nemo_rl/models/value/workers/__init__.py b/nemo_rl/models/value/workers/__init__.py new file mode 100644 index 0000000000..341a77c5bc --- /dev/null +++ b/nemo_rl/models/value/workers/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo_rl/models/value/workers/dtensor_value_worker_v2.py b/nemo_rl/models/value/workers/dtensor_value_worker_v2.py new file mode 100644 index 0000000000..e1eaed71ae --- /dev/null +++ b/nemo_rl/models/value/workers/dtensor_value_worker_v2.py @@ -0,0 +1,596 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import gc +import warnings +from contextlib import AbstractContextManager, contextmanager, nullcontext +from re import S +from typing import Any, Generator, Optional +from unittest import skip + +import ray +import torch +from nemo_automodel.components.distributed.cp_utils import create_context_parallel_ctx +from nemo_automodel.components.distributed.cp_utils import ( + get_train_context as get_train_context_automodel, +) +from nemo_automodel.components.training.utils import scale_grads_and_clip_grad_norm +from torch import nn +from transformers import AutoTokenizer + +from nemo_rl.algorithms.loss.interfaces import LossFunction +from nemo_rl.distributed.batched_data_dict import BatchedDataDict +from nemo_rl.models.automodel.data import ( + check_sequence_dim, + get_microbatch_iterator, + process_global_batch, +) +from nemo_rl.models.automodel.setup import ( + setup_distributed, + setup_model_and_optimizer, + validate_and_prepare_config, +) +from nemo_rl.models.automodel.train import ( + LogprobsPostProcessor, + LossPostProcessor, + ScorePostProcessor, + aggregate_training_statistics, + automodel_forward_backward, + forward_with_post_processing_fn, +) +from nemo_rl.models.policy.workers.base_policy_worker import AbstractPolicyWorker +from nemo_rl.models.policy.workers.patches import ( + apply_torch_aten_alias_tensor_patch, + apply_transformer_engine_patch, +) +from nemo_rl.models.value.config import ValueConfig +from nemo_rl.models.value.interfaces import ValueOutputSpec +from nemo_rl.utils.automodel_checkpoint import AutomodelCheckpointManager +from nemo_rl.utils.checkpoint import CheckpointingConfig +from nemo_rl.utils.nsys import wrap_with_nvtx_name + + +@contextlib.contextmanager +def get_train_context( + cp_size: int, + cp_mesh: Any, + cp_buffers: list, + sequence_dim: int, + dtype: torch.dtype, + autocast_enabled: bool = True, +) -> Generator[None, None, None]: + """Create combined context manager for training with context parallel and autocast.""" + with contextlib.ExitStack() as stack: + context_parallel_ctx = None + if cp_size > 1: + # Create context parallel context + context_parallel_ctx = create_context_parallel_ctx( + cp_mesh=cp_mesh, + cp_buffers=cp_buffers, + cp_seq_dims=[sequence_dim] * len(cp_buffers), + cp_no_restore_buffers=set(cp_buffers), + ) + + stack.enter_context( + get_train_context_automodel(False, False, context_parallel_ctx)() + ) + if autocast_enabled: + stack.enter_context(torch.autocast(device_type="cuda", dtype=dtype)) + yield + + +def get_runtime_env_for_value_worker(worker_type: str) -> dict: + """Get runtime environment for value worker.""" + from nemo_rl.models.policy.utils import get_runtime_env_for_policy_worker + + # Reuse policy worker runtime env + return get_runtime_env_for_policy_worker("dtensor_policy_worker_v2") + + +@ray.remote( + runtime_env=get_runtime_env_for_value_worker("dtensor_value_worker_v2") +) # pragma: no cover +class DTensorValueWorkerV2(AbstractPolicyWorker): + def __repr__(self) -> str: + """Customizes the actor's prefix in the Ray logs.""" + if torch.distributed.is_initialized(): + return f"{self.__class__.__qualname__}[rank={torch.distributed.get_rank()}]" + else: + return f"{self.__class__.__qualname__}" + + def __init__( + self, + config: ValueConfig, + tokenizer: AutoTokenizer, + weights_path: Optional[str] = None, + optimizer_path: Optional[str] = None, + init_optimizer: bool = True, + **kwargs: Any, + ): + """Initialize the DTensorValueWorkerV2. + + Note: Value models don't need a reference model since they don't compute KL divergence. + """ + # Apply patches + apply_transformer_engine_patch() + apply_torch_aten_alias_tensor_patch() + + # Store configuration and tokenizer + self.cfg = config + self.tokenizer = tokenizer + self.lora_enabled = ( + config["dtensor_cfg"].get("lora_cfg", {}).get("enabled", False) + ) + + # Ensure reward model config is set for value models + if ( + "reward_model_cfg" not in config + or not config["reward_model_cfg"]["enabled"] + ): + # Value models use the reward model architecture but predict values instead + config["reward_model_cfg"] = { + "enabled": True, + "reward_model_type": "regression", # Value is a regression task + } + + print(f"Initializing DTensorValueWorkerV2") + + # Initialize checkpoint manager + self.checkpoint_manager: Optional[AutomodelCheckpointManager] = None + + if "hf_config_overrides" not in config: + config["hf_config_overrides"] = {} + config["hf_config_overrides"]["num_labels"] = 1 + + # Validate configuration and prepare runtime settings + runtime_config = validate_and_prepare_config( + config=config, + processor=None, # Value models don't use vision processors + rank=0, # Temporary, will be updated after distributed init + ) + + # Set up distributed environment + distributed_manager = setup_distributed( + config=config, + runtime_config=runtime_config, + ) + + # Set instance attributes from distributed manager + self.rank = torch.distributed.get_rank() + self.device_mesh = distributed_manager.device_mesh + self.dp_cp_mesh = self.device_mesh["dp_cp"] + self.dp_mesh = self.device_mesh["dp"] + self.tp_mesh = self.device_mesh["tp"] + self.cp_mesh = self.device_mesh["cp"] + self.moe_mesh = distributed_manager.moe_mesh + self.dp_size = distributed_manager.dp_size + self.tp_size = distributed_manager.tp_size + self.cp_size = distributed_manager.cp_size + + # Initialize checkpoint manager + self._init_checkpoint_manager( + config_updates={ + "model_repo_id": config["model_name"], + "dequantize_base_checkpoint": config.get( + "dequantize_base_checkpoint", False + ), + "is_peft": self.lora_enabled, + "skip_task_head_prefixes_for_base_model": ["score."], + }, + ) + + # Set up model and optimizer + model_and_optimizer_state = setup_model_and_optimizer( + config=config, + tokenizer=tokenizer, + runtime_config=runtime_config, + distributed_manager=distributed_manager, + checkpoint_manager=self.checkpoint_manager, + is_vlm=False, # Value models don't use vision + init_optimizer=init_optimizer, + weights_path=weights_path, + optimizer_path=optimizer_path, + # optimizer_module_filter=["score."], + ) + + # Set instance attributes from model and optimizer state + ( + self.model, + self.model_state_dict_keys, + self.optimizer, + self.scheduler, + self.is_hf_model, + self.is_moe_model, + self._is_reward_model, + self.model_class, + self.model_config, + self.peft_config, + self.autocast_enabled, + ) = model_and_optimizer_state + + # Set instance attributes from runtime config + ( + self.model_class, + self.model_config, + self.hf_config_overrides, + self.allow_flash_attn_args, + self.attn_impl, + self.dtype, + self.enable_seq_packing, + self.max_grad_norm, + self.cpu_offload, + self.offload_optimizer_for_logprob, + self.is_generation_colocated, + self.sampling_params, + _runtime_is_reward_model, + ) = runtime_config + + @wrap_with_nvtx_name("dtensor_value_worker_v2/train") + def train( + self, + data: BatchedDataDict[Any], + loss_fn: LossFunction, + eval_mode: bool = False, + gbs: Optional[int] = None, + mbs: Optional[int] = None, + ) -> dict[str, Any]: + """Train the value function on a batch of data with a given loss function.""" + if gbs is None: + gbs = self.cfg["train_global_batch_size"] + if mbs is None: + mbs = self.cfg["train_micro_batch_size"] + local_gbs = gbs // self.dp_size + total_dataset_size = torch.tensor(data.size, device="cuda") + torch.distributed.all_reduce( + total_dataset_size, + op=torch.distributed.ReduceOp.SUM, + group=self.dp_mesh.get_group(), + ) + num_global_batches = int(total_dataset_size.item()) // gbs + + # Validate sequence dimension + sequence_dim, _ = check_sequence_dim(data) + + if eval_mode: + ctx: AbstractContextManager[Any] = torch.no_grad() + self.model.eval() + else: + ctx = nullcontext() + self.model.train() + + # Create loss post-processor + loss_post_processor = LossPostProcessor( + loss_fn=loss_fn, + cfg=self.cfg, + device_mesh=self.device_mesh, + cp_mesh=self.cp_mesh, + tp_mesh=self.tp_mesh, + cp_size=self.cp_size, + dp_size=self.dp_size, + enable_seq_packing=self.enable_seq_packing, + ) + + # Create train context factory + def train_context_fn(processed_inputs): + return get_train_context( + cp_size=self.cp_size, + cp_mesh=self.cp_mesh, + cp_buffers=processed_inputs.cp_buffers, + sequence_dim=sequence_dim, + dtype=self.dtype, + autocast_enabled=self.autocast_enabled, + ) + + # Setup cache clearing callback if configured + empty_cache_steps = self.cfg.get("dtensor_cfg", {}).get( + "clear_cache_every_n_steps" + ) + if empty_cache_steps: + warnings.warn( + f"Emptying cache every {empty_cache_steps} microbatches; doing so unnecessarily would incur a large performance overhead.", + ) + + def on_microbatch_start(mb_idx): + if empty_cache_steps and mb_idx % empty_cache_steps == 0: + torch.cuda.empty_cache() + + with ctx: + data = data.to("cuda") + + losses = [] + all_mb_metrics = [] + for gb_idx in range(num_global_batches): + # Process global batch + gb_result = process_global_batch( + data, + loss_fn, + self.dp_mesh.get_group(), + batch_idx=gb_idx, + batch_size=local_gbs, + ) + batch = gb_result["batch"] + global_valid_seqs = gb_result["global_valid_seqs"] + global_valid_toks = gb_result["global_valid_toks"] + + self.optimizer.zero_grad() + + # Get microbatch iterator + processed_iterator, iterator_len = get_microbatch_iterator( + batch, + self.cfg, + mbs, + self.dp_mesh, + tokenizer=self.tokenizer, + cp_size=self.cp_size, + ) + + # Use automodel_forward_backward for the training loop + mb_results = automodel_forward_backward( + model=self.model, + data_iterator=processed_iterator, + post_processing_fn=loss_post_processor, + forward_only=eval_mode, + is_reward_model=True, # Value models use reward model architecture + allow_flash_attn_args=False, # Typically False for value models + global_valid_seqs=global_valid_seqs, + global_valid_toks=global_valid_toks, + sampling_params=self.sampling_params, + sequence_dim=sequence_dim, + dp_size=self.dp_size, + cp_size=self.cp_size, + num_global_batches=num_global_batches, + train_context_fn=train_context_fn, + num_valid_microbatches=iterator_len, + on_microbatch_start=on_microbatch_start, + ) + + # Extract losses and metrics from results + mb_losses = [] + for mb_idx, (loss, loss_metrics) in enumerate(mb_results): + if mb_idx < iterator_len: + num_valid_samples = loss_metrics["num_valid_samples"] + loss_metrics["lr"] = self.optimizer.param_groups[0]["lr"] + loss_metrics["global_valid_seqs"] = global_valid_seqs.item() + loss_metrics["global_valid_toks"] = global_valid_toks.item() + + if num_valid_samples > 0: + mb_losses.append(loss.item()) + all_mb_metrics.append(loss_metrics) + + grad_norm: Optional[float | torch.Tensor] = None + if not eval_mode: + grad_norm = scale_grads_and_clip_grad_norm( + self.max_grad_norm, + [self.model], + norm_type=2.0, + pp_enabled=False, + device_mesh=self.device_mesh, + moe_mesh=self.moe_mesh, + ep_axis_name="ep" + if self.moe_mesh is not None + and "ep" in self.moe_mesh.mesh_dim_names + else None, + pp_axis_name=None, + foreach=True, + num_label_tokens=1, + dp_group_size=self.dp_size * self.cp_size, + ) + grad_norm = torch.tensor( + grad_norm, device="cpu", dtype=torch.float32 + ) + + # Update parameters + self.optimizer.step() + + losses.append(torch.tensor(mb_losses).sum().item()) + + # Release gradient memory + self.optimizer.zero_grad() + # Increment scheduler + if not eval_mode: + self.scheduler.step() + # Clear cache + torch.cuda.empty_cache() + + # Aggregate training statistics + metrics = aggregate_training_statistics( + losses=losses, + all_mb_metrics=all_mb_metrics, + grad_norm=grad_norm, + dp_group=self.dp_mesh.get_group(), + dtype=self.dtype, + ) + + return metrics + + @wrap_with_nvtx_name("dtensor_value_worker_v2/get_values") + def get_values( + self, data: BatchedDataDict[Any], micro_batch_size: Optional[int] = None + ) -> BatchedDataDict[ValueOutputSpec]: + """Get value predictions for a batch of data.""" + value_batch_size = ( + micro_batch_size + if micro_batch_size is not None + else self.cfg.get("logprob_batch_size", self.cfg["train_micro_batch_size"]) + ) + + # Validate sequence dimension + sequence_dim, seq_dim_size = check_sequence_dim(data) + + all_values = [] + self.model.eval() + + # Create value post-processor + value_post_processor = ScorePostProcessor( + cfg=self.cfg, + ) + + with torch.no_grad(): + data.to("cuda") + # Get microbatch iterator + processed_iterator, iterator_len = get_microbatch_iterator( + data, + self.cfg, + value_batch_size, + self.dp_mesh, + tokenizer=self.tokenizer, + cp_size=self.cp_size, + ) + + for batch_idx, processed_mb in enumerate(processed_iterator): + processed_inputs = processed_mb.processed_inputs + + with get_train_context( + cp_size=self.cp_size, + cp_mesh=self.cp_mesh, + cp_buffers=processed_inputs.cp_buffers, + sequence_dim=sequence_dim, + dtype=self.dtype, + autocast_enabled=self.autocast_enabled, + ): + # Use forward_with_post_processing_fn for forward pass + values, _metrics, _ = forward_with_post_processing_fn( + model=self.model, + post_processing_fn=value_post_processor, + processed_mb=processed_mb, + is_reward_model=True, # Value models use reward model architecture + allow_flash_attn_args=False, + sampling_params=self.sampling_params, + sequence_dim=sequence_dim, + ) + + # Skip dummy batches + if batch_idx >= iterator_len: + continue + + all_values.append(values) + + # Concatenate all batches + return_data = BatchedDataDict[ValueOutputSpec]() + + all_values_padded = [] + for val in all_values: + padding_needed = seq_dim_size - val.shape[1] + if padding_needed > 0: + val = torch.nn.functional.pad( + val, (0, padding_needed), mode="constant", value=0.0 + ) + all_values_padded.append(val) + return_data["values"] = torch.cat(all_values_padded, dim=0).cpu() + + return return_data + + @wrap_with_nvtx_name("dtensor_value_worker_v2/prepare_for_training") + def prepare_for_training(self, *args, **kwargs) -> None: + """Prepare for training by loading model and optimizer to GPU.""" + if not self.cpu_offload: + self.move_to_cuda(self.model) + else: + self.model = self.move_buffer_to_device(self.model, "cuda") + + self.model.train() + if self.optimizer is not None and not self.cpu_offload: + self.move_optimizer_to_device("cuda") + + torch.cuda.empty_cache() + + def move_optimizer_to_device(self, device: str | torch.device) -> None: + """Move optimizer state to specified device.""" + from torch.distributed.tensor import DTensor + + for state in self.optimizer.state.values(): + for k, v in state.items(): + if isinstance(v, (DTensor, torch.Tensor)): + state[k] = v.to(device) + + def move_to_device(self, model: nn.Module, device: str | torch.device) -> nn.Module: + """Move model to specified device.""" + model = self.move_buffer_to_device(model, device) + return model.to(device) + + def move_buffer_to_device( + self, model: nn.Module, device: str | torch.device + ) -> nn.Module: + """Move model buffers to specified device.""" + for v in model.buffers(): + torch.utils.swap_tensors(v, v.to(device)) + return model + + def move_to_cuda(self, model: torch.nn.Module) -> torch.nn.Module: + """Move model to CUDA.""" + model = self.move_to_device(model, "cuda") + gc.collect() + torch.cuda.empty_cache() + return model + + def move_to_cpu(self, model: torch.nn.Module) -> torch.nn.Module: + """Move model to CPU.""" + model = self.move_to_device(model, "cpu") + gc.collect() + torch.cuda.empty_cache() + return model + + def save_checkpoint( + self, + weights_path: str, + optimizer_path: Optional[str] = None, + tokenizer_path: Optional[str] = None, + checkpointing_cfg: Optional[CheckpointingConfig] = None, + ) -> None: + """Save a checkpoint of the value model.""" + self.checkpoint_manager.save_checkpoint( + model=self.model, + weights_path=weights_path, + optimizer=self.optimizer, + optimizer_path=optimizer_path, + scheduler=self.scheduler, + tokenizer=self.tokenizer if tokenizer_path else None, + tokenizer_path=tokenizer_path, + checkpointing_cfg=checkpointing_cfg, + lora_enabled=self.lora_enabled, + peft_config=self.peft_config, + ) + + def load_checkpoint( + self, + weights_path: str, + optimizer_path: Optional[str] = None, + ) -> None: + """Load a checkpoint into the value model.""" + self.checkpoint_manager.load_checkpoint( + model=self.model, + weights_path=weights_path, + optimizer=self.optimizer, + optimizer_path=optimizer_path, + scheduler=self.scheduler, + ) + + def _init_checkpoint_manager( + self, + config_updates: Optional[dict[str, Any]] = None, + checkpoint_root: Optional[str] = None, + ) -> None: + """Initialize the AutomodelCheckpointManager for this worker.""" + if self.checkpoint_manager is None: + self.checkpoint_manager = AutomodelCheckpointManager( + dp_mesh=self.dp_mesh, + tp_mesh=self.tp_mesh, + model_state_dict_keys=getattr(self, "model_state_dict_keys", None), + moe_mesh=self.moe_mesh, + ) + self.checkpoint_manager.init_checkpointer( + config_updates=config_updates, + checkpoint_root=checkpoint_root, + ) diff --git a/nemo_rl/utils/automodel_checkpoint.py b/nemo_rl/utils/automodel_checkpoint.py index bfba23fae5..702cd02178 100644 --- a/nemo_rl/utils/automodel_checkpoint.py +++ b/nemo_rl/utils/automodel_checkpoint.py @@ -119,6 +119,9 @@ def init_checkpointer( dequantize_base_checkpoint=config_updates.get( "dequantize_base_checkpoint", False ), + skip_task_head_prefixes_for_base_model=config_updates.get( + "skip_task_head_prefixes_for_base_model", None + ), ) self.checkpoint_config = base_cfg self.checkpointer = Checkpointer( From 90116e9845ddda6b5e9a7a122af5c3e27c2f078e Mon Sep 17 00:00:00 2001 From: Lukasz Pierscieniewski Date: Tue, 24 Mar 2026 18:42:45 +0100 Subject: [PATCH 2/2] fix(ppo): naming issues after rebase --- nemo_rl/algorithms/loss/loss_functions.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/nemo_rl/algorithms/loss/loss_functions.py b/nemo_rl/algorithms/loss/loss_functions.py index 503a8d518d..35286763b0 100755 --- a/nemo_rl/algorithms/loss/loss_functions.py +++ b/nemo_rl/algorithms/loss/loss_functions.py @@ -987,21 +987,24 @@ def __call__( class MseValueLossFn(LossFunction): """Mean Squared Error value loss function.""" + loss_type = LossType.TOKEN_LEVEL + input_type = LossInputType.LOGIT + def __init__(self, loss_cfg): self.ratio_clip_min = loss_cfg["ratio_clip_min"] self.ratio_clip_max = loss_cfg["ratio_clip_max"] def __call__( self, - values: torch.Tensor, + logits: torch.Tensor, data: BatchedDataDict, global_valid_seqs: torch.Tensor, global_valid_toks: torch.Tensor, ) -> tuple[torch.Tensor, dict[str, Any]]: """Compute Mean Squared Error value loss.""" - if values.shape[-1] != 1: - values = values[..., 0] + if logits.shape[-1] != 1: + logits = logits[..., 0] token_mask = data["token_mask"] sample_mask = data["sample_mask"] @@ -1010,13 +1013,13 @@ def __call__( mask = token_mask * sample_mask.unsqueeze(-1) - values_clamped = values.clamp( + values_clamped = logits.clamp( old_values - self.ratio_clip_min, old_values + self.ratio_clip_max, ) loss = torch.max( - torch.square(values - returns), + torch.square(logits - returns), torch.square(values_clamped - returns), ) @@ -1026,7 +1029,7 @@ def __call__( metrics = { "loss": float(loss.item()), - "num_valid_samples": int(values.shape[0]), + "num_valid_samples": int(logits.shape[0]), } return loss, metrics