diff --git a/src/raglite/_rag.py b/src/raglite/_rag.py index 25fc2bd4..ecef7f67 100644 --- a/src/raglite/_rag.py +++ b/src/raglite/_rag.py @@ -2,6 +2,8 @@ import json from collections.abc import AsyncIterator, Callable, Iterator +from concurrent.futures import ThreadPoolExecutor, as_completed +from functools import partial from typing import Any import numpy as np @@ -141,35 +143,56 @@ def _get_tools( return tools, tool_choice +def _run_tool( + tool_call: ChatCompletionMessageToolCall, + on_retrieval: Callable[[list[ChunkSpan]], None] | None, + config: RAGLiteConfig, +) -> dict[str, Any]: + """Run a single tool to search the knowledge base for RAG context.""" + if tool_call.function.name == "search_knowledge_base": + kwargs = json.loads(tool_call.function.arguments) + kwargs["config"] = config + chunk_spans = retrieve_context(**kwargs) + message = { + "role": "tool", + "content": '{{"documents": [{elements}]}}'.format( + elements=", ".join( + chunk_span.to_json(index=i + 1) for i, chunk_span in enumerate(chunk_spans) + ) + ), + "tool_call_id": tool_call.id, + } + + if chunk_spans and callable(on_retrieval): + on_retrieval(chunk_spans) + else: + error_message = f"Unknown function `{tool_call.function.name}`." + raise ValueError(error_message) + return message + + def _run_tools( tool_calls: list[ChatCompletionMessageToolCall], on_retrieval: Callable[[list[ChunkSpan]], None] | None, config: RAGLiteConfig, + max_workers: int | None = None, ) -> list[dict[str, Any]]: """Run tools to search the knowledge base for RAG context.""" - tool_messages: list[dict[str, Any]] = [] - for tool_call in tool_calls: - if tool_call.function.name == "search_knowledge_base": - kwargs = json.loads(tool_call.function.arguments) - kwargs["config"] = config - chunk_spans = retrieve_context(**kwargs) - tool_messages.append( - { - "role": "tool", - "content": '{{"documents": [{elements}]}}'.format( - elements=", ".join( - chunk_span.to_json(index=i + 1) - for i, chunk_span in enumerate(chunk_spans) - ) - ), - "tool_call_id": tool_call.id, - } - ) - if chunk_spans and callable(on_retrieval): - on_retrieval(chunk_spans) - else: - error_message = f"Unknown function `{tool_call.function.name}`." - raise ValueError(error_message) + tool_messages: list[dict[str, Any]] = [None] * len(tool_calls) # type: ignore[list-item] + with ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = [ + executor.submit(partial(_run_tool, tool_call, on_retrieval, config)) + for tool_call in tool_calls + ] + future_to_index = {future: i for i, future in enumerate(futures)} + for future in as_completed(futures): + try: + message = future.result() + tool_messages[future_to_index[future]] = message + except Exception as e: # noqa: PERF203 + executor.shutdown(cancel_futures=True) # Cancel remaining work. + error_message = f"Error executing tool: {e}" + raise ValueError(error_message) from e return tool_messages