From 18ea7bd8288f13432600a4720c8d7e5ef6c5b527 Mon Sep 17 00:00:00 2001 From: Sam Herring Date: Wed, 14 Jan 2026 10:58:12 -0800 Subject: [PATCH 1/8] Adding example Atropos server, config, and startup scripts --- torchtitan/grpo/test/gsm8k_server.py | 375 ++++++++++++++++++ torchtitan/grpo/test/scripts/run_full_test.sh | 123 ++++++ torchtitan/grpo/test/scripts/start_api.sh | 19 + torchtitan/grpo/test/scripts/start_env.sh | 58 +++ torchtitan/grpo/test/scripts/start_sglang.sh | 58 +++ torchtitan/grpo/test/scripts/start_trainer.sh | 68 ++++ torchtitan/grpo/test/test_config.toml | 101 +++++ 7 files changed, 802 insertions(+) create mode 100644 torchtitan/grpo/test/gsm8k_server.py create mode 100755 torchtitan/grpo/test/scripts/run_full_test.sh create mode 100755 torchtitan/grpo/test/scripts/start_api.sh create mode 100755 torchtitan/grpo/test/scripts/start_env.sh create mode 100755 torchtitan/grpo/test/scripts/start_sglang.sh create mode 100755 torchtitan/grpo/test/scripts/start_trainer.sh create mode 100644 torchtitan/grpo/test/test_config.toml diff --git a/torchtitan/grpo/test/gsm8k_server.py b/torchtitan/grpo/test/gsm8k_server.py new file mode 100644 index 0000000000..422dcec923 --- /dev/null +++ b/torchtitan/grpo/test/gsm8k_server.py @@ -0,0 +1,375 @@ +import random +import time +from typing import Dict, List, Optional, Tuple, TypedDict, Union + +from datasets import load_dataset +from latex2sympy2_extended import NormalizationConfig +from math_verify import LatexExtractionConfig, parse, verify +from tqdm.asyncio import tqdm_asyncio + +from atroposlib.envs.base import ( + APIServerConfig, + BaseEnv, + BaseEnvConfig, + ScoredDataGroup, +) +from atroposlib.type_definitions import Item + +system_prompt = ( + "You are a deep thinking AI, you may use extremely long chains of thought " + "to deeply consider the problem and deliberate with yourself via systematic " + "reasoning processes to help come to a correct solution prior to answering. " + "You should enclose your thoughts and internal monologue inside " + "tags, and then provide your solution or response to the problem.\n\n" +) + +system_prompt += """You are allocated a maximum of 2048 tokens, please strive to use less. + +You will then provide your answer like this: \\boxed{your answer here} +It is important that you provide your answer in the correct format. +If you do not, you will not receive credit for your answer. +So please end your answer with \\boxed{your answer here}""" + + +class GSM8kRow(TypedDict): + question: str + answer: str + + +class GSM8kEnv(BaseEnv): + + name = "gsm8k" + + def __init__( + self, + config: BaseEnvConfig, + server_configs: List[APIServerConfig], + slurm=True, + testing=False, + ): + super().__init__(config, server_configs, slurm, testing) + self.percent_correct_buffer = list() + self.eval_metrics = list() + # Add tracking for wandb visualizations + self.rollouts_for_wandb = [] + self.completion_lengths = [] + + @classmethod + def config_init(cls) -> Tuple[BaseEnvConfig, List[APIServerConfig]]: + env_config = BaseEnvConfig( + tokenizer_name="Qwen/Qwen3-1.7B", + group_size=8, + use_wandb=True, + rollout_server_url="http://localhost:8000", + total_steps=1000, + batch_size=12, + steps_per_eval=100, + max_token_length=2048, + wandb_name="gsm8k_qwen3_test", + ) + server_configs = [ + APIServerConfig( + model_name="Qwen/Qwen3-1.7B", + base_url="http://localhost:9001/v1", + api_key="x", + num_requests_for_eval=256, + ), + APIServerConfig( + model_name="Qwen/Qwen3-1.7B", + base_url="http://localhost:9002/v1", + api_key="x", + num_requests_for_eval=256, + ), + ] + + return env_config, server_configs + + async def wandb_log(self, wandb_metrics: Optional[Dict] = None): + if wandb_metrics is None: + wandb_metrics = {} + + # Try to calculate percent_correct, pass if there's a division by zero + try: + wandb_metrics["train/percent_correct"] = sum( + self.percent_correct_buffer + ) / len(self.percent_correct_buffer) + except ZeroDivisionError: + # Skip if buffer is empty + pass + + self.percent_correct_buffer = list() + for item in self.eval_metrics: + wandb_metrics[item[0]] = item[1] + self.eval_metrics = list() + # Call the parent method to handle the server metrics + await super().wandb_log(wandb_metrics) + + async def setup(self): + self.train = load_dataset("gsm8k", "main", split="train").shuffle(seed=42) + test_data = load_dataset("gsm8k", "main", split="test").shuffle(seed=42) + self.test = list() + for item in test_data: + self.test.append( + { + "question": item["question"], + "gold_answer": item["answer"] + .split("#")[-1] + .strip() + .replace(",", ""), + } + ) + self.iter = 0 + + def save_checkpoint(self, step, data=None): + if data is None: + data = {} + data["iter"] = self.iter + super().save_checkpoint(step, data) + + async def rollout_and_score_eval(self, question: str, answer: str) -> dict: + """Rollout and score evaluation with detailed sample data collection.""" + + async with self.server.managed_server(tokenizer=self.tokenizer) as managed: + completion = await managed.chat_completion( + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": question}, + ], + n=1, + max_tokens=self.config.max_token_length, + temperature=0.6, + ) + + response_content = completion.choices[0].message.content + + # Parse gold answer + gold_parsed = parse( + "\\boxed{" + answer + "}", + extraction_mode="first_match", + extraction_config=[LatexExtractionConfig()], + ) + + # Parse model answer + answer_parsed = parse( + response_content.split("")[-1], + extraction_config=[ + LatexExtractionConfig( + normalization_config=NormalizationConfig( + nits=False, + malformed_operators=False, + basic_latex=True, + equations=True, + boxed="all", + units=True, + ), + boxed_match_priority=0, + try_extract_without_anchor=False, + ) + ], + extraction_mode="first_match", + ) + + score = 1 if verify(answer_parsed, gold_parsed) else 0 + + sample = { + "messages": [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": question}, + {"role": "assistant", "content": response_content}, + ], + "question": question, + "gold_answer": answer, + "gold_parsed": str(gold_parsed) if gold_parsed else None, + "model_parsed": str(answer_parsed) if answer_parsed else None, + "score": int(score), + "correct": bool(score), + "finish_reason": completion.choices[0].finish_reason, + "response_after_think": ( + response_content.split("")[-1] + if "" in response_content + else response_content + ), + } + + return {"score": score, "sample": sample} + + async def evaluate(self, *args, **kwargs): + start_time = time.time() + + eval_tasks = [] + for item in self.test: + eval_tasks.append( + self.rollout_and_score_eval(item["question"], item["gold_answer"]) + ) + results = await tqdm_asyncio.gather(*eval_tasks) + + # Extract scores and samples + scores = [result["score"] for result in results] + samples = [result["sample"] for result in results] + + percent_correct = sum(scores) / len(scores) + + end_time = time.time() + + # Add to existing metrics for wandb + self.eval_metrics.append(("eval/percent_correct", percent_correct)) + + # Log evaluation results + eval_metrics = { + "eval/percent_correct": percent_correct, + } + + await self.evaluate_log( + metrics=eval_metrics, + samples=samples, + start_time=start_time, + end_time=end_time, + generation_parameters={ + "temperature": 0.0, + "max_tokens": self.config.max_token_length, + }, + ) + + async def collect_trajectories( + self, item: GSM8kRow + ) -> Tuple[ScoredDataGroup, list[Item]]: + user_message = {"role": "user", "content": item["question"]} + gold_answer = ( + "\\boxed{" + item["answer"].split("#")[-1].strip().replace(",", "") + "}" + ) + + async with self.server.managed_server(tokenizer=self.tokenizer) as managed: + + chat_completions = await managed.chat_completion( + messages=[{"role": "system", "content": system_prompt}, user_message], + n=self.config.group_size, + max_tokens=self.config.max_token_length, + temperature=1.0, + ) + + state = managed.get_state() + nodes = state["nodes"] + + to_score = list() + to_backlog = list() + for i, chat_completion in enumerate(chat_completions.choices): + messages = ( + {"role": "system", "content": system_prompt}, + user_message, + {"role": "assistant", "content": chat_completion.message.content}, + ) + to_score.append( + { + "messages": messages, + "gold_answer": gold_answer, + "finish_reason": chat_completion.finish_reason, + "tokens": nodes[i].tokens, + "masks": nodes[i].masked_tokens, + "logprobs": nodes[i].logprobs, + } + ) + to_postprocess = await self.score(to_score) + return to_postprocess, to_backlog + + async def score( + self, rollout_group_data + ) -> Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]]: + scores = ScoredDataGroup() + scores["tokens"] = list() + scores["masks"] = list() + scores["scores"] = list() + scores["inference_logprobs"] = list() + gold_parsed = parse( + rollout_group_data[0]["gold_answer"], + extraction_mode="first_match", + extraction_config=[LatexExtractionConfig()], + ) + if len(gold_parsed) != 0: + # We require the answer to be provided in correct latex (no malformed operators) + random.shuffle(rollout_group_data) + for item in rollout_group_data: + # print(item[0][-1]["content"]) + answer_parsed = parse( + item["messages"][-1]["content"].split("")[-1], + extraction_config=[ + LatexExtractionConfig( + normalization_config=NormalizationConfig( + nits=False, + malformed_operators=False, + basic_latex=True, + equations=True, + boxed="all", + units=True, + ), + # Ensures that boxed is tried first + boxed_match_priority=0, + try_extract_without_anchor=False, + ) + ], + extraction_mode="first_match", + ) + # Reward 1 if the content is the same as the ground truth, 0 otherwise + reward = verify(answer_parsed, gold_parsed) + + tokens = item["tokens"] + masks = item["masks"] + logprobs = item["logprobs"] + + # remove obviously bad examples + if len([1 for i in masks if i != -100]) < 10: + continue + scores["tokens"].append(tokens) + scores["masks"].append(masks) + scores["inference_logprobs"].append(logprobs) + scores["scores"].append(1.0 if reward else -1.0) + + if len(scores["tokens"]) >= self.config.group_size: + break + + for score in scores["scores"]: + self.percent_correct_buffer.append(max(score, 0)) + + # check if all the same + # print(scores['scores']) + if all([score == 1 for score in scores["scores"]]): + # Do length penalty :) + token_lengths = [len(token) for token in scores["tokens"]] + if max(token_lengths) == 0: + # What? But don't want to crash a run so just in case... + return None + + # Get max allowed token length from config + max_allowed_length = self.config.max_token_length + # Set threshold at 50% of max_token_length - no penalty below this + length_threshold = max_allowed_length * 0.5 + + # Apply modified length penalty with threshold + scores["scores"] = [] + for length in token_lengths: + if length <= length_threshold: + # No penalty for responses under threshold + scores["scores"].append(1.0) + else: + # Calculate how far we are between threshold and max as a percentage + percentage_of_range = (length - length_threshold) / ( + max_allowed_length - length_threshold + ) + # Cap at 1.0 in case length exceeds max_allowed_length + percentage_of_range = min(percentage_of_range, 1.0) + # Apply linear penalty scaling from 1.0 down to 0.0 + scores["scores"].append(1.0 - percentage_of_range) + if all([scores["scores"][0] == score for score in scores["scores"]]): + return None # If all the same, we return None + return scores + else: + # If the gold solution is not parseable, we return None + return None + + async def get_next_item(self) -> GSM8kRow: + next_item = self.train[self.iter % len(self.train)] + self.iter += 1 + return next_item + + +if __name__ == "__main__": + GSM8kEnv.cli() diff --git a/torchtitan/grpo/test/scripts/run_full_test.sh b/torchtitan/grpo/test/scripts/run_full_test.sh new file mode 100755 index 0000000000..a6898a1f73 --- /dev/null +++ b/torchtitan/grpo/test/scripts/run_full_test.sh @@ -0,0 +1,123 @@ +#!/bin/bash +# Master script to launch the full RL test pipeline + +set -e + +# Get script directory +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + +echo "========================================" +echo "GSM8k RL Test Pipeline Launcher" +echo "========================================" +echo "" +echo "This script will start all components:" +echo " 1. Atropos API Server" +echo " 2. SGLang Inference Servers" +echo " 3. GSM8k Environment Server" +echo " 4. TorchTitan Trainer" +echo "" +echo "Press Ctrl+C to stop all services" +echo "" + +# Cleanup function +cleanup() { + echo "" + echo "========================================" + echo "Shutting down all services..." + echo "========================================" + + if [ ! -z "$API_PID" ]; then + echo "Stopping Atropos API (PID: $API_PID)" + kill $API_PID 2>/dev/null || true + fi + + if [ ! -z "$SGLANG_PID" ]; then + echo "Stopping SGLang servers (PID: $SGLANG_PID)" + kill $SGLANG_PID 2>/dev/null || true + # Also kill any remaining sglang processes + pkill -f "sglang.launch_server" 2>/dev/null || true + fi + + if [ ! -z "$ENV_PID" ]; then + echo "Stopping GSM8k environment (PID: $ENV_PID)" + kill $ENV_PID 2>/dev/null || true + fi + + echo "All services stopped" + exit 0 +} + +# Set up trap for cleanup +trap cleanup EXIT INT TERM + +# Step 1: Start Atropos API +echo "Step 1/4: Starting Atropos API Server..." +"$SCRIPT_DIR/start_api.sh" > /tmp/atropos_api.log 2>&1 & +API_PID=$! +echo "API started (PID: $API_PID, log: /tmp/atropos_api.log)" +echo "Waiting for API to be ready..." +sleep 5 + +# Check if API is running +if ! curl -s http://localhost:8000/ > /dev/null; then + echo "ERROR: Atropos API failed to start" + echo "Check log at: /tmp/atropos_api.log" + exit 1 +fi +echo "API is ready" +echo "" + +# Step 2: Start SGLang servers +echo "Step 2/4: Starting SGLang Inference Servers..." +"$SCRIPT_DIR/start_sglang.sh" > /tmp/sglang_launcher.log 2>&1 & +SGLANG_PID=$! +echo "SGLang launcher started (PID: $SGLANG_PID)" +echo "Waiting for SGLang servers to load models (this may take ~30 seconds)..." +sleep 35 + +# Check if SGLang servers are running +SGLANG_READY=true +for PORT in 9001 9002; do + if ! curl -s "http://localhost:${PORT}/v1/models" > /dev/null; then + echo "WARNING: SGLang server on port $PORT is not responding" + SGLANG_READY=false + fi +done + +if [ "$SGLANG_READY" = false ]; then + echo "WARNING: Some SGLang servers may not be ready" + echo "Check logs at: /tmp/sglang_server_*.log" + echo "Continuing anyway..." +else + echo "SGLang servers are ready" +fi +echo "" + +# Step 3: Start GSM8k environment +echo "Step 3/4: Starting GSM8k Environment Server..." +"$SCRIPT_DIR/start_env.sh" > /tmp/gsm8k_env_wrapper.log 2>&1 & +ENV_PID=$! +echo "Environment started (PID: $ENV_PID, log: /tmp/gsm8k_env.log)" +echo "Waiting for environment to register..." +sleep 5 +echo "Environment should be running" +echo "" + +# Step 4: Start trainer +echo "Step 4/4: Starting TorchTitan Trainer..." +echo "========================================" +echo "" +"$SCRIPT_DIR/start_trainer.sh" + +# If we get here, training completed successfully +echo "" +echo "========================================" +echo "Test completed successfully!" +echo "========================================" +echo "" +echo "Logs available at:" +echo " - Atropos API: /tmp/atropos_api.log" +echo " - SGLang servers: /tmp/sglang_server_*.log" +echo " - GSM8k environment: /tmp/gsm8k_env.log" +echo " - Trainer: $LOGDIR" +echo "" diff --git a/torchtitan/grpo/test/scripts/start_api.sh b/torchtitan/grpo/test/scripts/start_api.sh new file mode 100755 index 0000000000..5cf335fc99 --- /dev/null +++ b/torchtitan/grpo/test/scripts/start_api.sh @@ -0,0 +1,19 @@ +#!/bin/bash +# Start Atropos API Server + +set -e + +echo "========================================" +echo "Starting Atropos API Server" +echo "========================================" + +# Add Atropos to PYTHONPATH +export PYTHONPATH=/home/shared/atropos:$PYTHONPATH + +# Change to Atropos directory +cd /home/shared/atropos + +# Start the API server +# The server will listen on http://localhost:8000 +echo "Starting API server on http://localhost:8000" +run-api diff --git a/torchtitan/grpo/test/scripts/start_env.sh b/torchtitan/grpo/test/scripts/start_env.sh new file mode 100755 index 0000000000..bc0c18e479 --- /dev/null +++ b/torchtitan/grpo/test/scripts/start_env.sh @@ -0,0 +1,58 @@ +#!/bin/bash +# Start GSM8k Environment Server + +set -e + +echo "========================================" +echo "Starting GSM8k Environment Server" +echo "========================================" + +# Add Atropos to PYTHONPATH +export PYTHONPATH=/home/shared/atropos:$PYTHONPATH + +# Get the torchtitan root directory +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +TORCHTITAN_ROOT="$(cd "$SCRIPT_DIR/../../../.." && pwd)" + +echo "TorchTitan root: $TORCHTITAN_ROOT" +cd "$TORCHTITAN_ROOT" + +# Configuration +MODEL_NAME="Qwen/Qwen3-1.7B" +SGLANG_URL_1="http://localhost:9001/v1" +SGLANG_URL_2="http://localhost:9002/v1" + +# Check if Atropos is accessible +if ! python -c "from atroposlib.envs.base import BaseEnv" 2>/dev/null; then + echo "ERROR: Cannot import Atropos. Is PYTHONPATH set correctly?" + echo "PYTHONPATH=$PYTHONPATH" + exit 1 +fi + +# Check if SGLang servers are running +echo "Checking SGLang server availability..." +if ! curl -s "$SGLANG_URL_1/models" > /dev/null; then + echo "WARNING: SGLang server at $SGLANG_URL_1 is not responding" +fi +if ! curl -s "$SGLANG_URL_2/models" > /dev/null; then + echo "WARNING: SGLang server at $SGLANG_URL_2 is not responding" +fi + +# Check if Atropos API is running +echo "Checking Atropos API availability..." +if ! curl -s "http://localhost:8000/" > /dev/null; then + echo "ERROR: Atropos API is not running on http://localhost:8000" + echo "Please start the API server first (./start_api.sh)" + exit 1 +fi + +echo "" +echo "Starting GSM8k environment..." +python torchtitan/grpo/test/gsm8k_server.py serve \ + --slurm false \ + --openai.model_name "$MODEL_NAME" \ + 2>&1 | tee /tmp/gsm8k_env.log + +echo "" +echo "GSM8k environment stopped" +echo "Log available at: /tmp/gsm8k_env.log" diff --git a/torchtitan/grpo/test/scripts/start_sglang.sh b/torchtitan/grpo/test/scripts/start_sglang.sh new file mode 100755 index 0000000000..e331fae250 --- /dev/null +++ b/torchtitan/grpo/test/scripts/start_sglang.sh @@ -0,0 +1,58 @@ +#!/bin/bash +# Start SGLang Inference Servers + +set -e + +echo "========================================" +echo "Starting SGLang Inference Servers" +echo "========================================" + +# Configuration +MODEL_PATH="/home/shared/torchtitan-conversions/qwen3_1.7b" # update with model path +TP_SIZE=1 +NUM_SERVERS=2 +BASE_PORT=9001 + +# Check if model path exists +if [ ! -d "$MODEL_PATH" ]; then + echo "ERROR: Model path not found: $MODEL_PATH" + echo "Please update MODEL_PATH in this script to point to your Qwen3-1.7B checkpoint" + exit 1 +fi + +# Start SGLang servers +for i in $(seq 0 $((NUM_SERVERS - 1))); do + PORT=$((BASE_PORT + i)) + echo "Starting SGLang server $((i+1))/$NUM_SERVERS on port $PORT..." + + python -m sglang.launch_server \ + --model-path "$MODEL_PATH" \ + --port $PORT \ + --tp $TP_SIZE \ + --host 0.0.0.0 \ + --log-level info \ + 2>&1 | tee "/tmp/sglang_server_${PORT}.log" & + + echo "SGLang server $((i+1)) starting (PID: $!)" +done + +echo "" +echo "All SGLang servers started!" +echo "Waiting for servers to be ready..." +sleep 30 + +# Test server connectivity +echo "" +echo "Testing server connectivity..." +for i in $(seq 0 $((NUM_SERVERS - 1))); do + PORT=$((BASE_PORT + i)) + if curl -s "http://localhost:${PORT}/v1/models" > /dev/null; then + echo "Server on port $PORT is ready" + else + echo "Server on port $PORT is not responding" + fi +done + +echo "" +echo "SGLang servers are ready for inference!" +echo "Logs available at: /tmp/sglang_server_*.log" diff --git a/torchtitan/grpo/test/scripts/start_trainer.sh b/torchtitan/grpo/test/scripts/start_trainer.sh new file mode 100755 index 0000000000..fd36ebf689 --- /dev/null +++ b/torchtitan/grpo/test/scripts/start_trainer.sh @@ -0,0 +1,68 @@ +#!/bin/bash +# Start TorchTitan RL Trainer +# This pulls batches from Atropos API and trains the model + +set -e + +echo "========================================" +echo "Starting TorchTitan RL Trainer" +echo "========================================" + +# Add Atropos to PYTHONPATH +export PYTHONPATH=/home/shared/atropos:$PYTHONPATH + +# Get the torchtitan root directory +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +TORCHTITAN_ROOT="$(cd "$SCRIPT_DIR/../../../.." && pwd)" + +echo "TorchTitan root: $TORCHTITAN_ROOT" +cd "$TORCHTITAN_ROOT" + +# Configuration +CONFIG_FILE="torchtitan/grpo/test/test_config.toml" +NGPU=${NGPU:-4} # default to 4 GPUs, override with: NGPU=8 ./start_trainer.sh +LOG_RANK=${LOG_RANK:-0} + +# Set required environment variables +export LOGDIR="${LOGDIR:-/tmp/torchtitan_logs}" +mkdir -p "$LOGDIR" +echo "Logs will be written to: $LOGDIR" + +# Check if config file exists +if [ ! -f "$CONFIG_FILE" ]; then + echo "ERROR: Config file not found: $CONFIG_FILE" + exit 1 +fi + +# Check if Atropos API is running +echo "Checking Atropos API availability..." +if ! curl -s "http://localhost:8000/" > /dev/null; then + echo "ERROR: Atropos API is not running on http://localhost:8000" + echo "Please start the API server first (./start_api.sh)" + exit 1 +fi + +echo "" +echo "Configuration:" +echo " - Config file: $CONFIG_FILE" +echo " - Number of GPUs: $NGPU" +echo " - Log directory: $LOGDIR" +echo " - Log rank filter: $LOG_RANK" +echo "" + +# Launch trainer with torchrun +echo "Launching trainer..." +PYTORCH_ALLOC_CONF="expandable_segments:True" \ +torchrun \ + --nproc_per_node=$NGPU \ + --rdzv_backend c10d \ + --rdzv_endpoint="localhost:0" \ + --local-ranks-filter $LOG_RANK \ + --role rank \ + --tee 3 \ + -m torchtitan.grpo_train \ + --job.config_file "$CONFIG_FILE" + +echo "" +echo "Training completed!" +echo "Check logs at: $LOGDIR" diff --git a/torchtitan/grpo/test/test_config.toml b/torchtitan/grpo/test/test_config.toml new file mode 100644 index 0000000000..612d2338fc --- /dev/null +++ b/torchtitan/grpo/test/test_config.toml @@ -0,0 +1,101 @@ +# torchtitan config.toml - GSM8k Test Configuration +# Test configuration for Qwen3-1.7B on GSM8k environment + +[job] +dump_folder = "/tmp/gsm8k_test_run" +description = "gsm8k_qwen3_1.7b_test" + +[profiling] +enable_profiling = false +save_traces_folder = "profile_trace" +profile_freq = 100 + +[metrics] +log_freq = 1 +enable_tensorboard = false +save_tb_folder = "tb" +enable_wandb = false + +[model] +name = "qwen3" +flavor = "1.7B" +tokenizer_path = "Qwen/Qwen3-1.7B" + +[optimizer] +name = "AdamW" +lr = 1e-6 +beta1 = 0.9 +beta2 = 0.95 +weight_decay = 0.1 + +[lr_scheduler] +warmup_steps = 10 +decay_type = "linear" +decay_ratio = 0.1 + +[training] +local_batch_size = 1 +seq_len = 2048 +global_batch_size = 32 # 8 rollouts * 4 gradient accumulation steps + +max_norm = 0.25 # grad norm clipping +steps = 100 # test run - 100 steps to validate pipeline + +[compile] +enable = false +components = ["model", "loss"] + +[parallelism] +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = -1 # Use all available GPUs for data parallel sharding +tensor_parallel_degree = 1 +context_parallel_degree = 1 +pipeline_parallel_degree = 1 + +[grpo] +sglang_tp = 1 +# update these URLs based on SGLang server configuration +sglang_urls = ["localhost:9001", "localhost:9002"] +sglang_slurm_num_nodes = 1 +sglang_port = 26756 + +# GRPO hyperparameters +logit_loss_weight = 0.0 +entropy_loss_weight = 0.0000 +kl_beta = 0.000 +kl_estimator_type = "k3" +ref_model_ema = 0.999 +clip_ratio_lower_bound = 0.0003 +clip_ratio_upper_bound = 0.0004 +policy_ratio_type = "sequence" +pos_scaler = 1.00 +neg_scaler = 1.00 +grpo_by_token = true +scale_adv_by_len = false +num_microbatches = 2 +onpolicy_logp_threshold = 0.0 +rollout_is_level = "sequence" +rollout_is_mode = "truncate" +rollout_is_threshold = 4.0 + +# disabled for this test +ptx_mixin_batchsize = 0 +ptx_scale_by_tokens = false + +[checkpoint] +enable = true +folder = "checkpoints" +# Update this path to point to your Qwen3-1.7B checkpoint +initial_load_path = "/home/shared/torchtitan-conversions/qwen3_1.7b" +initial_load_legacy = true +interval = 50 +export_dtype = "float32" +async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] + +[activation_checkpoint] +mode = 'selective' # ['none', 'selective', 'full'] +selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy + +[quantize.linear.float8] +enable_fsdp_float8_all_gather = false +precompute_float8_dynamic_scale_for_fsdp = false From 71a9a1d993f4fd6bb3b9041258de567a38cc00d6 Mon Sep 17 00:00:00 2001 From: Sam Herring Date: Wed, 14 Jan 2026 15:59:52 -0800 Subject: [PATCH 2/8] Updating scripts to use 1 server and killing sglang script when tests are cancelled / done --- torchtitan/grpo/test/gsm8k_server.py | 6 -- torchtitan/grpo/test/scripts/run_full_test.sh | 5 +- torchtitan/grpo/test/scripts/start_api.sh | 3 +- torchtitan/grpo/test/scripts/start_env.sh | 20 ++--- torchtitan/grpo/test/scripts/start_sglang.sh | 74 ++++++++++--------- torchtitan/grpo/test/scripts/start_trainer.sh | 3 +- torchtitan/grpo/test/test_config.toml | 3 +- 7 files changed, 52 insertions(+), 62 deletions(-) diff --git a/torchtitan/grpo/test/gsm8k_server.py b/torchtitan/grpo/test/gsm8k_server.py index 422dcec923..fc9ecf47b8 100644 --- a/torchtitan/grpo/test/gsm8k_server.py +++ b/torchtitan/grpo/test/gsm8k_server.py @@ -74,12 +74,6 @@ def config_init(cls) -> Tuple[BaseEnvConfig, List[APIServerConfig]]: api_key="x", num_requests_for_eval=256, ), - APIServerConfig( - model_name="Qwen/Qwen3-1.7B", - base_url="http://localhost:9002/v1", - api_key="x", - num_requests_for_eval=256, - ), ] return env_config, server_configs diff --git a/torchtitan/grpo/test/scripts/run_full_test.sh b/torchtitan/grpo/test/scripts/run_full_test.sh index a6898a1f73..c5e1b77107 100755 --- a/torchtitan/grpo/test/scripts/run_full_test.sh +++ b/torchtitan/grpo/test/scripts/run_full_test.sh @@ -34,10 +34,11 @@ cleanup() { if [ ! -z "$SGLANG_PID" ]; then echo "Stopping SGLang servers (PID: $SGLANG_PID)" kill $SGLANG_PID 2>/dev/null || true - # Also kill any remaining sglang processes - pkill -f "sglang.launch_server" 2>/dev/null || true fi + echo "Force-killing any remaining SGLang processes..." + pkill -9 -f "sglang.launch_server" 2>/dev/null || true + if [ ! -z "$ENV_PID" ]; then echo "Stopping GSM8k environment (PID: $ENV_PID)" kill $ENV_PID 2>/dev/null || true diff --git a/torchtitan/grpo/test/scripts/start_api.sh b/torchtitan/grpo/test/scripts/start_api.sh index 5cf335fc99..a0374b6b2e 100755 --- a/torchtitan/grpo/test/scripts/start_api.sh +++ b/torchtitan/grpo/test/scripts/start_api.sh @@ -7,8 +7,7 @@ echo "========================================" echo "Starting Atropos API Server" echo "========================================" -# Add Atropos to PYTHONPATH -export PYTHONPATH=/home/shared/atropos:$PYTHONPATH +source /home/nightwing/Projects/torchtitan/.venv/bin/activate # Change to Atropos directory cd /home/shared/atropos diff --git a/torchtitan/grpo/test/scripts/start_env.sh b/torchtitan/grpo/test/scripts/start_env.sh index bc0c18e479..acc6e35565 100755 --- a/torchtitan/grpo/test/scripts/start_env.sh +++ b/torchtitan/grpo/test/scripts/start_env.sh @@ -7,8 +7,8 @@ echo "========================================" echo "Starting GSM8k Environment Server" echo "========================================" -# Add Atropos to PYTHONPATH -export PYTHONPATH=/home/shared/atropos:$PYTHONPATH +# Activate TorchTitan venv (has atroposlib installed) +source /home/nightwing/Projects/torchtitan/.venv/bin/activate # Get the torchtitan root directory SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" @@ -19,23 +19,19 @@ cd "$TORCHTITAN_ROOT" # Configuration MODEL_NAME="Qwen/Qwen3-1.7B" -SGLANG_URL_1="http://localhost:9001/v1" -SGLANG_URL_2="http://localhost:9002/v1" +SGLANG_URL="http://localhost:9001/v1" # Check if Atropos is accessible if ! python -c "from atroposlib.envs.base import BaseEnv" 2>/dev/null; then - echo "ERROR: Cannot import Atropos. Is PYTHONPATH set correctly?" - echo "PYTHONPATH=$PYTHONPATH" + echo "ERROR: Cannot import Atropos. Is it installed in the venv?" + echo "Run: pip install -e /home/shared/atropos" exit 1 fi -# Check if SGLang servers are running +# Check if SGLang server is running echo "Checking SGLang server availability..." -if ! curl -s "$SGLANG_URL_1/models" > /dev/null; then - echo "WARNING: SGLang server at $SGLANG_URL_1 is not responding" -fi -if ! curl -s "$SGLANG_URL_2/models" > /dev/null; then - echo "WARNING: SGLang server at $SGLANG_URL_2 is not responding" +if ! curl -s "$SGLANG_URL/models" > /dev/null; then + echo "WARNING: SGLang server at $SGLANG_URL is not responding" fi # Check if Atropos API is running diff --git a/torchtitan/grpo/test/scripts/start_sglang.sh b/torchtitan/grpo/test/scripts/start_sglang.sh index e331fae250..31618248f9 100755 --- a/torchtitan/grpo/test/scripts/start_sglang.sh +++ b/torchtitan/grpo/test/scripts/start_sglang.sh @@ -1,58 +1,60 @@ #!/bin/bash -# Start SGLang Inference Servers +# Start SGLang Inference Server set -e echo "========================================" -echo "Starting SGLang Inference Servers" +echo "Starting SGLang Inference Server" echo "========================================" +# Cleanup function +cleanup_sglang() { + echo "Cleaning up any existing SGLang processes..." + pkill -9 -f "sglang.launch_server" 2>/dev/null || true + sleep 2 +} + +# Run cleanup first +cleanup_sglang + +source /home/nightwing/Projects/torchtitan/sglangvenv/bin/activate + # Configuration -MODEL_PATH="/home/shared/torchtitan-conversions/qwen3_1.7b" # update with model path +MODEL_PATH="/home/nightwing/Projects/torchtitan/tmp/qwen3-1.7b-hf" # HF path TP_SIZE=1 -NUM_SERVERS=2 +NUM_SERVERS=1 # Using 1 server for testing to avoid OOM BASE_PORT=9001 -# Check if model path exists -if [ ! -d "$MODEL_PATH" ]; then - echo "ERROR: Model path not found: $MODEL_PATH" - echo "Please update MODEL_PATH in this script to point to your Qwen3-1.7B checkpoint" - exit 1 -fi +# Note: Using HF model name - SGLang will auto-download if needed -# Start SGLang servers -for i in $(seq 0 $((NUM_SERVERS - 1))); do - PORT=$((BASE_PORT + i)) - echo "Starting SGLang server $((i+1))/$NUM_SERVERS on port $PORT..." +# Start SGLang server +echo "Starting SGLang server on port $BASE_PORT..." - python -m sglang.launch_server \ - --model-path "$MODEL_PATH" \ - --port $PORT \ - --tp $TP_SIZE \ - --host 0.0.0.0 \ - --log-level info \ - 2>&1 | tee "/tmp/sglang_server_${PORT}.log" & +python -m sglang.launch_server \ + --model-path "$MODEL_PATH" \ + --port $BASE_PORT \ + --tp $TP_SIZE \ + --host 0.0.0.0 \ + --log-level info \ + 2>&1 | tee "/tmp/sglang_server_${BASE_PORT}.log" & - echo "SGLang server $((i+1)) starting (PID: $!)" -done +SERVER_PID=$! +echo "SGLang server starting (PID: $SERVER_PID)" echo "" -echo "All SGLang servers started!" -echo "Waiting for servers to be ready..." -sleep 30 +echo "Waiting for server to be ready (~30 seconds)..." +sleep 60 # Test server connectivity echo "" echo "Testing server connectivity..." -for i in $(seq 0 $((NUM_SERVERS - 1))); do - PORT=$((BASE_PORT + i)) - if curl -s "http://localhost:${PORT}/v1/models" > /dev/null; then - echo "Server on port $PORT is ready" - else - echo "Server on port $PORT is not responding" - fi -done +if curl -s "http://localhost:${BASE_PORT}/v1/models" > /dev/null; then + echo "Server on port $BASE_PORT is ready" +else + echo "Server on port $BASE_PORT is not responding" + echo "Check log at: /tmp/sglang_server_${BASE_PORT}.log" +fi echo "" -echo "SGLang servers are ready for inference!" -echo "Logs available at: /tmp/sglang_server_*.log" +echo "SGLang server ready for inference!" +echo "Log available at: /tmp/sglang_server_${BASE_PORT}.log" diff --git a/torchtitan/grpo/test/scripts/start_trainer.sh b/torchtitan/grpo/test/scripts/start_trainer.sh index fd36ebf689..bea514916f 100755 --- a/torchtitan/grpo/test/scripts/start_trainer.sh +++ b/torchtitan/grpo/test/scripts/start_trainer.sh @@ -8,8 +8,7 @@ echo "========================================" echo "Starting TorchTitan RL Trainer" echo "========================================" -# Add Atropos to PYTHONPATH -export PYTHONPATH=/home/shared/atropos:$PYTHONPATH +source /home/nightwing/Projects/torchtitan/.venv/bin/activate # Get the torchtitan root directory SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" diff --git a/torchtitan/grpo/test/test_config.toml b/torchtitan/grpo/test/test_config.toml index 612d2338fc..de95b30915 100644 --- a/torchtitan/grpo/test/test_config.toml +++ b/torchtitan/grpo/test/test_config.toml @@ -54,8 +54,7 @@ pipeline_parallel_degree = 1 [grpo] sglang_tp = 1 -# update these URLs based on SGLang server configuration -sglang_urls = ["localhost:9001", "localhost:9002"] +sglang_urls = ["localhost:9001"] sglang_slurm_num_nodes = 1 sglang_port = 26756 From abc8fb4b3a03cb1a3b5fc3810792d2ba04588cb5 Mon Sep 17 00:00:00 2001 From: Sam Herring Date: Wed, 14 Jan 2026 16:06:42 -0800 Subject: [PATCH 3/8] Moving to vLLM inference --- torchtitan/grpo/test/scripts/run_full_test.sh | 48 +++++++-------- torchtitan/grpo/test/scripts/start_env.sh | 10 ++-- torchtitan/grpo/test/scripts/start_sglang.sh | 60 ------------------- torchtitan/grpo/test/scripts/start_vllm.sh | 57 ++++++++++++++++++ 4 files changed, 86 insertions(+), 89 deletions(-) delete mode 100755 torchtitan/grpo/test/scripts/start_sglang.sh create mode 100755 torchtitan/grpo/test/scripts/start_vllm.sh diff --git a/torchtitan/grpo/test/scripts/run_full_test.sh b/torchtitan/grpo/test/scripts/run_full_test.sh index c5e1b77107..7f70fb3a0f 100755 --- a/torchtitan/grpo/test/scripts/run_full_test.sh +++ b/torchtitan/grpo/test/scripts/run_full_test.sh @@ -12,7 +12,7 @@ echo "========================================" echo "" echo "This script will start all components:" echo " 1. Atropos API Server" -echo " 2. SGLang Inference Servers" +echo " 2. vLLM Inference Server" echo " 3. GSM8k Environment Server" echo " 4. TorchTitan Trainer" echo "" @@ -31,13 +31,14 @@ cleanup() { kill $API_PID 2>/dev/null || true fi - if [ ! -z "$SGLANG_PID" ]; then - echo "Stopping SGLang servers (PID: $SGLANG_PID)" - kill $SGLANG_PID 2>/dev/null || true + if [ ! -z "$VLLM_PID" ]; then + echo "Stopping vLLM server (PID: $VLLM_PID)" + kill $VLLM_PID 2>/dev/null || true fi - echo "Force-killing any remaining SGLang processes..." - pkill -9 -f "sglang.launch_server" 2>/dev/null || true + echo "Force-killing any remaining vLLM processes..." + pkill -9 -f "vllm.entrypoints.openai.api_server" 2>/dev/null || true + pkill -9 -f "vllm serve" 2>/dev/null || true if [ ! -z "$ENV_PID" ]; then echo "Stopping GSM8k environment (PID: $ENV_PID)" @@ -68,29 +69,28 @@ fi echo "API is ready" echo "" -# Step 2: Start SGLang servers -echo "Step 2/4: Starting SGLang Inference Servers..." -"$SCRIPT_DIR/start_sglang.sh" > /tmp/sglang_launcher.log 2>&1 & -SGLANG_PID=$! -echo "SGLang launcher started (PID: $SGLANG_PID)" -echo "Waiting for SGLang servers to load models (this may take ~30 seconds)..." +# Step 2: Start vLLM server +echo "Step 2/4: Starting vLLM Inference Server..." +"$SCRIPT_DIR/start_vllm.sh" > /tmp/vllm_launcher.log 2>&1 & +VLLM_PID=$! +echo "vLLM launcher started (PID: $VLLM_PID)" +echo "Waiting for vLLM server to load model (this may take ~30 seconds)..." sleep 35 -# Check if SGLang servers are running -SGLANG_READY=true -for PORT in 9001 9002; do - if ! curl -s "http://localhost:${PORT}/v1/models" > /dev/null; then - echo "WARNING: SGLang server on port $PORT is not responding" - SGLANG_READY=false - fi -done +# Check if vLLM server is running +VLLM_READY=true +PORT=9001 +if ! curl -s "http://localhost:${PORT}/v1/models" > /dev/null; then + echo "WARNING: vLLM server on port $PORT is not responding" + VLLM_READY=false +fi -if [ "$SGLANG_READY" = false ]; then - echo "WARNING: Some SGLang servers may not be ready" - echo "Check logs at: /tmp/sglang_server_*.log" +if [ "$VLLM_READY" = false ]; then + echo "WARNING: vLLM server may not be ready" + echo "Check log at: /tmp/vllm_server_${PORT}.log" echo "Continuing anyway..." else - echo "SGLang servers are ready" + echo "vLLM server is ready" fi echo "" diff --git a/torchtitan/grpo/test/scripts/start_env.sh b/torchtitan/grpo/test/scripts/start_env.sh index acc6e35565..7160cf88ac 100755 --- a/torchtitan/grpo/test/scripts/start_env.sh +++ b/torchtitan/grpo/test/scripts/start_env.sh @@ -19,7 +19,7 @@ cd "$TORCHTITAN_ROOT" # Configuration MODEL_NAME="Qwen/Qwen3-1.7B" -SGLANG_URL="http://localhost:9001/v1" +VLLM_URL="http://localhost:9001/v1" # Check if Atropos is accessible if ! python -c "from atroposlib.envs.base import BaseEnv" 2>/dev/null; then @@ -28,10 +28,10 @@ if ! python -c "from atroposlib.envs.base import BaseEnv" 2>/dev/null; then exit 1 fi -# Check if SGLang server is running -echo "Checking SGLang server availability..." -if ! curl -s "$SGLANG_URL/models" > /dev/null; then - echo "WARNING: SGLang server at $SGLANG_URL is not responding" +# Check if vLLM server is running +echo "Checking vLLM server availability..." +if ! curl -s "$VLLM_URL/models" > /dev/null; then + echo "WARNING: vLLM server at $VLLM_URL is not responding" fi # Check if Atropos API is running diff --git a/torchtitan/grpo/test/scripts/start_sglang.sh b/torchtitan/grpo/test/scripts/start_sglang.sh deleted file mode 100755 index 31618248f9..0000000000 --- a/torchtitan/grpo/test/scripts/start_sglang.sh +++ /dev/null @@ -1,60 +0,0 @@ -#!/bin/bash -# Start SGLang Inference Server - -set -e - -echo "========================================" -echo "Starting SGLang Inference Server" -echo "========================================" - -# Cleanup function -cleanup_sglang() { - echo "Cleaning up any existing SGLang processes..." - pkill -9 -f "sglang.launch_server" 2>/dev/null || true - sleep 2 -} - -# Run cleanup first -cleanup_sglang - -source /home/nightwing/Projects/torchtitan/sglangvenv/bin/activate - -# Configuration -MODEL_PATH="/home/nightwing/Projects/torchtitan/tmp/qwen3-1.7b-hf" # HF path -TP_SIZE=1 -NUM_SERVERS=1 # Using 1 server for testing to avoid OOM -BASE_PORT=9001 - -# Note: Using HF model name - SGLang will auto-download if needed - -# Start SGLang server -echo "Starting SGLang server on port $BASE_PORT..." - -python -m sglang.launch_server \ - --model-path "$MODEL_PATH" \ - --port $BASE_PORT \ - --tp $TP_SIZE \ - --host 0.0.0.0 \ - --log-level info \ - 2>&1 | tee "/tmp/sglang_server_${BASE_PORT}.log" & - -SERVER_PID=$! -echo "SGLang server starting (PID: $SERVER_PID)" - -echo "" -echo "Waiting for server to be ready (~30 seconds)..." -sleep 60 - -# Test server connectivity -echo "" -echo "Testing server connectivity..." -if curl -s "http://localhost:${BASE_PORT}/v1/models" > /dev/null; then - echo "Server on port $BASE_PORT is ready" -else - echo "Server on port $BASE_PORT is not responding" - echo "Check log at: /tmp/sglang_server_${BASE_PORT}.log" -fi - -echo "" -echo "SGLang server ready for inference!" -echo "Log available at: /tmp/sglang_server_${BASE_PORT}.log" diff --git a/torchtitan/grpo/test/scripts/start_vllm.sh b/torchtitan/grpo/test/scripts/start_vllm.sh new file mode 100755 index 0000000000..d3488ebcd9 --- /dev/null +++ b/torchtitan/grpo/test/scripts/start_vllm.sh @@ -0,0 +1,57 @@ +#!/bin/bash +# Start vLLM Inference Server + +set -e + +echo "========================================" +echo "Starting vLLM Inference Server" +echo "========================================" + +# Cleanup function +cleanup_vllm() { + echo "Cleaning up any existing vLLM processes..." + pkill -9 -f "vllm.entrypoints.openai.api_server" 2>/dev/null || true + pkill -9 -f "vllm serve" 2>/dev/null || true + sleep 2 +} + +cleanup_vllm + +source /home/nightwing/Projects/torchtitan/.venv/bin/activate + +# Configuration +MODEL_PATH="/home/nightwing/Projects/torchtitan/tmp/qwen3-1.7b-hf" # HF checkpoint path or name +TP_SIZE=1 +BASE_PORT=9001 + +echo "Starting vLLM server on port $BASE_PORT..." +echo "Model: $MODEL_PATH" + +python -m vllm.entrypoints.openai.api_server \ + --model "$MODEL_PATH" \ + --port $BASE_PORT \ + --host 0.0.0.0 \ + --tensor-parallel-size $TP_SIZE \ + --trust-remote-code \ + --gpu-memory-utilization 0.85 \ + 2>&1 | tee "/tmp/vllm_server_${BASE_PORT}.log" & + +SERVER_PID=$! +echo "vLLM server starting (PID: $SERVER_PID)" + +echo "" +echo "Waiting for server to be ready (~30 seconds)..." +sleep 60 + +echo "" +echo "Testing server connectivity..." +if curl -s "http://localhost:${BASE_PORT}/v1/models" > /dev/null; then + echo "✓ vLLM server on port $BASE_PORT is ready" +else + echo "✗ vLLM server on port $BASE_PORT is not responding" + echo "Check log at: /tmp/vllm_server_${BASE_PORT}.log" +fi + +echo "" +echo "vLLM server ready for inference!" +echo "Log available at: /tmp/vllm_server_${BASE_PORT}.log" From 237e76987762b5b344fb534df5e2512ea48efb1b Mon Sep 17 00:00:00 2001 From: Sam Herring Date: Fri, 23 Jan 2026 17:20:15 -0700 Subject: [PATCH 4/8] Working update --- torchtitan/grpo/test/scripts/run_full_test.sh | 6 +- torchtitan/grpo/test/scripts/start_vllm.sh | 31 ++++-- torchtitan/grpo/test/scripts/test_launcher.sh | 103 ++++++++++++++++++ torchtitan/grpo/test/test_config.toml | 4 +- torchtitan/grpo/test/test_full_rl.slurm | 52 +++++++++ torchtitan/grpo/test/test_single_node.slurm | 40 +++++++ 6 files changed, 226 insertions(+), 10 deletions(-) create mode 100644 torchtitan/grpo/test/scripts/test_launcher.sh create mode 100644 torchtitan/grpo/test/test_full_rl.slurm create mode 100644 torchtitan/grpo/test/test_single_node.slurm diff --git a/torchtitan/grpo/test/scripts/run_full_test.sh b/torchtitan/grpo/test/scripts/run_full_test.sh index 7f70fb3a0f..fb7228538d 100755 --- a/torchtitan/grpo/test/scripts/run_full_test.sh +++ b/torchtitan/grpo/test/scripts/run_full_test.sh @@ -39,6 +39,8 @@ cleanup() { echo "Force-killing any remaining vLLM processes..." pkill -9 -f "vllm.entrypoints.openai.api_server" 2>/dev/null || true pkill -9 -f "vllm serve" 2>/dev/null || true + pkill -9 -f "torchtitan.grpo.vllm_handling.vllm_runner" 2>/dev/null || true + lsof -ti:9001 | xargs kill -9 2>/dev/null || true if [ ! -z "$ENV_PID" ]; then echo "Stopping GSM8k environment (PID: $ENV_PID)" @@ -80,7 +82,7 @@ sleep 35 # Check if vLLM server is running VLLM_READY=true PORT=9001 -if ! curl -s "http://localhost:${PORT}/v1/models" > /dev/null; then +if ! curl -s "http://localhost:${PORT}/health" > /dev/null; then echo "WARNING: vLLM server on port $PORT is not responding" VLLM_READY=false fi @@ -88,6 +90,8 @@ fi if [ "$VLLM_READY" = false ]; then echo "WARNING: vLLM server may not be ready" echo "Check log at: /tmp/vllm_server_${PORT}.log" + echo "Last 30 lines of log:" + tail -30 /tmp/vllm_server_${PORT}.log 2>/dev/null || echo "Log file not found" echo "Continuing anyway..." else echo "vLLM server is ready" diff --git a/torchtitan/grpo/test/scripts/start_vllm.sh b/torchtitan/grpo/test/scripts/start_vllm.sh index d3488ebcd9..622054ea30 100755 --- a/torchtitan/grpo/test/scripts/start_vllm.sh +++ b/torchtitan/grpo/test/scripts/start_vllm.sh @@ -12,29 +12,44 @@ cleanup_vllm() { echo "Cleaning up any existing vLLM processes..." pkill -9 -f "vllm.entrypoints.openai.api_server" 2>/dev/null || true pkill -9 -f "vllm serve" 2>/dev/null || true + pkill -9 -f "torchtitan.grpo.vllm_handling.vllm_runner" 2>/dev/null || true + lsof -ti:9001 | xargs kill -9 2>/dev/null || true sleep 2 } cleanup_vllm -source /home/nightwing/Projects/torchtitan/.venv/bin/activate +# Use the separate vLLM environment (not the training env) +source /home/nightwing/envs/vllm/.venv/bin/activate # Configuration MODEL_PATH="/home/nightwing/Projects/torchtitan/tmp/qwen3-1.7b-hf" # HF checkpoint path or name TP_SIZE=1 BASE_PORT=9001 +# Set LOGDIR to match the trainer (needed for distributed_updater coordination) +export LOGDIR="${LOGDIR:-/tmp/torchtitan_logs}" +mkdir -p "$LOGDIR" + +# Set NUM_INFERENCE_NODES=0 for single node setup (required by distributed_updater) +export NUM_INFERENCE_NODES=0 + echo "Starting vLLM server on port $BASE_PORT..." +echo "CUDA_VISIBLE_DEVICES: 4" echo "Model: $MODEL_PATH" +echo "LOGDIR: $LOGDIR" +echo "NUM_INFERENCE_NODES: $NUM_INFERENCE_NODES" -python -m vllm.entrypoints.openai.api_server \ +# Run vLLM on GPU 4 (training uses GPUs 0-3) +# IMPORTANT: Set CUDA_VISIBLE_DEVICES as prefix, not export +CUDA_VISIBLE_DEVICES=4 nohup python -m torchtitan.grpo.vllm_handling.vllm_runner \ --model "$MODEL_PATH" \ --port $BASE_PORT \ --host 0.0.0.0 \ - --tensor-parallel-size $TP_SIZE \ - --trust-remote-code \ - --gpu-memory-utilization 0.85 \ - 2>&1 | tee "/tmp/vllm_server_${BASE_PORT}.log" & + --gpu-memory-utilization 0.75 \ + --dtype="bfloat16" \ + --log-level="error" \ + > "${LOGDIR}/vllm_${BASE_PORT}.log" 2>&1 & SERVER_PID=$! echo "vLLM server starting (PID: $SERVER_PID)" @@ -45,11 +60,13 @@ sleep 60 echo "" echo "Testing server connectivity..." -if curl -s "http://localhost:${BASE_PORT}/v1/models" > /dev/null; then +if curl -s "http://localhost:${BASE_PORT}/health" > /dev/null; then echo "✓ vLLM server on port $BASE_PORT is ready" else echo "✗ vLLM server on port $BASE_PORT is not responding" echo "Check log at: /tmp/vllm_server_${BASE_PORT}.log" + echo "Last 20 lines of log:" + tail -20 "/tmp/vllm_server_${BASE_PORT}.log" 2>/dev/null || echo "Log file not found" fi echo "" diff --git a/torchtitan/grpo/test/scripts/test_launcher.sh b/torchtitan/grpo/test/scripts/test_launcher.sh new file mode 100644 index 0000000000..ba12623911 --- /dev/null +++ b/torchtitan/grpo/test/scripts/test_launcher.sh @@ -0,0 +1,103 @@ +#!/bin/bash +# set -e temporarily disabled to see errors +set -x # Print commands for debugging + +printenv +ulimit -n 32000 + +# Get script directory +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +TORCHTITAN_ROOT="$(cd "$SCRIPT_DIR/../../../.." && pwd)" +cd "$TORCHTITAN_ROOT" + +# Set defaults if not running under SLURM +: "${SLURM_NODEID:=0}" +: "${NUM_TRAINING_NODES:=1}" +: "${NUM_INFERENCE_NODES:=0}" +: "${LOGDIR:=${TORCHTITAN_ROOT}/logs/test_run}" +: "${MODEL_NAME:=/home/nightwing/Projects/torchtitan/tmp/qwen3-1.7b-hf}" +: "${CONFIG_FILE:=${TORCHTITAN_ROOT}/torchtitan/grpo/test/test_config.toml}" +: "${API_ENV:=/home/nightwing/Projects/torchtitan/.venv}" +: "${TRAIN_ENV:=/home/nightwing/Projects/torchtitan/.venv}" +: "${VLLM_ENV:=/home/nightwing/envs/vllm/.venv}" + +# Export LOGDIR so child processes can see it +export LOGDIR +export NUM_INFERENCE_NODES +export MODEL_NAME +export CONFIG_FILE + +mkdir -p "$LOGDIR" + +echo "Starting test at $(date)" +echo "SLURM_NODEID: $SLURM_NODEID" +echo "NUM_TRAINING_NODES: $NUM_TRAINING_NODES" +echo "NUM_INFERENCE_NODES: $NUM_INFERENCE_NODES" +echo "LOGDIR: $LOGDIR" +echo "MODEL_NAME: $MODEL_NAME" + +# Start API and environment (always on node 0) +if [[ "$SLURM_NODEID" -eq 0 ]]; then + echo "Starting API and environment server..." + source ${API_ENV}/bin/activate + + # Start Atropos API + cd /home/shared/atropos + run-api > ${LOGDIR}/api.log 2>&1 & + cd "$TORCHTITAN_ROOT" + + # Start GSM8k environment server + python torchtitan/grpo/test/gsm8k_server.py serve --slurm=True --openai.model_name="$MODEL_NAME" > ${LOGDIR}/env_server.log 2>&1 & + + deactivate + echo "Started API and environment server..." +fi + +# Start training (on training nodes) +if [[ "$SLURM_NODEID" -lt "$NUM_TRAINING_NODES" ]]; then + echo "Setting up training environment..." + source ${TRAIN_ENV}/bin/activate + + nodes=( $( scontrol show hostnames $SLURM_JOB_NODELIST ) ) + nodes_array=($nodes) + head_node=${nodes_array[0]} + + export LOGLEVEL=INFO + export NCCL_DEBUG=WARN + export PYTHONFAULTHANDLER=1 + export LD_LIBRARY_PATH=/usr/local/lib/:$LD_LIBRARY_PATH + export CUDA_LAUNCH_BLOCKING=0 + + # Launch trainer (vLLM runs on separate inference node) + echo "Launching trainer..." + torchrun --nproc_per_node 8 --rdzv_id 101 --rdzv_backend c10d --rdzv_endpoint="$head_node:29500" --role rank --tee 3 \ + -m torchtitan.grpo_train --job.config_file ${CONFIG_FILE} +# else we're on an inference node +else + echo "Starting vLLM inference server..." + source ${VLLM_ENV}/bin/activate + + PORT_BASE=9000 + LOG_OFFSET=$((SLURM_NODEID * 8)) + + # Start 8 vLLM instances on GPUs 0-7 (matching dakota's setup) + for i in {0..7}; do + GPU_ID=$i + LOG_ID=$((GPU_ID + LOG_OFFSET)) + PORT=$((PORT_BASE + i)) + echo "Starting vLLM instance on GPU $GPU_ID, port $PORT" + CUDA_VISIBLE_DEVICES=$GPU_ID nohup python -m torchtitan.grpo.vllm_handling.vllm_runner \ + --model "$MODEL_NAME" \ + --host 0.0.0.0 \ + --gpu-memory-utilization 0.75 \ + --dtype="bfloat16" \ + --log-level="error" \ + --port $PORT > ${LOGDIR}/vllm_${LOG_ID}.log 2>&1 & + sleep 3 + done + + # Wait indefinitely (keep inference node alive) + wait +fi + +echo "Test completed at $(date)" diff --git a/torchtitan/grpo/test/test_config.toml b/torchtitan/grpo/test/test_config.toml index de95b30915..60fe9e32d6 100644 --- a/torchtitan/grpo/test/test_config.toml +++ b/torchtitan/grpo/test/test_config.toml @@ -14,7 +14,7 @@ profile_freq = 100 log_freq = 1 enable_tensorboard = false save_tb_folder = "tb" -enable_wandb = false +enable_wandb = true [model] name = "qwen3" @@ -55,7 +55,7 @@ pipeline_parallel_degree = 1 [grpo] sglang_tp = 1 sglang_urls = ["localhost:9001"] -sglang_slurm_num_nodes = 1 +sglang_slurm_num_nodes = 0 sglang_port = 26756 # GRPO hyperparameters diff --git a/torchtitan/grpo/test/test_full_rl.slurm b/torchtitan/grpo/test/test_full_rl.slurm new file mode 100644 index 0000000000..df9cdc567f --- /dev/null +++ b/torchtitan/grpo/test/test_full_rl.slurm @@ -0,0 +1,52 @@ +#!/bin/bash +#SBATCH --job-name=grpo_full_test +#SBATCH --output=logs/%j.out +#SBATCH --error=logs/%j.err +#SBATCH --nodes=2 +#SBATCH --ntasks-per-node=1 +#SBATCH --exclusive +#SBATCH --gpus-per-task=8 +#SBATCH --cpus-per-task=64 + +# Create logs directory +mkdir -p logs/$SLURM_JOB_ID + +# Set ulimit higher +ulimit -n 32000 +export LOGDIR="$(pwd)/logs/${SLURM_JOB_ID}" + +# echo slurm nodes +echo "SLURM nodes: $SLURM_JOB_NODELIST" + +# Basic config stuff - pointing to test setup +export CONFIG_FILE="$(pwd)/torchtitan/grpo/test/test_config.toml" +export MODEL_NAME="/home/nightwing/Projects/torchtitan/tmp/qwen3-1.7b-hf" +export PYTHON_SCRIPT="$(pwd)/torchtitan/grpo/test/gsm8k_server.py" +export PYTHON_ARGS="" +export TRAINING_ARGS="" +export NUM_TRAINING_NODES=1 +export NUM_INFERENCE_NODES=1 + +# NCCL settings +export NCCL_BUFFSIZE=33554432 +export CUDA_DEVICE_ORDER=PCI_BUS_ID +export NCCL_IB_AR_THRESHOLD=0 +export NCCL_IB_PCI_RELAXED_ORDERING=1 +export NCCL_IB_QPS_PER_CONNECTION=2 +export NCCL_IB_SPLIT_DATA_ON_QPS=0 +export NCCL_IGNORE_CPU_AFFINITY=1 + +# Define environment paths +export TRAIN_PATH="$(pwd)" +export TRAIN_ENV="${TRAIN_PATH}/.venv" +export VLLM_ENV="/home/nightwing/envs/vllm/.venv" +export API_ENV="${TRAIN_ENV}" + +# Get head node info +nodes=( $( scontrol show hostnames $SLURM_JOB_NODELIST ) ) +nodes_array=($nodes) +head_node=${nodes_array[0]} +export head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) + +# Use the proven working vllm_launch.sh script +srun -l --export=ALL ./vllm_launch.sh diff --git a/torchtitan/grpo/test/test_single_node.slurm b/torchtitan/grpo/test/test_single_node.slurm new file mode 100644 index 0000000000..05fb8989a3 --- /dev/null +++ b/torchtitan/grpo/test/test_single_node.slurm @@ -0,0 +1,40 @@ +#!/bin/bash +#SBATCH --job-name=grpo_two_node_test +#SBATCH --output=logs/%j.out +#SBATCH --error=logs/%j.err +#SBATCH --nodes=2 +#SBATCH --ntasks-per-node=1 +#SBATCH --exclusive +#SBATCH --gpus-per-task=8 +#SBATCH --cpus-per-task=64 + +# Create logs directory +mkdir -p logs/$SLURM_JOB_ID + +# Set ulimit higher +ulimit -n 32000 +export LOGDIR="$(pwd)/logs/${SLURM_JOB_ID}" + +# Basic config stuff +export CONFIG_FILE="$(pwd)/torchtitan/grpo/test/test_config.toml" +export MODEL_NAME="/home/nightwing/Projects/torchtitan/tmp/qwen3-1.7b-hf" +export PYTHON_SCRIPT="$(pwd)/torchtitan/grpo/test/gsm8k_server.py" +export PYTHON_ARGS="" +export TRAINING_ARGS="" +export NUM_TRAINING_NODES=1 +export NUM_INFERENCE_NODES=1 + +# Define environment paths +export TRAIN_PATH="$(pwd)" +export TRAIN_ENV="${TRAIN_PATH}/.venv" +export VLLM_ENV="/home/nightwing/envs/vllm/.venv" +export API_ENV="${TRAIN_ENV}" + +# Get head node info +nodes=( $( scontrol show hostnames $SLURM_JOB_NODELIST ) ) +nodes_array=($nodes) +head_node=${nodes_array[0]} +export head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) + +# Run the test using the same pattern as vllm_launch.sh +srun -l --export=ALL ./torchtitan/grpo/test/scripts/test_launcher.sh From 7c71712adeea462d093d9253685d4555914555b2 Mon Sep 17 00:00:00 2001 From: Sam Herring Date: Fri, 23 Jan 2026 19:47:38 -0700 Subject: [PATCH 5/8] Cleaning up --- torchtitan/grpo/test/test_single_node.slurm | 40 --------------------- 1 file changed, 40 deletions(-) delete mode 100644 torchtitan/grpo/test/test_single_node.slurm diff --git a/torchtitan/grpo/test/test_single_node.slurm b/torchtitan/grpo/test/test_single_node.slurm deleted file mode 100644 index 05fb8989a3..0000000000 --- a/torchtitan/grpo/test/test_single_node.slurm +++ /dev/null @@ -1,40 +0,0 @@ -#!/bin/bash -#SBATCH --job-name=grpo_two_node_test -#SBATCH --output=logs/%j.out -#SBATCH --error=logs/%j.err -#SBATCH --nodes=2 -#SBATCH --ntasks-per-node=1 -#SBATCH --exclusive -#SBATCH --gpus-per-task=8 -#SBATCH --cpus-per-task=64 - -# Create logs directory -mkdir -p logs/$SLURM_JOB_ID - -# Set ulimit higher -ulimit -n 32000 -export LOGDIR="$(pwd)/logs/${SLURM_JOB_ID}" - -# Basic config stuff -export CONFIG_FILE="$(pwd)/torchtitan/grpo/test/test_config.toml" -export MODEL_NAME="/home/nightwing/Projects/torchtitan/tmp/qwen3-1.7b-hf" -export PYTHON_SCRIPT="$(pwd)/torchtitan/grpo/test/gsm8k_server.py" -export PYTHON_ARGS="" -export TRAINING_ARGS="" -export NUM_TRAINING_NODES=1 -export NUM_INFERENCE_NODES=1 - -# Define environment paths -export TRAIN_PATH="$(pwd)" -export TRAIN_ENV="${TRAIN_PATH}/.venv" -export VLLM_ENV="/home/nightwing/envs/vllm/.venv" -export API_ENV="${TRAIN_ENV}" - -# Get head node info -nodes=( $( scontrol show hostnames $SLURM_JOB_NODELIST ) ) -nodes_array=($nodes) -head_node=${nodes_array[0]} -export head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) - -# Run the test using the same pattern as vllm_launch.sh -srun -l --export=ALL ./torchtitan/grpo/test/scripts/test_launcher.sh From 9e31cb8938cbf270f973f189c3559698cd14c681 Mon Sep 17 00:00:00 2001 From: Sam Herring Date: Thu, 5 Feb 2026 10:51:17 -0800 Subject: [PATCH 6/8] Updating all necessary pieces for a run --- torchtitan/grpo/data_handling.py | 9 ++-- torchtitan/grpo/test/gsm8k_server.py | 63 ++++++++++++++++++------- torchtitan/grpo/test/test_config.toml | 12 ++--- torchtitan/grpo/test/test_full_rl.slurm | 6 +-- torchtitan/grpo_train.py | 21 +++++---- vllm_launch.sh | 29 ++++++++---- 6 files changed, 88 insertions(+), 52 deletions(-) diff --git a/torchtitan/grpo/data_handling.py b/torchtitan/grpo/data_handling.py index a2b26c04a9..fd978e8e31 100644 --- a/torchtitan/grpo/data_handling.py +++ b/torchtitan/grpo/data_handling.py @@ -435,8 +435,7 @@ def data_handling( flag = flag + 1 torch.distributed.broadcast(flag, 0) - if dp_replicate_rank == 0: - send_wait(sglang_nccl_group, device) + send_wait(sglang_nccl_group, device) max_token_len = torch.tensor(max_token_len).to(device) torch.distributed.all_reduce(max_token_len) # back to int @@ -456,13 +455,11 @@ def data_handling( else: logger.debug("No batch yet, retrying...") torch.distributed.broadcast(flag, 0) - if dp_replicate_rank == 0: - send_wait(sglang_nccl_group, device) + send_wait(sglang_nccl_group, device) else: logger.debug("Waiting for batch from server...") torch.distributed.broadcast(flag, 0) - if dp_replicate_rank == 0: - send_wait(sglang_nccl_group, device) + send_wait(sglang_nccl_group, device) if flag.item() > 0: # Got the batch max_token_len = torch.tensor(0).to(device) diff --git a/torchtitan/grpo/test/gsm8k_server.py b/torchtitan/grpo/test/gsm8k_server.py index fc9ecf47b8..c9b6c17042 100644 --- a/torchtitan/grpo/test/gsm8k_server.py +++ b/torchtitan/grpo/test/gsm8k_server.py @@ -48,6 +48,10 @@ def __init__( testing=False, ): super().__init__(config, server_configs, slurm, testing) + print(f"DEBUG: GSM8kEnv initialized with {len(self.server.servers)} servers") + for i, server in enumerate(self.server.servers): + if hasattr(server, 'config'): + print(f"DEBUG: Server {i}: {server.config.base_url}") self.percent_correct_buffer = list() self.eval_metrics = list() # Add tracking for wandb visualizations @@ -57,7 +61,7 @@ def __init__( @classmethod def config_init(cls) -> Tuple[BaseEnvConfig, List[APIServerConfig]]: env_config = BaseEnvConfig( - tokenizer_name="Qwen/Qwen3-1.7B", + tokenizer_name="Qwen/Qwen2.5-7B", group_size=8, use_wandb=True, rollout_server_url="http://localhost:8000", @@ -69,7 +73,7 @@ def config_init(cls) -> Tuple[BaseEnvConfig, List[APIServerConfig]]: ) server_configs = [ APIServerConfig( - model_name="Qwen/Qwen3-1.7B", + model_name="Qwen/Qwen2.5-7B", base_url="http://localhost:9001/v1", api_key="x", num_requests_for_eval=256, @@ -227,22 +231,33 @@ async def evaluate(self, *args, **kwargs): async def collect_trajectories( self, item: GSM8kRow ) -> Tuple[ScoredDataGroup, list[Item]]: - user_message = {"role": "user", "content": item["question"]} - gold_answer = ( - "\\boxed{" + item["answer"].split("#")[-1].strip().replace(",", "") + "}" - ) + print(f"DEBUG: collect_trajectories() called for question: {item['question'][:80]}...") + try: + user_message = {"role": "user", "content": item["question"]} + gold_answer = ( + "\\boxed{" + item["answer"].split("#")[-1].strip().replace(",", "") + "}" + ) - async with self.server.managed_server(tokenizer=self.tokenizer) as managed: + print(f"DEBUG: About to call managed_server...") + async with self.server.managed_server(tokenizer=self.tokenizer) as managed: + print(f"DEBUG: Inside managed_server context, about to call chat_completion...") - chat_completions = await managed.chat_completion( - messages=[{"role": "system", "content": system_prompt}, user_message], - n=self.config.group_size, - max_tokens=self.config.max_token_length, - temperature=1.0, - ) + chat_completions = await managed.chat_completion( + messages=[{"role": "system", "content": system_prompt}, user_message], + n=self.config.group_size, + max_tokens=self.config.max_token_length, + temperature=1.0, + ) + print(f"DEBUG: chat_completion returned, got {len(chat_completions.choices)} completions") - state = managed.get_state() - nodes = state["nodes"] + state = managed.get_state() + nodes = state["nodes"] + print(f"DEBUG: Got state with {len(nodes)} nodes") + except Exception as e: + print(f"ERROR in collect_trajectories: {type(e).__name__}: {e}") + import traceback + traceback.print_exc() + return None, [] to_score = list() to_backlog = list() @@ -268,6 +283,7 @@ async def collect_trajectories( async def score( self, rollout_group_data ) -> Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]]: + print(f"DEBUG: score() called with {len(rollout_group_data)} rollouts") scores = ScoredDataGroup() scores["tokens"] = list() scores["masks"] = list() @@ -278,6 +294,7 @@ async def score( extraction_mode="first_match", extraction_config=[LatexExtractionConfig()], ) + print(f"DEBUG: Gold answer parsed: {len(gold_parsed)} elements") if len(gold_parsed) != 0: # We require the answer to be provided in correct latex (no malformed operators) random.shuffle(rollout_group_data) @@ -310,7 +327,9 @@ async def score( logprobs = item["logprobs"] # remove obviously bad examples - if len([1 for i in masks if i != -100]) < 10: + num_valid_tokens = len([1 for i in masks if i != -100]) + if num_valid_tokens < 5: # Lowered from 10 to 5 to be less strict + print(f"Filtering out sample with only {num_valid_tokens} valid tokens") continue scores["tokens"].append(tokens) scores["masks"].append(masks) @@ -320,6 +339,11 @@ async def score( if len(scores["tokens"]) >= self.config.group_size: break + # Check if we have enough valid samples after filtering + if len(scores["tokens"]) < self.config.group_size: + print(f"Warning: Only got {len(scores['tokens'])} samples after filtering, need {self.config.group_size}") + return None + for score in scores["scores"]: self.percent_correct_buffer.append(max(score, 0)) @@ -352,11 +376,14 @@ async def score( percentage_of_range = min(percentage_of_range, 1.0) # Apply linear penalty scaling from 1.0 down to 0.0 scores["scores"].append(1.0 - percentage_of_range) - if all([scores["scores"][0] == score for score in scores["scores"]]): - return None # If all the same, we return None + # allow training even when all scores are identical + # if all([scores["scores"][0] == score for score in scores["scores"]]): + # return None # If all the same, we return None + print(f"DEBUG: Returning scores with {len(scores['tokens'])} samples, scores: {scores['scores']}") return scores else: # If the gold solution is not parseable, we return None + print("DEBUG: Gold solution not parseable, returning None") return None async def get_next_item(self) -> GSM8kRow: diff --git a/torchtitan/grpo/test/test_config.toml b/torchtitan/grpo/test/test_config.toml index 60fe9e32d6..e5ca401c41 100644 --- a/torchtitan/grpo/test/test_config.toml +++ b/torchtitan/grpo/test/test_config.toml @@ -17,9 +17,9 @@ save_tb_folder = "tb" enable_wandb = true [model] -name = "qwen3" -flavor = "1.7B" -tokenizer_path = "Qwen/Qwen3-1.7B" +name = "qwen2" +flavor = "7B" +tokenizer_path = "Qwen/Qwen2.5-7B" [optimizer] name = "AdamW" @@ -54,8 +54,8 @@ pipeline_parallel_degree = 1 [grpo] sglang_tp = 1 -sglang_urls = ["localhost:9001"] -sglang_slurm_num_nodes = 0 +sglang_urls = [] +sglang_slurm_num_nodes = 1 sglang_port = 26756 # GRPO hyperparameters @@ -85,7 +85,7 @@ ptx_scale_by_tokens = false enable = true folder = "checkpoints" # Update this path to point to your Qwen3-1.7B checkpoint -initial_load_path = "/home/shared/torchtitan-conversions/qwen3_1.7b" +initial_load_path = "/home/shared/torchtitan-conversions/qwen_2-5_7b" initial_load_legacy = true interval = 50 export_dtype = "float32" diff --git a/torchtitan/grpo/test/test_full_rl.slurm b/torchtitan/grpo/test/test_full_rl.slurm index df9cdc567f..f2513da7a5 100644 --- a/torchtitan/grpo/test/test_full_rl.slurm +++ b/torchtitan/grpo/test/test_full_rl.slurm @@ -20,7 +20,7 @@ echo "SLURM nodes: $SLURM_JOB_NODELIST" # Basic config stuff - pointing to test setup export CONFIG_FILE="$(pwd)/torchtitan/grpo/test/test_config.toml" -export MODEL_NAME="/home/nightwing/Projects/torchtitan/tmp/qwen3-1.7b-hf" +export MODEL_NAME="/home/nightwing/Projects/torchtitan/tmp/qwen2.5-7b" export PYTHON_SCRIPT="$(pwd)/torchtitan/grpo/test/gsm8k_server.py" export PYTHON_ARGS="" export TRAINING_ARGS="" @@ -38,8 +38,8 @@ export NCCL_IGNORE_CPU_AFFINITY=1 # Define environment paths export TRAIN_PATH="$(pwd)" -export TRAIN_ENV="${TRAIN_PATH}/.venv" -export VLLM_ENV="/home/nightwing/envs/vllm/.venv" +export TRAIN_ENV="/home/nightwing/miniconda3/envs/torchtitan/" +export VLLM_ENV="/home/nightwing/miniconda3/envs/vllm/" export API_ENV="${TRAIN_ENV}" # Get head node info diff --git a/torchtitan/grpo_train.py b/torchtitan/grpo_train.py index eeff74dee3..393ab9b32b 100644 --- a/torchtitan/grpo_train.py +++ b/torchtitan/grpo_train.py @@ -522,16 +522,17 @@ def __init__(self, job_config: JobConfig): time.sleep(1) with open(f"{os.environ['LOGDIR']}/sglang_dtypes.json", "r") as f: self.weight_dtypes = json.load(f) - logger.debug( - f"Setting up SGlang process groups, dp_shard_degree: {self.dp_shard_degree}, tp_degree: {self.tp_degree}" - ) - hostname = "localhost" if local_rank < 8 else get_hostname_url() - logger.debug( - f"total: {self.total_group_size}, rank: {self.dp_shard_rank * self.tp_degree + self.tp_rank}, pg_server: {hostname}" - ) - self.sglang_nccl_group, self.sglang_gloo_group = setup_group( - hostname, job_config.grpo.sglang_port, self.total_group_size, local_rank - ) + # All ranks must join the process group (collective operation) + logger.debug( + f"Setting up SGlang process groups, dp_shard_degree: {self.dp_shard_degree}, tp_degree: {self.tp_degree}" + ) + hostname = get_hostname_url() + logger.debug( + f"total: {self.total_group_size}, rank: {self.dp_shard_rank * self.tp_degree + self.tp_rank}, pg_server: {hostname}" + ) + self.sglang_nccl_group, self.sglang_gloo_group = setup_group( + hostname, job_config.grpo.sglang_port, self.total_group_size, local_rank + ) if job_config.grpo.ptx_mixin_batchsize > 0: self.dataloader = self.train_spec.build_dataloader_fn( dp_world_size=dp_degree, diff --git a/vllm_launch.sh b/vllm_launch.sh index b84d3d3c0c..4a77950da5 100644 --- a/vllm_launch.sh +++ b/vllm_launch.sh @@ -10,7 +10,7 @@ if [[ "$SLURM_NODEID" -eq 0 ]]; then echo "Starting trajectory handler..." run-api > ${LOGDIR}/api.log 2>&1 & python $PYTHON_SCRIPT serve --slurm=True $PYTHON_ARGS > ${LOGDIR}/env_server.log 2>&1 & - deactivate + eactivate echo "Started trajectory handler..." fi echo $SLURM_NODEID ", " $NUM_TRAINING_NODES @@ -30,12 +30,13 @@ if [[ "$SLURM_NODEID" -lt "$NUM_TRAINING_NODES" ]]; then # export NCCL_P2P_DISABLE=1 # export NCCL_IB_DISABLE=1 + export NCCL_IB_DISABLE=1 + export NCCL_P2P_LEVEL=SYS + # debugging flags (optional) - export NCCL_DEBUG=WARN + export NCCL_DEBUG=INFO + export NCCL_DEBUG_SUBSYS=NET export PYTHONFAULTHANDLER=1 - # optional debug settings - # export NCCL_DEBUG=INFO - # NCCL_DEBUG_SUBSYS=INIT,GRAPH,ENV # export LD_LIBRARY_PATH=/opt/amazon/efa/lib:$LD_LIBRARY_PATH export LD_LIBRARY_PATH=/usr/local/lib/:$LD_LIBRARY_PATH @@ -44,10 +45,10 @@ if [[ "$SLURM_NODEID" -lt "$NUM_TRAINING_NODES" ]]; then # on your cluster you might need these: # set the network interface -# export NCCL_SOCKET_IFNAME="eth0,en,eth,em,bond" -# export NCCL_BUFFSIZE=2097152 -# export TORCH_DIST_INIT_BARRIER=1 -# export FI_EFA_SET_CUDA_SYNC_MEMOPS=0 + export NCCL_SOCKET_IFNAME=bond0 + export NCCL_BUFFSIZE=2097152 + export TORCH_DIST_INIT_BARRIER=1 + export FI_EFA_SET_CUDA_SYNC_MEMOPS=0 # dcgmi profile --pause # adjust sbatch --ntasks and sbatch --nodes above and --nnodes below @@ -69,6 +70,16 @@ else source ${VLLM_ENV}/bin/activate + # Set NCCL network settings for vLLM weight sync + export NCCL_SOCKET_IFNAME=bond0 + export NCCL_IB_DISABLE=1 + export NCCL_P2P_LEVEL=SYS + export NCCL_BUFFSIZE=2097152 + export TORCH_DIST_INIT_BARRIER=1 + export FI_EFA_SET_CUDA_SYNC_MEMOPS=0 + export NCCL_DEBUG=INFO + export NCCL_DEBUG_SUBSYS=NET + PORT_BASE=9000 # Start 8 vllm instances on GPUs 0-3 From 0048320f91c607099bf57c40f850c540ebe72e74 Mon Sep 17 00:00:00 2001 From: Sam Herring Date: Thu, 5 Feb 2026 10:52:57 -0800 Subject: [PATCH 7/8] cleaning up --- torchtitan/grpo/test/scripts/run_full_test.sh | 128 ------------------ torchtitan/grpo/test/scripts/start_api.sh | 18 --- torchtitan/grpo/test/scripts/start_env.sh | 54 -------- torchtitan/grpo/test/scripts/start_trainer.sh | 67 --------- torchtitan/grpo/test/scripts/start_vllm.sh | 74 ---------- torchtitan/grpo/test/scripts/test_launcher.sh | 103 -------------- 6 files changed, 444 deletions(-) delete mode 100755 torchtitan/grpo/test/scripts/run_full_test.sh delete mode 100755 torchtitan/grpo/test/scripts/start_api.sh delete mode 100755 torchtitan/grpo/test/scripts/start_env.sh delete mode 100755 torchtitan/grpo/test/scripts/start_trainer.sh delete mode 100755 torchtitan/grpo/test/scripts/start_vllm.sh delete mode 100644 torchtitan/grpo/test/scripts/test_launcher.sh diff --git a/torchtitan/grpo/test/scripts/run_full_test.sh b/torchtitan/grpo/test/scripts/run_full_test.sh deleted file mode 100755 index fb7228538d..0000000000 --- a/torchtitan/grpo/test/scripts/run_full_test.sh +++ /dev/null @@ -1,128 +0,0 @@ -#!/bin/bash -# Master script to launch the full RL test pipeline - -set -e - -# Get script directory -SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" - -echo "========================================" -echo "GSM8k RL Test Pipeline Launcher" -echo "========================================" -echo "" -echo "This script will start all components:" -echo " 1. Atropos API Server" -echo " 2. vLLM Inference Server" -echo " 3. GSM8k Environment Server" -echo " 4. TorchTitan Trainer" -echo "" -echo "Press Ctrl+C to stop all services" -echo "" - -# Cleanup function -cleanup() { - echo "" - echo "========================================" - echo "Shutting down all services..." - echo "========================================" - - if [ ! -z "$API_PID" ]; then - echo "Stopping Atropos API (PID: $API_PID)" - kill $API_PID 2>/dev/null || true - fi - - if [ ! -z "$VLLM_PID" ]; then - echo "Stopping vLLM server (PID: $VLLM_PID)" - kill $VLLM_PID 2>/dev/null || true - fi - - echo "Force-killing any remaining vLLM processes..." - pkill -9 -f "vllm.entrypoints.openai.api_server" 2>/dev/null || true - pkill -9 -f "vllm serve" 2>/dev/null || true - pkill -9 -f "torchtitan.grpo.vllm_handling.vllm_runner" 2>/dev/null || true - lsof -ti:9001 | xargs kill -9 2>/dev/null || true - - if [ ! -z "$ENV_PID" ]; then - echo "Stopping GSM8k environment (PID: $ENV_PID)" - kill $ENV_PID 2>/dev/null || true - fi - - echo "All services stopped" - exit 0 -} - -# Set up trap for cleanup -trap cleanup EXIT INT TERM - -# Step 1: Start Atropos API -echo "Step 1/4: Starting Atropos API Server..." -"$SCRIPT_DIR/start_api.sh" > /tmp/atropos_api.log 2>&1 & -API_PID=$! -echo "API started (PID: $API_PID, log: /tmp/atropos_api.log)" -echo "Waiting for API to be ready..." -sleep 5 - -# Check if API is running -if ! curl -s http://localhost:8000/ > /dev/null; then - echo "ERROR: Atropos API failed to start" - echo "Check log at: /tmp/atropos_api.log" - exit 1 -fi -echo "API is ready" -echo "" - -# Step 2: Start vLLM server -echo "Step 2/4: Starting vLLM Inference Server..." -"$SCRIPT_DIR/start_vllm.sh" > /tmp/vllm_launcher.log 2>&1 & -VLLM_PID=$! -echo "vLLM launcher started (PID: $VLLM_PID)" -echo "Waiting for vLLM server to load model (this may take ~30 seconds)..." -sleep 35 - -# Check if vLLM server is running -VLLM_READY=true -PORT=9001 -if ! curl -s "http://localhost:${PORT}/health" > /dev/null; then - echo "WARNING: vLLM server on port $PORT is not responding" - VLLM_READY=false -fi - -if [ "$VLLM_READY" = false ]; then - echo "WARNING: vLLM server may not be ready" - echo "Check log at: /tmp/vllm_server_${PORT}.log" - echo "Last 30 lines of log:" - tail -30 /tmp/vllm_server_${PORT}.log 2>/dev/null || echo "Log file not found" - echo "Continuing anyway..." -else - echo "vLLM server is ready" -fi -echo "" - -# Step 3: Start GSM8k environment -echo "Step 3/4: Starting GSM8k Environment Server..." -"$SCRIPT_DIR/start_env.sh" > /tmp/gsm8k_env_wrapper.log 2>&1 & -ENV_PID=$! -echo "Environment started (PID: $ENV_PID, log: /tmp/gsm8k_env.log)" -echo "Waiting for environment to register..." -sleep 5 -echo "Environment should be running" -echo "" - -# Step 4: Start trainer -echo "Step 4/4: Starting TorchTitan Trainer..." -echo "========================================" -echo "" -"$SCRIPT_DIR/start_trainer.sh" - -# If we get here, training completed successfully -echo "" -echo "========================================" -echo "Test completed successfully!" -echo "========================================" -echo "" -echo "Logs available at:" -echo " - Atropos API: /tmp/atropos_api.log" -echo " - SGLang servers: /tmp/sglang_server_*.log" -echo " - GSM8k environment: /tmp/gsm8k_env.log" -echo " - Trainer: $LOGDIR" -echo "" diff --git a/torchtitan/grpo/test/scripts/start_api.sh b/torchtitan/grpo/test/scripts/start_api.sh deleted file mode 100755 index a0374b6b2e..0000000000 --- a/torchtitan/grpo/test/scripts/start_api.sh +++ /dev/null @@ -1,18 +0,0 @@ -#!/bin/bash -# Start Atropos API Server - -set -e - -echo "========================================" -echo "Starting Atropos API Server" -echo "========================================" - -source /home/nightwing/Projects/torchtitan/.venv/bin/activate - -# Change to Atropos directory -cd /home/shared/atropos - -# Start the API server -# The server will listen on http://localhost:8000 -echo "Starting API server on http://localhost:8000" -run-api diff --git a/torchtitan/grpo/test/scripts/start_env.sh b/torchtitan/grpo/test/scripts/start_env.sh deleted file mode 100755 index 7160cf88ac..0000000000 --- a/torchtitan/grpo/test/scripts/start_env.sh +++ /dev/null @@ -1,54 +0,0 @@ -#!/bin/bash -# Start GSM8k Environment Server - -set -e - -echo "========================================" -echo "Starting GSM8k Environment Server" -echo "========================================" - -# Activate TorchTitan venv (has atroposlib installed) -source /home/nightwing/Projects/torchtitan/.venv/bin/activate - -# Get the torchtitan root directory -SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" -TORCHTITAN_ROOT="$(cd "$SCRIPT_DIR/../../../.." && pwd)" - -echo "TorchTitan root: $TORCHTITAN_ROOT" -cd "$TORCHTITAN_ROOT" - -# Configuration -MODEL_NAME="Qwen/Qwen3-1.7B" -VLLM_URL="http://localhost:9001/v1" - -# Check if Atropos is accessible -if ! python -c "from atroposlib.envs.base import BaseEnv" 2>/dev/null; then - echo "ERROR: Cannot import Atropos. Is it installed in the venv?" - echo "Run: pip install -e /home/shared/atropos" - exit 1 -fi - -# Check if vLLM server is running -echo "Checking vLLM server availability..." -if ! curl -s "$VLLM_URL/models" > /dev/null; then - echo "WARNING: vLLM server at $VLLM_URL is not responding" -fi - -# Check if Atropos API is running -echo "Checking Atropos API availability..." -if ! curl -s "http://localhost:8000/" > /dev/null; then - echo "ERROR: Atropos API is not running on http://localhost:8000" - echo "Please start the API server first (./start_api.sh)" - exit 1 -fi - -echo "" -echo "Starting GSM8k environment..." -python torchtitan/grpo/test/gsm8k_server.py serve \ - --slurm false \ - --openai.model_name "$MODEL_NAME" \ - 2>&1 | tee /tmp/gsm8k_env.log - -echo "" -echo "GSM8k environment stopped" -echo "Log available at: /tmp/gsm8k_env.log" diff --git a/torchtitan/grpo/test/scripts/start_trainer.sh b/torchtitan/grpo/test/scripts/start_trainer.sh deleted file mode 100755 index bea514916f..0000000000 --- a/torchtitan/grpo/test/scripts/start_trainer.sh +++ /dev/null @@ -1,67 +0,0 @@ -#!/bin/bash -# Start TorchTitan RL Trainer -# This pulls batches from Atropos API and trains the model - -set -e - -echo "========================================" -echo "Starting TorchTitan RL Trainer" -echo "========================================" - -source /home/nightwing/Projects/torchtitan/.venv/bin/activate - -# Get the torchtitan root directory -SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" -TORCHTITAN_ROOT="$(cd "$SCRIPT_DIR/../../../.." && pwd)" - -echo "TorchTitan root: $TORCHTITAN_ROOT" -cd "$TORCHTITAN_ROOT" - -# Configuration -CONFIG_FILE="torchtitan/grpo/test/test_config.toml" -NGPU=${NGPU:-4} # default to 4 GPUs, override with: NGPU=8 ./start_trainer.sh -LOG_RANK=${LOG_RANK:-0} - -# Set required environment variables -export LOGDIR="${LOGDIR:-/tmp/torchtitan_logs}" -mkdir -p "$LOGDIR" -echo "Logs will be written to: $LOGDIR" - -# Check if config file exists -if [ ! -f "$CONFIG_FILE" ]; then - echo "ERROR: Config file not found: $CONFIG_FILE" - exit 1 -fi - -# Check if Atropos API is running -echo "Checking Atropos API availability..." -if ! curl -s "http://localhost:8000/" > /dev/null; then - echo "ERROR: Atropos API is not running on http://localhost:8000" - echo "Please start the API server first (./start_api.sh)" - exit 1 -fi - -echo "" -echo "Configuration:" -echo " - Config file: $CONFIG_FILE" -echo " - Number of GPUs: $NGPU" -echo " - Log directory: $LOGDIR" -echo " - Log rank filter: $LOG_RANK" -echo "" - -# Launch trainer with torchrun -echo "Launching trainer..." -PYTORCH_ALLOC_CONF="expandable_segments:True" \ -torchrun \ - --nproc_per_node=$NGPU \ - --rdzv_backend c10d \ - --rdzv_endpoint="localhost:0" \ - --local-ranks-filter $LOG_RANK \ - --role rank \ - --tee 3 \ - -m torchtitan.grpo_train \ - --job.config_file "$CONFIG_FILE" - -echo "" -echo "Training completed!" -echo "Check logs at: $LOGDIR" diff --git a/torchtitan/grpo/test/scripts/start_vllm.sh b/torchtitan/grpo/test/scripts/start_vllm.sh deleted file mode 100755 index 622054ea30..0000000000 --- a/torchtitan/grpo/test/scripts/start_vllm.sh +++ /dev/null @@ -1,74 +0,0 @@ -#!/bin/bash -# Start vLLM Inference Server - -set -e - -echo "========================================" -echo "Starting vLLM Inference Server" -echo "========================================" - -# Cleanup function -cleanup_vllm() { - echo "Cleaning up any existing vLLM processes..." - pkill -9 -f "vllm.entrypoints.openai.api_server" 2>/dev/null || true - pkill -9 -f "vllm serve" 2>/dev/null || true - pkill -9 -f "torchtitan.grpo.vllm_handling.vllm_runner" 2>/dev/null || true - lsof -ti:9001 | xargs kill -9 2>/dev/null || true - sleep 2 -} - -cleanup_vllm - -# Use the separate vLLM environment (not the training env) -source /home/nightwing/envs/vllm/.venv/bin/activate - -# Configuration -MODEL_PATH="/home/nightwing/Projects/torchtitan/tmp/qwen3-1.7b-hf" # HF checkpoint path or name -TP_SIZE=1 -BASE_PORT=9001 - -# Set LOGDIR to match the trainer (needed for distributed_updater coordination) -export LOGDIR="${LOGDIR:-/tmp/torchtitan_logs}" -mkdir -p "$LOGDIR" - -# Set NUM_INFERENCE_NODES=0 for single node setup (required by distributed_updater) -export NUM_INFERENCE_NODES=0 - -echo "Starting vLLM server on port $BASE_PORT..." -echo "CUDA_VISIBLE_DEVICES: 4" -echo "Model: $MODEL_PATH" -echo "LOGDIR: $LOGDIR" -echo "NUM_INFERENCE_NODES: $NUM_INFERENCE_NODES" - -# Run vLLM on GPU 4 (training uses GPUs 0-3) -# IMPORTANT: Set CUDA_VISIBLE_DEVICES as prefix, not export -CUDA_VISIBLE_DEVICES=4 nohup python -m torchtitan.grpo.vllm_handling.vllm_runner \ - --model "$MODEL_PATH" \ - --port $BASE_PORT \ - --host 0.0.0.0 \ - --gpu-memory-utilization 0.75 \ - --dtype="bfloat16" \ - --log-level="error" \ - > "${LOGDIR}/vllm_${BASE_PORT}.log" 2>&1 & - -SERVER_PID=$! -echo "vLLM server starting (PID: $SERVER_PID)" - -echo "" -echo "Waiting for server to be ready (~30 seconds)..." -sleep 60 - -echo "" -echo "Testing server connectivity..." -if curl -s "http://localhost:${BASE_PORT}/health" > /dev/null; then - echo "✓ vLLM server on port $BASE_PORT is ready" -else - echo "✗ vLLM server on port $BASE_PORT is not responding" - echo "Check log at: /tmp/vllm_server_${BASE_PORT}.log" - echo "Last 20 lines of log:" - tail -20 "/tmp/vllm_server_${BASE_PORT}.log" 2>/dev/null || echo "Log file not found" -fi - -echo "" -echo "vLLM server ready for inference!" -echo "Log available at: /tmp/vllm_server_${BASE_PORT}.log" diff --git a/torchtitan/grpo/test/scripts/test_launcher.sh b/torchtitan/grpo/test/scripts/test_launcher.sh deleted file mode 100644 index ba12623911..0000000000 --- a/torchtitan/grpo/test/scripts/test_launcher.sh +++ /dev/null @@ -1,103 +0,0 @@ -#!/bin/bash -# set -e temporarily disabled to see errors -set -x # Print commands for debugging - -printenv -ulimit -n 32000 - -# Get script directory -SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" -TORCHTITAN_ROOT="$(cd "$SCRIPT_DIR/../../../.." && pwd)" -cd "$TORCHTITAN_ROOT" - -# Set defaults if not running under SLURM -: "${SLURM_NODEID:=0}" -: "${NUM_TRAINING_NODES:=1}" -: "${NUM_INFERENCE_NODES:=0}" -: "${LOGDIR:=${TORCHTITAN_ROOT}/logs/test_run}" -: "${MODEL_NAME:=/home/nightwing/Projects/torchtitan/tmp/qwen3-1.7b-hf}" -: "${CONFIG_FILE:=${TORCHTITAN_ROOT}/torchtitan/grpo/test/test_config.toml}" -: "${API_ENV:=/home/nightwing/Projects/torchtitan/.venv}" -: "${TRAIN_ENV:=/home/nightwing/Projects/torchtitan/.venv}" -: "${VLLM_ENV:=/home/nightwing/envs/vllm/.venv}" - -# Export LOGDIR so child processes can see it -export LOGDIR -export NUM_INFERENCE_NODES -export MODEL_NAME -export CONFIG_FILE - -mkdir -p "$LOGDIR" - -echo "Starting test at $(date)" -echo "SLURM_NODEID: $SLURM_NODEID" -echo "NUM_TRAINING_NODES: $NUM_TRAINING_NODES" -echo "NUM_INFERENCE_NODES: $NUM_INFERENCE_NODES" -echo "LOGDIR: $LOGDIR" -echo "MODEL_NAME: $MODEL_NAME" - -# Start API and environment (always on node 0) -if [[ "$SLURM_NODEID" -eq 0 ]]; then - echo "Starting API and environment server..." - source ${API_ENV}/bin/activate - - # Start Atropos API - cd /home/shared/atropos - run-api > ${LOGDIR}/api.log 2>&1 & - cd "$TORCHTITAN_ROOT" - - # Start GSM8k environment server - python torchtitan/grpo/test/gsm8k_server.py serve --slurm=True --openai.model_name="$MODEL_NAME" > ${LOGDIR}/env_server.log 2>&1 & - - deactivate - echo "Started API and environment server..." -fi - -# Start training (on training nodes) -if [[ "$SLURM_NODEID" -lt "$NUM_TRAINING_NODES" ]]; then - echo "Setting up training environment..." - source ${TRAIN_ENV}/bin/activate - - nodes=( $( scontrol show hostnames $SLURM_JOB_NODELIST ) ) - nodes_array=($nodes) - head_node=${nodes_array[0]} - - export LOGLEVEL=INFO - export NCCL_DEBUG=WARN - export PYTHONFAULTHANDLER=1 - export LD_LIBRARY_PATH=/usr/local/lib/:$LD_LIBRARY_PATH - export CUDA_LAUNCH_BLOCKING=0 - - # Launch trainer (vLLM runs on separate inference node) - echo "Launching trainer..." - torchrun --nproc_per_node 8 --rdzv_id 101 --rdzv_backend c10d --rdzv_endpoint="$head_node:29500" --role rank --tee 3 \ - -m torchtitan.grpo_train --job.config_file ${CONFIG_FILE} -# else we're on an inference node -else - echo "Starting vLLM inference server..." - source ${VLLM_ENV}/bin/activate - - PORT_BASE=9000 - LOG_OFFSET=$((SLURM_NODEID * 8)) - - # Start 8 vLLM instances on GPUs 0-7 (matching dakota's setup) - for i in {0..7}; do - GPU_ID=$i - LOG_ID=$((GPU_ID + LOG_OFFSET)) - PORT=$((PORT_BASE + i)) - echo "Starting vLLM instance on GPU $GPU_ID, port $PORT" - CUDA_VISIBLE_DEVICES=$GPU_ID nohup python -m torchtitan.grpo.vllm_handling.vllm_runner \ - --model "$MODEL_NAME" \ - --host 0.0.0.0 \ - --gpu-memory-utilization 0.75 \ - --dtype="bfloat16" \ - --log-level="error" \ - --port $PORT > ${LOGDIR}/vllm_${LOG_ID}.log 2>&1 & - sleep 3 - done - - # Wait indefinitely (keep inference node alive) - wait -fi - -echo "Test completed at $(date)" From 947a503963af4d95d16bda052c560ad952ac4edb Mon Sep 17 00:00:00 2001 From: Sam Herring Date: Tue, 17 Feb 2026 14:33:16 -0800 Subject: [PATCH 8/8] Reverting send wait and rank changes --- torchtitan/grpo/data_handling.py | 9 ++++++--- torchtitan/grpo_train.py | 21 ++++++++++----------- 2 files changed, 16 insertions(+), 14 deletions(-) diff --git a/torchtitan/grpo/data_handling.py b/torchtitan/grpo/data_handling.py index fd978e8e31..a2b26c04a9 100644 --- a/torchtitan/grpo/data_handling.py +++ b/torchtitan/grpo/data_handling.py @@ -435,7 +435,8 @@ def data_handling( flag = flag + 1 torch.distributed.broadcast(flag, 0) - send_wait(sglang_nccl_group, device) + if dp_replicate_rank == 0: + send_wait(sglang_nccl_group, device) max_token_len = torch.tensor(max_token_len).to(device) torch.distributed.all_reduce(max_token_len) # back to int @@ -455,11 +456,13 @@ def data_handling( else: logger.debug("No batch yet, retrying...") torch.distributed.broadcast(flag, 0) - send_wait(sglang_nccl_group, device) + if dp_replicate_rank == 0: + send_wait(sglang_nccl_group, device) else: logger.debug("Waiting for batch from server...") torch.distributed.broadcast(flag, 0) - send_wait(sglang_nccl_group, device) + if dp_replicate_rank == 0: + send_wait(sglang_nccl_group, device) if flag.item() > 0: # Got the batch max_token_len = torch.tensor(0).to(device) diff --git a/torchtitan/grpo_train.py b/torchtitan/grpo_train.py index 393ab9b32b..eeff74dee3 100644 --- a/torchtitan/grpo_train.py +++ b/torchtitan/grpo_train.py @@ -522,17 +522,16 @@ def __init__(self, job_config: JobConfig): time.sleep(1) with open(f"{os.environ['LOGDIR']}/sglang_dtypes.json", "r") as f: self.weight_dtypes = json.load(f) - # All ranks must join the process group (collective operation) - logger.debug( - f"Setting up SGlang process groups, dp_shard_degree: {self.dp_shard_degree}, tp_degree: {self.tp_degree}" - ) - hostname = get_hostname_url() - logger.debug( - f"total: {self.total_group_size}, rank: {self.dp_shard_rank * self.tp_degree + self.tp_rank}, pg_server: {hostname}" - ) - self.sglang_nccl_group, self.sglang_gloo_group = setup_group( - hostname, job_config.grpo.sglang_port, self.total_group_size, local_rank - ) + logger.debug( + f"Setting up SGlang process groups, dp_shard_degree: {self.dp_shard_degree}, tp_degree: {self.tp_degree}" + ) + hostname = "localhost" if local_rank < 8 else get_hostname_url() + logger.debug( + f"total: {self.total_group_size}, rank: {self.dp_shard_rank * self.tp_degree + self.tp_rank}, pg_server: {hostname}" + ) + self.sglang_nccl_group, self.sglang_gloo_group = setup_group( + hostname, job_config.grpo.sglang_port, self.total_group_size, local_rank + ) if job_config.grpo.ptx_mixin_batchsize > 0: self.dataloader = self.train_spec.build_dataloader_fn( dp_world_size=dp_degree,