Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 177 additions & 0 deletions inference/trillium/SGLang-Jax/Qwen3-MoE/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
# Serve Qwen3-MoE with SGLang-Jax on TPU

SGLang-Jax supports multiple Mixture-of-Experts (MoE) models from the Qwen3 family with varying hardware requirements:

- **[Qwen3-30B-A3B](https://huggingface.co/Qwen/Qwen3-30B-A3B)**: Runs on 4 TPU v6e chips
- **[Qwen3-Coder-480B-A35B-Instruct](https://huggingface.co/Qwen/Qwen3-Coder-480B-A35B-Instruct)**: Requires 64 TPU v6e chips (16 nodes × 4 chips)
- Other Qwen3 MoE variants with different scale requirements

**This tutorial focuses on deploying Qwen3-Coder-480B**, the largest model requiring a multi-node distributed setup. For smaller models like Qwen3-30B, you can follow similar steps but with adjusted node counts and parallelism settings.

## Hardware Requirements

Running Qwen3-Coder-480B requires a multi-node TPU cluster:

- **Total nodes**: 16
- **TPU chips per node**: 4 (v6e)
- **Total TPU chips**: 64
- **Tensor Parallelism (TP)**: 32 (for non-MoE layers)
- **Expert Tensor Parallelism (ETP)**: 64 (for MoE experts)


## Installation

### Option 1: Install from PyPI

```bash
uv venv --python 3.12 && source .venv/bin/activate
uv pip install sglang-jax
```

### Option 2: Install from Source

```bash
git clone https://github.com/sgl-project/sglang-jax
cd sglang-jax
uv venv --python 3.12 && source .venv/bin/activate
uv pip install -e python/
```
## Launch Distributed Server

### Preparation

1. **Get Node 0 IP address** (coordinator):

```bash
# On node 0
hostname -I | awk '{print $1}'
```

Save this IP as `NODE_RANK_0_IP`.

2. **Download model** (recommended to use shared storage or pre-download on all nodes):

```bash
export HF_TOKEN=your_huggingface_token
huggingface-cli download Qwen/Qwen3-Coder-480B --local-dir /path/to/model
```

### Launch Command

Run the following command **on each node**, replacing:
- `<NODE_RANK_0_IP>`: IP address of node 0
- `<NODE_RANK>`: Current node rank (0-15)
- `<QWEN3_CODER_480B_MODEL_PATH>`: Path to the downloaded model

```bash
JAX_COMPILATION_CACHE_DIR=/tmp/jit_cache \
python3 -u -m sgl_jax.launch_server \
--model-path <QWEN3_CODER_480B_MODEL_PATH> \
--trust-remote-code \
--dist-init-addr=<NODE_RANK_0_IP>:10011 \
--nnodes=16 \
--tp-size=32 \
--device=tpu \
--random-seed=3 \
--mem-fraction-static=0.8 \
--chunked-prefill-size=2048 \
--download-dir=/dev/shm \
--dtype=bfloat16 \
--max-running-requests=128 \
--skip-server-warmup \
--page-size=128 \
--tool-call-parser=qwen3_coder \
--node-rank=<NODE_RANK>
```

### Example for Specific Nodes

**Node 0 (coordinator):**

```bash
JAX_COMPILATION_CACHE_DIR=/tmp/jit_cache \
python3 -u -m sgl_jax.launch_server \
--model-path /path/to/Qwen3-Coder-480B \
--trust-remote-code \
--dist-init-addr=10.0.0.2:10011 \
--nnodes=16 \
--tp-size=32 \
--device=tpu \
--random-seed=3 \
--mem-fraction-static=0.8 \
--chunked-prefill-size=2048 \
--download-dir=/dev/shm \
--dtype=bfloat16 \
--max-running-requests=128 \
--skip-server-warmup \
--page-size=128 \
--tool-call-parser=qwen3_coder \
--node-rank=0
```

**Node 1:**

```bash
# Same command but with --node-rank=1
JAX_COMPILATION_CACHE_DIR=/tmp/jit_cache \
python3 -u -m sgl_jax.launch_server \
--model-path /path/to/Qwen3-Coder-480B \
--trust-remote-code \
--dist-init-addr=10.0.0.2:10011 \
--nnodes=16 \
--tp-size=32 \
--device=tpu \
--random-seed=3 \
--mem-fraction-static=0.8 \
--chunked-prefill-size=2048 \
--download-dir=/dev/shm \
--dtype=bfloat16 \
--max-running-requests=128 \
--skip-server-warmup \
--page-size=128 \
--tool-call-parser=qwen3_coder \
--node-rank=1
```

Repeat for all 16 nodes, incrementing `--node-rank` from 0 to 15.

## Configuration Parameters

### Distributed Training

- `--nnodes`: Number of nodes in the cluster (16)
- `--node-rank`: Rank of the current node (0-15)
- `--dist-init-addr`: Address of the coordinator node (node 0) with port

### Model Parallelism

- `--tp-size`: Tensor parallelism size for non-MoE layers (32)
- **ETP**: Expert tensor parallelism automatically configured to 64 based on total chips

### Memory and Performance

- `--mem-fraction-static`: Memory allocation for static buffers (0.8)
- `--chunked-prefill-size`: Prefill chunk size for batching (2048)
- `--max-running-requests`: Maximum concurrent requests (128)
- `--page-size`: Page size for memory management (128)

### Model-Specific

- `--tool-call-parser`: Parser for tool calls, set to `qwen3_coder` for this model
- `--dtype`: Data type for inference (bfloat16)
- `--random-seed`: Random seed for reproducibility (3)

## Verification

Once all nodes are running, the server will be accessible via the coordinator node (node 0). You can test it with:

```bash
curl http://<NODE_RANK_0_IP>:8000/v1/completions \
-H "Content-Type: application/json" \
-d '{
"model": "Qwen/Qwen3-Coder-480B",
"prompt": "def fibonacci(n):",
"max_tokens": 200,
"temperature": 0
}'
```
152 changes: 152 additions & 0 deletions inference/trillium/SGLang-Jax/Qwen3/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
# Serve Qwen3 with SGLang-Jax on TPU

This guide demonstrates how to serve [Qwen3-8B](https://huggingface.co/Qwen/Qwen3-8B) and [Qwen3-32B](https://huggingface.co/Qwen/Qwen3-32B) using SGLang-Jax on TPU.


## Provision TPU Resources

For **Qwen3-8B**, a single v6e chip is sufficient. For **Qwen3-32B**, use 4 chips or more.

### Option 1: Using gcloud CLI

Install and configure gcloud CLI by following the [official installation guide](https://cloud.google.com/sdk/docs/install).

**Create TPU VM:**

```bash
gcloud compute tpus tpu-vm create sgl-jax \
--zone=us-east5-a \
--version=v2-alpha-tpuv6e \
--accelerator-type=v6e-4
```

**Connect to TPU VM:**

```bash
gcloud compute tpus tpu-vm ssh sgl-jax --zone us-east5-a
```

### Option 2: Using SkyPilot (Recommended for Development)

SkyPilot simplifies TPU provisioning and offers automatic cost optimization, instance management, and environment setup.

**Prerequisites:**
- [Install SkyPilot](https://docs.skypilot.co/en/latest/getting-started/installation.html)
- [Configure GCP credentials](https://docs.skypilot.co/en/latest/getting-started/installation.html#gcp)

**Create configuration file `sgl-jax.yaml`:**

```yaml
resources:
accelerators: tpuv6e-4
accelerator_args:
tpu_vm: True
runtime_version: v2-alpha-tpuv6e

setup: |
uv venv --python 3.12
source .venv/bin/activate
uv pip install sglang-jax
```

**Launch TPU cluster:**

```bash
sky launch sgl-jax.yaml \
--cluster=sgl-jax-skypilot-v6e-4 \
--infra=gcp \
-i 30 \
--down \
-y \
--use-spot
```

This command will:
- Find the lowest-cost spot instance across regions
- Automatically shut down after 30 minutes of idleness
- Set up the SGLang-Jax environment automatically

**Connect to cluster:**

```bash
ssh sgl-jax-skypilot-v6e-4
```

> **Note:** SkyPilot manages the external IP automatically, so you don't need to track it manually.

## Installation

> **Note:** If you used SkyPilot to provision resources, the environment is already set up. Skip to the [Launch Server](#launch-server) section.

For gcloud CLI users, install SGLang-Jax using one of the following methods:

### Option 1: Install from PyPI

```bash
uv venv --python 3.12 && source .venv/bin/activate
uv pip install sglang-jax
```

### Option 2: Install from Source

```bash
git clone https://github.com/sgl-project/sglang-jax
cd sglang-jax
uv venv --python 3.12 && source .venv/bin/activate
uv pip install -e python/
```

## Launch Server

Set the model name and start the SGLang-Jax server:

```bash
export MODEL_NAME="Qwen/Qwen3-8B" # or "Qwen/Qwen3-32B"

JAX_COMPILATION_CACHE_DIR=/tmp/jit_cache \
uv run python -u -m sgl_jax.launch_server \
--model-path ${MODEL_NAME} \
--trust-remote-code \
--tp-size=4 \
--device=tpu \
--mem-fraction-static=0.8 \
--chunked-prefill-size=2048 \
--download-dir=/tmp \
--dtype=bfloat16 \
--max-running-requests 256 \
--skip-server-warmup \
--page-size=128
```

### Configuration Parameters

- `--tp-size`: Tensor parallelism size, should equal the number of TPU chips in your instance
- `--mem-fraction-static`: Fraction of memory allocated for static buffers
- `--chunked-prefill-size`: Size of prefill chunks for batching
- `--max-running-requests`: Maximum number of concurrent requests

## Run Benchmark

Test serving performance with different workload configurations:

```bash
uv run python -m sgl_jax.bench_serving \
--backend sgl-jax \
--dataset-name random \
--num-prompts 256 \
--random-input 4096 \
--random-output 1024 \
--max-concurrency 64 \
--random-range-ratio 1 \
--warmup-requests 0
```

### Benchmark Parameters

- `--backend`: Backend engine (use `sgl-jax`)
- `--random-input`: Input sequence length (e.g., 1024, 4096, 8192)
- `--random-output`: Output sequence length (e.g., 1, 1024)
- `--max-concurrency`: Maximum number of concurrent requests (e.g., 8, 16, 32, 64, 128, 256)
- `--num-prompts`: Total number of prompts to send

You can test various combinations of input/output lengths and concurrency levels to evaluate throughput and latency characteristics.
8 changes: 8 additions & 0 deletions inference/trillium/SGLang-Jax/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Serve SGLang-Jax on Trillium TPUs (v6e)

This repository provides examples demonstrating how to deploy and serve SGLang-Jax on Trillium TPUs using GCE (Google Compute Engine) for a select set of models.

- [Qwen3-8B/32B](./Qwen3/README.md)
- [Qwen/Qwen3-30B-A3B/Qwen/Qwen3-Coder-480B-A35B-Instruct](./Qwen3-MoE/README.md)

The SGLang-Jax project continues to support new models. For the specific model list, see https://github.com/sgl-project/sglang-jax/tree/main/python/sgl_jax/srt/models.