diff --git a/miles/backends/fsdp_utils/actor.py b/miles/backends/fsdp_utils/actor.py index 1e3e5b3ae..729d9a428 100644 --- a/miles/backends/fsdp_utils/actor.py +++ b/miles/backends/fsdp_utils/actor.py @@ -26,6 +26,7 @@ from ...utils import tracking_utils from ...utils.profile_utils import TrainProfiler +from ...models.peft import LoRAConfig, apply_lora from . import checkpoint from .data_packing import pack_sequences, pad_packed_sequence_with_cp, unpack_sequences from .lr_scheduler import get_lr_scheduler @@ -94,6 +95,16 @@ def init(self, args: Namespace, role: str, with_ref: bool = False) -> int: # ty attn_implementation=self.args.attn_implementation, ) + if args.use_lora: + lora_config = LoRAConfig( + lora_rank=args.lora_rank, + lora_alpha=args.lora_alpha, + lora_dropout=args.lora_dropout, + target_modules=args.lora_target_modules, + ) + model = apply_lora(model, lora_config) + logger.info(f"[Rank {dist.get_rank()}] Applied LoRA: {lora_config}") + model.train() full_state = model.state_dict() @@ -107,11 +118,22 @@ def init(self, args: Namespace, role: str, with_ref: bool = False) -> int: # ty self.model = model if args.gradient_checkpointing: - self.model.gradient_checkpointing_enable() + # Use non-reentrant mode for gradient checkpointing + self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) if args.optimizer == "adam": + trainable_params = [p for p in self.model.parameters() if p.requires_grad] + + if args.use_lora: + total_params = sum(p.numel() for p in self.model.parameters()) + trainable_count = sum(p.numel() for p in trainable_params) + logger.info( + f"[Rank {dist.get_rank()}] LoRA: {trainable_count:,} trainable params " + f"out of {total_params:,} total ({100 * trainable_count / total_params:.2f}%)" + ) + self.optimizer = torch.optim.AdamW( - self.model.parameters(), + trainable_params, lr=args.lr, betas=(args.adam_beta1, args.adam_beta2), eps=args.adam_eps, @@ -322,7 +344,11 @@ def save_model(self, iteration: int) -> None: if self.args.debug_rollout_only or self.args.save is None: return - checkpoint.save(self, iteration) + keys_filter = None + if self.args.use_lora: + keys_filter = lambda k: "lora_" in k + + checkpoint.save(self, iteration, keys_filter=keys_filter) def _compute_log_prob( self, diff --git a/miles/backends/fsdp_utils/checkpoint.py b/miles/backends/fsdp_utils/checkpoint.py index 3c49a10f8..3846bd98c 100644 --- a/miles/backends/fsdp_utils/checkpoint.py +++ b/miles/backends/fsdp_utils/checkpoint.py @@ -6,11 +6,13 @@ from pathlib import Path from typing import Any +import safetensors.torch import torch import torch.distributed as dist import torch.distributed.checkpoint as dcp -from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict +from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict, get_model_state_dict, StateDictOptions from torch.distributed.checkpoint.stateful import Stateful +from miles.models.peft import LoRAConfig logger = logging.getLogger(__name__) @@ -18,31 +20,53 @@ class ModelState(Stateful): """Wrapper for model state only.""" - def __init__(self, model): + def __init__(self, model, keys_filter=None): self.model = model + self.keys_filter = keys_filter def state_dict(self): model_state_dict, _ = get_state_dict(self.model, optimizers=[]) + if self.keys_filter: + model_state_dict = {k: v for k, v in model_state_dict.items() if self.keys_filter(k)} return {"model": model_state_dict} def load_state_dict(self, state_dict): - set_state_dict(self.model, optimizers=[], model_state_dict=state_dict["model"], optim_state_dict=None) + options = None + if self.keys_filter: + # For filtered loading (e.g., LoRA), use strict=False to allow partial loading + options = StateDictOptions(strict=False) + set_state_dict( + self.model, optimizers=[], + model_state_dict=state_dict["model"], + optim_state_dict=None, + options=options + ) class OptimizerState(Stateful): """Wrapper for optimizer state only.""" - def __init__(self, model, optimizer): + def __init__(self, model, optimizer, keys_filter=None): self.model = model self.optimizer = optimizer + self.keys_filter = keys_filter def state_dict(self): _, optimizer_state_dict = get_state_dict(self.model, optimizers=self.optimizer) + if self.keys_filter: + optimizer_state_dict = {k: v for k, v in optimizer_state_dict.items() if self.keys_filter(k)} return {"optim": optimizer_state_dict} def load_state_dict(self, state_dict): + options = None + if self.keys_filter: + # For filtered loading (e.g., LoRA), use strict=False to allow partial loading + options = StateDictOptions(strict=False) set_state_dict( - self.model, optimizers=self.optimizer, model_state_dict=None, optim_state_dict=state_dict["optim"] + self.model, optimizers=self.optimizer, + model_state_dict=None, + optim_state_dict=state_dict["optim"], + options=options ) @@ -108,8 +132,13 @@ def load(actor: Any) -> dict[str, Any] | None: logger.info(f"[FSDP] Model checkpoint {model_dir} not found; skipping load.") return None + keys_filter = None + if actor.args.use_lora: + keys_filter = lambda k: "lora_" in k + logger.info("[FSDP] LoRA mode: loading only LoRA weights from checkpoint") + # Load model weights (always) - model_state = ModelState(actor.model) + model_state = ModelState(actor.model, keys_filter=keys_filter) state_dict = {"model_state": model_state} try: @@ -122,7 +151,7 @@ def load(actor: Any) -> dict[str, Any] | None: # Load optimizer state (optional) load_optimizer = not getattr(actor.args, "no_load_optim", False) and hasattr(actor, "optimizer") if load_optimizer and optimizer_dir.exists(): - optimizer_state = OptimizerState(actor.model, actor.optimizer) + optimizer_state = OptimizerState(actor.model, actor.optimizer, keys_filter=keys_filter) optim_state_dict = {"optim_state": optimizer_state} try: dcp.load(state_dict=optim_state_dict, checkpoint_id=str(optimizer_dir)) @@ -187,7 +216,7 @@ def finalize_load(actor: Any, checkpoint_payload: dict[str, Any] | None) -> None dist.barrier() -def save(actor: Any, iteration: int) -> None: +def save(actor: Any, iteration: int, keys_filter=None) -> None: """Save checkpoint to disk. Saves model weights and optimizer state to separate directories. @@ -210,13 +239,13 @@ def save(actor: Any, iteration: int) -> None: dist.barrier() # Save model weights - model_state = ModelState(actor.model) + model_state = ModelState(actor.model, keys_filter=keys_filter) state_dict = {"model_state": model_state} dcp.save(state_dict, checkpoint_id=str(model_dir)) # Save optimizer state if hasattr(actor, "optimizer") and actor.optimizer is not None: - optimizer_state = OptimizerState(actor.model, actor.optimizer) + optimizer_state = OptimizerState(actor.model, actor.optimizer, keys_filter=keys_filter) optim_state_dict = {"optim_state": optimizer_state} dcp.save(optim_state_dict, checkpoint_id=str(optimizer_dir)) @@ -246,4 +275,29 @@ def save(actor: Any, iteration: int) -> None: tracker_file.write_text(str(step_id)) logger.info(f"[FSDP] Saved checkpoint to {checkpoint_dir}") + if actor.args.use_lora: + _save_hf_lora(actor, checkpoint_dir) + dist.barrier() + + +def _save_hf_lora(actor: Any, checkpoint_dir: Path) -> None: + """Save LoRA adapter in Hugging Face PEFT format.""" + + options = dcp.state_dict.StateDictOptions(full_state_dict=True, cpu_offload=True) + full_state_dict = get_model_state_dict(actor.model, options=options) + + if dist.get_rank() == 0: + lora_config = LoRAConfig( + lora_rank=actor.args.lora_rank, + lora_alpha=actor.args.lora_alpha, + lora_dropout=actor.args.lora_dropout, + target_modules=actor.args.lora_target_modules, + ) + peft_config = lora_config.to_hf_peft_config() + with open(checkpoint_dir / "adapter_config.json", "w") as f: + json.dump(peft_config, f, indent=2) + + lora_state_dict = {k: v for k, v in full_state_dict.items() if "lora_" in k} + safetensors.torch.save_file(lora_state_dict, checkpoint_dir / "adapter_model.safetensors") + logger.info(f"[FSDP] Saved HF LoRA adapter to {checkpoint_dir}") diff --git a/miles/models/peft/__init__.py b/miles/models/peft/__init__.py new file mode 100644 index 000000000..25e0269ea --- /dev/null +++ b/miles/models/peft/__init__.py @@ -0,0 +1,3 @@ +from .lora import LoRAConfig, LoRALinear, apply_lora, get_lora_state_dict, load_lora_state_dict +from .arguments import add_lora_arguments + diff --git a/miles/models/peft/arguments.py b/miles/models/peft/arguments.py new file mode 100644 index 000000000..3d12b6a72 --- /dev/null +++ b/miles/models/peft/arguments.py @@ -0,0 +1,38 @@ +import argparse + +def add_lora_arguments(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + """Add LoRA arguments to the parser.""" + group = parser.add_argument_group(title="LoRA") + + group.add_argument( + "--use-lora", + action="store_true", + help="Whether to use LoRA for training.", + ) + group.add_argument( + "--lora-rank", + type=int, + default=8, + help="LoRA rank.", + ) + group.add_argument( + "--lora-alpha", + type=int, + default=16, + help="LoRA alpha.", + ) + group.add_argument( + "--lora-dropout", + type=float, + default=0.0, + help="LoRA dropout.", + ) + group.add_argument( + "--lora-target-modules", + type=str, + nargs="+", + default=["q_proj", "v_proj"], + help="List of module names to apply LoRA to.", + ) + + return parser diff --git a/miles/models/peft/lora.py b/miles/models/peft/lora.py new file mode 100644 index 000000000..196645a2a --- /dev/null +++ b/miles/models/peft/lora.py @@ -0,0 +1,145 @@ +import math +from dataclasses import dataclass, field +from typing import List, Dict, Any + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +@dataclass +class LoRAConfig: + """Configuration for LoRA.""" + lora_rank: int = 8 + lora_alpha: int = 16 + lora_dropout: float = 0.0 + target_modules: List[str] = field(default_factory=lambda: ["q_proj", "v_proj"]) + bias: str = "none" # "none", "all", or "lora_only" - currently only "none" supported for simplicity + + def to_hf_peft_config(self) -> Dict[str, Any]: + """Convert to Hugging Face PEFT config format.""" + return { + "peft_type": "LORA", + "task_type": "CAUSAL_LM", + "inference_mode": False, + "r": self.lora_rank, + "lora_alpha": self.lora_alpha, + "lora_dropout": self.lora_dropout, + "target_modules": self.target_modules, + "bias": self.bias, + } + + +class LoRALinear(nn.Module): + """ + LoRA linear layer that wraps a base linear layer. + + Args: + base_layer: The existing Linear layer to wrap. + rank: LoRA rank (r). + alpha: LoRA alpha (scaling factor). + dropout: Dropout probability for LoRA input. + """ + def __init__( + self, + base_layer: nn.Linear, + rank: int = 8, + alpha: int = 16, + dropout: float = 0.0 + ): + super().__init__() + self.base_layer = base_layer + self.rank = rank + self.alpha = alpha + self.scaling = alpha / rank + + self.lora_A = nn.Parameter(torch.zeros(rank, base_layer.in_features)) + self.lora_B = nn.Parameter(torch.zeros(base_layer.out_features, rank)) + + self.dropout = nn.Dropout(p=dropout) + + self.reset_parameters() + + def reset_parameters(self): + nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) + nn.init.zeros_(self.lora_B) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + result = self.base_layer(x) + + lora_out = self.dropout(x) + lora_out = F.linear(lora_out, self.lora_A) + lora_out = F.linear(lora_out, self.lora_B) + + return result + lora_out * self.scaling + + def __repr__(self): + return ( + f"LoRALinear(in_features={self.base_layer.in_features}, " + f"out_features={self.base_layer.out_features}, " + f"rank={self.rank}, alpha={self.alpha})" + ) + + +def apply_lora(model: nn.Module, config: LoRAConfig) -> nn.Module: + """ + Apply LoRA to the model by replacing target linear layers with LoRALinear. + + Args: + model: The model to modify. + config: LoRA configuration. + + Returns: + The modified model. + """ + assert config.bias == "none", "Only bias='none' is currently supported" + target_modules = set(config.target_modules) + + # We need to collect replacements first to avoid modifying the dict while iterating + modules_to_replace = [] + + for name, module in model.named_modules(): + # Check if this module name ends with any of the target modules + # e.g. "model.layers.0.self_attn.q_proj" ends with "q_proj" + if any(name.endswith(target) for target in target_modules): + if isinstance(module, nn.Linear): + modules_to_replace.append((name, module)) + + if not modules_to_replace: + raise ValueError(f"No modules found matching {target_modules}") + + for name, module in modules_to_replace: + if '.' in name: + parent_name, child_name = name.rsplit('.', 1) + parent = model.get_submodule(parent_name) + else: + parent_name = "" + child_name = name + parent = model + + lora_layer = LoRALinear( + base_layer=module, + rank=config.lora_rank, + alpha=config.lora_alpha, + dropout=config.lora_dropout + ) + + setattr(parent, child_name, lora_layer) + + # Freeze all non-LoRA parameters + for n, p in model.named_parameters(): + if "lora_" not in n: + p.requires_grad = False + + return model + + +def get_lora_state_dict(model: nn.Module) -> Dict[str, torch.Tensor]: + """Return state dict with only LoRA parameters.""" + return {k: v for k, v in model.state_dict().items() if "lora_" in k} + + +def load_lora_state_dict(model: nn.Module, state_dict: Dict[str, torch.Tensor], strict: bool = False): + """Load LoRA parameters into the model.""" + # We only load keys that exist in the state_dict and match LoRA params + model.load_state_dict(state_dict, strict=strict) diff --git a/miles/rollout/data_source.py b/miles/rollout/data_source.py index 613319d34..9a4df15cf 100644 --- a/miles/rollout/data_source.py +++ b/miles/rollout/data_source.py @@ -42,8 +42,11 @@ def load(self, rollout_id=None): # TODO may further refactor data-loading part later class RolloutDataSource(DataSource): - def __init__(self, args): + def __init__(self, args, prompt_data=None): self.args = args + if prompt_data is None: + # For backwards compatibility with miles' default codepaths + prompt_data = args.prompt_data self.epoch_id = 0 self.sample_group_index = 0 @@ -63,7 +66,7 @@ def __init__(self, args): processor.save_pretrained(Path(d) / "processor") self.dataset = Dataset( - args.prompt_data, + prompt_data, tokenizer=tokenizer, processor=processor, max_length=args.rollout_max_prompt_len, diff --git a/miles/utils/arguments.py b/miles/utils/arguments.py index ce6e47161..65066e3e6 100644 --- a/miles/utils/arguments.py +++ b/miles/utils/arguments.py @@ -10,6 +10,7 @@ from miles.backends.sglang_utils.arguments import add_sglang_arguments from miles.backends.sglang_utils.arguments import validate_args as sglang_validate_args +from miles.models.peft import add_lora_arguments from miles.utils.eval_config import EvalDatasetConfig, build_eval_dataset_configs, ensure_dataset_list from miles.utils.logging_utils import configure_logger @@ -467,13 +468,29 @@ def add_data_arguments(parser): type=str, default=None, help=( - "The path to the prompt data. " - "Currently we only support jsonl format, and each line should contains --input-key and --label-key, " + "The path to the prompt data." + "Currently we only support jsonl/parquet format, and each line should contains --input-key and --label-key, " "which will be used as the prompt and the label respectively. " "If you want to use a custom template, you can set --apply-chat-template to true, in that case, " "the input should be the same structure as an openai message, e.g. [{'role': 'user', 'content': 'blabla'}]. " ), ) + # validation loss + parser.add_argument( + "--val-prompt-data", + type=str, + default=None, + help=( + "The path to the validation prompt data." + "Currently we only support jsonl/parquet format, and each line should contains --input-key and --label-key, " + "which will be used as the validation prompt and the label respectively. " + "If you want to use a custom template, you can set --apply-chat-template to true, in that case, " + "the input should be the same structure as an openai message, e.g. [{'role': 'user', 'content': 'blabla'}]. " + ), + ) + parser.add_argument("--val-interval", type=int, default=0, help="Validation interval.") + parser.add_argument("--val-steps", type=int, default=0, help="Number of validation steps.") + parser.add_argument("--apply-chat-template", action="store_true", default=False) # Temporarily be JSON-serialized str, will be a real dict after using Omegaconf parser.add_argument("--apply-chat-template-kwargs", type=json.loads, default="{}") @@ -1241,6 +1258,7 @@ def add_sglang_tp_size(): parser = add_cluster_arguments(parser) parser = add_train_arguments(parser) + parser = add_lora_arguments(parser) parser = add_rollout_arguments(parser) parser = add_fault_tolerance_arguments(parser) parser = add_data_arguments(parser) diff --git a/scripts/run-sft-torchrun.sh b/scripts/run-sft-torchrun.sh new file mode 100644 index 000000000..33d99b3f9 --- /dev/null +++ b/scripts/run-sft-torchrun.sh @@ -0,0 +1,117 @@ +#!/bin/bash +# +# Ray-free SFT Training Script +# +export PYTHONUNBUFFERED=1 +export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True +# FIXME(f.srambical): this is hardcoded for now +GPUS_PER_NODE=${SLURM_GPUS_ON_NODE} +NUM_NODES=${SLURM_JOB_NUM_NODES} +NODE_RANK=${SLURM_NODEID} +MASTER_ADDR=${MASTER_ADDR:-$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)} + +NVLINK_COUNT=$(nvidia-smi | grep -o "NVLink" | wc -l) +if [ "$NVLINK_COUNT" -gt 0 ]; then + HAS_NVLINK=1 +else + HAS_NVLINK=0 +fi +echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)" + +export NCCL_DEBUG=INFO +export TORCH_DISTRIBUTED_DEBUG=INFO + +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" + +RUN_ID=${RUN_ID:-"run_$(date +%Y%m%d_%H%M%S)"} +LOAD_PATH="/fast/project/HFMI_SynergyUnit/tab_model/huggingface/Qwen3-0.6B" +SAVE_PATH="/fast/project/HFMI_SynergyUnit/tab_model/huggingface/shared_data/${RUN_ID}/checkpoints" + +CKPT_ARGS=( + --hf-checkpoint /fast/project/HFMI_SynergyUnit/tab_model/huggingface/Qwen3-0.6B + --load ${LOAD_PATH} + --ref-load /fast/project/HFMI_SynergyUnit/tab_model/huggingface/Qwen3-0.6B + --save ${SAVE_PATH} + --save-interval 1000 +) + +SFT_ARGS=( + --rollout-function-path miles.rollout.sft_rollout.generate_rollout + --prompt-data /fast/project/HFMI_SynergyUnit/tab_model/huggingface/nemo_hf_part_jsonl_4k_tokens.jsonl + --val-prompt-data /fast/project/HFMI_SynergyUnit/tab_model/huggingface/nemo_hf_part_jsonl_4k_tokens_validation.jsonl + --val-interval 1000 + --val-steps 100 + --input-key messages + --apply-chat-template + --rollout-shuffle + --num-rollout 10000 + --rollout-batch-size 16 + --global-batch-size 16 + + --loss-type sft_loss + --calculate-per-token-loss + --disable-compute-advantages-and-returns +) + +LORA_ARGS=( + --use-lora + --lora-rank 8 + --lora-alpha 16 + --lora-dropout 0.0 + --lora-target-modules q_proj v_proj +) + +OPTIMIZER_ARGS=( + --optimizer adam + --lr 1e-4 + --lr-decay-style WSD + --lr-wsd-decay-style linear + --lr-warmup-iters 500 + --lr-decay-iters 10000 + --lr-wsd-decay-iters 2000 + --weight-decay 0.1 + --adam-beta1 0.9 + --adam-beta2 0.98 +) + +WANDB_ARGS=( + --use-wandb + --wandb-project crowd-pilot-miles + --wandb-team instant-uv + --wandb-group qwen3-0.6b-sft-torchrun +) + +TRAIN_BACKEND_ARGS=( + --train-backend fsdp + --update-weight-buffer-size 536870912 + --gradient-checkpointing + --attn-implementation flash_attention_3 +) + +PERF_ARGS=( + --use-dynamic-batch-size + --max-tokens-per-gpu 9216 +) + +MISC_ARGS=( + --rollout-max-context-len 8192 + --rollout-max-prompt-len 8000 + --rollout-max-response-len 8192 + --dump-details /fast/project/HFMI_SynergyUnit/tab_model/huggingface/shared_data/qwen3-600M-fsdp-1116-noref/dump_details +) + +torchrun \ + --nproc_per_node=${GPUS_PER_NODE} \ + --nnodes=${NUM_NODES} \ + --node_rank=${NODE_RANK} \ + --master_addr=${MASTER_ADDR} \ + --master_port=${MASTER_PORT:-29500} \ + train_sft.py \ + ${CKPT_ARGS[@]} \ + ${SFT_ARGS[@]} \ + ${LORA_ARGS[@]} \ + ${OPTIMIZER_ARGS[@]} \ + ${WANDB_ARGS[@]} \ + ${TRAIN_BACKEND_ARGS[@]} \ + ${PERF_ARGS[@]} \ + ${MISC_ARGS[@]} diff --git a/tests/models/peft/test_lora.py b/tests/models/peft/test_lora.py new file mode 100644 index 000000000..81ccc2e26 --- /dev/null +++ b/tests/models/peft/test_lora.py @@ -0,0 +1,64 @@ +import unittest +import torch +import torch.nn as nn +from miles.models.peft import LoRAConfig, LoRALinear, apply_lora, get_lora_state_dict, load_lora_state_dict + +class TestLoRA(unittest.TestCase): + def setUp(self): + self.input_dim = 10 + self.output_dim = 20 + self.base_layer = nn.Linear(self.input_dim, self.output_dim) + self.config = LoRAConfig(lora_rank=4, lora_alpha=8, lora_dropout=0.0) + + def test_lora_linear_forward(self): + lora_layer = LoRALinear(self.base_layer, rank=4, alpha=8, dropout=0.0) + x = torch.randn(5, self.input_dim) + + # Initial forward should match base layer (since B is zero) + out_lora = lora_layer(x) + out_base = self.base_layer(x) + torch.testing.assert_close(out_lora, out_base) + + # Modify LoRA weights + lora_layer.lora_B.data.fill_(1.0) + out_lora_mod = lora_layer(x) + self.assertFalse(torch.allclose(out_lora_mod, out_base)) + + def test_apply_lora(self): + model = nn.Sequential( + nn.Linear(10, 10), + nn.Linear(10, 10) + ) + # Name modules to match default target "q_proj", "v_proj" won't work here. + # Let's use custom config + config = LoRAConfig(target_modules=["0"], lora_rank=4) + + model = apply_lora(model, config) + + self.assertIsInstance(model[0], LoRALinear) + self.assertIsInstance(model[1], nn.Linear) + + # Check gradients + self.assertTrue(model[0].lora_A.requires_grad) + self.assertFalse(model[0].base_layer.weight.requires_grad) + self.assertFalse(model[1].weight.requires_grad) # Should be frozen by apply_lora + + def test_state_dict(self): + model = nn.Sequential( + nn.Linear(10, 10) + ) + config = LoRAConfig(target_modules=["0"], lora_rank=4) + model = apply_lora(model, config) + + state_dict = get_lora_state_dict(model) + self.assertEqual(len(state_dict), 2) # A and B + self.assertTrue(all("lora_" in k for k in state_dict.keys())) + + # Test loading + new_state = {k: torch.ones_like(v) for k, v in state_dict.items()} + load_lora_state_dict(model, new_state) + + self.assertTrue(torch.allclose(model[0].lora_A, torch.ones_like(model[0].lora_A))) + +if __name__ == "__main__": + unittest.main() diff --git a/train_sft.py b/train_sft.py new file mode 100644 index 000000000..5d18a607f --- /dev/null +++ b/train_sft.py @@ -0,0 +1,758 @@ +#!/usr/bin/env python3 +""" +Ray-free SFT Training Script for Miles. + +This script provides a simplified training path for Supervised Fine-Tuning (SFT) +that bypasses Ray entirely and uses torchrun for distributed training. + +Usage: + torchrun --nproc_per_node=2 train_sft.py \ + --hf-checkpoint /path/to/model \ + --prompt-data /path/to/data.parquet \ + --input-key messages \ + --apply-chat-template \ + ... + +This is equivalent to the Ray-based SFT with --debug-train-only, but without +the Ray overhead. +""" + +import logging +import os +from argparse import Namespace +from datetime import timedelta +from itertools import accumulate + +import torch +import torch.distributed as dist +from torch.distributed.device_mesh import init_device_mesh +from tqdm import tqdm +from transformers import AutoConfig + +from ring_flash_attn import substitute_hf_flash_attn, update_ring_flash_attn_params + +from miles.models.peft import LoRAConfig, apply_lora +from miles.backends.fsdp_utils import checkpoint +from miles.backends.fsdp_utils.actor import ( + apply_fsdp2, + get_logprob_and_entropy_with_cp, + sum_of_sample_mean, +) +from miles.backends.fsdp_utils.data_packing import ( + pack_sequences, + pad_packed_sequence_with_cp, + unpack_sequences, +) +from miles.backends.fsdp_utils.lr_scheduler import get_lr_scheduler +from miles.rollout.data_source import RolloutDataSource +from miles.utils import tracking_utils +from miles.utils.arguments import parse_args +from miles.utils.data import get_minimum_num_micro_batch_size +from miles.utils.distributed_utils import get_gloo_group, init_gloo_group +from miles.utils.logging_utils import configure_logger +from miles.utils.mask_utils import MultiTurnLossMaskGenerator +from miles.utils.misc import should_run_periodic_action +from miles.utils.processing_utils import load_processor, load_tokenizer +from miles.utils.profile_utils import TrainProfiler +from miles.utils.timer import timer +from miles.utils.tracking_utils import init_tracking +from miles.utils.types import Sample + +logger = logging.getLogger(__name__) + + +class SFTTrainer: + """ + A simplified trainer for SFT that runs without Ray. + + This class combines the functionality of: + - FSDPTrainRayActor (model initialization, FSDP wrapping, training) + - RolloutManager (data loading via generate_rollout) + - The main training loop from train.py + """ + + def __init__(self, args: Namespace): + self.args = args + self.device = torch.device("cuda") + + self._init_distributed() + + self._setup_device_mesh() + + torch.manual_seed(args.seed) + + self._enable_true_on_policy_optimizations() + + if dist.get_rank() == 0: + init_tracking(args, primary=True) + + self._load_tokenizer_and_config() + + self._init_data_source() + + self._init_model() + + self._init_optimizer() + + self._load_checkpoint() + + self.prof = TrainProfiler(args) + self.prof.on_init_end() + + logger.info(f"[Rank {dist.get_rank()}] SFTTrainer initialized successfully") + + def _init_distributed(self): + """Initialize distributed training.""" + # torchrun sets these environment variables + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + torch.cuda.set_device(f"cuda:{local_rank}") + + backend = self.args.distributed_backend + dist.init_process_group( + backend=backend, + timeout=timedelta(minutes=self.args.distributed_timeout_minutes), + ) + init_gloo_group() + + self.args.rank = dist.get_rank() + self.args.world_size = dist.get_world_size() + + logger.info( + f"[Rank {self.args.rank}] Distributed initialized: " + f"world_size={self.args.world_size}, local_rank={local_rank}" + ) + + def _setup_device_mesh(self): + """Setup device mesh for FSDP (no context parallelism for SFT).""" + world_size = dist.get_world_size() + rank = dist.get_rank() + + self.cp_size = self.args.context_parallel_size + self.dp_size = world_size // self.cp_size + + self.mesh = init_device_mesh( + "cuda", + mesh_shape=(self.dp_size, self.cp_size), + mesh_dim_names=("dp", "cp"), + ) + + self.dp_group = self.mesh.get_group("dp") + self.cp_group = self.mesh.get_group("cp") + self.dp_mesh = self.mesh["dp"] + + self.dp_rank = rank // self.cp_size + self.cp_rank = rank % self.cp_size + + logger.info( + f"[Rank {rank}] Device mesh: dp_size={self.dp_size}, cp_size={self.cp_size}, " + f"dp_rank={self.dp_rank}, cp_rank={self.cp_rank}" + ) + + # Setup Ring Flash Attention with CP group from mesh (only when cp_size > 1) + if self.cp_size > 1: + substitute_hf_flash_attn(self.cp_group, heads_k_stride=1) + logger.info(f"[Rank {rank}] CP initialized via device mesh") + + def _enable_true_on_policy_optimizations(self): + """Enable true on-policy optimizations or apply MoE patches.""" + if self.args.true_on_policy_mode: + from sglang.srt.batch_invariant_ops import enable_batch_invariant_mode + + from miles.backends.fsdp_utils.models.qwen3_moe import ( + apply_true_on_policy_patch_for_qwen3_moe, + ) + + logger.info("SFTTrainer: enabling batch_invariant_mode for true-on-policy") + enable_batch_invariant_mode( + # In Qwen3, rope uses bmm; disabling makes it aligned + enable_bmm=False, + ) + + apply_true_on_policy_patch_for_qwen3_moe() + else: + from miles.backends.fsdp_utils.models.qwen3_moe_hf import apply_fsdp_moe_patch + + apply_fsdp_moe_patch() + + def _load_tokenizer_and_config(self): + """Load tokenizer and model config sequentially to avoid race conditions.""" + for i in range(dist.get_world_size()): + if i == dist.get_rank(): + self.hf_config = AutoConfig.from_pretrained( + self.args.hf_checkpoint, trust_remote_code=True + ) + self.tokenizer = load_tokenizer( + self.args.hf_checkpoint, trust_remote_code=True + ) + self.processor = None + if self.args.multimodal_keys: + self.processor = load_processor( + self.args.hf_checkpoint, trust_remote_code=True + ) + dist.barrier(group=get_gloo_group()) + + # Initialize loss mask generator for SFT + self.mask_generator = MultiTurnLossMaskGenerator( + self.tokenizer, + tokenizer_type=getattr(self.args, "loss_mask_type", None), + ) + + def _init_data_source(self): + """Initialize the data source for SFT training.""" + self.data_source = RolloutDataSource(self.args, self.args.prompt_data) + self.val_data_source = None + if self.args.val_prompt_data is not None: + self.val_data_source = RolloutDataSource(self.args, self.args.val_prompt_data) + + # Calculate num_rollout from dataset size + if self.args.num_rollout is None: + num_rollout_per_epoch = len(self.data_source.dataset) // self.args.rollout_batch_size + self.args.num_rollout = num_rollout_per_epoch * self.args.num_epoch + self.num_rollout_per_epoch = num_rollout_per_epoch + else: + self.num_rollout_per_epoch = None + + if getattr(self.args, "start_rollout_id", None) is None: + self.args.start_rollout_id = 0 + + logger.info( + f"[Rank {dist.get_rank()}] Data source initialized: " + f"dataset_size={len(self.data_source.dataset)}, " + f"num_rollout={self.args.num_rollout}" + ) + + def _get_init_weight_context_manager(self): + """Get context manager for model initialization.""" + from accelerate import init_empty_weights + + use_meta_tensor = not self.hf_config.tie_word_embeddings + + def cpu_init_weights(): + return torch.device("cpu") + + if use_meta_tensor: + return init_empty_weights if dist.get_rank() != 0 else cpu_init_weights + else: + return cpu_init_weights + + def _fsdp2_load_full_state_dict(self, model, full_state, device_mesh, cpu_offload): + """Load full state dict into FSDP2 model with broadcast from rank 0.""" + from torch.distributed.checkpoint.state_dict import ( + StateDictOptions, + set_model_state_dict, + ) + + if dist.get_rank() == 0: + model = model.to(device=torch.cuda.current_device(), non_blocking=True) + else: + model = model.to_empty(device=torch.cuda.current_device()) + + is_cpu_offload = cpu_offload is not None + options = StateDictOptions( + full_state_dict=True, cpu_offload=is_cpu_offload, broadcast_from_rank0=True + ) + + set_model_state_dict(model, full_state, options=options) + + for _name, buf in model.named_buffers(): + dist.broadcast(buf, src=0) + + if is_cpu_offload: + model.to("cpu", non_blocking=True) + for buf in model.buffers(): + buf.data = buf.data.to(torch.cuda.current_device()) + + return model + + def _get_model_cls(self): + """Get the appropriate model class based on config.""" + if hasattr(self.hf_config, "vision_config"): + from transformers import AutoModelForImageTextToText + + return AutoModelForImageTextToText + else: + from transformers import AutoModelForCausalLM + + return AutoModelForCausalLM + + def _init_model(self): + """Initialize and wrap model with FSDP.""" + self.fsdp_cpu_offload = getattr(self.args, "fsdp_cpu_offload", False) + + init_context = self._get_init_weight_context_manager() + + with init_context(): + model = self._get_model_cls().from_pretrained( + self.args.hf_checkpoint, + trust_remote_code=True, + attn_implementation=self.args.attn_implementation, + ) + + if self.args.use_lora: + lora_config = LoRAConfig( + lora_rank=self.args.lora_rank, + lora_alpha=self.args.lora_alpha, + lora_dropout=self.args.lora_dropout, + target_modules=self.args.lora_target_modules, + ) + model = apply_lora(model, lora_config) + logger.info(f"[Rank {dist.get_rank()}] Applied LoRA: {lora_config}") + + model.train() + full_state = model.state_dict() + + model = apply_fsdp2( + model, mesh=self.dp_mesh, cpu_offload=self.fsdp_cpu_offload, args=self.args + ) + + model = self._fsdp2_load_full_state_dict( + model, + full_state, + self.dp_mesh, + cpu_offload=True if self.fsdp_cpu_offload else None, + ) + + self.model = model + + if self.args.gradient_checkpointing: + # Use non-reentrant mode for gradient checkpointing + self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) + + logger.info(f"[Rank {dist.get_rank()}] Model initialized with FSDP") + + def _init_optimizer(self): + """Initialize optimizer and learning rate scheduler.""" + trainable_params = [p for p in self.model.parameters() if p.requires_grad] + + if self.args.use_lora: + total_params = sum(p.numel() for p in self.model.parameters()) + trainable_count = sum(p.numel() for p in trainable_params) + logger.info( + f"[Rank {dist.get_rank()}] LoRA: {trainable_count:,} trainable params " + f"out of {total_params:,} total ({100 * trainable_count / total_params:.2f}%)" + ) + + if self.args.optimizer == "adam": + self.optimizer = torch.optim.AdamW( + trainable_params, + lr=self.args.lr, + betas=(self.args.adam_beta1, self.args.adam_beta2), + eps=self.args.adam_eps, + weight_decay=self.args.weight_decay, + ) + else: + raise ValueError(f"Unsupported optimizer: {self.args.optimizer}") + + self.lr_scheduler = get_lr_scheduler(self.args, self.optimizer) + self.global_step = 0 + self.micro_step = 0 + + def _load_checkpoint(self): + """Load checkpoint if available.""" + checkpoint_payload = checkpoint.load(self) + checkpoint.finalize_load(self, checkpoint_payload) + + if self.args.rollout_global_dataset and self.args.start_rollout_id > 0: + self.data_source.load(self.args.start_rollout_id - 1) + + def generate_sft_rollout(self, rollout_id: int, data_source: RolloutDataSource) -> list[Sample]: + """Generate SFT rollout data (tokenize and create loss masks).""" + samples = data_source.get_samples(self.args.rollout_batch_size) + + result = [] + for i, (sample,) in enumerate(samples): + messages = sample.prompt + token_ids, loss_mask = self.mask_generator.get_loss_mask(messages) + response_length = self.mask_generator.get_response_lengths([loss_mask])[0] + + sample.tokens = token_ids + sample.response_length = response_length + sample.reward = 0 + sample.loss_mask = loss_mask[-response_length:] + result.append(sample) + + if i == 0 and rollout_id == 0 and dist.get_rank() == 0: + logger.info( + f"SFT rollout sample: tokens_len={len(token_ids)}, " + f"response_length={response_length}" + ) + + return result + + def _convert_samples_to_train_data(self, samples: list[Sample]) -> dict: + """Convert samples to training data format.""" + train_data = { + "tokens": [sample.tokens for sample in samples], + "response_lengths": [sample.response_length for sample in samples], + "rewards": [0.0 for _ in samples], + "raw_reward": [0.0 for _ in samples], + "truncated": [0 for _ in samples], + "sample_indices": [sample.index for sample in samples], + } + + loss_masks = [] + for sample in samples: + if sample.loss_mask is None: + sample.loss_mask = [1] * sample.response_length + loss_masks.append(sample.loss_mask) + train_data["loss_masks"] = loss_masks + + return train_data + + def _split_train_data_by_dp(self, data: dict) -> dict: + """Split training data for current DP rank.""" + total_lengths = [len(t) for t in data["tokens"]] + data["total_lengths"] = total_lengths + + # Simple round-robin partitioning + partition = list(range(self.dp_rank, len(total_lengths), self.dp_size)) + + rollout_data = {"partition": partition, "total_lengths": total_lengths} + + for key in [ + "tokens", + "response_lengths", + "rewards", + "raw_reward", + "truncated", + "loss_masks", + "sample_indices", + ]: + if key in data: + rollout_data[key] = [data[key][j] for j in partition] + + return rollout_data + + def _packed_data(self, rollout_data: dict) -> tuple[list[dict], list[int]]: + """Pack variable-length sequences for efficient processing.""" + tokens = rollout_data["tokens"] + + packed_batches = [] + mbs_size_list = [] + local_batch_size = self.args.global_batch_size // self.dp_size + + if self.args.use_dynamic_batch_size: + max_tokens = self.args.max_tokens_per_gpu + if self.cp_size > 1: + max_tokens = max_tokens * self.cp_size + + for i in range(0, len(tokens), local_batch_size): + mbs_size_list.append( + get_minimum_num_micro_batch_size( + [len(t) for t in rollout_data["tokens"][i : i + local_batch_size]], + max_tokens, + ) + ) + num_microbatches = torch.tensor( + mbs_size_list, dtype=torch.int, device=torch.cuda.current_device() + ) + dist.all_reduce(num_microbatches, op=dist.ReduceOp.MAX, group=self.dp_group) + num_microbatches = num_microbatches.tolist() + else: + num_microbatches = [ + self.args.global_batch_size // (self.args.micro_batch_size * self.dp_size) + ] * (len(tokens) // local_batch_size) + + start = 0 + for mbs_size in num_microbatches: + end = start + local_batch_size + # Create dummy advantages/returns for SFT (not used but required by pack_sequences) + dummy_advantages = [ + torch.zeros(rollout_data["response_lengths"][i]) + for i in range(start, end) + ] + packed_batches.extend( + pack_sequences( + rollout_data["tokens"][start:end], + rollout_data["loss_masks"][start:end], + rollout_data["rewards"][start:end], + rollout_data["raw_reward"][start:end], + rollout_data["response_lengths"][start:end], + dummy_advantages, # advantages + dummy_advantages, # returns + num_packs=mbs_size, + ) + ) + start = end + + grad_accum = list(accumulate(num_microbatches)) + return packed_batches, grad_accum + + def _get_model_inputs_args(self, packed_sequence: dict) -> dict: + """Prepare model input arguments from packed sequence.""" + input_ids = packed_sequence["tokens"].unsqueeze(0) + position_ids = packed_sequence["position_ids"].unsqueeze(0) + + if self.cp_size > 1: + packed_sequence = pad_packed_sequence_with_cp(packed_sequence, self.cp_size) + + if not packed_sequence["cu_seqlens"].is_cuda: + packed_sequence["cu_seqlens"] = packed_sequence["cu_seqlens"].cuda() + cu_seqlens = packed_sequence["cu_seqlens"] + update_ring_flash_attn_params(cu_seqlens, self.cp_group) + + input_ids = torch.chunk( + packed_sequence["tokens"].unsqueeze(0), self.cp_size, dim=1 + )[self.cp_rank] + position_ids = torch.chunk( + packed_sequence["position_ids"].unsqueeze(0), self.cp_size, dim=1 + )[self.cp_rank] + + model_args = { + "input_ids": input_ids, + "position_ids": position_ids, + "attention_mask": None, + } + + if packed_sequence.get("multimodal_inputs"): + model_args.update(packed_sequence["multimodal_inputs"]) + + return model_args + + def _compute_sft_loss(self, unpacked_batches: list[dict], logits: torch.Tensor): + """Compute SFT loss (negative log likelihood).""" + loss_masks = [ + batch["loss_masks"].to(device=logits.device) for batch in unpacked_batches + ] + response_lengths = [batch["response_lengths"] for batch in unpacked_batches] + log_probs = torch.cat( + [batch["cur_log_probs"] for batch in unpacked_batches], dim=0 + ) + loss = -sum_of_sample_mean(log_probs, response_lengths, loss_masks) + + if log_probs.numel() == 0: + loss += 0 * logits.sum() + + return loss, {"loss": loss.detach()} + + def _train_step( + self, + packed_batch: dict, + reported_accum: dict, + mbs_id: int, + grad_accum: list[int], + ): + """Execute one training step.""" + # Prepare model inputs + model_args = self._get_model_inputs_args(packed_batch) + logits = self.model(**model_args).logits.squeeze(0).float() + + # Compute log probs and entropy (unified for both CP and non-CP modes) + log_probs, entropy_result = get_logprob_and_entropy_with_cp( + logits=logits, + target_tokens=packed_batch["tokens"], + cp_rank=self.cp_rank, + cp_size=self.cp_size, + cp_group=self.cp_group, + model_input_ids=model_args["input_ids"], + allow_compile=not self.args.true_on_policy_mode, + temperature=self.args.rollout_temperature, + ) + packed_batch["cur_log_probs"] = log_probs + packed_batch["entropy"] = entropy_result + + unpacked_batches = unpack_sequences(packed_batch) + loss, reported = self._compute_sft_loss(unpacked_batches, logits) + + # Scale loss for gradient accumulation + loss = loss * self.dp_size / self.args.global_batch_size + loss.backward() + + # Accumulate reported metrics (store tensors for later mean) + for k, v in reported.items(): + reported_accum.setdefault(k, []).append(v) + + if (mbs_id + 1) in grad_accum: + # TODO: check if the grad norm is global grad norm. + grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.clip_grad) + # the grad norm used to be of DTensor + grad_norm = float(grad_norm) + + self.optimizer.step() + # Update learning rate + self.lr_scheduler.step() + self.optimizer.zero_grad(set_to_none=True) + # Aggregate logs + aggregated = {k: torch.stack(v).sum().item() for k, v in reported_accum.items()} + # TODO: change this, this is slow. + reduced_aggregated = [None] * self.dp_size + dist.all_gather_object(reduced_aggregated, aggregated, group=self.dp_group) + aggregated = {} + for k in reported_accum.keys(): + aggregated[k] = sum([r[k] for r in reduced_aggregated]) / (self.args.global_batch_size) + reported_accum.clear() + if dist.get_rank() == 0: + log_dict = { + f"train/{k}": (val.item() if torch.is_tensor(val) else val) for k, val in aggregated.items() + } + log_dict["train/grad_norm"] = grad_norm + + # Log learning rate per parameter group; use scheduler's last computed LRs + lr_values = self.lr_scheduler.get_last_lr() + for gid, _group in enumerate(self.optimizer.param_groups): + log_dict[f"train/lr_{gid}"] = lr_values[gid] + + logger.info(f"step {self.global_step}: {log_dict}") + log_dict["train/step"] = self.global_step + tracking_utils.log(self.args, log_dict, step_key="train/step") + self.global_step += 1 + + def train_one_rollout(self, rollout_id: int): + """Execute one rollout's worth of training.""" + self.model.train() + samples = self.generate_sft_rollout(rollout_id, self.data_source) + + train_data = self._convert_samples_to_train_data(samples) + + rollout_data = self._split_train_data_by_dp(train_data) + + packed_batches, grad_accum = self._packed_data(rollout_data) + + if len(grad_accum) == 0: + logger.warning(f"[Rank {dist.get_rank()}] No batches to train on rollout {rollout_id}") + return + + with timer("actor_train"): + reported_accum = {} + self.optimizer.zero_grad(set_to_none=True) + + for mbs_id, packed_batch in enumerate( + tqdm(packed_batches, desc="actor_train", disable=dist.get_rank() != 0) + ): + self._train_step(packed_batch, reported_accum, mbs_id, grad_accum) + + self.prof.step(rollout_id=rollout_id) + + def calculate_val_loss(self, rollout_id: int): + """Calculate validation loss over `args.val_steps`.""" + self.model.eval() + reported_accum = {} + for v_step in tqdm(range(self.args.val_steps), desc="actor_val", disable=dist.get_rank() != 0): + samples = self.generate_sft_rollout(rollout_id, self.val_data_source) + val_data = self._convert_samples_to_train_data(samples) + rollout_data = self._split_train_data_by_dp(val_data) + packed_batches, accum = self._packed_data(rollout_data) + + if len(accum) == 0: + logger.warning(f"[Rank {dist.get_rank()}] No batches to validate on rollout {rollout_id}, validation step {v_step}") + return + + for mbs_id, packed_batch in enumerate(packed_batches): + reported = self._val_step(packed_batch) + for k, v in reported.items(): + reported_accum.setdefault(k, []).append(v) + + aggregated = {k: torch.stack(v).sum().item() for k, v in reported_accum.items()} + # TODO: change this, this is slow. + reduced_aggregated = [None] * self.dp_size + dist.all_gather_object(reduced_aggregated, aggregated, group=self.dp_group) + aggregated = {} + for k in reported_accum.keys(): + aggregated[k] = sum([r[k] for r in reduced_aggregated]) / (self.args.global_batch_size * self.args.val_steps) + reported_accum.clear() + if dist.get_rank() == 0: + log_dict = { + f"val/{k}": (val.item() if torch.is_tensor(val) else val) for k, val in aggregated.items() + } + logger.info(f"step {self.global_step}: {log_dict}") + log_dict["val/step"] = self.global_step + tracking_utils.log(self.args, log_dict, step_key="val/step") + + def _val_step(self, packed_batch): + model_args = self._get_model_inputs_args(packed_batch) + with torch.no_grad(): + logits = self.model(**model_args).logits.squeeze(0).float() + + # Compute log probs and entropy (unified for both CP and non-CP modes) + log_probs, entropy_result = get_logprob_and_entropy_with_cp( + logits=logits, + target_tokens=packed_batch["tokens"], + cp_rank=self.cp_rank, + cp_size=self.cp_size, + cp_group=self.cp_group, + model_input_ids=model_args["input_ids"], + allow_compile=not self.args.true_on_policy_mode, + temperature=self.args.rollout_temperature, + ) + packed_batch["cur_log_probs"] = log_probs + packed_batch["entropy"] = entropy_result + + unpacked_batches = unpack_sequences(packed_batch) + _, reported = self._compute_sft_loss(unpacked_batches, logits) + return reported + + + def save_model(self, iteration: int): + """Save model checkpoint.""" + if self.args.save is None: + return + + keys_filter = None + if self.args.use_lora: + keys_filter = lambda k: "lora_" in k + + checkpoint.save(self, iteration, keys_filter=keys_filter) + + if self.args.rollout_global_dataset: + self.data_source.save(iteration) + + def train(self): + """Main training loop.""" + logger.info( + f"[Rank {dist.get_rank()}] Starting training: " + f"rollout_id {self.args.start_rollout_id} -> {self.args.num_rollout}" + ) + if self.args.val_prompt_data: + assert self.args.val_interval > 0, f"val_interval must be greater than 0 when val_prompt_data is provided, got {self.args.val_interval}" + assert self.args.val_steps > 0, f"val_steps must be greater than 0 when val_prompt_data is provided, got {self.args.val_steps}" + + # calculate val loss at the beginning of training + if self.args.val_prompt_data and self.args.start_rollout_id == 0: + self.calculate_val_loss(rollout_id=0) + + for rollout_id in range(self.args.start_rollout_id, self.args.num_rollout): + self.train_one_rollout(rollout_id) + + # Save checkpoint periodically + if should_run_periodic_action( + rollout_id, self.args.save_interval, self.num_rollout_per_epoch + ): + self.save_model(rollout_id) + + # Calculate val loss periodically + if self.args.val_prompt_data and should_run_periodic_action(rollout_id, self.args.val_interval): + self.calculate_val_loss(rollout_id) + + logger.info(f"[Rank {dist.get_rank()}] Training completed!") + + +def set_sft_defaults(args: Namespace) -> Namespace: + """Set default values appropriate for SFT training.""" + if not hasattr(args, "loss_type") or args.loss_type is None: + args.loss_type = "sft_loss" + + if not hasattr(args, "advantage_estimator"): + args.advantage_estimator = None + + args.offload_train = False + args.offload_rollout = False + args.colocate = False + + return args + + +def main(): + configure_logger() + + args = parse_args() + + args = set_sft_defaults(args) + + trainer = SFTTrainer(args) + trainer.train() + + +if __name__ == "__main__": + main() + +