Skip to content

Commit 5631b2e

Browse files
Sigrid JinUser
authored andcommitted
fix: XProvence context pruning with bfloat16 and flash_attn compatibility
1 parent fc4f021 commit 5631b2e

File tree

2 files changed

+70
-31
lines changed

2 files changed

+70
-31
lines changed

backends/python/server/text_embeddings_server/models/__init__.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,19 @@
1212
from text_embeddings_server.models.default_model import DefaultModel
1313
from text_embeddings_server.models.classification_model import ClassificationModel
1414
from text_embeddings_server.models.xprovence_model import XProvenceModel
15-
from text_embeddings_server.models.jinaBert_model import FlashJinaBert
16-
from text_embeddings_server.models.flash_mistral import FlashMistral
17-
from text_embeddings_server.models.flash_qwen3 import FlashQwen3
1815
from text_embeddings_server.utils.device import get_device, use_ipex
1916

17+
# Flash attention models are optional (require flash_attn)
18+
FlashJinaBert = None
19+
FlashMistral = None
20+
FlashQwen3 = None
21+
try:
22+
from text_embeddings_server.models.jinaBert_model import FlashJinaBert
23+
from text_embeddings_server.models.flash_mistral import FlashMistral
24+
from text_embeddings_server.models.flash_qwen3 import FlashQwen3
25+
except ImportError as e:
26+
logger.warning(f"Flash attention models not available: {e}")
27+
2028
__all__ = ["Model"]
2129

2230
TRUST_REMOTE_CODE = os.getenv("TRUST_REMOTE_CODE", "false").lower() in ["true", "1"]
@@ -86,7 +94,8 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str):
8694
return XProvenceModel(model_path, device, datatype, trust_remote=True)
8795

8896
if (
89-
hasattr(config, "auto_map")
97+
FlashJinaBert is not None
98+
and hasattr(config, "auto_map")
9099
and isinstance(config.auto_map, dict)
91100
and "AutoModel" in config.auto_map
92101
and config.auto_map["AutoModel"]
@@ -126,13 +135,13 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str):
126135
else:
127136
return create_model(DefaultModel, model_path, device, datatype, pool)
128137

129-
if config.model_type == "mistral" and device.type == "hpu":
138+
if FlashMistral is not None and config.model_type == "mistral" and device.type == "hpu":
130139
try:
131140
return create_model(FlashMistral, model_path, device, datatype, pool)
132141
except FileNotFoundError:
133142
return create_model(DefaultModel, model_path, device, datatype, pool)
134143

135-
if config.model_type == "qwen3" and device.type == "hpu":
144+
if FlashQwen3 is not None and config.model_type == "qwen3" and device.type == "hpu":
136145
try:
137146
return create_model(FlashQwen3, model_path, device, datatype, pool)
138147
except FileNotFoundError:

backends/python/server/text_embeddings_server/models/xprovence_model.py

Lines changed: 55 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,13 @@ def __init__(
4646
):
4747
# XProvence requires AutoModel with trust_remote_code=True
4848
model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
49+
50+
# XProvence's process() method doesn't support bfloat16,
51+
# so we use float32 for full pruning support
52+
if dtype == torch.bfloat16:
53+
logger.info("XProvence: using float32 instead of bfloat16 for process() compatibility")
54+
dtype = torch.float32
55+
4956
model = model.to(dtype).to(device)
5057

5158
self.hidden_size = model.config.hidden_size
@@ -105,18 +112,20 @@ def embed(self, batch: PaddedBatch) -> List[Embedding]:
105112
@tracer.start_as_current_span("predict")
106113
def predict(self, batch: PaddedBatch) -> List[Score]:
107114
"""
108-
XProvence prediction with context pruning.
115+
XProvence prediction with context pruning support.
109116
110-
If raw_query and raw_text are provided, uses XProvence's process()
111-
method to perform sentence-level context pruning.
112-
Otherwise, falls back to standard reranking without pruning.
117+
For single-item batches with raw_query/raw_text available,
118+
uses XProvence's process() method for sentence-level pruning.
119+
Otherwise falls back to standard forward pass.
113120
"""
114-
# Check if raw text is available for XProvence processing
115-
if batch.raw_query and batch.raw_text:
121+
batch_size = len(batch)
122+
123+
# Use pruning only for single-item batches with raw text
124+
if batch_size == 1 and batch.raw_query and batch.raw_text:
116125
return self._predict_with_pruning(batch.raw_query, batch.raw_text)
117-
else:
118-
# Fallback: standard forward pass without pruning
119-
return self._predict_standard(batch)
126+
127+
# Multi-item batches or no raw text: use standard forward pass
128+
return self._predict_standard(batch)
120129

121130
def _predict_with_pruning(self, raw_query: str, raw_text: str) -> List[Score]:
122131
"""
@@ -125,12 +134,24 @@ def _predict_with_pruning(self, raw_query: str, raw_text: str) -> List[Score]:
125134
Returns score with pruned_text containing only relevant sentences.
126135
"""
127136
try:
128-
output = self.model.process(
129-
raw_query,
130-
raw_text,
131-
threshold=self.threshold,
132-
always_select_title=self.always_select_title,
133-
)
137+
# Disable tqdm progress bar to avoid broken pipe errors
138+
# when stdout/stderr is captured by the Rust process
139+
os.environ["TQDM_DISABLE"] = "1"
140+
141+
# Force float32 for XProvence's process() method
142+
# which creates internal tensors and doesn't support bfloat16
143+
original_dtype = torch.get_default_dtype()
144+
torch.set_default_dtype(torch.float32)
145+
146+
try:
147+
output = self.model.process(
148+
raw_query,
149+
raw_text,
150+
threshold=self.threshold,
151+
always_select_title=self.always_select_title,
152+
)
153+
finally:
154+
torch.set_default_dtype(original_dtype)
134155

135156
reranking_score = float(output["reranking_score"])
136157
pruned_context = output["pruned_context"]
@@ -157,17 +178,26 @@ def _predict_standard(self, batch: PaddedBatch) -> List[Score]:
157178

158179
output = self.model(**kwargs, return_dict=True)
159180

160-
# XProvence forward returns ranking logits at position 0
161-
if hasattr(output, "logits"):
162-
logits = output.logits
181+
# XProvence returns RankingCompressionOutput with ranking_scores
182+
if hasattr(output, "ranking_scores"):
183+
scores_tensor = output.ranking_scores
184+
elif hasattr(output, "logits"):
185+
# Fallback for standard classification models
186+
scores_tensor = output.logits[:, 0] if output.logits.dim() == 2 else output.logits
163187
else:
164-
# Assume first element is ranking logits
165-
logits = output[0]
188+
# Assume first element is scores
189+
scores_tensor = output[0]
166190

167-
# Extract scores (first column if multi-dimensional)
168-
if logits.dim() == 2 and logits.size(1) >= 1:
169-
scores = logits[:, 0].tolist()
191+
# Handle scalar (batch_size=1) vs tensor (batch_size>1)
192+
if scores_tensor.dim() == 0:
193+
# Scalar - single item batch
194+
scores = [float(scores_tensor.item())]
170195
else:
171-
scores = logits.tolist()
196+
# 1D tensor - multiple items
197+
scores = scores_tensor.view(-1).tolist()
198+
199+
# Ensure scores is a list
200+
if isinstance(scores, float):
201+
scores = [scores]
172202

173-
return [Score(values=[s], pruned_text=None) for s in scores]
203+
return [Score(values=[float(s)], pruned_text=None) for s in scores]

0 commit comments

Comments
 (0)