Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 70 additions & 2 deletions src/smolagents/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,22 @@ def new_init(self, *args, **kwargs):

CONVERSION_DICT = {"str": "string", "int": "integer", "float": "number"}

# Mapping from JSON schema types to Python type hint strings
# This is the reverse of CONVERSION_DICT and _BASE_TYPE_MAPPING from _function_type_hints_utils
JSON_SCHEMA_TO_PYTHON_TYPE = {
"string": "str",
"integer": "int",
"number": "float",
"boolean": "bool",
"array": "list",
"object": "dict",
"any": "Any",
"null": "None",
# Special types that don't have direct Python equivalents
"image": "image",
"audio": "audio",
}


class BaseTool(ABC):
name: str
Expand Down Expand Up @@ -255,12 +271,64 @@ def setup(self):
"""
self.is_initialized = True

def _schema_to_python_type(self, schema: dict | str) -> str:
"""
Convert a JSON schema type to a Python type hint string.

This method recursively converts JSON schema type definitions into Python type hint
strings that can be used in function signatures. It handles simple types, arrays,
objects, unions, and nested structures.

Args:
schema: JSON schema dictionary (e.g., {"type": "string"}) or type string (e.g., "string")

Returns:
Python type hint string (e.g., "str", "list[int]", "dict[str, str]")

Examples:
>>> self._schema_to_python_type("string")
'str'
>>> self._schema_to_python_type({"type": "array", "items": {"type": "integer"}})
'list[int]'
>>> self._schema_to_python_type({"type": "object", "additionalProperties": {"type": "string"}})
'dict[str, str]'
"""
# Normalize input: convert string to dict format
if isinstance(schema, str):
schema = {"type": schema}

schema_type = schema.get("type")

# Handle union types (list of types like ["string", "integer"])
if isinstance(schema_type, list):
type_strs = [JSON_SCHEMA_TO_PYTHON_TYPE.get(t, t) for t in schema_type]
return " | ".join(sorted(type_strs))

# Handle array types: recursively process items
if schema_type == "array":
if "items" in schema:
item_type = self._schema_to_python_type(schema["items"])
return f"list[{item_type}]"
return "list"

# Handle object/dict types: recursively process value types
if schema_type == "object":
if "additionalProperties" in schema:
value_type = self._schema_to_python_type(schema["additionalProperties"])
return f"dict[str, {value_type}]"
return "dict"

# Handle all other types using the centralized mapping
return JSON_SCHEMA_TO_PYTHON_TYPE.get(schema_type, schema_type)

def to_code_prompt(self) -> str:
args_signature = ", ".join(f"{arg_name}: {arg_schema['type']}" for arg_name, arg_schema in self.inputs.items())
args_signature = ", ".join(
f"{arg_name}: {self._schema_to_python_type(arg_schema)}" for arg_name, arg_schema in self.inputs.items()
)

# Use dict type for tools with output schema to indicate structured return
has_schema = hasattr(self, "output_schema") and self.output_schema is not None
output_type = "dict" if has_schema else self.output_type
output_type = "dict" if has_schema else self._schema_to_python_type({"type": self.output_type})
tool_signature = f"({args_signature}) -> {output_type}"
tool_doc = self.description

Expand Down
138 changes: 134 additions & 4 deletions tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,18 +126,18 @@ def forward(self, text) -> str:
@pytest.mark.parametrize(
"tool_fixture, expected_output",
[
("no_input_tool", 'def no_input_tool() -> string:\n """Tool with no inputs\n """'),
("no_input_tool", 'def no_input_tool() -> str:\n """Tool with no inputs\n """'),
(
"single_input_tool",
'def single_input_tool(text: string) -> string:\n """Tool with one input\n\n Args:\n text: Input text\n """',
'def single_input_tool(text: str) -> str:\n """Tool with one input\n\n Args:\n text: Input text\n """',
),
(
"multi_input_tool",
'def multi_input_tool(text: string, count: integer) -> object:\n """Tool with multiple inputs\n\n Args:\n text: Text input\n count: Number count\n """',
'def multi_input_tool(text: str, count: int) -> dict:\n """Tool with multiple inputs\n\n Args:\n text: Text input\n count: Number count\n """',
),
(
"multiline_description_tool",
'def multiline_description_tool(input: string) -> string:\n """This is a tool with\n multiple lines\n in the description\n\n Args:\n input: Some input\n """',
'def multiline_description_tool(input: str) -> str:\n """This is a tool with\n multiple lines\n in the description\n\n Args:\n input: Some input\n """',
),
],
)
Expand Down Expand Up @@ -174,6 +174,136 @@ def test_tool_to_tool_calling_prompt_output_format(self, tool_fixture, expected_
tool_calling_prompt = tool.to_tool_calling_prompt()
assert tool_calling_prompt == expected_output

def test_tool_to_code_prompt_array_type_extraction(self):
"""Test that to_code_prompt properly extracts array item types and shows Python type hints."""

# Test 1: Array of strings
@tool
def get_weather(locations: list[str]) -> dict[str, float]:
"""
Get weather at given locations.

