Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 187 additions & 0 deletions Dockerfile-neuron
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
ARG PLATFORM=neuron
FROM lukemathwalker/cargo-chef:latest-rust-1.85-bookworm AS chef
WORKDIR /usr/src

ENV SCCACHE=0.10.0
ENV RUSTC_WRAPPER=/usr/local/bin/sccache

# Donwload, configure sccache
RUN curl -fsSL https://github.com/mozilla/sccache/releases/download/v$SCCACHE/sccache-v$SCCACHE-x86_64-unknown-linux-musl.tar.gz | tar -xzv --strip-components=1 -C /usr/local/bin sccache-v$SCCACHE-x86_64-unknown-linux-musl/sccache && \
chmod +x /usr/local/bin/sccache

FROM chef AS planner

COPY backends backends
COPY core core
COPY router router
COPY Cargo.toml ./
COPY Cargo.lock ./

RUN cargo chef prepare --recipe-path recipe.json

FROM chef AS builder

ARG GIT_SHA
ARG DOCKER_LABEL

# sccache specific variables
ARG SCCACHE_GHA_ENABLED

COPY --from=planner /usr/src/recipe.json recipe.json

RUN --mount=type=secret,id=actions_results_url,env=ACTIONS_RESULTS_URL \
--mount=type=secret,id=actions_runtime_token,env=ACTIONS_RUNTIME_TOKEN \
cargo chef cook --release --features ort,candle,mkl,static-linking --no-default-features --recipe-path recipe.json && sccache -s

COPY backends backends
COPY core core
COPY router router
COPY Cargo.toml ./
COPY Cargo.lock ./

FROM builder AS http-builder

RUN --mount=type=secret,id=actions_results_url,env=ACTIONS_RESULTS_URL \
--mount=type=secret,id=actions_runtime_token,env=ACTIONS_RUNTIME_TOKEN \
cargo build --release --bin text-embeddings-router --features ort,candle,mkl,static-linking,http --no-default-features && sccache -s

FROM builder AS grpc-builder

RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \
unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \
rm -f $PROTOC_ZIP

COPY proto proto

RUN --mount=type=secret,id=actions_results_url,env=ACTIONS_RESULTS_URL \
--mount=type=secret,id=actions_runtime_token,env=ACTIONS_RUNTIME_TOKEN \
cargo build --release --bin text-embeddings-router --features ort,candle,mkl,static-linking,grpc --no-default-features && sccache -s

FROM public.ecr.aws/docker/library/ubuntu:22.04 AS neuron

ENV HUGGINGFACE_HUB_CACHE=/data \
PORT=80

RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
python3 \
python3-pip \
python3-dev \
build-essential \
git \
curl \
cmake \
pkg-config \
protobuf-compiler \
ninja-build \
&& rm -rf /var/lib/apt/lists/*

RUN ln -s /usr/bin/python3 /usr/local/bin/python || true
RUN ln -s /usr/bin/pip3 /usr/local/bin/pip || true

WORKDIR /usr/src
COPY backends backends
COPY backends/python/server/text_embeddings_server/models/__init__.py backends/python/server/text_embeddings_server/models/__init__.py
COPY backends/python/server/pyproject.toml backends/python/server/pyproject.toml
RUN cd backends/python/server && \
make install

ARG NEURONX_COLLECTIVES_LIB_VERSION=2.28.27.0-bc30ece58
ARG NEURONX_RUNTIME_LIB_VERSION=2.28.23.0-dd5879008
ARG NEURONX_TOOLS_VERSION=2.26.14.0

ARG NEURONX_CC_VERSION=2.21.18209.0+043b1bf7
ARG NEURONX_FRAMEWORK_VERSION=2.8.0.2.10.13553+1e4dd6ca
ARG NEURONX_DISTRIBUTED_VERSION=0.15.22404+1f27bddf
ARG NEURONX_DISTRIBUTED_INFERENCE_VERSION=0.6.10598+a59fdc00

RUN apt-get update \
&& apt-get upgrade -y \
&& apt-get install -y --no-install-recommends \
apt-transport-https \
build-essential \
ca-certificates \
cmake \
curl \
emacs \
git \
gnupg2 \
gpg-agent \
jq \
libgl1-mesa-glx \
libglib2.0-0 \
libsm6 \
libxext6 \
libxrender-dev \
libcap-dev \
libhwloc-dev \
openjdk-11-jdk \
unzip \
vim \
wget \
zlib1g-dev \
&& rm -rf /var/lib/apt/lists/* \
&& rm -rf /tmp/tmp* \
&& apt-get clean

RUN echo "deb https://apt.repos.neuron.amazonaws.com focal main" > /etc/apt/sources.list.d/neuron.list
RUN wget -qO - https://apt.repos.neuron.amazonaws.com/GPG-PUB-KEY-AMAZON-AWS-NEURON.PUB | apt-key add -

RUN apt-get update \
&& apt-get install -y \
aws-neuronx-tools=$NEURONX_TOOLS_VERSION \
aws-neuronx-collectives=$NEURONX_COLLECTIVES_LIB_VERSION \
aws-neuronx-runtime-lib=$NEURONX_RUNTIME_LIB_VERSION \
&& rm -rf /var/lib/apt/lists/* \
&& rm -rf /tmp/tmp* \
&& apt-get clean

RUN pip install --index-url https://pip.repos.neuron.amazonaws.com \
--extra-index-url https://pypi.org/simple \
--trusted-host pip.repos.neuron.amazonaws.com \
neuronx-cc==$NEURONX_CC_VERSION \
torch-neuronx==$NEURONX_FRAMEWORK_VERSION \
neuronx_distributed==$NEURONX_DISTRIBUTED_VERSION \
&& rm -rf ~/.cache/pip/*

# HF ARGS
ARG TRANSFORMERS_VERSION=4.55.4
ARG DIFFUSERS_VERSION=0.35.2
ARG HUGGINGFACE_HUB_VERSION=0.36.0
ARG OPTIMUM_NEURON_VERSION=0.4.1
ARG SENTENCE_TRANSFORMERS=5.1.2
ARG PEFT_VERSION=0.17.0
ARG DATASETS_VERSION=4.1.1

# install Hugging Face libraries and its dependencies
RUN pip install --no-cache-dir -U \
networkx==2.8.8 \
transformers[sentencepiece,audio,vision]==${TRANSFORMERS_VERSION} \
diffusers==${DIFFUSERS_VERSION} \
compel \
controlnet-aux \
huggingface_hub==${HUGGINGFACE_HUB_VERSION} \
hf_transfer \
datasets==${DATASETS_VERSION} \
optimum-neuron==${OPTIMUM_NEURON_VERSION} \
sentence_transformers==${SENTENCE_TRANSFORMERS} \
peft==${PEFT_VERSION} \
&& rm -rf ~/.cache/pip/*


FROM neuron AS grpc

COPY --from=grpc-builder /usr/src/target/release/text-embeddings-router /usr/local/bin/text-embeddings-router

ENTRYPOINT ["text-embeddings-router"]
CMD ["--json-output"]

FROM neuron

COPY --from=http-builder /usr/src/target/release/text-embeddings-router /usr/local/bin/text-embeddings-router

ENTRYPOINT ["text-embeddings-router"]
CMD ["--json-output"]


1 change: 1 addition & 0 deletions backends/python/server/requirements-neuron.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
transformers==4.55.4
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
from text_embeddings_server.models.jinaBert_model import FlashJinaBert
from text_embeddings_server.models.flash_mistral import FlashMistral
from text_embeddings_server.models.flash_qwen3 import FlashQwen3
from text_embeddings_server.utils.device import get_device, use_ipex
from text_embeddings_server.models.neuron_models import NeuronSentenceTransformers

from text_embeddings_server.utils.device import get_device, use_ipex, is_neuron

__all__ = ["Model"]

Expand Down Expand Up @@ -74,6 +76,11 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str):
logger.info(f"backend device: {device}")

config = AutoConfig.from_pretrained(model_path, trust_remote_code=TRUST_REMOTE_CODE)

# Neuron cases
if is_neuron():
if config.model_type == "bert":
return create_model(NeuronSentenceTransformers, model_path)

if (
hasattr(config, "auto_map")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import inspect
import torch

from pathlib import Path
from typing import Type, List
from optimum.neuron import NeuronModelForSentenceTransformers
from opentelemetry import trace

from text_embeddings_server.models import Model
from text_embeddings_server.models.types import PaddedBatch, Embedding, Score

tracer = trace.get_tracer(__name__)


class NeuronSentenceTransformers(Model):
def __init__(
self,
model_path: Path,
device: torch.device,
dtype: torch.dtype,
):
model = NeuronModelForSentenceTransformers.from_pretrained(model_path)

self.hidden_size = model.config.hidden_size
position_offset = 0
model_type = model.config.model_type
if model_type in ["xlm-roberta", "camembert", "roberta"]:
position_offset = model.config.pad_token_id + 1
if hasattr(model.config, "max_seq_length"):
self.max_input_length = model.config.max_seq_length
else:
self.max_input_length = (
model.config.max_position_embeddings - position_offset
)

self.has_position_ids = (
inspect.signature(model.forward).parameters.get("position_ids", None)
is not None
)
self.has_token_type_ids = (
inspect.signature(model.forward).parameters.get("token_type_ids", None)
is not None
)

super(NeuronSentenceTransformers, self).__init__(
model=model, dtype=dtype, device=device
)

@property
def batch_type(self) -> Type[PaddedBatch]:
return PaddedBatch

@tracer.start_as_current_span("embed")
def embed(self, batch: PaddedBatch) -> List[Embedding]:
pass

@tracer.start_as_current_span("predict")
def predict(self, batch: PaddedBatch) -> List[Score]:
kwargs = {"input_ids": batch.input_ids, "attention_mask": batch.attention_mask}
if self.has_token_type_ids:
kwargs["token_type_ids"] = batch.token_type_ids
if self.has_position_ids:
kwargs["position_ids"] = batch.position_ids

output = self.model(**kwargs, return_dict=True)
all_scores = output.logits.tolist()
return [Score(values=scores) for scores in all_scores]
19 changes: 19 additions & 0 deletions backends/python/server/text_embeddings_server/utils/device.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import os
import re
import functools
from loguru import logger
import importlib.metadata
import importlib.util
Expand Down Expand Up @@ -49,6 +51,21 @@ def is_hpu() -> bool:
is_hpu_available = False
return is_hpu_available

@functools.cache
def get_neuron_major() -> int:
MAJORS_FILE = "/proc/devices"
NEURON_MAJOR_LINE = re.compile(r"^\s*(\d+)\s+neuron\s*$")
if not os.path.exists(MAJORS_FILE):
return -1
with open(MAJORS_FILE, "r") as f:
for l in f.readlines():
m = NEURON_MAJOR_LINE.match(l)
if m:
return int(m.group(1))
return -1

def is_neuron() -> bool:
return get_neuron_major > -1

def use_ipex() -> bool:
value = os.environ.get("USE_IPEX", "True").lower()
Expand All @@ -72,5 +89,7 @@ def get_device():

if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device("xpu")
elif is_neuron():
device = torch.device("xla")

return device
37 changes: 37 additions & 0 deletions docs/source/en/ aws_neuron.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.

⚠️ Note that this file is in Markdown but contains specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.

-->
# Using TEI Container with AWS Trainium and Inferentia Instances

## Build Docker Image

To build a container optimized for AWS Neuron devices, run the following command:

```shell
platform="neuron"

docker build . -f Dockerfile-neuron -t tei_neuron
```

### Deploy Docker Container

To deploy your model on an AWS Trainium or Inferentia instance, use the following command:

```shell
model='Qwen/Qwen3-Embedding-0.6B'
volume=$PWD/data

docker run -p 8080:80 -v $volume:/data tei_neuron --model-id $model
```
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
title: Build custom container for TEI
- local: intel_container
title: Using TEI container with Intel Hardware
- local: local_neuron
title: Using TEI container with AWS Neuron
- local: examples
title: Example uses
title: Tutorials
Expand Down
Empty file.
Empty file.
Loading