From 771a4f9c6807cef580d7b369adaf2d3f42f857b1 Mon Sep 17 00:00:00 2001 From: ehsk Date: Tue, 23 Dec 2025 14:31:21 +0000 Subject: [PATCH 1/2] library upgrades + vllm1 weight update changes --- README.md | 6 ++--- conf/base.yaml | 11 +++++--- pipelinerl/finetune_loop.py | 25 ++++++++++------- pipelinerl/torch_utils.py | 53 +++++++++++++++++++++++++++++++++++++ pipelinerl/vllm0.py | 34 +++++++++++++++++++++++- pipelinerl/vllm1.py | 32 ++++++++++++++-------- pyproject.toml | 28 +++++++++++--------- 7 files changed, 148 insertions(+), 41 deletions(-) diff --git a/README.md b/README.md index 8a172c0d..fccf83a6 100644 --- a/README.md +++ b/README.md @@ -195,9 +195,9 @@ cd PipelineRL Create the environments with dependencies. ```bash -conda create -n pipeline-rl -y python=3.11 -conda run --no-capture-output -n pipeline-rl pip install torch==2.6.0 -conda run --no-capture-output -n pipeline-rl pip install -e . --no-build-isolation +conda create -n pipeline-rl -y python=3.12 +conda run --no-capture-output -n pipeline-rl pip install -e . +conda run --no-capture-output -n pipeline-rl pip install flash-attn==2.8.3 --no-build-isolation ``` By default Pipeline-RL will use the file system as the medium for streaming the generated data to the trainer processes. This works on one node, but the files can get quite large. To use Redis instead you will need to install the Redis server in the same conda environment: diff --git a/conf/base.yaml b/conf/base.yaml index 1f8d73cc..2dd03d03 100644 --- a/conf/base.yaml +++ b/conf/base.yaml @@ -57,14 +57,11 @@ test_llm: top_k: 50 vllm_config: - use_v1: false + use_v1: true quantization: null # or bf16_last_layer_fp32 vllm_kwargs: dtype: bfloat16 gpu-memory-utilization: 0.9 - num-scheduler-steps: 1 - disable-log-requests: "" - disable-frontend-multiprocessing: "" max-num-seqs: ${actor.llm_max_rollouts} max-num-batched-tokens: 1024 enable-chunked-prefill: "" @@ -73,6 +70,12 @@ vllm_config: pipeline-parallel-size: 1 generation-config: vllm max_model_len: 10000 + # V1 specific settings + # logprobs-mode: processed_logprobs + # V0 specific settings + # num-scheduler-steps: 1 + # disable-log-requests: "" + # disable-frontend-multiprocessing: "" world: replicas: 1 diff --git a/pipelinerl/finetune_loop.py b/pipelinerl/finetune_loop.py index 57d2950e..c72d2997 100644 --- a/pipelinerl/finetune_loop.py +++ b/pipelinerl/finetune_loop.py @@ -27,7 +27,7 @@ from ring_flash_attn import substitute_hf_flash_attn, update_ring_flash_attn_params from pipelinerl.finetune.value_model import AutoModelForCausalLMWithValueHead -import pipelinerl.torch_utils +from pipelinerl.torch_utils import stateless_init_process_group from pipelinerl.finetune.types import PipelineBatchEncoding from pipelinerl.finetune.checkpoints import ( load_model, @@ -212,7 +212,8 @@ def send_weight_update( for name, parameter in named_parameters.items(): with deepspeed.zero.GatheredParameters([parameter]): if get_accelerator().is_main_process: - dist.broadcast(parameter.data, src=0, group=self.actor_update_group) + # Use PyNcclCommunicator's broadcast method instead of torch.distributed + self.actor_update_group.broadcast(parameter.data, src=0, stream=torch.cuda.current_stream()) if get_accelerator().is_main_process: logger.info("Wait for HTTP requests") for future in futures: # type: ignore @@ -254,8 +255,8 @@ def send_weight_update( futures = self.request_weight_updates(messages) logger.info(f"Published weight update request for version {version}") for _, parameter in named_parameters.items(): - dist.broadcast(parameter.data, src=0, group=self.actor_update_group) - dist.barrier(self.actor_update_group) + # Use PyNcclCommunicator's broadcast method instead of torch.distributed + self.actor_update_group.broadcast(parameter.data, src=0, stream=torch.cuda.current_stream()) for future in futures: future.result() logger.info("Finished broadcasting weights") @@ -408,13 +409,18 @@ def run_finetuning_loop( get_accelerator().wait_for_everyone() if get_accelerator().is_main_process and args.send_weight_updates: - logger.info("Initializing actor process group") - actor_update_group = pipelinerl.torch_utils.init_extra_process_group( - group_name="actor", - backend="nccl", + logger.info("Initializing actor process group using StatelessProcessGroup") + + # Explicitly set CUDA device before creating NCCL process group + current_device = get_accelerator().device + torch.cuda.set_device(current_device) + logger.info(f"Set CUDA device to {current_device} for actor process group (rank 0)") + + actor_update_group = stateless_init_process_group( init_method=cfg.me.weight_update_group_init_method, rank=0, world_size=cfg.me.weight_update_group_world_size, + device=current_device, ) logger.info("Actor process group initialized") else: @@ -493,8 +499,7 @@ def run_finetuning_loop( finally: if weight_update_manager is not None: weight_update_manager.shutdown() - if actor_update_group: - dist.destroy_process_group(actor_update_group) + # PyNcclCommunicator doesn't need explicit destroy like torch.distributed process groups def rl_finetuning_worker( diff --git a/pipelinerl/torch_utils.py b/pipelinerl/torch_utils.py index 2d16d99f..588aeab5 100644 --- a/pipelinerl/torch_utils.py +++ b/pipelinerl/torch_utils.py @@ -1,14 +1,51 @@ +import logging from datetime import timedelta from typing import Any, Optional, Union +from urllib.parse import urlparse + +import torch +import torch.distributed as dist from torch.distributed.distributed_c10d import ( Backend, PrefixStore, + ProcessGroupNCCL, Store, _new_process_group_helper, _world, default_pg_timeout, rendezvous, ) +from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator +from vllm.distributed.utils import StatelessProcessGroup + +logger = logging.getLogger(__name__) + + +def stateless_init_process_group(init_method, rank, world_size, device): + """ + vLLM provides `StatelessProcessGroup` to create a process group + without considering the global process group in torch.distributed. + It is recommended to create `StatelessProcessGroup`, and then initialize + the data-plane communication (NCCL) between external (train processes) + and vLLM workers. + + Args: + init_method: TCP init method string (e.g., "tcp://localhost:9000") + rank: The rank of this process in the group + world_size: Total number of processes in the group + device: The CUDA device to use for NCCL communication + """ + # Parse master_address and master_port from init_method (e.g., "tcp://localhost:9000") + parsed = urlparse(init_method) + master_address = parsed.hostname or "localhost" + master_port = parsed.port or 9000 + logger.debug(f"Parsed master_address: {master_address}, master_port: {master_port}") + + pg = StatelessProcessGroup.create( + host=master_address, port=master_port, rank=rank, world_size=world_size + ) + pynccl = PyNcclCommunicator(pg, device=device) + return pynccl # Copy from pytorch to allow creating multiple main groups. @@ -22,6 +59,7 @@ def init_extra_process_group( store: Optional[Store] = None, group_name: str = None, pg_options: Optional[Any] = None, + device_id: Optional[torch.device] = None, ): assert (store is None) or (init_method is None), "Cannot specify both init_method and store." @@ -49,6 +87,19 @@ def init_extra_process_group( # different systems (e.g. RPC) in case the store is multi-tenant. store = PrefixStore(group_name, store) + # Create NCCL-specific options if using NCCL backend + logger.info(f"[{group_name}] Backend: {backend}, str(backend): {str(backend)}") + if pg_options is None and str(backend) == "nccl": + pg_options = ProcessGroupNCCL.Options() + pg_options.is_high_priority_stream = False + logger.info(f"[{group_name}] Created NCCL options: {pg_options}") + + # Ensure CUDA is synchronized before creating NCCL process group + if device_id is not None: + torch.cuda.synchronize(device_id) + logger.info(f"[{group_name}] CUDA synchronized on {device_id}") + + logger.info(f"[{group_name}] Creating process group: rank={rank}, world_size={world_size}, device_id={device_id}") pg, _ = _new_process_group_helper( world_size, rank, @@ -58,7 +109,9 @@ def init_extra_process_group( group_name=group_name, backend_options=pg_options, timeout=timeout, + device_id=device_id, ) + logger.info(f"[{group_name}] Process group created successfully") _world.pg_group_ranks[pg] = {i: i for i in range(world_size)} diff --git a/pipelinerl/vllm0.py b/pipelinerl/vllm0.py index 93e99d60..fb3bea58 100644 --- a/pipelinerl/vllm0.py +++ b/pipelinerl/vllm0.py @@ -1,3 +1,36 @@ +""" +DEPRECATED - Kept only for backward compatibility with older vLLM versions. + +This module provides a custom vLLM inference server with dynamic weight updates using the legacy V0 engine architecture. + +Compatibility: + - vLLM versions <= 0.8.x only + - The V0 engine was removed in vLLM 0.11.0 + - Use vllm1.py instead +""" +import warnings +from packaging import version as version_parser +import vllm + +# Check vLLM version compatibility +vllm_version = version_parser.parse(vllm.__version__) + +if vllm_version >= version_parser.parse("0.9.0"): + raise ImportError( + f"pipelinerl.vllm0 is not compatible with vLLM {vllm.__version__}. " + "This module only works with vLLM <= 0.8.x. " + "Please use pipelinerl.vllm1 for vLLM >= 0.11.0 instead." + ) + +# Only show deprecation warning for compatible versions +warnings.warn( + "pipelinerl.vllm0 is DEPRECATED and will be removed in a future version. " + "This module only works with vLLM <= 0.8.x. " + "Please use pipelinerl.vllm1 as it is actively maintained.", + DeprecationWarning, + stacklevel=2, +) + import asyncio import json import logging @@ -14,7 +47,6 @@ ) from vllm.entrypoints.launcher import serve_http from vllm.entrypoints.openai.api_server import ( - run_server, create_server_socket, build_app, init_app_state, diff --git a/pipelinerl/vllm1.py b/pipelinerl/vllm1.py index 1ac611d0..86e6be4a 100644 --- a/pipelinerl/vllm1.py +++ b/pipelinerl/vllm1.py @@ -2,7 +2,8 @@ import signal import torch import uvloop -from vllm.utils import FlexibleArgumentParser, set_ulimit +from vllm.utils.argparse_utils import FlexibleArgumentParser +from vllm.utils.system_utils import set_ulimit from vllm.entrypoints.openai.cli_args import ( make_arg_parser, validate_parsed_serve_args, @@ -26,8 +27,8 @@ from pipelinerl.finetune_loop import WeightUpdateRequest from pipelinerl.vllm_quantization import string_to_dtype # reuse mapping +from pipelinerl.torch_utils import stateless_init_process_group from typing import Any, Protocol, runtime_checkable -import pipelinerl.torch_utils import pipelinerl.vllm_quantization # Register bf16_last_layer_fp32 quantization config logger = logging.getLogger(__name__) @@ -40,7 +41,6 @@ handler.setFormatter(formatter) logger.addHandler(handler) - @runtime_checkable class LikeWorker(Protocol): rank: int @@ -72,15 +72,18 @@ def init_actor_update_group( prefix + f"Weight update group init method: {weight_update_group_init_method}, world size: {weight_update_group_world_size}" ) - self.process_group = pipelinerl.torch_utils.init_extra_process_group( - group_name="actor", - backend="nccl", + + # Use vLLM's StatelessProcessGroup instead of torch.distributed + self.model_update_group = stateless_init_process_group( init_method=weight_update_group_init_method, rank=self.pg_rank, world_size=weight_update_group_world_size, + device=self.device, ) + logger.info(prefix + "Actor update process group initialized") - def receive_weight_update(self: LikeWorker, request: WeightUpdateRequest): + def receive_weight_update(self: LikeWorker, request_json: str): + request = WeightUpdateRequest.model_validate_json(request_json) torch.cuda.synchronize(self.device) logger.info("Start receiving weight update") expected_dtypes = (torch.bfloat16, torch.float32, torch.float16) @@ -89,7 +92,8 @@ def receive_weight_update(self: LikeWorker, request: WeightUpdateRequest): if target_dtype not in expected_dtypes: logger.warning(f"Unexpected dtype for {info.name}: {info.dtype}") buffer = torch.empty(tuple(info.shape), dtype=target_dtype, device=self.device) - torch.distributed.broadcast(buffer, src=0, group=self.process_group) + # Use PyNcclCommunicator's broadcast method instead of torch.distributed + self.model_update_group.broadcast(buffer, src=0, stream=torch.cuda.current_stream()) loaded_params = self.model_runner.model.load_weights(weights=[(info.name, buffer)]) # type: ignore if len(loaded_params) != 1: raise ValueError(f"model {info.name} not found in model state dict") @@ -114,8 +118,13 @@ async def input_process_groups(self): ) async def receive_weight_update(self, request: WeightUpdateRequest): + # Ensure workers are ready by executing a dummy batch first + # This synchronizes workers before the NCCL collective + # logger.info("Synchronizing workers before weight update...") + # await self.engine_client.execute_dummy_batch_async() + # logger.info("Workers synchronized, starting weight update") await self.engine_client.collective_rpc_async( - "receive_weight_update", args=(request,) + "receive_weight_update", args=(request.model_dump_json(),) ) logger.info("Weight update processed") @@ -157,7 +166,7 @@ def signal_handler(*_) -> None: vllm_config=engine_config, usage_context=UsageContext.OPENAI_API_SERVER, disable_log_stats=engine_args.disable_log_stats, - disable_log_requests=engine_args.disable_log_requests, + enable_log_requests=engine_args.enable_log_requests, ) assert isinstance(engine.engine_core, AsyncMPClient) @@ -172,10 +181,11 @@ def signal_handler(*_) -> None: @app.post("/receive_weight_update") async def _receive_weight_update(request: WeightUpdateRequest): + logger.info("Received weight update request") await weight_update_manager.receive_weight_update(request) return {"status": "ok"} - await init_app_state(engine, engine_config, app.state, args) + await init_app_state(engine, app.state, args) shutdown_task = await serve_http( app, sock, diff --git a/pyproject.toml b/pyproject.toml index 7fc9978a..ab7e69b0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,11 @@ [build-system] -requires = ["setuptools>=61.0", "wheel"] +requires = [ + "setuptools>=61.0", + "wheel", + "numpy>=1.26.0", + "packaging>=23.0", + "torch~=2.9.0", +] build-backend = "setuptools.build_meta" [project] @@ -14,15 +20,13 @@ authors = [ ] dependencies = [ "aiohttp>=3.9.0", - "torch>=2.6", - "vllm==0.8.5.post1", - "accelerate==1.7.0", - "deepspeed==0.15.4", + "vllm==0.11.2", + "accelerate==1.12.0", + "deepspeed~=0.18.0", "browsergym>=0.13.0", "datasets>=2.21.0", - "transformers==4.51.1" , + "transformers~=4.57.0" , "fastapi>=0.115.0", - "flash-attn==2.7.4.post1", "joblib>=1.3.2", "jsonref>=1.1.0", "litellm>=1.61.0", @@ -32,11 +36,11 @@ dependencies = [ "Pillow>=10.0.0", "psutil>=5.9.0", "pydantic>=2.9.0", - "ring-flash-attn==0.1.6", - "math-verify[antlr4_9_3]==0.7.0", - "orjson==3.10.16", + "ring-flash-attn==0.1.8", + "math-verify[antlr4_9_3]==0.8.0", + "orjson~=3.11.0", "requests>=2.31.0", - "redis==5.2.1", + "redis~=7.0.0", "safetensors>=0.4.0", "tenacity>=8.2.0", "uvicorn>=0.29.0", @@ -50,7 +54,7 @@ tapeagents = [ "Tapeagents[finetune]==0.1.16", ] lora = [ - "peft==0.12.0", + "peft==0.18.0", ] [tool.setuptools.packages.find] From c0dc029df4318e9142ab2e1c5b4e1b3c85607bcb Mon Sep 17 00:00:00 2001 From: ehsk Date: Tue, 23 Dec 2025 14:38:27 +0000 Subject: [PATCH 2/2] unused parameter (device_id) removed --- pipelinerl/torch_utils.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/pipelinerl/torch_utils.py b/pipelinerl/torch_utils.py index 588aeab5..d2b78d53 100644 --- a/pipelinerl/torch_utils.py +++ b/pipelinerl/torch_utils.py @@ -59,7 +59,6 @@ def init_extra_process_group( store: Optional[Store] = None, group_name: str = None, pg_options: Optional[Any] = None, - device_id: Optional[torch.device] = None, ): assert (store is None) or (init_method is None), "Cannot specify both init_method and store." @@ -94,12 +93,6 @@ def init_extra_process_group( pg_options.is_high_priority_stream = False logger.info(f"[{group_name}] Created NCCL options: {pg_options}") - # Ensure CUDA is synchronized before creating NCCL process group - if device_id is not None: - torch.cuda.synchronize(device_id) - logger.info(f"[{group_name}] CUDA synchronized on {device_id}") - - logger.info(f"[{group_name}] Creating process group: rank={rank}, world_size={world_size}, device_id={device_id}") pg, _ = _new_process_group_helper( world_size, rank, @@ -109,7 +102,6 @@ def init_extra_process_group( group_name=group_name, backend_options=pg_options, timeout=timeout, - device_id=device_id, ) logger.info(f"[{group_name}] Process group created successfully")