Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 35 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,53 @@ TorchSpec is a torch-native speculative decoding training framework. We introduc

## Setup

### Choose Your Backend

TorchSpec supports two inference backends:

| Backend | Best For | Installation |
|---------|----------|--------------|
| **SGLang** | Production workloads, high throughput | `./tools/build_conda.sh 1 sglang` (default) |
| **vLLM** | Flexibility, easier deployment | `./tools/build_conda.sh 1 vllm` |
| **Both** | Development, comparison testing | `./tools/build_conda.sh 1 both` |

### Quick Setup

```bash
# Install with SGLang (default)
./tools/build_conda.sh
micromamba activate torchspec

# Or install with vLLM
./tools/build_conda.sh 1 vllm
micromamba activate torchspec
```

To install into your current environment instead: `./tools/build_conda.sh current`
To install into your current environment instead:
```bash
./tools/build_conda.sh current sglang # or 'vllm' or 'both'
```

Optional — install Flash Attention:

```bash
pip install -e ".[fa]"
```

### Backend-Specific Usage

**SGLang (default):**
```bash
./examples/qwen3-8b-single-node/run.sh
```

**vLLM:**
```bash
./examples/qwen3-8b-single-node/run.sh --config configs/vllm_qwen3_8b.yaml
```

TorchSpec uses vLLM's **Worker Extension** mechanism to hook into the model's forward pass and capture hidden states directly in the worker processes. This avoids RPC serialization issues and enables reliable hidden states extraction.

## Quick Start

Train an Eagle3 draft model for Qwen3-8B using inference engine (4 GPUs: 2 training + 2 inference):
Expand Down
74 changes: 74 additions & 0 deletions configs/vllm_qwen3_8b.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Configuration for train_entry.py with vLLM Engine inference (nested config format)
#
# GPU allocation:
# - 2 GPUs for inference (duplicate mode: each engine has full model copy)
# - 2 GPUs for training (DP/FSDP: model sharded across 2 GPUs)
# - Total: 4 GPUs
#
# Installation:
# pip install -e ".[vllm]" # Install vLLM backend
#
# Usage:
# python -m torchspec.train_entry --config configs/vllm_qwen3_8b.yaml
#
# Note: Uses vLLM Worker Extension to hook into model forward pass for hidden states capture.

model:
target_model_path: Qwen/Qwen3-8B
trust_remote_code: true

dataset:
train_data_path: ../examples/data/sample_conversations.jsonl
eval_data_path: ../examples/data/eval_conversations.jsonl
eval_interval: 100
chat_template: qwen
prompt_key: conversations

training:
attention_backend: flex_attention
micro_batch_size: 1
draft_accumulation_steps: 1
learning_rate: 1e-4
max_concurrent_batches: 1
max_grad_norm: 0.5
max_seq_length: 16384
num_epochs: 1
seed: 42
training_num_gpus_per_node: 2
training_num_nodes: 1
ttt_length: 7
save_per_epoch: true
warmup_ratio: 0.015

inference:
inference_engine_type: vllm
inference_num_gpus: 2
inference_num_gpus_per_engine: 2
inference_num_gpus_per_node: 4
max_sample_pool_size: 64
inference_buffer_threshold: 32
inference_batch_size: 8
vllm:
tp_size: 2
mem_fraction_static: 0.7
use_worker_extension: true
extra_args:
max_num_batched_tokens: 32768
compilation_config:
max_cudagraph_capture_size: 8

mooncake:
master_server_address: null
metadata_server: null
protocol: tcp
global_segment_size: 16GB
local_buffer_size: 4GB

output_dir: ./outputs/qwen3-8b-single-node
cache_dir: ./cache
model_download_dir: null

debug:
save_debug_train_data: null
debug_train_only: false
debug_inference_only: false
3 changes: 2 additions & 1 deletion docs/code_architecture.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ torchspec/
│ ├── base.py # InferenceEngine (ABC)
│ ├── hf_engine.py # HFEngine (Ray actor, inherits RayActor)
│ ├── hf_runner.py # HFRunner (core inference logic)
│ └── sgl_engine.py # SglEngine (Ray actor, inherits RayActor)
│ ├── sgl_engine.py # SglEngine (Ray actor, inherits RayActor)
│ └── vllm_engine.py # VllmEngine (Ray actor, uses vLLM extract_hidden_states)
├── models/ # Model definitions
│ ├── eagle3.py # Eagle3Model (core forward/loss)
│ ├── draft/ # Draft model implementations
Expand Down
12 changes: 11 additions & 1 deletion examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,22 @@ If you just want to try TorchSpec locally, start with **hf-quickstart** (3 GPUs,
./examples/hf-quickstart/run.sh
```

For production workloads with SGLang async inference, use **qwen3-8b-single-node**:
For production workloads with async inference, use **qwen3-8b-single-node**:

```bash
./examples/qwen3-8b-single-node/run.sh
```

## Switching inference backends

Examples use SGLang by default. To use vLLM instead:

```bash
# Use vLLM backend with qwen3-8b-single-node example
./examples/qwen3-8b-single-node/run.sh \
--config configs/vllm_qwen3_8b.yaml \
```

## Data

Sample training data is in [`data/sample_conversations.jsonl`](data/sample_conversations.jsonl). All examples that use local data point to this file by default.
6 changes: 2 additions & 4 deletions examples/qwen3-8b-single-node/run.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/bin/bash
# Train with SglEngine async inference (multi-GPU version)
# Train with SGLang/vLLM async inference (multi-GPU version)
#
# GPU allocation (default: 4 GPUs total):
# - 2 GPUs for inference (duplicate mode: each engine has full model copy)
Expand Down Expand Up @@ -46,7 +46,7 @@ INFERENCE_GPUS=2
LOCAL_IP=$(python3 -c "import socket; s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM); s.connect(('8.8.8.8', 80)); print(s.getsockname()[0]); s.close()")

echo "=============================================="
echo "Train with SglEngine inference"
echo "Train with async inference"
echo "=============================================="
echo "Config: $CONFIG_FILE (nested format)"
echo "Total GPUs: $TOTAL_GPUS (CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES)"
Expand All @@ -59,11 +59,9 @@ echo "=============================================="
python3 -m torchspec.train_entry \
--config "$CONFIG_FILE" \
training.training_num_gpus_per_node="$TRAIN_GPUS" \
inference.inference_engine_type="sgl" \
inference.inference_num_gpus="$INFERENCE_GPUS" \
inference.inference_num_gpus_per_engine=2 \
inference.inference_num_gpus_per_node="$TOTAL_GPUS" \
inference.sglang.tp_size=2 \
"$@"

echo "=============================================="
Expand Down
1 change: 0 additions & 1 deletion patches/vllm/v0.15.1/vllm.patch

This file was deleted.

6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,12 @@ dev = [
"ruff",
]

vllm = [
"vllm>=0.16.0",
]

fa = [
"flash-attn-cute @ git+https://github.com/Dao-AILab/flash-attention.git@fec3a6a18460c1b40f097208d4c16fe8964a679d#subdirectory=flash_attn/cute",
"flash-attention-cute @ git+https://github.com/Dao-AILab/flash-attention.git@fec3a6a18460c1b40f097208d4c16fe8964a679d#subdirectory=flash_attn/cute",
"nvidia-cutlass-dsl==4.4.0.dev1",
"nvidia-cutlass-dsl-libs-base==4.4.0.dev1",
]
Expand Down
File renamed without changes.
Loading