diff --git a/README.md b/README.md index 076406c..20af846 100644 --- a/README.md +++ b/README.md @@ -8,12 +8,32 @@ TorchSpec is a torch-native speculative decoding training framework. We introduc ## Setup +### Choose Your Backend + +TorchSpec supports two inference backends: + +| Backend | Best For | Installation | +|---------|----------|--------------| +| **SGLang** | Production workloads, high throughput | `./tools/build_conda.sh 1 sglang` (default) | +| **vLLM** | Flexibility, easier deployment | `./tools/build_conda.sh 1 vllm` | +| **Both** | Development, comparison testing | `./tools/build_conda.sh 1 both` | + +### Quick Setup + ```bash +# Install with SGLang (default) ./tools/build_conda.sh micromamba activate torchspec + +# Or install with vLLM +./tools/build_conda.sh 1 vllm +micromamba activate torchspec ``` -To install into your current environment instead: `./tools/build_conda.sh current` +To install into your current environment instead: +```bash +./tools/build_conda.sh current sglang # or 'vllm' or 'both' +``` Optional — install Flash Attention: @@ -21,6 +41,20 @@ Optional — install Flash Attention: pip install -e ".[fa]" ``` +### Backend-Specific Usage + +**SGLang (default):** +```bash +./examples/qwen3-8b-single-node/run.sh +``` + +**vLLM:** +```bash +./examples/qwen3-8b-single-node/run.sh --config configs/vllm_qwen3_8b.yaml +``` + +TorchSpec uses vLLM's **Worker Extension** mechanism to hook into the model's forward pass and capture hidden states directly in the worker processes. This avoids RPC serialization issues and enables reliable hidden states extraction. + ## Quick Start Train an Eagle3 draft model for Qwen3-8B using inference engine (4 GPUs: 2 training + 2 inference): diff --git a/configs/vllm_qwen3_8b.yaml b/configs/vllm_qwen3_8b.yaml new file mode 100644 index 0000000..96ce140 --- /dev/null +++ b/configs/vllm_qwen3_8b.yaml @@ -0,0 +1,74 @@ +# Configuration for train_entry.py with vLLM Engine inference (nested config format) +# +# GPU allocation: +# - 2 GPUs for inference (duplicate mode: each engine has full model copy) +# - 2 GPUs for training (DP/FSDP: model sharded across 2 GPUs) +# - Total: 4 GPUs +# +# Installation: +# pip install -e ".[vllm]" # Install vLLM backend +# +# Usage: +# python -m torchspec.train_entry --config configs/vllm_qwen3_8b.yaml +# +# Note: Uses vLLM Worker Extension to hook into model forward pass for hidden states capture. + +model: + target_model_path: Qwen/Qwen3-8B + trust_remote_code: true + +dataset: + train_data_path: ../examples/data/sample_conversations.jsonl + eval_data_path: ../examples/data/eval_conversations.jsonl + eval_interval: 100 + chat_template: qwen + prompt_key: conversations + +training: + attention_backend: flex_attention + micro_batch_size: 1 + draft_accumulation_steps: 1 + learning_rate: 1e-4 + max_concurrent_batches: 1 + max_grad_norm: 0.5 + max_seq_length: 16384 + num_epochs: 1 + seed: 42 + training_num_gpus_per_node: 2 + training_num_nodes: 1 + ttt_length: 7 + save_per_epoch: true + warmup_ratio: 0.015 + +inference: + inference_engine_type: vllm + inference_num_gpus: 2 + inference_num_gpus_per_engine: 2 + inference_num_gpus_per_node: 4 + max_sample_pool_size: 64 + inference_buffer_threshold: 32 + inference_batch_size: 8 + vllm: + tp_size: 2 + mem_fraction_static: 0.7 + use_worker_extension: true + extra_args: + max_num_batched_tokens: 32768 + compilation_config: + max_cudagraph_capture_size: 8 + +mooncake: + master_server_address: null + metadata_server: null + protocol: tcp + global_segment_size: 16GB + local_buffer_size: 4GB + +output_dir: ./outputs/qwen3-8b-single-node +cache_dir: ./cache +model_download_dir: null + +debug: + save_debug_train_data: null + debug_train_only: false + debug_inference_only: false diff --git a/docs/code_architecture.md b/docs/code_architecture.md index 5fa2379..ab10ad4 100644 --- a/docs/code_architecture.md +++ b/docs/code_architecture.md @@ -24,7 +24,8 @@ torchspec/ │ ├── base.py # InferenceEngine (ABC) │ ├── hf_engine.py # HFEngine (Ray actor, inherits RayActor) │ ├── hf_runner.py # HFRunner (core inference logic) -│ └── sgl_engine.py # SglEngine (Ray actor, inherits RayActor) +│ ├── sgl_engine.py # SglEngine (Ray actor, inherits RayActor) +│ └── vllm_engine.py # VllmEngine (Ray actor, uses vLLM extract_hidden_states) ├── models/ # Model definitions │ ├── eagle3.py # Eagle3Model (core forward/loss) │ ├── draft/ # Draft model implementations diff --git a/examples/README.md b/examples/README.md index ebd664f..1ae5f0d 100644 --- a/examples/README.md +++ b/examples/README.md @@ -17,12 +17,22 @@ If you just want to try TorchSpec locally, start with **hf-quickstart** (3 GPUs, ./examples/hf-quickstart/run.sh ``` -For production workloads with SGLang async inference, use **qwen3-8b-single-node**: +For production workloads with async inference, use **qwen3-8b-single-node**: ```bash ./examples/qwen3-8b-single-node/run.sh ``` +## Switching inference backends + +Examples use SGLang by default. To use vLLM instead: + +```bash +# Use vLLM backend with qwen3-8b-single-node example +./examples/qwen3-8b-single-node/run.sh \ + --config configs/vllm_qwen3_8b.yaml \ +``` + ## Data Sample training data is in [`data/sample_conversations.jsonl`](data/sample_conversations.jsonl). All examples that use local data point to this file by default. diff --git a/examples/qwen3-8b-single-node/run.sh b/examples/qwen3-8b-single-node/run.sh index ab68399..14675bb 100755 --- a/examples/qwen3-8b-single-node/run.sh +++ b/examples/qwen3-8b-single-node/run.sh @@ -1,5 +1,5 @@ #!/bin/bash -# Train with SglEngine async inference (multi-GPU version) +# Train with SGLang/vLLM async inference (multi-GPU version) # # GPU allocation (default: 4 GPUs total): # - 2 GPUs for inference (duplicate mode: each engine has full model copy) @@ -46,7 +46,7 @@ INFERENCE_GPUS=2 LOCAL_IP=$(python3 -c "import socket; s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM); s.connect(('8.8.8.8', 80)); print(s.getsockname()[0]); s.close()") echo "==============================================" -echo "Train with SglEngine inference" +echo "Train with async inference" echo "==============================================" echo "Config: $CONFIG_FILE (nested format)" echo "Total GPUs: $TOTAL_GPUS (CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES)" @@ -59,11 +59,9 @@ echo "==============================================" python3 -m torchspec.train_entry \ --config "$CONFIG_FILE" \ training.training_num_gpus_per_node="$TRAIN_GPUS" \ - inference.inference_engine_type="sgl" \ inference.inference_num_gpus="$INFERENCE_GPUS" \ inference.inference_num_gpus_per_engine=2 \ inference.inference_num_gpus_per_node="$TOTAL_GPUS" \ - inference.sglang.tp_size=2 \ "$@" echo "==============================================" diff --git a/patches/vllm/v0.15.1/vllm.patch b/patches/vllm/v0.15.1/vllm.patch deleted file mode 100644 index 390bdff..0000000 --- a/patches/vllm/v0.15.1/vllm.patch +++ /dev/null @@ -1 +0,0 @@ -# PLACEHOLDER diff --git a/pyproject.toml b/pyproject.toml index 12a7721..6a23e90 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,8 +41,12 @@ dev = [ "ruff", ] +vllm = [ + "vllm>=0.16.0", +] + fa = [ - "flash-attn-cute @ git+https://github.com/Dao-AILab/flash-attention.git@fec3a6a18460c1b40f097208d4c16fe8964a679d#subdirectory=flash_attn/cute", + "flash-attention-cute @ git+https://github.com/Dao-AILab/flash-attention.git@fec3a6a18460c1b40f097208d4c16fe8964a679d#subdirectory=flash_attn/cute", "nvidia-cutlass-dsl==4.4.0.dev1", "nvidia-cutlass-dsl-libs-base==4.4.0.dev1", ] diff --git a/tests/test_sglang_engine.py b/tests/test_sglang_engine_integration.py similarity index 100% rename from tests/test_sglang_engine.py rename to tests/test_sglang_engine_integration.py diff --git a/tests/test_vllm_engine.py b/tests/test_vllm_engine.py new file mode 100644 index 0000000..5f4ea85 --- /dev/null +++ b/tests/test_vllm_engine.py @@ -0,0 +1,469 @@ +# Copyright (c) 2026 LightSeek Foundation +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +"""Tests for vLLM Worker Extension. + +This file contains both: +- Unit tests: Test logic with mocks (no GPU/vLLM/Mooncake needed) +- Integration tests: Test with real vLLM engine (requires GPU + infrastructure) +""" + +from dataclasses import dataclass +from unittest.mock import MagicMock, patch + +import pytest +import torch + +# ============================================================================= +# Helpers +# ============================================================================= + + +@dataclass +class MockArgs: + """Mock args for VllmWorkerExtension initialization.""" + + target_model_path: str = "Qwen/Qwen3-8B" + tensor_parallel_size: int = 2 + max_model_len: int = 2048 + trust_remote_code: bool = True + + +def _import_vllm_worker_extension(): + """Import VllmWorkerExtension, skipping test if dependencies unavailable.""" + try: + from torchspec.inference.engine.vllm_worker_extension import ( + VllmWorkerExtension, + _sanitize_mooncake_key, + ) + + return VllmWorkerExtension, _sanitize_mooncake_key + except ImportError as e: + pytest.skip(f"VllmWorkerExtension import failed (missing deps): {e}") + + +# ============================================================================= +# Unit Tests (No real vLLM/GPU/Mooncake needed) +# ============================================================================= + + +class TestSanitizeMooncakeKey: + """Unit tests for _sanitize_mooncake_key pure function.""" + + def test_alphanumeric_unchanged(self): + """Test alphanumeric keys pass through unchanged.""" + _, _sanitize = _import_vllm_worker_extension() + assert _sanitize("req_abc_123") == "req_abc_123" + + def test_special_chars_replaced(self): + """Test special characters are replaced with underscores.""" + _, _sanitize = _import_vllm_worker_extension() + assert _sanitize("req@abc#123") == "req_abc_123" + assert _sanitize("req.id.name") == "req_id_name" + assert _sanitize("req:name|value") == "req_name_value" + + def test_leading_digit_prefixed(self): + """Test leading digits get 'k' prefix.""" + _, _sanitize = _import_vllm_worker_extension() + assert _sanitize("123_req") == "k123_req" + assert _sanitize("1abc") == "k1abc" + + def test_empty_string(self): + """Test empty string handling.""" + _, _sanitize = _import_vllm_worker_extension() + assert _sanitize("") == "" + + +class TestVllmWorkerExtensionState: + """Unit tests for VllmWorkerExtension state management.""" + + def test_init_stores_config(self): + """Test constructor initializes state correctly.""" + VllmWorkerExtension, _ = _import_vllm_worker_extension() + + ext = VllmWorkerExtension() + + assert ext._layer_ids == frozenset() + assert ext._captured_states is None + assert ext._request_metadata == [] + assert ext._current_request_metadata is None + assert ext._mooncake_store is None + assert ext._store_initialized is False + + def test_set_request_metadata(self): + """Test setting request metadata.""" + VllmWorkerExtension, _ = _import_vllm_worker_extension() + + ext = VllmWorkerExtension() + metadata = {"req_1": 100, "req_2": 200} + packed_map = {"req_1": "0,3", "req_2": "0,5"} + input_ids_map = {"req_1": [1, 2, 3], "req_2": [4, 5, 6]} + + ext._set_request_metadata(metadata, packed_map, input_ids_map) + + assert ext._current_request_metadata == metadata + assert ext._packed_loss_mask_map == packed_map + assert ext._input_ids_map == input_ids_map + + def test_reset_capture_clears_state(self): + """Test reset_capture clears all captured state.""" + VllmWorkerExtension, _ = _import_vllm_worker_extension() + + ext = VllmWorkerExtension() + ext._layer_ids = frozenset({5, 10, 15}) + ext._captured_states = [[torch.randn(10, 4096)], [torch.randn(10, 4096)]] + ext._captured_input_ids = torch.tensor([1, 2, 3]) + ext._request_metadata = [{"req_1": 10}] + ext._current_request_metadata = {"req_1": 10} + ext._packed_loss_mask_map = {"req_1": "0,3"} + ext._input_ids_map = {"req_1": [1, 2, 3]} + + ext._reset_capture() + + assert ext._captured_states is None + assert ext._captured_input_ids is None + assert ext._request_metadata == [] + assert ext._current_request_metadata is None + assert ext._packed_loss_mask_map == {} + assert ext._input_ids_map == {} + + def test_reset_capture_requires_prior_setup(self): + """Test reset_capture requires _setup_hidden_states_capture first.""" + VllmWorkerExtension, _ = _import_vllm_worker_extension() + + ext = VllmWorkerExtension() + # Don't set _layer_ids + + with pytest.raises(RuntimeError, match="Must call _setup_hidden_states_capture"): + ext._reset_capture() + + +class TestStoreCapturedStates: + """Unit tests for _store_captured_states with mocked dependencies.""" + + def test_store_first_capture(self): + """Test first capture initializes the state lists.""" + VllmWorkerExtension, _ = _import_vllm_worker_extension() + + ext = VllmWorkerExtension() + tensors = [torch.randn(10, 4096), torch.randn(10, 4096)] + + ext._store_captured_states(tensors) + + assert ext._captured_states is not None + assert len(ext._captured_states) == 2 + assert torch.equal(ext._captured_states[0][0], tensors[0]) + assert torch.equal(ext._captured_states[1][0], tensors[1]) + + def test_store_appends_to_existing(self): + """Test subsequent captures append to existing lists.""" + VllmWorkerExtension, _ = _import_vllm_worker_extension() + + ext = VllmWorkerExtension() + ext._captured_states = [[torch.randn(10, 4096)], [torch.randn(10, 4096)]] + + new_tensors = [torch.randn(10, 4096), torch.randn(10, 4096)] + ext._store_captured_states(new_tensors) + + assert len(ext._captured_states[0]) == 2 + assert len(ext._captured_states[1]) == 2 + assert torch.equal(ext._captured_states[0][1], new_tensors[0]) + + def test_store_extracts_metadata_from_input_batch(self): + """Test metadata extraction from model_runner.input_batch.""" + VllmWorkerExtension, _ = _import_vllm_worker_extension() + + ext = VllmWorkerExtension() + + # Mock model_runner with input_batch + mock_batch = MagicMock() + mock_batch.req_ids = ["req_1", "req_2"] + mock_batch.req_id_to_index = {"req_1": 0, "req_2": 1} + mock_batch.num_tokens = [100, 200] + mock_batch.num_computed_tokens = [0, 0] + + ext.model_runner = MagicMock() + ext.model_runner.input_batch = mock_batch + + tensors = [torch.randn(10, 4096)] + ext._store_captured_states(tensors) + + assert len(ext._request_metadata) == 1 + assert "req_1" in ext._request_metadata[0] + assert "req_2" in ext._request_metadata[0] + + +class TestCudaDeviceSafe: + """Unit tests for _get_cuda_device_safe with mocked torch.cuda.""" + + @patch("torch.cuda.is_initialized") + @patch("torch.cuda.current_device") + def test_initialized_context(self, mock_current, mock_initialized): + """Test when CUDA is already initialized.""" + VllmWorkerExtension, _ = _import_vllm_worker_extension() + + mock_initialized.return_value = True + mock_current.return_value = 1 + + ext = VllmWorkerExtension() + device = ext._get_cuda_device_safe() + + assert str(device) == "cuda:1" + + @patch("torch.cuda.is_initialized") + def test_uninitialized_context_fallback(self, mock_initialized): + """Test fallback when CUDA not initialized (V1 engine).""" + VllmWorkerExtension, _ = _import_vllm_worker_extension() + + mock_initialized.return_value = False + + ext = VllmWorkerExtension() + device = ext._get_cuda_device_safe() + + assert str(device) == "cuda:0" + + +class TestTokenSlicingLogic: + """Unit tests for token distribution and slicing logic.""" + + def test_ratio_based_distribution(self): + """Test ratio calculation for token distribution.""" + VllmWorkerExtension, _ = _import_vllm_worker_extension() + + ext = VllmWorkerExtension() + ext._current_request_metadata = {"req_1": 100, "req_2": 200} + + external_ids = list(ext._current_request_metadata.keys()) + token_counts = list(ext._current_request_metadata.values()) + total_expected = sum(token_counts) # 300 + total_captured = 150 # Half the expected tokens + + ratio = total_captured / total_expected # 0.5 + + # Calculate actual tokens per request + actual_tokens = {ext_id: int(tc * ratio) for ext_id, tc in zip(external_ids, token_counts)} + + assert actual_tokens == {"req_1": 50, "req_2": 100} + + def test_concatenated_tensors_shape(self): + """Test tensor concatenation from multiple iterations.""" + VllmWorkerExtension, _ = _import_vllm_worker_extension() + + ext = VllmWorkerExtension() + # Simulate 2 iterations with 5 tokens each + ext._captured_states = [ + [torch.randn(5, 4096), torch.randn(5, 4096)], # Layer 0 + [torch.randn(5, 4096), torch.randn(5, 4096)], # Layer 1 + ] + + # Concatenate (simulating _store_and_get_metadata logic) + concatenated = [torch.cat(layer_tensors, dim=0) for layer_tensors in ext._captured_states] + + assert concatenated[0].shape == (10, 4096) + assert concatenated[1].shape == (10, 4096) + + +# ============================================================================= +# VllmEngine.generate() metadata flow tests +# ============================================================================= + + +def _make_mock_output(request_id: str, prompt_token_ids: list[int]): + """Create a mock vLLM RequestOutput.""" + out = MagicMock() + out.request_id = request_id + out.prompt_token_ids = prompt_token_ids + return out + + +def _build_engine_with_mock_vllm(metadata_by_request: dict): + """Build a VllmEngine whose _engine is a mock vLLM LLM. + + Returns (engine, mock_llm) so tests can inspect collective_rpc calls. + """ + try: + from torchspec.inference.engine.vllm_engine import VllmEngine + except ImportError as e: + pytest.skip(f"VllmEngine import failed: {e}") + + args = MagicMock() + args.target_model_path = "mock-model" + args.trust_remote_code = True + engine = VllmEngine.__new__(VllmEngine) + engine.args = args + engine.rank = 0 + engine.base_gpu_id = 0 + engine._hidden_size = 4096 + engine.aux_hidden_state_layer_ids = [2, 4] + + mock_llm = MagicMock() + + def _collective_rpc(method, args=(), kwargs=None): + if method == "_store_and_get_metadata": + return [metadata_by_request] + return [None] + + mock_llm.collective_rpc = MagicMock(side_effect=_collective_rpc) + engine._engine = mock_llm + return engine, mock_llm + + +class TestGenerateMetadataFlow: + """Test that generate() builds and sends request_metadata for both + the input_ids path and the formatted_prompts (defer_tokenization) path. + """ + + def test_input_ids_path_sends_metadata_twice(self): + """input_ids path: _set_request_metadata is called both pre- and + post-generation with correct token counts.""" + ids_a = torch.tensor([10, 20, 30]) + ids_b = torch.tensor([40, 50, 60, 70]) + data_ids = ["a", "b"] + + worker_meta = { + "a": { + "mooncake_key": "a", + "tensor_shapes": {}, + "tensor_dtypes": {}, + "input_ids_list": ids_a.tolist(), + }, + "b": { + "mooncake_key": "b", + "tensor_shapes": {}, + "tensor_dtypes": {}, + "input_ids_list": ids_b.tolist(), + }, + } + engine, mock_llm = _build_engine_with_mock_vllm(worker_meta) + + mock_llm.generate.return_value = [ + _make_mock_output("0", ids_a.tolist()), + _make_mock_output("1", ids_b.tolist()), + ] + + results = engine.generate( + data_id=data_ids, + input_ids_ref=[ids_a, ids_b], + ) + + set_meta_calls = [ + c for c in mock_llm.collective_rpc.call_args_list if c[0][0] == "_set_request_metadata" + ] + assert len(set_meta_calls) == 2, ( + f"Expected 2 _set_request_metadata calls, got {len(set_meta_calls)}" + ) + + # Post-gen call (last one) must carry authoritative token counts + post_gen_args = set_meta_calls[-1][1]["args"] + req_meta = post_gen_args[0] + assert req_meta == {"a": 3, "b": 4} + + input_ids_map = post_gen_args[2] + assert input_ids_map == {"a": ids_a.tolist(), "b": ids_b.tolist()} + + assert len(results) == 2 + assert results[0]["data_id"] == "a" + assert results[1]["data_id"] == "b" + + def test_formatted_prompts_path_sends_metadata_post_gen(self): + """formatted_prompts (defer_tokenization) path: _set_request_metadata + is sent after generation with token counts from vLLM outputs.""" + prompt_tokens_a = [10, 20, 30, 40, 50] + prompt_tokens_b = [60, 70, 80] + data_ids = ["p0", "p1"] + + worker_meta = { + "p0": { + "mooncake_key": "p0", + "tensor_shapes": {}, + "tensor_dtypes": {}, + "input_ids_list": prompt_tokens_a, + }, + "p1": { + "mooncake_key": "p1", + "tensor_shapes": {}, + "tensor_dtypes": {}, + "input_ids_list": prompt_tokens_b, + }, + } + engine, mock_llm = _build_engine_with_mock_vllm(worker_meta) + + mock_llm.generate.return_value = [ + _make_mock_output("0", prompt_tokens_a), + _make_mock_output("1", prompt_tokens_b), + ] + + results = engine.generate( + data_id=data_ids, + formatted_prompts=["Hello world", "Goodbye"], + ) + + set_meta_calls = [ + c for c in mock_llm.collective_rpc.call_args_list if c[0][0] == "_set_request_metadata" + ] + # Only the post-gen call (pre-gen is skipped because request_metadata + # is empty before generation). + assert len(set_meta_calls) == 1 + + post_gen_args = set_meta_calls[0][1]["args"] + req_meta = post_gen_args[0] + assert req_meta == {"p0": 5, "p1": 3} + + input_ids_map = post_gen_args[2] + assert input_ids_map == {"p0": prompt_tokens_a, "p1": prompt_tokens_b} + + assert len(results) == 2 + assert results[0]["input_ids_list"] == prompt_tokens_a + assert results[1]["input_ids_list"] == prompt_tokens_b + + def test_formatted_prompts_with_no_packed_loss_mask(self): + """defer_tokenization path with packed_loss_mask_list=None works.""" + tokens = [1, 2, 3] + worker_meta = { + "d0": { + "mooncake_key": "d0", + "tensor_shapes": {}, + "tensor_dtypes": {}, + "input_ids_list": tokens, + }, + } + engine, mock_llm = _build_engine_with_mock_vllm(worker_meta) + mock_llm.generate.return_value = [_make_mock_output("0", tokens)] + + results = engine.generate( + data_id=["d0"], + formatted_prompts=["test"], + packed_loss_mask_list=None, + ) + + set_meta_calls = [ + c for c in mock_llm.collective_rpc.call_args_list if c[0][0] == "_set_request_metadata" + ] + assert len(set_meta_calls) == 1 + + packed_map = set_meta_calls[0][1]["args"][1] + assert packed_map == {} + + assert len(results) == 1 + assert "packed_loss_mask" not in results[0] + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/test_vllm_engine_integration.py b/tests/test_vllm_engine_integration.py new file mode 100644 index 0000000..258d96b --- /dev/null +++ b/tests/test_vllm_engine_integration.py @@ -0,0 +1,224 @@ +"""Standalone integration script that tests vLLM Worker Extension hidden states collection behavior. + +Tests: + 1. Short sequences via input_ids (basic capture) + 2. Longer sequences via input_ids + 3. formatted_prompts path (defer tokenization mode) +""" + +import os +import socket + +import torch +from transformers import AutoTokenizer + +from torchspec.transfer.mooncake import EagleMooncakeStore, MooncakeConfig + +# Detect local IP for Mooncake connections +try: + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + s.connect(("8.8.8.8", 80)) + LOCAL_IP = s.getsockname()[0] + s.close() +except Exception: + LOCAL_IP = "localhost" + +# Mooncake env vars read by MooncakeConfig.from_env() on both sides +os.environ["MOONCAKE_MASTER_HOST"] = LOCAL_IP +os.environ["MOONCAKE_MASTER_PORT"] = "50051" +os.environ["MOONCAKE_METADATA_PORT"] = "8090" +os.environ["MOONCAKE_LOCAL_HOSTNAME"] = LOCAL_IP +os.environ["MOONCAKE_MASTER_SERVER"] = f"{LOCAL_IP}:50051" + + +def collect_metadata(engine, internal_to_external=None): + """Call _store_and_get_metadata and merge results from all TP ranks.""" + args = (internal_to_external,) if internal_to_external else () + metadata_list = engine.collective_rpc("_store_and_get_metadata", args=args) + merged = {} + if isinstance(metadata_list, list): + for m in metadata_list: + if isinstance(m, dict): + merged.update(m) + elif isinstance(metadata_list, dict): + merged = metadata_list + return merged + + +def verify_from_mooncake(mooncake_store, keys, seq_lens, hidden_dim, last_hidden_dim): + """Fetch tensors from Mooncake and verify shapes.""" + for i, key in enumerate(keys): + seq_len = seq_lens[i] + shapes = { + "hidden_states": (seq_len, hidden_dim), + "input_ids": (seq_len,), + "last_hidden_states": (seq_len, last_hidden_dim), + } + dtypes = { + "hidden_states": torch.bfloat16, + "input_ids": torch.long, + "last_hidden_states": torch.bfloat16, + } + data = mooncake_store.get(key, shapes=shapes, dtypes=dtypes, device="cuda") + print(f"\n Key: {key}") + print( + f" hidden_states: shape={data.hidden_states.shape}, dtype={data.hidden_states.dtype}" + ) + print(f" input_ids: {data.input_ids.tolist()[:10]}{'...' if seq_len > 10 else ''}") + print(f" last_hidden_states: shape={data.last_hidden_states.shape}") + + assert data.hidden_states.shape == (seq_len, hidden_dim), ( + f"hidden_states shape {data.hidden_states.shape} != expected {(seq_len, hidden_dim)}" + ) + assert data.input_ids.shape == (seq_len,) + assert data.last_hidden_states.shape == (seq_len, last_hidden_dim) + + +if __name__ == "__main__": + model_path = "Qwen/Qwen3-8B" + aux_layer_ids = [2, 4, 6] + tp_size = 4 + hidden_size = 4096 + num_aux_layers = len(aux_layer_ids) + hidden_dim = num_aux_layers * hidden_size + last_hidden_dim = hidden_size + + tokenizer = AutoTokenizer.from_pretrained(model_path) + + from vllm import LLM, SamplingParams + + engine = LLM( + model=model_path, + tensor_parallel_size=tp_size, + gpu_memory_utilization=0.7, + trust_remote_code=True, + distributed_executor_backend="mp", + disable_custom_all_reduce=True, + disable_log_stats=True, + worker_extension_cls="torchspec.inference.engine.vllm_worker_extension.VllmWorkerExtension", + max_model_len=2048, + enable_chunked_prefill=False, + ) + + engine.collective_rpc("_setup_hidden_states_capture", args=(aux_layer_ids,)) + + mooncake_config = MooncakeConfig.from_env() + mooncake_store = EagleMooncakeStore(mooncake_config) + mooncake_store.setup(device="cuda") + + sampling_params = SamplingParams(max_tokens=1, temperature=0) + + # ========================================================================= + # Test 1: Short sequences + # ========================================================================= + print("\n" + "=" * 60) + print("TEST 1: Short sequences") + print("=" * 60) + + input_ids_list = [ + [1, 2345, 6789], + [100, 200, 300, 400], + [500, 600], + ] + data_ids = ["short_0", "short_1", "short_2"] + + prompts = [{"prompt_token_ids": ids} for ids in input_ids_list] + request_metadata = {data_ids[i]: len(ids) for i, ids in enumerate(input_ids_list)} + input_ids_map = {data_ids[i]: ids for i, ids in enumerate(input_ids_list)} + + engine.collective_rpc("_reset_capture") + engine.collective_rpc("_set_request_metadata", args=(request_metadata, {}, input_ids_map)) + + outputs = engine.generate(prompts, sampling_params, use_tqdm=False) + for i, output in enumerate(outputs): + print(f" Request {i}: {len(output.prompt_token_ids)} prompt tokens") + + metadata = collect_metadata(engine) + all_keys = [metadata[did]["mooncake_key"] for did in data_ids] + seq_lens = [request_metadata[did] for did in data_ids] + assert len(metadata) == len(data_ids), f"Expected {len(data_ids)} results, got {len(metadata)}" + + verify_from_mooncake(mooncake_store, all_keys, seq_lens, hidden_dim, last_hidden_dim) + print("\n✓ Test 1 passed") + + # ========================================================================= + # Test 2: Longer sequences + # ========================================================================= + print("\n" + "=" * 60) + print("TEST 2: Longer sequences") + print("=" * 60) + + long_input_ids_list = [ + list(range(1, 101)), + list(range(200, 351)), + list(range(400, 465)), + ] + long_data_ids = ["long_0", "long_1", "long_2"] + + prompts = [{"prompt_token_ids": ids} for ids in long_input_ids_list] + request_metadata = {long_data_ids[i]: len(ids) for i, ids in enumerate(long_input_ids_list)} + input_ids_map = {long_data_ids[i]: ids for i, ids in enumerate(long_input_ids_list)} + + engine.collective_rpc("_reset_capture") + engine.collective_rpc("_set_request_metadata", args=(request_metadata, {}, input_ids_map)) + + outputs = engine.generate(prompts, sampling_params, use_tqdm=False) + for i, output in enumerate(outputs): + print(f" Request {i}: {len(output.prompt_token_ids)} prompt tokens") + + metadata = collect_metadata(engine) + all_keys = [metadata[did]["mooncake_key"] for did in long_data_ids] + seq_lens = [request_metadata[did] for did in long_data_ids] + assert len(metadata) == len(long_data_ids), ( + f"Expected {len(long_data_ids)} results, got {len(metadata)}" + ) + + verify_from_mooncake(mooncake_store, all_keys, seq_lens, hidden_dim, last_hidden_dim) + print("\n✓ Test 2 passed") + + # ========================================================================= + # Test 3: formatted_prompts path (defer tokenization) + # ========================================================================= + print("\n" + "=" * 60) + print("TEST 3: formatted_prompts path (defer tokenization)") + print("=" * 60) + + text_prompts = [ + "Hello, world!", + "The quick brown fox jumps over the lazy dog.", + "Once upon a time", + ] + prompt_data_ids = ["prompt_0", "prompt_1", "prompt_2"] + + engine.collective_rpc("_reset_capture") + + outputs = engine.generate(text_prompts, sampling_params, use_tqdm=False) + + # Build authoritative metadata from outputs and set on workers, + # mirroring VllmEngine.generate()'s unconditional post-generation path. + request_metadata = {} + input_ids_map = {} + internal_to_external = {} + for i, output in enumerate(outputs): + did = prompt_data_ids[i] + request_metadata[did] = len(output.prompt_token_ids) + input_ids_map[did] = list(output.prompt_token_ids) + internal_to_external[output.request_id] = did + print(f' Request {i}: "{text_prompts[i]}" -> {len(output.prompt_token_ids)} tokens') + engine.collective_rpc("_set_request_metadata", args=(request_metadata, {}, input_ids_map)) + + metadata = collect_metadata(engine, internal_to_external=internal_to_external) + all_keys = [metadata[did]["mooncake_key"] for did in prompt_data_ids] + seq_lens = [request_metadata[did] for did in prompt_data_ids] + assert len(metadata) == len(prompt_data_ids), ( + f"Expected {len(prompt_data_ids)} results, got {len(metadata)}" + ) + + verify_from_mooncake(mooncake_store, all_keys, seq_lens, hidden_dim, last_hidden_dim) + print("\n✓ Test 3 passed") + + # ========================================================================= + print("\n" + "=" * 60) + print("All tests passed!") + print("=" * 60) + del engine diff --git a/tools/build_conda.sh b/tools/build_conda.sh index 3a38884..3668ce1 100755 --- a/tools/build_conda.sh +++ b/tools/build_conda.sh @@ -6,11 +6,31 @@ SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &> /dev/null && pwd)" PROJECT_ROOT="$(cd -- "$SCRIPT_DIR/.." && pwd)" # Parse command line arguments -# Usage: ./build_conda.sh [MODE] -# 1 - Create new micromamba env and install (default) -# current - Install into current environment -# 0 - Skip env creation and installation +# Usage: ./build_conda.sh [MODE] [BACKEND] +# MODE: +# 1 - Create new micromamba env and install (default) +# current - Install into current environment +# 0 - Skip env creation and installation +# BACKEND: +# sglang - Install SGLang only (default) +# vllm - Install vLLM only +# both - Install both backends + MODE="${1:-1}" +BACKEND="${2:-sglang}" + +# Validate backend +if [[ ! "$BACKEND" =~ ^(sglang|vllm|both)$ ]]; then + echo "Error: Invalid backend '$BACKEND'" + echo "Usage: $0 [MODE] [BACKEND]" + echo " BACKEND options: sglang (default), vllm, both" + exit 1 +fi + +echo "==========================================" +echo "TorchSpec Installation" +echo "Backend: $BACKEND" +echo "==========================================" if [ "$MODE" = "1" ]; then if ! command -v micromamba &> /dev/null; then @@ -31,41 +51,114 @@ else echo "Skipping micromamba setup (mode=0)" fi -SGLANG_VERSION="${SGLANG_VERSION:-v0.5.8.post1}" -SGLANG_COMMIT=0f2df9370a1de1b4fb11b071d39ab3ce2287a350 -SGLANG_FOLDER_NAME="_sglang" +# Install SGLang if requested +if [ "$BACKEND" = "sglang" ] || [ "$BACKEND" = "both" ]; then + echo "==========================================" + echo "Installing SGLang..." + echo "==========================================" + + SGLANG_VERSION="${SGLANG_VERSION:-v0.5.8.post1}" + SGLANG_COMMIT=0f2df9370a1de1b4fb11b071d39ab3ce2287a350 + SGLANG_FOLDER_NAME="_sglang" + + # Install sglang inside the conda environment + if [ ! -d "$PROJECT_ROOT/$SGLANG_FOLDER_NAME" ]; then + git clone https://github.com/sgl-project/sglang.git "$PROJECT_ROOT/$SGLANG_FOLDER_NAME" + fi + + # Avoid pythonpath conflict, because we are using the offline engine. + cd "$PROJECT_ROOT/$SGLANG_FOLDER_NAME" + git checkout $SGLANG_COMMIT + git reset --hard HEAD + + cd "$PROJECT_ROOT" + + if [ "$MODE" = "1" ]; then + micromamba run -n torchspec pip install -e "${SGLANG_FOLDER_NAME}/python[all]" + elif [ "$MODE" = "current" ]; then + pip install -e "${SGLANG_FOLDER_NAME}/python[all]" + fi + + cd "$PROJECT_ROOT/$SGLANG_FOLDER_NAME" -# Install sglang inside the conda environment -if [ ! -d "$PROJECT_ROOT/$SGLANG_FOLDER_NAME" ]; then - git clone https://github.com/sgl-project/sglang.git "$PROJECT_ROOT/$SGLANG_FOLDER_NAME" + # Apply sglang patch (matches Docker build behavior) + git apply "$PROJECT_ROOT/patches/sglang/$SGLANG_VERSION/sglang.patch" + + cd "$PROJECT_ROOT" fi -# Avoid pythonpath conflict, because we are using the offline engine. -cd "$PROJECT_ROOT/$SGLANG_FOLDER_NAME" -git checkout $SGLANG_COMMIT -git reset --hard HEAD +# Install vLLM if requested +if [ "$BACKEND" = "vllm" ] || [ "$BACKEND" = "both" ]; then + echo "==========================================" + echo "Installing vLLM..." + echo "==========================================" -cd "$PROJECT_ROOT" + if [ "$MODE" = "1" ]; then + micromamba run -n torchspec uv pip install "vllm>=0.16.0" + elif [ "$MODE" = "current" ]; then + pip install "vllm>=0.16.0" + fi +fi +# Install torchspec with appropriate extras if [ "$MODE" = "1" ]; then - micromamba run -n torchspec pip install -e "${SGLANG_FOLDER_NAME}/python[all]" - micromamba run -n torchspec uv pip install -e ".[dev]" + echo "==========================================" + echo "Installing TorchSpec..." + echo "==========================================" + + EXTRAS="dev" + if [ "$BACKEND" = "vllm" ]; then + EXTRAS="dev,vllm" + elif [ "$BACKEND" = "both" ]; then + EXTRAS="dev,vllm" + fi + + micromamba run -n torchspec uv pip install -e ".[$EXTRAS]" - echo "torchspec environment setup complete!" + echo "" + echo "==========================================" + echo "✓ TorchSpec environment setup complete!" + echo "==========================================" echo "Activate with: micromamba activate torchspec" + echo "" + if [ "$BACKEND" = "sglang" ]; then + echo "Backend: SGLang" + echo "Run: ./examples/qwen3-8b-single-node/run.sh" + elif [ "$BACKEND" = "vllm" ]; then + echo "Backend: vLLM" + echo "Run: ./examples/qwen3-8b-single-node/run.sh --config configs/vllm_qwen3_8b.yaml" + elif [ "$BACKEND" = "both" ]; then + echo "Backends: SGLang + vLLM" + echo "SGLang: ./examples/qwen3-8b-single-node/run.sh" + echo "vLLM: ./examples/qwen3-8b-single-node/run.sh --config configs/vllm_qwen3_8b.yaml" + fi elif [ "$MODE" = "current" ]; then - pip install -e "${SGLANG_FOLDER_NAME}/python[all]" - pip install -e ".[dev]" + EXTRAS="dev" + if [ "$BACKEND" = "vllm" ]; then + EXTRAS="dev,vllm" + elif [ "$BACKEND" = "both" ]; then + EXTRAS="dev,vllm" + fi + + pip install -e ".[$EXTRAS]" - echo "torchspec installed into current environment!" + echo "" + echo "==========================================" + echo "✓ TorchSpec installed into current environment!" + echo "==========================================" else + echo "" echo "Skipping package installation (mode=0)" echo "Please install packages manually:" - echo " pip install -e \"${SGLANG_FOLDER_NAME}/python[all]\"" - echo " pip install -e \".[dev]\"" + if [ "$BACKEND" = "sglang" ]; then + echo " pip install -e \"${SGLANG_FOLDER_NAME}/python[all]\"" + echo " pip install -e \".[dev]\"" + elif [ "$BACKEND" = "vllm" ]; then + echo " pip install vllm>=0.16.0" + echo " pip install -e \".[dev,vllm]\"" + elif [ "$BACKEND" = "both" ]; then + echo " pip install -e \"${SGLANG_FOLDER_NAME}/python[all]\"" + echo " pip install vllm>=0.16.0" + echo " pip install -e \".[dev,vllm]\"" + fi fi - -cd "$PROJECT_ROOT/$SGLANG_FOLDER_NAME" - -# Apply sglang patch (matches Docker build behavior) -git apply "$PROJECT_ROOT/patches/sglang/$SGLANG_VERSION/sglang.patch" diff --git a/torchspec/config/inference_config.py b/torchspec/config/inference_config.py index d097511..9375dd8 100644 --- a/torchspec/config/inference_config.py +++ b/torchspec/config/inference_config.py @@ -67,6 +67,49 @@ class SGLangConfig: extra_args: Dict[str, Any] = field(default_factory=dict) +@dataclass +class VllmConfig: + """Essential vLLM engine configuration. + + Only fields that TorchSpec explicitly uses are listed here. + Any additional vLLM engine kwargs can be supplied via ``extra_args`` + and will be forwarded as-is. + + Uses vLLM's extract_hidden_states speculative config for hidden states retrieval. + """ + + # Parallelism + tp_size: int = 8 + pp_size: int = 1 + nnodes: int = 1 + + # Memory + mem_fraction_static: float = 0.8 + + # Observability + enable_metrics: bool = False + + # Multimodal + enable_multimodal: bool = False + + # Networking (port is auto-selected by VllmEngine) + dist_init_addr: Optional[str] = None + dist_timeout: int = 60 + init_timeout: int = 300 + + # Hidden states extraction + num_speculative_tokens: int = 1 + + # Use worker extension for hidden states capture (new implementation) + # If False, falls back to LLM class with speculative_config + use_worker_extension: bool = True + + # Passthrough: forwarded as-is to vLLM LLM. + # Use this for any vLLM kwarg that TorchSpec doesn't need to + # inspect (e.g. quantization, max_model_len, trust_remote_code, ...). + extra_args: Dict[str, Any] = field(default_factory=dict) + + @dataclass class InferenceConfig: aux_hidden_states_layers: Optional[list] = None @@ -79,6 +122,7 @@ class InferenceConfig: inference_num_gpus_per_node: int = 8 max_sample_pool_size: int = 0 sglang: SGLangConfig = field(default_factory=SGLangConfig) + vllm: VllmConfig = field(default_factory=VllmConfig) @dataclass diff --git a/torchspec/config/mooncake_config.py b/torchspec/config/mooncake_config.py index c8b8479..a8bd95c 100644 --- a/torchspec/config/mooncake_config.py +++ b/torchspec/config/mooncake_config.py @@ -177,6 +177,16 @@ def export_env(self) -> None: os.environ["MOONCAKE_ENABLE_GPU_DIRECT"] = "1" if self.enable_gpu_direct else "0" if self.async_put_pool_size is not None: os.environ["MOONCAKE_ASYNC_PUT_POOL_SIZE"] = str(self.async_put_pool_size) + os.environ["MOONCAKE_STORE_FULL_WAIT_SECONDS"] = str(self.store_full_wait_seconds) + os.environ["MOONCAKE_STORE_FULL_LOG_INTERVAL_SECONDS"] = str( + self.store_full_log_interval_seconds + ) + os.environ["MOONCAKE_STORE_FULL_MAX_WAIT_SECONDS"] = str(self.store_full_max_wait_seconds) + os.environ["MOONCAKE_GET_RETRY_WAIT_SECONDS"] = str(self.get_retry_wait_seconds) + os.environ["MOONCAKE_GET_RETRY_LOG_INTERVAL_SECONDS"] = str( + self.get_retry_log_interval_seconds + ) + os.environ["MOONCAKE_GET_RETRY_MAX_WAIT_SECONDS"] = str(self.get_retry_max_wait_seconds) @classmethod def from_env(cls) -> "MooncakeConfig": diff --git a/torchspec/config/train_config.py b/torchspec/config/train_config.py index e4a4530..20da08b 100644 --- a/torchspec/config/train_config.py +++ b/torchspec/config/train_config.py @@ -194,6 +194,19 @@ def _resolve_relative_paths( OmegaConf.update(config, dotted_key, os.path.abspath(os.path.join(base_dir, expanded))) +def _validate_vllm_config(config: DictConfig) -> None: + """Raise if the vllm backend is selected with unsupported feature flags.""" + if config.model.target_model_backend != "vllm": + return + unsupported_flags = { + "inference.vllm.enable_multimodal": "enable_multimodal", + "training.train_with_decode": "train_with_decode", + } + for key, label in unsupported_flags.items(): + if OmegaConf.select(config, key): + raise NotImplementedError(f"{label} is not yet supported with the vllm backend!") + + def _save_config_snapshot(config: DictConfig) -> None: """Save the resolved config to output_dir/config.yaml if output_dir is set.""" output_dir = OmegaConf.select(config, "output_dir", default=None) @@ -236,6 +249,9 @@ def load_config( config = OmegaConf.merge(*configs_to_merge) _resolve_relative_paths(config, os.getcwd()) + + _validate_vllm_config(config) + if save_snapshot: _save_config_snapshot(config) @@ -247,6 +263,7 @@ def load_config( "decode": "decode_", "mooncake": "mooncake_", "sglang": "sglang_", + "vllm": "vllm_", } diff --git a/torchspec/controller/inference_manager.py b/torchspec/controller/inference_manager.py index b43a59c..4e032dd 100644 --- a/torchspec/controller/inference_manager.py +++ b/torchspec/controller/inference_manager.py @@ -408,7 +408,7 @@ def _prepare_engine_inputs(self, entries: list[InferenceInput]) -> dict: if self._defer_tokenization: input_ids_ref = None - packed_loss_mask_list = [None] * len(entries) + packed_loss_mask_list = None assert all(e.formatted_prompt is not None for e in entries), ( "formatted_prompt is required when defer_tokenization is True" ) diff --git a/torchspec/controller/loop.py b/torchspec/controller/loop.py index 183cbea..a8ad9b0 100644 --- a/torchspec/controller/loop.py +++ b/torchspec/controller/loop.py @@ -278,7 +278,7 @@ def training_loop( if step_time > 0: metrics["perf/train_capacity"] = args.global_batch_size / step_time - if wandb.run is not None: + if getattr(wandb, "run", None) is not None: wandb.log(metrics) # ── Eval at explicit interval (if configured) ───────── diff --git a/torchspec/inference/engine/__init__.py b/torchspec/inference/engine/__init__.py index 7ebc757..fdf43f4 100644 --- a/torchspec/inference/engine/__init__.py +++ b/torchspec/inference/engine/__init__.py @@ -22,10 +22,12 @@ from torchspec.inference.engine.hf_engine import HFEngine from torchspec.inference.engine.hf_runner import HFRunner from torchspec.inference.engine.sgl_engine import SglEngine +from torchspec.inference.engine.vllm_engine import VllmEngine __all__ = [ "InferenceEngine", "HFEngine", "HFRunner", "SglEngine", + "VllmEngine", ] diff --git a/torchspec/inference/engine/vllm_engine.py b/torchspec/inference/engine/vllm_engine.py new file mode 100644 index 0000000..a6681d1 --- /dev/null +++ b/torchspec/inference/engine/vllm_engine.py @@ -0,0 +1,501 @@ +# Copyright (c) 2026 LightSeek Foundation +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +""" +VLLM Ray actor engine for distributed deployment. + +Uses Worker Extension mode with MultiprocExecutor for reliable hidden states +extraction via model.forward patching in worker processes. +""" + +import socket + +import ray +import torch +from omegaconf import DictConfig, OmegaConf + +from torchspec.inference.engine.base import InferenceEngine +from torchspec.ray.ray_actor import RayActor +from torchspec.utils.logging import logger, setup_file_logging +from torchspec.utils.misc import get_default_eagle3_aux_layer_ids + +_PROTECTION_ENGINE_KEYS = frozenset( + { + "model", + "tensor_parallel_size", + "gpu_memory_utilization", + "nnodes", + "node_rank", + "distributed_backend", + } +) + + +class VllmEngine(InferenceEngine, RayActor): + """Ray actor wrapper for vLLM LLM engine with distributed deployment support. + + Uses Worker Extension mode with MultiprocExecutor and VllmWorkerExtension + for reliable hidden states extraction by patching model.forward in worker processes. + """ + + def __init__( + self, + args, + rank: int, + base_gpu_id: int | None = None, + num_gpus_per_engine: int = 1, + node_rank: int = 0, + engine_group: int = 0, + ): + self.args = args + self.rank = rank + self.base_gpu_id = base_gpu_id + self.num_gpus_per_engine = num_gpus_per_engine + self.node_rank = node_rank + self._engine = None + self._mooncake_config = None + self._mooncake_store = None + self._hidden_size = None + self.local_gpu_id = None + + setup_file_logging("inference", self.rank, group=engine_group) + + def init(self, mooncake_config=None, dist_init_addr: str | None = None) -> None: + if self.base_gpu_id is not None: + self.local_gpu_id = self.setup_gpu(self.base_gpu_id) + logger.info( + f"VllmEngine rank {self.rank}: base_gpu_id={self.base_gpu_id}, " + f"using local GPU {self.local_gpu_id}" + ) + + self._mooncake_config = mooncake_config + + if mooncake_config is not None: + try: + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + s.connect(("8.8.8.8", 80)) + local_ip = s.getsockname()[0] + s.close() + except Exception: + local_ip = "localhost" + logger.warning( + f"VllmEngine rank {self.rank}: failed to get local IP, using localhost" + ) + + mooncake_config.local_hostname = local_ip + mooncake_config.export_env() + + from torchspec.transfer.mooncake.utils import ( + check_mooncake_master_available, + ) + + check_mooncake_master_available( + mooncake_config.master_server_address, + mooncake_config.metadata_server, + ) + + mem_fraction = getattr(self.args, "vllm_mem_fraction_static", 0.8) + pp_size = getattr(self.args, "vllm_pp_size", 1) + + if self.args.aux_hidden_states_layers is not None: + self.aux_hidden_state_layer_ids = self.args.aux_hidden_states_layers + else: + self.aux_hidden_state_layer_ids = get_default_eagle3_aux_layer_ids( + self.args.target_model_path + ) + if self.rank == 0: + logger.info( + f"Using default aux hidden state layer ids: {self.aux_hidden_state_layer_ids}" + ) + + nnodes = getattr(self.args, "vllm_nnodes", 1) + tp_size = nnodes * self.num_gpus_per_engine + + logger.info( + f"VllmEngine rank {self.rank}: BEFORE init - " + f"base_gpu_id={self.base_gpu_id}, num_gpus={self.num_gpus_per_engine}, " + f"tp_size={tp_size}, pp_size={pp_size}, nnodes={nnodes}, node_rank={self.node_rank}, " + f"aux_hidden_state_layer_ids={self.aux_hidden_state_layer_ids}" + ) + + self._init_engine(tp_size, pp_size, nnodes, mem_fraction, dist_init_addr) + + self._hidden_size = self._get_hidden_size_from_engine() + + if self._mooncake_config is not None: + self._init_mooncake_store() + + logger.info( + f"VllmEngine rank {self.rank}: initialized from {self.args.target_model_path} " + f"(tp_size={tp_size}, aux_layers={self.aux_hidden_state_layer_ids}, hidden_size={self._hidden_size})" + ) + + def _init_engine( + self, + tp_size: int, + pp_size: int, + nnodes: int, + mem_fraction: float, + dist_init_addr: str | None, + ) -> None: + """Initialize LLM with worker extension enabled.""" + from vllm import LLM + + engine_kwargs = { + "model": self.args.target_model_path, + "tensor_parallel_size": tp_size, + "gpu_memory_utilization": mem_fraction, + "trust_remote_code": getattr(self.args, "trust_remote_code", True), + "distributed_executor_backend": "mp", + "disable_custom_all_reduce": True, + "worker_extension_cls": ( + "torchspec.inference.engine.vllm_worker_extension.VllmWorkerExtension" + ), + } + + extra_args = getattr(self.args, "vllm_extra_args", None) + if extra_args: + if isinstance(extra_args, DictConfig): + extra = OmegaConf.to_container(extra_args, resolve=True) + else: + extra = dict(extra_args) if not isinstance(extra_args, dict) else extra_args + blocked = extra.keys() & _PROTECTION_ENGINE_KEYS + if blocked: + logger.warning( + f"vllm extra_args contains protected keys that will be ignored: " + f"{sorted(blocked)}. These are managed internally by TorchSpec." + ) + extra = {k: v for k, v in extra.items() if k not in _PROTECTION_ENGINE_KEYS} + engine_kwargs.update(extra) + + inference_batch_size = getattr(self.args, "inference_batch_size", None) + if inference_batch_size is not None: + comp_cfg = engine_kwargs.get("compilation_config", {}) + if isinstance(comp_cfg, dict) and "max_cudagraph_capture_size" not in comp_cfg: + comp_cfg["max_cudagraph_capture_size"] = inference_batch_size + engine_kwargs["compilation_config"] = comp_cfg + logger.info( + f"VllmEngine rank {self.rank}: defaulting " + f"max_cudagraph_capture_size={inference_batch_size} from inference_batch_size" + ) + + # Disable prefix caching and chunked prefill + engine_kwargs["enable_prefix_caching"] = False + engine_kwargs["enable_chunked_prefill"] = False + + max_seq_length = getattr(self.args, "max_seq_length", None) + if max_seq_length: + engine_kwargs["max_model_len"] = max_seq_length + + if nnodes > 1: + engine_kwargs["nnodes"] = nnodes + engine_kwargs["node_rank"] = self.node_rank + if dist_init_addr: + engine_kwargs["distributed_backend"] = "nccl" + engine_kwargs["distributed_init_address"] = dist_init_addr + + self._engine = LLM(**engine_kwargs) + self._setup_rpc_hidden_states_capture() + logger.info( + f"VllmEngine rank {self.rank}: initialized worker extension mode " + f"with layers={self.aux_hidden_state_layer_ids}" + ) + + def _setup_rpc_hidden_states_capture(self) -> None: + """Initialize worker-side hidden-state capture hooks.""" + if self._engine is None: + raise RuntimeError("VllmEngine not initialized. Call init() first.") + if not hasattr(self._engine, "collective_rpc"): + raise RuntimeError("vLLM LLM.collective_rpc is required for worker extension mode") + + if self._mooncake_config is not None: + self._mooncake_config.export_env() + logger.info( + f"VllmEngine rank {self.rank}: Set Mooncake env vars for workers: " + f"master={self._mooncake_config.master_server_address}" + ) + + layer_ids = list(self.aux_hidden_state_layer_ids) + results = self._engine.collective_rpc( + "_setup_hidden_states_capture", + args=(layer_ids,), + ) + logger.info(f"VllmEngine rank {self.rank}: worker capture setup replies={results}") + + def generate( + self, + data_id: str | list[str], + input_ids_ref: ray.ObjectRef | list[torch.Tensor] | None = None, + packed_loss_mask_list: list[str | None] | None = None, + formatted_prompts: list[str] | None = None, + return_last_hidden_states: bool = False, + return_logits: bool = True, + multimodal_inputs: list[dict] | None = None, + ) -> list[dict]: + """Generate hidden states for training data using Worker Extension mode.""" + if self._engine is None: + raise RuntimeError("VllmEngine not initialized. Call init() first.") + + if (input_ids_ref is None) == (formatted_prompts is None): + raise ValueError("Exactly one of input_ids_ref or formatted_prompts must be set") + + use_prompts = formatted_prompts is not None + input_ids_list: list[torch.Tensor] | None = None + + if use_prompts: + batch_size = len(formatted_prompts) + prompts = formatted_prompts + else: + if isinstance(input_ids_ref, ray.ObjectRef): + input_ids_list = ray.get(input_ids_ref) + else: + input_ids_list = input_ids_ref + if input_ids_list is None: + raise ValueError("input_ids_ref resolved to None") + batch_size = len(input_ids_list) + prompts = self._format_input_ids_for_vllm(input_ids_list) + + if isinstance(data_id, str): + data_ids = [f"{data_id}_{i}" for i in range(batch_size)] + elif len(data_id) == batch_size: + data_ids = data_id + else: + raise ValueError( + f"data_id length {len(data_id)} does not match batch size {batch_size}" + ) + + from vllm import SamplingParams + + sampling_params = SamplingParams(max_tokens=1, temperature=0) + request_metadata = {} + if input_ids_list is not None: + for i, ids in enumerate(input_ids_list): + request_metadata[data_ids[i]] = int(self._normalize_input_ids(ids).numel()) + + # Build packed_loss_mask_map for workers + packed_loss_mask_map = {} + if packed_loss_mask_list is not None: + for i, data_id in enumerate(data_ids): + if i < len(packed_loss_mask_list): + packed_loss_mask_map[data_id] = packed_loss_mask_list[i] + + # Build input_ids_map for workers (pass real input_ids via RPC) + input_ids_map = {} + if input_ids_list is not None: + for i, data_id in enumerate(data_ids): + if i < len(input_ids_list): + ids = self._normalize_input_ids(input_ids_list[i]) + input_ids_map[data_id] = ids.cpu().tolist() + + try: + self._engine.collective_rpc("_reset_capture") + if request_metadata: + self._engine.collective_rpc( + "_set_request_metadata", + args=(request_metadata, packed_loss_mask_map, input_ids_map), + ) + except Exception as e: + logger.warning(f"Could not reset capture via worker extension: {e}") + + outputs = self._engine.generate(prompts, sampling_params, use_tqdm=False) + + # outputs are sorted by int(request_id), matching submission order. + # Build mapping from vLLM's internal worker IDs ("{request_id}-{uuid}") + # to our external data_ids. + internal_to_external = {} + for i, output in enumerate(outputs): + internal_to_external[output.request_id] = data_ids[i] + + # Always build request_metadata and input_ids_map from the + # outputs. + for i, output in enumerate(outputs): + did = data_ids[i] + request_metadata[did] = len(output.prompt_token_ids) + input_ids_map[did] = list(output.prompt_token_ids) + try: + self._engine.collective_rpc( + "_set_request_metadata", + args=(request_metadata, packed_loss_mask_map, input_ids_map), + ) + except Exception as e: + logger.warning( + f"VllmEngine rank {self.rank}: Could not set post-generation request metadata: {e}" + ) + + # Get metadata from workers (tensors are already stored in Mooncake by workers) + metadata_by_request: dict[str, dict] = {} + try: + metadata_list = self._engine.collective_rpc( + "_store_and_get_metadata", args=(internal_to_external,) + ) + if isinstance(metadata_list, list): + for metadata in metadata_list: + if isinstance(metadata, dict): + metadata_by_request.update(metadata) + elif isinstance(metadata_list, dict): + metadata_by_request = metadata_list + except Exception as e: + logger.warning(f"Could not get metadata from worker extension: {e}") + + if not metadata_by_request: + logger.error( + f"VllmEngine rank {self.rank}: metadata_by_request is EMPTY for " + f"data_ids={data_ids}. Worker returned metadata_list={metadata_list!r}. " + f"use_prompts={use_prompts}, request_metadata_keys={list(request_metadata.keys())}, " + f"internal_to_external={internal_to_external}" + ) + + results = [] + for i, output in enumerate(outputs): + seq_len = len(output.prompt_token_ids) + data_id = data_ids[i] + + # Get metadata for this request + metadata = metadata_by_request.get(data_id) + if metadata is None: + logger.error( + f"VllmEngine rank {self.rank}: No metadata for data_id={data_id}. " + f"metadata_by_request has keys={list(metadata_by_request.keys())}. " + f"Training may be corrupted." + ) + continue + + # Extract info from metadata (tensors are already in Mooncake) + mooncake_key = metadata.get("mooncake_key", data_id) + tensor_shapes = metadata.get("tensor_shapes", {}) + tensor_dtypes = metadata.get("tensor_dtypes", {}) + + result = { + "mooncake_key": mooncake_key, + "tensor_shapes": tensor_shapes, + "tensor_dtypes": tensor_dtypes, + "data_id": data_id, + "seq_len": seq_len, + } + # Get packed_loss_mask from metadata (returned by worker) + packed_loss_mask = metadata.get("packed_loss_mask") + if packed_loss_mask is not None: + result["packed_loss_mask"] = packed_loss_mask + # Get input_ids_list from metadata (returned by worker via RPC) + input_ids_list = metadata.get("input_ids_list") + if input_ids_list is not None: + result["input_ids_list"] = input_ids_list + results.append(result) + + # No need to flush here - workers already flushed after storing + + logger.debug( + f"VllmEngine rank {self.rank}: generated {len(results)} mooncake results " + f"for data_ids={data_ids}" + ) + return results + + def _init_mooncake_store(self) -> None: + if self._mooncake_store is not None or self._mooncake_config is None: + return + from torchspec.transfer.mooncake.eagle_store import EagleMooncakeStore + + self._mooncake_store = EagleMooncakeStore(self._mooncake_config) + if torch.cuda.is_available(): + self._mooncake_store.setup(device=torch.cuda.current_device()) + else: + self._mooncake_store.setup() + + def _normalize_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + if input_ids.dim() == 2 and input_ids.shape[0] == 1: + return input_ids.squeeze(0) + if input_ids.dim() == 1: + return input_ids + raise ValueError(f"Unexpected input_ids shape: {input_ids.shape}") + + def _format_input_ids_for_vllm( + self, input_ids_list: list[torch.Tensor] + ) -> list[dict[str, list[int]]]: + prompts = [] + for ids in input_ids_list: + prompts.append({"prompt_token_ids": self._normalize_input_ids(ids).tolist()}) + return prompts + + def health_check(self, timeout: float = 5.0) -> bool: + return self._engine is not None + + def shutdown(self) -> None: + if self._mooncake_store is not None: + try: + self._mooncake_store.close() + except Exception as e: + logger.warning(f"VllmEngine rank {self.rank}: Error closing mooncake store: {e}") + self._mooncake_store = None + + if self._engine is not None: + del self._engine + self._engine = None + + logger.info(f"VllmEngine rank {self.rank}: shutdown complete") + + def get_status(self) -> dict: + return { + "rank": self.rank, + "initialized": self._engine is not None, + "base_gpu_id": self.base_gpu_id, + "hidden_size": self._hidden_size, + } + + def _get_hidden_size_from_engine(self) -> int: + from transformers import AutoConfig + + config = AutoConfig.from_pretrained( + self.args.target_model_path, + trust_remote_code=getattr(self.args, "trust_remote_code", True), + ) + hidden_size = getattr(config, "hidden_size", None) + if hidden_size is None: + text_config = getattr(config, "text_config", None) + if text_config is not None: + hidden_size = getattr(text_config, "hidden_size", None) + if hidden_size is None: + raise ValueError( + f"Could not determine hidden_size from model config: {self.args.target_model_path}" + ) + return hidden_size + + def _get_tensor_shapes(self, seq_len: int) -> dict: + aux_hidden_state_layer_ids = self.aux_hidden_state_layer_ids + num_aux_layers = len(aux_hidden_state_layer_ids) + if self._hidden_size is None: + raise ValueError( + f"VllmEngine rank {self.rank}: hidden_size not initialized. Call init() first." + ) + hidden_size = self._hidden_size + + concat_hidden_size = num_aux_layers * hidden_size + + return { + "hidden_states": (seq_len, concat_hidden_size), + "input_ids": (seq_len,), + "last_hidden_states": (seq_len, hidden_size), + } + + def _get_tensor_dtypes(self) -> dict: + return { + "hidden_states": torch.bfloat16, + "input_ids": torch.long, + "last_hidden_states": torch.bfloat16, + } diff --git a/torchspec/inference/engine/vllm_worker_extension.py b/torchspec/inference/engine/vllm_worker_extension.py new file mode 100644 index 0000000..99b8a4d --- /dev/null +++ b/torchspec/inference/engine/vllm_worker_extension.py @@ -0,0 +1,778 @@ +# Copyright (c) 2026 LightSeek Foundation +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +"""vLLM Worker Extension for Hidden States Capture. + +This module provides a TorchSpec-style worker extension for vLLM that enables +reliable hidden states extraction during inference. It patches the model's +forward method in each worker process to capture intermediate layer activations +and store them directly to Mooncake to avoid RPC serialization issues. + +Based on the vllm-speculators approach but integrated into TorchSpec's +architecture with Ray Actors and Mooncake storage. +""" + +import logging +import os +import re +import types +from collections import defaultdict +from itertools import islice +from typing import Any, Dict, List, Optional + +import torch +from vllm.distributed import get_pp_group, get_tp_group +from vllm.sequence import IntermediateTensors + +logger = logging.getLogger(__name__) + + +def _sanitize_mooncake_key(key: str) -> str: + """Sanitize a key for use with Mooncake store. + + Mooncake keys should only contain alphanumeric characters, hyphens, and underscores. + This function replaces invalid characters with underscores. + + Args: + key: The original key (e.g., vLLM req_id) + + Returns: + A sanitized key safe for Mooncake operations + """ + sanitized = re.sub(r"[^a-zA-Z0-9_-]", "_", key) + if sanitized and sanitized[0].isdigit(): + sanitized = "k" + sanitized + return sanitized + + +def _patched_forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: Any, +) -> Any: + """Patched forward pass that captures hidden states from specified layers. + + This function is dynamically bound to base_model instances via types.MethodType. + It expects base_model to have an _extension attribute pointing to the + VllmWorkerExtension instance. + + Args: + input_ids: Input token IDs + positions: Position IDs + intermediate_tensors: For pipeline parallelism + inputs_embeds: Pre-computed input embeddings (for multimodal) + **kwargs: Additional arguments + + Returns: + Hidden states or IntermediateTensors (for PP) + """ + # Get extension reference + extension = self._extension # noqa: SLF001 + + # Handle pipeline parallelism - first rank does embedding + if get_pp_group().is_first_rank: + hidden_states = ( + inputs_embeds if inputs_embeds is not None else self.embed_input_ids(input_ids) + ) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + # Track auxiliary hidden states for capture + aux_hidden_states: List[torch.Tensor] = [] + + # Only capture on TP rank 0 to avoid duplicates + should_capture = get_tp_group().rank_in_group == 0 + target_layers = extension._layer_ids if should_capture else frozenset() # noqa: SLF001 + + # Capture input_ids only on first call (prefill phase) to avoid including generated tokens + if should_capture and get_pp_group().is_first_rank and extension._captured_input_ids is None: + # input_ids shape: (batch_size, seq_len) or (seq_len,) + if input_ids.dim() == 2: + # Flatten batch dimension + extension._captured_input_ids = input_ids.view(-1).clone() + else: + extension._captured_input_ids = input_ids.clone() + + # Process each layer + for idx, layer in enumerate(islice(self.layers, self.start_layer, self.end_layer)): + hidden_states, residual = layer( + hidden_states=hidden_states, + positions=positions, + residual=residual, + ) + absolute_layer_idx = self.start_layer + idx + + # Capture intermediate layers (not the last) before normalization + if absolute_layer_idx in target_layers: + # Add residual before capturing (matching speculators pattern) + captured = ( + (hidden_states + residual).clone() + if residual is not None + else hidden_states.clone() + ) + aux_hidden_states.append(captured) + + # Handle pipeline parallelism - return intermediate tensors if not last rank + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states, "residual": residual}) + + # Final normalization (only on last PP rank) + hidden_states, _ = self.norm(hidden_states, residual) + + # Store captured states (only on last PP rank, TP rank 0, and during prefill) + if should_capture and not extension._prefill_complete: # noqa: SLF001 + if aux_hidden_states: + extension._store_captured_states(aux_hidden_states) # noqa: SLF001 + extension._store_last_hidden_states(hidden_states) # noqa: SLF001 + + return hidden_states + + +class VllmWorkerExtension: + """Worker extension that adds hidden states capture functionality to vLLM. + + This extension hooks into vLLM's Worker by being specified in the worker + initialization. It patches the model's forward pass to intercept and capture + intermediate layer hidden states during inference. + + Key behaviors: + - Only captures on tensor parallel (TP) rank 0 to avoid duplicate data when + using tensor parallelism. All TP ranks compute the same hidden states, so + capturing from rank 0 is sufficient. + - Stores captured states in GPU memory during batch processing, then writes + directly to Mooncake to avoid RPC serialization issues. + - Supports pipeline parallelism by handling IntermediateTensors correctly. + - Tracks request metadata to map captured states back to original requests + across chunked prefill iterations. + + Attributes: + _layer_ids: Frozenset of layer indices for O(1) lookup during capture + _captured_states: Accumulated hidden states per layer (GPU tensors) + _request_metadata: Metadata tracking tokens per request per iteration + _mooncake_store: EagleMooncakeStore instance for direct storage + model_runner: Reference to the vLLM model runner + """ + + def __init__(self): + """Initialize the worker extension with Mooncake store support.""" + self._layer_ids: frozenset = frozenset() + self._captured_states: Optional[List[List[torch.Tensor]]] = None + self._request_metadata: List[Dict[str, int]] = [] + self._current_request_metadata: Optional[Dict[str, int]] = None + self._mooncake_store: Optional[Any] = None + self._store_initialized: bool = False + self._store_setup_complete: bool = False + self._init_retry_count: int = 0 + self._max_init_retries: int = 3 + self.model_runner: Optional[Any] = None + + def _get_cuda_device_safe(self) -> torch.device: + """Safely get CUDA device, handling uninitialized context (V1 compatibility). + + In vLLM V1, CUDA context may not be initialized when this method is called. + This method safely handles both initialized and uninitialized contexts. + + Returns: + torch.device: The CUDA device to use. Falls back to cuda:0 if context + is not yet initialized (common in V1 engine). + """ + try: + if torch.cuda.is_initialized(): + current_device = torch.cuda.current_device() + logger.debug(f"CUDA initialized, using device cuda:{current_device}") + return torch.device(f"cuda:{current_device}") + else: + # CUDA not initialized yet (V1), use device 0 as fallback + # V1 will initialize context during model loading + logger.debug("CUDA not initialized yet (V1), falling back to cuda:0") + return torch.device("cuda:0") + except RuntimeError as e: + # CUDA context not available + logger.warning(f"Failed to get CUDA device: {e}, falling back to cuda:0") + return torch.device("cuda:0") + + def _init_mooncake_store(self) -> bool: + """Initialize Mooncake store connection in the worker. + + Uses environment variables set by the main process to connect to + the Mooncake master and metadata servers. + + Returns: + True if initialization successful, False otherwise. + """ + if self._store_initialized: + return True + + # Only initialize on TP rank 0 - other ranks don't capture hidden states + try: + if get_tp_group().rank_in_group != 0: + logger.debug("Skipping Mooncake store init on non-zero TP rank") + return False + except Exception: + # If we can't get TP group info, proceed anyway (for backward compatibility) + pass + + try: + from torchspec.config.mooncake_config import MooncakeConfig + from torchspec.transfer.mooncake.eagle_store import EagleMooncakeStore + + if not os.environ.get("MOONCAKE_MASTER_SERVER") and not os.environ.get( + "MOONCAKE_MASTER_HOST" + ): + logger.warning( + "Mooncake master address not available in worker environment. " + "Set MOONCAKE_MASTER_SERVER environment variable." + ) + return False + + config = MooncakeConfig.from_env() + + # Create store object but don't call setup() yet + # setup() will be called lazily when CUDA context is ready + self._mooncake_store = EagleMooncakeStore(config) + # Mark as initialized but not yet setup + # setup() will be called on first put() when CUDA context is ready + self._store_initialized = True + self._store_setup_complete = False + + logger.info( + f"Worker initialized Mooncake store (setup deferred): " + f"master={config.master_server_address}, protocol={config.protocol}" + ) + return True + + except Exception as e: + logger.error(f"Failed to initialize Mooncake store in worker: {e}", exc_info=True) + self._mooncake_store = None + self._store_initialized = False + return False + + def _ensure_mooncake_store(self) -> bool: + """Ensure Mooncake store is initialized and setup, with retry logic. + + This method handles lazy initialization for vLLM V1 compatibility. + In V1, CUDA context may not be ready during initial Worker initialization, + so we defer the actual setup() call until first use. + + Returns: + True if store is ready for use, False otherwise. + """ + # Ensure attributes exist (for vLLM V1 compatibility where __init__ may not be called) + if not hasattr(self, "_store_initialized"): + self._store_initialized = False + if not hasattr(self, "_store_setup_complete"): + self._store_setup_complete = False + if not hasattr(self, "_init_retry_count"): + self._init_retry_count = 0 + if not hasattr(self, "_max_init_retries"): + self._max_init_retries = 3 + if not hasattr(self, "_mooncake_store"): + self._mooncake_store = None + + # Already fully initialized and setup + if self._store_initialized and self._store_setup_complete: + return True + + # Check retry limit + if self._init_retry_count >= self._max_init_retries: + logger.error( + f"Max retries ({self._max_init_retries}) exceeded for Mooncake store initialization" + ) + return False + + try: + # Initialize store if not already done + if not self._store_initialized: + if not self._init_mooncake_store(): + self._init_retry_count += 1 + logger.warning( + f"Mooncake store init failed (attempt {self._init_retry_count}/{self._max_init_retries})" + ) + return False + + # Setup store if not already done + if not self._store_setup_complete and self._mooncake_store is not None: + try: + # Use safe CUDA device detection for V1 compatibility + device = self._get_cuda_device_safe() + logger.info(f"Setting up Mooncake store on device {device}") + self._mooncake_store.setup(device=device) + + try: + logger.info("Warming up Mooncake RDMA path...") + self._mooncake_store.warmup_rdma() + logger.info("Mooncake RDMA warmup completed successfully") + except Exception as warmup_error: + logger.warning(f"Mooncake RDMA warmup failed: {warmup_error}") + + self._store_setup_complete = True + logger.info("Mooncake store setup completed successfully") + return True + except Exception as e: + self._init_retry_count += 1 + # Check if this is a CUDA context error (common in V1) + error_msg = str(e).lower() + if "cuda" in error_msg or "device" in error_msg: + logger.warning( + f"CUDA context not ready (attempt {self._init_retry_count}/{self._max_init_retries}): {e}. " + f"Will retry on next put." + ) + else: + logger.error(f"Mooncake store setup failed: {e}") + return False + + return True + + except Exception as e: + self._init_retry_count += 1 + logger.error(f"Unexpected error in _ensure_mooncake_store: {e}", exc_info=True) + return False + + def _store_last_hidden_states(self, hidden_states: torch.Tensor) -> None: + """Store post-norm hidden states from a forward pass for use as last_hidden_states""" + if getattr(self, "_captured_last_hs", None) is None: + self._captured_last_hs = [hidden_states.clone()] + else: + self._captured_last_hs.append(hidden_states.clone()) + + def _store_captured_states(self, aux_hidden_states: List[torch.Tensor]) -> None: + """Store captured hidden states from a forward pass. + + Args: + aux_hidden_states: List of tensors, one per target layer + """ + if self._captured_states is None: + self._captured_states = [[h] for h in aux_hidden_states] + else: + for i, h in enumerate(aux_hidden_states): + self._captured_states[i].append(h) + + # Track per-request token counts for this scheduler step + model_runner = getattr(self, "model_runner", None) + input_batch = getattr(model_runner, "input_batch", None) + step_tokens: Dict[str, int] = {} + if input_batch is not None and hasattr(input_batch, "req_ids"): + for req_id in input_batch.req_ids: + num_tokens = 0 + req_idx = getattr(input_batch, "req_id_to_index", {}).get(req_id) + if req_idx is not None: + num_computed = getattr( + input_batch, "num_computed_tokens", [0] * len(input_batch.req_ids) + )[req_idx] + num_total = getattr(input_batch, "num_tokens", [0] * len(input_batch.req_ids))[ + req_idx + ] + num_tokens = num_total - num_computed + step_tokens[req_id] = num_tokens + self._request_metadata.append(step_tokens) + + # With max_tokens=1 the prefill forward pass already generates the + # single allowed token, so no decode step is scheduled by vLLM. + # This check handles chunked prefill where multiple forward calls + # sum up to the total prefill token count. + if self._current_request_metadata and not self._prefill_complete: + expected = sum(self._current_request_metadata.values()) + captured = sum(t.shape[0] for t in self._captured_states[0]) + if captured == expected: + self._prefill_complete = True + elif captured > expected: + logger.warning(f"Captured more tokens than expected: {captured} > {expected}") + + def _store_input_ids(self, input_ids: torch.Tensor) -> None: + """Store input_ids from a forward pass. + + Args: + input_ids: Input token IDs tensor (batch_size, seq_len) or (seq_len,) + """ + # Flatten if needed and store + if input_ids.dim() == 2: + # (batch_size, seq_len) - flatten to (batch_size * seq_len,) + input_ids = input_ids.view(-1) + if getattr(self, "_captured_input_ids", None) is None: + self._captured_input_ids = input_ids.clone() + else: + self._captured_input_ids = torch.cat([self._captured_input_ids, input_ids], dim=0) + + def _setup_hidden_states_capture(self, layer_ids: List[int]) -> None: + """Setup model to capture auxiliary hidden states from specific layers. + + This method patches the model's forward method to intercept hidden states + during the forward pass. + + Args: + layer_ids: List of layer indices to capture from + """ + self._layer_ids = frozenset(layer_ids) + self._captured_states = None + self._request_metadata = [] + self._current_request_metadata = None + self._packed_loss_mask_map: Dict[str, Optional[str]] = {} + self._store_initialized = False + self._store_setup_complete = False + self._init_retry_count = 0 + self._mooncake_store = None + + model_runner = getattr(self, "model_runner", None) + if model_runner is None and hasattr(self, "model"): + model_runner = self + if model_runner is None: + raise AttributeError("Could not find model_runner for worker extension setup") + + self.model_runner = model_runner + model = self.model_runner.model # type: ignore[attr-defined] + + # Handle vision-language models (e.g., Qwen-VL) + if hasattr(model, "get_language_model"): + base_model = model.get_language_model().model + # Handle standard text models + elif hasattr(model, "model") and hasattr(model.model, "layers"): + base_model = model.model + else: + # Try to find model with layers attribute + attrs = [a for a in dir(model) if not a.startswith("_")] + raise AttributeError( + f"Could not find base model with 'layers' attribute. " + f"Model type: {type(model).__name__}, " + f"Available attributes: {attrs}" + ) + + # Attach extension reference and patch forward method + base_model._extension = self # noqa: SLF001 + base_model.forward = types.MethodType(_patched_forward, base_model) + + logger.info(f"Hidden states capture setup complete for layers {layer_ids}") + + def _set_request_metadata( + self, + request_metadata: Dict[str, int], + packed_loss_mask_map: Optional[Dict[str, Optional[str]]] = None, + input_ids_map: Optional[Dict[str, List[int]]] = None, + ) -> None: + """Set request metadata for the next forward pass. + + This is called before each scheduler iteration to track which tokens + belong to which request. + + Args: + request_metadata: Dict mapping request_id -> num_prefill_tokens + packed_loss_mask_map: Optional dict mapping request_id -> packed_loss_mask + string (values may be None when loss masks are not available). + input_ids_map: Optional dict mapping request_id -> input_ids list (passed via RPC) + """ + self._current_request_metadata = request_metadata + self._packed_loss_mask_map = packed_loss_mask_map or {} + self._input_ids_map = input_ids_map or {} + + def _reset_capture(self) -> None: + """Reset captured states before starting a new batch. + + Must be called before processing a new batch of requests. + """ + if not hasattr(self, "_layer_ids") or len(self._layer_ids) == 0: + raise RuntimeError("Must call _setup_hidden_states_capture before capturing states") + self._captured_states = None + self._captured_last_hs: Optional[List[torch.Tensor]] = None + self._captured_input_ids: Optional[torch.Tensor] = None + self._prefill_complete = False + self._request_metadata = [] + self._current_request_metadata = None + self._packed_loss_mask_map = {} + self._input_ids_map = {} + + def _store_and_get_metadata( + self, internal_to_external: Optional[Dict[str, str]] = None + ) -> Optional[Dict[str, Dict[str, Any]]]: + """Store captured hidden states to Mooncake and return metadata. + + This method stores tensors directly to Mooncake from the worker process, + avoiding RPC serialization issues. It returns only lightweight metadata + that can be safely serialized and returned via collective_rpc. + + Returns: + Dict mapping request_id to metadata dict with keys: + - 'mooncake_key': str, the base key used for storage + - 'tensor_shapes': dict of tensor shapes + - 'tensor_dtypes': dict of dtype names + - 'num_layers': int, number of captured layers + or None if no states captured or not on TP rank 0. + """ + # Only TP rank 0 has captured data + if get_tp_group().rank_in_group != 0: + return None + if self._captured_states is None: + logger.warning( + "_store_and_get_metadata: captured_states is None " + "(forward patch may not be running or no prefill occurred)" + ) + return None + + # Ensure Mooncake store is initialized and setup (with retry for V1 compatibility) + if not self._ensure_mooncake_store(): + logger.warning( + "Failed to initialize/setup Mooncake store, cannot store hidden states. " + "This may be due to CUDA context not being ready in V1 engine." + ) + return None + + # Concatenate captured states from all scheduler iterations + concatenated_layers = [ + torch.cat(layer_tensors, dim=0) for layer_tensors in self._captured_states + ] + total_captured_tokens = concatenated_layers[0].shape[0] + + # Concatenate post-norm hidden states for last_hidden_states + concatenated_last_hs = None + if getattr(self, "_captured_last_hs", None): + concatenated_last_hs = torch.cat(self._captured_last_hs, dim=0) + + internal_to_external = internal_to_external or {} + ext_token_counts = ( + dict(self._current_request_metadata) if self._current_request_metadata else {} + ) + + # Build worker-visible ID -> external ID lookup once. + # In V1, the worker sees "{counter}-{uuid8}" while internal_to_external + # maps bare counter strings (from output.request_id) to external data_ids. + worker_to_ext: Dict[str, str] = dict(internal_to_external) + for step_meta in self._request_metadata: + for worker_id in step_meta: + if worker_id not in worker_to_ext: + for counter, ext_id in internal_to_external.items(): + if worker_id.startswith(f"{counter}-"): + worker_to_ext[worker_id] = ext_id + break + + request_slices: List[tuple] = [] # (external_id, num_tokens) + seen_ext_ids: set = set() + + for step_meta in self._request_metadata: + for int_id in step_meta.keys(): + ext_id = worker_to_ext.get(int_id, int_id) + if ext_id not in seen_ext_ids: + n_tokens = ext_token_counts.get(ext_id, 0) + if n_tokens > 0: + request_slices.append((ext_id, n_tokens)) + seen_ext_ids.add(ext_id) + + # Fallback if _request_metadata didn't produce results + if not request_slices and ext_token_counts: + logger.warning( + "Internal request metadata mapping failed; falling back to external order" + ) + for ext_id, n_tokens in ext_token_counts.items(): + request_slices.append((ext_id, n_tokens)) + + if not request_slices: + logger.warning( + f"_store_and_get_metadata: request_slices is empty — cannot map " + f"captured tokens to requests. " + f"total_captured_tokens={total_captured_tokens}, " + f"_request_metadata steps={len(self._request_metadata)}, " + f"internal_to_external keys={list(internal_to_external.keys())[:5]}, " + f"ext_token_counts keys={list(ext_token_counts.keys())[:5]}, " + f"current_request_metadata={self._current_request_metadata is not None}" + ) + + total_expected_tokens = sum(n for _, n in request_slices) + + if total_captured_tokens != total_expected_tokens and total_expected_tokens > 0: + logger.warning( + f"Token count mismatch: captured={total_captured_tokens}, " + f"expected={total_expected_tokens}" + ) + + num_aux_layers = len(concatenated_layers) + request_chunks: defaultdict[str, List[List[torch.Tensor]]] = defaultdict( + lambda: [[] for _ in range(num_aux_layers)] + ) + request_last_hs: defaultdict[str, List[torch.Tensor]] = defaultdict(list) + current_idx = 0 + + for external_id, num_tokens in request_slices: + if current_idx >= total_captured_tokens: + break + actual_tokens = min(num_tokens, total_captured_tokens - current_idx) + if actual_tokens > 0: + for layer_idx, layer_tensor in enumerate(concatenated_layers): + chunk = layer_tensor[current_idx : current_idx + actual_tokens] + request_chunks[external_id][layer_idx].append(chunk) + if concatenated_last_hs is not None: + request_last_hs[external_id].append( + concatenated_last_hs[current_idx : current_idx + actual_tokens] + ) + current_idx += actual_tokens + + # Store to Mooncake and collect metadata + result: Dict[str, Dict[str, Any]] = {} + for req_id, layer_chunks in request_chunks.items(): + mooncake_key = _sanitize_mooncake_key(req_id) + if mooncake_key != req_id: + logger.debug(f"Sanitized key '{req_id}' -> '{mooncake_key}'") + + layer_tensors = [torch.cat(chunks, dim=0) for chunks in layer_chunks] + + if len(layer_tensors) > 1: + hidden_states = torch.cat(layer_tensors, dim=-1) + else: + hidden_states = layer_tensors[0] + + if req_id in request_last_hs and request_last_hs[req_id]: + last_hidden_states = torch.cat(request_last_hs[req_id], dim=0) + else: + last_hidden_states = layer_tensors[-1] + + # Use real input_ids from RPC, otherwise create dummy + if req_id in self._input_ids_map: + input_ids_list = self._input_ids_map[req_id] + input_ids = torch.tensor( + input_ids_list, dtype=torch.long, device=hidden_states.device + ) + else: + seq_len = hidden_states.shape[0] + input_ids = torch.zeros(seq_len, dtype=torch.long, device=hidden_states.device) + + # Skip empty tensors + if hidden_states.numel() == 0: + logger.error(f"Request {req_id}: hidden_states is empty! Skipping.") + continue + + try: + logger.debug( + f"Storing to Mooncake: key={mooncake_key}, " + f"hidden_states_shape={hidden_states.shape}" + ) + + # Store to Mooncake + tensor_shapes = self._mooncake_store.put( + key=mooncake_key, + hidden_states=hidden_states, + input_ids=input_ids, + last_hidden_states=last_hidden_states, + target=None, + ) + + logger.debug(f"Successfully stored to Mooncake: key={mooncake_key}") + + # Convert dtype to string for RPC serialization + # Include input_ids as a list for reconstruction (avoids Mooncake storage issues) + result[req_id] = { + "mooncake_key": mooncake_key, + "tensor_shapes": tensor_shapes, + "tensor_dtypes": { + "hidden_states": str(hidden_states.dtype).replace("torch.", ""), + "input_ids": str(input_ids.dtype).replace("torch.", ""), + "last_hidden_states": str(last_hidden_states.dtype).replace("torch.", ""), + }, + "num_layers": len(layer_tensors), + "packed_loss_mask": self._packed_loss_mask_map.get(req_id), + "input_ids_list": input_ids.cpu().tolist(), # Serialize via RPC instead of Mooncake + } + except Exception as e: + logger.warning( + f"Failed to store tensors to Mooncake for {req_id} (key={mooncake_key}): {e}" + ) + continue + + # Flush to ensure all writes are complete before returning + if self._mooncake_store is not None: + self._mooncake_store.flush() + + # Clear intermediate storage to free memory + self._captured_states = None + self._captured_last_hs = None + self._captured_input_ids = None + self._request_metadata = [] + self._input_ids_map = {} + + return result if result else None + + def _get_captured_states(self) -> Optional[Dict[str, List[torch.Tensor]]]: + """Legacy method - now delegates to _store_and_get_metadata. + + This method is kept for backward compatibility but should not be used + in production due to RPC serialization issues. Use _store_and_get_metadata + instead which stores tensors directly to Mooncake. + + Returns: + Dict mapping request_id to list of tensors (one per layer), + or None if no states captured. + """ + # If Mooncake store is available, use the new method + if self._store_initialized or self._init_mooncake_store(): + metadata = self._store_and_get_metadata() + if metadata is None: + return None + # Return empty dict to signal success - actual data is in Mooncake + return {} + + # Fallback to old behavior if Mooncake not available + if self._captured_states is None: + return None + + # Concatenate captured states from all scheduler iterations + concatenated_layers = [ + torch.cat(layer_tensors, dim=0) for layer_tensors in self._captured_states + ] + + # Slice and group by request + request_chunks: defaultdict[str, List[List[torch.Tensor]]] = defaultdict( + lambda: [[] for _ in range(len(concatenated_layers))] + ) + current_idx = 0 + + # Use external IDs for slicing + external_ids = ( + list(self._current_request_metadata.keys()) if self._current_request_metadata else [] + ) + token_counts = ( + list(self._current_request_metadata.values()) if self._current_request_metadata else [] + ) + + req_idx = 0 + for step_metadata in self._request_metadata: + step_tokens = sum(step_metadata.values()) if step_metadata else 0 + if step_tokens == 0 and req_idx < len(token_counts): + step_tokens = token_counts[req_idx] + + if req_idx < len(external_ids): + external_id = external_ids[req_idx] + for layer_idx, layer_tensor in enumerate(concatenated_layers): + chunk = layer_tensor[current_idx : current_idx + step_tokens].clone().cpu() + request_chunks[external_id][layer_idx].append(chunk) + current_idx += step_tokens + req_idx += 1 + + # Concatenate chunks for each request across iterations + result: Dict[str, List[torch.Tensor]] = { + req_id: [torch.cat(chunks, dim=0) for chunks in layer_chunks] + for req_id, layer_chunks in request_chunks.items() + } + + # Clear intermediate storage to free memory + self._captured_states = None + self._request_metadata = [] + + return result diff --git a/torchspec/inference/factory.py b/torchspec/inference/factory.py index eefc40c..348eb05 100644 --- a/torchspec/inference/factory.py +++ b/torchspec/inference/factory.py @@ -25,6 +25,7 @@ from torchspec.inference.engine.hf_engine import HFEngine from torchspec.inference.engine.sgl_engine import SglEngine +from torchspec.inference.engine.vllm_engine import VllmEngine from torchspec.utils.env import get_torchspec_env_vars from torchspec.utils.logging import logger @@ -36,7 +37,7 @@ def create_inference_engines(args, inference_pg, mooncake_config, engine_group: int = 0): """Create inference engines based on configured engine type (blocking). - Supports "hf" and "sgl" engine types via inference_engine_type config. + Supports "hf", "sgl", and "vllm" engine types via inference_engine_type config. Returns: List of head engines used for dispatching requests. Multi-node TP @@ -44,7 +45,7 @@ def create_inference_engines(args, inference_pg, mooncake_config, engine_group: """ engine_type = getattr(args, "inference_engine_type", "hf") - if engine_type not in ("hf", "sgl"): + if engine_type not in ("hf", "sgl", "vllm"): raise ValueError(f"Unknown inference_engine_type: {engine_type}") logger.info(f"Using {engine_type} engine for inference") @@ -76,15 +77,19 @@ def prepare_inference_engines(args, inference_pg, mooncake_config, engine_group: """ engine_type = getattr(args, "inference_engine_type", "hf") - if engine_type not in ("hf", "sgl"): + if engine_type not in ("hf", "sgl", "vllm"): raise ValueError(f"Unknown inference_engine_type: {engine_type}") logger.info(f"Preparing {engine_type} inference engines...") if engine_type == "hf": engines, init_refs = _prepare_hf_engines(args, inference_pg, mooncake_config, engine_group) - else: + elif engine_type == "sgl": engines, init_refs = _prepare_sgl_engines(args, inference_pg, mooncake_config, engine_group) + else: + engines, init_refs = _prepare_vllm_engines( + args, inference_pg, mooncake_config, engine_group + ) return engines, init_refs @@ -95,7 +100,7 @@ def init_engines(args, pg, engine_type: str, mooncake_config=None, engine_group: Args: args: Configuration arguments. pg: Placement group tuple (pg, reordered_bundle_indices, reordered_gpu_ids). - engine_type: Engine type ("hf" or "sgl"). + engine_type: Engine type ("hf", "sgl", or "vllm"). mooncake_config: MooncakeConfig object. Returns: @@ -105,6 +110,8 @@ def init_engines(args, pg, engine_type: str, mooncake_config=None, engine_group: return _init_hf_engines(args, pg, mooncake_config, engine_group) elif engine_type == "sgl": return _init_sgl_engines(args, pg, mooncake_config, engine_group) + elif engine_type == "vllm": + return _init_vllm_engines(args, pg, mooncake_config, engine_group) else: raise ValueError(f"Unknown engine_type: {engine_type}") @@ -160,6 +167,7 @@ def _prepare_sgl_engines( accept generate() calls. init_handles are ObjectRefs for ALL engines (head + worker) that must be waited on before use. """ + nnodes = getattr(args, "sglang_nnodes", 1) num_gpus_total = getattr(args, "inference_num_gpus", 1) @@ -268,6 +276,125 @@ def _init_sgl_engines(args, pg, mooncake_config=None, engine_group: int = 0) -> return head_engines +def _prepare_vllm_engines( + args, pg, mooncake_config=None, engine_group: int = 0 +) -> tuple[list, list]: + """Create vLLM engine actors and fire init calls without waiting. + + Handles three cases: + - Single-node, multiple engines: one engine per group of GPUs + - Multi-node, single replica: one engine per node, all forming one TP group + - Multi-node, multiple replicas: N independent TP groups, each spanning nnodes + + For multi-node, worker engines are stored in a module-level list to prevent GC. + + Returns: + Tuple of (head_engines, init_handles). head_engines are the engines that + accept generate() calls. init_handles are ObjectRefs for ALL engines + (head + worker) that must be waited on before use. + """ + nnodes = getattr(args, "vllm_nnodes", 1) + num_gpus_total = getattr(args, "inference_num_gpus", 1) + + if nnodes > 1: + gpus_per_node = getattr(args, "inference_num_gpus_per_node", 8) + gpus_per_replica = nnodes * gpus_per_node + num_replicas = num_gpus_total // gpus_per_replica + num_engines = num_replicas * nnodes + gpus_per_engine = gpus_per_node + else: + gpus_per_engine = getattr(args, "inference_num_gpus_per_engine", 1) + num_replicas = num_gpus_total // gpus_per_engine + num_engines = num_replicas + + logger.info( + f"Initializing {num_engines} vLLM engines " + f"({gpus_per_engine} GPU(s) each, nnodes={nnodes}, replicas={num_replicas})" + ) + + pg_obj, reordered_bundle_indices, reordered_gpu_ids = pg + VllmRayActor = ray.remote(VllmEngine) + env_vars = get_torchspec_env_vars() + + engines = [] + for i in range(num_engines): + node_rank = i % nnodes if nnodes > 1 else 0 + + bundle_offset = i * gpus_per_engine + base_gpu_id = int(reordered_gpu_ids[bundle_offset]) + + scheduling_strategy = PlacementGroupSchedulingStrategy( + placement_group=pg_obj, + placement_group_capture_child_tasks=True, + placement_group_bundle_index=reordered_bundle_indices[bundle_offset], + ) + + engine = VllmRayActor.options( + num_cpus=0.2, + num_gpus=0.2, + scheduling_strategy=scheduling_strategy, + runtime_env={"env_vars": env_vars}, + ).remote( + args=args, + rank=i, + base_gpu_id=base_gpu_id, + num_gpus_per_engine=gpus_per_engine, + node_rank=node_rank, + engine_group=engine_group, + ) + engines.append(engine) + + dist_init_addrs: dict[int, str] = {} + if nnodes > 1: + configured_addr = getattr(args, "vllm_dist_init_addr", None) + for replica_idx in range(num_replicas): + if configured_addr and num_replicas == 1: + dist_init_addrs[replica_idx] = configured_addr + logger.info( + f"Replica {replica_idx}: using configured dist_init_addr: {configured_addr}" + ) + else: + head_engine = engines[replica_idx * nnodes] + ip, port = ray.get( + [head_engine.get_node_ip.remote(), head_engine.find_free_port.remote()], + timeout=30, + ) + addr = f"{ip}:{port}" + dist_init_addrs[replica_idx] = addr + logger.info(f"Replica {replica_idx}: auto-negotiated dist_init_addr: {addr}") + + init_handles = [] + for i, engine in enumerate(engines): + replica_idx = i // nnodes if nnodes > 1 else i + init_handles.append( + engine.init.remote( + mooncake_config=mooncake_config, + dist_init_addr=dist_init_addrs.get(replica_idx), + ) + ) + + if nnodes > 1: + head_engines = [engines[i] for i in range(num_engines) if i % nnodes == 0] + worker_engines = [engines[i] for i in range(num_engines) if i % nnodes != 0] + _alive_worker_engines.extend(worker_engines) + logger.info( + f"Prepared multi-node vLLM engines: {len(head_engines)} heads + " + f"{len(worker_engines)} workers ({num_replicas} replicas)" + ) + return head_engines, init_handles + + return engines, init_handles + + +def _init_vllm_engines(args, pg, mooncake_config=None, engine_group: int = 0) -> list: + """Initialize vLLM engines with Ray placement groups (blocking).""" + head_engines, init_handles = _prepare_vllm_engines(args, pg, mooncake_config, engine_group) + nnodes = getattr(args, "vllm_nnodes", 1) + init_timeout = getattr(args, "vllm_init_timeout", 300 if nnodes == 1 else 600) + _wait_for_init(init_handles, "Vllm", timeout=init_timeout) + return head_engines + + def _create_and_init_actors( args, pg, diff --git a/torchspec/training/data_fetcher.py b/torchspec/training/data_fetcher.py index e41f5e1..18d3f5e 100644 --- a/torchspec/training/data_fetcher.py +++ b/torchspec/training/data_fetcher.py @@ -74,7 +74,17 @@ def __init__( def _load_from_mooncake(self, sample: TrainSample) -> Dict[str, Any]: """Load tensors from mooncake key into device memory.""" - dtypes = sample.tensor_dtypes or {} + dtypes_raw = sample.tensor_dtypes or {} + + # Convert string dtypes to torch.dtype objects + dtypes = {} + for key, dtype_val in dtypes_raw.items(): + if isinstance(dtype_val, str): + # Handle "bfloat16" or "torch.bfloat16" format + dtype_str = dtype_val.replace("torch.", "") + dtypes[key] = getattr(torch, dtype_str) + else: + dtypes[key] = dtype_val # DEBUG: Print the shapes we're requesting logger.debug(