diff --git a/inference/trillium/SGLang-Jax/Qwen3-MoE/README.md b/inference/trillium/SGLang-Jax/Qwen3-MoE/README.md new file mode 100644 index 0000000..ac7f39e --- /dev/null +++ b/inference/trillium/SGLang-Jax/Qwen3-MoE/README.md @@ -0,0 +1,177 @@ +# Serve Qwen3-MoE with SGLang-Jax on TPU + +SGLang-Jax supports multiple Mixture-of-Experts (MoE) models from the Qwen3 family with varying hardware requirements: + +- **[Qwen3-30B-A3B](https://huggingface.co/Qwen/Qwen3-30B-A3B)**: Runs on 4 TPU v6e chips +- **[Qwen3-Coder-480B-A35B-Instruct](https://huggingface.co/Qwen/Qwen3-Coder-480B-A35B-Instruct)**: Requires 64 TPU v6e chips (16 nodes × 4 chips) +- Other Qwen3 MoE variants with different scale requirements + +**This tutorial focuses on deploying Qwen3-Coder-480B**, the largest model requiring a multi-node distributed setup. For smaller models like Qwen3-30B, you can follow similar steps but with adjusted node counts and parallelism settings. + +## Hardware Requirements + +Running Qwen3-Coder-480B requires a multi-node TPU cluster: + +- **Total nodes**: 16 +- **TPU chips per node**: 4 (v6e) +- **Total TPU chips**: 64 +- **Tensor Parallelism (TP)**: 32 (for non-MoE layers) +- **Expert Tensor Parallelism (ETP)**: 64 (for MoE experts) + + +## Installation + +### Option 1: Install from PyPI + +```bash +uv venv --python 3.12 && source .venv/bin/activate +uv pip install sglang-jax +``` + +### Option 2: Install from Source + +```bash +git clone https://github.com/sgl-project/sglang-jax +cd sglang-jax +uv venv --python 3.12 && source .venv/bin/activate +uv pip install -e python/ +``` +## Launch Distributed Server + +### Preparation + +1. **Get Node 0 IP address** (coordinator): + +```bash +# On node 0 +hostname -I | awk '{print $1}' +``` + +Save this IP as `NODE_RANK_0_IP`. + +2. **Download model** (recommended to use shared storage or pre-download on all nodes): + +```bash +export HF_TOKEN=your_huggingface_token +huggingface-cli download Qwen/Qwen3-Coder-480B --local-dir /path/to/model +``` + +### Launch Command + +Run the following command **on each node**, replacing: +- ``: IP address of node 0 +- ``: Current node rank (0-15) +- ``: Path to the downloaded model + +```bash +JAX_COMPILATION_CACHE_DIR=/tmp/jit_cache \ +python3 -u -m sgl_jax.launch_server \ + --model-path \ + --trust-remote-code \ + --dist-init-addr=:10011 \ + --nnodes=16 \ + --tp-size=32 \ + --device=tpu \ + --random-seed=3 \ + --mem-fraction-static=0.8 \ + --chunked-prefill-size=2048 \ + --download-dir=/dev/shm \ + --dtype=bfloat16 \ + --max-running-requests=128 \ + --skip-server-warmup \ + --page-size=128 \ + --tool-call-parser=qwen3_coder \ + --node-rank= +``` + +### Example for Specific Nodes + +**Node 0 (coordinator):** + +```bash +JAX_COMPILATION_CACHE_DIR=/tmp/jit_cache \ +python3 -u -m sgl_jax.launch_server \ + --model-path /path/to/Qwen3-Coder-480B \ + --trust-remote-code \ + --dist-init-addr=10.0.0.2:10011 \ + --nnodes=16 \ + --tp-size=32 \ + --device=tpu \ + --random-seed=3 \ + --mem-fraction-static=0.8 \ + --chunked-prefill-size=2048 \ + --download-dir=/dev/shm \ + --dtype=bfloat16 \ + --max-running-requests=128 \ + --skip-server-warmup \ + --page-size=128 \ + --tool-call-parser=qwen3_coder \ + --node-rank=0 +``` + +**Node 1:** + +```bash +# Same command but with --node-rank=1 +JAX_COMPILATION_CACHE_DIR=/tmp/jit_cache \ +python3 -u -m sgl_jax.launch_server \ + --model-path /path/to/Qwen3-Coder-480B \ + --trust-remote-code \ + --dist-init-addr=10.0.0.2:10011 \ + --nnodes=16 \ + --tp-size=32 \ + --device=tpu \ + --random-seed=3 \ + --mem-fraction-static=0.8 \ + --chunked-prefill-size=2048 \ + --download-dir=/dev/shm \ + --dtype=bfloat16 \ + --max-running-requests=128 \ + --skip-server-warmup \ + --page-size=128 \ + --tool-call-parser=qwen3_coder \ + --node-rank=1 +``` + +Repeat for all 16 nodes, incrementing `--node-rank` from 0 to 15. + +## Configuration Parameters + +### Distributed Training + +- `--nnodes`: Number of nodes in the cluster (16) +- `--node-rank`: Rank of the current node (0-15) +- `--dist-init-addr`: Address of the coordinator node (node 0) with port + +### Model Parallelism + +- `--tp-size`: Tensor parallelism size for non-MoE layers (32) +- **ETP**: Expert tensor parallelism automatically configured to 64 based on total chips + +### Memory and Performance + +- `--mem-fraction-static`: Memory allocation for static buffers (0.8) +- `--chunked-prefill-size`: Prefill chunk size for batching (2048) +- `--max-running-requests`: Maximum concurrent requests (128) +- `--page-size`: Page size for memory management (128) + +### Model-Specific + +- `--tool-call-parser`: Parser for tool calls, set to `qwen3_coder` for this model +- `--dtype`: Data type for inference (bfloat16) +- `--random-seed`: Random seed for reproducibility (3) + +## Verification + +Once all nodes are running, the server will be accessible via the coordinator node (node 0). You can test it with: + +```bash +curl http://:8000/v1/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "Qwen/Qwen3-Coder-480B", + "prompt": "def fibonacci(n):", + "max_tokens": 200, + "temperature": 0 + }' +``` diff --git a/inference/trillium/SGLang-Jax/Qwen3/README.md b/inference/trillium/SGLang-Jax/Qwen3/README.md new file mode 100644 index 0000000..ce436e0 --- /dev/null +++ b/inference/trillium/SGLang-Jax/Qwen3/README.md @@ -0,0 +1,152 @@ +# Serve Qwen3 with SGLang-Jax on TPU + +This guide demonstrates how to serve [Qwen3-8B](https://huggingface.co/Qwen/Qwen3-8B) and [Qwen3-32B](https://huggingface.co/Qwen/Qwen3-32B) using SGLang-Jax on TPU. + + +## Provision TPU Resources + +For **Qwen3-8B**, a single v6e chip is sufficient. For **Qwen3-32B**, use 4 chips or more. + +### Option 1: Using gcloud CLI + +Install and configure gcloud CLI by following the [official installation guide](https://cloud.google.com/sdk/docs/install). + +**Create TPU VM:** + +```bash +gcloud compute tpus tpu-vm create sgl-jax \ + --zone=us-east5-a \ + --version=v2-alpha-tpuv6e \ + --accelerator-type=v6e-4 +``` + +**Connect to TPU VM:** + +```bash +gcloud compute tpus tpu-vm ssh sgl-jax --zone us-east5-a +``` + +### Option 2: Using SkyPilot (Recommended for Development) + +SkyPilot simplifies TPU provisioning and offers automatic cost optimization, instance management, and environment setup. + +**Prerequisites:** +- [Install SkyPilot](https://docs.skypilot.co/en/latest/getting-started/installation.html) +- [Configure GCP credentials](https://docs.skypilot.co/en/latest/getting-started/installation.html#gcp) + +**Create configuration file `sgl-jax.yaml`:** + +```yaml +resources: + accelerators: tpuv6e-4 + accelerator_args: + tpu_vm: True + runtime_version: v2-alpha-tpuv6e + +setup: | + uv venv --python 3.12 + source .venv/bin/activate + uv pip install sglang-jax +``` + +**Launch TPU cluster:** + +```bash +sky launch sgl-jax.yaml \ + --cluster=sgl-jax-skypilot-v6e-4 \ + --infra=gcp \ + -i 30 \ + --down \ + -y \ + --use-spot +``` + +This command will: +- Find the lowest-cost spot instance across regions +- Automatically shut down after 30 minutes of idleness +- Set up the SGLang-Jax environment automatically + +**Connect to cluster:** + +```bash +ssh sgl-jax-skypilot-v6e-4 +``` + +> **Note:** SkyPilot manages the external IP automatically, so you don't need to track it manually. + +## Installation + +> **Note:** If you used SkyPilot to provision resources, the environment is already set up. Skip to the [Launch Server](#launch-server) section. + +For gcloud CLI users, install SGLang-Jax using one of the following methods: + +### Option 1: Install from PyPI + +```bash +uv venv --python 3.12 && source .venv/bin/activate +uv pip install sglang-jax +``` + +### Option 2: Install from Source + +```bash +git clone https://github.com/sgl-project/sglang-jax +cd sglang-jax +uv venv --python 3.12 && source .venv/bin/activate +uv pip install -e python/ +``` + +## Launch Server + +Set the model name and start the SGLang-Jax server: + +```bash +export MODEL_NAME="Qwen/Qwen3-8B" # or "Qwen/Qwen3-32B" + +JAX_COMPILATION_CACHE_DIR=/tmp/jit_cache \ +uv run python -u -m sgl_jax.launch_server \ + --model-path ${MODEL_NAME} \ + --trust-remote-code \ + --tp-size=4 \ + --device=tpu \ + --mem-fraction-static=0.8 \ + --chunked-prefill-size=2048 \ + --download-dir=/tmp \ + --dtype=bfloat16 \ + --max-running-requests 256 \ + --skip-server-warmup \ + --page-size=128 +``` + +### Configuration Parameters + +- `--tp-size`: Tensor parallelism size, should equal the number of TPU chips in your instance +- `--mem-fraction-static`: Fraction of memory allocated for static buffers +- `--chunked-prefill-size`: Size of prefill chunks for batching +- `--max-running-requests`: Maximum number of concurrent requests + +## Run Benchmark + +Test serving performance with different workload configurations: + +```bash +uv run python -m sgl_jax.bench_serving \ + --backend sgl-jax \ + --dataset-name random \ + --num-prompts 256 \ + --random-input 4096 \ + --random-output 1024 \ + --max-concurrency 64 \ + --random-range-ratio 1 \ + --warmup-requests 0 +``` + +### Benchmark Parameters + +- `--backend`: Backend engine (use `sgl-jax`) +- `--random-input`: Input sequence length (e.g., 1024, 4096, 8192) +- `--random-output`: Output sequence length (e.g., 1, 1024) +- `--max-concurrency`: Maximum number of concurrent requests (e.g., 8, 16, 32, 64, 128, 256) +- `--num-prompts`: Total number of prompts to send + +You can test various combinations of input/output lengths and concurrency levels to evaluate throughput and latency characteristics. diff --git a/inference/trillium/SGLang-Jax/README.md b/inference/trillium/SGLang-Jax/README.md new file mode 100644 index 0000000..1749c66 --- /dev/null +++ b/inference/trillium/SGLang-Jax/README.md @@ -0,0 +1,8 @@ +# Serve SGLang-Jax on Trillium TPUs (v6e) + +This repository provides examples demonstrating how to deploy and serve SGLang-Jax on Trillium TPUs using GCE (Google Compute Engine) for a select set of models. + +- [Qwen3-8B/32B](./Qwen3/README.md) +- [Qwen/Qwen3-30B-A3B/Qwen/Qwen3-Coder-480B-A35B-Instruct](./Qwen3-MoE/README.md) + +The SGLang-Jax project continues to support new models. For the specific model list, see https://github.com/sgl-project/sglang-jax/tree/main/python/sgl_jax/srt/models. \ No newline at end of file