Skip to content

Commit fc4f021

Browse files
Sigrid JinUser
authored andcommitted
feat: xprovenance
1 parent 78502d8 commit fc4f021

File tree

12 files changed

+274
-11
lines changed

12 files changed

+274
-11
lines changed

backends/core/src/lib.rs

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@ pub struct Batch {
1414
pub max_length: u32,
1515
pub pooled_indices: Vec<u32>,
1616
pub raw_indices: Vec<u32>,
17+
/// XProvence: raw query texts for context pruning
18+
pub raw_queries: Vec<Option<String>>,
19+
/// XProvence: raw context texts for context pruning
20+
pub raw_texts: Vec<Option<String>>,
1721
}
1822

1923
impl Batch {
@@ -32,7 +36,16 @@ pub enum Embedding {
3236
}
3337

3438
pub type Embeddings = IntMap<usize, Embedding>;
35-
pub type Predictions = IntMap<usize, Vec<f32>>;
39+
40+
/// XProvence: Prediction result containing scores and optional pruned text
41+
#[derive(Debug, Clone)]
42+
pub struct Prediction {
43+
pub scores: Vec<f32>,
44+
/// XProvence: pruned context text after removing irrelevant sentences
45+
pub pruned_text: Option<String>,
46+
}
47+
48+
pub type Predictions = IntMap<usize, Prediction>;
3649

3750
pub trait Backend {
3851
fn health(&self) -> Result<(), BackendError>;

backends/grpc-client/src/client.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,13 +73,17 @@ impl Client {
7373
position_ids: Vec<u32>,
7474
cu_seq_lengths: Vec<u32>,
7575
max_length: u32,
76+
raw_query: Option<String>,
77+
raw_text: Option<String>,
7678
) -> Result<Vec<Score>> {
7779
let request = tonic::Request::new(EmbedRequest {
7880
input_ids,
7981
token_type_ids,
8082
position_ids,
8183
max_length,
8284
cu_seq_lengths,
85+
raw_query,
86+
raw_text,
8387
})
8488
.inject_context();
8589
let response = self.stub.predict(request).await?.into_inner();

backends/proto/embed.proto

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ message EmbedRequest {
2121
repeated uint32 cu_seq_lengths = 4;
2222
/// Length of the longest request
2323
uint32 max_length = 5;
24+
/// XProvence: raw query text for context pruning
25+
optional string raw_query = 6;
26+
/// XProvence: raw context text for context pruning
27+
optional string raw_text = 7;
2428
}
2529

2630
message Embedding {
@@ -33,6 +37,8 @@ message EmbedResponse {
3337

3438
message Score {
3539
repeated float values = 1;
40+
/// XProvence: pruned context text after removing irrelevant sentences
41+
optional string pruned_text = 2;
3642
}
3743

3844
message PredictResponse {

backends/python/server/text_embeddings_server/models/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from text_embeddings_server.models.masked_model import MaskedLanguageModel
1212
from text_embeddings_server.models.default_model import DefaultModel
1313
from text_embeddings_server.models.classification_model import ClassificationModel
14+
from text_embeddings_server.models.xprovence_model import XProvenceModel
1415
from text_embeddings_server.models.jinaBert_model import FlashJinaBert
1516
from text_embeddings_server.models.flash_mistral import FlashMistral
1617
from text_embeddings_server.models.flash_qwen3 import FlashQwen3
@@ -75,6 +76,15 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str):
7576

7677
config = AutoConfig.from_pretrained(model_path, trust_remote_code=TRUST_REMOTE_CODE)
7778

79+
# XProvence: Check for XProvence architecture (context pruning reranker)
80+
if (
81+
hasattr(config, "architectures")
82+
and config.architectures
83+
and "XProvence" in config.architectures[0]
84+
):
85+
logger.info("Detected XProvence model for context pruning")
86+
return XProvenceModel(model_path, device, datatype, trust_remote=True)
87+
7888
if (
7989
hasattr(config, "auto_map")
8090
and isinstance(config.auto_map, dict)

backends/python/server/text_embeddings_server/models/types.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ class PaddedBatch(Batch):
3636
token_type_ids: torch.Tensor
3737
position_ids: torch.Tensor
3838
attention_mask: torch.Tensor
39+
# XProvence: raw text for context pruning
40+
raw_query: str = None
41+
raw_text: str = None
3942

4043
@classmethod
4144
@tracer.start_as_current_span("from_pb")
@@ -77,11 +80,17 @@ def from_pb(
7780
# Move padded tensors all at once
7881
all_tensors = all_tensors.to(device)
7982

83+
# XProvence: Extract raw text if present in proto
84+
raw_query = pb.raw_query if hasattr(pb, 'raw_query') and pb.raw_query else None
85+
raw_text = pb.raw_text if hasattr(pb, 'raw_text') and pb.raw_text else None
86+
8087
return PaddedBatch(
8188
input_ids=all_tensors[0],
8289
token_type_ids=all_tensors[1],
8390
position_ids=all_tensors[2],
8491
attention_mask=all_tensors[3],
92+
raw_query=raw_query,
93+
raw_text=raw_text,
8594
)
8695

8796
def __len__(self):
Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
import os
2+
import torch
3+
4+
from pathlib import Path
5+
from typing import Type, List
6+
from transformers import AutoModel
7+
from opentelemetry import trace
8+
from loguru import logger
9+
10+
from text_embeddings_server.models.model import Model
11+
from text_embeddings_server.models.types import PaddedBatch, Embedding, Score
12+
13+
tracer = trace.get_tracer(__name__)
14+
15+
16+
def _parse_bool(value: str) -> bool:
17+
"""Parse boolean from string with common conventions."""
18+
return str(value).lower() in ("true", "1", "t", "yes", "on")
19+
20+
21+
class XProvenceModel(Model):
22+
"""
23+
XProvence: Zero-cost context pruning model for RAG.
24+
25+
XProvence removes irrelevant sentences from passages based on relevance
26+
to the query, returning both a reranking score and pruned context.
27+
28+
Based on bge-reranker-v2-m3 (XLM-RoBERTa), supports 16+ languages.
29+
30+
Environment Variables:
31+
XPROVENCE_THRESHOLD (float): Pruning threshold between 0.0-1.0.
32+
- 0.3 (default): Conservative pruning, minimal performance drop
33+
- 0.7: Aggressive pruning, higher compression
34+
XPROVENCE_ALWAYS_SELECT_TITLE (bool): Keep first sentence as title.
35+
- true (default): Always include first sentence (useful for Wikipedia)
36+
- false: Only include sentences above threshold
37+
"""
38+
39+
def __init__(
40+
self,
41+
model_path: Path,
42+
device: torch.device,
43+
dtype: torch.dtype,
44+
pool: str = "cls",
45+
trust_remote: bool = True,
46+
):
47+
# XProvence requires AutoModel with trust_remote_code=True
48+
model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
49+
model = model.to(dtype).to(device)
50+
51+
self.hidden_size = model.config.hidden_size
52+
53+
# XProvence is based on XLM-RoBERTa
54+
position_offset = 0
55+
model_type = model.config.model_type
56+
if model_type in ["xlm-roberta", "camembert", "roberta"]:
57+
position_offset = model.config.pad_token_id + 1
58+
59+
if hasattr(model.config, "max_seq_length"):
60+
self.max_input_length = model.config.max_seq_length
61+
else:
62+
self.max_input_length = (
63+
model.config.max_position_embeddings - position_offset
64+
)
65+
66+
# XProvence pruning options from environment variables
67+
# XPROVENCE_THRESHOLD: 0.0-1.0, lower = more conservative (default: 0.3)
68+
# XPROVENCE_ALWAYS_SELECT_TITLE: keep first sentence as title (default: true)
69+
try:
70+
threshold_env = os.getenv("XPROVENCE_THRESHOLD", "0.3")
71+
self.threshold = float(threshold_env)
72+
if not (0.0 <= self.threshold <= 1.0):
73+
logger.warning(
74+
f"XPROVENCE_THRESHOLD={self.threshold} out of bounds [0.0, 1.0], "
75+
f"defaulting to 0.3"
76+
)
77+
self.threshold = 0.3
78+
except ValueError:
79+
logger.error(
80+
f"Invalid XPROVENCE_THRESHOLD='{threshold_env}', defaulting to 0.3"
81+
)
82+
self.threshold = 0.3
83+
84+
self.always_select_title = _parse_bool(
85+
os.getenv("XPROVENCE_ALWAYS_SELECT_TITLE", "true")
86+
)
87+
88+
logger.info(
89+
f"XProvence model loaded: threshold={self.threshold}, "
90+
f"always_select_title={self.always_select_title} "
91+
f"(Configure via XPROVENCE_THRESHOLD, XPROVENCE_ALWAYS_SELECT_TITLE env vars)"
92+
)
93+
94+
super(XProvenceModel, self).__init__(model=model, dtype=dtype, device=device)
95+
96+
@property
97+
def batch_type(self) -> Type[PaddedBatch]:
98+
return PaddedBatch
99+
100+
@tracer.start_as_current_span("embed")
101+
def embed(self, batch: PaddedBatch) -> List[Embedding]:
102+
# XProvence is a reranker, not an embedding model
103+
pass
104+
105+
@tracer.start_as_current_span("predict")
106+
def predict(self, batch: PaddedBatch) -> List[Score]:
107+
"""
108+
XProvence prediction with context pruning.
109+
110+
If raw_query and raw_text are provided, uses XProvence's process()
111+
method to perform sentence-level context pruning.
112+
Otherwise, falls back to standard reranking without pruning.
113+
"""
114+
# Check if raw text is available for XProvence processing
115+
if batch.raw_query and batch.raw_text:
116+
return self._predict_with_pruning(batch.raw_query, batch.raw_text)
117+
else:
118+
# Fallback: standard forward pass without pruning
119+
return self._predict_standard(batch)
120+
121+
def _predict_with_pruning(self, raw_query: str, raw_text: str) -> List[Score]:
122+
"""
123+
Use XProvence's process() method for context pruning.
124+
125+
Returns score with pruned_text containing only relevant sentences.
126+
"""
127+
try:
128+
output = self.model.process(
129+
raw_query,
130+
raw_text,
131+
threshold=self.threshold,
132+
always_select_title=self.always_select_title,
133+
)
134+
135+
reranking_score = float(output["reranking_score"])
136+
pruned_context = output["pruned_context"]
137+
138+
logger.debug(
139+
f"XProvence pruning: score={reranking_score:.4f}, "
140+
f"original_len={len(raw_text)}, pruned_len={len(pruned_context)}"
141+
)
142+
143+
return [Score(values=[reranking_score], pruned_text=pruned_context)]
144+
145+
except Exception as e:
146+
logger.error(f"XProvence process() failed: {e}, falling back to standard")
147+
# Return a default score without pruning on error
148+
return [Score(values=[0.0], pruned_text=None)]
149+
150+
def _predict_standard(self, batch: PaddedBatch) -> List[Score]:
151+
"""
152+
Standard forward pass without context pruning.
153+
154+
Used as fallback when raw text is not available.
155+
"""
156+
kwargs = {"input_ids": batch.input_ids, "attention_mask": batch.attention_mask}
157+
158+
output = self.model(**kwargs, return_dict=True)
159+
160+
# XProvence forward returns ranking logits at position 0
161+
if hasattr(output, "logits"):
162+
logits = output.logits
163+
else:
164+
# Assume first element is ranking logits
165+
logits = output[0]
166+
167+
# Extract scores (first column if multi-dimensional)
168+
if logits.dim() == 2 and logits.size(1) >= 1:
169+
scores = logits[:, 0].tolist()
170+
else:
171+
scores = logits.tolist()
172+
173+
return [Score(values=[s], pruned_text=None) for s in scores]

backends/python/src/lib.rs

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use backend_grpc_client::Client;
55
use nohash_hasher::BuildNoHashHasher;
66
use std::collections::HashMap;
77
use text_embeddings_backend_core::{
8-
Backend, BackendError, Batch, Embedding, Embeddings, ModelType, Pool, Predictions,
8+
Backend, BackendError, Batch, Embedding, Embeddings, ModelType, Pool, Prediction, Predictions,
99
};
1010
use tokio::runtime::Runtime;
1111

@@ -108,6 +108,11 @@ impl Backend for PythonBackend {
108108
));
109109
}
110110
let batch_size = batch.len();
111+
112+
// XProvence: Get first raw query/text from batch (for single request)
113+
let raw_query = batch.raw_queries.first().cloned().flatten();
114+
let raw_text = batch.raw_texts.first().cloned().flatten();
115+
111116
let results = self
112117
.tokio_runtime
113118
.block_on(self.backend_client.clone().predict(
@@ -116,15 +121,22 @@ impl Backend for PythonBackend {
116121
batch.position_ids,
117122
batch.cumulative_seq_lengths,
118123
batch.max_length,
124+
raw_query,
125+
raw_text,
119126
))
120127
.map_err(|err| BackendError::Inference(err.to_string()))?;
121-
let raw_results: Vec<Vec<f32>> = results.into_iter().map(|r| r.values).collect();
122128

123129
let mut predictions =
124130
HashMap::with_capacity_and_hasher(batch_size, BuildNoHashHasher::default());
125131

126-
for (i, r) in raw_results.into_iter().enumerate() {
127-
predictions.insert(i, r);
132+
for (i, score) in results.into_iter().enumerate() {
133+
predictions.insert(
134+
i,
135+
Prediction {
136+
scores: score.values,
137+
pruned_text: score.pruned_text,
138+
},
139+
);
128140
}
129141

130142
Ok(predictions)

core/src/infer.rs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -561,11 +561,13 @@ async fn backend_task(backend: Backend, mut embed_receiver: mpsc::Receiver<NextB
561561
inference: inference_duration,
562562
};
563563

564+
let prediction = predictions.remove(&i).expect(
565+
"prediction not found in results. This is a backend bug.",
566+
);
564567
let _ = m.response_tx.send(Ok(InferResult::Classification(
565568
ClassificationInferResponse {
566-
results: predictions.remove(&i).expect(
567-
"prediction not found in results. This is a backend bug.",
568-
),
569+
results: prediction.scores,
570+
pruned_text: prediction.pruned_text,
569571
metadata: infer_metadata,
570572
},
571573
)));
@@ -642,6 +644,8 @@ pub(crate) enum InferResult {
642644
#[derive(Debug)]
643645
pub struct ClassificationInferResponse {
644646
pub results: Vec<f32>,
647+
/// XProvence: pruned context text after removing irrelevant sentences
648+
pub pruned_text: Option<String>,
645649
pub metadata: InferMetadata,
646650
}
647651

core/src/queue.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,10 @@ fn queue_blocking_task(
129129
let mut cu_seq_lengths = Vec::with_capacity(capacity);
130130
cu_seq_lengths.push(0);
131131

132+
// XProvence: raw text vectors for context pruning
133+
let mut raw_queries = Vec::with_capacity(capacity);
134+
let mut raw_texts = Vec::with_capacity(capacity);
135+
132136
let mut current_tokens = 0;
133137
let mut max_length = 0;
134138

@@ -168,6 +172,10 @@ fn queue_blocking_task(
168172
token_type_ids.extend(entry.encoding.token_type_ids);
169173
position_ids.extend(entry.encoding.position_ids);
170174

175+
// XProvence: collect raw texts for context pruning
176+
raw_queries.push(entry.encoding.raw_query);
177+
raw_texts.push(entry.encoding.raw_text);
178+
171179
current_tokens += entry_tokens;
172180
metadata.push(entry.metadata);
173181
cu_seq_lengths.push(current_tokens as u32);
@@ -193,6 +201,8 @@ fn queue_blocking_task(
193201
max_length,
194202
pooled_indices,
195203
raw_indices,
204+
raw_queries,
205+
raw_texts,
196206
},
197207
))
198208
};

0 commit comments

Comments
 (0)