Args:
locations: The locations to get the weather for.
"""
return {"temp": 72.5}

code_prompt = get_weather.to_code_prompt()
# Should show list[str] not just array
assert "locations: list[str]" in code_prompt, f"Expected 'locations: list[str]' in output but got: {code_prompt}"
assert "locations: array" not in code_prompt, f"Should not contain 'locations: array' but got: {code_prompt}"
assert "-> dict" in code_prompt, f"Expected '-> dict' in output but got: {code_prompt}"

# Test 2: Array of integers with dict return type
@tool
def process_data(items: list[int], config: dict[str, str]) -> str:
"""
Process data items with configuration.

Args:
items: List of integer items to process
config: Configuration dictionary
"""
return "done"

code_prompt = process_data.to_code_prompt()
# Should show list[int] not just array
assert "items: list[int]" in code_prompt, f"Expected 'items: list[int]' in output but got: {code_prompt}"
assert "config: dict[str, str]" in code_prompt, f"Expected 'config: dict[str, str]' in output but got: {code_prompt}"
assert "-> str" in code_prompt, f"Expected '-> str' in output but got: {code_prompt}"

# Test 3: Tool with simple array (no items specification in manual definition)
class SimpleArrayTool(Tool):
name = "simple_array_tool"
description = "Tool with simple array"
inputs = {
"items": {
"type": "array",
"description": "Some items"
}
}
output_type = "string"

def forward(self, items):
return "done"

simple_tool = SimpleArrayTool()
code_prompt = simple_tool.to_code_prompt()
# Should show list (not array) even without items
assert "items: list" in code_prompt, f"Expected 'items: list' in output but got: {code_prompt}"
assert "items: array" not in code_prompt, f"Should not contain 'items: array' but got: {code_prompt}"

# Test 4: Nested arrays
class NestedArrayTool(Tool):
name = "nested_array_tool"
description = "Tool with nested arrays"
inputs = {
"matrix": {
"type": "array",
"items": {
"type": "array",
"items": {"type": "integer"}
},
"description": "2D matrix"
}
}
output_type = "string"

def forward(self, matrix):
return "done"

nested_tool = NestedArrayTool()
code_prompt = nested_tool.to_code_prompt()
# Should show list[list[int]] for nested arrays
assert "matrix: list[list[int]]" in code_prompt, f"Expected 'matrix: list[list[int]]' in output but got: {code_prompt}"

# Test 5: Array of objects with properties (complex nested structure)
class ComplexArrayTool(Tool):
name = "complex_array_tool"
description = "Tool with array of option objects"
inputs = {
"options": {
"type": "array",
"description": (
"Required for single_choice/multiple_choice. Omit for text input. "
"Array of option objects where each has 'value' and 'label', optionally 'description'."
),
"nullable": True,
"items": {
"type": "object",
"properties": {
"value": {
"type": "string",
"description": "The value to return when selected (e.g., 'grid', 'modern', 'analytics')",
},
"label": {
"type": "string",
"description": (
"Display text shown to user (e.g., 'Grid Layout', 'Modern Style', 'Analytics Dashboard')"
),
},
"description": {
"type": "string",
"description": (
"Optional help text explaining the option (e.g., 'Cards arranged in a responsive grid')"
),
"nullable": True,
},
},
},
}
}
output_type = "string"

def forward(self, options: list[dict] | None = None):
return "done"

complex_tool = ComplexArrayTool()
code_prompt = complex_tool.to_code_prompt()
# Should show list[dict] for array of objects with properties
assert "options: list[dict]" in code_prompt, f"Expected 'options: list[dict]' in output but got: {code_prompt}"
assert "options: array" not in code_prompt, f"Should not contain 'options: array' but got: {code_prompt}"

def test_tool_init_with_decorator(self):
@tool
def coolfunc(a: str, b: int) -> float:
Expand Down