diff --git a/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py b/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py index c4ba589a..c63bb17a 100644 --- a/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py +++ b/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py @@ -1,4 +1,6 @@ -from typing import List, Optional, Type +import asyncio +import logging +from typing import Any, Dict, List, Optional, Type, Union from databricks_ai_bridge.utils.vector_search import IndexDetails from databricks_ai_bridge.vector_search_retriever_tool import ( @@ -7,13 +9,20 @@ VectorSearchRetrieverToolMixin, vector_search_retriever_tool_trace, ) +from langchain_core.documents import Document from langchain_core.embeddings import Embeddings from langchain_core.tools import BaseTool from pydantic import BaseModel, Field, PrivateAttr, model_validator from databricks_langchain import DatabricksEmbeddings +from databricks_langchain.multi_server_mcp_client import ( + DatabricksMCPServer, + DatabricksMultiServerMCPClient, +) from databricks_langchain.vectorstores import DatabricksVectorSearch +_logger = logging.getLogger(__name__) + class VectorSearchRetrieverTool(BaseTool, VectorSearchRetrieverToolMixin): """ @@ -48,6 +57,7 @@ class VectorSearchRetrieverTool(BaseTool, VectorSearchRetrieverToolMixin): args_schema: Type[BaseModel] = VectorSearchRetrieverToolInput _vector_store: DatabricksVectorSearch = PrivateAttr() + _mcp_tool: Optional[BaseTool] = PrivateAttr(default=None) @model_validator(mode="after") def _validate_tool_inputs(self): @@ -83,11 +93,68 @@ def _validate_tool_inputs(self): return self - @vector_search_retriever_tool_trace - def _run(self, query: str, filters: Optional[List[FilterItem]] = None, **kwargs) -> str: + def _create_or_get_mcp_tool(self) -> BaseTool: + """Create or return existing MCP tool using LangChain MCP Server.""" + if self._mcp_tool is not None: + return self._mcp_tool + + catalog, schema, index = self._parse_index_name() + + try: + server = DatabricksMCPServer.from_vector_search( + catalog=catalog, + schema=schema, + index_name=index, + name=f"vs-{index}", + workspace_client=self.workspace_client, + ) + client = DatabricksMultiServerMCPClient([server]) + except Exception as e: + self._handle_mcp_creation_error(e) + + tools = asyncio.run(client.get_tools()) + self._validate_mcp_tools(tools) + + self._mcp_tool = tools[0] + return self._mcp_tool + + def _parse_mcp_response(self, mcp_response: Any) -> List[Document]: + """Parse MCP tool response into LangChain Documents.""" + if isinstance(mcp_response, list) and mcp_response: + first_item = mcp_response[0] + if isinstance(first_item, dict) and first_item.get("type") == "text": + # Extract the actual JSON string from the content block + mcp_response = first_item.get("text", "") + + dicts = self._parse_mcp_response_to_dicts(mcp_response, strict=True) + return [Document(page_content=d["page_content"], metadata=d["metadata"]) for d in dicts] + + def _execute_mcp_path( + self, + query: str, + filters: Optional[Union[Dict[str, Any], List[FilterItem]]] = None, + **kwargs: Any, + ) -> List[Document]: + """Execute vector search via LangChain MCP infrastructure.""" + try: + mcp_tool = self._create_or_get_mcp_tool() + mcp_input = self._build_mcp_params(filters, query=query, **kwargs) + # MCP tools only support async invocation + result = asyncio.run(mcp_tool.ainvoke(mcp_input)) + return self._parse_mcp_response(result) + except Exception as e: + self._handle_mcp_execution_error(e) + + def _execute_direct_api_path( + self, + query: str, + filters: Optional[Union[Dict[str, Any], List[FilterItem]]] = None, + **kwargs: Any, + ) -> List[Document]: + """Execute vector search via direct DatabricksVectorSearch API.""" kwargs = {**kwargs, **(self.model_extra or {})} - # Since LLM can generate either a dict or FilterItem, convert to dict always - filters_dict = {dict(item)["key"]: dict(item)["value"] for item in (filters or [])} + # Normalize filters to dict format + filters_dict = self._normalize_filters(filters) combined_filters = {**filters_dict, **(self.filters or {})} # Allow kwargs to override the default values upon invocation @@ -104,3 +171,18 @@ def _run(self, query: str, filters: Optional[List[FilterItem]] = None, **kwargs) } ) return self._vector_store.similarity_search(**kwargs) + + @vector_search_retriever_tool_trace + def _run( + self, + query: str, + filters: Optional[Union[Dict[str, Any], List[FilterItem]]] = None, + **kwargs, + ) -> List[Document]: + """Execute vector search with automatic routing.""" + index_details = IndexDetails(self._vector_store.index) + + if index_details.is_databricks_managed_embeddings(): + return self._execute_mcp_path(query, filters, **kwargs) + else: + return self._execute_direct_api_path(query, filters, **kwargs) diff --git a/integrations/langchain/tests/unit_tests/test_vector_search_retriever_tool.py b/integrations/langchain/tests/unit_tests/test_vector_search_retriever_tool.py index a8ed42ac..80dc7f49 100644 --- a/integrations/langchain/tests/unit_tests/test_vector_search_retriever_tool.py +++ b/integrations/langchain/tests/unit_tests/test_vector_search_retriever_tool.py @@ -1,8 +1,9 @@ import json import os import threading +import uuid from typing import Any, Dict, List, Optional -from unittest.mock import MagicMock, create_autospec, patch +from unittest.mock import AsyncMock, MagicMock, create_autospec, patch import mlflow import pytest @@ -13,12 +14,14 @@ ALL_INDEX_NAMES, DELTA_SYNC_INDEX, DELTA_SYNC_INDEX_EMBEDDING_MODEL_ENDPOINT_NAME, + DIRECT_ACCESS_INDEX, INPUT_TEXTS, _get_index, mock_vs_client, mock_workspace_client, ) from databricks_ai_bridge.vector_search_retriever_tool import FilterItem +from langchain_core.documents import Document from langchain_core.embeddings import Embeddings from langchain_core.tools import BaseTool from mlflow.entities import SpanType @@ -41,6 +44,88 @@ ) +def _create_mcp_response(texts: List[str] = None) -> List[Dict[str, Any]]: + """Create a mock MCP response in LangChain MCP adapter content block format.""" + texts = texts or INPUT_TEXTS + search_results = [ + {"id": str(uuid.uuid4()), "text": text, "score": 0.85 - (i * 0.1)} + for i, text in enumerate(texts) + ] + return [{"type": "text", "text": json.dumps(search_results), "id": f"lc_{uuid.uuid4()}"}] + + +def assert_mcp_tool_called_with(mock_tool, expected_args: Dict[str, Any]): + """Assert MCP tool was called with expected args, handling JSON-stringified filters.""" + mock_tool.ainvoke.assert_called_once() + call_args = mock_tool.ainvoke.call_args[0][0] + for key, value in expected_args.items(): + if key == "filters": + assert json.loads(call_args["filters"]) == value + else: + assert call_args[key] == value + + +@pytest.fixture +def mock_mcp_infrastructure(): + """Mock MCP infrastructure for tests that need it.""" + # Create mock MCP tool that returns content block format + # (matching what langchain-mcp-adapters actually returns) + # MCP tools are async-only, so we mock ainvoke + mock_tool = MagicMock() + mock_tool.ainvoke = AsyncMock(return_value=_create_mcp_response()) + + # Create mock MCP client + mock_client_instance = MagicMock() + mock_client_instance.get_tools = AsyncMock(return_value=[mock_tool]) + + # Create mock MCP server + mock_server_instance = MagicMock() + + with ( + patch( + "databricks_langchain.vector_search_retriever_tool.DatabricksMultiServerMCPClient" + ) as mock_client_class, + patch( + "databricks_langchain.vector_search_retriever_tool.DatabricksMCPServer" + ) as mock_server_class, + ): + mock_client_class.return_value = mock_client_instance + mock_server_class.from_vector_search.return_value = mock_server_instance + yield { + "client_class": mock_client_class, + "client_instance": mock_client_instance, + "server_class": mock_server_class, + "server_instance": mock_server_instance, + "tool": mock_tool, + } + + +@pytest.fixture(params=["mcp", "direct_api"]) +def execution_path(request, mock_mcp_infrastructure): + """Parametrized fixture that sets up mocks for MCP or Direct API path.""" + if request.param == "mcp": + yield { + "path": "mcp", + "index_name": DELTA_SYNC_INDEX, + "mock_tool": mock_mcp_infrastructure["tool"], + "mock_mcp": mock_mcp_infrastructure, + } + else: + # For direct API, use an index that requires self-managed embeddings + yield { + "path": "direct_api", + "index_name": DIRECT_ACCESS_INDEX, + "mock_tool": None, + "mock_mcp": mock_mcp_infrastructure, + } + + +def setup_tool_for_path(execution_path, tool): + """Set up mock for the tool based on execution path.""" + if execution_path["path"] == "direct_api": + tool._vector_store.similarity_search = MagicMock(return_value=[]) + + def init_vector_search_tool( index_name: str, columns: Optional[List[str]] = None, @@ -93,40 +178,57 @@ def test_chat_model_bind_tools(llm: ChatDatabricks, index_name: str) -> None: assert isinstance(response, AIMessage) -def test_filters_are_passed_through() -> None: - vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX) - vector_search_tool._vector_store.similarity_search = MagicMock() +def test_filters_are_passed_through(execution_path) -> None: + """Test filters are passed through correctly on both paths.""" + tool = init_vector_search_tool(execution_path["index_name"]) + setup_tool_for_path(execution_path, tool) - vector_search_tool.invoke( + tool.invoke( { "query": "what cities are in Germany", "filters": [FilterItem(key="country", value="Germany")], } ) - vector_search_tool._vector_store.similarity_search.assert_called_once_with( - query="what cities are in Germany", - k=vector_search_tool.num_results, - filter={"country": "Germany"}, - query_type=vector_search_tool.query_type, - ) + if execution_path["path"] == "mcp": + assert_mcp_tool_called_with( + execution_path["mock_tool"], + {"query": "what cities are in Germany", "filters": {"country": "Germany"}}, + ) + else: + tool._vector_store.similarity_search.assert_called_once_with( + query="what cities are in Germany", + k=tool.num_results, + filter={"country": "Germany"}, + query_type=tool.query_type, + ) -def test_filters_are_combined() -> None: - vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX, filters={"city LIKE": "Berlin"}) - vector_search_tool._vector_store.similarity_search = MagicMock() - vector_search_tool.invoke( +def test_filters_are_combined(execution_path) -> None: + """Test filters are combined correctly (predefined + runtime) on both paths.""" + tool = init_vector_search_tool(execution_path["index_name"], filters={"city LIKE": "Berlin"}) + setup_tool_for_path(execution_path, tool) + + tool.invoke( { "query": "what cities are in Germany", "filters": [FilterItem(key="country", value="Germany")], } ) - vector_search_tool._vector_store.similarity_search.assert_called_once_with( - query="what cities are in Germany", - k=vector_search_tool.num_results, - filter={"city LIKE": "Berlin", "country": "Germany"}, - query_type=vector_search_tool.query_type, - ) + + expected_filters = {"city LIKE": "Berlin", "country": "Germany"} + if execution_path["path"] == "mcp": + assert_mcp_tool_called_with( + execution_path["mock_tool"], + {"query": "what cities are in Germany", "filters": expected_filters}, + ) + else: + tool._vector_store.similarity_search.assert_called_once_with( + query="what cities are in Germany", + k=tool.num_results, + filter=expected_filters, + query_type=tool.query_type, + ) @pytest.mark.parametrize("index_name", ALL_INDEX_NAMES) @@ -136,6 +238,7 @@ def test_filters_are_combined() -> None: @pytest.mark.parametrize("embedding", [None, EMBEDDING_MODEL]) @pytest.mark.parametrize("text_column", [None, "text"]) def test_vector_search_retriever_tool_combinations( + mock_mcp_infrastructure, index_name: str, columns: Optional[List[str]], tool_name: Optional[str], @@ -160,7 +263,8 @@ def test_vector_search_retriever_tool_combinations( assert result is not None -def test_vector_search_retriever_tool_combinations() -> None: +def test_vector_search_retriever_tool_doc_uri_primary_key(mock_mcp_infrastructure) -> None: + """Test that doc_uri and primary_key work correctly with MCP path.""" vector_search_tool = init_vector_search_tool( index_name=DELTA_SYNC_INDEX, doc_uri="uri", @@ -168,8 +272,13 @@ def test_vector_search_retriever_tool_combinations() -> None: ) assert isinstance(vector_search_tool, BaseTool) result = vector_search_tool.invoke("Databricks Agent Framework") - assert all(item.metadata.keys() == {"doc_uri", "chunk_id"} for item in result) - assert all(item.page_content for item in result) + # With MCP path, results are parsed from mock JSON response + assert result is not None + assert len(result) > 0 + assert all(isinstance(doc, Document) for doc in result) + # Verify Documents have expected structure from mock response + assert all(doc.page_content for doc in result) + assert all("id" in doc.metadata for doc in result) @pytest.mark.parametrize("index_name", ALL_INDEX_NAMES) @@ -191,16 +300,22 @@ def test_vector_search_retriever_tool_description_generation(index_name: str) -> @pytest.mark.parametrize("index_name", ALL_INDEX_NAMES) @pytest.mark.parametrize("tool_name", [None, "test_tool"]) -def test_vs_tool_tracing(index_name: str, tool_name: Optional[str]) -> None: +def test_vs_tool_tracing( + mock_mcp_infrastructure, index_name: str, tool_name: Optional[str] +) -> None: vector_search_tool = init_vector_search_tool(index_name, tool_name=tool_name) vector_search_tool._run("Databricks Agent Framework") + trace = mlflow.get_trace(mlflow.get_last_active_trace_id()) spans = trace.search_spans(name=tool_name or index_name, span_type=SpanType.RETRIEVER) assert len(spans) == 1 inputs = json.loads(trace.to_dict()["data"]["spans"][0]["attributes"]["mlflow.spanInputs"]) assert inputs["query"] == "Databricks Agent Framework" outputs = json.loads(trace.to_dict()["data"]["spans"][0]["attributes"]["mlflow.spanOutputs"]) - assert [d["page_content"] in INPUT_TEXTS for d in outputs] + # Verify outputs are Documents with page_content + assert len(outputs) > 0 + assert all("page_content" in d for d in outputs) + assert all(d["page_content"] for d in outputs) # page_content is not empty @pytest.mark.parametrize("index_name", ALL_INDEX_NAMES) @@ -344,36 +459,47 @@ def test_vector_search_client_with_sp_workspace_client(): ) -def test_kwargs_are_passed_through() -> None: - vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX, score_threshold=0.5) - vector_search_tool._vector_store.similarity_search = MagicMock() +def test_kwargs_are_passed_through(execution_path) -> None: + """Test kwargs are passed through correctly on both paths.""" + tool = init_vector_search_tool(execution_path["index_name"], score_threshold=0.5) + setup_tool_for_path(execution_path, tool) - vector_search_tool.invoke( - {"query": "what cities are in Germany", "extra_param": "something random"}, - ) - vector_search_tool._vector_store.similarity_search.assert_called_once_with( - query="what cities are in Germany", - k=vector_search_tool.num_results, - query_type=vector_search_tool.query_type, - filter={}, - score_threshold=0.5, - extra_param="something random", - ) + tool.invoke({"query": "what cities are in Germany"}) + if execution_path["path"] == "mcp": + assert_mcp_tool_called_with( + execution_path["mock_tool"], + {"query": "what cities are in Germany", "score_threshold": 0.5}, + ) + else: + tool._vector_store.similarity_search.assert_called_once_with( + query="what cities are in Germany", + k=tool.num_results, + filter={}, + query_type=tool.query_type, + score_threshold=0.5, + ) -def test_kwargs_override_both_num_results_and_query_type() -> None: - vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX, num_results=10, query_type="ANN") - vector_search_tool._vector_store.similarity_search = MagicMock() - vector_search_tool.invoke( - {"query": "what cities are in Germany", "k": 3, "query_type": "HYBRID"}, - ) - vector_search_tool._vector_store.similarity_search.assert_called_once_with( - query="what cities are in Germany", - k=3, # Should use overridden value - query_type="HYBRID", # Should use overridden value - filter={}, - ) +def test_kwargs_override_both_num_results_and_query_type(execution_path) -> None: + """Test kwargs can override num_results and query_type on both paths.""" + tool = init_vector_search_tool(execution_path["index_name"], num_results=10, query_type="ANN") + setup_tool_for_path(execution_path, tool) + + tool.invoke({"query": "what cities are in Germany", "k": 3, "query_type": "HYBRID"}) + + if execution_path["path"] == "mcp": + assert_mcp_tool_called_with( + execution_path["mock_tool"], + {"query": "what cities are in Germany", "num_results": 3, "query_type": "HYBRID"}, + ) + else: + tool._vector_store.similarity_search.assert_called_once_with( + query="what cities are in Germany", + k=3, + filter={}, + query_type="HYBRID", + ) def test_enhanced_filter_description_with_column_metadata() -> None: @@ -458,34 +584,38 @@ def test_cannot_use_both_dynamic_filter_and_predefined_filters() -> None: ) -def test_predefined_filters_work_without_dynamic_filter() -> None: - """Test that predefined filters work correctly when dynamic_filter is False.""" - # Initialize tool with only predefined filters (dynamic_filter=False by default) - vector_search_tool = init_vector_search_tool( - DELTA_SYNC_INDEX, filters={"status": "active", "category": "electronics"} +def test_predefined_filters_work_without_dynamic_filter(execution_path) -> None: + """Test that predefined filters work correctly when dynamic_filter is False on both paths.""" + tool = init_vector_search_tool( + execution_path["index_name"], filters={"status": "active", "category": "electronics"} ) + setup_tool_for_path(execution_path, tool) # The filters parameter should NOT be exposed since dynamic_filter=False - args_schema = vector_search_tool.args_schema + args_schema = tool.args_schema assert "filters" not in args_schema.model_fields - # Test that predefined filters are used - vector_search_tool._vector_store.similarity_search = MagicMock() - - vector_search_tool.invoke({"query": "what electronics are available"}) + tool.invoke({"query": "what electronics are available"}) - vector_search_tool._vector_store.similarity_search.assert_called_once_with( - query="what electronics are available", - k=vector_search_tool.num_results, - query_type=vector_search_tool.query_type, - filter={"status": "active", "category": "electronics"}, # Only predefined filters - ) + expected_filters = {"status": "active", "category": "electronics"} + if execution_path["path"] == "mcp": + assert_mcp_tool_called_with( + execution_path["mock_tool"], + {"query": "what electronics are available", "filters": expected_filters}, + ) + else: + tool._vector_store.similarity_search.assert_called_once_with( + query="what electronics are available", + k=tool.num_results, + filter=expected_filters, + query_type=tool.query_type, + ) -def test_filter_item_serialization() -> None: - """Test that FilterItem objects are properly converted to dictionaries.""" - vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX) - vector_search_tool._vector_store.similarity_search = MagicMock() +def test_filter_item_serialization(execution_path) -> None: + """Test that FilterItem objects are properly converted to dictionaries on both paths.""" + tool = init_vector_search_tool(execution_path["index_name"]) + setup_tool_for_path(execution_path, tool) # Test various filter types filters = [ @@ -495,7 +625,7 @@ def test_filter_item_serialization() -> None: FilterItem(key="tags", value=["wireless", "bluetooth"]), ] - vector_search_tool.invoke({"query": "find products", "filters": filters}) + tool.invoke({"query": "find products", "filters": filters}) expected_filters = { "category": "electronics", @@ -504,9 +634,84 @@ def test_filter_item_serialization() -> None: "tags": ["wireless", "bluetooth"], } - vector_search_tool._vector_store.similarity_search.assert_called_once_with( - query="find products", - k=vector_search_tool.num_results, - query_type=vector_search_tool.query_type, - filter=expected_filters, + if execution_path["path"] == "mcp": + assert_mcp_tool_called_with( + execution_path["mock_tool"], + {"query": "find products", "filters": expected_filters}, + ) + else: + tool._vector_store.similarity_search.assert_called_once_with( + query="find products", + k=tool.num_results, + filter=expected_filters, + query_type=tool.query_type, + ) + + +# ============================================================================= +# MCP Path Specific Tests +# ============================================================================= + + +def test_mcp_path_is_used_for_databricks_managed_embeddings(mock_mcp_infrastructure) -> None: + """Test that MCP path is used for Databricks-managed embeddings indexes.""" + vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX) + + # Invoke the tool (should use MCP path for DELTA_SYNC_INDEX which has managed embeddings) + result = vector_search_tool._run("test query") + + # Verify MCP server was created with correct parameters + mock_mcp_infrastructure["server_class"].from_vector_search.assert_called_once() + call_kwargs = mock_mcp_infrastructure["server_class"].from_vector_search.call_args[1] + assert call_kwargs["catalog"] == "test" + assert call_kwargs["schema"] == "delta_sync" + assert call_kwargs["index_name"] == "index" + + # Verify MCP client was used + mock_mcp_infrastructure["client_class"].assert_called_once() + + # Verify MCP tool was invoked with expected query (ainvoke since MCP tools are async-only) + mock_mcp_infrastructure["tool"].ainvoke.assert_called_once_with( + { + "query": "test query", + "num_results": vector_search_tool.num_results, + "query_type": vector_search_tool.query_type, + "include_score": "false", + } ) + + +def test_direct_api_path_is_used_for_self_managed_embeddings(mock_mcp_infrastructure) -> None: + """Test that direct API path is used for self-managed embeddings indexes.""" + # Use an index that requires self-managed embeddings + index_name = "test.direct_access.index" + vector_search_tool = init_vector_search_tool(index_name) + vector_search_tool._vector_store.similarity_search = MagicMock(return_value=[]) + + # Invoke the tool (should use direct API path) + result = vector_search_tool._run("test query") + + # Verify similarity_search was called directly + vector_search_tool._vector_store.similarity_search.assert_called_once() + + # Verify MCP was NOT used for self-managed embeddings + mock_mcp_infrastructure["tool"].ainvoke.assert_not_called() + + +def test_mcp_tool_is_cached(mock_mcp_infrastructure) -> None: + """Test that MCP tool is cached and not recreated on subsequent calls.""" + vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX) + + # Call _run multiple times + vector_search_tool._run("query 1") + vector_search_tool._run("query 2") + vector_search_tool._run("query 3") + + # MCP server should only be created once + assert mock_mcp_infrastructure["server_class"].from_vector_search.call_count == 1 + + # MCP client should only be created once + assert mock_mcp_infrastructure["client_class"].call_count == 1 + + # But MCP tool should be invoked 3 times + assert mock_mcp_infrastructure["tool"].ainvoke.call_count == 3 diff --git a/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py b/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py index b026b4ef..713d72f6 100644 --- a/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py +++ b/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py @@ -1,5 +1,4 @@ import inspect -import json import logging from typing import Any, Callable, Dict, List, Optional, Tuple, Union @@ -198,40 +197,6 @@ def _validate_tool_inputs(self): return self - @vector_search_retriever_tool_trace - def execute( - self, - query: str, - filters: Optional[Union[Dict[str, Any], List[FilterItem]]] = None, - openai_client: OpenAI = None, - **kwargs: Any, - ) -> List[Dict]: - """ - Execute the VectorSearchIndex tool calls from the ChatCompletions response that correspond to the - self.tool VectorSearchRetrieverToolInput and attach the retrieved documents into tool call messages. - - Execute vector search with automatic routing: - - MCP path: Used for Databricks-managed embeddings (no embedding model configuration needed) - - Direct API path: Used for self-managed embeddings (requires openai_client) - - Args: - query: The query text to use for the retrieval. - filters: Optional filters to refine vector search results. - openai_client: The OpenAI client object used to generate embeddings for retrieval queries. - Only used for self-managed embeddings. If not provided, the default OpenAI - client in the current environment will be used. - **kwargs: Additional search parameters (e.g., num_results, query_type, score_threshold, reranker). - For Databricks-managed embeddings, these are passed as MCP metadata. - For self-managed embeddings, these are passed to similarity_search(). - - Returns: - A list of document dictionaries. Format may vary between MCP and Direct API paths. - """ - if self._index_details.is_databricks_managed_embeddings(): - return self._execute_mcp_path(query, filters, **kwargs) - else: - return self._execute_direct_api_path(query, filters, openai_client, **kwargs) - def _create_or_get_mcp_toolkit(self) -> Callable: """ If it does not exist, create the MCP tool execution function for this index. @@ -243,12 +208,7 @@ def _create_or_get_mcp_toolkit(self) -> Callable: if self._mcp_tool_execute is not None: return self._mcp_tool_execute - parts = self.index_name.split(".") - if len(parts) != 3: - raise ValueError( - f"Invalid index name format: {self.index_name}. Expected 'catalog.schema.index'" - ) - catalog, schema, index = parts + catalog, schema, index = self._parse_index_name() try: self._mcp_toolkit = McpServerToolkit.from_vector_search( @@ -258,125 +218,27 @@ def _create_or_get_mcp_toolkit(self) -> Callable: workspace_client=self.workspace_client, ) except Exception as e: - raise RuntimeError( - f"Failed to initialize MCP toolkit for index {self.index_name}. " - f"Ensure the index exists and is configured for Databricks-managed embeddings. " - f"Error: {e}" - ) from e + self._handle_mcp_creation_error(e) tools = self._mcp_toolkit.get_tools() - if len(tools) < 1: - raise ValueError( - f"Expected exactly 1 MCP tool for index {self.index_name}, but got {len(tools)}" - ) + self._validate_mcp_tools(tools) self._mcp_tool_execute = tools[0].execute return self._mcp_tool_execute - def _normalize_filters( - self, filters: Optional[Union[Dict[str, Any], List[FilterItem]]] - ) -> Dict[str, Any]: - """ - Normalize filters to a dict format. - - Args: - filters: Either a dict or List[FilterItem] - - Returns: - Dict of filter key-value pairs - """ - if filters is None: - return {} - if isinstance(filters, dict): - return filters - return {item.model_dump()["key"]: item.model_dump()["value"] for item in filters} - - def _build_mcp_meta( - self, filters: Optional[Union[Dict[str, Any], List[FilterItem]]] = None, **kwargs: Any - ) -> Dict[str, Any]: - kwargs = {**(self.model_extra or {}), **kwargs} - - meta = {} - - num_results = kwargs.pop("num_results", self.num_results) - meta["num_results"] = num_results - - if self.query_type or "query_type" in kwargs: - query_type = kwargs.pop("query_type", self.query_type) - if query_type: - meta["query_type"] = query_type - - if self.columns: - meta["columns"] = ",".join(self.columns) - - combined_filters = {**self._normalize_filters(filters), **(self.filters or {})} - if combined_filters: - try: - meta["filters"] = json.dumps(combined_filters) - except (TypeError, ValueError) as e: - raise ValueError(f"Filters must be JSON serializable: {e}") from e - - if "score_threshold" in kwargs: - meta["score_threshold"] = float(kwargs.pop("score_threshold")) - - # Always send include_score explicitly to override backend defaults - meta["include_score"] = "true" if self.include_score else "false" - - reranker = kwargs.pop("reranker", self.reranker) - if reranker and hasattr(reranker, "columns_to_rerank"): - meta["columns_to_rerank"] = ",".join(reranker.columns_to_rerank) - - # Warn about any unknown kwargs - if kwargs: - _logger.warning( - f"Ignoring unsupported kwargs for MCP vector search: {list(kwargs.keys())}" - ) - - return meta - - def _normalize_mcp_result(self, result: Dict) -> Dict: - """ - Normalize MCP result to page_content/metadata format for backward compatibility. - - MCP returns: {"id": "doc1", "text": "content", "score": 0.95} - We convert to: {"page_content": "content", "metadata": {"id": "doc1", "score": 0.95}} - - This ensures callers get consistent output regardless of MCP vs Direct API path. - """ - text_column = self.text_column - page_content = result.get(text_column, "") - - metadata = {k: v for k, v in result.items() if k != text_column} - - return {"page_content": page_content, "metadata": metadata} - - def _parse_mcp_response(self, mcp_response: str) -> List[Dict]: - """ - Parse MCP JSON response and normalize to page_content/metadata format. - - The Vector Search MCP server returns a JSON array of flat result dicts. - We parse and normalize each result for consistent output format. - """ + def _execute_mcp_path( + self, + query: str, + filters: Optional[Union[Dict[str, Any], List[FilterItem]]] = None, + **kwargs: Any, + ) -> List[Dict]: try: - parsed = json.loads(mcp_response) - except json.JSONDecodeError as e: - _logger.error(f"Failed to parse MCP response as JSON: {mcp_response[:200]}...") - raise ValueError( - f"Unable to parse MCP response. Expected JSON format. Error: {e}" - ) from e - - if not isinstance(parsed, list): - # Show preview of what we got (limit to 500 chars for readability) - response_preview = str(parsed)[:500] - _logger.error( - f"MCP response is not a list: {type(parsed).__name__}. Content: {response_preview}" - ) - raise ValueError( - f"Expected MCP vector search to return a JSON array of results, " - f"but got {type(parsed).__name__}: {response_preview}" - ) - - return [self._normalize_mcp_result(result) for result in parsed] + mcp_execute = self._create_or_get_mcp_toolkit() + meta = self._build_mcp_params(filters, **kwargs) + mcp_response = mcp_execute(query=query, _meta=meta) + return self._parse_mcp_response_to_dicts(mcp_response, strict=True) + except Exception as e: + self._handle_mcp_execution_error(e) def _execute_direct_api_path( self, @@ -385,13 +247,9 @@ def _execute_direct_api_path( openai_client: OpenAI = None, **kwargs: Any, ) -> List[Dict]: - from openai import OpenAI + from databricks_openai import DatabricksOpenAI - oai_client = openai_client or OpenAI() - if not oai_client.api_key: - raise ValueError( - "OpenAI API key is required to generate embeddings for retrieval queries." - ) + oai_client = openai_client or DatabricksOpenAI(workspace_client=self.workspace_client) signature = inspect.signature(self._index.similarity_search) kwargs = {**kwargs, **(self.model_extra or {})} @@ -439,20 +297,36 @@ def _execute_direct_api_path( ) return [doc for doc, _ in docs_with_score] - def _execute_mcp_path( + @vector_search_retriever_tool_trace + def execute( self, query: str, filters: Optional[Union[Dict[str, Any], List[FilterItem]]] = None, + openai_client: OpenAI = None, **kwargs: Any, ) -> List[Dict]: - try: - mcp_execute = self._create_or_get_mcp_toolkit() - meta = self._build_mcp_meta(filters, **kwargs) - mcp_response = mcp_execute(query=query, _meta=meta) - documents = self._parse_mcp_response(mcp_response) - return documents - except Exception as e: - _logger.error(f"MCP vector search failed: {e}", exc_info=True) - raise RuntimeError( - f"Vector search via MCP failed for index {self.index_name}. Error: {e}" - ) from e + """ + Execute the VectorSearchIndex tool calls from the ChatCompletions response that correspond to the + self.tool VectorSearchRetrieverToolInput and attach the retrieved documents into tool call messages. + + Execute vector search with automatic routing: + - MCP path: Used for Databricks-managed embeddings (no embedding model configuration needed) + - Direct API path: Used for self-managed embeddings (requires openai_client) + + Args: + query: The query text to use for the retrieval. + filters: Optional filters to refine vector search results. + openai_client: The OpenAI client object used to generate embeddings for retrieval queries. + Only used for self-managed embeddings. If not provided, a DatabricksOpenAI + client will be created using the workspace_client for authentication. + **kwargs: Additional search parameters (e.g., num_results, query_type, score_threshold, reranker). + For Databricks-managed embeddings, these are passed as MCP metadata. + For self-managed embeddings, these are passed to similarity_search(). + + Returns: + A list of document dictionaries. Format may vary between MCP and Direct API paths. + """ + if self._index_details.is_databricks_managed_embeddings(): + return self._execute_mcp_path(query, filters, **kwargs) + else: + return self._execute_direct_api_path(query, filters, openai_client, **kwargs) diff --git a/integrations/openai/tests/unit_tests/test_vector_search_retriever_tool.py b/integrations/openai/tests/unit_tests/test_vector_search_retriever_tool.py index 4ac64f20..2cb0f5a9 100644 --- a/integrations/openai/tests/unit_tests/test_vector_search_retriever_tool.py +++ b/integrations/openai/tests/unit_tests/test_vector_search_retriever_tool.py @@ -45,7 +45,8 @@ def mock_openai_client(): mock_response.data = [Mock(embedding=[0.1, 0.2, 0.3, 0.4])] mock_client.embeddings.create.return_value = mock_response with patch("openai.OpenAI", return_value=mock_client): - yield mock_client + with patch("databricks_openai.DatabricksOpenAI", return_value=mock_client): + yield mock_client @pytest.fixture @@ -289,38 +290,6 @@ def test_open_ai_client_from_env( assert all(["id" in d["metadata"] for d in docs]) -@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES) -def test_vector_search_retriever_index_name_rewrite( - index_name: str, -) -> None: - if index_name == DELTA_SYNC_INDEX: - self_managed_embeddings_test = SelfManagedEmbeddingsTest() - else: - from openai import OpenAI - - self_managed_embeddings_test = SelfManagedEmbeddingsTest( - "text", "text-embedding-3-small", OpenAI(api_key="your-api-key") - ) - - vector_search_tool = init_vector_search_tool( - index_name=index_name, - text_column=self_managed_embeddings_test.text_column, - embedding_model_name=self_managed_embeddings_test.embedding_model_name, - ) - assert vector_search_tool.tool["function"]["name"] == index_name.replace(".", "__") - - -@pytest.mark.parametrize( - "index_name", - ["catalog.schema.really_really_really_long_tool_name_that_should_be_truncated_to_64_chars"], -) -def test_vector_search_retriever_long_index_name( - index_name: str, -) -> None: - vector_search_tool = init_vector_search_tool(index_name=index_name) - assert len(vector_search_tool.tool["function"]["name"]) <= 64 - - def test_vector_search_client_model_serving_environment(): with patch("os.path.isfile", return_value=True): # Simulate Model Serving Environment @@ -578,123 +547,6 @@ def test_include_score_always_sent_in_meta(mock_mcp_toolkit) -> None: assert call_kwargs["_meta"]["include_score"] == "false" -def test_get_filter_param_description_with_column_metadata() -> None: - """Test that _get_filter_param_description includes column metadata when available.""" - # Mock table info with column metadata - mock_column1 = Mock() - mock_column1.name = "category" - mock_column1.type_name.name = "STRING" - - mock_column2 = Mock() - mock_column2.name = "price" - mock_column2.type_name.name = "FLOAT" - - mock_column3 = Mock() - mock_column3.name = "__internal_column" # Should be excluded - mock_column3.type_name.name = "STRING" - - mock_table_info = Mock() - mock_table_info.columns = [mock_column1, mock_column2, mock_column3] - - with patch("databricks.sdk.WorkspaceClient") as mock_ws_client_class: - mock_ws_client = Mock() - mock_ws_client.tables.get.return_value = mock_table_info - mock_ws_client_class.return_value = mock_ws_client - - vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX) - - # Test the _get_filter_param_description method directly - description = vector_search_tool._get_filter_param_description() - - # Should include available columns in description - assert "Available columns for filtering: category (STRING), price (FLOAT)" in description - - # Should include comprehensive filter syntax - assert "Inclusion:" in description - assert "Exclusion:" in description - assert "Comparisons:" in description - assert "Pattern match:" in description - assert "OR logic:" in description - - # Should include examples - assert "Examples:" in description - assert "Filter by category:" in description - assert "Filter by price range:" in description - - -def test_enhanced_filter_description_used_in_tool_schema() -> None: - """Test that the tool schema includes comprehensive filter descriptions.""" - vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX, dynamic_filter=True) - - # Check that the tool schema includes enhanced filter description - tool_schema = vector_search_tool.tool - filter_param = tool_schema["function"]["parameters"]["properties"]["filters"] - - # Check that it includes the comprehensive filter syntax - assert "Inclusion:" in filter_param["description"] - assert "Exclusion:" in filter_param["description"] - assert "Comparisons:" in filter_param["description"] - assert "Pattern match:" in filter_param["description"] - assert "OR logic:" in filter_param["description"] - - # Check that it includes useful filter information - assert "array of key-value pairs" in filter_param["description"] - assert "column" in filter_param["description"] - - -def test_enhanced_filter_description_fails_on_table_metadata_error() -> None: - """Test that tool initialization fails with clear error when table metadata cannot be retrieved.""" - # Mock WorkspaceClient to raise an exception when accessing table metadata - with patch("databricks.sdk.WorkspaceClient") as mock_ws_client_class: - mock_ws_client = MagicMock() - mock_ws_client.tables.get.side_effect = Exception("Permission denied") - mock_ws_client_class.return_value = mock_ws_client - - # Try to initialize tool with dynamic_filter=True - # This should fail because we can't get table metadata - with pytest.raises( - ValueError, - match="Failed to retrieve table metadata for index.*Permission denied", - ): - init_vector_search_tool(DELTA_SYNC_INDEX, dynamic_filter=True) - - -def test_enhanced_filter_description_fails_on_empty_columns() -> None: - """Test that tool initialization fails when table has no valid columns.""" - # Mock WorkspaceClient to return a table with no valid columns (all start with __) - with patch("databricks.sdk.WorkspaceClient") as mock_ws_client_class: - mock_ws_client = MagicMock() - mock_table = MagicMock() - mock_column = MagicMock() - mock_column.name = "__internal_column" - mock_column.type_name = MagicMock() - mock_column.type_name.name = "STRING" - mock_table.columns = [mock_column] - mock_ws_client.tables.get.return_value = mock_table - mock_ws_client_class.return_value = mock_ws_client - - # Try to initialize tool with dynamic_filter=True - # This should fail because there are no valid columns - with pytest.raises( - ValueError, - match="No valid columns found in table metadata for index", - ): - init_vector_search_tool(DELTA_SYNC_INDEX, dynamic_filter=True) - - -def test_cannot_use_both_dynamic_filter_and_predefined_filters() -> None: - """Test that using both dynamic_filter and predefined filters raises an error.""" - # Try to initialize tool with both dynamic_filter=True and predefined filters - with pytest.raises( - ValueError, match="Cannot use both dynamic_filter=True and predefined filters" - ): - init_vector_search_tool( - DELTA_SYNC_INDEX, - filters={"status": "active", "category": "electronics"}, - dynamic_filter=True, - ) - - def test_predefined_filters_work_without_dynamic_filter(execution_path) -> None: """Test that predefined filters work correctly when dynamic_filter is False.""" predefined_filters = {"status": "active", "category": "electronics"} @@ -732,51 +584,6 @@ def test_predefined_filters_work_without_dynamic_filter(execution_path) -> None: assert call_kwargs["query_type"] == vector_search_tool.query_type -def test_filter_item_serialization(execution_path) -> None: - """Test that FilterItem objects are properly converted to dictionaries.""" - vector_search_tool = init_vector_search_tool(execution_path["index_name"]) - setup_tool_for_path(execution_path, vector_search_tool) - - # Test various filter types - filters = [ - FilterItem(key="category", value="electronics"), - FilterItem(key="price >=", value=100), - FilterItem(key="status NOT", value="discontinued"), - FilterItem(key="tags", value=["wireless", "bluetooth"]), - ] - - vector_search_tool.execute("find products", filters=filters) - - expected_filters = { - "category": "electronics", - "price >=": 100, - "status NOT": "discontinued", - "tags": ["wireless", "bluetooth"], - } - - if execution_path["path"] == "mcp": - mock_tool = execution_path["mock_tool"] - mock_tool.execute.assert_called_once() - call_kwargs = mock_tool.execute.call_args.kwargs - - assert call_kwargs["query"] == "find products" - - meta = call_kwargs["_meta"] - # Filters should be serialized as JSON - assert json.loads(meta["filters"]) == expected_filters - assert meta["num_results"] == vector_search_tool.num_results - assert meta["query_type"] == vector_search_tool.query_type - assert meta["columns"] == ",".join(vector_search_tool.columns) - else: - vector_search_tool._index.similarity_search.assert_called_once() - call_kwargs = vector_search_tool._index.similarity_search.call_args.kwargs - - assert call_kwargs["filters"] == expected_filters - assert call_kwargs["num_results"] == vector_search_tool.num_results - assert call_kwargs["query_type"] == vector_search_tool.query_type - assert call_kwargs["columns"] == vector_search_tool.columns - - def test_reranker_is_passed_through(execution_path) -> None: reranker = DatabricksReranker(columns_to_rerank=["country"]) vector_search_tool = init_vector_search_tool(execution_path["index_name"], reranker=reranker) @@ -844,72 +651,3 @@ def test_reranker_is_overriden(execution_path) -> None: assert call_kwargs["filters"] == {"country": "Germany"} assert call_kwargs["num_results"] == vector_search_tool.num_results assert call_kwargs["query_type"] == vector_search_tool.query_type - - -# ============================================================================ -# Response Format Normalization Tests -# ============================================================================ - - -class TestMCPResponseNormalization: - """Test that MCP responses are normalized to match Direct API format.""" - - def test_normalize_mcp_result_basic(self) -> None: - """Test basic normalization of a single MCP result.""" - vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX) - - mcp_result = { - "id": "doc-123", - "text": "This is the document content", - "score": 0.95, - } - - normalized = vector_search_tool._normalize_mcp_result(mcp_result) - - assert normalized["page_content"] == "This is the document content" - assert normalized["metadata"]["id"] == "doc-123" - assert normalized["metadata"]["score"] == 0.95 - assert "text" not in normalized["metadata"] # text column moved to page_content - - def test_normalize_mcp_result_missing_text_column(self) -> None: - """Test normalization handles missing text column gracefully.""" - vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX) - - mcp_result = { - "id": "doc-789", - "score": 0.75, - # "text" column is missing - } - - normalized = vector_search_tool._normalize_mcp_result(mcp_result) - - assert normalized["page_content"] == "" # Empty string when text column missing - assert normalized["metadata"]["id"] == "doc-789" - assert normalized["metadata"]["score"] == 0.75 - - def test_parse_mcp_response_empty_list(self) -> None: - """Test parsing empty MCP response.""" - vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX) - - mcp_response = json.dumps([]) - - results = vector_search_tool._parse_mcp_response(mcp_response) - - assert results == [] - - def test_parse_mcp_response_invalid_json(self) -> None: - """Test parsing invalid JSON raises ValueError.""" - vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX) - - with pytest.raises(ValueError, match="Unable to parse MCP response"): - vector_search_tool._parse_mcp_response("not valid json {") - - def test_parse_mcp_response_not_a_list(self) -> None: - """Test parsing non-list JSON raises ValueError.""" - vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX) - - # MCP should return a list, not a dict - mcp_response = json.dumps({"error": "something went wrong"}) - - with pytest.raises(ValueError, match="Expected MCP vector search to return a JSON array"): - vector_search_tool._parse_mcp_response(mcp_response) diff --git a/src/databricks_ai_bridge/vector_search_retriever_tool.py b/src/databricks_ai_bridge/vector_search_retriever_tool.py index 6b2aed2d..0fd56f71 100644 --- a/src/databricks_ai_bridge/vector_search_retriever_tool.py +++ b/src/databricks_ai_bridge/vector_search_retriever_tool.py @@ -1,7 +1,8 @@ +import json import logging import re from functools import wraps -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple, Union import mlflow from databricks.sdk import WorkspaceClient @@ -298,3 +299,131 @@ def _get_tool_name(self) -> str: ) return tool_name[-64:] return tool_name + + def _normalize_filters( + self, filters: Optional[Union[Dict[str, Any], List["FilterItem"]]] + ) -> Dict[str, Any]: + """Normalize filters to dict format.""" + if filters is None: + return {} + if isinstance(filters, dict): + return filters + return {item.model_dump()["key"]: item.model_dump()["value"] for item in filters} + + def _parse_index_name(self) -> Tuple[str, str, str]: + """Parse index_name into (catalog, schema, index) tuple.""" + parts = self.index_name.split(".") + if len(parts) != 3: + raise ValueError( + f"Invalid index name format: {self.index_name}. Expected 'catalog.schema.index'" + ) + return parts[0], parts[1], parts[2] + + def _handle_mcp_creation_error(self, error: Exception) -> None: + """Raise standardized error for MCP initialization failures.""" + raise RuntimeError( + f"Failed to initialize MCP tool for index {self.index_name}. " + f"Ensure the index exists and is configured for Databricks-managed embeddings. " + f"Error: {error}" + ) from error + + def _validate_mcp_tools(self, tools: list) -> None: + """Validate that exactly one MCP tool was returned.""" + if not tools: + raise ValueError(f"No MCP tools found for index {self.index_name}") + if len(tools) != 1: + raise ValueError( + f"Expected exactly 1 MCP tool for index {self.index_name}, but got {len(tools)}" + ) + + def _handle_mcp_execution_error(self, error: Exception) -> None: + """Log and raise standardized error for MCP execution failures.""" + _logger.error(f"MCP vector search failed: {error}", exc_info=True) + raise RuntimeError( + f"Vector search via MCP failed for index {self.index_name}. Error: {error}" + ) from error + + def _build_mcp_params( + self, + filters: Optional[Union[Dict[str, Any], List["FilterItem"]]] = None, + query: Optional[str] = None, + **kwargs: Any, + ) -> Dict[str, Any]: + """Build common MCP parameters dict.""" + kwargs = {**(self.model_extra or {}), **kwargs} + params: Dict[str, Any] = {} + + if query is not None: + params["query"] = query + + num_results = kwargs.pop("num_results", kwargs.pop("k", self.num_results)) + if num_results: + params["num_results"] = num_results + + query_type = kwargs.pop("query_type", self.query_type) + if query_type: + params["query_type"] = query_type + + combined_filters = {**self._normalize_filters(filters), **(self.filters or {})} + if combined_filters: + try: + params["filters"] = json.dumps(combined_filters) + except (TypeError, ValueError) as e: + raise ValueError(f"Filters must be JSON serializable: {e}") from e + + if self.columns: + params["columns"] = ",".join(self.columns) + + if "score_threshold" in kwargs: + params["score_threshold"] = float(kwargs.pop("score_threshold")) + + params["include_score"] = "true" if self.include_score else "false" + + reranker = kwargs.pop("reranker", self.reranker) + if reranker and hasattr(reranker, "columns_to_rerank"): + params["columns_to_rerank"] = ",".join(reranker.columns_to_rerank) + + if kwargs: + _logger.warning( + f"Ignoring unsupported kwargs for MCP vector search: {list(kwargs.keys())}" + ) + + return params + + def _parse_mcp_response_to_dicts( + self, mcp_response: str, text_column: Optional[str] = None, strict: bool = True + ) -> List[Dict[str, Any]]: + """Parse MCP JSON response to list of dicts with page_content/metadata structure.""" + text_col = text_column or getattr(self, "text_column", None) or "text" + + try: + parsed = json.loads(mcp_response) + except json.JSONDecodeError as e: + if strict: + _logger.error(f"Failed to parse MCP response as JSON: {mcp_response[:200]}...") + raise ValueError( + f"Unable to parse MCP response. Expected JSON format. Error: {e}" + ) from e + return [{"page_content": mcp_response, "metadata": {}}] + + if not isinstance(parsed, list): + if strict: + response_preview = str(parsed)[:500] + _logger.error( + f"MCP response is not a list: {type(parsed).__name__}. Content: {response_preview}" + ) + raise ValueError( + f"Expected JSON array, got {type(parsed).__name__}: {response_preview}" + ) + return [{"page_content": str(parsed), "metadata": {}}] + + results = [] + for item in parsed: + if isinstance(item, dict): + page_content = item.get(text_col, str(item)) + metadata = {k: v for k, v in item.items() if k != text_col} + results.append({"page_content": page_content, "metadata": metadata}) + else: + results.append({"page_content": str(item), "metadata": {}}) + + return results diff --git a/tests/databricks_ai_bridge/test_vector_search_retriever_tool.py b/tests/databricks_ai_bridge/test_vector_search_retriever_tool.py index dbf72a45..5a90c1cf 100644 --- a/tests/databricks_ai_bridge/test_vector_search_retriever_tool.py +++ b/tests/databricks_ai_bridge/test_vector_search_retriever_tool.py @@ -1,3 +1,4 @@ +import json from unittest.mock import MagicMock import pytest @@ -5,7 +6,10 @@ from databricks_ai_bridge.test_utils.vector_search import mock_workspace_client # noqa: F401 from databricks_ai_bridge.utils.vector_search import IndexDetails -from databricks_ai_bridge.vector_search_retriever_tool import VectorSearchRetrieverToolMixin +from databricks_ai_bridge.vector_search_retriever_tool import ( + FilterItem, + VectorSearchRetrieverToolMixin, +) class DummyVectorSearchRetrieverTool(VectorSearchRetrieverToolMixin): @@ -78,3 +82,377 @@ def test_describe_columns(): "country (STRING): Name of the country\n" "description (STRING): Detailed description of the city" ) + + +# ============================================================================= +# Tests for _normalize_filters +# ============================================================================= + + +def test_normalize_filters_with_filter_items(): + """Test that FilterItem list is normalized to dict.""" + tool = DummyVectorSearchRetrieverTool(index_name=index_name) + + filters = [ + FilterItem(key="category", value="electronics"), + FilterItem(key="price >=", value=100), + ] + + result = tool._normalize_filters(filters) + + assert result == {"category": "electronics", "price >=": 100} + + +# ============================================================================= +# Tests for _parse_mcp_response_to_dicts +# ============================================================================= + + +def test_parse_mcp_response_to_dicts_json_array(): + """Test that JSON array response is parsed correctly into dicts.""" + tool = DummyVectorSearchRetrieverTool(index_name=index_name) + + json_response = json.dumps( + [ + {"id": "doc1", "text": "content1", "score": 0.9}, + {"id": "doc2", "text": "content2", "score": 0.8}, + ] + ) + + dicts = tool._parse_mcp_response_to_dicts(json_response) + + assert len(dicts) == 2 + assert dicts[0]["page_content"] == "content1" + assert dicts[0]["metadata"] == {"id": "doc1", "score": 0.9} + assert dicts[1]["page_content"] == "content2" + assert dicts[1]["metadata"] == {"id": "doc2", "score": 0.8} + + +def test_parse_mcp_response_to_dicts_non_json_strict(): + """Test that non-JSON response raises ValueError when strict=True.""" + tool = DummyVectorSearchRetrieverTool(index_name=index_name) + + plain_text_response = "This is a plain text response" + + with pytest.raises(ValueError, match="Unable to parse MCP response"): + tool._parse_mcp_response_to_dicts(plain_text_response, strict=True) + + +def test_parse_mcp_response_to_dicts_non_json_non_strict(): + """Test that non-JSON response is treated as single document when strict=False.""" + tool = DummyVectorSearchRetrieverTool(index_name=index_name) + + plain_text_response = "This is a plain text response" + + dicts = tool._parse_mcp_response_to_dicts(plain_text_response, strict=False) + + assert len(dicts) == 1 + assert dicts[0]["page_content"] == plain_text_response + assert dicts[0]["metadata"] == {} + + +def test_parse_mcp_response_to_dicts_non_list_json_strict(): + """Test that non-list JSON raises ValueError when strict=True.""" + tool = DummyVectorSearchRetrieverTool(index_name=index_name) + + json_response = json.dumps({"message": "single object response"}) + + with pytest.raises(ValueError, match="Expected JSON array, got"): + tool._parse_mcp_response_to_dicts(json_response, strict=True) + + +def test_parse_mcp_response_to_dicts_non_list_json_non_strict(): + """Test that non-list JSON is converted to single document when strict=False.""" + tool = DummyVectorSearchRetrieverTool(index_name=index_name) + + json_response = json.dumps({"message": "single object response"}) + + dicts = tool._parse_mcp_response_to_dicts(json_response, strict=False) + + assert len(dicts) == 1 + assert dicts[0]["page_content"] == "{'message': 'single object response'}" + assert dicts[0]["metadata"] == {} + + +def test_parse_mcp_response_to_dicts_empty_list(): + """Test parsing empty list response.""" + tool = DummyVectorSearchRetrieverTool(index_name=index_name) + + json_response = json.dumps([]) + + dicts = tool._parse_mcp_response_to_dicts(json_response) + + assert dicts == [] + + +def test_parse_mcp_response_to_dicts_custom_text_column(): + """Test that custom text column is used for page_content.""" + tool = DummyVectorSearchRetrieverTool(index_name=index_name) + + json_response = json.dumps( + [ + {"id": "doc1", "content": "custom content", "score": 0.9}, + ] + ) + + dicts = tool._parse_mcp_response_to_dicts(json_response, text_column="content") + + assert len(dicts) == 1 + assert dicts[0]["page_content"] == "custom content" + assert dicts[0]["metadata"] == {"id": "doc1", "score": 0.9} + + +# ============================================================================= +# Tests for _build_mcp_params +# ============================================================================= + + +def test_build_mcp_params_basic(): + """Test basic MCP params building.""" + tool = DummyVectorSearchRetrieverTool(index_name=index_name) + + params = tool._build_mcp_params(None) + + assert params["num_results"] == tool.num_results + assert params["query_type"] == tool.query_type + assert params["include_score"] == "false" + assert "filters" not in params + + +def test_build_mcp_params_with_filters(): + """Test MCP params building with filters.""" + tool = DummyVectorSearchRetrieverTool(index_name=index_name) + + filters = [FilterItem(key="category", value="electronics")] + params = tool._build_mcp_params(filters) + + assert json.loads(params["filters"]) == {"category": "electronics"} + + +def test_build_mcp_params_combines_filters(): + """Test MCP params building combines predefined and runtime filters.""" + tool = DummyVectorSearchRetrieverTool(index_name=index_name, filters={"status": "active"}) + + runtime_filters = [FilterItem(key="category", value="electronics")] + params = tool._build_mcp_params(runtime_filters) + + expected_filters = {"status": "active", "category": "electronics"} + assert json.loads(params["filters"]) == expected_filters + + +def test_build_mcp_params_kwargs_override_defaults(): + """Test that kwargs override default values.""" + tool = DummyVectorSearchRetrieverTool(index_name=index_name, num_results=10, query_type="ANN") + + params = tool._build_mcp_params(None, num_results=5, query_type="HYBRID") + + assert params["num_results"] == 5 + assert params["query_type"] == "HYBRID" + + +def test_build_mcp_params_with_columns(): + """Test MCP params building with columns.""" + tool = DummyVectorSearchRetrieverTool(index_name=index_name, columns=["id", "text", "score"]) + + params = tool._build_mcp_params(None) + + assert params["columns"] == "id,text,score" + + +def test_build_mcp_params_with_include_score(): + """Test MCP params building with include_score=True.""" + tool = DummyVectorSearchRetrieverTool(index_name=index_name, include_score=True) + + params = tool._build_mcp_params(None) + + assert params["include_score"] == "true" + + +def test_build_mcp_params_k_alias_for_num_results(): + """Test that 'k' kwarg is treated as alias for num_results.""" + tool = DummyVectorSearchRetrieverTool(index_name=index_name, num_results=10) + + params = tool._build_mcp_params(None, k=3) + + assert params["num_results"] == 3 + + +def test_build_mcp_params_with_reranker(): + """Test MCP params building with reranker.""" + from databricks.vector_search.reranker import DatabricksReranker + + reranker = DatabricksReranker(columns_to_rerank=["text", "title"]) + tool = DummyVectorSearchRetrieverTool(index_name=index_name, reranker=reranker) + + params = tool._build_mcp_params(None) + + assert params["columns_to_rerank"] == "text,title" + + +# ============================================================================= +# Tests for _parse_index_name +# ============================================================================= + + +def test_parse_index_name_invalid(): + """Test parsing invalid index name raises ValueError.""" + tool = DummyVectorSearchRetrieverTool(index_name="invalid_index_name") + + with pytest.raises(ValueError, match="Invalid index name format"): + tool._parse_index_name() + + +# ============================================================================= +# Tests for validate_filter_configuration +# ============================================================================= + + +def test_cannot_use_both_dynamic_filter_and_predefined_filters(): + """Test that using both dynamic_filter and predefined filters raises an error.""" + # Try to initialize tool with both dynamic_filter=True and predefined filters + with pytest.raises( + ValueError, match="Cannot use both dynamic_filter=True and predefined filters" + ): + DummyVectorSearchRetrieverTool( + index_name=index_name, + filters={"status": "active", "category": "electronics"}, + dynamic_filter=True, + ) + + +# ============================================================================= +# Tests for _get_tool_name +# ============================================================================= + + +def test_get_tool_name_replaces_dots(): + """Test that dots in index name are replaced with underscores.""" + tool = DummyVectorSearchRetrieverTool(index_name="catalog.schema.my_index") + assert tool._get_tool_name() == "catalog__schema__my_index" + + +def test_get_tool_name_uses_custom_name(): + """Test that custom tool_name is used when provided.""" + tool = DummyVectorSearchRetrieverTool(index_name=index_name, tool_name="custom_tool") + assert tool._get_tool_name() == "custom_tool" + + +def test_get_tool_name_truncates_long_names(): + """Test that long tool names are truncated to 64 characters.""" + long_index = ( + "catalog.schema.really_really_really_long_tool_name_that_should_be_truncated_to_64_chars" + ) + tool = DummyVectorSearchRetrieverTool(index_name=long_index) + result = tool._get_tool_name() + assert len(result) <= 64 + + +# ============================================================================= +# Tests for _validate_mcp_tools +# ============================================================================= + + +def test_validate_mcp_tools_empty_list(): + """Test that empty tools list raises ValueError.""" + tool = DummyVectorSearchRetrieverTool(index_name=index_name) + with pytest.raises(ValueError, match="No MCP tools found for index"): + tool._validate_mcp_tools([]) + + +def test_validate_mcp_tools_multiple_tools(): + """Test that multiple tools raises ValueError.""" + tool = DummyVectorSearchRetrieverTool(index_name=index_name) + with pytest.raises(ValueError, match="Expected exactly 1 MCP tool"): + tool._validate_mcp_tools([MagicMock(), MagicMock()]) + + +def test_validate_mcp_tools_single_tool(): + """Test that single tool passes validation.""" + tool = DummyVectorSearchRetrieverTool(index_name=index_name) + # Should not raise + tool._validate_mcp_tools([MagicMock()]) + + +# ============================================================================= +# Tests for _get_filter_param_description +# ============================================================================= + + +def test_get_filter_param_description_includes_column_metadata(): + """Test that _get_filter_param_description includes column metadata when available.""" + from unittest.mock import Mock, patch + + mock_column1 = Mock() + mock_column1.name = "category" + mock_column1.type_name.name = "STRING" + + mock_column2 = Mock() + mock_column2.name = "price" + mock_column2.type_name.name = "FLOAT" + + mock_column3 = Mock() + mock_column3.name = "__internal_column" # Should be excluded + mock_column3.type_name.name = "STRING" + + mock_table_info = Mock() + mock_table_info.columns = [mock_column1, mock_column2, mock_column3] + + with patch("databricks.sdk.WorkspaceClient") as mock_ws_client_class: + mock_ws_client = Mock() + mock_ws_client.tables.get.return_value = mock_table_info + mock_ws_client_class.return_value = mock_ws_client + + tool = DummyVectorSearchRetrieverTool(index_name=index_name) + description = tool._get_filter_param_description() + + # Should include available columns in description + assert "Available columns for filtering: category (STRING), price (FLOAT)" in description + + # Should include comprehensive filter syntax + assert "Inclusion:" in description + assert "Exclusion:" in description + assert "Comparisons:" in description + assert "Pattern match:" in description + assert "OR logic:" in description + + +def test_get_filter_param_description_fails_on_table_metadata_error(): + """Test that _get_filter_param_description fails with clear error when table metadata cannot be retrieved.""" + from unittest.mock import patch + + with patch("databricks.sdk.WorkspaceClient") as mock_ws_client_class: + mock_ws_client = MagicMock() + mock_ws_client.tables.get.side_effect = Exception("Permission denied") + mock_ws_client_class.return_value = mock_ws_client + + tool = DummyVectorSearchRetrieverTool(index_name=index_name) + + with pytest.raises( + ValueError, + match="Failed to retrieve table metadata for index.*Permission denied", + ): + tool._get_filter_param_description() + + +def test_get_filter_param_description_fails_on_empty_columns(): + """Test that _get_filter_param_description fails when table has no valid columns.""" + from unittest.mock import patch + + with patch("databricks.sdk.WorkspaceClient") as mock_ws_client_class: + mock_ws_client = MagicMock() + mock_table = MagicMock() + mock_column = MagicMock() + mock_column.name = "__internal_column" + mock_column.type_name = MagicMock() + mock_column.type_name.name = "STRING" + mock_table.columns = [mock_column] + mock_ws_client.tables.get.return_value = mock_table + mock_ws_client_class.return_value = mock_ws_client + + tool = DummyVectorSearchRetrieverTool(index_name=index_name) + + with pytest.raises( + ValueError, + match="No valid columns found in table metadata for index", + ): + tool._get_filter_param_description()