Skip to content

Commit 89441fe

Browse files
author
Sigrid Jin
committed
feat: xprovenance
1 parent 78502d8 commit 89441fe

File tree

12 files changed

+290
-19
lines changed

12 files changed

+290
-19
lines changed

backends/core/src/lib.rs

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@ pub struct Batch {
1414
pub max_length: u32,
1515
pub pooled_indices: Vec<u32>,
1616
pub raw_indices: Vec<u32>,
17+
/// XProvence: raw query texts for context pruning
18+
pub raw_queries: Vec<Option<String>>,
19+
/// XProvence: raw context texts for context pruning
20+
pub raw_texts: Vec<Option<String>>,
1721
}
1822

1923
impl Batch {
@@ -32,7 +36,16 @@ pub enum Embedding {
3236
}
3337

3438
pub type Embeddings = IntMap<usize, Embedding>;
35-
pub type Predictions = IntMap<usize, Vec<f32>>;
39+
40+
/// XProvence: Prediction result containing scores and optional pruned text
41+
#[derive(Debug, Clone)]
42+
pub struct Prediction {
43+
pub scores: Vec<f32>,
44+
/// XProvence: pruned context text after removing irrelevant sentences
45+
pub pruned_text: Option<String>,
46+
}
47+
48+
pub type Predictions = IntMap<usize, Prediction>;
3649

3750
pub trait Backend {
3851
fn health(&self) -> Result<(), BackendError>;

backends/grpc-client/src/client.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,13 +73,17 @@ impl Client {
7373
position_ids: Vec<u32>,
7474
cu_seq_lengths: Vec<u32>,
7575
max_length: u32,
76+
raw_query: Option<String>,
77+
raw_text: Option<String>,
7678
) -> Result<Vec<Score>> {
7779
let request = tonic::Request::new(EmbedRequest {
7880
input_ids,
7981
token_type_ids,
8082
position_ids,
8183
max_length,
8284
cu_seq_lengths,
85+
raw_query,
86+
raw_text,
8387
})
8488
.inject_context();
8589
let response = self.stub.predict(request).await?.into_inner();

backends/proto/embed.proto

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ message EmbedRequest {
2121
repeated uint32 cu_seq_lengths = 4;
2222
/// Length of the longest request
2323
uint32 max_length = 5;
24+
/// XProvence: raw query text for context pruning
25+
optional string raw_query = 6;
26+
/// XProvence: raw context text for context pruning
27+
optional string raw_text = 7;
2428
}
2529

2630
message Embedding {
@@ -33,6 +37,8 @@ message EmbedResponse {
3337

3438
message Score {
3539
repeated float values = 1;
40+
/// XProvence: pruned context text after removing irrelevant sentences
41+
optional string pruned_text = 2;
3642
}
3743

3844
message PredictResponse {

backends/python/server/text_embeddings_server/models/__init__.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,19 @@
1111
from text_embeddings_server.models.masked_model import MaskedLanguageModel
1212
from text_embeddings_server.models.default_model import DefaultModel
1313
from text_embeddings_server.models.classification_model import ClassificationModel
14-
from text_embeddings_server.models.jinaBert_model import FlashJinaBert
15-
from text_embeddings_server.models.flash_mistral import FlashMistral
16-
from text_embeddings_server.models.flash_qwen3 import FlashQwen3
14+
from text_embeddings_server.models.xprovence_model import XProvenceModel
1715
from text_embeddings_server.utils.device import get_device, use_ipex
1816

17+
FlashJinaBert = None
18+
FlashMistral = None
19+
FlashQwen3 = None
20+
try:
21+
from text_embeddings_server.models.jinaBert_model import FlashJinaBert
22+
from text_embeddings_server.models.flash_mistral import FlashMistral
23+
from text_embeddings_server.models.flash_qwen3 import FlashQwen3
24+
except ImportError as e:
25+
logger.warning(f"Flash attention models not available: {e}")
26+
1927
__all__ = ["Model"]
2028

2129
TRUST_REMOTE_CODE = os.getenv("TRUST_REMOTE_CODE", "false").lower() in ["true", "1"]
@@ -76,13 +84,21 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str):
7684
config = AutoConfig.from_pretrained(model_path, trust_remote_code=TRUST_REMOTE_CODE)
7785

7886
if (
79-
hasattr(config, "auto_map")
87+
hasattr(config, "architectures")
88+
and config.architectures
89+
and "XProvence" in config.architectures[0]
90+
):
91+
logger.info("Detected XProvence model for context pruning")
92+
return XProvenceModel(model_path, device, datatype, trust_remote=True)
93+
94+
if (
95+
FlashJinaBert is not None
96+
and hasattr(config, "auto_map")
8097
and isinstance(config.auto_map, dict)
8198
and "AutoModel" in config.auto_map
8299
and config.auto_map["AutoModel"]
83100
== "jinaai/jina-bert-v2-qk-post-norm--modeling_bert.JinaBertModel"
84101
):
85-
# Add specific offline modeling for model "jinaai/jina-embeddings-v2-base-code" which uses "autoMap" to reference code in other repository
86102
return create_model(FlashJinaBert, model_path, device, datatype)
87103

88104
if config.model_type == "bert":
@@ -116,19 +132,18 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str):
116132
else:
117133
return create_model(DefaultModel, model_path, device, datatype, pool)
118134

119-
if config.model_type == "mistral" and device.type == "hpu":
135+
if FlashMistral is not None and config.model_type == "mistral" and device.type == "hpu":
120136
try:
121137
return create_model(FlashMistral, model_path, device, datatype, pool)
122138
except FileNotFoundError:
123139
return create_model(DefaultModel, model_path, device, datatype, pool)
124140

125-
if config.model_type == "qwen3" and device.type == "hpu":
141+
if FlashQwen3 is not None and config.model_type == "qwen3" and device.type == "hpu":
126142
try:
127143
return create_model(FlashQwen3, model_path, device, datatype, pool)
128144
except FileNotFoundError:
129145
return create_model(DefaultModel, model_path, device, datatype, pool)
130146

131-
# Default case
132147
if config.architectures[0].endswith("Classification"):
133148
return create_model(ClassificationModel, model_path, device, datatype)
134149
elif config.architectures[0].endswith("ForMaskedLM") and pool == "splade":

backends/python/server/text_embeddings_server/models/types.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ class PaddedBatch(Batch):
3636
token_type_ids: torch.Tensor
3737
position_ids: torch.Tensor
3838
attention_mask: torch.Tensor
39+
# XProvence: raw text for context pruning
40+
raw_query: str = None
41+
raw_text: str = None
3942

4043
@classmethod
4144
@tracer.start_as_current_span("from_pb")
@@ -77,11 +80,17 @@ def from_pb(
7780
# Move padded tensors all at once
7881
all_tensors = all_tensors.to(device)
7982

83+
# XProvence: Extract raw text if present in proto
84+
raw_query = pb.raw_query if hasattr(pb, 'raw_query') and pb.raw_query else None
85+
raw_text = pb.raw_text if hasattr(pb, 'raw_text') and pb.raw_text else None
86+
8087
return PaddedBatch(
8188
input_ids=all_tensors[0],
8289
token_type_ids=all_tensors[1],
8390
position_ids=all_tensors[2],
8491
attention_mask=all_tensors[3],
92+
raw_query=raw_query,
93+
raw_text=raw_text,
8594
)
8695

8796
def __len__(self):
Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
import os
2+
import torch
3+
4+
from pathlib import Path
5+
from typing import Type, List
6+
from transformers import AutoModel
7+
from opentelemetry import trace
8+
from loguru import logger
9+
10+
from text_embeddings_server.models.model import Model
11+
from text_embeddings_server.models.types import PaddedBatch, Embedding, Score
12+
13+
tracer = trace.get_tracer(__name__)
14+
15+
16+
def _parse_bool(value: str) -> bool:
17+
"""Parse boolean from string with common conventions."""
18+
return str(value).lower() in ("true", "1", "t", "yes", "on")
19+
20+
21+
class XProvenceModel(Model):
22+
"""
23+
XProvence: Zero-cost context pruning model for RAG.
24+
25+
XProvence removes irrelevant sentences from passages based on relevance
26+
to the query, returning both a reranking score and pruned context.
27+
28+
Based on bge-reranker-v2-m3 (XLM-RoBERTa), supports 16+ languages.
29+
30+
Environment Variables:
31+
XPROVENCE_THRESHOLD (float): Pruning threshold between 0.0-1.0.
32+
- 0.3 (default): Conservative pruning, minimal performance drop
33+
- 0.7: Aggressive pruning, higher compression
34+
XPROVENCE_ALWAYS_SELECT_TITLE (bool): Keep first sentence as title.
35+
- true (default): Always include first sentence (useful for Wikipedia)
36+
- false: Only include sentences above threshold
37+
"""
38+
39+
def __init__(
40+
self,
41+
model_path: Path,
42+
device: torch.device,
43+
dtype: torch.dtype,
44+
pool: str = "cls",
45+
trust_remote: bool = True,
46+
):
47+
model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
48+
49+
if dtype == torch.bfloat16:
50+
logger.info("XProvence: using float32 instead of bfloat16 for process() compatibility")
51+
dtype = torch.float32
52+
53+
model = model.to(dtype).to(device)
54+
55+
self.hidden_size = model.config.hidden_size
56+
57+
position_offset = 0
58+
model_type = model.config.model_type
59+
if model_type in ["xlm-roberta", "camembert", "roberta"]:
60+
position_offset = model.config.pad_token_id + 1
61+
62+
if hasattr(model.config, "max_seq_length"):
63+
self.max_input_length = model.config.max_seq_length
64+
else:
65+
self.max_input_length = (
66+
model.config.max_position_embeddings - position_offset
67+
)
68+
69+
try:
70+
threshold_env = os.getenv("XPROVENCE_THRESHOLD", "0.3")
71+
self.threshold = float(threshold_env)
72+
if not (0.0 <= self.threshold <= 1.0):
73+
logger.warning(
74+
f"XPROVENCE_THRESHOLD={self.threshold} out of bounds [0.0, 1.0], "
75+
f"defaulting to 0.3"
76+
)
77+
self.threshold = 0.3
78+
except ValueError:
79+
logger.error(
80+
f"Invalid XPROVENCE_THRESHOLD='{threshold_env}', defaulting to 0.3"
81+
)
82+
self.threshold = 0.3
83+
84+
self.always_select_title = _parse_bool(
85+
os.getenv("XPROVENCE_ALWAYS_SELECT_TITLE", "true")
86+
)
87+
88+
logger.info(
89+
f"XProvence model loaded: threshold={self.threshold}, "
90+
f"always_select_title={self.always_select_title} "
91+
f"(Configure via XPROVENCE_THRESHOLD, XPROVENCE_ALWAYS_SELECT_TITLE env vars)"
92+
)
93+
94+
super(XProvenceModel, self).__init__(model=model, dtype=dtype, device=device)
95+
96+
@property
97+
def batch_type(self) -> Type[PaddedBatch]:
98+
return PaddedBatch
99+
100+
@tracer.start_as_current_span("embed")
101+
def embed(self, batch: PaddedBatch) -> List[Embedding]:
102+
pass
103+
104+
@tracer.start_as_current_span("predict")
105+
def predict(self, batch: PaddedBatch) -> List[Score]:
106+
"""
107+
XProvence prediction with context pruning support.
108+
109+
For single-item batches with raw_query/raw_text available,
110+
uses XProvence's process() method for sentence-level pruning.
111+
Otherwise falls back to standard forward pass.
112+
"""
113+
batch_size = len(batch)
114+
115+
if batch_size == 1 and batch.raw_query and batch.raw_text:
116+
return self._predict_with_pruning(batch.raw_query, batch.raw_text)
117+
118+
return self._predict_standard(batch)
119+
120+
def _predict_with_pruning(self, raw_query: str, raw_text: str) -> List[Score]:
121+
"""
122+
Use XProvence's process() method for context pruning.
123+
124+
Returns score with pruned_text containing only relevant sentences.
125+
"""
126+
try:
127+
os.environ["TQDM_DISABLE"] = "1"
128+
129+
original_dtype = torch.get_default_dtype()
130+
torch.set_default_dtype(torch.float32)
131+
132+
try:
133+
output = self.model.process(
134+
raw_query,
135+
raw_text,
136+
threshold=self.threshold,
137+
always_select_title=self.always_select_title,
138+
)
139+
finally:
140+
torch.set_default_dtype(original_dtype)
141+
142+
reranking_score = float(output["reranking_score"])
143+
pruned_context = output["pruned_context"]
144+
145+
logger.debug(
146+
f"XProvence pruning: score={reranking_score:.4f}, "
147+
f"original_len={len(raw_text)}, pruned_len={len(pruned_context)}"
148+
)
149+
150+
return [Score(values=[reranking_score], pruned_text=pruned_context)]
151+
152+
except Exception as e:
153+
logger.error(f"XProvence process() failed: {e}, falling back to standard")
154+
return [Score(values=[0.0], pruned_text=None)]
155+
156+
def _predict_standard(self, batch: PaddedBatch) -> List[Score]:
157+
kwargs = {"input_ids": batch.input_ids, "attention_mask": batch.attention_mask}
158+
159+
output = self.model(**kwargs, return_dict=True)
160+
161+
if hasattr(output, "ranking_scores"):
162+
scores_tensor = output.ranking_scores
163+
elif hasattr(output, "logits"):
164+
scores_tensor = output.logits[:, 0] if output.logits.dim() == 2 else output.logits
165+
else:
166+
scores_tensor = output[0]
167+
168+
if scores_tensor.dim() == 0:
169+
scores = [float(scores_tensor.item())]
170+
else:
171+
scores = scores_tensor.view(-1).tolist()
172+
173+
if isinstance(scores, float):
174+
scores = [scores]
175+
176+
return [Score(values=[float(s)], pruned_text=None) for s in scores]

backends/python/src/lib.rs

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use backend_grpc_client::Client;
55
use nohash_hasher::BuildNoHashHasher;
66
use std::collections::HashMap;
77
use text_embeddings_backend_core::{
8-
Backend, BackendError, Batch, Embedding, Embeddings, ModelType, Pool, Predictions,
8+
Backend, BackendError, Batch, Embedding, Embeddings, ModelType, Pool, Prediction, Predictions,
99
};
1010
use tokio::runtime::Runtime;
1111

@@ -108,6 +108,11 @@ impl Backend for PythonBackend {
108108
));
109109
}
110110
let batch_size = batch.len();
111+
112+
// XProvence: Get first raw query/text from batch (for single request)
113+
let raw_query = batch.raw_queries.first().cloned().flatten();
114+
let raw_text = batch.raw_texts.first().cloned().flatten();
115+
111116
let results = self
112117
.tokio_runtime
113118
.block_on(self.backend_client.clone().predict(
@@ -116,15 +121,22 @@ impl Backend for PythonBackend {
116121
batch.position_ids,
117122
batch.cumulative_seq_lengths,
118123
batch.max_length,
124+
raw_query,
125+
raw_text,
119126
))
120127
.map_err(|err| BackendError::Inference(err.to_string()))?;
121-
let raw_results: Vec<Vec<f32>> = results.into_iter().map(|r| r.values).collect();
122128

123129
let mut predictions =
124130
HashMap::with_capacity_and_hasher(batch_size, BuildNoHashHasher::default());
125131

126-
for (i, r) in raw_results.into_iter().enumerate() {
127-
predictions.insert(i, r);
132+
for (i, score) in results.into_iter().enumerate() {
133+
predictions.insert(
134+
i,
135+
Prediction {
136+
scores: score.values,
137+
pruned_text: score.pruned_text,
138+
},
139+
);
128140
}
129141

130142
Ok(predictions)

0 commit comments

Comments
 (0)