From 50fd543255aac7d3bb28cd2c407b63dc01aa5871 Mon Sep 17 00:00:00 2001 From: juncaipeng <13006307475@163.com> Date: Wed, 5 Nov 2025 05:19:08 +0000 Subject: [PATCH 1/2] add simple router and refine splitwise deployment --- benchmarks/backend_request_func.py | 10 +- benchmarks/benchmark_dataset.py | 6 +- benchmarks/benchmark_serving.py | 8 +- benchmarks/yaml/qwen25_7b-vl-32k-bf16.yaml | 2 +- .../yaml/request_yaml/qwen25-vl-32k.yaml | 2 +- docs/features/multi-node_deployment.md | 4 +- docs/zh/features/multi-node_deployment.md | 4 +- examples/splitwise/start_mixed.sh | 71 +++ examples/splitwise/start_v0_tp1.sh | 66 +++ examples/splitwise/start_v1_tp1.sh | 96 ++++ examples/splitwise/start_v1_tp2.sh | 98 ++++ examples/splitwise/start_v2_tp1.sh | 93 ++++ examples/splitwise/start_v2_tp2.sh | 96 ++++ examples/splitwise/stop.sh | 7 + examples/splitwise/test.sh | 20 + fastdeploy/cache_manager/cache_messager.py | 6 +- fastdeploy/config.py | 82 ++- fastdeploy/engine/args_utils.py | 68 ++- fastdeploy/engine/common_engine.py | 128 ++++- fastdeploy/engine/engine.py | 2 - fastdeploy/engine/request.py | 2 + fastdeploy/entrypoints/openai/protocol.py | 13 +- fastdeploy/entrypoints/openai/serving_chat.py | 6 +- .../entrypoints/openai/serving_completion.py | 6 +- .../inter_communicator/engine_worker_queue.py | 10 +- fastdeploy/output/token_processor.py | 13 +- fastdeploy/router/__init__.py | 15 + fastdeploy/router/launch.py | 58 ++ fastdeploy/router/router.py | 317 +++++++++++ fastdeploy/router/utils.py | 131 +++++ fastdeploy/scheduler/config.py | 11 +- fastdeploy/scheduler/local_scheduler.py | 19 + fastdeploy/splitwise/splitwise_connector.py | 33 +- fastdeploy/utils.py | 1 + fastdeploy/worker/worker_process.py | 2 + requirements.txt | 2 + tests/e2e/test_ernie_03b_pd.py | 73 +-- tests/e2e/test_ernie_03b_pd_multi_node.py | 500 ++++++++++++++++++ 38 files changed, 1910 insertions(+), 171 deletions(-) create mode 100644 examples/splitwise/start_mixed.sh create mode 100644 examples/splitwise/start_v0_tp1.sh create mode 100644 examples/splitwise/start_v1_tp1.sh create mode 100644 examples/splitwise/start_v1_tp2.sh create mode 100644 examples/splitwise/start_v2_tp1.sh create mode 100644 examples/splitwise/start_v2_tp2.sh create mode 100644 examples/splitwise/stop.sh create mode 100644 examples/splitwise/test.sh create mode 100644 fastdeploy/router/__init__.py create mode 100644 fastdeploy/router/launch.py create mode 100644 fastdeploy/router/router.py create mode 100644 fastdeploy/router/utils.py create mode 100644 tests/e2e/test_ernie_03b_pd_multi_node.py diff --git a/benchmarks/backend_request_func.py b/benchmarks/backend_request_func.py index 09eedeb8ff6..2ccb4e3452b 100644 --- a/benchmarks/backend_request_func.py +++ b/benchmarks/backend_request_func.py @@ -94,10 +94,11 @@ async def async_request_eb_openai_chat_completions( "stream_options": { "include_usage": True, "continuous_usage_stats": True, - } + }, + "max_tokens": request_func_input.output_len, } if request_func_input.response_format: - payload["response_format"] =request_func_input.response_format + payload["response_format"] = request_func_input.response_format # 超参由yaml传入 payload.update(request_func_input.hyper_parameters) @@ -132,13 +133,13 @@ async def async_request_eb_openai_chat_completions( chunk = chunk_bytes.decode("utf-8").removeprefix("data: ") if chunk != "[DONE]": - #print("####chunk:", chunk, type(chunk)) + # print("####chunk:", chunk, type(chunk)) timestamp = time.perf_counter() data = json.loads(chunk) if request_id == "None" and "id" in data: request_id = data["id"] - + if choices := data.get("choices"): content = choices[0]["delta"].get("content") reason_content = choices[0]["delta"].get("reasoning_content") @@ -164,7 +165,6 @@ async def async_request_eb_openai_chat_completions( elif usage := data.get("usage", {}): output.output_tokens = usage.get("completion_tokens", 0) output.prompt_tokens = usage.get("prompt_tokens", 0) - most_recent_timestamp = timestamp diff --git a/benchmarks/benchmark_dataset.py b/benchmarks/benchmark_dataset.py index 0bc4750623c..e9552c6d2ad 100644 --- a/benchmarks/benchmark_dataset.py +++ b/benchmarks/benchmark_dataset.py @@ -46,7 +46,7 @@ class SampleRequest: prompt_len: int expected_output_len: int response_format: Optional[dict] = None - + class BenchmarkDataset(ABC): """BenchmarkDataset""" @@ -299,7 +299,7 @@ def sample( prompt = entry["messages"][-1].get("content", "") history_QA = entry.get("messages", []) response_format = entry.get("response_format") - new_output_len = int(entry.get("max_tokens", 12288)) + new_output_len = int(entry.get("max_tokens", output_len if output_len else 12288)) if enable_multimodal_chat: prompt = self.apply_multimodal_chat_transformation(prompt, None) @@ -311,7 +311,7 @@ def sample( prompt_len=0, history_QA=history_QA, expected_output_len=new_output_len, - response_format=response_format + response_format=response_format, ) ) cnt += 1 diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index 3b779d99c3e..74ae0e37b93 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -352,7 +352,7 @@ async def benchmark( ignore_eos=ignore_eos, debug=debug, extra_body=extra_body, - response_format=response_format + response_format=response_format, ) print("test_input:", test_input) @@ -384,7 +384,7 @@ async def benchmark( logprobs=logprobs, ignore_eos=ignore_eos, extra_body=extra_body, - response_format=response_format + response_format=response_format, ) profile_output = await request_func(request_func_input=profile_input) if profile_output.success: @@ -444,7 +444,7 @@ async def limited_request_func(request_func_input, pbar): debug=debug, ignore_eos=ignore_eos, extra_body=extra_body, - response_format=response_format + response_format=response_format, ) tasks.append(asyncio.create_task(limited_request_func(request_func_input=request_func_input, pbar=pbar))) outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks) @@ -460,7 +460,7 @@ async def limited_request_func(request_func_input, pbar): api_url=base_url + "/stop_profile", output_len=test_output_len, logprobs=logprobs, - response_format=response_format + response_format=response_format, ) profile_output = await request_func(request_func_input=profile_input) if profile_output.success: diff --git a/benchmarks/yaml/qwen25_7b-vl-32k-bf16.yaml b/benchmarks/yaml/qwen25_7b-vl-32k-bf16.yaml index d159e676f60..a946c0f9859 100644 --- a/benchmarks/yaml/qwen25_7b-vl-32k-bf16.yaml +++ b/benchmarks/yaml/qwen25_7b-vl-32k-bf16.yaml @@ -3,4 +3,4 @@ max_num_seqs: 128 gpu_memory_utilization: 0.85 tensor_parallel_size: 1 limit_mm_per_prompt: '{"image": 100, "video": 100}' -enable_mm: True \ No newline at end of file +enable_mm: True diff --git a/benchmarks/yaml/request_yaml/qwen25-vl-32k.yaml b/benchmarks/yaml/request_yaml/qwen25-vl-32k.yaml index 0c9a944e699..b26e6874970 100644 --- a/benchmarks/yaml/request_yaml/qwen25-vl-32k.yaml +++ b/benchmarks/yaml/request_yaml/qwen25-vl-32k.yaml @@ -5,4 +5,4 @@ metadata: max_tokens: 32768 repetition_penalty: 1.05 frequency_penalty: 0 -presence_penalty: 0 \ No newline at end of file +presence_penalty: 0 diff --git a/docs/features/multi-node_deployment.md b/docs/features/multi-node_deployment.md index cf0920058d1..ca1ee94f532 100644 --- a/docs/features/multi-node_deployment.md +++ b/docs/features/multi-node_deployment.md @@ -26,7 +26,7 @@ We recommend using mpirun for one-command startup without manually starting each 4. Ensure all nodes can resolve each other's hostnames * Online inference startup example: - + ```shell python -m fastdeploy.entrypoints.openai.api_server \ --model baidu/ERNIE-4.5-300B-A47B-Paddle \ @@ -40,7 +40,7 @@ We recommend using mpirun for one-command startup without manually starting each ``` * Offline startup example: - + ```python from fastdeploy.engine.sampling_params import SamplingParams from fastdeploy.entrypoints.llm import LLM diff --git a/docs/zh/features/multi-node_deployment.md b/docs/zh/features/multi-node_deployment.md index 81ca1dc7218..7789f588fea 100644 --- a/docs/zh/features/multi-node_deployment.md +++ b/docs/zh/features/multi-node_deployment.md @@ -26,7 +26,7 @@ 4. 确保所有节点能够解析彼此的主机名 * 在线推理启动示例: - + ```shell python -m fastdeploy.entrypoints.openai.api_server \ --model baidu/ERNIE-4.5-300B-A47B-Paddle \ @@ -40,7 +40,7 @@ ``` * 离线启动示例: - + ```python from fastdeploy.engine.sampling_params import SamplingParams from fastdeploy.entrypoints.llm import LLM diff --git a/examples/splitwise/start_mixed.sh b/examples/splitwise/start_mixed.sh new file mode 100644 index 00000000000..bf3e78ab058 --- /dev/null +++ b/examples/splitwise/start_mixed.sh @@ -0,0 +1,71 @@ +#!/bin/bash +set -e + +wait_for_health() { + local server_port=$1 + while true; do + status_code=$(curl -s -o /dev/null -w "%{http_code}" "http://0.0.0.0:${server_port}/health" || echo "000") + if [ "$status_code" -eq 200 ]; then + break + else + echo "Service not ready. Retrying in 2s..." + sleep 2 + fi + done +} + +# prepare environment +MODEL_NAME="PaddlePaddle/ERNIE-4.5-0.3B-Paddle" +# MODEL_NAME="baidu/ERNIE-4.5-21B-A3B-Paddle" + +export FD_DEBUG=1 +export ENABLE_V1_KVCACHE_SCHEDULER=0 +export KVCACHE_GDRCOPY_FLUSH_ENABLE=1 + +unset http_proxy && unset https_proxy +rm -rf log_* + +# start router +export FD_LOG_DIR="log_router" +mkdir -p ${FD_LOG_DIR} + +router_port=9000 +nohup python -m fastdeploy.router.launch \ + --port ${router_port} \ + 2>&1 >${FD_LOG_DIR}/nohup & +sleep 1 + +# start modelserver 0 +export CUDA_VISIBLE_DEVICES=0 +export FD_LOG_DIR="log_server_0" +mkdir -p ${FD_LOG_DIR} + +nohup python -m fastdeploy.entrypoints.openai.api_server \ + --model ${MODEL_NAME} \ + --port 8100 \ + --metrics-port 8101 \ + --engine-worker-queue-port 8102 \ + --cache-queue-port 8103 \ + --max-model-len 32768 \ + --router "0.0.0.0:${router_port}" \ + 2>&1 >${FD_LOG_DIR}/nohup & +sleep 1 + +wait_for_health 8100 + +# start modelserver 1 +export CUDA_VISIBLE_DEVICES=1 +export FD_LOG_DIR="log_server_1" +mkdir -p ${FD_LOG_DIR} + +nohup python -m fastdeploy.entrypoints.openai.api_server \ + --model ${MODEL_NAME} \ + --port 8200 \ + --metrics-port 8201 \ + --engine-worker-queue-port 8202 \ + --cache-queue-port 8203 \ + --max-model-len 32768 \ + --router "0.0.0.0:${router_port}" \ + 2>&1 >${FD_LOG_DIR}/nohup & + +wait_for_health 8200 diff --git a/examples/splitwise/start_v0_tp1.sh b/examples/splitwise/start_v0_tp1.sh new file mode 100644 index 00000000000..30dbb5a906d --- /dev/null +++ b/examples/splitwise/start_v0_tp1.sh @@ -0,0 +1,66 @@ +#!/bin/bash +set -e + +# Test splitwise deployment +# v0 requires prefill and decode in one node and it uses local scheduler +# v1 supports prefill and decode in multi node and it uses splitwise scheduler +# v2 supports prefill and decode in multi node and it uses router and local scheduler + +wait_for_health() { + local server_port=$1 + while true; do + status_code=$(curl -s -o /dev/null -w "%{http_code}" "http://0.0.0.0:${server_port}/health" || echo "000") + if [ "$status_code" -eq 200 ]; then + break + else + echo "Service not ready. Retrying in 2s..." + sleep 2 + fi + done +} + +MODEL_NAME="PaddlePaddle/ERNIE-4.5-0.3B-Paddle" +# MODEL_NAME="baidu/ERNIE-4.5-21B-A3B-Paddle" +aistudio download --model ${MODEL_NAME} + +unset http_proxy && unset https_proxy +rm -rf log_* + +# start prefill +export FD_LOG_DIR="log_prefill" +mkdir -p ${FD_LOG_DIR} + +export CUDA_VISIBLE_DEVICES=0 +export FD_DEBUG=1 +export ENABLE_V1_KVCACHE_SCHEDULER=0 + +nohup python -m fastdeploy.entrypoints.openai.api_server \ + --model ${MODEL_NAME} \ + --port 8100 \ + --metrics-port 8101 \ + --engine-worker-queue-port 8102 \ + --cache-queue-port 8103 \ + --max-model-len 32768 \ + --splitwise-role "prefill" \ + 2>&1 >${FD_LOG_DIR}/nohup & +wait_for_health 8100 + +# start decode +export FD_LOG_DIR="log_decode" +mkdir -p ${FD_LOG_DIR} + +export CUDA_VISIBLE_DEVICES=1 +export FD_DEBUG=1 +export ENABLE_V1_KVCACHE_SCHEDULER=0 + +nohup python -m fastdeploy.entrypoints.openai.api_server \ + --model ${MODEL_NAME} \ + --port 9000 \ + --metrics-port 9001 \ + --engine-worker-queue-port 9002 \ + --cache-queue-port 9003 \ + --max-model-len 32768 \ + --splitwise-role "decode" \ + --innode-prefill-ports 8102 \ + 2>&1 >${FD_LOG_DIR}/nohup & +wait_for_health 9000 diff --git a/examples/splitwise/start_v1_tp1.sh b/examples/splitwise/start_v1_tp1.sh new file mode 100644 index 00000000000..12377404c1d --- /dev/null +++ b/examples/splitwise/start_v1_tp1.sh @@ -0,0 +1,96 @@ +#!/bin/bash +set -e + +# Test splitwise deployment +# v0 requires prefill and decode in one node and it uses local scheduler +# v1 supports prefill and decode in multi node and it uses splitwise scheduler +# v2 supports prefill and decode in multi node and it uses router and local scheduler + +wait_for_health() { + local server_port=$1 + while true; do + status_code=$(curl -s -o /dev/null -w "%{http_code}" "http://0.0.0.0:${server_port}/health" || echo "000") + if [ "$status_code" -eq 200 ]; then + break + else + echo "Service not ready. Retrying in 2s..." + sleep 2 + fi + done +} + +# prepare environment +MODEL_NAME="PaddlePaddle/ERNIE-4.5-0.3B-Paddle" +# MODEL_NAME="baidu/ERNIE-4.5-21B-A3B-Paddle" + +export FD_DEBUG=1 +export ENABLE_V1_KVCACHE_SCHEDULER=0 +export KVCACHE_GDRCOPY_FLUSH_ENABLE=1 + +SCRIPT_PATH=$(readlink -f "$0") +SCRIPT_DIR=$(dirname "$SCRIPT_PATH") +export $(bash ${SCRIPT_DIR}/../../scripts/get_rdma_nics.sh gpu) +echo "KVCACHE_RDMA_NICS:${KVCACHE_RDMA_NICS}" +if [ -z "${KVCACHE_RDMA_NICS}" ]; then + echo "KVCACHE_RDMA_NICS is empty, please check the output of get_rdma_nics.sh" + exit 1 +fi + +unset http_proxy && unset https_proxy +rm -rf log_* + +# start redis +if ! redis-cli ping &>/dev/null; then + echo "Redis is not running. Starting redis-server..." + redis-server --daemonize yes + sleep 1 +else + echo "Redis is already running." +fi +sleep 1 + +# start prefill +export CUDA_VISIBLE_DEVICES=0 +export FD_LOG_DIR="log_prefill" +mkdir -p ${FD_LOG_DIR} + +nohup python -m fastdeploy.entrypoints.openai.api_server \ + --model ${MODEL_NAME} \ + --port 8100 \ + --metrics-port 8101 \ + --engine-worker-queue-port 8102 \ + --cache-queue-port 8103 \ + --max-model-len 32768 \ + --splitwise-role "prefill" \ + --cache-transfer-protocol "rdma,ipc" \ + --rdma-comm-ports 8104 \ + --pd-comm-port 8105 \ + --scheduler-name "splitwise" \ + --scheduler-host "127.0.0.1" \ + --scheduler-port 6379 \ + --scheduler-ttl 9000 \ + 2>&1 >${FD_LOG_DIR}/nohup & +wait_for_health 8100 + +# start decode +export CUDA_VISIBLE_DEVICES=1 +export FD_LOG_DIR="log_decode" +mkdir -p ${FD_LOG_DIR} + +nohup python -m fastdeploy.entrypoints.openai.api_server \ + --model ${MODEL_NAME} \ + --port 9000 \ + --metrics-port 9001 \ + --engine-worker-queue-port 9002 \ + --cache-queue-port 9003 \ + --max-model-len 32768 \ + --splitwise-role "decode" \ + --cache-transfer-protocol "rdma,ipc" \ + --rdma-comm-ports 9004 \ + --pd-comm-port 9005 \ + --scheduler-name "splitwise" \ + --scheduler-host "127.0.0.1" \ + --scheduler-port 6379 \ + --scheduler-ttl 9000 \ + 2>&1 >${FD_LOG_DIR}/nohup & +wait_for_health 9000 diff --git a/examples/splitwise/start_v1_tp2.sh b/examples/splitwise/start_v1_tp2.sh new file mode 100644 index 00000000000..cf0b728064a --- /dev/null +++ b/examples/splitwise/start_v1_tp2.sh @@ -0,0 +1,98 @@ +#!/bin/bash +set -e + +# Test splitwise deployment +# v0 requires prefill and decode in one node and it uses local scheduler +# v1 supports prefill and decode in multi node and it uses splitwise scheduler +# v2 supports prefill and decode in multi node and it uses router and local scheduler + +wait_for_health() { + local server_port=$1 + while true; do + status_code=$(curl -s -o /dev/null -w "%{http_code}" "http://0.0.0.0:${server_port}/health" || echo "000") + if [ "$status_code" -eq 200 ]; then + break + else + echo "Service not ready. Retrying in 2s..." + sleep 2 + fi + done +} + +# prepare environment +MODEL_NAME="PaddlePaddle/ERNIE-4.5-0.3B-Paddle" +# MODEL_NAME="baidu/ERNIE-4.5-21B-A3B-Paddle" + +export FD_DEBUG=1 +export ENABLE_V1_KVCACHE_SCHEDULER=0 +export KVCACHE_GDRCOPY_FLUSH_ENABLE=1 + +SCRIPT_PATH=$(readlink -f "$0") +SCRIPT_DIR=$(dirname "$SCRIPT_PATH") +export $(bash ${SCRIPT_DIR}/../../scripts/get_rdma_nics.sh gpu) +echo "KVCACHE_RDMA_NICS:${KVCACHE_RDMA_NICS}" +if [ -z "${KVCACHE_RDMA_NICS}" ]; then + echo "KVCACHE_RDMA_NICS is empty, please check the output of get_rdma_nics.sh" + exit 1 +fi + +unset http_proxy && unset https_proxy +rm -rf log_* + +# start redis +if ! redis-cli ping &>/dev/null; then + echo "Redis is not running. Starting redis-server..." + redis-server --daemonize yes + sleep 1 +else + echo "Redis is already running." +fi +sleep 1 + +# start prefill +export CUDA_VISIBLE_DEVICES=0,1 +export FD_LOG_DIR="log_prefill" +mkdir -p ${FD_LOG_DIR} + +nohup python -m fastdeploy.entrypoints.openai.api_server \ + --model ${MODEL_NAME} \ + --port 8100 \ + --metrics-port 8101 \ + --engine-worker-queue-port 8102 \ + --cache-queue-port 8103 \ + --max-model-len 32768 \ + --tensor-parallel-size 2 \ + --splitwise-role "prefill" \ + --cache-transfer-protocol "rdma,ipc" \ + --pd-comm-port 8104 \ + --rdma-comm-ports 8105,8106 \ + --scheduler-name "splitwise" \ + --scheduler-host "127.0.0.1" \ + --scheduler-port 6379 \ + --scheduler-ttl 9000 \ + 2>&1 >${FD_LOG_DIR}/nohup & +wait_for_health 8100 + +# start decode +export CUDA_VISIBLE_DEVICES=2,3 +export FD_LOG_DIR="log_decode" +mkdir -p ${FD_LOG_DIR} + +nohup python -m fastdeploy.entrypoints.openai.api_server \ + --model ${MODEL_NAME} \ + --port 9000 \ + --metrics-port 9001 \ + --engine-worker-queue-port 9002 \ + --cache-queue-port 9003 \ + --max-model-len 32768 \ + --tensor-parallel-size 2 \ + --splitwise-role "decode" \ + --cache-transfer-protocol "rdma,ipc" \ + --pd-comm-port 9004 \ + --rdma-comm-ports 9005,9006 \ + --scheduler-name "splitwise" \ + --scheduler-host "127.0.0.1" \ + --scheduler-port 6379 \ + --scheduler-ttl 9000 \ + 2>&1 >${FD_LOG_DIR}/nohup & +wait_for_health 9000 diff --git a/examples/splitwise/start_v2_tp1.sh b/examples/splitwise/start_v2_tp1.sh new file mode 100644 index 00000000000..78a0358f957 --- /dev/null +++ b/examples/splitwise/start_v2_tp1.sh @@ -0,0 +1,93 @@ +#!/bin/bash +set -e + +# Test splitwise deployment +# v0 requires prefill and decode in one node and it uses local scheduler +# v1 supports prefill and decode in multi node and it uses splitwise scheduler +# v2 supports prefill and decode in multi node and it uses router and local scheduler + +wait_for_health() { + local server_port=$1 + while true; do + status_code=$(curl -s -o /dev/null -w "%{http_code}" "http://0.0.0.0:${server_port}/health" || echo "000") + if [ "$status_code" -eq 200 ]; then + break + else + echo "Service not ready. Retrying in 2s..." + sleep 2 + fi + done +} + +# prepare environment +MODEL_NAME="PaddlePaddle/ERNIE-4.5-0.3B-Paddle" +# MODEL_NAME="baidu/ERNIE-4.5-21B-A3B-Paddle" + +export FD_DEBUG=1 +export ENABLE_V1_KVCACHE_SCHEDULER=0 +export KVCACHE_GDRCOPY_FLUSH_ENABLE=1 + +SCRIPT_PATH=$(readlink -f "$0") +SCRIPT_DIR=$(dirname "$SCRIPT_PATH") +export $(bash ${SCRIPT_DIR}/../../scripts/get_rdma_nics.sh gpu) +echo "KVCACHE_RDMA_NICS:${KVCACHE_RDMA_NICS}" +if [ -z "${KVCACHE_RDMA_NICS}" ]; then + echo "KVCACHE_RDMA_NICS is empty, please check the output of get_rdma_nics.sh" + exit 1 +fi + +unset http_proxy && unset https_proxy +rm -rf log_* + +# start router +export FD_LOG_DIR="log_router" +mkdir -p ${FD_LOG_DIR} + +router_port=9000 +nohup python -m fastdeploy.router.launch \ + --port ${router_port} \ + --splitwise \ + 2>&1 >${FD_LOG_DIR}/nohup & +sleep 1 + +# start prefill +export CUDA_VISIBLE_DEVICES=0 +export FD_LOG_DIR="log_prefill" +mkdir -p ${FD_LOG_DIR} + +nohup python -m fastdeploy.entrypoints.openai.api_server \ + --model ${MODEL_NAME} \ + --port 8100 \ + --metrics-port 8101 \ + --engine-worker-queue-port 8102 \ + --cache-queue-port 8103 \ + --max-model-len 32768 \ + --splitwise-role "prefill" \ + --cache-transfer-protocol "ipc,rdma" \ + --rdma-comm-ports 8104 \ + --pd-comm-port 8105 \ + --router "0.0.0.0:${router_port}" \ + 2>&1 >${FD_LOG_DIR}/nohup & + +wait_for_health 8100 + +# start decode +export CUDA_VISIBLE_DEVICES=1 +export FD_LOG_DIR="log_decode" +mkdir -p ${FD_LOG_DIR} + +nohup python -m fastdeploy.entrypoints.openai.api_server \ + --model ${MODEL_NAME} \ + --port 8200 \ + --metrics-port 8201 \ + --engine-worker-queue-port 8202 \ + --cache-queue-port 8203 \ + --max-model-len 32768 \ + --splitwise-role "decode" \ + --cache-transfer-protocol "ipc,rdma" \ + --rdma-comm-ports 8204 \ + --pd-comm-port 8205 \ + --router "0.0.0.0:${router_port}" \ + 2>&1 >${FD_LOG_DIR}/nohup & + +wait_for_health 8200 diff --git a/examples/splitwise/start_v2_tp2.sh b/examples/splitwise/start_v2_tp2.sh new file mode 100644 index 00000000000..5563b2f4c98 --- /dev/null +++ b/examples/splitwise/start_v2_tp2.sh @@ -0,0 +1,96 @@ +#!/bin/bash +set -e + +# Test splitwise deployment +# v0 requires prefill and decode in one node and it uses local scheduler +# v1 supports prefill and decode in multi node and it uses splitwise scheduler +# v2 supports prefill and decode in multi node and it uses router and local scheduler + +wait_for_health() { + local server_port=$1 + while true; do + status_code=$(curl -s -o /dev/null -w "%{http_code}" "http://0.0.0.0:${server_port}/health" || echo "000") + if [ "$status_code" -eq 200 ]; then + break + else + echo "Service not ready. Retrying in 2s..." + sleep 2 + fi + done +} + +# prepare environment +MODEL_NAME="PaddlePaddle/ERNIE-4.5-0.3B-Paddle" +# MODEL_NAME="baidu/ERNIE-4.5-21B-A3B-Paddle" + +export FD_DEBUG=1 +export ENABLE_V1_KVCACHE_SCHEDULER=0 +export KVCACHE_GDRCOPY_FLUSH_ENABLE=1 + +SCRIPT_PATH=$(readlink -f "$0") +SCRIPT_DIR=$(dirname "$SCRIPT_PATH") +export $(bash ${SCRIPT_DIR}/../../scripts/get_rdma_nics.sh gpu) +echo "KVCACHE_RDMA_NICS:${KVCACHE_RDMA_NICS}" +if [ -z "${KVCACHE_RDMA_NICS}" ]; then + echo "KVCACHE_RDMA_NICS is empty, please check the output of get_rdma_nics.sh" + exit 1 +fi + +unset http_proxy && unset https_proxy +rm -rf log_* + +# start router +export FD_LOG_DIR="log_router" +mkdir -p ${FD_LOG_DIR} + +echo "start router" +router_port=9000 +nohup python -m fastdeploy.router.launch \ + --port ${router_port} \ + --splitwise \ + 2>&1 >${FD_LOG_DIR}/nohup & +sleep 1 + +# start prefill +export CUDA_VISIBLE_DEVICES=0,1 +export FD_LOG_DIR="log_prefill" +mkdir -p ${FD_LOG_DIR} + +echo "start prefill" +nohup python -m fastdeploy.entrypoints.openai.api_server \ + --model ${MODEL_NAME} \ + --port 8100 \ + --metrics-port 8101 \ + --engine-worker-queue-port 8102 \ + --cache-queue-port 8103 \ + --tensor-parallel-size 2 \ + --max-model-len 32768 \ + --splitwise-role "prefill" \ + --pd-comm-port 8104 \ + --rdma-comm-ports 8105,8106 \ + --router "0.0.0.0:${router_port}" \ + 2>&1 >${FD_LOG_DIR}/nohup & + +wait_for_health 8100 + +# start decode +export CUDA_VISIBLE_DEVICES=2,3 +export FD_LOG_DIR="log_decode" +mkdir -p ${FD_LOG_DIR} + +echo "start decode" +nohup python -m fastdeploy.entrypoints.openai.api_server \ + --model ${MODEL_NAME} \ + --port 8200 \ + --metrics-port 8201 \ + --engine-worker-queue-port 8202 \ + --cache-queue-port 8203 \ + --max-model-len 32768 \ + --tensor-parallel-size 2 \ + --splitwise-role "decode" \ + --pd-comm-port 8204 \ + --rdma-comm-ports 8205,8206 \ + --router "0.0.0.0:${router_port}" \ + 2>&1 >${FD_LOG_DIR}/nohup & + +wait_for_health 8200 diff --git a/examples/splitwise/stop.sh b/examples/splitwise/stop.sh new file mode 100644 index 00000000000..5b0f13c5d95 --- /dev/null +++ b/examples/splitwise/stop.sh @@ -0,0 +1,7 @@ +pkill -9 -f python +pkill -9 -f fastdeploy +pkill -f -9 gunicorn + +if redis-cli ping >/dev/null 2>&1; then + redis-cli shutdown +fi diff --git a/examples/splitwise/test.sh b/examples/splitwise/test.sh new file mode 100644 index 00000000000..090e2ec2e10 --- /dev/null +++ b/examples/splitwise/test.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +# using v0 version, the request must be sent to the decode instance +# using v1 version, the request can be sent to the prefill or decode instance +# using v2 version, the request must be sent to the router + +port=${1:-9000} +echo "port: ${port}" + +unset http_proxy && unset https_proxy + +curl -X POST "http://0.0.0.0:${port}/v1/chat/completions" \ +-H "Content-Type: application/json" \ +-d '{ + "messages": [ + {"role": "user", "content": "Introduce shenzhen"} + ], + "max_tokens": 20, + "stream": true +}' diff --git a/fastdeploy/cache_manager/cache_messager.py b/fastdeploy/cache_manager/cache_messager.py index e6e6aa15218..0e3469b2267 100644 --- a/fastdeploy/cache_manager/cache_messager.py +++ b/fastdeploy/cache_manager/cache_messager.py @@ -344,6 +344,7 @@ def prefill_layerwise_send_cache_thread(self): ) item["layer_idx"] = current_layer_idx if item["layer_idx"] == self.num_layers: + item["status"] = "finished" if item["transfer_protocol"] == "ipc": self.messager["ipc"].write_block_by_sync(target_id) logger.info(f"finish write cache {item['request_id']}") @@ -359,7 +360,7 @@ def prefill_layerwise_send_cache_thread(self): def _handle_connect_task(self): while True: try: - task = self.engine_worker_queue.get_connect_rdma_task() + task, _ = self.engine_worker_queue.get_connect_rdma_task() if task is None: time.sleep(0.001) continue @@ -376,7 +377,8 @@ def _handle_connect_task(self): self.engine_worker_queue.connect_task_response_barrier.wait() self.engine_worker_queue.put_connect_rdma_task_response(response) except Exception as e: - logger.error(f"handle_connect_task has exception: {e}") + time.sleep(0.001) + logger.error(f"handle_connect_task has exception: {e}, {str(traceback.format_exc())}") class CacheMessagerV1: diff --git a/fastdeploy/config.py b/fastdeploy/config.py index 14e5579764a..c80350db0e5 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -1308,6 +1308,24 @@ def print(self): logger.info("=============================================================") +class RouterConfig: + """ + Configuration for router + Attributes: + router: the url of router, such as http://127.0.0.1:8000 + api_server_host: the host ip of model server + api_server_port: the http port of model server + """ + + def __init__(self, args: dict): + self.router = args["router"] + if self.router is not None and not self.router.startswith(("http://", "https://")): + self.router = f"http://{self.router}" + + self.api_server_host = get_host_ip() + self.api_server_port = args["port"] + + class CommitConfig: """ Configuration for tracking version information from version.txt @@ -1409,6 +1427,7 @@ def __init__( speculative_config: SpeculativeConfig = None, eplb_config: EPLBConfig = None, structured_outputs_config: StructuredOutputsConfig = None, + router_config: RouterConfig = None, tokenizer: str = None, ips: str = None, use_warmup: bool = False, @@ -1436,6 +1455,7 @@ def __init__( self.cache_config: CacheConfig = cache_config # type: ignore self.plas_attention_config: Optional[PlasAttentionConfig] = plas_attention_config self.structured_outputs_config: StructuredOutputsConfig = structured_outputs_config + self.router_config: RouterConfig = router_config # Initialize cuda graph capture list max_capture_shape = self.scheduler_config.max_num_seqs @@ -1515,6 +1535,7 @@ def __init__( self.read_from_config() self.postprocess() + self.init_cache_info() if test_mode: return self.check() @@ -1732,29 +1753,66 @@ def init_cache_info(self): """ initialize cache info """ - disaggregate_info = {} + # TODO: group the splitiwse params, remove code of v0 + # v0 requires prefill and decode in one node and it uses local scheduler + # v1 supports prefill and decode in multi node and it uses splitwise or dp scheduler + # v2 supports prefill and decode in multi node and it uses router and local scheduler + self.splitwise_version = None + if self.scheduler_config.name == "local" and (self.router_config is None or self.router_config.router is None): + self.splitwise_version = "v0" + elif self.scheduler_config.name in ("splitwise", "dp"): + self.splitwise_version = "v1" + elif self.scheduler_config.name == "local" and self.router_config and self.router_config.router: + self.splitwise_version = "v2" + else: + raise ValueError( + f"Unsupported scheduler mode, scheduler_name: {self.scheduler_config.name}, " + f"router_config: {self.router_config}" + ) + logger.info(f"splitwise_version: {self.splitwise_version}") + + if isinstance(self.parallel_config.engine_worker_queue_port, (int, str)): + engine_worker_queue_port = self.parallel_config.engine_worker_queue_port + else: + engine_worker_queue_port = self.parallel_config.engine_worker_queue_port[ + self.parallel_config.local_data_parallel_id + ] + connector_port = self.cache_config.pd_comm_port[0] if self.cache_config.pd_comm_port else None + + self.disaggregate_info = {} if self.scheduler_config.splitwise_role != "mixed": - disaggregate_info["role"] = self.scheduler_config.splitwise_role - disaggregate_info["cache_info"] = dict() + self.disaggregate_info["role"] = self.scheduler_config.splitwise_role + self.disaggregate_info["cache_info"] = dict() current_protocol = self.cache_config.cache_transfer_protocol.split(",") - disaggregate_info["transfer_protocol"] = current_protocol + self.disaggregate_info["transfer_protocol"] = current_protocol + for protocol in current_protocol: if protocol == "ipc": - disaggregate_info["cache_info"][protocol] = { + self.disaggregate_info["cache_info"][protocol] = { "ip": self.host_ip, - "port": self.parallel_config.engine_worker_queue_port[ - self.parallel_config.local_data_parallel_id - ], + "port": engine_worker_queue_port, "device_ids": self.local_device_ids, } elif protocol == "rdma": - disaggregate_info["cache_info"][protocol] = { + self.disaggregate_info["cache_info"][protocol] = { "ip": self.host_ip, - "port": self.cache_config.pd_comm_port[0], + "port": connector_port, "rdma_port": self.cache_config.rdma_comm_ports, } - self.disaggregate_info = disaggregate_info - logger.info(f"disaggregate_info: {self.disaggregate_info}") + logger.info(f"disaggregate_info: {self.disaggregate_info}") + + if self.router_config: + self.register_info = { + "role": self.scheduler_config.splitwise_role, + "host_ip": self.host_ip, + "port": self.router_config.api_server_port, + "connector_port": connector_port, + "rdma_ports": self.cache_config.rdma_comm_ports, + "engine_worker_queue_port": engine_worker_queue_port, + "device_ids": self.local_device_ids, + "transfer_protocol": self.cache_config.cache_transfer_protocol.split(","), + } + logger.info(f"register_info: {self.register_info}") def read_from_config(self): """ diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index 04c4a9b232e..e6eb22f95e5 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -34,6 +34,7 @@ ParallelConfig, PlasAttentionConfig, PoolerConfig, + RouterConfig, RunnerOption, SpeculativeConfig, StructuredOutputsConfig, @@ -74,6 +75,10 @@ class EngineArgs: """ The name or path of the model to be used. """ + port: Optional[str] = None + """ + Port for api server. + """ served_model_name: Optional[str] = None """ The name of the model being served. @@ -445,6 +450,11 @@ class EngineArgs: - To enable custom logits processors, add your dotted paths to module and class names to the list. """ + router: Optional[str] = None + """ + Url for router server, such as `0.0.0.0:30000`. + """ + def __post_init__(self): """ Post-initialization processing to set default tokenizer if not provided. @@ -859,21 +869,6 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: help="Flag to enable prefix caching.", ) - perf_group.add_argument( - "--splitwise-role", - type=str, - default=EngineArgs.splitwise_role, - help="Role of splitwise. Default is \ - 'mixed'. (prefill, decode, mixed)", - ) - - perf_group.add_argument( - "--innode-prefill-ports", - type=lambda s: s.split(",") if s else None, - default=EngineArgs.innode_prefill_ports, - help="port for innode prefill", - ) - perf_group.add_argument( "--enable-chunked-prefill", action="store_true", @@ -903,27 +898,53 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: help=("For chunked prefill, the threshold number of" " tokens for a prompt to be considered long."), ) - perf_group.add_argument( + # Splitwise deployment parameters group + splitwise_group = parser.add_argument_group("Splitwise Deployment") + splitwise_group.add_argument( + "--splitwise-role", + type=str, + default=EngineArgs.splitwise_role, + help="Role of splitwise. Default is \ + 'mixed'. (prefill, decode, mixed)", + ) + + splitwise_group.add_argument( + "--innode-prefill-ports", + type=lambda s: s.split(",") if s else None, + default=EngineArgs.innode_prefill_ports, + help="port for innode prefill, only used in single machine splitwise deployment", + ) + + splitwise_group.add_argument( "--cache-transfer-protocol", type=str, default=EngineArgs.cache_transfer_protocol, - help="support protocol list, comma separated, default is ipc", + help="support protocol list (ipc or rdma), comma separated, default is ipc", ) - perf_group.add_argument( + splitwise_group.add_argument( "--pd-comm-port", type=lambda s: s.split(",") if s else None, default=EngineArgs.pd_comm_port, help="port for splitwise communication.", ) - perf_group.add_argument( + splitwise_group.add_argument( "--rdma-comm-ports", type=lambda s: s.split(",") if s else None, default=EngineArgs.rdma_comm_ports, help="ports for rdma communication.", ) + # Router parameters group + router_group = parser.add_argument_group("Router") + router_group.add_argument( + "--router", + type=str, + default=EngineArgs.router, + help="url for router server.", + ) + # Scheduler parameters group scheduler_group = parser.add_argument_group("Scheduler") scheduler_group.add_argument( @@ -1044,7 +1065,11 @@ def from_cli_args(cls, args: FlexibleArgumentParser) -> "EngineArgs": """ Create an instance of EngineArgs from command line arguments. """ - return cls(**{field.name: getattr(args, field.name) for field in dataclass_fields(cls)}) + args_dict = {} + for field in dataclass_fields(cls): + if hasattr(args, field.name): + args_dict[field.name] = getattr(args, field.name) + return cls(**args_dict) def create_speculative_config(self) -> SpeculativeConfig: """ """ @@ -1063,6 +1088,7 @@ def create_scheduler_config(self) -> SchedulerConfig: prefix_len = len(prefix) all = asdict(self) + all.pop("port") # port and scheduler_port are not the same params = dict() for k, v in all.items(): if k[:prefix_len] == prefix: @@ -1151,6 +1177,7 @@ def create_engine_config(self, port_availability_check=True) -> FDConfig: scheduler_cfg = self.create_scheduler_config() graph_opt_cfg = self.create_graph_optimization_config() plas_attention_config = self.create_plas_attention_config() + router_config = RouterConfig(all_dict) early_stop_cfg = self.create_early_stop_config() early_stop_cfg.update_enable_early_stop(self.enable_early_stop) @@ -1170,6 +1197,7 @@ def create_engine_config(self, port_availability_check=True) -> FDConfig: speculative_config=speculative_cfg, eplb_config=eplb_cfg, structured_outputs_config=structured_outputs_config, + router_config=router_config, ips=self.ips, use_warmup=self.use_warmup, limit_mm_per_prompt=self.limit_mm_per_prompt, diff --git a/fastdeploy/engine/common_engine.py b/fastdeploy/engine/common_engine.py index bc5f86853d6..abbab33f67e 100644 --- a/fastdeploy/engine/common_engine.py +++ b/fastdeploy/engine/common_engine.py @@ -23,10 +23,11 @@ import traceback import weakref from concurrent.futures import ThreadPoolExecutor -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Union import numpy as np import paddle +import requests import zmq from opentelemetry import trace @@ -45,6 +46,7 @@ from fastdeploy.metrics.trace_util import start_span, start_span_request from fastdeploy.model_executor.guided_decoding import schema_checker from fastdeploy.plugins.token_processor import load_token_processor_plugins +from fastdeploy.router.utils import check_service_health from fastdeploy.splitwise.internal_adapter_utils import InternalAdapter from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector from fastdeploy.utils import ( @@ -95,6 +97,7 @@ def __init__(self, cfg, start_queue=True): self.enable_decode_cache_task = envs.FD_ENABLE_CACHE_TASK == "1" if envs.ENABLE_V1_KVCACHE_SCHEDULER: + self.llm_logger.info("Use V1 KVCache Scheduler") self.resource_manager = ResourceManagerV1( cfg.scheduler_config.max_num_seqs, cfg, @@ -103,6 +106,7 @@ def __init__(self, cfg, start_queue=True): cfg.parallel_config.local_data_parallel_id, ) else: + self.llm_logger.info("Use V0 KVCache Scheduler") self.resource_manager = ResourceManager( cfg.scheduler_config.max_num_seqs, cfg, @@ -118,7 +122,6 @@ def __init__(self, cfg, start_queue=True): ] self.split_connector = SplitwiseConnector(cfg, self.engine_worker_queue, self.resource_manager) - self.waiting_requests = [] self.token_processor = TokenProcessor( cfg=cfg, cached_generated_tokens=self.scheduler, @@ -149,14 +152,18 @@ def __init__(self, cfg, start_queue=True): def start(self): self.running = True if envs.ENABLE_V1_KVCACHE_SCHEDULER: - self.insert_task_to_worker_thread = threading.Thread(target=self._scheduler_task_to_worker_v1, daemon=True) + self.insert_task_to_worker_thread = threading.Thread( + target=self._schedule_request_to_worker_v1, daemon=True + ) else: - self.insert_task_to_worker_thread = threading.Thread(target=self._insert_task_to_worker, daemon=True) + self.insert_task_to_worker_thread = threading.Thread(target=self._schedule_request_to_worker, daemon=True) self.insert_task_to_worker_thread.start() self.token_processor.tasks_queue = self.engine_worker_queue self.token_processor.run() if self.cfg.scheduler_config.splitwise_role != "mixed": - self.split_mode_get_tasks() + self._process_splitwise_task() + + self._register_to_router() def create_data_processor(self): self.input_processor = InputPreprocessor( @@ -313,7 +320,7 @@ def start_worker_queue_service(self, start_queue): local_data_parallel_id=self.cfg.parallel_config.local_data_parallel_id, ) - def insert_tasks(self, tasks, current_id=-1, allocated=False): + def insert_tasks(self, tasks: Union[List[Request], List[RequestOutput]], current_id=-1, allocated=False): """ Insert tasks to engine. """ @@ -358,6 +365,7 @@ def insert_tasks(self, tasks, current_id=-1, allocated=False): current_tasks.append(cur_task) if current_tasks: self.engine_worker_queue.put_tasks((current_tasks, self.resource_manager.real_bsz)) + self.llm_logger.debug(f"put task to engine worker queue, task:{current_tasks}") return True self.resource_manager.check_and_free_block_tables() @@ -574,7 +582,7 @@ def update_mm_requests_chunk_size(self, requests): patch_st += chunk_patch_num request.set("prefill_chunk_info", chunks_info) - def _insert_task_to_worker(self): + def _schedule_request_to_worker(self): """ Insert task to engine thread, monitor scheduler request queue. if the engine has resource, insert task to engine @@ -619,9 +627,12 @@ def _insert_task_to_worker(self): if len(tasks) == 0: time.sleep(0.001) continue + if self.cfg.splitwise_version == "v2" and self.cfg.scheduler_config.splitwise_role == "decode": + # the task in decode instance will processed in _process_splitwise_task thread + continue + llm_logger.debug(f"get tasks from scheduler: {tasks}") if self.cfg.scheduler_config.splitwise_role != "mixed": - self.llm_logger.info("Inserting splitwise tasks") self.split_connector.send_splitwise_tasks(tasks, current_id) insert_successful = self.insert_tasks(tasks, current_id) @@ -636,7 +647,7 @@ def _insert_task_to_worker(self): err_msg = f"Error happend while insert task to engine: {e}, {traceback.format_exc()!s}." self.llm_logger.error(err_msg) - def _scheduler_task_to_worker_v1(self): + def _schedule_request_to_worker_v1(self): """ Insert tasks to worker with scheduler v1 (ENABLE_V1_KVCACHE_SCHEDULER=1). """ @@ -664,6 +675,7 @@ def _fetch_request(): max_num_batched_tokens=max_num_batched_tokens, batch=num_prefill_batch, ) + self.llm_logger.debug(f"get tasks from scheduler: {tasks}") if self.cfg.scheduler_config.splitwise_role != "mixed": need_delete_tasks = [] if envs.FD_OFFLINE_PERF_TEST_FOR_PD: @@ -822,6 +834,7 @@ def _insert_zmq_task_to_scheduler(self): if envs.FD_ENABLE_INTERNAL_ADAPTER: if self.cfg.scheduler_config.splitwise_role == "decode": return + while self.running: try: block = True if len(added_requests) == 0 else False @@ -975,17 +988,38 @@ def _zmq_send_generated_tokens(self): except Exception as e: llm_logger.error(f"Unexcepted error happend: {e}, {traceback.format_exc()!s}") - def split_mode_get_tasks(self): + def _process_splitwise_task(self): """ - Split mode get tasks + Processing tasks from engine worker queue in splitwise deployment. + For v0 version, prefill instance gets tasks from engine worker queue. + For v1 and v2 version, decode instance gets raw tasks from engine worker queue to preallocate resources, + and decode instance gets prefilled tasks from engine worker queue to generate tokens. + TODO: unifiy the communication between decode and prefill instances. """ def receiver_loop(): + waiting_resource_requests = [] + waiting_ready_tasks = [] + + # Waiting for the api_server and scheduler in decode to + # receive the request sent by the client + def _decode_process_prefilled_task_v0_scheduler(input_tasks): + ready_tasks = [] + waiting_tasks = [] + for task in input_tasks: + if not hasattr(self.scheduler, "has_request") or self.scheduler.has_request(task.request_id): + ready_tasks.append(task) + else: + waiting_tasks.append(task) + self.insert_tasks(ready_tasks, allocated=True) + if self.cfg.splitwise_version in ("v0", "v2"): + self.scheduler.put_results(ready_tasks) + return waiting_tasks + while self.running: try: - processed_indices = [] - for idx, task in enumerate(self.waiting_requests): + for idx, task in enumerate(waiting_resource_requests): if envs.ENABLE_V1_KVCACHE_SCHEDULER: if self.resource_manager.preallocate_resource_in_d(task): self.llm_logger.info(f"Resource available, processing task {task.request_id}") @@ -1004,21 +1038,27 @@ def receiver_loop(): break for idx in sorted(processed_indices, reverse=True): - self.waiting_requests.pop(idx) + waiting_resource_requests.pop(idx) - if not self.engine_worker_queue.disaggregate_queue_empty(): + waiting_ready_tasks = _decode_process_prefilled_task_v0_scheduler(waiting_ready_tasks) + + if self.engine_worker_queue.disaggregate_queue_empty(): + time.sleep(0.001) + else: items = self.engine_worker_queue.get_disaggregated_tasks() for item in items: role = item[0] tasks = item[1] + # prefill instance gets tasks from engine worker queue if role == "prefill": for task in tasks: task.max_tokens = task.min_tokens = 2 self.insert_tasks(tasks) - + # decode instance gets tasks from engine worker queue elif role == "decode": - if hasattr(tasks[0], "finished"): + if isinstance(tasks[0], RequestOutput): + self.llm_logger.debug(f"receive prefilled tasks, {tasks}") if not isinstance(tasks, list): tasks = [tasks] for task in tasks: @@ -1057,13 +1097,12 @@ def receiver_loop(): self.resource_manager.insert_task_for_decoding(task) else: - self.insert_tasks(tasks, allocated=True) - if self.cfg.innode_prefill_ports is not None: - self.scheduler.put_results(tasks) - else: - if len(self.waiting_requests): + waiting_ready_tasks.extend(_decode_process_prefilled_task_v0_scheduler(tasks)) + elif isinstance(tasks[0], Request): + self.llm_logger.debug(f"receive tasks to preallocate resource, {tasks}") + if len(waiting_resource_requests): self.llm_logger.info(f"Waiting for resource for task {tasks[0].request_id}") - self.waiting_requests.extend(tasks) + waiting_resource_requests.extend(tasks) else: new_waiting = [] for task in tasks: @@ -1087,13 +1126,12 @@ def receiver_loop(): if not self.enable_decode_cache_task: self.split_connector.send_cache_infos(new_waiting, -1) else: - self.waiting_requests.extend(new_waiting) + waiting_resource_requests.extend(new_waiting) self.llm_logger.info( f"Added {len(new_waiting)} tasks to waiting queue" ) - - else: - time.sleep(0.001) + else: + raise ValueError(f"Unsupported task type: {type(tasks[0])}") except Exception as e: self.llm_logger.error(f"Error in main loop: {e}") @@ -1130,6 +1168,42 @@ def clear_data(self): llm_logger.error(f"Clear data error: {e}") return False + def _register_to_router(self): + """If use router, register this server to router""" + timeout = 5 + sleep_seconds = 10 + + def _register(): + while True: + try: + time.sleep(sleep_seconds) + + api_server_host = self.cfg.router_config.api_server_host + api_server_port = self.cfg.router_config.api_server_port + api_server_url = f"http://{api_server_host}:{api_server_port}" + if not check_service_health(api_server_url): + continue + + router_url = self.cfg.router_config.router + resp = requests.post( + f"{router_url}/register", + json=self.cfg.register_info, + timeout=timeout, + ) + if not resp.ok: + llm_logger.error( + f"Router registration failed: {resp.status_code}, " + f"{resp.text}, {self.cfg.register_info}" + ) + except requests.exceptions.RequestException as e: + llm_logger.error(f"Register to router request error: {e}") + except Exception as e: + llm_logger.exception(f"Unexpected error during router registration: {e}") + + if self.cfg.router_config.router is not None: + register_thread = threading.Thread(target=_register, daemon=True) + register_thread.start() + def _exit_sub_services(self): """ exit sub services diff --git a/fastdeploy/engine/engine.py b/fastdeploy/engine/engine.py index fb57f2abe9b..5a4b7b39a1b 100644 --- a/fastdeploy/engine/engine.py +++ b/fastdeploy/engine/engine.py @@ -694,8 +694,6 @@ def launch_components(self): self.splitwise_receive_thread.daemon = True self.splitwise_receive_thread.start() - self.cfg.init_cache_info() - role = self.cfg.scheduler_config.splitwise_role host_ip = self.cfg.host_ip disaggregate = self.cfg.disaggregate_info diff --git a/fastdeploy/engine/request.py b/fastdeploy/engine/request.py index d0906685502..96a36aa4813 100644 --- a/fastdeploy/engine/request.py +++ b/fastdeploy/engine/request.py @@ -527,6 +527,8 @@ def __repr__(self) -> str: f"num_input_image_tokens={self.num_input_image_tokens}, " f"num_input_video_tokens={self.num_input_video_tokens}, " f"metrics={self.metrics}, " + f"error_code={self.error_code}, " + f"error_msg={self.error_msg}," ) @classmethod diff --git a/fastdeploy/entrypoints/openai/protocol.py b/fastdeploy/entrypoints/openai/protocol.py index f6e9e5dca51..4a1e4ef647f 100644 --- a/fastdeploy/entrypoints/openai/protocol.py +++ b/fastdeploy/entrypoints/openai/protocol.py @@ -451,6 +451,8 @@ class CompletionRequest(BaseModel): temperature: Optional[float] = Field(default=None, ge=0) top_p: Optional[float] = Field(default=None, ge=0, le=1) user: Optional[str] = None + request_id: Optional[str] = None + disaggregate_info: Optional[dict] = None # doc: begin-completion-sampling-params top_k: Optional[int] = None @@ -486,8 +488,6 @@ def to_dict_for_infer(self, request_id=None, prompt=None): dict: request parameters in dict format """ req_dict = {} - if request_id is not None: - req_dict["request_id"] = request_id # parse request model into dict if self.suffix is not None: @@ -497,6 +497,8 @@ def to_dict_for_infer(self, request_id=None, prompt=None): if value is not None: req_dict[key] = value + if request_id is not None: + req_dict["request_id"] = request_id if prompt is not None: req_dict["prompt"] = prompt @@ -604,6 +606,8 @@ class ChatCompletionRequest(BaseModel): user: Optional[str] = None metadata: Optional[dict] = None response_format: Optional[AnyResponseFormat] = None + request_id: Optional[str] = None + disaggregate_info: Optional[dict] = None # doc: begin-chat-completion-sampling-params top_k: Optional[int] = None @@ -644,8 +648,6 @@ def to_dict_for_infer(self, request_id=None): dict: request parameters in dict format """ req_dict = {} - if request_id is not None: - req_dict["request_id"] = request_id req_dict["max_tokens"] = self.max_completion_tokens or self.max_tokens req_dict["logprobs"] = self.top_logprobs if self.logprobs else None @@ -666,6 +668,9 @@ def to_dict_for_infer(self, request_id=None): if value is not None: req_dict[key] = value + if request_id is not None: + req_dict["request_id"] = request_id + if "prompt_token_ids" in req_dict: if "messages" in req_dict: del req_dict["messages"] diff --git a/fastdeploy/entrypoints/openai/serving_chat.py b/fastdeploy/entrypoints/openai/serving_chat.py index b0b407e05ad..0cc760ebf66 100644 --- a/fastdeploy/entrypoints/openai/serving_chat.py +++ b/fastdeploy/entrypoints/openai/serving_chat.py @@ -114,7 +114,11 @@ async def create_chat_completion(self, request: ChatCompletionRequest): await asyncio.wait_for(self.engine_client.semaphore.acquire(), timeout=self.max_waiting_time) api_server_logger.info(f"current {self.engine_client.semaphore.status()}") - if request.user is not None: + if request.request_id is not None: + request_id = request.request_id + if not request_id.startswith("chatcmpl-"): + request_id = f"chatcmpl-{request_id}" + elif request.user is not None: request_id = f"chatcmpl-{request.user}-{uuid.uuid4()}" else: request_id = f"chatcmpl-{uuid.uuid4()}" diff --git a/fastdeploy/entrypoints/openai/serving_completion.py b/fastdeploy/entrypoints/openai/serving_completion.py index c27375305ed..56c58e04b0c 100644 --- a/fastdeploy/entrypoints/openai/serving_completion.py +++ b/fastdeploy/entrypoints/openai/serving_completion.py @@ -85,7 +85,11 @@ async def create_completion(self, request: CompletionRequest): error=ErrorInfo(message=err_msg, type=ErrorType.INTERNAL_ERROR, code=ErrorCode.MODEL_NOT_SUPPORT) ) created_time = int(time.time()) - if request.user is not None: + if request.request_id is not None: + request_id = request.request_id + if not request_id.startswith("cmpl-"): + request_id = f"cmpl-{request_id}" + elif request.user is not None: request_id = f"cmpl-{request.user}-{uuid.uuid4()}" else: request_id = f"cmpl-{uuid.uuid4()}" diff --git a/fastdeploy/inter_communicator/engine_worker_queue.py b/fastdeploy/inter_communicator/engine_worker_queue.py index be4880e17a5..7544db6fdc5 100644 --- a/fastdeploy/inter_communicator/engine_worker_queue.py +++ b/fastdeploy/inter_communicator/engine_worker_queue.py @@ -662,7 +662,10 @@ def put_cache_info(self, cache_info) -> None: self.client_read_info_flag[:] = [0] * self.num_client self.cache_infos.extend(cache_info) - llm_logger.debug(f"cache_infos: {self.cache_infos} local_data_parallel_id:{self.local_data_parallel_id}") + llm_logger.debug( + f"put cache_infos to engine worker queue: {self.cache_infos}, " + f"local_data_parallel_id:{self.local_data_parallel_id}" + ) self.lock_info.release() def get_cache_info(self) -> List[Any]: @@ -684,7 +687,10 @@ def get_cache_info(self) -> List[Any]: self.cache_infos[:] = list() self.lock_info.release() if len(cache_infos) != 0: - llm_logger.debug(f"get cache infos: {cache_infos} local_data_parallel_id:{self.local_data_parallel_id}") + llm_logger.debug( + f"get cache infos from engine worker queue: {cache_infos}, " + f"local_data_parallel_id:{self.local_data_parallel_id}" + ) return cache_infos def num_cache_infos(self) -> int: diff --git a/fastdeploy/output/token_processor.py b/fastdeploy/output/token_processor.py index 681f1b0d07b..f14d10d9463 100644 --- a/fastdeploy/output/token_processor.py +++ b/fastdeploy/output/token_processor.py @@ -456,6 +456,7 @@ def _recycle_resources(self, task_id, index, task, result=None, is_prefill=False recycle resources """ if is_prefill: + start_time = time.time() while True: finished_task_ids = self.engine_worker_queue.get_finished_req() if len(finished_task_ids) > 0: @@ -474,6 +475,9 @@ def _recycle_resources(self, task_id, index, task, result=None, is_prefill=False if self.prefill_result_status[task_id] != "finished": result.error_code = 400 result.error_message = f"{task_id} failed to {self.prefill_result_status[task_id]}" + llm_logger.info( + f"wait for sending cache, request_id: {task_id}, cost seconds: {time.time()-start_time:.5f}" + ) self.split_connector.send_first_token(task.disaggregate_info, [result]) break else: @@ -731,11 +735,10 @@ def _process_batch_output(self): self._record_completion_metrics(task, current_time) self._recycle_resources(task_id, i, task, result, is_prefill) break - if ( - not is_prefill - or self.cfg.scheduler_config.name == "splitwise" - or self.cfg.scheduler_config.name == "dp" - ): + + if not (is_prefill and self.cfg.splitwise_version == "v0"): + # NOTE: prefill instance in v0 version does not return result to scheduler + llm_logger.debug(f"get response from infer: {result}") batch_result.append(result) self.postprocess(batch_result, mtype) diff --git a/fastdeploy/router/__init__.py b/fastdeploy/router/__init__.py new file mode 100644 index 00000000000..31be300c18e --- /dev/null +++ b/fastdeploy/router/__init__.py @@ -0,0 +1,15 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" diff --git a/fastdeploy/router/launch.py b/fastdeploy/router/launch.py new file mode 100644 index 00000000000..421baa65e82 --- /dev/null +++ b/fastdeploy/router/launch.py @@ -0,0 +1,58 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import argparse + +from fastdeploy.router.router import start_router +from fastdeploy.utils import router_logger as logger + + +def main() -> None: + parser = argparse.ArgumentParser(description="Router for splitwise deployment testing") + parser.add_argument( + "--host", + type=str, + default="0.0.0.0", + help="Host address to bind the router server.", + ) + parser.add_argument( + "--port", + type=int, + default="9000", + help="Port number to bind the router server", + ) + parser.add_argument( + "--splitwise", + action="store_true", + help="Router uses splitwise deployment", + ) + parser.add_argument( + "--request-timeout-secs", + type=int, + default=1800, + help="Request timeout in seconds", + ) + args = parser.parse_args() + + try: + start_router(args) + except Exception as e: + logger.error(f"Error starting router: {e}") + raise e + + +if __name__ == "__main__": + main() diff --git a/fastdeploy/router/router.py b/fastdeploy/router/router.py new file mode 100644 index 00000000000..3ff4cd37125 --- /dev/null +++ b/fastdeploy/router/router.py @@ -0,0 +1,317 @@ +""" +Async Router server for FastDeploy. +Handles client requests and manages prefill/decode/mixed instances. +This module references the router implementation of slglang and vllm. +""" + +import asyncio +import random +from itertools import chain +from uuid import uuid4 + +import aiohttp +import uvicorn +from fastapi import FastAPI, HTTPException +from fastapi.responses import ORJSONResponse, Response, StreamingResponse + +from fastdeploy.router.utils import ( + InstanceInfo, + InstanceRole, + check_service_health_async, +) +from fastdeploy.utils import router_logger as logger + +app = FastAPI() + + +class Router: + """ + Router class that handles requests from client and + collects prefill/decode instance information + """ + + def __init__(self, args): + self.args = args + self.host = args.host + self.port = args.port + self.splitwise = args.splitwise + self.timeout = args.request_timeout_secs + + self.mixed_servers = [] + self.prefill_servers = [] + self.decode_servers = [] + self.lock = asyncio.Lock() # async-safe lock + + async def register_instance(self, instance_info_dict: dict): + """Register an instance asynchronously""" + try: + inst_info = InstanceInfo(**instance_info_dict) + except Exception as e: + logger.error(f"register instance failed: {e}") + raise + + if (self.splitwise and inst_info.role == InstanceRole.MIXED) or ( + not self.splitwise and inst_info.role != InstanceRole.MIXED + ): + raise ValueError(f"Invalid instance role: {inst_info.role}, splitwise: {self.splitwise}") + + if not await check_service_health_async(inst_info.url()): + raise RuntimeError(f"Instance {inst_info} is not healthy") + + async with self.lock: + if inst_info.role == InstanceRole.MIXED and inst_info not in self.mixed_servers: + self.mixed_servers.append(inst_info) + logger.info( + f"Register mixed instance success: {inst_info}, " f"total mixed: {len(self.mixed_servers)}" + ) + elif inst_info.role == InstanceRole.PREFILL and inst_info not in self.prefill_servers: + self.prefill_servers.append(inst_info) + logger.info( + f"Register prefill instance success: {inst_info}, " + f"prefill: {len(self.prefill_servers)}, decode: {len(self.decode_servers)}" + ) + elif inst_info.role == InstanceRole.DECODE and inst_info not in self.decode_servers: + self.decode_servers.append(inst_info) + logger.info( + f"Register decode instance success: {inst_info}, " + f"prefill: {len(self.prefill_servers)}, decode: {len(self.decode_servers)}" + ) + + async def registered_number(self): + """Get number of registered instances""" + return { + "mixed": len(self.mixed_servers), + "prefill": len(self.prefill_servers), + "decode": len(self.decode_servers), + } + + async def select_pd(self): + """Select one prefill and one decode server""" + async with self.lock: + if not self.prefill_servers: + raise RuntimeError("No prefill servers available") + if not self.decode_servers: + raise RuntimeError("No decode servers available") + pidx = random.randint(0, len(self.prefill_servers) - 1) + didx = random.randint(0, len(self.decode_servers) - 1) + return self.prefill_servers[pidx], self.decode_servers[didx] + + async def select_mixed(self): + """Select one mixed server""" + async with self.lock: + if not self.mixed_servers: + raise RuntimeError("No mixed servers available") + idx = random.randint(0, len(self.mixed_servers) - 1) + return self.mixed_servers[idx] + + async def handle_request(self, request_data: dict, endpoint_name: str): + if self.splitwise: + return await self.handle_splitwise_request(request_data, endpoint_name) + else: + return await self.handle_mixed_request(request_data, endpoint_name) + + async def handle_mixed_request(self, request_data: dict, endpoint_name: str): + logger.debug(f"Received request: {request_data}") + mixed_server = await self.select_mixed() + + if request_data.get("stream", False): + return await self._generate_stream(request_data, [mixed_server.url()], endpoint=endpoint_name) + else: + return await self._generate(request_data, [mixed_server.url()], endpoint=endpoint_name) + + async def handle_splitwise_request(self, request_data: dict, endpoint_name: str): + logger.debug(f"Received request: {request_data}") + prefill_server, decode_server = await self.select_pd() + + # TODO: unify the disaggregate_info in server and remove redundancy params + is_same_node = prefill_server.host_ip == decode_server.host_ip + use_ipc = ( + is_same_node and "ipc" in prefill_server.transfer_protocol and "ipc" in decode_server.transfer_protocol + ) + + cache_info = {} + if use_ipc: + cache_info["ipc"] = { + "ip": decode_server.host_ip, + "port": decode_server.engine_worker_queue_port, + "device_ids": decode_server.device_ids, + } + else: + cache_info["rdma"] = { + "ip": decode_server.host_ip, + "port": decode_server.connector_port, + "rdma_port": decode_server.rdma_ports, + } + + disaggregate_info = { + "prefill": prefill_server.to_dict(), + "decode": decode_server.to_dict(), + "role": "decode", + "cache_info": cache_info, + "transfer_protocol": "ipc" if use_ipc else "rdma", + } + + modified_request = request_data.copy() + modified_request["disaggregate_info"] = disaggregate_info + if "request_id" not in modified_request: + modified_request["request_id"] = str(uuid4()) + + logger.debug(f"Modified request: {modified_request}") + + if request_data.get("stream", False): + return await self._generate_stream( + modified_request, [prefill_server.url(), decode_server.url()], endpoint=endpoint_name + ) + else: + return await self._generate( + modified_request, [prefill_server.url(), decode_server.url()], endpoint=endpoint_name + ) + + async def _generate( + self, modified_request, urls, return_result_url_index=-1, endpoint="v1/chat/completions" + ) -> ORJSONResponse: + async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=self.timeout)) as session: + tasks = [session.post(f"{url}/{endpoint}", json=modified_request) for url in urls] + results = await asyncio.gather(*tasks) + ret_json = await results[return_result_url_index].json() + return ORJSONResponse(content=ret_json, status_code=results[return_result_url_index].status) + + async def _generate_stream( + self, modified_request, urls, return_result_url_index=-1, endpoint="v1/chat/completions" + ): + async def stream_results(): + async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=self.timeout)) as session: + tasks = [session.post(f"{url}/{endpoint}", json=modified_request) for url in urls] + results = await asyncio.gather(*tasks) + + AIOHTTP_STREAM_READ_CHUNK_SIZE = 1024 * 64 # prevent aiohttp's "Chunk too big" error + async for chunk in results[return_result_url_index].content.iter_chunked( + AIOHTTP_STREAM_READ_CHUNK_SIZE + ): + logger.debug(f"receive response chunk: {chunk}") + yield chunk + + return StreamingResponse(stream_results(), media_type="text/event-stream") + + async def monitor_instance_health(self, interval_secs: float = 5.0): + """ + Continuously check the health of prefill, decode, and mixed instances and remove unhealthy ones. + """ + while True: + try: + prefill_to_remove = [] + decode_to_remove = [] + mixed_to_remove = [] + + async with aiohttp.ClientSession() as session: + # check servers + prefill_tasks = [(inst, session.get(f"{inst.url()}/health")) for inst in self.prefill_servers] + decode_tasks = [(inst, session.get(f"{inst.url()}/health")) for inst in self.decode_servers] + mixed_tasks = [(inst, session.get(f"{inst.url()}/health")) for inst in self.mixed_servers] + + # gather all tasks concurrently + all_tasks = prefill_tasks + decode_tasks + mixed_tasks + for inst, coro in all_tasks: + try: + resp = await coro + if resp.status != 200: + logger.warning(f"Instance {inst.url()} unhealthy: {resp.status}") + if inst in self.prefill_servers: + prefill_to_remove.append(inst) + elif inst in self.decode_servers: + decode_to_remove.append(inst) + elif inst in self.mixed_servers: + mixed_to_remove.append(inst) + except Exception as e: + logger.warning(f"Instance {inst.url()} check failed: {e}") + if inst in self.prefill_servers: + prefill_to_remove.append(inst) + elif inst in self.decode_servers: + decode_to_remove.append(inst) + elif inst in self.mixed_servers: + mixed_to_remove.append(inst) + + # remove unhealthy instances under lock + async with self.lock: + if prefill_to_remove: + for inst in prefill_to_remove: + self.prefill_servers.remove(inst) + logger.info(f"Removed unhealthy prefill instance: {inst.url()}") + if decode_to_remove: + for inst in decode_to_remove: + self.decode_servers.remove(inst) + logger.info(f"Removed unhealthy decode instance: {inst.url()}") + if mixed_to_remove: + for inst in mixed_to_remove: + self.mixed_servers.remove(inst) + logger.info(f"Removed unhealthy mixed instance: {inst.url()}") + + await asyncio.sleep(interval_secs) + + prefill_instances = [inst.url() for inst in self.prefill_servers] + decode_instances = [inst.url() for inst in self.decode_servers] + mixed_instance = [inst.url() for inst in self.mixed_servers] + logger.debug( + f"Healthy prefill instances: {prefill_instances}, " + f"Healthy decode instances: {decode_instances}, " + f"Healthy mixed instance: {mixed_instance}" + ) + + except Exception as e: + logger.exception(f"Failed to monitor instance health: {e}") + + +@app.post("/register") +async def register(instance_info_dict: dict): + """Register prefill/decode/mixed servers""" + try: + await app.state.router.register_instance(instance_info_dict) + except Exception as e: + raise HTTPException(status_code=400, detail=str(e)) + return {"status": "success"} + + +@app.get("/registered_number") +async def registered_number(): + """Get the number of registered prefill/decode/mixed servers""" + return await app.state.router.registered_number() + + +@app.post("/v1/chat/completions") +async def create_chat_completion(request_data: dict): + return await app.state.router.handle_request(request_data, "v1/chat/completions") + + +@app.post("/v1/completions") +async def create_completion(request_data: dict): + return await app.state.router.handle_request(request_data, "v1/completions") + + +@app.get("/health") +async def health_check(): + """Basic health check""" + return Response(status_code=200) + + +@app.get("/health_generate") +async def health_generate(): + """Check all prefill and decode servers are healthy""" + router = app.state.router + async with aiohttp.ClientSession() as session: + tasks = [session.get(f"{s.url()}/health") for s in chain(router.prefill_servers, router.decode_servers)] + for coro in asyncio.as_completed(tasks): + resp = await coro + if resp.status != 200: + logger.warning(f"Server {resp.url} not healthy: {resp.status}") + return Response(status_code=200) + + +def start_router(router_args): + app.state.router_args = router_args + + @app.on_event("startup") + async def startup_event(): + app.state.router = Router(app.state.router_args) + asyncio.create_task(app.state.router.monitor_instance_health(interval_secs=5)) + + uvicorn.run(app, host=router_args.host, port=router_args.port) diff --git a/fastdeploy/router/utils.py b/fastdeploy/router/utils.py new file mode 100644 index 00000000000..7c83db90f7e --- /dev/null +++ b/fastdeploy/router/utils.py @@ -0,0 +1,131 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import asyncio +from dataclasses import asdict, dataclass, field +from enum import Enum +from typing import List, Union + +import aiohttp +import requests + + +class InstanceRole(Enum): + MIXED = 0 + PREFILL = 1 + DECODE = 2 + + +@dataclass +class InstanceInfo: + role: Union[InstanceRole, str] + host_ip: str + port: Union[int, str] + connector_port: Union[int, str] = 0 + engine_worker_queue_port: Union[int, str] = 0 + transfer_protocol: List[str] = field(default_factory=list) + rdma_ports: Union[List[str], List[int]] = field(default_factory=list) + device_ids: Union[List[str], List[int]] = field(default_factory=list) + + def __post_init__(self): + """check and unify fields""" + if isinstance(self.role, str): + try: + self.role = InstanceRole[self.role.upper()] + except KeyError: + raise ValueError(f"Invalid role string: {self.role}") + elif not isinstance(self.role, InstanceRole): + raise TypeError(f"role must be InstanceRole or str, got {type(self.role)}") + + for t in self.transfer_protocol: + assert t in ["ipc", "rdma"], f"Invalid transfer_protocol: {self.transfer_protocol}" + + self.port = str(self.port) + self.connector_port = str(self.connector_port) + self.engine_worker_queue_port = str(self.engine_worker_queue_port) + if self.rdma_ports: + self.rdma_ports = [str(p) for p in self.rdma_ports] + if self.device_ids: + self.device_ids = [str(i) for i in self.device_ids] + + def to_dict(self): + return {k: (v.name if isinstance(v, Enum) else v) for k, v in asdict(self).items()} + + def url(self) -> str: + url = f"{self.host_ip}:{self.port}" + if not url.startswith(("http://", "https://")): + url = f"http://{url}" + return url + + +def check_service_health(base_url: str, timeout: int = 3) -> bool: + """ + Check the health status of a service. + + Args: + base_url (str): The base URL of the service, e.g. "http://127.0.0.1:8080" + timeout (int): Request timeout in seconds. + + Returns: + bool: True if the service is healthy, False otherwise. + """ + if not base_url.startswith(("http://", "https://")): + base_url = f"http://{base_url}" + + url = f"{base_url.rstrip('/')}/health" + try: + resp = requests.get(url, timeout=timeout) + if resp.status_code == 200: + return True + else: + return False + except Exception: + return False + + +async def check_service_health_async(base_url: str, timeout: int = 3) -> bool: + """ + Asynchronously check the health status of a service. + + Args: + base_url (str): The base URL of the service, e.g. "http://127.0.0.1:8080" + timeout (int): Request timeout in seconds. + + Returns: + bool: True if the service is healthy, False otherwise. + """ + if not base_url.startswith(("http://", "https://")): + base_url = f"http://{base_url}" + + url = f"{base_url.rstrip('/')}/health" + try: + async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=timeout)) as session: + async with session.get(url) as resp: + status = resp.status + text = await resp.text() + + if status == 200: + print(f"[OK] Service is healthy ({status})") + return True + else: + print(f"[WARN] Service not healthy ({status}): {text}") + return False + except aiohttp.ClientError as e: + print(f"[ERROR] Failed to connect to {url}: {e}") + return False + except asyncio.TimeoutError: + print(f"[ERROR] Request to {url} timed out after {timeout}s") + return False diff --git a/fastdeploy/scheduler/config.py b/fastdeploy/scheduler/config.py index e992933be3e..83ee476e467 100644 --- a/fastdeploy/scheduler/config.py +++ b/fastdeploy/scheduler/config.py @@ -16,7 +16,9 @@ import redis -from fastdeploy.utils import llm_logger +from fastdeploy.utils import get_logger, llm_logger + +config_logger = get_logger("config", "config.log") from .dp_scheduler import DPScheduler from .global_scheduler import GlobalScheduler @@ -84,10 +86,10 @@ def print(self): """ Print the current configuration to logs. """ - llm_logger.info("LocalScheduler Configuration Information :") + config_logger.info("LocalScheduler Configuration Information :") for k, v in self.__dict__.items(): - llm_logger.info("{:<20}:{:<6}{}".format(k, "", v)) - llm_logger.info("=============================================================") + config_logger.info("{:<20}:{:<6}{}".format(k, "", v)) + config_logger.info("=============================================================") class DPLocalSchedulerConfig(LocalSchedulerConfig): @@ -312,6 +314,7 @@ def scheduler(self): Returns: Initialized scheduler instance (LocalScheduler or GlobalScheduler) """ + llm_logger.info("Scheduler Type: %s" % self.name) if self.name == "global": return GlobalScheduler( diff --git a/fastdeploy/scheduler/local_scheduler.py b/fastdeploy/scheduler/local_scheduler.py index b246ca09ce0..548789f7a79 100644 --- a/fastdeploy/scheduler/local_scheduler.py +++ b/fastdeploy/scheduler/local_scheduler.py @@ -195,6 +195,20 @@ def put_requests(self, requests: List[Request]) -> List[Tuple[str, Optional[str] results += [(request_id, "duplicated request_id") for request_id in duplicated_ids] return results + def has_request(self, request_id: str) -> bool: + """ + Check if there are any pending requests in the scheduler. + + Args: + request_id: Optional specific request ID to check. + If None, checks whether there are any pending requests. + + Returns: + True if there are pending requests, False otherwise. + """ + with self.mutex: + return request_id in self.requests + def calc_required_blocks(self, token_num, block_size): """ Calculate the number of blocks needed for a given number of tokens. @@ -292,6 +306,7 @@ def put_results(self, results: List[RequestOutput]): Args: results: List of RequestOutput objects containing results """ + scheduler_logger.debug(f"put results: {results}") responses: List[ScheduledResponse] = [ScheduledResponse(result) for result in results] finished_responses = [response.request_id for response in responses if response.finished] @@ -354,4 +369,8 @@ def _get_results(): if finished: self._recycle(request_id) scheduler_logger.info(f"Scheduler has pulled a finished response: {[request_id]}") + + if results: + scheduler_logger.debug(f"get responses, {results}") + return results diff --git a/fastdeploy/splitwise/splitwise_connector.py b/fastdeploy/splitwise/splitwise_connector.py index ea402239072..a56ecefecd2 100644 --- a/fastdeploy/splitwise/splitwise_connector.py +++ b/fastdeploy/splitwise/splitwise_connector.py @@ -18,12 +18,12 @@ import time import traceback from concurrent.futures import ThreadPoolExecutor -from typing import Dict +from typing import Dict, List import zmq from fastdeploy import envs -from fastdeploy.engine.request import CompletionOutput, Request, RequestOutput +from fastdeploy.engine.request import Request, RequestOutput from fastdeploy.inter_communicator import EngineWorkerQueue from fastdeploy.metrics.metrics import main_process_metrics from fastdeploy.utils import get_logger @@ -241,7 +241,7 @@ def dispatch_innode_splitwise_tasks(self, tasks, current_id): }, } - def send_splitwise_tasks(self, tasks, current_id): + def send_splitwise_tasks(self, tasks: List[Request], current_id): """ Send splitwise tasks to all connected addresses. @@ -276,6 +276,7 @@ def send_splitwise_tasks(self, tasks, current_id): task.disaggregate_info["cache_info"] = self.cfg.disaggregate_info["cache_info"] task.disaggregate_info["cache_info"]["rdma"]["current_id"] = current_id task.disaggregate_info["role"] = "decode" + self.logger.debug(f"send task to coupled instance, {addr}, {task}") self._send_message(addr, "prefill", [task]) task.disaggregate_info["cache_info"] = decode_diagg task.disaggregate_info["role"] = "prefill" @@ -311,7 +312,7 @@ def send_first_token(self, prefill_msg, tasks_list): """ if not isinstance(tasks_list, list): tasks_list = [tasks_list] - self.logger.info("send first token to port decode") + self.logger.info(f"send first token to decode, {[x.request_id for x in tasks_list]}") if prefill_msg["transfer_protocol"] == "ipc": port = prefill_msg["cache_info"]["ipc"]["port"] if port not in self.connect_innode_instances: @@ -355,7 +356,7 @@ def check_decode_allocated(self, task): self.logger.error(f"Receive_decode_allocated error: {msg}") return False, msg - def send_cache_infos(self, tasks, current_id): + def send_cache_infos(self, tasks: List[Request], current_id): """ Send cache information to specific port. @@ -432,8 +433,10 @@ def send_cache_infos(self, tasks, current_id): if not is_decode and len(temp_cache_info): for k, v in temp_cache_info.items(): + self.logger.debug(f"send cache info to cachemessager, {v}") self.engine_worker_queue.put_cache_info(v) else: + self.logger.debug(f"send cache info to coupled instance, {temp_cache_info}") if len(temp_cache_info): for k, v in temp_cache_info.items(): self.logger.info(f"{k} {v}") @@ -490,7 +493,7 @@ def _handle_prefill(self, tasks): """ Handle prefill tasks from other nodes. """ - + self.logger.debug(f"_handle_prefill function receive {tasks}") tasks_data = [Request.from_dict(task) for task in tasks] self.engine_worker_queue.put_disaggregated_tasks(("decode", tasks_data)) @@ -498,21 +501,9 @@ def _handle_decode(self, payload): """ Handle decode tasks from other nodes. """ + self.logger.debug(f"_handle_decode function receive {payload}") tasks = [] for task in payload: - tasks.append( - RequestOutput( - request_id=task["request_id"], - outputs=CompletionOutput( - index=task["outputs"]["index"], - send_idx=0, - token_ids=task["outputs"]["token_ids"], - draft_token_ids=task["outputs"]["draft_token_ids"], - ), - finished=True, - num_cached_tokens=task["num_cached_tokens"], - error_code=task["error_code"], - error_msg=task["error_msg"], - ) - ) + output = RequestOutput.from_dict(task) + tasks.append(output) self.engine_worker_queue.put_disaggregated_tasks(("decode", tasks)) diff --git a/fastdeploy/utils.py b/fastdeploy/utils.py index a5a45a8a287..89ac8e20ae8 100644 --- a/fastdeploy/utils.py +++ b/fastdeploy/utils.py @@ -982,6 +982,7 @@ def init_bos_client(): console_logger = get_logger("console", "console.log", print_to_console=True) spec_logger = get_logger("speculate", "speculate.log") zmq_client_logger = get_logger("zmq_client", "zmq_client.log") +router_logger = get_logger("router", "router.log") def parse_type(return_type: Callable[[str], T]) -> Callable[[str], T]: diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index f84ad66d239..e09e00c5ef0 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -475,8 +475,10 @@ def event_loop_normal(self) -> None: # Execute model to generate token. The generated token will be written to the buffer. # These generated tokens can be obtained through get_output op. + start_execute_time = time.time() self.worker.execute_model(req_dicts, num_running_requests) self.exist_prefill_task_signal.value[0] = self.worker.exist_prefill() + logger.debug(f"execute model cost: {time.time()-start_execute_time:.5f} s") def initialize_kv_cache(self) -> None: """Profiles the peak memory usage of the model to determine how many diff --git a/requirements.txt b/requirements.txt index d07936e1bb2..ab726f74511 100644 --- a/requirements.txt +++ b/requirements.txt @@ -42,3 +42,5 @@ opentelemetry-instrumentation-fastapi partial_json_parser msgspec einops +setproctitle +aistudio_sdk diff --git a/tests/e2e/test_ernie_03b_pd.py b/tests/e2e/test_ernie_03b_pd.py index 8fcbfb4ee2b..f8a53e9dc07 100644 --- a/tests/e2e/test_ernie_03b_pd.py +++ b/tests/e2e/test_ernie_03b_pd.py @@ -97,18 +97,24 @@ def setup_and_run_server(): clean_ports() print("log dir clean ") - if os.path.exists("log") and os.path.isdir("log"): - shutil.rmtree("log") + if os.path.exists("log_prefill") and os.path.isdir("log_prefill"): + shutil.rmtree("log_prefill") + if os.path.exists("log_decode") and os.path.isdir("log_decode"): + shutil.rmtree("log_decode") base_path = os.getenv("MODEL_PATH") if base_path: model_path = os.path.join(base_path, "ERNIE-4.5-0.3B-Paddle") else: - model_path = "./ERNIE-4.5-0.3B-Paddle" + model_path = "baidu/ERNIE-4.5-0.3B-Paddle" + print(f"model_path: {model_path}") # prefill实例 + print("start prefill...") env_prefill = os.environ.copy() env_prefill["CUDA_VISIBLE_DEVICES"] = "0" + env_prefill["ENABLE_V1_KVCACHE_SCHEDULER"] = "0" + env_prefill["FD_LOG_DIR"] = "log_prefill" env_prefill["INFERENCE_MSG_QUEUE_ID"] = str(FD_API_PORT) prefill_log_path = "server.log" prefill_cmd = [ @@ -146,12 +152,15 @@ def setup_and_run_server(): start_new_session=True, # Enables killing full group via os.killpg env=env_prefill, ) + time.sleep(3) # decode实例 + print("start decode...") env_decode = os.environ.copy() env_decode["CUDA_VISIBLE_DEVICES"] = "1" + env_prefill["ENABLE_V1_KVCACHE_SCHEDULER"] = "0" env_decode["INFERENCE_MSG_QUEUE_ID"] = str(FD_API_PORT + 1) - env_decode["FD_LOG_DIR"] = "decode_log" + env_decode["FD_LOG_DIR"] = "log_decode" decode_log_path = "decode_server.log" decode_cmd = [ sys.executable, @@ -177,6 +186,8 @@ def setup_and_run_server(): "wint8", "--splitwise-role", "decode", + "--innode-prefill-ports", + str(FD_ENGINE_QUEUE_PORT), ] # Start subprocess in new process group @@ -312,18 +323,7 @@ def test_chat_usage_stream(api_url): "stream_options": {"include_usage": True, "continuous_usage_stats": True}, "metadata": {"min_tokens": 10}, } - p_url, d_url = api_url - - response = send_request(url=p_url, payload=payload) - chunks = get_stream_chunks(response) - result = "".join([x["choices"][0]["delta"]["content"] for x in chunks[:-1]]) - print("Prefill Response:", result) - assert result != "", "结果为空" - usage = chunks[-1]["usage"] - total_tokens = usage["completion_tokens"] + usage["prompt_tokens"] - assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens" - assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens" - assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens" + _, d_url = api_url # Only the decode server receives the request response = send_request(url=d_url, payload=payload) chunks = get_stream_chunks(response) @@ -354,16 +354,7 @@ def test_chat_usage_non_stream(api_url): "stream": False, "metadata": {"min_tokens": 10}, } - p_url, d_url = api_url - - response = send_request(url=p_url, payload=payload).json() - usage = response["usage"] - result = response["choices"][0]["message"]["content"] - assert result != "", "结果为空" - total_tokens = usage["completion_tokens"] + usage["prompt_tokens"] - assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens" - assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens" - assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens" + _, d_url = api_url response = send_request(url=d_url, payload=payload).json() usage = response["usage"] @@ -388,25 +379,13 @@ def test_non_chat_usage_stream(api_url): "stream_options": {"include_usage": True, "continuous_usage_stats": True}, "metadata": {"min_tokens": 10}, } - p_url, d_url = api_url - p_url = p_url.replace("chat/completions", "completions") + _, d_url = api_url d_url = d_url.replace("chat/completions", "completions") - response = send_request(url=p_url, payload=payload) - chunks = get_stream_chunks(response) - result = "".join([x["choices"][0]["text"] for x in chunks[:-1]]) - # print("Prefill Response:", result) - assert result != "", "结果为空" - usage = chunks[-1]["usage"] - total_tokens = usage["completion_tokens"] + usage["prompt_tokens"] - assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens" - assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens" - assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens" - response = send_request(url=d_url, payload=payload) chunks = get_stream_chunks(response) result = "".join([x["choices"][0]["text"] for x in chunks[:-1]]) - # print("Decode Response:", result) + print("Decode Response:", result) assert result != "", "结果为空" usage = chunks[-1]["usage"] total_tokens = usage["completion_tokens"] + usage["prompt_tokens"] @@ -427,23 +406,13 @@ def test_non_chat_usage_non_stream(api_url): "stream": False, "metadata": {"min_tokens": 10}, } - p_url, d_url = api_url - p_url = p_url.replace("chat/completions", "completions") + _, d_url = api_url d_url = d_url.replace("chat/completions", "completions") - response = send_request(url=p_url, payload=payload).json() - usage = response["usage"] - result = response["choices"][0]["text"] - # print("Prefill Response:", result) - assert result != "", "结果为空" - total_tokens = usage["completion_tokens"] + usage["prompt_tokens"] - assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens" - assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens" - assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens" - response = send_request(url=d_url, payload=payload).json() usage = response["usage"] result = response["choices"][0]["text"] + print("Decode Response:", result) assert result != "", "结果为空" total_tokens = usage["completion_tokens"] + usage["prompt_tokens"] assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens" diff --git a/tests/e2e/test_ernie_03b_pd_multi_node.py b/tests/e2e/test_ernie_03b_pd_multi_node.py new file mode 100644 index 00000000000..0417fdacd63 --- /dev/null +++ b/tests/e2e/test_ernie_03b_pd_multi_node.py @@ -0,0 +1,500 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import shutil +import signal +import socket +import subprocess +import sys +import time + +import pytest +import requests + +# Read ports from environment variables; use default values if not set +FD_API_PORT = int(os.getenv("FD_API_PORT", 8188)) +FD_ENGINE_QUEUE_PORT = int(os.getenv("FD_ENGINE_QUEUE_PORT", 8133)) +FD_METRICS_PORT = int(os.getenv("FD_METRICS_PORT", 8233)) +FD_CACHE_QUEUE_PORT = int(os.getenv("FD_CACHE_QUEUE_PORT", 8333)) +FD_CONNECTOR_PORT = int(os.getenv("FD_CONNECTOR_PORT", 8433)) +FD_ROUTER_PORT = int(os.getenv("FD_ROUTER_PORT", 8533)) + +# List of ports to clean before and after tests +PORTS_TO_CLEAN = [ + FD_API_PORT, + FD_ENGINE_QUEUE_PORT, + FD_METRICS_PORT, + FD_CACHE_QUEUE_PORT, + FD_CONNECTOR_PORT, + FD_API_PORT + 1, + FD_ENGINE_QUEUE_PORT + 1, + FD_METRICS_PORT + 1, + FD_CACHE_QUEUE_PORT + 1, + FD_CONNECTOR_PORT + 1, + FD_ROUTER_PORT, +] + + +def is_port_open(host: str, port: int, timeout=1.0): + """ + Check if a TCP port is open on the given host. + Returns True if connection succeeds, False otherwise. + """ + try: + with socket.create_connection((host, port), timeout): + return True + except Exception: + return False + + +def check_service_health(base_url: str, timeout: int = 3) -> bool: + """ + Check the health status of a service. + + Args: + base_url (str): The base URL of the service, e.g. "http://127.0.0.1:8080" + timeout (int): Request timeout in seconds. + + Returns: + bool: True if the service is healthy, False otherwise. + """ + if not base_url.startswith("http"): + base_url = f"http://{base_url}" + url = f"{base_url.rstrip('/')}/health" + try: + resp = requests.get(url, timeout=timeout) + if resp.status_code == 200: + return True + else: + return False + except Exception: + return False + + +def get_registered_number(router_url) -> list: + """ + Get the number of registered models in the router. + + Args: + router_url (str): The base URL of the router, e.g. "http://localhost:8080". + + Returns: + int: The number of registered models. + """ + if not router_url.startswith("http"): + router_url = f"http://{router_url}" + + try: + response = requests.get(f"{router_url}/registered_number", timeout=60) + registered_numbers = response.json() + return registered_numbers + except Exception: + return {"mixed": 0, "prefill": 0, "decode": 0} + + +def kill_process_on_port(port: int): + """ + Kill processes that are listening on the given port. + Uses `lsof` to find process ids and sends SIGKILL. + """ + try: + output = subprocess.check_output(f"lsof -i:{port} -t", shell=True).decode().strip() + current_pid = os.getpid() + parent_pid = os.getppid() + for pid in output.splitlines(): + pid = int(pid) + if pid in (current_pid, parent_pid): + print(f"Skip killing current process (pid={pid}) on port {port}") + continue + os.kill(pid, signal.SIGKILL) + print(f"Killed process on port {port}, pid={pid}") + except subprocess.CalledProcessError: + pass + + +def clean_ports(): + """ + Kill all processes occupying the ports listed in PORTS_TO_CLEAN. + """ + for port in PORTS_TO_CLEAN: + kill_process_on_port(port) + time.sleep(2) + + +@pytest.fixture(scope="session", autouse=True) +def setup_and_run_server(): + """ + Pytest fixture that runs once per test session: + - Cleans ports before tests + - Starts the API server as a subprocess + - Waits for server port to open (up to 30 seconds) + - Tears down server after all tests finish + """ + print("Pre-test port cleanup...") + clean_ports() + + print("log dir clean ") + if os.path.exists("log_router") and os.path.isdir("log_router"): + shutil.rmtree("log_router") + if os.path.exists("log_prefill") and os.path.isdir("log_prefill"): + shutil.rmtree("log_prefill") + if os.path.exists("log_decode") and os.path.isdir("log_decode"): + shutil.rmtree("log_decode") + + base_path = os.getenv("MODEL_PATH") + if base_path: + model_path = os.path.join(base_path, "ERNIE-4.5-0.3B-Paddle") + else: + model_path = "baidu/ERNIE-4.5-0.3B-Paddle" + print(f"model_path: {model_path}") + + # router + print("start router...") + env_router = os.environ.copy() + env_router["FD_LOG_DIR"] = "log_router" + router_log_path = "router.log" + + router_cmd = [ + sys.executable, + "-m", + "fastdeploy.router.launch", + "--port", + str(FD_ROUTER_PORT), + "--splitwise", + ] + + with open(router_log_path, "w") as logfile: + process_router = subprocess.Popen( + router_cmd, + stdout=logfile, + stderr=subprocess.STDOUT, + start_new_session=True, # Enables killing full group via os.killpg + env=env_router, + ) + + # prefill实例 + print("start prefill...") + env_prefill = os.environ.copy() + env_prefill["CUDA_VISIBLE_DEVICES"] = "0" + env_prefill["ENABLE_V1_KVCACHE_SCHEDULER"] = "0" + env_prefill["FD_LOG_DIR"] = "log_prefill" + env_prefill["INFERENCE_MSG_QUEUE_ID"] = str(FD_API_PORT) + prefill_log_path = "server.log" + prefill_cmd = [ + sys.executable, + "-m", + "fastdeploy.entrypoints.openai.api_server", + "--model", + model_path, + "--port", + str(FD_API_PORT), + "--tensor-parallel-size", + "1", + "--engine-worker-queue-port", + str(FD_ENGINE_QUEUE_PORT), + "--metrics-port", + str(FD_METRICS_PORT), + "--cache-queue-port", + str(FD_CACHE_QUEUE_PORT), + "--max-model-len", + "8192", + "--max-num-seqs", + "20", + "--quantization", + "wint8", + "--splitwise-role", + "prefill", + "--cache-transfer-protocol", + "ipc", + "--pd-comm-port", + str(FD_CONNECTOR_PORT), + "--router", + f"0.0.0.0:{FD_ROUTER_PORT}", + ] + + # Start subprocess in new process group + with open(prefill_log_path, "w") as logfile: + process_prefill = subprocess.Popen( + prefill_cmd, + stdout=logfile, + stderr=subprocess.STDOUT, + start_new_session=True, # Enables killing full group via os.killpg + env=env_prefill, + ) + time.sleep(1) + + # decode实例 + print("start decode...") + env_decode = os.environ.copy() + env_decode["CUDA_VISIBLE_DEVICES"] = "1" + env_decode["ENABLE_V1_KVCACHE_SCHEDULER"] = "0" + env_decode["INFERENCE_MSG_QUEUE_ID"] = str(FD_API_PORT + 1) + env_decode["FD_LOG_DIR"] = "log_decode" + decode_log_path = "decode_server.log" + decode_cmd = [ + sys.executable, + "-m", + "fastdeploy.entrypoints.openai.api_server", + "--model", + model_path, + "--port", + str(FD_API_PORT + 1), + "--tensor-parallel-size", + "1", + "--engine-worker-queue-port", + str(FD_ENGINE_QUEUE_PORT + 1), + "--metrics-port", + str(FD_METRICS_PORT + 1), + "--cache-queue-port", + str(FD_CACHE_QUEUE_PORT + 1), + "--max-model-len", + "8192", + "--max-num-seqs", + "20", + "--quantization", + "wint8", + "--splitwise-role", + "decode", + "--cache-transfer-protocol", + "ipc", + "--pd-comm-port", + str(FD_CONNECTOR_PORT + 1), + "--router", + f"0.0.0.0:{FD_ROUTER_PORT}", + ] + + # Start subprocess in new process group + with open(decode_log_path, "w") as logfile: + process_decode = subprocess.Popen( + decode_cmd, + stdout=logfile, + stderr=subprocess.STDOUT, + start_new_session=True, # Enables killing full group via os.killpg + env=env_decode, + ) + + # Wait up to 300 seconds for API server to be ready + for _ in range(60): + registered_numbers = get_registered_number(f"0.0.0.0:{FD_ROUTER_PORT}") + if registered_numbers["prefill"] >= 1 and registered_numbers["decode"] >= 1: + print("Prefill and decode servers are both online") + break + time.sleep(5) + else: + print("[TIMEOUT] API server failed to start in 5 minutes. Cleaning up...") + try: + os.killpg(process_prefill.pid, signal.SIGTERM) + os.killpg(process_decode.pid, signal.SIGTERM) + clean_ports() + except Exception as e: + print(f"Failed to kill process group: {e}") + raise RuntimeError(f"API server did not start on port {FD_API_PORT}") + + yield # Run tests + + print("\n===== Post-test server cleanup... =====") + try: + os.killpg(process_router.pid, signal.SIGTERM) + os.killpg(process_prefill.pid, signal.SIGTERM) + os.killpg(process_decode.pid, signal.SIGTERM) + clean_ports() + print(f"Prefill server (pid={process_prefill.pid}) terminated") + print(f"Decode server (pid={process_decode.pid}) terminated") + except Exception as e: + print(f"Failed to terminate API server: {e}") + + +@pytest.fixture(scope="session") +def api_url(request): + """ + Returns the API endpoint URL for chat completions. + """ + return f"http://0.0.0.0:{FD_ROUTER_PORT}/v1/chat/completions" + + +@pytest.fixture(scope="session") +def metrics_url(request): + """ + Returns the metrics endpoint URL. + """ + return f"http://0.0.0.0:{FD_METRICS_PORT}/metrics" + + +@pytest.fixture +def headers(): + """ + Returns common HTTP request headers. + """ + return {"Content-Type": "application/json"} + + +def test_metrics_config(metrics_url): + timeout = 600 + url = metrics_url.replace("metrics", "config-info") + res = requests.get(url, timeout=timeout) + assert res.status_code == 200 + + +def send_request(url, payload, timeout=600): + """ + 发送请求到指定的URL,并返回响应结果。 + """ + headers = { + "Content-Type": "application/json", + } + + try: + res = requests.post(url, headers=headers, json=payload, timeout=timeout) + print("🟢 接收响应中...\n") + return res + except requests.exceptions.Timeout: + print(f"❌ 请求超时(超过 {timeout} 秒)") + return None + except requests.exceptions.RequestException as e: + print(f"❌ 请求失败:{e}") + return None + + +def get_stream_chunks(response): + """解析流式返回,生成chunk List[dict]""" + chunks = [] + + if response.status_code == 200: + for line in response.iter_lines(decode_unicode=True): + if line: + if line.startswith("data: "): + line = line[len("data: ") :] + + if line.strip() == "[DONE]": + break + + try: + chunk = json.loads(line) + chunks.append(chunk) + except Exception as e: + print(f"解析失败: {e}, 行内容: {line}") + else: + print(f"请求失败,状态码: {response.status_code}") + print("返回内容:", response.text) + + return chunks + + +def test_chat_usage_stream(api_url): + """测试流式chat usage""" + payload = { + "model": "default", + "temperature": 0, + "top_p": 0, + "seed": 33, + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "牛顿的三大运动定律是什么?"}, + ], + "max_tokens": 50, + "stream": True, + "stream_options": {"include_usage": True, "continuous_usage_stats": True}, + "metadata": {"min_tokens": 10}, + } + + response = send_request(url=api_url, payload=payload) + chunks = get_stream_chunks(response) + result = "".join([x["choices"][0]["delta"]["content"] for x in chunks[:-1]]) + print("Decode Response:", result) + assert result != "", "结果为空" + usage = chunks[-1]["usage"] + total_tokens = usage["completion_tokens"] + usage["prompt_tokens"] + assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens" + assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens" + assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens" + + +def test_chat_usage_non_stream(api_url): + """测试非流式chat usage""" + payload = { + "model": "default", + "temperature": 0, + "top_p": 0, + "seed": 33, + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "牛顿的三大运动定律是什么?"}, + ], + "max_tokens": 50, + "stream": False, + "metadata": {"min_tokens": 10}, + } + + response = send_request(url=api_url, payload=payload).json() + usage = response["usage"] + result = response["choices"][0]["message"]["content"] + assert result != "", "结果为空" + total_tokens = usage["completion_tokens"] + usage["prompt_tokens"] + assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens" + assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens" + assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens" + + +def test_non_chat_usage_stream(api_url): + """测试流式非chat usage""" + payload = { + "model": "default", + "temperature": 0, + "top_p": 0, + "seed": 33, + "prompt": "牛顿的三大运动定律是什么?", + "max_tokens": 50, + "stream": True, + "stream_options": {"include_usage": True, "continuous_usage_stats": True}, + "metadata": {"min_tokens": 10}, + } + api_url = api_url.replace("chat/completions", "completions") + + response = send_request(url=api_url, payload=payload) + chunks = get_stream_chunks(response) + result = "".join([x["choices"][0]["text"] for x in chunks[:-1]]) + print("Decode Response:", result) + assert result != "", "结果为空" + usage = chunks[-1]["usage"] + total_tokens = usage["completion_tokens"] + usage["prompt_tokens"] + assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens" + assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens" + assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens" + + +def test_non_chat_usage_non_stream(api_url): + """测试非流式非chat usage""" + payload = { + "model": "default", + "temperature": 0, + "top_p": 0, + "seed": 33, + "prompt": "牛顿的三大运动定律是什么?", + "max_tokens": 50, + "stream": False, + "metadata": {"min_tokens": 10}, + } + api_url = api_url.replace("chat/completions", "completions") + + response = send_request(url=api_url, payload=payload).json() + usage = response["usage"] + result = response["choices"][0]["text"] + print("Decode Response:", result) + assert result != "", "结果为空" + total_tokens = usage["completion_tokens"] + usage["prompt_tokens"] + assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens" + assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens" + assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens" From 7cbe8fa2c683cdab08ead0bfa0d448b49420e4c5 Mon Sep 17 00:00:00 2001 From: juncaipeng <13006307475@163.com> Date: Thu, 6 Nov 2025 02:58:50 +0000 Subject: [PATCH 2/2] fix --- fastdeploy/cache_manager/cache_messager.py | 3 +- tests/e2e/test_ernie_03b_router.py | 486 +++++++++++++++++++++ 2 files changed, 488 insertions(+), 1 deletion(-) create mode 100644 tests/e2e/test_ernie_03b_router.py diff --git a/fastdeploy/cache_manager/cache_messager.py b/fastdeploy/cache_manager/cache_messager.py index 0e3469b2267..bd02a4c4f96 100644 --- a/fastdeploy/cache_manager/cache_messager.py +++ b/fastdeploy/cache_manager/cache_messager.py @@ -344,7 +344,8 @@ def prefill_layerwise_send_cache_thread(self): ) item["layer_idx"] = current_layer_idx if item["layer_idx"] == self.num_layers: - item["status"] = "finished" + if "error" not in item["status"]: + item["status"] = "finished" if item["transfer_protocol"] == "ipc": self.messager["ipc"].write_block_by_sync(target_id) logger.info(f"finish write cache {item['request_id']}") diff --git a/tests/e2e/test_ernie_03b_router.py b/tests/e2e/test_ernie_03b_router.py new file mode 100644 index 00000000000..4037f6b775a --- /dev/null +++ b/tests/e2e/test_ernie_03b_router.py @@ -0,0 +1,486 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Test for router and mixed server + +import json +import os +import shutil +import signal +import socket +import subprocess +import sys +import time + +import pytest +import requests + +# Read ports from environment variables; use default values if not set +FD_API_PORT = int(os.getenv("FD_API_PORT", 8188)) +FD_ENGINE_QUEUE_PORT = int(os.getenv("FD_ENGINE_QUEUE_PORT", 8133)) +FD_METRICS_PORT = int(os.getenv("FD_METRICS_PORT", 8233)) +FD_CACHE_QUEUE_PORT = int(os.getenv("FD_CACHE_QUEUE_PORT", 8333)) +FD_ROUTER_PORT = int(os.getenv("FD_ROUTER_PORT", 8533)) + +# List of ports to clean before and after tests +PORTS_TO_CLEAN = [ + FD_API_PORT, + FD_ENGINE_QUEUE_PORT, + FD_METRICS_PORT, + FD_CACHE_QUEUE_PORT, + FD_API_PORT + 1, + FD_ENGINE_QUEUE_PORT + 1, + FD_METRICS_PORT + 1, + FD_CACHE_QUEUE_PORT + 1, + FD_ROUTER_PORT, +] + + +def is_port_open(host: str, port: int, timeout=1.0): + """ + Check if a TCP port is open on the given host. + Returns True if connection succeeds, False otherwise. + """ + try: + with socket.create_connection((host, port), timeout): + return True + except Exception: + return False + + +def check_service_health(base_url: str, timeout: int = 3) -> bool: + """ + Check the health status of a service. + + Args: + base_url (str): The base URL of the service, e.g. "http://127.0.0.1:8080" + timeout (int): Request timeout in seconds. + + Returns: + bool: True if the service is healthy, False otherwise. + """ + if not base_url.startswith("http"): + base_url = f"http://{base_url}" + url = f"{base_url.rstrip('/')}/health" + try: + resp = requests.get(url, timeout=timeout) + if resp.status_code == 200: + return True + else: + return False + except Exception: + return False + + +def get_registered_number(router_url) -> list: + """ + Get the number of registered models in the router. + + Args: + router_url (str): The base URL of the router, e.g. "http://localhost:8080". + + Returns: + int: The number of registered models. + """ + if not router_url.startswith("http"): + router_url = f"http://{router_url}" + + try: + response = requests.get(f"{router_url}/registered_number", timeout=60) + registered_numbers = response.json() + return registered_numbers + except Exception: + return {"mixed": 0, "prefill": 0, "decode": 0} + + +def kill_process_on_port(port: int): + """ + Kill processes that are listening on the given port. + Uses `lsof` to find process ids and sends SIGKILL. + """ + try: + output = subprocess.check_output(f"lsof -i:{port} -t", shell=True).decode().strip() + current_pid = os.getpid() + parent_pid = os.getppid() + for pid in output.splitlines(): + pid = int(pid) + if pid in (current_pid, parent_pid): + print(f"Skip killing current process (pid={pid}) on port {port}") + continue + os.kill(pid, signal.SIGKILL) + print(f"Killed process on port {port}, pid={pid}") + except subprocess.CalledProcessError: + pass + + +def clean_ports(): + """ + Kill all processes occupying the ports listed in PORTS_TO_CLEAN. + """ + for port in PORTS_TO_CLEAN: + kill_process_on_port(port) + time.sleep(2) + + +@pytest.fixture(scope="session", autouse=True) +def setup_and_run_server(): + """ + Pytest fixture that runs once per test session: + - Cleans ports before tests + - Starts the API server as a subprocess + - Waits for server port to open (up to 30 seconds) + - Tears down server after all tests finish + """ + print("Pre-test port cleanup...") + clean_ports() + + print("log dir clean ") + if os.path.exists("log_router") and os.path.isdir("log_router"): + shutil.rmtree("log_router") + if os.path.exists("log_server_0") and os.path.isdir("log_server_0"): + shutil.rmtree("log_server_0") + if os.path.exists("log_server_1") and os.path.isdir("log_server_1"): + shutil.rmtree("log_server_1") + + base_path = os.getenv("MODEL_PATH") + if base_path: + model_path = os.path.join(base_path, "ERNIE-4.5-0.3B-Paddle") + else: + model_path = "baidu/ERNIE-4.5-0.3B-Paddle" + print(f"model_path: {model_path}") + + # router + print("start router...") + env_router = os.environ.copy() + env_router["FD_LOG_DIR"] = "log_router" + router_log_path = "router.log" + + router_cmd = [ + sys.executable, + "-m", + "fastdeploy.router.launch", + "--port", + str(FD_ROUTER_PORT), + ] + + with open(router_log_path, "w") as logfile: + process_router = subprocess.Popen( + router_cmd, + stdout=logfile, + stderr=subprocess.STDOUT, + start_new_session=True, # Enables killing full group via os.killpg + env=env_router, + ) + + # server0 + print("start server0...") + env_server_0 = os.environ.copy() + env_server_0["CUDA_VISIBLE_DEVICES"] = "0" + env_server_0["ENABLE_V1_KVCACHE_SCHEDULER"] = "0" + env_server_0["FD_LOG_DIR"] = "log_server_0" + env_server_0["INFERENCE_MSG_QUEUE_ID"] = str(FD_API_PORT) + log_path = "server_0.log" + cmd = [ + sys.executable, + "-m", + "fastdeploy.entrypoints.openai.api_server", + "--model", + model_path, + "--port", + str(FD_API_PORT), + "--tensor-parallel-size", + "1", + "--engine-worker-queue-port", + str(FD_ENGINE_QUEUE_PORT), + "--metrics-port", + str(FD_METRICS_PORT), + "--cache-queue-port", + str(FD_CACHE_QUEUE_PORT), + "--max-model-len", + "8192", + "--max-num-seqs", + "20", + "--quantization", + "wint8", + "--router", + f"0.0.0.0:{FD_ROUTER_PORT}", + ] + + # Start subprocess in new process group + with open(log_path, "w") as logfile: + process_server_0 = subprocess.Popen( + cmd, + stdout=logfile, + stderr=subprocess.STDOUT, + start_new_session=True, # Enables killing full group via os.killpg + env=env_server_0, + ) + time.sleep(1) + + # server 1 + print("start server 1...") + env_server_1 = os.environ.copy() + env_server_1["CUDA_VISIBLE_DEVICES"] = "1" + env_server_1["ENABLE_V1_KVCACHE_SCHEDULER"] = "0" + env_server_1["INFERENCE_MSG_QUEUE_ID"] = str(FD_API_PORT + 1) + env_server_1["FD_LOG_DIR"] = "log_server_1" + log_path = "server_1.log" + cmd = [ + sys.executable, + "-m", + "fastdeploy.entrypoints.openai.api_server", + "--model", + model_path, + "--port", + str(FD_API_PORT + 1), + "--tensor-parallel-size", + "1", + "--engine-worker-queue-port", + str(FD_ENGINE_QUEUE_PORT + 1), + "--metrics-port", + str(FD_METRICS_PORT + 1), + "--cache-queue-port", + str(FD_CACHE_QUEUE_PORT + 1), + "--max-model-len", + "8192", + "--max-num-seqs", + "20", + "--quantization", + "wint8", + "--router", + f"0.0.0.0:{FD_ROUTER_PORT}", + ] + + # Start subprocess in new process group + with open(log_path, "w") as logfile: + process_server_1 = subprocess.Popen( + cmd, + stdout=logfile, + stderr=subprocess.STDOUT, + start_new_session=True, # Enables killing full group via os.killpg + env=env_server_1, + ) + + # Wait up to 300 seconds for API server to be ready + for _ in range(60): + registered_numbers = get_registered_number(f"0.0.0.0:{FD_ROUTER_PORT}") + if registered_numbers["mixed"] >= 2: + print("Mixed servers are both online") + break + time.sleep(5) + else: + print("[TIMEOUT] API server failed to start in 5 minutes. Cleaning up...") + try: + os.killpg(process_server_0.pid, signal.SIGTERM) + os.killpg(process_server_1.pid, signal.SIGTERM) + clean_ports() + except Exception as e: + print(f"Failed to kill process group: {e}") + raise RuntimeError(f"API server did not start on port {FD_API_PORT}") + + yield # Run tests + + print("\n===== Post-test server cleanup... =====") + try: + os.killpg(process_router.pid, signal.SIGTERM) + os.killpg(process_server_0.pid, signal.SIGTERM) + os.killpg(process_server_1.pid, signal.SIGTERM) + clean_ports() + print(f"server (pid={process_server_0.pid}) terminated") + print(f"server (pid={process_server_1.pid}) terminated") + except Exception as e: + print(f"Failed to terminate API server: {e}") + + +@pytest.fixture(scope="session") +def api_url(request): + """ + Returns the API endpoint URL for chat completions. + """ + return f"http://0.0.0.0:{FD_ROUTER_PORT}/v1/chat/completions" + + +@pytest.fixture(scope="session") +def metrics_url(request): + """ + Returns the metrics endpoint URL. + """ + return f"http://0.0.0.0:{FD_METRICS_PORT}/metrics" + + +@pytest.fixture +def headers(): + """ + Returns common HTTP request headers. + """ + return {"Content-Type": "application/json"} + + +def test_metrics_config(metrics_url): + timeout = 600 + url = metrics_url.replace("metrics", "config-info") + res = requests.get(url, timeout=timeout) + assert res.status_code == 200 + + +def send_request(url, payload, timeout=600): + """ + 发送请求到指定的URL,并返回响应结果。 + """ + headers = { + "Content-Type": "application/json", + } + + try: + res = requests.post(url, headers=headers, json=payload, timeout=timeout) + print("🟢 接收响应中...\n") + return res + except requests.exceptions.Timeout: + print(f"❌ 请求超时(超过 {timeout} 秒)") + return None + except requests.exceptions.RequestException as e: + print(f"❌ 请求失败:{e}") + return None + + +def get_stream_chunks(response): + """解析流式返回,生成chunk List[dict]""" + chunks = [] + + if response.status_code == 200: + for line in response.iter_lines(decode_unicode=True): + if line: + if line.startswith("data: "): + line = line[len("data: ") :] + + if line.strip() == "[DONE]": + break + + try: + chunk = json.loads(line) + chunks.append(chunk) + except Exception as e: + print(f"解析失败: {e}, 行内容: {line}") + else: + print(f"请求失败,状态码: {response.status_code}") + print("返回内容:", response.text) + + return chunks + + +def test_chat_usage_stream(api_url): + """测试流式chat usage""" + payload = { + "model": "default", + "temperature": 0, + "top_p": 0, + "seed": 33, + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "牛顿的三大运动定律是什么?"}, + ], + "max_tokens": 50, + "stream": True, + "stream_options": {"include_usage": True, "continuous_usage_stats": True}, + "metadata": {"min_tokens": 10}, + } + + response = send_request(url=api_url, payload=payload) + chunks = get_stream_chunks(response) + result = "".join([x["choices"][0]["delta"]["content"] for x in chunks[:-1]]) + print("Response:", result) + assert result != "", "结果为空" + usage = chunks[-1]["usage"] + total_tokens = usage["completion_tokens"] + usage["prompt_tokens"] + assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens" + assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens" + assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens" + + +def test_chat_usage_non_stream(api_url): + """测试非流式chat usage""" + payload = { + "model": "default", + "temperature": 0, + "top_p": 0, + "seed": 33, + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "牛顿的三大运动定律是什么?"}, + ], + "max_tokens": 50, + "stream": False, + "metadata": {"min_tokens": 10}, + } + + response = send_request(url=api_url, payload=payload).json() + usage = response["usage"] + result = response["choices"][0]["message"]["content"] + assert result != "", "结果为空" + total_tokens = usage["completion_tokens"] + usage["prompt_tokens"] + assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens" + assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens" + assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens" + + +def test_non_chat_usage_stream(api_url): + """测试流式非chat usage""" + payload = { + "model": "default", + "temperature": 0, + "top_p": 0, + "seed": 33, + "prompt": "牛顿的三大运动定律是什么?", + "max_tokens": 50, + "stream": True, + "stream_options": {"include_usage": True, "continuous_usage_stats": True}, + "metadata": {"min_tokens": 10}, + } + api_url = api_url.replace("chat/completions", "completions") + + response = send_request(url=api_url, payload=payload) + chunks = get_stream_chunks(response) + result = "".join([x["choices"][0]["text"] for x in chunks[:-1]]) + print("Response:", result) + assert result != "", "结果为空" + usage = chunks[-1]["usage"] + total_tokens = usage["completion_tokens"] + usage["prompt_tokens"] + assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens" + assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens" + assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens" + + +def test_non_chat_usage_non_stream(api_url): + """测试非流式非chat usage""" + payload = { + "model": "default", + "temperature": 0, + "top_p": 0, + "seed": 33, + "prompt": "牛顿的三大运动定律是什么?", + "max_tokens": 50, + "stream": False, + "metadata": {"min_tokens": 10}, + } + api_url = api_url.replace("chat/completions", "completions") + + response = send_request(url=api_url, payload=payload).json() + usage = response["usage"] + result = response["choices"][0]["text"] + print("Response:", result) + assert result != "", "结果为空" + total_tokens = usage["completion_tokens"] + usage["prompt_tokens"] + assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens" + assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens" + assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens"