From 4e1f8d44b79abb6c0d89f68bfff722cb285c466c Mon Sep 17 00:00:00 2001 From: Jon Irastorza Date: Wed, 15 Oct 2025 12:09:47 +0000 Subject: [PATCH 1/3] feat: paralellize tool execution. --- src/raglite/_rag.py | 66 ++++++++++++++++++++++++++++++--------------- 1 file changed, 44 insertions(+), 22 deletions(-) diff --git a/src/raglite/_rag.py b/src/raglite/_rag.py index 25fc2bd4..5d1b5df9 100644 --- a/src/raglite/_rag.py +++ b/src/raglite/_rag.py @@ -2,6 +2,8 @@ import json from collections.abc import AsyncIterator, Callable, Iterator +from concurrent.futures import ThreadPoolExecutor, as_completed +from functools import partial from typing import Any import numpy as np @@ -141,35 +143,55 @@ def _get_tools( return tools, tool_choice +def _run_tool( + tool_call: ChatCompletionMessageToolCall, + on_retrieval: Callable[[list[ChunkSpan]], None] | None, + config: RAGLiteConfig, +) -> dict[str, Any]: + """Run a single tool to search the knowledge base for RAG context.""" + if tool_call.function.name == "search_knowledge_base": + kwargs = json.loads(tool_call.function.arguments) + kwargs["config"] = config + chunk_spans = retrieve_context(**kwargs) + message = { + "role": "tool", + "content": '{{"documents": [{elements}]}}'.format( + elements=", ".join( + chunk_span.to_json(index=i + 1) for i, chunk_span in enumerate(chunk_spans) + ) + ), + "tool_call_id": tool_call.id, + } + + if chunk_spans and callable(on_retrieval): + on_retrieval(chunk_spans) + else: + error_message = f"Unknown function `{tool_call.function.name}`." + raise ValueError(error_message) + return message + + def _run_tools( tool_calls: list[ChatCompletionMessageToolCall], on_retrieval: Callable[[list[ChunkSpan]], None] | None, config: RAGLiteConfig, + max_workers: int | None = None, ) -> list[dict[str, Any]]: """Run tools to search the knowledge base for RAG context.""" tool_messages: list[dict[str, Any]] = [] - for tool_call in tool_calls: - if tool_call.function.name == "search_knowledge_base": - kwargs = json.loads(tool_call.function.arguments) - kwargs["config"] = config - chunk_spans = retrieve_context(**kwargs) - tool_messages.append( - { - "role": "tool", - "content": '{{"documents": [{elements}]}}'.format( - elements=", ".join( - chunk_span.to_json(index=i + 1) - for i, chunk_span in enumerate(chunk_spans) - ) - ), - "tool_call_id": tool_call.id, - } - ) - if chunk_spans and callable(on_retrieval): - on_retrieval(chunk_spans) - else: - error_message = f"Unknown function `{tool_call.function.name}`." - raise ValueError(error_message) + with ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = [ + executor.submit(partial(_run_tool, tool_call, on_retrieval, config)) + for tool_call in tool_calls + ] + for future in as_completed(futures): + try: + message = future.result() + except Exception as e: + executor.shutdown(cancel_futures=True) # Cancel remaining work. + error_message = f"Error processing document: {e}" + raise ValueError(error_message) from e + tool_messages.append(message) return tool_messages From 2c0b0f0d3d1afa944d4e3de8244e28aa951ee49d Mon Sep 17 00:00:00 2001 From: Jon Irastorza Date: Wed, 22 Oct 2025 09:10:23 +0000 Subject: [PATCH 2/3] fix: error message and tool messages ordering. --- src/raglite/_rag.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/raglite/_rag.py b/src/raglite/_rag.py index 5d1b5df9..3854486f 100644 --- a/src/raglite/_rag.py +++ b/src/raglite/_rag.py @@ -178,20 +178,21 @@ def _run_tools( max_workers: int | None = None, ) -> list[dict[str, Any]]: """Run tools to search the knowledge base for RAG context.""" - tool_messages: list[dict[str, Any]] = [] + tool_messages: list[dict[str, Any]] = [None] * len(tool_calls) # type: ignore[list-item] with ThreadPoolExecutor(max_workers=max_workers) as executor: futures = [ executor.submit(partial(_run_tool, tool_call, on_retrieval, config)) for tool_call in tool_calls ] + future_to_index = {future: i for i, future in enumerate(futures)} for future in as_completed(futures): try: message = future.result() except Exception as e: executor.shutdown(cancel_futures=True) # Cancel remaining work. - error_message = f"Error processing document: {e}" + error_message = f"Error executing tool: {e}" raise ValueError(error_message) from e - tool_messages.append(message) + tool_messages[future_to_index[future]] = message return tool_messages From edf54b48c847997502508a462e7da40be5b7eaa5 Mon Sep 17 00:00:00 2001 From: Jon Irastorza Date: Mon, 27 Oct 2025 14:27:53 +0000 Subject: [PATCH 3/3] fix: try-except modified on _run_tools --- src/raglite/_rag.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/raglite/_rag.py b/src/raglite/_rag.py index 3854486f..ecef7f67 100644 --- a/src/raglite/_rag.py +++ b/src/raglite/_rag.py @@ -188,11 +188,11 @@ def _run_tools( for future in as_completed(futures): try: message = future.result() - except Exception as e: + tool_messages[future_to_index[future]] = message + except Exception as e: # noqa: PERF203 executor.shutdown(cancel_futures=True) # Cancel remaining work. error_message = f"Error executing tool: {e}" raise ValueError(error_message) from e - tool_messages[future_to_index[future]] = message return tool_messages