From 710b8c17c13bd24b839121608efd29fd5801e4f0 Mon Sep 17 00:00:00 2001 From: JingyaHuang Date: Fri, 22 Aug 2025 11:23:55 +0000 Subject: [PATCH 1/5] 1st draft --- Dockerfile.neuron | 43 +++++ backends/Cargo.toml | 1 + backends/neuron/Cargo.toml | 16 ++ backends/neuron/server/README.md | 0 .../server/text_embeddings_server/__init__.py | 0 .../server/text_embeddings_server/cli.py | 55 +++++++ .../text_embeddings_server/models/__init__.py | 126 +++++++++++++++ .../server/text_embeddings_server/server.py | 92 +++++++++++ backends/neuron/src/lib.rs | 132 ++++++++++++++++ backends/neuron/src/logging.rs | 61 ++++++++ backends/neuron/src/management.rs | 148 ++++++++++++++++++ docs/source/en/_toctree.yml | 2 + docs/source/en/local_neuron.md | 1 + integration_tests/neuron/conftest.py | 0 integration_tests/neuron/test_embed.py | 0 15 files changed, 677 insertions(+) create mode 100644 Dockerfile.neuron create mode 100644 backends/neuron/Cargo.toml create mode 100644 backends/neuron/server/README.md create mode 100644 backends/neuron/server/text_embeddings_server/__init__.py create mode 100644 backends/neuron/server/text_embeddings_server/cli.py create mode 100644 backends/neuron/server/text_embeddings_server/models/__init__.py create mode 100644 backends/neuron/server/text_embeddings_server/server.py create mode 100644 backends/neuron/src/lib.rs create mode 100644 backends/neuron/src/logging.rs create mode 100644 backends/neuron/src/management.rs create mode 100644 docs/source/en/local_neuron.md create mode 100644 integration_tests/neuron/conftest.py create mode 100644 integration_tests/neuron/test_embed.py diff --git a/Dockerfile.neuron b/Dockerfile.neuron new file mode 100644 index 00000000..f8b03ab2 --- /dev/null +++ b/Dockerfile.neuron @@ -0,0 +1,43 @@ +ARG PLATFORM=neuron +FROM lukemathwalker/cargo-chef:latest-rust-1.85-bookworm AS chef +WORKDIR /usr/src + +ENV SCCACHE=0.10.0 +ENV RUSTC_WRAPPER=/usr/local/bin/sccache + +# Donwload, configure sccache +RUN curl -fsSL https://github.com/mozilla/sccache/releases/download/v$SCCACHE/sccache-v$SCCACHE-x86_64-unknown-linux-musl.tar.gz | tar -xzv --strip-components=1 -C /usr/local/bin sccache-v$SCCACHE-x86_64-unknown-linux-musl/sccache && \ + chmod +x /usr/local/bin/sccache + +FROM chef AS planner + +COPY backends backends +COPY core core +COPY router router +COPY Cargo.toml ./ +COPY Cargo.lock ./ + +RUN cargo chef prepare --recipe-path recipe.json + +FROM chef AS builder + +ARG GIT_SHA +ARG DOCKER_LABEL + +# sccache specific variables +ARG SCCACHE_GHA_ENABLED + +COPY --from=planner /usr/src/recipe.json recipe.json + +RUN --mount=type=secret,id=actions_results_url,env=ACTIONS_RESULTS_URL \ + --mount=type=secret,id=actions_runtime_token,env=ACTIONS_RUNTIME_TOKEN \ + cargo chef cook --release --features python --no-default-features --recipe-path recipe.json && sccache -s + +COPY backends backends +COPY core core +COPY router router +COPY Cargo.toml ./ +COPY Cargo.lock ./ + +WORKDIR /usr/src + diff --git a/backends/Cargo.toml b/backends/Cargo.toml index bb9d7419..7d821ff4 100644 --- a/backends/Cargo.toml +++ b/backends/Cargo.toml @@ -21,6 +21,7 @@ rand = { workspace = true } [features] clap = ["dep:clap", "text-embeddings-backend-core/clap"] python = ["dep:text-embeddings-backend-python"] +neuron = ["dep:text-embeddings-backend-neuron"] ort = ["dep:text-embeddings-backend-ort"] candle = ["dep:text-embeddings-backend-candle"] cuda = ["text-embeddings-backend-candle?/cuda"] diff --git a/backends/neuron/Cargo.toml b/backends/neuron/Cargo.toml new file mode 100644 index 00000000..b38f350e --- /dev/null +++ b/backends/neuron/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "text-embeddings-backend-python" +version.workspace = true +edition.workspace = true +authors.workspace = true +homepage.workspace = true + +[dependencies] +backend-grpc-client = { path = "../grpc-client" } +nohash-hasher = "^0.2" +serde = { version = "^1.0", features = ["derive"] } +serde_json = "^1.0" +text-embeddings-backend-core = { path = "../core" } +thiserror = "^1.0" +tokio = { version = "^1.25", features = ["sync"] } +tracing = "^0.1" diff --git a/backends/neuron/server/README.md b/backends/neuron/server/README.md new file mode 100644 index 00000000..e69de29b diff --git a/backends/neuron/server/text_embeddings_server/__init__.py b/backends/neuron/server/text_embeddings_server/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/backends/neuron/server/text_embeddings_server/cli.py b/backends/neuron/server/text_embeddings_server/cli.py new file mode 100644 index 00000000..c4dfaa4c --- /dev/null +++ b/backends/neuron/server/text_embeddings_server/cli.py @@ -0,0 +1,55 @@ +import sys +import typer + +from pathlib import Path +from loguru import logger +from typing import Optional +from enum import Enum + +app = typer.Typer() + + +class Dtype(str, Enum): + float32 = "float32" + float16 = "float16" + bloat16 = "bfloat16" + + +@app.command() +def serve( + model_path: Path, + dtype: Dtype = "float32", + uds_path: Path = "/tmp/text-embeddings-server", + logger_level: str = "INFO", + json_output: bool = False, + otlp_endpoint: Optional[str] = None, + otlp_service_name: str = "text-embeddings-inference.server", + pool: str = "cls", +): + # Remove default handler + logger.remove() + logger.add( + sys.stdout, + format="{message}", + filter="text_embeddings_server", + level=logger_level, + serialize=json_output, + backtrace=True, + diagnose=False, + ) + + # Import here after the logger is added to log potential import exceptions + from text_embeddings_server import server + from text_embeddings_server.utils.tracing import setup_tracing + + # Setup OpenTelemetry distributed tracing + if otlp_endpoint is not None: + setup_tracing(otlp_endpoint=otlp_endpoint, otlp_service_name=otlp_service_name) + + # Downgrade enum into str for easier management later on + dtype = None if dtype is None else dtype.value + server.serve(model_path, dtype, uds_path, pool) + + +if __name__ == "__main__": + app() diff --git a/backends/neuron/server/text_embeddings_server/models/__init__.py b/backends/neuron/server/text_embeddings_server/models/__init__.py new file mode 100644 index 00000000..06c39832 --- /dev/null +++ b/backends/neuron/server/text_embeddings_server/models/__init__.py @@ -0,0 +1,126 @@ +import os +import torch + +from loguru import logger +from pathlib import Path +from typing import Optional +from transformers import AutoConfig +from transformers.models.bert import BertConfig + +from text_embeddings_server.models.model import Model +from text_embeddings_server.models.masked_model import MaskedLanguageModel +from text_embeddings_server.models.default_model import DefaultModel +from text_embeddings_server.models.classification_model import ClassificationModel +from text_embeddings_server.models.jinaBert_model import FlashJinaBert +from text_embeddings_server.models.flash_mistral import FlashMistral +from text_embeddings_server.models.flash_qwen3 import FlashQwen3 +from text_embeddings_server.utils.device import get_device, use_ipex + +__all__ = ["Model"] + +TRUST_REMOTE_CODE = os.getenv("TRUST_REMOTE_CODE", "false").lower() in ["true", "1"] +DISABLE_TENSOR_CACHE = os.getenv("DISABLE_TENSOR_CACHE", "false").lower() in [ + "true", + "1", +] +# Disable gradients +torch.set_grad_enabled(False) + +FLASH_ATTENTION = True +try: + from text_embeddings_server.models.flash_bert import FlashBert +except ImportError as e: + logger.warning(f"Could not import Flash Attention enabled models: {e}") + FLASH_ATTENTION = False + +if FLASH_ATTENTION: + __all__.append(FlashBert) + + +def create_model(model_class, model_path, device, datatype, pool="cls"): + """Create a model instance and load it into Neuron devices.""" + model_handle = model_class( + model_path, + device, + datatype, + pool, + trust_remote=TRUST_REMOTE_CODE, + ) + return model_handle + + +def get_model(model_path: Path, dtype: Optional[str], pool: str): + if dtype == "float32": + datatype = torch.float32 + elif dtype == "float16": + datatype = torch.float16 + elif dtype == "bfloat16": + datatype = torch.bfloat16 + else: + raise RuntimeError(f"Unknown dtype {dtype}") + + device = get_device() + logger.info(f"backend device: {device}") + + config = AutoConfig.from_pretrained(model_path, trust_remote_code=TRUST_REMOTE_CODE) + + if ( + hasattr(config, "auto_map") + and isinstance(config.auto_map, dict) + and "AutoModel" in config.auto_map + and config.auto_map["AutoModel"] + == "jinaai/jina-bert-v2-qk-post-norm--modeling_bert.JinaBertModel" + ): + # Add specific offline modeling for model "jinaai/jina-embeddings-v2-base-code" which uses "autoMap" to reference code in other repository + return create_model(FlashJinaBert, model_path, device, datatype) + + if config.model_type == "bert": + config: BertConfig + if ( + use_ipex() + or device.type in ["cuda", "hpu"] + and config.position_embedding_type == "absolute" + and datatype in [torch.float16, torch.bfloat16] + and FLASH_ATTENTION + ): + if pool != "cls": + if config.architectures[0].endswith("ForMaskedLM") and pool == "splade": + return create_model( + MaskedLanguageModel, model_path, device, datatype, pool + ) + return create_model(DefaultModel, model_path, device, datatype, pool) + + try: + return create_model(FlashBert, model_path, device, datatype) + except FileNotFoundError: + logger.info( + "Do not have safetensors file for this model, use default transformers model path instead" + ) + return create_model(DefaultModel, model_path, device, datatype, pool) + + if config.architectures[0].endswith("Classification"): + return create_model(ClassificationModel, model_path, device, datatype) + elif config.architectures[0].endswith("ForMaskedLM") and pool == "splade": + return create_model(MaskedLanguageModel, model_path, device, datatype) + else: + return create_model(DefaultModel, model_path, device, datatype, pool) + + if config.model_type == "mistral" and device.type == "hpu": + try: + return create_model(FlashMistral, model_path, device, datatype, pool) + except FileNotFoundError: + return create_model(DefaultModel, model_path, device, datatype, pool) + + if config.model_type == "qwen3" and device.type == "hpu": + try: + return create_model(FlashQwen3, model_path, device, datatype, pool) + except FileNotFoundError: + return create_model(DefaultModel, model_path, device, datatype, pool) + + # Default case + if config.architectures[0].endswith("Classification"): + return create_model(ClassificationModel, model_path, device, datatype) + elif config.architectures[0].endswith("ForMaskedLM") and pool == "splade": + return create_model(MaskedLanguageModel, model_path, device, datatype) + else: + return create_model(DefaultModel, model_path, device, datatype, pool) diff --git a/backends/neuron/server/text_embeddings_server/server.py b/backends/neuron/server/text_embeddings_server/server.py new file mode 100644 index 00000000..646d79bc --- /dev/null +++ b/backends/neuron/server/text_embeddings_server/server.py @@ -0,0 +1,92 @@ +import asyncio +import torch +from grpc import aio +from loguru import logger + +from grpc_reflection.v1alpha import reflection +from pathlib import Path +from typing import Optional + +from text_embeddings_server.models import Model, get_model +from text_embeddings_server.pb import embed_pb2_grpc, embed_pb2 +from text_embeddings_server.utils.tracing import UDSOpenTelemetryAioServerInterceptor +from text_embeddings_server.utils.interceptor import ExceptionInterceptor + + +class EmbeddingService(embed_pb2_grpc.EmbeddingServiceServicer): + def __init__(self, model: Model): + self.model = model + # Force inference mode for the lifetime of EmbeddingService + self._inference_mode_raii_guard = torch._C._InferenceMode(True) + + async def Health(self, request, context): + if self.model.device.type == "cuda": + torch.zeros((2, 2), device="cuda") + return embed_pb2.HealthResponse() + + async def Embed(self, request, context): + max_input_length = self.model.max_input_length + batch = self.model.batch_type.from_pb( + request, self.model.device, max_input_length + ) + + embeddings = self.model.embed(batch) + + return embed_pb2.EmbedResponse(embeddings=embeddings) + + async def Predict(self, request, context): + max_input_length = self.model.max_input_length + batch = self.model.batch_type.from_pb( + request, self.model.device, max_input_length + ) + + scores = self.model.predict(batch) + + return embed_pb2.PredictResponse(scores=scores) + + +def serve( + model_path: Path, + dtype: Optional[str], + uds_path: Path, + pool: str, +): + async def serve_inner( + model_path: Path, + dtype: Optional[str] = None, + ): + unix_socket = f"unix://{uds_path}" + + try: + model = get_model(model_path, dtype, pool) + except Exception: + logger.exception("Error when initializing model") + raise + + server = aio.server( + interceptors=[ + ExceptionInterceptor(), + UDSOpenTelemetryAioServerInterceptor(), + ] + ) + embed_pb2_grpc.add_EmbeddingServiceServicer_to_server( + EmbeddingService(model), server + ) + SERVICE_NAMES = ( + embed_pb2.DESCRIPTOR.services_by_name["EmbeddingService"].full_name, + reflection.SERVICE_NAME, + ) + reflection.enable_server_reflection(SERVICE_NAMES, server) + server.add_insecure_port(unix_socket) + + await server.start() + + logger.info(f"Server started at {unix_socket}") + + try: + await server.wait_for_termination() + except KeyboardInterrupt: + logger.info("Signal received. Shutting down") + await server.stop(0) + + asyncio.run(serve_inner(model_path, dtype)) diff --git a/backends/neuron/src/lib.rs b/backends/neuron/src/lib.rs new file mode 100644 index 00000000..53255b07 --- /dev/null +++ b/backends/neuron/src/lib.rs @@ -0,0 +1,132 @@ +mod logging; +mod management; + +use backend_grpc_client::Client; +use nohash_hasher::BuildNoHashHasher; +use std::collections::HashMap; +use text_embeddings_backend_core::{ + Backend, BackendError, Batch, Embedding, Embeddings, ModelType, Pool, Predictions, +}; +use tokio::runtime::Runtime; + +pub struct PythonBackend { + _backend_process: management::BackendProcess, + tokio_runtime: Runtime, + backend_client: Client, +} + +impl PythonBackend { + pub fn new( + model_path: String, + dtype: String, + model_type: ModelType, + uds_path: String, + otlp_endpoint: Option, + otlp_service_name: String, + ) -> Result { + let pool = match model_type { + ModelType::Classifier => Pool::Cls, + ModelType::Embedding(pool) => pool, + }; + + let backend_process = management::BackendProcess::new( + model_path, + dtype, + &uds_path, + otlp_endpoint, + otlp_service_name, + pool, + )?; + let tokio_runtime = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .map_err(|err| BackendError::Start(format!("Could not start Tokio runtime: {err}")))?; + + let backend_client = tokio_runtime + .block_on(Client::connect_uds(uds_path)) + .map_err(|err| { + BackendError::Start(format!("Could not connect to backend process: {err}")) + })?; + + Ok(Self { + _backend_process: backend_process, + tokio_runtime, + backend_client, + }) + } +} + +impl Backend for PythonBackend { + fn health(&self) -> Result<(), BackendError> { + if self + .tokio_runtime + .block_on(self.backend_client.clone().health()) + .is_err() + { + return Err(BackendError::Unhealthy); + } + Ok(()) + } + + fn is_padded(&self) -> bool { + false + } + + fn embed(&self, batch: Batch) -> Result { + if !batch.raw_indices.is_empty() { + return Err(BackendError::Inference( + "raw embeddings are not supported for the Python backend.".to_string(), + )); + } + let batch_size = batch.len(); + + let results = self + .tokio_runtime + .block_on(self.backend_client.clone().embed( + batch.input_ids, + batch.token_type_ids, + batch.position_ids, + batch.cumulative_seq_lengths, + batch.max_length, + )) + .map_err(|err| BackendError::Inference(err.to_string()))?; + let pooled_embeddings: Vec> = results.into_iter().map(|r| r.values).collect(); + + let mut embeddings = + HashMap::with_capacity_and_hasher(batch_size, BuildNoHashHasher::default()); + for (i, e) in pooled_embeddings.into_iter().enumerate() { + embeddings.insert(i, Embedding::Pooled(e)); + } + + Ok(embeddings) + } + + fn predict(&self, batch: Batch) -> Result { + if !batch.raw_indices.is_empty() { + return Err(BackendError::Inference( + "raw embeddings are not supported for the Python backend.".to_string(), + )); + } + let batch_size = batch.len(); + let results = self + .tokio_runtime + .block_on(self.backend_client.clone().predict( + batch.input_ids, + batch.token_type_ids, + batch.position_ids, + batch.cumulative_seq_lengths, + batch.max_length, + )) + .map_err(|err| BackendError::Inference(err.to_string()))?; + let raw_results: Vec> = results.into_iter().map(|r| r.values).collect(); + + let mut predictions = + HashMap::with_capacity_and_hasher(batch_size, BuildNoHashHasher::default()); + + for (i, r) in raw_results.into_iter().enumerate() { + predictions.insert(i, r); + } + + Ok(predictions) + } +} diff --git a/backends/neuron/src/logging.rs b/backends/neuron/src/logging.rs new file mode 100644 index 00000000..8f55e8e6 --- /dev/null +++ b/backends/neuron/src/logging.rs @@ -0,0 +1,61 @@ +use serde::Deserialize; +use std::io::{BufRead, Lines}; + +#[derive(Deserialize)] +#[serde(rename_all = "UPPERCASE")] +enum PythonLogLevelEnum { + Trace, + Debug, + Info, + Success, + Warning, + Error, + Critical, +} + +#[derive(Deserialize)] +struct PythonLogLevel { + name: PythonLogLevelEnum, +} + +#[derive(Deserialize)] +struct PythonLogRecord { + level: PythonLogLevel, +} + +#[derive(Deserialize)] +struct PythonLogMessage { + text: String, + record: PythonLogRecord, +} + +impl PythonLogMessage { + fn trace(&self) { + match self.record.level.name { + PythonLogLevelEnum::Trace => tracing::trace!("{}", self.text), + PythonLogLevelEnum::Debug => tracing::debug!("{}", self.text), + PythonLogLevelEnum::Info => tracing::info!("{}", self.text), + PythonLogLevelEnum::Success => tracing::info!("{}", self.text), + PythonLogLevelEnum::Warning => tracing::warn!("{}", self.text), + PythonLogLevelEnum::Error => tracing::error!("{}", self.text), + PythonLogLevelEnum::Critical => tracing::error!("{}", self.text), + } + } +} + +impl TryFrom<&String> for PythonLogMessage { + type Error = serde_json::Error; + + fn try_from(value: &String) -> Result { + serde_json::from_str::(value) + } +} + +pub(crate) fn log_lines(lines: Lines) { + for line in lines.map_while(Result::ok) { + match PythonLogMessage::try_from(&line) { + Ok(log) => log.trace(), + Err(_) => tracing::debug!("{line}"), + } + } +} diff --git a/backends/neuron/src/management.rs b/backends/neuron/src/management.rs new file mode 100644 index 00000000..81c294a9 --- /dev/null +++ b/backends/neuron/src/management.rs @@ -0,0 +1,148 @@ +use crate::logging::log_lines; +use std::ffi::OsString; +use std::io::{BufRead, BufReader}; +use std::os::unix::process::{CommandExt, ExitStatusExt}; +use std::path::Path; +use std::process::{Child, Command, Stdio}; +use std::sync::mpsc; +use std::thread::sleep; +use std::time::{Duration, Instant}; +use std::{env, fs, io, thread}; +use text_embeddings_backend_core::{BackendError, Pool}; + +#[derive(Debug)] +pub(crate) struct BackendProcess { + inner: Child, +} + +impl BackendProcess { + pub(crate) fn new( + model_path: String, + dtype: String, + uds_path: &str, + otlp_endpoint: Option, + otlp_service_name: String, + pool: Pool, + ) -> Result { + // Get UDS path + let uds = Path::new(uds_path); + + // Clean previous runs + if uds.exists() { + fs::remove_file(uds).expect("could not remove UDS file"); + } + + let pool = match pool { + Pool::Cls => "cls", + Pool::Mean => "mean", + Pool::LastToken => "lasttoken", + Pool::Splade => "splade", + }; + + // Process args + let mut python_server_args = vec![ + model_path, + "--dtype".to_owned(), + dtype, + "--uds-path".to_owned(), + uds_path.to_owned(), + "--logger-level".to_owned(), + "INFO".to_owned(), + "--json-output".to_owned(), + "--pool".to_owned(), + pool.to_owned(), + ]; + + // OpenTelemetry + if let Some(otlp_endpoint) = otlp_endpoint { + python_server_args.push("--otlp-endpoint".to_owned()); + python_server_args.push(otlp_endpoint); + } + + python_server_args.push("--otlp-service-name".to_owned()); + python_server_args.push(otlp_service_name); + + // Copy current process env + let envs: Vec<(OsString, OsString)> = env::vars_os().collect(); + + tracing::info!("Starting Python backend"); + let mut p = match Command::new("python-text-embeddings-server") + .args(python_server_args) + .envs(envs) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .process_group(0) + .spawn() + { + Ok(p) => p, + Err(err) => { + if err.kind() == io::ErrorKind::NotFound { + return Err(BackendError::Start( + "python-text-embeddings-server not found in PATH".to_owned(), + )); + } + return Err(BackendError::Start(err.to_string())); + } + }; + + let stdout_reader = BufReader::new(p.stdout.take().unwrap()); + let stderr_reader = BufReader::new(p.stderr.take().unwrap()); + + //stdout tracing thread + thread::spawn(move || { + let _span = tracing::span!(tracing::Level::INFO, "python-backend").entered(); + log_lines(stdout_reader.lines()); + }); + + let start_time = Instant::now(); + let mut wait_time = Instant::now(); + + loop { + // Process exited + if let Some(exit_status) = p.try_wait().unwrap() { + // We read stderr in another thread as it seems that lines() can block in some cases + let (err_sender, err_receiver) = mpsc::channel(); + thread::spawn(move || { + for line in stderr_reader.lines().map_while(Result::ok) { + err_sender.send(line).unwrap_or(()); + } + }); + let mut err = String::new(); + while let Ok(line) = err_receiver.recv_timeout(Duration::from_millis(10)) { + err = err + "\n" + &line; + } + + tracing::debug!("Python Backend complete standard error output:\n{err}"); + + if let Some(signal) = exit_status.signal() { + return Err(BackendError::Start(format!( + "Python Backend process was signaled to shutdown with signal {signal}" + ))); + } + return Err(BackendError::Start( + "Python backend failed to start".to_string(), + )); + } + + // Shard is ready + if uds.exists() { + tracing::info!("Python backend ready in {:?}", start_time.elapsed()); + break; + } else if wait_time.elapsed() > Duration::from_secs(10) { + tracing::info!("Waiting for Python backend to be ready..."); + wait_time = Instant::now(); + } + sleep(Duration::from_millis(5)); + } + + Ok(Self { inner: p }) + } +} + +impl Drop for BackendProcess { + fn drop(&mut self) { + self.inner.kill().unwrap(); + let _ = self.inner.wait(); + tracing::info!("Python backend process terminated"); + } +} diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index fa6f21e6..b9eebac2 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -19,6 +19,8 @@ title: Build custom container for TEI - local: intel_container title: Using TEI container with Intel Hardware + - local: local_neuron + title: Using TEI container with AWS Neuron - local: examples title: Example uses title: Tutorials diff --git a/docs/source/en/local_neuron.md b/docs/source/en/local_neuron.md new file mode 100644 index 00000000..e0a2cf2b --- /dev/null +++ b/docs/source/en/local_neuron.md @@ -0,0 +1 @@ +# Neuron backend for AWS Trainium and Inferentia \ No newline at end of file diff --git a/integration_tests/neuron/conftest.py b/integration_tests/neuron/conftest.py new file mode 100644 index 00000000..e69de29b diff --git a/integration_tests/neuron/test_embed.py b/integration_tests/neuron/test_embed.py new file mode 100644 index 00000000..e69de29b From 139b179f1cd346705fc3267b1f39162e438d8b21 Mon Sep 17 00:00:00 2001 From: JingyaHuang Date: Wed, 22 Oct 2025 16:29:32 +0000 Subject: [PATCH 2/5] feat: sentence transformer for neuron --- backends/neuron/Cargo.toml | 16 -- backends/neuron/server/README.md | 0 .../server/text_embeddings_server/__init__.py | 0 .../server/text_embeddings_server/cli.py | 55 ------- .../text_embeddings_server/models/__init__.py | 126 --------------- .../server/text_embeddings_server/server.py | 92 ----------- backends/neuron/src/lib.rs | 132 ---------------- backends/neuron/src/logging.rs | 61 -------- backends/neuron/src/management.rs | 148 ------------------ .../text_embeddings_server/models/__init__.py | 9 +- .../models/neuron_models.py | 67 ++++++++ .../text_embeddings_server/utils/device.py | 19 +++ 12 files changed, 94 insertions(+), 631 deletions(-) delete mode 100644 backends/neuron/Cargo.toml delete mode 100644 backends/neuron/server/README.md delete mode 100644 backends/neuron/server/text_embeddings_server/__init__.py delete mode 100644 backends/neuron/server/text_embeddings_server/cli.py delete mode 100644 backends/neuron/server/text_embeddings_server/models/__init__.py delete mode 100644 backends/neuron/server/text_embeddings_server/server.py delete mode 100644 backends/neuron/src/lib.rs delete mode 100644 backends/neuron/src/logging.rs delete mode 100644 backends/neuron/src/management.rs create mode 100644 backends/python/server/text_embeddings_server/models/neuron_models.py diff --git a/backends/neuron/Cargo.toml b/backends/neuron/Cargo.toml deleted file mode 100644 index b38f350e..00000000 --- a/backends/neuron/Cargo.toml +++ /dev/null @@ -1,16 +0,0 @@ -[package] -name = "text-embeddings-backend-python" -version.workspace = true -edition.workspace = true -authors.workspace = true -homepage.workspace = true - -[dependencies] -backend-grpc-client = { path = "../grpc-client" } -nohash-hasher = "^0.2" -serde = { version = "^1.0", features = ["derive"] } -serde_json = "^1.0" -text-embeddings-backend-core = { path = "../core" } -thiserror = "^1.0" -tokio = { version = "^1.25", features = ["sync"] } -tracing = "^0.1" diff --git a/backends/neuron/server/README.md b/backends/neuron/server/README.md deleted file mode 100644 index e69de29b..00000000 diff --git a/backends/neuron/server/text_embeddings_server/__init__.py b/backends/neuron/server/text_embeddings_server/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/backends/neuron/server/text_embeddings_server/cli.py b/backends/neuron/server/text_embeddings_server/cli.py deleted file mode 100644 index c4dfaa4c..00000000 --- a/backends/neuron/server/text_embeddings_server/cli.py +++ /dev/null @@ -1,55 +0,0 @@ -import sys -import typer - -from pathlib import Path -from loguru import logger -from typing import Optional -from enum import Enum - -app = typer.Typer() - - -class Dtype(str, Enum): - float32 = "float32" - float16 = "float16" - bloat16 = "bfloat16" - - -@app.command() -def serve( - model_path: Path, - dtype: Dtype = "float32", - uds_path: Path = "/tmp/text-embeddings-server", - logger_level: str = "INFO", - json_output: bool = False, - otlp_endpoint: Optional[str] = None, - otlp_service_name: str = "text-embeddings-inference.server", - pool: str = "cls", -): - # Remove default handler - logger.remove() - logger.add( - sys.stdout, - format="{message}", - filter="text_embeddings_server", - level=logger_level, - serialize=json_output, - backtrace=True, - diagnose=False, - ) - - # Import here after the logger is added to log potential import exceptions - from text_embeddings_server import server - from text_embeddings_server.utils.tracing import setup_tracing - - # Setup OpenTelemetry distributed tracing - if otlp_endpoint is not None: - setup_tracing(otlp_endpoint=otlp_endpoint, otlp_service_name=otlp_service_name) - - # Downgrade enum into str for easier management later on - dtype = None if dtype is None else dtype.value - server.serve(model_path, dtype, uds_path, pool) - - -if __name__ == "__main__": - app() diff --git a/backends/neuron/server/text_embeddings_server/models/__init__.py b/backends/neuron/server/text_embeddings_server/models/__init__.py deleted file mode 100644 index 06c39832..00000000 --- a/backends/neuron/server/text_embeddings_server/models/__init__.py +++ /dev/null @@ -1,126 +0,0 @@ -import os -import torch - -from loguru import logger -from pathlib import Path -from typing import Optional -from transformers import AutoConfig -from transformers.models.bert import BertConfig - -from text_embeddings_server.models.model import Model -from text_embeddings_server.models.masked_model import MaskedLanguageModel -from text_embeddings_server.models.default_model import DefaultModel -from text_embeddings_server.models.classification_model import ClassificationModel -from text_embeddings_server.models.jinaBert_model import FlashJinaBert -from text_embeddings_server.models.flash_mistral import FlashMistral -from text_embeddings_server.models.flash_qwen3 import FlashQwen3 -from text_embeddings_server.utils.device import get_device, use_ipex - -__all__ = ["Model"] - -TRUST_REMOTE_CODE = os.getenv("TRUST_REMOTE_CODE", "false").lower() in ["true", "1"] -DISABLE_TENSOR_CACHE = os.getenv("DISABLE_TENSOR_CACHE", "false").lower() in [ - "true", - "1", -] -# Disable gradients -torch.set_grad_enabled(False) - -FLASH_ATTENTION = True -try: - from text_embeddings_server.models.flash_bert import FlashBert -except ImportError as e: - logger.warning(f"Could not import Flash Attention enabled models: {e}") - FLASH_ATTENTION = False - -if FLASH_ATTENTION: - __all__.append(FlashBert) - - -def create_model(model_class, model_path, device, datatype, pool="cls"): - """Create a model instance and load it into Neuron devices.""" - model_handle = model_class( - model_path, - device, - datatype, - pool, - trust_remote=TRUST_REMOTE_CODE, - ) - return model_handle - - -def get_model(model_path: Path, dtype: Optional[str], pool: str): - if dtype == "float32": - datatype = torch.float32 - elif dtype == "float16": - datatype = torch.float16 - elif dtype == "bfloat16": - datatype = torch.bfloat16 - else: - raise RuntimeError(f"Unknown dtype {dtype}") - - device = get_device() - logger.info(f"backend device: {device}") - - config = AutoConfig.from_pretrained(model_path, trust_remote_code=TRUST_REMOTE_CODE) - - if ( - hasattr(config, "auto_map") - and isinstance(config.auto_map, dict) - and "AutoModel" in config.auto_map - and config.auto_map["AutoModel"] - == "jinaai/jina-bert-v2-qk-post-norm--modeling_bert.JinaBertModel" - ): - # Add specific offline modeling for model "jinaai/jina-embeddings-v2-base-code" which uses "autoMap" to reference code in other repository - return create_model(FlashJinaBert, model_path, device, datatype) - - if config.model_type == "bert": - config: BertConfig - if ( - use_ipex() - or device.type in ["cuda", "hpu"] - and config.position_embedding_type == "absolute" - and datatype in [torch.float16, torch.bfloat16] - and FLASH_ATTENTION - ): - if pool != "cls": - if config.architectures[0].endswith("ForMaskedLM") and pool == "splade": - return create_model( - MaskedLanguageModel, model_path, device, datatype, pool - ) - return create_model(DefaultModel, model_path, device, datatype, pool) - - try: - return create_model(FlashBert, model_path, device, datatype) - except FileNotFoundError: - logger.info( - "Do not have safetensors file for this model, use default transformers model path instead" - ) - return create_model(DefaultModel, model_path, device, datatype, pool) - - if config.architectures[0].endswith("Classification"): - return create_model(ClassificationModel, model_path, device, datatype) - elif config.architectures[0].endswith("ForMaskedLM") and pool == "splade": - return create_model(MaskedLanguageModel, model_path, device, datatype) - else: - return create_model(DefaultModel, model_path, device, datatype, pool) - - if config.model_type == "mistral" and device.type == "hpu": - try: - return create_model(FlashMistral, model_path, device, datatype, pool) - except FileNotFoundError: - return create_model(DefaultModel, model_path, device, datatype, pool) - - if config.model_type == "qwen3" and device.type == "hpu": - try: - return create_model(FlashQwen3, model_path, device, datatype, pool) - except FileNotFoundError: - return create_model(DefaultModel, model_path, device, datatype, pool) - - # Default case - if config.architectures[0].endswith("Classification"): - return create_model(ClassificationModel, model_path, device, datatype) - elif config.architectures[0].endswith("ForMaskedLM") and pool == "splade": - return create_model(MaskedLanguageModel, model_path, device, datatype) - else: - return create_model(DefaultModel, model_path, device, datatype, pool) diff --git a/backends/neuron/server/text_embeddings_server/server.py b/backends/neuron/server/text_embeddings_server/server.py deleted file mode 100644 index 646d79bc..00000000 --- a/backends/neuron/server/text_embeddings_server/server.py +++ /dev/null @@ -1,92 +0,0 @@ -import asyncio -import torch -from grpc import aio -from loguru import logger - -from grpc_reflection.v1alpha import reflection -from pathlib import Path -from typing import Optional - -from text_embeddings_server.models import Model, get_model -from text_embeddings_server.pb import embed_pb2_grpc, embed_pb2 -from text_embeddings_server.utils.tracing import UDSOpenTelemetryAioServerInterceptor -from text_embeddings_server.utils.interceptor import ExceptionInterceptor - - -class EmbeddingService(embed_pb2_grpc.EmbeddingServiceServicer): - def __init__(self, model: Model): - self.model = model - # Force inference mode for the lifetime of EmbeddingService - self._inference_mode_raii_guard = torch._C._InferenceMode(True) - - async def Health(self, request, context): - if self.model.device.type == "cuda": - torch.zeros((2, 2), device="cuda") - return embed_pb2.HealthResponse() - - async def Embed(self, request, context): - max_input_length = self.model.max_input_length - batch = self.model.batch_type.from_pb( - request, self.model.device, max_input_length - ) - - embeddings = self.model.embed(batch) - - return embed_pb2.EmbedResponse(embeddings=embeddings) - - async def Predict(self, request, context): - max_input_length = self.model.max_input_length - batch = self.model.batch_type.from_pb( - request, self.model.device, max_input_length - ) - - scores = self.model.predict(batch) - - return embed_pb2.PredictResponse(scores=scores) - - -def serve( - model_path: Path, - dtype: Optional[str], - uds_path: Path, - pool: str, -): - async def serve_inner( - model_path: Path, - dtype: Optional[str] = None, - ): - unix_socket = f"unix://{uds_path}" - - try: - model = get_model(model_path, dtype, pool) - except Exception: - logger.exception("Error when initializing model") - raise - - server = aio.server( - interceptors=[ - ExceptionInterceptor(), - UDSOpenTelemetryAioServerInterceptor(), - ] - ) - embed_pb2_grpc.add_EmbeddingServiceServicer_to_server( - EmbeddingService(model), server - ) - SERVICE_NAMES = ( - embed_pb2.DESCRIPTOR.services_by_name["EmbeddingService"].full_name, - reflection.SERVICE_NAME, - ) - reflection.enable_server_reflection(SERVICE_NAMES, server) - server.add_insecure_port(unix_socket) - - await server.start() - - logger.info(f"Server started at {unix_socket}") - - try: - await server.wait_for_termination() - except KeyboardInterrupt: - logger.info("Signal received. Shutting down") - await server.stop(0) - - asyncio.run(serve_inner(model_path, dtype)) diff --git a/backends/neuron/src/lib.rs b/backends/neuron/src/lib.rs deleted file mode 100644 index 53255b07..00000000 --- a/backends/neuron/src/lib.rs +++ /dev/null @@ -1,132 +0,0 @@ -mod logging; -mod management; - -use backend_grpc_client::Client; -use nohash_hasher::BuildNoHashHasher; -use std::collections::HashMap; -use text_embeddings_backend_core::{ - Backend, BackendError, Batch, Embedding, Embeddings, ModelType, Pool, Predictions, -}; -use tokio::runtime::Runtime; - -pub struct PythonBackend { - _backend_process: management::BackendProcess, - tokio_runtime: Runtime, - backend_client: Client, -} - -impl PythonBackend { - pub fn new( - model_path: String, - dtype: String, - model_type: ModelType, - uds_path: String, - otlp_endpoint: Option, - otlp_service_name: String, - ) -> Result { - let pool = match model_type { - ModelType::Classifier => Pool::Cls, - ModelType::Embedding(pool) => pool, - }; - - let backend_process = management::BackendProcess::new( - model_path, - dtype, - &uds_path, - otlp_endpoint, - otlp_service_name, - pool, - )?; - let tokio_runtime = tokio::runtime::Builder::new_current_thread() - .enable_all() - .build() - .map_err(|err| BackendError::Start(format!("Could not start Tokio runtime: {err}")))?; - - let backend_client = tokio_runtime - .block_on(Client::connect_uds(uds_path)) - .map_err(|err| { - BackendError::Start(format!("Could not connect to backend process: {err}")) - })?; - - Ok(Self { - _backend_process: backend_process, - tokio_runtime, - backend_client, - }) - } -} - -impl Backend for PythonBackend { - fn health(&self) -> Result<(), BackendError> { - if self - .tokio_runtime - .block_on(self.backend_client.clone().health()) - .is_err() - { - return Err(BackendError::Unhealthy); - } - Ok(()) - } - - fn is_padded(&self) -> bool { - false - } - - fn embed(&self, batch: Batch) -> Result { - if !batch.raw_indices.is_empty() { - return Err(BackendError::Inference( - "raw embeddings are not supported for the Python backend.".to_string(), - )); - } - let batch_size = batch.len(); - - let results = self - .tokio_runtime - .block_on(self.backend_client.clone().embed( - batch.input_ids, - batch.token_type_ids, - batch.position_ids, - batch.cumulative_seq_lengths, - batch.max_length, - )) - .map_err(|err| BackendError::Inference(err.to_string()))?; - let pooled_embeddings: Vec> = results.into_iter().map(|r| r.values).collect(); - - let mut embeddings = - HashMap::with_capacity_and_hasher(batch_size, BuildNoHashHasher::default()); - for (i, e) in pooled_embeddings.into_iter().enumerate() { - embeddings.insert(i, Embedding::Pooled(e)); - } - - Ok(embeddings) - } - - fn predict(&self, batch: Batch) -> Result { - if !batch.raw_indices.is_empty() { - return Err(BackendError::Inference( - "raw embeddings are not supported for the Python backend.".to_string(), - )); - } - let batch_size = batch.len(); - let results = self - .tokio_runtime - .block_on(self.backend_client.clone().predict( - batch.input_ids, - batch.token_type_ids, - batch.position_ids, - batch.cumulative_seq_lengths, - batch.max_length, - )) - .map_err(|err| BackendError::Inference(err.to_string()))?; - let raw_results: Vec> = results.into_iter().map(|r| r.values).collect(); - - let mut predictions = - HashMap::with_capacity_and_hasher(batch_size, BuildNoHashHasher::default()); - - for (i, r) in raw_results.into_iter().enumerate() { - predictions.insert(i, r); - } - - Ok(predictions) - } -} diff --git a/backends/neuron/src/logging.rs b/backends/neuron/src/logging.rs deleted file mode 100644 index 8f55e8e6..00000000 --- a/backends/neuron/src/logging.rs +++ /dev/null @@ -1,61 +0,0 @@ -use serde::Deserialize; -use std::io::{BufRead, Lines}; - -#[derive(Deserialize)] -#[serde(rename_all = "UPPERCASE")] -enum PythonLogLevelEnum { - Trace, - Debug, - Info, - Success, - Warning, - Error, - Critical, -} - -#[derive(Deserialize)] -struct PythonLogLevel { - name: PythonLogLevelEnum, -} - -#[derive(Deserialize)] -struct PythonLogRecord { - level: PythonLogLevel, -} - -#[derive(Deserialize)] -struct PythonLogMessage { - text: String, - record: PythonLogRecord, -} - -impl PythonLogMessage { - fn trace(&self) { - match self.record.level.name { - PythonLogLevelEnum::Trace => tracing::trace!("{}", self.text), - PythonLogLevelEnum::Debug => tracing::debug!("{}", self.text), - PythonLogLevelEnum::Info => tracing::info!("{}", self.text), - PythonLogLevelEnum::Success => tracing::info!("{}", self.text), - PythonLogLevelEnum::Warning => tracing::warn!("{}", self.text), - PythonLogLevelEnum::Error => tracing::error!("{}", self.text), - PythonLogLevelEnum::Critical => tracing::error!("{}", self.text), - } - } -} - -impl TryFrom<&String> for PythonLogMessage { - type Error = serde_json::Error; - - fn try_from(value: &String) -> Result { - serde_json::from_str::(value) - } -} - -pub(crate) fn log_lines(lines: Lines) { - for line in lines.map_while(Result::ok) { - match PythonLogMessage::try_from(&line) { - Ok(log) => log.trace(), - Err(_) => tracing::debug!("{line}"), - } - } -} diff --git a/backends/neuron/src/management.rs b/backends/neuron/src/management.rs deleted file mode 100644 index 81c294a9..00000000 --- a/backends/neuron/src/management.rs +++ /dev/null @@ -1,148 +0,0 @@ -use crate::logging::log_lines; -use std::ffi::OsString; -use std::io::{BufRead, BufReader}; -use std::os::unix::process::{CommandExt, ExitStatusExt}; -use std::path::Path; -use std::process::{Child, Command, Stdio}; -use std::sync::mpsc; -use std::thread::sleep; -use std::time::{Duration, Instant}; -use std::{env, fs, io, thread}; -use text_embeddings_backend_core::{BackendError, Pool}; - -#[derive(Debug)] -pub(crate) struct BackendProcess { - inner: Child, -} - -impl BackendProcess { - pub(crate) fn new( - model_path: String, - dtype: String, - uds_path: &str, - otlp_endpoint: Option, - otlp_service_name: String, - pool: Pool, - ) -> Result { - // Get UDS path - let uds = Path::new(uds_path); - - // Clean previous runs - if uds.exists() { - fs::remove_file(uds).expect("could not remove UDS file"); - } - - let pool = match pool { - Pool::Cls => "cls", - Pool::Mean => "mean", - Pool::LastToken => "lasttoken", - Pool::Splade => "splade", - }; - - // Process args - let mut python_server_args = vec![ - model_path, - "--dtype".to_owned(), - dtype, - "--uds-path".to_owned(), - uds_path.to_owned(), - "--logger-level".to_owned(), - "INFO".to_owned(), - "--json-output".to_owned(), - "--pool".to_owned(), - pool.to_owned(), - ]; - - // OpenTelemetry - if let Some(otlp_endpoint) = otlp_endpoint { - python_server_args.push("--otlp-endpoint".to_owned()); - python_server_args.push(otlp_endpoint); - } - - python_server_args.push("--otlp-service-name".to_owned()); - python_server_args.push(otlp_service_name); - - // Copy current process env - let envs: Vec<(OsString, OsString)> = env::vars_os().collect(); - - tracing::info!("Starting Python backend"); - let mut p = match Command::new("python-text-embeddings-server") - .args(python_server_args) - .envs(envs) - .stdout(Stdio::piped()) - .stderr(Stdio::piped()) - .process_group(0) - .spawn() - { - Ok(p) => p, - Err(err) => { - if err.kind() == io::ErrorKind::NotFound { - return Err(BackendError::Start( - "python-text-embeddings-server not found in PATH".to_owned(), - )); - } - return Err(BackendError::Start(err.to_string())); - } - }; - - let stdout_reader = BufReader::new(p.stdout.take().unwrap()); - let stderr_reader = BufReader::new(p.stderr.take().unwrap()); - - //stdout tracing thread - thread::spawn(move || { - let _span = tracing::span!(tracing::Level::INFO, "python-backend").entered(); - log_lines(stdout_reader.lines()); - }); - - let start_time = Instant::now(); - let mut wait_time = Instant::now(); - - loop { - // Process exited - if let Some(exit_status) = p.try_wait().unwrap() { - // We read stderr in another thread as it seems that lines() can block in some cases - let (err_sender, err_receiver) = mpsc::channel(); - thread::spawn(move || { - for line in stderr_reader.lines().map_while(Result::ok) { - err_sender.send(line).unwrap_or(()); - } - }); - let mut err = String::new(); - while let Ok(line) = err_receiver.recv_timeout(Duration::from_millis(10)) { - err = err + "\n" + &line; - } - - tracing::debug!("Python Backend complete standard error output:\n{err}"); - - if let Some(signal) = exit_status.signal() { - return Err(BackendError::Start(format!( - "Python Backend process was signaled to shutdown with signal {signal}" - ))); - } - return Err(BackendError::Start( - "Python backend failed to start".to_string(), - )); - } - - // Shard is ready - if uds.exists() { - tracing::info!("Python backend ready in {:?}", start_time.elapsed()); - break; - } else if wait_time.elapsed() > Duration::from_secs(10) { - tracing::info!("Waiting for Python backend to be ready..."); - wait_time = Instant::now(); - } - sleep(Duration::from_millis(5)); - } - - Ok(Self { inner: p }) - } -} - -impl Drop for BackendProcess { - fn drop(&mut self) { - self.inner.kill().unwrap(); - let _ = self.inner.wait(); - tracing::info!("Python backend process terminated"); - } -} diff --git a/backends/python/server/text_embeddings_server/models/__init__.py b/backends/python/server/text_embeddings_server/models/__init__.py index 1e919f23..8fb4076c 100644 --- a/backends/python/server/text_embeddings_server/models/__init__.py +++ b/backends/python/server/text_embeddings_server/models/__init__.py @@ -14,7 +14,9 @@ from text_embeddings_server.models.jinaBert_model import FlashJinaBert from text_embeddings_server.models.flash_mistral import FlashMistral from text_embeddings_server.models.flash_qwen3 import FlashQwen3 -from text_embeddings_server.utils.device import get_device, use_ipex +from text_embeddings_server.models.neuron_models import NeuronSentenceTransformers + +from text_embeddings_server.utils.device import get_device, use_ipex, is_neuron __all__ = ["Model"] @@ -74,6 +76,11 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str): logger.info(f"backend device: {device}") config = AutoConfig.from_pretrained(model_path, trust_remote_code=TRUST_REMOTE_CODE) + + # Neuron cases + if is_neuron(): + if config.model_type == "bert": + return create_model(NeuronSentenceTransformers, model_path) if ( hasattr(config, "auto_map") diff --git a/backends/python/server/text_embeddings_server/models/neuron_models.py b/backends/python/server/text_embeddings_server/models/neuron_models.py new file mode 100644 index 00000000..d795db07 --- /dev/null +++ b/backends/python/server/text_embeddings_server/models/neuron_models.py @@ -0,0 +1,67 @@ +import inspect +import torch + +from pathlib import Path +from typing import Type, List +from optimum.neuron import NeuronModelForSentenceTransformers +from opentelemetry import trace + +from text_embeddings_server.models import Model +from text_embeddings_server.models.types import PaddedBatch, Embedding, Score + +tracer = trace.get_tracer(__name__) + + +class NeuronSentenceTransformers(Model): + def __init__( + self, + model_path: Path, + device: torch.device, + dtype: torch.dtype, + ): + model = NeuronModelForSentenceTransformers.from_pretrained(model_path) + + self.hidden_size = model.config.hidden_size + position_offset = 0 + model_type = model.config.model_type + if model_type in ["xlm-roberta", "camembert", "roberta"]: + position_offset = model.config.pad_token_id + 1 + if hasattr(model.config, "max_seq_length"): + self.max_input_length = model.config.max_seq_length + else: + self.max_input_length = ( + model.config.max_position_embeddings - position_offset + ) + + self.has_position_ids = ( + inspect.signature(model.forward).parameters.get("position_ids", None) + is not None + ) + self.has_token_type_ids = ( + inspect.signature(model.forward).parameters.get("token_type_ids", None) + is not None + ) + + super(NeuronSentenceTransformers, self).__init__( + model=model, dtype=dtype, device=device + ) + + @property + def batch_type(self) -> Type[PaddedBatch]: + return PaddedBatch + + @tracer.start_as_current_span("embed") + def embed(self, batch: PaddedBatch) -> List[Embedding]: + pass + + @tracer.start_as_current_span("predict") + def predict(self, batch: PaddedBatch) -> List[Score]: + kwargs = {"input_ids": batch.input_ids, "attention_mask": batch.attention_mask} + if self.has_token_type_ids: + kwargs["token_type_ids"] = batch.token_type_ids + if self.has_position_ids: + kwargs["position_ids"] = batch.position_ids + + output = self.model(**kwargs, return_dict=True) + all_scores = output.logits.tolist() + return [Score(values=scores) for scores in all_scores] diff --git a/backends/python/server/text_embeddings_server/utils/device.py b/backends/python/server/text_embeddings_server/utils/device.py index 3f3b04dd..46b81370 100644 --- a/backends/python/server/text_embeddings_server/utils/device.py +++ b/backends/python/server/text_embeddings_server/utils/device.py @@ -1,4 +1,6 @@ import os +import re +import functools from loguru import logger import importlib.metadata import importlib.util @@ -49,6 +51,21 @@ def is_hpu() -> bool: is_hpu_available = False return is_hpu_available +@functools.cache +def get_neuron_major() -> int: + MAJORS_FILE = "/proc/devices" + NEURON_MAJOR_LINE = re.compile(r"^\s*(\d+)\s+neuron\s*$") + if not os.path.exists(MAJORS_FILE): + return -1 + with open(MAJORS_FILE, "r") as f: + for l in f.readlines(): + m = NEURON_MAJOR_LINE.match(l) + if m: + return int(m.group(1)) + return -1 + +def is_neuron() -> bool: + return get_neuron_major > -1 def use_ipex() -> bool: value = os.environ.get("USE_IPEX", "True").lower() @@ -72,5 +89,7 @@ def get_device(): if hasattr(torch, "xpu") and torch.xpu.is_available(): device = torch.device("xpu") + elif is_neuron(): + device = torch.device("xla") return device From dd0c08ddad7abe38caf76f720844e7438e42067a Mon Sep 17 00:00:00 2001 From: JingyaHuang Date: Mon, 27 Oct 2025 17:10:38 +0000 Subject: [PATCH 3/5] fix: neuron dockerfile --- Dockerfile-neuron | 187 ++++++++++++++++++ Dockerfile.neuron | 43 ---- backends/Cargo.toml | 1 - .../python/server/requirements-neuron.txt | 1 + docs/source/en/ aws_neuron.md | 37 ++++ docs/source/en/local_neuron.md | 1 - 6 files changed, 225 insertions(+), 45 deletions(-) create mode 100644 Dockerfile-neuron delete mode 100644 Dockerfile.neuron create mode 100644 backends/python/server/requirements-neuron.txt create mode 100644 docs/source/en/ aws_neuron.md delete mode 100644 docs/source/en/local_neuron.md diff --git a/Dockerfile-neuron b/Dockerfile-neuron new file mode 100644 index 00000000..52797d68 --- /dev/null +++ b/Dockerfile-neuron @@ -0,0 +1,187 @@ +ARG PLATFORM=neuron +FROM lukemathwalker/cargo-chef:latest-rust-1.85-bookworm AS chef +WORKDIR /usr/src + +ENV SCCACHE=0.10.0 +ENV RUSTC_WRAPPER=/usr/local/bin/sccache + +# Donwload, configure sccache +RUN curl -fsSL https://github.com/mozilla/sccache/releases/download/v$SCCACHE/sccache-v$SCCACHE-x86_64-unknown-linux-musl.tar.gz | tar -xzv --strip-components=1 -C /usr/local/bin sccache-v$SCCACHE-x86_64-unknown-linux-musl/sccache && \ + chmod +x /usr/local/bin/sccache + +FROM chef AS planner + +COPY backends backends +COPY core core +COPY router router +COPY Cargo.toml ./ +COPY Cargo.lock ./ + +RUN cargo chef prepare --recipe-path recipe.json + +FROM chef AS builder + +ARG GIT_SHA +ARG DOCKER_LABEL + +# sccache specific variables +ARG SCCACHE_GHA_ENABLED + +COPY --from=planner /usr/src/recipe.json recipe.json + +RUN --mount=type=secret,id=actions_results_url,env=ACTIONS_RESULTS_URL \ + --mount=type=secret,id=actions_runtime_token,env=ACTIONS_RUNTIME_TOKEN \ + cargo chef cook --release --features ort,candle,mkl,static-linking --no-default-features --recipe-path recipe.json && sccache -s + +COPY backends backends +COPY core core +COPY router router +COPY Cargo.toml ./ +COPY Cargo.lock ./ + +FROM builder AS http-builder + +RUN --mount=type=secret,id=actions_results_url,env=ACTIONS_RESULTS_URL \ + --mount=type=secret,id=actions_runtime_token,env=ACTIONS_RUNTIME_TOKEN \ + cargo build --release --bin text-embeddings-router --features ort,candle,mkl,static-linking,http --no-default-features && sccache -s + +FROM builder AS grpc-builder + +RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \ + curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \ + unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \ + unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \ + rm -f $PROTOC_ZIP + +COPY proto proto + +RUN --mount=type=secret,id=actions_results_url,env=ACTIONS_RESULTS_URL \ + --mount=type=secret,id=actions_runtime_token,env=ACTIONS_RUNTIME_TOKEN \ + cargo build --release --bin text-embeddings-router --features ort,candle,mkl,static-linking,grpc --no-default-features && sccache -s + +FROM public.ecr.aws/docker/library/ubuntu:22.04 AS neuron + +ENV HUGGINGFACE_HUB_CACHE=/data \ + PORT=80 + +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ + python3 \ + python3-pip \ + python3-dev \ + build-essential \ + git \ + curl \ + cmake \ + pkg-config \ + protobuf-compiler \ + ninja-build \ + && rm -rf /var/lib/apt/lists/* + +RUN ln -s /usr/bin/python3 /usr/local/bin/python || true +RUN ln -s /usr/bin/pip3 /usr/local/bin/pip || true + +WORKDIR /usr/src +COPY backends backends +COPY backends/python/server/text_embeddings_server/models/__init__.py backends/python/server/text_embeddings_server/models/__init__.py +COPY backends/python/server/pyproject.toml backends/python/server/pyproject.toml +RUN cd backends/python/server && \ + make install + +ARG NEURONX_COLLECTIVES_LIB_VERSION=2.28.27.0-bc30ece58 +ARG NEURONX_RUNTIME_LIB_VERSION=2.28.23.0-dd5879008 +ARG NEURONX_TOOLS_VERSION=2.26.14.0 + +ARG NEURONX_CC_VERSION=2.21.18209.0+043b1bf7 +ARG NEURONX_FRAMEWORK_VERSION=2.8.0.2.10.13553+1e4dd6ca +ARG NEURONX_DISTRIBUTED_VERSION=0.15.22404+1f27bddf +ARG NEURONX_DISTRIBUTED_INFERENCE_VERSION=0.6.10598+a59fdc00 + +RUN apt-get update \ + && apt-get upgrade -y \ + && apt-get install -y --no-install-recommends \ + apt-transport-https \ + build-essential \ + ca-certificates \ + cmake \ + curl \ + emacs \ + git \ + gnupg2 \ + gpg-agent \ + jq \ + libgl1-mesa-glx \ + libglib2.0-0 \ + libsm6 \ + libxext6 \ + libxrender-dev \ + libcap-dev \ + libhwloc-dev \ + openjdk-11-jdk \ + unzip \ + vim \ + wget \ + zlib1g-dev \ + && rm -rf /var/lib/apt/lists/* \ + && rm -rf /tmp/tmp* \ + && apt-get clean + +RUN echo "deb https://apt.repos.neuron.amazonaws.com focal main" > /etc/apt/sources.list.d/neuron.list +RUN wget -qO - https://apt.repos.neuron.amazonaws.com/GPG-PUB-KEY-AMAZON-AWS-NEURON.PUB | apt-key add - + +RUN apt-get update \ + && apt-get install -y \ + aws-neuronx-tools=$NEURONX_TOOLS_VERSION \ + aws-neuronx-collectives=$NEURONX_COLLECTIVES_LIB_VERSION \ + aws-neuronx-runtime-lib=$NEURONX_RUNTIME_LIB_VERSION \ + && rm -rf /var/lib/apt/lists/* \ + && rm -rf /tmp/tmp* \ + && apt-get clean + +RUN pip install --index-url https://pip.repos.neuron.amazonaws.com \ + --extra-index-url https://pypi.org/simple \ + --trusted-host pip.repos.neuron.amazonaws.com \ + neuronx-cc==$NEURONX_CC_VERSION \ + torch-neuronx==$NEURONX_FRAMEWORK_VERSION \ + neuronx_distributed==$NEURONX_DISTRIBUTED_VERSION \ + && rm -rf ~/.cache/pip/* + +# HF ARGS +ARG TRANSFORMERS_VERSION=4.55.4 +ARG DIFFUSERS_VERSION=0.35.2 +ARG HUGGINGFACE_HUB_VERSION=0.36.0 +ARG OPTIMUM_NEURON_VERSION=0.4.1 +ARG SENTENCE_TRANSFORMERS=5.1.2 +ARG PEFT_VERSION=0.17.0 +ARG DATASETS_VERSION=4.1.1 + +# install Hugging Face libraries and its dependencies +RUN pip install --no-cache-dir -U \ + networkx==2.8.8 \ + transformers[sentencepiece,audio,vision]==${TRANSFORMERS_VERSION} \ + diffusers==${DIFFUSERS_VERSION} \ + compel \ + controlnet-aux \ + huggingface_hub==${HUGGINGFACE_HUB_VERSION} \ + hf_transfer \ + datasets==${DATASETS_VERSION} \ + optimum-neuron==${OPTIMUM_NEURON_VERSION} \ + sentence_transformers==${SENTENCE_TRANSFORMERS} \ + peft==${PEFT_VERSION} \ + && rm -rf ~/.cache/pip/* + + +FROM neuron AS grpc + +COPY --from=grpc-builder /usr/src/target/release/text-embeddings-router /usr/local/bin/text-embeddings-router + +ENTRYPOINT ["text-embeddings-router"] +CMD ["--json-output"] + +FROM neuron + +COPY --from=http-builder /usr/src/target/release/text-embeddings-router /usr/local/bin/text-embeddings-router + +ENTRYPOINT ["text-embeddings-router"] +CMD ["--json-output"] + + diff --git a/Dockerfile.neuron b/Dockerfile.neuron deleted file mode 100644 index f8b03ab2..00000000 --- a/Dockerfile.neuron +++ /dev/null @@ -1,43 +0,0 @@ -ARG PLATFORM=neuron -FROM lukemathwalker/cargo-chef:latest-rust-1.85-bookworm AS chef -WORKDIR /usr/src - -ENV SCCACHE=0.10.0 -ENV RUSTC_WRAPPER=/usr/local/bin/sccache - -# Donwload, configure sccache -RUN curl -fsSL https://github.com/mozilla/sccache/releases/download/v$SCCACHE/sccache-v$SCCACHE-x86_64-unknown-linux-musl.tar.gz | tar -xzv --strip-components=1 -C /usr/local/bin sccache-v$SCCACHE-x86_64-unknown-linux-musl/sccache && \ - chmod +x /usr/local/bin/sccache - -FROM chef AS planner - -COPY backends backends -COPY core core -COPY router router -COPY Cargo.toml ./ -COPY Cargo.lock ./ - -RUN cargo chef prepare --recipe-path recipe.json - -FROM chef AS builder - -ARG GIT_SHA -ARG DOCKER_LABEL - -# sccache specific variables -ARG SCCACHE_GHA_ENABLED - -COPY --from=planner /usr/src/recipe.json recipe.json - -RUN --mount=type=secret,id=actions_results_url,env=ACTIONS_RESULTS_URL \ - --mount=type=secret,id=actions_runtime_token,env=ACTIONS_RUNTIME_TOKEN \ - cargo chef cook --release --features python --no-default-features --recipe-path recipe.json && sccache -s - -COPY backends backends -COPY core core -COPY router router -COPY Cargo.toml ./ -COPY Cargo.lock ./ - -WORKDIR /usr/src - diff --git a/backends/Cargo.toml b/backends/Cargo.toml index 7d821ff4..bb9d7419 100644 --- a/backends/Cargo.toml +++ b/backends/Cargo.toml @@ -21,7 +21,6 @@ rand = { workspace = true } [features] clap = ["dep:clap", "text-embeddings-backend-core/clap"] python = ["dep:text-embeddings-backend-python"] -neuron = ["dep:text-embeddings-backend-neuron"] ort = ["dep:text-embeddings-backend-ort"] candle = ["dep:text-embeddings-backend-candle"] cuda = ["text-embeddings-backend-candle?/cuda"] diff --git a/backends/python/server/requirements-neuron.txt b/backends/python/server/requirements-neuron.txt new file mode 100644 index 00000000..b8ce3518 --- /dev/null +++ b/backends/python/server/requirements-neuron.txt @@ -0,0 +1 @@ +transformers==4.55.4 \ No newline at end of file diff --git a/docs/source/en/ aws_neuron.md b/docs/source/en/ aws_neuron.md new file mode 100644 index 00000000..13ea7f86 --- /dev/null +++ b/docs/source/en/ aws_neuron.md @@ -0,0 +1,37 @@ + +# Using TEI Container with AWS Trainium and Inferentia Instances + +## Build Docker Image + +To build a container optimized for AWS Neuron devices, run the following command: + +```shell +platform="neuron" + +docker build . -f Dockerfile-neuron -t tei_neuron +``` + +### Deploy Docker Container + +To deploy your model on an AWS Trainium or Inferentia instance, use the following command: + +```shell +model='Qwen/Qwen3-Embedding-0.6B' +volume=$PWD/data + +docker run -p 8080:80 -v $volume:/data tei_neuron --model-id $model +``` \ No newline at end of file diff --git a/docs/source/en/local_neuron.md b/docs/source/en/local_neuron.md deleted file mode 100644 index e0a2cf2b..00000000 --- a/docs/source/en/local_neuron.md +++ /dev/null @@ -1 +0,0 @@ -# Neuron backend for AWS Trainium and Inferentia \ No newline at end of file From 1e4f3c92d03c9193c62f9d7d20e476b0f2f11dda Mon Sep 17 00:00:00 2001 From: JingyaHuang Date: Tue, 28 Oct 2025 17:23:49 +0000 Subject: [PATCH 4/5] remove useless --- Dockerfile-neuron | 2 -- backends/python/server/requirements-neuron.txt | 1 - 2 files changed, 3 deletions(-) delete mode 100644 backends/python/server/requirements-neuron.txt diff --git a/Dockerfile-neuron b/Dockerfile-neuron index 52797d68..a536ab7d 100644 --- a/Dockerfile-neuron +++ b/Dockerfile-neuron @@ -183,5 +183,3 @@ COPY --from=http-builder /usr/src/target/release/text-embeddings-router /usr/loc ENTRYPOINT ["text-embeddings-router"] CMD ["--json-output"] - - diff --git a/backends/python/server/requirements-neuron.txt b/backends/python/server/requirements-neuron.txt deleted file mode 100644 index b8ce3518..00000000 --- a/backends/python/server/requirements-neuron.txt +++ /dev/null @@ -1 +0,0 @@ -transformers==4.55.4 \ No newline at end of file From a25cf98d6d98135258ad5ec18549ebdeed02f02a Mon Sep 17 00:00:00 2001 From: JingyaHuang Date: Fri, 31 Oct 2025 13:11:12 +0000 Subject: [PATCH 5/5] fix dockerfile --- Dockerfile-neuron | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/Dockerfile-neuron b/Dockerfile-neuron index a536ab7d..e09c6491 100644 --- a/Dockerfile-neuron +++ b/Dockerfile-neuron @@ -1,4 +1,3 @@ -ARG PLATFORM=neuron FROM lukemathwalker/cargo-chef:latest-rust-1.85-bookworm AS chef WORKDIR /usr/src @@ -31,7 +30,7 @@ COPY --from=planner /usr/src/recipe.json recipe.json RUN --mount=type=secret,id=actions_results_url,env=ACTIONS_RESULTS_URL \ --mount=type=secret,id=actions_runtime_token,env=ACTIONS_RUNTIME_TOKEN \ - cargo chef cook --release --features ort,candle,mkl,static-linking --no-default-features --recipe-path recipe.json && sccache -s + cargo chef cook --release --features python --no-default-features --recipe-path recipe.json && sccache -s COPY backends backends COPY core core @@ -39,25 +38,25 @@ COPY router router COPY Cargo.toml ./ COPY Cargo.lock ./ +RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \ + curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \ + unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \ + unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \ + rm -f $PROTOC_ZIP + FROM builder AS http-builder RUN --mount=type=secret,id=actions_results_url,env=ACTIONS_RESULTS_URL \ --mount=type=secret,id=actions_runtime_token,env=ACTIONS_RUNTIME_TOKEN \ - cargo build --release --bin text-embeddings-router --features ort,candle,mkl,static-linking,http --no-default-features && sccache -s + cargo build --release --bin text-embeddings-router -F python -F http --no-default-features && sccache -s FROM builder AS grpc-builder -RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \ - curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \ - unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \ - unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \ - rm -f $PROTOC_ZIP - COPY proto proto RUN --mount=type=secret,id=actions_results_url,env=ACTIONS_RESULTS_URL \ --mount=type=secret,id=actions_runtime_token,env=ACTIONS_RUNTIME_TOKEN \ - cargo build --release --bin text-embeddings-router --features ort,candle,mkl,static-linking,grpc --no-default-features && sccache -s + cargo build --release --bin text-embeddings-router -F grpc -F python --no-default-features && sccache -s FROM public.ecr.aws/docker/library/ubuntu:22.04 AS neuron