|
| 1 | +import os |
| 2 | +import torch |
| 3 | + |
| 4 | +from pathlib import Path |
| 5 | +from typing import Type, List |
| 6 | +from transformers import AutoModel |
| 7 | +from opentelemetry import trace |
| 8 | +from loguru import logger |
| 9 | + |
| 10 | +from text_embeddings_server.models.model import Model |
| 11 | +from text_embeddings_server.models.types import PaddedBatch, Embedding, Score |
| 12 | + |
| 13 | +tracer = trace.get_tracer(__name__) |
| 14 | + |
| 15 | + |
| 16 | +def _parse_bool(value: str) -> bool: |
| 17 | + """Parse boolean from string with common conventions.""" |
| 18 | + return str(value).lower() in ("true", "1", "t", "yes", "on") |
| 19 | + |
| 20 | + |
| 21 | +class XProvenceModel(Model): |
| 22 | + """ |
| 23 | + XProvence: Zero-cost context pruning model for RAG. |
| 24 | +
|
| 25 | + XProvence removes irrelevant sentences from passages based on relevance |
| 26 | + to the query, returning both a reranking score and pruned context. |
| 27 | +
|
| 28 | + Based on bge-reranker-v2-m3 (XLM-RoBERTa), supports 16+ languages. |
| 29 | +
|
| 30 | + Environment Variables: |
| 31 | + XPROVENCE_THRESHOLD (float): Pruning threshold between 0.0-1.0. |
| 32 | + - 0.3 (default): Conservative pruning, minimal performance drop |
| 33 | + - 0.7: Aggressive pruning, higher compression |
| 34 | + XPROVENCE_ALWAYS_SELECT_TITLE (bool): Keep first sentence as title. |
| 35 | + - true (default): Always include first sentence (useful for Wikipedia) |
| 36 | + - false: Only include sentences above threshold |
| 37 | + """ |
| 38 | + |
| 39 | + def __init__( |
| 40 | + self, |
| 41 | + model_path: Path, |
| 42 | + device: torch.device, |
| 43 | + dtype: torch.dtype, |
| 44 | + pool: str = "cls", |
| 45 | + trust_remote: bool = True, |
| 46 | + ): |
| 47 | + # XProvence requires AutoModel with trust_remote_code=True |
| 48 | + model = AutoModel.from_pretrained(model_path, trust_remote_code=True) |
| 49 | + model = model.to(dtype).to(device) |
| 50 | + |
| 51 | + self.hidden_size = model.config.hidden_size |
| 52 | + |
| 53 | + # XProvence is based on XLM-RoBERTa |
| 54 | + position_offset = 0 |
| 55 | + model_type = model.config.model_type |
| 56 | + if model_type in ["xlm-roberta", "camembert", "roberta"]: |
| 57 | + position_offset = model.config.pad_token_id + 1 |
| 58 | + |
| 59 | + if hasattr(model.config, "max_seq_length"): |
| 60 | + self.max_input_length = model.config.max_seq_length |
| 61 | + else: |
| 62 | + self.max_input_length = ( |
| 63 | + model.config.max_position_embeddings - position_offset |
| 64 | + ) |
| 65 | + |
| 66 | + # XProvence pruning options from environment variables |
| 67 | + # XPROVENCE_THRESHOLD: 0.0-1.0, lower = more conservative (default: 0.3) |
| 68 | + # XPROVENCE_ALWAYS_SELECT_TITLE: keep first sentence as title (default: true) |
| 69 | + try: |
| 70 | + threshold_env = os.getenv("XPROVENCE_THRESHOLD", "0.3") |
| 71 | + self.threshold = float(threshold_env) |
| 72 | + if not (0.0 <= self.threshold <= 1.0): |
| 73 | + logger.warning( |
| 74 | + f"XPROVENCE_THRESHOLD={self.threshold} out of bounds [0.0, 1.0], " |
| 75 | + f"defaulting to 0.3" |
| 76 | + ) |
| 77 | + self.threshold = 0.3 |
| 78 | + except ValueError: |
| 79 | + logger.error( |
| 80 | + f"Invalid XPROVENCE_THRESHOLD='{threshold_env}', defaulting to 0.3" |
| 81 | + ) |
| 82 | + self.threshold = 0.3 |
| 83 | + |
| 84 | + self.always_select_title = _parse_bool( |
| 85 | + os.getenv("XPROVENCE_ALWAYS_SELECT_TITLE", "true") |
| 86 | + ) |
| 87 | + |
| 88 | + logger.info( |
| 89 | + f"XProvence model loaded: threshold={self.threshold}, " |
| 90 | + f"always_select_title={self.always_select_title} " |
| 91 | + f"(Configure via XPROVENCE_THRESHOLD, XPROVENCE_ALWAYS_SELECT_TITLE env vars)" |
| 92 | + ) |
| 93 | + |
| 94 | + super(XProvenceModel, self).__init__(model=model, dtype=dtype, device=device) |
| 95 | + |
| 96 | + @property |
| 97 | + def batch_type(self) -> Type[PaddedBatch]: |
| 98 | + return PaddedBatch |
| 99 | + |
| 100 | + @tracer.start_as_current_span("embed") |
| 101 | + def embed(self, batch: PaddedBatch) -> List[Embedding]: |
| 102 | + # XProvence is a reranker, not an embedding model |
| 103 | + pass |
| 104 | + |
| 105 | + @tracer.start_as_current_span("predict") |
| 106 | + def predict(self, batch: PaddedBatch) -> List[Score]: |
| 107 | + """ |
| 108 | + XProvence prediction with context pruning. |
| 109 | +
|
| 110 | + If raw_query and raw_text are provided, uses XProvence's process() |
| 111 | + method to perform sentence-level context pruning. |
| 112 | + Otherwise, falls back to standard reranking without pruning. |
| 113 | + """ |
| 114 | + # Check if raw text is available for XProvence processing |
| 115 | + if batch.raw_query and batch.raw_text: |
| 116 | + return self._predict_with_pruning(batch.raw_query, batch.raw_text) |
| 117 | + else: |
| 118 | + # Fallback: standard forward pass without pruning |
| 119 | + return self._predict_standard(batch) |
| 120 | + |
| 121 | + def _predict_with_pruning(self, raw_query: str, raw_text: str) -> List[Score]: |
| 122 | + """ |
| 123 | + Use XProvence's process() method for context pruning. |
| 124 | +
|
| 125 | + Returns score with pruned_text containing only relevant sentences. |
| 126 | + """ |
| 127 | + try: |
| 128 | + output = self.model.process( |
| 129 | + raw_query, |
| 130 | + raw_text, |
| 131 | + threshold=self.threshold, |
| 132 | + always_select_title=self.always_select_title, |
| 133 | + ) |
| 134 | + |
| 135 | + reranking_score = float(output["reranking_score"]) |
| 136 | + pruned_context = output["pruned_context"] |
| 137 | + |
| 138 | + logger.debug( |
| 139 | + f"XProvence pruning: score={reranking_score:.4f}, " |
| 140 | + f"original_len={len(raw_text)}, pruned_len={len(pruned_context)}" |
| 141 | + ) |
| 142 | + |
| 143 | + return [Score(values=[reranking_score], pruned_text=pruned_context)] |
| 144 | + |
| 145 | + except Exception as e: |
| 146 | + logger.error(f"XProvence process() failed: {e}, falling back to standard") |
| 147 | + # Return a default score without pruning on error |
| 148 | + return [Score(values=[0.0], pruned_text=None)] |
| 149 | + |
| 150 | + def _predict_standard(self, batch: PaddedBatch) -> List[Score]: |
| 151 | + """ |
| 152 | + Standard forward pass without context pruning. |
| 153 | +
|
| 154 | + Used as fallback when raw text is not available. |
| 155 | + """ |
| 156 | + kwargs = {"input_ids": batch.input_ids, "attention_mask": batch.attention_mask} |
| 157 | + |
| 158 | + output = self.model(**kwargs, return_dict=True) |
| 159 | + |
| 160 | + # XProvence forward returns ranking logits at position 0 |
| 161 | + if hasattr(output, "logits"): |
| 162 | + logits = output.logits |
| 163 | + else: |
| 164 | + # Assume first element is ranking logits |
| 165 | + logits = output[0] |
| 166 | + |
| 167 | + # Extract scores (first column if multi-dimensional) |
| 168 | + if logits.dim() == 2 and logits.size(1) >= 1: |
| 169 | + scores = logits[:, 0].tolist() |
| 170 | + else: |
| 171 | + scores = logits.tolist() |
| 172 | + |
| 173 | + return [Score(values=[s], pruned_text=None) for s in scores] |
0 commit comments