From 5be80a0aaec4131a1e2215ecb59f25a2620eedfe Mon Sep 17 00:00:00 2001 From: Sigrid Jin Date: Fri, 5 Dec 2025 10:20:29 +0000 Subject: [PATCH] feat: add XProvence context pruning support for RAG Add XProvence model integration for zero-cost context pruning in reranking. XProvence removes irrelevant sentences from passages based on query relevance, returning both reranking scores and pruned context. Changes: - Add XProvenceModel class with process() method for sentence-level pruning - Add pruned_text field to Score/Prediction types and HTTP response - Pass raw_query/raw_text through tokenization pipeline for pruning - Make flash_attn imports optional for XProvence compatibility - Add XProvence architecture detection in router and Python backend - Handle bfloat16 to float32 conversion for XProvence process() method - Update candle, ort backends to support Prediction with pruned_text - Add Dockerfile-cuda-python for Python backend with CUDA support Configuration: - XPROVENCE_THRESHOLD: Pruning threshold 0.0-1.0 (default: 0.3) - XPROVENCE_ALWAYS_SELECT_TITLE: Keep first sentence as title (default: true) Usage: XPROVENCE_THRESHOLD=0.3 text-embeddings-router \ --model-id naver/xprovence-reranker-bgem3-v1 --port 8080 Docker build: docker build -f Dockerfile-cuda-python -t tei-python-cuda . --- Dockerfile | 158 ++++++++-------- backends/candle/src/lib.rs | 7 +- backends/core/src/lib.rs | 15 +- backends/grpc-client/src/client.rs | 6 + backends/ort/src/lib.rs | 7 +- backends/proto/embed.proto | 6 + .../text_embeddings_server/models/__init__.py | 31 ++- .../text_embeddings_server/models/types.py | 9 + .../models/xprovence_model.py | 176 ++++++++++++++++++ backends/python/src/lib.rs | 20 +- backends/src/lib.rs | 6 + core/src/infer.rs | 10 +- core/src/queue.rs | 10 + core/src/tokenization.rs | 12 ++ router/src/http/server.rs | 12 +- router/src/http/types.rs | 4 + router/src/lib.rs | 3 +- 17 files changed, 394 insertions(+), 98 deletions(-) create mode 100644 backends/python/server/text_embeddings_server/models/xprovence_model.py diff --git a/Dockerfile b/Dockerfile index e4a01b249..fbddbf631 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,14 +1,35 @@ -FROM lukemathwalker/cargo-chef:latest-rust-1.85-bookworm AS chef -WORKDIR /usr/src +# Dockerfile for TEI with Python backend and CUDA support +# Supports: L40s (sm_89), RTX 3090 (sm_86) + +# ============================================================================= +# Stage 1: Rust Builder +# ============================================================================= +FROM nvidia/cuda:12.4.0-devel-ubuntu22.04 AS rust-builder ENV SCCACHE=0.10.0 ENV RUSTC_WRAPPER=/usr/local/bin/sccache +ENV PATH="/root/.cargo/bin:${PATH}" +ENV CARGO_CHEF=0.1.71 + +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ + curl \ + libssl-dev \ + pkg-config \ + protobuf-compiler \ + && rm -rf /var/lib/apt/lists/* -# Donwload, configure sccache RUN curl -fsSL https://github.com/mozilla/sccache/releases/download/v$SCCACHE/sccache-v$SCCACHE-x86_64-unknown-linux-musl.tar.gz | tar -xzv --strip-components=1 -C /usr/local/bin sccache-v$SCCACHE-x86_64-unknown-linux-musl/sccache && \ chmod +x /usr/local/bin/sccache -FROM chef AS planner +RUN curl https://sh.rustup.rs -sSf | bash -s -- -y +RUN cargo install cargo-chef --version $CARGO_CHEF --locked + +# ============================================================================= +# Stage 2: Recipe Planner +# ============================================================================= +FROM rust-builder AS planner + +WORKDIR /usr/src COPY backends backends COPY core core @@ -16,34 +37,21 @@ COPY router router COPY Cargo.toml ./ COPY Cargo.lock ./ -RUN cargo chef prepare --recipe-path recipe.json +RUN cargo chef prepare --recipe-path recipe.json -FROM chef AS builder +# ============================================================================= +# Stage 3: Dependency Builder +# ============================================================================= +FROM rust-builder AS builder ARG GIT_SHA ARG DOCKER_LABEL -# sccache specific variables -ARG SCCACHE_GHA_ENABLED - -RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB \ - | gpg --dearmor | tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null && \ - echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | \ - tee /etc/apt/sources.list.d/oneAPI.list - -RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ - intel-oneapi-mkl-devel=2024.0.0-49656 \ - build-essential \ - && rm -rf /var/lib/apt/lists/* - -RUN echo "int mkl_serv_intel_cpu_true() {return 1;}" > fakeintel.c && \ - gcc -shared -fPIC -o libfakeintel.so fakeintel.c +WORKDIR /usr/src COPY --from=planner /usr/src/recipe.json recipe.json -RUN --mount=type=secret,id=actions_results_url,env=ACTIONS_RESULTS_URL \ - --mount=type=secret,id=actions_runtime_token,env=ACTIONS_RUNTIME_TOKEN \ - cargo chef cook --release --features ort,candle,mkl,static-linking --no-default-features --recipe-path recipe.json && sccache -s +RUN cargo chef cook --release --features python --features http --recipe-path recipe.json && sccache -s COPY backends backends COPY core core @@ -51,73 +59,75 @@ COPY router router COPY Cargo.toml ./ COPY Cargo.lock ./ -FROM builder AS http-builder +RUN cargo build --release --bin text-embeddings-router -F python -F http --no-default-features && sccache -s -RUN --mount=type=secret,id=actions_results_url,env=ACTIONS_RESULTS_URL \ - --mount=type=secret,id=actions_runtime_token,env=ACTIONS_RUNTIME_TOKEN \ - cargo build --release --bin text-embeddings-router --features ort,candle,mkl,static-linking,http --no-default-features && sccache -s +# ============================================================================= +# Stage 4: Python Environment +# ============================================================================= +FROM nvidia/cuda:12.4.0-runtime-ubuntu22.04 AS python-builder -FROM builder AS grpc-builder +ENV DEBIAN_FRONTEND=noninteractive -RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \ - curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \ - unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \ - unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \ - rm -f $PROTOC_ZIP +RUN apt-get update && apt-get install -y --no-install-recommends \ + python3.10 \ + python3.10-dev \ + python3-pip \ + git \ + && rm -rf /var/lib/apt/lists/* -COPY proto proto +RUN ln -sf /usr/bin/python3.10 /usr/bin/python && \ + ln -sf /usr/bin/python3.10 /usr/bin/python3 -RUN --mount=type=secret,id=actions_results_url,env=ACTIONS_RESULTS_URL \ - --mount=type=secret,id=actions_runtime_token,env=ACTIONS_RUNTIME_TOKEN \ - cargo build --release --bin text-embeddings-router --features ort,candle,mkl,static-linking,grpc --no-default-features && sccache -s +RUN pip install --no-cache-dir --upgrade pip setuptools wheel -FROM debian:bookworm-slim AS base +WORKDIR /opt/server -ENV HUGGINGFACE_HUB_CACHE=/data \ - PORT=80 \ - MKL_ENABLE_INSTRUCTIONS=AVX512_E4 \ - RAYON_NUM_THREADS=8 \ - LD_PRELOAD=/usr/local/libfakeintel.so \ - LD_LIBRARY_PATH=/usr/local/lib +COPY backends/proto /opt/proto +COPY backends/python/server /opt/server -RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ - libomp-dev \ - ca-certificates \ - libssl-dev \ - curl \ - && rm -rf /var/lib/apt/lists/* +RUN pip install grpcio-tools==1.62.2 mypy-protobuf==3.6.0 'types-protobuf' --no-cache-dir && \ + mkdir -p text_embeddings_server/pb && \ + python -m grpc_tools.protoc -I/opt/proto --python_out=text_embeddings_server/pb \ + --grpc_python_out=text_embeddings_server/pb --mypy_out=text_embeddings_server/pb /opt/proto/embed.proto && \ + find text_embeddings_server/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \; && \ + touch text_embeddings_server/pb/__init__.py -# Copy a lot of the Intel shared objects because of the mkl_serv_intel_cpu_true patch... -COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_intel_lp64.so.2 /usr/local/lib/libmkl_intel_lp64.so.2 -COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_intel_thread.so.2 /usr/local/lib/libmkl_intel_thread.so.2 -COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_core.so.2 /usr/local/lib/libmkl_core.so.2 -COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_vml_def.so.2 /usr/local/lib/libmkl_vml_def.so.2 -COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_def.so.2 /usr/local/lib/libmkl_def.so.2 -COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_vml_avx2.so.2 /usr/local/lib/libmkl_vml_avx2.so.2 -COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_vml_avx512.so.2 /usr/local/lib/libmkl_vml_avx512.so.2 -COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_avx2.so.2 /usr/local/lib/libmkl_avx2.so.2 -COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_avx512.so.2 /usr/local/lib/libmkl_avx512.so.2 -COPY --from=builder /usr/src/libfakeintel.so /usr/local/libfakeintel.so +RUN pip install --no-cache-dir torch==2.5.1 --index-url https://download.pytorch.org/whl/cu124 -FROM base AS grpc +RUN pip install --no-cache-dir -r requirements.txt -COPY --from=grpc-builder /usr/src/target/release/text-embeddings-router /usr/local/bin/text-embeddings-router +RUN pip install --no-cache-dir -e . -ENTRYPOINT ["text-embeddings-router"] -CMD ["--json-output"] +# ============================================================================= +# Stage 5: Final Image +# ============================================================================= +FROM nvidia/cuda:12.4.0-runtime-ubuntu22.04 + +ENV DEBIAN_FRONTEND=noninteractive +ENV HUGGINGFACE_HUB_CACHE=/data +ENV PORT=80 +ENV TQDM_DISABLE=1 + +RUN apt-get update && apt-get install -y --no-install-recommends \ + python3.10 \ + python3-pip \ + ca-certificates \ + libssl-dev \ + curl \ + && rm -rf /var/lib/apt/lists/* -FROM base AS http +RUN ln -sf /usr/bin/python3.10 /usr/bin/python && \ + ln -sf /usr/bin/python3.10 /usr/bin/python3 -COPY --from=http-builder /usr/src/target/release/text-embeddings-router /usr/local/bin/text-embeddings-router +COPY --from=python-builder /usr/local/lib/python3.10/dist-packages /usr/local/lib/python3.10/dist-packages +COPY --from=python-builder /opt/server /opt/server -# Amazon SageMaker compatible image -FROM http AS sagemaker -COPY --chmod=775 sagemaker-entrypoint.sh entrypoint.sh +COPY --from=builder /usr/src/target/release/text-embeddings-router /usr/local/bin/text-embeddings-router -ENTRYPOINT ["./entrypoint.sh"] +ENV PATH="/usr/local/bin:${PATH}" +ENV PYTHONPATH="/opt/server:${PYTHONPATH}" -# Default image -FROM http +WORKDIR /opt/server ENTRYPOINT ["text-embeddings-router"] CMD ["--json-output"] diff --git a/backends/candle/src/lib.rs b/backends/candle/src/lib.rs index ff824f555..0d5fa97fc 100644 --- a/backends/candle/src/lib.rs +++ b/backends/candle/src/lib.rs @@ -14,7 +14,7 @@ use serde::{de::Deserializer, Deserialize}; use std::collections::HashMap; use std::path::Path; use text_embeddings_backend_core::{ - Backend, BackendError, Batch, Embedding, Embeddings, ModelType, Predictions, + Backend, BackendError, Batch, Embedding, Embeddings, ModelType, Prediction, Predictions, }; #[cfg(feature = "cuda")] @@ -653,7 +653,10 @@ impl Backend for CandleBackend { let mut predictions = HashMap::with_capacity_and_hasher(batch_size, BuildNoHashHasher::default()); for (i, r) in results.into_iter().enumerate() { - predictions.insert(i, r); + predictions.insert(i, Prediction { + scores: r, + pruned_text: None, + }); } Ok(predictions) diff --git a/backends/core/src/lib.rs b/backends/core/src/lib.rs index 8e134d2be..55dad0d8e 100644 --- a/backends/core/src/lib.rs +++ b/backends/core/src/lib.rs @@ -14,6 +14,10 @@ pub struct Batch { pub max_length: u32, pub pooled_indices: Vec, pub raw_indices: Vec, + /// XProvence: raw query texts for context pruning + pub raw_queries: Vec>, + /// XProvence: raw context texts for context pruning + pub raw_texts: Vec>, } impl Batch { @@ -32,7 +36,16 @@ pub enum Embedding { } pub type Embeddings = IntMap; -pub type Predictions = IntMap>; + +/// XProvence: Prediction result containing scores and optional pruned text +#[derive(Debug, Clone)] +pub struct Prediction { + pub scores: Vec, + /// XProvence: pruned context text after removing irrelevant sentences + pub pruned_text: Option, +} + +pub type Predictions = IntMap; pub trait Backend { fn health(&self) -> Result<(), BackendError>; diff --git a/backends/grpc-client/src/client.rs b/backends/grpc-client/src/client.rs index 1f6036eed..33f75da6e 100644 --- a/backends/grpc-client/src/client.rs +++ b/backends/grpc-client/src/client.rs @@ -59,6 +59,8 @@ impl Client { position_ids, max_length, cu_seq_lengths, + raw_query: None, + raw_text: None, }) .inject_context(); let response = self.stub.embed(request).await?.into_inner(); @@ -73,6 +75,8 @@ impl Client { position_ids: Vec, cu_seq_lengths: Vec, max_length: u32, + raw_query: Option, + raw_text: Option, ) -> Result> { let request = tonic::Request::new(EmbedRequest { input_ids, @@ -80,6 +84,8 @@ impl Client { position_ids, max_length, cu_seq_lengths, + raw_query, + raw_text, }) .inject_context(); let response = self.stub.predict(request).await?.into_inner(); diff --git a/backends/ort/src/lib.rs b/backends/ort/src/lib.rs index bfc2d03ad..4f84d4f79 100644 --- a/backends/ort/src/lib.rs +++ b/backends/ort/src/lib.rs @@ -8,7 +8,7 @@ use std::ops::{Div, Mul}; use std::path::Path; use std::sync::Mutex; use text_embeddings_backend_core::{ - Backend, BackendError, Batch, Embedding, Embeddings, ModelType, Pool, Predictions, + Backend, BackendError, Batch, Embedding, Embeddings, ModelType, Pool, Prediction, Predictions, }; #[derive(Debug, Clone, Deserialize)] @@ -679,7 +679,10 @@ impl Backend for OrtBackend { let mut predictions = HashMap::with_capacity_and_hasher(batch_size, BuildNoHashHasher::default()); for (i, r) in outputs.rows().into_iter().enumerate() { - predictions.insert(i, r.to_vec()); + predictions.insert(i, Prediction { + scores: r.to_vec(), + pruned_text: None, + }); } Ok(predictions) diff --git a/backends/proto/embed.proto b/backends/proto/embed.proto index 036f3db4b..55df0889f 100644 --- a/backends/proto/embed.proto +++ b/backends/proto/embed.proto @@ -21,6 +21,10 @@ message EmbedRequest { repeated uint32 cu_seq_lengths = 4; /// Length of the longest request uint32 max_length = 5; + /// XProvence: raw query text for context pruning + optional string raw_query = 6; + /// XProvence: raw context text for context pruning + optional string raw_text = 7; } message Embedding { @@ -33,6 +37,8 @@ message EmbedResponse { message Score { repeated float values = 1; + /// XProvence: pruned context text after removing irrelevant sentences + optional string pruned_text = 2; } message PredictResponse { diff --git a/backends/python/server/text_embeddings_server/models/__init__.py b/backends/python/server/text_embeddings_server/models/__init__.py index 1e919f233..55f93c0bf 100644 --- a/backends/python/server/text_embeddings_server/models/__init__.py +++ b/backends/python/server/text_embeddings_server/models/__init__.py @@ -11,11 +11,19 @@ from text_embeddings_server.models.masked_model import MaskedLanguageModel from text_embeddings_server.models.default_model import DefaultModel from text_embeddings_server.models.classification_model import ClassificationModel -from text_embeddings_server.models.jinaBert_model import FlashJinaBert -from text_embeddings_server.models.flash_mistral import FlashMistral -from text_embeddings_server.models.flash_qwen3 import FlashQwen3 +from text_embeddings_server.models.xprovence_model import XProvenceModel from text_embeddings_server.utils.device import get_device, use_ipex +FlashJinaBert = None +FlashMistral = None +FlashQwen3 = None +try: + from text_embeddings_server.models.jinaBert_model import FlashJinaBert + from text_embeddings_server.models.flash_mistral import FlashMistral + from text_embeddings_server.models.flash_qwen3 import FlashQwen3 +except ImportError as e: + logger.warning(f"Flash attention models not available: {e}") + __all__ = ["Model"] TRUST_REMOTE_CODE = os.getenv("TRUST_REMOTE_CODE", "false").lower() in ["true", "1"] @@ -76,13 +84,21 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str): config = AutoConfig.from_pretrained(model_path, trust_remote_code=TRUST_REMOTE_CODE) if ( - hasattr(config, "auto_map") + hasattr(config, "architectures") + and config.architectures + and "XProvence" in config.architectures[0] + ): + logger.info("Detected XProvence model for context pruning") + return XProvenceModel(model_path, device, datatype, trust_remote=True) + + if ( + FlashJinaBert is not None + and hasattr(config, "auto_map") and isinstance(config.auto_map, dict) and "AutoModel" in config.auto_map and config.auto_map["AutoModel"] == "jinaai/jina-bert-v2-qk-post-norm--modeling_bert.JinaBertModel" ): - # Add specific offline modeling for model "jinaai/jina-embeddings-v2-base-code" which uses "autoMap" to reference code in other repository return create_model(FlashJinaBert, model_path, device, datatype) if config.model_type == "bert": @@ -116,19 +132,18 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str): else: return create_model(DefaultModel, model_path, device, datatype, pool) - if config.model_type == "mistral" and device.type == "hpu": + if FlashMistral is not None and config.model_type == "mistral" and device.type == "hpu": try: return create_model(FlashMistral, model_path, device, datatype, pool) except FileNotFoundError: return create_model(DefaultModel, model_path, device, datatype, pool) - if config.model_type == "qwen3" and device.type == "hpu": + if FlashQwen3 is not None and config.model_type == "qwen3" and device.type == "hpu": try: return create_model(FlashQwen3, model_path, device, datatype, pool) except FileNotFoundError: return create_model(DefaultModel, model_path, device, datatype, pool) - # Default case if config.architectures[0].endswith("Classification"): return create_model(ClassificationModel, model_path, device, datatype) elif config.architectures[0].endswith("ForMaskedLM") and pool == "splade": diff --git a/backends/python/server/text_embeddings_server/models/types.py b/backends/python/server/text_embeddings_server/models/types.py index f27572a9b..f4da0da8e 100644 --- a/backends/python/server/text_embeddings_server/models/types.py +++ b/backends/python/server/text_embeddings_server/models/types.py @@ -36,6 +36,9 @@ class PaddedBatch(Batch): token_type_ids: torch.Tensor position_ids: torch.Tensor attention_mask: torch.Tensor + # XProvence: raw text for context pruning + raw_query: str = None + raw_text: str = None @classmethod @tracer.start_as_current_span("from_pb") @@ -77,11 +80,17 @@ def from_pb( # Move padded tensors all at once all_tensors = all_tensors.to(device) + # XProvence: Extract raw text if present in proto + raw_query = pb.raw_query if hasattr(pb, 'raw_query') and pb.raw_query else None + raw_text = pb.raw_text if hasattr(pb, 'raw_text') and pb.raw_text else None + return PaddedBatch( input_ids=all_tensors[0], token_type_ids=all_tensors[1], position_ids=all_tensors[2], attention_mask=all_tensors[3], + raw_query=raw_query, + raw_text=raw_text, ) def __len__(self): diff --git a/backends/python/server/text_embeddings_server/models/xprovence_model.py b/backends/python/server/text_embeddings_server/models/xprovence_model.py new file mode 100644 index 000000000..f4145a871 --- /dev/null +++ b/backends/python/server/text_embeddings_server/models/xprovence_model.py @@ -0,0 +1,176 @@ +import os +import torch + +from pathlib import Path +from typing import Type, List +from transformers import AutoModel +from opentelemetry import trace +from loguru import logger + +from text_embeddings_server.models.model import Model +from text_embeddings_server.models.types import PaddedBatch, Embedding, Score + +tracer = trace.get_tracer(__name__) + + +def _parse_bool(value: str) -> bool: + """Parse boolean from string with common conventions.""" + return str(value).lower() in ("true", "1", "t", "yes", "on") + + +class XProvenceModel(Model): + """ + XProvence: Zero-cost context pruning model for RAG. + + XProvence removes irrelevant sentences from passages based on relevance + to the query, returning both a reranking score and pruned context. + + Based on bge-reranker-v2-m3 (XLM-RoBERTa), supports 16+ languages. + + Environment Variables: + XPROVENCE_THRESHOLD (float): Pruning threshold between 0.0-1.0. + - 0.3 (default): Conservative pruning, minimal performance drop + - 0.7: Aggressive pruning, higher compression + XPROVENCE_ALWAYS_SELECT_TITLE (bool): Keep first sentence as title. + - true (default): Always include first sentence (useful for Wikipedia) + - false: Only include sentences above threshold + """ + + def __init__( + self, + model_path: Path, + device: torch.device, + dtype: torch.dtype, + pool: str = "cls", + trust_remote: bool = True, + ): + model = AutoModel.from_pretrained(model_path, trust_remote_code=True) + + if dtype == torch.bfloat16: + logger.info("XProvence: using float32 instead of bfloat16 for process() compatibility") + dtype = torch.float32 + + model = model.to(dtype).to(device) + + self.hidden_size = model.config.hidden_size + + position_offset = 0 + model_type = model.config.model_type + if model_type in ["xlm-roberta", "camembert", "roberta"]: + position_offset = model.config.pad_token_id + 1 + + if hasattr(model.config, "max_seq_length"): + self.max_input_length = model.config.max_seq_length + else: + self.max_input_length = ( + model.config.max_position_embeddings - position_offset + ) + + try: + threshold_env = os.getenv("XPROVENCE_THRESHOLD", "0.3") + self.threshold = float(threshold_env) + if not (0.0 <= self.threshold <= 1.0): + logger.warning( + f"XPROVENCE_THRESHOLD={self.threshold} out of bounds [0.0, 1.0], " + f"defaulting to 0.3" + ) + self.threshold = 0.3 + except ValueError: + logger.error( + f"Invalid XPROVENCE_THRESHOLD='{threshold_env}', defaulting to 0.3" + ) + self.threshold = 0.3 + + self.always_select_title = _parse_bool( + os.getenv("XPROVENCE_ALWAYS_SELECT_TITLE", "true") + ) + + logger.info( + f"XProvence model loaded: threshold={self.threshold}, " + f"always_select_title={self.always_select_title} " + f"(Configure via XPROVENCE_THRESHOLD, XPROVENCE_ALWAYS_SELECT_TITLE env vars)" + ) + + super(XProvenceModel, self).__init__(model=model, dtype=dtype, device=device) + + @property + def batch_type(self) -> Type[PaddedBatch]: + return PaddedBatch + + @tracer.start_as_current_span("embed") + def embed(self, batch: PaddedBatch) -> List[Embedding]: + pass + + @tracer.start_as_current_span("predict") + def predict(self, batch: PaddedBatch) -> List[Score]: + """ + XProvence prediction with context pruning support. + + For single-item batches with raw_query/raw_text available, + uses XProvence's process() method for sentence-level pruning. + Otherwise falls back to standard forward pass. + """ + batch_size = len(batch) + + if batch_size == 1 and batch.raw_query and batch.raw_text: + return self._predict_with_pruning(batch.raw_query, batch.raw_text) + + return self._predict_standard(batch) + + def _predict_with_pruning(self, raw_query: str, raw_text: str) -> List[Score]: + """ + Use XProvence's process() method for context pruning. + + Returns score with pruned_text containing only relevant sentences. + """ + try: + os.environ["TQDM_DISABLE"] = "1" + + original_dtype = torch.get_default_dtype() + torch.set_default_dtype(torch.float32) + + try: + output = self.model.process( + raw_query, + raw_text, + threshold=self.threshold, + always_select_title=self.always_select_title, + ) + finally: + torch.set_default_dtype(original_dtype) + + reranking_score = float(output["reranking_score"]) + pruned_context = output["pruned_context"] + + logger.debug( + f"XProvence pruning: score={reranking_score:.4f}, " + f"original_len={len(raw_text)}, pruned_len={len(pruned_context)}" + ) + + return [Score(values=[reranking_score], pruned_text=pruned_context)] + + except Exception as e: + logger.error(f"XProvence process() failed: {e}, falling back to standard") + return [Score(values=[0.0], pruned_text=None)] + + def _predict_standard(self, batch: PaddedBatch) -> List[Score]: + kwargs = {"input_ids": batch.input_ids, "attention_mask": batch.attention_mask} + + output = self.model(**kwargs, return_dict=True) + + if hasattr(output, "ranking_scores"): + scores_tensor = output.ranking_scores + elif hasattr(output, "logits"): + scores_tensor = output.logits[:, 0] if output.logits.dim() == 2 else output.logits + else: + scores_tensor = output[0] + + if scores_tensor.dim() == 0: + scores = [float(scores_tensor.item())] + else: + scores = scores_tensor.view(-1).tolist() + + if isinstance(scores, float): + scores = [scores] + + return [Score(values=[float(s)], pruned_text=None) for s in scores] diff --git a/backends/python/src/lib.rs b/backends/python/src/lib.rs index 53255b07d..0c1d04684 100644 --- a/backends/python/src/lib.rs +++ b/backends/python/src/lib.rs @@ -5,7 +5,7 @@ use backend_grpc_client::Client; use nohash_hasher::BuildNoHashHasher; use std::collections::HashMap; use text_embeddings_backend_core::{ - Backend, BackendError, Batch, Embedding, Embeddings, ModelType, Pool, Predictions, + Backend, BackendError, Batch, Embedding, Embeddings, ModelType, Pool, Prediction, Predictions, }; use tokio::runtime::Runtime; @@ -108,6 +108,11 @@ impl Backend for PythonBackend { )); } let batch_size = batch.len(); + + // XProvence: Get first raw query/text from batch (for single request) + let raw_query = batch.raw_queries.first().cloned().flatten(); + let raw_text = batch.raw_texts.first().cloned().flatten(); + let results = self .tokio_runtime .block_on(self.backend_client.clone().predict( @@ -116,15 +121,22 @@ impl Backend for PythonBackend { batch.position_ids, batch.cumulative_seq_lengths, batch.max_length, + raw_query, + raw_text, )) .map_err(|err| BackendError::Inference(err.to_string()))?; - let raw_results: Vec> = results.into_iter().map(|r| r.values).collect(); let mut predictions = HashMap::with_capacity_and_hasher(batch_size, BuildNoHashHasher::default()); - for (i, r) in raw_results.into_iter().enumerate() { - predictions.insert(i, r); + for (i, score) in results.into_iter().enumerate() { + predictions.insert( + i, + Prediction { + scores: score.values, + pruned_text: score.pruned_text, + }, + ); } Ok(predictions) diff --git a/backends/src/lib.rs b/backends/src/lib.rs index 245715b38..79bc05d29 100644 --- a/backends/src/lib.rs +++ b/backends/src/lib.rs @@ -223,6 +223,8 @@ impl Backend { max_length: tmp_length, pooled_indices, raw_indices: vec![], + raw_queries: vec![], + raw_texts: vec![], } } @@ -280,6 +282,8 @@ impl Backend { max_length, pooled_indices, raw_indices: vec![], + raw_queries: vec![], + raw_texts: vec![], }; match &self.model_type { @@ -314,6 +318,8 @@ impl Backend { max_length: 1, pooled_indices: vec![0], raw_indices: vec![], + raw_queries: vec![], + raw_texts: vec![], }; match &self.model_type { ModelType::Classifier => self.predict(batch).await.map(|_| ()), diff --git a/core/src/infer.rs b/core/src/infer.rs index a2ff22c51..fb16eb15a 100644 --- a/core/src/infer.rs +++ b/core/src/infer.rs @@ -561,11 +561,13 @@ async fn backend_task(backend: Backend, mut embed_receiver: mpsc::Receiver, + /// XProvence: pruned context text after removing irrelevant sentences + pub pruned_text: Option, pub metadata: InferMetadata, } diff --git a/core/src/queue.rs b/core/src/queue.rs index 3fd8b7715..acc3409d4 100644 --- a/core/src/queue.rs +++ b/core/src/queue.rs @@ -129,6 +129,10 @@ fn queue_blocking_task( let mut cu_seq_lengths = Vec::with_capacity(capacity); cu_seq_lengths.push(0); + // XProvence: raw text vectors for context pruning + let mut raw_queries = Vec::with_capacity(capacity); + let mut raw_texts = Vec::with_capacity(capacity); + let mut current_tokens = 0; let mut max_length = 0; @@ -168,6 +172,10 @@ fn queue_blocking_task( token_type_ids.extend(entry.encoding.token_type_ids); position_ids.extend(entry.encoding.position_ids); + // XProvence: collect raw texts for context pruning + raw_queries.push(entry.encoding.raw_query); + raw_texts.push(entry.encoding.raw_text); + current_tokens += entry_tokens; metadata.push(entry.metadata); cu_seq_lengths.push(current_tokens as u32); @@ -193,6 +201,8 @@ fn queue_blocking_task( max_length, pooled_indices, raw_indices, + raw_queries, + raw_texts, }, )) }; diff --git a/core/src/tokenization.rs b/core/src/tokenization.rs index 3639b9845..f42ceb352 100644 --- a/core/src/tokenization.rs +++ b/core/src/tokenization.rs @@ -374,6 +374,12 @@ fn encode_input( prompts: Option<&HashMap>, tokenizer: &mut Tokenizer, ) -> Result { + // XProvence: Extract raw query and text before tokenization (for Dual inputs) + let (raw_query, raw_text) = match &inputs { + EncodingInput::Dual(query, text) => (Some(query.clone()), Some(text.clone())), + _ => (None, None), + }; + // Default truncation params let truncate_params = truncate.then_some(TruncationParams { direction: truncation_direction, @@ -406,6 +412,8 @@ fn encode_input( token_type_ids: encoding.get_type_ids().to_vec(), position_ids: (position_offset as u32..(seq_len + position_offset) as u32) .collect::>(), + raw_query, + raw_text, }) } @@ -414,6 +422,10 @@ pub struct ValidEncoding { pub input_ids: Vec, pub token_type_ids: Vec, pub position_ids: Vec, + /// XProvence: raw query text for context pruning (from Dual input) + pub raw_query: Option, + /// XProvence: raw context text for context pruning (from Dual input) + pub raw_text: Option, } #[derive(Debug)] diff --git a/router/src/http/server.rs b/router/src/http/server.rs index a22af9628..1cb57a165 100644 --- a/router/src/http/server.rs +++ b/router/src/http/server.rs @@ -361,13 +361,16 @@ async fn rerank( .map_err(ErrorResponse::from)?; let score = response.results[0]; + // XProvence: extract pruned_text from response + let pruned_text = response.pruned_text; - Ok::<(usize, Duration, Duration, Duration, f32), ErrorResponse>(( + Ok::<(usize, Duration, Duration, Duration, f32, Option), ErrorResponse>(( response.metadata.prompt_tokens, response.metadata.tokenization, response.metadata.queue, response.metadata.inference, score, + pruned_text, )) }; @@ -410,7 +413,7 @@ async fn rerank( let results = join_all(futures) .await .into_iter() - .collect::, ErrorResponse>>()?; + .collect::)>, ErrorResponse>>()?; let mut ranks = Vec::with_capacity(batch_size); let mut total_tokenization_time = 0; @@ -430,6 +433,9 @@ async fn rerank( }; let score = r.4; + // XProvence: extract pruned_text from result + let pruned_text = r.5; + // Check that s is not NaN or the partial_cmp below will panic if score.is_nan() { Err(ErrorResponse { @@ -438,7 +444,7 @@ async fn rerank( })?; } - ranks.push(Rank { index, text, score }) + ranks.push(Rank { index, text, score, pruned_text }) } // Reverse sort diff --git a/router/src/http/types.rs b/router/src/http/types.rs index dedaab60a..ce9994b22 100644 --- a/router/src/http/types.rs +++ b/router/src/http/types.rs @@ -266,6 +266,10 @@ pub(crate) struct Rank { pub text: Option, #[schema(example = "1.0")] pub score: f32, + /// XProvence: pruned context with irrelevant sentences removed + #[schema(nullable = true, default = "null")] + #[serde(skip_serializing_if = "Option::is_none")] + pub pruned_text: Option, } #[derive(Serialize, ToSchema)] diff --git a/router/src/lib.rs b/router/src/lib.rs index d83bd95c5..9c5eb98f4 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -396,7 +396,8 @@ fn get_backend_model_type( return Ok(text_embeddings_backend::ModelType::Embedding( text_embeddings_backend::Pool::Splade, )); - } else if arch.ends_with("Classification") { + } else if arch.ends_with("Classification") || arch == "XProvence" { + // XProvence is a reranker model for context pruning if pooling.is_some() { tracing::warn!( "`--pooling` arg is set but model is a classifier. Ignoring `--pooling` arg."