diff --git a/integrations/langchain/tests/integration_tests/test_fmapi_tool_calling.py b/integrations/langchain/tests/integration_tests/test_fmapi_tool_calling.py new file mode 100644 index 00000000..312949c1 --- /dev/null +++ b/integrations/langchain/tests/integration_tests/test_fmapi_tool_calling.py @@ -0,0 +1,253 @@ +""" +End-to-end FMAPI tool calling tests for ChatDatabricks via LangGraph. + +Prerequisites: +- FMAPI endpoints must be available on the test workspace +""" + +from __future__ import annotations + +import os + +import pytest +from databricks_ai_bridge.test_utils.fmapi import ( + LANGCHAIN_SKIP_MODELS, + async_retry, + discover_foundation_models, + max_tokens_for_model, + retry, +) +from langchain_core.messages import AIMessage, AIMessageChunk, ToolMessage +from langchain_core.tools import tool +from langgraph.checkpoint.memory import MemorySaver +from langgraph.prebuilt import create_react_agent + +from databricks_langchain import ChatDatabricks + +pytestmark = pytest.mark.skipif( + os.environ.get("RUN_FMAPI_TOOL_CALLING_TESTS") != "1", + reason="FMAPI tool calling tests disabled. Set RUN_FMAPI_TOOL_CALLING_TESTS=1 to enable.", +) + +_FOUNDATION_MODELS = discover_foundation_models(LANGCHAIN_SKIP_MODELS) + + +@tool +def add(a: int, b: int) -> int: + """Add two integers. + + Args: + a: First integer + b: Second integer + """ + return a + b + + +@tool +def multiply(a: int, b: int) -> int: + """Multiply two integers. + + Args: + a: First integer + b: Second integer + """ + return a * b + + +# ============================================================================= +# Sync LangGraph Agent +# ============================================================================= + + +@pytest.mark.integration +@pytest.mark.parametrize("model", _FOUNDATION_MODELS) +class TestLangGraphSync: + """Sync LangGraph agent tests using ChatDatabricks + create_react_agent.""" + + def test_single_turn(self, model): + """Single-turn: agent calls tools and produces a final answer.""" + + def _run(): + llm = ChatDatabricks(model=model, max_tokens=max_tokens_for_model(model)) + agent = create_react_agent(llm, [add, multiply]) + + response = agent.invoke( + { + "messages": [ + ( + "human", + "Use the add tool to compute 10 + 5, then use the multiply tool " + "to multiply the result by 3. You MUST use the tools.", + ) + ] + } + ) + + last_message = response["messages"][-1] + assert isinstance(last_message, AIMessage) + assert "45" in last_message.content + + tool_messages = [m for m in response["messages"] if isinstance(m, ToolMessage)] + assert len(tool_messages) > 0, "Expected tool calls in conversation history" + + retry(_run) + + def test_multi_turn(self, model): + """Multi-turn: agent maintains conversation context across turns.""" + + def _run(): + llm = ChatDatabricks(model=model, max_tokens=max_tokens_for_model(model)) + agent = create_react_agent(llm, [add, multiply], checkpointer=MemorySaver()) + config = {"configurable": {"thread_id": f"test-sync-multi-turn-{model}"}} + + response = agent.invoke({"messages": [("human", "What is 10 + 5?")]}, config=config) + last_message = response["messages"][-1] + assert isinstance(last_message, AIMessage) + assert "15" in last_message.content + + response = agent.invoke({"messages": [("human", "Multiply that by 3")]}, config=config) + last_message = response["messages"][-1] + assert isinstance(last_message, AIMessage) + assert "45" in last_message.content + + retry(_run) + + def test_streaming(self, model): + """Streaming: agent streams node updates and tool execution events.""" + + def _run(): + llm = ChatDatabricks(model=model, max_tokens=max_tokens_for_model(model)) + agent = create_react_agent(llm, [add, multiply]) + + events = list( + agent.stream( + { + "messages": [ + ( + "human", + "Use the add tool to compute 10 + 5, then use the multiply tool " + "to multiply the result by 3. You MUST use the tools.", + ) + ] + }, + stream_mode="updates", + ) + ) + + assert len(events) > 0, "No stream events received" + + nodes_seen = set() + for event in events: + nodes_seen.update(event.keys()) + + assert "agent" in nodes_seen, f"Expected 'agent' node, got: {nodes_seen}" + assert "tools" in nodes_seen, f"Expected 'tools' node, got: {nodes_seen}" + + last_event = events[-1] + last_messages = list(last_event.values())[0]["messages"] + assert any("45" in str(m.content) for m in last_messages) + + retry(_run) + + +# ============================================================================= +# Async LangGraph Agent +# ============================================================================= + + +@pytest.mark.integration +@pytest.mark.asyncio +@pytest.mark.parametrize("model", _FOUNDATION_MODELS) +class TestLangGraphAsync: + """Async LangGraph agent tests using ChatDatabricks + create_react_agent.""" + + async def test_single_turn(self, model): + """Single-turn via ainvoke.""" + + async def _run(): + llm = ChatDatabricks(model=model, max_tokens=max_tokens_for_model(model)) + agent = create_react_agent(llm, [add, multiply]) + + response = await agent.ainvoke( + { + "messages": [ + ( + "human", + "Use the add tool to compute 10 + 5, then use the multiply tool " + "to multiply the result by 3. You MUST use the tools.", + ) + ] + } + ) + + last_message = response["messages"][-1] + assert isinstance(last_message, AIMessage) + assert "45" in last_message.content + + tool_messages = [m for m in response["messages"] if isinstance(m, ToolMessage)] + assert len(tool_messages) > 0, "Expected tool calls in conversation history" + + await async_retry(_run) + + async def test_multi_turn(self, model): + """Multi-turn via ainvoke with MemorySaver checkpointer.""" + + async def _run(): + llm = ChatDatabricks(model=model, max_tokens=max_tokens_for_model(model)) + agent = create_react_agent(llm, [add, multiply], checkpointer=MemorySaver()) + config = {"configurable": {"thread_id": f"test-async-multi-turn-{model}"}} + + response = await agent.ainvoke( + {"messages": [("human", "What is 10 + 5?")]}, config=config + ) + last_message = response["messages"][-1] + assert isinstance(last_message, AIMessage) + assert "15" in last_message.content + + response = await agent.ainvoke( + {"messages": [("human", "Multiply that by 3")]}, config=config + ) + last_message = response["messages"][-1] + assert isinstance(last_message, AIMessage) + assert "45" in last_message.content + + await async_retry(_run) + + async def test_streaming(self, model): + """Streaming via astream with updates + messages stream modes.""" + + async def _run(): + llm = ChatDatabricks(model=model, max_tokens=max_tokens_for_model(model)) + agent = create_react_agent(llm, [add, multiply]) + + nodes_seen = set() + got_message_chunks = False + event_count = 0 + + async for event in agent.astream( + { + "messages": [ + ( + "human", + "Use the add tool to compute 10 + 5, then use the multiply tool " + "to multiply the result by 3. You MUST use the tools.", + ) + ] + }, + stream_mode=["updates", "messages"], + ): + event_count += 1 + mode, data = event + if mode == "updates": + nodes_seen.update(data.keys()) + elif mode == "messages": + chunk, _metadata = data + if isinstance(chunk, AIMessageChunk): + got_message_chunks = True + + assert event_count > 0, "No stream events received" + assert "agent" in nodes_seen, f"Expected 'agent' node, got: {nodes_seen}" + assert "tools" in nodes_seen, f"Expected 'tools' node, got: {nodes_seen}" + assert got_message_chunks, "Expected AIMessageChunk tokens in message stream" + + await async_retry(_run) diff --git a/integrations/openai/src/databricks_openai/utils/clients.py b/integrations/openai/src/databricks_openai/utils/clients.py index 7d138ee8..70044482 100644 --- a/integrations/openai/src/databricks_openai/utils/clients.py +++ b/integrations/openai/src/databricks_openai/utils/clients.py @@ -1,4 +1,5 @@ import os +from collections.abc import AsyncIterator, Iterator from typing import Any, Generator from databricks.sdk import WorkspaceClient @@ -61,6 +62,77 @@ def _should_strip_strict(model: str | None) -> bool: return "gpt" not in model.lower() +# Gemini FMAPI compat: rejects list content in tool messages (request side) +# and returns list content in responses (response side). We flatten both. +# Note: Gemini 3.x requires thought_signature which is an Agents SDK issue, not fixable here. + + +def _is_gemini_model(model: str | None) -> bool: + """Returns True if the model is a Gemini variant.""" + if not model: + return False + return "gemini" in model.lower() or "gemma" in model.lower() + + +def _fix_gemini_messages(messages: Any) -> None: + """Flatten list content in outbound tool messages for Gemini.""" + if not messages: + return + for msg in messages: + if ( + isinstance(msg, dict) + and msg.get("role") == "tool" + and isinstance(msg.get("content"), list) + ): + msg["content"] = "".join( + p.get("text", "") if isinstance(p, dict) else getattr(p, "text", p) + for p in msg["content"] + ) + + +def _fix_gemini_content(response: Any) -> None: + """Flatten list content in response messages/deltas for Gemini.""" + if not hasattr(response, "choices"): + return + for choice in response.choices: + obj = getattr(choice, "message", None) or getattr(choice, "delta", None) + if obj is not None and isinstance(getattr(obj, "content", None), list): + obj.content = "".join( + p.get("text", "") if isinstance(p, dict) else getattr(p, "text", p) + for p in obj.content + ) + + +class _GeminiStreamWrapper: + """Wraps a sync Stream, flattening list content in each chunk.""" + + def __init__(self, stream: Any): + self._stream = stream + + def __iter__(self) -> Iterator: + for chunk in self._stream: + _fix_gemini_content(chunk) + yield chunk + + def __getattr__(self, name: str): + return getattr(self._stream, name) + + +class _AsyncGeminiStreamWrapper: + """Wraps an AsyncStream, flattening list content in each chunk.""" + + def __init__(self, stream: Any): + self._stream = stream + + async def __aiter__(self) -> AsyncIterator: + async for chunk in self._stream: + _fix_gemini_content(chunk) + yield chunk + + def __getattr__(self, name: str): + return getattr(self._stream, name) + + def _is_claude_model(model: str | None) -> bool: """Returns True if the model is a Claude variant.""" if not model: @@ -199,7 +271,14 @@ def create(self, **kwargs): _strip_strict_from_tools(kwargs.get("tools")) if _is_claude_model(model): _fix_empty_assistant_content_in_messages(kwargs.get("messages")) - return super().create(**kwargs) + if _is_gemini_model(model): + _fix_gemini_messages(kwargs.get("messages")) + response = super().create(**kwargs) + if _is_gemini_model(model): + if kwargs.get("stream"): + return _GeminiStreamWrapper(response) + _fix_gemini_content(response) + return response class DatabricksChat(Chat): @@ -346,7 +425,14 @@ async def create(self, **kwargs): _strip_strict_from_tools(kwargs.get("tools")) if _is_claude_model(model): _fix_empty_assistant_content_in_messages(kwargs.get("messages")) - return await super().create(**kwargs) + if _is_gemini_model(model): + _fix_gemini_messages(kwargs.get("messages")) + response = await super().create(**kwargs) + if _is_gemini_model(model): + if kwargs.get("stream"): + return _AsyncGeminiStreamWrapper(response) + _fix_gemini_content(response) + return response class AsyncDatabricksChat(AsyncChat): diff --git a/integrations/openai/tests/integration_tests/test_fmapi_tool_calling.py b/integrations/openai/tests/integration_tests/test_fmapi_tool_calling.py new file mode 100644 index 00000000..2207239a --- /dev/null +++ b/integrations/openai/tests/integration_tests/test_fmapi_tool_calling.py @@ -0,0 +1,348 @@ +""" +End-to-end FMAPI tool calling tests for DatabricksOpenAI (sync + async via Agents SDK). + +Prerequisites: +- FMAPI endpoints must be available on the test workspace +- echo_message UC function in integration_testing.databricks_ai_bridge_mcp_test +""" + +from __future__ import annotations + +import json +import os + +import pytest +from databricks.sdk import WorkspaceClient +from databricks_ai_bridge.test_utils.fmapi import ( + COMMON_SKIP_MODELS, + async_retry, + discover_foundation_models, + retry, +) + +from databricks_openai import AsyncDatabricksOpenAI, DatabricksOpenAI + +pytestmark = pytest.mark.skipif( + os.environ.get("RUN_FMAPI_TOOL_CALLING_TESTS") != "1", + reason="FMAPI tool calling tests disabled. Set RUN_FMAPI_TOOL_CALLING_TESTS=1 to enable.", +) + +_FOUNDATION_MODELS = discover_foundation_models(COMMON_SKIP_MODELS) + + +# MCP test infrastructure +_MCP_CATALOG = "integration_testing" +_MCP_SCHEMA = "databricks_ai_bridge_mcp_test" +_MCP_FUNCTION = "echo_message" + + +@pytest.fixture(scope="module") +def workspace_client(): + return WorkspaceClient() + + +@pytest.fixture(scope="module") +def sync_client(workspace_client): + return DatabricksOpenAI(workspace_client=workspace_client) + + +@pytest.fixture(scope="module") +def async_client(workspace_client): + return AsyncDatabricksOpenAI(workspace_client=workspace_client) + + +# ============================================================================= +# Async DatabricksOpenAI +# ============================================================================= + + +@pytest.mark.integration +@pytest.mark.asyncio +@pytest.mark.parametrize("model", _FOUNDATION_MODELS) +class TestAgentToolCalling: + """Async agent tests using the OpenAI Agents SDK with MCP tools. + + Pattern: AsyncDatabricksOpenAI -> McpServer -> Agent -> Runner.run + """ + + async def test_single_turn(self, async_client, workspace_client, model): + """Single-turn: user sends one message, agent calls a tool and responds.""" + from agents import Agent, Runner, set_default_openai_api, set_default_openai_client + from agents.items import MessageOutputItem, ToolCallItem, ToolCallOutputItem + + from databricks_openai.agents import McpServer + + async def _run(): + set_default_openai_client(async_client) + set_default_openai_api("chat_completions") + + async with McpServer.from_uc_function( + catalog=_MCP_CATALOG, + schema=_MCP_SCHEMA, + function_name=_MCP_FUNCTION, + workspace_client=workspace_client, + timeout=60, + ) as server: + agent = Agent( + name="echo-agent", + instructions="Use the echo_message tool to echo messages when asked.", + model=model, + mcp_servers=[server], + ) + result = await Runner.run(agent, "Echo the message 'hello from FMAPI test'") + + assert result.final_output is not None + assert "hello from FMAPI test" in result.final_output + + item_types = [type(item) for item in result.new_items] + assert ToolCallItem in item_types, f"Expected a tool call, got: {item_types}" + assert ToolCallOutputItem in item_types, f"Expected tool output, got: {item_types}" + assert MessageOutputItem in item_types, f"Expected a message, got: {item_types}" + + input_list = result.to_input_list() + assert len(input_list) > 1, "Expected multi-item conversation history" + + await async_retry(_run) + + async def test_multi_turn(self, async_client, workspace_client, model): + """Multi-turn: two-turn conversation with full history replay.""" + from agents import Agent, Runner, set_default_openai_api, set_default_openai_client + from agents.items import ToolCallItem + + from databricks_openai.agents import McpServer + + async def _run(): + set_default_openai_client(async_client) + set_default_openai_api("chat_completions") + + async with McpServer.from_uc_function( + catalog=_MCP_CATALOG, + schema=_MCP_SCHEMA, + function_name=_MCP_FUNCTION, + workspace_client=workspace_client, + timeout=60, + ) as server: + agent = Agent( + name="echo-agent", + instructions="Use the echo_message tool to echo messages when asked.", + model=model, + mcp_servers=[server], + ) + + first_result = await Runner.run(agent, "Echo the message 'hello'") + assert first_result.final_output is not None + assert "hello" in first_result.final_output + + first_item_types = [type(item) for item in first_result.new_items] + assert ToolCallItem in first_item_types + + history = first_result.to_input_list() + history.append({"role": "user", "content": "Now echo the message 'world'"}) + + second_result = await Runner.run(agent, history) + assert second_result.final_output is not None + assert "world" in second_result.final_output + + second_item_types = [type(item) for item in second_result.new_items] + assert ToolCallItem in second_item_types + + second_history = second_result.to_input_list() + assert len(second_history) > len(history), ( + f"Expected history to grow: {len(history)} -> {len(second_history)}" + ) + + await async_retry(_run) + + async def test_streaming(self, async_client, workspace_client, model): + """Streaming: verify stream events via Runner.run_streamed().""" + from agents import Agent, Runner, set_default_openai_api, set_default_openai_client + from agents.stream_events import RunItemStreamEvent + + from databricks_openai.agents import McpServer + + async def _run(): + set_default_openai_client(async_client) + set_default_openai_api("chat_completions") + + async with McpServer.from_uc_function( + catalog=_MCP_CATALOG, + schema=_MCP_SCHEMA, + function_name=_MCP_FUNCTION, + workspace_client=workspace_client, + timeout=60, + ) as server: + agent = Agent( + name="echo-agent", + instructions="Use the echo_message tool to echo messages when asked.", + model=model, + mcp_servers=[server], + ) + result = Runner.run_streamed(agent, input="Echo the message 'streaming test'") + + run_item_events = [] + event_count = 0 + async for event in result.stream_events(): + event_count += 1 + if isinstance(event, RunItemStreamEvent): + run_item_events.append(event) + + assert event_count > 0, "No stream events received" + + event_names = [e.name for e in run_item_events] + assert "tool_called" in event_names, ( + f"Expected tool_called event, got: {event_names}" + ) + assert "tool_output" in event_names, ( + f"Expected tool_output event, got: {event_names}" + ) + assert "message_output_created" in event_names, ( + f"Expected message_output_created event, got: {event_names}" + ) + + assert result.final_output is not None + assert "streaming test" in result.final_output + + await async_retry(_run) + + +# ============================================================================= +# Sync DatabricksOpenAI — direct chat.completions.create() +# ============================================================================= + +# echo_message tool definition for direct chat.completions.create() tests +_ECHO_MESSAGE_TOOL = { + "type": "function", + "function": { + "name": "echo_message", + "description": "Echo back the provided message", + "parameters": { + "type": "object", + "properties": { + "message": { + "type": "string", + "description": "The message to echo back", + } + }, + "required": ["message"], + }, + }, +} + + +@pytest.mark.integration +@pytest.mark.parametrize("model", _FOUNDATION_MODELS) +class TestSyncClientToolCalling: + """Sync DatabricksOpenAI tests using direct chat.completions.create().""" + + def test_single_turn(self, sync_client, model): + """Single-turn: model receives tool, produces a tool call.""" + + def _run(): + response = sync_client.chat.completions.create( + model=model, + messages=[{"role": "user", "content": "Echo the message 'sync hello'"}], + tools=[_ECHO_MESSAGE_TOOL], + max_tokens=200, + ) + message = response.choices[0].message + assert message.tool_calls is not None + assert len(message.tool_calls) >= 1 + + tool_call = message.tool_calls[0] + assert tool_call.id is not None + assert tool_call.type == "function" + assert tool_call.function.name == "echo_message" + args = json.loads(tool_call.function.arguments) + assert "message" in args + + retry(_run) + + def test_multi_turn(self, sync_client, model): + """Multi-turn: tool_call -> tool result -> text response.""" + + def _run(): + # Turn 1: get tool call + response = sync_client.chat.completions.create( + model=model, + messages=[{"role": "user", "content": "Echo the message 'sync world'"}], + tools=[_ECHO_MESSAGE_TOOL], + max_tokens=200, + ) + assistant_msg = response.choices[0].message + assert assistant_msg.tool_calls is not None + tool_call = assistant_msg.tool_calls[0] + + # Turn 2: send tool result back, get text response + # Manually construct the assistant message to avoid extra fields + # (e.g. "annotations") that model_dump() includes but FMAPI rejects + assistant_dict = { + "role": "assistant", + "content": assistant_msg.content or "", + "tool_calls": [ + { + "id": tc.id, + "type": tc.type, + "function": { + "name": tc.function.name, + "arguments": tc.function.arguments, + }, + } + for tc in assistant_msg.tool_calls + ], + } + response = sync_client.chat.completions.create( + model=model, + messages=[ + {"role": "user", "content": "Echo the message 'sync world'"}, + assistant_dict, + { + "role": "tool", + "tool_call_id": tool_call.id, + "content": "sync world", + }, + ], + tools=[_ECHO_MESSAGE_TOOL], + max_tokens=200, + ) + followup = response.choices[0].message + assert followup.content is not None + assert "sync world" in followup.content + + retry(_run) + + def test_streaming(self, sync_client, model): + """Streaming: tool call arrives as chunked deltas.""" + + def _run(): + stream = sync_client.chat.completions.create( + model=model, + messages=[{"role": "user", "content": "Echo the message 'sync stream'"}], + tools=[_ECHO_MESSAGE_TOOL], + max_tokens=200, + stream=True, + ) + chunks = list(stream) + assert len(chunks) > 0 + + # Reassemble tool call from streamed deltas + tool_call_name = "" + tool_call_args = "" + tool_call_id = None + for chunk in chunks: + delta = chunk.choices[0].delta if chunk.choices else None + if delta and delta.tool_calls: + tc = delta.tool_calls[0] + if tc.id: + tool_call_id = tc.id + if tc.function: + if tc.function.name: + tool_call_name += tc.function.name + if tc.function.arguments: + tool_call_args += tc.function.arguments + + assert tool_call_id is not None, "No tool call ID found in stream" + assert tool_call_name == "echo_message" + args = json.loads(tool_call_args) + assert "message" in args + + retry(_run) diff --git a/src/databricks_ai_bridge/test_utils/fmapi.py b/src/databricks_ai_bridge/test_utils/fmapi.py new file mode 100644 index 00000000..55c3c235 --- /dev/null +++ b/src/databricks_ai_bridge/test_utils/fmapi.py @@ -0,0 +1,130 @@ +"""Shared test utilities for FMAPI tool calling integration tests. + +Used by both databricks-openai and databricks-langchain test suites. +""" + +from __future__ import annotations + +import logging + +from databricks.sdk import WorkspaceClient + +log = logging.getLogger(__name__) + +# Max retries for flaky models (e.g. transient FMAPI errors, model non-determinism) +MAX_RETRIES = 3 + +# Default max_tokens for test requests +DEFAULT_MAX_TOKENS = 200 + +# Models shared across both OpenAI and LangChain skip lists +COMMON_SKIP_MODELS = { + "databricks-gpt-5-nano", # too small for reliable tool calling + "databricks-gpt-oss-20b", # hallucinates tool names in agent loop + "databricks-gpt-oss-120b", # hallucinates tool names in agent loop + "databricks-llama-4-maverick", # hallucinates tool names in agent loop + "databricks-gemini-3-flash", # requires thought_signature on function calls + "databricks-gemini-3-pro", # requires thought_signature on function calls + "databricks-gemini-3-1-pro", # requires thought_signature on function calls + "databricks-gpt-5-1-codex-max", # Responses API only, no Chat Completions support + "databricks-gpt-5-1-codex-mini", # Responses API only, no Chat Completions support + "databricks-gpt-5-2-codex", # Responses API only, no Chat Completions support + "databricks-gpt-5-3-codex", # Responses API only, no Chat Completions support +} + +# Additional models skipped only in LangChain tests +LANGCHAIN_SKIP_MODELS = COMMON_SKIP_MODELS | { + "databricks-gemma-3-12b", # outputs raw tool call text instead of executing tools +} + +# Reasoning models consume reasoning tokens from the max_tokens budget. +# Gemini 2.5 Pro needs 200-600 reasoning tokens with 2 tools, so 200 is too small. +MODEL_MAX_TOKENS: dict[str, int] = { + "databricks-gemini-2-5-pro": 1000, +} + + +def max_tokens_for_model(model: str) -> int: + """Return appropriate max_tokens for a model, accounting for reasoning token overhead.""" + return MODEL_MAX_TOKENS.get(model, DEFAULT_MAX_TOKENS) + + +# Fallback list if dynamic discovery fails (e.g. auth not configured at collection time) +FALLBACK_MODELS = [ + "databricks-claude-sonnet-4-6", + "databricks-claude-opus-4-6", + "databricks-meta-llama-3-3-70b-instruct", + "databricks-gpt-5-2", + "databricks-gpt-5-1", + "databricks-qwen3-next-80b-a3b-instruct", +] + + +def has_function_calling(w: WorkspaceClient, endpoint_name: str) -> bool: + """Check if an endpoint supports function calling via the capabilities API.""" + try: + resp = w.api_client.do("GET", f"/api/2.0/serving-endpoints/{endpoint_name}") + return resp.get("capabilities", {}).get("function_calling", False) # type: ignore[invalid-assignment] + except Exception: + return False + + +def discover_foundation_models(skip_models: set[str]) -> list[str]: + """Discover all FMAPI chat models that support tool calling. + + 1. List all serving endpoints with databricks- prefix and llm/v1/chat task + 2. Check capabilities.function_calling via the serving-endpoints API + 3. Models in skip_models are excluded entirely + """ + try: + w = WorkspaceClient() + endpoints = list(w.serving_endpoints.list()) + except Exception as exc: + log.warning("Could not discover FMAPI models, using fallback list: %s", exc) + return FALLBACK_MODELS + + chat_endpoints = [ + e + for e in endpoints + if e.name and e.name.startswith("databricks-") and e.task == "llm/v1/chat" + ] + + models = [] + for e in sorted(chat_endpoints, key=lambda e: e.name or ""): + name = e.name or "" + if not has_function_calling(w, name): + log.info("Skipping %s: does not support function calling", name) + continue + if name in skip_models: + log.info("Skipping %s: in skip list", name) + continue + models.append(name) + + log.info("Discovered %d FMAPI models with function calling support", len(models)) + return models + + +def retry(fn, retries=MAX_RETRIES): + """Retry a test function up to `retries` times. Only fails if all attempts fail.""" + last_exc = None + for attempt in range(retries): + try: + return fn() + except Exception as exc: + last_exc = exc + if attempt < retries - 1: + log.warning("Attempt %d/%d failed: %s — retrying", attempt + 1, retries, exc) + raise last_exc # type: ignore[misc] + + +async def async_retry(fn, retries=MAX_RETRIES): + """Retry an async test function up to `retries` times.""" + last_exc = None + for attempt in range(retries): + try: + return await fn() + except Exception as exc: + last_exc = exc + if attempt < retries - 1: + log.warning("Attempt %d/%d failed: %s — retrying", attempt + 1, retries, exc) + raise last_exc # type: ignore[misc]