-
-
Notifications
You must be signed in to change notification settings - Fork 11.4k
[CI/Build] Add common tool call parser test suite #27599
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
bbrowning
wants to merge
4
commits into
vllm-project:main
Choose a base branch
from
bbrowning:20251006-tool-parser-tests
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+1,185
−1
Open
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
3dfa4d8
[CI/Build] Add common tool call parser test suite
bbrowning 4485580
[CI/Build] More strict data type checking in common tool parser suite
bbrowning 39fda21
Merge branch 'main' into 20251006-tool-parser-tests
DarkLight1337 ffdc985
Merge remote-tracking branch 'upstream/main' into 20251006-tool-parse…
bbrowning File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,378 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
|
|
||
| import json | ||
| from dataclasses import dataclass, field | ||
| from types import NoneType | ||
| from typing import Any | ||
|
|
||
| import pytest | ||
|
|
||
| from tests.entrypoints.openai.tool_parsers.utils import run_tool_extraction | ||
| from vllm.entrypoints.openai.tool_parsers import ToolParserManager | ||
| from vllm.transformers_utils.tokenizer import AnyTokenizer | ||
|
|
||
|
|
||
| @dataclass | ||
| class ToolParserTestConfig: | ||
| """Configuration for a tool parser's common tests. | ||
|
|
||
| This dataclass contains all the test data and expected results needed | ||
| to run the common test suite for a parser. Each parser test file | ||
| creates one instance of this config with parser-specific values. | ||
|
|
||
| Attributes: | ||
| parser_name: Name used with ToolParserManager (e.g., "mistral") | ||
|
|
||
| Test data (model outputs): | ||
| no_tool_calls_output: Plain text without any tool syntax | ||
| single_tool_call_output: One tool call with simple arguments | ||
| parallel_tool_calls_output: Multiple tool calls in one response | ||
| various_data_types_output: Tool with various data types | ||
| empty_arguments_output: Tool call with no parameters | ||
| surrounding_text_output: Tool call mixed with regular text | ||
| escaped_strings_output: Tool call with escaped chars | ||
| malformed_input_outputs: List of invalid inputs | ||
|
|
||
| Expected results: | ||
| single_tool_call_expected_name: Expected function name | ||
| single_tool_call_expected_args: Expected arguments dict | ||
| parallel_tool_calls_count: Number of tools in parallel test | ||
| parallel_tool_calls_names: Function names in order | ||
| single_tool_call_expected_content: Content field when tool called | ||
| parallel_tool_calls_expected_content: Content for parallel test | ||
|
|
||
| xfail markers: | ||
| xfail_streaming: Mapping test name to xfail reason (streaming only) | ||
| xfail_nonstreaming: Mapping test name to xfail reason (non-streaming) | ||
|
|
||
| Special flags: | ||
| allow_empty_or_json_empty_args: True if "" or "{}" both valid for empty args | ||
| supports_typed_arguments: True if the parser supports typed function arguments | ||
| """ | ||
|
|
||
| # Parser identification | ||
| parser_name: str | ||
|
|
||
| # Test data - model outputs for each common test | ||
| no_tool_calls_output: str | ||
| single_tool_call_output: str | ||
| parallel_tool_calls_output: str | ||
| various_data_types_output: str | ||
| empty_arguments_output: str | ||
| surrounding_text_output: str | ||
| escaped_strings_output: str | ||
| malformed_input_outputs: list[str] | ||
|
|
||
| # Expected results for specific tests (optional overrides) | ||
| single_tool_call_expected_name: str = "get_weather" | ||
| single_tool_call_expected_args: dict[str, Any] = field( | ||
| default_factory=lambda: {"city": "Tokyo"} | ||
| ) | ||
| parallel_tool_calls_count: int = 2 | ||
| parallel_tool_calls_names: list[str] = field( | ||
| default_factory=lambda: ["get_weather", "get_time"] | ||
| ) | ||
|
|
||
| # xfail configuration - maps test name to xfail reason | ||
| xfail_streaming: dict[str, str] = field(default_factory=dict) | ||
| xfail_nonstreaming: dict[str, str] = field(default_factory=dict) | ||
|
|
||
| # Content expectations (some parsers strip content, others don't) | ||
| single_tool_call_expected_content: str | None = None | ||
| parallel_tool_calls_expected_content: str | None = None | ||
|
|
||
| # Special assertions for edge cases | ||
| allow_empty_or_json_empty_args: bool = True # "{}" or "" for empty args | ||
| supports_typed_arguments: bool = True | ||
|
|
||
|
|
||
| class ToolParserTests: | ||
| """Mixin class providing common test suite for tool parsers. | ||
|
|
||
| To use this mixin in a parser test file: | ||
|
|
||
| 1. Create a test_config fixture that returns a ToolParserTestConfig instance | ||
| 2. Inherit from this class | ||
| 3. Add parser-specific tests as additional methods | ||
|
|
||
| Example: | ||
| class TestMistralToolParser(ToolParserTests): | ||
| @pytest.fixture | ||
| def test_config(self) -> ToolParserTestConfig: | ||
| return ToolParserTestConfig( | ||
| parser_name="mistral", | ||
| no_tool_calls_output="Plain text...", | ||
| # ... other config ... | ||
| ) | ||
|
|
||
| # Parser-specific tests | ||
| def test_mistral_specific_feature(self, tool_parser): | ||
| # Custom test logic | ||
| pass | ||
| """ | ||
|
|
||
| @pytest.fixture | ||
| def test_config(self) -> ToolParserTestConfig: | ||
| """Override this to provide parser-specific configuration.""" | ||
| raise NotImplementedError( | ||
| "Subclass must provide test_config fixture returning ToolParserTestConfig" | ||
| ) | ||
|
|
||
| @pytest.fixture | ||
| def tokenizer(self, default_tokenizer: AnyTokenizer) -> AnyTokenizer: | ||
| """Override this to provide parser-specific tokenizer.""" | ||
| return default_tokenizer | ||
|
|
||
| @pytest.fixture | ||
| def tool_parser(self, test_config: ToolParserTestConfig, tokenizer: AnyTokenizer): | ||
| return ToolParserManager.get_tool_parser(test_config.parser_name)(tokenizer) | ||
|
|
||
| @pytest.fixture(params=[True, False]) | ||
| def streaming(self, request: pytest.FixtureRequest) -> bool: | ||
| return request.param | ||
|
|
||
| def test_no_tool_calls( | ||
| self, | ||
| request: pytest.FixtureRequest, | ||
| tool_parser: Any, | ||
| test_config: ToolParserTestConfig, | ||
| streaming: bool, | ||
| ): | ||
| """Verify parser handles plain text without tool syntax.""" | ||
| # Apply xfail markers if configured | ||
| test_name = "test_no_tool_calls" | ||
| self.apply_xfail_mark(request, test_config, test_name, streaming) | ||
|
|
||
| content, tool_calls = run_tool_extraction( | ||
| tool_parser, test_config.no_tool_calls_output, streaming=streaming | ||
| ) | ||
| assert content == test_config.no_tool_calls_output, ( | ||
| f"Expected content to match input, got {content}" | ||
| ) | ||
| assert len(tool_calls) == 0, f"Expected no tool calls, got {len(tool_calls)}" | ||
|
|
||
| def test_single_tool_call_simple_args( | ||
| self, | ||
| request: pytest.FixtureRequest, | ||
| tool_parser: Any, | ||
| test_config: ToolParserTestConfig, | ||
| streaming: bool, | ||
| ): | ||
| """Verify parser extracts one tool with simple arguments.""" | ||
| # Apply xfail markers if configured | ||
| test_name = "test_single_tool_call_simple_args" | ||
| self.apply_xfail_mark(request, test_config, test_name, streaming) | ||
|
|
||
| content, tool_calls = run_tool_extraction( | ||
| tool_parser, test_config.single_tool_call_output, streaming=streaming | ||
| ) | ||
|
|
||
| # Content check (some parsers strip it) | ||
| if test_config.single_tool_call_expected_content is not None: | ||
| assert content == test_config.single_tool_call_expected_content | ||
|
|
||
| assert len(tool_calls) == 1, f"Expected 1 tool call, got {len(tool_calls)}" | ||
| assert tool_calls[0].type == "function" | ||
| assert tool_calls[0].function.name == test_config.single_tool_call_expected_name | ||
|
|
||
| args = json.loads(tool_calls[0].function.arguments) | ||
| for key, value in test_config.single_tool_call_expected_args.items(): | ||
| assert args.get(key) == value, ( | ||
| f"Expected {key}={value}, got {args.get(key)}" | ||
| ) | ||
|
|
||
| def test_parallel_tool_calls( | ||
| self, | ||
| request: pytest.FixtureRequest, | ||
| tool_parser: Any, | ||
| test_config: ToolParserTestConfig, | ||
| streaming: bool, | ||
| ): | ||
| """Verify parser handles multiple tools in one response.""" | ||
| # Apply xfail markers if configured | ||
| test_name = "test_parallel_tool_calls" | ||
| self.apply_xfail_mark(request, test_config, test_name, streaming) | ||
|
|
||
| content, tool_calls = run_tool_extraction( | ||
| tool_parser, | ||
| test_config.parallel_tool_calls_output, | ||
| streaming=streaming, | ||
| ) | ||
|
|
||
| assert len(tool_calls) == test_config.parallel_tool_calls_count, ( | ||
| f"Expected {test_config.parallel_tool_calls_count} " | ||
| f"tool calls, got {len(tool_calls)}" | ||
| ) | ||
|
|
||
| # Verify tool names match expected | ||
| for i, expected_name in enumerate(test_config.parallel_tool_calls_names): | ||
| assert tool_calls[i].type == "function" | ||
| assert tool_calls[i].function.name == expected_name | ||
|
|
||
| # Verify unique IDs | ||
| ids = [tc.id for tc in tool_calls] | ||
| assert len(ids) == len(set(ids)), "Tool call IDs should be unique" | ||
|
|
||
| def test_various_data_types( | ||
| self, | ||
| request: pytest.FixtureRequest, | ||
| tool_parser: Any, | ||
| test_config: ToolParserTestConfig, | ||
| streaming: bool, | ||
| ): | ||
| """Verify parser handles all JSON types in arguments.""" | ||
| # Apply xfail markers if configured | ||
| test_name = "test_various_data_types" | ||
| self.apply_xfail_mark(request, test_config, test_name, streaming) | ||
|
|
||
| content, tool_calls = run_tool_extraction( | ||
| tool_parser, | ||
| test_config.various_data_types_output, | ||
| streaming=streaming, | ||
| ) | ||
| assert len(tool_calls) == 1, f"Expected 1 tool call, got {len(tool_calls)}" | ||
|
|
||
| args = json.loads(tool_calls[0].function.arguments) | ||
| # Verify all expected fields present | ||
| required_fields_types = { | ||
| "string_field": str, | ||
| "int_field": int, | ||
| "float_field": float, | ||
| "bool_field": bool, | ||
| "null_field": NoneType, | ||
| "array_field": list, | ||
| "object_field": dict, | ||
| } | ||
| for required_field, expected_type in required_fields_types.items(): | ||
| assert required_field in args, ( | ||
| f"Expected field '{required_field}' in arguments" | ||
| ) | ||
| if test_config.supports_typed_arguments: | ||
| found_type = type(args[required_field]) | ||
| assert found_type is expected_type, ( | ||
| f"Expected field '{required_field}' to have type {expected_type}, " | ||
| f"got {found_type}" | ||
| ) | ||
|
|
||
| def test_empty_arguments( | ||
| self, | ||
| request: pytest.FixtureRequest, | ||
| tool_parser: Any, | ||
| test_config: ToolParserTestConfig, | ||
| streaming: bool, | ||
| ): | ||
| """Verify parser handles parameterless tool calls.""" | ||
| # Apply xfail markers if configured | ||
| test_name = "test_empty_arguments" | ||
| self.apply_xfail_mark(request, test_config, test_name, streaming) | ||
|
|
||
| content, tool_calls = run_tool_extraction( | ||
| tool_parser, test_config.empty_arguments_output, streaming=streaming | ||
| ) | ||
| assert len(tool_calls) == 1, f"Expected 1 tool call, got {len(tool_calls)}" | ||
|
|
||
| args = tool_calls[0].function.arguments | ||
| if test_config.allow_empty_or_json_empty_args: | ||
| assert args in ["{}", ""], f"Expected empty args, got {args}" | ||
| else: | ||
| assert args == "{}", f"Expected {{}}, got {args}" | ||
|
|
||
| def test_surrounding_text( | ||
| self, | ||
| request: pytest.FixtureRequest, | ||
| tool_parser: Any, | ||
| test_config: ToolParserTestConfig, | ||
| streaming: bool, | ||
| ): | ||
| """Verify parser extracts tools from mixed content.""" | ||
| # Apply xfail markers if configured | ||
| test_name = "test_surrounding_text" | ||
| self.apply_xfail_mark(request, test_config, test_name, streaming) | ||
|
|
||
| content, tool_calls = run_tool_extraction( | ||
| tool_parser, test_config.surrounding_text_output, streaming=streaming | ||
| ) | ||
| assert len(tool_calls) >= 1, ( | ||
| f"Expected at least 1 tool call, got {len(tool_calls)}" | ||
| ) | ||
|
|
||
| def test_escaped_strings( | ||
| self, | ||
| request: pytest.FixtureRequest, | ||
| tool_parser: Any, | ||
| test_config: ToolParserTestConfig, | ||
| streaming: bool, | ||
| ): | ||
| """Verify parser handles escaped characters in arguments.""" | ||
| # Apply xfail markers if configured | ||
| test_name = "test_escaped_strings" | ||
| self.apply_xfail_mark(request, test_config, test_name, streaming) | ||
|
|
||
| content, tool_calls = run_tool_extraction( | ||
| tool_parser, test_config.escaped_strings_output, streaming=streaming | ||
| ) | ||
| assert len(tool_calls) == 1, f"Expected 1 tool call, got {len(tool_calls)}" | ||
|
|
||
| args = json.loads(tool_calls[0].function.arguments) | ||
| # At minimum, verify we can parse and have expected fields | ||
| # Exact escaping behavior varies by parser | ||
| assert len(args) > 0, "Expected some arguments with escaped strings" | ||
|
|
||
| def test_malformed_input( | ||
| self, | ||
| request: pytest.FixtureRequest, | ||
| tool_parser: Any, | ||
| test_config: ToolParserTestConfig, | ||
| streaming: bool, | ||
| ): | ||
| """Verify parser gracefully handles invalid syntax.""" | ||
| # Apply xfail markers if configured | ||
| test_name = "test_malformed_input" | ||
| self.apply_xfail_mark(request, test_config, test_name, streaming) | ||
|
|
||
| for malformed_input in test_config.malformed_input_outputs: | ||
| # Should not raise exception | ||
| content, tool_calls = run_tool_extraction( | ||
| tool_parser, malformed_input, streaming=streaming | ||
| ) | ||
| # Parser should handle gracefully (exact behavior varies) | ||
|
|
||
| def test_streaming_reconstruction( | ||
| self, | ||
| request: pytest.FixtureRequest, | ||
| tool_parser: Any, | ||
| test_config: ToolParserTestConfig, | ||
| ): | ||
| """Verify streaming produces same result as non-streaming.""" | ||
| test_name = "test_streaming_reconstruction" | ||
| self.apply_xfail_mark(request, test_config, test_name, True) | ||
|
|
||
| test_output = test_config.single_tool_call_output | ||
|
|
||
| # Non-streaming result | ||
| content_non, tools_non = run_tool_extraction( | ||
| tool_parser, test_output, streaming=False | ||
| ) | ||
|
|
||
| # Streaming result | ||
| content_stream, tools_stream = run_tool_extraction( | ||
| tool_parser, test_output, streaming=True | ||
| ) | ||
|
|
||
| # Compare results | ||
| assert content_non == content_stream, "Content should match between modes" | ||
| assert len(tools_non) == len(tools_stream), "Tool count should match" | ||
| if len(tools_non) > 0: | ||
| assert tools_non[0].function.name == tools_stream[0].function.name | ||
| assert tools_non[0].function.arguments == tools_stream[0].function.arguments | ||
|
|
||
| def apply_xfail_mark(self, request, test_config, test_name, streaming): | ||
| reason = None | ||
| if streaming and test_name in test_config.xfail_streaming: | ||
| reason = test_config.xfail_streaming[test_name] | ||
| elif not streaming and test_name in test_config.xfail_nonstreaming: | ||
| reason = test_config.xfail_nonstreaming[test_name] | ||
| if reason is not None: | ||
| mark = pytest.mark.xfail(reason=reason, strict=True) | ||
| request.node.add_marker(mark) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.