@@ -46,6 +46,13 @@ def __init__(
4646 ):
4747 # XProvence requires AutoModel with trust_remote_code=True
4848 model = AutoModel .from_pretrained (model_path , trust_remote_code = True )
49+
50+ # XProvence's process() method doesn't support bfloat16,
51+ # so we use float32 for full pruning support
52+ if dtype == torch .bfloat16 :
53+ logger .info ("XProvence: using float32 instead of bfloat16 for process() compatibility" )
54+ dtype = torch .float32
55+
4956 model = model .to (dtype ).to (device )
5057
5158 self .hidden_size = model .config .hidden_size
@@ -105,18 +112,20 @@ def embed(self, batch: PaddedBatch) -> List[Embedding]:
105112 @tracer .start_as_current_span ("predict" )
106113 def predict (self , batch : PaddedBatch ) -> List [Score ]:
107114 """
108- XProvence prediction with context pruning.
115+ XProvence prediction with context pruning support .
109116
110- If raw_query and raw_text are provided, uses XProvence's process()
111- method to perform sentence-level context pruning.
112- Otherwise, falls back to standard reranking without pruning .
117+ For single-item batches with raw_query/ raw_text available,
118+ uses XProvence's process() method for sentence-level pruning.
119+ Otherwise falls back to standard forward pass .
113120 """
114- # Check if raw text is available for XProvence processing
115- if batch .raw_query and batch .raw_text :
121+ batch_size = len (batch )
122+
123+ # Use pruning only for single-item batches with raw text
124+ if batch_size == 1 and batch .raw_query and batch .raw_text :
116125 return self ._predict_with_pruning (batch .raw_query , batch .raw_text )
117- else :
118- # Fallback: standard forward pass without pruning
119- return self ._predict_standard (batch )
126+
127+ # Multi-item batches or no raw text: use standard forward pass
128+ return self ._predict_standard (batch )
120129
121130 def _predict_with_pruning (self , raw_query : str , raw_text : str ) -> List [Score ]:
122131 """
@@ -125,12 +134,24 @@ def _predict_with_pruning(self, raw_query: str, raw_text: str) -> List[Score]:
125134 Returns score with pruned_text containing only relevant sentences.
126135 """
127136 try :
128- output = self .model .process (
129- raw_query ,
130- raw_text ,
131- threshold = self .threshold ,
132- always_select_title = self .always_select_title ,
133- )
137+ # Disable tqdm progress bar to avoid broken pipe errors
138+ # when stdout/stderr is captured by the Rust process
139+ os .environ ["TQDM_DISABLE" ] = "1"
140+
141+ # Force float32 for XProvence's process() method
142+ # which creates internal tensors and doesn't support bfloat16
143+ original_dtype = torch .get_default_dtype ()
144+ torch .set_default_dtype (torch .float32 )
145+
146+ try :
147+ output = self .model .process (
148+ raw_query ,
149+ raw_text ,
150+ threshold = self .threshold ,
151+ always_select_title = self .always_select_title ,
152+ )
153+ finally :
154+ torch .set_default_dtype (original_dtype )
134155
135156 reranking_score = float (output ["reranking_score" ])
136157 pruned_context = output ["pruned_context" ]
@@ -157,17 +178,26 @@ def _predict_standard(self, batch: PaddedBatch) -> List[Score]:
157178
158179 output = self .model (** kwargs , return_dict = True )
159180
160- # XProvence forward returns ranking logits at position 0
161- if hasattr (output , "logits" ):
162- logits = output .logits
181+ # XProvence returns RankingCompressionOutput with ranking_scores
182+ if hasattr (output , "ranking_scores" ):
183+ scores_tensor = output .ranking_scores
184+ elif hasattr (output , "logits" ):
185+ # Fallback for standard classification models
186+ scores_tensor = output .logits [:, 0 ] if output .logits .dim () == 2 else output .logits
163187 else :
164- # Assume first element is ranking logits
165- logits = output [0 ]
188+ # Assume first element is scores
189+ scores_tensor = output [0 ]
166190
167- # Extract scores (first column if multi-dimensional)
168- if logits .dim () == 2 and logits .size (1 ) >= 1 :
169- scores = logits [:, 0 ].tolist ()
191+ # Handle scalar (batch_size=1) vs tensor (batch_size>1)
192+ if scores_tensor .dim () == 0 :
193+ # Scalar - single item batch
194+ scores = [float (scores_tensor .item ())]
170195 else :
171- scores = logits .tolist ()
196+ # 1D tensor - multiple items
197+ scores = scores_tensor .view (- 1 ).tolist ()
198+
199+ # Ensure scores is a list
200+ if isinstance (scores , float ):
201+ scores = [scores ]
172202
173- return [Score (values = [s ], pruned_text = None ) for s in scores ]
203+ return [Score (values = [float ( s ) ], pruned_text = None ) for s in scores ]
0 commit comments