From 292f91f9e66f18fb8e43d5598f0d003ae00a42b8 Mon Sep 17 00:00:00 2001 From: elronbandel Date: Mon, 11 Aug 2025 14:30:34 +0300 Subject: [PATCH 01/16] Seperate mock inference class and add control Signed-off-by: elronbandel --- src/unitxt/inference.py | 96 +++++++++++++++++++++-------------------- 1 file changed, 50 insertions(+), 46 deletions(-) diff --git a/src/unitxt/inference.py b/src/unitxt/inference.py index 87488d1da7..8becb21aa1 100644 --- a/src/unitxt/inference.py +++ b/src/unitxt/inference.py @@ -161,7 +161,51 @@ def __repr__(self): return f"ListWithMetadata(data={super().__repr__()}, metadata={self.metadata})" -class InferenceEngine(Artifact): +class MockInferenceMixin(Artifact): + use_mock: bool = False + + @property + def is_mock(self): + return self.use_mock or settings.mock_inference_mode + + def _mock_infer( + self, + dataset: Union[List[Dict[str, Any]], Dataset], + return_meta_data: bool = False, + ) -> Union[List[str], List[TextGenerationInferenceOutput]]: + result = [] + for instance in dataset: + prediction = str(instance["source"]) + if return_meta_data: + result.append( + TextGenerationInferenceOutput( + prediction=prediction, generated_text=prediction + ) + ) + else: + result.append(prediction) + return result + + @staticmethod + def mock_logprobs_default_value_factory() -> List[Dict[str, Any]]: + return [ + { + "logprob": -1, + "text": "[[10]]", + "top_tokens": [ + {"logprob": -1, "text": "[[10]]"}, + ], + } + ] + + def _mock_infer_log_probs( + self, + dataset: Union[List[Dict[str, Any]], Dataset], + ) -> Union[List[str], List[TextGenerationInferenceOutput]]: + return [self.mock_logprobs_default_value_factory() for _ in dataset] + + +class InferenceEngine(abc.ABC, MockInferenceMixin): """Abstract base class for inference.""" cache_batch_size: int = 100 @@ -187,7 +231,7 @@ def prepare_engine(self): pass def prepare(self): - if not settings.mock_inference_mode: + if not self.is_mock: super().prepare() # no need to prepare a mock with error_context( self, @@ -254,7 +298,7 @@ def infer( predictions. """ self.verify_infer_inputs(dataset, return_meta_data) - if settings.mock_inference_mode: + if self.is_mock: result = self._mock_infer(dataset, return_meta_data) else: if self.use_cache: @@ -330,24 +374,6 @@ def infer( }, ) - def _mock_infer( - self, - dataset: Union[List[Dict[str, Any]], Dataset], - return_meta_data: bool = False, - ) -> Union[List[str], List[TextGenerationInferenceOutput]]: - result = [] - for instance in dataset: - prediction = str(instance["source"]) - if return_meta_data: - result.append( - TextGenerationInferenceOutput( - prediction=prediction, generated_text=prediction - ) - ) - else: - result.append(prediction) - return result - @abc.abstractmethod def get_engine_id(self): raise NotImplementedError() @@ -409,7 +435,7 @@ def to_tools(self, instance): return None -class LogProbInferenceEngine(abc.ABC, Artifact): +class LogProbInferenceEngine(abc.ABC, MockInferenceMixin): """Abstract base class for inference with log probs.""" @abc.abstractmethod @@ -426,12 +452,6 @@ def _infer_log_probs( """ pass - def _mock_infer_log_probs( - self, - dataset: Union[List[Dict[str, Any]], Dataset], - ) -> Union[List[str], List[TextGenerationInferenceOutput]]: - return [mock_logprobs_default_value_factory() for instance in dataset] - def infer_log_probs( self, dataset: Union[List[Dict[str, Any]], Dataset], @@ -452,7 +472,7 @@ def infer_log_probs( [self.verify_instance(instance) for instance in dataset] - if settings.mock_inference_mode: + if self.is_mock: result = self._mock_infer_log_probs(dataset) else: result = self._infer_log_probs(dataset, return_meta_data) @@ -1288,23 +1308,11 @@ def get_return_object(self, output, inp, return_meta_data): return output["generated_text"] -def mock_logprobs_default_value_factory() -> List[Dict[str, Any]]: - return [ - { - "logprob": -1, - "text": "[[10]]", - "top_tokens": [ - {"logprob": -1, "text": "[[10]]"}, - ], - } - ] - - class MockInferenceEngine(InferenceEngine, LogProbInferenceEngine): model_name: str default_inference_value: str = "[[10]]" default_inference_value_logprob: List[Dict[str, Any]] = dataclasses.field( - default_factory=mock_logprobs_default_value_factory, + default_factory=MockInferenceMixin.mock_logprobs_default_value_factory, ) label: str = "mock_inference_engine" @@ -1374,10 +1382,6 @@ def get_return_object( return predict_result -class MockModeMixin(Artifact): - mock_mode: bool = False - - class GenericInferenceEngine( InferenceEngine, ArtifactFetcherMixin, LogProbInferenceEngine ): From 1e3872ee9b762e9cf530d04be1e618da941dd47a Mon Sep 17 00:00:00 2001 From: elronbandel Date: Mon, 11 Aug 2025 20:40:06 +0300 Subject: [PATCH 02/16] Seperate cache and prepare for streaming Signed-off-by: elronbandel --- src/unitxt/inference.py | 236 ++++++++++++++++------- tests/inference/test_inference_engine.py | 6 +- 2 files changed, 170 insertions(+), 72 deletions(-) diff --git a/src/unitxt/inference.py b/src/unitxt/inference.py index 8becb21aa1..33be6dc15d 100644 --- a/src/unitxt/inference.py +++ b/src/unitxt/inference.py @@ -6,7 +6,6 @@ import io import json import logging -import math import os import re import sys @@ -205,11 +204,112 @@ def _mock_infer_log_probs( return [self.mock_logprobs_default_value_factory() for _ in dataset] -class InferenceEngine(abc.ABC, MockInferenceMixin): - """Abstract base class for inference.""" +class CachedInferenceMixin(Artifact): + """Mixin that provides caching functionality for inference engines.""" - cache_batch_size: int = 100 + cache_batch_size: int = ( + 100 # Kept for backwards compatibility, not used in streaming + ) use_cache: bool = True + _cache: Any = InternalField(default=None, name="Disk cache instance") + + def _initialize_cache(self): + """Initialize the disk cache if caching is enabled.""" + if self.use_cache and self._cache is None: + from diskcache import Cache + + self._cache = Cache( + os.path.join( + settings.inference_engine_cache_path, self.__class__.__name__ + ) + ) + + def get_instance_cache_key(self, instance): + """Extract cacheable fields from an instance.""" + instance_key_fields = ["media", "source", "task_data"] + return {key: instance[key] for key in instance if key in instance_key_fields} + + def _get_cache_key(self, instance: Dict[str, Any]) -> str: + """Generate a unique cache key for each input.""" + record = self.get_instance_cache_key(instance) + record["version"] = constants.version + record.update(self.to_dict()) + instance_str = json.dumps(record, sort_keys=True) + return hashlib.md5(instance_str.encode()).hexdigest() + + def _get_cached_result(self, instance: Dict[str, Any]): + """Get cached result for an instance, returns None if not found.""" + if not self.use_cache or self._cache is None: + return None + cache_key = self._get_cache_key(instance) + return self._cache.get(cache_key) + + def _cache_result(self, instance: Dict[str, Any], prediction): + """Cache a prediction result for an instance.""" + if not self.use_cache or self._cache is None or prediction is None: + return + cache_key = self._get_cache_key(instance) + self._cache[cache_key] = prediction + + def _apply_caching_to_streaming( + self, + dataset: Union[List[Dict[str, Any]], Dataset], + streaming_func, + return_meta_data: bool = False, + ): + """Apply caching logic to a streaming inference function.""" + if not self.use_cache: + # No caching, just use streaming directly + instances_iter = ((i, instance) for i, instance in enumerate(dataset)) + results = {} + for index, prediction in streaming_func(instances_iter, return_meta_data): + results[index] = prediction + return [results[i] for i in range(len(dataset))] + + # Initialize cache if needed + self._initialize_cache() + + # Phase 1: Identify cached vs missing instances + cached_results = {} + missing_instances = [] + cached_count = 0 + + for i, instance in enumerate(dataset): + cached_value = self._get_cached_result(instance) + + if cached_value is not None: + cached_results[i] = cached_value + cached_count += 1 + else: + missing_instances.append((i, instance)) + + message = ( + f"Found {cached_count} cached instances, inferring {len(missing_instances)} instances" + + ( + f" (cache: {self._cache.directory})" + if hasattr(self._cache, "directory") + else "" + ) + ) + logger.info(message) + + # Phase 2: Stream missing instances and cache results immediately + if missing_instances: + missing_iter = iter(missing_instances) + + for index, prediction in streaming_func(missing_iter, return_meta_data): + # Cache immediately when ready, but always store result (even if None) + if prediction is not None: + instance = dataset[index] + self._cache_result(instance, prediction) + cached_results[index] = prediction + + # Phase 3: Reconstruct results in original order + return [cached_results[i] for i in range(len(dataset))] + + +class InferenceEngine(abc.ABC, MockInferenceMixin, CachedInferenceMixin): + """Abstract base class for inference.""" @abc.abstractmethod def _infer( @@ -225,6 +325,34 @@ def _infer( """ pass + def _infer_streaming( + self, + instances: Iterable[Tuple[int, Dict[str, Any]]], + return_meta_data: bool = False, + ) -> Iterable[Tuple[int, Union[str, TextGenerationInferenceOutput]]]: + """Default streaming wrapper that uses existing _infer method. + + Engines can override this for true streaming behavior. + + Args: + instances: Iterable of (index, instance) tuples + return_meta_data: Whether to return metadata + + Yields: + (index, prediction) tuples as results become ready + """ + # Collect all instances (maintains backwards compatibility) + all_instances = list(instances) + if not all_instances: + return + + indices = [idx for idx, _ in all_instances] + batch_instances = [inst for _, inst in all_instances] + + # Use existing _infer method + predictions = self._infer(batch_instances, return_meta_data) + yield from zip(indices, predictions) + @abc.abstractmethod def prepare_engine(self): """Perform inference on the input dataset.""" @@ -232,21 +360,13 @@ def prepare_engine(self): def prepare(self): if not self.is_mock: - super().prepare() # no need to prepare a mock + super().prepare() # This will call CachedInferenceMixin.prepare() which initializes cache with error_context( self, stage="Prepare Inference Engine", help="https://www.unitxt.ai/en/latest/docs/inference.html", ): self.prepare_engine() - if self.use_cache: - from diskcache import Cache - - self._cache = Cache( - os.path.join( - settings.inference_engine_cache_path, self.__class__.__name__ - ) - ) def __call__( self, @@ -302,62 +422,10 @@ def infer( result = self._mock_infer(dataset, return_meta_data) else: if self.use_cache: - with error_context( - self, - stage="Inference Cache Handling", - help="https://www.unitxt.ai/en/latest/docs/inference.html", - ): - number_of_batches = math.ceil(len(dataset) / self.cache_batch_size) - result = [] - for batch_index, batch in enumerate( - batched(dataset, self.cache_batch_size) - ): - cached_results = [] - missing_examples = [] - for i, item in enumerate(batch): - cache_key = self._get_cache_key(item) - cached_value = self._cache.get(cache_key) - if cached_value is not None: - cached_results.append( - (i, cached_value) - ) # each element is index in batch, and value - else: - missing_examples.append( - (i, item) - ) # each element is index in batch and example - # infere on missing examples only, without indices - - logger.info( - f"Inferring batch {batch_index + 1} / {number_of_batches} with {len(missing_examples)} instances (found {len(cached_results)} instances in {self._cache.directory})" - ) - if len(missing_examples) > 0: - with error_context( - self, - stage="Running Inference", - help="https://www.unitxt.ai/en/latest/docs/inference.html", - ): - inferred_results = self._infer( - [e[1] for e in missing_examples], return_meta_data - ) - # recombined to index and value - inferred_results = list( - zip([e[0] for e in missing_examples], inferred_results) - ) - # Add missing examples to cache - for (_, item), (_, prediction) in zip( - missing_examples, inferred_results - ): - if prediction is None: - continue - cache_key = self._get_cache_key(item) - self._cache[cache_key] = prediction - else: - inferred_results = [] - # Combine cached and inferred results in original order - batch_predictions = [ - p[1] for p in sorted(cached_results + inferred_results) - ] - result.extend(batch_predictions) + # Use the mixin's caching functionality + result = self._apply_caching_to_streaming( + dataset, self._infer_streaming, return_meta_data + ) else: with error_context( self, @@ -452,6 +520,36 @@ def _infer_log_probs( """ pass + def _infer_log_probs_streaming( + self, + instances: Iterable[Tuple[int, Dict[str, Any]]], + return_meta_data: bool = False, + ) -> Iterable[Tuple[int, Union[Dict, TextGenerationInferenceOutput]]]: + """Default streaming wrapper for log probs inference that uses existing _infer_log_probs method. + + Engines can override this for true streaming behavior. + + Args: + instances: Iterable of (index, instance) tuples + return_meta_data: Whether to return metadata + + Yields: + (index, prediction) tuples as results become ready + """ + # Collect all instances (maintains backwards compatibility) + all_instances = list(instances) + if not all_instances: + return + + indices = [idx for idx, _ in all_instances] + batch_instances = [inst for _, inst in all_instances] + + # Use existing _infer_log_probs method + predictions = self._infer_log_probs(batch_instances, return_meta_data) + + # Yield results with original indices + yield from zip(indices, predictions) + def infer_log_probs( self, dataset: Union[List[Dict[str, Any]], Dataset], diff --git a/tests/inference/test_inference_engine.py b/tests/inference/test_inference_engine.py index 48261b8f01..bd8696d696 100644 --- a/tests/inference/test_inference_engine.py +++ b/tests/inference/test_inference_engine.py @@ -484,7 +484,7 @@ def test_cache(self): ) inference_model = HFPipelineBasedInferenceEngine( model_name=model_name, - max_new_tokens=32, + max_new_tokens=1, # Very small for fast testing temperature=0, top_p=1, use_cache=False, @@ -496,7 +496,7 @@ def test_cache(self): # Set seed for reproducibility inference_model = HFPipelineBasedInferenceEngine( model_name=model_name, - max_new_tokens=32, + max_new_tokens=1, # Very small for fast testing temperature=0, top_p=1, use_cache=True, @@ -543,7 +543,7 @@ def test_cache(self): inference_model = HFPipelineBasedInferenceEngine( model_name=model_name, - max_new_tokens=32, + max_new_tokens=1, # Very small for fast testing temperature=0, top_p=1, use_cache=True, From 396ce8803139d89281123f5f777b38fc848f608d Mon Sep 17 00:00:00 2001 From: elronbandel Date: Wed, 13 Aug 2025 13:09:38 +0300 Subject: [PATCH 03/16] Refactor inference and caching Signed-off-by: elronbandel --- ...existing_dataset_by_llm_as_judge_direct.py | 6 +- src/unitxt/api.py | 6 - src/unitxt/inference.py | 1515 ++++++++--------- src/unitxt/llm_as_judge_from_template.py | 6 +- src/unitxt/metrics.py | 3 +- tests/inference/test_inference_engine.py | 218 ++- 6 files changed, 820 insertions(+), 934 deletions(-) diff --git a/examples/evaluate_existing_dataset_by_llm_as_judge_direct.py b/examples/evaluate_existing_dataset_by_llm_as_judge_direct.py index 800a17b78a..fc09406d72 100644 --- a/examples/evaluate_existing_dataset_by_llm_as_judge_direct.py +++ b/examples/evaluate_existing_dataset_by_llm_as_judge_direct.py @@ -21,7 +21,7 @@ ] dataset = load_dataset( card="cards.squad", - metrics=metrics, + # metrics=metrics, loader_limit=2, max_test_instances=2, split="test", @@ -39,8 +39,8 @@ For the arguments these inference engines can receive, please refer to the classes documentation or read about the the open ai api arguments the CrossProviderInferenceEngine follows. """ -predictions = inference_model.infer(dataset) - +predictions = inference_model(dataset) +exit() gold_answers = [d[0] for d in dataset["references"]] # Evaluate the predictions using the defined metric. diff --git a/src/unitxt/api.py b/src/unitxt/api.py index 23de331bd4..c96a163896 100644 --- a/src/unitxt/api.py +++ b/src/unitxt/api.py @@ -15,7 +15,6 @@ from .error_utils import UnitxtError from .inference import ( InferenceEngine, - LogProbInferenceEngine, OptionSelectingByLogProbsInferenceEngine, ) from .loaders import LoadFromDictionary @@ -380,11 +379,6 @@ def add_previous_messages(example, index): dataset = dataset.map(add_previous_messages, with_indices=True) engine, _ = fetch_artifact(engine) if return_log_probs: - if not isinstance(engine, LogProbInferenceEngine): - raise NotImplementedError( - f"Error in infer: return_log_probs set to True but supplied engine " - f"{engine.__class__.__name__} does not support logprobs." - ) infer_outputs = engine.infer_log_probs(dataset, return_meta_data) raw_predictions = ( [output.prediction for output in infer_outputs] diff --git a/src/unitxt/inference.py b/src/unitxt/inference.py index 33be6dc15d..75c61c7258 100644 --- a/src/unitxt/inference.py +++ b/src/unitxt/inference.py @@ -14,9 +14,9 @@ from collections import Counter from datetime import datetime from itertools import islice -from multiprocessing.pool import ThreadPool from typing import ( Any, + AsyncIterable, Dict, Iterable, List, @@ -32,7 +32,6 @@ from datasets import Dataset, DatasetDict, Image from tqdm import tqdm, trange -from tqdm.asyncio import tqdm_asyncio from .artifact import Artifact from .base_metric import Metric @@ -241,6 +240,7 @@ def _get_cached_result(self, instance: Dict[str, Any]): """Get cached result for an instance, returns None if not found.""" if not self.use_cache or self._cache is None: return None + cache_key = self._get_cache_key(instance) return self._cache.get(cache_key) @@ -251,68 +251,15 @@ def _cache_result(self, instance: Dict[str, Any], prediction): cache_key = self._get_cache_key(instance) self._cache[cache_key] = prediction - def _apply_caching_to_streaming( - self, - dataset: Union[List[Dict[str, Any]], Dataset], - streaming_func, - return_meta_data: bool = False, - ): - """Apply caching logic to a streaming inference function.""" - if not self.use_cache: - # No caching, just use streaming directly - instances_iter = ((i, instance) for i, instance in enumerate(dataset)) - results = {} - for index, prediction in streaming_func(instances_iter, return_meta_data): - results[index] = prediction - return [results[i] for i in range(len(dataset))] - - # Initialize cache if needed - self._initialize_cache() - - # Phase 1: Identify cached vs missing instances - cached_results = {} - missing_instances = [] - cached_count = 0 - - for i, instance in enumerate(dataset): - cached_value = self._get_cached_result(instance) - - if cached_value is not None: - cached_results[i] = cached_value - cached_count += 1 - else: - missing_instances.append((i, instance)) - - message = ( - f"Found {cached_count} cached instances, inferring {len(missing_instances)} instances" - + ( - f" (cache: {self._cache.directory})" - if hasattr(self._cache, "directory") - else "" - ) - ) - logger.info(message) - - # Phase 2: Stream missing instances and cache results immediately - if missing_instances: - missing_iter = iter(missing_instances) - - for index, prediction in streaming_func(missing_iter, return_meta_data): - # Cache immediately when ready, but always store result (even if None) - if prediction is not None: - instance = dataset[index] - self._cache_result(instance, prediction) - cached_results[index] = prediction - - # Phase 3: Reconstruct results in original order - return [cached_results[i] for i in range(len(dataset))] - class InferenceEngine(abc.ABC, MockInferenceMixin, CachedInferenceMixin): """Abstract base class for inference.""" + concurrency_limit: int = 100 + support_log_probs: bool = False + @abc.abstractmethod - def _infer( + async def _infer( self, dataset: Union[List[Dict[str, Any]], Dataset], return_meta_data: bool = False, @@ -325,34 +272,6 @@ def _infer( """ pass - def _infer_streaming( - self, - instances: Iterable[Tuple[int, Dict[str, Any]]], - return_meta_data: bool = False, - ) -> Iterable[Tuple[int, Union[str, TextGenerationInferenceOutput]]]: - """Default streaming wrapper that uses existing _infer method. - - Engines can override this for true streaming behavior. - - Args: - instances: Iterable of (index, instance) tuples - return_meta_data: Whether to return metadata - - Yields: - (index, prediction) tuples as results become ready - """ - # Collect all instances (maintains backwards compatibility) - all_instances = list(instances) - if not all_instances: - return - - indices = [idx for idx, _ in all_instances] - batch_instances = [inst for _, inst in all_instances] - - # Use existing _infer method - predictions = self._infer(batch_instances, return_meta_data) - yield from zip(indices, predictions) - @abc.abstractmethod def prepare_engine(self): """Perform inference on the input dataset.""" @@ -360,6 +279,7 @@ def prepare_engine(self): def prepare(self): if not self.is_mock: + self._initialize_cache() super().prepare() # This will call CachedInferenceMixin.prepare() which initializes cache with error_context( self, @@ -399,11 +319,6 @@ def verify_infer_inputs( raise Exception( "Dataset passed to infer() is not list of dictionaries or Huggingface Dataset" ) - if return_meta_data and not hasattr(self, "get_return_object"): - raise NotImplementedError( - f"Inference engine {self.__class__.__name__} does not support return_meta_data as it " - f"does not contain a 'get_return_object' method. Please set return_meta_data=False." - ) [self.verify_instance(instance) for instance in dataset] @@ -417,22 +332,54 @@ def infer( If return_meta_data - returns a list of TextGenerationInferenceOutput, else returns a list of the string predictions. """ + return asyncio.run(self._async_infer(dataset, return_meta_data)) + + def infer_log_probs( + self, + dataset: Union[List[Dict[str, Any]], Dataset], + return_meta_data: bool = False, + ) -> Union[ListWithMetadata[str], ListWithMetadata[TextGenerationInferenceOutput]]: + if not self.support_log_probs: + raise NotImplementedError( + f"return_log_probs set to True but supplied engine " + f"{self.__class__.__name__} does not support logprobs." + ) + return asyncio.run( + self._async_infer(dataset, return_meta_data, return_log_probs=True) + ) + + @abc.abstractmethod + async def _infer_streaming( + self, + instances: Iterable[Tuple[int, Dict[str, Any]]], + total_len: int, + return_meta_data: bool = False, + ) -> AsyncIterable[Tuple[int, Union[str, TextGenerationInferenceOutput]]]: + ... + + async def _async_infer( + self, + dataset: Union[List[Dict[str, Any]], Dataset], + return_meta_data: bool = False, + return_log_probs: bool = False, + ) -> Union[ListWithMetadata[str], ListWithMetadata[TextGenerationInferenceOutput]]: + """Internal async method that handles all inference logic.""" self.verify_infer_inputs(dataset, return_meta_data) + if self.is_mock: - result = self._mock_infer(dataset, return_meta_data) - else: - if self.use_cache: - # Use the mixin's caching functionality - result = self._apply_caching_to_streaming( - dataset, self._infer_streaming, return_meta_data - ) + logger.info("Running inference with mock") + if return_log_probs: + result = self._mock_infer_log_probs(dataset) else: - with error_context( - self, - stage="Running Inference", - help="https://www.unitxt.ai/en/latest/docs/inference.html", - ): - result = self._infer(dataset, return_meta_data) + result = self._mock_infer(dataset, return_meta_data) + else: + logger.info(f"Running inference with {self.get_engine_id()}") + results: Dict[int, Any] = {} + async for index, prediction in self._infer_streaming( + enumerate(dataset), len(dataset), return_meta_data, return_log_probs + ): + results[index] = prediction + result = [results[i] for i in range(len(dataset))] return ListWithMetadata( result, metadata={ @@ -503,78 +450,216 @@ def to_tools(self, instance): return None -class LogProbInferenceEngine(abc.ABC, MockInferenceMixin): - """Abstract base class for inference with log probs.""" +class BatchInferenceEngine(InferenceEngine): + """Base class for inference engines that process instances in batches locally. + + This is designed for local compute engines (like HF Pipeline/AutoModel) that: + - Benefit from batching multiple instances together + - Process all instances in a single batch call + - Are typically GPU/CPU bound rather than network bound + """ + + batch_size: int = NonPositionalField(default=1) @abc.abstractmethod - def _infer_log_probs( + def _infer_batch( self, - dataset: Union[List[Dict[str, Any]], Dataset], + instances: List[Dict[str, Any]], return_meta_data: bool = False, - ) -> Union[List[Dict], List[TextGenerationInferenceOutput]]: - """Perform inference on the input dataset that returns log probs. + return_log_probs: bool = False, + ) -> Union[List[str], List[TextGenerationInferenceOutput]]: + """Process a single batch of instances. - If return_meta_data - returns a list of TextGenerationInferenceOutput, else returns a list of the logprob dicts. - return_meta_data is only supported for some InferenceEngines. - predictions. + Args: + instances: List of instances to process in this batch + return_meta_data: Whether to return metadata + return_log_probs: Whether to return log probabilities per token + + Returns: + List of predictions corresponding to the input instances """ pass - def _infer_log_probs_streaming( + async def _infer( + self, + dataset: Union[List[Dict[str, Any]], Dataset], + return_meta_data: bool = False, + ) -> Union[List[str], List[TextGenerationInferenceOutput]]: + """Process dataset by batching instances.""" + raise NotImplementedError + + async def _infer_streaming( self, instances: Iterable[Tuple[int, Dict[str, Any]]], + total_len: int, return_meta_data: bool = False, - ) -> Iterable[Tuple[int, Union[Dict, TextGenerationInferenceOutput]]]: - """Default streaming wrapper for log probs inference that uses existing _infer_log_probs method. + return_log_probs: bool = False, + ) -> AsyncIterable[Tuple[int, Union[str, TextGenerationInferenceOutput]]]: + """Stream results by processing instances in batches, with tqdm progress.""" + current_batch = [] + current_indices = [] + total_generated = 0 + total_loaded = 0 + pbar = tqdm( + total=total_len, desc=f"Inference {self.get_engine_id()}", unit="inst" + ) + try: + for idx, instance in instances: + cached_result = self._get_cached_result(instance) + if cached_result is not None: + pbar.update(1) + total_loaded += 1 + yield idx, cached_result + continue - Engines can override this for true streaming behavior. + current_batch.append(instance) + current_indices.append(idx) - Args: - instances: Iterable of (index, instance) tuples - return_meta_data: Whether to return metadata + if len(current_batch) >= self.batch_size: + results = self._infer_batch( + current_batch, return_meta_data, return_log_probs + ) + for original_idx, instance, result in zip( + current_indices, current_batch, results + ): + self._cache_result(instance, result) + pbar.update(1) + total_generated += 1 + yield original_idx, result + current_batch = [] + current_indices = [] + + if current_batch: + results = self._infer_batch( + current_batch, return_meta_data, return_log_probs + ) + for original_idx, result in zip(current_indices, results): + self._cache_result(instance, result) + pbar.update(1) + total_generated += 1 + yield original_idx, result + finally: + # Process remaining instances + if current_batch: + results = self._infer_batch( + current_batch, return_meta_data, return_log_probs + ) + for original_idx, result in zip(current_indices, results): + self._cache_result(instance, result) + pbar.update(1) + total_generated += 1 + yield original_idx, result + pbar.close() + + logger.info( + f"Inference Summary: {total_generated} generated, {total_loaded} loaded from cache." + ) - Yields: - (index, prediction) tuples as results become ready - """ - # Collect all instances (maintains backwards compatibility) - all_instances = list(instances) - if not all_instances: - return - indices = [idx for idx, _ in all_instances] - batch_instances = [inst for _, inst in all_instances] +class SingleInferenceEngine(InferenceEngine): + """Base class for inference engines that process instances individually via API calls. - # Use existing _infer_log_probs method - predictions = self._infer_log_probs(batch_instances, return_meta_data) + This is designed for API-based engines (like OpenAI, LiteLLM, Ollama, VLLM) that: + - Make individual API calls per instance + - Have built-in concurrency/batching at the API level + - Benefit from async concurrent calls + - Handle rate limiting and retries at the API level + """ - # Yield results with original indices - yield from zip(indices, predictions) + concurrency_limit: int = 100 + max_concurrent_calls: int = NonPositionalField(default=10) - def infer_log_probs( + @abc.abstractmethod + async def _infer_single( self, - dataset: Union[List[Dict[str, Any]], Dataset], + instance: Dict[str, Any], return_meta_data: bool = False, - ) -> Union[List[Dict], List[TextGenerationInferenceOutput]]: - """Verifies instances of a dataset and performs inference that returns log probabilities of top tokens. + return_log_probs: bool = False, + ) -> Union[str, TextGenerationInferenceOutput]: + """Process a single instance via API call. - For each instance , generates a list of top tokens per position. - [ "top_tokens": [ { "text": ..., "logprob": ...} , ... ] - If return_meta_data - returns a list of TextGenerationInferenceOutput, else returns the list of the logprob dicts. - return_meta_data is only supported for some InferenceEngines. + Args: + instance: Single instance to process + return_meta_data: Whether to return metadata + return_log_probs: Whether to return log probabilities per token + + Returns: + Single prediction result """ - if return_meta_data and not hasattr(self, "get_return_object"): - raise NotImplementedError( - f"Inference engine {self.__class__.__name__} does not support return_meta_data as it " - f"does not contain a 'get_return_object' method. Please set return_meta_data=False." - ) + pass - [self.verify_instance(instance) for instance in dataset] + async def _infer( + self, + dataset: Union[List[Dict[str, Any]], Dataset], + return_meta_data: bool = False, + ) -> Union[List[str], List[TextGenerationInferenceOutput]]: + raise NotImplementedError - if self.is_mock: - result = self._mock_infer_log_probs(dataset) - else: - result = self._infer_log_probs(dataset, return_meta_data) - return result + async def _infer_streaming( + self, + instances: Iterable[Tuple[int, Dict[str, Any]]], + total_len: int, + return_meta_data: bool = False, + return_log_probs: bool = False, + ) -> AsyncIterable[Tuple[int, Union[str, TextGenerationInferenceOutput]]]: + """Stream results concurrently without realizing the input iterable, with tqdm progress.""" + sem = asyncio.Semaphore(self.concurrency_limit) + it = iter(instances) + pending: set[asyncio.Task] = set() + + # Initialize tqdm with total length + pbar = tqdm( + total=total_len, desc=f"Inference ({self.get_engine_id()})", unit="item" + ) + total_loaded = 0 + total_generated = 0 + + async def bounded_infer(idx: int, instance: Dict[str, Any]): + nonlocal total_loaded, total_generated + cached_result = self._get_cached_result(instance) + if cached_result is not None: + total_loaded += 1 + return idx, cached_result + async with sem: + result = await self._infer_single( + instance, return_meta_data, return_log_probs + ) + total_generated += 1 + self._cache_result(instance, result) + return idx, result + + # Prime the pool + for _ in range(self.concurrency_limit): + try: + idx, inst = next(it) + except StopIteration: + break + pending.add(asyncio.create_task(bounded_infer(idx, inst))) + + # Drain while refilling + try: + while pending: + done, pending = await asyncio.wait( + pending, return_when=asyncio.FIRST_COMPLETED + ) + + for task in done: + idx, result = await task + pbar.update(1) # update progress for each finished task + yield idx, result + + while len(pending) < self.concurrency_limit: + try: + idx, inst = next(it) + except StopIteration: + break + pending.add(asyncio.create_task(bounded_infer(idx, inst))) + finally: + pbar.close() + + logger.info( + f"Inference Summary: {total_generated} generated, {total_loaded} loaded from cache." + ) class LazyLoadMixin(Artifact): @@ -598,8 +683,7 @@ class HFGenerationParamsMixin(Artifact): class HFInferenceEngineBase( - InferenceEngine, - LogProbInferenceEngine, + BatchInferenceEngine, PackageRequirementsMixin, LazyLoadMixin, HFGenerationParamsMixin, @@ -788,14 +872,6 @@ def infer( self._prepare_engine() return super().infer(dataset, return_meta_data) - @abc.abstractmethod - def _infer( - self, - dataset: Union[List[Dict[str, Any]], Dataset], - return_meta_data: bool = False, - ) -> Union[List[str], List[TextGenerationInferenceOutput]]: - raise NotImplementedError - def infer_log_probs( self, dataset: Union[List[Dict[str, Any]], Dataset], @@ -806,11 +882,12 @@ def infer_log_probs( return super().infer_log_probs(dataset, return_meta_data) @abc.abstractmethod - def _infer_log_probs( + def _infer_batch( self, - dataset: Union[List[Dict[str, Any]], Dataset], + instances: List[Dict[str, Any]], return_meta_data: bool = False, - ) -> Union[List[Dict], List[TextGenerationInferenceOutput]]: + return_log_probs: bool = False, + ) -> Union[List[str], List[TextGenerationInferenceOutput]]: raise NotImplementedError @@ -825,6 +902,7 @@ class HFAutoModelInferenceEngine(HFInferenceEngineBase): padding: bool = True truncation: bool = True padding_side: str = "left" # for decoder only models + support_log_probs: bool = True def _init_processor(self): from transformers import AutoTokenizer @@ -913,123 +991,86 @@ def prepare_inputs(self, data: Iterable, tools: Iterable) -> Mapping: **tokenizer_kargs, ).to(self.device or self.device_map) - def _infer_fn( + def _infer_batch( self, - dataset: Union[List[Dict[str, Any]], Dataset], - return_meta_data: bool, - return_logprobs: bool, - ) -> Union[List[str], List[Dict], List[TextGenerationInferenceOutput]]: - """Performs inference on the dataset in batches. - - Args: - dataset: A list of dictionaries or a Dataset object containing the input data. - Each item should have a "source" key. - return_meta_data: Whether to include metadata in the output. - return_logprobs: Whether to return log probabilities along with the output. - - Returns: - A list of outputs, which can be strings, dictionaries (if metadata is returned), - or TextGenerationInferenceOutput objects (if logprobs are returned). - """ - all_final_outputs = [] # List to store results from all batches - - for batch in tqdm( - batched(dataset, self.batch_size), - desc=f"Running inference in batches of {self.batch_size}", - total=len(dataset) // self.batch_size, - ): - # Get the current batch - sources = [] - tools = [] - for instance in batch: - sources.append(instance["source"]) - if "task_data" in instance: - task_data = instance["task_data"] - if isinstance(task_data, str): - task_data = json.loads(task_data) - if "__tools__" in task_data: - tools.append(task_data["__tools__"]) - else: - tools.append(None) + instances: List[Dict[str, Any]], + return_meta_data: bool = False, + return_log_probs: bool = False, + ) -> Union[List[str], List[TextGenerationInferenceOutput]]: + sources = [] + tools = [] + for instance in instances: + sources.append(instance["source"]) + if "task_data" in instance: + task_data = instance["task_data"] + if isinstance(task_data, str): + task_data = json.loads(task_data) + if "__tools__" in task_data: + tools.append(task_data["__tools__"]) else: tools.append(None) - # Tokenize inputs for the batch - - tokenized_inputs = self.prepare_inputs(sources, tools) - - # Determine input length (handle encoder-decoder models) - input_length = ( - 1 - if self.model.config.is_encoder_decoder - else tokenized_inputs.input_ids.shape[1] - ) - - # Make predictions for the batch - predictions = self.make_predictions(tokenized_inputs) - sequences = predictions.sequences # Sequences for the current batch + else: + tools.append(None) + # Tokenize inputs for the batch - output_tokens = sequences[:, input_length:] + tokenized_inputs = self.prepare_inputs(sources, tools) - output_tokens_strings = [] - for tokens in output_tokens: - output_tokens_strings.append( - [ - self.processor.decode(token, skip_special_tokens=True) - for token in tokens - ] - ) + # Determine input length (handle encoder-decoder models) + input_length = ( + 1 + if self.model.config.is_encoder_decoder + else tokenized_inputs.input_ids.shape[1] + ) - output_strings = [] - for tokens in output_tokens: - output_strings.append( - self.processor.decode(tokens, skip_special_tokens=True) - ) + # Make predictions for the batch + predictions = self.make_predictions(tokenized_inputs) + sequences = predictions.sequences # Sequences for the current batch - if return_logprobs: - outputs = self.get_logprobs(predictions, output_tokens_strings) - else: - outputs = output_strings - - # Create return objects for the batch - batch_results = [] - for i in range(len(sequences)): - batch_results.append( - self.get_return_object( - output=outputs[i], - generated_text=output_strings[i], - output_tokens=len(output_tokens_strings[i]), - inp=sources[i], - inp_tokens=len(tokenized_inputs.encodings[i].tokens) - if tokenized_inputs.encodings is not None - else None, - return_meta_data=return_meta_data, - ) - ) + output_tokens = sequences[:, input_length:] - all_final_outputs.extend(batch_results) + output_tokens_strings = [] + for tokens in output_tokens: + output_tokens_strings.append( + [ + self.processor.decode(token, skip_special_tokens=True) + for token in tokens + ] + ) - return all_final_outputs + output_strings = [] + for tokens in output_tokens: + output_strings.append( + self.processor.decode(tokens, skip_special_tokens=True) + ) - def _infer( - self, - dataset: Union[List[Dict[str, Any]], Dataset], - return_meta_data: bool = False, - ) -> Union[List[str], List[TextGenerationInferenceOutput]]: - return self._infer_fn(dataset, return_meta_data, False) + if return_log_probs: + outputs = self.get_logprobs(predictions, output_tokens_strings) + else: + outputs = output_strings - def _infer_log_probs( - self, - dataset: Union[List[Dict[str, Any]], Dataset], - return_meta_data: bool = False, - ) -> Union[List[Dict], List[TextGenerationInferenceOutput]]: - self.verify_not_chat_api(dataset) - return self._infer_fn(dataset, return_meta_data, True) + # Create return objects for the batch + results = [] + for i in range(len(sequences)): + results.append( + self.get_return_object( + output=outputs[i], + generated_text=output_strings[i], + output_tokens=len(output_tokens_strings[i]), + inp=sources[i], + inp_tokens=len(tokenized_inputs.encodings[i].tokens) + if tokenized_inputs.encodings is not None + else None, + return_meta_data=return_meta_data, + ) + ) + return results class HFLlavaInferenceEngine(HFInferenceEngineBase): lazy_load: bool = True label: str = "hf_lava" image_token: str = "" + support_log_probs: bool = True def compute_transition_scores( self, sequences: Sequence, scores: Sequence, beam_indices: Optional[int] @@ -1094,15 +1135,15 @@ def prepare_inputs(self, data: Iterable) -> Mapping: return inputs - def _infer_fn( + def _infer_batch( self, - dataset: Union[List[Dict[str, Any]], Dataset], - return_meta_data: bool, - return_logprobs: bool, - ) -> Union[List[str], List[Dict], List[TextGenerationInferenceOutput]]: + instances: List[Dict[str, Any]], + return_meta_data: bool = False, + return_log_probs: bool = False, + ) -> Union[List[str], List[TextGenerationInferenceOutput]]: results = [] - for instance in tqdm(dataset): + for instance in instances: processed_inputs = self.prepare_inputs(instance) input_len = len(processed_inputs["input_ids"][0]) @@ -1127,7 +1168,7 @@ def _infer_fn( self.processor.decode(tokens, skip_special_tokens=True) ) - if return_logprobs: + if return_log_probs: final_outputs = self.get_logprobs(predictions, output_tokens_strings) else: final_outputs = output_strings @@ -1145,20 +1186,6 @@ def _infer_fn( return results - def _infer( - self, - dataset: Union[List[Dict[str, Any]], Dataset], - return_meta_data: bool = False, - ) -> Union[List[str], List[TextGenerationInferenceOutput]]: - return self._infer_fn(dataset, return_meta_data, False) - - def _infer_log_probs( - self, - dataset: Union[List[Dict[str, Any]], Dataset], - return_meta_data: bool = False, - ) -> Union[List[Dict], List[TextGenerationInferenceOutput]]: - return self._infer_fn(dataset, return_meta_data, True) - class HFPeftInferenceEngine(HFAutoModelInferenceEngine): label: str = "hf_peft_auto_model" @@ -1236,7 +1263,7 @@ def _init_model(self): class HFPipelineBasedInferenceEngine( - InferenceEngine, + BatchInferenceEngine, PackageRequirementsMixin, LazyLoadMixin, HFGenerationParamsMixin, @@ -1377,42 +1404,58 @@ def prepare_engine(self): if not self.lazy_load: self._prepare_engine() - def _infer( + def _infer_batch( self, - dataset: Union[List[Dict[str, Any]], Dataset], + instances: List[Dict[str, Any]], return_meta_data: bool = False, - ) -> Union[List[str], List[TextGenerationInferenceOutput]]: + return_log_probs: bool = False, + ) -> List[Union[str, TextGenerationInferenceOutput]]: + """Run a synchronous batch through the model and return outputs (with optional metadata).""" if not self._is_loaded(): self._prepare_engine() - outputs = self.model([instance["source"] for instance in dataset]) + # Prepare input texts + inputs = [inst["source"] for inst in instances] - return [ - self.get_return_object(output[0], instance["source"], return_meta_data) - if isinstance(output, list) - else self.get_return_object(output, instance["source"], return_meta_data) - for output, instance in zip(outputs, dataset) - ] + # Synchronous model call + outputs = self.model(inputs) - def get_return_object(self, output, inp, return_meta_data): - if return_meta_data: - return TextGenerationInferenceOutput( - prediction=output["generated_text"], - generated_text=output["generated_text"], - model_name=self.model_name, - inference_type=self.label, - input_text=inp, - ) - return output["generated_text"] + results: List[Union[str, TextGenerationInferenceOutput]] = [] + for output, inst in zip(outputs, instances): + # Normalize single-item list outputs + if isinstance(output, list) and output: + output = output[0] + + # Extract text if output is a dict + if isinstance(output, dict): + text = output.get("generated_text", "") + else: + text = str(output) + + if return_meta_data: + results.append( + TextGenerationInferenceOutput( + prediction=text, + generated_text=text, + model_name=self.model_name, + inference_type=self.label, + input_text=inst["source"], + ) + ) + else: + results.append(text) + + return results -class MockInferenceEngine(InferenceEngine, LogProbInferenceEngine): +class MockInferenceEngine(SingleInferenceEngine): model_name: str default_inference_value: str = "[[10]]" default_inference_value_logprob: List[Dict[str, Any]] = dataclasses.field( default_factory=MockInferenceMixin.mock_logprobs_default_value_factory, ) label: str = "mock_inference_engine" + support_log_probs: bool = True def get_engine_id(self): return get_model_and_label_id(self.model_name, "mock") @@ -1438,33 +1481,18 @@ def _mock_infer( result.append(self.default_inference_value) return result - def _infer( + async def _infer_single( self, - dataset: Union[List[Dict[str, Any]], Dataset], - return_meta_data: bool = False, - ) -> Union[List[str], List[TextGenerationInferenceOutput]]: - return [ - self.get_return_object( - self.default_inference_value, instance, return_meta_data - ) - for instance in dataset - ] - - def _infer_log_probs( - self, - dataset: Union[List[Dict[str, Any]], Dataset], + instance: Dict[str, Any], return_meta_data: bool = False, - ) -> Union[List[Dict], List[TextGenerationInferenceOutput]]: - return [ - self.get_return_object( - self.default_inference_value_logprob, instance, return_meta_data - ) - for instance in dataset - ] + return_log_probs: bool = False, + ) -> Union[str, TextGenerationInferenceOutput]: + predict_result = ( + self.default_inference_value_logprob + if return_log_probs + else self.default_inference_value + ) - def get_return_object( - self, predict_result, generated_text, instance, return_meta_data - ): if return_meta_data: return TextGenerationInferenceOutput( prediction=predict_result, @@ -1477,12 +1505,11 @@ def get_return_object( seed=111, stop_reason="", ) + return predict_result -class GenericInferenceEngine( - InferenceEngine, ArtifactFetcherMixin, LogProbInferenceEngine -): +class GenericInferenceEngine(InferenceEngine, ArtifactFetcherMixin): default: Optional[str] = None def prepare_engine(self): @@ -1500,34 +1527,17 @@ def prepare_engine(self): engine_reference = self.default self.engine = self.get_artifact(engine_reference) - def get_engine_id(self): - # If mock_inference_mode is set, no engine is prepared. - if hasattr(self, "engine"): - return f"generic_{self.engine.get_engine_id()}" - return "generic_inference_engine" - - def _infer( - self, - dataset: Union[List[Dict[str, Any]], Dataset], - return_meta_data: bool = False, - ) -> Union[List[str], List[TextGenerationInferenceOutput]]: - return self.engine._infer(dataset) - - def _infer_log_probs( + def _infer_streaming( self, - dataset: Union[List[Dict[str, Any]], Dataset], + instances: Iterable[Tuple[Union[int, Dict[str, Any]]]], + total_len: int, return_meta_data: bool = False, - ) -> Union[List[str], List[TextGenerationInferenceOutput]]: - if not isinstance(self.engine, LogProbInferenceEngine): - raise NotImplementedError( - f"Error in infer: inference engine used by the GenericInferenceEngine" - f"({self.engine.__class__.__name__}) does not support logprobs." - ) - return self.engine._infer_log_probs(dataset) + ) -> AsyncIterable[Tuple[int, Union[str, TextGenerationInferenceOutput]]]: + return self.engine._infer_streaming(instances, total_len, return_meta_data) class OllamaInferenceEngine( - InferenceEngine, StandardAPIParamsMixin, PackageRequirementsMixin + SingleInferenceEngine, StandardAPIParamsMixin, PackageRequirementsMixin ): label: str = "ollama" _requirements_list = { @@ -1539,32 +1549,39 @@ def get_engine_id(self): return get_model_and_label_id(self.model, self.label) def prepare_engine(self): - from ollama import Client + from ollama import AsyncClient - self.client = Client( + self.client = AsyncClient( host=self.credentials["api_base"] if self.credentials is not None and "api_base" in self.credentials else None ) - def _infer( + async def _infer_single( self, - dataset: Union[List[Dict[str, Any]], Dataset], + instance: Dict[str, Any], return_meta_data: bool = False, - ) -> Union[List[str], List[TextGenerationInferenceOutput]]: + return_log_probs: bool = False, + ) -> Union[str, TextGenerationInferenceOutput]: args = self.to_dict([StandardAPIParamsMixin]) - results = [] model = args.pop("model") - for instance in dataset: - messages = self.to_messages(instance) - response = self.client.chat( - messages=messages, - model=model, - options=args, - ) - results.append(response) + messages = self.to_messages(instance) + response = await self.client.chat( + messages=messages, + model=model, + options=args, + ) - return [element["message"]["content"] for element in results] + prediction = response["message"]["content"] + + if return_meta_data: + return TextGenerationInferenceOutput( + prediction=prediction, + generated_text=prediction, + model_name=self.model, + inference_type=self.label, + ) + return prediction class OptionSelectingByLogProbsInferenceEngine: @@ -1710,8 +1727,7 @@ def inner(self, args): class OpenAiInferenceEngine( - InferenceEngine, - LogProbInferenceEngine, + SingleInferenceEngine, OpenAiInferenceEngineParamsMixin, PackageRequirementsMixin, ): @@ -1726,6 +1742,7 @@ class OpenAiInferenceEngine( default_headers: Dict[str, str] = {} credentials: CredentialsOpenAi = {} num_parallel_requests: int = 20 + support_log_probs: bool = True def get_engine_id(self) -> str: return get_model_and_label_id(self.model_name, self.label) @@ -1749,17 +1766,19 @@ def get_default_headers(self) -> Dict[str, str]: return self.default_headers def create_client(self): - from openai import OpenAI + from openai import AsyncOpenAI self.credentials = self._prepare_credentials() - return OpenAI( + return AsyncOpenAI( api_key=self.credentials["api_key"], base_url=self.base_url or self.credentials["api_url"], default_headers=self.get_default_headers(), ) def prepare_engine(self): - self.client = self.create_client() + from openai import AsyncOpenAI + + self.client: AsyncOpenAI = self.create_client() self._set_inference_parameters() def _get_completion_kwargs(self): @@ -1769,96 +1788,24 @@ def _get_completion_kwargs(self): if v is not None } - def _parallel_infer( - self, - dataset: Union[List[Dict[str, Any]], Dataset], - infer_func, - return_meta_data: bool = False, - ) -> Union[List[str], List[TextGenerationInferenceOutput]]: - inputs = [(instance, return_meta_data) for instance in dataset] - outputs = [] - with ThreadPool(processes=self.num_parallel_requests) as pool: - for output in tqdm( - pool.imap(infer_func, inputs), - total=len(inputs), - desc=f"Inferring with {self.__class__.__name__}", - ): - outputs.append(output) - - return outputs - - def _infer( - self, - dataset: Union[List[Dict[str, Any]], Dataset], - return_meta_data: bool = False, - ) -> Union[List[str], List[TextGenerationInferenceOutput]]: - return self._parallel_infer( - dataset=dataset, - return_meta_data=return_meta_data, - infer_func=self._get_chat_completion, - ) - - def _infer_log_probs( + async def _infer_single( self, - dataset: Union[List[Dict[str, Any]], Dataset], + instance: Dict[str, Any], return_meta_data: bool = False, - ) -> Union[List[Dict], List[TextGenerationInferenceOutput]]: - return self._parallel_infer( - dataset=dataset, - return_meta_data=return_meta_data, - infer_func=self._get_logprobs, - ) - - def get_client_model_name(self): - return self.model_name - - @run_with_imap - def _get_chat_completion(self, instance, return_meta_data): - import openai - + return_log_probs: bool = False, + ) -> Union[str, TextGenerationInferenceOutput]: tools = self.to_tools(instance) messages = self.to_messages(instance) - try: - response = self.client.chat.completions.create( - messages=messages, - tools=tools, - model=self.get_client_model_name(), - **self._get_completion_kwargs(), - # tool_choice="auto" - ) - - if tools is None: - prediction = response.choices[0].message.content - else: - try: - func_call = response.choices[0].message.tool_calls[0].function - prediction = f'{{"name": "{func_call.name}", "arguments": {func_call.arguments}}}' - except: - prediction = response.choices[0].message.content or "" - - return self.get_return_object(prediction, response, return_meta_data) - # catch in case of content_filtering failure - except openai.BadRequestError as e: - logging.error( - f"Error predicting instance {messages}:{e}. Returning empty prediction" - ) - return TextGenerationInferenceOutput( - prediction="-", generated_text="-", input_tokens=0, output_tokens=0 - ) - - @run_with_imap - def _get_logprobs(self, instance, return_meta_data): - import openai + response = await self.client.chat.completions.create( + messages=messages, + tools=tools, + model=self.get_client_model_name(), + **self._get_completion_kwargs(), + ) - messages = self.to_messages(instance) - try: - response = self.client.chat.completions.create( - messages=messages, - model=self.model_name, - **self._get_completion_kwargs(), - ) + if return_log_probs: top_logprobs_response = response.choices[0].logprobs.content - pred_output = [ + prediction = [ { "text": generated_token.token, "logprob": generated_token.logprob, @@ -1869,33 +1816,29 @@ def _get_logprobs(self, instance, return_meta_data): } for generated_token in top_logprobs_response ] - return self.get_return_object(pred_output, response, return_meta_data) - # catch in case of content_filtering failure - except openai.BadRequestError as e: - logging.error( - f"Error predicting instance {messages}:{e}. Returning empty prediction" - ) - prediction = [ - {"text": "-", "logprob": 0, "top_tokens": [{"text": "-", "logprob": 0}]} - ] - return TextGenerationInferenceOutput( - prediction=prediction, - generated_text=prediction, - input_tokens=0, - output_tokens=0, - ) + else: + if tools is None: + prediction = response.choices[0].message.content + else: + try: + func_call = response.choices[0].message.tool_calls[0].function + prediction = f'{{"name": "{func_call.name}", "arguments": {func_call.arguments}}}' + except: + prediction = response.choices[0].message.content or "" - def get_return_object(self, predict_result, response, return_meta_data): if return_meta_data: return TextGenerationInferenceOutput( - prediction=predict_result, + prediction=prediction, generated_text=response.choices[0].message.content, input_tokens=response.usage.prompt_tokens, output_tokens=response.usage.completion_tokens, model_name=self.model_name, inference_type=self.label, ) - return predict_result + return prediction + + def get_client_model_name(self): + return self.model_name class AzureOpenAIInferenceEngine(OpenAiInferenceEngine): @@ -1926,10 +1869,10 @@ def _prepare_credentials(self) -> CredentialsOpenAi: return {"api_key": api_key, "api_url": api_url, "api_version": api_version} def create_client(self): - from openai import AzureOpenAI + from openai import AsyncAzureOpenAI self.credentials = self._prepare_credentials() - return AzureOpenAI( + return AsyncAzureOpenAI( api_key=self.credentials["api_key"], base_url=self.credentials["api_url"], api_version=self.credentials["api_version"], @@ -2027,7 +1970,9 @@ class TogetherAiInferenceEngineParamsMixin(Artifact): class TogetherAiInferenceEngine( - InferenceEngine, TogetherAiInferenceEngineParamsMixin, PackageRequirementsMixin + SingleInferenceEngine, + TogetherAiInferenceEngineParamsMixin, + PackageRequirementsMixin, ): label: str = "together" model_name: str @@ -2041,7 +1986,7 @@ def get_engine_id(self): return get_model_and_label_id(self.model_name, self.label) def prepare_engine(self): - from together import Together + from together import AsyncTogether from together.types.models import ModelType api_key_env_var_name = "TOGETHER_API_KEY" # pragma: allowlist secret @@ -2050,7 +1995,7 @@ def prepare_engine(self): f"Error while trying to run TogetherAiInferenceEngine." f" Please set the environment param '{api_key_env_var_name}'." ) - self.client = Together(api_key=api_key) + self.client = AsyncTogether(api_key=api_key) self._set_inference_parameters() # Get model type from Together List Models API @@ -2075,39 +2020,34 @@ def _get_infer_kwargs(self): if v is not None } - def _infer_chat(self, instance: Dict[str, Any]) -> str: + async def _infer_chat(self, instance: Dict[str, Any]) -> str: messages = self.to_messages(instance) - response = self.client.chat.completions.create( + response = await self.client.chat.completions.create( model=self.model_name, messages=messages, **self._get_infer_kwargs(), ) return response.choices[0].message.content - def _infer_text(self, instance: Dict[str, Any]) -> str: - response = self.client.completions.create( + async def _infer_text(self, instance: Dict[str, Any]) -> str: + response = await self.client.completions.create( model=self.model_name, prompt=instance["source"], **self._get_infer_kwargs(), ) return response.choices[0].text - def _infer( + async def _infer_single( self, - dataset: Union[List[Dict[str, Any]], Dataset], + instance: Dict[str, Any], return_meta_data: bool = False, - ) -> Union[List[str], List[TextGenerationInferenceOutput]]: + return_log_probs: bool = False, + ) -> Union[str, TextGenerationInferenceOutput]: from together.types.models import ModelType - outputs = [] if self.model_type == ModelType.CHAT: - for instance in tqdm(dataset, desc="Inferring with Together AI Chat API"): - outputs.append(self._infer_chat(instance)) - else: - self.verify_not_chat_api(dataset) - for instance in tqdm(dataset, desc="Inferring with Together AI Text API"): - outputs.append(self._infer_text(instance)) - return outputs + return await self._infer_chat(instance) + return await self._infer_text(instance) @deprecation( @@ -2196,9 +2136,8 @@ class WMLChatParamsMixin(Artifact): class WMLInferenceEngineBase( - InferenceEngine, + SingleInferenceEngine, PackageRequirementsMixin, - LogProbInferenceEngine, OptionSelectingByLogProbsInferenceEngine, ): """Base for classes running inference using ibm-watsonx-ai. @@ -2260,15 +2199,6 @@ def verify(self): or (self.deployment_id and not (self.model_name and self.deployment_id)) ), "Either 'model_name' or 'deployment_id' must be specified, but not both at the same time." - # def process_data_before_dump(self, data): - # if "credentials" in data: - # for key, value in data["credentials"].items(): - # if key != "url": - # data["credentials"][key] = "" - # else: - # data["credentials"][key] = value - # return data - def _initialize_wml_client(self): if self.external_client: return self.external_client @@ -2390,51 +2320,6 @@ def _load_model(self): api_client=self._client, ) - @abc.abstractmethod - def _send_requests( - self, - dataset: Union[List[Dict[str, Any]], Dataset], - return_logprobs: bool, - return_meta_data: bool, - ) -> Union[List[str], List[Dict], List[TextGenerationInferenceOutput]]: - raise NotImplementedError( - f"The class '{self.get_pretty_print_name()}' is an abstract class. " - f"Please used either 'WMLInferenceEngineGeneration' or " - f"'WMLInferenceEngineChat' instead, depending on your task." - ) - - def _infer( - self, - dataset: Union[List[Dict[str, Any]], Dataset], - return_meta_data: bool = False, - ) -> Union[List[str], List[TextGenerationInferenceOutput]]: - if self._model is None: - self._load_model() - - return self._send_requests( - dataset=dataset, - return_logprobs=False, - return_meta_data=return_meta_data, - ) - - def _infer_log_probs( - self, - dataset: Union[List[Dict[str, Any]], Dataset], - return_meta_data: bool = False, - ) -> Union[List[Dict], List[TextGenerationInferenceOutput]]: - if self._model is None: - self._load_model() - - return self._send_requests( - dataset=dataset, - return_logprobs=True, - return_meta_data=return_meta_data, - ) - - @abc.abstractmethod - def get_return_object(self, predict_result, result, input_text, return_meta_data): - raise NotImplementedError - def get_model_details(self) -> Dict: return self._model.get_details() @@ -2557,55 +2442,61 @@ def _set_logprobs_params(self, params: Dict[str, Any]) -> Dict[str, Any]: "return_options": logprobs_return_options, } - def _send_requests( + async def _infer_single( self, - dataset: Union[List[Dict[str, Any]], Dataset], - return_logprobs: bool, - return_meta_data: bool, - ) -> Union[List[str], List[Dict], List[TextGenerationInferenceOutput]]: - self.verify_not_chat_api(dataset) + instance, + return_meta_data: bool = False, + return_log_probs: bool = False, + ): + if self._model is None: + self._load_model() params = self.to_dict([WMLGenerationParamsMixin], keep_empty=False) + inp = instance["source"] - if return_logprobs: - generation_type = "generated_tokens" - params = self._set_logprobs_params(params) - else: - generation_type = "generated_text" + response = await self._model.agenerate(prompt=inp, params=params) + result = response["results"][0] + pred = result.get("generated_text", "") - inputs: List[str] = [instance["source"] for instance in dataset] + if return_meta_data: + return TextGenerationInferenceOutput( + prediction=pred, + generated_text=result.get("generated_text"), + input_tokens=result.get("input_token_count"), + output_tokens=result.get("generated_token_count"), + model_name=self.model_name or self.deployment_id, + inference_type=self.label, + stop_reason=result.get("stop_reason"), + seed=self.random_seed, + input_text=inp, + ) + return pred - results = self._model.generate( - prompt=inputs, - params=params, - concurrency_limit=self.concurrency_limit, - ) + async def _infer_log_probs_single(self, instance, return_meta_data: bool = False): + if self._model is None: + self._load_model() - final_results = [] - for result, inp in zip(results, inputs): - result_metadata = result["results"][0] - generated_content = result_metadata[generation_type] - final_results.append( - self.get_return_object( - generated_content, result_metadata, inp, return_meta_data - ) - ) - return final_results + params = self.to_dict([WMLGenerationParamsMixin], keep_empty=False) + params = self._set_logprobs_params(params) + inp = instance["source"] + + response = await self._model.agenerate(prompt=inp, params=params) + result = response["results"][0] + pred = result.get("generated_tokens", {}) - def get_return_object(self, predict_result, result, input_text, return_meta_data): if return_meta_data: return TextGenerationInferenceOutput( - prediction=predict_result, - generated_text=result["generated_text"], - input_tokens=result["input_token_count"], - output_tokens=result["generated_token_count"], + prediction=pred, + generated_text=result.get("generated_text"), + input_tokens=result.get("input_token_count"), + output_tokens=result.get("generated_token_count"), model_name=self.model_name or self.deployment_id, inference_type=self.label, - stop_reason=result["stop_reason"], + stop_reason=result.get("stop_reason"), seed=self.random_seed, - input_text=input_text, + input_text=inp, ) - return predict_result + return pred class WMLInferenceEngineChat(WMLInferenceEngineBase, WMLChatParamsMixin): @@ -2659,6 +2550,17 @@ class WMLInferenceEngineChat(WMLInferenceEngineBase, WMLChatParamsMixin): image_encoder: Optional[EncodeImageToString] = NonPositionalField( default_factory=EncodeImageToString ) + support_log_probs: bool = True + + def _async_infer( + self, + dataset: Union[List[Dict[str, Any]], Dataset], + return_meta_data: bool = False, + return_log_probs: bool = False, + ) -> Union[ListWithMetadata[str], ListWithMetadata[TextGenerationInferenceOutput]]: + if self._model is None: + self._load_model() + return super()._async_infer(dataset, return_meta_data, return_log_probs) @staticmethod def _extract_queries(instance: Dict[str, Any]) -> Tuple[Optional[str], List]: @@ -2722,7 +2624,7 @@ def _create_messages_from_instance( } ) - messages.append([message]) + messages.append(message) return messages @@ -2773,7 +2675,7 @@ def to_messages(self, instance: Union[Dict, List]) -> List[List[Dict[str, Any]]] self.verify_messages(messages) # This is done to be compatible with inputs containing # images as SDK allows sending only one image per message. - return [messages] + return messages def to_tools( self, instance: Dict[str, Any] @@ -2792,110 +2694,92 @@ def to_tools( return {"tools": None, "tool_choice": None} - def _handle_async_requests( - self, - data: List[Dict[str, Any]], - params: Dict[str, Any], - ) -> List[Dict[str, Any]]: - async def handle_async_requests(start_idx, end_idx): - coroutines = [ - self._model.achat( - messages=data[idx]["msg"], - params=params, - tools=data[idx]["tools"]["tools"], - tool_choice=data[idx]["tools"]["tool_choice"], - ) - for idx in range(start_idx, end_idx) - ] - batch_results = await asyncio.gather(*coroutines) - return list(batch_results) - - loop = asyncio.get_event_loop() - results = [] - - for batch_idx in range(0, len(data), self.concurrency_limit): - batch_results = loop.run_until_complete( - handle_async_requests( - batch_idx, min(batch_idx + self.concurrency_limit, len(data)) - ) - ) - results.extend(batch_results) - - return results - - def _send_requests( + async def _infer_single( self, - dataset: Union[List[Dict[str, Any]], Dataset], - return_logprobs: bool, - return_meta_data: bool, - ) -> Union[List[str], List[Dict], List[TextGenerationInferenceOutput]]: + instance, + return_meta_data: bool = False, + return_log_probs: bool = False, + ): + # Build params params = self.to_dict([WMLChatParamsMixin], keep_empty=False) + params["logprobs"] = return_log_probs + output_type = "logprobs" if return_log_probs else "message" - if return_logprobs: - output_type = "logprobs" - params["logprobs"] = True - else: - output_type = "message" - params["logprobs"] = False + # Prepare inputs + messages = self.to_messages(instance) + tools_info = self.to_tools(instance) or {} + tools = tools_info.get("tools") + tool_choice = tools_info.get("tool_choice") + tool_call_expected = tools is not None - data = [ - { - "idx": i, - "msg": message, - "tools": self.to_tools(dataset[i]), - } - for i in range(len(dataset)) - for message in self.to_messages(dataset[i]) - ] + # Single request + response = await self._model.achat( + messages=messages, + params=params, + tools=tools, + tool_choice=tool_choice, + ) - responses = self._handle_async_requests(data, params) + # Extract output + choice = response["choices"][0] + output = choice[output_type] - results = [] - for inp, response in zip(data, responses): - idx = inp["idx"] - tool_call = data[idx]["tools"]["tools"] is not None - - output = response["choices"][0][output_type] - if "content" not in output: - output["content"] = "" - if tool_call: - if "tool_calls" in output: - func = output["tool_calls"][0]["function"] - prediction = f'{{"name": "{func["name"]}", "arguments": {func["arguments"]}}}' - else: - prediction = output["content"] - else: - prediction = output["content"] + # normalize "content" for dict responses + if isinstance(output, dict) and "content" not in output: + output["content"] = "" - results.append( - self.get_return_object( - prediction, - response["choices"][0]["message"]["content"], - response, - str(inp), - return_meta_data, + # compute prediction (mirror your _send_requests logic) + if tool_call_expected: + if ( + isinstance(output, dict) + and "tool_calls" in output + and output["tool_calls"] + ): + func = output["tool_calls"][0]["function"] + prediction = ( + f'{{"name": "{func["name"]}", "arguments": {func["arguments"]}}}' + ) + else: + prediction = ( + output.get("content", "") + if isinstance(output, dict) + else str(output) ) + else: + prediction = ( + output.get("content", "") if isinstance(output, dict) else str(output) ) - return results + # assistant-visible text for metadata + generated_text = ( + choice.get("message", {}).get("content", "") + if "message" in choice + else ( + output.get("content", "") if isinstance(output, dict) else str(output) + ) + ) + + # snapshot of input (like str(inp) in your _send_requests) + input_snapshot = str( + { + "idx": 0, + "msg": messages, + "tools": {"tools": tools, "tool_choice": tool_choice}, + } + ) - def get_return_object( - self, predict_result, generated_text, result, input_text, return_meta_data - ): if return_meta_data: return TextGenerationInferenceOutput( - prediction=predict_result, + prediction=prediction, generated_text=generated_text, - input_tokens=result["usage"]["prompt_tokens"], - output_tokens=len(predict_result) - if isinstance(predict_result, list) - else None, + input_tokens=response["usage"]["prompt_tokens"], + output_tokens=len(prediction) if isinstance(prediction, list) else None, model_name=self.model_name or self.deployment_id, inference_type=self.label, - stop_reason=result["choices"][0]["finish_reason"], - input_text=input_text, + stop_reason=response["choices"][0]["finish_reason"], + input_text=input_snapshot, ) - return predict_result + return prediction @deprecation( @@ -2946,7 +2830,7 @@ def get_text_without_images(instance, image_token=""): class LMMSEvalBaseInferenceEngine( - InferenceEngine, PackageRequirementsMixin, LazyLoadMixin, TorchDeviceMixin + BatchInferenceEngine, PackageRequirementsMixin, LazyLoadMixin, TorchDeviceMixin ): label = "lmms-eval" model_type: str @@ -2995,10 +2879,11 @@ class LMMSEvalInferenceEngine(LMMSEvalBaseInferenceEngine): do_sample: bool = False generate_until: List[str] = ["\n\n"] - def _infer( + def _infer_batch( self, - dataset: Union[List[Dict[str, Any]], Dataset], + instances: List[Dict[str, Any]], return_meta_data: bool = False, + return_log_probs: bool = False, ) -> Union[List[str], List[TextGenerationInferenceOutput]]: if not self._is_loaded(): self._prepare_engine() @@ -3008,7 +2893,7 @@ def _infer( temp_task_name = str(uuid.uuid4()) requests = [] - for i, instance in enumerate(dataset): + for i, instance in enumerate(instances): requests.append( Instance( request_type="generate_until", @@ -3034,7 +2919,7 @@ def _infer( ) ) - self.model.task_dict[temp_task_name] = DatasetDict({"test": dataset}) + self.model.task_dict[temp_task_name] = DatasetDict({"test": instances}) responses = self.model.generate_until(requests) @@ -3067,10 +2952,11 @@ def make_instance(self, instance, special_args, index, task_name): }, ) - def _infer( + def _infer_batch( self, - dataset: Union[List[Dict[str, Any]], Dataset], + instances: List[Dict[str, Any]], return_meta_data: bool = False, + return_log_probs: bool = False, ) -> Union[List[str], List[TextGenerationInferenceOutput]]: if not self._is_loaded(): self._prepare_engine() @@ -3078,7 +2964,7 @@ def _infer( temp_task_name = str(uuid.uuid4()) requests = [] - for i, instance in enumerate(dataset): + for i, instance in enumerate(instances): task_data = instance["task_data"] if isinstance(task_data, str): @@ -3094,15 +2980,15 @@ def _infer( ) ) - self.model.task_dict[temp_task_name] = DatasetDict({"test": dataset}) + self.model.task_dict[temp_task_name] = DatasetDict({"test": instances}) self.model.metadata = {} responses = self.model.loglikelihood(requests) self.model.task_dict.pop(temp_task_name) - optimal_scores = [sys.float_info.max] * len(dataset) - optimal_responses = [None] * len(dataset) + optimal_scores = [sys.float_info.max] * len(instances) + optimal_responses = [None] * len(instances) for request, (score, _) in zip(requests, responses): if score < optimal_scores[request.idx]: @@ -3135,7 +3021,9 @@ class VLLMParamsMixin(Artifact): prompt_logprobs: Optional[int] = None -class VLLMInferenceEngine(InferenceEngine, PackageRequirementsMixin, VLLMParamsMixin): +class VLLMInferenceEngine( + SingleInferenceEngine, PackageRequirementsMixin, VLLMParamsMixin +): label = "vllm" def get_engine_id(self): @@ -3144,10 +3032,13 @@ def get_engine_id(self): def prepare_engine(self): args = self.to_dict([VLLMParamsMixin]) args.pop("model") - from vllm import LLM, SamplingParams + + from vllm import AsyncLLMEngine, SamplingParams + from vllm.engine.arg_utils import AsyncEngineArgs self.sampling_params = SamplingParams(**args) - self.llm = LLM( + # Use AsyncLLMEngine for single instance async processing + engine_args = AsyncEngineArgs( model=self.model, device="auto", trust_remote_code=True, @@ -3157,27 +3048,53 @@ def prepare_engine(self): max_num_seqs=64, enforce_eager=True, ) + self.llm = AsyncLLMEngine.from_engine_args(engine_args) - def _infer( + async def _infer_single( self, - dataset: Union[List[Dict[str, Any]], Dataset], + instance: Dict[str, Any], return_meta_data: bool = False, - ) -> Union[List[str], List[TextGenerationInferenceOutput]]: - inputs = [] - for instance in dataset: - inputs.append(instance["source"]) - - if isinstance(inputs[0], list): - # outputs = self.llm.chat(inputs, self.sampling_params, chat_template="{{- bos_token }}\n{%- if custom_tools is defined %}\n {%- set tools = custom_tools %}\n{%- endif %}\n{%- if not tools_in_user_message is defined %}\n {%- set tools_in_user_message = true %}\n{%- endif %}\n{%- if not date_string is defined %}\n {%- if strftime_now is defined %}\n {%- set date_string = strftime_now(\"%d %b %Y\") %}\n {%- else %}\n {%- set date_string = \"26 Jul 2024\" %}\n {%- endif %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n {%- set user_supplied_system_message = true %}\n{%- else %}\n {%- set system_message = \"\" %}\n {%- set user_supplied_system_message = false %}\n{%- endif %}\n\n{#- Find out if there are any images #}\n{% set image_ns = namespace(has_images=false) %} \n{%- for message in messages %}\n {%- for content in message['content'] %}\n {%- if content['type'] == 'image' %}\n {%- set image_ns.has_images = true %}\n {%- endif %}\n {%- endfor %}\n{%- endfor %}\n\n{#- System message if there are no images, or if the user supplied one #}\n{%- if user_supplied_system_message or not image_ns.has_images %}\n {{- \"<|start_header_id|>system<|end_header_id|>\\n\" }}\n {%- if tools is not none %}\n {{- \"Environment: ipython\\n\" }}\n {%- endif %}\n {{- \"Cutting Knowledge Date: December 2023\\n\" }}\n {{- \"Today Date: \" + date_string + \"\\n\" }}\n {%- if tools is not none and not tools_in_user_message %}\n {{- \"You have access to the following functions. To call a function, please respond with JSON for a function call.\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\" }}\n {%- endfor %}\n {%- endif %}\n {{- system_message }}\n {{- \"<|eot_id|>\" }}\n{%- endif %}\n\n{#- Custom tools are passed in a user message with some extra guidance #}\n{%- if tools_in_user_message and not tools is none %}\n {#- Extract the first user message so we can plug it in here #}\n {%- if messages | length != 0 %}\n {%- set first_user_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n {%- else %}\n {{- raise_exception(\"Cannot put tools in the first user message when there's no first user message!\") }}\n{%- endif %}\n {{- '<|start_header_id|>user<|end_header_id|>\\n' -}}\n {{- \"Given the following functions, please respond with a JSON for a function call \" }}\n {{- \"with its proper arguments that best answers the given prompt.\\n\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\" }}\n {%- endfor %}\n {{- first_user_message + \"<|eot_id|>\"}}\n{%- endif %}\n\n{%- for message in messages %}\n {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}\n {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n' }}\n {%- if message['content'] is string %}\n {{- message['content'] }}\n {%- else %}\n {%- for content in message['content'] %}\n {%- if content['type'] == 'image' %}\n {{- '<|image|>' }}\n {%- elif content['type'] == 'text' %}\n {{- content['text'] }}\n {%- endif %}\n {%- endfor %}\n {%- endif %}\n {{- '<|eot_id|>' }}\n {%- elif 'tool_calls' in message %}\n {%- if not message.tool_calls|length == 1 %}\n {{- raise_exception(\"This model only supports single tool-calls at once!\") }}\n {%- endif %}\n {%- set tool_call = message.tool_calls[0].function %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n' -}}\n {{- '{\"name\": \"' + tool_call.name + '\", ' }}\n {{- '\"parameters\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- \"}\" }}\n {{- \"<|eot_id|>\" }}\n {%- elif message.role == \"tool\" or message.role == \"ipython\" %}\n {{- \"<|start_header_id|>ipython<|end_header_id|>\\n\" }}\n {%- if message.content is mapping or message.content is iterable %}\n {{- message.content | tojson }}\n {%- else %}\n {{- message.content }}\n {%- endif %}\n {{- \"<|eot_id|>\" }}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n' }}\n{%- endif %}") - outputs = self.llm.chat(inputs, self.sampling_params) + return_log_probs: bool = False, + ) -> Union[str, TextGenerationInferenceOutput]: + import uuid + + source = instance["source"] + request_id = str(uuid.uuid4()) + + # Use the async VLLM generate method + if isinstance(source, list): + # Chat format + result_generator = self.llm.generate( + request_id, source, self.sampling_params + ) else: - outputs = self.llm.generate(inputs, self.sampling_params) + # Simple text format + result_generator = self.llm.generate( + request_id, source, self.sampling_params + ) - predictions = [] - for output in outputs: - predictions.append(output.outputs[0].text) + # Iterate through the async generator to get the final result + final_output = None + async for request_output in result_generator: + final_output = request_output - return predictions + if final_output and final_output.outputs: + generated_text = final_output.outputs[0].text + + if return_meta_data: + return TextGenerationInferenceOutput( + prediction=generated_text, + generated_text=generated_text, + input_tokens=len(final_output.prompt_token_ids) + if final_output.prompt_token_ids + else 0, + output_tokens=len(final_output.outputs[0].token_ids) + if final_output.outputs[0].token_ids + else 0, + ) + return generated_text + + return "" class AsyncTokenBucket: @@ -3218,7 +3135,7 @@ async def acquire(self, tokens=1): class LiteLLMInferenceEngine( - InferenceEngine, StandardAPIParamsMixin, PackageRequirementsMixin + SingleInferenceEngine, StandardAPIParamsMixin, PackageRequirementsMixin ): label: str = "litellm" max_requests_per_second: float = 6 @@ -3241,84 +3158,59 @@ def prepare_engine(self): from litellm import acompletion self._completion = acompletion - # Initialize a semaphore to limit concurrency - self._semaphore = asyncio.Semaphore(round(self.max_requests_per_second)) + # Use parent's concurrent calls limit instead of custom semaphore + self.max_concurrent_calls = round(self.max_requests_per_second) - async def _infer_instance( - self, index: int, instance: Dict[str, Any] - ) -> TextGenerationInferenceOutput: + async def _infer_single( + self, + instance: Dict[str, Any], + return_meta_data: bool = False, + return_log_probs: bool = False, + ) -> Union[str, TextGenerationInferenceOutput]: """Process a single inference request.""" - async with self._semaphore: - await self._rate_limiter.acquire() - # Introduce a slight delay to prevent burstiness - await asyncio.sleep(0.01) - messages = self.to_messages(instance) - tools = self.to_tools(instance) - kwargs = self.to_dict([StandardAPIParamsMixin]) - kwargs = {k: v for k, v in kwargs.items() if v is not None} - del kwargs["credentials"] - try: - response = await self._completion( - messages=messages, - tools=tools, - max_retries=self.max_retries, - drop_params=False, - **self.credentials, - **kwargs, - ) - except Exception as e: - raise RuntimeError( - f"Error inferring the following instance:\n{instance}" - ) from e + await self._rate_limiter.acquire() + # Introduce a slight delay to prevent burstiness + await asyncio.sleep(0.01) + messages = self.to_messages(instance) + tools = self.to_tools(instance) + kwargs = self.to_dict([StandardAPIParamsMixin]) + kwargs = {k: v for k, v in kwargs.items() if v is not None} + del kwargs["credentials"] - usage = response.get("usage", {}) + response = await self._completion( + messages=messages, + tools=tools, + max_retries=self.max_retries, + drop_params=False, + **self.credentials, + **kwargs, + ) - if tools is None: - prediction = response["choices"][0]["message"]["content"] - else: - try: - func_call = response["choices"][0]["message"]["tool_calls"][0][ - "function" - ] - prediction = f'{{"name": "{func_call.name}", "arguments": {func_call.arguments}}}' - except: - prediction = response["choices"][0]["message"]["content"] or "" - return TextGenerationInferenceOutput( - prediction=prediction, - generated_text=response["choices"][0]["message"]["content"], - input_tokens=usage.get("prompt_tokens"), - output_tokens=usage.get("completion_tokens"), - model_name=response.get("model", self.model), - inference_type=self.inference_type, - ) + usage = response.get("usage", {}) - async def _infer_async( - self, dataset: List[Dict[str, Any]] - ) -> List[TextGenerationInferenceOutput]: - """Process multiple inference requests concurrently with a progress bar.""" - tasks = ( - self._infer_instance(i, instance) for i, instance in enumerate(dataset) - ) - # Use tqdm_asyncio.gather to display progress bar - return await tqdm_asyncio.gather( - *tasks, desc=f"LiteLLM Inference ({self.model})", total=len(dataset) + if tools is None: + prediction = response["choices"][0]["message"]["content"] + else: + try: + func_call = response["choices"][0]["message"]["tool_calls"][0][ + "function" + ] + prediction = f'{{"name": "{func_call.name}", "arguments": {func_call.arguments}}}' + except: + prediction = response["choices"][0]["message"]["content"] or "" + + result = TextGenerationInferenceOutput( + prediction=prediction, + generated_text=response["choices"][0]["message"]["content"], + input_tokens=usage.get("prompt_tokens"), + output_tokens=usage.get("completion_tokens"), + model_name=response.get("model", self.model), + inference_type=self.inference_type, ) - def _infer( - self, - dataset: Union[List[Dict[str, Any]], Dataset], - return_meta_data: bool = False, - ) -> Union[List[str], List[TextGenerationInferenceOutput]]: - """Main inference entry point.""" - loop = asyncio.get_event_loop() - responses = loop.run_until_complete(self._infer_async(dataset)) - return self.get_return_object(responses, return_meta_data) - - def get_return_object(self, responses, return_meta_data): if return_meta_data: - return responses - - return [response.prediction for response in responses] + return result + return prediction _supported_apis = Literal[ @@ -3336,9 +3228,7 @@ def get_return_object(self, responses, return_meta_data): ] -class CrossProviderInferenceEngine( - InferenceEngine, StandardAPIParamsMixin, LogProbInferenceEngine -): +class CrossProviderInferenceEngine(InferenceEngine, StandardAPIParamsMixin): """Inference engine capable of dynamically switching between multiple providers APIs. This class extends the InferenceEngine and OpenAiInferenceEngineParamsMixin @@ -3637,12 +3527,29 @@ def prepare_engine(self): self.engine: InferenceEngine = cls(**args) self.data_classification_policy = self.engine.data_classification_policy - def _infer( + async def _infer( self, dataset: Union[List[Dict[str, Any]], Dataset], return_meta_data: bool = False, ) -> Union[List[str], List[TextGenerationInferenceOutput]]: - return self.engine._infer(dataset, return_meta_data) + return await self.engine._infer(dataset, return_meta_data) + + async def _async_infer( + self, + dataset: Union[List[Dict[str, Any]], Dataset], + return_meta_data: bool = False, + ) -> Union[ListWithMetadata[str], ListWithMetadata[TextGenerationInferenceOutput]]: + return await self.engine._async_infer(dataset, return_meta_data) + + async def _infer_streaming( + self, + instances: Iterable[Tuple[Union[int, Dict[str, Any]]]], + total_len: int, + return_meta_data: bool = False, + ) -> AsyncIterable[Tuple[Union[int, str, TextGenerationInferenceOutput]]]: + return await self.engine._infer_streaming( + instances, total_len, return_meta_data + ) def get_engine_id(self): api = self.get_provider_name() @@ -3650,19 +3557,8 @@ def get_engine_id(self): return get_model_and_label_id(self.provider_model_map[api][self.model], api) return get_model_and_label_id(self.model, api) - def _infer_log_probs( - self, - dataset: Union[List[Dict[str, Any]], Dataset], - return_meta_data: bool = False, - ) -> Union[List[Dict], List[TextGenerationInferenceOutput]]: - if not isinstance(self.engine, LogProbInferenceEngine): - raise UnitxtError( - f"The underlying inference engine of this instance of CrossProviderInferenceEngine ({self.engine.get_engine_id()}) must inherit from LogProbInferenceEngine and implement _infer_log_probs" - ) - return self.engine._infer_log_probs(dataset, return_meta_data) - -class HFOptionSelectingInferenceEngine(InferenceEngine, TorchDeviceMixin): +class HFOptionSelectingInferenceEngine(BatchInferenceEngine, TorchDeviceMixin): """HuggingFace based class for inference engines that calculate log probabilities. This class uses models from the HuggingFace Transformers library to calculate log probabilities for text inputs. @@ -3670,7 +3566,6 @@ class HFOptionSelectingInferenceEngine(InferenceEngine, TorchDeviceMixin): label = "hf_option_selection" model_name: str - batch_size: int _requirements_list = { "transformers": "Install huggingface package using 'pip install --upgrade transformers" @@ -3740,10 +3635,11 @@ def get_log_probs(self, texts): return log_probs - def _infer( + def _infer_batch( self, - dataset: Union[List[Dict[str, Any]], Dataset], + instances: List[Dict[str, Any]], return_meta_data: bool = False, + return_log_probs: bool = False, ) -> Union[List[str], List[TextGenerationInferenceOutput]]: if return_meta_data and not hasattr(self.engine, "get_return_object"): raise NotImplementedError( @@ -3753,7 +3649,7 @@ def _infer( inputs = [] - for instance in dataset: + for instance in instances: for option in instance["task_data"]["options"]: if isinstance(instance["source"], list): inputs.append( @@ -3767,7 +3663,7 @@ def _infer( scores_iterator = iter(scores) predictions = [] - for instance in dataset: + for instance in instances: options_scores = Counter() for option in instance["task_data"]["options"]: score = next(scores_iterator) @@ -3777,31 +3673,32 @@ def _infer( return predictions -class MetricInferenceEngine(InferenceEngine): +class MetricInferenceEngine(BatchInferenceEngine): """An inference engine that uses the output of a metric as its prediction. Used to evaluate metrics like LLM as Judge or Granite Guardian. Args: - InferenceEngine (_type_): _description_ + BatchInferenceEngine (_type_): _description_ """ metric: Metric prediction_field: Optional[str] = None - def _infer( + def _infer_batch( self, - dataset: Union[List[Dict[str, Any]], Dataset], + instances: List[Dict[str, Any]], return_meta_data: bool = False, + return_log_probs: bool = False, ) -> Union[List[str], List[TextGenerationInferenceOutput]]: task_data = [ json.loads(instance["task_data"]) if "task_data" in instance else {} - for instance in dataset + for instance in instances ] predictions = ( [td[self.prediction_field] for td in task_data] if self.prediction_field else [] ) - references = [instance["references"] for instance in dataset] + references = [instance["references"] for instance in instances] return self.metric.compute( task_data=task_data, predictions=predictions, diff --git a/src/unitxt/llm_as_judge_from_template.py b/src/unitxt/llm_as_judge_from_template.py index df2d5abab8..5eab5a59b9 100644 --- a/src/unitxt/llm_as_judge_from_template.py +++ b/src/unitxt/llm_as_judge_from_template.py @@ -5,7 +5,7 @@ from .api import infer from .dataclass import Field from .formats import ChatAPIFormat, Format, SystemFormat -from .inference import InferenceEngine, LogProbInferenceEngine, OpenAiInferenceEngine +from .inference import InferenceEngine, OpenAiInferenceEngine from .metrics import BulkInstanceMetric from .operator import SequentialOperator from .operators import ArtifactFetcherMixin @@ -386,9 +386,7 @@ def preprocess_instance(self, instance): def verify(self): super().verify() - if self.infer_log_probs and not isinstance( - self.inference_model, LogProbInferenceEngine - ): + if self.infer_log_probs and not self.inference_model.support_log_probs: raise NotImplementedError( f"Error in TaskBasedLLMAsJudge: return_log_probs set to True but supplied engine " f"{self.inference_model.__class__.__name__} does not support logprobs." diff --git a/src/unitxt/metrics.py b/src/unitxt/metrics.py index de05034250..c631f24a07 100644 --- a/src/unitxt/metrics.py +++ b/src/unitxt/metrics.py @@ -49,7 +49,6 @@ from .inference import ( HFPipelineBasedInferenceEngine, InferenceEngine, - LogProbInferenceEngine, TorchDeviceMixin, WMLInferenceEngineGeneration, ) @@ -6770,7 +6769,7 @@ class GraniteGuardianBase(InstanceMetric): safe_token = "No" unsafe_token = "Yes" - inference_engine: LogProbInferenceEngine = None + inference_engine: InferenceEngine = None generation_params: Dict = None risk_name: str = None risk_type: RiskType = None diff --git a/tests/inference/test_inference_engine.py b/tests/inference/test_inference_engine.py index bd8696d696..84587243fa 100644 --- a/tests/inference/test_inference_engine.py +++ b/tests/inference/test_inference_engine.py @@ -1,7 +1,5 @@ import os -import random -import shutil -import time +import tempfile from functools import lru_cache from typing import Any, Dict, List, cast @@ -93,8 +91,9 @@ def get_text_dataset(format=None): class TestInferenceEngine(UnitxtInferenceTestCase): def test_pipeline_based_inference_engine(self): model = HFPipelineBasedInferenceEngine( - model_name=local_decoder_model, # pragma: allowlist secret + model_name=local_decoder_model, max_new_tokens=2, + use_cache=False, ) dataset = get_text_dataset() @@ -108,6 +107,7 @@ def test_pipeline_based_inference_engine_lazy_load(self): model_name=local_decoder_model, # pragma: allowlist secret max_new_tokens=2, lazy_load=True, + use_cache=False, ) dataset = get_text_dataset() @@ -121,6 +121,7 @@ def test_dataset_verification_inference_engine(self): max_new_tokens=2, lazy_load=True, data_classification_policy=["public"], + use_cache=False, ) dataset = [{"source": "", "data_classification_policy": ["pii"]}] with self.assertRaises(UnitxtError) as e: @@ -141,6 +142,7 @@ def test_llava_inference_engine(self): model_name="llava-hf/llava-interleave-qwen-0.5b-hf", max_new_tokens=3, temperature=0.0, + use_cache=False, ) dataset = get_image_dataset(format="formats.chat_api") @@ -168,6 +170,7 @@ def test_watsonx_inference(self): top_k=1, repetition_penalty=1.5, decoding_method="greedy", + use_cache=True, ) dataset = get_text_dataset() @@ -181,6 +184,7 @@ def test_watsonx_chat_inference(self): model_name="ibm/granite-3-8b-instruct", data_classification_policy=["public"], temperature=0, + use_cache=False, ) dataset = get_text_dataset() @@ -228,6 +232,7 @@ def test_rits_inference(self): model = RITSInferenceEngine( model_name="microsoft/phi-4", max_tokens=128, + use_cache=False, ) dataset = get_text_dataset() @@ -292,17 +297,18 @@ def test_option_selecting_by_log_prob_inference_engines(self): def test_hf_auto_model_inference_engine_batching(self): model = HFAutoModelInferenceEngine( - model_name=local_decoder_model, # pragma: allowlist secret + model_name=local_decoder_model, max_new_tokens=2, batch_size=2, data_classification_policy=["public"], + use_cache=False, ) dataset = get_text_dataset() predictions = list(model(dataset)) - self.assertListEqual(predictions, ["7\n", "12"]) + self.assertListEqual(predictions, ["", "12"]) def test_hf_auto_model_inference_engine(self): data = get_text_dataset() @@ -312,6 +318,7 @@ def test_hf_auto_model_inference_engine(self): repetition_penalty=1.5, top_k=5, data_classification_policy=["public"], + use_cache=False, ) self.assertEqual(engine.get_engine_id(), "flan_t5_small_hf_auto_model") @@ -346,6 +353,7 @@ def test_watsonx_inference_with_images(self): max_tokens=128, top_logprobs=3, temperature=0.0, + use_cache=False, ) results = inference_engine.infer_log_probs( @@ -374,6 +382,7 @@ def test_lite_llm_inference_engine(self): temperature=0, top_p=1, seed=42, + use_cache=False, ) dataset = get_text_dataset(format="formats.chat_api") @@ -388,6 +397,7 @@ def test_lite_llm_inference_engine_without_task_data_not_failing(self): temperature=0, top_p=1, seed=42, + use_cache=False, ).infer([{"source": "say hello."}]) def test_log_prob_scoring_inference_engine(self): @@ -453,8 +463,9 @@ def test_hugginface_pipeline_inference_engine_chat_api(self): model_name=local_decoder_model, max_new_tokens=1, top_k=1, + use_cache=False, ) - predictions = engine.infer(dataset) + predictions = engine(dataset) self.assertEqual(predictions[0], "hi") self.assertEqual(predictions[1], "I") @@ -470,109 +481,96 @@ def test_ollama_inference_engine(self): self.assertTrue("Ottawa" in predictions[0], predictions[0]) def test_cache(self): - unitxt.settings.allow_unverified_code = True - if os.path.exists(unitxt.settings.inference_engine_cache_path): - shutil.rmtree(unitxt.settings.inference_engine_cache_path) - - model_name = local_decoder_model # pragma: allowlist secret - - dataset = load_dataset( - card="cards.openbook_qa", - split="test", - # format="formats.chat_api", - loader_limit=20, - ) - inference_model = HFPipelineBasedInferenceEngine( - model_name=model_name, - max_new_tokens=1, # Very small for fast testing - temperature=0, - top_p=1, - use_cache=False, - device="cpu", - ) - start_time = time.time() - predictions_without_cache = inference_model.infer(dataset) - inference_without_cache_time = time.time() - start_time - # Set seed for reproducibility - inference_model = HFPipelineBasedInferenceEngine( - model_name=model_name, - max_new_tokens=1, # Very small for fast testing - temperature=0, - top_p=1, - use_cache=True, - cache_batch_size=5, - device="cpu", - ) - start_time = time.time() - predictions_with_cache = inference_model.infer(dataset) - inference_with_cache_time = time.time() - start_time - - self.assertEqual(len(predictions_without_cache), len(predictions_with_cache)) - for p1, p2 in zip(predictions_without_cache, predictions_with_cache): - self.assertEqual(p1, p2) - - logger.info( - f"Time of inference without cache: {inference_without_cache_time}, " - f"with cache (cache is empty): {inference_with_cache_time}" - ) - - start_time = time.time() - predictions_with_cache = inference_model.infer(dataset) - inference_with_cache_time = time.time() - start_time - - self.assertEqual(len(predictions_without_cache), len(predictions_with_cache)) - for p1, p2 in zip(predictions_without_cache, predictions_with_cache): - self.assertEqual(p1, p2) - - logger.info( - f"Time of inference without cache: {inference_without_cache_time}, " - f"with cache (cache is full): {inference_with_cache_time}" - ) - - self.assertGreater(inference_without_cache_time, 2) - self.assertLess(inference_with_cache_time, 0.5) - - # Ensure that even in the case of failures, the cache allows incremental addition of predictions, - # enabling the run to complete. To test this, introduce noise that causes the inference engine's - # `infer` method to return empty results 20% of the time (empty results are not stored in the cache). - # Verify that after enough runs, all predictions are successfully cached and the final results - # match those obtained without caching. - - if os.path.exists(unitxt.settings.inference_engine_cache_path): - shutil.rmtree(unitxt.settings.inference_engine_cache_path) - - inference_model = HFPipelineBasedInferenceEngine( - model_name=model_name, - max_new_tokens=1, # Very small for fast testing - temperature=0, - top_p=1, - use_cache=True, - cache_batch_size=5, - device="cpu", - ) - - def my_wrapper(original_method): - random.seed(int(time.time())) - - def wrapped(*args, **kwargs): - predictions = original_method(*args, **kwargs) - return [p if random.random() < 0.6 else None for p in predictions] - - return wrapped - - inference_model._infer = my_wrapper(inference_model._infer) - predictions = [None] - while predictions.count(None) > 0: - start_time = time.time() - predictions = inference_model.infer(dataset) - inference_time = time.time() - start_time - logger.info( - f"Inference time: {inference_time}, predictions contains {predictions.count(None)} Nones" - ) - - self.assertEqual(len(predictions_without_cache), len(predictions_with_cache)) - for p1, p2 in zip(predictions_without_cache, predictions_with_cache): - self.assertEqual(p1, p2) + with tempfile.TemporaryDirectory() as temp_dir: + with unitxt.settings.context(inference_engine_cache_path=temp_dir): + + def raise_error(*args, **kwargs): + raise NotImplementedError + + dataset = get_text_dataset(format="formats.chat_api") + + model = HFPipelineBasedInferenceEngine( + model_name=local_decoder_model, + max_new_tokens=1, + temperature=0, + top_p=1, + use_cache=True, + device="cpu", + ) + + model._infer_batch = raise_error + + with self.assertRaises(NotImplementedError): + model(dataset) + + model = HFPipelineBasedInferenceEngine( + model_name=local_decoder_model, + max_new_tokens=1, + temperature=0, + top_p=1, + use_cache=True, + device="cpu", + ) + + predictions_without_cache = model(dataset) + + model = HFPipelineBasedInferenceEngine( + model_name=local_decoder_model, + max_new_tokens=1, + temperature=0, + top_p=1, + use_cache=True, + device="cpu", + ) + + model._infer_batch = raise_error + + predictions_with_cache = model(dataset) + + self.assertEqual( + len(predictions_without_cache), len(predictions_with_cache) + ) + for p1, p2 in zip(predictions_without_cache, predictions_with_cache): + self.assertEqual(p1, p2) + + # # Ensure that even in the case of failures, the cache allows incremental addition of predictions, + # # enabling the run to complete. To test this, introduce noise that causes the inference engine's + # # `infer` method to return empty results 20% of the time (empty results are not stored in the cache). + # # Verify that after enough runs, all predictions are successfully cached and the final results + # # match those obtained without caching. + + # inference_model = HFPipelineBasedInferenceEngine( + # model_name=model_name, + # max_new_tokens=1, # Very small for fast testing + # temperature=0, + # top_p=1, + # use_cache=True, + # cache_batch_size=5, + # device="cpu", + # ) + + # def my_wrapper(original_method): + # random.seed(int(time.time())) + + # async def wrapped(*args, **kwargs): + # predictions = await original_method(*args, **kwargs) + # return [p if random.random() < 0.6 else None for p in predictions] + + # return wrapped + + # inference_model._infer = my_wrapper(inference_model._infer) + # predictions = [None] + # while predictions.count(None) > 0: + # start_time = time.time() + # predictions = inference_model(dataset) + # inference_time = time.time() - start_time + # logger.info( + # f"Inference time: {inference_time}, predictions contains {predictions.count(None)} Nones" + # ) + + # self.assertEqual(len(predictions_without_cache), len(predictions_with_cache)) + # for p1, p2 in zip(predictions_without_cache, predictions_with_cache): + # self.assertEqual(p1, p2) def test_wml_chat_tool_calling(self): instance = { From bd1e7d625fe99e3d450a28b07acc1306367506b2 Mon Sep 17 00:00:00 2001 From: elronbandel Date: Tue, 19 Aug 2025 20:20:10 +0300 Subject: [PATCH 04/16] Some fixes Signed-off-by: elronbandel --- src/unitxt/inference.py | 40 ++++------------------------------------ 1 file changed, 4 insertions(+), 36 deletions(-) diff --git a/src/unitxt/inference.py b/src/unitxt/inference.py index 75c61c7258..f648b158cc 100644 --- a/src/unitxt/inference.py +++ b/src/unitxt/inference.py @@ -258,20 +258,6 @@ class InferenceEngine(abc.ABC, MockInferenceMixin, CachedInferenceMixin): concurrency_limit: int = 100 support_log_probs: bool = False - @abc.abstractmethod - async def _infer( - self, - dataset: Union[List[Dict[str, Any]], Dataset], - return_meta_data: bool = False, - ) -> Union[List[str], List[TextGenerationInferenceOutput]]: - """Perform inference on the input dataset. - - If return_meta_data - returns a list of TextGenerationInferenceOutput, else returns a list of the string. - return_meta_data is only supported for some InferenceEngines. - predictions. - """ - pass - @abc.abstractmethod def prepare_engine(self): """Perform inference on the input dataset.""" @@ -480,14 +466,6 @@ def _infer_batch( """ pass - async def _infer( - self, - dataset: Union[List[Dict[str, Any]], Dataset], - return_meta_data: bool = False, - ) -> Union[List[str], List[TextGenerationInferenceOutput]]: - """Process dataset by batching instances.""" - raise NotImplementedError - async def _infer_streaming( self, instances: Iterable[Tuple[int, Dict[str, Any]]], @@ -588,13 +566,6 @@ async def _infer_single( """ pass - async def _infer( - self, - dataset: Union[List[Dict[str, Any]], Dataset], - return_meta_data: bool = False, - ) -> Union[List[str], List[TextGenerationInferenceOutput]]: - raise NotImplementedError - async def _infer_streaming( self, instances: Iterable[Tuple[int, Dict[str, Any]]], @@ -3527,18 +3498,13 @@ def prepare_engine(self): self.engine: InferenceEngine = cls(**args) self.data_classification_policy = self.engine.data_classification_policy - async def _infer( - self, - dataset: Union[List[Dict[str, Any]], Dataset], - return_meta_data: bool = False, - ) -> Union[List[str], List[TextGenerationInferenceOutput]]: - return await self.engine._infer(dataset, return_meta_data) - async def _async_infer( self, dataset: Union[List[Dict[str, Any]], Dataset], return_meta_data: bool = False, ) -> Union[ListWithMetadata[str], ListWithMetadata[TextGenerationInferenceOutput]]: + if not hasattr(self, "engine"): + self.prepare_engine() return await self.engine._async_infer(dataset, return_meta_data) async def _infer_streaming( @@ -3547,6 +3513,8 @@ async def _infer_streaming( total_len: int, return_meta_data: bool = False, ) -> AsyncIterable[Tuple[Union[int, str, TextGenerationInferenceOutput]]]: + if not hasattr(self, "engine"): + self.prepare_engine() return await self.engine._infer_streaming( instances, total_len, return_meta_data ) From 2d1e376d6a9e40a2406a81ad9c88e75175b21185 Mon Sep 17 00:00:00 2001 From: elronbandel Date: Wed, 20 Aug 2025 14:07:28 +0300 Subject: [PATCH 05/16] Some fixes Signed-off-by: elronbandel --- src/unitxt/inference.py | 56 ++++++++++++------------ tests/inference/test_inference_engine.py | 2 +- 2 files changed, 29 insertions(+), 29 deletions(-) diff --git a/src/unitxt/inference.py b/src/unitxt/inference.py index f648b158cc..e6d74e96b9 100644 --- a/src/unitxt/inference.py +++ b/src/unitxt/inference.py @@ -252,7 +252,22 @@ def _cache_result(self, instance: Dict[str, Any], prediction): self._cache[cache_key] = prediction -class InferenceEngine(abc.ABC, MockInferenceMixin, CachedInferenceMixin): +class PersistentAsyncLoopMixin: + _loop = None + + def _get_persistent_loop(self): + if self._loop is None or self._loop.is_closed(): + self._loop = asyncio.new_event_loop() + return self._loop + + def _run_coroutine(self, coroutine): + loop = self._get_persistent_loop() + return loop.run_until_complete(coroutine) + + +class InferenceEngine( + abc.ABC, MockInferenceMixin, CachedInferenceMixin, PersistentAsyncLoopMixin +): """Abstract base class for inference.""" concurrency_limit: int = 100 @@ -318,7 +333,7 @@ def infer( If return_meta_data - returns a list of TextGenerationInferenceOutput, else returns a list of the string predictions. """ - return asyncio.run(self._async_infer(dataset, return_meta_data)) + return self._run_coroutine(self._async_infer(dataset, return_meta_data)) def infer_log_probs( self, @@ -330,7 +345,7 @@ def infer_log_probs( f"return_log_probs set to True but supplied engine " f"{self.__class__.__name__} does not support logprobs." ) - return asyncio.run( + return self._run_coroutine( self._async_infer(dataset, return_meta_data, return_log_probs=True) ) @@ -2377,6 +2392,8 @@ class WMLInferenceEngineGeneration(WMLInferenceEngineBase, WMLGenerationParamsMi results = wml_inference.infer(dataset["test"]) """ + support_log_probs: bool = True + def verify(self): super().verify() @@ -2422,38 +2439,21 @@ async def _infer_single( if self._model is None: self._load_model() + # Base params params = self.to_dict([WMLGenerationParamsMixin], keep_empty=False) - inp = instance["source"] - - response = await self._model.agenerate(prompt=inp, params=params) - result = response["results"][0] - pred = result.get("generated_text", "") - - if return_meta_data: - return TextGenerationInferenceOutput( - prediction=pred, - generated_text=result.get("generated_text"), - input_tokens=result.get("input_token_count"), - output_tokens=result.get("generated_token_count"), - model_name=self.model_name or self.deployment_id, - inference_type=self.label, - stop_reason=result.get("stop_reason"), - seed=self.random_seed, - input_text=inp, - ) - return pred - async def _infer_log_probs_single(self, instance, return_meta_data: bool = False): - if self._model is None: - self._load_model() + if return_log_probs: + params = self._set_logprobs_params(params) - params = self.to_dict([WMLGenerationParamsMixin], keep_empty=False) - params = self._set_logprobs_params(params) inp = instance["source"] response = await self._model.agenerate(prompt=inp, params=params) result = response["results"][0] - pred = result.get("generated_tokens", {}) + + if return_log_probs: + pred = result.get("generated_tokens", {}) + else: + pred = result.get("generated_text", "") if return_meta_data: return TextGenerationInferenceOutput( diff --git a/tests/inference/test_inference_engine.py b/tests/inference/test_inference_engine.py index 84587243fa..90f5775b32 100644 --- a/tests/inference/test_inference_engine.py +++ b/tests/inference/test_inference_engine.py @@ -308,7 +308,7 @@ def test_hf_auto_model_inference_engine_batching(self): predictions = list(model(dataset)) - self.assertListEqual(predictions, ["", "12"]) + self.assertListEqual(predictions, ["7\n", "12"]) def test_hf_auto_model_inference_engine(self): data = get_text_dataset() From 5abc29fd7f31023bc2390453f0dba346ed22ee43 Mon Sep 17 00:00:00 2001 From: elronbandel Date: Wed, 20 Aug 2025 14:37:43 +0300 Subject: [PATCH 06/16] Make event loop thread safe Signed-off-by: elronbandel --- src/unitxt/inference.py | 31 +++++++++++++++++++++---------- 1 file changed, 21 insertions(+), 10 deletions(-) diff --git a/src/unitxt/inference.py b/src/unitxt/inference.py index e6d74e96b9..f7e1fa43f4 100644 --- a/src/unitxt/inference.py +++ b/src/unitxt/inference.py @@ -9,6 +9,7 @@ import os import re import sys +import threading import time import uuid from collections import Counter @@ -17,6 +18,7 @@ from typing import ( Any, AsyncIterable, + Awaitable, Dict, Iterable, List, @@ -253,16 +255,25 @@ def _cache_result(self, instance: Dict[str, Any], prediction): class PersistentAsyncLoopMixin: - _loop = None + """Mixin providing a per-thread asyncio event loop. - def _get_persistent_loop(self): - if self._loop is None or self._loop.is_closed(): - self._loop = asyncio.new_event_loop() - return self._loop + Each thread gets its own persistent loop stored in thread-local memory. + This avoids 'event loop is closed' and 'loop already running' errors when running async code from multiple threads. + """ + + _thread_local = threading.local() # separate storage per thread + + def _get_loop(self) -> asyncio.AbstractEventLoop: + loop = getattr(self._thread_local, "loop", None) + if loop is None or loop.is_closed(): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) # bind to this thread + self._thread_local.loop = loop # stored only in this thread + return loop - def _run_coroutine(self, coroutine): - loop = self._get_persistent_loop() - return loop.run_until_complete(coroutine) + def _run_async(self, coro: Awaitable[Any]) -> Any: + loop = self._get_loop() + return loop.run_until_complete(coro) class InferenceEngine( @@ -333,7 +344,7 @@ def infer( If return_meta_data - returns a list of TextGenerationInferenceOutput, else returns a list of the string predictions. """ - return self._run_coroutine(self._async_infer(dataset, return_meta_data)) + return self._run_async(self._async_infer(dataset, return_meta_data)) def infer_log_probs( self, @@ -345,7 +356,7 @@ def infer_log_probs( f"return_log_probs set to True but supplied engine " f"{self.__class__.__name__} does not support logprobs." ) - return self._run_coroutine( + return self._run_async( self._async_infer(dataset, return_meta_data, return_log_probs=True) ) From f698e63773ea041e484367fe6ae3eac8b1479a2d Mon Sep 17 00:00:00 2001 From: elronbandel Date: Thu, 21 Aug 2025 13:18:19 +0300 Subject: [PATCH 07/16] Fix generic engine Signed-off-by: elronbandel --- src/unitxt/inference.py | 69 ++++++++++++++++++++++------------------- 1 file changed, 37 insertions(+), 32 deletions(-) diff --git a/src/unitxt/inference.py b/src/unitxt/inference.py index f7e1fa43f4..7f74c85e08 100644 --- a/src/unitxt/inference.py +++ b/src/unitxt/inference.py @@ -30,6 +30,7 @@ TypedDict, TypeVar, Union, + cast, ) from datasets import Dataset, DatasetDict, Image @@ -1506,7 +1507,38 @@ async def _infer_single( return predict_result -class GenericInferenceEngine(InferenceEngine, ArtifactFetcherMixin): +class DecoratedInferenceEngine(InferenceEngine, LazyLoadMixin): + engine: InferenceEngine = InternalField(default=None) + + def _is_loaded(self): + return hasattr(self, "engine") and self.engine is not None + + async def _async_infer( + self, + dataset: Union[List[Dict[str, Any]], Dataset], + return_meta_data: bool = False, + ) -> Union[ListWithMetadata[str], ListWithMetadata[TextGenerationInferenceOutput]]: + if not self._is_loaded(): + self.prepare_engine() + return await self.engine._async_infer(dataset, return_meta_data) + + async def _infer_streaming( + self, + instances: Iterable[Tuple[Union[int, Dict[str, Any]]]], + total_len: int, + return_meta_data: bool = False, + ) -> AsyncIterable[Tuple[Union[int, str, TextGenerationInferenceOutput]]]: + if not self._is_loaded(): + self.prepare_engine() + return await self.engine._infer_streaming( + instances, total_len, return_meta_data + ) + + def get_engine_id(self): + return self.engine.get_engine_id() + + +class GenericInferenceEngine(DecoratedInferenceEngine, ArtifactFetcherMixin): default: Optional[str] = None def prepare_engine(self): @@ -1522,15 +1554,9 @@ def prepare_engine(self): "\nor passing a similar required engine in the default argument" ) engine_reference = self.default - self.engine = self.get_artifact(engine_reference) - - def _infer_streaming( - self, - instances: Iterable[Tuple[Union[int, Dict[str, Any]]]], - total_len: int, - return_meta_data: bool = False, - ) -> AsyncIterable[Tuple[int, Union[str, TextGenerationInferenceOutput]]]: - return self.engine._infer_streaming(instances, total_len, return_meta_data) + self.engine: InferenceEngine = cast( + InferenceEngine, self.get_artifact(engine_reference) + ) class OllamaInferenceEngine( @@ -3210,7 +3236,7 @@ async def _infer_single( ] -class CrossProviderInferenceEngine(InferenceEngine, StandardAPIParamsMixin): +class CrossProviderInferenceEngine(DecoratedInferenceEngine, StandardAPIParamsMixin): """Inference engine capable of dynamically switching between multiple providers APIs. This class extends the InferenceEngine and OpenAiInferenceEngineParamsMixin @@ -3509,27 +3535,6 @@ def prepare_engine(self): self.engine: InferenceEngine = cls(**args) self.data_classification_policy = self.engine.data_classification_policy - async def _async_infer( - self, - dataset: Union[List[Dict[str, Any]], Dataset], - return_meta_data: bool = False, - ) -> Union[ListWithMetadata[str], ListWithMetadata[TextGenerationInferenceOutput]]: - if not hasattr(self, "engine"): - self.prepare_engine() - return await self.engine._async_infer(dataset, return_meta_data) - - async def _infer_streaming( - self, - instances: Iterable[Tuple[Union[int, Dict[str, Any]]]], - total_len: int, - return_meta_data: bool = False, - ) -> AsyncIterable[Tuple[Union[int, str, TextGenerationInferenceOutput]]]: - if not hasattr(self, "engine"): - self.prepare_engine() - return await self.engine._infer_streaming( - instances, total_len, return_meta_data - ) - def get_engine_id(self): api = self.get_provider_name() if self.model in self.provider_model_map[api]: From b2ffe9fbddba87679a8b3faeb2f277580d0ef9f2 Mon Sep 17 00:00:00 2001 From: elronbandel Date: Thu, 21 Aug 2025 15:02:47 +0300 Subject: [PATCH 08/16] Small fix Signed-off-by: elronbandel --- src/unitxt/inference.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/unitxt/inference.py b/src/unitxt/inference.py index 7f74c85e08..0ff67f7e43 100644 --- a/src/unitxt/inference.py +++ b/src/unitxt/inference.py @@ -1535,6 +1535,8 @@ async def _infer_streaming( ) def get_engine_id(self): + if not self._is_loaded(): + return self.prepare_engine() return self.engine.get_engine_id() From 96cd9893ff83d654855e29dd7693b61485eb9b10 Mon Sep 17 00:00:00 2001 From: elronbandel Date: Thu, 21 Aug 2025 17:19:29 +0300 Subject: [PATCH 09/16] Fix lazy preparation Signed-off-by: elronbandel --- src/unitxt/inference.py | 210 ++++++++++------------- tests/inference/test_inference_engine.py | 5 +- 2 files changed, 96 insertions(+), 119 deletions(-) diff --git a/src/unitxt/inference.py b/src/unitxt/inference.py index 0ff67f7e43..0db0cdd196 100644 --- a/src/unitxt/inference.py +++ b/src/unitxt/inference.py @@ -277,29 +277,48 @@ def _run_async(self, coro: Awaitable[Any]) -> Any: return loop.run_until_complete(coro) +class LazyPrepareMixin(Artifact): + lazy_prepare: bool = NonPositionalField(default=True) + + @abc.abstractmethod + def _is_prepared(self): + pass + + @abc.abstractmethod + def _prepare_engine(self): + pass + + def prepare_engine(self): + if not self._is_prepared(): + with error_context( + self, + stage="Prepare Inference Engine", + help="https://www.unitxt.ai/en/latest/docs/inference.html", + ): + self._prepare_engine() + + def lazy_prepare_engine(self): + if not self.lazy_prepare: + self.prepare_engine() + + class InferenceEngine( - abc.ABC, MockInferenceMixin, CachedInferenceMixin, PersistentAsyncLoopMixin + abc.ABC, + LazyPrepareMixin, + MockInferenceMixin, + CachedInferenceMixin, + PersistentAsyncLoopMixin, ): """Abstract base class for inference.""" concurrency_limit: int = 100 support_log_probs: bool = False - @abc.abstractmethod - def prepare_engine(self): - """Perform inference on the input dataset.""" - pass - def prepare(self): if not self.is_mock: self._initialize_cache() super().prepare() # This will call CachedInferenceMixin.prepare() which initializes cache - with error_context( - self, - stage="Prepare Inference Engine", - help="https://www.unitxt.ai/en/latest/docs/inference.html", - ): - self.prepare_engine() + self.lazy_prepare_engine() def __call__( self, @@ -377,6 +396,7 @@ async def _async_infer( return_log_probs: bool = False, ) -> Union[ListWithMetadata[str], ListWithMetadata[TextGenerationInferenceOutput]]: """Internal async method that handles all inference logic.""" + self.prepare_engine() self.verify_infer_inputs(dataset, return_meta_data) if self.is_mock: @@ -660,14 +680,6 @@ async def bounded_infer(idx: int, instance: Dict[str, Any]): ) -class LazyLoadMixin(Artifact): - lazy_load: bool = NonPositionalField(default=False) - - @abc.abstractmethod - def _is_loaded(self): - pass - - class HFGenerationParamsMixin(Artifact): max_new_tokens: Optional[int] = None do_sample: bool = False @@ -680,10 +692,20 @@ class HFGenerationParamsMixin(Artifact): eos_token_id: Optional[int] = None +class ClientMixin: + def _is_prepared(self): + return hasattr(self, "client") and self.client is not None + + +class ModelMixin: + def _is_prepared(self): + return hasattr(self, "model") and self.model is not None + + class HFInferenceEngineBase( + ModelMixin, BatchInferenceEngine, PackageRequirementsMixin, - LazyLoadMixin, HFGenerationParamsMixin, TorchDeviceMixin, ): @@ -712,9 +734,6 @@ class HFInferenceEngineBase( "accelerate": "pip install accelerate", } - def _is_loaded(self): - return hasattr(self, "model") and self.model is not None - def _set_inference_device(self): if self.device is not None and self.device_map is not None: raise ValueError( @@ -766,10 +785,6 @@ def _prepare_engine(self): self._init_processor() self._init_model() - def prepare_engine(self): - if not self.lazy_load: - self._prepare_engine() - def get_engine_id(self): return get_model_and_label_id(self.model_name, self.label) @@ -861,24 +876,6 @@ def get_return_object( ) return output - def infer( - self, - dataset: Union[List[Dict[str, Any]], Dataset], - return_meta_data: bool = False, - ) -> Union[List[str], List[TextGenerationInferenceOutput]]: - if not self._is_loaded(): - self._prepare_engine() - return super().infer(dataset, return_meta_data) - - def infer_log_probs( - self, - dataset: Union[List[Dict[str, Any]], Dataset], - return_meta_data: bool = False, - ) -> Union[List[Dict], List[TextGenerationInferenceOutput]]: - if not self._is_loaded(): - self._prepare_engine() - return super().infer_log_probs(dataset, return_meta_data) - @abc.abstractmethod def _infer_batch( self, @@ -1263,7 +1260,7 @@ def _init_model(self): class HFPipelineBasedInferenceEngine( BatchInferenceEngine, PackageRequirementsMixin, - LazyLoadMixin, + LazyPrepareMixin, HFGenerationParamsMixin, TorchDeviceMixin, ): @@ -1287,7 +1284,7 @@ class HFPipelineBasedInferenceEngine( "accelerate": "pip install accelerate", } - def _is_loaded(self): + def _is_prepared(self): return hasattr(self, "model") and self.model is not None def get_engine_id(self): @@ -1398,10 +1395,6 @@ def _prepare_engine(self): model_args = self._get_model_args() self._create_pipeline(model_args) - def prepare_engine(self): - if not self.lazy_load: - self._prepare_engine() - def _infer_batch( self, instances: List[Dict[str, Any]], @@ -1409,9 +1402,6 @@ def _infer_batch( return_log_probs: bool = False, ) -> List[Union[str, TextGenerationInferenceOutput]]: """Run a synchronous batch through the model and return outputs (with optional metadata).""" - if not self._is_loaded(): - self._prepare_engine() - # Prepare input texts inputs = [inst["source"] for inst in instances] @@ -1458,7 +1448,7 @@ class MockInferenceEngine(SingleInferenceEngine): def get_engine_id(self): return get_model_and_label_id(self.model_name, "mock") - def prepare_engine(self): + def _prepare_engine(self): return def _mock_infer( @@ -1507,19 +1497,21 @@ async def _infer_single( return predict_result -class DecoratedInferenceEngine(InferenceEngine, LazyLoadMixin): +class DecoratedInferenceEngine(InferenceEngine, LazyPrepareMixin): engine: InferenceEngine = InternalField(default=None) - def _is_loaded(self): - return hasattr(self, "engine") and self.engine is not None + def _is_prepared(self): + return ( + hasattr(self, "engine") + and self.engine is not None + and self.engine._is_prepared + ) async def _async_infer( self, dataset: Union[List[Dict[str, Any]], Dataset], return_meta_data: bool = False, ) -> Union[ListWithMetadata[str], ListWithMetadata[TextGenerationInferenceOutput]]: - if not self._is_loaded(): - self.prepare_engine() return await self.engine._async_infer(dataset, return_meta_data) async def _infer_streaming( @@ -1528,22 +1520,19 @@ async def _infer_streaming( total_len: int, return_meta_data: bool = False, ) -> AsyncIterable[Tuple[Union[int, str, TextGenerationInferenceOutput]]]: - if not self._is_loaded(): - self.prepare_engine() return await self.engine._infer_streaming( instances, total_len, return_meta_data ) def get_engine_id(self): - if not self._is_loaded(): - return self.prepare_engine() + self.prepare_engine() return self.engine.get_engine_id() class GenericInferenceEngine(DecoratedInferenceEngine, ArtifactFetcherMixin): default: Optional[str] = None - def prepare_engine(self): + def _prepare_engine(self): if "UNITXT_INFERENCE_ENGINE" in os.environ: engine_reference = os.environ["UNITXT_INFERENCE_ENGINE"] else: @@ -1562,7 +1551,7 @@ def prepare_engine(self): class OllamaInferenceEngine( - SingleInferenceEngine, StandardAPIParamsMixin, PackageRequirementsMixin + ClientMixin, SingleInferenceEngine, StandardAPIParamsMixin, PackageRequirementsMixin ): label: str = "ollama" _requirements_list = { @@ -1573,7 +1562,7 @@ class OllamaInferenceEngine( def get_engine_id(self): return get_model_and_label_id(self.model, self.label) - def prepare_engine(self): + def _prepare_engine(self): from ollama import AsyncClient self.client = AsyncClient( @@ -1752,6 +1741,7 @@ def inner(self, args): class OpenAiInferenceEngine( + ClientMixin, SingleInferenceEngine, OpenAiInferenceEngineParamsMixin, PackageRequirementsMixin, @@ -1800,7 +1790,7 @@ def create_client(self): default_headers=self.get_default_headers(), ) - def prepare_engine(self): + def _prepare_engine(self): from openai import AsyncOpenAI self.client: AsyncOpenAI = self.create_client() @@ -1932,13 +1922,13 @@ class RITSInferenceEngine( def get_default_headers(self): return {"RITS_API_KEY": self.credentials["api_key"]} - def prepare_engine(self): + def _prepare_engine(self): # inference endpoint need the '/v1' path self.base_url = ( RITSInferenceEngine.get_base_url_from_model_name(self.model_name) + "/v1" ) logger.info(f"Created RITS inference engine with base url: {self.base_url}") - super().prepare_engine() + super()._prepare_engine() def get_client_model_name(self): if self.model_name.startswith("byom-"): @@ -1995,6 +1985,7 @@ class TogetherAiInferenceEngineParamsMixin(Artifact): class TogetherAiInferenceEngine( + ClientMixin, SingleInferenceEngine, TogetherAiInferenceEngineParamsMixin, PackageRequirementsMixin, @@ -2010,7 +2001,7 @@ class TogetherAiInferenceEngine( def get_engine_id(self): return get_model_and_label_id(self.model_name, self.label) - def prepare_engine(self): + def _prepare_engine(self): from together import AsyncTogether from together.types.models import ModelType @@ -2329,13 +2320,18 @@ def _verify_wml_credentials(credentials: CredentialsWML) -> None: "as keys for WML credentials dict." ) - def prepare_engine(self): + def _prepare_engine(self): self.check_missing_requirements() self._client = self._initialize_wml_client() self._set_inference_parameters() + self._load_model() + + def _is_prepared(self): + return self._client is not None and self._model is not None + def _load_model(self): from ibm_watsonx_ai.foundation_models.inference import ModelInference @@ -2349,8 +2345,7 @@ def get_model_details(self) -> Dict: return self._model.get_details() def get_token_count(self, dataset): - if self._model is None: - self._load_model() + self.prepare_engine() texts = [instance["source"] for instance in dataset] @@ -2364,8 +2359,7 @@ def get_token_count(self, dataset): def get_options_log_probs(self, dataset): """Add to each instance in the data a "options_log_prob" field, which is a dict with str as key and a list of {text: str, logprob:float}.""" - if self._model is None: - self._load_model() + self.prepare_engine() texts = [x["source"] for x in dataset] @@ -2475,10 +2469,6 @@ async def _infer_single( return_meta_data: bool = False, return_log_probs: bool = False, ): - if self._model is None: - self._load_model() - - # Base params params = self.to_dict([WMLGenerationParamsMixin], keep_empty=False) if return_log_probs: @@ -2562,16 +2552,6 @@ class WMLInferenceEngineChat(WMLInferenceEngineBase, WMLChatParamsMixin): ) support_log_probs: bool = True - def _async_infer( - self, - dataset: Union[List[Dict[str, Any]], Dataset], - return_meta_data: bool = False, - return_log_probs: bool = False, - ) -> Union[ListWithMetadata[str], ListWithMetadata[TextGenerationInferenceOutput]]: - if self._model is None: - self._load_model() - return super()._async_infer(dataset, return_meta_data, return_log_probs) - @staticmethod def _extract_queries(instance: Dict[str, Any]) -> Tuple[Optional[str], List]: task_data = instance["task_data"] @@ -2710,7 +2690,6 @@ async def _infer_single( return_meta_data: bool = False, return_log_probs: bool = False, ): - # Build params params = self.to_dict([WMLChatParamsMixin], keep_empty=False) params["logprobs"] = return_log_probs output_type = "logprobs" if return_log_probs else "message" @@ -2840,7 +2819,11 @@ def get_text_without_images(instance, image_token=""): class LMMSEvalBaseInferenceEngine( - BatchInferenceEngine, PackageRequirementsMixin, LazyLoadMixin, TorchDeviceMixin + ModelMixin, + BatchInferenceEngine, + PackageRequirementsMixin, + LazyPrepareMixin, + TorchDeviceMixin, ): label = "lmms-eval" model_type: str @@ -2855,10 +2838,6 @@ class LMMSEvalBaseInferenceEngine( def get_engine_id(self): return get_model_and_label_id(self.model_type, self.label) - def prepare_engine(self): - if not self.lazy_load: - self._prepare_engine() - def _prepare_engine(self): from lmms_eval.api.instance import Instance from lmms_eval.models import get_model @@ -2879,9 +2858,6 @@ def _prepare_engine(self): }, ) - def _is_loaded(self): - return hasattr(self, "model") and self.model is not None - class LMMSEvalInferenceEngine(LMMSEvalBaseInferenceEngine): max_new_tokens: int = 32 @@ -2895,9 +2871,6 @@ def _infer_batch( return_meta_data: bool = False, return_log_probs: bool = False, ) -> Union[List[str], List[TextGenerationInferenceOutput]]: - if not self._is_loaded(): - self._prepare_engine() - from lmms_eval.api.instance import Instance temp_task_name = str(uuid.uuid4()) @@ -2968,9 +2941,6 @@ def _infer_batch( return_meta_data: bool = False, return_log_probs: bool = False, ) -> Union[List[str], List[TextGenerationInferenceOutput]]: - if not self._is_loaded(): - self._prepare_engine() - temp_task_name = str(uuid.uuid4()) requests = [] @@ -3032,14 +3002,14 @@ class VLLMParamsMixin(Artifact): class VLLMInferenceEngine( - SingleInferenceEngine, PackageRequirementsMixin, VLLMParamsMixin + ModelMixin, SingleInferenceEngine, PackageRequirementsMixin, VLLMParamsMixin ): label = "vllm" def get_engine_id(self): return get_model_and_label_id(self.model, self.label) - def prepare_engine(self): + def _prepare_engine(self): args = self.to_dict([VLLMParamsMixin]) args.pop("model") @@ -3058,7 +3028,7 @@ def prepare_engine(self): max_num_seqs=64, enforce_eager=True, ) - self.llm = AsyncLLMEngine.from_engine_args(engine_args) + self.model = AsyncLLMEngine.from_engine_args(engine_args) async def _infer_single( self, @@ -3074,12 +3044,12 @@ async def _infer_single( # Use the async VLLM generate method if isinstance(source, list): # Chat format - result_generator = self.llm.generate( + result_generator = self.model.generate( request_id, source, self.sampling_params ) else: # Simple text format - result_generator = self.llm.generate( + result_generator = self.model.generate( request_id, source, self.sampling_params ) @@ -3156,7 +3126,10 @@ class LiteLLMInferenceEngine( def get_engine_id(self): return get_model_and_label_id(self.model, self.label) - def prepare_engine(self): + def _is_prepared(self): + return hasattr(self, "_completion") and self._completion is not None + + def _prepare_engine(self): if self.credentials is None: self.credentials = {} # Initialize the token bucket rate limiter @@ -3504,7 +3477,7 @@ def get_return_object(self, **kwargs): def get_provider_name(self): return self.provider if self.provider is not None else settings.default_provider - def prepare_engine(self): + def _prepare_engine(self): provider = self.get_provider_name() if provider not in self._provider_to_base_class: raise UnitxtError( @@ -3544,7 +3517,9 @@ def get_engine_id(self): return get_model_and_label_id(self.model, api) -class HFOptionSelectingInferenceEngine(BatchInferenceEngine, TorchDeviceMixin): +class HFOptionSelectingInferenceEngine( + ModelMixin, BatchInferenceEngine, TorchDeviceMixin +): """HuggingFace based class for inference engines that calculate log probabilities. This class uses models from the HuggingFace Transformers library to calculate log probabilities for text inputs. @@ -3561,7 +3536,7 @@ def get_engine_id(self): return get_model_and_label_id(self.model_name, self.label) @retry_connection_with_exponential_backoff(backoff_factor=2) - def prepare_engine(self): + def _prepare_engine(self): from transformers import AutoModelForCausalLM, AutoTokenizer self.device = self.get_device() @@ -3691,8 +3666,11 @@ def _infer_batch( references=references, ) - def prepare_engine(self): + def _prepare_engine(self): pass + def _is_prepared(self): + return hasattr(self, "metric") and self.metric is not None + def get_engine_id(self): return "metric_inference_engine" diff --git a/tests/inference/test_inference_engine.py b/tests/inference/test_inference_engine.py index 90f5775b32..6944e8caf4 100644 --- a/tests/inference/test_inference_engine.py +++ b/tests/inference/test_inference_engine.py @@ -104,9 +104,9 @@ def test_pipeline_based_inference_engine(self): def test_pipeline_based_inference_engine_lazy_load(self): model = HFPipelineBasedInferenceEngine( - model_name=local_decoder_model, # pragma: allowlist secret + model_name=local_decoder_model, max_new_tokens=2, - lazy_load=True, + lazy_prepare=True, use_cache=False, ) dataset = get_text_dataset() @@ -119,7 +119,6 @@ def test_dataset_verification_inference_engine(self): inference_model = HFPipelineBasedInferenceEngine( model_name=local_decoder_model, # pragma: allowlist secret max_new_tokens=2, - lazy_load=True, data_classification_policy=["public"], use_cache=False, ) From 6ad0554ac34669c103c2cb390e44ab2f688742a1 Mon Sep 17 00:00:00 2001 From: elronbandel Date: Thu, 21 Aug 2025 20:16:37 +0300 Subject: [PATCH 10/16] Fix mock Signed-off-by: elronbandel --- src/unitxt/inference.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/unitxt/inference.py b/src/unitxt/inference.py index 0db0cdd196..9a144e646e 100644 --- a/src/unitxt/inference.py +++ b/src/unitxt/inference.py @@ -1444,6 +1444,7 @@ class MockInferenceEngine(SingleInferenceEngine): ) label: str = "mock_inference_engine" support_log_probs: bool = True + lazy_prepare = False def get_engine_id(self): return get_model_and_label_id(self.model_name, "mock") @@ -1451,6 +1452,9 @@ def get_engine_id(self): def _prepare_engine(self): return + def _is_prepared(self): + return True + def _mock_infer( self, dataset: Union[List[Dict[str, Any]], Dataset], From fa80d31d34ca349efd389910938898a396886e4d Mon Sep 17 00:00:00 2001 From: elronbandel Date: Thu, 21 Aug 2025 20:56:03 +0300 Subject: [PATCH 11/16] Some more fixes Signed-off-by: elronbandel --- src/unitxt/inference.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/unitxt/inference.py b/src/unitxt/inference.py index 9a144e646e..317847c724 100644 --- a/src/unitxt/inference.py +++ b/src/unitxt/inference.py @@ -296,6 +296,10 @@ def prepare_engine(self): help="https://www.unitxt.ai/en/latest/docs/inference.html", ): self._prepare_engine() + if not self._is_prepared(): + raise RuntimeError( + "After calling _prepare_engine then _is_prepared() must return True." + ) def lazy_prepare_engine(self): if not self.lazy_prepare: @@ -1511,11 +1515,15 @@ def _is_prepared(self): and self.engine._is_prepared ) + def _prepare_engine(self): + return self.engine._prepare_engine() + async def _async_infer( self, dataset: Union[List[Dict[str, Any]], Dataset], return_meta_data: bool = False, ) -> Union[ListWithMetadata[str], ListWithMetadata[TextGenerationInferenceOutput]]: + self.prepare_engine() return await self.engine._async_infer(dataset, return_meta_data) async def _infer_streaming( @@ -1524,6 +1532,7 @@ async def _infer_streaming( total_len: int, return_meta_data: bool = False, ) -> AsyncIterable[Tuple[Union[int, str, TextGenerationInferenceOutput]]]: + self.prepare_engine() return await self.engine._infer_streaming( instances, total_len, return_meta_data ) From 8531cb308355aff2e805e16257b7e0c23f33ef3e Mon Sep 17 00:00:00 2001 From: elronbandel Date: Thu, 21 Aug 2025 22:14:55 +0300 Subject: [PATCH 12/16] Some minor improvements Signed-off-by: elronbandel --- src/unitxt/inference.py | 86 ++++++++++++++--------------------------- 1 file changed, 28 insertions(+), 58 deletions(-) diff --git a/src/unitxt/inference.py b/src/unitxt/inference.py index 317847c724..a26090e3bb 100644 --- a/src/unitxt/inference.py +++ b/src/unitxt/inference.py @@ -336,18 +336,6 @@ def __call__( ): return self.infer(dataset=dataset, return_meta_data=return_meta_data) - def get_instance_cache_key(self, instance): - instance_key_fields = ["media", "source", "task_data"] - return {key: instance[key] for key in instance if key in instance_key_fields} - - def _get_cache_key(self, instance: Dict[str, Any]) -> str: - """Generate a unique cache key for each input.""" - record = self.get_instance_cache_key(instance) - record["version"] = constants.version - record.update(self.to_dict()) - instance_str = json.dumps(record, sort_keys=True) - return hashlib.md5(instance_str.encode()).hexdigest() - def verify_infer_inputs( self, dataset: Union[List[Dict[str, Any]], Dataset], return_meta_data: bool ): @@ -355,8 +343,8 @@ def verify_infer_inputs( raise Exception( "Dataset passed to infer() is not list of dictionaries or Huggingface Dataset" ) - - [self.verify_instance(instance) for instance in dataset] + for instance in dataset: + self.verify_instance(instance) def infer( self, @@ -400,7 +388,6 @@ async def _async_infer( return_log_probs: bool = False, ) -> Union[ListWithMetadata[str], ListWithMetadata[TextGenerationInferenceOutput]]: """Internal async method that handles all inference logic.""" - self.prepare_engine() self.verify_infer_inputs(dataset, return_meta_data) if self.is_mock: @@ -410,6 +397,7 @@ async def _async_infer( else: result = self._mock_infer(dataset, return_meta_data) else: + self.prepare_engine() logger.info(f"Running inference with {self.get_engine_id()}") results: Dict[int, Any] = {} async for index, prediction in self._infer_streaming( @@ -530,7 +518,7 @@ async def _infer_streaming( total_generated = 0 total_loaded = 0 pbar = tqdm( - total=total_len, desc=f"Inference {self.get_engine_id()}", unit="inst" + total=total_len, desc=f"Inference ({self.get_engine_id()})", unit="item" ) try: for idx, instance in instances: @@ -562,22 +550,14 @@ async def _infer_streaming( results = self._infer_batch( current_batch, return_meta_data, return_log_probs ) - for original_idx, result in zip(current_indices, results): + for original_idx, instance, result in zip( + current_indices, current_batch, results + ): self._cache_result(instance, result) pbar.update(1) total_generated += 1 yield original_idx, result finally: - # Process remaining instances - if current_batch: - results = self._infer_batch( - current_batch, return_meta_data, return_log_probs - ) - for original_idx, result in zip(current_indices, results): - self._cache_result(instance, result) - pbar.update(1) - total_generated += 1 - yield original_idx, result pbar.close() logger.info( @@ -624,61 +604,57 @@ async def _infer_streaming( return_meta_data: bool = False, return_log_probs: bool = False, ) -> AsyncIterable[Tuple[int, Union[str, TextGenerationInferenceOutput]]]: - """Stream results concurrently without realizing the input iterable, with tqdm progress.""" sem = asyncio.Semaphore(self.concurrency_limit) it = iter(instances) pending: set[asyncio.Task] = set() - # Initialize tqdm with total length pbar = tqdm( total=total_len, desc=f"Inference ({self.get_engine_id()})", unit="item" ) total_loaded = 0 total_generated = 0 - async def bounded_infer(idx: int, instance: Dict[str, Any]): - nonlocal total_loaded, total_generated + async def process_instance(idx: int, instance: Dict[str, Any]): cached_result = self._get_cached_result(instance) if cached_result is not None: - total_loaded += 1 - return idx, cached_result + return idx, cached_result, True + async with sem: result = await self._infer_single( instance, return_meta_data, return_log_probs ) - total_generated += 1 self._cache_result(instance, result) - return idx, result + return idx, result, False - # Prime the pool - for _ in range(self.concurrency_limit): - try: - idx, inst = next(it) - except StopIteration: - break - pending.add(asyncio.create_task(bounded_infer(idx, inst))) - - # Drain while refilling try: + for _ in range(self.concurrency_limit): + try: + idx, inst = next(it) + pending.add(asyncio.create_task(process_instance(idx, inst))) + except StopIteration: + break + while pending: done, pending = await asyncio.wait( pending, return_when=asyncio.FIRST_COMPLETED ) for task in done: - idx, result = await task - pbar.update(1) # update progress for each finished task + idx, result, from_cache = await task + total_loaded += from_cache + total_generated += not from_cache + pbar.update(1) yield idx, result while len(pending) < self.concurrency_limit: try: idx, inst = next(it) + pending.add(asyncio.create_task(process_instance(idx, inst))) except StopIteration: break - pending.add(asyncio.create_task(bounded_infer(idx, inst))) + finally: pbar.close() - logger.info( f"Inference Summary: {total_generated} generated, {total_loaded} loaded from cache." ) @@ -1509,14 +1485,7 @@ class DecoratedInferenceEngine(InferenceEngine, LazyPrepareMixin): engine: InferenceEngine = InternalField(default=None) def _is_prepared(self): - return ( - hasattr(self, "engine") - and self.engine is not None - and self.engine._is_prepared - ) - - def _prepare_engine(self): - return self.engine._prepare_engine() + return hasattr(self, "engine") and self.engine is not None async def _async_infer( self, @@ -1538,7 +1507,8 @@ async def _infer_streaming( ) def get_engine_id(self): - self.prepare_engine() + if not self.is_mock: + self.prepare_engine() return self.engine.get_engine_id() @@ -3253,7 +3223,7 @@ class CrossProviderInferenceEngine(DecoratedInferenceEngine, StandardAPIParamsMi label: str = "cross_provider" provider: Optional[_supported_apis] = None provider_specific_args: Optional[Dict[str, Dict[str, str]]] = None - + lazy_prepare: bool = False provider_model_map: Dict[_supported_apis, Dict[str, str]] = { "watsonx-sdk": { # checked from ibm_watsonx_ai.APIClient().foundation_models.ChatModels "granite-20b-code-instruct": "ibm/granite-20b-code-instruct", From efc200a42897b0aea5970cb5855c7489cdeb03cc Mon Sep 17 00:00:00 2001 From: elronbandel Date: Thu, 21 Aug 2025 22:17:23 +0300 Subject: [PATCH 13/16] fix Signed-off-by: elronbandel --- src/unitxt/inference.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/unitxt/inference.py b/src/unitxt/inference.py index a26090e3bb..a277a9a8b7 100644 --- a/src/unitxt/inference.py +++ b/src/unitxt/inference.py @@ -3544,6 +3544,8 @@ def get_log_probs(self, texts): import torch from tqdm import tqdm + self.prepare_engine() + log_probs = [] # Process texts in batches From 186976756bbc743e9a9750d6d62cb4b8527f1d04 Mon Sep 17 00:00:00 2001 From: elronbandel Date: Sun, 24 Aug 2025 20:41:02 +0300 Subject: [PATCH 14/16] Fix mock id access Signed-off-by: elronbandel --- src/unitxt/inference.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/unitxt/inference.py b/src/unitxt/inference.py index a277a9a8b7..9fa70449b6 100644 --- a/src/unitxt/inference.py +++ b/src/unitxt/inference.py @@ -1507,8 +1507,9 @@ async def _infer_streaming( ) def get_engine_id(self): - if not self.is_mock: - self.prepare_engine() + if self.is_mock: + return "mock" + self.prepare_engine() return self.engine.get_engine_id() From 1e81f9add6549538fe60d71bc113ef21dba59e4b Mon Sep 17 00:00:00 2001 From: elronbandel Date: Mon, 25 Aug 2025 10:15:21 +0300 Subject: [PATCH 15/16] Fix example Signed-off-by: elronbandel --- examples/evaluate_existing_dataset_by_llm_as_judge_direct.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/evaluate_existing_dataset_by_llm_as_judge_direct.py b/examples/evaluate_existing_dataset_by_llm_as_judge_direct.py index fc09406d72..71085e51b4 100644 --- a/examples/evaluate_existing_dataset_by_llm_as_judge_direct.py +++ b/examples/evaluate_existing_dataset_by_llm_as_judge_direct.py @@ -21,7 +21,7 @@ ] dataset = load_dataset( card="cards.squad", - # metrics=metrics, + metrics=metrics, loader_limit=2, max_test_instances=2, split="test", @@ -40,7 +40,6 @@ about the the open ai api arguments the CrossProviderInferenceEngine follows. """ predictions = inference_model(dataset) -exit() gold_answers = [d[0] for d in dataset["references"]] # Evaluate the predictions using the defined metric. From dd9b456117cee0dea6f98e590033d5dbe1468070 Mon Sep 17 00:00:00 2001 From: elronbandel Date: Mon, 25 Aug 2025 10:18:17 +0300 Subject: [PATCH 16/16] remove old lazy_load Signed-off-by: elronbandel --- src/unitxt/inference.py | 1 - src/unitxt/metrics.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/src/unitxt/inference.py b/src/unitxt/inference.py index 9fa70449b6..9f4c6863f1 100644 --- a/src/unitxt/inference.py +++ b/src/unitxt/inference.py @@ -1042,7 +1042,6 @@ def _infer_batch( class HFLlavaInferenceEngine(HFInferenceEngineBase): - lazy_load: bool = True label: str = "hf_lava" image_token: str = "" support_log_probs: bool = True diff --git a/src/unitxt/metrics.py b/src/unitxt/metrics.py index c631f24a07..31c3c802e5 100644 --- a/src/unitxt/metrics.py +++ b/src/unitxt/metrics.py @@ -6450,7 +6450,7 @@ def prepare(self): IsCodeMixed.inference_model = HFPipelineBasedInferenceEngine( model_name="Nexusflow/Starling-LM-7B-beta", max_new_tokens=1, - lazy_load=True, + lazy_prepare=True, ) # the processing steps for preparing the prompt (instruction, answer prefix etc.) # that we send to the generative model