diff --git a/Cargo.lock b/Cargo.lock index e461772131..e18a6bcbe0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2776,6 +2776,7 @@ dependencies = [ "num-traits", "openai-harmony", "regex", + "rstest 0.25.0", "rustpython-parser", "serde", "serde_json", @@ -8761,6 +8762,18 @@ dependencies = [ "rustc_version", ] +[[package]] +name = "rstest" +version = "0.25.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6fc39292f8613e913f7df8fa892b8944ceb47c247b78e1b1ae2f09e019be789d" +dependencies = [ + "futures-timer", + "futures-util", + "rstest_macros 0.25.0", + "rustc_version", +] + [[package]] name = "rstest_macros" version = "0.18.2" @@ -8796,6 +8809,24 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "rstest_macros" +version = "0.25.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f168d99749d307be9de54d23fd226628d99768225ef08f6ffb52e0182a27746" +dependencies = [ + "cfg-if 1.0.4", + "glob", + "proc-macro-crate", + "proc-macro2", + "quote", + "regex", + "relative-path", + "rustc_version", + "syn 2.0.110", + "unicode-ident", +] + [[package]] name = "rstest_reuse" version = "0.7.0" diff --git a/lib/parsers/Cargo.toml b/lib/parsers/Cargo.toml index c4d6fb93ea..9255bdb2d9 100644 --- a/lib/parsers/Cargo.toml +++ b/lib/parsers/Cargo.toml @@ -38,3 +38,6 @@ openai-harmony = "0.0.3" lazy_static = "1.5.0" rustpython-parser = "0.4.0" num-traits = "0.2" + +[dev-dependencies] +rstest = "0.25" diff --git a/lib/parsers/src/tool_calling/config.rs b/lib/parsers/src/tool_calling/config.rs index 9fa4874e9c..57f8d37f96 100644 --- a/lib/parsers/src/tool_calling/config.rs +++ b/lib/parsers/src/tool_calling/config.rs @@ -55,6 +55,7 @@ impl Default for JsonParserConfig { } /// Configuration for parsing tool calls with different formats +// TODO(2ez4bz): refactor to allow other parser configs than `JsonParserConfig`. #[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] pub struct ToolCallConfig { /// The format type for tool calls @@ -192,4 +193,12 @@ impl ToolCallConfig { }, } } + + pub fn qwen3_coder() -> Self { + // value + Self { + format: ToolCallParserType::Xml, + json: JsonParserConfig::default(), // Not used for qwen3_coder but kept for consistency. + } + } } diff --git a/lib/parsers/src/tool_calling/mod.rs b/lib/parsers/src/tool_calling/mod.rs index 82240961fa..b32e8c583c 100644 --- a/lib/parsers/src/tool_calling/mod.rs +++ b/lib/parsers/src/tool_calling/mod.rs @@ -10,6 +10,7 @@ pub mod response; #[cfg(test)] pub mod tests; pub mod tools; +pub mod xml; // Re-export main types and functions for convenience pub use config::{JsonParserConfig, ToolCallConfig, ToolCallParserType}; @@ -22,3 +23,4 @@ pub use parsers::{ pub use pythonic::try_tool_call_parse_pythonic; pub use response::{CalledFunction, ToolCallResponse, ToolCallType}; pub use tools::{try_tool_call_parse_aggregate, try_tool_call_parse_stream}; +pub use xml::try_tool_call_parse_xml; diff --git a/lib/parsers/src/tool_calling/parsers.rs b/lib/parsers/src/tool_calling/parsers.rs index 3366987146..acf866159a 100644 --- a/lib/parsers/src/tool_calling/parsers.rs +++ b/lib/parsers/src/tool_calling/parsers.rs @@ -14,6 +14,9 @@ use super::pythonic::{ try_tool_call_parse_pythonic, }; use super::response::ToolCallResponse; +use super::xml::{ + detect_tool_call_start_xml, find_tool_call_end_position_xml, try_tool_call_parse_xml, +}; use std::collections::HashMap; use std::sync::OnceLock; @@ -32,6 +35,7 @@ pub fn get_tool_parser_map() -> &'static HashMap<&'static str, ToolCallConfig> { map.insert("harmony", ToolCallConfig::harmony()); map.insert("deepseek_v3", ToolCallConfig::deepseek_v3()); map.insert("deepseek_v3_1", ToolCallConfig::deepseek_v3_1()); + map.insert("qwen3_coder", ToolCallConfig::qwen3_coder()); map.insert("default", ToolCallConfig::default()); map }) @@ -64,7 +68,8 @@ pub async fn try_tool_call_parse( anyhow::bail!("Typescript parser not implemented"); } ToolCallParserType::Xml => { - anyhow::bail!("Xml parser not implemented"); + let (results, normal_content) = try_tool_call_parse_xml(message)?; + Ok((results, normal_content)) } } } @@ -113,9 +118,7 @@ pub fn detect_tool_call_start(chunk: &str, parser_str: Option<&str>) -> anyhow:: ToolCallParserType::Typescript => { anyhow::bail!("Typescript parser not implemented"); } - ToolCallParserType::Xml => { - anyhow::bail!("Xml parser not implemented"); - } + ToolCallParserType::Xml => Ok(detect_tool_call_start_xml(chunk)), }, None => anyhow::bail!( "Parser '{}' is not implemented. Available parsers: {:?}", @@ -149,10 +152,7 @@ pub fn find_tool_call_end_position(chunk: &str, parser_str: Option<&str>) -> usi // Typescript parser not implemented chunk.len() } - ToolCallParserType::Xml => { - // Xml parser not implemented - chunk.len() - } + ToolCallParserType::Xml => find_tool_call_end_position_xml(chunk), }, None => { // Unknown parser, return full content length @@ -188,6 +188,7 @@ mod tests { "pythonic", "deepseek_v3", "deepseek_v3_1", + "qwen3_coder", ]; for parser in available_parsers { assert!(parsers.contains(&parser)); @@ -1681,13 +1682,13 @@ mod parallel_tool_calling_tests { validate_weather_tool_calls(&result, &[("Dallas", "TX"), ("Orlando", "FL")]); } - // ============================================================================= - // 2. QWEN3CODER TOOL PARSER FORMAT (XML-style tags) - Testing via hermes parser - // ============================================================================= + // ================================================= + // 2. QWEN3CODER TOOL PARSER FORMAT (XML-style tags) + // ================================================= #[tokio::test] async fn test_parallel_qwen3coder_format_two_cities() { - let _input = r#" + let input = r#" Dallas @@ -1714,12 +1715,7 @@ fahrenheit "#; - // Note: This format would need a specialized parser, but for now we test with hermes - // which handles multiple tags - let input_hermes_format = r#"{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}} -{"name": "get_current_weather", "arguments": {"city": "Orlando", "state": "FL", "unit": "fahrenheit"}}"#; - - let (result, content) = detect_and_parse_tool_call(input_hermes_format, Some("hermes")) + let (result, content) = detect_and_parse_tool_call(input, Some("qwen3_coder")) .await .unwrap(); @@ -2471,4 +2467,335 @@ mod detect_parser_tests { let result = detect_tool_call_start(text, Some("deepseek_v3_1")).unwrap(); assert!(result); } + + #[test] + fn test_e2e_detect_tool_call_start_xml() { + let text = r#"Dallas"#; + let result = detect_tool_call_start(text, Some("qwen3_coder")).unwrap(); + assert!(result); + } + + #[test] + fn test_e2e_detect_tool_call_start_xml_partial() { + let text = r#" (String, serde_json::Value) { + let args: serde_json::Value = serde_json::from_str(&call.function.arguments).unwrap(); + (call.function.name, args) + } + + #[tokio::test] + async fn test_qwen3_coder_simple_tool_call() { + let input = r#" + + +pwd && ls + + +"#; + let (result, content) = detect_and_parse_tool_call(input, Some("qwen3_coder")) + .await + .unwrap(); + assert_eq!(content, Some("".to_string())); + assert_eq!(result.len(), 1); + let (name, args) = extract_name_and_args(result[0].clone()); + assert_eq!(name, "execute_bash"); + assert_eq!(args["command"], "pwd && ls"); + } + + #[tokio::test] + async fn test_qwen3_coder_multiple_parameters() { + let input = r#" + + +Dallas + + +TX + + +fahrenheit + + +"#; + let (result, content) = detect_and_parse_tool_call(input, Some("qwen3_coder")) + .await + .unwrap(); + assert_eq!(content, Some("".to_string())); + assert_eq!(result.len(), 1); + let (name, args) = extract_name_and_args(result[0].clone()); + assert_eq!(name, "get_current_weather"); + assert_eq!(args["city"], "Dallas"); + assert_eq!(args["state"], "TX"); + assert_eq!(args["unit"], "fahrenheit"); + } + + #[tokio::test] + async fn test_qwen3_coder_with_normal_text() { + let input = r#"I'll help you check the weather. + + +San Francisco + + +fahrenheit + + + Let me get that information for you."#; + let (result, content) = detect_and_parse_tool_call(input, Some("qwen3_coder")) + .await + .unwrap(); + assert_eq!( + content, + Some( + "I'll help you check the weather. Let me get that information for you." + .to_string() + ) + ); + assert_eq!(result.len(), 1); + let (name, args) = extract_name_and_args(result[0].clone()); + assert_eq!(name, "get_current_weather"); + assert_eq!(args["city"], "San Francisco"); + assert_eq!(args["unit"], "fahrenheit"); + } + + #[tokio::test] + async fn test_qwen3_coder_parallel_tool_calls() { + let input = r#" + + +Dallas + + +TX + + +fahrenheit + + + + + + +Orlando + + +FL + + +fahrenheit + + +"#; + let (result, content) = detect_and_parse_tool_call(input, Some("qwen3_coder")) + .await + .unwrap(); + assert_eq!(content, Some("".to_string())); + assert_eq!(result.len(), 2); + + let (name1, args1) = extract_name_and_args(result[0].clone()); + assert_eq!(name1, "get_current_weather"); + assert_eq!(args1["city"], "Dallas"); + assert_eq!(args1["state"], "TX"); + assert_eq!(args1["unit"], "fahrenheit"); + + let (name2, args2) = extract_name_and_args(result[1].clone()); + assert_eq!(name2, "get_current_weather"); + assert_eq!(args2["city"], "Orlando"); + assert_eq!(args2["state"], "FL"); + assert_eq!(args2["unit"], "fahrenheit"); + } + + #[tokio::test] + async fn test_qwen3_coder_json_parameter_value() { + let input = r#" + + +{"timeout": 30, "retries": 3} + + +"#; + let (result, content) = detect_and_parse_tool_call(input, Some("qwen3_coder")) + .await + .unwrap(); + assert_eq!(content, Some("".to_string())); + assert_eq!(result.len(), 1); + let (name, args) = extract_name_and_args(result[0].clone()); + assert_eq!(name, "process_data"); + assert!(args["config"].is_object()); + assert_eq!(args["config"]["timeout"], 30); + assert_eq!(args["config"]["retries"], 3); + } + + #[tokio::test] + async fn test_qwen3_coder_numeric_parameters() { + let input = r#" + + +42 + + +3.15 + + +true + + +"#; + let (result, _) = detect_and_parse_tool_call(input, Some("qwen3_coder")) + .await + .unwrap(); + assert_eq!(result.len(), 1); + let (name, args) = extract_name_and_args(result[0].clone()); + assert_eq!(name, "calculate"); + assert_eq!(args["x"], 42); + assert_eq!(args["y"], 3.15); + assert_eq!(args["enabled"], true); + } + + #[tokio::test] + async fn test_qwen3_coder_no_tool_calls() { + let input = "This is just normal text without any tool calls."; + let (result, content) = detect_and_parse_tool_call(input, Some("qwen3_coder")) + .await + .unwrap(); + assert_eq!(result.len(), 0); + assert_eq!(content, Some(input.to_string())); + } + + #[tokio::test] + async fn test_qwen3_coder_compact_format() { + let input = r#"rust programming10"#; + let (result, content) = detect_and_parse_tool_call(input, Some("qwen3_coder")) + .await + .unwrap(); + assert_eq!(content, Some("".to_string())); + assert_eq!(result.len(), 1); + let (name, args) = extract_name_and_args(result[0].clone()); + assert_eq!(name, "search"); + assert_eq!(args["query"], "rust programming"); + assert_eq!(args["limit"], 10); + } + + #[tokio::test] + async fn test_qwen3_coder_html_entities() { + let input = r#" + + +<div>Hello & Welcome</div> + + +"#; + let (result, _) = detect_and_parse_tool_call(input, Some("qwen3_coder")) + .await + .unwrap(); + assert_eq!(result.len(), 1); + let (name, args) = extract_name_and_args(result[0].clone()); + assert_eq!(name, "print_message"); + assert_eq!(args["text"], "
Hello & Welcome
"); + } + + #[tokio::test] + async fn test_qwen3_coder_three_parallel_calls() { + let input = r#" + + +Dallas + + + + + + +Orlando + + + + + + +Seattle + + +"#; + let (result, content) = detect_and_parse_tool_call(input, Some("qwen3_coder")) + .await + .unwrap(); + assert_eq!(content, Some("".to_string())); + assert_eq!(result.len(), 3); + + let cities = ["Dallas", "Orlando", "Seattle"]; + for (i, expected_city) in cities.iter().enumerate() { + let (name, args) = extract_name_and_args(result[i].clone()); + assert_eq!(name, "get_current_weather"); + assert_eq!(args["city"], *expected_city); + } + } + + #[tokio::test] + async fn test_qwen3_coder_mixed_tool_types() { + let input = r#" + + +Dallas + + +fahrenheit + + + + + + +weather forecasting + + +5 + + +"#; + let (result, content) = detect_and_parse_tool_call(input, Some("qwen3_coder")) + .await + .unwrap(); + assert_eq!(content, Some("".to_string())); + assert_eq!(result.len(), 2); + + let (name1, args1) = extract_name_and_args(result[0].clone()); + assert_eq!(name1, "get_current_weather"); + assert_eq!(args1["city"], "Dallas"); + assert_eq!(args1["unit"], "fahrenheit"); + + let (name2, args2) = extract_name_and_args(result[1].clone()); + assert_eq!(name2, "web_search"); + assert_eq!(args2["query"], "weather forecasting"); + assert_eq!(args2["max_results"], 5); + } + + #[tokio::test] + async fn test_qwen3_coder_array_parameter_value() { + let input = r#" + + +[1, 2, 3, 4, 5] + + +"#; + let (result, _) = detect_and_parse_tool_call(input, Some("qwen3_coder")) + .await + .unwrap(); + assert_eq!(result.len(), 1); + let (name, args) = extract_name_and_args(result[0].clone()); + assert_eq!(name, "process_list"); + assert!(args["items"].is_array()); + assert_eq!(args["items"], serde_json::json!([1, 2, 3, 4, 5])); + } } diff --git a/lib/parsers/src/tool_calling/xml/mod.rs b/lib/parsers/src/tool_calling/xml/mod.rs new file mode 100644 index 0000000000..ff46c3ba75 --- /dev/null +++ b/lib/parsers/src/tool_calling/xml/mod.rs @@ -0,0 +1,9 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +mod parser; + +pub use super::response; +pub use parser::{ + detect_tool_call_start_xml, find_tool_call_end_position_xml, try_tool_call_parse_xml, +}; diff --git a/lib/parsers/src/tool_calling/xml/parser.rs b/lib/parsers/src/tool_calling/xml/parser.rs new file mode 100644 index 0000000000..3089e3c1d2 --- /dev/null +++ b/lib/parsers/src/tool_calling/xml/parser.rs @@ -0,0 +1,469 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +// Reference implementation: +// https://github.com/sgl-project/sglang/blob/44da737770e4bcd9bfa27751f0a0751c9b5c06e1/python/sglang/srt/function_call/qwen3_coder_detector.py + +use std::collections::HashMap; +use std::sync::OnceLock; + +use regex::Regex; +use uuid::Uuid; + +use super::response::{CalledFunction, ToolCallResponse, ToolCallType}; + +/// Check if a chunk contains the start of a xml-style tool call. +/// Format: ... +// TODO(2ez4bz): Add a parser config struct that allows parameterizing: +// * the tool call start / end tokens +// * the function start / end tokens +// * the parameter start / end tokens +pub fn detect_tool_call_start_xml(chunk: &str) -> bool { + // Check for complete or partial start token. + let start_token = ""; + + // Check if we have the complete start token. + if chunk.contains(start_token) { + return true; + } + + // Check for partial match at the end of the chunk (for streaming). + for i in 1..start_token.len() { + if chunk.ends_with(&start_token[..i]) { + return true; + } + } + + false +} + +/// Find the end position of a Qwen3Coder tool call. +/// Returns the position after or the length of the chunk if not found. +pub fn find_tool_call_end_position_xml(chunk: &str) -> usize { + let end_token = "
"; + + if let Some(pos) = chunk.find(end_token) { + pos + end_token.len() + } else { + chunk.len() + } +} + +/// Try to parse Qwen3Coder formatted tool calls from a message. +/// Format: value +/// Returns (parsed_tool_calls, normal_text_content) +pub fn try_tool_call_parse_xml( + message: &str, +) -> anyhow::Result<(Vec, Option)> { + let (normal_text, tool_calls) = extract_tool_calls(message)?; + + let normal_content = if normal_text.is_empty() { + Some("".to_string()) + } else { + Some(normal_text) + }; + + Ok((tool_calls, normal_content)) +} + +/// Extract tool calls and normal text from message. +fn extract_tool_calls(text: &str) -> anyhow::Result<(String, Vec)> { + let mut normal_parts = Vec::new(); + let mut calls = Vec::new(); + let mut cursor = 0; + + let start_token = ""; + let end_token = ""; + + while cursor < text.len() { + // Find next tool call start. + if let Some(start_pos) = text[cursor..].find(start_token) { + let abs_start = cursor + start_pos; + + // Add text before tool call to normal parts. + normal_parts.push(&text[cursor..abs_start]); + + // Find the corresponding end token. + if let Some(end_pos) = text[abs_start..].find(end_token) { + let abs_end = abs_start + end_pos + end_token.len(); + let block = &text[abs_start..abs_end]; + + // Parse this tool call block. + if let Ok(mut parsed_calls) = parse_tool_call_block(block) { + calls.append(&mut parsed_calls); + } + + cursor = abs_end; + } else { + // No end token found -> treat the rest as normal text. + normal_parts.push(&text[abs_start..]); + break; + } + } else { + // No more tool calls. + normal_parts.push(&text[cursor..]); + break; + } + } + + let normal_text = normal_parts.join("").trim().to_string(); + Ok((normal_text, calls)) +} + +/// Parse a single tool call block +/// Format: value... +fn parse_tool_call_block(block: &str) -> anyhow::Result> { + static FUNCTION_REGEX: OnceLock = OnceLock::new(); + static PARAMETER_REGEX: OnceLock = OnceLock::new(); + + let function_regex = FUNCTION_REGEX.get_or_init(|| { + // Match content or partial content + // (?s) makes . match newlines + Regex::new(r"(?s)]+)>(.*?)(?:|$)").unwrap() + }); + + let parameter_regex = PARAMETER_REGEX.get_or_init(|| { + // Match value or partial value + // (?s) makes . match newlines + Regex::new(r"(?s)]+)>(.*?)(?:|$)").unwrap() + }); + + let mut results = Vec::new(); + + // Find all function blocks. + for func_cap in function_regex.captures_iter(block) { + let function_name = func_cap.get(1).map(|m| m.as_str().trim()).unwrap_or(""); + let function_body = func_cap.get(2).map(|m| m.as_str()).unwrap_or(""); + + if function_name.is_empty() { + continue; + } + + // Parse parameters from the function body. + let mut parameters: HashMap = HashMap::new(); + + for param_cap in parameter_regex.captures_iter(function_body) { + let param_name = param_cap.get(1).map(|m| m.as_str().trim()).unwrap_or(""); + let param_value = param_cap.get(2).map(|m| m.as_str()).unwrap_or(""); + + if !param_name.is_empty() { + let parsed_value = safe_parse_value(param_value); + parameters.insert(param_name.to_string(), parsed_value); + } + } + + // Create tool call response. + let arguments_json = serde_json::to_string(¶meters)?; + + let tool_call = ToolCallResponse { + id: format!("call-{}", Uuid::new_v4()), + tp: ToolCallType::Function, + function: CalledFunction { + name: function_name.to_string(), + arguments: arguments_json, + }, + }; + + results.push(tool_call); + } + + Ok(results) +} + +/// Safely parse a value - tries JSON, then falls back to string. +/// Mimics SGLang's `_safe_val` function in spirit. +fn safe_parse_value(raw: &str) -> serde_json::Value { + // HTML unescape + let unescaped = html_unescape(raw.trim()); + + if let Ok(value) = serde_json::from_str::(&unescaped) { + return value; + } + + if let Ok(num) = unescaped.parse::() { + return serde_json::Value::Number(num.into()); + } + + if let Ok(num) = unescaped.parse::() + && let Some(num_val) = serde_json::Number::from_f64(num) + { + return serde_json::Value::Number(num_val); + } + + match unescaped.to_lowercase().as_str() { + "true" => return serde_json::Value::Bool(true), + "false" => return serde_json::Value::Bool(false), + "null" | "none" => return serde_json::Value::Null, + _ => {} + } + + // Default to string, stripping newlines from start and end. + serde_json::Value::String(unescaped.trim_matches('\n').to_string()) +} + +/// Simple HTML unescape for common entities. +fn html_unescape(s: &str) -> String { + s.replace("<", "<") + .replace(">", ">") + .replace("&", "&") + .replace(""", "\"") + .replace("'", "'") + .replace("'", "'") +} + +#[cfg(test)] +mod tests { + use super::*; + use rstest::rstest; + + #[test] + fn test_detect_tool_call_start() { + assert!(detect_tool_call_start_xml("")); + assert!(detect_tool_call_start_xml("text ")); + assert!(detect_tool_call_start_xml(" + assert_eq!(&text[pos..], "more text"); + + let text_no_end = ""; + let pos = find_tool_call_end_position_xml(text_no_end); + assert_eq!(pos, text_no_end.len()); + } + + #[rstest] + #[case(r#"{"key": "value"}"#, serde_json::json!({"key": "value"}), "JSON object")] + #[case(r#"[1, 2, 3]"#, serde_json::json!([1, 2, 3]), "JSON array")] + #[case("42", serde_json::json!(42), "integer")] + #[case("3.15", serde_json::json!(3.15), "float")] + #[case("true", serde_json::json!(true), "boolean true")] + #[case("false", serde_json::json!(false), "boolean false")] + #[case("null", serde_json::json!(null), "null")] + #[case("hello", serde_json::json!("hello"), "unquoted string")] + #[case(" text ", serde_json::json!("text"), "trimmed string")] + fn test_safe_parse_value( + #[case] input: &str, + #[case] expected: serde_json::Value, + #[case] _description: &str, + ) { + assert_eq!(safe_parse_value(input), expected); + } + + #[rstest] + #[case("<div>", "
", "HTML tags")] + #[case("a & b", "a & b", "ampersand")] + #[case(""quoted"", "\"quoted\"", "quotes")] + fn test_html_unescape(#[case] input: &str, #[case] expected: &str, #[case] _description: &str) { + assert_eq!(html_unescape(input), expected); + } + + #[test] + fn test_parse_simple_tool_call() { + let input = r#" + + +pwd && ls + + +"#; + + let (calls, normal) = try_tool_call_parse_xml(input).unwrap(); + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].function.name, "execute_bash"); + assert_eq!(normal, Some("".to_string())); + + let args: serde_json::Value = serde_json::from_str(&calls[0].function.arguments).unwrap(); + assert_eq!(args["command"], "pwd && ls"); + } + + #[test] + fn test_parse_multiple_parameters() { + let input = r#" + + +San Francisco + + +CA + + +fahrenheit + + +"#; + + let (calls, _) = try_tool_call_parse_xml(input).unwrap(); + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].function.name, "get_weather"); + + let args: serde_json::Value = serde_json::from_str(&calls[0].function.arguments).unwrap(); + assert_eq!(args["city"], "San Francisco"); + assert_eq!(args["state"], "CA"); + assert_eq!(args["unit"], "fahrenheit"); + } + + #[test] + fn test_parse_with_normal_text() { + let input = r#"I'll help you with that. + + +Dallas + + + Let me check that for you."#; + + let (calls, normal) = try_tool_call_parse_xml(input).unwrap(); + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].function.name, "get_weather"); + assert_eq!( + normal, + Some("I'll help you with that. Let me check that for you.".to_string()) + ); + } + + #[test] + fn test_parse_multiple_tool_calls() { + let input = r#" + + +Dallas + + + + + + +Orlando + + +"#; + + let (calls, _) = try_tool_call_parse_xml(input).unwrap(); + assert_eq!(calls.len(), 2); + assert_eq!(calls[0].function.name, "get_weather"); + assert_eq!(calls[1].function.name, "get_weather"); + + let args0: serde_json::Value = serde_json::from_str(&calls[0].function.arguments).unwrap(); + let args1: serde_json::Value = serde_json::from_str(&calls[1].function.arguments).unwrap(); + assert_eq!(args0["city"], "Dallas"); + assert_eq!(args1["city"], "Orlando"); + } + + #[test] + fn test_parse_json_parameter_value() { + let input = r#" + + +{"setting": "value", "count": 42} + + +"#; + + let (calls, _) = try_tool_call_parse_xml(input).unwrap(); + assert_eq!(calls.len(), 1); + + let args: serde_json::Value = serde_json::from_str(&calls[0].function.arguments).unwrap(); + assert!(args["config"].is_object()); + assert_eq!(args["config"]["setting"], "value"); + assert_eq!(args["config"]["count"], 42); + } + + #[test] + fn test_parse_no_tool_calls() { + let input = "This is just normal text without any tool calls."; + let (calls, normal) = try_tool_call_parse_xml(input).unwrap(); + assert_eq!(calls.len(), 0); + assert_eq!(normal, Some(input.to_string())); + } + + #[test] + fn test_parse_malformed_tool_call() { + let input = r#" + + +value +"#; + + // Should handle gracefully - might parse or return empty + let result = try_tool_call_parse_xml(input); + assert!(result.is_ok()); + } + + #[test] + fn test_parse_missing_parameter_closing_tag() { + let input = r#" + + +ls -la + +"#; + + let (calls, _) = try_tool_call_parse_xml(input).unwrap(); + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].function.name, "execute_bash"); + + let args: serde_json::Value = serde_json::from_str(&calls[0].function.arguments).unwrap(); + assert_eq!(args["command"], "ls -la"); + } + + #[test] + fn test_parse_missing_function_closing_tag() { + let input = r#" + + +Boston + +"#; + + let (calls, _) = try_tool_call_parse_xml(input).unwrap(); + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].function.name, "get_weather"); + + let args: serde_json::Value = serde_json::from_str(&calls[0].function.arguments).unwrap(); + assert_eq!(args["city"], "Boston"); + } + + #[test] + fn test_parse_missing_both_closing_tags() { + let input = r#" + + +SELECT * FROM users +"#; + + let (calls, _) = try_tool_call_parse_xml(input).unwrap(); + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].function.name, "run_query"); + + let args: serde_json::Value = serde_json::from_str(&calls[0].function.arguments).unwrap(); + // This matches the original SGLang python implementation. + assert_eq!(args["sql"], "SELECT * FROM users\n"); + } + + #[test] + fn test_parse_multiple_parameters_missing_closing_tags() { + let input = r#" + + +rust programming + +10 + +"#; + + let (calls, _) = try_tool_call_parse_xml(input).unwrap(); + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].function.name, "search"); + + let args: serde_json::Value = serde_json::from_str(&calls[0].function.arguments).unwrap(); + // This matches the original SGLang python implementation. + assert_eq!(args["query"], "rust programming\n\n10"); + } +}