From ee025b44b662f4f57f455d86ec264ce0b5803350 Mon Sep 17 00:00:00 2001 From: jjovalle99 Date: Sat, 4 Apr 2026 17:47:10 +0100 Subject: [PATCH] fix: propagate parameter descriptions in create_tool_call _get_function_parameters mutates field_info.description after FieldInfo construction. Pydantic v2 ignores this because _attributes_set is not updated. All docstring-derived parameter descriptions are silently dropped from the generated tool schema. For fresh FieldInfos, pass description to Field() at construction. For existing FieldInfos, override via Annotated stacking (public API). --- src/mistralai/extra/run/tools.py | 27 ++- src/mistralai/extra/tests/test_tools.py | 288 ++++++++++++++++++++++++ 2 files changed, 307 insertions(+), 8 deletions(-) create mode 100644 src/mistralai/extra/tests/test_tools.py diff --git a/src/mistralai/extra/run/tools.py b/src/mistralai/extra/run/tools.py index 95dc21a9..7ba7ff70 100644 --- a/src/mistralai/extra/run/tools.py +++ b/src/mistralai/extra/run/tools.py @@ -3,7 +3,7 @@ import json import logging from dataclasses import dataclass -from typing import Any, Callable, ForwardRef, Sequence, cast, get_type_hints +from typing import Annotated, Any, Callable, ForwardRef, Sequence, cast, get_type_hints import opentelemetry.semconv._incubating.attributes.gen_ai_attributes as gen_ai_attributes from griffe import ( @@ -17,6 +17,7 @@ from opentelemetry.trace import Status, StatusCode from pydantic import Field, create_model from pydantic.fields import FieldInfo +from pydantic_core import PydanticUndefined as _PYDANTIC_UNDEFINED from mistralai.client.models import ( Function, @@ -95,7 +96,7 @@ def _get_function_parameters( param_annotations[param.name] = type_hints.get(param.name) # resolve all params into Field and create the parameters schema - fields: dict[str, tuple[type, FieldInfo]] = {} + fields: dict[str, Any] = {} for p in params_from_sig: default = p.default if p.default is not inspect.Parameter.empty else ... annotation = ( @@ -127,15 +128,25 @@ def _get_function_parameters( if isinstance(annotation, ForwardRef): annotation = param_annotations[p.name] - # no Field + description = param_descriptions[p.name] or None + if field_info is None: if default is ...: - field_info = Field() + field_info = Field(description=description) else: - field_info = Field(default=default) - - field_info.description = param_descriptions[p.name] - fields[p.name] = (cast(type, annotation), field_info) + field_info = Field(default=default, description=description) + fields[p.name] = (cast(type, annotation), field_info) + elif description: + typed = Annotated[ # type: ignore[valid-type] + cast(type, annotation), field_info, Field(description=description) + ] + raw_default = field_info.default + if raw_default is not _PYDANTIC_UNDEFINED: + fields[p.name] = (typed, raw_default) + else: + fields[p.name] = (typed, ...) + else: + fields[p.name] = (cast(type, annotation), field_info) schema = create_model("_", **fields).model_json_schema() # type: ignore[call-overload] schema.pop("title", None) diff --git a/src/mistralai/extra/tests/test_tools.py b/src/mistralai/extra/tests/test_tools.py new file mode 100644 index 00000000..46d81ea0 --- /dev/null +++ b/src/mistralai/extra/tests/test_tools.py @@ -0,0 +1,288 @@ +"""Unit tests for create_tool_call parameter description propagation. + +Validates that parameter descriptions from docstrings and Annotated[T, Field(...)] +annotations correctly appear in the JSON schema produced by create_tool_call(). + +This is a regression test for a Pydantic v2 bug where post-construction mutation +of FieldInfo.description is silently ignored by model_json_schema(). + +Fixtures are defined inline so each test is self-contained. +""" + +import unittest +from typing import Annotated, Optional + +from pydantic import Field + +from ..run.tools import create_tool_call + + +def _props(func): + """Shorthand: create a tool call and return its parameter properties.""" + return create_tool_call(func).function.parameters["properties"] + + +class TestCreateToolCallDescriptions(unittest.TestCase): + """Descriptions from docstrings must appear in the generated JSON schema.""" + + # -- Docstring descriptions (Path 3: no existing FieldInfo) ---------------- + + def test_required_param_gets_docstring_description(self): + def search(query: str) -> str: + """Search the web. + + Args: + query: The search query to execute. + """ + return "" + + props = _props(search) + self.assertEqual(props["query"]["description"], "The search query to execute.") + + def test_optional_param_with_default_gets_docstring_description(self): + def search(query: str, limit: int = 10) -> str: + """Search the web. + + Args: + query: The search query. + limit: Maximum number of results. + """ + return "" + + props = _props(search) + self.assertEqual(props["limit"]["description"], "Maximum number of results.") + self.assertEqual(props["limit"]["default"], 10) + + def test_multiple_params_all_get_descriptions(self): + def fetch(url: str, timeout: int = 30, verbose: bool = False) -> str: + """Fetch a URL. + + Args: + url: The URL to fetch. + timeout: Request timeout in seconds. + verbose: Enable verbose logging. + """ + return "" + + props = _props(fetch) + self.assertEqual(props["url"]["description"], "The URL to fetch.") + self.assertEqual(props["timeout"]["description"], "Request timeout in seconds.") + self.assertEqual(props["verbose"]["description"], "Enable verbose logging.") + + # -- Annotated + docstring (Path 2: existing FieldInfo) -------------------- + + def test_annotated_field_description_overridden_by_docstring(self): + def search(query: Annotated[str, Field(description="original")]) -> str: + """Search. + + Args: + query: From docstring. + """ + return "" + + props = _props(search) + self.assertEqual(props["query"]["description"], "From docstring.") + + def test_annotated_field_description_preserved_when_no_docstring_entry(self): + """When the docstring has no Args entry for a param, the Field(description=...) + from Annotated must be preserved, not clobbered with empty string.""" + + def search(query: Annotated[str, Field(description="keep me")]) -> str: + """Search the web.""" + return "" + + props = _props(search) + self.assertEqual(props["query"]["description"], "keep me") + + def test_annotated_field_constraints_preserved_with_docstring(self): + def count(n: Annotated[int, Field(ge=0, le=100)]) -> str: + """Count items. + + Args: + n: Number of items. + """ + return "" + + props = _props(count) + self.assertEqual(props["n"]["description"], "Number of items.") + self.assertEqual(props["n"]["minimum"], 0) + self.assertEqual(props["n"]["maximum"], 100) + + def test_annotated_field_constraints_preserved_without_docstring_entry(self): + def count( + n: Annotated[int, Field(ge=0, le=100, description="original")], + ) -> str: + """Count items.""" + return "" + + props = _props(count) + self.assertEqual(props["n"]["description"], "original") + self.assertEqual(props["n"]["minimum"], 0) + self.assertEqual(props["n"]["maximum"], 100) + + # -- Field as default value (Path 1: isinstance(default, FieldInfo)) ------- + + def test_field_default_value_with_docstring(self): + def search(query: str, limit: int = Field(default=10, ge=1)) -> str: + """Search. + + Args: + query: The query. + limit: Max results. + """ + return "" + + props = _props(search) + self.assertEqual(props["limit"]["description"], "Max results.") + self.assertEqual(props["limit"]["default"], 10) + self.assertEqual(props["limit"]["minimum"], 1) + + def test_field_default_value_without_docstring_entry(self): + """Field(default=..., ge=...) without a docstring entry should preserve + constraints and not inject a spurious empty description.""" + + def search(query: str, limit: int = Field(default=10, ge=1)) -> str: + """Search. + + Args: + query: The query. + """ + return "" + + props = _props(search) + self.assertEqual(props["limit"]["default"], 10) + self.assertEqual(props["limit"]["minimum"], 1) + + # -- Edge cases ------------------------------------------------------------ + + def test_undocumented_param_has_no_description_key(self): + """Params without any docstring entry or Field description should not + have a description key in the schema (not even an empty string).""" + + def search(query: str) -> str: + """Search the web.""" + return "" + + props = _props(search) + self.assertIn("query", props) + self.assertNotIn("description", props["query"]) + + def test_required_params_in_required_list(self): + def search(query: str, limit: int = 10) -> str: + """Search. + + Args: + query: The query. + limit: Max results. + """ + return "" + + tool = create_tool_call(search) + required = tool.function.parameters.get("required", []) + self.assertIn("query", required) + self.assertNotIn("limit", required) + + def test_optional_type_annotation(self): + def search(query: str, tag: Optional[str] = None) -> str: + """Search. + + Args: + query: The query. + tag: Optional tag filter. + """ + return "" + + props = _props(search) + self.assertEqual(props["tag"]["description"], "Optional tag filter.") + + def test_list_type_annotation(self): + def search(queries: list[str]) -> str: + """Batch search. + + Args: + queries: List of search queries. + """ + return "" + + props = _props(search) + self.assertEqual(props["queries"]["description"], "List of search queries.") + + def test_function_level_description(self): + def search(query: str) -> str: + """Search the web for information. + + Args: + query: The search query. + """ + return "" + + tool = create_tool_call(search) + self.assertEqual(tool.function.description, "Search the web for information.") + + def test_no_docstring_at_all(self): + def search(query: str) -> str: + return "" + + tool = create_tool_call(search) + self.assertIsNotNone(tool.function.parameters) + self.assertIn("query", tool.function.parameters["properties"]) + + def test_shared_field_info_no_cross_contamination(self): + """Two functions sharing the same FieldInfo instance via Annotated must + not cross-contaminate descriptions.""" + + shared_field = Field(ge=0) + + def func_a(n: Annotated[int, shared_field]) -> str: + """A. + + Args: + n: Description A. + """ + return "" + + def func_b(n: Annotated[int, shared_field]) -> str: + """B. + + Args: + n: Description B. + """ + return "" + + props_a = _props(func_a) + props_b = _props(func_b) + self.assertEqual(props_a["n"]["description"], "Description A.") + self.assertEqual(props_b["n"]["description"], "Description B.") + # Calling func_a again after func_b must still produce "Description A." + props_a_again = _props(func_a) + self.assertEqual(props_a_again["n"]["description"], "Description A.") + # Original shared instance must be unmodified + self.assertIsNone(shared_field.description) + + +class TestCreateToolCallRegressionPydanticV2(unittest.TestCase): + """Regression: post-construction FieldInfo.description mutation is broken in Pydantic v2.""" + + def test_description_appears_in_schema_not_silently_dropped(self): + """The original bug: docstring descriptions were silently dropped from the + JSON schema because FieldInfo.description was mutated after construction, + which Pydantic v2 ignores in model_json_schema().""" + + def get_weather(city: str, units: str = "celsius") -> str: + """Get weather for a city. + + Args: + city: The city name. + units: Temperature units. + """ + return "" + + tool = create_tool_call(get_weather) + props = tool.function.parameters["properties"] + self.assertEqual(props["city"]["description"], "The city name.") + self.assertEqual(props["units"]["description"], "Temperature units.") + self.assertEqual(props["units"]["default"], "celsius") + + +if __name__ == "__main__": + unittest.main()