Skip to content

[skyrl][inference] Rollout plan for the new inference backend #1014

@kouroshHakha

Description

@kouroshHakha

Overview

This document outlines the incremental execution plan to validate and battle-test the new HTTP-based inference stack (RemoteInferenceClient + ServerGroup + InferenceRouter) against all existing functionality in skyrl_train/inference_engines/.

Goal: Achieve feature parity before Phase 4 (removing legacy code).

Feature Flag: All tests should pass with _SKYRL_USE_NEW_INFERENCE=1

Initial RFC #845: In this RFC we outlined the intended change.
Last PR in the series feature gated by _SKYRL_USE_NEW_INFERENCE #931


Incremental Rollout Sequence

Master Task Table

# Scenario vLLM Deps Priority Complexity Status Owner
Core (Wave 1 - Validate Now)
1 Basic colocated: 4x TP=1, FSDP2, NCCL sync Current P0 Low
2 Non-colocated: 4x TP=1 inference, TP4 training Current P0 Low
3 Remote server: standalone vLLM, HTTP weight sync Current P0 Low
Async (Wave 2 - Core Async Features)
4 One-step off-policy async Current P1 Medium
5 Fully async: in-flight weight update (abort mode) Current P1 High
Multi-Turn (Wave 3 - Session Routing)
6 Token-in/Token-out (TITO) multi-turn Current P1 Medium
7 SkyRL Gym HTTP generator (chat completions) Current P1 Medium
Parallelism (Wave 4 - Scale Testing)
8 Colocated TP > 1 (2x TP=2) Current P2 Low
9 Megatron backend (TP2×PP2 training) Current P2 Medium
10 MoE with DP Expert Parallelism Current P2 Medium
Advanced (Wave 5 - Feature Parity)
11 LoRA adapter fine-tuning Current P2 Medium
12 FlashRL (int8/fp8 quantization) Current / might need RFC #31848 and the QERL PR P2 High
Blocked on vLLM RFCs (Wave 6)
13 Fully async: pause mode="keep" (no abort) RFC #32103 P3 Low
14 Native weight sync (no worker extension) RFC #31848 P3 Medium
15 CUDA IPC weight sync (colocated fast path) RFC #31848 P3 Medium

Wave 1: Core Scenarios (Validate Now)

These are the foundational tests that MUST pass before any rollout.

Task 1: Basic Colocated (4x TP=1, FSDP2, NCCL)

Attribute Value
Example examples/gsm8k/run_gsm8k.sh
Config colocate_all=true, num_engines=4, tp=1, weight_sync=nccl
Tests Sleep/wake cycle, NCCL broadcast, session routing
GPUs 4
Duration ~10 min (1 epoch)
# Baseline (legacy)
bash examples/gsm8k/run_gsm8k.sh trainer.epochs=1

# New HTTP stack
_SKYRL_USE_NEW_INFERENCE=1 bash examples/gsm8k/run_gsm8k.sh trainer.epochs=1

Validation Checklist:

  • ServerGroup creates 4 VLLMServerActors with shared placement group
  • InferenceRouter starts and routes requests
  • Sleep before training step, wake before generation
  • Weight sync via NCCL/CUDA_IPC broadcast completes
  • Eval accuracy ≥ legacy path

Task 2: Non-Colocated (4x TP=1 inference, TP4 training)

Attribute Value
Example examples/async/async_run_gsm8k.sh (without async trainer)
Config colocate_all=false, separate PGs for training/inference
Tests NCCL broadcast over separate GPUs
GPUs 8 (4 training + 4 inference)
Duration ~10 min
# Create a simple non-colocated test script
_SKYRL_USE_NEW_INFERENCE=1 bash examples/gsm8k/run_gsm8k.sh \
  trainer.epochs=1 \
  trainer.placement.colocate_all=false \
  trainer.placement.policy_num_gpus_per_node=4 \
  trainer.placement.ref_num_gpus_per_node=4

Validation Checklist:

  • ServerGroup creates servers on separate GPU set
  • No sleep/wake needed (separate GPUs)
  • Weight sync still works (NCCL across GPU sets)

Task 3: Remote Inference Server (Standalone vLLM)

Use vLLM + vllm-router to stand up the server externally. This path does not support colocation.

Attribute Value
Example examples/remote_inference_engine/run_remote.sh
Config external_server_urls=["http://127.0.0.1:8001"]
Tests HTTP weight sync to external server
GPUs 4 training + 4 inference (separate process)
Duration ~15 min
# Terminal 1: Start standalone server
bash examples/remote_inference_engine/run_vllm_server.sh

# Terminal 2: Run training
_SKYRL_USE_NEW_INFERENCE=1 bash examples/remote_inference_engine/run_remote.sh trainer.epochs=1

Validation Checklist:

  • Training connects to external server via HTTP
    - what endpoints are missing right now and are waiting for the RFCs to wrap up?
  • Internal router created for data plane
  • Control plane fan-out to external server works
  • Weight sync via HTTP + NCCL works

Wave 2: Async Features

Task 4: One-Step Off-Policy Async

Attribute Value
Example examples/async/async_run_gsm8k.sh
Config colocate_all=false, AsyncRayPPOTrainer
Tests Pipelined generation/training, weight sync between stages
GPUs 8
_SKYRL_USE_NEW_INFERENCE=1 bash examples/async/async_run_gsm8k.sh trainer.epochs=1

Validation Checklist:

  • Generation step N+1 overlaps training step N
  • Weight sync between pipeline stages
  • generation_ack / sync_finished events work

Task 5: Fully Async with In-Flight Weight Update

Attribute Value
Example examples/fully_async/fully_async_run_gsm8k.sh
Config batched=false, async_engine=true, staleness manager
Tests Pause (abort) → weight sync → resume → retry
GPUs 4 (2 train + 2 gen typical)
Critical This tests the core in-flight update semantics
_SKYRL_USE_NEW_INFERENCE=1 bash examples/fully_async/fully_async_run_gsm8k.sh \
  trainer.epochs=1 \
  trainer.fully_async.max_staleness_steps=2

Validation Checklist:

  • Pause aborts in-flight requests
  • Partial responses accumulated correctly
  • Retry on abort works (accumulated tokens re-fed)
  • Weight sync during pause succeeds
  • Resume continues generation
  • Staleness stays within bounds

Key Code Path: RemoteInferenceClient.generate() retry loop + pause(mode=ABORT)


Wave 3: Multi-Turn & Session Routing

Task 6: Token-In/Token-Out (TITO)

Attribute Value
Example Multi-turn env (e.g., terminal_bench, llm_as_a_judge)
Config batched=false, session IDs for routing
Tests Token preservation, session affinity
_SKYRL_USE_NEW_INFERENCE=1 bash examples/llm_as_a_judge/run_llm_judge.sh trainer.epochs=1

Validation Checklist:

  • Session ID passed in X-Session-ID header
  • Same session → same server (hash-based routing)
  • Token IDs preserved across turns
  • Chat template applied correctly

Task 7: SkyRL Gym HTTP Generator

Attribute Value
Tests /v1/chat/completions endpoint through router
Config SkyRLGymHTTPGenerator with HTTP client

Validation Checklist:

  • Generator uses /v1/chat/completion or /v1/completion
  • Error handling for server errors

Wave 4: Parallelism Configurations

Task 8: Colocated TP > 1 (2x TP=2)

_SKYRL_USE_NEW_INFERENCE=1 bash examples/gsm8k/run_gsm8k.sh \
  trainer.epochs=1 \
  generator.num_inference_engines=2 \
  generator.inference_engine_tensor_parallel_size=2

Validation: get_world_size() returns 4 (2 servers × 2 TP)


Task 9: Megatron Backend

_SKYRL_USE_NEW_INFERENCE=1 bash examples/megatron/run_megatron.sh trainer.epochs=1

Validation: Megatron ↔ vLLM weight mapping works


Task 10: MoE with Expert Parallelism

_SKYRL_USE_NEW_INFERENCE=1 bash examples/moe/run_qwen1_5_MoE_A2_7B.sh trainer.epochs=1

Validation: EP inference + weight sync


Wave 5: Advanced Features

Task 11: LoRA Fine-Tuning

_SKYRL_USE_NEW_INFERENCE=1 bash examples/lora/run_qwen2_5_0.5b_gsm8k_grpo_lora.sh trainer.epochs=1

Validation: Only LoRA adapter weights synced


Task 12: FlashRL (Quantized)

⚠️ FlashRL uses custom get_inference_client() override.

_SKYRL_USE_NEW_INFERENCE=1 bash examples/flash_rl/run_dapo_gsm8k_flashrl_0.5b_int8.sh trainer.epochs=1

Required work:

  • Verify FlashRL engine creation works with ServerGroup
  • May need to create create_server_group_flashrl() helper
  • Weight sync with quantized models

Wave 6: Blocked on Upstream vLLM RFCs

Task 14: Pause Mode "keep" (RFC #32103)

Current Behavior: pause(mode=ABORT) aborts in-flight requests, client retries with accumulated tokens.

Future (RFC #32103): pause(mode=KEEP) preserves KV cache and scheduler state. No abort, no retry needed.

Benefit: Cleaner async RL, no wasted compute on aborted tokens.

Action: Wait for vLLM RFC #32103 to land, then:

  1. Upgrade vLLM
  2. Add PauseMode.KEEP to RemoteInferenceClient
  3. Update fully async trainer to use mode=KEEP

Task 15: Native Weight Sync (RFC #31848)

Current Behavior: Weight sync via custom WorkerWrap extension + /collective_rpc.

Future (RFC #31848): vLLM native /update_weights endpoint.

Benefit: No custom worker extension needed, cleaner integration.

Action: Wait for RFC, then:

  1. Upgrade vLLM
  2. Remove WorkerWrap from server startup
  3. Use native endpoints

Task 16: CUDA IPC Weight Sync

Current Behavior: Colocated mode uses NCCL broadcast (same as non-colocated).

Future: CUDA IPC for same-GPU weight transfer (faster).

Dependency: RFC #31848 for native weight sync, then CUDA IPC strategy.


Architecture Comparison

Component Legacy Stack New HTTP Stack
Managed servers RayWrappedInferenceEngine (Ray actors) ServerGroup + VLLMServerActor
Remote servers RemoteInferenceEngine (HTTP wrapper) RemoteInferenceClient (direct HTTP)
Client InferenceEngineClient (tokenizer, routing) RemoteInferenceClient (pure HTTP)
Routing Client-side route_prompts_to_engines() Server-side InferenceRouter (X-Session-ID)
Control plane Ray actor methods HTTP fan-out to /pause, /resume, etc.
Weight sync WeightLoader per engine HTTP fan-out to /init_weight_transfer, /update_weights

cc @SumanthRH @ahao-anyscale @CharlieFRuan

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions