From e13feadfe0b9a9ad1f633cd100b1a8c8adfcaa1b Mon Sep 17 00:00:00 2001 From: Yingge He <157551214+yinggeh@users.noreply.github.com> Date: Tue, 21 Oct 2025 11:55:28 -0700 Subject: [PATCH 1/7] ci: Update backend and tests to V1 engine (#100) --- .../accuracy_test/accuracy_test.py | 2 - ci/L0_backend_vllm/accuracy_test/test.sh | 12 --- .../metrics_test/vllm_metrics_test.py | 8 +- ci/L0_check_health_vllm/test.sh | 19 ++--- src/model.py | 21 +++-- src/utils/metrics.py | 81 ++++++++++++++----- 6 files changed, 83 insertions(+), 60 deletions(-) diff --git a/ci/L0_backend_vllm/accuracy_test/accuracy_test.py b/ci/L0_backend_vllm/accuracy_test/accuracy_test.py index 15c343a2..6000c1a2 100644 --- a/ci/L0_backend_vllm/accuracy_test/accuracy_test.py +++ b/ci/L0_backend_vllm/accuracy_test/accuracy_test.py @@ -190,7 +190,6 @@ def test_guided_decoding(self): sampling_params = SAMPLING_PARAMETERS guided_decoding_params = { "choice": ["Positive", "Negative"], - "backend": "outlines", } sampling_params["guided_decoding"] = json.dumps(guided_decoding_params) for i in range(len(GUIDED_PROMPTS)): @@ -245,7 +244,6 @@ def tearDown(self): if FLAGS.generate_guided_baseline: guided_decoding_params = { "choice": ["Positive", "Negative"], - "backend": "outlines", } guided_generation = GuidedDecodingParams(**guided_decoding_params) asyncio.run( diff --git a/ci/L0_backend_vllm/accuracy_test/test.sh b/ci/L0_backend_vllm/accuracy_test/test.sh index 8a94fff0..f575b7b1 100755 --- a/ci/L0_backend_vllm/accuracy_test/test.sh +++ b/ci/L0_backend_vllm/accuracy_test/test.sh @@ -48,17 +48,11 @@ RET=0 set +e # Need to generate baseline first, since running 2 vLLM engines causes # memory issues: https://github.com/vllm-project/vllm/issues/2248 -export VLLM_USE_V1=0 -export VLLM_WORKER_MULTIPROC_METHOD=spawn python3 $CLIENT_PY --generate-baseline >> $VLLM_ENGINE_LOG 2>&1 & BASELINE_PID=$! wait $BASELINE_PID python3 $CLIENT_PY --generate-guided-baseline > $VLLM_ENGINE_LOG 2>&1 & BASELINE_PID=$! wait $BASELINE_PID - -unset VLLM_USE_V1 -unset VLLM_WORKER_MULTIPROC_METHOD - set -e run_server @@ -88,12 +82,6 @@ set -e kill $SERVER_PID wait $SERVER_PID -# Check that warning about V1 Engine appears in log - this warning is expected -if ! grep -q "Engine in background thread is experimental on VLLM_USE_V1=1. Falling back to V0 Engine." $SERVER_LOG; then - echo -e "\n***\n*** ERROR: Expected warning about vLLM falling back to V0 Engine not found in logs.\n***" - RET=1 -fi - rm -rf models/ if [ $RET -eq 1 ]; then diff --git a/ci/L0_backend_vllm/metrics_test/vllm_metrics_test.py b/ci/L0_backend_vllm/metrics_test/vllm_metrics_test.py index b8ddeb49..0111056c 100644 --- a/ci/L0_backend_vllm/metrics_test/vllm_metrics_test.py +++ b/ci/L0_backend_vllm/metrics_test/vllm_metrics_test.py @@ -173,13 +173,11 @@ def test_vllm_metrics(self): # TODO: Revisit this test due to the removal of best_of def test_custom_sampling_params(self): # Adding sampling parameters for testing metrics. - # Definitions can be found here https://docs.vllm.ai/en/latest/dev/sampling_params.html - n, best_of = 2, 4 + # Definitions can be found here https://docs.vllm.ai/en/latest/api/vllm/sampling_params.html + n, temperature = 2, 1 custom_sampling_parameters = self.sampling_parameters.copy() - # Changing "temperature" because "best_of" must be 1 when using greedy - # sampling, i.e. "temperature": "0". custom_sampling_parameters.update( - {"n": str(n), "best_of": str(best_of), "temperature": "1"} + {"n": str(n), "temperature": str(temperature)} ) # Test vLLM metrics diff --git a/ci/L0_check_health_vllm/test.sh b/ci/L0_check_health_vllm/test.sh index 80668bcb..81bf4489 100755 --- a/ci/L0_check_health_vllm/test.sh +++ b/ci/L0_check_health_vllm/test.sh @@ -48,23 +48,24 @@ function enable_health_check { } VLLM_INSTALL_PATH="/usr/local/lib/python3.12/dist-packages/vllm" +VLLM_V1_ENGINE_PATH="$VLLM_INSTALL_PATH/v1/engine" function mock_vllm_async_llm_engine { # backup original file - mv $VLLM_INSTALL_PATH/engine/multiprocessing/client.py $VLLM_INSTALL_PATH/engine/multiprocessing/client.py.backup - cp $VLLM_INSTALL_PATH/engine/multiprocessing/client.py.backup $VLLM_INSTALL_PATH/engine/multiprocessing/client.py + mv $VLLM_V1_ENGINE_PATH/async_llm.py $VLLM_V1_ENGINE_PATH/async_llm.py.backup + cp $VLLM_V1_ENGINE_PATH/async_llm.py.backup $VLLM_V1_ENGINE_PATH/async_llm.py # overwrite the original check_health method - echo -e "" >> $VLLM_INSTALL_PATH/engine/multiprocessing/client.py - echo -e " async def check_health(self, check_count=[0]):" >> $VLLM_INSTALL_PATH/engine/multiprocessing/client.py - echo -e " check_count[0] += 1" >> $VLLM_INSTALL_PATH/engine/multiprocessing/client.py - echo -e " if check_count[0] > 1:" >> $VLLM_INSTALL_PATH/engine/multiprocessing/client.py - echo -e " raise RuntimeError(\"Simulated vLLM check_health() failure\")" >> $VLLM_INSTALL_PATH/engine/multiprocessing/client.py + echo -e "" >> $VLLM_V1_ENGINE_PATH/async_llm.py + echo -e " async def check_health(self, check_count=[0]):" >> $VLLM_V1_ENGINE_PATH/async_llm.py + echo -e " check_count[0] += 1" >> $VLLM_V1_ENGINE_PATH/async_llm.py + echo -e " if check_count[0] > 1:" >> $VLLM_V1_ENGINE_PATH/async_llm.py + echo -e " raise RuntimeError(\"Simulated vLLM check_health() failure\")" >> $VLLM_V1_ENGINE_PATH/async_llm.py } function unmock_vllm_async_llm_engine { # restore from backup - rm -f $VLLM_INSTALL_PATH/engine/multiprocessing/client.py - mv $VLLM_INSTALL_PATH/engine/multiprocessing/client.py.backup $VLLM_INSTALL_PATH/engine/multiprocessing/client.py + rm -f $VLLM_V1_ENGINE_PATH/async_llm.py + mv $VLLM_V1_ENGINE_PATH/async_llm.py.backup $VLLM_V1_ENGINE_PATH/async_llm.py } function test_check_health { diff --git a/src/model.py b/src/model.py index 7a135dcf..4145b71b 100644 --- a/src/model.py +++ b/src/model.py @@ -35,7 +35,6 @@ from typing import Dict, List import numpy as np -import torch import triton_python_backend_utils as pb_utils from PIL import Image from vllm.engine.arg_utils import AsyncEngineArgs @@ -45,7 +44,7 @@ from vllm.lora.request import LoRARequest from vllm.utils import random_uuid -from utils.metrics import VllmStatLogger +from utils.metrics import VllmStatLoggerFactory from utils.vllm_backend_utils import TritonSamplingParams _VLLM_ENGINE_ARGS_FILENAME = "model.json" @@ -184,12 +183,12 @@ def initialize(self, args): and not self._aync_engine_args.disable_log_stats ) - # Starting the vLLM engine and its event thread running the AsyncIO event loop. - self._init_engine() - # Setup vLLM metrics self._setup_metrics() + # Starting the vLLM engine and its event thread running the AsyncIO event loop. + self._init_engine() + # Starting the response thread. It allows vLLM to keep making progress while # response sender(s) are sending responses to server frontend. self._response_queue = queue.Queue() @@ -258,6 +257,7 @@ async def _run_llm_engine(self): async with build_async_engine_client_from_engine_args( engine_args=self._aync_engine_args, disable_frontend_multiprocessing=self._enable_metrics, + stat_loggers=self._vllm_metrics, ) as engine: # Capture the engine event loop and make it visible to other threads. self._event_loop = asyncio.get_running_loop() @@ -348,7 +348,7 @@ def _setup_lora(self): ) def _setup_metrics(self): - self._vllm_metrics = None + self._vllm_metrics = [] # TODO: Do not read metrics directly from the vLLM engine, read from prometheus # client to allow the use of ZMQ process when metrics are enabled. See # https://github.com/vllm-project/vllm/blob/v0.6.3.post1/vllm/entrypoints/openai/api_server.py#L222-L245 @@ -359,9 +359,8 @@ def _setup_metrics(self): "version": self.args["model_version"], } # Add vLLM custom metrics - vllm_config = self._llm_engine.engine.vllm_config - self._vllm_metrics = VllmStatLogger(labels, vllm_config, self.logger) - self._llm_engine.add_logger("triton", self._vllm_metrics) + factory = VllmStatLoggerFactory(labels, self.logger) + self._vllm_metrics.append(factory) except pb_utils.TritonModelException as e: if "metrics not supported" in str(e): # Metrics are disabled at the server @@ -785,8 +784,8 @@ def finalize(self): self._response_thread = None # Shutdown the metrics thread. - if self._vllm_metrics is not None: - self._vllm_metrics.finalize() + for stat_logger_factory in self._vllm_metrics: + stat_logger_factory.finalize() # When using parallel tensors, the stub process may not shutdown due to # unreleased references, so manually run the garbage collector once. diff --git a/src/utils/metrics.py b/src/utils/metrics.py index ecb044d4..644eb6d9 100644 --- a/src/utils/metrics.py +++ b/src/utils/metrics.py @@ -26,13 +26,12 @@ import queue import threading -from typing import Dict, List, Union +from typing import Dict, List, Optional, Union import triton_python_backend_utils as pb_utils from vllm.config import VllmConfig -from vllm.engine.metrics import StatLoggerBase as VllmStatLoggerBase -from vllm.engine.metrics import Stats as VllmStats -from vllm.engine.metrics import SupportsMetricsInfo, build_1_2_5_buckets +from vllm.v1.metrics.loggers import StatLoggerBase, build_1_2_5_buckets +from vllm.v1.metrics.stats import IterationStats, SchedulerStats class TritonMetrics: @@ -161,13 +160,35 @@ def __init__(self, labels: List[str], max_model_len: int): ) -class VllmStatLogger(VllmStatLoggerBase): +# Create a partially initialized callable that adapts VllmStatLogger to StatLoggerFactory interface +class VllmStatLoggerFactory: + def __init__(self, labels, log_logger): + self._labels = labels + self._log_logger = log_logger + self._instances_list = [] + + def __call__(self, vllm_config, engine_index): + stat_logger = VllmStatLogger( + self._labels, self._log_logger, vllm_config, engine_index + ) + self._instances_list.append(stat_logger) + return stat_logger + + def finalize(self): + for stat_logger in self._instances_list: + if stat_logger is not None: + stat_logger.finalize() + + +class VllmStatLogger(StatLoggerBase): """StatLogger is used as an adapter between vLLM stats collector and Triton metrics provider.""" - def __init__(self, labels: Dict, vllm_config: VllmConfig, log_logger) -> None: + def __init__( + self, labels: Dict, log_logger, vllm_config: VllmConfig, engine_index: int + ) -> None: # Tracked stats over current local logging interval. # local_interval not used here. It's for vLLM logs to stdout. - super().__init__(local_interval=0, vllm_config=vllm_config) + super().__init__(vllm_config=vllm_config, engine_index=engine_index) self.metrics = TritonMetrics( labels=labels, max_model_len=vllm_config.model_config.max_model_len ) @@ -176,12 +197,9 @@ def __init__(self, labels: Dict, vllm_config: VllmConfig, log_logger) -> None: # Starting the metrics thread. It allows vLLM to keep making progress # while reporting metrics to triton metrics service. self._logger_queue = queue.Queue() - self._logger_thread = threading.Thread(target=self.logger_loop) + self._logger_thread = threading.Thread(target=self._logger_loop) self._logger_thread.start() - def info(self, type: str, obj: SupportsMetricsInfo) -> None: - pass - def _log_counter(self, counter, data: Union[int, float]) -> None: """Convenience function for logging to counter. @@ -208,7 +226,12 @@ def _log_histogram(self, histogram, data: Union[List[int], List[float]]) -> None for datum in data: self._logger_queue.put_nowait((histogram, "observe", datum)) - def log(self, stats: VllmStats) -> None: + def record( + self, + scheduler_stats: Optional[SchedulerStats], + iteration_stats: Optional[IterationStats], + engine_idx: int = 0, + ) -> None: """Report stats to Triton metrics server. Args: @@ -217,38 +240,54 @@ def log(self, stats: VllmStats) -> None: Returns: None """ + + # Parse finished request stats into lists + e2e_latency: List[float] = [] + num_prompt_tokens: List[int] = [] + num_generation_tokens: List[int] = [] + for finished_req in iteration_stats.finished_requests: + e2e_latency.append(finished_req.e2e_latency) + num_prompt_tokens.append(finished_req.num_prompt_tokens) + num_generation_tokens.append(finished_req.num_generation_tokens) + # The list of vLLM metrics reporting to Triton is also documented here. # https://github.com/triton-inference-server/vllm_backend/blob/main/README.md#triton-metrics counter_metrics = [ - (self.metrics.counter_prompt_tokens, stats.num_prompt_tokens_iter), - (self.metrics.counter_generation_tokens, stats.num_generation_tokens_iter), + (self.metrics.counter_prompt_tokens, iteration_stats.num_prompt_tokens), + ( + self.metrics.counter_generation_tokens, + iteration_stats.num_generation_tokens, + ), ] histogram_metrics = [ ( self.metrics.histogram_time_to_first_token, - stats.time_to_first_tokens_iter, + iteration_stats.time_to_first_tokens_iter, ), ( self.metrics.histogram_time_per_output_token, - stats.time_per_output_tokens_iter, + iteration_stats.inter_token_latencies_iter, ), - (self.metrics.histogram_e2e_time_request, stats.time_e2e_requests), + (self.metrics.histogram_e2e_time_request, e2e_latency), ( self.metrics.histogram_num_prompt_tokens_request, - stats.num_prompt_tokens_requests, + num_prompt_tokens, ), ( self.metrics.histogram_num_generation_tokens_request, - stats.num_generation_tokens_requests, + num_generation_tokens, ), - (self.metrics.histogram_n_request, stats.n_requests), + (self.metrics.histogram_n_request, iteration_stats.n_params_iter), ] for metric, data in counter_metrics: self._log_counter(metric, data) for metric, data in histogram_metrics: self._log_histogram(metric, data) - def logger_loop(self): + def log_engine_initialized(self) -> None: + pass + + def _logger_loop(self): while True: item = self._logger_queue.get() # To signal shutdown a None item will be added to the queue. From 460bd458071c433f9f14894641e70f25b5563cb7 Mon Sep 17 00:00:00 2001 From: Misha Chornyi <99709299+mc-nv@users.noreply.github.com> Date: Fri, 24 Oct 2025 18:01:51 -0700 Subject: [PATCH 2/7] fix: Update environment configuration (#102) --- ci/L0_backend_vllm/test.sh | 3 +++ 1 file changed, 3 insertions(+) diff --git a/ci/L0_backend_vllm/test.sh b/ci/L0_backend_vllm/test.sh index b4d27357..c3ff6c8e 100755 --- a/ci/L0_backend_vllm/test.sh +++ b/ci/L0_backend_vllm/test.sh @@ -28,6 +28,9 @@ RET=0 SUBTESTS="accuracy_test request_cancellation enabled_stream vllm_backend metrics_test" +export C_INCLUDE_PATH=/usr/local/cuda/include:$C_INCLUDE_PATH +export TRITON_PTXAS_PATH=/usr/local/cuda/bin/ptxas + python3 -m pip install tritonclient[grpc] for TEST in ${SUBTESTS}; do From f41cfdb2d7d2dd569a5871a0cc3aaf50c4196aec Mon Sep 17 00:00:00 2001 From: Yingge He <157551214+yinggeh@users.noreply.github.com> Date: Mon, 27 Oct 2025 10:06:47 -0700 Subject: [PATCH 3/7] test: Allow more time for cold start on SBSA (#101) --- ci/L0_check_health_vllm/test.sh | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/ci/L0_check_health_vllm/test.sh b/ci/L0_check_health_vllm/test.sh index 81bf4489..3918d3d8 100755 --- a/ci/L0_check_health_vllm/test.sh +++ b/ci/L0_check_health_vllm/test.sh @@ -31,11 +31,12 @@ source ../common/util.sh pip3 install pytest==8.1.1 pip3 install tritonclient[grpc] +rm -f *.log *.report.xml RET=0 function setup_model_repository { - local sample_model_repo_path=${1:-"../../samples/model_repository"} - rm -rf models vllm_baseline_output.pkl && mkdir -p models + local sample_model_repo_path="../../samples/model_repository" + rm -rf models && mkdir -p models cp -r $sample_model_repo_path/vllm_model models/vllm_opt } @@ -94,8 +95,12 @@ function test_check_health { } # Test health check unspecified +# Cold start on SBSA device can take longer than default 120 seconds +PREV_SERVER_TIMEOUT=$SERVER_TIMEOUT +SERVER_TIMEOUT=240 setup_model_repository test_check_health "health_check_unspecified" "test_vllm_is_healthy" +SERVER_TIMEOUT=$PREV_SERVER_TIMEOUT # Test health check disabled setup_model_repository From d50bda112e2fa2ce5135cf9069f4e56b51c27e5f Mon Sep 17 00:00:00 2001 From: Misha Chornyi <99709299+mc-nv@users.noreply.github.com> Date: Wed, 29 Oct 2025 11:33:18 -0700 Subject: [PATCH 4/7] make utilization lower (#103) Co-authored-by: Yingge He <157551214+yinggeh@users.noreply.github.com> --- samples/model_repository/vllm_model/1/model.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/samples/model_repository/vllm_model/1/model.json b/samples/model_repository/vllm_model/1/model.json index 50ed9637..657953c4 100644 --- a/samples/model_repository/vllm_model/1/model.json +++ b/samples/model_repository/vllm_model/1/model.json @@ -1,5 +1,5 @@ { "model":"facebook/opt-125m", - "gpu_memory_utilization": 0.5, + "gpu_memory_utilization": 0.1, "enforce_eager": true } From 2c3e148868fea7d400303aae6371fe605c0a46f6 Mon Sep 17 00:00:00 2001 From: Yingge He Date: Wed, 29 Oct 2025 18:13:55 -0700 Subject: [PATCH 5/7] Support embedding endpoint in OpenAI API frontend --- src/model.py | 310 ++++++++-------------------------- src/utils/request.py | 385 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 458 insertions(+), 237 deletions(-) create mode 100644 src/utils/request.py diff --git a/src/model.py b/src/model.py index 7a135dcf..7663b0db 100644 --- a/src/model.py +++ b/src/model.py @@ -25,25 +25,20 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import asyncio -import base64 import gc import json import os import queue import threading -from io import BytesIO from typing import Dict, List import numpy as np import torch import triton_python_backend_utils as pb_utils -from PIL import Image from vllm.engine.arg_utils import AsyncEngineArgs from vllm.entrypoints.openai.api_server import ( build_async_engine_client_from_engine_args, ) -from vllm.lora.request import LoRARequest -from vllm.utils import random_uuid from utils.metrics import VllmStatLogger from utils.vllm_backend_utils import TritonSamplingParams @@ -74,6 +69,7 @@ def auto_complete_config(cls, auto_complete_model_config): def _auto_complete_inputs_and_outputs(auto_complete_model_config): # Inputs expected by the backend. inputs = [ + # TODO: Support array input {"name": "text_input", "data_type": "TYPE_STRING", "dims": [1]}, { "name": "image", @@ -129,6 +125,13 @@ def _auto_complete_inputs_and_outputs(auto_complete_model_config): "dims": [1], "optional": True, }, + # Tentative input reserved for embedding requests in OpenAI frontend. May change in the future. + { + "name": "embedding_request", + "data_type": "TYPE_STRING", + "dims": [1], + "optional": True, + }, ] # Outputs expected by the backend. outputs = [ @@ -396,6 +399,35 @@ def _response_loop(self): if response_flag == pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL: self._ongoing_request_count -= 1 + def respond_error(self, request, error_message, triton_error): + output_tensor = pb_utils.Tensor( + "text_output", + np.asarray([error_message], dtype=self.output_dtype), + ) + response = pb_utils.InferenceResponse( + output_tensors=[output_tensor], error=triton_error + ) + response_sender = request.get_response_sender() + response_sender.send( + response, flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL + ) + + def _validate_request_task_name(self, request): + embedding_request = pb_utils.get_input_tensor_by_name( + request, "embedding_request" + ) + if embedding_request is None: + request_task_name = "generate" + else: + request_task_name = "embed" + + if request_task_name not in self.supported_tasks: + raise ValueError( + f"Model {self.args['model_name']} does not support '{request_task_name}' request" + ) + + return request_task_name + def execute(self, requests): if self._enable_health_check and not self._check_health(requests): return None @@ -405,11 +437,11 @@ def execute(self, requests): assert ( self._llm_engine_shutdown_event.is_set() is False ), "Cannot create tasks after shutdown has been requested" - coro = self._generate(request) + coro = self._infer(request) asyncio.run_coroutine_threadsafe(coro, self._event_loop) return None - async def _generate(self, request): + async def _infer(self, request): response_sender = request.get_response_sender() response_state = { "response_sender": response_sender, @@ -419,27 +451,21 @@ async def _generate(self, request): self._ongoing_request_count += 1 decrement_ongoing_request_count = True try: - request_id = random_uuid() - ( - prompt, - stream, - prepend_input, - parameters, - additional_outputs, - ) = self._get_input_tensors(request) - - sampling_params = TritonSamplingParams.from_dict(parameters, self.logger) - lora_name = sampling_params.lora_name - lora_request = None - if lora_name is not None: - lora_id = str(self.supported_loras.index(lora_name) + 1) - lora_int_id = int(lora_id) - lora_local_path = self.lora_repository[lora_name] - lora_request = LoRARequest(lora_id, lora_int_id, lora_local_path) - - response_iterator = self._llm_engine.generate( - prompt, sampling_params, request_id, lora_request=lora_request - ) + request_task_name = self._validate_request_task_name(request) + if request_task_name == "generate": + request = GenerateRequest( + request, self._llm_engine.generate, self.output_dtype + ) + elif request_task_name == "embed": + request = EmbedRequest( + request, self._llm_engine.encode, self.output_dtype + ) + else: + raise ValueError( + f"VLLM backend does not support '{request_task_name}' request" + ) + + response_iterator = request.execute() request_output_state = {} async for request_output in response_iterator: @@ -447,14 +473,14 @@ async def _generate(self, request): # the response state if streaming. If not streaming, cancellation state # needs to be checked here. is_cancelled = response_state["is_cancelled"] - if not stream: + if not request.stream: is_cancelled = response_sender.is_cancelled() if is_cancelled: self.logger.log_info("[vllm] Cancelling the request") - await self._llm_engine.abort(request_id) + await self._llm_engine.abort(request.id) self.logger.log_info("[vllm] Successfully cancelled the request") - if stream: + if request.stream: # Add cancelled final response to response loop. response_state["last_response_generated"] = True response = pb_utils.InferenceResponse( @@ -472,12 +498,12 @@ async def _generate(self, request): break # Send each response if streaming. - if stream: - response = self._create_response( + if request.stream: + response = request.create_response( request_output_state, request_output, prepend_input=False, - additional_outputs=additional_outputs, + additional_outputs=request.additional_outputs, ) flags = 0 if request_output.finished: @@ -487,15 +513,20 @@ async def _generate(self, request): self._response_queue.put_nowait((response_state, response, flags)) # Send the last response which contains all the outputs if not streaming. - if not stream: - response_sender.send( - self._create_response( + if not request.stream: + if request_task_name == "generate": + response = request.create_response( request_output_state={}, request_output=request_output, - prepend_input=prepend_input, - additional_outputs=additional_outputs, - ), - flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL, + prepend_input=request.prepend_input, + additional_outputs=request.additional_outputs, + ) + else: + response = request.create_response( + request_output=request_output, + ) + response_sender.send( + response, flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL ) except Exception as e: @@ -516,191 +547,6 @@ async def _generate(self, request): if decrement_ongoing_request_count: self._ongoing_request_count -= 1 - def _get_input_tensors(self, request): - # prompt - prompt = pb_utils.get_input_tensor_by_name(request, "text_input").as_numpy()[0] - if isinstance(prompt, bytes): - prompt = prompt.decode("utf-8") - - # image - images = pb_utils.get_input_tensor_by_name(request, "image") - if images: - images_vllm = [] - for image_np in images.as_numpy(): - image_b = base64.b64decode(image_np.decode("utf-8")) - image_rgb = Image.open(BytesIO(image_b)).convert("RGB") - images_vllm.append(image_rgb) - if len(images_vllm) > 0: - prompt = { - "prompt": prompt, - "multi_modal_data": {"image": images_vllm}, - } - - # stream - stream = pb_utils.get_input_tensor_by_name(request, "stream") - if stream: - stream = stream.as_numpy()[0] - else: - stream = False - - # prepend_input / exclude_input_in_output - prepend_input = pb_utils.get_input_tensor_by_name( - request, "exclude_input_in_output" - ) - if prepend_input: - # When `exclude_input_in_output` is False, we want to prepend input prompt - # to output, thus prepend_input should be True, and vice versa. - prepend_input = not prepend_input.as_numpy()[0] - elif prepend_input is None and stream: - prepend_input = False - else: - prepend_input = True - if prepend_input and stream: - raise ValueError( - "When streaming, `exclude_input_in_output` = False is not allowed." - ) - - # parameters / sampling_parameters - # An alternative mechanism to receive serialized parameters as an input - # tensor, because request parameters are not yet supported via BLS. - sampling_parameters = pb_utils.get_input_tensor_by_name( - request, "sampling_parameters" - ) - if sampling_parameters: - parameters = sampling_parameters.as_numpy()[0].decode("utf-8") - else: - parameters = request.parameters() - - # additional outputs - additional_outputs = { - "return_finish_reason": None, - "return_cumulative_logprob": None, - "return_logprobs": None, - "return_num_input_tokens": None, - "return_num_output_tokens": None, - } - for tensor_name in additional_outputs.keys(): - tensor = pb_utils.get_input_tensor_by_name(request, tensor_name) - if tensor: - tensor = bool(tensor.as_numpy()[0]) - else: - tensor = False - additional_outputs[tensor_name] = tensor - - return prompt, stream, prepend_input, parameters, additional_outputs - - def _create_response( - self, request_output_state, request_output, prepend_input, additional_outputs - ): - output_tensors = [] - - # text_output - prepend_prompt = "" - if "prev_lens_text_output" not in request_output_state: - # this is the first response - if prepend_input: - prepend_prompt = request_output.prompt - request_output_state["prev_lens_text_output"] = [0] * len( - request_output.outputs - ) - prev_lens = request_output_state["prev_lens_text_output"] - text_output = [ - (prepend_prompt + output.text[prev_len:]).encode("utf-8") - for output, prev_len in zip(request_output.outputs, prev_lens) - ] - request_output_state["prev_lens_text_output"] = [ - len(output.text) for output in request_output.outputs - ] - output_tensors.append( - pb_utils.Tensor( - "text_output", np.asarray(text_output, dtype=self.output_dtype) - ) - ) - - # finish_reason - if additional_outputs["return_finish_reason"]: - finish_reason = [ - str(output.finish_reason) for output in request_output.outputs - ] - output_tensors.append( - pb_utils.Tensor( - "finish_reason", np.asarray(finish_reason, dtype=np.object_) - ) - ) - - # cumulative_logprob - if additional_outputs["return_cumulative_logprob"]: - cumulative_logprob = [ - output.cumulative_logprob for output in request_output.outputs - ] - output_tensors.append( - pb_utils.Tensor( - "cumulative_logprob", - np.asarray(cumulative_logprob, dtype=np.float32), - ) - ) - - # logprobs - # https://github.com/vllm-project/vllm/blob/v0.6.3.post1/vllm/sequence.py#L37-L58 - if additional_outputs["return_logprobs"]: - if "prev_lens_logprobs" not in request_output_state: - request_output_state["prev_lens_logprobs"] = [0] * len( - request_output.outputs - ) - logprobs = [] - for i in range(len(request_output.outputs)): - output = request_output.outputs[i] - if output.logprobs is None: - logprobs.append("null".encode("utf-8")) - continue - prev_len = request_output_state["prev_lens_logprobs"][i] - request_output_state["prev_lens_logprobs"][i] = len(output.logprobs) - logprobs_py = [] - for logprob_d_vllm in output.logprobs[prev_len:]: - logprob_d_py = {} - for token_id, logprob_vllm in logprob_d_vllm.items(): - logprob_d_py[token_id] = { - "logprob": logprob_vllm.logprob, - "rank": logprob_vllm.rank, - "decoded_token": logprob_vllm.decoded_token, - } - logprobs_py.append(logprob_d_py) - logprobs.append(json.dumps(logprobs_py).encode("utf-8")) - output_tensors.append( - pb_utils.Tensor("logprobs", np.asarray(logprobs, dtype=np.object_)) - ) - - # num_input_tokens - if additional_outputs["return_num_input_tokens"]: - num_input_tokens = len(request_output.prompt_token_ids) - output_tensors.append( - pb_utils.Tensor( - "num_input_tokens", np.asarray(num_input_tokens, dtype=np.uint32) - ) - ) - - # num_output_tokens - if additional_outputs["return_num_output_tokens"]: - if "prev_lens_num_output_tokens" not in request_output_state: - request_output_state["prev_lens_num_output_tokens"] = [0] * len( - request_output.outputs - ) - prev_lens = request_output_state["prev_lens_num_output_tokens"] - num_output_tokens = [ - (len(output.token_ids) - prev_len) - for output, prev_len in zip(request_output.outputs, prev_lens) - ] - request_output_state["prev_lens_num_output_tokens"] = [ - len(output.token_ids) for output in request_output.outputs - ] - output_tensors.append( - pb_utils.Tensor( - "num_output_tokens", np.asarray(num_output_tokens, dtype=np.uint32) - ) - ) - - return pb_utils.InferenceResponse(output_tensors=output_tensors) - def _verify_loras(self, request): # We will check if the requested lora exists here, if not we will send a # response with `LoRA not found` information. In this way we may avoid @@ -730,17 +576,7 @@ def _verify_loras(self, request): self.logger.log_info(f"[vllm] LoRA {lora_name} not found.") if lora_error is not None: - output_tensor = pb_utils.Tensor( - "text_output", - np.asarray(["[Error] Unsupported LoRA."], dtype=self.output_dtype), - ) - response = pb_utils.InferenceResponse( - output_tensors=[output_tensor], error=lora_error - ) - response_sender = request.get_response_sender() - response_sender.send( - response, flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL - ) + self.respond_error(request, lora_error.message, lora_error) else: verified_request = request return verified_request diff --git a/src/utils/request.py b/src/utils/request.py new file mode 100644 index 00000000..e8ae26aa --- /dev/null +++ b/src/utils/request.py @@ -0,0 +1,385 @@ +# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import base64 +import json +from abc import abstractmethod +from io import BytesIO + +import numpy as np +import triton_python_backend_utils as pb_utils +from PIL import Image +from vllm.inputs.data import TokensPrompt +from vllm.lora.request import LoRARequest +from vllm.pooling_params import PoolingParams +from vllm.utils import random_uuid + +from utils.vllm_backend_utils import TritonSamplingParams + + +class RequestBase: + def __init__(self, request, executor_callback, output_dtype): + self.request = request + self.executor_callback = executor_callback + self.output_dtype = output_dtype + self.id = random_uuid() + self.stream = False + self.prepend_input = False + + @abstractmethod + def _get_input_tensors(self): + raise NotImplementedError + + @abstractmethod + def execute(self): + raise NotImplementedError + + @abstractmethod + def create_response(self, *args, **kwargs): + raise NotImplementedError + + +class GenerateRequest(RequestBase): + def __init__(self, request, executor_callback, output_dtype): + super().__init__(request, executor_callback, output_dtype) + + def _get_input_tensors(self): + # prompt + prompt = pb_utils.get_input_tensor_by_name( + self.request, "text_input" + ).as_numpy()[0] + if isinstance(prompt, bytes): + prompt = prompt.decode("utf-8") + + # image + images = pb_utils.get_input_tensor_by_name(self.request, "image") + if images: + images_vllm = [] + for image_np in images.as_numpy(): + image_b = base64.b64decode(image_np.decode("utf-8")) + image_rgb = Image.open(BytesIO(image_b)).convert("RGB") + images_vllm.append(image_rgb) + if len(images_vllm) > 0: + prompt = { + "prompt": prompt, + "multi_modal_data": {"image": images_vllm}, + } + + # stream + stream = pb_utils.get_input_tensor_by_name(self.request, "stream") + if stream: + stream = stream.as_numpy()[0] + else: + stream = False + + # prepend_input / exclude_input_in_output + prepend_input = pb_utils.get_input_tensor_by_name( + self.request, "exclude_input_in_output" + ) + if prepend_input: + # When `exclude_input_in_output` is False, we want to prepend input prompt + # to output, thus prepend_input should be True, and vice versa. + prepend_input = not prepend_input.as_numpy()[0] + elif prepend_input is None and stream: + prepend_input = False + else: + prepend_input = True + if prepend_input and stream: + raise ValueError( + "When streaming, `exclude_input_in_output` = False is not allowed." + ) + + # parameters / sampling_parameters + # An alternative mechanism to receive serialized parameters as an input + # tensor, because request parameters are not yet supported via BLS. + sampling_parameters = pb_utils.get_input_tensor_by_name( + self.request, "sampling_parameters" + ) + if sampling_parameters: + parameters = sampling_parameters.as_numpy()[0].decode("utf-8") + else: + parameters = self.request.parameters() + + # additional outputs + additional_outputs = { + "return_finish_reason": None, + "return_cumulative_logprob": None, + "return_logprobs": None, + "return_num_input_tokens": None, + "return_num_output_tokens": None, + } + for tensor_name in additional_outputs.keys(): + tensor = pb_utils.get_input_tensor_by_name(self.request, tensor_name) + if tensor: + tensor = bool(tensor.as_numpy()[0]) + else: + tensor = False + additional_outputs[tensor_name] = tensor + + return prompt, stream, prepend_input, parameters, additional_outputs + + async def execute(self): + ( + prompt, + self.stream, + self.prepend_input, + parameters, + self.additional_outputs, + ) = self._get_input_tensors() + + sampling_params = TritonSamplingParams.from_dict(parameters, self.logger) + lora_name = sampling_params.lora_name + lora_request = None + if lora_name is not None: + lora_id = str(self.supported_loras.index(lora_name) + 1) + lora_int_id = int(lora_id) + lora_local_path = self.lora_repository[lora_name] + lora_request = LoRARequest(lora_id, lora_int_id, lora_local_path) + + response_iterator = self.executor_callback( + prompt, sampling_params, self.id, lora_request=lora_request + ) + + async for response in response_iterator: + yield response + + def create_response(self, request_output_state, request_output, prepend_input): + output_tensors = [] + + # text_output + prepend_prompt = "" + if "prev_lens_text_output" not in request_output_state: + # this is the first response + if prepend_input: + prepend_prompt = request_output.prompt + request_output_state["prev_lens_text_output"] = [0] * len( + request_output.outputs + ) + prev_lens = request_output_state["prev_lens_text_output"] + text_output = [ + (prepend_prompt + output.text[prev_len:]).encode("utf-8") + for output, prev_len in zip(request_output.outputs, prev_lens) + ] + request_output_state["prev_lens_text_output"] = [ + len(output.text) for output in request_output.outputs + ] + output_tensors.append( + pb_utils.Tensor( + "text_output", np.asarray(text_output, dtype=self.output_dtype) + ) + ) + + # finish_reason + if self.additional_outputs["return_finish_reason"]: + finish_reason = [ + str(output.finish_reason) for output in request_output.outputs + ] + output_tensors.append( + pb_utils.Tensor( + "finish_reason", np.asarray(finish_reason, dtype=np.object_) + ) + ) + + # cumulative_logprob + if self.additional_outputs["return_cumulative_logprob"]: + cumulative_logprob = [ + output.cumulative_logprob for output in request_output.outputs + ] + output_tensors.append( + pb_utils.Tensor( + "cumulative_logprob", + np.asarray(cumulative_logprob, dtype=np.float32), + ) + ) + + # logprobs + # https://github.com/vllm-project/vllm/blob/v0.6.3.post1/vllm/sequence.py#L37-L58 + if self.additional_outputs["return_logprobs"]: + if "prev_lens_logprobs" not in request_output_state: + request_output_state["prev_lens_logprobs"] = [0] * len( + request_output.outputs + ) + logprobs = [] + for i in range(len(request_output.outputs)): + output = request_output.outputs[i] + if output.logprobs is None: + logprobs.append("null".encode("utf-8")) + continue + prev_len = request_output_state["prev_lens_logprobs"][i] + request_output_state["prev_lens_logprobs"][i] = len(output.logprobs) + logprobs_py = [] + for logprob_d_vllm in output.logprobs[prev_len:]: + logprob_d_py = {} + for token_id, logprob_vllm in logprob_d_vllm.items(): + logprob_d_py[token_id] = { + "logprob": logprob_vllm.logprob, + "rank": logprob_vllm.rank, + "decoded_token": logprob_vllm.decoded_token, + } + logprobs_py.append(logprob_d_py) + logprobs.append(json.dumps(logprobs_py).encode("utf-8")) + output_tensors.append( + pb_utils.Tensor("logprobs", np.asarray(logprobs, dtype=np.object_)) + ) + + # num_input_tokens + if self.additional_outputs["return_num_input_tokens"]: + num_input_tokens = len(request_output.prompt_token_ids) + output_tensors.append( + pb_utils.Tensor( + "num_input_tokens", np.asarray(num_input_tokens, dtype=np.uint32) + ) + ) + + # num_output_tokens + if self.additional_outputs["return_num_output_tokens"]: + if "prev_lens_num_output_tokens" not in request_output_state: + request_output_state["prev_lens_num_output_tokens"] = [0] * len( + request_output.outputs + ) + prev_lens = request_output_state["prev_lens_num_output_tokens"] + num_output_tokens = [ + (len(output.token_ids) - prev_len) + for output, prev_len in zip(request_output.outputs, prev_lens) + ] + request_output_state["prev_lens_num_output_tokens"] = [ + len(output.token_ids) for output in request_output.outputs + ] + output_tensors.append( + pb_utils.Tensor( + "num_output_tokens", np.asarray(num_output_tokens, dtype=np.uint32) + ) + ) + + return pb_utils.InferenceResponse(output_tensors=output_tensors) + + +class EmbedRequest(RequestBase): + def __init__(self, request, executor_callback, output_dtype): + super().__init__(request, executor_callback, output_dtype) + + def _get_input_tensors(self): + embedding_request = pb_utils.get_input_tensor_by_name( + self.request, "embedding_request" + ).as_numpy()[0] + embedding_request = json.loads(embedding_request.decode("utf-8")) + # prompt + prompt = embedding_request["input"] + if isinstance(prompt, str): + pass # do nothing + elif ( + isinstance(prompt, list) and len(prompt) > 0 and isinstance(prompt[0], int) + ): + # Single list of token IDs + prompt = TokensPrompt(prompt_token_ids=prompt) + + # pooling_params + pooling_params = self._to_pooling_params(embedding_request) + + # additional outputs + additional_outputs = { + "return_num_input_tokens": None, + "return_num_output_tokens": None, + } + for tensor_name in additional_outputs.keys(): + tensor = pb_utils.get_input_tensor_by_name(self.request, tensor_name) + if tensor: + tensor = bool(tensor.as_numpy()[0]) + else: + tensor = False + additional_outputs[tensor_name] = tensor + + return prompt, pooling_params, additional_outputs + + async def execute(self): + ( + prompt, + pooling_params, + self.additional_outputs, + ) = self._get_input_tensors() + + # Create PoolingParams for embeddings + response_iterator = self.executor_callback(prompt, pooling_params, self.id) + + # Yield each response from the async iterator + async for response in response_iterator: + yield response + + def _to_pooling_params(self, embedding_request: dict): + pooling_params_dict = embedding_request.get("pooling_params", {}) + + pooling_params = PoolingParams(task="embed") + dims = None + if "dimensions" in pooling_params_dict: + dims = pooling_params_dict["dimensions"][0] + pooling_params = PoolingParams(dimensions=dims, task="embed") + return pooling_params + + def create_response(self, request_output): + output_tensors = [] + + # Extract embedding vector from output + # PoolingRequestOutput.outputs is a PoolingOutput with .data (torch.Tensor) + pooling_data = request_output.outputs.data + + # Convert torch tensor to numpy array then to list for JSON serialization + if hasattr(pooling_data, "cpu"): + # It's a torch tensor - move to CPU and convert to numpy + embedding_array = pooling_data.cpu().numpy() + else: + # Already numpy or list + embedding_array = np.array(pooling_data, dtype=np.float32) + + # Create response tensor - for embeddings, we use text_output to return the vector + # (This is a simplification - you may want to define a proper embedding output tensor) + embedding_list = ( + embedding_array.tolist() + if hasattr(embedding_array, "tolist") + else list(embedding_array) + ) + embedding_str = json.dumps(embedding_list) + output_tensors.append( + pb_utils.Tensor( + "text_output", np.asarray([embedding_str], dtype=self.output_dtype) + ) + ) + + # num_input_tokens + if self.additional_outputs["return_num_input_tokens"]: + num_input_tokens = len(request_output.prompt_token_ids) + output_tensors.append( + pb_utils.Tensor( + "num_input_tokens", np.asarray(num_input_tokens, dtype=np.uint32) + ) + ) + + # For embeddings, num_output_tokens is 0 (no generation happened) + if self.additional_outputs["return_num_output_tokens"]: + output_tensors.append( + pb_utils.Tensor("num_output_tokens", np.asarray(0, dtype=np.uint32)) + ) + + return pb_utils.InferenceResponse(output_tensors=output_tensors) From 943ee5f8c0e8f98f3adb8f7dd8acd5abc6ce23c7 Mon Sep 17 00:00:00 2001 From: Yingge He Date: Thu, 30 Oct 2025 15:45:48 -0700 Subject: [PATCH 6/7] Address comment and rebase to r25.10 (V1 API) --- src/model.py | 17 +++++++----- src/utils/request.py | 61 +++++++++++++++++++++++--------------------- 2 files changed, 42 insertions(+), 36 deletions(-) diff --git a/src/model.py b/src/model.py index 25a9e7d7..f5957b28 100644 --- a/src/model.py +++ b/src/model.py @@ -40,7 +40,7 @@ ) from utils.metrics import VllmStatLoggerFactory -from utils.vllm_backend_utils import TritonSamplingParams +from utils.request import EmbedRequest, GenerateRequest _VLLM_ENGINE_ARGS_FILENAME = "model.json" _MULTI_LORA_ARGS_FILENAME = "multi_lora.json" @@ -249,6 +249,11 @@ def _init_engine(self): self._event_thread = None raise e + # Get supported tasks from the engine running in another thread + self.supported_tasks = asyncio.run_coroutine_threadsafe( + self._llm_engine.get_supported_tasks(), self._event_loop + ).result() + async def _run_llm_engine(self): # Counter to keep track of ongoing request counts. self._ongoing_request_count = 0 @@ -453,11 +458,11 @@ async def _infer(self, request): request_task_name = self._validate_request_task_name(request) if request_task_name == "generate": request = GenerateRequest( - request, self._llm_engine.generate, self.output_dtype + request, self._llm_engine.generate, self.output_dtype, self.logger ) elif request_task_name == "embed": request = EmbedRequest( - request, self._llm_engine.encode, self.output_dtype + request, self._llm_engine.encode, self.output_dtype, self.logger ) else: raise ValueError( @@ -499,10 +504,9 @@ async def _infer(self, request): # Send each response if streaming. if request.stream: response = request.create_response( - request_output_state, request_output, + request_output_state, prepend_input=False, - additional_outputs=request.additional_outputs, ) flags = 0 if request_output.finished: @@ -515,10 +519,9 @@ async def _infer(self, request): if not request.stream: if request_task_name == "generate": response = request.create_response( - request_output_state={}, request_output=request_output, + request_output_state={}, prepend_input=request.prepend_input, - additional_outputs=request.additional_outputs, ) else: response = request.create_response( diff --git a/src/utils/request.py b/src/utils/request.py index e8ae26aa..ff9b6c6b 100644 --- a/src/utils/request.py +++ b/src/utils/request.py @@ -28,12 +28,19 @@ import json from abc import abstractmethod from io import BytesIO +from typing import Callable import numpy as np import triton_python_backend_utils as pb_utils from PIL import Image from vllm.inputs.data import TokensPrompt from vllm.lora.request import LoRARequest +from vllm.outputs import ( + EmbeddingOutput, + EmbeddingRequestOutput, + PoolingRequestOutput, + RequestOutput, +) from vllm.pooling_params import PoolingParams from vllm.utils import random_uuid @@ -41,10 +48,13 @@ class RequestBase: - def __init__(self, request, executor_callback, output_dtype): + def __init__( + self, request, executor_callback: Callable, output_dtype: np.dtype, logger + ): self.request = request self.executor_callback = executor_callback self.output_dtype = output_dtype + self.logger = logger self.id = random_uuid() self.stream = False self.prepend_input = False @@ -58,13 +68,15 @@ def execute(self): raise NotImplementedError @abstractmethod - def create_response(self, *args, **kwargs): + def create_response(self, request_output, *args, **kwargs): raise NotImplementedError class GenerateRequest(RequestBase): - def __init__(self, request, executor_callback, output_dtype): - super().__init__(request, executor_callback, output_dtype) + def __init__( + self, request, executor_callback: Callable, output_dtype: np.dtype, logger + ): + super().__init__(request, executor_callback, output_dtype, logger) def _get_input_tensors(self): # prompt @@ -166,7 +178,12 @@ async def execute(self): async for response in response_iterator: yield response - def create_response(self, request_output_state, request_output, prepend_input): + def create_response( + self, + request_output: RequestOutput, + request_output_state: dict, + prepend_input: bool, + ): output_tensors = [] # text_output @@ -278,8 +295,10 @@ def create_response(self, request_output_state, request_output, prepend_input): class EmbedRequest(RequestBase): - def __init__(self, request, executor_callback, output_dtype): - super().__init__(request, executor_callback, output_dtype) + def __init__( + self, request, executor_callback: Callable, output_dtype: np.dtype, logger + ): + super().__init__(request, executor_callback, output_dtype, logger) def _get_input_tensors(self): embedding_request = pb_utils.get_input_tensor_by_name( @@ -338,32 +357,16 @@ def _to_pooling_params(self, embedding_request: dict): pooling_params = PoolingParams(dimensions=dims, task="embed") return pooling_params - def create_response(self, request_output): + def create_response(self, request_output: PoolingRequestOutput[EmbeddingOutput]): output_tensors = [] + request_output = EmbeddingRequestOutput.from_base(request_output) - # Extract embedding vector from output - # PoolingRequestOutput.outputs is a PoolingOutput with .data (torch.Tensor) - pooling_data = request_output.outputs.data - - # Convert torch tensor to numpy array then to list for JSON serialization - if hasattr(pooling_data, "cpu"): - # It's a torch tensor - move to CPU and convert to numpy - embedding_array = pooling_data.cpu().numpy() - else: - # Already numpy or list - embedding_array = np.array(pooling_data, dtype=np.float32) - - # Create response tensor - for embeddings, we use text_output to return the vector - # (This is a simplification - you may want to define a proper embedding output tensor) - embedding_list = ( - embedding_array.tolist() - if hasattr(embedding_array, "tolist") - else list(embedding_array) - ) - embedding_str = json.dumps(embedding_list) + # Extract embedding list from output + embedding: list[float] = request_output.outputs.embedding output_tensors.append( pb_utils.Tensor( - "text_output", np.asarray([embedding_str], dtype=self.output_dtype) + "text_output", + np.asarray([json.dumps(embedding)], dtype=self.output_dtype), ) ) From 6bd56c46cdf4a206674cc93ed00c58c0989150b3 Mon Sep 17 00:00:00 2001 From: Yingge He Date: Fri, 31 Oct 2025 14:45:56 -0700 Subject: [PATCH 7/7] Add warning --- src/model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/model.py b/src/model.py index f5957b28..067b2161 100644 --- a/src/model.py +++ b/src/model.py @@ -124,7 +124,8 @@ def _auto_complete_inputs_and_outputs(auto_complete_model_config): "dims": [1], "optional": True, }, - # Tentative input reserved for embedding requests in OpenAI frontend. May change in the future. + # Tentative input reserved for embedding requests in OpenAI-compatible frontend. Subject to change in the future. + # WARN: Triton client should never set this input. It is reserved for embedding requests in OpenAI-compatible frontend. { "name": "embedding_request", "data_type": "TYPE_STRING",