Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 46 additions & 23 deletions src/raglite/_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import json
from collections.abc import AsyncIterator, Callable, Iterator
from concurrent.futures import ThreadPoolExecutor, as_completed
from functools import partial
from typing import Any

import numpy as np
Expand Down Expand Up @@ -141,35 +143,56 @@ def _get_tools(
return tools, tool_choice


def _run_tool(
tool_call: ChatCompletionMessageToolCall,
on_retrieval: Callable[[list[ChunkSpan]], None] | None,
config: RAGLiteConfig,
) -> dict[str, Any]:
"""Run a single tool to search the knowledge base for RAG context."""
if tool_call.function.name == "search_knowledge_base":
kwargs = json.loads(tool_call.function.arguments)
kwargs["config"] = config
chunk_spans = retrieve_context(**kwargs)
message = {
"role": "tool",
"content": '{{"documents": [{elements}]}}'.format(
elements=", ".join(
chunk_span.to_json(index=i + 1) for i, chunk_span in enumerate(chunk_spans)
)
),
"tool_call_id": tool_call.id,
}
Comment on lines +155 to +164

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If chunk_spans is empty, the content field will be {"documents": []}, which might be fine, but might not be.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should confirm if retrieve_context can return 0 chunks.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not change this behavior with respect to the main branch.
Without metadata_filter, I think retrieve_context always returns a list of ChunkSpans, even if they are not that relevant (low similarity). With a metadata_filter applied (f.e. using self-query), the list could be empty.
I don't know if an empty list should be an issue, I find it more correct than retrieving non-relevant chunks for a query that does not have related documents on the database.
Open to discuss this :)
@Robbe-Superlinear

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that retrieve_context should return an empty list when no relevant chunks are found after applying metadata filtering.

However, how should we handle this case in the response? Currently, we return a message like this:

{
    "role": "tool",
    "content": "{documents: []}", 
    "tool_call_id": tool_call.id
}

We could consider a few alternatives for better clarity:
Option 1:

{
    "role": "tool",
    "content": "{documents: [], message: No results found}", 
    "tool_call_id": tool_call.id
}

Option 2:

{
    "role": "tool",
    "content": "{message: No results found}", 
    "tool_call_id": tool_call.id
}

Option 3: We could return None and let the _run_tools function handle empty tool_call responses.


if chunk_spans and callable(on_retrieval):
on_retrieval(chunk_spans)
else:
error_message = f"Unknown function `{tool_call.function.name}`."
raise ValueError(error_message)
return message


def _run_tools(
tool_calls: list[ChatCompletionMessageToolCall],
on_retrieval: Callable[[list[ChunkSpan]], None] | None,
config: RAGLiteConfig,
max_workers: int | None = None,
) -> list[dict[str, Any]]:
"""Run tools to search the knowledge base for RAG context."""
tool_messages: list[dict[str, Any]] = []
for tool_call in tool_calls:
if tool_call.function.name == "search_knowledge_base":
kwargs = json.loads(tool_call.function.arguments)
kwargs["config"] = config
chunk_spans = retrieve_context(**kwargs)
tool_messages.append(
{
"role": "tool",
"content": '{{"documents": [{elements}]}}'.format(
elements=", ".join(
chunk_span.to_json(index=i + 1)
for i, chunk_span in enumerate(chunk_spans)
)
),
"tool_call_id": tool_call.id,
}
)
if chunk_spans and callable(on_retrieval):
on_retrieval(chunk_spans)
else:
error_message = f"Unknown function `{tool_call.function.name}`."
raise ValueError(error_message)
tool_messages: list[dict[str, Any]] = [None] * len(tool_calls) # type: ignore[list-item]
with ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = [
executor.submit(partial(_run_tool, tool_call, on_retrieval, config))
for tool_call in tool_calls
]
future_to_index = {future: i for i, future in enumerate(futures)}
for future in as_completed(futures):
try:
message = future.result()
tool_messages[future_to_index[future]] = message
except Exception as e: # noqa: PERF203
executor.shutdown(cancel_futures=True) # Cancel remaining work.
error_message = f"Error executing tool: {e}"
raise ValueError(error_message) from e
return tool_messages


Expand Down