diff --git a/Dockerfile-neuron b/Dockerfile-neuron new file mode 100644 index 00000000..e09c6491 --- /dev/null +++ b/Dockerfile-neuron @@ -0,0 +1,184 @@ +FROM lukemathwalker/cargo-chef:latest-rust-1.85-bookworm AS chef +WORKDIR /usr/src + +ENV SCCACHE=0.10.0 +ENV RUSTC_WRAPPER=/usr/local/bin/sccache + +# Donwload, configure sccache +RUN curl -fsSL https://github.com/mozilla/sccache/releases/download/v$SCCACHE/sccache-v$SCCACHE-x86_64-unknown-linux-musl.tar.gz | tar -xzv --strip-components=1 -C /usr/local/bin sccache-v$SCCACHE-x86_64-unknown-linux-musl/sccache && \ + chmod +x /usr/local/bin/sccache + +FROM chef AS planner + +COPY backends backends +COPY core core +COPY router router +COPY Cargo.toml ./ +COPY Cargo.lock ./ + +RUN cargo chef prepare --recipe-path recipe.json + +FROM chef AS builder + +ARG GIT_SHA +ARG DOCKER_LABEL + +# sccache specific variables +ARG SCCACHE_GHA_ENABLED + +COPY --from=planner /usr/src/recipe.json recipe.json + +RUN --mount=type=secret,id=actions_results_url,env=ACTIONS_RESULTS_URL \ + --mount=type=secret,id=actions_runtime_token,env=ACTIONS_RUNTIME_TOKEN \ + cargo chef cook --release --features python --no-default-features --recipe-path recipe.json && sccache -s + +COPY backends backends +COPY core core +COPY router router +COPY Cargo.toml ./ +COPY Cargo.lock ./ + +RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \ + curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \ + unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \ + unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \ + rm -f $PROTOC_ZIP + +FROM builder AS http-builder + +RUN --mount=type=secret,id=actions_results_url,env=ACTIONS_RESULTS_URL \ + --mount=type=secret,id=actions_runtime_token,env=ACTIONS_RUNTIME_TOKEN \ + cargo build --release --bin text-embeddings-router -F python -F http --no-default-features && sccache -s + +FROM builder AS grpc-builder + +COPY proto proto + +RUN --mount=type=secret,id=actions_results_url,env=ACTIONS_RESULTS_URL \ + --mount=type=secret,id=actions_runtime_token,env=ACTIONS_RUNTIME_TOKEN \ + cargo build --release --bin text-embeddings-router -F grpc -F python --no-default-features && sccache -s + +FROM public.ecr.aws/docker/library/ubuntu:22.04 AS neuron + +ENV HUGGINGFACE_HUB_CACHE=/data \ + PORT=80 + +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ + python3 \ + python3-pip \ + python3-dev \ + build-essential \ + git \ + curl \ + cmake \ + pkg-config \ + protobuf-compiler \ + ninja-build \ + && rm -rf /var/lib/apt/lists/* + +RUN ln -s /usr/bin/python3 /usr/local/bin/python || true +RUN ln -s /usr/bin/pip3 /usr/local/bin/pip || true + +WORKDIR /usr/src +COPY backends backends +COPY backends/python/server/text_embeddings_server/models/__init__.py backends/python/server/text_embeddings_server/models/__init__.py +COPY backends/python/server/pyproject.toml backends/python/server/pyproject.toml +RUN cd backends/python/server && \ + make install + +ARG NEURONX_COLLECTIVES_LIB_VERSION=2.28.27.0-bc30ece58 +ARG NEURONX_RUNTIME_LIB_VERSION=2.28.23.0-dd5879008 +ARG NEURONX_TOOLS_VERSION=2.26.14.0 + +ARG NEURONX_CC_VERSION=2.21.18209.0+043b1bf7 +ARG NEURONX_FRAMEWORK_VERSION=2.8.0.2.10.13553+1e4dd6ca +ARG NEURONX_DISTRIBUTED_VERSION=0.15.22404+1f27bddf +ARG NEURONX_DISTRIBUTED_INFERENCE_VERSION=0.6.10598+a59fdc00 + +RUN apt-get update \ + && apt-get upgrade -y \ + && apt-get install -y --no-install-recommends \ + apt-transport-https \ + build-essential \ + ca-certificates \ + cmake \ + curl \ + emacs \ + git \ + gnupg2 \ + gpg-agent \ + jq \ + libgl1-mesa-glx \ + libglib2.0-0 \ + libsm6 \ + libxext6 \ + libxrender-dev \ + libcap-dev \ + libhwloc-dev \ + openjdk-11-jdk \ + unzip \ + vim \ + wget \ + zlib1g-dev \ + && rm -rf /var/lib/apt/lists/* \ + && rm -rf /tmp/tmp* \ + && apt-get clean + +RUN echo "deb https://apt.repos.neuron.amazonaws.com focal main" > /etc/apt/sources.list.d/neuron.list +RUN wget -qO - https://apt.repos.neuron.amazonaws.com/GPG-PUB-KEY-AMAZON-AWS-NEURON.PUB | apt-key add - + +RUN apt-get update \ + && apt-get install -y \ + aws-neuronx-tools=$NEURONX_TOOLS_VERSION \ + aws-neuronx-collectives=$NEURONX_COLLECTIVES_LIB_VERSION \ + aws-neuronx-runtime-lib=$NEURONX_RUNTIME_LIB_VERSION \ + && rm -rf /var/lib/apt/lists/* \ + && rm -rf /tmp/tmp* \ + && apt-get clean + +RUN pip install --index-url https://pip.repos.neuron.amazonaws.com \ + --extra-index-url https://pypi.org/simple \ + --trusted-host pip.repos.neuron.amazonaws.com \ + neuronx-cc==$NEURONX_CC_VERSION \ + torch-neuronx==$NEURONX_FRAMEWORK_VERSION \ + neuronx_distributed==$NEURONX_DISTRIBUTED_VERSION \ + && rm -rf ~/.cache/pip/* + +# HF ARGS +ARG TRANSFORMERS_VERSION=4.55.4 +ARG DIFFUSERS_VERSION=0.35.2 +ARG HUGGINGFACE_HUB_VERSION=0.36.0 +ARG OPTIMUM_NEURON_VERSION=0.4.1 +ARG SENTENCE_TRANSFORMERS=5.1.2 +ARG PEFT_VERSION=0.17.0 +ARG DATASETS_VERSION=4.1.1 + +# install Hugging Face libraries and its dependencies +RUN pip install --no-cache-dir -U \ + networkx==2.8.8 \ + transformers[sentencepiece,audio,vision]==${TRANSFORMERS_VERSION} \ + diffusers==${DIFFUSERS_VERSION} \ + compel \ + controlnet-aux \ + huggingface_hub==${HUGGINGFACE_HUB_VERSION} \ + hf_transfer \ + datasets==${DATASETS_VERSION} \ + optimum-neuron==${OPTIMUM_NEURON_VERSION} \ + sentence_transformers==${SENTENCE_TRANSFORMERS} \ + peft==${PEFT_VERSION} \ + && rm -rf ~/.cache/pip/* + + +FROM neuron AS grpc + +COPY --from=grpc-builder /usr/src/target/release/text-embeddings-router /usr/local/bin/text-embeddings-router + +ENTRYPOINT ["text-embeddings-router"] +CMD ["--json-output"] + +FROM neuron + +COPY --from=http-builder /usr/src/target/release/text-embeddings-router /usr/local/bin/text-embeddings-router + +ENTRYPOINT ["text-embeddings-router"] +CMD ["--json-output"] diff --git a/backends/python/server/text_embeddings_server/models/__init__.py b/backends/python/server/text_embeddings_server/models/__init__.py index 1e919f23..8fb4076c 100644 --- a/backends/python/server/text_embeddings_server/models/__init__.py +++ b/backends/python/server/text_embeddings_server/models/__init__.py @@ -14,7 +14,9 @@ from text_embeddings_server.models.jinaBert_model import FlashJinaBert from text_embeddings_server.models.flash_mistral import FlashMistral from text_embeddings_server.models.flash_qwen3 import FlashQwen3 -from text_embeddings_server.utils.device import get_device, use_ipex +from text_embeddings_server.models.neuron_models import NeuronSentenceTransformers + +from text_embeddings_server.utils.device import get_device, use_ipex, is_neuron __all__ = ["Model"] @@ -74,6 +76,11 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str): logger.info(f"backend device: {device}") config = AutoConfig.from_pretrained(model_path, trust_remote_code=TRUST_REMOTE_CODE) + + # Neuron cases + if is_neuron(): + if config.model_type == "bert": + return create_model(NeuronSentenceTransformers, model_path) if ( hasattr(config, "auto_map") diff --git a/backends/python/server/text_embeddings_server/models/neuron_models.py b/backends/python/server/text_embeddings_server/models/neuron_models.py new file mode 100644 index 00000000..d795db07 --- /dev/null +++ b/backends/python/server/text_embeddings_server/models/neuron_models.py @@ -0,0 +1,67 @@ +import inspect +import torch + +from pathlib import Path +from typing import Type, List +from optimum.neuron import NeuronModelForSentenceTransformers +from opentelemetry import trace + +from text_embeddings_server.models import Model +from text_embeddings_server.models.types import PaddedBatch, Embedding, Score + +tracer = trace.get_tracer(__name__) + + +class NeuronSentenceTransformers(Model): + def __init__( + self, + model_path: Path, + device: torch.device, + dtype: torch.dtype, + ): + model = NeuronModelForSentenceTransformers.from_pretrained(model_path) + + self.hidden_size = model.config.hidden_size + position_offset = 0 + model_type = model.config.model_type + if model_type in ["xlm-roberta", "camembert", "roberta"]: + position_offset = model.config.pad_token_id + 1 + if hasattr(model.config, "max_seq_length"): + self.max_input_length = model.config.max_seq_length + else: + self.max_input_length = ( + model.config.max_position_embeddings - position_offset + ) + + self.has_position_ids = ( + inspect.signature(model.forward).parameters.get("position_ids", None) + is not None + ) + self.has_token_type_ids = ( + inspect.signature(model.forward).parameters.get("token_type_ids", None) + is not None + ) + + super(NeuronSentenceTransformers, self).__init__( + model=model, dtype=dtype, device=device + ) + + @property + def batch_type(self) -> Type[PaddedBatch]: + return PaddedBatch + + @tracer.start_as_current_span("embed") + def embed(self, batch: PaddedBatch) -> List[Embedding]: + pass + + @tracer.start_as_current_span("predict") + def predict(self, batch: PaddedBatch) -> List[Score]: + kwargs = {"input_ids": batch.input_ids, "attention_mask": batch.attention_mask} + if self.has_token_type_ids: + kwargs["token_type_ids"] = batch.token_type_ids + if self.has_position_ids: + kwargs["position_ids"] = batch.position_ids + + output = self.model(**kwargs, return_dict=True) + all_scores = output.logits.tolist() + return [Score(values=scores) for scores in all_scores] diff --git a/backends/python/server/text_embeddings_server/utils/device.py b/backends/python/server/text_embeddings_server/utils/device.py index 3f3b04dd..46b81370 100644 --- a/backends/python/server/text_embeddings_server/utils/device.py +++ b/backends/python/server/text_embeddings_server/utils/device.py @@ -1,4 +1,6 @@ import os +import re +import functools from loguru import logger import importlib.metadata import importlib.util @@ -49,6 +51,21 @@ def is_hpu() -> bool: is_hpu_available = False return is_hpu_available +@functools.cache +def get_neuron_major() -> int: + MAJORS_FILE = "/proc/devices" + NEURON_MAJOR_LINE = re.compile(r"^\s*(\d+)\s+neuron\s*$") + if not os.path.exists(MAJORS_FILE): + return -1 + with open(MAJORS_FILE, "r") as f: + for l in f.readlines(): + m = NEURON_MAJOR_LINE.match(l) + if m: + return int(m.group(1)) + return -1 + +def is_neuron() -> bool: + return get_neuron_major > -1 def use_ipex() -> bool: value = os.environ.get("USE_IPEX", "True").lower() @@ -72,5 +89,7 @@ def get_device(): if hasattr(torch, "xpu") and torch.xpu.is_available(): device = torch.device("xpu") + elif is_neuron(): + device = torch.device("xla") return device diff --git a/docs/source/en/ aws_neuron.md b/docs/source/en/ aws_neuron.md new file mode 100644 index 00000000..13ea7f86 --- /dev/null +++ b/docs/source/en/ aws_neuron.md @@ -0,0 +1,37 @@ + +# Using TEI Container with AWS Trainium and Inferentia Instances + +## Build Docker Image + +To build a container optimized for AWS Neuron devices, run the following command: + +```shell +platform="neuron" + +docker build . -f Dockerfile-neuron -t tei_neuron +``` + +### Deploy Docker Container + +To deploy your model on an AWS Trainium or Inferentia instance, use the following command: + +```shell +model='Qwen/Qwen3-Embedding-0.6B' +volume=$PWD/data + +docker run -p 8080:80 -v $volume:/data tei_neuron --model-id $model +``` \ No newline at end of file diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index fa6f21e6..b9eebac2 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -19,6 +19,8 @@ title: Build custom container for TEI - local: intel_container title: Using TEI container with Intel Hardware + - local: local_neuron + title: Using TEI container with AWS Neuron - local: examples title: Example uses title: Tutorials diff --git a/integration_tests/neuron/conftest.py b/integration_tests/neuron/conftest.py new file mode 100644 index 00000000..e69de29b diff --git a/integration_tests/neuron/test_embed.py b/integration_tests/neuron/test_embed.py new file mode 100644 index 00000000..e69de29b