From 4bb3180112548fd3ac18ca142a7df9b3c555897f Mon Sep 17 00:00:00 2001 From: FlamingoPg <1106310035@qq.com> Date: Fri, 6 Mar 2026 15:32:04 +0800 Subject: [PATCH 01/10] add vllm support --- README.md | 36 +- configs/vllm_qwen3_8b.yaml | 76 ++ docs/code_architecture.md | 3 +- examples/README.md | 12 +- examples/qwen3-8b-single-node/run.sh | 6 +- patches/vllm/v0.15.1/vllm.patch | 1 - pyproject.toml | 6 +- tests/test_vllm_engine.py | 426 ++++++++++ tools/build_conda.sh | 149 +++- torchspec/config/inference_config.py | 44 + torchspec/config/train_config.py | 1 + torchspec/controller/loop.py | 2 +- torchspec/inference/engine/__init__.py | 6 +- torchspec/inference/engine/vllm_engine.py | 706 ++++++++++++++++ .../inference/engine/vllm_worker_extension.py | 781 ++++++++++++++++++ torchspec/inference/factory.py | 139 +++- torchspec/training/data_fetcher.py | 12 +- 17 files changed, 2360 insertions(+), 46 deletions(-) create mode 100644 configs/vllm_qwen3_8b.yaml delete mode 100644 patches/vllm/v0.15.1/vllm.patch create mode 100644 tests/test_vllm_engine.py create mode 100644 torchspec/inference/engine/vllm_engine.py create mode 100644 torchspec/inference/engine/vllm_worker_extension.py diff --git a/README.md b/README.md index 076406c..20af846 100644 --- a/README.md +++ b/README.md @@ -8,12 +8,32 @@ TorchSpec is a torch-native speculative decoding training framework. We introduc ## Setup +### Choose Your Backend + +TorchSpec supports two inference backends: + +| Backend | Best For | Installation | +|---------|----------|--------------| +| **SGLang** | Production workloads, high throughput | `./tools/build_conda.sh 1 sglang` (default) | +| **vLLM** | Flexibility, easier deployment | `./tools/build_conda.sh 1 vllm` | +| **Both** | Development, comparison testing | `./tools/build_conda.sh 1 both` | + +### Quick Setup + ```bash +# Install with SGLang (default) ./tools/build_conda.sh micromamba activate torchspec + +# Or install with vLLM +./tools/build_conda.sh 1 vllm +micromamba activate torchspec ``` -To install into your current environment instead: `./tools/build_conda.sh current` +To install into your current environment instead: +```bash +./tools/build_conda.sh current sglang # or 'vllm' or 'both' +``` Optional — install Flash Attention: @@ -21,6 +41,20 @@ Optional — install Flash Attention: pip install -e ".[fa]" ``` +### Backend-Specific Usage + +**SGLang (default):** +```bash +./examples/qwen3-8b-single-node/run.sh +``` + +**vLLM:** +```bash +./examples/qwen3-8b-single-node/run.sh --config configs/vllm_qwen3_8b.yaml +``` + +TorchSpec uses vLLM's **Worker Extension** mechanism to hook into the model's forward pass and capture hidden states directly in the worker processes. This avoids RPC serialization issues and enables reliable hidden states extraction. + ## Quick Start Train an Eagle3 draft model for Qwen3-8B using inference engine (4 GPUs: 2 training + 2 inference): diff --git a/configs/vllm_qwen3_8b.yaml b/configs/vllm_qwen3_8b.yaml new file mode 100644 index 0000000..5fe7126 --- /dev/null +++ b/configs/vllm_qwen3_8b.yaml @@ -0,0 +1,76 @@ +# Configuration for train_entry.py with vLLM Engine inference (nested config format) +# +# GPU allocation: +# - 2 GPUs for inference (duplicate mode: each engine has full model copy) +# - 2 GPUs for training (DP/FSDP: model sharded across 2 GPUs) +# - Total: 4 GPUs +# +# Installation: +# pip install -e ".[vllm]" # Install vLLM backend +# +# Usage: +# python -m torchspec.train_entry --config configs/vllm_qwen3_8b.yaml +# +# Note: Uses vLLM Worker Extension to hook into model forward pass for hidden states capture. + +model: + target_model_path: Qwen/Qwen3-8B + trust_remote_code: true + +dataset: + train_data_path: examples/data/sample_conversations.jsonl + eval_data_path: examples/data/eval_conversations.jsonl + eval_interval: 100 + chat_template: qwen + prompt_key: conversations + +# Use GPUs 4-7 to avoid zombie process on GPU 0-1 +# GPU 4-5: inference (vLLM, TP=2) +# GPU 6-7: training (DP=2) + +training: + attention_backend: flex_attention + micro_batch_size: 1 + draft_accumulation_steps: 1 + learning_rate: 1e-4 + max_concurrent_batches: 1 + max_grad_norm: 0.5 + max_seq_length: 16384 + num_epochs: 1 + seed: 42 + training_num_gpus_per_node: 2 + training_num_nodes: 1 + ttt_length: 7 + save_per_epoch: true + warmup_ratio: 0.015 + +inference: + inference_engine_type: vllm + inference_num_gpus: 2 + inference_num_gpus_per_engine: 2 + inference_num_gpus_per_node: 4 + max_sample_pool_size: 64 + inference_buffer_threshold: 32 + inference_batch_size: 8 + vllm: + tp_size: 2 + mem_fraction_static: 0.7 + use_worker_extension: true + extra_args: + max_num_batched_tokens: 32768 + +mooncake: + master_server_address: null + metadata_server: null + protocol: tcp + global_segment_size: 16GB + local_buffer_size: 4GB + +output_dir: ./outputs/vllm_qwen3_8b-single-node +cache_dir: ./cache +model_download_dir: null + +debug: + save_debug_train_data: null + debug_train_only: false + debug_inference_only: false diff --git a/docs/code_architecture.md b/docs/code_architecture.md index 5fa2379..ab10ad4 100644 --- a/docs/code_architecture.md +++ b/docs/code_architecture.md @@ -24,7 +24,8 @@ torchspec/ │ ├── base.py # InferenceEngine (ABC) │ ├── hf_engine.py # HFEngine (Ray actor, inherits RayActor) │ ├── hf_runner.py # HFRunner (core inference logic) -│ └── sgl_engine.py # SglEngine (Ray actor, inherits RayActor) +│ ├── sgl_engine.py # SglEngine (Ray actor, inherits RayActor) +│ └── vllm_engine.py # VllmEngine (Ray actor, uses vLLM extract_hidden_states) ├── models/ # Model definitions │ ├── eagle3.py # Eagle3Model (core forward/loss) │ ├── draft/ # Draft model implementations diff --git a/examples/README.md b/examples/README.md index ebd664f..1ae5f0d 100644 --- a/examples/README.md +++ b/examples/README.md @@ -17,12 +17,22 @@ If you just want to try TorchSpec locally, start with **hf-quickstart** (3 GPUs, ./examples/hf-quickstart/run.sh ``` -For production workloads with SGLang async inference, use **qwen3-8b-single-node**: +For production workloads with async inference, use **qwen3-8b-single-node**: ```bash ./examples/qwen3-8b-single-node/run.sh ``` +## Switching inference backends + +Examples use SGLang by default. To use vLLM instead: + +```bash +# Use vLLM backend with qwen3-8b-single-node example +./examples/qwen3-8b-single-node/run.sh \ + --config configs/vllm_qwen3_8b.yaml \ +``` + ## Data Sample training data is in [`data/sample_conversations.jsonl`](data/sample_conversations.jsonl). All examples that use local data point to this file by default. diff --git a/examples/qwen3-8b-single-node/run.sh b/examples/qwen3-8b-single-node/run.sh index ab68399..14675bb 100755 --- a/examples/qwen3-8b-single-node/run.sh +++ b/examples/qwen3-8b-single-node/run.sh @@ -1,5 +1,5 @@ #!/bin/bash -# Train with SglEngine async inference (multi-GPU version) +# Train with SGLang/vLLM async inference (multi-GPU version) # # GPU allocation (default: 4 GPUs total): # - 2 GPUs for inference (duplicate mode: each engine has full model copy) @@ -46,7 +46,7 @@ INFERENCE_GPUS=2 LOCAL_IP=$(python3 -c "import socket; s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM); s.connect(('8.8.8.8', 80)); print(s.getsockname()[0]); s.close()") echo "==============================================" -echo "Train with SglEngine inference" +echo "Train with async inference" echo "==============================================" echo "Config: $CONFIG_FILE (nested format)" echo "Total GPUs: $TOTAL_GPUS (CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES)" @@ -59,11 +59,9 @@ echo "==============================================" python3 -m torchspec.train_entry \ --config "$CONFIG_FILE" \ training.training_num_gpus_per_node="$TRAIN_GPUS" \ - inference.inference_engine_type="sgl" \ inference.inference_num_gpus="$INFERENCE_GPUS" \ inference.inference_num_gpus_per_engine=2 \ inference.inference_num_gpus_per_node="$TOTAL_GPUS" \ - inference.sglang.tp_size=2 \ "$@" echo "==============================================" diff --git a/patches/vllm/v0.15.1/vllm.patch b/patches/vllm/v0.15.1/vllm.patch deleted file mode 100644 index 390bdff..0000000 --- a/patches/vllm/v0.15.1/vllm.patch +++ /dev/null @@ -1 +0,0 @@ -# PLACEHOLDER diff --git a/pyproject.toml b/pyproject.toml index 12a7721..6a23e90 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,8 +41,12 @@ dev = [ "ruff", ] +vllm = [ + "vllm>=0.16.0", +] + fa = [ - "flash-attn-cute @ git+https://github.com/Dao-AILab/flash-attention.git@fec3a6a18460c1b40f097208d4c16fe8964a679d#subdirectory=flash_attn/cute", + "flash-attention-cute @ git+https://github.com/Dao-AILab/flash-attention.git@fec3a6a18460c1b40f097208d4c16fe8964a679d#subdirectory=flash_attn/cute", "nvidia-cutlass-dsl==4.4.0.dev1", "nvidia-cutlass-dsl-libs-base==4.4.0.dev1", ] diff --git a/tests/test_vllm_engine.py b/tests/test_vllm_engine.py new file mode 100644 index 0000000..6c3ee56 --- /dev/null +++ b/tests/test_vllm_engine.py @@ -0,0 +1,426 @@ +# Copyright (c) 2026 LightSeek Foundation +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +"""Tests for vLLM Worker Extension. + +This file contains both: +- Unit tests: Test logic with mocks (no GPU/vLLM/Mooncake needed) +- Integration tests: Test with real vLLM engine (requires GPU + infrastructure) +""" + +import os +from dataclasses import dataclass +from unittest.mock import MagicMock, patch + +import pytest +import torch + +# ============================================================================= +# Helpers +# ============================================================================= + + +@dataclass +class MockArgs: + """Mock args for VllmWorkerExtension initialization.""" + + target_model_path: str = "Qwen/Qwen3-8B" + tensor_parallel_size: int = 2 + max_model_len: int = 2048 + trust_remote_code: bool = True + + +def _import_vllm_worker_extension(): + """Import VllmWorkerExtension, skipping test if dependencies unavailable.""" + try: + from torchspec.inference.engine.vllm_worker_extension import ( + VllmWorkerExtension, + _sanitize_mooncake_key, + ) + + return VllmWorkerExtension, _sanitize_mooncake_key + except ImportError as e: + pytest.skip(f"VllmWorkerExtension import failed (missing deps): {e}") + + +# ============================================================================= +# Unit Tests (No real vLLM/GPU/Mooncake needed) +# ============================================================================= + + +class TestSanitizeMooncakeKey: + """Unit tests for _sanitize_mooncake_key pure function.""" + + def test_alphanumeric_unchanged(self): + """Test alphanumeric keys pass through unchanged.""" + _, _sanitize = _import_vllm_worker_extension() + assert _sanitize("req_abc_123") == "req_abc_123" + + def test_special_chars_replaced(self): + """Test special characters are replaced with underscores.""" + _, _sanitize = _import_vllm_worker_extension() + assert _sanitize("req@abc#123") == "req_abc_123" + assert _sanitize("req.id.name") == "req_id_name" + assert _sanitize("req:name|value") == "req_name_value" + + def test_leading_digit_prefixed(self): + """Test leading digits get 'k' prefix.""" + _, _sanitize = _import_vllm_worker_extension() + assert _sanitize("123_req") == "k123_req" + assert _sanitize("1abc") == "k1abc" + + def test_empty_string(self): + """Test empty string handling.""" + _, _sanitize = _import_vllm_worker_extension() + assert _sanitize("") == "" + + +class TestVllmWorkerExtensionState: + """Unit tests for VllmWorkerExtension state management.""" + + def test_init_stores_config(self): + """Test constructor initializes state correctly.""" + VllmWorkerExtension, _ = _import_vllm_worker_extension() + + ext = VllmWorkerExtension() + + assert ext._layer_ids == frozenset() + assert ext._captured_states is None + assert ext._request_metadata == [] + assert ext._current_request_metadata is None + assert ext._mooncake_store is None + assert ext._store_initialized is False + + def test_set_request_metadata(self): + """Test setting request metadata.""" + VllmWorkerExtension, _ = _import_vllm_worker_extension() + + ext = VllmWorkerExtension() + metadata = {"req_1": 100, "req_2": 200} + packed_map = {"req_1": "0,3", "req_2": "0,5"} + input_ids_map = {"req_1": [1, 2, 3], "req_2": [4, 5, 6]} + + ext._set_request_metadata(metadata, packed_map, input_ids_map) + + assert ext._current_request_metadata == metadata + assert ext._packed_loss_mask_map == packed_map + assert ext._input_ids_map == input_ids_map + + def test_reset_capture_clears_state(self): + """Test reset_capture clears all captured state.""" + VllmWorkerExtension, _ = _import_vllm_worker_extension() + + ext = VllmWorkerExtension() + ext._layer_ids = frozenset({5, 10, 15}) + ext._captured_states = [[torch.randn(10, 4096)], [torch.randn(10, 4096)]] + ext._captured_input_ids = torch.tensor([1, 2, 3]) + ext._request_metadata = [{"req_1": 10}] + ext._current_request_metadata = {"req_1": 10} + ext._packed_loss_mask_map = {"req_1": "0,3"} + ext._input_ids_map = {"req_1": [1, 2, 3]} + + ext._reset_capture() + + assert ext._captured_states is None + assert ext._captured_input_ids is None + assert ext._request_metadata == [] + assert ext._current_request_metadata is None + assert ext._packed_loss_mask_map == {} + assert ext._input_ids_map == {} + + def test_reset_capture_requires_prior_setup(self): + """Test reset_capture requires _setup_hidden_states_capture first.""" + VllmWorkerExtension, _ = _import_vllm_worker_extension() + + ext = VllmWorkerExtension() + # Don't set _layer_ids + + with pytest.raises(RuntimeError, match="Must call _setup_hidden_states_capture"): + ext._reset_capture() + + +class TestStoreCapturedStates: + """Unit tests for _store_captured_states with mocked dependencies.""" + + def test_store_first_capture(self): + """Test first capture initializes the state lists.""" + VllmWorkerExtension, _ = _import_vllm_worker_extension() + + ext = VllmWorkerExtension() + tensors = [torch.randn(10, 4096), torch.randn(10, 4096)] + + ext._store_captured_states(tensors) + + assert ext._captured_states is not None + assert len(ext._captured_states) == 2 + assert torch.equal(ext._captured_states[0][0], tensors[0]) + assert torch.equal(ext._captured_states[1][0], tensors[1]) + + def test_store_appends_to_existing(self): + """Test subsequent captures append to existing lists.""" + VllmWorkerExtension, _ = _import_vllm_worker_extension() + + ext = VllmWorkerExtension() + ext._captured_states = [[torch.randn(10, 4096)], [torch.randn(10, 4096)]] + + new_tensors = [torch.randn(10, 4096), torch.randn(10, 4096)] + ext._store_captured_states(new_tensors) + + assert len(ext._captured_states[0]) == 2 + assert len(ext._captured_states[1]) == 2 + assert torch.equal(ext._captured_states[0][1], new_tensors[0]) + + def test_store_extracts_metadata_from_input_batch(self): + """Test metadata extraction from model_runner.input_batch.""" + VllmWorkerExtension, _ = _import_vllm_worker_extension() + + ext = VllmWorkerExtension() + + # Mock model_runner with input_batch + mock_batch = MagicMock() + mock_batch.req_ids = ["req_1", "req_2"] + mock_batch.req_id_to_index = {"req_1": 0, "req_2": 1} + mock_batch.num_tokens = [100, 200] + mock_batch.num_computed_tokens = [0, 0] + + ext.model_runner = MagicMock() + ext.model_runner.input_batch = mock_batch + + tensors = [torch.randn(10, 4096)] + ext._store_captured_states(tensors) + + assert len(ext._request_metadata) == 1 + assert "req_1" in ext._request_metadata[0] + assert "req_2" in ext._request_metadata[0] + + +class TestCudaDeviceSafe: + """Unit tests for _get_cuda_device_safe with mocked torch.cuda.""" + + @patch("torch.cuda.is_initialized") + @patch("torch.cuda.current_device") + def test_initialized_context(self, mock_current, mock_initialized): + """Test when CUDA is already initialized.""" + VllmWorkerExtension, _ = _import_vllm_worker_extension() + + mock_initialized.return_value = True + mock_current.return_value = 1 + + ext = VllmWorkerExtension() + device = ext._get_cuda_device_safe() + + assert str(device) == "cuda:1" + + @patch("torch.cuda.is_initialized") + def test_uninitialized_context_fallback(self, mock_initialized): + """Test fallback when CUDA not initialized (V1 engine).""" + VllmWorkerExtension, _ = _import_vllm_worker_extension() + + mock_initialized.return_value = False + + ext = VllmWorkerExtension() + device = ext._get_cuda_device_safe() + + assert str(device) == "cuda:0" + + +class TestTokenSlicingLogic: + """Unit tests for token distribution and slicing logic.""" + + def test_ratio_based_distribution(self): + """Test ratio calculation for token distribution.""" + VllmWorkerExtension, _ = _import_vllm_worker_extension() + + ext = VllmWorkerExtension() + ext._current_request_metadata = {"req_1": 100, "req_2": 200} + + external_ids = list(ext._current_request_metadata.keys()) + token_counts = list(ext._current_request_metadata.values()) + total_expected = sum(token_counts) # 300 + total_captured = 150 # Half the expected tokens + + ratio = total_captured / total_expected # 0.5 + + # Calculate actual tokens per request + actual_tokens = {ext_id: int(tc * ratio) for ext_id, tc in zip(external_ids, token_counts)} + + assert actual_tokens == {"req_1": 50, "req_2": 100} + + def test_concatenated_tensors_shape(self): + """Test tensor concatenation from multiple iterations.""" + VllmWorkerExtension, _ = _import_vllm_worker_extension() + + ext = VllmWorkerExtension() + # Simulate 2 iterations with 5 tokens each + ext._captured_states = [ + [torch.randn(5, 4096), torch.randn(5, 4096)], # Layer 0 + [torch.randn(5, 4096), torch.randn(5, 4096)], # Layer 1 + ] + + # Concatenate (simulating _store_and_get_metadata logic) + concatenated = [torch.cat(layer_tensors, dim=0) for layer_tensors in ext._captured_states] + + assert concatenated[0].shape == (10, 4096) + assert concatenated[1].shape == (10, 4096) + + +# ============================================================================= +# Integration Tests (Requires real GPU + vLLM + Mooncake) +# ============================================================================= + + +@pytest.mark.integration +@pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU required") +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Need at least 2 GPUs for TP=2") +class TestVllmWorkerExtensionIntegration: + """Integration tests for vLLM Worker Extension with real infrastructure.""" + + @pytest.fixture(autouse=True) + def setup_env(self): + """Setup Mooncake environment variables.""" + os.environ.setdefault("MOONCAKE_MASTER_HOST", "0.0.0.0") + os.environ.setdefault("MOONCAKE_MASTER_PORT", "50051") + os.environ.setdefault("MOONCAKE_METADATA_PORT", "8090") + yield + # Cleanup not needed for env vars + + def test_vllm_worker_extension_mooncake(self): + """Test vLLM Worker Extension stores and retrieves hidden states from Mooncake.""" + from transformers import AutoTokenizer + from vllm import LLM, SamplingParams + + from torchspec.transfer.mooncake import EagleMooncakeStore, MooncakeConfig + + model_path = "Qwen/Qwen3-8B" + + # Initialize tokenizer + tokenizer = AutoTokenizer.from_pretrained(model_path) + + # Test inputs + input_ids_list = [ + [1, 2345, 6789], + [100, 200, 300, 400], + [500, 600], + ] + data_ids = ["test_req_0", "test_req_1", "test_req_2"] + + # Initialize vLLM with Worker Extension + engine = LLM( + model=model_path, + tensor_parallel_size=2, + gpu_memory_utilization=0.7, + trust_remote_code=True, + worker_extension_cls="torchspec.inference.engine.vllm_worker_extension.VllmWorkerExtension", + max_model_len=2048, + ) + + try: + # Configure hidden states capture + engine.collective_rpc("_setup_hidden_states_capture", args=([5, 10, 15],)) + + # Prepare generation + prompts = [tokenizer.decode(ids) for ids in input_ids_list] + sampling_params = SamplingParams(max_tokens=32, temperature=0) + + # Setup request metadata + request_metadata = {data_ids[i]: len(ids) for i, ids in enumerate(input_ids_list)} + engine.collective_rpc("_reset_capture") + engine.collective_rpc("_set_request_metadata", args=(request_metadata,)) + + # Generate + print("=== Generating with vLLM Worker Extension ===") + outputs = engine.generate(prompts, sampling_params) + assert len(outputs) == len(input_ids_list), "Generation output count mismatch" + + for i, output in enumerate(outputs): + print(f"\n--- Request {i} ---") + print(f"output_ids: {output.prompt_token_ids + list(output.outputs[0].token_ids)}") + print(f"num tokens generated: {len(output.outputs[0].token_ids)}") + + # Retrieve metadata from Mooncake + print("\n=== Retrieving metadata from Mooncake ===") + metadata_list = engine.collective_rpc("_store_and_get_metadata") + assert metadata_list is not None, "No metadata returned from workers" + + all_keys = [] + seq_lens = [] + for metadata in metadata_list: + if isinstance(metadata, dict): + for req_id, meta in metadata.items(): + assert "mooncake_key" in meta + assert "tensor_shapes" in meta + assert "num_layers" in meta + assert meta["num_layers"] == 3 + all_keys.append(meta["mooncake_key"]) + seq_lens.append(request_metadata[req_id]) + print( + f" {req_id}: key={meta['mooncake_key']}, layers={meta['num_layers']}" + ) + + # Fetch data from Mooncake Store + print("\n=== Fetching data from Mooncake Store ===") + mooncake_config = MooncakeConfig.from_env() + mooncake_store = EagleMooncakeStore(mooncake_config) + mooncake_store.setup(device="cuda") + + # Qwen3-8B dimensions + hidden_dim = 12288 # 3 layers concatenated (4096 * 3) + last_hidden_dim = 4096 + + for i, key in enumerate(all_keys): + seq_len = seq_lens[i] + shapes = { + "hidden_states": (seq_len, hidden_dim), + "input_ids": (seq_len,), + "last_hidden_states": (seq_len, last_hidden_dim), + } + dtypes = { + "hidden_states": torch.bfloat16, + "input_ids": torch.long, + "last_hidden_states": torch.bfloat16, + } + + data = mooncake_store.get(key, shapes=shapes, dtypes=dtypes, device="cuda") + print(f"\n Key: {key}") + print( + f" hidden_states: shape={data.hidden_states.shape}, dtype={data.hidden_states.dtype}" + ) + print(f" input_ids: {data.input_ids.tolist()}") + print(f" last_hidden_states: shape={data.last_hidden_states.shape}") + + # Verify tensor device consistency + assert data.hidden_states.device == data.input_ids.device, ( + f"Device mismatch: hidden_states={data.hidden_states.device}, input_ids={data.input_ids.device}" + ) + + print("\n✓ Test completed - hidden states sent to Mooncake and retrieved successfully") + + finally: + # Cleanup + if hasattr(engine, "shutdown"): + engine.shutdown() + + +# ============================================================================= +# Legacy main block (kept for backward compatibility) +# ============================================================================= + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tools/build_conda.sh b/tools/build_conda.sh index 3a38884..3668ce1 100755 --- a/tools/build_conda.sh +++ b/tools/build_conda.sh @@ -6,11 +6,31 @@ SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &> /dev/null && pwd)" PROJECT_ROOT="$(cd -- "$SCRIPT_DIR/.." && pwd)" # Parse command line arguments -# Usage: ./build_conda.sh [MODE] -# 1 - Create new micromamba env and install (default) -# current - Install into current environment -# 0 - Skip env creation and installation +# Usage: ./build_conda.sh [MODE] [BACKEND] +# MODE: +# 1 - Create new micromamba env and install (default) +# current - Install into current environment +# 0 - Skip env creation and installation +# BACKEND: +# sglang - Install SGLang only (default) +# vllm - Install vLLM only +# both - Install both backends + MODE="${1:-1}" +BACKEND="${2:-sglang}" + +# Validate backend +if [[ ! "$BACKEND" =~ ^(sglang|vllm|both)$ ]]; then + echo "Error: Invalid backend '$BACKEND'" + echo "Usage: $0 [MODE] [BACKEND]" + echo " BACKEND options: sglang (default), vllm, both" + exit 1 +fi + +echo "==========================================" +echo "TorchSpec Installation" +echo "Backend: $BACKEND" +echo "==========================================" if [ "$MODE" = "1" ]; then if ! command -v micromamba &> /dev/null; then @@ -31,41 +51,114 @@ else echo "Skipping micromamba setup (mode=0)" fi -SGLANG_VERSION="${SGLANG_VERSION:-v0.5.8.post1}" -SGLANG_COMMIT=0f2df9370a1de1b4fb11b071d39ab3ce2287a350 -SGLANG_FOLDER_NAME="_sglang" +# Install SGLang if requested +if [ "$BACKEND" = "sglang" ] || [ "$BACKEND" = "both" ]; then + echo "==========================================" + echo "Installing SGLang..." + echo "==========================================" + + SGLANG_VERSION="${SGLANG_VERSION:-v0.5.8.post1}" + SGLANG_COMMIT=0f2df9370a1de1b4fb11b071d39ab3ce2287a350 + SGLANG_FOLDER_NAME="_sglang" + + # Install sglang inside the conda environment + if [ ! -d "$PROJECT_ROOT/$SGLANG_FOLDER_NAME" ]; then + git clone https://github.com/sgl-project/sglang.git "$PROJECT_ROOT/$SGLANG_FOLDER_NAME" + fi + + # Avoid pythonpath conflict, because we are using the offline engine. + cd "$PROJECT_ROOT/$SGLANG_FOLDER_NAME" + git checkout $SGLANG_COMMIT + git reset --hard HEAD + + cd "$PROJECT_ROOT" + + if [ "$MODE" = "1" ]; then + micromamba run -n torchspec pip install -e "${SGLANG_FOLDER_NAME}/python[all]" + elif [ "$MODE" = "current" ]; then + pip install -e "${SGLANG_FOLDER_NAME}/python[all]" + fi + + cd "$PROJECT_ROOT/$SGLANG_FOLDER_NAME" -# Install sglang inside the conda environment -if [ ! -d "$PROJECT_ROOT/$SGLANG_FOLDER_NAME" ]; then - git clone https://github.com/sgl-project/sglang.git "$PROJECT_ROOT/$SGLANG_FOLDER_NAME" + # Apply sglang patch (matches Docker build behavior) + git apply "$PROJECT_ROOT/patches/sglang/$SGLANG_VERSION/sglang.patch" + + cd "$PROJECT_ROOT" fi -# Avoid pythonpath conflict, because we are using the offline engine. -cd "$PROJECT_ROOT/$SGLANG_FOLDER_NAME" -git checkout $SGLANG_COMMIT -git reset --hard HEAD +# Install vLLM if requested +if [ "$BACKEND" = "vllm" ] || [ "$BACKEND" = "both" ]; then + echo "==========================================" + echo "Installing vLLM..." + echo "==========================================" -cd "$PROJECT_ROOT" + if [ "$MODE" = "1" ]; then + micromamba run -n torchspec uv pip install "vllm>=0.16.0" + elif [ "$MODE" = "current" ]; then + pip install "vllm>=0.16.0" + fi +fi +# Install torchspec with appropriate extras if [ "$MODE" = "1" ]; then - micromamba run -n torchspec pip install -e "${SGLANG_FOLDER_NAME}/python[all]" - micromamba run -n torchspec uv pip install -e ".[dev]" + echo "==========================================" + echo "Installing TorchSpec..." + echo "==========================================" + + EXTRAS="dev" + if [ "$BACKEND" = "vllm" ]; then + EXTRAS="dev,vllm" + elif [ "$BACKEND" = "both" ]; then + EXTRAS="dev,vllm" + fi + + micromamba run -n torchspec uv pip install -e ".[$EXTRAS]" - echo "torchspec environment setup complete!" + echo "" + echo "==========================================" + echo "✓ TorchSpec environment setup complete!" + echo "==========================================" echo "Activate with: micromamba activate torchspec" + echo "" + if [ "$BACKEND" = "sglang" ]; then + echo "Backend: SGLang" + echo "Run: ./examples/qwen3-8b-single-node/run.sh" + elif [ "$BACKEND" = "vllm" ]; then + echo "Backend: vLLM" + echo "Run: ./examples/qwen3-8b-single-node/run.sh --config configs/vllm_qwen3_8b.yaml" + elif [ "$BACKEND" = "both" ]; then + echo "Backends: SGLang + vLLM" + echo "SGLang: ./examples/qwen3-8b-single-node/run.sh" + echo "vLLM: ./examples/qwen3-8b-single-node/run.sh --config configs/vllm_qwen3_8b.yaml" + fi elif [ "$MODE" = "current" ]; then - pip install -e "${SGLANG_FOLDER_NAME}/python[all]" - pip install -e ".[dev]" + EXTRAS="dev" + if [ "$BACKEND" = "vllm" ]; then + EXTRAS="dev,vllm" + elif [ "$BACKEND" = "both" ]; then + EXTRAS="dev,vllm" + fi + + pip install -e ".[$EXTRAS]" - echo "torchspec installed into current environment!" + echo "" + echo "==========================================" + echo "✓ TorchSpec installed into current environment!" + echo "==========================================" else + echo "" echo "Skipping package installation (mode=0)" echo "Please install packages manually:" - echo " pip install -e \"${SGLANG_FOLDER_NAME}/python[all]\"" - echo " pip install -e \".[dev]\"" + if [ "$BACKEND" = "sglang" ]; then + echo " pip install -e \"${SGLANG_FOLDER_NAME}/python[all]\"" + echo " pip install -e \".[dev]\"" + elif [ "$BACKEND" = "vllm" ]; then + echo " pip install vllm>=0.16.0" + echo " pip install -e \".[dev,vllm]\"" + elif [ "$BACKEND" = "both" ]; then + echo " pip install -e \"${SGLANG_FOLDER_NAME}/python[all]\"" + echo " pip install vllm>=0.16.0" + echo " pip install -e \".[dev,vllm]\"" + fi fi - -cd "$PROJECT_ROOT/$SGLANG_FOLDER_NAME" - -# Apply sglang patch (matches Docker build behavior) -git apply "$PROJECT_ROOT/patches/sglang/$SGLANG_VERSION/sglang.patch" diff --git a/torchspec/config/inference_config.py b/torchspec/config/inference_config.py index d097511..9375dd8 100644 --- a/torchspec/config/inference_config.py +++ b/torchspec/config/inference_config.py @@ -67,6 +67,49 @@ class SGLangConfig: extra_args: Dict[str, Any] = field(default_factory=dict) +@dataclass +class VllmConfig: + """Essential vLLM engine configuration. + + Only fields that TorchSpec explicitly uses are listed here. + Any additional vLLM engine kwargs can be supplied via ``extra_args`` + and will be forwarded as-is. + + Uses vLLM's extract_hidden_states speculative config for hidden states retrieval. + """ + + # Parallelism + tp_size: int = 8 + pp_size: int = 1 + nnodes: int = 1 + + # Memory + mem_fraction_static: float = 0.8 + + # Observability + enable_metrics: bool = False + + # Multimodal + enable_multimodal: bool = False + + # Networking (port is auto-selected by VllmEngine) + dist_init_addr: Optional[str] = None + dist_timeout: int = 60 + init_timeout: int = 300 + + # Hidden states extraction + num_speculative_tokens: int = 1 + + # Use worker extension for hidden states capture (new implementation) + # If False, falls back to LLM class with speculative_config + use_worker_extension: bool = True + + # Passthrough: forwarded as-is to vLLM LLM. + # Use this for any vLLM kwarg that TorchSpec doesn't need to + # inspect (e.g. quantization, max_model_len, trust_remote_code, ...). + extra_args: Dict[str, Any] = field(default_factory=dict) + + @dataclass class InferenceConfig: aux_hidden_states_layers: Optional[list] = None @@ -79,6 +122,7 @@ class InferenceConfig: inference_num_gpus_per_node: int = 8 max_sample_pool_size: int = 0 sglang: SGLangConfig = field(default_factory=SGLangConfig) + vllm: VllmConfig = field(default_factory=VllmConfig) @dataclass diff --git a/torchspec/config/train_config.py b/torchspec/config/train_config.py index 1bacfda..338efc5 100644 --- a/torchspec/config/train_config.py +++ b/torchspec/config/train_config.py @@ -211,6 +211,7 @@ def load_config( _PREFIXED_SECTIONS = { "mooncake": "mooncake_", "sglang": "sglang_", + "vllm": "vllm_", } diff --git a/torchspec/controller/loop.py b/torchspec/controller/loop.py index 206e45e..bd360a0 100644 --- a/torchspec/controller/loop.py +++ b/torchspec/controller/loop.py @@ -333,7 +333,7 @@ def _try_eval(step: int, eval_cached: bool) -> tuple[dict, bool]: if step_time > 0: metrics["perf/train_capacity"] = args.global_batch_size / step_time - if wandb.run is not None: + if getattr(wandb, "run", None) is not None: wandb.log(metrics) # ── Eval at explicit interval (if configured) ───────── diff --git a/torchspec/inference/engine/__init__.py b/torchspec/inference/engine/__init__.py index 7ebc757..ce1a187 100644 --- a/torchspec/inference/engine/__init__.py +++ b/torchspec/inference/engine/__init__.py @@ -21,7 +21,11 @@ from torchspec.inference.engine.base import InferenceEngine from torchspec.inference.engine.hf_engine import HFEngine from torchspec.inference.engine.hf_runner import HFRunner -from torchspec.inference.engine.sgl_engine import SglEngine + +try: + from torchspec.inference.engine.sgl_engine import SglEngine +except ModuleNotFoundError: + SglEngine = None __all__ = [ "InferenceEngine", diff --git a/torchspec/inference/engine/vllm_engine.py b/torchspec/inference/engine/vllm_engine.py new file mode 100644 index 0000000..cff3032 --- /dev/null +++ b/torchspec/inference/engine/vllm_engine.py @@ -0,0 +1,706 @@ +# Copyright (c) 2026 LightSeek Foundation +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +""" +VLLM Ray actor engine for distributed deployment. + +Uses Worker Extension mode with MultiprocExecutor for reliable hidden states +extraction via model.forward patching in worker processes. +""" + +import os +import socket +import tempfile +import uuid +from typing import Any + +import ray +import torch +from omegaconf import DictConfig, OmegaConf + +from torchspec.inference.engine.base import InferenceEngine +from torchspec.ray.ray_actor import RayActor +from torchspec.utils.logging import logger, setup_file_logging +from torchspec.utils.misc import get_default_eagle3_aux_layer_ids + +_PROTECTION_ENGINE_KEYS = frozenset( + { + "model", + "tensor_parallel_size", + "gpu_memory_utilization", + "nnodes", + "node_rank", + "distributed_backend", + } +) + + +class VllmEngine(InferenceEngine, RayActor): + """Ray actor wrapper for vLLM LLM engine with distributed deployment support. + + Uses Worker Extension mode with MultiprocExecutor and VllmWorkerExtension + for reliable hidden states extraction by patching model.forward in worker processes. + """ + + def __init__( + self, + args, + rank: int, + base_gpu_id: int | None = None, + num_gpus_per_engine: int = 1, + node_rank: int = 0, + engine_group: int = 0, + ): + self.args = args + self.rank = rank + self.base_gpu_id = base_gpu_id + self.num_gpus_per_engine = num_gpus_per_engine + self.node_rank = node_rank + self._engine = None + self._mooncake_config = None + self._mooncake_store = None + self._hidden_size = None + self.local_gpu_id = None + self._storage_path = None + setup_file_logging("inference", self.rank, group=engine_group) + + def init(self, mooncake_config=None, dist_init_addr: str | None = None) -> None: + if self.base_gpu_id is not None: + self.local_gpu_id = self.setup_gpu(self.base_gpu_id) + logger.info( + f"VllmEngine rank {self.rank}: base_gpu_id={self.base_gpu_id}, " + f"using local GPU {self.local_gpu_id}" + ) + + self._mooncake_config = mooncake_config + + if mooncake_config is not None: + try: + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + s.connect(("8.8.8.8", 80)) + local_ip = s.getsockname()[0] + s.close() + except Exception: + local_ip = "localhost" + logger.warning( + f"VllmEngine rank {self.rank}: failed to get local IP, using localhost" + ) + + mooncake_config.local_hostname = local_ip + mooncake_config.export_env() + + from torchspec.transfer.mooncake.utils import ( + check_mooncake_master_available, + ) + + check_mooncake_master_available( + mooncake_config.master_server_address, + mooncake_config.metadata_server, + ) + + mem_fraction = getattr(self.args, "vllm_mem_fraction_static", 0.8) + pp_size = getattr(self.args, "vllm_pp_size", 1) + + if self.args.aux_hidden_states_layers is not None: + self.aux_hidden_state_layer_ids = self.args.aux_hidden_states_layers + else: + self.aux_hidden_state_layer_ids = get_default_eagle3_aux_layer_ids( + self.args.target_model_path + ) + if self.rank == 0: + logger.info( + f"Using default aux hidden state layer ids: {self.aux_hidden_state_layer_ids}" + ) + + nnodes = getattr(self.args, "vllm_nnodes", 1) + tp_size = nnodes * self.num_gpus_per_engine + + logger.info( + f"VllmEngine rank {self.rank}: BEFORE init - " + f"base_gpu_id={self.base_gpu_id}, num_gpus={self.num_gpus_per_engine}, " + f"tp_size={tp_size}, pp_size={pp_size}, nnodes={nnodes}, node_rank={self.node_rank}, " + f"aux_hidden_state_layer_ids={self.aux_hidden_state_layer_ids}" + ) + + self._init_engine(tp_size, pp_size, nnodes, mem_fraction, dist_init_addr) + + self._hidden_size = self._get_hidden_size_from_engine() + + if self._mooncake_config is not None: + self._init_mooncake_store() + + logger.info( + f"VllmEngine rank {self.rank}: initialized from {self.args.target_model_path} " + f"(tp_size={tp_size}, aux_layers={self.aux_hidden_state_layer_ids}, hidden_size={self._hidden_size})" + ) + + def _init_engine( + self, + tp_size: int, + pp_size: int, + nnodes: int, + mem_fraction: float, + dist_init_addr: str | None, + ) -> None: + """Initialize the vLLM engine using Worker Extension mode.""" + self._init_worker_extension_mode(tp_size, pp_size, nnodes, mem_fraction, dist_init_addr) + + def _init_worker_extension_mode( + self, + tp_size: int, + pp_size: int, + nnodes: int, + mem_fraction: float, + dist_init_addr: str | None, + ) -> None: + """Initialize LLM with worker extension enabled.""" + from vllm import LLM + + self._storage_path = tempfile.mkdtemp(prefix="vllm_hidden_states_") + + engine_kwargs = { + "model": self.args.target_model_path, + "tensor_parallel_size": tp_size, + "gpu_memory_utilization": mem_fraction, + "trust_remote_code": getattr(self.args, "trust_remote_code", True), + "distributed_executor_backend": "mp", + "disable_custom_all_reduce": True, + "worker_extension_cls": ( + "torchspec.inference.engine.vllm_worker_extension.VllmWorkerExtension" + ), + } + + extra_args = getattr(self.args, "vllm_extra_args", None) + if extra_args: + if isinstance(extra_args, DictConfig): + extra = OmegaConf.to_container(extra_args, resolve=True) + else: + extra = dict(extra_args) if not isinstance(extra_args, dict) else extra_args + blocked = extra.keys() & _PROTECTION_ENGINE_KEYS + if blocked: + logger.warning( + f"vllm extra_args contains protected keys that will be ignored: " + f"{sorted(blocked)}. These are managed internally by TorchSpec." + ) + extra = {k: v for k, v in extra.items() if k not in _PROTECTION_ENGINE_KEYS} + engine_kwargs.update(extra) + + max_seq_length = getattr(self.args, "max_seq_length", None) + if max_seq_length: + engine_kwargs["max_model_len"] = max_seq_length + # Disable chunked prefill to encourage single-step processing + if "enable_chunked_prefill" not in engine_kwargs: + engine_kwargs["enable_chunked_prefill"] = False + + if nnodes > 1: + engine_kwargs["nnodes"] = nnodes + engine_kwargs["node_rank"] = self.node_rank + if dist_init_addr: + engine_kwargs["distributed_backend"] = "nccl" + engine_kwargs["distributed_init_address"] = dist_init_addr + + self._engine = LLM(**engine_kwargs) + self._setup_rpc_hidden_states_capture() + logger.info( + f"VllmEngine rank {self.rank}: initialized worker extension mode " + f"with layers={self.aux_hidden_state_layer_ids}" + ) + + def _setup_rpc_hidden_states_capture(self) -> None: + """Initialize worker-side hidden-state capture hooks.""" + if self._engine is None: + raise RuntimeError("VllmEngine not initialized. Call init() first.") + if not hasattr(self._engine, "collective_rpc"): + raise RuntimeError("vLLM LLM.collective_rpc is required for worker extension mode") + + # Set environment variables so workers can connect to Mooncake + if self._mooncake_config is not None: + import os + + os.environ["TORCHSPEC_MOONCAKE_MASTER_ADDR"] = ( + self._mooncake_config.master_server_address + ) + os.environ["TORCHSPEC_MOONCAKE_METADATA_PORT"] = str( + self._mooncake_config.metadata_server.split(":")[-1].replace("/metadata", "") + ) + os.environ["TORCHSPEC_MOONCAKE_LOCAL_HOSTNAME"] = self._mooncake_config.local_hostname + os.environ["TORCHSPEC_MOONCAKE_PROTOCOL"] = self._mooncake_config.protocol + if self._mooncake_config.device_name: + os.environ["TORCHSPEC_MOONCAKE_DEVICE_NAME"] = self._mooncake_config.device_name + logger.info( + f"VllmEngine rank {self.rank}: Set Mooncake env vars for workers: " + f"master={self._mooncake_config.master_server_address}" + ) + + layer_ids = list(self.aux_hidden_state_layer_ids) + results = self._engine.collective_rpc( + "_setup_hidden_states_capture", + args=(layer_ids,), + ) + logger.info(f"VllmEngine rank {self.rank}: worker capture setup replies={results}") + + def generate( + self, + data_id: str | list[str], + input_ids_ref: ray.ObjectRef | list[torch.Tensor] | None = None, + packed_loss_mask_list: list[str] | None = None, + formatted_prompts: list[str] | None = None, + return_last_hidden_states: bool = False, + return_logits: bool = True, + multimodal_inputs: list[dict] | None = None, + ) -> list[dict]: + """Generate hidden states for training data using Worker Extension mode.""" + return self._generate_worker_extension( + data_id, + input_ids_ref, + packed_loss_mask_list, + formatted_prompts, + return_last_hidden_states, + return_logits, + multimodal_inputs, + ) + + def _generate_worker_extension( + self, + data_id: str | list[str], + input_ids_ref: ray.ObjectRef | list[torch.Tensor] | None, + packed_loss_mask_list: list[str] | None, + formatted_prompts: list[str] | None, + return_last_hidden_states: bool, + return_logits: bool, + multimodal_inputs: list[dict] | None, + ) -> list[dict]: + """Generate using worker extension mode.""" + if self._engine is None: + raise RuntimeError("VllmEngine not initialized. Call init() first.") + + if (input_ids_ref is None) == (formatted_prompts is None): + raise ValueError("Exactly one of input_ids_ref or formatted_prompts must be set") + + use_prompts = formatted_prompts is not None + input_ids_list: list[torch.Tensor] | None = None + + if use_prompts: + batch_size = len(formatted_prompts) + prompts = formatted_prompts + else: + if isinstance(input_ids_ref, ray.ObjectRef): + input_ids_list = ray.get(input_ids_ref) + else: + input_ids_list = input_ids_ref + if input_ids_list is None: + raise ValueError("input_ids_ref resolved to None") + batch_size = len(input_ids_list) + prompts = self._convert_input_ids_to_prompts(input_ids_list) + + if isinstance(data_id, str): + data_ids = [f"{data_id}_{i}" for i in range(batch_size)] + elif len(data_id) == batch_size: + data_ids = data_id + else: + raise ValueError( + f"data_id length {len(data_id)} does not match batch size {batch_size}" + ) + + from vllm import SamplingParams + + sampling_params = SamplingParams(max_tokens=1, temperature=0) + request_metadata = {} + if input_ids_list is not None: + for i, ids in enumerate(input_ids_list): + request_metadata[data_ids[i]] = int(self._normalize_input_ids(ids).numel()) + + # Build packed_loss_mask_map for workers + packed_loss_mask_map = {} + if packed_loss_mask_list is not None: + for i, data_id in enumerate(data_ids): + if i < len(packed_loss_mask_list): + packed_loss_mask_map[data_id] = packed_loss_mask_list[i] + + # Build input_ids_map for workers (pass real input_ids via RPC) + input_ids_map = {} + if input_ids_list is not None: + for i, data_id in enumerate(data_ids): + if i < len(input_ids_list): + ids = self._normalize_input_ids(input_ids_list[i]) + input_ids_map[data_id] = ids.cpu().tolist() + + try: + self._engine.collective_rpc("_reset_capture") + if request_metadata: + self._engine.collective_rpc( + "_set_request_metadata", + args=(request_metadata, packed_loss_mask_map, input_ids_map), + ) + except Exception as e: + logger.warning(f"Could not reset capture via worker extension: {e}") + + outputs = self._engine.generate(prompts, sampling_params) + + # Get metadata from workers (tensors are already stored in Mooncake by workers) + metadata_by_request: dict[str, dict] = {} + try: + # Workers store tensors directly to Mooncake and return metadata only + metadata_list = self._engine.collective_rpc("_store_and_get_metadata") + if isinstance(metadata_list, list): + for metadata in metadata_list: + if isinstance(metadata, dict): + metadata_by_request.update(metadata) + elif isinstance(metadata_list, dict): + metadata_by_request = metadata_list + except Exception as e: + logger.warning(f"Could not get metadata from worker extension: {e}") + + results = [] + for i, output in enumerate(outputs): + seq_len = len(output.prompt_token_ids) + data_id = data_ids[i] + + # Get metadata for this request + metadata = metadata_by_request.get(data_id) + if metadata is None: + logger.error( + f"VllmEngine rank {self.rank}: No metadata for data_id={data_id}. " + f"Training may be corrupted." + ) + continue + + # Extract info from metadata (tensors are already in Mooncake) + mooncake_key = metadata.get("mooncake_key", data_id) + tensor_shapes = metadata.get("tensor_shapes", {}) + tensor_dtypes = metadata.get("tensor_dtypes", {}) + + result = { + "mooncake_key": mooncake_key, + "tensor_shapes": tensor_shapes, + "tensor_dtypes": tensor_dtypes, + "data_id": data_id, + "seq_len": seq_len, + } + # Get packed_loss_mask from metadata (returned by worker) + packed_loss_mask = metadata.get("packed_loss_mask") + if packed_loss_mask is not None: + result["packed_loss_mask"] = packed_loss_mask + # Get input_ids_list from metadata (returned by worker via RPC) + input_ids_list = metadata.get("input_ids_list") + if input_ids_list is not None: + result["input_ids_list"] = input_ids_list + results.append(result) + + # No need to flush here - workers already flushed after storing + + logger.debug( + f"VllmEngine rank {self.rank}: generated {len(results)} mooncake results " + f"for data_ids={data_ids}" + ) + return results + + def _init_mooncake_store(self) -> None: + if self._mooncake_store is not None or self._mooncake_config is None: + return + from torchspec.transfer.mooncake.eagle_store import EagleMooncakeStore + + self._mooncake_store = EagleMooncakeStore(self._mooncake_config) + if torch.cuda.is_available(): + self._mooncake_store.setup(device=torch.cuda.current_device()) + else: + self._mooncake_store.setup() + + def _normalize_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + if input_ids.dim() == 2 and input_ids.shape[0] == 1: + return input_ids.squeeze(0) + if input_ids.dim() == 1: + return input_ids + raise ValueError(f"Unexpected input_ids shape: {input_ids.shape}") + + def _get_sample_input_ids( + self, + index: int, + input_ids_list: list[torch.Tensor] | None, + output: Any, + ) -> torch.Tensor: + if input_ids_list is not None: + return self._normalize_input_ids(input_ids_list[index]).to(dtype=torch.long) + return torch.tensor(output.prompt_token_ids, dtype=torch.long) + + def _merge_captured_states( + self, + captured_states: Any, + ) -> tuple[dict[str, list[torch.Tensor]], list[list[torch.Tensor]]]: + merged: dict[str, list[torch.Tensor]] = {} + ordered: list[list[torch.Tensor]] = [] + + # Handle different return types from collective_rpc + if captured_states is None: + return merged, ordered + + # If it's a single dict, wrap it in a list + if isinstance(captured_states, dict): + captured_states = [captured_states] + + if not isinstance(captured_states, list): + logger.warning(f"Unexpected captured_states type: {type(captured_states)}") + return merged, ordered + + # Collect layer states from all workers for each request + # With tensor parallelism, we need to concatenate along hidden dim + request_states: dict[str, list[list[torch.Tensor]]] = {} + + for reply in captured_states: + if not isinstance(reply, dict): + logger.debug(f"Skipping non-dict reply: {type(reply)}") + continue + for request_id, layer_states in reply.items(): + if not isinstance(layer_states, list): + logger.debug( + f"Skipping non-list layer_states for {request_id}: {type(layer_states)}" + ) + continue + if request_id not in request_states: + request_states[request_id] = [] + request_states[request_id].append(layer_states) + + # Merge states: concatenate tensors from different workers along hidden dim + for request_id, worker_states_list in request_states.items(): + if not worker_states_list: + continue + + # Get number of layers from first worker + num_layers = len(worker_states_list[0]) + logger.debug( + f"Merging {len(worker_states_list)} workers for request {request_id} with {num_layers} layers" + ) + + # Concatenate tensors from all workers for each layer + merged_layers = [] + for layer_idx in range(num_layers): + layer_tensors = [ + worker_states[layer_idx] + for worker_states in worker_states_list + if layer_idx < len(worker_states) + ] + + # Check if layer_tensors contains lists (nested structure) + if layer_tensors and isinstance(layer_tensors[0], list): + # This shouldn't happen after proper extraction, but handle it + logger.warning(f"Unexpected nested list structure for layer {layer_idx}") + layer_tensors = [ + item + for sublist in layer_tensors + for item in (sublist if isinstance(sublist, list) else [sublist]) + ] + + if len(layer_tensors) == 1: + merged_layers.append(layer_tensors[0]) + elif len(layer_tensors) > 1: + # Concatenate along hidden dimension (dim=-1) + merged_layers.append(torch.cat(layer_tensors, dim=-1)) + else: + # No tensors for this layer + logger.warning(f"No tensors for layer {layer_idx} in request {request_id}") + merged_layers.append(None) # type: ignore[arg-type] + + merged[request_id] = merged_layers + ordered.append(merged_layers) + + return merged, ordered + + def _store_tensors_to_mooncake( + self, + data_id: str, + input_ids: torch.Tensor, + hidden_states: torch.Tensor, + last_hidden_states: torch.Tensor | None, + ) -> tuple[str, dict[str, tuple[int, ...]], dict[str, torch.dtype]] | None: + if self._mooncake_store is None: + self._init_mooncake_store() + if self._mooncake_store is None: + return None + + if input_ids.dtype != torch.long: + input_ids = input_ids.to(dtype=torch.long) + if hidden_states.dtype != torch.bfloat16: + hidden_states = hidden_states.to(dtype=torch.bfloat16) + if last_hidden_states is not None and last_hidden_states.dtype != torch.bfloat16: + last_hidden_states = last_hidden_states.to(dtype=torch.bfloat16) + + mooncake_key = f"vllm_{self.rank}_{data_id}_{uuid.uuid4().hex}" + tensor_shapes = self._mooncake_store.put( + key=mooncake_key, + hidden_states=hidden_states, + input_ids=input_ids, + last_hidden_states=last_hidden_states, + target=None, + ) + tensor_dtypes = { + "hidden_states": hidden_states.dtype, + "input_ids": input_ids.dtype, + "last_hidden_states": ( + last_hidden_states.dtype if last_hidden_states is not None else hidden_states.dtype + ), + } + return mooncake_key, tensor_shapes, tensor_dtypes + + def _store_sample_to_mooncake( + self, + data_id: str, + input_ids: torch.Tensor, + layer_states: list[torch.Tensor] | None, + hidden_states_path: str | None, + ) -> tuple[str, dict[str, tuple[int, ...]], dict[str, torch.dtype]] | None: + if layer_states: + # Debug: log the structure of layer_states + logger.debug(f"layer_states type: {type(layer_states)}, len: {len(layer_states)}") + if layer_states: + logger.debug(f"layer_states[0] type: {type(layer_states[0])}") + if isinstance(layer_states[0], list): + logger.error(f"layer_states[0] is a list with len {len(layer_states[0])}") + # Flatten the list if needed + layer_states = [ + item + for sublist in layer_states + for item in (sublist if isinstance(sublist, list) else [sublist]) + ] + logger.debug(f"After flattening: layer_states len: {len(layer_states)}") + + # Filter out any non-tensor elements + layer_states = [ls for ls in layer_states if isinstance(ls, torch.Tensor)] + + if not layer_states: + logger.error(f"No valid tensor layers found for data_id={data_id}") + return None + + hidden_states = ( + torch.cat(layer_states, dim=-1) if len(layer_states) > 1 else layer_states[0] + ) + last_hidden_states = layer_states[-1] + return self._store_tensors_to_mooncake( + data_id=data_id, + input_ids=input_ids, + hidden_states=hidden_states, + last_hidden_states=last_hidden_states, + ) + + if hidden_states_path is None or not os.path.exists(hidden_states_path): + return None + + data = torch.load(hidden_states_path, map_location="cpu") + hidden_states = data.get("hidden_states") + if not isinstance(hidden_states, torch.Tensor): + return None + stored_input_ids = data.get("input_ids") + if isinstance(stored_input_ids, torch.Tensor): + input_ids = self._normalize_input_ids(stored_input_ids) + last_hidden_states = data.get("last_hidden_states") + if not isinstance(last_hidden_states, torch.Tensor): + if self._hidden_size is not None and hidden_states.shape[-1] >= self._hidden_size: + last_hidden_states = hidden_states[:, -self._hidden_size :] + else: + last_hidden_states = hidden_states + + return self._store_tensors_to_mooncake( + data_id=data_id, + input_ids=input_ids, + hidden_states=hidden_states, + last_hidden_states=last_hidden_states, + ) + + def _convert_input_ids_to_prompts( + self, input_ids_list: list[torch.Tensor] + ) -> list[dict[str, list[int]]]: + prompts = [] + for ids in input_ids_list: + prompts.append({"prompt_token_ids": self._normalize_input_ids(ids).tolist()}) + return prompts + + def health_check(self, timeout: float = 5.0) -> bool: + return self._engine is not None + + def shutdown(self) -> None: + if self._mooncake_store is not None: + try: + self._mooncake_store.close() + except Exception as e: + logger.warning(f"VllmEngine rank {self.rank}: Error closing mooncake store: {e}") + self._mooncake_store = None + + if self._engine is not None: + del self._engine + self._engine = None + + if self._storage_path and os.path.exists(self._storage_path): + import shutil + + try: + shutil.rmtree(self._storage_path) + except Exception: + pass + + logger.info(f"VllmEngine rank {self.rank}: shutdown complete") + + def get_status(self) -> dict: + return { + "rank": self.rank, + "initialized": self._engine is not None, + "base_gpu_id": self.base_gpu_id, + "hidden_size": self._hidden_size, + } + + def _get_hidden_size_from_engine(self) -> int: + from transformers import AutoConfig + + config = AutoConfig.from_pretrained( + self.args.target_model_path, + trust_remote_code=getattr(self.args, "trust_remote_code", True), + ) + hidden_size = getattr(config, "hidden_size", None) + if hidden_size is None: + text_config = getattr(config, "text_config", None) + if text_config is not None: + hidden_size = getattr(text_config, "hidden_size", None) + if hidden_size is None: + raise ValueError( + f"Could not determine hidden_size from model config: {self.args.target_model_path}" + ) + return hidden_size + + def _get_tensor_shapes(self, seq_len: int) -> dict: + aux_hidden_state_layer_ids = self.aux_hidden_state_layer_ids + num_aux_layers = len(aux_hidden_state_layer_ids) + if self._hidden_size is None: + raise ValueError( + f"VllmEngine rank {self.rank}: hidden_size not initialized. Call init() first." + ) + hidden_size = self._hidden_size + + concat_hidden_size = num_aux_layers * hidden_size + + return { + "hidden_states": (seq_len, concat_hidden_size), + "input_ids": (seq_len,), + "last_hidden_states": (seq_len, hidden_size), + } + + def _get_tensor_dtypes(self) -> dict: + return { + "hidden_states": torch.bfloat16, + "input_ids": torch.long, + "last_hidden_states": torch.bfloat16, + } diff --git a/torchspec/inference/engine/vllm_worker_extension.py b/torchspec/inference/engine/vllm_worker_extension.py new file mode 100644 index 0000000..867de58 --- /dev/null +++ b/torchspec/inference/engine/vllm_worker_extension.py @@ -0,0 +1,781 @@ +# Copyright (c) 2026 LightSeek Foundation +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +"""vLLM Worker Extension for Hidden States Capture. + +This module provides a TorchSpec-style worker extension for vLLM that enables +reliable hidden states extraction during inference. It patches the model's +forward method in each worker process to capture intermediate layer activations +and store them directly to Mooncake to avoid RPC serialization issues. + +Based on the vllm-speculators approach but integrated into TorchSpec's +architecture with Ray Actors and Mooncake storage. +""" + +import logging +import os +import re +import types +from collections import defaultdict +from itertools import islice +from typing import Any, Dict, List, Optional + +import torch +from vllm.distributed import get_pp_group, get_tp_group +from vllm.sequence import IntermediateTensors + +logger = logging.getLogger(__name__) + + +def _sanitize_mooncake_key(key: str) -> str: + """Sanitize a key for use with Mooncake store. + + Mooncake keys should only contain alphanumeric characters, hyphens, and underscores. + This function replaces invalid characters with underscores. + + Args: + key: The original key (e.g., vLLM req_id) + + Returns: + A sanitized key safe for Mooncake operations + """ + sanitized = re.sub(r"[^a-zA-Z0-9_-]", "_", key) + if sanitized and sanitized[0].isdigit(): + sanitized = "k" + sanitized + return sanitized + + +def _patched_forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: Any, +) -> Any: + """Patched forward pass that captures hidden states from specified layers. + + This function is dynamically bound to base_model instances via types.MethodType. + It expects base_model to have an _extension attribute pointing to the + VllmWorkerExtension instance. + + Args: + input_ids: Input token IDs + positions: Position IDs + intermediate_tensors: For pipeline parallelism + inputs_embeds: Pre-computed input embeddings (for multimodal) + **kwargs: Additional arguments + + Returns: + Hidden states or IntermediateTensors (for PP) + """ + # Get extension reference + extension = self._extension # noqa: SLF001 + + # Handle pipeline parallelism - first rank does embedding + if get_pp_group().is_first_rank: + hidden_states = ( + inputs_embeds if inputs_embeds is not None else self.embed_input_ids(input_ids) + ) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + # Track auxiliary hidden states for capture + aux_hidden_states: List[torch.Tensor] = [] + + # Only capture on TP rank 0 to avoid duplicates + should_capture = get_tp_group().rank_in_group == 0 + target_layers = extension._layer_ids if should_capture else frozenset() # noqa: SLF001 + + # Capture input_ids only on first call (prefill phase) to avoid including generated tokens + if should_capture and get_pp_group().is_first_rank and extension._captured_input_ids is None: + # input_ids shape: (batch_size, seq_len) or (seq_len,) + if input_ids.dim() == 2: + # Flatten batch dimension + extension._captured_input_ids = input_ids.view(-1).clone() + else: + extension._captured_input_ids = input_ids.clone() + + # Process each layer + for idx, layer in enumerate(islice(self.layers, self.start_layer, self.end_layer)): + hidden_states, residual = layer( + hidden_states=hidden_states, + positions=positions, + residual=residual, + ) + absolute_layer_idx = self.start_layer + idx + + # Capture intermediate layers (not the last) before normalization + if absolute_layer_idx in target_layers: + # Add residual before capturing (matching speculators pattern) + captured = ( + (hidden_states + residual).clone() + if residual is not None + else hidden_states.clone() + ) + aux_hidden_states.append(captured) + + # Handle pipeline parallelism - return intermediate tensors if not last rank + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states, "residual": residual}) + + # Final normalization (only on last PP rank) + hidden_states, _ = self.norm(hidden_states, residual) + + # Store captured states (only on last PP rank and TP rank 0) + if should_capture and aux_hidden_states: + extension._store_captured_states(aux_hidden_states) # noqa: SLF001 + + return hidden_states + + +class VllmWorkerExtension: + """Worker extension that adds hidden states capture functionality to vLLM. + + This extension hooks into vLLM's Worker by being specified in the worker + initialization. It patches the model's forward pass to intercept and capture + intermediate layer hidden states during inference. + + Key behaviors: + - Only captures on tensor parallel (TP) rank 0 to avoid duplicate data when + using tensor parallelism. All TP ranks compute the same hidden states, so + capturing from rank 0 is sufficient. + - Stores captured states in GPU memory during batch processing, then writes + directly to Mooncake to avoid RPC serialization issues. + - Supports pipeline parallelism by handling IntermediateTensors correctly. + - Tracks request metadata to map captured states back to original requests + across chunked prefill iterations. + + Attributes: + _layer_ids: Frozenset of layer indices for O(1) lookup during capture + _captured_states: Accumulated hidden states per layer (GPU tensors) + _request_metadata: Metadata tracking tokens per request per iteration + _mooncake_store: EagleMooncakeStore instance for direct storage + model_runner: Reference to the vLLM model runner + """ + + def __init__(self): + """Initialize the worker extension with Mooncake store support.""" + self._layer_ids: frozenset = frozenset() + self._captured_states: Optional[List[List[torch.Tensor]]] = None + self._request_metadata: List[Dict[str, int]] = [] + self._current_request_metadata: Optional[Dict[str, int]] = None + self._mooncake_store: Optional[Any] = None + self._store_initialized: bool = False + self._store_setup_complete: bool = False + self._init_retry_count: int = 0 + self._max_init_retries: int = 3 + self.model_runner: Optional[Any] = None + + def _get_cuda_device_safe(self) -> torch.device: + """Safely get CUDA device, handling uninitialized context (V1 compatibility). + + In vLLM V1, CUDA context may not be initialized when this method is called. + This method safely handles both initialized and uninitialized contexts. + + Returns: + torch.device: The CUDA device to use. Falls back to cuda:0 if context + is not yet initialized (common in V1 engine). + """ + try: + if torch.cuda.is_initialized(): + current_device = torch.cuda.current_device() + logger.debug(f"CUDA initialized, using device cuda:{current_device}") + return torch.device(f"cuda:{current_device}") + else: + # CUDA not initialized yet (V1), use device 0 as fallback + # V1 will initialize context during model loading + logger.debug("CUDA not initialized yet (V1), falling back to cuda:0") + return torch.device("cuda:0") + except RuntimeError as e: + # CUDA context not available + logger.warning(f"Failed to get CUDA device: {e}, falling back to cuda:0") + return torch.device("cuda:0") + + def _init_mooncake_store(self) -> bool: + """Initialize Mooncake store connection in the worker. + + Uses environment variables set by the main process to connect to + the Mooncake master and metadata servers. + + Returns: + True if initialization successful, False otherwise. + """ + if self._store_initialized: + return True + + # Only initialize on TP rank 0 - other ranks don't capture hidden states + try: + if get_tp_group().rank_in_group != 0: + logger.debug("Skipping Mooncake store init on non-zero TP rank") + return False + except Exception: + # If we can't get TP group info, proceed anyway (for backward compatibility) + pass + + try: + # Import here to avoid circular dependencies + from torchspec.config.mooncake_config import MooncakeConfig + from torchspec.transfer.mooncake.eagle_store import EagleMooncakeStore + + # Get connection info from environment (set by main process) + # Try TORCHSPEC_MOONCAKE_* first, then fall back to MOONCAKE_* + master_addr = os.environ.get("TORCHSPEC_MOONCAKE_MASTER_ADDR") or os.environ.get( + "MOONCAKE_MASTER_SERVER" + ) + metadata_server = os.environ.get( + "TORCHSPEC_MOONCAKE_METADATA_SERVER" + ) or os.environ.get("MOONCAKE_METADATA_SERVER") + local_hostname = os.environ.get("TORCHSPEC_MOONCAKE_LOCAL_HOSTNAME") or os.environ.get( + "MOONCAKE_LOCAL_HOSTNAME", "localhost" + ) + protocol = os.environ.get("TORCHSPEC_MOONCAKE_PROTOCOL") or os.environ.get( + "MOONCAKE_PROTOCOL", "tcp" + ) + + if not master_addr: + logger.warning( + "Mooncake master address not available in worker environment. " + "Set TORCHSPEC_MOONCAKE_MASTER_ADDR or MOONCAKE_MASTER_SERVER environment variable." + ) + return False + + # Parse metadata_server to get port if not explicitly set + if metadata_server: + # Extract port from URL like "http://host:port/metadata" + try: + metadata_port = metadata_server.split(":")[-1].replace("/metadata", "") + except Exception: + metadata_port = "8090" + else: + metadata_port = "8090" + metadata_server = f"http://{master_addr.split(':')[0]}:{metadata_port}/metadata" + + # Read buffer sizes from environment (set by main process via export_env) + host_buffer_size_env = os.environ.get( + "TORCHSPEC_MOONCAKE_HOST_BUFFER_SIZE" + ) or os.environ.get("MOONCAKE_HOST_BUFFER_SIZE") + + # Build config kwargs with optional overrides from environment + config_kwargs = { + "master_server_address": master_addr, + "metadata_server": metadata_server, + "local_hostname": local_hostname, + "protocol": protocol, + "device_name": os.environ.get("TORCHSPEC_MOONCAKE_DEVICE_NAME") + or os.environ.get("MOONCAKE_DEVICE_NAME", ""), + "async_put_pool_size": int( + os.environ.get("TORCHSPEC_MOONCAKE_ASYNC_POOL_SIZE") + or os.environ.get("MOONCAKE_ASYNC_PUT_POOL_SIZE", "2") + ), + "enable_gpu_direct": ( + os.environ.get("TORCHSPEC_MOONCAKE_GPU_DIRECT") + or os.environ.get("MOONCAKE_ENABLE_GPU_DIRECT", "0") + ).lower() + in ("true", "1", "yes"), + } + + # Only override defaults if environment variables are set + if host_buffer_size_env: + config_kwargs["host_buffer_size"] = int(host_buffer_size_env) + + global_segment_size_env = os.environ.get( + "TORCHSPEC_MOONCAKE_GLOBAL_SEGMENT_SIZE" + ) or os.environ.get("MOONCAKE_GLOBAL_SEGMENT_SIZE") + if global_segment_size_env: + config_kwargs["global_segment_size"] = int(global_segment_size_env) + + local_buffer_size_env = os.environ.get( + "TORCHSPEC_MOONCAKE_LOCAL_BUFFER_SIZE" + ) or os.environ.get("MOONCAKE_LOCAL_BUFFER_SIZE") + if local_buffer_size_env: + config_kwargs["local_buffer_size"] = int(local_buffer_size_env) + + # Create config for worker + config = MooncakeConfig(**config_kwargs) + + # Create store object but don't call setup() yet + # setup() will be called lazily when CUDA context is ready + self._mooncake_store = EagleMooncakeStore(config) + # Mark as initialized but not yet setup + # setup() will be called on first put() when CUDA context is ready + self._store_initialized = True + self._store_setup_complete = False + + logger.info( + f"Worker initialized Mooncake store (setup deferred): master={master_addr}, protocol={protocol}" + ) + return True + + except Exception as e: + logger.error(f"Failed to initialize Mooncake store in worker: {e}", exc_info=True) + self._mooncake_store = None + self._store_initialized = False + return False + + def _ensure_mooncake_store(self) -> bool: + """Ensure Mooncake store is initialized and setup, with retry logic. + + This method handles lazy initialization for vLLM V1 compatibility. + In V1, CUDA context may not be ready during initial Worker initialization, + so we defer the actual setup() call until first use. + + Returns: + True if store is ready for use, False otherwise. + """ + # Ensure attributes exist (for vLLM V1 compatibility where __init__ may not be called) + if not hasattr(self, "_store_initialized"): + self._store_initialized = False + if not hasattr(self, "_store_setup_complete"): + self._store_setup_complete = False + if not hasattr(self, "_init_retry_count"): + self._init_retry_count = 0 + if not hasattr(self, "_max_init_retries"): + self._max_init_retries = 3 + if not hasattr(self, "_mooncake_store"): + self._mooncake_store = None + + # Already fully initialized and setup + if self._store_initialized and self._store_setup_complete: + return True + + # Check retry limit + if self._init_retry_count >= self._max_init_retries: + logger.error( + f"Max retries ({self._max_init_retries}) exceeded for Mooncake store initialization" + ) + return False + + try: + # Initialize store if not already done + if not self._store_initialized: + if not self._init_mooncake_store(): + self._init_retry_count += 1 + logger.warning( + f"Mooncake store init failed (attempt {self._init_retry_count}/{self._max_init_retries})" + ) + return False + + # Setup store if not already done + if not self._store_setup_complete and self._mooncake_store is not None: + try: + # Use safe CUDA device detection for V1 compatibility + device = self._get_cuda_device_safe() + logger.info(f"Setting up Mooncake store on device {device}") + self._mooncake_store.setup(device=device) + + try: + logger.info("Warming up Mooncake RDMA path...") + self._mooncake_store.warmup_rdma() + logger.info("Mooncake RDMA warmup completed successfully") + except Exception as warmup_error: + logger.warning(f"Mooncake RDMA warmup failed: {warmup_error}") + + self._store_setup_complete = True + logger.info("Mooncake store setup completed successfully") + return True + except Exception as e: + self._init_retry_count += 1 + # Check if this is a CUDA context error (common in V1) + error_msg = str(e).lower() + if "cuda" in error_msg or "device" in error_msg: + logger.warning( + f"CUDA context not ready (attempt {self._init_retry_count}/{self._max_init_retries}): {e}. " + f"Will retry on next put." + ) + else: + logger.error(f"Mooncake store setup failed: {e}") + return False + + return True + + except Exception as e: + self._init_retry_count += 1 + logger.error(f"Unexpected error in _ensure_mooncake_store: {e}", exc_info=True) + return False + + def _store_captured_states(self, aux_hidden_states: List[torch.Tensor]) -> None: + """Store captured hidden states from a forward pass. + + Args: + aux_hidden_states: List of tensors, one per target layer + """ + if getattr(self, "_captured_states", None) is None: + # First capture - initialize lists for each layer + self._captured_states = [[h] for h in aux_hidden_states] + else: + # Append to existing lists + for i, h in enumerate(aux_hidden_states): + self._captured_states[i].append(h) + + # Track how many tokens were captured in this step + # Get from input_batch if available, otherwise use metadata + model_runner = getattr(self, "model_runner", None) + input_batch = getattr(model_runner, "input_batch", None) + if input_batch is not None and hasattr(input_batch, "req_ids"): + # Track by internal request IDs - will map to external IDs later + step_tokens = {} + for req_id in input_batch.req_ids: + num_tokens = 0 + req_idx = getattr(input_batch, "req_id_to_index", {}).get(req_id) + if req_idx is not None: + num_computed = getattr( + input_batch, "num_computed_tokens", [0] * len(input_batch.req_ids) + )[req_idx] + num_total = getattr(input_batch, "num_tokens", [0] * len(input_batch.req_ids))[ + req_idx + ] + num_tokens = num_total - num_computed + step_tokens[req_id] = num_tokens + self._request_metadata.append(step_tokens) + else: + # Fallback: assume all requests in one step + self._request_metadata.append({}) + + def _store_input_ids(self, input_ids: torch.Tensor) -> None: + """Store input_ids from a forward pass. + + Args: + input_ids: Input token IDs tensor (batch_size, seq_len) or (seq_len,) + """ + # Flatten if needed and store + if input_ids.dim() == 2: + # (batch_size, seq_len) - flatten to (batch_size * seq_len,) + input_ids = input_ids.view(-1) + if getattr(self, "_captured_input_ids", None) is None: + self._captured_input_ids = input_ids.clone() + else: + self._captured_input_ids = torch.cat([self._captured_input_ids, input_ids], dim=0) + + def _setup_hidden_states_capture(self, layer_ids: List[int]) -> None: + """Setup model to capture auxiliary hidden states from specific layers. + + This method patches the model's forward method to intercept hidden states + during the forward pass. + + Args: + layer_ids: List of layer indices to capture from + """ + self._layer_ids = frozenset(layer_ids) + self._captured_states = None + self._request_metadata = [] + self._current_request_metadata = None + self._packed_loss_mask_map: Dict[str, str] = {} + self._store_initialized = False + self._store_setup_complete = False + self._init_retry_count = 0 + self._mooncake_store = None + + model_runner = getattr(self, "model_runner", None) + if model_runner is None and hasattr(self, "model"): + model_runner = self + if model_runner is None: + raise AttributeError("Could not find model_runner for worker extension setup") + + self.model_runner = model_runner + model = self.model_runner.model # type: ignore[attr-defined] + + # Handle vision-language models (e.g., Qwen-VL) + if hasattr(model, "get_language_model"): + base_model = model.get_language_model().model + # Handle standard text models + elif hasattr(model, "model") and hasattr(model.model, "layers"): + base_model = model.model + else: + # Try to find model with layers attribute + attrs = [a for a in dir(model) if not a.startswith("_")] + raise AttributeError( + f"Could not find base model with 'layers' attribute. " + f"Model type: {type(model).__name__}, " + f"Available attributes: {attrs}" + ) + + # Attach extension reference and patch forward method + base_model._extension = self # noqa: SLF001 + base_model.forward = types.MethodType(_patched_forward, base_model) + + logger.info(f"Hidden states capture setup complete for layers {layer_ids}") + + def _set_request_metadata( + self, + request_metadata: Dict[str, int], + packed_loss_mask_map: Optional[Dict[str, str]] = None, + input_ids_map: Optional[Dict[str, List[int]]] = None, + ) -> None: + """Set request metadata for the next forward pass. + + This is called before each scheduler iteration to track which tokens + belong to which request. + + Args: + request_metadata: Dict mapping request_id -> num_prefill_tokens + packed_loss_mask_map: Optional dict mapping request_id -> packed_loss_mask string + input_ids_map: Optional dict mapping request_id -> input_ids list (passed via RPC) + """ + self._current_request_metadata = request_metadata + self._packed_loss_mask_map = packed_loss_mask_map or {} + self._input_ids_map = input_ids_map or {} + + def _reset_capture(self) -> None: + """Reset captured states before starting a new batch. + + Must be called before processing a new batch of requests. + """ + if not hasattr(self, "_layer_ids") or len(self._layer_ids) == 0: + raise RuntimeError("Must call _setup_hidden_states_capture before capturing states") + self._captured_states = None + self._captured_input_ids: Optional[torch.Tensor] = None + self._request_metadata = [] + self._current_request_metadata = None + self._packed_loss_mask_map = {} + self._input_ids_map = {} + + def _store_and_get_metadata(self) -> Optional[Dict[str, Dict[str, Any]]]: + """Store captured hidden states to Mooncake and return metadata. + + This method stores tensors directly to Mooncake from the worker process, + avoiding RPC serialization issues. It returns only lightweight metadata + that can be safely serialized and returned via collective_rpc. + + Returns: + Dict mapping request_id to metadata dict with keys: + - 'mooncake_key': str, the base key used for storage + - 'tensor_shapes': dict of tensor shapes + - 'tensor_dtypes': dict of dtype names + - 'num_layers': int, number of captured layers + or None if no states captured or not on TP rank 0. + """ + # Only TP rank 0 has captured data + if get_tp_group().rank_in_group != 0: + return None + if self._captured_states is None: + return None + + # Ensure Mooncake store is initialized and setup (with retry for V1 compatibility) + if not self._ensure_mooncake_store(): + logger.error( + "Failed to initialize/setup Mooncake store, cannot store hidden states. " + "This may be due to CUDA context not being ready in V1 engine." + ) + return None + + # Concatenate captured states from all scheduler iterations + concatenated_layers = [ + torch.cat(layer_tensors, dim=0) for layer_tensors in self._captured_states + ] + total_captured_tokens = concatenated_layers[0].shape[0] + + # Slice and group by request using external IDs + external_ids = ( + list(self._current_request_metadata.keys()) if self._current_request_metadata else [] + ) + token_counts = ( + list(self._current_request_metadata.values()) if self._current_request_metadata else [] + ) + total_expected_tokens = sum(token_counts) if token_counts else 0 + + request_chunks: defaultdict[str, List[List[torch.Tensor]]] = defaultdict( + lambda: [[] for _ in range(len(concatenated_layers))] + ) + current_idx = 0 + + # Handle multi-step scheduling:按比例分配实际捕获的tokens + if total_expected_tokens > 0 and total_captured_tokens > 0: + ratio = total_captured_tokens / total_expected_tokens + for req_idx, (external_id, expected_tokens) in enumerate( + zip(external_ids, token_counts) + ): + # 按比例分配实际捕获的token数 + actual_tokens = int(expected_tokens * ratio) + for req_idx, (external_id, expected_tokens) in enumerate( + zip(external_ids, token_counts) + ): + # 按比例分配实际捕获的token数 + actual_tokens = int(expected_tokens * ratio) + actual_tokens = min(actual_tokens, total_captured_tokens - current_idx) + if actual_tokens > 0: + for layer_idx, layer_tensor in enumerate(concatenated_layers): + chunk = ( + layer_tensor[current_idx : current_idx + actual_tokens].clone().cpu() + ) + request_chunks[external_id][layer_idx].append(chunk) + current_idx += actual_tokens + else: + # Fallback: simple sequential slicing + for req_idx, (external_id, num_tokens) in enumerate(zip(external_ids, token_counts)): + if current_idx < total_captured_tokens: + for layer_idx, layer_tensor in enumerate(concatenated_layers): + chunk = layer_tensor[current_idx : current_idx + num_tokens].clone().cpu() + request_chunks[external_id][layer_idx].append(chunk) + current_idx += num_tokens + + # Store to Mooncake and collect metadata + result: Dict[str, Dict[str, Any]] = {} + for req_id, layer_chunks in request_chunks.items(): + mooncake_key = _sanitize_mooncake_key(req_id) + if mooncake_key != req_id: + logger.debug(f"Sanitized key '{req_id}' -> '{mooncake_key}'") + + # Concatenate all layer chunks for this request (keep on GPU) + layer_tensors = [torch.cat(chunks, dim=0) for chunks in layer_chunks] + + # Concatenate all layers along hidden dimension + if len(layer_tensors) > 1: + hidden_states = torch.cat(layer_tensors, dim=-1) + else: + hidden_states = layer_tensors[0] + + last_hidden_states = layer_tensors[-1] + + # Use real input_ids from RPC, otherwise create dummy + if req_id in self._input_ids_map: + input_ids_list = self._input_ids_map[req_id] + input_ids = torch.tensor( + input_ids_list, dtype=torch.long, device=hidden_states.device + ) + else: + seq_len = hidden_states.shape[0] + input_ids = torch.zeros(seq_len, dtype=torch.long, device=hidden_states.device) + + # Skip empty tensors + if hidden_states.numel() == 0: + logger.error(f"Request {req_id}: hidden_states is empty! Skipping.") + continue + + try: + logger.debug( + f"Storing to Mooncake: key={mooncake_key}, " + f"hidden_states_shape={hidden_states.shape}" + ) + + # Store to Mooncake + tensor_shapes = self._mooncake_store.put( + key=mooncake_key, + hidden_states=hidden_states, + input_ids=input_ids, + last_hidden_states=last_hidden_states, + target=None, + ) + + logger.debug(f"Successfully stored to Mooncake: key={mooncake_key}") + + # Convert dtype to string for RPC serialization + # Include input_ids as a list for reconstruction (avoids Mooncake storage issues) + result[req_id] = { + "mooncake_key": mooncake_key, + "tensor_shapes": tensor_shapes, + "tensor_dtypes": { + "hidden_states": str(hidden_states.dtype).replace("torch.", ""), + "input_ids": str(input_ids.dtype).replace("torch.", ""), + "last_hidden_states": str(last_hidden_states.dtype).replace("torch.", ""), + }, + "num_layers": len(layer_tensors), + "packed_loss_mask": self._packed_loss_mask_map.get(req_id), + "input_ids_list": input_ids.cpu().tolist(), # Serialize via RPC instead of Mooncake + } + except Exception as e: + logger.error( + f"Failed to store tensors to Mooncake for {req_id} (key={mooncake_key}): {e}" + ) + # Continue with other requests even if one fails + continue + + # Flush to ensure all writes are complete before returning + if self._mooncake_store is not None: + self._mooncake_store.flush() + + # Clear intermediate storage to free memory + self._captured_states = None + self._captured_input_ids = None + self._request_metadata = [] + self._input_ids_map = {} + + return result if result else None + + def _get_captured_states(self) -> Optional[Dict[str, List[torch.Tensor]]]: + """Legacy method - now delegates to _store_and_get_metadata. + + This method is kept for backward compatibility but should not be used + in production due to RPC serialization issues. Use _store_and_get_metadata + instead which stores tensors directly to Mooncake. + + Returns: + Dict mapping request_id to list of tensors (one per layer), + or None if no states captured. + """ + # If Mooncake store is available, use the new method + if self._store_initialized or self._init_mooncake_store(): + metadata = self._store_and_get_metadata() + if metadata is None: + return None + # Return empty dict to signal success - actual data is in Mooncake + return {} + + # Fallback to old behavior if Mooncake not available + if self._captured_states is None: + return None + + # Concatenate captured states from all scheduler iterations + concatenated_layers = [ + torch.cat(layer_tensors, dim=0) for layer_tensors in self._captured_states + ] + + # Slice and group by request + request_chunks: defaultdict[str, List[List[torch.Tensor]]] = defaultdict( + lambda: [[] for _ in range(len(concatenated_layers))] + ) + current_idx = 0 + + # Use external IDs for slicing + external_ids = ( + list(self._current_request_metadata.keys()) if self._current_request_metadata else [] + ) + token_counts = ( + list(self._current_request_metadata.values()) if self._current_request_metadata else [] + ) + + req_idx = 0 + for step_metadata in self._request_metadata: + step_tokens = sum(step_metadata.values()) if step_metadata else 0 + if step_tokens == 0 and req_idx < len(token_counts): + step_tokens = token_counts[req_idx] + + if req_idx < len(external_ids): + external_id = external_ids[req_idx] + for layer_idx, layer_tensor in enumerate(concatenated_layers): + chunk = layer_tensor[current_idx : current_idx + step_tokens].clone().cpu() + request_chunks[external_id][layer_idx].append(chunk) + current_idx += step_tokens + req_idx += 1 + + # Concatenate chunks for each request across iterations + result: Dict[str, List[torch.Tensor]] = { + req_id: [torch.cat(chunks, dim=0) for chunks in layer_chunks] + for req_id, layer_chunks in request_chunks.items() + } + + # Clear intermediate storage to free memory + self._captured_states = None + self._request_metadata = [] + + return result diff --git a/torchspec/inference/factory.py b/torchspec/inference/factory.py index eefc40c..054b816 100644 --- a/torchspec/inference/factory.py +++ b/torchspec/inference/factory.py @@ -24,7 +24,7 @@ from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy from torchspec.inference.engine.hf_engine import HFEngine -from torchspec.inference.engine.sgl_engine import SglEngine +from torchspec.inference.engine.vllm_engine import VllmEngine from torchspec.utils.env import get_torchspec_env_vars from torchspec.utils.logging import logger @@ -36,7 +36,7 @@ def create_inference_engines(args, inference_pg, mooncake_config, engine_group: int = 0): """Create inference engines based on configured engine type (blocking). - Supports "hf" and "sgl" engine types via inference_engine_type config. + Supports "hf", "sgl", and "vllm" engine types via inference_engine_type config. Returns: List of head engines used for dispatching requests. Multi-node TP @@ -44,7 +44,7 @@ def create_inference_engines(args, inference_pg, mooncake_config, engine_group: """ engine_type = getattr(args, "inference_engine_type", "hf") - if engine_type not in ("hf", "sgl"): + if engine_type not in ("hf", "sgl", "vllm"): raise ValueError(f"Unknown inference_engine_type: {engine_type}") logger.info(f"Using {engine_type} engine for inference") @@ -76,15 +76,19 @@ def prepare_inference_engines(args, inference_pg, mooncake_config, engine_group: """ engine_type = getattr(args, "inference_engine_type", "hf") - if engine_type not in ("hf", "sgl"): + if engine_type not in ("hf", "sgl", "vllm"): raise ValueError(f"Unknown inference_engine_type: {engine_type}") logger.info(f"Preparing {engine_type} inference engines...") if engine_type == "hf": engines, init_refs = _prepare_hf_engines(args, inference_pg, mooncake_config, engine_group) - else: + elif engine_type == "sgl": engines, init_refs = _prepare_sgl_engines(args, inference_pg, mooncake_config, engine_group) + else: + engines, init_refs = _prepare_vllm_engines( + args, inference_pg, mooncake_config, engine_group + ) return engines, init_refs @@ -95,7 +99,7 @@ def init_engines(args, pg, engine_type: str, mooncake_config=None, engine_group: Args: args: Configuration arguments. pg: Placement group tuple (pg, reordered_bundle_indices, reordered_gpu_ids). - engine_type: Engine type ("hf" or "sgl"). + engine_type: Engine type ("hf", "sgl", or "vllm"). mooncake_config: MooncakeConfig object. Returns: @@ -105,6 +109,8 @@ def init_engines(args, pg, engine_type: str, mooncake_config=None, engine_group: return _init_hf_engines(args, pg, mooncake_config, engine_group) elif engine_type == "sgl": return _init_sgl_engines(args, pg, mooncake_config, engine_group) + elif engine_type == "vllm": + return _init_vllm_engines(args, pg, mooncake_config, engine_group) else: raise ValueError(f"Unknown engine_type: {engine_type}") @@ -160,6 +166,8 @@ def _prepare_sgl_engines( accept generate() calls. init_handles are ObjectRefs for ALL engines (head + worker) that must be waited on before use. """ + from torchspec.inference.engine.sgl_engine import SglEngine + nnodes = getattr(args, "sglang_nnodes", 1) num_gpus_total = getattr(args, "inference_num_gpus", 1) @@ -268,6 +276,125 @@ def _init_sgl_engines(args, pg, mooncake_config=None, engine_group: int = 0) -> return head_engines +def _prepare_vllm_engines( + args, pg, mooncake_config=None, engine_group: int = 0 +) -> tuple[list, list]: + """Create vLLM engine actors and fire init calls without waiting. + + Handles three cases: + - Single-node, multiple engines: one engine per group of GPUs + - Multi-node, single replica: one engine per node, all forming one TP group + - Multi-node, multiple replicas: N independent TP groups, each spanning nnodes + + For multi-node, worker engines are stored in a module-level list to prevent GC. + + Returns: + Tuple of (head_engines, init_handles). head_engines are the engines that + accept generate() calls. init_handles are ObjectRefs for ALL engines + (head + worker) that must be waited on before use. + """ + nnodes = getattr(args, "vllm_nnodes", 1) + num_gpus_total = getattr(args, "inference_num_gpus", 1) + + if nnodes > 1: + gpus_per_node = getattr(args, "inference_num_gpus_per_node", 8) + gpus_per_replica = nnodes * gpus_per_node + num_replicas = num_gpus_total // gpus_per_replica + num_engines = num_replicas * nnodes + gpus_per_engine = gpus_per_node + else: + gpus_per_engine = getattr(args, "inference_num_gpus_per_engine", 1) + num_replicas = num_gpus_total // gpus_per_engine + num_engines = num_replicas + + logger.info( + f"Initializing {num_engines} vLLM engines " + f"({gpus_per_engine} GPU(s) each, nnodes={nnodes}, replicas={num_replicas})" + ) + + pg_obj, reordered_bundle_indices, reordered_gpu_ids = pg + VllmRayActor = ray.remote(VllmEngine) + env_vars = get_torchspec_env_vars() + + engines = [] + for i in range(num_engines): + node_rank = i % nnodes if nnodes > 1 else 0 + + bundle_offset = i * gpus_per_engine + base_gpu_id = int(reordered_gpu_ids[bundle_offset]) + + scheduling_strategy = PlacementGroupSchedulingStrategy( + placement_group=pg_obj, + placement_group_capture_child_tasks=True, + placement_group_bundle_index=reordered_bundle_indices[bundle_offset], + ) + + engine = VllmRayActor.options( + num_cpus=0.2, + num_gpus=0.2, + scheduling_strategy=scheduling_strategy, + runtime_env={"env_vars": env_vars}, + ).remote( + args=args, + rank=i, + base_gpu_id=base_gpu_id, + num_gpus_per_engine=gpus_per_engine, + node_rank=node_rank, + engine_group=engine_group, + ) + engines.append(engine) + + dist_init_addrs: dict[int, str] = {} + if nnodes > 1: + configured_addr = getattr(args, "vllm_dist_init_addr", None) + for replica_idx in range(num_replicas): + if configured_addr and num_replicas == 1: + dist_init_addrs[replica_idx] = configured_addr + logger.info( + f"Replica {replica_idx}: using configured dist_init_addr: {configured_addr}" + ) + else: + head_engine = engines[replica_idx * nnodes] + ip, port = ray.get( + [head_engine.get_node_ip.remote(), head_engine.find_free_port.remote()], + timeout=30, + ) + addr = f"{ip}:{port}" + dist_init_addrs[replica_idx] = addr + logger.info(f"Replica {replica_idx}: auto-negotiated dist_init_addr: {addr}") + + init_handles = [] + for i, engine in enumerate(engines): + replica_idx = i // nnodes if nnodes > 1 else i + init_handles.append( + engine.init.remote( + mooncake_config=mooncake_config, + dist_init_addr=dist_init_addrs.get(replica_idx), + ) + ) + + if nnodes > 1: + head_engines = [engines[i] for i in range(num_engines) if i % nnodes == 0] + worker_engines = [engines[i] for i in range(num_engines) if i % nnodes != 0] + _alive_worker_engines.extend(worker_engines) + logger.info( + f"Prepared multi-node vLLM engines: {len(head_engines)} heads + " + f"{len(worker_engines)} workers ({num_replicas} replicas)" + ) + return head_engines, init_handles + + return engines, init_handles + + +def _init_vllm_engines(args, pg, mooncake_config=None, engine_group: int = 0) -> list: + """Initialize vLLM engines with Ray placement groups (blocking).""" + head_engines, init_handles = _prepare_vllm_engines(args, pg, mooncake_config, engine_group) + nnodes = getattr(args, "vllm_nnodes", 1) + init_timeout = getattr(args, "vllm_init_timeout", 300 if nnodes == 1 else 600) + _wait_for_init(init_handles, "Vllm", timeout=init_timeout) + return head_engines + + def _create_and_init_actors( args, pg, diff --git a/torchspec/training/data_fetcher.py b/torchspec/training/data_fetcher.py index e41f5e1..18d3f5e 100644 --- a/torchspec/training/data_fetcher.py +++ b/torchspec/training/data_fetcher.py @@ -74,7 +74,17 @@ def __init__( def _load_from_mooncake(self, sample: TrainSample) -> Dict[str, Any]: """Load tensors from mooncake key into device memory.""" - dtypes = sample.tensor_dtypes or {} + dtypes_raw = sample.tensor_dtypes or {} + + # Convert string dtypes to torch.dtype objects + dtypes = {} + for key, dtype_val in dtypes_raw.items(): + if isinstance(dtype_val, str): + # Handle "bfloat16" or "torch.bfloat16" format + dtype_str = dtype_val.replace("torch.", "") + dtypes[key] = getattr(torch, dtype_str) + else: + dtypes[key] = dtype_val # DEBUG: Print the shapes we're requesting logger.debug( From 3e06c7b4a3b22ca5ceac60dd453520123c18c634 Mon Sep 17 00:00:00 2001 From: FlamingoPg <1106310035@qq.com> Date: Fri, 6 Mar 2026 15:41:25 +0800 Subject: [PATCH 02/10] upd --- configs/vllm_qwen3_8b.yaml | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/configs/vllm_qwen3_8b.yaml b/configs/vllm_qwen3_8b.yaml index 5fe7126..9075c32 100644 --- a/configs/vllm_qwen3_8b.yaml +++ b/configs/vllm_qwen3_8b.yaml @@ -24,10 +24,6 @@ dataset: chat_template: qwen prompt_key: conversations -# Use GPUs 4-7 to avoid zombie process on GPU 0-1 -# GPU 4-5: inference (vLLM, TP=2) -# GPU 6-7: training (DP=2) - training: attention_backend: flex_attention micro_batch_size: 1 @@ -66,7 +62,7 @@ mooncake: global_segment_size: 16GB local_buffer_size: 4GB -output_dir: ./outputs/vllm_qwen3_8b-single-node +output_dir: ./outputs/qwen3-8b-single-node cache_dir: ./cache model_download_dir: null From 7e152d9cdcdc55237cbf4b85a49917db7ada2c70 Mon Sep 17 00:00:00 2001 From: FlamingoPg <1106310035@qq.com> Date: Fri, 6 Mar 2026 15:45:44 +0800 Subject: [PATCH 03/10] fix import --- torchspec/inference/engine/__init__.py | 6 ++++++ torchspec/inference/factory.py | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/torchspec/inference/engine/__init__.py b/torchspec/inference/engine/__init__.py index ce1a187..12e41a5 100644 --- a/torchspec/inference/engine/__init__.py +++ b/torchspec/inference/engine/__init__.py @@ -27,9 +27,15 @@ except ModuleNotFoundError: SglEngine = None +try: + from torchspec.inference.engine.vllm_engine import VllmEngine +except ModuleNotFoundError: + VllmEngine = None + __all__ = [ "InferenceEngine", "HFEngine", "HFRunner", "SglEngine", + "VllmEngine", ] diff --git a/torchspec/inference/factory.py b/torchspec/inference/factory.py index 054b816..348eb05 100644 --- a/torchspec/inference/factory.py +++ b/torchspec/inference/factory.py @@ -24,6 +24,7 @@ from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy from torchspec.inference.engine.hf_engine import HFEngine +from torchspec.inference.engine.sgl_engine import SglEngine from torchspec.inference.engine.vllm_engine import VllmEngine from torchspec.utils.env import get_torchspec_env_vars from torchspec.utils.logging import logger @@ -166,7 +167,6 @@ def _prepare_sgl_engines( accept generate() calls. init_handles are ObjectRefs for ALL engines (head + worker) that must be waited on before use. """ - from torchspec.inference.engine.sgl_engine import SglEngine nnodes = getattr(args, "sglang_nnodes", 1) num_gpus_total = getattr(args, "inference_num_gpus", 1) From d58e727086470718afeb2435570e6c98648390a6 Mon Sep 17 00:00:00 2001 From: Fan Yin <1106310035@qq.com> Date: Fri, 6 Mar 2026 15:47:19 +0800 Subject: [PATCH 04/10] Update dataset paths to use relative paths --- configs/vllm_qwen3_8b.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/configs/vllm_qwen3_8b.yaml b/configs/vllm_qwen3_8b.yaml index 9075c32..21d4f40 100644 --- a/configs/vllm_qwen3_8b.yaml +++ b/configs/vllm_qwen3_8b.yaml @@ -18,8 +18,8 @@ model: trust_remote_code: true dataset: - train_data_path: examples/data/sample_conversations.jsonl - eval_data_path: examples/data/eval_conversations.jsonl + train_data_path: ./examples/data/sample_conversations.jsonl + eval_data_path: ./examples/data/eval_conversations.jsonl eval_interval: 100 chat_template: qwen prompt_key: conversations From 3d9d634cc5bd9feed5ac7e164ceab674e8cb5dae Mon Sep 17 00:00:00 2001 From: Fan Yin <1106310035@qq.com> Date: Fri, 6 Mar 2026 15:50:13 +0800 Subject: [PATCH 05/10] Refactor imports in __init__.py for engines Removed conditional imports for SglEngine and VllmEngine, now importing them directly. --- torchspec/inference/engine/__init__.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/torchspec/inference/engine/__init__.py b/torchspec/inference/engine/__init__.py index 12e41a5..fdf43f4 100644 --- a/torchspec/inference/engine/__init__.py +++ b/torchspec/inference/engine/__init__.py @@ -21,16 +21,8 @@ from torchspec.inference.engine.base import InferenceEngine from torchspec.inference.engine.hf_engine import HFEngine from torchspec.inference.engine.hf_runner import HFRunner - -try: - from torchspec.inference.engine.sgl_engine import SglEngine -except ModuleNotFoundError: - SglEngine = None - -try: - from torchspec.inference.engine.vllm_engine import VllmEngine -except ModuleNotFoundError: - VllmEngine = None +from torchspec.inference.engine.sgl_engine import SglEngine +from torchspec.inference.engine.vllm_engine import VllmEngine __all__ = [ "InferenceEngine", From fcd55071d628aa39142145730c040b3840d963a1 Mon Sep 17 00:00:00 2001 From: Yubo Wang Date: Fri, 6 Mar 2026 23:11:40 +0000 Subject: [PATCH 06/10] fix vllm integration --- ...e.py => test_sglang_engine_integration.py} | 0 tests/test_vllm_engine_integration.py | 223 +++++++++++++ torchspec/inference/engine/vllm_engine.py | 303 ++++-------------- .../inference/engine/vllm_worker_extension.py | 175 ++++++---- torchspec/ray/placement_group.py | 13 +- 5 files changed, 405 insertions(+), 309 deletions(-) rename tests/{test_sglang_engine.py => test_sglang_engine_integration.py} (100%) create mode 100644 tests/test_vllm_engine_integration.py diff --git a/tests/test_sglang_engine.py b/tests/test_sglang_engine_integration.py similarity index 100% rename from tests/test_sglang_engine.py rename to tests/test_sglang_engine_integration.py diff --git a/tests/test_vllm_engine_integration.py b/tests/test_vllm_engine_integration.py new file mode 100644 index 0000000..4c8f98e --- /dev/null +++ b/tests/test_vllm_engine_integration.py @@ -0,0 +1,223 @@ +"""Standalone integration script that tests vLLM Worker Extension hidden states collection behavior. + +Tests: + 1. Short sequences via input_ids (basic capture) + 2. Longer sequences via input_ids + 3. formatted_prompts path (defer tokenization mode) +""" + +import os +import socket + +import torch +from transformers import AutoTokenizer + +from torchspec.transfer.mooncake import EagleMooncakeStore, MooncakeConfig + +# Detect local IP for Mooncake connections +try: + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + s.connect(("8.8.8.8", 80)) + LOCAL_IP = s.getsockname()[0] + s.close() +except Exception: + LOCAL_IP = "localhost" + +# Mooncake env vars for MooncakeConfig.from_env() (retrieval side) +os.environ["MOONCAKE_MASTER_HOST"] = LOCAL_IP +os.environ["MOONCAKE_MASTER_PORT"] = "50051" +os.environ["MOONCAKE_METADATA_PORT"] = "8090" +os.environ["MOONCAKE_LOCAL_HOSTNAME"] = LOCAL_IP + +# Mooncake env vars for worker extension (storage side) +os.environ["TORCHSPEC_MOONCAKE_MASTER_ADDR"] = f"{LOCAL_IP}:50051" +os.environ["TORCHSPEC_MOONCAKE_LOCAL_HOSTNAME"] = LOCAL_IP +os.environ["TORCHSPEC_MOONCAKE_PROTOCOL"] = "tcp" + + +def collect_metadata(engine, internal_to_external=None): + """Call _store_and_get_metadata and merge results from all TP ranks.""" + args = (internal_to_external,) if internal_to_external else () + metadata_list = engine.collective_rpc("_store_and_get_metadata", args=args) + merged = {} + if isinstance(metadata_list, list): + for m in metadata_list: + if isinstance(m, dict): + merged.update(m) + elif isinstance(metadata_list, dict): + merged = metadata_list + return merged + + +def verify_from_mooncake(mooncake_store, keys, seq_lens, hidden_dim, last_hidden_dim): + """Fetch tensors from Mooncake and verify shapes.""" + for i, key in enumerate(keys): + seq_len = seq_lens[i] + shapes = { + "hidden_states": (seq_len, hidden_dim), + "input_ids": (seq_len,), + "last_hidden_states": (seq_len, last_hidden_dim), + } + dtypes = { + "hidden_states": torch.bfloat16, + "input_ids": torch.long, + "last_hidden_states": torch.bfloat16, + } + data = mooncake_store.get(key, shapes=shapes, dtypes=dtypes, device="cuda") + print(f"\n Key: {key}") + print(f" hidden_states: shape={data.hidden_states.shape}, dtype={data.hidden_states.dtype}") + print(f" input_ids: {data.input_ids.tolist()[:10]}{'...' if seq_len > 10 else ''}") + print(f" last_hidden_states: shape={data.last_hidden_states.shape}") + + assert data.hidden_states.shape == (seq_len, hidden_dim), ( + f"hidden_states shape {data.hidden_states.shape} != expected {(seq_len, hidden_dim)}" + ) + assert data.input_ids.shape == (seq_len,) + assert data.last_hidden_states.shape == (seq_len, last_hidden_dim) + + +if __name__ == "__main__": + model_path = "Qwen/Qwen3-8B" + aux_layer_ids = [2, 4, 6] + tp_size = 4 + hidden_size = 4096 + num_aux_layers = len(aux_layer_ids) + hidden_dim = num_aux_layers * hidden_size + last_hidden_dim = hidden_size + + tokenizer = AutoTokenizer.from_pretrained(model_path) + + from vllm import LLM, SamplingParams + + engine = LLM( + model=model_path, + tensor_parallel_size=tp_size, + gpu_memory_utilization=0.7, + trust_remote_code=True, + distributed_executor_backend="mp", + disable_custom_all_reduce=True, + disable_log_stats=True, + worker_extension_cls="torchspec.inference.engine.vllm_worker_extension.VllmWorkerExtension", + max_model_len=2048, + enable_chunked_prefill=False, + ) + + engine.collective_rpc("_setup_hidden_states_capture", args=(aux_layer_ids,)) + + mooncake_config = MooncakeConfig.from_env() + mooncake_store = EagleMooncakeStore(mooncake_config) + mooncake_store.setup(device="cuda") + + sampling_params = SamplingParams(max_tokens=1, temperature=0) + + # ========================================================================= + # Test 1: Short sequences + # ========================================================================= + print("\n" + "=" * 60) + print("TEST 1: Short sequences") + print("=" * 60) + + input_ids_list = [ + [1, 2345, 6789], + [100, 200, 300, 400], + [500, 600], + ] + data_ids = ["short_0", "short_1", "short_2"] + + prompts = [{"prompt_token_ids": ids} for ids in input_ids_list] + request_metadata = {data_ids[i]: len(ids) for i, ids in enumerate(input_ids_list)} + input_ids_map = {data_ids[i]: ids for i, ids in enumerate(input_ids_list)} + + engine.collective_rpc("_reset_capture") + engine.collective_rpc("_set_request_metadata", args=(request_metadata, {}, input_ids_map)) + + outputs = engine.generate(prompts, sampling_params, use_tqdm=False) + for i, output in enumerate(outputs): + print(f" Request {i}: {len(output.prompt_token_ids)} prompt tokens") + + metadata = collect_metadata(engine) + all_keys = [metadata[did]["mooncake_key"] for did in data_ids] + seq_lens = [request_metadata[did] for did in data_ids] + assert len(metadata) == len(data_ids), f"Expected {len(data_ids)} results, got {len(metadata)}" + + verify_from_mooncake(mooncake_store, all_keys, seq_lens, hidden_dim, last_hidden_dim) + print("\n✓ Test 1 passed") + + # ========================================================================= + # Test 2: Longer sequences + # ========================================================================= + print("\n" + "=" * 60) + print("TEST 2: Longer sequences") + print("=" * 60) + + long_input_ids_list = [ + list(range(1, 101)), + list(range(200, 351)), + list(range(400, 465)), + ] + long_data_ids = ["long_0", "long_1", "long_2"] + + prompts = [{"prompt_token_ids": ids} for ids in long_input_ids_list] + request_metadata = {long_data_ids[i]: len(ids) for i, ids in enumerate(long_input_ids_list)} + input_ids_map = {long_data_ids[i]: ids for i, ids in enumerate(long_input_ids_list)} + + engine.collective_rpc("_reset_capture") + engine.collective_rpc("_set_request_metadata", args=(request_metadata, {}, input_ids_map)) + + outputs = engine.generate(prompts, sampling_params, use_tqdm=False) + for i, output in enumerate(outputs): + print(f" Request {i}: {len(output.prompt_token_ids)} prompt tokens") + + metadata = collect_metadata(engine) + all_keys = [metadata[did]["mooncake_key"] for did in long_data_ids] + seq_lens = [request_metadata[did] for did in long_data_ids] + assert len(metadata) == len(long_data_ids), ( + f"Expected {len(long_data_ids)} results, got {len(metadata)}" + ) + + verify_from_mooncake(mooncake_store, all_keys, seq_lens, hidden_dim, last_hidden_dim) + print("\n✓ Test 2 passed") + + # ========================================================================= + # Test 3: formatted_prompts path (defer tokenization) + # ========================================================================= + print("\n" + "=" * 60) + print("TEST 3: formatted_prompts path (defer tokenization)") + print("=" * 60) + + text_prompts = [ + "Hello, world!", + "The quick brown fox jumps over the lazy dog.", + "Once upon a time", + ] + prompt_data_ids = ["prompt_0", "prompt_1", "prompt_2"] + + engine.collective_rpc("_reset_capture") + + outputs = engine.generate(text_prompts, sampling_params, use_tqdm=False) + + # Build metadata from outputs post-generation (same as VllmEngine does) + request_metadata = {} + input_ids_map = {} + for i, output in enumerate(outputs): + did = prompt_data_ids[i] + request_metadata[did] = len(output.prompt_token_ids) + input_ids_map[did] = list(output.prompt_token_ids) + print(f" Request {i}: \"{text_prompts[i]}\" -> {len(output.prompt_token_ids)} tokens") + engine.collective_rpc("_set_request_metadata", args=(request_metadata, {}, input_ids_map)) + + metadata = collect_metadata(engine) + all_keys = [metadata[did]["mooncake_key"] for did in prompt_data_ids] + seq_lens = [request_metadata[did] for did in prompt_data_ids] + assert len(metadata) == len(prompt_data_ids), ( + f"Expected {len(prompt_data_ids)} results, got {len(metadata)}" + ) + + verify_from_mooncake(mooncake_store, all_keys, seq_lens, hidden_dim, last_hidden_dim) + print("\n✓ Test 3 passed") + + # ========================================================================= + print("\n" + "=" * 60) + print("All tests passed!") + print("=" * 60) + del engine diff --git a/torchspec/inference/engine/vllm_engine.py b/torchspec/inference/engine/vllm_engine.py index cff3032..d1569e7 100644 --- a/torchspec/inference/engine/vllm_engine.py +++ b/torchspec/inference/engine/vllm_engine.py @@ -27,9 +27,6 @@ import os import socket -import tempfile -import uuid -from typing import Any import ray import torch @@ -78,7 +75,7 @@ def __init__( self._mooncake_store = None self._hidden_size = None self.local_gpu_id = None - self._storage_path = None + setup_file_logging("inference", self.rank, group=engine_group) def init(self, mooncake_config=None, dist_init_addr: str | None = None) -> None: @@ -158,23 +155,10 @@ def _init_engine( nnodes: int, mem_fraction: float, dist_init_addr: str | None, - ) -> None: - """Initialize the vLLM engine using Worker Extension mode.""" - self._init_worker_extension_mode(tp_size, pp_size, nnodes, mem_fraction, dist_init_addr) - - def _init_worker_extension_mode( - self, - tp_size: int, - pp_size: int, - nnodes: int, - mem_fraction: float, - dist_init_addr: str | None, ) -> None: """Initialize LLM with worker extension enabled.""" from vllm import LLM - self._storage_path = tempfile.mkdtemp(prefix="vllm_hidden_states_") - engine_kwargs = { "model": self.args.target_model_path, "tensor_parallel_size": tp_size, @@ -202,12 +186,24 @@ def _init_worker_extension_mode( extra = {k: v for k, v in extra.items() if k not in _PROTECTION_ENGINE_KEYS} engine_kwargs.update(extra) + inference_batch_size = getattr(self.args, "inference_batch_size", None) + if inference_batch_size is not None: + comp_cfg = engine_kwargs.get("compilation_config", {}) + if isinstance(comp_cfg, dict) and "max_cudagraph_capture_size" not in comp_cfg: + comp_cfg["max_cudagraph_capture_size"] = inference_batch_size + engine_kwargs["compilation_config"] = comp_cfg + logger.info( + f"VllmEngine rank {self.rank}: defaulting " + f"max_cudagraph_capture_size={inference_batch_size} from inference_batch_size" + ) + + # Disable prefix caching and chunked prefill + engine_kwargs["enable_prefix_caching"] = False + engine_kwargs["enable_chunked_prefill"] = False + max_seq_length = getattr(self.args, "max_seq_length", None) if max_seq_length: engine_kwargs["max_model_len"] = max_seq_length - # Disable chunked prefill to encourage single-step processing - if "enable_chunked_prefill" not in engine_kwargs: - engine_kwargs["enable_chunked_prefill"] = False if nnodes > 1: engine_kwargs["nnodes"] = nnodes @@ -267,27 +263,6 @@ def generate( multimodal_inputs: list[dict] | None = None, ) -> list[dict]: """Generate hidden states for training data using Worker Extension mode.""" - return self._generate_worker_extension( - data_id, - input_ids_ref, - packed_loss_mask_list, - formatted_prompts, - return_last_hidden_states, - return_logits, - multimodal_inputs, - ) - - def _generate_worker_extension( - self, - data_id: str | list[str], - input_ids_ref: ray.ObjectRef | list[torch.Tensor] | None, - packed_loss_mask_list: list[str] | None, - formatted_prompts: list[str] | None, - return_last_hidden_states: bool, - return_logits: bool, - multimodal_inputs: list[dict] | None, - ) -> list[dict]: - """Generate using worker extension mode.""" if self._engine is None: raise RuntimeError("VllmEngine not initialized. Call init() first.") @@ -308,7 +283,7 @@ def _generate_worker_extension( if input_ids_list is None: raise ValueError("input_ids_ref resolved to None") batch_size = len(input_ids_list) - prompts = self._convert_input_ids_to_prompts(input_ids_list) + prompts = self._format_input_ids_for_vllm(input_ids_list) if isinstance(data_id, str): data_ids = [f"{data_id}_{i}" for i in range(batch_size)] @@ -352,13 +327,40 @@ def _generate_worker_extension( except Exception as e: logger.warning(f"Could not reset capture via worker extension: {e}") - outputs = self._engine.generate(prompts, sampling_params) + outputs = self._engine.generate(prompts, sampling_params, use_tqdm=False) + + # outputs are sorted by int(request_id), matching submission order. + # Build mapping from vLLM's internal worker IDs ("{request_id}-{uuid}") + # to our external data_ids. + internal_to_external = {} + for i, output in enumerate(outputs): + internal_to_external[output.request_id] = data_ids[i] + + # For the formatted_prompts path, request_metadata and input_ids_map + # were not set before generation (no input_ids available). Build them + # from the outputs so the worker can map captured states to requests. + if use_prompts and not request_metadata: + for i, output in enumerate(outputs): + did = data_ids[i] + request_metadata[did] = len(output.prompt_token_ids) + input_ids_map[did] = list(output.prompt_token_ids) + try: + self._engine.collective_rpc( + "_set_request_metadata", + args=(request_metadata, packed_loss_mask_map, input_ids_map), + ) + except Exception as e: + logger.warning( + f"VllmEngine rank {self.rank}: Could not set post-generation " + f"request metadata: {e}" + ) # Get metadata from workers (tensors are already stored in Mooncake by workers) metadata_by_request: dict[str, dict] = {} try: - # Workers store tensors directly to Mooncake and return metadata only - metadata_list = self._engine.collective_rpc("_store_and_get_metadata") + metadata_list = self._engine.collective_rpc( + "_store_and_get_metadata", args=(internal_to_external,) + ) if isinstance(metadata_list, list): for metadata in metadata_list: if isinstance(metadata, dict): @@ -368,6 +370,14 @@ def _generate_worker_extension( except Exception as e: logger.warning(f"Could not get metadata from worker extension: {e}") + if not metadata_by_request: + logger.error( + f"VllmEngine rank {self.rank}: metadata_by_request is EMPTY for " + f"data_ids={data_ids}. Worker returned metadata_list={metadata_list!r}. " + f"use_prompts={use_prompts}, request_metadata_keys={list(request_metadata.keys())}, " + f"internal_to_external={internal_to_external}" + ) + results = [] for i, output in enumerate(outputs): seq_len = len(output.prompt_token_ids) @@ -378,6 +388,7 @@ def _generate_worker_extension( if metadata is None: logger.error( f"VllmEngine rank {self.rank}: No metadata for data_id={data_id}. " + f"metadata_by_request has keys={list(metadata_by_request.keys())}. " f"Training may be corrupted." ) continue @@ -430,199 +441,7 @@ def _normalize_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return input_ids raise ValueError(f"Unexpected input_ids shape: {input_ids.shape}") - def _get_sample_input_ids( - self, - index: int, - input_ids_list: list[torch.Tensor] | None, - output: Any, - ) -> torch.Tensor: - if input_ids_list is not None: - return self._normalize_input_ids(input_ids_list[index]).to(dtype=torch.long) - return torch.tensor(output.prompt_token_ids, dtype=torch.long) - - def _merge_captured_states( - self, - captured_states: Any, - ) -> tuple[dict[str, list[torch.Tensor]], list[list[torch.Tensor]]]: - merged: dict[str, list[torch.Tensor]] = {} - ordered: list[list[torch.Tensor]] = [] - - # Handle different return types from collective_rpc - if captured_states is None: - return merged, ordered - - # If it's a single dict, wrap it in a list - if isinstance(captured_states, dict): - captured_states = [captured_states] - - if not isinstance(captured_states, list): - logger.warning(f"Unexpected captured_states type: {type(captured_states)}") - return merged, ordered - - # Collect layer states from all workers for each request - # With tensor parallelism, we need to concatenate along hidden dim - request_states: dict[str, list[list[torch.Tensor]]] = {} - - for reply in captured_states: - if not isinstance(reply, dict): - logger.debug(f"Skipping non-dict reply: {type(reply)}") - continue - for request_id, layer_states in reply.items(): - if not isinstance(layer_states, list): - logger.debug( - f"Skipping non-list layer_states for {request_id}: {type(layer_states)}" - ) - continue - if request_id not in request_states: - request_states[request_id] = [] - request_states[request_id].append(layer_states) - - # Merge states: concatenate tensors from different workers along hidden dim - for request_id, worker_states_list in request_states.items(): - if not worker_states_list: - continue - - # Get number of layers from first worker - num_layers = len(worker_states_list[0]) - logger.debug( - f"Merging {len(worker_states_list)} workers for request {request_id} with {num_layers} layers" - ) - - # Concatenate tensors from all workers for each layer - merged_layers = [] - for layer_idx in range(num_layers): - layer_tensors = [ - worker_states[layer_idx] - for worker_states in worker_states_list - if layer_idx < len(worker_states) - ] - - # Check if layer_tensors contains lists (nested structure) - if layer_tensors and isinstance(layer_tensors[0], list): - # This shouldn't happen after proper extraction, but handle it - logger.warning(f"Unexpected nested list structure for layer {layer_idx}") - layer_tensors = [ - item - for sublist in layer_tensors - for item in (sublist if isinstance(sublist, list) else [sublist]) - ] - - if len(layer_tensors) == 1: - merged_layers.append(layer_tensors[0]) - elif len(layer_tensors) > 1: - # Concatenate along hidden dimension (dim=-1) - merged_layers.append(torch.cat(layer_tensors, dim=-1)) - else: - # No tensors for this layer - logger.warning(f"No tensors for layer {layer_idx} in request {request_id}") - merged_layers.append(None) # type: ignore[arg-type] - - merged[request_id] = merged_layers - ordered.append(merged_layers) - - return merged, ordered - - def _store_tensors_to_mooncake( - self, - data_id: str, - input_ids: torch.Tensor, - hidden_states: torch.Tensor, - last_hidden_states: torch.Tensor | None, - ) -> tuple[str, dict[str, tuple[int, ...]], dict[str, torch.dtype]] | None: - if self._mooncake_store is None: - self._init_mooncake_store() - if self._mooncake_store is None: - return None - - if input_ids.dtype != torch.long: - input_ids = input_ids.to(dtype=torch.long) - if hidden_states.dtype != torch.bfloat16: - hidden_states = hidden_states.to(dtype=torch.bfloat16) - if last_hidden_states is not None and last_hidden_states.dtype != torch.bfloat16: - last_hidden_states = last_hidden_states.to(dtype=torch.bfloat16) - - mooncake_key = f"vllm_{self.rank}_{data_id}_{uuid.uuid4().hex}" - tensor_shapes = self._mooncake_store.put( - key=mooncake_key, - hidden_states=hidden_states, - input_ids=input_ids, - last_hidden_states=last_hidden_states, - target=None, - ) - tensor_dtypes = { - "hidden_states": hidden_states.dtype, - "input_ids": input_ids.dtype, - "last_hidden_states": ( - last_hidden_states.dtype if last_hidden_states is not None else hidden_states.dtype - ), - } - return mooncake_key, tensor_shapes, tensor_dtypes - - def _store_sample_to_mooncake( - self, - data_id: str, - input_ids: torch.Tensor, - layer_states: list[torch.Tensor] | None, - hidden_states_path: str | None, - ) -> tuple[str, dict[str, tuple[int, ...]], dict[str, torch.dtype]] | None: - if layer_states: - # Debug: log the structure of layer_states - logger.debug(f"layer_states type: {type(layer_states)}, len: {len(layer_states)}") - if layer_states: - logger.debug(f"layer_states[0] type: {type(layer_states[0])}") - if isinstance(layer_states[0], list): - logger.error(f"layer_states[0] is a list with len {len(layer_states[0])}") - # Flatten the list if needed - layer_states = [ - item - for sublist in layer_states - for item in (sublist if isinstance(sublist, list) else [sublist]) - ] - logger.debug(f"After flattening: layer_states len: {len(layer_states)}") - - # Filter out any non-tensor elements - layer_states = [ls for ls in layer_states if isinstance(ls, torch.Tensor)] - - if not layer_states: - logger.error(f"No valid tensor layers found for data_id={data_id}") - return None - - hidden_states = ( - torch.cat(layer_states, dim=-1) if len(layer_states) > 1 else layer_states[0] - ) - last_hidden_states = layer_states[-1] - return self._store_tensors_to_mooncake( - data_id=data_id, - input_ids=input_ids, - hidden_states=hidden_states, - last_hidden_states=last_hidden_states, - ) - - if hidden_states_path is None or not os.path.exists(hidden_states_path): - return None - - data = torch.load(hidden_states_path, map_location="cpu") - hidden_states = data.get("hidden_states") - if not isinstance(hidden_states, torch.Tensor): - return None - stored_input_ids = data.get("input_ids") - if isinstance(stored_input_ids, torch.Tensor): - input_ids = self._normalize_input_ids(stored_input_ids) - last_hidden_states = data.get("last_hidden_states") - if not isinstance(last_hidden_states, torch.Tensor): - if self._hidden_size is not None and hidden_states.shape[-1] >= self._hidden_size: - last_hidden_states = hidden_states[:, -self._hidden_size :] - else: - last_hidden_states = hidden_states - - return self._store_tensors_to_mooncake( - data_id=data_id, - input_ids=input_ids, - hidden_states=hidden_states, - last_hidden_states=last_hidden_states, - ) - - def _convert_input_ids_to_prompts( + def _format_input_ids_for_vllm( self, input_ids_list: list[torch.Tensor] ) -> list[dict[str, list[int]]]: prompts = [] @@ -645,14 +464,6 @@ def shutdown(self) -> None: del self._engine self._engine = None - if self._storage_path and os.path.exists(self._storage_path): - import shutil - - try: - shutil.rmtree(self._storage_path) - except Exception: - pass - logger.info(f"VllmEngine rank {self.rank}: shutdown complete") def get_status(self) -> dict: diff --git a/torchspec/inference/engine/vllm_worker_extension.py b/torchspec/inference/engine/vllm_worker_extension.py index 867de58..aaef8d8 100644 --- a/torchspec/inference/engine/vllm_worker_extension.py +++ b/torchspec/inference/engine/vllm_worker_extension.py @@ -142,9 +142,11 @@ def _patched_forward( # Final normalization (only on last PP rank) hidden_states, _ = self.norm(hidden_states, residual) - # Store captured states (only on last PP rank and TP rank 0) - if should_capture and aux_hidden_states: - extension._store_captured_states(aux_hidden_states) # noqa: SLF001 + # Store captured states (only on last PP rank, TP rank 0, and during prefill) + if should_capture and not extension._prefill_complete: # noqa: SLF001 + if aux_hidden_states: + extension._store_captured_states(aux_hidden_states) # noqa: SLF001 + extension._store_last_hidden_states(hidden_states) # noqa: SLF001 return hidden_states @@ -414,27 +416,30 @@ def _ensure_mooncake_store(self) -> bool: logger.error(f"Unexpected error in _ensure_mooncake_store: {e}", exc_info=True) return False + def _store_last_hidden_states(self, hidden_states: torch.Tensor) -> None: + """Store post-norm hidden states from a forward pass for use as last_hidden_states""" + if getattr(self, "_captured_last_hs", None) is None: + self._captured_last_hs = [hidden_states.clone()] + else: + self._captured_last_hs.append(hidden_states.clone()) + def _store_captured_states(self, aux_hidden_states: List[torch.Tensor]) -> None: """Store captured hidden states from a forward pass. Args: aux_hidden_states: List of tensors, one per target layer """ - if getattr(self, "_captured_states", None) is None: - # First capture - initialize lists for each layer + if self._captured_states is None: self._captured_states = [[h] for h in aux_hidden_states] else: - # Append to existing lists for i, h in enumerate(aux_hidden_states): self._captured_states[i].append(h) - # Track how many tokens were captured in this step - # Get from input_batch if available, otherwise use metadata + # Track per-request token counts for this scheduler step model_runner = getattr(self, "model_runner", None) input_batch = getattr(model_runner, "input_batch", None) + step_tokens: Dict[str, int] = {} if input_batch is not None and hasattr(input_batch, "req_ids"): - # Track by internal request IDs - will map to external IDs later - step_tokens = {} for req_id in input_batch.req_ids: num_tokens = 0 req_idx = getattr(input_batch, "req_id_to_index", {}).get(req_id) @@ -447,10 +452,21 @@ def _store_captured_states(self, aux_hidden_states: List[torch.Tensor]) -> None: ] num_tokens = num_total - num_computed step_tokens[req_id] = num_tokens - self._request_metadata.append(step_tokens) - else: - # Fallback: assume all requests in one step - self._request_metadata.append({}) + self._request_metadata.append(step_tokens) + + # With max_tokens=1 the prefill forward pass already generates the + # single allowed token, so no decode step is scheduled by vLLM. + # This check handles chunked prefill where multiple forward calls + # sum up to the total prefill token count. + if self._current_request_metadata and not self._prefill_complete: + expected = sum(self._current_request_metadata.values()) + captured = sum(t.shape[0] for t in self._captured_states[0]) + if captured == expected: + self._prefill_complete = True + elif captured > expected: + logger.warning( + f"Captured more tokens than expected: {captured} > {expected}" + ) def _store_input_ids(self, input_ids: torch.Tensor) -> None: """Store input_ids from a forward pass. @@ -544,13 +560,15 @@ def _reset_capture(self) -> None: if not hasattr(self, "_layer_ids") or len(self._layer_ids) == 0: raise RuntimeError("Must call _setup_hidden_states_capture before capturing states") self._captured_states = None + self._captured_last_hs: Optional[List[torch.Tensor]] = None self._captured_input_ids: Optional[torch.Tensor] = None + self._prefill_complete = False self._request_metadata = [] self._current_request_metadata = None self._packed_loss_mask_map = {} self._input_ids_map = {} - def _store_and_get_metadata(self) -> Optional[Dict[str, Dict[str, Any]]]: + def _store_and_get_metadata(self, internal_to_external: Optional[Dict[str, str]] = None) -> Optional[Dict[str, Dict[str, Any]]]: """Store captured hidden states to Mooncake and return metadata. This method stores tensors directly to Mooncake from the worker process, @@ -569,11 +587,15 @@ def _store_and_get_metadata(self) -> Optional[Dict[str, Dict[str, Any]]]: if get_tp_group().rank_in_group != 0: return None if self._captured_states is None: + logger.warning( + "_store_and_get_metadata: captured_states is None " + "(forward patch may not be running or no prefill occurred)" + ) return None # Ensure Mooncake store is initialized and setup (with retry for V1 compatibility) if not self._ensure_mooncake_store(): - logger.error( + logger.warning( "Failed to initialize/setup Mooncake store, cannot store hidden states. " "This may be due to CUDA context not being ready in V1 engine." ) @@ -585,49 +607,83 @@ def _store_and_get_metadata(self) -> Optional[Dict[str, Dict[str, Any]]]: ] total_captured_tokens = concatenated_layers[0].shape[0] - # Slice and group by request using external IDs - external_ids = ( - list(self._current_request_metadata.keys()) if self._current_request_metadata else [] - ) - token_counts = ( - list(self._current_request_metadata.values()) if self._current_request_metadata else [] - ) - total_expected_tokens = sum(token_counts) if token_counts else 0 + # Concatenate post-norm hidden states for last_hidden_states + concatenated_last_hs = None + if getattr(self, "_captured_last_hs", None): + concatenated_last_hs = torch.cat(self._captured_last_hs, dim=0) + + internal_to_external = internal_to_external or {} + ext_token_counts = dict(self._current_request_metadata) if self._current_request_metadata else {} + + # Build worker-visible ID -> external ID lookup once. + # In V1, the worker sees "{counter}-{uuid8}" while internal_to_external + # maps bare counter strings (from output.request_id) to external data_ids. + worker_to_ext: Dict[str, str] = dict(internal_to_external) + for step_meta in self._request_metadata: + for worker_id in step_meta: + if worker_id not in worker_to_ext: + for counter, ext_id in internal_to_external.items(): + if worker_id.startswith(f"{counter}-"): + worker_to_ext[worker_id] = ext_id + break + + request_slices: List[tuple] = [] # (external_id, num_tokens) + seen_ext_ids: set = set() + + for step_meta in self._request_metadata: + for int_id in step_meta.keys(): + ext_id = worker_to_ext.get(int_id, int_id) + if ext_id not in seen_ext_ids: + n_tokens = ext_token_counts.get(ext_id, 0) + if n_tokens > 0: + request_slices.append((ext_id, n_tokens)) + seen_ext_ids.add(ext_id) + + # Fallback if _request_metadata didn't produce results + if not request_slices and ext_token_counts: + logger.warning("Internal request metadata mapping failed; falling back to external order") + for ext_id, n_tokens in ext_token_counts.items(): + request_slices.append((ext_id, n_tokens)) + + if not request_slices: + logger.warning( + f"_store_and_get_metadata: request_slices is empty — cannot map " + f"captured tokens to requests. " + f"total_captured_tokens={total_captured_tokens}, " + f"_request_metadata steps={len(self._request_metadata)}, " + f"internal_to_external keys={list(internal_to_external.keys())[:5]}, " + f"ext_token_counts keys={list(ext_token_counts.keys())[:5]}, " + f"current_request_metadata={self._current_request_metadata is not None}" + ) + total_expected_tokens = sum(n for _, n in request_slices) + + if total_captured_tokens != total_expected_tokens and total_expected_tokens > 0: + logger.warning( + f"Token count mismatch: captured={total_captured_tokens}, " + f"expected={total_expected_tokens}" + ) + + num_aux_layers = len(concatenated_layers) request_chunks: defaultdict[str, List[List[torch.Tensor]]] = defaultdict( - lambda: [[] for _ in range(len(concatenated_layers))] + lambda: [[] for _ in range(num_aux_layers)] ) + request_last_hs: defaultdict[str, List[torch.Tensor]] = defaultdict(list) current_idx = 0 - # Handle multi-step scheduling:按比例分配实际捕获的tokens - if total_expected_tokens > 0 and total_captured_tokens > 0: - ratio = total_captured_tokens / total_expected_tokens - for req_idx, (external_id, expected_tokens) in enumerate( - zip(external_ids, token_counts) - ): - # 按比例分配实际捕获的token数 - actual_tokens = int(expected_tokens * ratio) - for req_idx, (external_id, expected_tokens) in enumerate( - zip(external_ids, token_counts) - ): - # 按比例分配实际捕获的token数 - actual_tokens = int(expected_tokens * ratio) - actual_tokens = min(actual_tokens, total_captured_tokens - current_idx) - if actual_tokens > 0: - for layer_idx, layer_tensor in enumerate(concatenated_layers): - chunk = ( - layer_tensor[current_idx : current_idx + actual_tokens].clone().cpu() - ) - request_chunks[external_id][layer_idx].append(chunk) - current_idx += actual_tokens - else: - # Fallback: simple sequential slicing - for req_idx, (external_id, num_tokens) in enumerate(zip(external_ids, token_counts)): - if current_idx < total_captured_tokens: - for layer_idx, layer_tensor in enumerate(concatenated_layers): - chunk = layer_tensor[current_idx : current_idx + num_tokens].clone().cpu() - request_chunks[external_id][layer_idx].append(chunk) - current_idx += num_tokens + for external_id, num_tokens in request_slices: + if current_idx >= total_captured_tokens: + break + actual_tokens = min(num_tokens, total_captured_tokens - current_idx) + if actual_tokens > 0: + for layer_idx, layer_tensor in enumerate(concatenated_layers): + chunk = layer_tensor[current_idx : current_idx + actual_tokens] + request_chunks[external_id][layer_idx].append(chunk) + if concatenated_last_hs is not None: + request_last_hs[external_id].append( + concatenated_last_hs[current_idx : current_idx + actual_tokens] + ) + current_idx += actual_tokens # Store to Mooncake and collect metadata result: Dict[str, Dict[str, Any]] = {} @@ -636,16 +692,17 @@ def _store_and_get_metadata(self) -> Optional[Dict[str, Dict[str, Any]]]: if mooncake_key != req_id: logger.debug(f"Sanitized key '{req_id}' -> '{mooncake_key}'") - # Concatenate all layer chunks for this request (keep on GPU) layer_tensors = [torch.cat(chunks, dim=0) for chunks in layer_chunks] - # Concatenate all layers along hidden dimension if len(layer_tensors) > 1: hidden_states = torch.cat(layer_tensors, dim=-1) else: hidden_states = layer_tensors[0] - last_hidden_states = layer_tensors[-1] + if req_id in request_last_hs and request_last_hs[req_id]: + last_hidden_states = torch.cat(request_last_hs[req_id], dim=0) + else: + last_hidden_states = layer_tensors[-1] # Use real input_ids from RPC, otherwise create dummy if req_id in self._input_ids_map: @@ -694,10 +751,9 @@ def _store_and_get_metadata(self) -> Optional[Dict[str, Dict[str, Any]]]: "input_ids_list": input_ids.cpu().tolist(), # Serialize via RPC instead of Mooncake } except Exception as e: - logger.error( + logger.warning( f"Failed to store tensors to Mooncake for {req_id} (key={mooncake_key}): {e}" ) - # Continue with other requests even if one fails continue # Flush to ensure all writes are complete before returning @@ -706,6 +762,7 @@ def _store_and_get_metadata(self) -> Optional[Dict[str, Dict[str, Any]]]: # Clear intermediate storage to free memory self._captured_states = None + self._captured_last_hs = None self._captured_input_ids = None self._request_metadata = [] self._input_ids_map = {} diff --git a/torchspec/ray/placement_group.py b/torchspec/ray/placement_group.py index c7efbc6..5f56273 100644 --- a/torchspec/ray/placement_group.py +++ b/torchspec/ray/placement_group.py @@ -57,9 +57,12 @@ def sort_key(x): return (node_ip_parts, gpu_id) -def _create_placement_group(num_gpus, strategy="PACK", name=None): +def _create_placement_group(num_gpus, strategy="PACK", name=None, node_ip=None): """Create a placement group with the specified number of GPUs.""" - bundles = [{"GPU": 1, "CPU": 1} for _ in range(num_gpus)] + bundle = {"GPU": 1, "CPU": 1} + if node_ip: + bundle[f"node:{node_ip}"] = 0.001 + bundles = [bundle.copy() for _ in range(num_gpus)] pg = placement_group(bundles, strategy=strategy, name=name) num_bundles = len(bundles) @@ -193,14 +196,16 @@ def create_placement_groups(args): f"{num_training_gpus} GPUs for training..." ) + pin_node_ip = os.environ.get("TORCHSPEC_PIN_NODE_IP") + logger.info("Creating inference placement group...") inference_pg, inference_bundle_indices, inference_gpu_ids = _create_placement_group( - num_inference_gpus, strategy="PACK", name="inference_pg" + num_inference_gpus, strategy="PACK", name="inference_pg", node_ip=pin_node_ip ) logger.info("Creating training placement group...") training_pg, training_bundle_indices, training_gpu_ids = _create_placement_group( - num_training_gpus, strategy="PACK", name="training_pg" + num_training_gpus, strategy="PACK", name="training_pg", node_ip=pin_node_ip ) return { From 7121138f805dd5bafe7d103b36b9bcc9ac3eeef6 Mon Sep 17 00:00:00 2001 From: Yubo Wang Date: Fri, 6 Mar 2026 23:14:39 +0000 Subject: [PATCH 07/10] revert placement group --- torchspec/ray/placement_group.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/torchspec/ray/placement_group.py b/torchspec/ray/placement_group.py index 5f56273..c7efbc6 100644 --- a/torchspec/ray/placement_group.py +++ b/torchspec/ray/placement_group.py @@ -57,12 +57,9 @@ def sort_key(x): return (node_ip_parts, gpu_id) -def _create_placement_group(num_gpus, strategy="PACK", name=None, node_ip=None): +def _create_placement_group(num_gpus, strategy="PACK", name=None): """Create a placement group with the specified number of GPUs.""" - bundle = {"GPU": 1, "CPU": 1} - if node_ip: - bundle[f"node:{node_ip}"] = 0.001 - bundles = [bundle.copy() for _ in range(num_gpus)] + bundles = [{"GPU": 1, "CPU": 1} for _ in range(num_gpus)] pg = placement_group(bundles, strategy=strategy, name=name) num_bundles = len(bundles) @@ -196,16 +193,14 @@ def create_placement_groups(args): f"{num_training_gpus} GPUs for training..." ) - pin_node_ip = os.environ.get("TORCHSPEC_PIN_NODE_IP") - logger.info("Creating inference placement group...") inference_pg, inference_bundle_indices, inference_gpu_ids = _create_placement_group( - num_inference_gpus, strategy="PACK", name="inference_pg", node_ip=pin_node_ip + num_inference_gpus, strategy="PACK", name="inference_pg" ) logger.info("Creating training placement group...") training_pg, training_bundle_indices, training_gpu_ids = _create_placement_group( - num_training_gpus, strategy="PACK", name="training_pg", node_ip=pin_node_ip + num_training_gpus, strategy="PACK", name="training_pg" ) return { From 062134bb7198e11cf6aebef5050f333df9bfbd2c Mon Sep 17 00:00:00 2001 From: Yubo Wang Date: Fri, 6 Mar 2026 23:16:44 +0000 Subject: [PATCH 08/10] lint --- tests/test_vllm_engine_integration.py | 6 ++++-- torchspec/inference/engine/vllm_engine.py | 1 - .../inference/engine/vllm_worker_extension.py | 16 ++++++++++------ 3 files changed, 14 insertions(+), 9 deletions(-) diff --git a/tests/test_vllm_engine_integration.py b/tests/test_vllm_engine_integration.py index 4c8f98e..4af40ef 100644 --- a/tests/test_vllm_engine_integration.py +++ b/tests/test_vllm_engine_integration.py @@ -65,7 +65,9 @@ def verify_from_mooncake(mooncake_store, keys, seq_lens, hidden_dim, last_hidden } data = mooncake_store.get(key, shapes=shapes, dtypes=dtypes, device="cuda") print(f"\n Key: {key}") - print(f" hidden_states: shape={data.hidden_states.shape}, dtype={data.hidden_states.dtype}") + print( + f" hidden_states: shape={data.hidden_states.shape}, dtype={data.hidden_states.dtype}" + ) print(f" input_ids: {data.input_ids.tolist()[:10]}{'...' if seq_len > 10 else ''}") print(f" last_hidden_states: shape={data.last_hidden_states.shape}") @@ -203,7 +205,7 @@ def verify_from_mooncake(mooncake_store, keys, seq_lens, hidden_dim, last_hidden did = prompt_data_ids[i] request_metadata[did] = len(output.prompt_token_ids) input_ids_map[did] = list(output.prompt_token_ids) - print(f" Request {i}: \"{text_prompts[i]}\" -> {len(output.prompt_token_ids)} tokens") + print(f' Request {i}: "{text_prompts[i]}" -> {len(output.prompt_token_ids)} tokens') engine.collective_rpc("_set_request_metadata", args=(request_metadata, {}, input_ids_map)) metadata = collect_metadata(engine) diff --git a/torchspec/inference/engine/vllm_engine.py b/torchspec/inference/engine/vllm_engine.py index d1569e7..7da1776 100644 --- a/torchspec/inference/engine/vllm_engine.py +++ b/torchspec/inference/engine/vllm_engine.py @@ -25,7 +25,6 @@ extraction via model.forward patching in worker processes. """ -import os import socket import ray diff --git a/torchspec/inference/engine/vllm_worker_extension.py b/torchspec/inference/engine/vllm_worker_extension.py index aaef8d8..ad5ebbf 100644 --- a/torchspec/inference/engine/vllm_worker_extension.py +++ b/torchspec/inference/engine/vllm_worker_extension.py @@ -464,9 +464,7 @@ def _store_captured_states(self, aux_hidden_states: List[torch.Tensor]) -> None: if captured == expected: self._prefill_complete = True elif captured > expected: - logger.warning( - f"Captured more tokens than expected: {captured} > {expected}" - ) + logger.warning(f"Captured more tokens than expected: {captured} > {expected}") def _store_input_ids(self, input_ids: torch.Tensor) -> None: """Store input_ids from a forward pass. @@ -568,7 +566,9 @@ def _reset_capture(self) -> None: self._packed_loss_mask_map = {} self._input_ids_map = {} - def _store_and_get_metadata(self, internal_to_external: Optional[Dict[str, str]] = None) -> Optional[Dict[str, Dict[str, Any]]]: + def _store_and_get_metadata( + self, internal_to_external: Optional[Dict[str, str]] = None + ) -> Optional[Dict[str, Dict[str, Any]]]: """Store captured hidden states to Mooncake and return metadata. This method stores tensors directly to Mooncake from the worker process, @@ -613,7 +613,9 @@ def _store_and_get_metadata(self, internal_to_external: Optional[Dict[str, str]] concatenated_last_hs = torch.cat(self._captured_last_hs, dim=0) internal_to_external = internal_to_external or {} - ext_token_counts = dict(self._current_request_metadata) if self._current_request_metadata else {} + ext_token_counts = ( + dict(self._current_request_metadata) if self._current_request_metadata else {} + ) # Build worker-visible ID -> external ID lookup once. # In V1, the worker sees "{counter}-{uuid8}" while internal_to_external @@ -641,7 +643,9 @@ def _store_and_get_metadata(self, internal_to_external: Optional[Dict[str, str]] # Fallback if _request_metadata didn't produce results if not request_slices and ext_token_counts: - logger.warning("Internal request metadata mapping failed; falling back to external order") + logger.warning( + "Internal request metadata mapping failed; falling back to external order" + ) for ext_id, n_tokens in ext_token_counts.items(): request_slices.append((ext_id, n_tokens)) From dee5460c1dfb3e5c121ec3d1455715fee69ce9cf Mon Sep 17 00:00:00 2001 From: Yubo Wang Date: Fri, 6 Mar 2026 23:37:07 +0000 Subject: [PATCH 09/10] fix env parsing --- configs/vllm_qwen3_8b.yaml | 6 +- tests/test_vllm_engine_integration.py | 8 +- torchspec/config/mooncake_config.py | 10 +++ torchspec/inference/engine/vllm_engine.py | 14 +--- .../inference/engine/vllm_worker_extension.py | 79 ++----------------- 5 files changed, 24 insertions(+), 93 deletions(-) diff --git a/configs/vllm_qwen3_8b.yaml b/configs/vllm_qwen3_8b.yaml index 21d4f40..96ce140 100644 --- a/configs/vllm_qwen3_8b.yaml +++ b/configs/vllm_qwen3_8b.yaml @@ -18,8 +18,8 @@ model: trust_remote_code: true dataset: - train_data_path: ./examples/data/sample_conversations.jsonl - eval_data_path: ./examples/data/eval_conversations.jsonl + train_data_path: ../examples/data/sample_conversations.jsonl + eval_data_path: ../examples/data/eval_conversations.jsonl eval_interval: 100 chat_template: qwen prompt_key: conversations @@ -54,6 +54,8 @@ inference: use_worker_extension: true extra_args: max_num_batched_tokens: 32768 + compilation_config: + max_cudagraph_capture_size: 8 mooncake: master_server_address: null diff --git a/tests/test_vllm_engine_integration.py b/tests/test_vllm_engine_integration.py index 4af40ef..3718b10 100644 --- a/tests/test_vllm_engine_integration.py +++ b/tests/test_vllm_engine_integration.py @@ -23,16 +23,12 @@ except Exception: LOCAL_IP = "localhost" -# Mooncake env vars for MooncakeConfig.from_env() (retrieval side) +# Mooncake env vars read by MooncakeConfig.from_env() on both sides os.environ["MOONCAKE_MASTER_HOST"] = LOCAL_IP os.environ["MOONCAKE_MASTER_PORT"] = "50051" os.environ["MOONCAKE_METADATA_PORT"] = "8090" os.environ["MOONCAKE_LOCAL_HOSTNAME"] = LOCAL_IP - -# Mooncake env vars for worker extension (storage side) -os.environ["TORCHSPEC_MOONCAKE_MASTER_ADDR"] = f"{LOCAL_IP}:50051" -os.environ["TORCHSPEC_MOONCAKE_LOCAL_HOSTNAME"] = LOCAL_IP -os.environ["TORCHSPEC_MOONCAKE_PROTOCOL"] = "tcp" +os.environ["MOONCAKE_MASTER_SERVER"] = f"{LOCAL_IP}:50051" def collect_metadata(engine, internal_to_external=None): diff --git a/torchspec/config/mooncake_config.py b/torchspec/config/mooncake_config.py index c8b8479..a8bd95c 100644 --- a/torchspec/config/mooncake_config.py +++ b/torchspec/config/mooncake_config.py @@ -177,6 +177,16 @@ def export_env(self) -> None: os.environ["MOONCAKE_ENABLE_GPU_DIRECT"] = "1" if self.enable_gpu_direct else "0" if self.async_put_pool_size is not None: os.environ["MOONCAKE_ASYNC_PUT_POOL_SIZE"] = str(self.async_put_pool_size) + os.environ["MOONCAKE_STORE_FULL_WAIT_SECONDS"] = str(self.store_full_wait_seconds) + os.environ["MOONCAKE_STORE_FULL_LOG_INTERVAL_SECONDS"] = str( + self.store_full_log_interval_seconds + ) + os.environ["MOONCAKE_STORE_FULL_MAX_WAIT_SECONDS"] = str(self.store_full_max_wait_seconds) + os.environ["MOONCAKE_GET_RETRY_WAIT_SECONDS"] = str(self.get_retry_wait_seconds) + os.environ["MOONCAKE_GET_RETRY_LOG_INTERVAL_SECONDS"] = str( + self.get_retry_log_interval_seconds + ) + os.environ["MOONCAKE_GET_RETRY_MAX_WAIT_SECONDS"] = str(self.get_retry_max_wait_seconds) @classmethod def from_env(cls) -> "MooncakeConfig": diff --git a/torchspec/inference/engine/vllm_engine.py b/torchspec/inference/engine/vllm_engine.py index 7da1776..d4912d5 100644 --- a/torchspec/inference/engine/vllm_engine.py +++ b/torchspec/inference/engine/vllm_engine.py @@ -225,20 +225,8 @@ def _setup_rpc_hidden_states_capture(self) -> None: if not hasattr(self._engine, "collective_rpc"): raise RuntimeError("vLLM LLM.collective_rpc is required for worker extension mode") - # Set environment variables so workers can connect to Mooncake if self._mooncake_config is not None: - import os - - os.environ["TORCHSPEC_MOONCAKE_MASTER_ADDR"] = ( - self._mooncake_config.master_server_address - ) - os.environ["TORCHSPEC_MOONCAKE_METADATA_PORT"] = str( - self._mooncake_config.metadata_server.split(":")[-1].replace("/metadata", "") - ) - os.environ["TORCHSPEC_MOONCAKE_LOCAL_HOSTNAME"] = self._mooncake_config.local_hostname - os.environ["TORCHSPEC_MOONCAKE_PROTOCOL"] = self._mooncake_config.protocol - if self._mooncake_config.device_name: - os.environ["TORCHSPEC_MOONCAKE_DEVICE_NAME"] = self._mooncake_config.device_name + self._mooncake_config.export_env() logger.info( f"VllmEngine rank {self.rank}: Set Mooncake env vars for workers: " f"master={self._mooncake_config.master_server_address}" diff --git a/torchspec/inference/engine/vllm_worker_extension.py b/torchspec/inference/engine/vllm_worker_extension.py index ad5ebbf..a474ac7 100644 --- a/torchspec/inference/engine/vllm_worker_extension.py +++ b/torchspec/inference/engine/vllm_worker_extension.py @@ -236,85 +236,19 @@ def _init_mooncake_store(self) -> bool: pass try: - # Import here to avoid circular dependencies from torchspec.config.mooncake_config import MooncakeConfig from torchspec.transfer.mooncake.eagle_store import EagleMooncakeStore - # Get connection info from environment (set by main process) - # Try TORCHSPEC_MOONCAKE_* first, then fall back to MOONCAKE_* - master_addr = os.environ.get("TORCHSPEC_MOONCAKE_MASTER_ADDR") or os.environ.get( - "MOONCAKE_MASTER_SERVER" - ) - metadata_server = os.environ.get( - "TORCHSPEC_MOONCAKE_METADATA_SERVER" - ) or os.environ.get("MOONCAKE_METADATA_SERVER") - local_hostname = os.environ.get("TORCHSPEC_MOONCAKE_LOCAL_HOSTNAME") or os.environ.get( - "MOONCAKE_LOCAL_HOSTNAME", "localhost" - ) - protocol = os.environ.get("TORCHSPEC_MOONCAKE_PROTOCOL") or os.environ.get( - "MOONCAKE_PROTOCOL", "tcp" - ) - - if not master_addr: + if not os.environ.get("MOONCAKE_MASTER_SERVER") and not os.environ.get( + "MOONCAKE_MASTER_HOST" + ): logger.warning( "Mooncake master address not available in worker environment. " - "Set TORCHSPEC_MOONCAKE_MASTER_ADDR or MOONCAKE_MASTER_SERVER environment variable." + "Set MOONCAKE_MASTER_SERVER environment variable." ) return False - # Parse metadata_server to get port if not explicitly set - if metadata_server: - # Extract port from URL like "http://host:port/metadata" - try: - metadata_port = metadata_server.split(":")[-1].replace("/metadata", "") - except Exception: - metadata_port = "8090" - else: - metadata_port = "8090" - metadata_server = f"http://{master_addr.split(':')[0]}:{metadata_port}/metadata" - - # Read buffer sizes from environment (set by main process via export_env) - host_buffer_size_env = os.environ.get( - "TORCHSPEC_MOONCAKE_HOST_BUFFER_SIZE" - ) or os.environ.get("MOONCAKE_HOST_BUFFER_SIZE") - - # Build config kwargs with optional overrides from environment - config_kwargs = { - "master_server_address": master_addr, - "metadata_server": metadata_server, - "local_hostname": local_hostname, - "protocol": protocol, - "device_name": os.environ.get("TORCHSPEC_MOONCAKE_DEVICE_NAME") - or os.environ.get("MOONCAKE_DEVICE_NAME", ""), - "async_put_pool_size": int( - os.environ.get("TORCHSPEC_MOONCAKE_ASYNC_POOL_SIZE") - or os.environ.get("MOONCAKE_ASYNC_PUT_POOL_SIZE", "2") - ), - "enable_gpu_direct": ( - os.environ.get("TORCHSPEC_MOONCAKE_GPU_DIRECT") - or os.environ.get("MOONCAKE_ENABLE_GPU_DIRECT", "0") - ).lower() - in ("true", "1", "yes"), - } - - # Only override defaults if environment variables are set - if host_buffer_size_env: - config_kwargs["host_buffer_size"] = int(host_buffer_size_env) - - global_segment_size_env = os.environ.get( - "TORCHSPEC_MOONCAKE_GLOBAL_SEGMENT_SIZE" - ) or os.environ.get("MOONCAKE_GLOBAL_SEGMENT_SIZE") - if global_segment_size_env: - config_kwargs["global_segment_size"] = int(global_segment_size_env) - - local_buffer_size_env = os.environ.get( - "TORCHSPEC_MOONCAKE_LOCAL_BUFFER_SIZE" - ) or os.environ.get("MOONCAKE_LOCAL_BUFFER_SIZE") - if local_buffer_size_env: - config_kwargs["local_buffer_size"] = int(local_buffer_size_env) - - # Create config for worker - config = MooncakeConfig(**config_kwargs) + config = MooncakeConfig.from_env() # Create store object but don't call setup() yet # setup() will be called lazily when CUDA context is ready @@ -325,7 +259,8 @@ def _init_mooncake_store(self) -> bool: self._store_setup_complete = False logger.info( - f"Worker initialized Mooncake store (setup deferred): master={master_addr}, protocol={protocol}" + f"Worker initialized Mooncake store (setup deferred): " + f"master={config.master_server_address}, protocol={config.protocol}" ) return True From 48453044dde98b225516f06c86fe923462e7ca36 Mon Sep 17 00:00:00 2001 From: Yubo Wang Date: Sat, 7 Mar 2026 00:04:57 +0000 Subject: [PATCH 10/10] fix defer tokenization --- tests/test_vllm_engine.py | 295 ++++++++++-------- tests/test_vllm_engine_integration.py | 7 +- torchspec/config/train_config.py | 16 + torchspec/controller/inference_manager.py | 2 +- torchspec/inference/engine/vllm_engine.py | 35 +-- .../inference/engine/vllm_worker_extension.py | 7 +- 6 files changed, 211 insertions(+), 151 deletions(-) diff --git a/tests/test_vllm_engine.py b/tests/test_vllm_engine.py index 6c3ee56..5f4ea85 100644 --- a/tests/test_vllm_engine.py +++ b/tests/test_vllm_engine.py @@ -25,7 +25,6 @@ - Integration tests: Test with real vLLM engine (requires GPU + infrastructure) """ -import os from dataclasses import dataclass from unittest.mock import MagicMock, patch @@ -282,145 +281,189 @@ def test_concatenated_tensors_shape(self): # ============================================================================= -# Integration Tests (Requires real GPU + vLLM + Mooncake) +# VllmEngine.generate() metadata flow tests # ============================================================================= -@pytest.mark.integration -@pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU required") -@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Need at least 2 GPUs for TP=2") -class TestVllmWorkerExtensionIntegration: - """Integration tests for vLLM Worker Extension with real infrastructure.""" +def _make_mock_output(request_id: str, prompt_token_ids: list[int]): + """Create a mock vLLM RequestOutput.""" + out = MagicMock() + out.request_id = request_id + out.prompt_token_ids = prompt_token_ids + return out - @pytest.fixture(autouse=True) - def setup_env(self): - """Setup Mooncake environment variables.""" - os.environ.setdefault("MOONCAKE_MASTER_HOST", "0.0.0.0") - os.environ.setdefault("MOONCAKE_MASTER_PORT", "50051") - os.environ.setdefault("MOONCAKE_METADATA_PORT", "8090") - yield - # Cleanup not needed for env vars - def test_vllm_worker_extension_mooncake(self): - """Test vLLM Worker Extension stores and retrieves hidden states from Mooncake.""" - from transformers import AutoTokenizer - from vllm import LLM, SamplingParams +def _build_engine_with_mock_vllm(metadata_by_request: dict): + """Build a VllmEngine whose _engine is a mock vLLM LLM. - from torchspec.transfer.mooncake import EagleMooncakeStore, MooncakeConfig + Returns (engine, mock_llm) so tests can inspect collective_rpc calls. + """ + try: + from torchspec.inference.engine.vllm_engine import VllmEngine + except ImportError as e: + pytest.skip(f"VllmEngine import failed: {e}") + + args = MagicMock() + args.target_model_path = "mock-model" + args.trust_remote_code = True + engine = VllmEngine.__new__(VllmEngine) + engine.args = args + engine.rank = 0 + engine.base_gpu_id = 0 + engine._hidden_size = 4096 + engine.aux_hidden_state_layer_ids = [2, 4] + + mock_llm = MagicMock() + + def _collective_rpc(method, args=(), kwargs=None): + if method == "_store_and_get_metadata": + return [metadata_by_request] + return [None] + + mock_llm.collective_rpc = MagicMock(side_effect=_collective_rpc) + engine._engine = mock_llm + return engine, mock_llm + + +class TestGenerateMetadataFlow: + """Test that generate() builds and sends request_metadata for both + the input_ids path and the formatted_prompts (defer_tokenization) path. + """ + + def test_input_ids_path_sends_metadata_twice(self): + """input_ids path: _set_request_metadata is called both pre- and + post-generation with correct token counts.""" + ids_a = torch.tensor([10, 20, 30]) + ids_b = torch.tensor([40, 50, 60, 70]) + data_ids = ["a", "b"] + + worker_meta = { + "a": { + "mooncake_key": "a", + "tensor_shapes": {}, + "tensor_dtypes": {}, + "input_ids_list": ids_a.tolist(), + }, + "b": { + "mooncake_key": "b", + "tensor_shapes": {}, + "tensor_dtypes": {}, + "input_ids_list": ids_b.tolist(), + }, + } + engine, mock_llm = _build_engine_with_mock_vllm(worker_meta) + + mock_llm.generate.return_value = [ + _make_mock_output("0", ids_a.tolist()), + _make_mock_output("1", ids_b.tolist()), + ] + + results = engine.generate( + data_id=data_ids, + input_ids_ref=[ids_a, ids_b], + ) + + set_meta_calls = [ + c for c in mock_llm.collective_rpc.call_args_list if c[0][0] == "_set_request_metadata" + ] + assert len(set_meta_calls) == 2, ( + f"Expected 2 _set_request_metadata calls, got {len(set_meta_calls)}" + ) - model_path = "Qwen/Qwen3-8B" + # Post-gen call (last one) must carry authoritative token counts + post_gen_args = set_meta_calls[-1][1]["args"] + req_meta = post_gen_args[0] + assert req_meta == {"a": 3, "b": 4} + + input_ids_map = post_gen_args[2] + assert input_ids_map == {"a": ids_a.tolist(), "b": ids_b.tolist()} + + assert len(results) == 2 + assert results[0]["data_id"] == "a" + assert results[1]["data_id"] == "b" + + def test_formatted_prompts_path_sends_metadata_post_gen(self): + """formatted_prompts (defer_tokenization) path: _set_request_metadata + is sent after generation with token counts from vLLM outputs.""" + prompt_tokens_a = [10, 20, 30, 40, 50] + prompt_tokens_b = [60, 70, 80] + data_ids = ["p0", "p1"] + + worker_meta = { + "p0": { + "mooncake_key": "p0", + "tensor_shapes": {}, + "tensor_dtypes": {}, + "input_ids_list": prompt_tokens_a, + }, + "p1": { + "mooncake_key": "p1", + "tensor_shapes": {}, + "tensor_dtypes": {}, + "input_ids_list": prompt_tokens_b, + }, + } + engine, mock_llm = _build_engine_with_mock_vllm(worker_meta) + + mock_llm.generate.return_value = [ + _make_mock_output("0", prompt_tokens_a), + _make_mock_output("1", prompt_tokens_b), + ] - # Initialize tokenizer - tokenizer = AutoTokenizer.from_pretrained(model_path) + results = engine.generate( + data_id=data_ids, + formatted_prompts=["Hello world", "Goodbye"], + ) - # Test inputs - input_ids_list = [ - [1, 2345, 6789], - [100, 200, 300, 400], - [500, 600], + set_meta_calls = [ + c for c in mock_llm.collective_rpc.call_args_list if c[0][0] == "_set_request_metadata" ] - data_ids = ["test_req_0", "test_req_1", "test_req_2"] - - # Initialize vLLM with Worker Extension - engine = LLM( - model=model_path, - tensor_parallel_size=2, - gpu_memory_utilization=0.7, - trust_remote_code=True, - worker_extension_cls="torchspec.inference.engine.vllm_worker_extension.VllmWorkerExtension", - max_model_len=2048, + # Only the post-gen call (pre-gen is skipped because request_metadata + # is empty before generation). + assert len(set_meta_calls) == 1 + + post_gen_args = set_meta_calls[0][1]["args"] + req_meta = post_gen_args[0] + assert req_meta == {"p0": 5, "p1": 3} + + input_ids_map = post_gen_args[2] + assert input_ids_map == {"p0": prompt_tokens_a, "p1": prompt_tokens_b} + + assert len(results) == 2 + assert results[0]["input_ids_list"] == prompt_tokens_a + assert results[1]["input_ids_list"] == prompt_tokens_b + + def test_formatted_prompts_with_no_packed_loss_mask(self): + """defer_tokenization path with packed_loss_mask_list=None works.""" + tokens = [1, 2, 3] + worker_meta = { + "d0": { + "mooncake_key": "d0", + "tensor_shapes": {}, + "tensor_dtypes": {}, + "input_ids_list": tokens, + }, + } + engine, mock_llm = _build_engine_with_mock_vllm(worker_meta) + mock_llm.generate.return_value = [_make_mock_output("0", tokens)] + + results = engine.generate( + data_id=["d0"], + formatted_prompts=["test"], + packed_loss_mask_list=None, ) - try: - # Configure hidden states capture - engine.collective_rpc("_setup_hidden_states_capture", args=([5, 10, 15],)) - - # Prepare generation - prompts = [tokenizer.decode(ids) for ids in input_ids_list] - sampling_params = SamplingParams(max_tokens=32, temperature=0) - - # Setup request metadata - request_metadata = {data_ids[i]: len(ids) for i, ids in enumerate(input_ids_list)} - engine.collective_rpc("_reset_capture") - engine.collective_rpc("_set_request_metadata", args=(request_metadata,)) - - # Generate - print("=== Generating with vLLM Worker Extension ===") - outputs = engine.generate(prompts, sampling_params) - assert len(outputs) == len(input_ids_list), "Generation output count mismatch" - - for i, output in enumerate(outputs): - print(f"\n--- Request {i} ---") - print(f"output_ids: {output.prompt_token_ids + list(output.outputs[0].token_ids)}") - print(f"num tokens generated: {len(output.outputs[0].token_ids)}") - - # Retrieve metadata from Mooncake - print("\n=== Retrieving metadata from Mooncake ===") - metadata_list = engine.collective_rpc("_store_and_get_metadata") - assert metadata_list is not None, "No metadata returned from workers" - - all_keys = [] - seq_lens = [] - for metadata in metadata_list: - if isinstance(metadata, dict): - for req_id, meta in metadata.items(): - assert "mooncake_key" in meta - assert "tensor_shapes" in meta - assert "num_layers" in meta - assert meta["num_layers"] == 3 - all_keys.append(meta["mooncake_key"]) - seq_lens.append(request_metadata[req_id]) - print( - f" {req_id}: key={meta['mooncake_key']}, layers={meta['num_layers']}" - ) - - # Fetch data from Mooncake Store - print("\n=== Fetching data from Mooncake Store ===") - mooncake_config = MooncakeConfig.from_env() - mooncake_store = EagleMooncakeStore(mooncake_config) - mooncake_store.setup(device="cuda") - - # Qwen3-8B dimensions - hidden_dim = 12288 # 3 layers concatenated (4096 * 3) - last_hidden_dim = 4096 - - for i, key in enumerate(all_keys): - seq_len = seq_lens[i] - shapes = { - "hidden_states": (seq_len, hidden_dim), - "input_ids": (seq_len,), - "last_hidden_states": (seq_len, last_hidden_dim), - } - dtypes = { - "hidden_states": torch.bfloat16, - "input_ids": torch.long, - "last_hidden_states": torch.bfloat16, - } - - data = mooncake_store.get(key, shapes=shapes, dtypes=dtypes, device="cuda") - print(f"\n Key: {key}") - print( - f" hidden_states: shape={data.hidden_states.shape}, dtype={data.hidden_states.dtype}" - ) - print(f" input_ids: {data.input_ids.tolist()}") - print(f" last_hidden_states: shape={data.last_hidden_states.shape}") - - # Verify tensor device consistency - assert data.hidden_states.device == data.input_ids.device, ( - f"Device mismatch: hidden_states={data.hidden_states.device}, input_ids={data.input_ids.device}" - ) - - print("\n✓ Test completed - hidden states sent to Mooncake and retrieved successfully") - - finally: - # Cleanup - if hasattr(engine, "shutdown"): - engine.shutdown() + set_meta_calls = [ + c for c in mock_llm.collective_rpc.call_args_list if c[0][0] == "_set_request_metadata" + ] + assert len(set_meta_calls) == 1 + packed_map = set_meta_calls[0][1]["args"][1] + assert packed_map == {} + + assert len(results) == 1 + assert "packed_loss_mask" not in results[0] -# ============================================================================= -# Legacy main block (kept for backward compatibility) -# ============================================================================= if __name__ == "__main__": pytest.main([__file__, "-v"]) diff --git a/tests/test_vllm_engine_integration.py b/tests/test_vllm_engine_integration.py index 3718b10..258d96b 100644 --- a/tests/test_vllm_engine_integration.py +++ b/tests/test_vllm_engine_integration.py @@ -194,17 +194,20 @@ def verify_from_mooncake(mooncake_store, keys, seq_lens, hidden_dim, last_hidden outputs = engine.generate(text_prompts, sampling_params, use_tqdm=False) - # Build metadata from outputs post-generation (same as VllmEngine does) + # Build authoritative metadata from outputs and set on workers, + # mirroring VllmEngine.generate()'s unconditional post-generation path. request_metadata = {} input_ids_map = {} + internal_to_external = {} for i, output in enumerate(outputs): did = prompt_data_ids[i] request_metadata[did] = len(output.prompt_token_ids) input_ids_map[did] = list(output.prompt_token_ids) + internal_to_external[output.request_id] = did print(f' Request {i}: "{text_prompts[i]}" -> {len(output.prompt_token_ids)} tokens') engine.collective_rpc("_set_request_metadata", args=(request_metadata, {}, input_ids_map)) - metadata = collect_metadata(engine) + metadata = collect_metadata(engine, internal_to_external=internal_to_external) all_keys = [metadata[did]["mooncake_key"] for did in prompt_data_ids] seq_lens = [request_metadata[did] for did in prompt_data_ids] assert len(metadata) == len(prompt_data_ids), ( diff --git a/torchspec/config/train_config.py b/torchspec/config/train_config.py index 3a6477c..20da08b 100644 --- a/torchspec/config/train_config.py +++ b/torchspec/config/train_config.py @@ -194,6 +194,19 @@ def _resolve_relative_paths( OmegaConf.update(config, dotted_key, os.path.abspath(os.path.join(base_dir, expanded))) +def _validate_vllm_config(config: DictConfig) -> None: + """Raise if the vllm backend is selected with unsupported feature flags.""" + if config.model.target_model_backend != "vllm": + return + unsupported_flags = { + "inference.vllm.enable_multimodal": "enable_multimodal", + "training.train_with_decode": "train_with_decode", + } + for key, label in unsupported_flags.items(): + if OmegaConf.select(config, key): + raise NotImplementedError(f"{label} is not yet supported with the vllm backend!") + + def _save_config_snapshot(config: DictConfig) -> None: """Save the resolved config to output_dir/config.yaml if output_dir is set.""" output_dir = OmegaConf.select(config, "output_dir", default=None) @@ -236,6 +249,9 @@ def load_config( config = OmegaConf.merge(*configs_to_merge) _resolve_relative_paths(config, os.getcwd()) + + _validate_vllm_config(config) + if save_snapshot: _save_config_snapshot(config) diff --git a/torchspec/controller/inference_manager.py b/torchspec/controller/inference_manager.py index b43a59c..4e032dd 100644 --- a/torchspec/controller/inference_manager.py +++ b/torchspec/controller/inference_manager.py @@ -408,7 +408,7 @@ def _prepare_engine_inputs(self, entries: list[InferenceInput]) -> dict: if self._defer_tokenization: input_ids_ref = None - packed_loss_mask_list = [None] * len(entries) + packed_loss_mask_list = None assert all(e.formatted_prompt is not None for e in entries), ( "formatted_prompt is required when defer_tokenization is True" ) diff --git a/torchspec/inference/engine/vllm_engine.py b/torchspec/inference/engine/vllm_engine.py index d4912d5..a6681d1 100644 --- a/torchspec/inference/engine/vllm_engine.py +++ b/torchspec/inference/engine/vllm_engine.py @@ -243,7 +243,7 @@ def generate( self, data_id: str | list[str], input_ids_ref: ray.ObjectRef | list[torch.Tensor] | None = None, - packed_loss_mask_list: list[str] | None = None, + packed_loss_mask_list: list[str | None] | None = None, formatted_prompts: list[str] | None = None, return_last_hidden_states: bool = False, return_logits: bool = True, @@ -323,24 +323,21 @@ def generate( for i, output in enumerate(outputs): internal_to_external[output.request_id] = data_ids[i] - # For the formatted_prompts path, request_metadata and input_ids_map - # were not set before generation (no input_ids available). Build them - # from the outputs so the worker can map captured states to requests. - if use_prompts and not request_metadata: - for i, output in enumerate(outputs): - did = data_ids[i] - request_metadata[did] = len(output.prompt_token_ids) - input_ids_map[did] = list(output.prompt_token_ids) - try: - self._engine.collective_rpc( - "_set_request_metadata", - args=(request_metadata, packed_loss_mask_map, input_ids_map), - ) - except Exception as e: - logger.warning( - f"VllmEngine rank {self.rank}: Could not set post-generation " - f"request metadata: {e}" - ) + # Always build request_metadata and input_ids_map from the + # outputs. + for i, output in enumerate(outputs): + did = data_ids[i] + request_metadata[did] = len(output.prompt_token_ids) + input_ids_map[did] = list(output.prompt_token_ids) + try: + self._engine.collective_rpc( + "_set_request_metadata", + args=(request_metadata, packed_loss_mask_map, input_ids_map), + ) + except Exception as e: + logger.warning( + f"VllmEngine rank {self.rank}: Could not set post-generation request metadata: {e}" + ) # Get metadata from workers (tensors are already stored in Mooncake by workers) metadata_by_request: dict[str, dict] = {} diff --git a/torchspec/inference/engine/vllm_worker_extension.py b/torchspec/inference/engine/vllm_worker_extension.py index a474ac7..99b8a4d 100644 --- a/torchspec/inference/engine/vllm_worker_extension.py +++ b/torchspec/inference/engine/vllm_worker_extension.py @@ -429,7 +429,7 @@ def _setup_hidden_states_capture(self, layer_ids: List[int]) -> None: self._captured_states = None self._request_metadata = [] self._current_request_metadata = None - self._packed_loss_mask_map: Dict[str, str] = {} + self._packed_loss_mask_map: Dict[str, Optional[str]] = {} self._store_initialized = False self._store_setup_complete = False self._init_retry_count = 0 @@ -468,7 +468,7 @@ def _setup_hidden_states_capture(self, layer_ids: List[int]) -> None: def _set_request_metadata( self, request_metadata: Dict[str, int], - packed_loss_mask_map: Optional[Dict[str, str]] = None, + packed_loss_mask_map: Optional[Dict[str, Optional[str]]] = None, input_ids_map: Optional[Dict[str, List[int]]] = None, ) -> None: """Set request metadata for the next forward pass. @@ -478,7 +478,8 @@ def _set_request_metadata( Args: request_metadata: Dict mapping request_id -> num_prefill_tokens - packed_loss_mask_map: Optional dict mapping request_id -> packed_loss_mask string + packed_loss_mask_map: Optional dict mapping request_id -> packed_loss_mask + string (values may be None when loss masks are not available). input_ids_map: Optional dict mapping request_id -> input_ids list (passed via RPC) """ self._current_request_metadata = request_metadata