From 25bff98ccdab8722a119a31f6f3ed530837fb32b Mon Sep 17 00:00:00 2001 From: Rexopia Date: Sun, 1 Mar 2026 15:01:17 +0800 Subject: [PATCH 1/8] chore: add .worktrees/ to .gitignore --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 6f57097..a0137e2 100644 --- a/.gitignore +++ b/.gitignore @@ -12,3 +12,4 @@ docs/ vendor/ +.worktrees/ From dc3335d9c112ed892b2a5b4e9a1e4144ae962f5c Mon Sep 17 00:00:00 2001 From: Rexopia Date: Sun, 1 Mar 2026 15:08:47 +0800 Subject: [PATCH 2/8] chore: add glob, walkdir, which dependencies for tools Co-Authored-By: Claude Opus 4.6 --- crewforge-rs/Cargo.lock | 61 +++++++++++++++++++++++++++++++++++++++++ crewforge-rs/Cargo.toml | 3 ++ 2 files changed, 64 insertions(+) diff --git a/crewforge-rs/Cargo.lock b/crewforge-rs/Cargo.lock index 454ba26..6aafc98 100644 --- a/crewforge-rs/Cargo.lock +++ b/crewforge-rs/Cargo.lock @@ -445,6 +445,7 @@ dependencies = [ "crossterm", "directories", "futures", + "glob", "predicates", "rand", "ratatui", @@ -465,6 +466,8 @@ dependencies = [ "unicode-width 0.2.0", "urlencoding", "uuid", + "walkdir", + "which", ] [[package]] @@ -613,6 +616,12 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "env_home" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7f84e12ccf0a7ddc17a6c41c93326024c42920d7ee630d04950e6926645c0fe" + [[package]] name = "equivalent" version = "1.0.2" @@ -820,6 +829,12 @@ dependencies = [ "wasip3", ] +[[package]] +name = "glob" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0cc23270f6e1808e30a928bdc84dea0b9b4136a8bc82338574f23baf47bbd280" + [[package]] name = "h2" version = "0.4.13" @@ -1880,6 +1895,15 @@ version = "1.0.23" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9774ba4a74de5f7b1c1451ed6cd5285a32eddb5cccb8cc655a4e50009e06477f" +[[package]] +name = "same-file" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502" +dependencies = [ + "winapi-util", +] + [[package]] name = "schannel" version = "0.1.28" @@ -2597,6 +2621,16 @@ dependencies = [ "libc", ] +[[package]] +name = "walkdir" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "29790946404f91d9c5d06f9874efddea1dc06c5efe94541a7d6863108e3a5e4b" +dependencies = [ + "same-file", + "winapi-util", +] + [[package]] name = "want" version = "0.3.1" @@ -2756,6 +2790,18 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "which" +version = "7.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24d643ce3fd3e5b54854602a080f34fb10ab75e0b813ee32d00ca2b44fa74762" +dependencies = [ + "either", + "env_home", + "rustix 1.1.4", + "winsafe", +] + [[package]] name = "winapi" version = "0.3.9" @@ -2772,6 +2818,15 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" +[[package]] +name = "winapi-util" +version = "0.1.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22" +dependencies = [ + "windows-sys 0.61.2", +] + [[package]] name = "winapi-x86_64-pc-windows-gnu" version = "0.4.0" @@ -3079,6 +3134,12 @@ version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d6bbff5f0aada427a1e5a6da5f1f98158182f26556f345ac9e04d36d0ebed650" +[[package]] +name = "winsafe" +version = "0.0.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d135d17ab770252ad95e9a872d365cf3090e3be864a34ab46f48555993efc904" + [[package]] name = "wit-bindgen" version = "0.51.0" diff --git a/crewforge-rs/Cargo.toml b/crewforge-rs/Cargo.toml index 2813fab..f98f882 100644 --- a/crewforge-rs/Cargo.toml +++ b/crewforge-rs/Cargo.toml @@ -34,6 +34,9 @@ tui-textarea = { version = "0.7", features = ["crossterm"] } unicode-width = "0.2" uuid = { version = "1", features = ["v4"] } urlencoding = "2" +glob = "0.3" +walkdir = "2" +which = "7" [dev-dependencies] assert_cmd = "2" From adc54cc5871e7c6012bd9fe54c9683fca7d6406e Mon Sep 17 00:00:00 2001 From: Rexopia Date: Sun, 1 Mar 2026 15:27:39 +0800 Subject: [PATCH 3/8] refactor: change Tool trait from call()->String to execute()->ToolResult ToolResult carries success/output/error fields, enabling tools to report security denials as business logic rather than program errors. All callsites (dispatcher, loop_, agent_cmd, agentctl) updated. Co-Authored-By: Claude Opus 4.6 --- crewforge-rs/src/agent/dispatcher.rs | 52 +++++++++++++++++++--------- crewforge-rs/src/agent/loop_.rs | 37 +++++++++++++------- crewforge-rs/src/agent/mod.rs | 12 ++++++- crewforge-rs/src/agent_cmd.rs | 16 ++++++--- crewforge-rs/src/bin/agentctl.rs | 17 ++++++--- 5 files changed, 96 insertions(+), 38 deletions(-) diff --git a/crewforge-rs/src/agent/dispatcher.rs b/crewforge-rs/src/agent/dispatcher.rs index f6923e6..8285a10 100644 --- a/crewforge-rs/src/agent/dispatcher.rs +++ b/crewforge-rs/src/agent/dispatcher.rs @@ -12,9 +12,8 @@ pub struct ParsedToolCall { #[derive(Debug, Clone)] pub struct ToolExecutionResult { + pub tool_result: super::ToolResult, pub name: String, - pub output: String, - pub success: bool, pub tool_call_id: Option, } @@ -95,11 +94,16 @@ impl ToolDispatcher for XmlToolDispatcher { fn format_results(&self, results: &[ToolExecutionResult]) -> ConversationMessage { let mut content = String::new(); for result in results { - let status = if result.success { "ok" } else { "error" }; + let status = if result.tool_result.success { "ok" } else { "error" }; + let output = if let Some(ref err) = result.tool_result.error { + err.as_str() + } else { + &result.tool_result.output + }; let _ = writeln!( content, "\n{}\n", - result.name, status, result.output + result.name, status, output ); } ConversationMessage::Chat(ChatMessage::user(format!("[Tool results]\n{content}"))) @@ -162,12 +166,19 @@ impl ToolDispatcher for NativeToolDispatcher { fn format_results(&self, results: &[ToolExecutionResult]) -> ConversationMessage { let messages: Vec = results .iter() - .map(|result| ToolResultMessage { - tool_call_id: result - .tool_call_id - .clone() - .unwrap_or_else(|| "unknown".to_string()), - content: result.output.clone(), + .map(|result| { + let content = if let Some(ref err) = result.tool_result.error { + format!("{}\n{}", result.tool_result.output, err) + } else { + result.tool_result.output.clone() + }; + ToolResultMessage { + tool_call_id: result + .tool_call_id + .clone() + .unwrap_or_else(|| "unknown".to_string()), + content, + } }) .collect(); ConversationMessage::ToolResults(messages) @@ -222,9 +233,12 @@ mod tests { assert_eq!(calls[0].tool_call_id.as_deref(), Some("tc1")); let msg = dispatcher.format_results(&[ToolExecutionResult { + tool_result: crate::agent::ToolResult { + success: true, + output: "hello".into(), + error: None, + }, name: "file_read".into(), - output: "hello".into(), - success: true, tool_call_id: Some("tc1".into()), }]); match msg { @@ -240,9 +254,12 @@ mod tests { fn xml_format_results_contains_tool_result_tags() { let dispatcher = XmlToolDispatcher; let msg = dispatcher.format_results(&[ToolExecutionResult { + tool_result: crate::agent::ToolResult { + success: true, + output: "ok".into(), + error: None, + }, name: "shell".into(), - output: "ok".into(), - success: true, tool_call_id: None, }]); let rendered = match msg { @@ -257,9 +274,12 @@ mod tests { fn native_format_results_keeps_tool_call_id() { let dispatcher = NativeToolDispatcher; let msg = dispatcher.format_results(&[ToolExecutionResult { + tool_result: crate::agent::ToolResult { + success: true, + output: "ok".into(), + error: None, + }, name: "shell".into(), - output: "ok".into(), - success: true, tool_call_id: Some("tc-1".into()), }]); diff --git a/crewforge-rs/src/agent/loop_.rs b/crewforge-rs/src/agent/loop_.rs index 61466f6..6562016 100644 --- a/crewforge-rs/src/agent/loop_.rs +++ b/crewforge-rs/src/agent/loop_.rs @@ -279,8 +279,12 @@ impl AgentSession { }); let result = execute_tool(&self.tools, call).await; - let success = result.success; - let output = result.output.clone(); + let success = result.tool_result.success; + let output = if let Some(ref err) = result.tool_result.error { + err.clone() + } else { + result.tool_result.output.clone() + }; events.push(AgentEvent::ToolCallFinished { name: call.name.clone(), @@ -333,24 +337,29 @@ fn tool_call_signature(name: &str, arguments: &serde_json::Value) -> (String, St async fn execute_tool(tools: &[Box], call: &ParsedToolCall) -> ToolExecutionResult { let tool = tools.iter().find(|t| t.name() == call.name); match tool { - Some(t) => match t.call(call.arguments.clone()).await { - Ok(output) => ToolExecutionResult { + Some(t) => match t.execute(call.arguments.clone()).await { + Ok(tool_result) => ToolExecutionResult { name: call.name.clone(), - output, - success: true, + tool_result, tool_call_id: call.tool_call_id.clone(), }, Err(e) => ToolExecutionResult { name: call.name.clone(), - output: format!("Error: {e}"), - success: false, + tool_result: super::ToolResult { + success: false, + output: String::new(), + error: Some(format!("Error: {e}")), + }, tool_call_id: call.tool_call_id.clone(), }, }, None => ToolExecutionResult { name: call.name.clone(), - output: format!("Unknown tool: {}", call.name), - success: false, + tool_result: super::ToolResult { + success: false, + output: String::new(), + error: Some(format!("Unknown tool: {}", call.name)), + }, tool_call_id: call.tool_call_id.clone(), }, } @@ -612,8 +621,12 @@ mod tests { fn name(&self) -> &str { "noop" } fn description(&self) -> &str { "no-op" } fn parameters(&self) -> serde_json::Value { serde_json::json!({}) } - async fn call(&self, _args: serde_json::Value) -> anyhow::Result { - Ok("done".to_string()) + async fn execute(&self, _args: serde_json::Value) -> anyhow::Result { + Ok(crate::agent::ToolResult { + success: true, + output: "done".to_string(), + error: None, + }) } } diff --git a/crewforge-rs/src/agent/mod.rs b/crewforge-rs/src/agent/mod.rs index 8b39b07..d22eded 100644 --- a/crewforge-rs/src/agent/mod.rs +++ b/crewforge-rs/src/agent/mod.rs @@ -8,6 +8,16 @@ pub use loop_::{AgentEvent, AgentSession, AgentSessionConfig, StopReason}; use async_trait::async_trait; use crate::provider::traits::ToolSpec; +/// Structured result from tool execution. +/// Security denials use `success: false` with an `error` message — +/// they are business logic, not program errors. +#[derive(Debug, Clone)] +pub struct ToolResult { + pub success: bool, + pub output: String, + pub error: Option, +} + /// Generic tool interface. Implement this for any tool the agent can use. /// CrewForge hub tools (HubGet, HubAck, HubPost) implement this trait. #[async_trait] @@ -24,5 +34,5 @@ pub trait Tool: Send + Sync { } } - async fn call(&self, args: serde_json::Value) -> anyhow::Result; + async fn execute(&self, args: serde_json::Value) -> anyhow::Result; } diff --git a/crewforge-rs/src/agent_cmd.rs b/crewforge-rs/src/agent_cmd.rs index c15e0a0..0ee02b1 100644 --- a/crewforge-rs/src/agent_cmd.rs +++ b/crewforge-rs/src/agent_cmd.rs @@ -78,9 +78,13 @@ impl Tool for EchoTool { "required": ["message"] }) } - async fn call(&self, args: serde_json::Value) -> anyhow::Result { + async fn execute(&self, args: serde_json::Value) -> anyhow::Result { let msg = args.get("message").and_then(|v| v.as_str()).unwrap_or("[no message]"); - Ok(format!("Echo: {msg}")) + Ok(crewforge::agent::ToolResult { + success: true, + output: format!("Echo: {msg}"), + error: None, + }) } } @@ -93,7 +97,7 @@ impl Tool for DatetimeTool { fn parameters(&self) -> serde_json::Value { serde_json::json!({"type": "object", "properties": {}, "required": []}) } - async fn call(&self, _args: serde_json::Value) -> anyhow::Result { + async fn execute(&self, _args: serde_json::Value) -> anyhow::Result { use std::time::{SystemTime, UNIX_EPOCH}; let secs = SystemTime::now() .duration_since(UNIX_EPOCH) @@ -103,7 +107,11 @@ impl Tool for DatetimeTool { let m = (secs / 60) % 60; let h = (secs / 3600) % 24; let days = secs / 86400; - Ok(format!("UTC unix_day={days} {:02}:{:02}:{:02}", h, m, s)) + Ok(crewforge::agent::ToolResult { + success: true, + output: format!("UTC unix_day={days} {:02}:{:02}:{:02}", h, m, s), + error: None, + }) } } diff --git a/crewforge-rs/src/bin/agentctl.rs b/crewforge-rs/src/bin/agentctl.rs index d55bc35..b22962e 100644 --- a/crewforge-rs/src/bin/agentctl.rs +++ b/crewforge-rs/src/bin/agentctl.rs @@ -88,12 +88,16 @@ impl Tool for EchoTool { }) } - async fn call(&self, args: serde_json::Value) -> anyhow::Result { + async fn execute(&self, args: serde_json::Value) -> anyhow::Result { let msg = args .get("message") .and_then(|v| v.as_str()) .unwrap_or("[no message]"); - Ok(format!("Echo: {msg}")) + Ok(crewforge::agent::ToolResult { + success: true, + output: format!("Echo: {msg}"), + error: None, + }) } } @@ -118,18 +122,21 @@ impl Tool for DatetimeTool { }) } - async fn call(&self, _args: serde_json::Value) -> anyhow::Result { + async fn execute(&self, _args: serde_json::Value) -> anyhow::Result { use std::time::{SystemTime, UNIX_EPOCH}; let secs = SystemTime::now() .duration_since(UNIX_EPOCH) .map(|d| d.as_secs()) .unwrap_or(0); - // Simple ISO-like format without chrono let s = secs % 60; let m = (secs / 60) % 60; let h = (secs / 3600) % 24; let days = secs / 86400; - Ok(format!("UTC unix_day={days} {:02}:{:02}:{:02}", h, m, s)) + Ok(crewforge::agent::ToolResult { + success: true, + output: format!("UTC unix_day={days} {:02}:{:02}:{:02}", h, m, s), + error: None, + }) } } From 281d03b657c2ed7e13065ff9f4a1d7832eeab678 Mon Sep 17 00:00:00 2001 From: Rexopia Date: Sun, 1 Mar 2026 15:36:10 +0800 Subject: [PATCH 4/8] feat(security): add SecurityPolicy with path ACL, command filter, rate-limiting Ported from zeroclaw's SecurityPolicy with layered defense: - AutonomyLevel (ReadOnly/Supervised/Full) - Path validation (traversal, null bytes, workspace confinement, symlink escape) - Command allowlist (quote-aware lexer, injection blocking, risk classification) - Sliding-window rate limiting (ActionTracker) - 57 tests covering security bypass vectors Co-Authored-By: Claude Opus 4.6 --- crewforge-rs/src/security/mod.rs | 2 + crewforge-rs/src/security/policy.rs | 1396 +++++++++++++++++++++++++++ 2 files changed, 1398 insertions(+) create mode 100644 crewforge-rs/src/security/policy.rs diff --git a/crewforge-rs/src/security/mod.rs b/crewforge-rs/src/security/mod.rs index 55467d7..b14a582 100644 --- a/crewforge-rs/src/security/mod.rs +++ b/crewforge-rs/src/security/mod.rs @@ -1,2 +1,4 @@ +pub mod policy; pub mod secrets; +pub use policy::*; pub use secrets::SecretStore; diff --git a/crewforge-rs/src/security/policy.rs b/crewforge-rs/src/security/policy.rs new file mode 100644 index 0000000..cba93cc --- /dev/null +++ b/crewforge-rs/src/security/policy.rs @@ -0,0 +1,1396 @@ +use std::path::{Path, PathBuf}; +use std::sync::Mutex; +use std::time::Instant; + +/// How much autonomy the agent has. +#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] +pub enum AutonomyLevel { + /// Read-only: can observe but not act + ReadOnly, + /// Supervised: acts but requires approval for risky operations + #[default] + Supervised, + /// Full: autonomous execution within policy bounds + Full, +} + +/// Risk score for shell command execution. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum CommandRiskLevel { + Low, + Medium, + High, +} + +/// Classifies whether a tool operation is read-only or side-effecting. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ToolOperation { + Read, + Act, +} + +/// Sliding-window action tracker for rate limiting. +#[derive(Debug)] +pub struct ActionTracker { + actions: Mutex>, +} + +impl Default for ActionTracker { + fn default() -> Self { + Self { + actions: Mutex::new(Vec::new()), + } + } +} + +impl ActionTracker { + pub fn new() -> Self { + Self::default() + } + + /// Record an action and return the current count within the window. + pub fn record(&self) -> usize { + let mut actions = self.actions.lock().unwrap(); + let cutoff = Instant::now() + .checked_sub(std::time::Duration::from_secs(3600)) + .unwrap_or_else(Instant::now); + actions.retain(|t| *t > cutoff); + actions.push(Instant::now()); + actions.len() + } + + /// Count of actions in the current window without recording. + pub fn count(&self) -> usize { + let mut actions = self.actions.lock().unwrap(); + let cutoff = Instant::now() + .checked_sub(std::time::Duration::from_secs(3600)) + .unwrap_or_else(Instant::now); + actions.retain(|t| *t > cutoff); + actions.len() + } +} + +impl Clone for ActionTracker { + fn clone(&self) -> Self { + let actions = self.actions.lock().unwrap(); + Self { + actions: Mutex::new(actions.clone()), + } + } +} + +/// Security policy enforced on all tool executions. +#[derive(Debug, Clone)] +pub struct SecurityPolicy { + pub autonomy: AutonomyLevel, + pub workspace_dir: PathBuf, + pub workspace_only: bool, + pub allowed_commands: Vec, + pub forbidden_paths: Vec, + pub allowed_roots: Vec, + pub max_actions_per_hour: u32, + pub require_approval_for_medium_risk: bool, + pub block_high_risk_commands: bool, + pub shell_env_passthrough: Vec, + pub tracker: ActionTracker, +} + +impl Default for SecurityPolicy { + fn default() -> Self { + Self { + autonomy: AutonomyLevel::Supervised, + workspace_dir: PathBuf::from("."), + workspace_only: true, + allowed_commands: vec![ + "git".into(), + "npm".into(), + "cargo".into(), + "ls".into(), + "cat".into(), + "grep".into(), + "find".into(), + "echo".into(), + "pwd".into(), + "wc".into(), + "head".into(), + "tail".into(), + "date".into(), + ], + forbidden_paths: vec![ + "/etc".into(), + "/root".into(), + "/home".into(), + "/usr".into(), + "/bin".into(), + "/sbin".into(), + "/lib".into(), + "/opt".into(), + "/boot".into(), + "/dev".into(), + "/proc".into(), + "/sys".into(), + "/var".into(), + "/tmp".into(), + "~/.ssh".into(), + "~/.gnupg".into(), + "~/.aws".into(), + "~/.config".into(), + ], + allowed_roots: Vec::new(), + max_actions_per_hour: 60, + require_approval_for_medium_risk: true, + block_high_risk_commands: true, + shell_env_passthrough: vec![], + tracker: ActionTracker::new(), + } + } +} + +// ── Shell Command Parsing Utilities ───────────────────────────────────────── + +fn home_dir() -> Option { + std::env::var_os("HOME").map(PathBuf::from) +} + +fn expand_user_path(path: &str) -> PathBuf { + let home = home_dir(); + if let (true, Some(h)) = (path == "~", &home) { + return h.clone(); + } + if let (Some(stripped), Some(h)) = (path.strip_prefix("~/"), &home) { + return h.join(stripped); + } + PathBuf::from(path) +} + +/// Skip leading environment variable assignments (e.g. `FOO=bar cmd args`). +fn skip_env_assignments(s: &str) -> &str { + let mut rest = s; + loop { + let Some(word) = rest.split_whitespace().next() else { + return rest; + }; + if word.contains('=') + && word + .chars() + .next() + .is_some_and(|c| c.is_ascii_alphabetic() || c == '_') + { + rest = rest[word.len()..].trim_start(); + } else { + return rest; + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum QuoteState { + None, + Single, + Double, +} + +/// Split a shell command into sub-commands by unquoted separators. +fn split_unquoted_segments(command: &str) -> Vec { + let mut segments = Vec::new(); + let mut current = String::new(); + let mut quote = QuoteState::None; + let mut escaped = false; + let mut chars = command.chars().peekable(); + + let push_segment = |segments: &mut Vec, current: &mut String| { + let trimmed = current.trim(); + if !trimmed.is_empty() { + segments.push(trimmed.to_string()); + } + current.clear(); + }; + + while let Some(ch) = chars.next() { + match quote { + QuoteState::Single => { + if ch == '\'' { + quote = QuoteState::None; + } + current.push(ch); + } + QuoteState::Double => { + if escaped { + escaped = false; + current.push(ch); + continue; + } + if ch == '\\' { + escaped = true; + current.push(ch); + continue; + } + if ch == '"' { + quote = QuoteState::None; + } + current.push(ch); + } + QuoteState::None => { + if escaped { + escaped = false; + current.push(ch); + continue; + } + if ch == '\\' { + escaped = true; + current.push(ch); + continue; + } + + match ch { + '\'' => { + quote = QuoteState::Single; + current.push(ch); + } + '"' => { + quote = QuoteState::Double; + current.push(ch); + } + ';' | '\n' => push_segment(&mut segments, &mut current), + '|' => { + if chars.next_if_eq(&'|').is_some() { + // `||` + } + push_segment(&mut segments, &mut current); + } + '&' => { + if chars.next_if_eq(&'&').is_some() { + push_segment(&mut segments, &mut current); + } else { + current.push(ch); + } + } + _ => current.push(ch), + } + } + } + } + + let trimmed = current.trim(); + if !trimmed.is_empty() { + segments.push(trimmed.to_string()); + } + + segments +} + +/// Detect a single unquoted `&` operator (background/chain). `&&` is allowed. +fn contains_unquoted_single_ampersand(command: &str) -> bool { + let mut quote = QuoteState::None; + let mut escaped = false; + let mut chars = command.chars().peekable(); + + while let Some(ch) = chars.next() { + match quote { + QuoteState::Single => { + if ch == '\'' { + quote = QuoteState::None; + } + } + QuoteState::Double => { + if escaped { + escaped = false; + continue; + } + if ch == '\\' { + escaped = true; + continue; + } + if ch == '"' { + quote = QuoteState::None; + } + } + QuoteState::None => { + if escaped { + escaped = false; + continue; + } + if ch == '\\' { + escaped = true; + continue; + } + match ch { + '\'' => quote = QuoteState::Single, + '"' => quote = QuoteState::Double, + '&' => { + if chars.next_if_eq(&'&').is_none() { + return true; + } + } + _ => {} + } + } + } + } + + false +} + +/// Detect an unquoted character in a shell command. +fn contains_unquoted_char(command: &str, target: char) -> bool { + let mut quote = QuoteState::None; + let mut escaped = false; + + for ch in command.chars() { + match quote { + QuoteState::Single => { + if ch == '\'' { + quote = QuoteState::None; + } + } + QuoteState::Double => { + if escaped { + escaped = false; + continue; + } + if ch == '\\' { + escaped = true; + continue; + } + if ch == '"' { + quote = QuoteState::None; + } + } + QuoteState::None => { + if escaped { + escaped = false; + continue; + } + if ch == '\\' { + escaped = true; + continue; + } + match ch { + '\'' => quote = QuoteState::Single, + '"' => quote = QuoteState::Double, + _ if ch == target => return true, + _ => {} + } + } + } + } + + false +} + +fn is_valid_env_var_name(name: &str) -> bool { + !name.is_empty() + && name + .chars() + .next() + .is_some_and(|c| c.is_ascii_alphabetic() || c == '_') + && name + .chars() + .all(|c| c.is_ascii_alphanumeric() || c == '_') +} + +/// Detect unquoted shell variable expansions that are not explicitly allowlisted. +fn contains_disallowed_unquoted_shell_variable_expansion( + command: &str, + allowed_vars: &[String], +) -> bool { + let mut quote = QuoteState::None; + let mut escaped = false; + let chars: Vec = command.chars().collect(); + let mut i = 0usize; + + while i < chars.len() { + let ch = chars[i]; + + match quote { + QuoteState::Single => { + if ch == '\'' { + quote = QuoteState::None; + } + i += 1; + continue; + } + QuoteState::Double => { + if escaped { + escaped = false; + i += 1; + continue; + } + if ch == '\\' { + escaped = true; + i += 1; + continue; + } + if ch == '"' { + quote = QuoteState::None; + i += 1; + continue; + } + } + QuoteState::None => { + if escaped { + escaped = false; + i += 1; + continue; + } + if ch == '\\' { + escaped = true; + i += 1; + continue; + } + if ch == '\'' { + quote = QuoteState::Single; + i += 1; + continue; + } + if ch == '"' { + quote = QuoteState::Double; + i += 1; + continue; + } + } + } + + if ch != '$' { + i += 1; + continue; + } + + let Some(next) = chars.get(i + 1).copied() else { + i += 1; + continue; + }; + + match next { + '(' => return true, + '{' => { + let mut j = i + 2; + while j < chars.len() && chars[j] != '}' { + j += 1; + } + if j >= chars.len() { + return true; + } + let inner: String = chars[i + 2..j].iter().collect(); + if !is_valid_env_var_name(&inner) + || !allowed_vars.iter().any(|allowed| allowed == &inner) + { + return true; + } + i = j + 1; + continue; + } + c if c.is_ascii_alphabetic() || c == '_' => { + let mut j = i + 2; + while j < chars.len() && (chars[j].is_ascii_alphanumeric() || chars[j] == '_') { + j += 1; + } + let name: String = chars[i + 1..j].iter().collect(); + if !allowed_vars.iter().any(|allowed| allowed == &name) { + return true; + } + i = j; + continue; + } + c if c.is_ascii_digit() || matches!(c, '#' | '?' | '!' | '$' | '*' | '@' | '-') => { + return true; + } + _ => {} + } + + i += 1; + } + + false +} + +fn strip_wrapping_quotes(token: &str) -> &str { + token.trim_matches(|c| c == '"' || c == '\'') +} + +fn looks_like_path(candidate: &str) -> bool { + candidate.starts_with('/') + || candidate.starts_with("./") + || candidate.starts_with("../") + || candidate.starts_with('~') + || candidate == "." + || candidate == ".." + || candidate.contains('/') +} + +fn is_allowlist_entry_match(allowed: &str, executable: &str, executable_base: &str) -> bool { + let allowed = strip_wrapping_quotes(allowed).trim(); + if allowed.is_empty() { + return false; + } + if allowed == "*" { + return true; + } + if looks_like_path(allowed) { + let allowed_path = expand_user_path(allowed); + let executable_path = expand_user_path(executable); + return executable_path == allowed_path; + } + allowed == executable_base +} + +// ── SecurityPolicy Methods ────────────────────────────────────────────────── + +impl SecurityPolicy { + /// Classify command risk. Any high-risk segment marks the whole command high. + pub fn command_risk_level(&self, command: &str) -> CommandRiskLevel { + let mut saw_medium = false; + + for segment in split_unquoted_segments(command) { + let cmd_part = skip_env_assignments(&segment); + let mut words = cmd_part.split_whitespace(); + let Some(base_raw) = words.next() else { + continue; + }; + + let base = base_raw + .rsplit('/') + .next() + .unwrap_or("") + .to_ascii_lowercase(); + + let args: Vec = words.map(|w| w.to_ascii_lowercase()).collect(); + let joined_segment = cmd_part.to_ascii_lowercase(); + + if matches!( + base.as_str(), + "rm" | "mkfs" + | "dd" + | "shutdown" + | "reboot" + | "halt" + | "poweroff" + | "sudo" + | "su" + | "chown" + | "chmod" + | "useradd" + | "userdel" + | "usermod" + | "passwd" + | "mount" + | "umount" + | "iptables" + | "ufw" + | "firewall-cmd" + | "curl" + | "wget" + | "nc" + | "ncat" + | "netcat" + | "scp" + | "ssh" + | "ftp" + | "telnet" + ) { + return CommandRiskLevel::High; + } + + if joined_segment.contains("rm -rf /") + || joined_segment.contains("rm -fr /") + || joined_segment.contains(":(){:|:&};:") + { + return CommandRiskLevel::High; + } + + let medium = match base.as_str() { + "git" => args.first().is_some_and(|verb| { + matches!( + verb.as_str(), + "commit" + | "push" + | "reset" + | "clean" + | "rebase" + | "merge" + | "cherry-pick" + | "revert" + | "branch" + | "checkout" + | "switch" + | "tag" + ) + }), + "npm" | "pnpm" | "yarn" => args.first().is_some_and(|verb| { + matches!( + verb.as_str(), + "install" | "add" | "remove" | "uninstall" | "update" | "publish" + ) + }), + "cargo" => args.first().is_some_and(|verb| { + matches!( + verb.as_str(), + "add" | "remove" | "install" | "clean" | "publish" + ) + }), + "touch" | "mkdir" | "mv" | "cp" | "ln" => true, + _ => false, + }; + + saw_medium |= medium; + } + + if saw_medium { + CommandRiskLevel::Medium + } else { + CommandRiskLevel::Low + } + } + + /// Validate full command execution policy (allowlist + risk gate). + pub fn validate_command_execution( + &self, + command: &str, + approved: bool, + ) -> Result { + if !self.is_command_allowed(command) { + return Err(format!("Command not allowed by security policy: {command}")); + } + + let risk = self.command_risk_level(command); + + if risk == CommandRiskLevel::High { + if self.block_high_risk_commands { + return Err("Command blocked: high-risk command is disallowed by policy".into()); + } + if self.autonomy == AutonomyLevel::Supervised && !approved { + return Err( + "Command requires explicit approval (approved=true): high-risk operation" + .into(), + ); + } + } + + if risk == CommandRiskLevel::Medium + && self.autonomy == AutonomyLevel::Supervised + && self.require_approval_for_medium_risk + && !approved + { + return Err( + "Command requires explicit approval (approved=true): medium-risk operation".into(), + ); + } + + Ok(risk) + } + + /// Check if a shell command is allowed. + pub fn is_command_allowed(&self, command: &str) -> bool { + if self.autonomy == AutonomyLevel::ReadOnly { + return false; + } + + if command.contains('`') + || contains_disallowed_unquoted_shell_variable_expansion( + command, + &self.shell_env_passthrough, + ) + || command.contains("<(") + || command.contains(">(") + { + return false; + } + + if contains_unquoted_char(command, '>') || contains_unquoted_char(command, '<') { + return false; + } + + if command + .split_whitespace() + .any(|w| w == "tee" || w.ends_with("/tee")) + { + return false; + } + + if contains_unquoted_single_ampersand(command) { + return false; + } + + let segments = split_unquoted_segments(command); + for segment in &segments { + let cmd_part = skip_env_assignments(segment); + let mut words = cmd_part.split_whitespace(); + let executable = strip_wrapping_quotes(words.next().unwrap_or("")).trim(); + let base_cmd = executable.rsplit('/').next().unwrap_or(""); + + if base_cmd.is_empty() { + continue; + } + + if !self + .allowed_commands + .iter() + .any(|allowed| is_allowlist_entry_match(allowed, executable, base_cmd)) + { + return false; + } + + let args: Vec = words.map(|w| w.to_ascii_lowercase()).collect(); + if !self.is_args_safe(base_cmd, &args) { + return false; + } + } + + segments.iter().any(|s| { + let s = skip_env_assignments(s.trim()); + s.split_whitespace().next().is_some_and(|w| !w.is_empty()) + }) + } + + /// Check for dangerous arguments that allow sub-command execution. + fn is_args_safe(&self, base: &str, args: &[String]) -> bool { + let base = base.to_ascii_lowercase(); + match base.as_str() { + "find" => !args.iter().any(|arg| arg == "-exec" || arg == "-ok"), + "git" => !args.iter().any(|arg| { + arg == "config" + || arg.starts_with("config.") + || arg == "alias" + || arg.starts_with("alias.") + || arg == "-c" + }), + _ => true, + } + } + + /// Return the first path-like argument blocked by path policy. + pub fn forbidden_path_argument(&self, command: &str) -> Option { + let forbidden_candidate = |raw: &str| { + let candidate = strip_wrapping_quotes(raw).trim(); + if candidate.is_empty() || candidate.contains("://") { + return None; + } + if looks_like_path(candidate) && !self.is_path_allowed(candidate) { + Some(candidate.to_string()) + } else { + None + } + }; + + for segment in split_unquoted_segments(command) { + let cmd_part = skip_env_assignments(&segment); + let mut words = cmd_part.split_whitespace(); + let Some(_executable) = words.next() else { + continue; + }; + + for token in words { + let candidate = strip_wrapping_quotes(token).trim(); + if candidate.is_empty() || candidate.contains("://") { + continue; + } + + if candidate.starts_with('-') { + if let Some((_, value)) = candidate.split_once('=') { + let blocked = forbidden_candidate(value); + if blocked.is_some() { + return blocked; + } + } + continue; + } + + if let Some(blocked) = forbidden_candidate(candidate) { + return Some(blocked); + } + } + } + + None + } + + /// Check if a file path is allowed (no path traversal, within workspace). + pub fn is_path_allowed(&self, path: &str) -> bool { + if path.contains('\0') { + return false; + } + + if Path::new(path) + .components() + .any(|c| matches!(c, std::path::Component::ParentDir)) + { + return false; + } + + let lower = path.to_lowercase(); + if lower.contains("..%2f") || lower.contains("%2f..") { + return false; + } + + if path.starts_with('~') && path != "~" && !path.starts_with("~/") { + return false; + } + + let expanded_path = expand_user_path(path); + + if self.workspace_only && expanded_path.is_absolute() { + return false; + } + + for forbidden in &self.forbidden_paths { + let forbidden_path = expand_user_path(forbidden); + if expanded_path.starts_with(forbidden_path) { + return false; + } + } + + true + } + + /// Validate that a resolved path is inside the workspace or an allowed root. + pub fn is_resolved_path_allowed(&self, resolved: &Path) -> bool { + let workspace_root = self + .workspace_dir + .canonicalize() + .unwrap_or_else(|_| self.workspace_dir.clone()); + if resolved.starts_with(&workspace_root) { + return true; + } + + for root in &self.allowed_roots { + let canonical = root.canonicalize().unwrap_or_else(|_| root.clone()); + if resolved.starts_with(&canonical) { + return true; + } + } + + for forbidden in &self.forbidden_paths { + let forbidden_path = expand_user_path(forbidden); + if resolved.starts_with(&forbidden_path) { + return false; + } + } + + if !self.workspace_only { + return true; + } + + false + } + + /// Returns human-readable guidance on how to fix path violations. + pub fn resolved_path_violation_message(&self, resolved: &Path) -> String { + let guidance = if self.allowed_roots.is_empty() { + "Add the directory to allowed_roots, or move the file into the workspace." + } else { + "Add a matching parent directory to allowed_roots, or move the file into the workspace." + }; + format!( + "Resolved path escapes workspace allowlist: {}. {}", + resolved.display(), + guidance + ) + } + + /// Check if autonomy level permits any action at all. + pub fn can_act(&self) -> bool { + self.autonomy != AutonomyLevel::ReadOnly + } + + /// Enforce policy for a tool operation. + pub fn enforce_tool_operation( + &self, + operation: ToolOperation, + operation_name: &str, + ) -> Result<(), String> { + match operation { + ToolOperation::Read => Ok(()), + ToolOperation::Act => { + if !self.can_act() { + return Err(format!( + "Security policy: read-only mode, cannot perform '{operation_name}'" + )); + } + if !self.record_action() { + return Err("Rate limit exceeded: action budget exhausted".to_string()); + } + Ok(()) + } + } + } + + /// Record an action and check if the rate limit has been exceeded. + pub fn record_action(&self) -> bool { + let count = self.tracker.record(); + count <= self.max_actions_per_hour as usize + } + + /// Check if the rate limit would be exceeded without recording. + pub fn is_rate_limited(&self) -> bool { + self.tracker.count() >= self.max_actions_per_hour as usize + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn test_policy() -> SecurityPolicy { + SecurityPolicy::default() + } + + // ── AutonomyLevel tests ───────────────────────────────────────────── + + #[test] + fn default_autonomy_is_supervised() { + assert_eq!(AutonomyLevel::default(), AutonomyLevel::Supervised); + } + + #[test] + fn can_act_readonly_is_false() { + let p = SecurityPolicy { + autonomy: AutonomyLevel::ReadOnly, + ..SecurityPolicy::default() + }; + assert!(!p.can_act()); + } + + #[test] + fn can_act_supervised_is_true() { + assert!(test_policy().can_act()); + } + + #[test] + fn can_act_full_is_true() { + let p = SecurityPolicy { + autonomy: AutonomyLevel::Full, + ..SecurityPolicy::default() + }; + assert!(p.can_act()); + } + + // ── Path validation tests ─────────────────────────────────────────── + + #[test] + fn path_relative_allowed() { + assert!(test_policy().is_path_allowed("src/main.rs")); + } + + #[test] + fn path_traversal_blocked() { + assert!(!test_policy().is_path_allowed("../../../etc/passwd")); + } + + #[test] + fn path_absolute_blocked_workspace_only() { + assert!(!test_policy().is_path_allowed("/tmp/file.txt")); + } + + #[test] + fn path_absolute_allowed_workspace_not_only() { + let p = SecurityPolicy { + workspace_only: false, + ..SecurityPolicy::default() + }; + // /some/safe/path not in forbidden list + assert!(p.is_path_allowed("/some/safe/path")); + } + + #[test] + fn path_forbidden_blocked() { + let p = SecurityPolicy { + workspace_only: false, + ..SecurityPolicy::default() + }; + assert!(!p.is_path_allowed("/etc/passwd")); + } + + #[test] + fn path_null_byte_blocked() { + assert!(!test_policy().is_path_allowed("file\0.txt")); + } + + #[test] + fn path_url_encoded_traversal_blocked() { + assert!(!test_policy().is_path_allowed("..%2f..%2fetc/passwd")); + } + + #[test] + fn path_tilde_user_blocked() { + assert!(!test_policy().is_path_allowed("~root/.ssh/id_rsa")); + } + + #[test] + fn path_dotfile_in_workspace_allowed() { + assert!(test_policy().is_path_allowed(".env")); + } + + // ── Resolved path tests ───────────────────────────────────────────── + + #[test] + fn resolved_path_inside_workspace() { + let dir = tempfile::tempdir().unwrap(); + let p = SecurityPolicy { + workspace_dir: dir.path().to_path_buf(), + ..SecurityPolicy::default() + }; + let inside = dir.path().join("src/main.rs"); + assert!(p.is_resolved_path_allowed(&inside)); + } + + #[test] + fn resolved_path_outside_workspace_blocked() { + let dir = tempfile::tempdir().unwrap(); + let p = SecurityPolicy { + workspace_dir: dir.path().to_path_buf(), + ..SecurityPolicy::default() + }; + assert!(!p.is_resolved_path_allowed(Path::new("/etc/passwd"))); + } + + #[test] + fn resolved_path_allowed_roots() { + let dir = tempfile::tempdir().unwrap(); + let extra = tempfile::tempdir().unwrap(); + let p = SecurityPolicy { + workspace_dir: dir.path().to_path_buf(), + allowed_roots: vec![extra.path().to_path_buf()], + ..SecurityPolicy::default() + }; + let inside_extra = extra.path().join("file.txt"); + assert!(p.is_resolved_path_allowed(&inside_extra)); + } + + #[cfg(unix)] + #[test] + fn resolved_path_symlink_escape_blocked() { + let workspace = tempfile::tempdir().unwrap(); + let outside = tempfile::tempdir().unwrap(); + let link_path = workspace.path().join("link"); + std::os::unix::fs::symlink(outside.path(), &link_path).unwrap(); + let resolved = link_path.canonicalize().unwrap(); + + let p = SecurityPolicy { + workspace_dir: workspace.path().to_path_buf(), + ..SecurityPolicy::default() + }; + assert!(!p.is_resolved_path_allowed(&resolved)); + } + + // ── Command allowlist tests ───────────────────────────────────────── + + #[test] + fn command_allowed_basic() { + assert!(test_policy().is_command_allowed("ls -la")); + } + + #[test] + fn command_blocked_unknown() { + assert!(!test_policy().is_command_allowed("python3 script.py")); + } + + #[test] + fn command_readonly_blocks_all() { + let p = SecurityPolicy { + autonomy: AutonomyLevel::ReadOnly, + ..SecurityPolicy::default() + }; + assert!(!p.is_command_allowed("ls")); + } + + #[test] + fn command_pipe_both_sides_checked() { + assert!(test_policy().is_command_allowed("grep foo | head -5")); + assert!(!test_policy().is_command_allowed("grep foo | python3")); + } + + #[test] + fn command_semicolon_injection_blocked() { + assert!(!test_policy().is_command_allowed("ls; rm -rf /")); + } + + #[test] + fn command_backtick_injection_blocked() { + assert!(!test_policy().is_command_allowed("echo `rm -rf /`")); + } + + #[test] + fn command_dollar_paren_blocked() { + assert!(!test_policy().is_command_allowed("echo $(rm -rf /)")); + } + + #[test] + fn command_redirect_blocked() { + assert!(!test_policy().is_command_allowed("echo secret > /tmp/file")); + } + + #[test] + fn command_background_blocked() { + assert!(!test_policy().is_command_allowed("ls & rm -rf /")); + } + + #[test] + fn command_and_chain_allowed() { + assert!(test_policy().is_command_allowed("ls && echo ok")); + } + + #[test] + fn command_tee_blocked() { + assert!(!test_policy().is_command_allowed("echo secret | tee /tmp/out")); + } + + #[test] + fn command_find_exec_blocked() { + assert!(!test_policy().is_command_allowed("find . -exec rm {} \\;")); + } + + #[test] + fn command_git_config_blocked() { + assert!(!test_policy().is_command_allowed("git config core.editor vim")); + } + + #[test] + fn command_process_substitution_blocked() { + assert!(!test_policy().is_command_allowed("cat <(echo foo)")); + } + + #[test] + fn command_env_prefix_handled() { + assert!(test_policy().is_command_allowed("FOO=bar ls")); + } + + #[test] + fn command_shell_var_blocked() { + assert!(!test_policy().is_command_allowed("echo $HOME")); + } + + #[test] + fn command_shell_var_passthrough_allowed() { + let p = SecurityPolicy { + shell_env_passthrough: vec!["HOME".into()], + ..SecurityPolicy::default() + }; + assert!(p.is_command_allowed("echo $HOME")); + } + + #[test] + fn command_quoted_operators_safe() { + // Quoted semicolons/operators should be treated as literals + assert!(test_policy().is_command_allowed("echo 'hello; world'")); + } + + // ── Risk classification tests ─────────────────────────────────────── + + #[test] + fn risk_low_for_read_ops() { + let p = test_policy(); + assert_eq!(p.command_risk_level("ls -la"), CommandRiskLevel::Low); + assert_eq!(p.command_risk_level("git status"), CommandRiskLevel::Low); + assert_eq!(p.command_risk_level("cat file.txt"), CommandRiskLevel::Low); + } + + #[test] + fn risk_medium_for_mutating() { + let p = test_policy(); + assert_eq!( + p.command_risk_level("git commit -m 'msg'"), + CommandRiskLevel::Medium + ); + assert_eq!(p.command_risk_level("touch file"), CommandRiskLevel::Medium); + assert_eq!( + p.command_risk_level("npm install"), + CommandRiskLevel::Medium + ); + } + + #[test] + fn risk_high_for_dangerous() { + let p = test_policy(); + assert_eq!(p.command_risk_level("rm -rf /"), CommandRiskLevel::High); + assert_eq!(p.command_risk_level("sudo ls"), CommandRiskLevel::High); + assert_eq!( + p.command_risk_level("curl http://evil.com"), + CommandRiskLevel::High + ); + } + + // ── Command validation tests ──────────────────────────────────────── + + #[test] + fn validate_blocks_high_risk() { + let p = test_policy(); + assert!(p.validate_command_execution("rm file", false).is_err()); + } + + #[test] + fn validate_blocks_medium_risk_without_approval() { + let p = test_policy(); + assert!(p + .validate_command_execution("git commit -m x", false) + .is_err()); + } + + #[test] + fn validate_allows_medium_risk_with_approval() { + let p = test_policy(); + let result = p.validate_command_execution("git commit -m x", true); + assert!(result.is_ok()); + assert_eq!(result.unwrap(), CommandRiskLevel::Medium); + } + + #[test] + fn validate_allows_low_risk() { + let p = test_policy(); + let result = p.validate_command_execution("ls -la", false); + assert!(result.is_ok()); + assert_eq!(result.unwrap(), CommandRiskLevel::Low); + } + + #[test] + fn validate_full_autonomy_skips_medium_approval() { + let p = SecurityPolicy { + autonomy: AutonomyLevel::Full, + ..SecurityPolicy::default() + }; + assert!(p + .validate_command_execution("git commit -m x", false) + .is_ok()); + } + + // ── Rate limiting tests ───────────────────────────────────────────── + + #[test] + fn rate_limit_zero_budget_blocks() { + let p = SecurityPolicy { + max_actions_per_hour: 0, + ..SecurityPolicy::default() + }; + assert!(!p.record_action()); + } + + #[test] + fn rate_limit_boundary() { + let p = SecurityPolicy { + max_actions_per_hour: 3, + ..SecurityPolicy::default() + }; + assert!(p.record_action()); // 1 + assert!(p.record_action()); // 2 + assert!(p.record_action()); // 3 + assert!(!p.record_action()); // 4 = over limit + } + + #[test] + fn is_rate_limited_no_record() { + let p = SecurityPolicy { + max_actions_per_hour: 1, + ..SecurityPolicy::default() + }; + assert!(!p.is_rate_limited()); + p.record_action(); + assert!(p.is_rate_limited()); + } + + #[test] + fn tracker_clone_independence() { + let p = test_policy(); + p.record_action(); + let p2 = p.clone(); + p.record_action(); + // p has 2 actions, p2 should have 1 + assert_eq!(p.tracker.count(), 2); + assert_eq!(p2.tracker.count(), 1); + } + + // ── Enforce tool operation tests ──────────────────────────────────── + + #[test] + fn enforce_read_always_ok() { + let p = SecurityPolicy { + autonomy: AutonomyLevel::ReadOnly, + ..SecurityPolicy::default() + }; + assert!(p.enforce_tool_operation(ToolOperation::Read, "read").is_ok()); + } + + #[test] + fn enforce_act_blocked_readonly() { + let p = SecurityPolicy { + autonomy: AutonomyLevel::ReadOnly, + ..SecurityPolicy::default() + }; + assert!(p + .enforce_tool_operation(ToolOperation::Act, "write") + .is_err()); + } + + #[test] + fn enforce_act_rate_limited() { + let p = SecurityPolicy { + max_actions_per_hour: 1, + ..SecurityPolicy::default() + }; + assert!(p + .enforce_tool_operation(ToolOperation::Act, "write") + .is_ok()); + assert!(p + .enforce_tool_operation(ToolOperation::Act, "write") + .is_err()); + } + + // ── Forbidden path argument tests ─────────────────────────────────── + + #[test] + fn forbidden_path_argument_detects_absolute() { + let p = test_policy(); + assert!(p.forbidden_path_argument("cat /etc/passwd").is_some()); + } + + #[test] + fn forbidden_path_argument_safe_relative() { + let p = test_policy(); + assert!(p.forbidden_path_argument("cat src/main.rs").is_none()); + } + + #[test] + fn forbidden_path_argument_option_value() { + let p = test_policy(); + assert!(p + .forbidden_path_argument("cmd --file=/etc/passwd") + .is_some()); + } + + // ── Default policy sanity ─────────────────────────────────────────── + + #[test] + fn default_policy_sanity() { + let p = SecurityPolicy::default(); + assert_eq!(p.autonomy, AutonomyLevel::Supervised); + assert!(p.workspace_only); + assert!(!p.allowed_commands.is_empty()); + assert!(!p.forbidden_paths.is_empty()); + assert!(p.block_high_risk_commands); + assert!(p.require_approval_for_medium_risk); + } + + #[test] + fn security_checklist_root_path_blocked() { + assert!(!test_policy().is_path_allowed("/")); + } + + #[test] + fn security_checklist_all_system_dirs_blocked() { + let p = test_policy(); + for dir in &[ + "/etc", "/root", "/usr", "/bin", "/sbin", "/boot", "/dev", "/proc", "/sys", + ] { + assert!( + !p.is_path_allowed(dir), + "{dir} should be blocked" + ); + } + } + + #[test] + fn resolved_path_violation_message_content() { + let p = test_policy(); + let msg = p.resolved_path_violation_message(Path::new("/etc/passwd")); + assert!(msg.contains("escapes workspace")); + assert!(msg.contains("allowed_roots")); + } +} From 7ff43fbdc6e0cd23090f762d49109d04f42bcf61 Mon Sep 17 00:00:00 2001 From: Rexopia Date: Sun, 1 Mar 2026 16:15:24 +0800 Subject: [PATCH 5/8] feat(tools): add 6 built-in tools with SecurityPolicy enforcement Implements: shell, file_read, file_write, file_edit, glob_search, content_search. All tools inject Arc for path ACL, command filtering, and rate-limiting. RuntimeAdapter trait abstracts shell execution for testability. Replaces demo echo/datetime tools in agent_cmd.rs with real tools. Co-Authored-By: Claude Opus 4.6 --- crewforge-rs/src/agent_cmd.rs | 73 +-- crewforge-rs/src/lib.rs | 1 + crewforge-rs/src/tools/content_search.rs | 792 +++++++++++++++++++++++ crewforge-rs/src/tools/file_edit.rs | 475 ++++++++++++++ crewforge-rs/src/tools/file_read.rs | 427 ++++++++++++ crewforge-rs/src/tools/file_write.rs | 311 +++++++++ crewforge-rs/src/tools/glob_search.rs | 357 ++++++++++ crewforge-rs/src/tools/mod.rs | 43 ++ crewforge-rs/src/tools/shell.rs | 434 +++++++++++++ crewforge-rs/src/tools/traits.rs | 55 ++ 10 files changed, 2906 insertions(+), 62 deletions(-) create mode 100644 crewforge-rs/src/tools/content_search.rs create mode 100644 crewforge-rs/src/tools/file_edit.rs create mode 100644 crewforge-rs/src/tools/file_read.rs create mode 100644 crewforge-rs/src/tools/file_write.rs create mode 100644 crewforge-rs/src/tools/glob_search.rs create mode 100644 crewforge-rs/src/tools/mod.rs create mode 100644 crewforge-rs/src/tools/shell.rs create mode 100644 crewforge-rs/src/tools/traits.rs diff --git a/crewforge-rs/src/agent_cmd.rs b/crewforge-rs/src/agent_cmd.rs index 0ee02b1..6e3f85a 100644 --- a/crewforge-rs/src/agent_cmd.rs +++ b/crewforge-rs/src/agent_cmd.rs @@ -14,12 +14,13 @@ use std::io::{self, BufRead, Write}; use std::sync::Arc; use anyhow::Result; -use async_trait::async_trait; use clap::Args; use crewforge::{ agent::{AgentEvent, AgentSession, AgentSessionConfig, StopReason, Tool}, auth::{AuthService, default_state_dir}, provider::{self, default_api_key_env}, + security::SecurityPolicy, + tools::{default_tools, TokioRuntime}, }; // ── Clap args ───────────────────────────────────────────────────────────────── @@ -59,62 +60,6 @@ pub struct AgentArgs { temperature: f64, } -// ── Built-in test tools ─────────────────────────────────────────────────────── - -struct EchoTool; - -#[async_trait] -impl Tool for EchoTool { - fn name(&self) -> &str { "echo" } - fn description(&self) -> &str { - "Echo back the provided message. Useful for verifying that tool calling works end-to-end." - } - fn parameters(&self) -> serde_json::Value { - serde_json::json!({ - "type": "object", - "properties": { - "message": {"type": "string", "description": "The message to echo back"} - }, - "required": ["message"] - }) - } - async fn execute(&self, args: serde_json::Value) -> anyhow::Result { - let msg = args.get("message").and_then(|v| v.as_str()).unwrap_or("[no message]"); - Ok(crewforge::agent::ToolResult { - success: true, - output: format!("Echo: {msg}"), - error: None, - }) - } -} - -struct DatetimeTool; - -#[async_trait] -impl Tool for DatetimeTool { - fn name(&self) -> &str { "get_datetime" } - fn description(&self) -> &str { "Get the current UTC date and time." } - fn parameters(&self) -> serde_json::Value { - serde_json::json!({"type": "object", "properties": {}, "required": []}) - } - async fn execute(&self, _args: serde_json::Value) -> anyhow::Result { - use std::time::{SystemTime, UNIX_EPOCH}; - let secs = SystemTime::now() - .duration_since(UNIX_EPOCH) - .map(|d| d.as_secs()) - .unwrap_or(0); - let s = secs % 60; - let m = (secs / 60) % 60; - let h = (secs / 3600) % 24; - let days = secs / 86400; - Ok(crewforge::agent::ToolResult { - success: true, - output: format!("UTC unix_day={days} {:02}:{:02}:{:02}", h, m, s), - error: None, - }) - } -} - // ── Event rendering ─────────────────────────────────────────────────────────── fn print_event(event: &AgentEvent) { @@ -205,7 +150,13 @@ pub async fn run(args: AgentArgs) -> Result<()> { let tools: Vec> = if args.no_tools { vec![] } else { - vec![Box::new(EchoTool), Box::new(DatetimeTool)] + let workspace = std::env::current_dir().unwrap_or_else(|_| ".".into()); + let security = Arc::new(SecurityPolicy { + workspace_dir: workspace, + ..SecurityPolicy::default() + }); + let runtime = Arc::new(TokioRuntime); + default_tools(security, runtime) }; let config = AgentSessionConfig { @@ -214,17 +165,15 @@ pub async fn run(args: AgentArgs) -> Result<()> { ..Default::default() }; + let tool_names: Vec = tools.iter().map(|t| t.name().to_string()).collect(); let mut session = AgentSession::new(provider, &args.model, &args.system, tools, config); eprintln!( "\x1b[1mcrewforge agent\x1b[0m provider={} model={} tools={}", args.provider, args.model, - if args.no_tools { "off" } else { "echo,get_datetime" } + if args.no_tools { "off".to_string() } else { tool_names.join(", ") } ); - if !args.no_tools { - eprintln!("tools: echo(message), get_datetime()"); - } eprintln!("Type your message and press Enter. Ctrl-D to exit.\n"); let stdin = io::stdin(); diff --git a/crewforge-rs/src/lib.rs b/crewforge-rs/src/lib.rs index a6ac1b0..6729c88 100644 --- a/crewforge-rs/src/lib.rs +++ b/crewforge-rs/src/lib.rs @@ -2,3 +2,4 @@ pub mod agent; pub mod auth; pub mod provider; pub mod security; +pub mod tools; diff --git a/crewforge-rs/src/tools/content_search.rs b/crewforge-rs/src/tools/content_search.rs new file mode 100644 index 0000000..da4fd0d --- /dev/null +++ b/crewforge-rs/src/tools/content_search.rs @@ -0,0 +1,792 @@ +use crate::agent::ToolResult; +use crate::security::SecurityPolicy; +use async_trait::async_trait; +use std::process::Stdio; +use std::sync::{Arc, OnceLock}; + +const MAX_RESULTS: usize = 1000; +const MAX_OUTPUT_BYTES: usize = 1_048_576; // 1 MB +const TIMEOUT_SECS: u64 = 30; + +pub struct ContentSearchTool { + security: Arc, + has_rg: bool, +} + +impl ContentSearchTool { + pub fn new(security: Arc) -> Self { + let has_rg = which::which("rg").is_ok(); + Self { security, has_rg } + } + + #[cfg(test)] + fn new_with_backend(security: Arc, has_rg: bool) -> Self { + Self { security, has_rg } + } +} + +#[async_trait] +impl crate::agent::Tool for ContentSearchTool { + fn name(&self) -> &str { + "content_search" + } + + fn description(&self) -> &str { + "Search file contents by regex pattern within the workspace. \ + Uses ripgrep (rg) with grep fallback. \ + Output modes: 'content', 'files_with_matches', 'count'." + } + + fn parameters(&self) -> serde_json::Value { + serde_json::json!({ + "type": "object", + "properties": { + "pattern": { + "type": "string", + "description": "Regular expression pattern to search for" + }, + "path": { + "type": "string", + "description": "Directory to search in, relative to workspace root. Defaults to '.'", + "default": "." + }, + "output_mode": { + "type": "string", + "description": "Output format: 'content', 'files_with_matches', 'count'", + "enum": ["content", "files_with_matches", "count"], + "default": "content" + }, + "include": { + "type": "string", + "description": "File glob filter, e.g. '*.rs', '*.{ts,tsx}'" + }, + "case_sensitive": { + "type": "boolean", + "description": "Case-sensitive matching. Defaults to true", + "default": true + }, + "context_before": { + "type": "integer", + "description": "Lines of context before each match (content mode only)", + "default": 0 + }, + "context_after": { + "type": "integer", + "description": "Lines of context after each match (content mode only)", + "default": 0 + } + }, + "required": ["pattern"] + }) + } + + async fn execute(&self, args: serde_json::Value) -> anyhow::Result { + let pattern = args + .get("pattern") + .and_then(|v| v.as_str()) + .ok_or_else(|| anyhow::anyhow!("Missing 'pattern' parameter"))?; + + if pattern.is_empty() { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("Empty pattern is not allowed.".into()), + }); + } + + let search_path = args.get("path").and_then(|v| v.as_str()).unwrap_or("."); + + let output_mode = args + .get("output_mode") + .and_then(|v| v.as_str()) + .unwrap_or("content"); + + if !matches!(output_mode, "content" | "files_with_matches" | "count") { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!( + "Invalid output_mode '{output_mode}'. Allowed: content, files_with_matches, count." + )), + }); + } + + let include = args.get("include").and_then(|v| v.as_str()); + + let case_sensitive = args + .get("case_sensitive") + .and_then(|v| v.as_bool()) + .unwrap_or(true); + + #[allow(clippy::cast_possible_truncation)] + let context_before = args + .get("context_before") + .and_then(|v| v.as_u64()) + .unwrap_or(0) as usize; + + #[allow(clippy::cast_possible_truncation)] + let context_after = args + .get("context_after") + .and_then(|v| v.as_u64()) + .unwrap_or(0) as usize; + + #[allow(clippy::cast_possible_truncation)] + let max_results = args + .get("max_results") + .and_then(|v| v.as_u64()) + .map(|v| v as usize) + .unwrap_or(MAX_RESULTS) + .min(MAX_RESULTS); + + if self.security.is_rate_limited() { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("Rate limit exceeded".into()), + }); + } + + if std::path::Path::new(search_path).is_absolute() { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("Absolute paths are not allowed. Use a relative path.".into()), + }); + } + + if search_path.contains("../") || search_path.contains("..\\") || search_path == ".." { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("Path traversal ('..') is not allowed.".into()), + }); + } + + if !self.security.is_path_allowed(search_path) { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!( + "Path '{search_path}' is not allowed by security policy." + )), + }); + } + + if !self.security.record_action() { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("Rate limit exceeded: action budget exhausted".into()), + }); + } + + let workspace = &self.security.workspace_dir; + let resolved_path = workspace.join(search_path); + + let resolved_canon = match std::fs::canonicalize(&resolved_path) { + Ok(p) => p, + Err(e) => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!("Cannot resolve path '{search_path}': {e}")), + }); + } + }; + + if !self.security.is_resolved_path_allowed(&resolved_canon) { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!( + "Resolved path for '{search_path}' is outside the allowed workspace." + )), + }); + } + + // Build command + let mut cmd = if self.has_rg { + build_rg_command( + pattern, + &resolved_canon, + output_mode, + include, + case_sensitive, + context_before, + context_after, + ) + } else { + build_grep_command( + pattern, + &resolved_canon, + output_mode, + include, + case_sensitive, + context_before, + context_after, + ) + }; + + // Clear environment, keep only safe variables + cmd.env_clear(); + for key in &["PATH", "HOME", "LANG", "LC_ALL", "LC_CTYPE"] { + if let Ok(val) = std::env::var(key) { + cmd.env(key, val); + } + } + + cmd.stdout(Stdio::piped()); + cmd.stderr(Stdio::piped()); + + let output = match tokio::time::timeout( + std::time::Duration::from_secs(TIMEOUT_SECS), + tokio::process::Command::from(cmd).output(), + ) + .await + { + Ok(Ok(out)) => out, + Ok(Err(e)) => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!("Failed to execute search command: {e}")), + }); + } + Err(_) => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!("Search timed out after {TIMEOUT_SECS} seconds.")), + }); + } + }; + + // Exit code: 0 = matches found, 1 = no matches (grep/rg), 2 = error + let exit_code = output.status.code().unwrap_or(-1); + if exit_code >= 2 { + let stderr = String::from_utf8_lossy(&output.stderr); + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!("Search error: {}", stderr.trim())), + }); + } + + let raw_stdout = String::from_utf8_lossy(&output.stdout); + + let workspace_canon = + std::fs::canonicalize(workspace).unwrap_or_else(|_| workspace.clone()); + + let formatted = format_line_output(&raw_stdout, &workspace_canon, output_mode, max_results); + + // Truncate if too large + let final_output = if formatted.len() > MAX_OUTPUT_BYTES { + let mut truncated = truncate_utf8(&formatted, MAX_OUTPUT_BYTES).to_string(); + truncated.push_str("\n\n[Output truncated: exceeded 1 MB limit]"); + truncated + } else { + formatted + }; + + Ok(ToolResult { + success: true, + output: final_output, + error: None, + }) + } +} + +fn build_rg_command( + pattern: &str, + search_path: &std::path::Path, + output_mode: &str, + include: Option<&str>, + case_sensitive: bool, + context_before: usize, + context_after: usize, +) -> std::process::Command { + let mut cmd = std::process::Command::new("rg"); + + cmd.arg("--no-heading"); + cmd.arg("--line-number"); + cmd.arg("--with-filename"); + + match output_mode { + "files_with_matches" => { + cmd.arg("--files-with-matches"); + } + "count" => { + cmd.arg("--count"); + } + _ => { + if context_before > 0 { + cmd.arg("-B").arg(context_before.to_string()); + } + if context_after > 0 { + cmd.arg("-A").arg(context_after.to_string()); + } + } + } + + if !case_sensitive { + cmd.arg("-i"); + } + + if let Some(glob) = include { + cmd.arg("--glob").arg(glob); + } + + cmd.arg("--"); + cmd.arg(pattern); + cmd.arg(search_path); + + cmd +} + +fn build_grep_command( + pattern: &str, + search_path: &std::path::Path, + output_mode: &str, + include: Option<&str>, + case_sensitive: bool, + context_before: usize, + context_after: usize, +) -> std::process::Command { + let mut cmd = std::process::Command::new("grep"); + + cmd.arg("-r"); + cmd.arg("-n"); + cmd.arg("-E"); + cmd.arg("--binary-files=without-match"); + + match output_mode { + "files_with_matches" => { + cmd.arg("-l"); + } + "count" => { + cmd.arg("-c"); + } + _ => { + if context_before > 0 { + cmd.arg("-B").arg(context_before.to_string()); + } + if context_after > 0 { + cmd.arg("-A").arg(context_after.to_string()); + } + } + } + + if !case_sensitive { + cmd.arg("-i"); + } + + if let Some(glob) = include { + cmd.arg("--include").arg(glob); + } + + cmd.arg("--"); + cmd.arg(pattern); + cmd.arg(search_path); + + cmd +} + +fn relativize_path(line: &str, workspace_prefix: &str) -> String { + if let Some(rest) = line.strip_prefix(workspace_prefix) { + let trimmed = rest + .strip_prefix('/') + .or_else(|| rest.strip_prefix('\\')) + .unwrap_or(rest); + return trimmed.to_string(); + } + line.to_string() +} + +fn parse_content_line(line: &str) -> Option<(&str, bool)> { + static MATCH_RE: OnceLock = OnceLock::new(); + static CONTEXT_RE: OnceLock = OnceLock::new(); + + let match_re = MATCH_RE.get_or_init(|| { + regex::Regex::new(r"^(?P.+?):\d+:").expect("match line regex must be valid") + }); + if let Some(caps) = match_re.captures(line) { + return caps.name("path").map(|m| (m.as_str(), true)); + } + + let context_re = CONTEXT_RE.get_or_init(|| { + regex::Regex::new(r"^(?P.+?)-\d+-").expect("context line regex must be valid") + }); + if let Some(caps) = context_re.captures(line) { + return caps.name("path").map(|m| (m.as_str(), false)); + } + + None +} + +fn parse_count_line(line: &str) -> Option<(&str, usize)> { + static COUNT_RE: OnceLock = OnceLock::new(); + let count_re = COUNT_RE.get_or_init(|| { + regex::Regex::new(r"^(?P.+?):(?P\d+)\s*$").expect("count line regex valid") + }); + + let caps = count_re.captures(line)?; + let path = caps.name("path")?.as_str(); + let count = caps.name("count")?.as_str().parse::().ok()?; + Some((path, count)) +} + +fn format_line_output( + raw: &str, + workspace_canon: &std::path::Path, + output_mode: &str, + max_results: usize, +) -> String { + if raw.trim().is_empty() { + return "No matches found.".to_string(); + } + + let workspace_prefix = workspace_canon.to_string_lossy(); + + let mut lines: Vec = Vec::new(); + let mut truncated = false; + let mut file_set = std::collections::HashSet::new(); + let mut total_matches: usize = 0; + + for line in raw.lines() { + if line.is_empty() { + continue; + } + + let relativized = relativize_path(line, &workspace_prefix); + + match output_mode { + "files_with_matches" => { + let path = relativized.trim(); + if !path.is_empty() && file_set.insert(path.to_string()) { + lines.push(path.to_string()); + if lines.len() >= max_results { + truncated = true; + break; + } + } + } + "count" => { + if let Some((path, count)) = parse_count_line(&relativized) { + if count > 0 { + file_set.insert(path.to_string()); + total_matches += count; + lines.push(format!("{path}:{count}")); + if lines.len() >= max_results { + truncated = true; + break; + } + } + } + } + _ => { + if relativized == "--" { + lines.push(relativized); + if lines.len() >= max_results { + truncated = true; + break; + } + continue; + } + if let Some((path, is_match)) = parse_content_line(&relativized) { + file_set.insert(path.to_string()); + if is_match { + total_matches += 1; + } + } else { + total_matches += 1; + } + lines.push(relativized); + if lines.len() >= max_results { + truncated = true; + break; + } + } + } + } + + if lines.is_empty() { + return "No matches found.".to_string(); + } + + use std::fmt::Write; + let mut buf = lines.join("\n"); + + if truncated { + let _ = write!( + buf, + "\n\n[Results truncated: showing first {max_results} results]" + ); + } + + match output_mode { + "files_with_matches" => { + let _ = write!(buf, "\n\nTotal: {} files", file_set.len()); + } + "count" => { + let _ = write!( + buf, + "\n\nTotal: {} matches in {} files", + total_matches, + file_set.len() + ); + } + _ => { + let _ = write!( + buf, + "\n\nTotal: {} matching lines in {} files", + total_matches, + file_set.len() + ); + } + } + + buf +} + +fn truncate_utf8(input: &str, max_bytes: usize) -> &str { + if input.len() <= max_bytes { + return input; + } + let mut end = max_bytes; + while end > 0 && !input.is_char_boundary(end) { + end -= 1; + } + &input[..end] +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::agent::Tool; + use crate::security::AutonomyLevel; + use serde_json::json; + + fn test_security(workspace: std::path::PathBuf) -> Arc { + Arc::new(SecurityPolicy { + workspace_dir: workspace, + ..SecurityPolicy::default() + }) + } + + fn create_test_files(dir: &tempfile::TempDir) { + std::fs::write( + dir.path().join("hello.rs"), + "fn main() {\n println!(\"hello\");\n}\n", + ) + .unwrap(); + std::fs::write( + dir.path().join("lib.rs"), + "pub fn greet() {\n println!(\"greet\");\n}\n", + ) + .unwrap(); + std::fs::write(dir.path().join("readme.txt"), "This is a readme file.\n").unwrap(); + } + + #[tokio::test] + async fn content_search_basic_match() { + let dir = tempfile::tempdir().unwrap(); + create_test_files(&dir); + + let tool = ContentSearchTool::new(test_security(dir.path().to_path_buf())); + let result = tool.execute(json!({"pattern": "fn main"})).await.unwrap(); + + assert!(result.success); + assert!(result.output.contains("hello.rs")); + assert!(result.output.contains("fn main")); + } + + #[tokio::test] + async fn content_search_files_with_matches_mode() { + let dir = tempfile::tempdir().unwrap(); + create_test_files(&dir); + + let tool = ContentSearchTool::new(test_security(dir.path().to_path_buf())); + let result = tool + .execute(json!({"pattern": "println", "output_mode": "files_with_matches"})) + .await + .unwrap(); + + assert!(result.success); + assert!(result.output.contains("hello.rs")); + assert!(result.output.contains("lib.rs")); + assert!(!result.output.contains("readme.txt")); + assert!(result.output.contains("Total: 2 files")); + } + + #[tokio::test] + async fn content_search_count_mode() { + let dir = tempfile::tempdir().unwrap(); + create_test_files(&dir); + + let tool = ContentSearchTool::new(test_security(dir.path().to_path_buf())); + let result = tool + .execute(json!({"pattern": "println", "output_mode": "count"})) + .await + .unwrap(); + + assert!(result.success); + assert!(result.output.contains("hello.rs")); + assert!(result.output.contains("lib.rs")); + assert!(result.output.contains("Total:")); + } + + #[tokio::test] + async fn content_search_case_insensitive() { + let dir = tempfile::tempdir().unwrap(); + std::fs::write(dir.path().join("test.txt"), "Hello World\nhello world\n").unwrap(); + + let tool = ContentSearchTool::new(test_security(dir.path().to_path_buf())); + let result = tool + .execute(json!({"pattern": "HELLO", "case_sensitive": false})) + .await + .unwrap(); + + assert!(result.success); + assert!(result.output.contains("Hello World")); + assert!(result.output.contains("hello world")); + } + + #[tokio::test] + async fn content_search_include_filter() { + let dir = tempfile::tempdir().unwrap(); + create_test_files(&dir); + + let tool = ContentSearchTool::new(test_security(dir.path().to_path_buf())); + let result = tool + .execute(json!({"pattern": "fn", "include": "*.rs"})) + .await + .unwrap(); + + assert!(result.success); + assert!(result.output.contains("hello.rs")); + assert!(!result.output.contains("readme.txt")); + } + + #[tokio::test] + async fn content_search_context_lines() { + let dir = tempfile::tempdir().unwrap(); + std::fs::write( + dir.path().join("ctx.rs"), + "line1\nline2\ntarget_line\nline4\nline5\n", + ) + .unwrap(); + + let tool = ContentSearchTool::new(test_security(dir.path().to_path_buf())); + let result = tool + .execute(json!({"pattern": "target_line", "context_before": 1, "context_after": 1})) + .await + .unwrap(); + + assert!(result.success); + assert!(result.output.contains("target_line")); + assert!(result.output.contains("line2")); + assert!(result.output.contains("line4")); + } + + #[tokio::test] + async fn content_search_no_matches() { + let dir = tempfile::tempdir().unwrap(); + create_test_files(&dir); + + let tool = ContentSearchTool::new(test_security(dir.path().to_path_buf())); + let result = tool + .execute(json!({"pattern": "nonexistent_string_xyz"})) + .await + .unwrap(); + + assert!(result.success); + assert!(result.output.contains("No matches found")); + } + + #[tokio::test] + async fn content_search_empty_pattern_rejected() { + let tool = ContentSearchTool::new(test_security(std::env::temp_dir())); + let result = tool.execute(json!({"pattern": ""})).await.unwrap(); + + assert!(!result.success); + assert!(result.error.as_ref().unwrap().contains("Empty pattern")); + } + + #[tokio::test] + async fn content_search_rejects_absolute_path() { + let tool = ContentSearchTool::new(test_security(std::env::temp_dir())); + let result = tool + .execute(json!({"pattern": "test", "path": "/etc"})) + .await + .unwrap(); + + assert!(!result.success); + assert!(result.error.as_ref().unwrap().contains("Absolute paths")); + } + + #[tokio::test] + async fn content_search_rejects_path_traversal() { + let tool = ContentSearchTool::new(test_security(std::env::temp_dir())); + let result = tool + .execute(json!({"pattern": "test", "path": "../../../etc"})) + .await + .unwrap(); + + assert!(!result.success); + assert!(result.error.as_ref().unwrap().contains("Path traversal")); + } + + #[tokio::test] + async fn content_search_rate_limited() { + let dir = tempfile::tempdir().unwrap(); + std::fs::write(dir.path().join("file.txt"), "test content\n").unwrap(); + + let security = Arc::new(SecurityPolicy { + workspace_dir: dir.path().to_path_buf(), + max_actions_per_hour: 0, + ..SecurityPolicy::default() + }); + let tool = ContentSearchTool::new(security); + let result = tool.execute(json!({"pattern": "test"})).await.unwrap(); + + assert!(!result.success); + assert!(result.error.as_ref().unwrap().contains("Rate limit")); + } + + #[tokio::test] + async fn content_search_multiline_without_rg() { + let dir = tempfile::tempdir().unwrap(); + std::fs::write(dir.path().join("test.txt"), "line1\nline2\n").unwrap(); + + let tool = ContentSearchTool::new_with_backend( + test_security(dir.path().to_path_buf()), + false, // no rg + ); + // Without multiline support in grep fallback, this should still work for basic patterns + let result = tool + .execute(json!({"pattern": "line1"})) + .await + .unwrap(); + + assert!(result.success); + } + + #[test] + fn relativize_path_strips_prefix() { + let result = relativize_path("/workspace/src/main.rs:42:fn main()", "/workspace"); + assert_eq!(result, "src/main.rs:42:fn main()"); + } + + #[test] + fn relativize_path_no_prefix() { + let result = relativize_path("src/main.rs:42:fn main()", "/workspace"); + assert_eq!(result, "src/main.rs:42:fn main()"); + } + + #[test] + fn truncate_utf8_keeps_char_boundary() { + let text = "abc\u{4f60}\u{597d}"; // "abc你好" + let truncated = truncate_utf8(text, 4); + assert_eq!(truncated, "abc"); + } +} diff --git a/crewforge-rs/src/tools/file_edit.rs b/crewforge-rs/src/tools/file_edit.rs new file mode 100644 index 0000000..b922fe3 --- /dev/null +++ b/crewforge-rs/src/tools/file_edit.rs @@ -0,0 +1,475 @@ +use crate::agent::ToolResult; +use crate::security::SecurityPolicy; +use async_trait::async_trait; +use std::sync::Arc; + +pub struct FileEditTool { + security: Arc, +} + +impl FileEditTool { + pub fn new(security: Arc) -> Self { + Self { security } + } +} + +#[async_trait] +impl crate::agent::Tool for FileEditTool { + fn name(&self) -> &str { + "file_edit" + } + + fn description(&self) -> &str { + "Edit a file by replacing an exact string match with new content" + } + + fn parameters(&self) -> serde_json::Value { + serde_json::json!({ + "type": "object", + "properties": { + "path": { + "type": "string", + "description": "Relative path to the file from workspace root" + }, + "old_string": { + "type": "string", + "description": "The exact text to find and replace (must appear exactly once)" + }, + "new_string": { + "type": "string", + "description": "The replacement text (empty string to delete the matched text)" + } + }, + "required": ["path", "old_string", "new_string"] + }) + } + + async fn execute(&self, args: serde_json::Value) -> anyhow::Result { + let path = args + .get("path") + .and_then(|v| v.as_str()) + .ok_or_else(|| anyhow::anyhow!("Missing 'path' parameter"))?; + let old_string = args + .get("old_string") + .and_then(|v| v.as_str()) + .ok_or_else(|| anyhow::anyhow!("Missing 'old_string' parameter"))?; + let new_string = args + .get("new_string") + .and_then(|v| v.as_str()) + .ok_or_else(|| anyhow::anyhow!("Missing 'new_string' parameter"))?; + + if old_string.is_empty() { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("old_string must not be empty".into()), + }); + } + + if !self.security.can_act() { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("Action blocked: autonomy is read-only".into()), + }); + } + + if self.security.is_rate_limited() { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("Rate limit exceeded".into()), + }); + } + + if !self.security.is_path_allowed(path) { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!("Path not allowed by security policy: {path}")), + }); + } + + let full_path = self.security.workspace_dir.join(path); + + let Some(parent) = full_path.parent() else { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("Invalid path: missing parent directory".into()), + }); + }; + + let resolved_parent = match tokio::fs::canonicalize(parent).await { + Ok(p) => p, + Err(e) => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!("Failed to resolve file path: {e}")), + }); + } + }; + + if !self.security.is_resolved_path_allowed(&resolved_parent) { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some( + self.security + .resolved_path_violation_message(&resolved_parent), + ), + }); + } + + let Some(file_name) = full_path.file_name() else { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("Invalid path: missing file name".into()), + }); + }; + + let resolved_target = resolved_parent.join(file_name); + + // Symlink check + if let Ok(meta) = tokio::fs::symlink_metadata(&resolved_target).await { + if meta.file_type().is_symlink() { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!( + "Refusing to edit through symlink: {}", + resolved_target.display() + )), + }); + } + } + + if !self.security.record_action() { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("Rate limit exceeded: action budget exhausted".into()), + }); + } + + let content = match tokio::fs::read_to_string(&resolved_target).await { + Ok(c) => c, + Err(e) => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!("Failed to read file: {e}")), + }); + } + }; + + let match_count = content.matches(old_string).count(); + + if match_count == 0 { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("old_string not found in file".into()), + }); + } + + if match_count > 1 { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!( + "old_string matches {match_count} times; must match exactly once" + )), + }); + } + + let new_content = content.replacen(old_string, new_string, 1); + + match tokio::fs::write(&resolved_target, &new_content).await { + Ok(()) => Ok(ToolResult { + success: true, + output: format!( + "Edited {path}: replaced 1 occurrence ({} bytes)", + new_content.len() + ), + error: None, + }), + Err(e) => Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!("Failed to write file: {e}")), + }), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::agent::Tool; + use crate::security::AutonomyLevel; + use serde_json::json; + + fn test_security(workspace: std::path::PathBuf) -> Arc { + Arc::new(SecurityPolicy { + workspace_dir: workspace, + ..SecurityPolicy::default() + }) + } + + #[tokio::test] + async fn file_edit_replaces_single_match() { + let dir = tempfile::tempdir().unwrap(); + tokio::fs::write(dir.path().join("test.txt"), "hello world") + .await + .unwrap(); + + let tool = FileEditTool::new(test_security(dir.path().to_path_buf())); + let result = tool + .execute(json!({ + "path": "test.txt", + "old_string": "hello", + "new_string": "goodbye" + })) + .await + .unwrap(); + + assert!(result.success, "edit should succeed: {:?}", result.error); + assert!(result.output.contains("replaced 1 occurrence")); + + let content = tokio::fs::read_to_string(dir.path().join("test.txt")) + .await + .unwrap(); + assert_eq!(content, "goodbye world"); + } + + #[tokio::test] + async fn file_edit_not_found() { + let dir = tempfile::tempdir().unwrap(); + tokio::fs::write(dir.path().join("test.txt"), "hello world") + .await + .unwrap(); + + let tool = FileEditTool::new(test_security(dir.path().to_path_buf())); + let result = tool + .execute(json!({ + "path": "test.txt", + "old_string": "nonexistent", + "new_string": "replacement" + })) + .await + .unwrap(); + + assert!(!result.success); + assert!(result.error.as_deref().unwrap_or("").contains("not found")); + } + + #[tokio::test] + async fn file_edit_multiple_matches() { + let dir = tempfile::tempdir().unwrap(); + tokio::fs::write(dir.path().join("test.txt"), "aaa bbb aaa") + .await + .unwrap(); + + let tool = FileEditTool::new(test_security(dir.path().to_path_buf())); + let result = tool + .execute(json!({ + "path": "test.txt", + "old_string": "aaa", + "new_string": "ccc" + })) + .await + .unwrap(); + + assert!(!result.success); + assert!(result + .error + .as_deref() + .unwrap_or("") + .contains("matches 2 times")); + } + + #[tokio::test] + async fn file_edit_delete_via_empty_new_string() { + let dir = tempfile::tempdir().unwrap(); + tokio::fs::write(dir.path().join("test.txt"), "keep remove keep") + .await + .unwrap(); + + let tool = FileEditTool::new(test_security(dir.path().to_path_buf())); + let result = tool + .execute(json!({ + "path": "test.txt", + "old_string": " remove", + "new_string": "" + })) + .await + .unwrap(); + + assert!(result.success); + let content = tokio::fs::read_to_string(dir.path().join("test.txt")) + .await + .unwrap(); + assert_eq!(content, "keep keep"); + } + + #[tokio::test] + async fn file_edit_rejects_empty_old_string() { + let dir = tempfile::tempdir().unwrap(); + tokio::fs::write(dir.path().join("test.txt"), "hello") + .await + .unwrap(); + + let tool = FileEditTool::new(test_security(dir.path().to_path_buf())); + let result = tool + .execute(json!({ + "path": "test.txt", + "old_string": "", + "new_string": "x" + })) + .await + .unwrap(); + + assert!(!result.success); + assert!(result + .error + .as_deref() + .unwrap_or("") + .contains("must not be empty")); + } + + #[tokio::test] + async fn file_edit_blocks_path_traversal() { + let dir = tempfile::tempdir().unwrap(); + let tool = FileEditTool::new(test_security(dir.path().to_path_buf())); + let result = tool + .execute(json!({ + "path": "../../etc/passwd", + "old_string": "root", + "new_string": "hacked" + })) + .await + .unwrap(); + + assert!(!result.success); + assert!(result.error.as_ref().unwrap().contains("not allowed")); + } + + #[tokio::test] + async fn file_edit_blocks_readonly_mode() { + let dir = tempfile::tempdir().unwrap(); + tokio::fs::write(dir.path().join("test.txt"), "hello") + .await + .unwrap(); + + let security = Arc::new(SecurityPolicy { + autonomy: AutonomyLevel::ReadOnly, + workspace_dir: dir.path().to_path_buf(), + ..SecurityPolicy::default() + }); + let tool = FileEditTool::new(security); + let result = tool + .execute(json!({ + "path": "test.txt", + "old_string": "hello", + "new_string": "world" + })) + .await + .unwrap(); + + assert!(!result.success); + assert!(result.error.as_deref().unwrap_or("").contains("read-only")); + } + + #[cfg(unix)] + #[tokio::test] + async fn file_edit_blocks_symlink_escape() { + let workspace = tempfile::tempdir().unwrap(); + let outside = tempfile::tempdir().unwrap(); + + std::os::unix::fs::symlink(outside.path(), workspace.path().join("escape_dir")).unwrap(); + + let tool = FileEditTool::new(test_security(workspace.path().to_path_buf())); + let result = tool + .execute(json!({ + "path": "escape_dir/target.txt", + "old_string": "a", + "new_string": "b" + })) + .await + .unwrap(); + + assert!(!result.success); + } + + #[cfg(unix)] + #[tokio::test] + async fn file_edit_blocks_symlink_target_file() { + let workspace = tempfile::tempdir().unwrap(); + let outside = tempfile::tempdir().unwrap(); + tokio::fs::write(outside.path().join("target.txt"), "original") + .await + .unwrap(); + std::os::unix::fs::symlink( + outside.path().join("target.txt"), + workspace.path().join("linked.txt"), + ) + .unwrap(); + + let tool = FileEditTool::new(test_security(workspace.path().to_path_buf())); + let result = tool + .execute(json!({ + "path": "linked.txt", + "old_string": "original", + "new_string": "hacked" + })) + .await + .unwrap(); + + assert!(!result.success); + assert!(result.error.as_deref().unwrap_or("").contains("symlink")); + + let content = tokio::fs::read_to_string(outside.path().join("target.txt")) + .await + .unwrap(); + assert_eq!(content, "original"); + } + + #[tokio::test] + async fn file_edit_blocks_null_byte_in_path() { + let dir = tempfile::tempdir().unwrap(); + let tool = FileEditTool::new(test_security(dir.path().to_path_buf())); + let result = tool + .execute(json!({ + "path": "test\0evil.txt", + "old_string": "old", + "new_string": "new" + })) + .await + .unwrap(); + assert!(!result.success); + } + + #[tokio::test] + async fn file_edit_nonexistent_file() { + let dir = tempfile::tempdir().unwrap(); + let tool = FileEditTool::new(test_security(dir.path().to_path_buf())); + let result = tool + .execute(json!({ + "path": "missing.txt", + "old_string": "a", + "new_string": "b" + })) + .await + .unwrap(); + + assert!(!result.success); + assert!(result + .error + .as_deref() + .unwrap_or("") + .contains("Failed to read file")); + } +} diff --git a/crewforge-rs/src/tools/file_read.rs b/crewforge-rs/src/tools/file_read.rs new file mode 100644 index 0000000..1d2680d --- /dev/null +++ b/crewforge-rs/src/tools/file_read.rs @@ -0,0 +1,427 @@ +use crate::agent::ToolResult; +use crate::security::SecurityPolicy; +use async_trait::async_trait; +use std::sync::Arc; + +const MAX_FILE_SIZE_BYTES: u64 = 10 * 1024 * 1024; + +pub struct FileReadTool { + security: Arc, +} + +impl FileReadTool { + pub fn new(security: Arc) -> Self { + Self { security } + } +} + +#[async_trait] +impl crate::agent::Tool for FileReadTool { + fn name(&self) -> &str { + "file_read" + } + + fn description(&self) -> &str { + "Read file contents with line numbers. Supports partial reading via offset and limit." + } + + fn parameters(&self) -> serde_json::Value { + serde_json::json!({ + "type": "object", + "properties": { + "path": { + "type": "string", + "description": "Relative path to the file from workspace root" + }, + "offset": { + "type": "integer", + "description": "Starting line number (1-based, default: 1)" + }, + "limit": { + "type": "integer", + "description": "Maximum number of lines to return (default: all)" + } + }, + "required": ["path"] + }) + } + + async fn execute(&self, args: serde_json::Value) -> anyhow::Result { + let path = args + .get("path") + .and_then(|v| v.as_str()) + .ok_or_else(|| anyhow::anyhow!("Missing 'path' parameter"))?; + + if self.security.is_rate_limited() { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("Rate limit exceeded".into()), + }); + } + + if !self.security.is_path_allowed(path) { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!("Path not allowed by security policy: {path}")), + }); + } + + if !self.security.record_action() { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("Rate limit exceeded: action budget exhausted".into()), + }); + } + + let full_path = self.security.workspace_dir.join(path); + + let resolved_path = match tokio::fs::canonicalize(&full_path).await { + Ok(p) => p, + Err(e) => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!("Failed to resolve file path: {e}")), + }); + } + }; + + if !self.security.is_resolved_path_allowed(&resolved_path) { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(self.security.resolved_path_violation_message(&resolved_path)), + }); + } + + match tokio::fs::metadata(&resolved_path).await { + Ok(meta) => { + if meta.len() > MAX_FILE_SIZE_BYTES { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!( + "File too large: {} bytes (limit: {MAX_FILE_SIZE_BYTES} bytes)", + meta.len() + )), + }); + } + } + Err(e) => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!("Failed to read file metadata: {e}")), + }); + } + } + + match tokio::fs::read_to_string(&resolved_path).await { + Ok(contents) => { + let lines: Vec<&str> = contents.lines().collect(); + let total = lines.len(); + + if total == 0 { + return Ok(ToolResult { + success: true, + output: String::new(), + error: None, + }); + } + + let offset = args + .get("offset") + .and_then(|v| v.as_u64()) + .map(|v| { + usize::try_from(v.max(1)) + .unwrap_or(usize::MAX) + .saturating_sub(1) + }) + .unwrap_or(0); + let start = offset.min(total); + + let end = match args.get("limit").and_then(|v| v.as_u64()) { + Some(l) => { + let limit = usize::try_from(l).unwrap_or(usize::MAX); + start.saturating_add(limit).min(total) + } + None => total, + }; + + if start >= end { + return Ok(ToolResult { + success: true, + output: format!("[No lines in range, file has {total} lines]"), + error: None, + }); + } + + let numbered: String = lines[start..end] + .iter() + .enumerate() + .map(|(i, line)| format!("{}: {}", start + i + 1, line)) + .collect::>() + .join("\n"); + + let partial = start > 0 || end < total; + let summary = if partial { + format!("\n[Lines {}-{} of {total}]", start + 1, end) + } else { + format!("\n[{total} lines total]") + }; + + Ok(ToolResult { + success: true, + output: format!("{numbered}{summary}"), + error: None, + }) + } + Err(_) => { + let bytes = tokio::fs::read(&resolved_path) + .await + .map_err(|e| anyhow::anyhow!("Failed to read file: {e}"))?; + + let lossy = String::from_utf8_lossy(&bytes).into_owned(); + Ok(ToolResult { + success: true, + output: lossy, + error: None, + }) + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::agent::Tool; + use crate::security::AutonomyLevel; + use serde_json::json; + + fn test_security(workspace: std::path::PathBuf) -> Arc { + Arc::new(SecurityPolicy { + workspace_dir: workspace, + ..SecurityPolicy::default() + }) + } + + #[tokio::test] + async fn file_read_existing_file() { + let dir = tempfile::tempdir().unwrap(); + tokio::fs::write(dir.path().join("test.txt"), "hello world") + .await + .unwrap(); + + let tool = FileReadTool::new(test_security(dir.path().to_path_buf())); + let result = tool.execute(json!({"path": "test.txt"})).await.unwrap(); + assert!(result.success); + assert!(result.output.contains("1: hello world")); + assert!(result.output.contains("[1 lines total]")); + } + + #[tokio::test] + async fn file_read_nonexistent_file() { + let dir = tempfile::tempdir().unwrap(); + let tool = FileReadTool::new(test_security(dir.path().to_path_buf())); + let result = tool + .execute(json!({"path": "no_such_file.txt"})) + .await + .unwrap(); + assert!(!result.success); + assert!(result.error.is_some()); + } + + #[tokio::test] + async fn file_read_blocks_path_traversal() { + let dir = tempfile::tempdir().unwrap(); + let tool = FileReadTool::new(test_security(dir.path().to_path_buf())); + let result = tool + .execute(json!({"path": "../../../etc/passwd"})) + .await + .unwrap(); + assert!(!result.success); + assert!(result.error.unwrap().contains("not allowed")); + } + + #[tokio::test] + async fn file_read_blocks_absolute_path() { + let dir = tempfile::tempdir().unwrap(); + let tool = FileReadTool::new(test_security(dir.path().to_path_buf())); + let result = tool + .execute(json!({"path": "/etc/passwd"})) + .await + .unwrap(); + assert!(!result.success); + } + + #[tokio::test] + async fn file_read_blocks_when_rate_limited() { + let dir = tempfile::tempdir().unwrap(); + tokio::fs::write(dir.path().join("test.txt"), "ok") + .await + .unwrap(); + + let security = Arc::new(SecurityPolicy { + workspace_dir: dir.path().to_path_buf(), + max_actions_per_hour: 0, + ..SecurityPolicy::default() + }); + let tool = FileReadTool::new(security); + let result = tool.execute(json!({"path": "test.txt"})).await.unwrap(); + assert!(!result.success); + assert!(result.error.unwrap().contains("Rate limit")); + } + + #[tokio::test] + async fn file_read_with_offset_and_limit() { + let dir = tempfile::tempdir().unwrap(); + tokio::fs::write(dir.path().join("multi.txt"), "line1\nline2\nline3\nline4\nline5") + .await + .unwrap(); + + let tool = FileReadTool::new(test_security(dir.path().to_path_buf())); + let result = tool + .execute(json!({"path": "multi.txt", "offset": 2, "limit": 2})) + .await + .unwrap(); + assert!(result.success); + assert!(result.output.contains("2: line2")); + assert!(result.output.contains("3: line3")); + assert!(!result.output.contains("4: line4")); + } + + #[tokio::test] + async fn file_read_offset_beyond_end() { + let dir = tempfile::tempdir().unwrap(); + tokio::fs::write(dir.path().join("short.txt"), "one\ntwo") + .await + .unwrap(); + + let tool = FileReadTool::new(test_security(dir.path().to_path_buf())); + let result = tool + .execute(json!({"path": "short.txt", "offset": 100})) + .await + .unwrap(); + assert!(result.success); + assert!(result.output.contains("No lines in range")); + } + + #[tokio::test] + async fn file_read_empty_file() { + let dir = tempfile::tempdir().unwrap(); + tokio::fs::write(dir.path().join("empty.txt"), "") + .await + .unwrap(); + + let tool = FileReadTool::new(test_security(dir.path().to_path_buf())); + let result = tool + .execute(json!({"path": "empty.txt"})) + .await + .unwrap(); + assert!(result.success); + assert!(result.output.is_empty()); + } + + #[tokio::test] + async fn file_read_nested_path() { + let dir = tempfile::tempdir().unwrap(); + tokio::fs::create_dir_all(dir.path().join("sub/dir")) + .await + .unwrap(); + tokio::fs::write(dir.path().join("sub/dir/file.txt"), "nested content") + .await + .unwrap(); + + let tool = FileReadTool::new(test_security(dir.path().to_path_buf())); + let result = tool + .execute(json!({"path": "sub/dir/file.txt"})) + .await + .unwrap(); + assert!(result.success); + assert!(result.output.contains("nested content")); + } + + #[cfg(unix)] + #[tokio::test] + async fn file_read_blocks_symlink_escape() { + let workspace = tempfile::tempdir().unwrap(); + let outside = tempfile::tempdir().unwrap(); + tokio::fs::write(outside.path().join("secret.txt"), "secret data") + .await + .unwrap(); + std::os::unix::fs::symlink( + outside.path().join("secret.txt"), + workspace.path().join("link.txt"), + ) + .unwrap(); + + let tool = FileReadTool::new(test_security(workspace.path().to_path_buf())); + let result = tool.execute(json!({"path": "link.txt"})).await.unwrap(); + assert!(!result.success); + } + + #[tokio::test] + async fn file_read_blocks_null_byte_in_path() { + let dir = tempfile::tempdir().unwrap(); + let tool = FileReadTool::new(test_security(dir.path().to_path_buf())); + let result = tool + .execute(json!({"path": "file\0.txt"})) + .await + .unwrap(); + assert!(!result.success); + } + + #[tokio::test] + async fn file_read_rejects_oversized_file() { + let dir = tempfile::tempdir().unwrap(); + let big_file = dir.path().join("big.bin"); + // Create a file just over the limit using sparse writing + let f = std::fs::File::create(&big_file).unwrap(); + f.set_len(MAX_FILE_SIZE_BYTES + 1).unwrap(); + + let tool = FileReadTool::new(test_security(dir.path().to_path_buf())); + let result = tool.execute(json!({"path": "big.bin"})).await.unwrap(); + assert!(!result.success); + assert!(result.error.unwrap().contains("too large")); + } + + #[tokio::test] + async fn file_read_lossy_reads_binary_file() { + let dir = tempfile::tempdir().unwrap(); + tokio::fs::write(dir.path().join("binary.bin"), b"\xff\xfe\x00\x01hello") + .await + .unwrap(); + + let tool = FileReadTool::new(test_security(dir.path().to_path_buf())); + let result = tool + .execute(json!({"path": "binary.bin"})) + .await + .unwrap(); + assert!(result.success); + assert!(result.output.contains("hello")); + } + + #[tokio::test] + async fn file_read_blocks_readonly_not_applicable() { + // ReadOnly mode should still allow file reads (they don't use enforce_tool_operation) + let dir = tempfile::tempdir().unwrap(); + tokio::fs::write(dir.path().join("file.txt"), "content") + .await + .unwrap(); + + let security = Arc::new(SecurityPolicy { + autonomy: AutonomyLevel::ReadOnly, + workspace_dir: dir.path().to_path_buf(), + ..SecurityPolicy::default() + }); + // FileReadTool uses is_rate_limited + is_path_allowed + record_action + // ReadOnly doesn't block reads via is_path_allowed + let tool = FileReadTool::new(security); + let result = tool.execute(json!({"path": "file.txt"})).await.unwrap(); + assert!(result.success); + } +} diff --git a/crewforge-rs/src/tools/file_write.rs b/crewforge-rs/src/tools/file_write.rs new file mode 100644 index 0000000..e8d64ad --- /dev/null +++ b/crewforge-rs/src/tools/file_write.rs @@ -0,0 +1,311 @@ +use crate::agent::ToolResult; +use crate::security::SecurityPolicy; +use async_trait::async_trait; +use std::sync::Arc; + +pub struct FileWriteTool { + security: Arc, +} + +impl FileWriteTool { + pub fn new(security: Arc) -> Self { + Self { security } + } +} + +#[async_trait] +impl crate::agent::Tool for FileWriteTool { + fn name(&self) -> &str { + "file_write" + } + + fn description(&self) -> &str { + "Write content to a file. Creates parent directories if needed. Refuses to write through symlinks." + } + + fn parameters(&self) -> serde_json::Value { + serde_json::json!({ + "type": "object", + "properties": { + "path": { + "type": "string", + "description": "Relative path to write to" + }, + "content": { + "type": "string", + "description": "Content to write to the file" + } + }, + "required": ["path", "content"] + }) + } + + async fn execute(&self, args: serde_json::Value) -> anyhow::Result { + let path = args + .get("path") + .and_then(|v| v.as_str()) + .ok_or_else(|| anyhow::anyhow!("Missing 'path' parameter"))?; + let content = args + .get("content") + .and_then(|v| v.as_str()) + .ok_or_else(|| anyhow::anyhow!("Missing 'content' parameter"))?; + + if !self.security.can_act() { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("Security policy: read-only mode, cannot write files".into()), + }); + } + + if self.security.is_rate_limited() { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("Rate limit exceeded".into()), + }); + } + + if !self.security.is_path_allowed(path) { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!("Path not allowed by security policy: {path}")), + }); + } + + let full_path = self.security.workspace_dir.join(path); + + // Create parent directories + if let Some(parent) = full_path.parent() { + if let Err(e) = tokio::fs::create_dir_all(parent).await { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!("Failed to create parent directories: {e}")), + }); + } + + // Canonicalize parent to check resolved path + let resolved_parent = match parent.canonicalize() { + Ok(p) => p, + Err(e) => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!("Failed to resolve parent path: {e}")), + }); + } + }; + + if !self.security.is_resolved_path_allowed(&resolved_parent) { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some( + self.security + .resolved_path_violation_message(&resolved_parent), + ), + }); + } + } + + // Refuse to write through symlinks + #[cfg(unix)] + if full_path.exists() { + let meta = std::fs::symlink_metadata(&full_path); + if let Ok(m) = meta { + if m.file_type().is_symlink() { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("Refusing to write through symlink".into()), + }); + } + } + } + + if !self.security.record_action() { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("Rate limit exceeded: action budget exhausted".into()), + }); + } + + match tokio::fs::write(&full_path, content).await { + Ok(()) => Ok(ToolResult { + success: true, + output: format!("Wrote {} bytes to {path}", content.len()), + error: None, + }), + Err(e) => Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!("Failed to write file: {e}")), + }), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::agent::Tool; + use crate::security::AutonomyLevel; + use serde_json::json; + + fn test_security(workspace: std::path::PathBuf) -> Arc { + Arc::new(SecurityPolicy { + workspace_dir: workspace, + ..SecurityPolicy::default() + }) + } + + #[tokio::test] + async fn file_write_creates_file() { + let dir = tempfile::tempdir().unwrap(); + let tool = FileWriteTool::new(test_security(dir.path().to_path_buf())); + let result = tool + .execute(json!({"path": "new.txt", "content": "hello"})) + .await + .unwrap(); + assert!(result.success); + let content = tokio::fs::read_to_string(dir.path().join("new.txt")) + .await + .unwrap(); + assert_eq!(content, "hello"); + } + + #[tokio::test] + async fn file_write_creates_parent_dirs() { + let dir = tempfile::tempdir().unwrap(); + let tool = FileWriteTool::new(test_security(dir.path().to_path_buf())); + let result = tool + .execute(json!({"path": "sub/dir/file.txt", "content": "nested"})) + .await + .unwrap(); + assert!(result.success); + assert!(dir.path().join("sub/dir/file.txt").exists()); + } + + #[tokio::test] + async fn file_write_overwrites_existing() { + let dir = tempfile::tempdir().unwrap(); + tokio::fs::write(dir.path().join("exist.txt"), "old") + .await + .unwrap(); + + let tool = FileWriteTool::new(test_security(dir.path().to_path_buf())); + let result = tool + .execute(json!({"path": "exist.txt", "content": "new"})) + .await + .unwrap(); + assert!(result.success); + let content = tokio::fs::read_to_string(dir.path().join("exist.txt")) + .await + .unwrap(); + assert_eq!(content, "new"); + } + + #[tokio::test] + async fn file_write_blocks_path_traversal() { + let dir = tempfile::tempdir().unwrap(); + let tool = FileWriteTool::new(test_security(dir.path().to_path_buf())); + let result = tool + .execute(json!({"path": "../escape.txt", "content": "bad"})) + .await + .unwrap(); + assert!(!result.success); + } + + #[tokio::test] + async fn file_write_blocks_absolute_path() { + let dir = tempfile::tempdir().unwrap(); + let tool = FileWriteTool::new(test_security(dir.path().to_path_buf())); + let result = tool + .execute(json!({"path": "/tmp/bad.txt", "content": "bad"})) + .await + .unwrap(); + assert!(!result.success); + } + + #[tokio::test] + async fn file_write_blocks_readonly_mode() { + let dir = tempfile::tempdir().unwrap(); + let security = Arc::new(SecurityPolicy { + autonomy: AutonomyLevel::ReadOnly, + workspace_dir: dir.path().to_path_buf(), + ..SecurityPolicy::default() + }); + let tool = FileWriteTool::new(security); + let result = tool + .execute(json!({"path": "file.txt", "content": "no"})) + .await + .unwrap(); + assert!(!result.success); + assert!(result.error.unwrap().contains("read-only")); + } + + #[tokio::test] + async fn file_write_blocks_when_rate_limited() { + let dir = tempfile::tempdir().unwrap(); + let security = Arc::new(SecurityPolicy { + workspace_dir: dir.path().to_path_buf(), + max_actions_per_hour: 0, + ..SecurityPolicy::default() + }); + let tool = FileWriteTool::new(security); + let result = tool + .execute(json!({"path": "file.txt", "content": "no"})) + .await + .unwrap(); + assert!(!result.success); + } + + #[cfg(unix)] + #[tokio::test] + async fn file_write_blocks_symlink_escape() { + let workspace = tempfile::tempdir().unwrap(); + let outside = tempfile::tempdir().unwrap(); + std::os::unix::fs::symlink(outside.path(), workspace.path().join("link_dir")).unwrap(); + + let tool = FileWriteTool::new(test_security(workspace.path().to_path_buf())); + let result = tool + .execute(json!({"path": "link_dir/file.txt", "content": "bad"})) + .await + .unwrap(); + assert!(!result.success); + } + + #[cfg(unix)] + #[tokio::test] + async fn file_write_blocks_symlink_target_file() { + let workspace = tempfile::tempdir().unwrap(); + let outside = tempfile::tempdir().unwrap(); + let target = outside.path().join("target.txt"); + tokio::fs::write(&target, "original").await.unwrap(); + std::os::unix::fs::symlink(&target, workspace.path().join("link.txt")).unwrap(); + + let tool = FileWriteTool::new(test_security(workspace.path().to_path_buf())); + let result = tool + .execute(json!({"path": "link.txt", "content": "overwrite"})) + .await + .unwrap(); + assert!(!result.success); + // Verify original file was not modified + let content = tokio::fs::read_to_string(&target).await.unwrap(); + assert_eq!(content, "original"); + } + + #[tokio::test] + async fn file_write_blocks_null_byte_in_path() { + let dir = tempfile::tempdir().unwrap(); + let tool = FileWriteTool::new(test_security(dir.path().to_path_buf())); + let result = tool + .execute(json!({"path": "file\0.txt", "content": "bad"})) + .await + .unwrap(); + assert!(!result.success); + } +} diff --git a/crewforge-rs/src/tools/glob_search.rs b/crewforge-rs/src/tools/glob_search.rs new file mode 100644 index 0000000..48da5a6 --- /dev/null +++ b/crewforge-rs/src/tools/glob_search.rs @@ -0,0 +1,357 @@ +use crate::agent::ToolResult; +use crate::security::SecurityPolicy; +use async_trait::async_trait; +use std::sync::Arc; + +const MAX_RESULTS: usize = 1000; + +pub struct GlobSearchTool { + security: Arc, +} + +impl GlobSearchTool { + pub fn new(security: Arc) -> Self { + Self { security } + } +} + +#[async_trait] +impl crate::agent::Tool for GlobSearchTool { + fn name(&self) -> &str { + "glob_search" + } + + fn description(&self) -> &str { + "Search for files matching a glob pattern within the workspace. \ + Returns a sorted list of matching file paths relative to the workspace root." + } + + fn parameters(&self) -> serde_json::Value { + serde_json::json!({ + "type": "object", + "properties": { + "pattern": { + "type": "string", + "description": "Glob pattern to match files, e.g. '**/*.rs', 'src/**/mod.rs'" + } + }, + "required": ["pattern"] + }) + } + + async fn execute(&self, args: serde_json::Value) -> anyhow::Result { + let pattern = args + .get("pattern") + .and_then(|v| v.as_str()) + .ok_or_else(|| anyhow::anyhow!("Missing 'pattern' parameter"))?; + + if self.security.is_rate_limited() { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("Rate limit exceeded".into()), + }); + } + + // Reject absolute paths + if pattern.starts_with('/') || pattern.starts_with('\\') { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("Absolute paths are not allowed. Use a relative glob pattern.".into()), + }); + } + + // Reject path traversal + if pattern.contains("../") || pattern.contains("..\\") || pattern == ".." { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("Path traversal ('..') is not allowed in glob patterns.".into()), + }); + } + + if !self.security.record_action() { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("Rate limit exceeded: action budget exhausted".into()), + }); + } + + let workspace = &self.security.workspace_dir; + let full_pattern = workspace.join(pattern).to_string_lossy().to_string(); + + let entries = match glob::glob(&full_pattern) { + Ok(paths) => paths, + Err(e) => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!("Invalid glob pattern: {e}")), + }); + } + }; + + let workspace_canon = match std::fs::canonicalize(workspace) { + Ok(p) => p, + Err(e) => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!("Cannot resolve workspace directory: {e}")), + }); + } + }; + + let mut results = Vec::new(); + let mut truncated = false; + + for entry in entries { + let path = match entry { + Ok(p) => p, + Err(_) => continue, + }; + + let resolved = match std::fs::canonicalize(&path) { + Ok(p) => p, + Err(_) => continue, + }; + + if !self.security.is_resolved_path_allowed(&resolved) { + continue; + } + + if resolved.is_dir() { + continue; + } + + if let Ok(rel) = resolved.strip_prefix(&workspace_canon) { + results.push(rel.to_string_lossy().to_string()); + } + + if results.len() >= MAX_RESULTS { + truncated = true; + break; + } + } + + results.sort(); + + let output = if results.is_empty() { + format!("No files matching pattern '{pattern}' found in workspace.") + } else { + use std::fmt::Write; + let mut buf = results.join("\n"); + if truncated { + let _ = write!( + buf, + "\n\n[Results truncated: showing first {MAX_RESULTS} of more matches]" + ); + } + let _ = write!(buf, "\n\nTotal: {} files", results.len()); + buf + }; + + Ok(ToolResult { + success: true, + output, + error: None, + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::agent::Tool; + use crate::security::AutonomyLevel; + use serde_json::json; + + fn test_security(workspace: std::path::PathBuf) -> Arc { + Arc::new(SecurityPolicy { + workspace_dir: workspace, + ..SecurityPolicy::default() + }) + } + + #[tokio::test] + async fn glob_search_single_file() { + let dir = tempfile::tempdir().unwrap(); + std::fs::write(dir.path().join("hello.txt"), "content").unwrap(); + + let tool = GlobSearchTool::new(test_security(dir.path().to_path_buf())); + let result = tool + .execute(json!({"pattern": "hello.txt"})) + .await + .unwrap(); + + assert!(result.success); + assert!(result.output.contains("hello.txt")); + } + + #[tokio::test] + async fn glob_search_multiple_files() { + let dir = tempfile::tempdir().unwrap(); + std::fs::write(dir.path().join("a.txt"), "").unwrap(); + std::fs::write(dir.path().join("b.txt"), "").unwrap(); + std::fs::write(dir.path().join("c.rs"), "").unwrap(); + + let tool = GlobSearchTool::new(test_security(dir.path().to_path_buf())); + let result = tool.execute(json!({"pattern": "*.txt"})).await.unwrap(); + + assert!(result.success); + assert!(result.output.contains("a.txt")); + assert!(result.output.contains("b.txt")); + assert!(!result.output.contains("c.rs")); + } + + #[tokio::test] + async fn glob_search_recursive() { + let dir = tempfile::tempdir().unwrap(); + std::fs::create_dir_all(dir.path().join("sub/deep")).unwrap(); + std::fs::write(dir.path().join("root.txt"), "").unwrap(); + std::fs::write(dir.path().join("sub/mid.txt"), "").unwrap(); + std::fs::write(dir.path().join("sub/deep/leaf.txt"), "").unwrap(); + + let tool = GlobSearchTool::new(test_security(dir.path().to_path_buf())); + let result = tool + .execute(json!({"pattern": "**/*.txt"})) + .await + .unwrap(); + + assert!(result.success); + assert!(result.output.contains("root.txt")); + assert!(result.output.contains("mid.txt")); + assert!(result.output.contains("leaf.txt")); + } + + #[tokio::test] + async fn glob_search_no_matches() { + let dir = tempfile::tempdir().unwrap(); + + let tool = GlobSearchTool::new(test_security(dir.path().to_path_buf())); + let result = tool + .execute(json!({"pattern": "*.nonexistent"})) + .await + .unwrap(); + + assert!(result.success); + assert!(result.output.contains("No files matching pattern")); + } + + #[tokio::test] + async fn glob_search_rejects_absolute_path() { + let tool = GlobSearchTool::new(test_security(std::env::temp_dir())); + let result = tool + .execute(json!({"pattern": "/etc/**/*"})) + .await + .unwrap(); + + assert!(!result.success); + assert!(result.error.as_ref().unwrap().contains("Absolute paths")); + } + + #[tokio::test] + async fn glob_search_rejects_path_traversal() { + let tool = GlobSearchTool::new(test_security(std::env::temp_dir())); + let result = tool + .execute(json!({"pattern": "../../../etc/passwd"})) + .await + .unwrap(); + + assert!(!result.success); + assert!(result.error.as_ref().unwrap().contains("Path traversal")); + } + + #[cfg(unix)] + #[tokio::test] + async fn glob_search_filters_symlink_escape() { + let root = tempfile::tempdir().unwrap(); + let workspace = root.path().join("workspace"); + let outside = root.path().join("outside"); + + std::fs::create_dir_all(&workspace).unwrap(); + std::fs::create_dir_all(&outside).unwrap(); + std::fs::write(outside.join("secret.txt"), "leaked").unwrap(); + + std::os::unix::fs::symlink(outside.join("secret.txt"), workspace.join("escape.txt")) + .unwrap(); + std::fs::write(workspace.join("legit.txt"), "ok").unwrap(); + + let tool = GlobSearchTool::new(test_security(workspace.clone())); + let result = tool.execute(json!({"pattern": "*.txt"})).await.unwrap(); + + assert!(result.success); + assert!(result.output.contains("legit.txt")); + assert!(!result.output.contains("escape.txt")); + assert!(!result.output.contains("secret.txt")); + } + + #[tokio::test] + async fn glob_search_results_sorted() { + let dir = tempfile::tempdir().unwrap(); + std::fs::write(dir.path().join("c.txt"), "").unwrap(); + std::fs::write(dir.path().join("a.txt"), "").unwrap(); + std::fs::write(dir.path().join("b.txt"), "").unwrap(); + + let tool = GlobSearchTool::new(test_security(dir.path().to_path_buf())); + let result = tool.execute(json!({"pattern": "*.txt"})).await.unwrap(); + + assert!(result.success); + let lines: Vec<&str> = result.output.lines().collect(); + assert!(lines.len() >= 3); + assert_eq!(lines[0], "a.txt"); + assert_eq!(lines[1], "b.txt"); + assert_eq!(lines[2], "c.txt"); + } + + #[tokio::test] + async fn glob_search_excludes_directories() { + let dir = tempfile::tempdir().unwrap(); + std::fs::create_dir(dir.path().join("subdir")).unwrap(); + std::fs::write(dir.path().join("file.txt"), "").unwrap(); + + let tool = GlobSearchTool::new(test_security(dir.path().to_path_buf())); + let result = tool.execute(json!({"pattern": "*"})).await.unwrap(); + + assert!(result.success); + assert!(result.output.contains("file.txt")); + assert!(!result.output.contains("subdir")); + } + + #[tokio::test] + async fn glob_search_rate_limited() { + let dir = tempfile::tempdir().unwrap(); + std::fs::write(dir.path().join("file.txt"), "").unwrap(); + + let security = Arc::new(SecurityPolicy { + workspace_dir: dir.path().to_path_buf(), + max_actions_per_hour: 0, + ..SecurityPolicy::default() + }); + let tool = GlobSearchTool::new(security); + let result = tool.execute(json!({"pattern": "*.txt"})).await.unwrap(); + + assert!(!result.success); + assert!(result.error.as_ref().unwrap().contains("Rate limit")); + } + + #[tokio::test] + async fn glob_search_invalid_pattern() { + let dir = tempfile::tempdir().unwrap(); + + let tool = GlobSearchTool::new(test_security(dir.path().to_path_buf())); + let result = tool + .execute(json!({"pattern": "[invalid"})) + .await + .unwrap(); + + assert!(!result.success); + assert!(result + .error + .as_ref() + .unwrap() + .contains("Invalid glob pattern")); + } +} diff --git a/crewforge-rs/src/tools/mod.rs b/crewforge-rs/src/tools/mod.rs new file mode 100644 index 0000000..87cc976 --- /dev/null +++ b/crewforge-rs/src/tools/mod.rs @@ -0,0 +1,43 @@ +pub mod content_search; +pub mod file_edit; +pub mod file_read; +pub mod file_write; +pub mod glob_search; +pub mod shell; +pub mod traits; + +pub use content_search::ContentSearchTool; +pub use file_edit::FileEditTool; +pub use file_read::FileReadTool; +pub use file_write::FileWriteTool; +pub use glob_search::GlobSearchTool; +pub use shell::ShellTool; +pub use traits::{RuntimeAdapter, TokioRuntime}; + +use crate::agent::Tool; +use crate::security::SecurityPolicy; +use std::sync::Arc; + +/// All built-in tools backed by SecurityPolicy. +pub fn default_tools( + security: Arc, + runtime: Arc, +) -> Vec> { + vec![ + Box::new(ShellTool::new(security.clone(), runtime)), + Box::new(FileReadTool::new(security.clone())), + Box::new(FileWriteTool::new(security.clone())), + Box::new(FileEditTool::new(security.clone())), + Box::new(GlobSearchTool::new(security.clone())), + Box::new(ContentSearchTool::new(security)), + ] +} + +/// All tools including future memory/browser tools. +/// For now, same as default_tools. +pub fn all_tools( + security: Arc, + runtime: Arc, +) -> Vec> { + default_tools(security, runtime) +} diff --git a/crewforge-rs/src/tools/shell.rs b/crewforge-rs/src/tools/shell.rs new file mode 100644 index 0000000..90dd73e --- /dev/null +++ b/crewforge-rs/src/tools/shell.rs @@ -0,0 +1,434 @@ +use crate::agent::ToolResult; +use crate::security::SecurityPolicy; +use crate::tools::traits::RuntimeAdapter; +use async_trait::async_trait; +use std::collections::HashSet; +use std::sync::Arc; +use std::time::Duration; + +const SHELL_TIMEOUT_SECS: u64 = 60; +const MAX_OUTPUT_BYTES: usize = 1_048_576; // 1 MB +const SAFE_ENV_VARS: &[&str] = &[ + "PATH", "HOME", "TERM", "LANG", "LC_ALL", "LC_CTYPE", "USER", "SHELL", "TMPDIR", +]; + +pub struct ShellTool { + security: Arc, + runtime: Arc, +} + +impl ShellTool { + pub fn new(security: Arc, runtime: Arc) -> Self { + Self { security, runtime } + } +} + +fn is_valid_env_var_name(name: &str) -> bool { + if name.is_empty() { + return false; + } + let mut chars = name.chars(); + let first = chars.next().unwrap(); + if !first.is_ascii_alphabetic() && first != '_' { + return false; + } + chars.all(|c| c.is_ascii_alphanumeric() || c == '_') +} + +pub(crate) fn collect_allowed_shell_env_vars(security: &SecurityPolicy) -> Vec { + let mut out = Vec::new(); + let mut seen = HashSet::new(); + for key in SAFE_ENV_VARS + .iter() + .copied() + .chain(security.shell_env_passthrough.iter().map(|s| s.as_str())) + { + let candidate = key.trim(); + if candidate.is_empty() || !is_valid_env_var_name(candidate) { + continue; + } + if seen.insert(candidate.to_string()) { + out.push(candidate.to_string()); + } + } + out +} + +fn extract_command_argument(args: &serde_json::Value) -> Option { + if let Some(command) = args + .get("command") + .and_then(|v| v.as_str()) + .map(str::trim) + .filter(|cmd| !cmd.is_empty()) + { + return Some(command.to_string()); + } + + for alias in [ + "cmd", + "script", + "shell_command", + "command_line", + "bash", + "sh", + "input", + ] { + if let Some(command) = args + .get(alias) + .and_then(|v| v.as_str()) + .map(str::trim) + .filter(|cmd| !cmd.is_empty()) + { + return Some(command.to_string()); + } + } + + args.as_str() + .map(str::trim) + .filter(|cmd| !cmd.is_empty()) + .map(ToString::to_string) +} + +fn truncate_utf8(s: &str, max_bytes: usize) -> usize { + if s.len() <= max_bytes { + return s.len(); + } + let mut end = max_bytes; + while end > 0 && !s.is_char_boundary(end) { + end -= 1; + } + end +} + +#[async_trait] +impl crate::agent::Tool for ShellTool { + fn name(&self) -> &str { + "shell" + } + + fn description(&self) -> &str { + "Execute a shell command in the workspace directory" + } + + fn parameters(&self) -> serde_json::Value { + serde_json::json!({ + "type": "object", + "properties": { + "command": { + "type": "string", + "description": "The shell command to execute" + }, + "approved": { + "type": "boolean", + "description": "Set true to explicitly approve medium/high-risk commands", + "default": false + } + }, + "required": ["command"] + }) + } + + async fn execute(&self, args: serde_json::Value) -> anyhow::Result { + let command = extract_command_argument(&args) + .ok_or_else(|| anyhow::anyhow!("Missing 'command' parameter"))?; + let approved = args + .get("approved") + .and_then(|v| v.as_bool()) + .unwrap_or(false); + + if self.security.is_rate_limited() { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("Rate limit exceeded".into()), + }); + } + + match self.security.validate_command_execution(&command, approved) { + Ok(_) => {} + Err(reason) => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(reason), + }); + } + } + + if let Some(path) = self.security.forbidden_path_argument(&command) { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!("Path blocked by security policy: {path}")), + }); + } + + if !self.security.record_action() { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("Rate limit exceeded: action budget exhausted".into()), + }); + } + + let mut cmd = match self + .runtime + .build_shell_command(&command, &self.security.workspace_dir) + { + Ok(cmd) => cmd, + Err(e) => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!("Failed to build runtime command: {e}")), + }); + } + }; + cmd.env_clear(); + + for var in collect_allowed_shell_env_vars(&self.security) { + if let Ok(val) = std::env::var(&var) { + cmd.env(&var, val); + } + } + + let result = + tokio::time::timeout(Duration::from_secs(SHELL_TIMEOUT_SECS), cmd.output()).await; + + match result { + Ok(Ok(output)) => { + let mut stdout = String::from_utf8_lossy(&output.stdout).to_string(); + let mut stderr = String::from_utf8_lossy(&output.stderr).to_string(); + + if stdout.len() > MAX_OUTPUT_BYTES { + let boundary = truncate_utf8(&stdout, MAX_OUTPUT_BYTES); + stdout.truncate(boundary); + stdout.push_str("\n... [output truncated at 1MB]"); + } + if stderr.len() > MAX_OUTPUT_BYTES { + let boundary = truncate_utf8(&stderr, MAX_OUTPUT_BYTES); + stderr.truncate(boundary); + stderr.push_str("\n... [stderr truncated at 1MB]"); + } + + Ok(ToolResult { + success: output.status.success(), + output: stdout, + error: if stderr.is_empty() { + None + } else { + Some(stderr) + }, + }) + } + Ok(Err(e)) => Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!("Failed to execute command: {e}")), + }), + Err(_) => Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!( + "Command timed out after {SHELL_TIMEOUT_SECS}s and was killed" + )), + }), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::agent::Tool; + use crate::security::AutonomyLevel; + use crate::tools::traits::TokioRuntime; + use serde_json::json; + + fn test_security(autonomy: AutonomyLevel) -> Arc { + Arc::new(SecurityPolicy { + autonomy, + workspace_dir: std::env::temp_dir(), + ..SecurityPolicy::default() + }) + } + + fn test_runtime() -> Arc { + Arc::new(TokioRuntime) + } + + #[test] + fn shell_name_and_schema() { + let tool = ShellTool::new(test_security(AutonomyLevel::Supervised), test_runtime()); + assert_eq!(tool.name(), "shell"); + let schema = tool.parameters(); + assert!(schema["properties"]["command"].is_object()); + assert!(schema["required"] + .as_array() + .unwrap() + .contains(&json!("command"))); + } + + #[tokio::test] + async fn shell_executes_simple_command() { + let tool = ShellTool::new(test_security(AutonomyLevel::Supervised), test_runtime()); + let result = tool + .execute(json!({"command": "echo hello"})) + .await + .unwrap(); + assert!(result.success); + assert!(result.output.trim().contains("hello")); + } + + #[tokio::test] + async fn shell_blocks_dangerous_command() { + let tool = ShellTool::new(test_security(AutonomyLevel::Supervised), test_runtime()); + let result = tool + .execute(json!({"command": "rm -rf /"})) + .await + .unwrap(); + assert!(!result.success); + } + + #[tokio::test] + async fn shell_blocks_rate_limited() { + let security = Arc::new(SecurityPolicy { + autonomy: AutonomyLevel::Supervised, + max_actions_per_hour: 0, + workspace_dir: std::env::temp_dir(), + ..SecurityPolicy::default() + }); + let tool = ShellTool::new(security, test_runtime()); + let result = tool + .execute(json!({"command": "echo test"})) + .await + .unwrap(); + assert!(!result.success); + assert!(result.error.as_deref().unwrap_or("").contains("Rate limit")); + } + + #[tokio::test] + async fn shell_blocks_readonly_mode() { + let tool = ShellTool::new(test_security(AutonomyLevel::ReadOnly), test_runtime()); + let result = tool + .execute(json!({"command": "ls"})) + .await + .unwrap(); + assert!(!result.success); + assert!(result.error.as_ref().unwrap().contains("not allowed")); + } + + #[tokio::test] + async fn shell_forbidden_path_in_args() { + let tool = ShellTool::new(test_security(AutonomyLevel::Supervised), test_runtime()); + let result = tool + .execute(json!({"command": "cat /etc/passwd"})) + .await + .unwrap(); + assert!(!result.success); + assert!(result + .error + .as_deref() + .unwrap_or("") + .contains("Path blocked")); + } + + #[tokio::test] + async fn shell_captures_exit_code() { + let tool = ShellTool::new(test_security(AutonomyLevel::Supervised), test_runtime()); + let result = tool + .execute(json!({"command": "ls /nonexistent_dir_xyz"})) + .await + .unwrap(); + assert!(!result.success); + } + + #[tokio::test] + async fn shell_captures_stderr() { + let tool = ShellTool::new(test_security(AutonomyLevel::Full), test_runtime()); + let result = tool + .execute(json!({"command": "echo error_msg >&2"})) + .await + .unwrap(); + assert!(result.error.as_deref().unwrap_or("").contains("error_msg")); + } + + #[tokio::test] + async fn shell_missing_command_param() { + let tool = ShellTool::new(test_security(AutonomyLevel::Supervised), test_runtime()); + let result = tool.execute(json!({})).await; + assert!(result.is_err()); + } + + #[test] + fn extract_command_supports_aliases() { + assert_eq!( + extract_command_argument(&json!({"cmd": "echo from-cmd"})).as_deref(), + Some("echo from-cmd") + ); + assert_eq!( + extract_command_argument(&json!({"script": "echo from-script"})).as_deref(), + Some("echo from-script") + ); + assert_eq!( + extract_command_argument(&json!("echo from-string")).as_deref(), + Some("echo from-string") + ); + } + + #[test] + fn shell_safe_env_vars_excludes_secrets() { + for var in SAFE_ENV_VARS { + let lower = var.to_lowercase(); + assert!( + !lower.contains("key") && !lower.contains("secret") && !lower.contains("token"), + "SAFE_ENV_VARS must not include: {var}" + ); + } + } + + #[test] + fn invalid_env_var_names_are_filtered() { + let security = SecurityPolicy { + shell_env_passthrough: vec![ + "VALID_NAME".into(), + "BAD-NAME".into(), + "1NOPE".into(), + "ALSO_VALID".into(), + ], + ..SecurityPolicy::default() + }; + let vars = collect_allowed_shell_env_vars(&security); + assert!(vars.contains(&"VALID_NAME".to_string())); + assert!(vars.contains(&"ALSO_VALID".to_string())); + assert!(!vars.contains(&"BAD-NAME".to_string())); + assert!(!vars.contains(&"1NOPE".to_string())); + } + + #[tokio::test] + async fn shell_record_action_budget_exhaustion() { + let security = Arc::new(SecurityPolicy { + autonomy: AutonomyLevel::Full, + max_actions_per_hour: 1, + workspace_dir: std::env::temp_dir(), + ..SecurityPolicy::default() + }); + let tool = ShellTool::new(security, test_runtime()); + + let r1 = tool + .execute(json!({"command": "echo first"})) + .await + .unwrap(); + assert!(r1.success); + + let r2 = tool + .execute(json!({"command": "echo second"})) + .await + .unwrap(); + assert!(!r2.success); + assert!( + r2.error.as_deref().unwrap_or("").contains("Rate limit") + || r2.error.as_deref().unwrap_or("").contains("budget") + ); + } +} diff --git a/crewforge-rs/src/tools/traits.rs b/crewforge-rs/src/tools/traits.rs new file mode 100644 index 0000000..2e63ed6 --- /dev/null +++ b/crewforge-rs/src/tools/traits.rs @@ -0,0 +1,55 @@ +use std::path::Path; + +/// Abstracts shell command execution for testability. +pub trait RuntimeAdapter: Send + Sync { + fn build_shell_command( + &self, + command: &str, + workspace_dir: &Path, + ) -> anyhow::Result; +} + +/// Native runtime: executes shell commands via `sh -c` on Unix. +pub struct TokioRuntime; + +impl RuntimeAdapter for TokioRuntime { + fn build_shell_command( + &self, + command: &str, + workspace_dir: &Path, + ) -> anyhow::Result { + let mut cmd = tokio::process::Command::new("sh"); + cmd.arg("-c").arg(command); + cmd.current_dir(workspace_dir); + cmd.stdout(std::process::Stdio::piped()); + cmd.stderr(std::process::Stdio::piped()); + Ok(cmd) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn tokio_runtime_echo() { + let rt = TokioRuntime; + let dir = std::env::current_dir().unwrap(); + let mut cmd = rt.build_shell_command("echo hello", &dir).unwrap(); + let output = cmd.output().await.unwrap(); + assert!(output.status.success()); + let stdout = String::from_utf8_lossy(&output.stdout); + assert!(stdout.trim() == "hello"); + } + + #[tokio::test] + async fn tokio_runtime_sets_cwd() { + let rt = TokioRuntime; + let dir = tempfile::tempdir().unwrap(); + let mut cmd = rt.build_shell_command("pwd", dir.path()).unwrap(); + let output = cmd.output().await.unwrap(); + let stdout = String::from_utf8_lossy(&output.stdout); + let canonical = dir.path().canonicalize().unwrap(); + assert_eq!(stdout.trim(), canonical.to_str().unwrap()); + } +} From 548b1524fb800fbf803ad81ece7bd45ec6161c92 Mon Sep 17 00:00:00 2001 From: Rexopia Date: Sun, 1 Mar 2026 16:45:57 +0800 Subject: [PATCH 6/8] chore: fix all clippy warnings and formatting across codebase Resolve collapsible_if, dead_code, unused imports/labels/mut, too_many_arguments, and unnecessary_map_or clippy lints. Apply cargo fmt to all files. 488 tests pass, 0 warnings. Co-Authored-By: Claude Opus 4.6 --- crewforge-rs/src/agent/dispatcher.rs | 12 +- crewforge-rs/src/agent/history.rs | 32 ++-- crewforge-rs/src/agent/loop_.rs | 62 +++++-- crewforge-rs/src/agent/mod.rs | 2 +- crewforge-rs/src/agent/prompt.rs | 15 +- crewforge-rs/src/agent_cmd.rs | 71 +++++--- crewforge-rs/src/auth/gemini_oauth.rs | 13 +- crewforge-rs/src/auth/mod.rs | 10 +- crewforge-rs/src/auth/oauth_common.rs | 14 +- crewforge-rs/src/auth/openai_oauth.rs | 17 +- crewforge-rs/src/auth_cmd.rs | 38 ++-- crewforge-rs/src/bin/agentctl.rs | 65 ++++--- crewforge-rs/src/chat.rs | 2 +- crewforge-rs/src/main.rs | 2 +- crewforge-rs/src/provider/anthropic.rs | 65 ++++--- crewforge-rs/src/provider/compatible.rs | 210 +++++++++++---------- crewforge-rs/src/provider/copilot.rs | 153 ++++++++-------- crewforge-rs/src/provider/gemini.rs | 14 +- crewforge-rs/src/provider/glm.rs | 27 +-- crewforge-rs/src/provider/mod.rs | 19 +- crewforge-rs/src/provider/ollama.rs | 211 +++++++++++----------- crewforge-rs/src/provider/openai.rs | 115 ++++++------ crewforge-rs/src/provider/openai_codex.rs | 13 +- crewforge-rs/src/provider/openrouter.rs | 127 ++++++------- crewforge-rs/src/provider/reliable.rs | 141 ++++++++------- crewforge-rs/src/provider/router.rs | 3 +- crewforge-rs/src/provider/traits.rs | 48 ++--- crewforge-rs/src/security/policy.rs | 58 +++--- crewforge-rs/src/security/secrets.rs | 2 + crewforge-rs/src/tools/content_search.rs | 24 +-- crewforge-rs/src/tools/file_edit.rs | 58 +++--- crewforge-rs/src/tools/file_read.rs | 34 ++-- crewforge-rs/src/tools/file_write.rs | 16 +- crewforge-rs/src/tools/glob_search.rs | 33 ++-- crewforge-rs/src/tools/shell.rs | 37 ++-- crewforge-rs/src/tui.rs | 35 +++- 36 files changed, 915 insertions(+), 883 deletions(-) diff --git a/crewforge-rs/src/agent/dispatcher.rs b/crewforge-rs/src/agent/dispatcher.rs index 8285a10..761359a 100644 --- a/crewforge-rs/src/agent/dispatcher.rs +++ b/crewforge-rs/src/agent/dispatcher.rs @@ -1,5 +1,5 @@ -use crate::provider::traits::{ChatMessage, ChatResponse, ConversationMessage, ToolResultMessage}; use super::Tool; +use crate::provider::traits::{ChatMessage, ChatResponse, ConversationMessage, ToolResultMessage}; use serde_json::Value; use std::fmt::Write; @@ -94,7 +94,11 @@ impl ToolDispatcher for XmlToolDispatcher { fn format_results(&self, results: &[ToolExecutionResult]) -> ConversationMessage { let mut content = String::new(); for result in results { - let status = if result.tool_result.success { "ok" } else { "error" }; + let status = if result.tool_result.success { + "ok" + } else { + "error" + }; let output = if let Some(ref err) = result.tool_result.error { err.as_str() } else { @@ -361,9 +365,7 @@ mod tests { #[test] fn xml_dispatcher_skips_empty_name() { let response = ChatResponse { - text: Some( - "{\"name\":\"\",\"arguments\":{}}text after".into(), - ), + text: Some("{\"name\":\"\",\"arguments\":{}}text after".into()), tool_calls: vec![], usage: None, reasoning_content: None, diff --git a/crewforge-rs/src/agent/history.rs b/crewforge-rs/src/agent/history.rs index 239badc..6657f86 100644 --- a/crewforge-rs/src/agent/history.rs +++ b/crewforge-rs/src/agent/history.rs @@ -1,6 +1,4 @@ -use crate::provider::traits::{ - ChatMessage, ConversationMessage, Provider, ToolResultMessage, -}; +use crate::provider::traits::{ChatMessage, ConversationMessage, Provider}; use anyhow::Result; use std::fmt::Write; @@ -110,7 +108,9 @@ fn build_compaction_transcript(messages: &[ConversationMessage]) -> String { let role = chat.role.to_uppercase(); let _ = writeln!(transcript, "{role}: {}", chat.content.trim()); } - ConversationMessage::AssistantToolCalls { text, tool_calls, .. } => { + ConversationMessage::AssistantToolCalls { + text, tool_calls, .. + } => { let text_str = text.as_deref().unwrap_or("").trim(); let _ = writeln!(transcript, "ASSISTANT: {text_str}"); for tc in tool_calls { @@ -258,12 +258,10 @@ mod tests { #[test] fn to_provider_messages_native_tool_results() { - let history = vec![ConversationMessage::ToolResults(vec![ - ToolResultMessage { - tool_call_id: "tc1".into(), - content: "output".into(), - }, - ])]; + let history = vec![ConversationMessage::ToolResults(vec![ToolResultMessage { + tool_call_id: "tc1".into(), + content: "output".into(), + }])]; let msgs = to_provider_messages_native(&history); assert_eq!(msgs.len(), 1); assert_eq!(msgs[0].role, "tool"); @@ -291,12 +289,10 @@ mod tests { #[test] fn to_provider_messages_xml_tool_results_as_user() { - let history = vec![ConversationMessage::ToolResults(vec![ - ToolResultMessage { - tool_call_id: "tc1".into(), - content: "result_data".into(), - }, - ])]; + let history = vec![ConversationMessage::ToolResults(vec![ToolResultMessage { + tool_call_id: "tc1".into(), + content: "result_data".into(), + }])]; let msgs = to_provider_messages_xml(&history); assert_eq!(msgs.len(), 1); assert_eq!(msgs[0].role, "user"); @@ -318,9 +314,7 @@ mod tests { // system + 3 most recent non-system assert_eq!(history.len(), 4); // system is still first - assert!( - matches!(&history[0], ConversationMessage::Chat(m) if m.role == "system") - ); + assert!(matches!(&history[0], ConversationMessage::Chat(m) if m.role == "system")); // most recent messages are preserved assert!( matches!(&history[history.len()-1], ConversationMessage::Chat(m) if m.content == "msg3") diff --git a/crewforge-rs/src/agent/loop_.rs b/crewforge-rs/src/agent/loop_.rs index 6562016..3960e64 100644 --- a/crewforge-rs/src/agent/loop_.rs +++ b/crewforge-rs/src/agent/loop_.rs @@ -5,9 +5,6 @@ use std::collections::HashSet; use std::sync::Arc; use std::sync::atomic::{AtomicBool, Ordering}; -use crate::provider::traits::{ - ChatMessage, ChatRequest, ConversationMessage, Provider, ToolCall, ToolSpec, -}; use super::Tool; use super::dispatcher::{ NativeToolDispatcher, ParsedToolCall, ToolDispatcher, ToolExecutionResult, XmlToolDispatcher, @@ -15,12 +12,17 @@ use super::dispatcher::{ use super::history::{ auto_compact_history, to_provider_messages_native, to_provider_messages_xml, trim_history, }; +use crate::provider::traits::{ + ChatMessage, ChatRequest, ConversationMessage, Provider, ToolCall, ToolSpec, +}; // ── Public event/config/stop types ─────────────────────────────────────────── #[derive(Debug, Clone)] pub enum AgentEvent { - LlmThinking { iteration: usize }, + LlmThinking { + iteration: usize, + }, LlmResponse { text: Option, tool_call_count: usize, @@ -93,7 +95,9 @@ impl AgentSession { let system_prompt = system_prompt.into(); let mut history = Vec::new(); if !system_prompt.is_empty() { - history.push(ConversationMessage::Chat(ChatMessage::system(system_prompt))); + history.push(ConversationMessage::Chat(ChatMessage::system( + system_prompt, + ))); } Self { provider, @@ -126,7 +130,9 @@ impl AgentSession { // Add initial user message to history. self.history - .push(ConversationMessage::Chat(ChatMessage::user(initial_message))); + .push(ConversationMessage::Chat(ChatMessage::user( + initial_message, + ))); // Compact and trim history before starting. let _ = auto_compact_history( @@ -144,7 +150,7 @@ impl AgentSession { let mut final_text: Option = None; let mut iterations_used = 0; - 'outer: for iteration in 0..self.config.max_iterations { + for iteration in 0..self.config.max_iterations { if self.cancelled.load(Ordering::SeqCst) { events.push(AgentEvent::TurnFinished { final_text, @@ -221,7 +227,9 @@ impl AgentSession { if parsed_calls.is_empty() { // No tool calls — this is the final response for this turn. self.history - .push(ConversationMessage::Chat(ChatMessage::assistant(text.clone()))); + .push(ConversationMessage::Chat(ChatMessage::assistant( + text.clone(), + ))); final_text = Some(text); events.push(AgentEvent::TurnFinished { final_text: final_text.clone(), @@ -245,7 +253,11 @@ impl AgentSession { .collect() }; self.history.push(ConversationMessage::AssistantToolCalls { - text: if text.is_empty() { None } else { Some(text.clone()) }, + text: if text.is_empty() { + None + } else { + Some(text.clone()) + }, tool_calls: tool_calls_for_history, reasoning_content: response.reasoning_content.clone(), }); @@ -312,7 +324,6 @@ impl AgentSession { }); return events; } - } // Max iterations reached. @@ -370,8 +381,8 @@ async fn execute_tool(tools: &[Box], call: &ParsedToolCall) -> ToolExe /// Matches patterns such as `token=...`, `api_key: "..."`, `password=...`, etc. /// Preserves the first 4 characters of each value for context; redacts the rest. pub fn scrub_credentials(input: &str) -> String { - use std::sync::LazyLock; use regex::Regex; + use std::sync::LazyLock; // Matches key=value and key: value forms (token=, api_key:, bearer:, etc.) static SENSITIVE_KV_REGEX: LazyLock = LazyLock::new(|| { @@ -527,7 +538,11 @@ mod tests { ))); // Should have LlmThinking at start. - assert!(events.iter().any(|e| matches!(e, AgentEvent::LlmThinking { iteration: 0 }))); + assert!( + events + .iter() + .any(|e| matches!(e, AgentEvent::LlmThinking { iteration: 0 })) + ); // History should include system, user, and assistant messages. assert!(session.history.len() >= 3); @@ -571,14 +586,16 @@ mod tests { // First turn is consumed (provider called once before cancel check kicks in // inside the loop), OR cancelled before LLM call — depends on loop ordering. // Either way we must have a TurnFinished with Cancelled or Done. - let finished = events.iter().find(|e| matches!(e, AgentEvent::TurnFinished { .. })); + let finished = events + .iter() + .find(|e| matches!(e, AgentEvent::TurnFinished { .. })); assert!(finished.is_some()); } #[tokio::test] async fn run_turn_cancel_handle_signals_session() { let provider = Arc::new(EchoProvider); - let mut session = AgentSession::new( + let session = AgentSession::new( provider, "test-model", "", @@ -618,10 +635,19 @@ mod tests { #[async_trait] impl Tool for NoopTool { - fn name(&self) -> &str { "noop" } - fn description(&self) -> &str { "no-op" } - fn parameters(&self) -> serde_json::Value { serde_json::json!({}) } - async fn execute(&self, _args: serde_json::Value) -> anyhow::Result { + fn name(&self) -> &str { + "noop" + } + fn description(&self) -> &str { + "no-op" + } + fn parameters(&self) -> serde_json::Value { + serde_json::json!({}) + } + async fn execute( + &self, + _args: serde_json::Value, + ) -> anyhow::Result { Ok(crate::agent::ToolResult { success: true, output: "done".to_string(), diff --git a/crewforge-rs/src/agent/mod.rs b/crewforge-rs/src/agent/mod.rs index d22eded..6049199 100644 --- a/crewforge-rs/src/agent/mod.rs +++ b/crewforge-rs/src/agent/mod.rs @@ -5,8 +5,8 @@ pub mod prompt; pub use loop_::{AgentEvent, AgentSession, AgentSessionConfig, StopReason}; -use async_trait::async_trait; use crate::provider::traits::ToolSpec; +use async_trait::async_trait; /// Structured result from tool execution. /// Security denials use `success: false` with an `error` message — diff --git a/crewforge-rs/src/agent/prompt.rs b/crewforge-rs/src/agent/prompt.rs index 69cf5f7..4e71822 100644 --- a/crewforge-rs/src/agent/prompt.rs +++ b/crewforge-rs/src/agent/prompt.rs @@ -1,14 +1,11 @@ /// Build a basic system prompt for an agent with optional instructions. -pub fn build_system_prompt( - agent_name: &str, - instructions: Option<&str>, -) -> String { +pub fn build_system_prompt(agent_name: &str, instructions: Option<&str>) -> String { let mut prompt = format!("You are {agent_name}, an AI assistant."); - if let Some(instr) = instructions { - if !instr.is_empty() { - prompt.push_str("\n\n"); - prompt.push_str(instr); - } + if let Some(instr) = instructions + && !instr.is_empty() + { + prompt.push_str("\n\n"); + prompt.push_str(instr); } prompt } diff --git a/crewforge-rs/src/agent_cmd.rs b/crewforge-rs/src/agent_cmd.rs index 6e3f85a..70564e4 100644 --- a/crewforge-rs/src/agent_cmd.rs +++ b/crewforge-rs/src/agent_cmd.rs @@ -20,7 +20,7 @@ use crewforge::{ auth::{AuthService, default_state_dir}, provider::{self, default_api_key_env}, security::SecurityPolicy, - tools::{default_tools, TokioRuntime}, + tools::{TokioRuntime, default_tools}, }; // ── Clap args ───────────────────────────────────────────────────────────────── @@ -71,35 +71,56 @@ fn print_event(event: &AgentEvent) { eprintln!("\x1b[2m[thinking... round {}]\x1b[0m", iteration + 1); } } - AgentEvent::LlmResponse { text, tool_call_count, usage } => { + AgentEvent::LlmResponse { + text, + tool_call_count, + usage, + } => { if *tool_call_count == 0 { if let Some(t) = text { println!("{t}"); } - } else if let Some(t) = text { - if !t.is_empty() { - eprintln!("\x1b[2m[llm]: {t}\x1b[0m"); - } + } else if let Some(t) = text + && !t.is_empty() + { + eprintln!("\x1b[2m[llm]: {t}\x1b[0m"); } - if let Some(u) = usage { - if u.input_tokens.is_some() || u.output_tokens.is_some() { - eprintln!( - "\x1b[2m[tokens] in={} out={}\x1b[0m", - u.input_tokens.unwrap_or(0), - u.output_tokens.unwrap_or(0) - ); - } + if let Some(u) = usage + && (u.input_tokens.is_some() || u.output_tokens.is_some()) + { + eprintln!( + "\x1b[2m[tokens] in={} out={}\x1b[0m", + u.input_tokens.unwrap_or(0), + u.output_tokens.unwrap_or(0) + ); } } - AgentEvent::ToolCallStarted { iteration, name, args } => { + AgentEvent::ToolCallStarted { + iteration, + name, + args, + } => { let args_str = serde_json::to_string(args).unwrap_or_else(|_| "{}".to_string()); - eprintln!("\x1b[33m[tool:{}] {} {}\x1b[0m", iteration + 1, name, args_str); + eprintln!( + "\x1b[33m[tool:{}] {} {}\x1b[0m", + iteration + 1, + name, + args_str + ); } - AgentEvent::ToolCallFinished { name, result, success } => { + AgentEvent::ToolCallFinished { + name, + result, + success, + } => { let icon = if *success { "✓" } else { "✗" }; eprintln!("\x1b[32m[{icon} {name}] {result}\x1b[0m"); } - AgentEvent::TurnFinished { final_text, iterations_used, stop_reason } => { + AgentEvent::TurnFinished { + final_text, + iterations_used, + stop_reason, + } => { let reason = match stop_reason { StopReason::Done => "done", StopReason::MaxIterations => "max_iterations", @@ -109,10 +130,10 @@ fn print_event(event: &AgentEvent) { "\x1b[2m[turn finished: {} iteration(s), reason={}]\x1b[0m", iterations_used, reason ); - if *iterations_used == 0 { - if let Some(t) = final_text { - println!("{t}"); - } + if *iterations_used == 0 + && let Some(t) = final_text + { + println!("{t}"); } } AgentEvent::Error { message, fatal } => { @@ -172,7 +193,11 @@ pub async fn run(args: AgentArgs) -> Result<()> { "\x1b[1mcrewforge agent\x1b[0m provider={} model={} tools={}", args.provider, args.model, - if args.no_tools { "off".to_string() } else { tool_names.join(", ") } + if args.no_tools { + "off".to_string() + } else { + tool_names.join(", ") + } ); eprintln!("Type your message and press Enter. Ctrl-D to exit.\n"); diff --git a/crewforge-rs/src/auth/gemini_oauth.rs b/crewforge-rs/src/auth/gemini_oauth.rs index 3deb7a2..e3f6c91 100644 --- a/crewforge-rs/src/auth/gemini_oauth.rs +++ b/crewforge-rs/src/auth/gemini_oauth.rs @@ -21,7 +21,7 @@ use tokio::net::TcpListener; // Re-export for external use #[allow(unused_imports)] -pub use crate::auth::oauth_common::{generate_pkce_state, PkceState}; +pub use crate::auth::oauth_common::{PkceState, generate_pkce_state}; /// Get Gemini OAuth client ID from environment. /// Required: set GEMINI_OAUTH_CLIENT_ID environment variable. @@ -478,12 +478,11 @@ pub fn parse_code_from_redirect(input: &str, expected_state: Option<&str>) -> Re let params = parse_query_params(query); if let Some(code) = params.get("code") { - if let Some(expected) = expected_state { - if let Some(actual) = params.get("state") { - if actual != expected { - anyhow::bail!("OAuth state mismatch: expected {expected}, got {actual}"); - } - } + if let Some(expected) = expected_state + && let Some(actual) = params.get("state") + && actual != expected + { + anyhow::bail!("OAuth state mismatch: expected {expected}, got {actual}"); } return Ok(code.clone()); } diff --git a/crewforge-rs/src/auth/mod.rs b/crewforge-rs/src/auth/mod.rs index cbb466a..303fc96 100644 --- a/crewforge-rs/src/auth/mod.rs +++ b/crewforge-rs/src/auth/mod.rs @@ -6,7 +6,7 @@ pub mod profiles; use crate::auth::openai_oauth::refresh_access_token; use crate::auth::profiles::{ - profile_id, AuthProfile, AuthProfileKind, AuthProfilesData, AuthProfilesStore, TokenSet, + AuthProfile, AuthProfileKind, AuthProfilesData, AuthProfilesStore, TokenSet, profile_id, }; use anyhow::Result; use std::collections::HashMap; @@ -384,10 +384,10 @@ pub fn select_profile_id( return None; } - if let Some(active) = data.active_profiles.get(provider) { - if data.profiles.contains_key(active) { - return Some(active.clone()); - } + if let Some(active) = data.active_profiles.get(provider) + && data.profiles.contains_key(active) + { + return Some(active.clone()); } let default = default_profile_id(provider); diff --git a/crewforge-rs/src/auth/oauth_common.rs b/crewforge-rs/src/auth/oauth_common.rs index fc499d0..caa5ed1 100644 --- a/crewforge-rs/src/auth/oauth_common.rs +++ b/crewforge-rs/src/auth/oauth_common.rs @@ -35,7 +35,7 @@ pub fn generate_pkce_state() -> PkceState { /// Generate a cryptographically random base64url-encoded string. pub fn random_base64url(byte_len: usize) -> String { - use chacha20poly1305::aead::{rand_core::RngCore, OsRng}; + use chacha20poly1305::aead::{OsRng, rand_core::RngCore}; let mut bytes = vec![0_u8; byte_len]; OsRng.fill_bytes(&mut bytes); @@ -66,12 +66,12 @@ pub fn url_decode(input: &str) -> String { b'%' if i + 2 < bytes.len() => { let hi = bytes[i + 1] as char; let lo = bytes[i + 2] as char; - if let (Some(h), Some(l)) = (hi.to_digit(16), lo.to_digit(16)) { - if let Ok(value) = u8::try_from(h * 16 + l) { - out.push(value); - i += 3; - continue; - } + if let (Some(h), Some(l)) = (hi.to_digit(16), lo.to_digit(16)) + && let Ok(value) = u8::try_from(h * 16 + l) + { + out.push(value); + i += 3; + continue; } out.push(bytes[i]); i += 1; diff --git a/crewforge-rs/src/auth/openai_oauth.rs b/crewforge-rs/src/auth/openai_oauth.rs index 218ff11..4c84d9e 100644 --- a/crewforge-rs/src/auth/openai_oauth.rs +++ b/crewforge-rs/src/auth/openai_oauth.rs @@ -13,7 +13,7 @@ use tokio::net::TcpListener; // Re-export for external use #[allow(unused_imports)] -pub use crate::auth::oauth_common::{generate_pkce_state, PkceState}; +pub use crate::auth::oauth_common::{PkceState, generate_pkce_state}; pub const OPENAI_OAUTH_CLIENT_ID: &str = "app_EMoamEEZ73f0CkXaXp7hrann"; pub const OPENAI_OAUTH_AUTHORIZE_URL: &str = "https://auth.openai.com/oauth/authorize"; @@ -328,10 +328,10 @@ pub fn extract_account_id_from_jwt(token: &str) -> Option { "sub", "https://api.openai.com/account_id", ] { - if let Some(value) = claims.get(key).and_then(|v| v.as_str()) { - if !value.trim().is_empty() { - return Some(value.to_string()); - } + if let Some(value) = claims.get(key).and_then(|v| v.as_str()) + && !value.trim().is_empty() + { + return Some(value.to_string()); } } @@ -409,9 +409,10 @@ mod tests { Some("xyz"), ) .unwrap_err(); - assert!(err - .to_string() - .contains("OpenAI OAuth error: access_denied")); + assert!( + err.to_string() + .contains("OpenAI OAuth error: access_denied") + ); } #[test] diff --git a/crewforge-rs/src/auth_cmd.rs b/crewforge-rs/src/auth_cmd.rs index 69fbe83..75ba0f6 100644 --- a/crewforge-rs/src/auth_cmd.rs +++ b/crewforge-rs/src/auth_cmd.rs @@ -1,11 +1,11 @@ //! `crewforge auth` subcommand — OAuth login, token management, and profile listing. -use crewforge::auth::{self, AuthService, default_state_dir, normalize_provider}; -use crewforge::auth::oauth_common::PkceState; -use crewforge::security::SecretStore; -use anyhow::{bail, Result}; +use anyhow::{Result, bail}; use chrono::{DateTime, Utc}; use clap::Subcommand; +use crewforge::auth::oauth_common::PkceState; +use crewforge::auth::{self, AuthService, default_state_dir, normalize_provider}; +use crewforge::security::SecretStore; use serde::{Deserialize, Serialize}; use std::path::PathBuf; @@ -341,9 +341,7 @@ async fn run_login(provider: String, profile: String, device_code: bool) -> Resu match provider.as_str() { "gemini" => run_gemini_login(&svc, &client, &profile, device_code).await, "openai-codex" => run_openai_login(&svc, &client, &profile, device_code).await, - _ => bail!( - "`auth login` supports --provider openai-codex or gemini, got: {provider}" - ), + _ => bail!("`auth login` supports --provider openai-codex or gemini, got: {provider}"), } } @@ -407,15 +405,12 @@ async fn run_gemini_login( } Err(e) => { println!("Callback capture failed: {e}"); - println!( - "Run `crewforge auth paste-redirect --provider gemini --profile {profile}`" - ); + println!("Run `crewforge auth paste-redirect --provider gemini --profile {profile}`"); return Ok(()); } }; - let token_set = - auth::gemini_oauth::exchange_code_for_tokens(client, &code, &pkce).await?; + let token_set = auth::gemini_oauth::exchange_code_for_tokens(client, &code, &pkce).await?; let account_id = token_set .id_token .as_deref() @@ -456,9 +451,7 @@ async fn run_openai_login( return Ok(()); } Err(e) => { - println!( - "Device-code flow unavailable: {e}. Falling back to browser/paste flow." - ); + println!("Device-code flow unavailable: {e}. Falling back to browser/paste flow."); } } } @@ -494,8 +487,7 @@ async fn run_openai_login( } }; - let token_set = - auth::openai_oauth::exchange_code_for_tokens(client, &code, &pkce).await?; + let token_set = auth::openai_oauth::exchange_code_for_tokens(client, &code, &pkce).await?; let account_id = extract_openai_account_id(&token_set.access_token); svc.store_openai_tokens(profile, token_set, account_id, true) .await?; @@ -613,10 +605,12 @@ async fn run_paste_token( if token.is_empty() { bail!("Token cannot be empty"); } - let kind = - auth::anthropic_token::detect_auth_kind(&token, auth_kind.as_deref()); + let kind = auth::anthropic_token::detect_auth_kind(&token, auth_kind.as_deref()); let mut metadata = std::collections::HashMap::new(); - metadata.insert("auth_kind".to_string(), kind.as_metadata_value().to_string()); + metadata.insert( + "auth_kind".to_string(), + kind.as_metadata_value().to_string(), + ); let svc = make_auth_service(); svc.store_provider_token(&provider, &profile, &token, metadata, true) @@ -638,9 +632,7 @@ async fn run_refresh(provider: String, profile: Option) -> Result<()> { .get_valid_openai_access_token(profile.as_deref()) .await? { - Some(_) => println!( - "OpenAI Codex token is valid (refresh completed if needed)." - ), + Some(_) => println!("OpenAI Codex token is valid (refresh completed if needed)."), None => bail!( "No OpenAI Codex auth profile found. \ Run `crewforge auth login --provider openai-codex`." diff --git a/crewforge-rs/src/bin/agentctl.rs b/crewforge-rs/src/bin/agentctl.rs index b22962e..e10ad36 100644 --- a/crewforge-rs/src/bin/agentctl.rs +++ b/crewforge-rs/src/bin/agentctl.rs @@ -16,7 +16,7 @@ use async_trait::async_trait; use clap::Parser; use crewforge::{ agent::{AgentEvent, AgentSession, AgentSessionConfig, StopReason, Tool}, - provider::{self, ToolSpec}, + provider::{self}, }; // ── CLI args ────────────────────────────────────────────────────────────────── @@ -88,7 +88,10 @@ impl Tool for EchoTool { }) } - async fn execute(&self, args: serde_json::Value) -> anyhow::Result { + async fn execute( + &self, + args: serde_json::Value, + ) -> anyhow::Result { let msg = args .get("message") .and_then(|v| v.as_str()) @@ -122,7 +125,10 @@ impl Tool for DatetimeTool { }) } - async fn execute(&self, _args: serde_json::Value) -> anyhow::Result { + async fn execute( + &self, + _args: serde_json::Value, + ) -> anyhow::Result { use std::time::{SystemTime, UNIX_EPOCH}; let secs = SystemTime::now() .duration_since(UNIX_EPOCH) @@ -160,21 +166,19 @@ fn print_event(event: &AgentEvent) { if let Some(t) = text { println!("{t}"); } - } else { - if let Some(t) = text { - if !t.is_empty() { - eprintln!("\x1b[2m[llm]: {t}\x1b[0m"); - } - } + } else if let Some(t) = text + && !t.is_empty() + { + eprintln!("\x1b[2m[llm]: {t}\x1b[0m"); } - if let Some(u) = usage { - if u.input_tokens.is_some() || u.output_tokens.is_some() { - eprintln!( - "\x1b[2m[tokens] in={} out={}\x1b[0m", - u.input_tokens.unwrap_or(0), - u.output_tokens.unwrap_or(0) - ); - } + if let Some(u) = usage + && (u.input_tokens.is_some() || u.output_tokens.is_some()) + { + eprintln!( + "\x1b[2m[tokens] in={} out={}\x1b[0m", + u.input_tokens.unwrap_or(0), + u.output_tokens.unwrap_or(0) + ); } } AgentEvent::ToolCallStarted { @@ -213,10 +217,10 @@ fn print_event(event: &AgentEvent) { iterations_used, reason ); // final_text already printed via LlmResponse when tool_call_count==0 - if *iterations_used == 0 { - if let Some(t) = final_text { - println!("{t}"); - } + if *iterations_used == 0 + && let Some(t) = final_text + { + println!("{t}"); } } AgentEvent::Error { message, fatal } => { @@ -243,10 +247,7 @@ async fn main() -> anyhow::Result<()> { None // create_provider will pick it up from the env var } else { // Fall back to stored auth profile. - let svc = crewforge::auth::AuthService::new( - &crewforge::auth::default_state_dir(), - false, - ); + let svc = crewforge::auth::AuthService::new(&crewforge::auth::default_state_dir(), false); svc.get_provider_bearer_token(&args.provider, None) .await .unwrap_or(None) @@ -272,13 +273,7 @@ async fn main() -> anyhow::Result<()> { ..Default::default() }; - let mut session = AgentSession::new( - provider, - &args.model, - &args.system, - tools, - config, - ); + let mut session = AgentSession::new(provider, &args.model, &args.system, tools, config); // Migration notice. eprintln!("\x1b[2m[note] agentctl is an internal tool. Use `crewforge agent` instead.\x1b[0m"); @@ -288,7 +283,11 @@ async fn main() -> anyhow::Result<()> { "\x1b[1magentctl\x1b[0m provider={} model={} tools={}", args.provider, args.model, - if args.no_tools { "off" } else { "echo,get_datetime" } + if args.no_tools { + "off" + } else { + "echo,get_datetime" + } ); if !args.no_tools { eprintln!("tools: echo(message), get_datetime()"); diff --git a/crewforge-rs/src/chat.rs b/crewforge-rs/src/chat.rs index 5c39875..a46c928 100644 --- a/crewforge-rs/src/chat.rs +++ b/crewforge-rs/src/chat.rs @@ -21,9 +21,9 @@ use crate::hub::{RateLimitUsage, RoomHub}; use crate::kernel::{MessageEvent, MessageRole, SessionKernel}; use crate::managed_opencode::{self, HUB_ACK_TOOL, HUB_GET_TOOL, HUB_POST_TOOL}; use crate::mcp_server::RoomHubMcpServer; +use crate::opencode_provider::{OpencodeCliProvider, OpencodeProviderConfig}; use crate::profiles::{self, GlobalProfile}; use crate::prompt_theme; -use crate::opencode_provider::{OpencodeCliProvider, OpencodeProviderConfig}; use crate::scheduler::{WakeDecision, WorkerState, decide_wake, on_wake_finished}; use crate::text::{format_time, to_single_line_error}; use crate::tui::DisplayLine; diff --git a/crewforge-rs/src/main.rs b/crewforge-rs/src/main.rs index 9b84f3f..a75917c 100644 --- a/crewforge-rs/src/main.rs +++ b/crewforge-rs/src/main.rs @@ -7,9 +7,9 @@ mod init; mod kernel; mod managed_opencode; mod mcp_server; +mod opencode_provider; mod profiles; mod prompt_theme; -mod opencode_provider; mod scheduler; mod text; mod tui; diff --git a/crewforge-rs/src/provider/anthropic.rs b/crewforge-rs/src/provider/anthropic.rs index 74453f6..0224794 100644 --- a/crewforge-rs/src/provider/anthropic.rs +++ b/crewforge-rs/src/provider/anthropic.rs @@ -38,10 +38,13 @@ struct ContentBlock { kind: String, #[serde(default)] text: Option, + #[allow(dead_code)] #[serde(default)] id: Option, + #[allow(dead_code)] #[serde(default)] name: Option, + #[allow(dead_code)] #[serde(default)] input: Option, } @@ -208,15 +211,15 @@ impl AnthropicProvider { /// Apply cache control to the last message content block fn apply_cache_to_last_message(messages: &mut [NativeMessage]) { - if let Some(last_msg) = messages.last_mut() { - if let Some(last_content) = last_msg.content.last_mut() { - match last_content { - NativeContentOut::Text { cache_control, .. } - | NativeContentOut::ToolResult { cache_control, .. } => { - *cache_control = Some(CacheControl::ephemeral()); - } - NativeContentOut::ToolUse { .. } => {} + if let Some(last_msg) = messages.last_mut() + && let Some(last_content) = last_msg.content.last_mut() + { + match last_content { + NativeContentOut::Text { cache_control, .. } + | NativeContentOut::ToolResult { cache_control, .. } => { + *cache_control = Some(CacheControl::ephemeral()); } + NativeContentOut::ToolUse { .. } => {} } } } @@ -385,10 +388,10 @@ impl AnthropicProvider { for block in response.content { match block.kind.as_str() { "text" => { - if let Some(text) = block.text.map(|t| t.trim().to_string()) { - if !text.is_empty() { - text_parts.push(text); - } + if let Some(text) = block.text.map(|t| t.trim().to_string()) + && !text.is_empty() + { + text_parts.push(text); } } "tool_use" => { @@ -518,10 +521,10 @@ impl Provider for AnthropicProvider { .header("content-type", "application/json") .json(&native_request); - if let Some(tools) = &native_request.tools { - if !tools.is_empty() { - req = req.header("anthropic-beta", "prompt-caching-2024-07-31"); - } + if let Some(tools) = &native_request.tools + && !tools.is_empty() + { + req = req.header("anthropic-beta", "prompt-caching-2024-07-31"); } req = self.apply_auth(req, credential); @@ -588,10 +591,12 @@ mod tests { .chat_with_system(None, "hello", "claude-sonnet-4", 0.7) .await; assert!(result.is_err()); - assert!(result - .unwrap_err() - .to_string() - .contains("credentials not set")); + assert!( + result + .unwrap_err() + .to_string() + .contains("credentials not set") + ); } #[test] @@ -654,16 +659,20 @@ mod tests { let (_, native) = AnthropicProvider::convert_messages(&messages); assert_eq!(native.len(), 2); // First message should contain ToolUse block - assert!(native[0] - .content - .iter() - .any(|c| matches!(c, NativeContentOut::ToolUse { .. }))); + assert!( + native[0] + .content + .iter() + .any(|c| matches!(c, NativeContentOut::ToolUse { .. })) + ); // Second message (tool result) becomes a user message with ToolResult block assert_eq!(native[1].role, "user"); - assert!(native[1] - .content - .iter() - .any(|c| matches!(c, NativeContentOut::ToolResult { .. }))); + assert!( + native[1] + .content + .iter() + .any(|c| matches!(c, NativeContentOut::ToolResult { .. })) + ); } #[test] diff --git a/crewforge-rs/src/provider/compatible.rs b/crewforge-rs/src/provider/compatible.rs index dc4463e..0b49107 100644 --- a/crewforge-rs/src/provider/compatible.rs +++ b/crewforge-rs/src/provider/compatible.rs @@ -2,16 +2,15 @@ //! Most LLM APIs follow the same `/v1/chat/completions` format. //! This module provides a single implementation that works for all of them. +use crate::provider::traits::ToolSpec; use crate::provider::traits::{ ChatMessage, ChatRequest as ProviderChatRequest, ChatResponse as ProviderChatResponse, - Provider, TokenUsage, - ToolCall as ProviderToolCall, + Provider, TokenUsage, ToolCall as ProviderToolCall, }; -use crate::provider::traits::ToolSpec; use async_trait::async_trait; use reqwest::{ - header::{HeaderMap, HeaderValue, USER_AGENT}, Client, + header::{HeaderMap, HeaderValue, USER_AGENT}, }; use serde::{Deserialize, Serialize}; @@ -149,6 +148,7 @@ impl OpenAiCompatibleProvider { ) } + #[allow(clippy::too_many_arguments)] fn new_with_options( name: &str, base_url: &str, @@ -289,6 +289,7 @@ impl OpenAiCompatibleProvider { } } + #[allow(dead_code)] fn tool_specs_to_openai_format(tools: &[ToolSpec]) -> Vec { tools .iter() @@ -448,10 +449,10 @@ impl ToolCall { /// Extract function name with fallback logic for various provider formats fn function_name(&self) -> Option { // Standard OpenAI format: tool_calls[].function.name - if let Some(ref func) = self.function { - if let Some(ref name) = func.name { - return Some(name.clone()); - } + if let Some(ref func) = self.function + && let Some(ref name) = func.name + { + return Some(name.clone()); } // Fallback: direct name field self.name.clone() @@ -460,10 +461,10 @@ impl ToolCall { /// Extract arguments with fallback logic and type conversion fn function_arguments(&self) -> Option { // Standard OpenAI format: tool_calls[].function.arguments (string) - if let Some(ref func) = self.function { - if let Some(ref args) = func.arguments { - return Some(args.clone()); - } + if let Some(ref func) = self.function + && let Some(ref args) = func.arguments + { + return Some(args.clone()); } // Fallback: direct arguments field if let Some(ref args) = self.arguments { @@ -605,10 +606,10 @@ fn extract_responses_text(response: ResponsesResponse) -> Option { for item in &response.output { for content in &item.content { - if content.kind.as_deref() == Some("output_text") { - if let Some(text) = first_nonempty(content.text.as_deref()) { - return Some(text); - } + if content.kind.as_deref() == Some("output_text") + && let Some(text) = first_nonempty(content.text.as_deref()) + { + return Some(text); } } } @@ -705,9 +706,7 @@ impl OpenAiCompatibleProvider { .ok_or_else(|| anyhow::anyhow!("No response from {} Responses API", self.name)) } - fn convert_tool_specs( - tools: Option<&[ToolSpec]>, - ) -> Option> { + fn convert_tool_specs(tools: Option<&[ToolSpec]>) -> Option> { tools.map(|items| { items .iter() @@ -725,82 +724,75 @@ impl OpenAiCompatibleProvider { }) } + #[allow(dead_code)] fn to_message_content(_role: &str, content: &str) -> MessageContent { MessageContent::Text(content.to_string()) } - fn convert_messages_for_native( - messages: &[ChatMessage], - ) -> Vec { + fn convert_messages_for_native(messages: &[ChatMessage]) -> Vec { messages .iter() .map(|message| { - if message.role == "assistant" { - if let Ok(value) = serde_json::from_str::(&message.content) - { - if let Some(tool_calls_value) = value.get("tool_calls") { - if let Ok(parsed_calls) = - serde_json::from_value::>( - tool_calls_value.clone(), - ) - { - let tool_calls = parsed_calls - .into_iter() - .map(|tc| ToolCall { - id: Some(tc.id), - kind: Some("function".to_string()), - function: Some(Function { - name: Some(tc.name), - arguments: Some(tc.arguments), - }), - name: None, - arguments: None, - parameters: None, - }) - .collect::>(); - - let content = value - .get("content") - .and_then(serde_json::Value::as_str) - .map(|value| MessageContent::Text(value.to_string())); - - let reasoning_content = value - .get("reasoning_content") - .and_then(serde_json::Value::as_str) - .map(ToString::to_string); - - return NativeMessage { - role: "assistant".to_string(), - content, - tool_call_id: None, - tool_calls: Some(tool_calls), - reasoning_content, - }; - } - } - } + if message.role == "assistant" + && let Ok(value) = serde_json::from_str::(&message.content) + && let Some(tool_calls_value) = value.get("tool_calls") + && let Ok(parsed_calls) = + serde_json::from_value::>(tool_calls_value.clone()) + { + let tool_calls = parsed_calls + .into_iter() + .map(|tc| ToolCall { + id: Some(tc.id), + kind: Some("function".to_string()), + function: Some(Function { + name: Some(tc.name), + arguments: Some(tc.arguments), + }), + name: None, + arguments: None, + parameters: None, + }) + .collect::>(); + + let content = value + .get("content") + .and_then(serde_json::Value::as_str) + .map(|value| MessageContent::Text(value.to_string())); + + let reasoning_content = value + .get("reasoning_content") + .and_then(serde_json::Value::as_str) + .map(ToString::to_string); + + return NativeMessage { + role: "assistant".to_string(), + content, + tool_call_id: None, + tool_calls: Some(tool_calls), + reasoning_content, + }; } - if message.role == "tool" { - if let Ok(value) = serde_json::from_str::(&message.content) { - let tool_call_id = value - .get("tool_call_id") - .and_then(serde_json::Value::as_str) - .map(ToString::to_string); - let content = value - .get("content") - .and_then(serde_json::Value::as_str) - .map(|value| MessageContent::Text(value.to_string())) - .or_else(|| Some(MessageContent::Text(message.content.clone()))); - - return NativeMessage { - role: "tool".to_string(), - content, - tool_call_id, - tool_calls: None, - reasoning_content: None, - }; - } + if message.role == "tool" + && let Ok(value) = serde_json::from_str::(&message.content) + { + let tool_call_id = value + .get("tool_call_id") + .and_then(serde_json::Value::as_str) + .map(ToString::to_string); + let content = value + .get("content") + .and_then(serde_json::Value::as_str) + .map(|value| MessageContent::Text(value.to_string())) + .or_else(|| Some(MessageContent::Text(message.content.clone()))); + + return NativeMessage { + role: "tool".to_string(), + content, + tool_call_id, + tool_calls: None, + reasoning_content: None, + }; } NativeMessage { @@ -1063,10 +1055,7 @@ impl Provider for OpenAiCompatibleProvider { // If tool_calls are present, serialize the full message as JSON // so parse_tool_calls can handle the OpenAI-style format if c.message.tool_calls.is_some() - && c.message - .tool_calls - .as_ref() - .map_or(false, |t| !t.is_empty()) + && c.message.tool_calls.as_ref().is_some_and(|t| !t.is_empty()) { serde_json::to_string(&c.message) .unwrap_or_else(|_| c.message.effective_content()) @@ -1168,10 +1157,7 @@ impl Provider for OpenAiCompatibleProvider { // If tool_calls are present, serialize the full message as JSON // so parse_tool_calls can handle the OpenAI-style format if c.message.tool_calls.is_some() - && c.message - .tool_calls - .as_ref() - .map_or(false, |t| !t.is_empty()) + && c.message.tool_calls.as_ref().is_some_and(|t| !t.is_empty()) { serde_json::to_string(&c.message) .unwrap_or_else(|_| c.message.effective_content()) @@ -1313,9 +1299,7 @@ impl Provider for OpenAiCompatibleProvider { }; let native_request = NativeChatRequest { model: model.to_string(), - messages: Self::convert_messages_for_native( - &effective_messages, - ), + messages: Self::convert_messages_for_native(&effective_messages), temperature, stream: Some(false), tool_choice: tools.as_ref().map(|_| "auto".to_string()), @@ -1471,10 +1455,12 @@ mod tests { .chat_with_system(None, "hello", "llama-3.3-70b", 0.7) .await; assert!(result.is_err()); - assert!(result - .unwrap_err() - .to_string() - .contains("Venice API key not set")); + assert!( + result + .unwrap_err() + .to_string() + .contains("Venice API key not set") + ); } #[test] @@ -1654,9 +1640,10 @@ mod tests { .await .expect_err("system-only fallback payload should fail"); - assert!(err - .to_string() - .contains("requires at least one non-system message")); + assert!( + err.to_string() + .contains("requires at least one non-system message") + ); } #[test] @@ -1967,7 +1954,10 @@ mod tests { OpenAiCompatibleProvider::with_prompt_guided_tool_instructions(&input, Some(&tools)); assert!(!output.is_empty()); assert_eq!(output[0].role, "system"); - assert!(output[0].content.contains("Available Tools") || output[0].content.contains("Tool Use Protocol")); + assert!( + output[0].content.contains("Available Tools") + || output[0].content.contains("Tool Use Protocol") + ); assert!(output[0].content.contains("shell_exec")); } @@ -2129,10 +2119,12 @@ mod tests { let result = p.chat_with_tools(&messages, &tools, "model", 0.7).await; assert!(result.is_err()); - assert!(result - .unwrap_err() - .to_string() - .contains("TestProvider API key not set")); + assert!( + result + .unwrap_err() + .to_string() + .contains("TestProvider API key not set") + ); } #[test] diff --git a/crewforge-rs/src/provider/copilot.rs b/crewforge-rs/src/provider/copilot.rs index b6dea7e..08ed40d 100644 --- a/crewforge-rs/src/provider/copilot.rs +++ b/crewforge-rs/src/provider/copilot.rs @@ -252,58 +252,55 @@ impl CopilotProvider { messages .iter() .map(|message| { - if message.role == "assistant" { - if let Ok(value) = serde_json::from_str::(&message.content) { - if let Some(tool_calls_value) = value.get("tool_calls") { - if let Ok(parsed_calls) = - serde_json::from_value::>(tool_calls_value.clone()) - { - let tool_calls = parsed_calls - .into_iter() - .map(|tool_call| NativeToolCall { - id: Some(tool_call.id), - kind: Some("function".to_string()), - function: NativeFunctionCall { - name: tool_call.name, - arguments: tool_call.arguments, - }, - }) - .collect::>(); - - let content = value - .get("content") - .and_then(serde_json::Value::as_str) - .map(ToString::to_string); - - return ApiMessage { - role: "assistant".to_string(), - content, - tool_call_id: None, - tool_calls: Some(tool_calls), - }; - } - } - } + if message.role == "assistant" + && let Ok(value) = serde_json::from_str::(&message.content) + && let Some(tool_calls_value) = value.get("tool_calls") + && let Ok(parsed_calls) = + serde_json::from_value::>(tool_calls_value.clone()) + { + let tool_calls = parsed_calls + .into_iter() + .map(|tool_call| NativeToolCall { + id: Some(tool_call.id), + kind: Some("function".to_string()), + function: NativeFunctionCall { + name: tool_call.name, + arguments: tool_call.arguments, + }, + }) + .collect::>(); + + let content = value + .get("content") + .and_then(serde_json::Value::as_str) + .map(ToString::to_string); + + return ApiMessage { + role: "assistant".to_string(), + content, + tool_call_id: None, + tool_calls: Some(tool_calls), + }; } - if message.role == "tool" { - if let Ok(value) = serde_json::from_str::(&message.content) { - let tool_call_id = value - .get("tool_call_id") - .and_then(serde_json::Value::as_str) - .map(ToString::to_string); - let content = value - .get("content") - .and_then(serde_json::Value::as_str) - .map(ToString::to_string); - - return ApiMessage { - role: "tool".to_string(), - content, - tool_call_id, - tool_calls: None, - }; - } + if message.role == "tool" + && let Ok(value) = serde_json::from_str::(&message.content) + { + let tool_call_id = value + .get("tool_call_id") + .and_then(serde_json::Value::as_str) + .map(ToString::to_string); + let content = value + .get("content") + .and_then(serde_json::Value::as_str) + .map(ToString::to_string); + + return ApiMessage { + role: "tool".to_string(), + content, + tool_call_id, + tool_calls: None, + }; } ApiMessage { @@ -390,28 +387,28 @@ impl CopilotProvider { async fn get_api_key(&self) -> anyhow::Result<(String, String)> { let mut cached = self.refresh_lock.lock().await; - if let Some(cached_key) = cached.as_ref() { - if chrono::Utc::now().timestamp() + 120 < cached_key.expires_at { - return Ok((cached_key.token.clone(), cached_key.api_endpoint.clone())); - } + if let Some(cached_key) = cached.as_ref() + && chrono::Utc::now().timestamp() + 120 < cached_key.expires_at + { + return Ok((cached_key.token.clone(), cached_key.api_endpoint.clone())); } - if let Some(info) = self.load_api_key_from_disk().await { - if chrono::Utc::now().timestamp() + 120 < info.expires_at { - let endpoint = info - .endpoints - .as_ref() - .and_then(|e| e.api.clone()) - .unwrap_or_else(|| DEFAULT_API.to_string()); - let token = info.token; - - *cached = Some(CachedApiKey { - token: token.clone(), - api_endpoint: endpoint.clone(), - expires_at: info.expires_at, - }); - return Ok((token, endpoint)); - } + if let Some(info) = self.load_api_key_from_disk().await + && chrono::Utc::now().timestamp() + 120 < info.expires_at + { + let endpoint = info + .endpoints + .as_ref() + .and_then(|e| e.api.clone()) + .unwrap_or_else(|| DEFAULT_API.to_string()); + let token = info.token; + + *cached = Some(CachedApiKey { + token: token.clone(), + api_endpoint: endpoint.clone(), + expires_at: info.expires_at, + }); + return Ok((token, endpoint)); } let access_token = self.get_github_access_token().await?; @@ -702,12 +699,16 @@ mod tests { #[test] fn copilot_headers_include_required_fields() { let headers = CopilotProvider::COPILOT_HEADERS; - assert!(headers - .iter() - .any(|(header, _)| *header == "Editor-Version")); - assert!(headers - .iter() - .any(|(header, _)| *header == "Editor-Plugin-Version")); + assert!( + headers + .iter() + .any(|(header, _)| *header == "Editor-Version") + ); + assert!( + headers + .iter() + .any(|(header, _)| *header == "Editor-Plugin-Version") + ); assert!(headers.iter().any(|(header, _)| *header == "User-Agent")); } diff --git a/crewforge-rs/src/provider/gemini.rs b/crewforge-rs/src/provider/gemini.rs index 7c5cb5e..d6196e5 100644 --- a/crewforge-rs/src/provider/gemini.rs +++ b/crewforge-rs/src/provider/gemini.rs @@ -205,12 +205,7 @@ impl GeminiProvider { let url = self.build_generate_content_url(model); - let response = self - .http_client() - .post(&url) - .json(&request) - .send() - .await?; + let response = self.http_client().post(&url).json(&request).send().await?; if !response.status().is_success() { let status = response.status(); @@ -464,7 +459,12 @@ mod tests { .chat_with_system(None, "hello", "gemini-2.5-pro", 0.7) .await; assert!(result.is_err()); - assert!(result.unwrap_err().to_string().contains("API key not found")); + assert!( + result + .unwrap_err() + .to_string() + .contains("API key not found") + ); if let Some(v) = old_gemini { unsafe { std::env::set_var("GEMINI_API_KEY", v) }; } diff --git a/crewforge-rs/src/provider/glm.rs b/crewforge-rs/src/provider/glm.rs index f3ed2ef..066e49b 100644 --- a/crewforge-rs/src/provider/glm.rs +++ b/crewforge-rs/src/provider/glm.rs @@ -53,8 +53,16 @@ fn base64url_encode_bytes(data: &[u8]) -> String { let mut i = 0; while i < data.len() { let b0 = data[i] as u32; - let b1 = if i + 1 < data.len() { data[i + 1] as u32 } else { 0 }; - let b2 = if i + 2 < data.len() { data[i + 2] as u32 } else { 0 }; + let b1 = if i + 1 < data.len() { + data[i + 1] as u32 + } else { + 0 + }; + let b2 = if i + 2 < data.len() { + data[i + 2] as u32 + } else { + 0 + }; let triple = (b0 << 16) | (b1 << 8) | b2; result.push(CHARS[((triple >> 18) & 0x3F) as usize] as char); @@ -101,17 +109,14 @@ impl GlmProvider { ); } - let now_ms = SystemTime::now() - .duration_since(UNIX_EPOCH)? - .as_millis() as u64; + let now_ms = SystemTime::now().duration_since(UNIX_EPOCH)?.as_millis() as u64; // Check cache (valid for 3 minutes, token expires at 3.5 min) - if let Ok(cache) = self.token_cache.lock() { - if let Some((ref token, expiry)) = *cache { - if now_ms < expiry { - return Ok(token.clone()); - } - } + if let Ok(cache) = self.token_cache.lock() + && let Some((ref token, expiry)) = *cache + && now_ms < expiry + { + return Ok(token.clone()); } let exp_ms = now_ms + 210_000; // 3.5 minutes diff --git a/crewforge-rs/src/provider/mod.rs b/crewforge-rs/src/provider/mod.rs index 8c000eb..adb10a8 100644 --- a/crewforge-rs/src/provider/mod.rs +++ b/crewforge-rs/src/provider/mod.rs @@ -12,13 +12,13 @@ pub mod router; pub mod traits; pub use traits::{ - build_tool_instructions_text, ChatMessage, ChatRequest, ChatResponse, ConversationMessage, - Provider, ProviderCapabilities, ToolCall, ToolResultMessage, ToolSpec, TokenUsage, + ChatMessage, ChatRequest, ChatResponse, ConversationMessage, Provider, ProviderCapabilities, + TokenUsage, ToolCall, ToolResultMessage, ToolSpec, build_tool_instructions_text, }; +pub use compatible::{api_error, sanitize_api_error}; pub use reliable::ReliableProvider; pub use router::{Route, RouterProvider}; -pub use compatible::{api_error, sanitize_api_error}; use compatible::{AuthStyle, OpenAiCompatibleProvider}; @@ -50,7 +50,10 @@ pub fn create_provider( resolved_key.as_deref(), )), "gemini" | "google" => Box::new(gemini::GeminiProvider::new(resolved_key.as_deref())), - "ollama" => Box::new(ollama::OllamaProvider::new(base_url, resolved_key.as_deref())), + "ollama" => Box::new(ollama::OllamaProvider::new( + base_url, + resolved_key.as_deref(), + )), "openrouter" => Box::new(openrouter::OpenRouterProvider::new(resolved_key.as_deref())), "glm" | "zhipuai" | "zhipu" => Box::new(glm::GlmProvider::new(resolved_key.as_deref())), "moonshot" | "kimi" => Box::new(OpenAiCompatibleProvider::new( @@ -128,10 +131,10 @@ pub fn create_provider( } fn resolve_api_key(provider_name: &str, explicit: Option<&str>) -> Option { - if let Some(k) = explicit { - if !k.is_empty() { - return Some(k.to_string()); - } + if let Some(k) = explicit + && !k.is_empty() + { + return Some(k.to_string()); } let env_var = default_api_key_env(provider_name)?; std::env::var(env_var).ok().filter(|k| !k.is_empty()) diff --git a/crewforge-rs/src/provider/ollama.rs b/crewforge-rs/src/provider/ollama.rs index a66eba3..d7786f2 100644 --- a/crewforge-rs/src/provider/ollama.rs +++ b/crewforge-rs/src/provider/ollama.rs @@ -93,11 +93,7 @@ struct OllamaFunction { // ─── Implementation ─────────────────────────────────────────────────────────── fn sanitize_api_error(raw: &str) -> String { - let truncated = if raw.len() > 500 { - &raw[..500] - } else { - raw - }; + let truncated = if raw.len() > 500 { &raw[..500] } else { raw }; truncated.replace('\n', " ").trim().to_string() } @@ -232,71 +228,65 @@ impl OllamaProvider { messages .iter() .map(|message| { - if message.role == "assistant" { - if let Ok(value) = serde_json::from_str::(&message.content) { - if let Some(tool_calls_value) = value.get("tool_calls") { - if let Ok(parsed_calls) = - serde_json::from_value::>(tool_calls_value.clone()) - { - let outgoing_calls: Vec = parsed_calls - .into_iter() - .map(|call| { - tool_name_by_id.insert(call.id.clone(), call.name.clone()); - OutgoingToolCall { - kind: "function".to_string(), - function: OutgoingFunction { - name: call.name, - arguments: Self::parse_tool_arguments( - &call.arguments, - ), - }, - } - }) - .collect(); - let content = value - .get("content") - .and_then(serde_json::Value::as_str) - .map(ToString::to_string); - return Message { - role: "assistant".to_string(), - content, - tool_calls: Some(outgoing_calls), - tool_name: None, - }; + if message.role == "assistant" + && let Ok(value) = serde_json::from_str::(&message.content) + && let Some(tool_calls_value) = value.get("tool_calls") + && let Ok(parsed_calls) = + serde_json::from_value::>(tool_calls_value.clone()) + { + let outgoing_calls: Vec = parsed_calls + .into_iter() + .map(|call| { + tool_name_by_id.insert(call.id.clone(), call.name.clone()); + OutgoingToolCall { + kind: "function".to_string(), + function: OutgoingFunction { + name: call.name, + arguments: Self::parse_tool_arguments(&call.arguments), + }, } - } - } + }) + .collect(); + let content = value + .get("content") + .and_then(serde_json::Value::as_str) + .map(ToString::to_string); + return Message { + role: "assistant".to_string(), + content, + tool_calls: Some(outgoing_calls), + tool_name: None, + }; } - if message.role == "tool" { - if let Ok(value) = serde_json::from_str::(&message.content) { - let tool_name = value - .get("tool_name") - .and_then(serde_json::Value::as_str) - .map(ToString::to_string) - .or_else(|| { - value - .get("tool_call_id") - .and_then(serde_json::Value::as_str) - .and_then(|id| tool_name_by_id.get(id)) - .cloned() - }); - let content = value - .get("content") - .and_then(serde_json::Value::as_str) - .map(ToString::to_string) - .or_else(|| { - (!message.content.trim().is_empty()) - .then_some(message.content.clone()) - }); - - return Message { - role: "tool".to_string(), - content, - tool_calls: None, - tool_name, - }; - } + if message.role == "tool" + && let Ok(value) = serde_json::from_str::(&message.content) + { + let tool_name = value + .get("tool_name") + .and_then(serde_json::Value::as_str) + .map(ToString::to_string) + .or_else(|| { + value + .get("tool_call_id") + .and_then(serde_json::Value::as_str) + .and_then(|id| tool_name_by_id.get(id)) + .cloned() + }); + let content = value + .get("content") + .and_then(serde_json::Value::as_str) + .map(ToString::to_string) + .or_else(|| { + (!message.content.trim().is_empty()).then_some(message.content.clone()) + }); + + return Message { + role: "tool".to_string(), + content, + tool_calls: None, + tool_name, + }; } Message { @@ -334,10 +324,8 @@ impl OllamaProvider { let mut request_builder = self.http_client().post(&url).json(&request); - if should_auth { - if let Some(key) = self.api_key.as_ref() { - request_builder = request_builder.bearer_auth(key); - } + if should_auth && let Some(key) = self.api_key.as_ref() { + request_builder = request_builder.bearer_auth(key); } let response = request_builder.send().await?; @@ -412,24 +400,23 @@ impl OllamaProvider { let args = &tc.function.arguments; // Pattern 1: Nested tool_call wrapper - if name == "tool_call" + if (name == "tool_call" || name == "tool.call" || name.starts_with("tool_call>") - || name.starts_with("tool_call<") + || name.starts_with("tool_call<")) + && let Some(nested_name) = args.get("name").and_then(|v| v.as_str()) { - if let Some(nested_name) = args.get("name").and_then(|v| v.as_str()) { - let nested_args = args - .get("arguments") - .cloned() - .unwrap_or(serde_json::json!({})); - tracing::debug!( - "Unwrapped nested tool call: {} -> {} with args {:?}", - name, - nested_name, - nested_args - ); - return (nested_name.to_string(), nested_args); - } + let nested_args = args + .get("arguments") + .cloned() + .unwrap_or(serde_json::json!({})); + tracing::debug!( + "Unwrapped nested tool call: {} -> {} with args {:?}", + name, + nested_name, + nested_args + ); + return (nested_name.to_string(), nested_args); } // Pattern 2: Prefixed tool name (tool.shell, tool.file_read, etc.) @@ -634,25 +621,25 @@ impl Provider for OllamaProvider { temperature: f64, ) -> anyhow::Result { // Convert ToolSpec to OpenAI-compatible JSON and delegate to chat_with_tools. - if let Some(specs) = request.tools { - if !specs.is_empty() { - let tools: Vec = specs - .iter() - .map(|s| { - serde_json::json!({ - "type": "function", - "function": { - "name": s.name, - "description": s.description, - "parameters": s.parameters - } - }) + if let Some(specs) = request.tools + && !specs.is_empty() + { + let tools: Vec = specs + .iter() + .map(|s| { + serde_json::json!({ + "type": "function", + "function": { + "name": s.name, + "description": s.description, + "parameters": s.parameters + } }) - .collect(); - return self - .chat_with_tools(request.messages, &tools, model, temperature) - .await; - } + }) + .collect(); + return self + .chat_with_tools(request.messages, &tools, model, temperature) + .await; } // No tools — fall back to plain text chat. @@ -718,9 +705,11 @@ mod tests { let error = p .resolve_request_details("qwen3:cloud") .expect_err("cloud suffix should fail on local endpoint"); - assert!(error - .to_string() - .contains("requested cloud routing, but Ollama endpoint is local")); + assert!( + error + .to_string() + .contains("requested cloud routing, but Ollama endpoint is local") + ); } #[test] @@ -729,9 +718,11 @@ mod tests { let error = p .resolve_request_details("qwen3:cloud") .expect_err("cloud suffix should require API key"); - assert!(error - .to_string() - .contains("requested cloud routing, but no API key is configured")); + assert!( + error + .to_string() + .contains("requested cloud routing, but no API key is configured") + ); } #[test] diff --git a/crewforge-rs/src/provider/openai.rs b/crewforge-rs/src/provider/openai.rs index 2d67122..7e4e68a 100644 --- a/crewforge-rs/src/provider/openai.rs +++ b/crewforge-rs/src/provider/openai.rs @@ -206,63 +206,58 @@ impl OpenAiProvider { messages .iter() .map(|m| { - if m.role == "assistant" { - if let Ok(value) = serde_json::from_str::(&m.content) { - if let Some(tool_calls_value) = value.get("tool_calls") { - if let Ok(parsed_calls) = - serde_json::from_value::>( - tool_calls_value.clone(), - ) - { - let tool_calls = parsed_calls - .into_iter() - .map(|tc| NativeToolCall { - id: Some(tc.id), - kind: Some("function".to_string()), - function: NativeFunctionCall { - name: tc.name, - arguments: tc.arguments, - }, - }) - .collect::>(); - let content = value - .get("content") - .and_then(serde_json::Value::as_str) - .map(ToString::to_string); - let reasoning_content = value - .get("reasoning_content") - .and_then(serde_json::Value::as_str) - .map(ToString::to_string); - return NativeMessage { - role: "assistant".to_string(), - content, - tool_call_id: None, - tool_calls: Some(tool_calls), - reasoning_content, - }; - } - } - } + if m.role == "assistant" + && let Ok(value) = serde_json::from_str::(&m.content) + && let Some(tool_calls_value) = value.get("tool_calls") + && let Ok(parsed_calls) = + serde_json::from_value::>(tool_calls_value.clone()) + { + let tool_calls = parsed_calls + .into_iter() + .map(|tc| NativeToolCall { + id: Some(tc.id), + kind: Some("function".to_string()), + function: NativeFunctionCall { + name: tc.name, + arguments: tc.arguments, + }, + }) + .collect::>(); + let content = value + .get("content") + .and_then(serde_json::Value::as_str) + .map(ToString::to_string); + let reasoning_content = value + .get("reasoning_content") + .and_then(serde_json::Value::as_str) + .map(ToString::to_string); + return NativeMessage { + role: "assistant".to_string(), + content, + tool_call_id: None, + tool_calls: Some(tool_calls), + reasoning_content, + }; } - if m.role == "tool" { - if let Ok(value) = serde_json::from_str::(&m.content) { - let tool_call_id = value - .get("tool_call_id") - .and_then(serde_json::Value::as_str) - .map(ToString::to_string); - let content = value - .get("content") - .and_then(serde_json::Value::as_str) - .map(ToString::to_string); - return NativeMessage { - role: "tool".to_string(), - content, - tool_call_id, - tool_calls: None, - reasoning_content: None, - }; - } + if m.role == "tool" + && let Ok(value) = serde_json::from_str::(&m.content) + { + let tool_call_id = value + .get("tool_call_id") + .and_then(serde_json::Value::as_str) + .map(ToString::to_string); + let content = value + .get("content") + .and_then(serde_json::Value::as_str) + .map(ToString::to_string); + return NativeMessage { + role: "tool".to_string(), + content, + tool_call_id, + tool_calls: None, + reasoning_content: None, + }; } NativeMessage { @@ -694,10 +689,12 @@ mod tests { let result = p.chat_with_tools(&messages, &tools, "gpt-4o", 0.7).await; assert!(result.is_err()); - assert!(result - .unwrap_err() - .to_string() - .contains("Invalid OpenAI tool specification")); + assert!( + result + .unwrap_err() + .to_string() + .contains("Invalid OpenAI tool specification") + ); } #[test] diff --git a/crewforge-rs/src/provider/openai_codex.rs b/crewforge-rs/src/provider/openai_codex.rs index 71a40e0..a6edf78 100644 --- a/crewforge-rs/src/provider/openai_codex.rs +++ b/crewforge-rs/src/provider/openai_codex.rs @@ -1,7 +1,7 @@ -use crate::auth::openai_oauth::extract_account_id_from_jwt; use crate::auth::AuthService; -use crate::provider::traits::{ChatMessage, Provider, ProviderCapabilities}; +use crate::auth::openai_oauth::extract_account_id_from_jwt; use crate::provider::ProviderRuntimeOptions; +use crate::provider::traits::{ChatMessage, Provider, ProviderCapabilities}; use async_trait::async_trait; use reqwest::Client; use serde::{Deserialize, Serialize}; @@ -198,6 +198,7 @@ fn first_nonempty(text: Option<&str>) -> Option { }) } +#[allow(dead_code)] fn resolve_instructions(system_prompt: Option<&str>) -> String { first_nonempty(system_prompt).unwrap_or_else(|| DEFAULT_CODEX_INSTRUCTIONS.to_string()) } @@ -315,10 +316,10 @@ fn extract_responses_text(response: &ResponsesResponse) -> Option { for item in &response.output { for content in &item.content { - if content.kind.as_deref() == Some("output_text") { - if let Some(text) = first_nonempty(content.text.as_deref()) { - return Some(text); - } + if content.kind.as_deref() == Some("output_text") + && let Some(text) = first_nonempty(content.text.as_deref()) + { + return Some(text); } } } diff --git a/crewforge-rs/src/provider/openrouter.rs b/crewforge-rs/src/provider/openrouter.rs index 3dae0a5..5887d6b 100644 --- a/crewforge-rs/src/provider/openrouter.rs +++ b/crewforge-rs/src/provider/openrouter.rs @@ -164,64 +164,59 @@ impl OpenRouterProvider { messages .iter() .map(|m| { - if m.role == "assistant" { - if let Ok(value) = serde_json::from_str::(&m.content) { - if let Some(tool_calls_value) = value.get("tool_calls") { - if let Ok(parsed_calls) = - serde_json::from_value::>( - tool_calls_value.clone(), - ) - { - let tool_calls = parsed_calls - .into_iter() - .map(|tc| NativeToolCall { - id: Some(tc.id), - kind: Some("function".to_string()), - function: NativeFunctionCall { - name: tc.name, - arguments: tc.arguments, - }, - }) - .collect::>(); - let content = value - .get("content") - .and_then(serde_json::Value::as_str) - .map(ToString::to_string); - let reasoning_content = value - .get("reasoning_content") - .and_then(serde_json::Value::as_str) - .map(ToString::to_string); - return NativeMessage { - role: "assistant".to_string(), - content, - tool_call_id: None, - tool_calls: Some(tool_calls), - reasoning_content, - }; - } - } - } + if m.role == "assistant" + && let Ok(value) = serde_json::from_str::(&m.content) + && let Some(tool_calls_value) = value.get("tool_calls") + && let Ok(parsed_calls) = + serde_json::from_value::>(tool_calls_value.clone()) + { + let tool_calls = parsed_calls + .into_iter() + .map(|tc| NativeToolCall { + id: Some(tc.id), + kind: Some("function".to_string()), + function: NativeFunctionCall { + name: tc.name, + arguments: tc.arguments, + }, + }) + .collect::>(); + let content = value + .get("content") + .and_then(serde_json::Value::as_str) + .map(ToString::to_string); + let reasoning_content = value + .get("reasoning_content") + .and_then(serde_json::Value::as_str) + .map(ToString::to_string); + return NativeMessage { + role: "assistant".to_string(), + content, + tool_call_id: None, + tool_calls: Some(tool_calls), + reasoning_content, + }; } - if m.role == "tool" { - if let Ok(value) = serde_json::from_str::(&m.content) { - let tool_call_id = value - .get("tool_call_id") - .and_then(serde_json::Value::as_str) - .map(ToString::to_string); - let content = value - .get("content") - .and_then(serde_json::Value::as_str) - .map(ToString::to_string) - .or_else(|| Some(m.content.clone())); - return NativeMessage { - role: "tool".to_string(), - content, - tool_call_id, - tool_calls: None, - reasoning_content: None, - }; - } + if m.role == "tool" + && let Ok(value) = serde_json::from_str::(&m.content) + { + let tool_call_id = value + .get("tool_call_id") + .and_then(serde_json::Value::as_str) + .map(ToString::to_string); + let content = value + .get("content") + .and_then(serde_json::Value::as_str) + .map(ToString::to_string) + .or_else(|| Some(m.content.clone())); + return NativeMessage { + role: "tool".to_string(), + content, + tool_call_id, + tool_calls: None, + reasoning_content: None, + }; } NativeMessage { @@ -292,10 +287,9 @@ impl Provider for OpenRouterProvider { model: &str, temperature: f64, ) -> anyhow::Result { - let credential = self - .credential - .as_ref() - .ok_or_else(|| anyhow::anyhow!("OpenRouter API key not set. Set OPENROUTER_API_KEY."))?; + let credential = self.credential.as_ref().ok_or_else(|| { + anyhow::anyhow!("OpenRouter API key not set. Set OPENROUTER_API_KEY.") + })?; let mut messages = Vec::new(); @@ -347,10 +341,9 @@ impl Provider for OpenRouterProvider { model: &str, temperature: f64, ) -> anyhow::Result { - let credential = self - .credential - .as_ref() - .ok_or_else(|| anyhow::anyhow!("OpenRouter API key not set. Set OPENROUTER_API_KEY."))?; + let credential = self.credential.as_ref().ok_or_else(|| { + anyhow::anyhow!("OpenRouter API key not set. Set OPENROUTER_API_KEY.") + })?; let api_messages: Vec = messages .iter() @@ -478,11 +471,7 @@ impl Provider for OpenRouterProvider { }) }) .collect(); - if specs.is_empty() { - None - } else { - Some(specs) - } + if specs.is_empty() { None } else { Some(specs) } }; let native_messages = Self::convert_messages(messages); diff --git a/crewforge-rs/src/provider/reliable.rs b/crewforge-rs/src/provider/reliable.rs index 2bb183d..8a2baac 100644 --- a/crewforge-rs/src/provider/reliable.rs +++ b/crewforge-rs/src/provider/reliable.rs @@ -9,8 +9,8 @@ //! Loop invariant: `failures` accumulates every failed attempt so the final //! error message gives operators a complete diagnostic trail. -use crate::provider::traits::{ChatMessage, ChatRequest, ChatResponse}; use crate::provider::Provider; +use crate::provider::traits::{ChatMessage, ChatRequest, ChatResponse}; use async_trait::async_trait; use std::collections::HashMap; use std::sync::atomic::{AtomicUsize, Ordering}; @@ -30,20 +30,20 @@ fn is_non_retryable(err: &anyhow::Error) -> bool { // 4xx errors are generally non-retryable (bad request, auth failure, etc.), // except 429 (rate-limit — transient) and 408 (timeout — worth retrying). - if let Some(reqwest_err) = err.downcast_ref::() { - if let Some(status) = reqwest_err.status() { - let code = status.as_u16(); - return status.is_client_error() && code != 429 && code != 408; - } + if let Some(reqwest_err) = err.downcast_ref::() + && let Some(status) = reqwest_err.status() + { + let code = status.as_u16(); + return status.is_client_error() && code != 429 && code != 408; } // Fallback: parse status codes from stringified errors (some providers // embed codes in error messages rather than returning typed HTTP errors). let msg = err.to_string(); for word in msg.split(|c: char| !c.is_ascii_digit()) { - if let Ok(code) = word.parse::() { - if (400..500).contains(&code) { - return code != 429 && code != 408; - } + if let Ok(code) = word.parse::() + && (400..500).contains(&code) + { + return code != 429 && code != 408; } } @@ -97,10 +97,10 @@ fn is_context_window_exceeded(err: &anyhow::Error) -> bool { /// Check if an error is a rate-limit (429) error. fn is_rate_limited(err: &anyhow::Error) -> bool { - if let Some(reqwest_err) = err.downcast_ref::() { - if let Some(status) = reqwest_err.status() { - return status.as_u16() == 429; - } + if let Some(reqwest_err) = err.downcast_ref::() + && let Some(status) = reqwest_err.status() + { + return status.as_u16() == 429; } let msg = err.to_string(); msg.contains("429") @@ -143,10 +143,10 @@ fn is_non_retryable_rate_limit(err: &anyhow::Error) -> bool { // Known provider business codes observed for 429 where retry is futile. for token in lower.split(|c: char| !c.is_ascii_digit()) { - if let Ok(code) = token.parse::() { - if matches!(code, 1113 | 1311) { - return true; - } + if let Ok(code) = token.parse::() + && matches!(code, 1113 | 1311) + { + return true; } } @@ -173,12 +173,13 @@ fn parse_retry_after_ms(err: &anyhow::Error) -> Option { .chars() .take_while(|c| c.is_ascii_digit() || *c == '.') .collect(); - if let Ok(secs) = num_str.parse::() { - if secs.is_finite() && secs >= 0.0 { - let millis = Duration::from_secs_f64(secs).as_millis(); - if let Ok(value) = u64::try_from(millis) { - return Some(value); - } + if let Ok(secs) = num_str.parse::() + && secs.is_finite() + && secs >= 0.0 + { + let millis = Duration::from_secs_f64(secs).as_millis(); + if let Ok(value) = u64::try_from(millis) { + return Some(value); } } } @@ -410,17 +411,18 @@ impl Provider for ReliableProvider { // Rate-limit with rotatable keys: cycle to the next API key // so the retry hits a different quota bucket. - if rate_limited && !non_retryable_rate_limit { - if let Some(new_key) = self.rotate_key() { - tracing::warn!( - provider = provider_name, - error = %error_detail, - "Rate limited; key rotation selected key ending ...{} \ - but cannot apply (Provider trait has no set_api_key). \ - Retrying with original key.", - &new_key[new_key.len().saturating_sub(4)..] - ); - } + if rate_limited + && !non_retryable_rate_limit + && let Some(new_key) = self.rotate_key() + { + tracing::warn!( + provider = provider_name, + error = %error_detail, + "Rate limited; key rotation selected key ending ...{} \ + but cannot apply (Provider trait has no set_api_key). \ + Retrying with original key.", + &new_key[new_key.len().saturating_sub(4)..] + ); } if non_retryable { @@ -528,17 +530,18 @@ impl Provider for ReliableProvider { &error_detail, ); - if rate_limited && !non_retryable_rate_limit { - if let Some(new_key) = self.rotate_key() { - tracing::warn!( - provider = provider_name, - error = %error_detail, - "Rate limited; key rotation selected key ending ...{} \ - but cannot apply (Provider trait has no set_api_key). \ - Retrying with original key.", - &new_key[new_key.len().saturating_sub(4)..] - ); - } + if rate_limited + && !non_retryable_rate_limit + && let Some(new_key) = self.rotate_key() + { + tracing::warn!( + provider = provider_name, + error = %error_detail, + "Rate limited; key rotation selected key ending ...{} \ + but cannot apply (Provider trait has no set_api_key). \ + Retrying with original key.", + &new_key[new_key.len().saturating_sub(4)..] + ); } if non_retryable { @@ -652,17 +655,18 @@ impl Provider for ReliableProvider { &error_detail, ); - if rate_limited && !non_retryable_rate_limit { - if let Some(new_key) = self.rotate_key() { - tracing::warn!( - provider = provider_name, - error = %error_detail, - "Rate limited; key rotation selected key ending ...{} \ - but cannot apply (Provider trait has no set_api_key). \ - Retrying with original key.", - &new_key[new_key.len().saturating_sub(4)..] - ); - } + if rate_limited + && !non_retryable_rate_limit + && let Some(new_key) = self.rotate_key() + { + tracing::warn!( + provider = provider_name, + error = %error_detail, + "Rate limited; key rotation selected key ending ...{} \ + but cannot apply (Provider trait has no set_api_key). \ + Retrying with original key.", + &new_key[new_key.len().saturating_sub(4)..] + ); } if non_retryable { @@ -763,17 +767,18 @@ impl Provider for ReliableProvider { &error_detail, ); - if rate_limited && !non_retryable_rate_limit { - if let Some(new_key) = self.rotate_key() { - tracing::warn!( - provider = provider_name, - error = %error_detail, - "Rate limited; key rotation selected key ending ...{} \ - but cannot apply (Provider trait has no set_api_key). \ - Retrying with original key.", - &new_key[new_key.len().saturating_sub(4)..] - ); - } + if rate_limited + && !non_retryable_rate_limit + && let Some(new_key) = self.rotate_key() + { + tracing::warn!( + provider = provider_name, + error = %error_detail, + "Rate limited; key rotation selected key ending ...{} \ + but cannot apply (Provider trait has no set_api_key). \ + Retrying with original key.", + &new_key[new_key.len().saturating_sub(4)..] + ); } if non_retryable { diff --git a/crewforge-rs/src/provider/router.rs b/crewforge-rs/src/provider/router.rs index 167f544..f8e8888 100644 --- a/crewforge-rs/src/provider/router.rs +++ b/crewforge-rs/src/provider/router.rs @@ -1,8 +1,8 @@ //! Multi-model router that dispatches requests to different provider+model //! combinations based on a task hint encoded in the model parameter. -use crate::provider::traits::{ChatMessage, ChatRequest, ChatResponse}; use crate::provider::Provider; +use crate::provider::traits::{ChatMessage, ChatRequest, ChatResponse}; use async_trait::async_trait; use std::collections::HashMap; @@ -25,6 +25,7 @@ pub struct RouterProvider { routes: HashMap, // hint → (provider_index, model) providers: Vec<(String, Box)>, default_index: usize, + #[allow(dead_code)] default_model: String, } diff --git a/crewforge-rs/src/provider/traits.rs b/crewforge-rs/src/provider/traits.rs index 7391019..3d668c0 100644 --- a/crewforge-rs/src/provider/traits.rs +++ b/crewforge-rs/src/provider/traits.rs @@ -235,32 +235,32 @@ pub trait Provider: Send + Sync { model: &str, temperature: f64, ) -> anyhow::Result { - if let Some(tools) = request.tools { - if !tools.is_empty() && !self.supports_native_tools() { - let tool_instructions = build_tool_instructions_text(tools); - let mut modified_messages = request.messages.to_vec(); - - if let Some(system_message) = - modified_messages.iter_mut().find(|m| m.role == "system") - { - if !system_message.content.is_empty() { - system_message.content.push_str("\n\n"); - } - system_message.content.push_str(&tool_instructions); - } else { - modified_messages.insert(0, ChatMessage::system(tool_instructions)); + if let Some(tools) = request.tools + && !tools.is_empty() + && !self.supports_native_tools() + { + let tool_instructions = build_tool_instructions_text(tools); + let mut modified_messages = request.messages.to_vec(); + + if let Some(system_message) = modified_messages.iter_mut().find(|m| m.role == "system") + { + if !system_message.content.is_empty() { + system_message.content.push_str("\n\n"); } - - let text = self - .chat_with_history(&modified_messages, model, temperature) - .await?; - return Ok(ChatResponse { - text: Some(text), - tool_calls: Vec::new(), - usage: None, - reasoning_content: None, - }); + system_message.content.push_str(&tool_instructions); + } else { + modified_messages.insert(0, ChatMessage::system(tool_instructions)); } + + let text = self + .chat_with_history(&modified_messages, model, temperature) + .await?; + return Ok(ChatResponse { + text: Some(text), + tool_calls: Vec::new(), + usage: None, + reasoning_content: None, + }); } let text = self diff --git a/crewforge-rs/src/security/policy.rs b/crewforge-rs/src/security/policy.rs index cba93cc..3fddb47 100644 --- a/crewforge-rs/src/security/policy.rs +++ b/crewforge-rs/src/security/policy.rs @@ -384,9 +384,7 @@ fn is_valid_env_var_name(name: &str) -> bool { .chars() .next() .is_some_and(|c| c.is_ascii_alphabetic() || c == '_') - && name - .chars() - .all(|c| c.is_ascii_alphanumeric() || c == '_') + && name.chars().all(|c| c.is_ascii_alphanumeric() || c == '_') } /// Detect unquoted shell variable expansions that are not explicitly allowlisted. @@ -786,7 +784,7 @@ impl SecurityPolicy { } if candidate.starts_with('-') { - if let Some((_, value)) = candidate.split_once('=') { + if let Some((_, value)) = candidate.split_once('=') { let blocked = forbidden_candidate(value); if blocked.is_some() { return blocked; @@ -1220,9 +1218,10 @@ mod tests { #[test] fn validate_blocks_medium_risk_without_approval() { let p = test_policy(); - assert!(p - .validate_command_execution("git commit -m x", false) - .is_err()); + assert!( + p.validate_command_execution("git commit -m x", false) + .is_err() + ); } #[test] @@ -1247,9 +1246,10 @@ mod tests { autonomy: AutonomyLevel::Full, ..SecurityPolicy::default() }; - assert!(p - .validate_command_execution("git commit -m x", false) - .is_ok()); + assert!( + p.validate_command_execution("git commit -m x", false) + .is_ok() + ); } // ── Rate limiting tests ───────────────────────────────────────────── @@ -1305,7 +1305,10 @@ mod tests { autonomy: AutonomyLevel::ReadOnly, ..SecurityPolicy::default() }; - assert!(p.enforce_tool_operation(ToolOperation::Read, "read").is_ok()); + assert!( + p.enforce_tool_operation(ToolOperation::Read, "read") + .is_ok() + ); } #[test] @@ -1314,9 +1317,10 @@ mod tests { autonomy: AutonomyLevel::ReadOnly, ..SecurityPolicy::default() }; - assert!(p - .enforce_tool_operation(ToolOperation::Act, "write") - .is_err()); + assert!( + p.enforce_tool_operation(ToolOperation::Act, "write") + .is_err() + ); } #[test] @@ -1325,12 +1329,14 @@ mod tests { max_actions_per_hour: 1, ..SecurityPolicy::default() }; - assert!(p - .enforce_tool_operation(ToolOperation::Act, "write") - .is_ok()); - assert!(p - .enforce_tool_operation(ToolOperation::Act, "write") - .is_err()); + assert!( + p.enforce_tool_operation(ToolOperation::Act, "write") + .is_ok() + ); + assert!( + p.enforce_tool_operation(ToolOperation::Act, "write") + .is_err() + ); } // ── Forbidden path argument tests ─────────────────────────────────── @@ -1350,9 +1356,10 @@ mod tests { #[test] fn forbidden_path_argument_option_value() { let p = test_policy(); - assert!(p - .forbidden_path_argument("cmd --file=/etc/passwd") - .is_some()); + assert!( + p.forbidden_path_argument("cmd --file=/etc/passwd") + .is_some() + ); } // ── Default policy sanity ─────────────────────────────────────────── @@ -1379,10 +1386,7 @@ mod tests { for dir in &[ "/etc", "/root", "/usr", "/bin", "/sbin", "/boot", "/dev", "/proc", "/sys", ] { - assert!( - !p.is_path_allowed(dir), - "{dir} should be blocked" - ); + assert!(!p.is_path_allowed(dir), "{dir} should be blocked"); } } diff --git a/crewforge-rs/src/security/secrets.rs b/crewforge-rs/src/security/secrets.rs index 23f38d1..7f3a5a4 100644 --- a/crewforge-rs/src/security/secrets.rs +++ b/crewforge-rs/src/security/secrets.rs @@ -27,6 +27,7 @@ use std::fs; use std::path::{Path, PathBuf}; /// Length of the random encryption key in bytes (256-bit, matches `ChaCha20`). +#[allow(dead_code)] const KEY_LEN: usize = 32; /// ChaCha20-Poly1305 nonce length in bytes. @@ -254,6 +255,7 @@ fn hex_encode(data: &[u8]) -> String { /// Build the `/grant` argument for `icacls` using a normalized username. /// Returns `None` when the username is empty or whitespace-only. +#[allow(dead_code)] fn build_windows_icacls_grant_arg(username: &str) -> Option { let normalized = username.trim(); if normalized.is_empty() { diff --git a/crewforge-rs/src/tools/content_search.rs b/crewforge-rs/src/tools/content_search.rs index da4fd0d..fd679a6 100644 --- a/crewforge-rs/src/tools/content_search.rs +++ b/crewforge-rs/src/tools/content_search.rs @@ -471,15 +471,15 @@ fn format_line_output( } } "count" => { - if let Some((path, count)) = parse_count_line(&relativized) { - if count > 0 { - file_set.insert(path.to_string()); - total_matches += count; - lines.push(format!("{path}:{count}")); - if lines.len() >= max_results { - truncated = true; - break; - } + if let Some((path, count)) = parse_count_line(&relativized) + && count > 0 + { + file_set.insert(path.to_string()); + total_matches += count; + lines.push(format!("{path}:{count}")); + if lines.len() >= max_results { + truncated = true; + break; } } } @@ -563,7 +563,6 @@ fn truncate_utf8(input: &str, max_bytes: usize) -> &str { mod tests { use super::*; use crate::agent::Tool; - use crate::security::AutonomyLevel; use serde_json::json; fn test_security(workspace: std::path::PathBuf) -> Arc { @@ -763,10 +762,7 @@ mod tests { false, // no rg ); // Without multiline support in grep fallback, this should still work for basic patterns - let result = tool - .execute(json!({"pattern": "line1"})) - .await - .unwrap(); + let result = tool.execute(json!({"pattern": "line1"})).await.unwrap(); assert!(result.success); } diff --git a/crewforge-rs/src/tools/file_edit.rs b/crewforge-rs/src/tools/file_edit.rs index b922fe3..843d756 100644 --- a/crewforge-rs/src/tools/file_edit.rs +++ b/crewforge-rs/src/tools/file_edit.rs @@ -133,17 +133,17 @@ impl crate::agent::Tool for FileEditTool { let resolved_target = resolved_parent.join(file_name); // Symlink check - if let Ok(meta) = tokio::fs::symlink_metadata(&resolved_target).await { - if meta.file_type().is_symlink() { - return Ok(ToolResult { - success: false, - output: String::new(), - error: Some(format!( - "Refusing to edit through symlink: {}", - resolved_target.display() - )), - }); - } + if let Ok(meta) = tokio::fs::symlink_metadata(&resolved_target).await + && meta.file_type().is_symlink() + { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!( + "Refusing to edit through symlink: {}", + resolved_target.display() + )), + }); } if !self.security.record_action() { @@ -284,11 +284,13 @@ mod tests { .unwrap(); assert!(!result.success); - assert!(result - .error - .as_deref() - .unwrap_or("") - .contains("matches 2 times")); + assert!( + result + .error + .as_deref() + .unwrap_or("") + .contains("matches 2 times") + ); } #[tokio::test] @@ -333,11 +335,13 @@ mod tests { .unwrap(); assert!(!result.success); - assert!(result - .error - .as_deref() - .unwrap_or("") - .contains("must not be empty")); + assert!( + result + .error + .as_deref() + .unwrap_or("") + .contains("must not be empty") + ); } #[tokio::test] @@ -466,10 +470,12 @@ mod tests { .unwrap(); assert!(!result.success); - assert!(result - .error - .as_deref() - .unwrap_or("") - .contains("Failed to read file")); + assert!( + result + .error + .as_deref() + .unwrap_or("") + .contains("Failed to read file") + ); } } diff --git a/crewforge-rs/src/tools/file_read.rs b/crewforge-rs/src/tools/file_read.rs index 1d2680d..e0cba38 100644 --- a/crewforge-rs/src/tools/file_read.rs +++ b/crewforge-rs/src/tools/file_read.rs @@ -93,7 +93,10 @@ impl crate::agent::Tool for FileReadTool { return Ok(ToolResult { success: false, output: String::new(), - error: Some(self.security.resolved_path_violation_message(&resolved_path)), + error: Some( + self.security + .resolved_path_violation_message(&resolved_path), + ), }); } @@ -251,10 +254,7 @@ mod tests { async fn file_read_blocks_absolute_path() { let dir = tempfile::tempdir().unwrap(); let tool = FileReadTool::new(test_security(dir.path().to_path_buf())); - let result = tool - .execute(json!({"path": "/etc/passwd"})) - .await - .unwrap(); + let result = tool.execute(json!({"path": "/etc/passwd"})).await.unwrap(); assert!(!result.success); } @@ -279,9 +279,12 @@ mod tests { #[tokio::test] async fn file_read_with_offset_and_limit() { let dir = tempfile::tempdir().unwrap(); - tokio::fs::write(dir.path().join("multi.txt"), "line1\nline2\nline3\nline4\nline5") - .await - .unwrap(); + tokio::fs::write( + dir.path().join("multi.txt"), + "line1\nline2\nline3\nline4\nline5", + ) + .await + .unwrap(); let tool = FileReadTool::new(test_security(dir.path().to_path_buf())); let result = tool @@ -318,10 +321,7 @@ mod tests { .unwrap(); let tool = FileReadTool::new(test_security(dir.path().to_path_buf())); - let result = tool - .execute(json!({"path": "empty.txt"})) - .await - .unwrap(); + let result = tool.execute(json!({"path": "empty.txt"})).await.unwrap(); assert!(result.success); assert!(result.output.is_empty()); } @@ -368,10 +368,7 @@ mod tests { async fn file_read_blocks_null_byte_in_path() { let dir = tempfile::tempdir().unwrap(); let tool = FileReadTool::new(test_security(dir.path().to_path_buf())); - let result = tool - .execute(json!({"path": "file\0.txt"})) - .await - .unwrap(); + let result = tool.execute(json!({"path": "file\0.txt"})).await.unwrap(); assert!(!result.success); } @@ -397,10 +394,7 @@ mod tests { .unwrap(); let tool = FileReadTool::new(test_security(dir.path().to_path_buf())); - let result = tool - .execute(json!({"path": "binary.bin"})) - .await - .unwrap(); + let result = tool.execute(json!({"path": "binary.bin"})).await.unwrap(); assert!(result.success); assert!(result.output.contains("hello")); } diff --git a/crewforge-rs/src/tools/file_write.rs b/crewforge-rs/src/tools/file_write.rs index e8d64ad..826ef86 100644 --- a/crewforge-rs/src/tools/file_write.rs +++ b/crewforge-rs/src/tools/file_write.rs @@ -114,14 +114,14 @@ impl crate::agent::Tool for FileWriteTool { #[cfg(unix)] if full_path.exists() { let meta = std::fs::symlink_metadata(&full_path); - if let Ok(m) = meta { - if m.file_type().is_symlink() { - return Ok(ToolResult { - success: false, - output: String::new(), - error: Some("Refusing to write through symlink".into()), - }); - } + if let Ok(m) = meta + && m.file_type().is_symlink() + { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("Refusing to write through symlink".into()), + }); } } diff --git a/crewforge-rs/src/tools/glob_search.rs b/crewforge-rs/src/tools/glob_search.rs index 48da5a6..8066424 100644 --- a/crewforge-rs/src/tools/glob_search.rs +++ b/crewforge-rs/src/tools/glob_search.rs @@ -165,7 +165,6 @@ impl crate::agent::Tool for GlobSearchTool { mod tests { use super::*; use crate::agent::Tool; - use crate::security::AutonomyLevel; use serde_json::json; fn test_security(workspace: std::path::PathBuf) -> Arc { @@ -181,10 +180,7 @@ mod tests { std::fs::write(dir.path().join("hello.txt"), "content").unwrap(); let tool = GlobSearchTool::new(test_security(dir.path().to_path_buf())); - let result = tool - .execute(json!({"pattern": "hello.txt"})) - .await - .unwrap(); + let result = tool.execute(json!({"pattern": "hello.txt"})).await.unwrap(); assert!(result.success); assert!(result.output.contains("hello.txt")); @@ -215,10 +211,7 @@ mod tests { std::fs::write(dir.path().join("sub/deep/leaf.txt"), "").unwrap(); let tool = GlobSearchTool::new(test_security(dir.path().to_path_buf())); - let result = tool - .execute(json!({"pattern": "**/*.txt"})) - .await - .unwrap(); + let result = tool.execute(json!({"pattern": "**/*.txt"})).await.unwrap(); assert!(result.success); assert!(result.output.contains("root.txt")); @@ -243,10 +236,7 @@ mod tests { #[tokio::test] async fn glob_search_rejects_absolute_path() { let tool = GlobSearchTool::new(test_security(std::env::temp_dir())); - let result = tool - .execute(json!({"pattern": "/etc/**/*"})) - .await - .unwrap(); + let result = tool.execute(json!({"pattern": "/etc/**/*"})).await.unwrap(); assert!(!result.success); assert!(result.error.as_ref().unwrap().contains("Absolute paths")); @@ -342,16 +332,15 @@ mod tests { let dir = tempfile::tempdir().unwrap(); let tool = GlobSearchTool::new(test_security(dir.path().to_path_buf())); - let result = tool - .execute(json!({"pattern": "[invalid"})) - .await - .unwrap(); + let result = tool.execute(json!({"pattern": "[invalid"})).await.unwrap(); assert!(!result.success); - assert!(result - .error - .as_ref() - .unwrap() - .contains("Invalid glob pattern")); + assert!( + result + .error + .as_ref() + .unwrap() + .contains("Invalid glob pattern") + ); } } diff --git a/crewforge-rs/src/tools/shell.rs b/crewforge-rs/src/tools/shell.rs index 90dd73e..f460098 100644 --- a/crewforge-rs/src/tools/shell.rs +++ b/crewforge-rs/src/tools/shell.rs @@ -263,10 +263,12 @@ mod tests { assert_eq!(tool.name(), "shell"); let schema = tool.parameters(); assert!(schema["properties"]["command"].is_object()); - assert!(schema["required"] - .as_array() - .unwrap() - .contains(&json!("command"))); + assert!( + schema["required"] + .as_array() + .unwrap() + .contains(&json!("command")) + ); } #[tokio::test] @@ -283,10 +285,7 @@ mod tests { #[tokio::test] async fn shell_blocks_dangerous_command() { let tool = ShellTool::new(test_security(AutonomyLevel::Supervised), test_runtime()); - let result = tool - .execute(json!({"command": "rm -rf /"})) - .await - .unwrap(); + let result = tool.execute(json!({"command": "rm -rf /"})).await.unwrap(); assert!(!result.success); } @@ -299,10 +298,7 @@ mod tests { ..SecurityPolicy::default() }); let tool = ShellTool::new(security, test_runtime()); - let result = tool - .execute(json!({"command": "echo test"})) - .await - .unwrap(); + let result = tool.execute(json!({"command": "echo test"})).await.unwrap(); assert!(!result.success); assert!(result.error.as_deref().unwrap_or("").contains("Rate limit")); } @@ -310,10 +306,7 @@ mod tests { #[tokio::test] async fn shell_blocks_readonly_mode() { let tool = ShellTool::new(test_security(AutonomyLevel::ReadOnly), test_runtime()); - let result = tool - .execute(json!({"command": "ls"})) - .await - .unwrap(); + let result = tool.execute(json!({"command": "ls"})).await.unwrap(); assert!(!result.success); assert!(result.error.as_ref().unwrap().contains("not allowed")); } @@ -326,11 +319,13 @@ mod tests { .await .unwrap(); assert!(!result.success); - assert!(result - .error - .as_deref() - .unwrap_or("") - .contains("Path blocked")); + assert!( + result + .error + .as_deref() + .unwrap_or("") + .contains("Path blocked") + ); } #[tokio::test] diff --git a/crewforge-rs/src/tui.rs b/crewforge-rs/src/tui.rs index 75582bb..ddb3025 100644 --- a/crewforge-rs/src/tui.rs +++ b/crewforge-rs/src/tui.rs @@ -338,7 +338,9 @@ fn prefixed_message_lines( if density.is_comfort() { spans.push(Span::styled( " ", - Style::default().fg(Color::DarkGray).add_modifier(Modifier::DIM), + Style::default() + .fg(Color::DarkGray) + .add_modifier(Modifier::DIM), )); } spans.push(Span::raw(line)); @@ -415,7 +417,10 @@ fn agent_status_symbol(state: AgentStatusState) -> &'static str { } } -fn build_status_line(statuses: &BTreeMap, density: UiDensity) -> Line<'static> { +fn build_status_line( + statuses: &BTreeMap, + density: UiDensity, +) -> Line<'static> { let has_running = statuses .values() .any(|entry| matches!(entry.state, AgentStatusState::Active)); @@ -430,7 +435,9 @@ fn build_status_line(statuses: &BTreeMap, density: UiD let density_style = if density.is_comfort() { Style::default().fg(Color::Cyan).add_modifier(Modifier::DIM) } else { - Style::default().fg(Color::Yellow).add_modifier(Modifier::DIM) + Style::default() + .fg(Color::Yellow) + .add_modifier(Modifier::DIM) }; let mut spans = vec![ @@ -596,7 +603,11 @@ impl RenderedLineCache { fn ensure_rows(&mut self, lines: &[DisplayLine], view_width: u16, density: UiDensity) { let view_width = view_width.max(1); - if !self.valid || self.width != view_width || self.len != lines.len() || self.density != density { + if !self.valid + || self.width != view_width + || self.len != lines.len() + || self.density != density + { self.rows = rendered_rows_for_lines(lines, view_width, density); self.width = view_width; self.len = lines.len(); @@ -1074,8 +1085,8 @@ pub async fn run_tui_loop( }, if pending_prefetch.is_some() => { let (start, join) = prefetch_join; let events = join.context("failed joining history prefetch task")??; - if history_pager.next_older_start() == Some(start) { - if apply_loaded_history_events( + if history_pager.next_older_start() == Some(start) + && apply_loaded_history_events( start, events, &mut history_pager, @@ -1088,7 +1099,6 @@ pub async fn run_tui_loop( ) { mark_render_dirty(&mut render_dirty, &mut render_pending_since); } - } } maybe_line = msg_rx.recv() => { match maybe_line { @@ -1359,6 +1369,7 @@ fn maybe_start_history_prefetch( *pending_prefetch = Some(HistoryPrefetch { start, task }); } +#[allow(clippy::too_many_arguments)] async fn prepend_next_older_page( history_pager: &mut SessionHistoryPager, runtime: &ChatRuntime, @@ -1407,6 +1418,7 @@ async fn prepend_next_older_page( .await } +#[allow(clippy::too_many_arguments)] fn apply_loaded_history_events( start: usize, events: Vec, @@ -1444,7 +1456,8 @@ fn prepend_history_lines( if older_lines.is_empty() { return false; } - let added_rows = rendered_line_count(&older_lines, view_width, density).min(u16::MAX as usize) as u16; + let added_rows = + rendered_line_count(&older_lines, view_width, density).min(u16::MAX as usize) as u16; display_lines.splice(0..0, older_lines); rendered_line_cache.invalidate(); *scroll_offset = scroll_offset.saturating_add(added_rows); @@ -1602,7 +1615,11 @@ fn rendered_line_count(lines: &[DisplayLine], view_width: u16, density: UiDensit rendered_rows_for_lines(lines, view_width, density).len() } -fn rendered_rows_for_lines(lines: &[DisplayLine], view_width: u16, density: UiDensity) -> Vec> { +fn rendered_rows_for_lines( + lines: &[DisplayLine], + view_width: u16, + density: UiDensity, +) -> Vec> { let view_width = view_width.max(1) as usize; let mut rows = Vec::new(); for line in lines { From 974c123d55a3229fa47a8546a1a523c10d8af869 Mon Sep 17 00:00:00 2001 From: Rexopia Date: Sun, 1 Mar 2026 21:15:30 +0800 Subject: [PATCH 7/8] fix: address review findings for tools + security MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - I-1: Deduplicate is_valid_env_var_name (pub(crate) in policy.rs, import in shell.rs) - I-2: Remove record_action() from FileReadTool (reads don't consume rate-limit budget) - I-3: ContentSearchTool builds tokio::process::Command directly (no std→tokio conversion) - S-1: Document env sandboxing intent in ShellTool - S-2: Standardize OnceLock → LazyLock for compiled regexes - S-3: GlobSearchTool uses spawn_blocking to avoid blocking async runtime - S-4: Remove undocumented max_results arg extraction from ContentSearchTool - S-5: Add ToolResult::ok()/denied() convenience constructors, apply in FileReadTool Co-Authored-By: Claude Opus 4.6 --- crewforge-rs/src/agent/mod.rs | 18 +++++ crewforge-rs/src/security/policy.rs | 2 +- crewforge-rs/src/tools/content_search.rs | 91 ++++++++++------------- crewforge-rs/src/tools/file_read.rs | 88 ++++++---------------- crewforge-rs/src/tools/glob_search.rs | 95 +++++++++++++----------- crewforge-rs/src/tools/shell.rs | 16 +--- 6 files changed, 137 insertions(+), 173 deletions(-) diff --git a/crewforge-rs/src/agent/mod.rs b/crewforge-rs/src/agent/mod.rs index 6049199..80a3394 100644 --- a/crewforge-rs/src/agent/mod.rs +++ b/crewforge-rs/src/agent/mod.rs @@ -18,6 +18,24 @@ pub struct ToolResult { pub error: Option, } +impl ToolResult { + pub fn ok(output: impl Into) -> Self { + Self { + success: true, + output: output.into(), + error: None, + } + } + + pub fn denied(message: impl Into) -> Self { + Self { + success: false, + output: String::new(), + error: Some(message.into()), + } + } +} + /// Generic tool interface. Implement this for any tool the agent can use. /// CrewForge hub tools (HubGet, HubAck, HubPost) implement this trait. #[async_trait] diff --git a/crewforge-rs/src/security/policy.rs b/crewforge-rs/src/security/policy.rs index 3fddb47..d012d31 100644 --- a/crewforge-rs/src/security/policy.rs +++ b/crewforge-rs/src/security/policy.rs @@ -378,7 +378,7 @@ fn contains_unquoted_char(command: &str, target: char) -> bool { false } -fn is_valid_env_var_name(name: &str) -> bool { +pub(crate) fn is_valid_env_var_name(name: &str) -> bool { !name.is_empty() && name .chars() diff --git a/crewforge-rs/src/tools/content_search.rs b/crewforge-rs/src/tools/content_search.rs index fd679a6..4333611 100644 --- a/crewforge-rs/src/tools/content_search.rs +++ b/crewforge-rs/src/tools/content_search.rs @@ -2,7 +2,7 @@ use crate::agent::ToolResult; use crate::security::SecurityPolicy; use async_trait::async_trait; use std::process::Stdio; -use std::sync::{Arc, OnceLock}; +use std::sync::{Arc, LazyLock}; const MAX_RESULTS: usize = 1000; const MAX_OUTPUT_BYTES: usize = 1_048_576; // 1 MB @@ -130,14 +130,6 @@ impl crate::agent::Tool for ContentSearchTool { .and_then(|v| v.as_u64()) .unwrap_or(0) as usize; - #[allow(clippy::cast_possible_truncation)] - let max_results = args - .get("max_results") - .and_then(|v| v.as_u64()) - .map(|v| v as usize) - .unwrap_or(MAX_RESULTS) - .min(MAX_RESULTS); - if self.security.is_rate_limited() { return Ok(ToolResult { success: false, @@ -204,7 +196,6 @@ impl crate::agent::Tool for ContentSearchTool { }); } - // Build command let mut cmd = if self.has_rg { build_rg_command( pattern, @@ -235,31 +226,26 @@ impl crate::agent::Tool for ContentSearchTool { } } - cmd.stdout(Stdio::piped()); - cmd.stderr(Stdio::piped()); - - let output = match tokio::time::timeout( - std::time::Duration::from_secs(TIMEOUT_SECS), - tokio::process::Command::from(cmd).output(), - ) - .await - { - Ok(Ok(out)) => out, - Ok(Err(e)) => { - return Ok(ToolResult { - success: false, - output: String::new(), - error: Some(format!("Failed to execute search command: {e}")), - }); - } - Err(_) => { - return Ok(ToolResult { - success: false, - output: String::new(), - error: Some(format!("Search timed out after {TIMEOUT_SECS} seconds.")), - }); - } - }; + let output = + match tokio::time::timeout(std::time::Duration::from_secs(TIMEOUT_SECS), cmd.output()) + .await + { + Ok(Ok(out)) => out, + Ok(Err(e)) => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!("Failed to execute search command: {e}")), + }); + } + Err(_) => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!("Search timed out after {TIMEOUT_SECS} seconds.")), + }); + } + }; // Exit code: 0 = matches found, 1 = no matches (grep/rg), 2 = error let exit_code = output.status.code().unwrap_or(-1); @@ -277,7 +263,7 @@ impl crate::agent::Tool for ContentSearchTool { let workspace_canon = std::fs::canonicalize(workspace).unwrap_or_else(|_| workspace.clone()); - let formatted = format_line_output(&raw_stdout, &workspace_canon, output_mode, max_results); + let formatted = format_line_output(&raw_stdout, &workspace_canon, output_mode, MAX_RESULTS); // Truncate if too large let final_output = if formatted.len() > MAX_OUTPUT_BYTES { @@ -304,12 +290,14 @@ fn build_rg_command( case_sensitive: bool, context_before: usize, context_after: usize, -) -> std::process::Command { - let mut cmd = std::process::Command::new("rg"); +) -> tokio::process::Command { + let mut cmd = tokio::process::Command::new("rg"); cmd.arg("--no-heading"); cmd.arg("--line-number"); cmd.arg("--with-filename"); + cmd.stdout(Stdio::piped()); + cmd.stderr(Stdio::piped()); match output_mode { "files_with_matches" => { @@ -351,13 +339,15 @@ fn build_grep_command( case_sensitive: bool, context_before: usize, context_after: usize, -) -> std::process::Command { - let mut cmd = std::process::Command::new("grep"); +) -> tokio::process::Command { + let mut cmd = tokio::process::Command::new("grep"); cmd.arg("-r"); cmd.arg("-n"); cmd.arg("-E"); cmd.arg("--binary-files=without-match"); + cmd.stdout(Stdio::piped()); + cmd.stderr(Stdio::piped()); match output_mode { "files_with_matches" => { @@ -403,20 +393,18 @@ fn relativize_path(line: &str, workspace_prefix: &str) -> String { } fn parse_content_line(line: &str) -> Option<(&str, bool)> { - static MATCH_RE: OnceLock = OnceLock::new(); - static CONTEXT_RE: OnceLock = OnceLock::new(); - - let match_re = MATCH_RE.get_or_init(|| { + static MATCH_RE: LazyLock = LazyLock::new(|| { regex::Regex::new(r"^(?P.+?):\d+:").expect("match line regex must be valid") }); - if let Some(caps) = match_re.captures(line) { + static CONTEXT_RE: LazyLock = LazyLock::new(|| { + regex::Regex::new(r"^(?P.+?)-\d+-").expect("context line regex must be valid") + }); + + if let Some(caps) = MATCH_RE.captures(line) { return caps.name("path").map(|m| (m.as_str(), true)); } - let context_re = CONTEXT_RE.get_or_init(|| { - regex::Regex::new(r"^(?P.+?)-\d+-").expect("context line regex must be valid") - }); - if let Some(caps) = context_re.captures(line) { + if let Some(caps) = CONTEXT_RE.captures(line) { return caps.name("path").map(|m| (m.as_str(), false)); } @@ -424,12 +412,11 @@ fn parse_content_line(line: &str) -> Option<(&str, bool)> { } fn parse_count_line(line: &str) -> Option<(&str, usize)> { - static COUNT_RE: OnceLock = OnceLock::new(); - let count_re = COUNT_RE.get_or_init(|| { + static COUNT_RE: LazyLock = LazyLock::new(|| { regex::Regex::new(r"^(?P.+?):(?P\d+)\s*$").expect("count line regex valid") }); - let caps = count_re.captures(line)?; + let caps = COUNT_RE.captures(line)?; let path = caps.name("path")?.as_str(); let count = caps.name("count")?.as_str().parse::().ok()?; Some((path, count)) diff --git a/crewforge-rs/src/tools/file_read.rs b/crewforge-rs/src/tools/file_read.rs index e0cba38..3bf7b91 100644 --- a/crewforge-rs/src/tools/file_read.rs +++ b/crewforge-rs/src/tools/file_read.rs @@ -53,27 +53,13 @@ impl crate::agent::Tool for FileReadTool { .ok_or_else(|| anyhow::anyhow!("Missing 'path' parameter"))?; if self.security.is_rate_limited() { - return Ok(ToolResult { - success: false, - output: String::new(), - error: Some("Rate limit exceeded".into()), - }); + return Ok(ToolResult::denied("Rate limit exceeded")); } if !self.security.is_path_allowed(path) { - return Ok(ToolResult { - success: false, - output: String::new(), - error: Some(format!("Path not allowed by security policy: {path}")), - }); - } - - if !self.security.record_action() { - return Ok(ToolResult { - success: false, - output: String::new(), - error: Some("Rate limit exceeded: action budget exhausted".into()), - }); + return Ok(ToolResult::denied(format!( + "Path not allowed by security policy: {path}" + ))); } let full_path = self.security.workspace_dir.join(path); @@ -81,44 +67,32 @@ impl crate::agent::Tool for FileReadTool { let resolved_path = match tokio::fs::canonicalize(&full_path).await { Ok(p) => p, Err(e) => { - return Ok(ToolResult { - success: false, - output: String::new(), - error: Some(format!("Failed to resolve file path: {e}")), - }); + return Ok(ToolResult::denied(format!( + "Failed to resolve file path: {e}" + ))); } }; if !self.security.is_resolved_path_allowed(&resolved_path) { - return Ok(ToolResult { - success: false, - output: String::new(), - error: Some( - self.security - .resolved_path_violation_message(&resolved_path), - ), - }); + return Ok(ToolResult::denied( + self.security + .resolved_path_violation_message(&resolved_path), + )); } match tokio::fs::metadata(&resolved_path).await { Ok(meta) => { if meta.len() > MAX_FILE_SIZE_BYTES { - return Ok(ToolResult { - success: false, - output: String::new(), - error: Some(format!( - "File too large: {} bytes (limit: {MAX_FILE_SIZE_BYTES} bytes)", - meta.len() - )), - }); + return Ok(ToolResult::denied(format!( + "File too large: {} bytes (limit: {MAX_FILE_SIZE_BYTES} bytes)", + meta.len() + ))); } } Err(e) => { - return Ok(ToolResult { - success: false, - output: String::new(), - error: Some(format!("Failed to read file metadata: {e}")), - }); + return Ok(ToolResult::denied(format!( + "Failed to read file metadata: {e}" + ))); } } @@ -128,11 +102,7 @@ impl crate::agent::Tool for FileReadTool { let total = lines.len(); if total == 0 { - return Ok(ToolResult { - success: true, - output: String::new(), - error: None, - }); + return Ok(ToolResult::ok("")); } let offset = args @@ -155,11 +125,9 @@ impl crate::agent::Tool for FileReadTool { }; if start >= end { - return Ok(ToolResult { - success: true, - output: format!("[No lines in range, file has {total} lines]"), - error: None, - }); + return Ok(ToolResult::ok(format!( + "[No lines in range, file has {total} lines]" + ))); } let numbered: String = lines[start..end] @@ -176,11 +144,7 @@ impl crate::agent::Tool for FileReadTool { format!("\n[{total} lines total]") }; - Ok(ToolResult { - success: true, - output: format!("{numbered}{summary}"), - error: None, - }) + Ok(ToolResult::ok(format!("{numbered}{summary}"))) } Err(_) => { let bytes = tokio::fs::read(&resolved_path) @@ -188,11 +152,7 @@ impl crate::agent::Tool for FileReadTool { .map_err(|e| anyhow::anyhow!("Failed to read file: {e}"))?; let lossy = String::from_utf8_lossy(&bytes).into_owned(); - Ok(ToolResult { - success: true, - output: lossy, - error: None, - }) + Ok(ToolResult::ok(lossy)) } } } diff --git a/crewforge-rs/src/tools/glob_search.rs b/crewforge-rs/src/tools/glob_search.rs index 8066424..94667e7 100644 --- a/crewforge-rs/src/tools/glob_search.rs +++ b/crewforge-rs/src/tools/glob_search.rs @@ -81,62 +81,69 @@ impl crate::agent::Tool for GlobSearchTool { let workspace = &self.security.workspace_dir; let full_pattern = workspace.join(pattern).to_string_lossy().to_string(); + let security = self.security.clone(); + let pattern_owned = pattern.to_string(); - let entries = match glob::glob(&full_pattern) { - Ok(paths) => paths, - Err(e) => { - return Ok(ToolResult { - success: false, - output: String::new(), - error: Some(format!("Invalid glob pattern: {e}")), - }); - } - }; - - let workspace_canon = match std::fs::canonicalize(workspace) { - Ok(p) => p, - Err(e) => { - return Ok(ToolResult { - success: false, - output: String::new(), - error: Some(format!("Cannot resolve workspace directory: {e}")), - }); - } - }; - - let mut results = Vec::new(); - let mut truncated = false; - - for entry in entries { - let path = match entry { - Ok(p) => p, - Err(_) => continue, + let (results, truncated) = match tokio::task::spawn_blocking(move || { + let entries = match glob::glob(&full_pattern) { + Ok(paths) => paths, + Err(e) => return Err(format!("Invalid glob pattern: {e}")), }; - let resolved = match std::fs::canonicalize(&path) { + let workspace_canon = match std::fs::canonicalize(&security.workspace_dir) { Ok(p) => p, - Err(_) => continue, + Err(e) => return Err(format!("Cannot resolve workspace directory: {e}")), }; - if !self.security.is_resolved_path_allowed(&resolved) { - continue; - } + let mut results = Vec::new(); + let mut truncated = false; - if resolved.is_dir() { - continue; - } + for entry in entries { + let path = match entry { + Ok(p) => p, + Err(_) => continue, + }; + + let resolved = match std::fs::canonicalize(&path) { + Ok(p) => p, + Err(_) => continue, + }; + + if !security.is_resolved_path_allowed(&resolved) { + continue; + } + + if resolved.is_dir() { + continue; + } + + if let Ok(rel) = resolved.strip_prefix(&workspace_canon) { + results.push(rel.to_string_lossy().to_string()); + } - if let Ok(rel) = resolved.strip_prefix(&workspace_canon) { - results.push(rel.to_string_lossy().to_string()); + if results.len() >= MAX_RESULTS { + truncated = true; + break; + } } - if results.len() >= MAX_RESULTS { - truncated = true; - break; + results.sort(); + Ok((results, truncated)) + }) + .await + .unwrap_or_else(|e| Err(format!("Glob task panicked: {e}"))) + { + Ok(pair) => pair, + Err(msg) => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(msg), + }); } - } + }; - results.sort(); + let pattern = &pattern_owned; let output = if results.is_empty() { format!("No files matching pattern '{pattern}' found in workspace.") diff --git a/crewforge-rs/src/tools/shell.rs b/crewforge-rs/src/tools/shell.rs index f460098..dd00b2f 100644 --- a/crewforge-rs/src/tools/shell.rs +++ b/crewforge-rs/src/tools/shell.rs @@ -1,5 +1,6 @@ use crate::agent::ToolResult; use crate::security::SecurityPolicy; +use crate::security::policy::is_valid_env_var_name; use crate::tools::traits::RuntimeAdapter; use async_trait::async_trait; use std::collections::HashSet; @@ -23,18 +24,6 @@ impl ShellTool { } } -fn is_valid_env_var_name(name: &str) -> bool { - if name.is_empty() { - return false; - } - let mut chars = name.chars(); - let first = chars.next().unwrap(); - if !first.is_ascii_alphabetic() && first != '_' { - return false; - } - chars.all(|c| c.is_ascii_alphanumeric() || c == '_') -} - pub(crate) fn collect_allowed_shell_env_vars(security: &SecurityPolicy) -> Vec { let mut out = Vec::new(); let mut seen = HashSet::new(); @@ -184,6 +173,9 @@ impl crate::agent::Tool for ShellTool { }); } }; + // Environment sandboxing: clear all env vars then selectively restore + // safe ones. This is intentionally done here (not in RuntimeAdapter) so + // the security policy controls which vars are passed through. cmd.env_clear(); for var in collect_allowed_shell_env_vars(&self.security) { From 3c6dd8e3692a08a9dd46903fbc9f022d6f67e7f8 Mon Sep 17 00:00:00 2001 From: Rexopia Date: Mon, 2 Mar 2026 01:33:28 +0800 Subject: [PATCH 8/8] refactor(provider): slim provider stack from ~11k to ~4k lines MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Delete 8 independent provider files (anthropic, openai, gemini, ollama, copilot, glm, openrouter, openai_codex) and replace with: - anthropic_oauth.rs: lean Anthropic Messages API with OAuth/API-key dual auth - openai_oauth.rs: renamed from openai_codex.rs (OpenAI Codex OAuth provider) - compatible.rs: simplified from 2,210 to ~480 lines — pure base_url + Bearer API key, no AuthStyle enum, no special auth modes, no config files All existing compatible-layer providers (openai, moonshot, qwen, minimax, deepseek, groq, mistral, xai, openrouter) now use the 3-arg constructor. GLM (JWT auth) dropped for now. 437 tests pass, clippy/fmt clean. Co-Authored-By: Claude Opus 4.6 --- crewforge-rs/src/provider/anthropic.rs | 745 ----- crewforge-rs/src/provider/anthropic_oauth.rs | 560 ++++ crewforge-rs/src/provider/compatible.rs | 2400 ++++------------- crewforge-rs/src/provider/copilot.rs | 745 ----- crewforge-rs/src/provider/gemini.rs | 496 ---- crewforge-rs/src/provider/glm.rs | 367 --- crewforge-rs/src/provider/mod.rs | 59 +- crewforge-rs/src/provider/ollama.rs | 989 ------- crewforge-rs/src/provider/openai.rs | 827 ------ .../{openai_codex.rs => openai_oauth.rs} | 0 crewforge-rs/src/provider/openrouter.rs | 823 ------ 11 files changed, 1075 insertions(+), 6936 deletions(-) delete mode 100644 crewforge-rs/src/provider/anthropic.rs create mode 100644 crewforge-rs/src/provider/anthropic_oauth.rs delete mode 100644 crewforge-rs/src/provider/copilot.rs delete mode 100644 crewforge-rs/src/provider/gemini.rs delete mode 100644 crewforge-rs/src/provider/glm.rs delete mode 100644 crewforge-rs/src/provider/ollama.rs delete mode 100644 crewforge-rs/src/provider/openai.rs rename crewforge-rs/src/provider/{openai_codex.rs => openai_oauth.rs} (100%) delete mode 100644 crewforge-rs/src/provider/openrouter.rs diff --git a/crewforge-rs/src/provider/anthropic.rs b/crewforge-rs/src/provider/anthropic.rs deleted file mode 100644 index 0224794..0000000 --- a/crewforge-rs/src/provider/anthropic.rs +++ /dev/null @@ -1,745 +0,0 @@ -use crate::provider::traits::{ - ChatMessage, ChatRequest as ProviderChatRequest, ChatResponse as ProviderChatResponse, - Provider, TokenUsage, ToolCall as ProviderToolCall, ToolSpec, -}; -use async_trait::async_trait; -use reqwest::Client; -use serde::{Deserialize, Serialize}; - -pub struct AnthropicProvider { - credential: Option, - base_url: String, -} - -#[derive(Debug, Serialize)] -struct ChatRequest { - model: String, - max_tokens: u32, - #[serde(skip_serializing_if = "Option::is_none")] - system: Option, - messages: Vec, - temperature: f64, -} - -#[derive(Debug, Serialize)] -struct Message { - role: String, - content: String, -} - -#[derive(Debug, Deserialize)] -struct ChatResponse { - content: Vec, -} - -#[derive(Debug, Deserialize)] -struct ContentBlock { - #[serde(rename = "type")] - kind: String, - #[serde(default)] - text: Option, - #[allow(dead_code)] - #[serde(default)] - id: Option, - #[allow(dead_code)] - #[serde(default)] - name: Option, - #[allow(dead_code)] - #[serde(default)] - input: Option, -} - -#[derive(Debug, Serialize)] -struct NativeChatRequest<'a> { - model: String, - max_tokens: u32, - #[serde(skip_serializing_if = "Option::is_none")] - system: Option, - messages: Vec, - temperature: f64, - #[serde(skip_serializing_if = "Option::is_none")] - tools: Option>>, -} - -#[derive(Debug, Serialize)] -struct NativeMessage { - role: String, - content: Vec, -} - -#[derive(Debug, Serialize)] -#[serde(tag = "type")] -enum NativeContentOut { - #[serde(rename = "text")] - Text { - text: String, - #[serde(skip_serializing_if = "Option::is_none")] - cache_control: Option, - }, - #[serde(rename = "tool_use")] - ToolUse { - id: String, - name: String, - input: serde_json::Value, - #[serde(skip_serializing_if = "Option::is_none")] - cache_control: Option, - }, - #[serde(rename = "tool_result")] - ToolResult { - tool_use_id: String, - content: String, - #[serde(skip_serializing_if = "Option::is_none")] - cache_control: Option, - }, -} - -#[derive(Debug, Serialize)] -struct NativeToolSpec<'a> { - name: &'a str, - description: &'a str, - input_schema: &'a serde_json::Value, - #[serde(skip_serializing_if = "Option::is_none")] - cache_control: Option, -} - -#[derive(Debug, Clone, Serialize)] -struct CacheControl { - #[serde(rename = "type")] - cache_type: String, -} - -impl CacheControl { - fn ephemeral() -> Self { - Self { - cache_type: "ephemeral".to_string(), - } - } -} - -#[derive(Debug, Serialize)] -#[serde(untagged)] -enum SystemPrompt { - String(String), - Blocks(Vec), -} - -#[derive(Debug, Serialize)] -struct SystemBlock { - #[serde(rename = "type")] - block_type: String, - text: String, - #[serde(skip_serializing_if = "Option::is_none")] - cache_control: Option, -} - -#[derive(Debug, Deserialize)] -struct NativeChatResponse { - #[serde(default)] - content: Vec, - #[serde(default)] - usage: Option, -} - -#[derive(Debug, Deserialize)] -struct AnthropicUsage { - #[serde(default)] - input_tokens: Option, - #[serde(default)] - output_tokens: Option, -} - -#[derive(Debug, Deserialize)] -struct NativeContentIn { - #[serde(rename = "type")] - kind: String, - #[serde(default)] - text: Option, - #[serde(default)] - id: Option, - #[serde(default)] - name: Option, - #[serde(default)] - input: Option, -} - -impl AnthropicProvider { - pub fn new(credential: Option<&str>) -> Self { - Self::with_base_url(credential, None) - } - - pub fn with_base_url(credential: Option<&str>, base_url: Option<&str>) -> Self { - let base_url = base_url - .map(|u| u.trim_end_matches('/')) - .unwrap_or("https://api.anthropic.com") - .to_string(); - Self { - credential: credential - .map(str::trim) - .filter(|k| !k.is_empty()) - .map(ToString::to_string), - base_url, - } - } - - fn is_setup_token(token: &str) -> bool { - token.starts_with("sk-ant-oat01-") - } - - fn apply_auth( - &self, - request: reqwest::RequestBuilder, - credential: &str, - ) -> reqwest::RequestBuilder { - if Self::is_setup_token(credential) { - request - .header("Authorization", format!("Bearer {credential}")) - .header("anthropic-beta", "oauth-2025-04-20") - } else { - request.header("x-api-key", credential) - } - } - - /// Cache system prompts larger than ~1024 tokens (3KB of text) - fn should_cache_system(text: &str) -> bool { - text.len() > 3072 - } - - /// Cache conversations with more than 4 messages (excluding system) - fn should_cache_conversation(messages: &[ChatMessage]) -> bool { - messages.iter().filter(|m| m.role != "system").count() > 4 - } - - /// Apply cache control to the last message content block - fn apply_cache_to_last_message(messages: &mut [NativeMessage]) { - if let Some(last_msg) = messages.last_mut() - && let Some(last_content) = last_msg.content.last_mut() - { - match last_content { - NativeContentOut::Text { cache_control, .. } - | NativeContentOut::ToolResult { cache_control, .. } => { - *cache_control = Some(CacheControl::ephemeral()); - } - NativeContentOut::ToolUse { .. } => {} - } - } - } - - fn convert_tools<'a>(tools: Option<&'a [ToolSpec]>) -> Option>> { - let items = tools?; - if items.is_empty() { - return None; - } - let mut native_tools: Vec> = items - .iter() - .map(|tool| NativeToolSpec { - name: &tool.name, - description: &tool.description, - input_schema: &tool.parameters, - cache_control: None, - }) - .collect(); - - // Cache the last tool definition (caches all tools) - if let Some(last_tool) = native_tools.last_mut() { - last_tool.cache_control = Some(CacheControl::ephemeral()); - } - - Some(native_tools) - } - - fn parse_assistant_tool_call_message(content: &str) -> Option> { - let value = serde_json::from_str::(content).ok()?; - let tool_calls = value - .get("tool_calls") - .and_then(|v| serde_json::from_value::>(v.clone()).ok())?; - - let mut blocks = Vec::new(); - if let Some(text) = value - .get("content") - .and_then(serde_json::Value::as_str) - .map(str::trim) - .filter(|t| !t.is_empty()) - { - blocks.push(NativeContentOut::Text { - text: text.to_string(), - cache_control: None, - }); - } - for call in tool_calls { - let input = serde_json::from_str::(&call.arguments) - .unwrap_or_else(|_| serde_json::Value::Object(serde_json::Map::new())); - blocks.push(NativeContentOut::ToolUse { - id: call.id, - name: call.name, - input, - cache_control: None, - }); - } - Some(blocks) - } - - fn parse_tool_result_message(content: &str) -> Option { - let value = serde_json::from_str::(content).ok()?; - let tool_use_id = value - .get("tool_call_id") - .and_then(serde_json::Value::as_str)? - .to_string(); - let result = value - .get("content") - .and_then(serde_json::Value::as_str) - .unwrap_or("") - .to_string(); - Some(NativeMessage { - role: "user".to_string(), - content: vec![NativeContentOut::ToolResult { - tool_use_id, - content: result, - cache_control: None, - }], - }) - } - - fn convert_messages(messages: &[ChatMessage]) -> (Option, Vec) { - let mut system_text = None; - let mut native_messages = Vec::new(); - - for msg in messages { - match msg.role.as_str() { - "system" => { - if system_text.is_none() { - system_text = Some(msg.content.clone()); - } - } - "assistant" => { - if let Some(blocks) = Self::parse_assistant_tool_call_message(&msg.content) { - native_messages.push(NativeMessage { - role: "assistant".to_string(), - content: blocks, - }); - } else { - native_messages.push(NativeMessage { - role: "assistant".to_string(), - content: vec![NativeContentOut::Text { - text: msg.content.clone(), - cache_control: None, - }], - }); - } - } - "tool" => { - if let Some(tool_result) = Self::parse_tool_result_message(&msg.content) { - native_messages.push(tool_result); - } else { - native_messages.push(NativeMessage { - role: "user".to_string(), - content: vec![NativeContentOut::Text { - text: msg.content.clone(), - cache_control: None, - }], - }); - } - } - _ => { - native_messages.push(NativeMessage { - role: "user".to_string(), - content: vec![NativeContentOut::Text { - text: msg.content.clone(), - cache_control: None, - }], - }); - } - } - } - - // Convert system text to SystemPrompt with cache control if large - let system_prompt = system_text.map(|text| { - if Self::should_cache_system(&text) { - SystemPrompt::Blocks(vec![SystemBlock { - block_type: "text".to_string(), - text, - cache_control: Some(CacheControl::ephemeral()), - }]) - } else { - SystemPrompt::String(text) - } - }); - - (system_prompt, native_messages) - } - - fn parse_text_response(response: ChatResponse) -> anyhow::Result { - response - .content - .into_iter() - .find(|c| c.kind == "text") - .and_then(|c| c.text) - .ok_or_else(|| anyhow::anyhow!("No response from Anthropic")) - } - - fn parse_native_response(response: NativeChatResponse) -> ProviderChatResponse { - let mut text_parts = Vec::new(); - let mut tool_calls = Vec::new(); - - let usage = response.usage.map(|u| TokenUsage { - input_tokens: u.input_tokens, - output_tokens: u.output_tokens, - }); - - for block in response.content { - match block.kind.as_str() { - "text" => { - if let Some(text) = block.text.map(|t| t.trim().to_string()) - && !text.is_empty() - { - text_parts.push(text); - } - } - "tool_use" => { - let name = block.name.unwrap_or_default(); - if name.is_empty() { - continue; - } - let arguments = block - .input - .unwrap_or_else(|| serde_json::Value::Object(serde_json::Map::new())); - tool_calls.push(ProviderToolCall { - id: block.id.unwrap_or_else(|| uuid::Uuid::new_v4().to_string()), - name, - arguments: arguments.to_string(), - }); - } - _ => {} - } - } - - ProviderChatResponse { - text: if text_parts.is_empty() { - None - } else { - Some(text_parts.join("\n")) - }, - tool_calls, - usage, - reasoning_content: None, - } - } - - fn http_client(&self) -> Client { - reqwest::Client::builder() - .timeout(std::time::Duration::from_secs(120)) - .build() - .unwrap_or_default() - } -} - -async fn api_error(provider: &str, response: reqwest::Response) -> anyhow::Error { - let status = response.status(); - let body = response - .text() - .await - .unwrap_or_else(|_| "".to_string()); - anyhow::anyhow!("API error ({provider}, {status}): {body}") -} - -#[async_trait] -impl Provider for AnthropicProvider { - async fn chat_with_system( - &self, - system_prompt: Option<&str>, - message: &str, - model: &str, - temperature: f64, - ) -> anyhow::Result { - let credential = self.credential.as_ref().ok_or_else(|| { - anyhow::anyhow!( - "Anthropic credentials not set. Set ANTHROPIC_API_KEY or ANTHROPIC_OAUTH_TOKEN (setup-token)." - ) - })?; - - let request = ChatRequest { - model: model.to_string(), - max_tokens: 4096, - system: system_prompt.map(ToString::to_string), - messages: vec![Message { - role: "user".to_string(), - content: message.to_string(), - }], - temperature, - }; - - let mut req = self - .http_client() - .post(format!("{}/v1/messages", self.base_url)) - .header("anthropic-version", "2023-06-01") - .header("content-type", "application/json") - .json(&request); - - req = self.apply_auth(req, credential); - - let response = req.send().await?; - - if !response.status().is_success() { - return Err(api_error("Anthropic", response).await); - } - - let chat_response: ChatResponse = response.json().await?; - Self::parse_text_response(chat_response) - } - - async fn chat( - &self, - request: ProviderChatRequest<'_>, - model: &str, - temperature: f64, - ) -> anyhow::Result { - let credential = self.credential.as_ref().ok_or_else(|| { - anyhow::anyhow!( - "Anthropic credentials not set. Set ANTHROPIC_API_KEY or ANTHROPIC_OAUTH_TOKEN (setup-token)." - ) - })?; - - let (system_prompt, mut messages) = Self::convert_messages(request.messages); - - // Auto-cache last message if conversation is long - if Self::should_cache_conversation(request.messages) { - Self::apply_cache_to_last_message(&mut messages); - } - - let native_request = NativeChatRequest { - model: model.to_string(), - max_tokens: 4096, - system: system_prompt, - messages, - temperature, - tools: Self::convert_tools(request.tools), - }; - - let mut req = self - .http_client() - .post(format!("{}/v1/messages", self.base_url)) - .header("anthropic-version", "2023-06-01") - .header("content-type", "application/json") - .json(&native_request); - - if let Some(tools) = &native_request.tools - && !tools.is_empty() - { - req = req.header("anthropic-beta", "prompt-caching-2024-07-31"); - } - - req = self.apply_auth(req, credential); - - let response = req.send().await?; - - if !response.status().is_success() { - return Err(api_error("Anthropic", response).await); - } - - let native_response: NativeChatResponse = response.json().await?; - Ok(Self::parse_native_response(native_response)) - } - - fn supports_native_tools(&self) -> bool { - true - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn creates_with_key() { - let p = AnthropicProvider::new(Some("test-key")); - assert_eq!(p.credential.as_deref(), Some("test-key")); - } - - #[test] - fn creates_without_key() { - let p = AnthropicProvider::new(None); - assert!(p.credential.is_none()); - } - - #[test] - fn trims_empty_key() { - let p = AnthropicProvider::new(Some(" ")); - assert!(p.credential.is_none()); - } - - #[test] - fn default_base_url() { - let p = AnthropicProvider::new(None); - assert_eq!(p.base_url, "https://api.anthropic.com"); - } - - #[test] - fn custom_base_url_strips_trailing_slash() { - let p = AnthropicProvider::with_base_url(None, Some("https://custom.example.com/")); - assert_eq!(p.base_url, "https://custom.example.com"); - } - - #[test] - fn is_setup_token_detects_oat01_prefix() { - assert!(AnthropicProvider::is_setup_token("sk-ant-oat01-abc123")); - assert!(!AnthropicProvider::is_setup_token("sk-ant-api03-abc123")); - } - - #[tokio::test] - async fn chat_fails_without_key() { - let p = AnthropicProvider::new(None); - let result = p - .chat_with_system(None, "hello", "claude-sonnet-4", 0.7) - .await; - assert!(result.is_err()); - assert!( - result - .unwrap_err() - .to_string() - .contains("credentials not set") - ); - } - - #[test] - fn parse_text_response_extracts_text() { - let response = ChatResponse { - content: vec![ContentBlock { - kind: "text".to_string(), - text: Some("Hello!".to_string()), - id: None, - name: None, - input: None, - }], - }; - let result = AnthropicProvider::parse_text_response(response).unwrap(); - assert_eq!(result, "Hello!"); - } - - #[test] - fn parse_text_response_fails_with_no_text() { - let response = ChatResponse { - content: vec![ContentBlock { - kind: "tool_use".to_string(), - text: None, - id: Some("id".to_string()), - name: Some("shell".to_string()), - input: Some(serde_json::json!({})), - }], - }; - assert!(AnthropicProvider::parse_text_response(response).is_err()); - } - - #[test] - fn convert_messages_extracts_system_prompt() { - let messages = vec![ - ChatMessage::system("Be helpful"), - ChatMessage::user("Hello"), - ]; - let (system, native) = AnthropicProvider::convert_messages(&messages); - assert!(system.is_some()); - assert_eq!(native.len(), 1); - } - - #[test] - fn convert_messages_handles_tool_call_history() { - let tool_call_json = serde_json::json!({ - "content": "Let me check", - "tool_calls": [{ - "id": "call_1", - "name": "shell", - "arguments": "{\"command\":\"date\"}" - }] - }); - let messages = vec![ - ChatMessage::assistant(tool_call_json.to_string()), - ChatMessage { - role: "tool".to_string(), - content: r#"{"tool_call_id":"call_1","content":"Mon Dec 1"}"#.to_string(), - }, - ]; - let (_, native) = AnthropicProvider::convert_messages(&messages); - assert_eq!(native.len(), 2); - // First message should contain ToolUse block - assert!( - native[0] - .content - .iter() - .any(|c| matches!(c, NativeContentOut::ToolUse { .. })) - ); - // Second message (tool result) becomes a user message with ToolResult block - assert_eq!(native[1].role, "user"); - assert!( - native[1] - .content - .iter() - .any(|c| matches!(c, NativeContentOut::ToolResult { .. })) - ); - } - - #[test] - fn parse_native_response_extracts_tool_calls() { - let response = NativeChatResponse { - content: vec![ - NativeContentIn { - kind: "text".to_string(), - text: Some("I'll use a tool".to_string()), - id: None, - name: None, - input: None, - }, - NativeContentIn { - kind: "tool_use".to_string(), - text: None, - id: Some("call_1".to_string()), - name: Some("shell".to_string()), - input: Some(serde_json::json!({"command": "date"})), - }, - ], - usage: None, - }; - let result = AnthropicProvider::parse_native_response(response); - assert_eq!(result.text.as_deref(), Some("I'll use a tool")); - assert_eq!(result.tool_calls.len(), 1); - assert_eq!(result.tool_calls[0].name, "shell"); - } - - #[test] - fn parse_native_response_reports_usage() { - let response = NativeChatResponse { - content: vec![NativeContentIn { - kind: "text".to_string(), - text: Some("Hi".to_string()), - id: None, - name: None, - input: None, - }], - usage: Some(AnthropicUsage { - input_tokens: Some(10), - output_tokens: Some(5), - }), - }; - let result = AnthropicProvider::parse_native_response(response); - let usage = result.usage.unwrap(); - assert_eq!(usage.input_tokens, Some(10)); - assert_eq!(usage.output_tokens, Some(5)); - } - - #[test] - fn should_cache_system_triggers_for_large_text() { - let small = "small"; - let large = "x".repeat(4096); - assert!(!AnthropicProvider::should_cache_system(small)); - assert!(AnthropicProvider::should_cache_system(&large)); - } - - #[test] - fn should_cache_conversation_triggers_after_4_non_system_messages() { - let messages: Vec = (0..5) - .map(|i| ChatMessage::user(format!("msg {i}"))) - .collect(); - assert!(AnthropicProvider::should_cache_conversation(&messages)); - let short: Vec = (0..3) - .map(|i| ChatMessage::user(format!("msg {i}"))) - .collect(); - assert!(!AnthropicProvider::should_cache_conversation(&short)); - } -} diff --git a/crewforge-rs/src/provider/anthropic_oauth.rs b/crewforge-rs/src/provider/anthropic_oauth.rs new file mode 100644 index 0000000..96d4164 --- /dev/null +++ b/crewforge-rs/src/provider/anthropic_oauth.rs @@ -0,0 +1,560 @@ +use crate::auth::AuthService; +use crate::auth::anthropic_token::{AnthropicAuthKind, detect_auth_kind}; +use crate::provider::ProviderRuntimeOptions; +use crate::provider::traits::{ + ChatMessage, ChatRequest as ProviderChatRequest, ChatResponse as ProviderChatResponse, + Provider, ProviderCapabilities, TokenUsage, ToolCall as ProviderToolCall, ToolSpec, +}; +use async_trait::async_trait; +use reqwest::Client; +use serde::{Deserialize, Serialize}; +use std::path::PathBuf; + +const DEFAULT_BASE_URL: &str = "https://api.anthropic.com"; +const ANTHROPIC_VERSION: &str = "2023-06-01"; +const ANTHROPIC_OAUTH_BETA: &str = "oauth-2025-04-20"; +const DEFAULT_MAX_TOKENS: u32 = 4096; + +pub struct AnthropicOAuthProvider { + auth: AuthService, + auth_profile_override: Option, + base_url: String, + /// Explicit API key passed via --api-key or env var (bypasses AuthService). + api_key: Option, + client: Client, +} + +// ── Request types ──────────────────────────────────────────────────────────── + +#[derive(Debug, Serialize)] +struct MessagesRequest<'a> { + model: String, + max_tokens: u32, + #[serde(skip_serializing_if = "Option::is_none")] + system: Option, + messages: Vec, + temperature: f64, + #[serde(skip_serializing_if = "Option::is_none")] + tools: Option>>, +} + +#[derive(Debug, Serialize)] +struct NativeMessage { + role: String, + content: Vec, +} + +#[derive(Debug, Serialize)] +#[serde(tag = "type")] +enum ContentBlock { + #[serde(rename = "text")] + Text { text: String }, + #[serde(rename = "tool_use")] + ToolUse { + id: String, + name: String, + input: serde_json::Value, + }, + #[serde(rename = "tool_result")] + ToolResult { + tool_use_id: String, + content: String, + }, +} + +#[derive(Debug, Serialize)] +struct NativeToolSpec<'a> { + name: &'a str, + description: &'a str, + input_schema: &'a serde_json::Value, +} + +// ── Response types ─────────────────────────────────────────────────────────── + +#[derive(Debug, Deserialize)] +struct MessagesResponse { + #[serde(default)] + content: Vec, + #[serde(default)] + usage: Option, +} + +#[derive(Debug, Deserialize)] +struct ResponseContent { + #[serde(rename = "type")] + kind: String, + #[serde(default)] + text: Option, + #[serde(default)] + id: Option, + #[serde(default)] + name: Option, + #[serde(default)] + input: Option, +} + +#[derive(Debug, Deserialize)] +struct AnthropicUsage { + #[serde(default)] + input_tokens: Option, + #[serde(default)] + output_tokens: Option, +} + +// ── Provider implementation ────────────────────────────────────────────────── + +impl AnthropicOAuthProvider { + pub fn new(options: &ProviderRuntimeOptions, api_key: Option<&str>) -> anyhow::Result { + let state_dir = options + .crewforge_dir + .clone() + .unwrap_or_else(default_crewforge_dir); + let auth = AuthService::new(&state_dir, options.secrets_encrypt); + let base_url = options + .provider_api_url + .as_deref() + .unwrap_or(DEFAULT_BASE_URL) + .trim_end_matches('/') + .to_string(); + + Ok(Self { + auth, + auth_profile_override: options.auth_profile_override.clone(), + base_url, + api_key: api_key + .map(str::trim) + .filter(|k| !k.is_empty()) + .map(ToString::to_string), + client: Client::builder() + .timeout(std::time::Duration::from_secs(120)) + .connect_timeout(std::time::Duration::from_secs(10)) + .build() + .unwrap_or_else(|_| Client::new()), + }) + } + + /// Resolve credential: explicit api_key → AuthService profile. + async fn resolve_credential(&self) -> anyhow::Result { + if let Some(key) = &self.api_key { + return Ok(key.clone()); + } + self.auth + .get_provider_bearer_token("anthropic", self.auth_profile_override.as_deref()) + .await? + .ok_or_else(|| { + anyhow::anyhow!( + "Anthropic credentials not found. Set ANTHROPIC_API_KEY or run `crewforge auth paste-token --provider anthropic`." + ) + }) + } + + /// Apply auth headers based on token kind. + fn apply_auth( + &self, + mut req: reqwest::RequestBuilder, + credential: &str, + ) -> reqwest::RequestBuilder { + let kind = detect_auth_kind(credential, None); + match kind { + AnthropicAuthKind::Authorization => { + req = req + .header("Authorization", format!("Bearer {credential}")) + .header("anthropic-beta", ANTHROPIC_OAUTH_BETA); + } + AnthropicAuthKind::ApiKey => { + req = req.header("x-api-key", credential); + } + } + req + } + + async fn send_messages( + &self, + request: &MessagesRequest<'_>, + ) -> anyhow::Result { + let credential = self.resolve_credential().await?; + + let mut req = self + .client + .post(format!("{}/v1/messages", self.base_url)) + .header("anthropic-version", ANTHROPIC_VERSION) + .header("content-type", "application/json") + .json(request); + + req = self.apply_auth(req, &credential); + + let response = req.send().await?; + if !response.status().is_success() { + return Err(super::api_error("Anthropic", response).await); + } + + response + .json() + .await + .map_err(|e| anyhow::anyhow!("Anthropic response parse failed: {e}")) + } +} + +fn default_crewforge_dir() -> PathBuf { + directories::UserDirs::new().map_or_else( + || PathBuf::from(".crewforge"), + |dirs| dirs.home_dir().join(".crewforge"), + ) +} + +// ── Message conversion ─────────────────────────────────────────────────────── + +fn convert_messages(messages: &[ChatMessage]) -> (Option, Vec) { + let mut system_text = None; + let mut native = Vec::new(); + + for msg in messages { + match msg.role.as_str() { + "system" => { + if system_text.is_none() { + system_text = Some(msg.content.clone()); + } + } + "assistant" => { + if let Some(blocks) = parse_assistant_tool_call_content(&msg.content) { + native.push(NativeMessage { + role: "assistant".to_string(), + content: blocks, + }); + } else { + native.push(NativeMessage { + role: "assistant".to_string(), + content: vec![ContentBlock::Text { + text: msg.content.clone(), + }], + }); + } + } + "tool" => { + if let Some(result_msg) = parse_tool_result_content(&msg.content) { + native.push(result_msg); + } else { + native.push(NativeMessage { + role: "user".to_string(), + content: vec![ContentBlock::Text { + text: msg.content.clone(), + }], + }); + } + } + _ => { + native.push(NativeMessage { + role: "user".to_string(), + content: vec![ContentBlock::Text { + text: msg.content.clone(), + }], + }); + } + } + } + + (system_text, native) +} + +fn parse_assistant_tool_call_content(content: &str) -> Option> { + let value = serde_json::from_str::(content).ok()?; + let tool_calls = value + .get("tool_calls") + .and_then(|v| serde_json::from_value::>(v.clone()).ok())?; + + let mut blocks = Vec::new(); + if let Some(text) = value + .get("content") + .and_then(serde_json::Value::as_str) + .map(str::trim) + .filter(|t| !t.is_empty()) + { + blocks.push(ContentBlock::Text { + text: text.to_string(), + }); + } + for call in tool_calls { + let input = serde_json::from_str::(&call.arguments) + .unwrap_or_else(|_| serde_json::Value::Object(serde_json::Map::new())); + blocks.push(ContentBlock::ToolUse { + id: call.id, + name: call.name, + input, + }); + } + Some(blocks) +} + +fn parse_tool_result_content(content: &str) -> Option { + let value = serde_json::from_str::(content).ok()?; + let tool_use_id = value + .get("tool_call_id") + .and_then(serde_json::Value::as_str)? + .to_string(); + let result = value + .get("content") + .and_then(serde_json::Value::as_str) + .unwrap_or("") + .to_string(); + Some(NativeMessage { + role: "user".to_string(), + content: vec![ContentBlock::ToolResult { + tool_use_id, + content: result, + }], + }) +} + +fn convert_tools<'a>(tools: Option<&'a [ToolSpec]>) -> Option>> { + let items = tools?; + if items.is_empty() { + return None; + } + Some( + items + .iter() + .map(|tool| NativeToolSpec { + name: &tool.name, + description: &tool.description, + input_schema: &tool.parameters, + }) + .collect(), + ) +} + +fn parse_response(response: MessagesResponse) -> ProviderChatResponse { + let mut text_parts = Vec::new(); + let mut tool_calls = Vec::new(); + + let usage = response.usage.map(|u| TokenUsage { + input_tokens: u.input_tokens, + output_tokens: u.output_tokens, + }); + + for block in response.content { + match block.kind.as_str() { + "text" => { + if let Some(text) = block + .text + .map(|t| t.trim().to_string()) + .filter(|t| !t.is_empty()) + { + text_parts.push(text); + } + } + "tool_use" => { + let name = block.name.unwrap_or_default(); + if name.is_empty() { + continue; + } + let arguments = block + .input + .unwrap_or_else(|| serde_json::Value::Object(serde_json::Map::new())); + tool_calls.push(ProviderToolCall { + id: block.id.unwrap_or_else(|| uuid::Uuid::new_v4().to_string()), + name, + arguments: arguments.to_string(), + }); + } + _ => {} + } + } + + ProviderChatResponse { + text: if text_parts.is_empty() { + None + } else { + Some(text_parts.join("\n")) + }, + tool_calls, + usage, + reasoning_content: None, + } +} + +// ── Provider trait ──────────────────────────────────────────────────────────── + +#[async_trait] +impl Provider for AnthropicOAuthProvider { + fn capabilities(&self) -> ProviderCapabilities { + ProviderCapabilities { + native_tool_calling: true, + vision: false, + } + } + + async fn chat_with_system( + &self, + system_prompt: Option<&str>, + message: &str, + model: &str, + temperature: f64, + ) -> anyhow::Result { + let request = MessagesRequest { + model: model.to_string(), + max_tokens: DEFAULT_MAX_TOKENS, + system: system_prompt.map(ToString::to_string), + messages: vec![NativeMessage { + role: "user".to_string(), + content: vec![ContentBlock::Text { + text: message.to_string(), + }], + }], + temperature, + tools: None, + }; + + let response = self.send_messages(&request).await?; + response + .content + .into_iter() + .find(|c| c.kind == "text") + .and_then(|c| c.text) + .ok_or_else(|| anyhow::anyhow!("No text response from Anthropic")) + } + + async fn chat( + &self, + request: ProviderChatRequest<'_>, + model: &str, + temperature: f64, + ) -> anyhow::Result { + let (system, messages) = convert_messages(request.messages); + + let native_request = MessagesRequest { + model: model.to_string(), + max_tokens: DEFAULT_MAX_TOKENS, + system, + messages, + temperature, + tools: convert_tools(request.tools), + }; + + let response = self.send_messages(&native_request).await?; + Ok(parse_response(response)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn default_state_dir_is_non_empty() { + let path = default_crewforge_dir(); + assert!(!path.as_os_str().is_empty()); + } + + #[test] + fn convert_messages_extracts_system() { + let messages = vec![ + ChatMessage::system("Be helpful"), + ChatMessage::user("Hello"), + ]; + let (system, native) = convert_messages(&messages); + assert_eq!(system.as_deref(), Some("Be helpful")); + assert_eq!(native.len(), 1); + } + + #[test] + fn convert_messages_handles_tool_call_history() { + let tool_call_json = serde_json::json!({ + "content": "Let me check", + "tool_calls": [{ + "id": "call_1", + "name": "shell", + "arguments": "{\"command\":\"date\"}" + }] + }); + let messages = vec![ + ChatMessage::assistant(tool_call_json.to_string()), + ChatMessage { + role: "tool".to_string(), + content: r#"{"tool_call_id":"call_1","content":"Mon Dec 1"}"#.to_string(), + }, + ]; + let (_, native) = convert_messages(&messages); + assert_eq!(native.len(), 2); + assert!( + native[0] + .content + .iter() + .any(|c| matches!(c, ContentBlock::ToolUse { .. })) + ); + assert_eq!(native[1].role, "user"); + assert!( + native[1] + .content + .iter() + .any(|c| matches!(c, ContentBlock::ToolResult { .. })) + ); + } + + #[test] + fn parse_response_extracts_text_and_tool_calls() { + let response = MessagesResponse { + content: vec![ + ResponseContent { + kind: "text".to_string(), + text: Some("I'll help".to_string()), + id: None, + name: None, + input: None, + }, + ResponseContent { + kind: "tool_use".to_string(), + text: None, + id: Some("call_1".to_string()), + name: Some("shell".to_string()), + input: Some(serde_json::json!({"command": "date"})), + }, + ], + usage: Some(AnthropicUsage { + input_tokens: Some(10), + output_tokens: Some(5), + }), + }; + let result = parse_response(response); + assert_eq!(result.text.as_deref(), Some("I'll help")); + assert_eq!(result.tool_calls.len(), 1); + assert_eq!(result.tool_calls[0].name, "shell"); + let usage = result.usage.unwrap(); + assert_eq!(usage.input_tokens, Some(10)); + assert_eq!(usage.output_tokens, Some(5)); + } + + #[test] + fn convert_tools_maps_spec() { + let tools = vec![ToolSpec { + name: "shell".to_string(), + description: "Run a shell command".to_string(), + parameters: serde_json::json!({"type": "object"}), + }]; + let native = convert_tools(Some(&tools)).unwrap(); + assert_eq!(native.len(), 1); + assert_eq!(native[0].name, "shell"); + } + + #[test] + fn convert_tools_returns_none_for_empty() { + assert!(convert_tools(Some(&[])).is_none()); + assert!(convert_tools(None).is_none()); + } + + #[tokio::test] + async fn resolve_credential_uses_explicit_key() { + let opts = ProviderRuntimeOptions::default(); + let provider = AnthropicOAuthProvider::new(&opts, Some("sk-ant-api-test")).unwrap(); + let cred = provider.resolve_credential().await.unwrap(); + assert_eq!(cred, "sk-ant-api-test"); + } + + #[test] + fn capabilities_reports_native_tools() { + let opts = ProviderRuntimeOptions::default(); + let provider = AnthropicOAuthProvider::new(&opts, None).unwrap(); + let caps = provider.capabilities(); + assert!(caps.native_tool_calling); + assert!(!caps.vision); + } +} diff --git a/crewforge-rs/src/provider/compatible.rs b/crewforge-rs/src/provider/compatible.rs index 0b49107..434bd82 100644 --- a/crewforge-rs/src/provider/compatible.rs +++ b/crewforge-rs/src/provider/compatible.rs @@ -1,316 +1,30 @@ //! Generic OpenAI-compatible provider. -//! Most LLM APIs follow the same `/v1/chat/completions` format. +//! Most LLM APIs follow the same `/v1/chat/completions` format with Bearer auth. //! This module provides a single implementation that works for all of them. -use crate::provider::traits::ToolSpec; use crate::provider::traits::{ ChatMessage, ChatRequest as ProviderChatRequest, ChatResponse as ProviderChatResponse, - Provider, TokenUsage, ToolCall as ProviderToolCall, + Provider, ProviderCapabilities, TokenUsage, ToolCall as ProviderToolCall, ToolSpec, }; use async_trait::async_trait; -use reqwest::{ - Client, - header::{HeaderMap, HeaderValue, USER_AGENT}, -}; +use reqwest::Client; use serde::{Deserialize, Serialize}; /// A provider that speaks the OpenAI-compatible chat completions API. -/// Used by: Venice, Vercel AI Gateway, Cloudflare AI Gateway, Moonshot, -/// Synthetic, `OpenCode` Zen, `Z.AI`, `GLM`, `MiniMax`, Bedrock, Qianfan, Groq, Mistral, `xAI`, etc. -#[allow(clippy::struct_excessive_bools)] +/// Authentication is always via `Authorization: Bearer `. pub struct OpenAiCompatibleProvider { - pub(crate) name: String, - pub(crate) base_url: String, - pub(crate) credential: Option, - pub(crate) auth_header: AuthStyle, - supports_vision: bool, - /// When false, do not fall back to /v1/responses on chat completions 404. - /// GLM/Zhipu does not support the responses API. - supports_responses_fallback: bool, - user_agent: Option, - /// When true, collect all `system` messages and prepend their content - /// to the first `user` message, then drop the system messages. - /// Required for providers that reject `role: system` (e.g. MiniMax). - merge_system_into_user: bool, - /// Whether this provider supports OpenAI-style native tool calling. - /// When false, tools are injected into the system prompt as text. - native_tool_calling: bool, -} - -/// How the provider expects the API key to be sent. -#[derive(Debug, Clone)] -pub enum AuthStyle { - /// `Authorization: Bearer ` - Bearer, - /// `x-api-key: ` (used by some Chinese providers) - XApiKey, - /// Custom header name - Custom(String), + name: String, + base_url: String, + credential: Option, + client: Client, } -impl OpenAiCompatibleProvider { - pub fn new( - name: &str, - base_url: &str, - credential: Option<&str>, - auth_style: AuthStyle, - ) -> Self { - Self::new_with_options( - name, base_url, credential, auth_style, false, true, None, false, - ) - } - - pub fn new_with_vision( - name: &str, - base_url: &str, - credential: Option<&str>, - auth_style: AuthStyle, - supports_vision: bool, - ) -> Self { - Self::new_with_options( - name, - base_url, - credential, - auth_style, - supports_vision, - true, - None, - false, - ) - } - - /// Same as `new` but skips the /v1/responses fallback on 404. - /// Use for providers (e.g. GLM) that only support chat completions. - pub fn new_no_responses_fallback( - name: &str, - base_url: &str, - credential: Option<&str>, - auth_style: AuthStyle, - ) -> Self { - Self::new_with_options( - name, base_url, credential, auth_style, false, false, None, false, - ) - } - - /// Create a provider with a custom User-Agent header. - /// - /// Some providers (for example Kimi Code) require a specific User-Agent - /// for request routing and policy enforcement. - pub fn new_with_user_agent( - name: &str, - base_url: &str, - credential: Option<&str>, - auth_style: AuthStyle, - user_agent: &str, - ) -> Self { - Self::new_with_options( - name, - base_url, - credential, - auth_style, - false, - true, - Some(user_agent), - false, - ) - } - - pub fn new_with_user_agent_and_vision( - name: &str, - base_url: &str, - credential: Option<&str>, - auth_style: AuthStyle, - user_agent: &str, - supports_vision: bool, - ) -> Self { - Self::new_with_options( - name, - base_url, - credential, - auth_style, - supports_vision, - true, - Some(user_agent), - false, - ) - } - - /// For providers that do not support `role: system` (e.g. MiniMax). - /// System prompt content is prepended to the first user message instead. - pub fn new_merge_system_into_user( - name: &str, - base_url: &str, - credential: Option<&str>, - auth_style: AuthStyle, - ) -> Self { - Self::new_with_options( - name, base_url, credential, auth_style, false, false, None, true, - ) - } - - #[allow(clippy::too_many_arguments)] - fn new_with_options( - name: &str, - base_url: &str, - credential: Option<&str>, - auth_style: AuthStyle, - supports_vision: bool, - supports_responses_fallback: bool, - user_agent: Option<&str>, - merge_system_into_user: bool, - ) -> Self { - Self { - name: name.to_string(), - base_url: base_url.trim_end_matches('/').to_string(), - credential: credential.map(ToString::to_string), - auth_header: auth_style, - supports_vision, - supports_responses_fallback, - user_agent: user_agent.map(ToString::to_string), - merge_system_into_user, - native_tool_calling: !merge_system_into_user, - } - } - - /// Collect all `system` role messages, concatenate their content, - /// and prepend to the first `user` message. Drop all system messages. - /// Used for providers (e.g. MiniMax) that reject `role: system`. - fn flatten_system_messages(messages: &[ChatMessage]) -> Vec { - let system_content: String = messages - .iter() - .filter(|m| m.role == "system") - .map(|m| m.content.as_str()) - .collect::>() - .join("\n\n"); - - if system_content.is_empty() { - return messages.to_vec(); - } - - let mut result: Vec = messages - .iter() - .filter(|m| m.role != "system") - .cloned() - .collect(); - - if let Some(first_user) = result.iter_mut().find(|m| m.role == "user") { - first_user.content = format!("{system_content}\n\n{}", first_user.content); - } else { - // No user message found: insert a synthetic user message with system content - result.insert(0, ChatMessage::user(&system_content)); - } - - result - } - - fn http_client(&self) -> Client { - if let Some(ua) = self.user_agent.as_deref() { - let mut headers = HeaderMap::new(); - if let Ok(value) = HeaderValue::from_str(ua) { - headers.insert(USER_AGENT, value); - } - - let builder = Client::builder() - .timeout(std::time::Duration::from_secs(120)) - .connect_timeout(std::time::Duration::from_secs(10)) - .default_headers(headers); - - return builder.build().unwrap_or_else(|error| { - tracing::warn!("Failed to build timeout client with user-agent: {error}"); - Client::new() - }); - } - - reqwest::Client::builder() - .timeout(std::time::Duration::from_secs(120)) - .build() - .unwrap_or_default() - } - - /// Build the full URL for chat completions, detecting if base_url already includes the path. - /// This allows custom providers with non-standard endpoints (e.g., VolcEngine ARK uses - /// `/api/coding/v3/chat/completions` instead of `/v1/chat/completions`). - fn chat_completions_url(&self) -> String { - let has_full_endpoint = reqwest::Url::parse(&self.base_url) - .map(|url| { - url.path() - .trim_end_matches('/') - .ends_with("/chat/completions") - }) - .unwrap_or_else(|_| { - self.base_url - .trim_end_matches('/') - .ends_with("/chat/completions") - }); - - if has_full_endpoint { - self.base_url.clone() - } else { - format!("{}/chat/completions", self.base_url) - } - } - - fn path_ends_with(&self, suffix: &str) -> bool { - if let Ok(url) = reqwest::Url::parse(&self.base_url) { - return url.path().trim_end_matches('/').ends_with(suffix); - } - - self.base_url.trim_end_matches('/').ends_with(suffix) - } - - fn has_explicit_api_path(&self) -> bool { - let Ok(url) = reqwest::Url::parse(&self.base_url) else { - return false; - }; - - let path = url.path().trim_end_matches('/'); - !path.is_empty() && path != "/" - } - - /// Build the full URL for responses API, detecting if base_url already includes the path. - fn responses_url(&self) -> String { - if self.path_ends_with("/responses") { - return self.base_url.clone(); - } - - let normalized_base = self.base_url.trim_end_matches('/'); - - // If chat endpoint is explicitly configured, derive sibling responses endpoint. - if let Some(prefix) = normalized_base.strip_suffix("/chat/completions") { - return format!("{prefix}/responses"); - } - - // If an explicit API path already exists (e.g. /v1, /openai, /api/coding/v3), - // append responses directly to avoid duplicate /v1 segments. - if self.has_explicit_api_path() { - format!("{normalized_base}/responses") - } else { - format!("{normalized_base}/v1/responses") - } - } - - #[allow(dead_code)] - fn tool_specs_to_openai_format(tools: &[ToolSpec]) -> Vec { - tools - .iter() - .map(|tool| { - serde_json::json!({ - "type": "function", - "function": { - "name": tool.name, - "description": tool.description, - "parameters": tool.parameters - } - }) - }) - .collect() - } -} +// ── Request types ──────────────────────────────────────────────────────────── #[derive(Debug, Serialize)] -struct ApiChatRequest { +struct ChatCompletionsRequest { model: String, - messages: Vec, + messages: Vec, temperature: f64, #[serde(skip_serializing_if = "Option::is_none")] stream: Option, @@ -321,19 +35,39 @@ struct ApiChatRequest { } #[derive(Debug, Serialize)] -struct Message { +struct RequestMessage { role: String, - content: MessageContent, + #[serde(skip_serializing_if = "Option::is_none")] + content: Option, + #[serde(skip_serializing_if = "Option::is_none")] + tool_call_id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + tool_calls: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + reasoning_content: Option, } -#[derive(Debug, Serialize)] -#[serde(untagged)] -enum MessageContent { - Text(String), +#[derive(Debug, Serialize, Deserialize)] +struct ToolCallOut { + #[serde(skip_serializing_if = "Option::is_none")] + id: Option, + #[serde(rename = "type", skip_serializing_if = "Option::is_none")] + kind: Option, + function: FunctionRef, } +#[derive(Debug, Serialize, Deserialize)] +struct FunctionRef { + #[serde(default)] + name: Option, + #[serde(default)] + arguments: Option, +} + +// ── Response types ─────────────────────────────────────────────────────────── + #[derive(Debug, Deserialize)] -struct ApiChatResponse { +struct ChatCompletionsResponse { choices: Vec, #[serde(default)] usage: Option, @@ -352,419 +86,156 @@ struct Choice { message: ResponseMessage, } -/// Remove `...` blocks from model output. -/// Some reasoning models (e.g. MiniMax) embed their chain-of-thought inline -/// in the `content` field rather than a separate `reasoning_content` field. -/// The resulting `` tags must be stripped before returning to the user. -fn strip_think_tags(s: &str) -> String { - let mut result = String::with_capacity(s.len()); - let mut rest = s; - loop { - if let Some(start) = rest.find("") { - result.push_str(&rest[..start]); - if let Some(end) = rest[start..].find("") { - rest = &rest[start + end + "".len()..]; - } else { - // Unclosed tag: drop the rest to avoid leaking partial reasoning. - break; - } - } else { - result.push_str(rest); - break; - } - } - result.trim().to_string() -} - -#[derive(Debug, Deserialize, Serialize)] +#[derive(Debug, Deserialize)] struct ResponseMessage { #[serde(default)] content: Option, - /// Reasoning/thinking models (e.g. Qwen3, GLM-4) may return their output - /// in `reasoning_content` instead of `content`. Used as automatic fallback. #[serde(default)] reasoning_content: Option, #[serde(default)] - tool_calls: Option>, -} - -impl ResponseMessage { - /// Extract text content, falling back to `reasoning_content` when `content` - /// is missing or empty. Reasoning/thinking models (Qwen3, GLM-4, etc.) - /// often return their output solely in `reasoning_content`. - /// Strips `...` blocks that some models (e.g. MiniMax) embed - /// inline in `content` instead of using a separate field. - fn effective_content(&self) -> String { - if let Some(content) = self.content.as_ref().filter(|c| !c.is_empty()) { - let stripped = strip_think_tags(content); - if !stripped.is_empty() { - return stripped; - } - } - - self.reasoning_content - .as_ref() - .map(|c| strip_think_tags(c)) - .filter(|c| !c.is_empty()) - .unwrap_or_default() - } - - fn effective_content_optional(&self) -> Option { - if let Some(content) = self.content.as_ref().filter(|c| !c.is_empty()) { - let stripped = strip_think_tags(content); - if !stripped.is_empty() { - return Some(stripped); - } - } - - self.reasoning_content - .as_ref() - .map(|c| strip_think_tags(c)) - .filter(|c| !c.is_empty()) - } + tool_calls: Option>, } -#[derive(Debug, Deserialize, Serialize)] -struct ToolCall { - #[serde(skip_serializing_if = "Option::is_none")] - id: Option, - #[serde(rename = "type")] +#[derive(Debug, Deserialize)] +struct ToolCallIn { #[serde(default)] - kind: Option, + id: Option, #[serde(default)] - function: Option, - - // Compatibility: Some providers (e.g., older GLM) may use 'name' directly + function: Option, + // Fallback: some providers use top-level name/arguments #[serde(default)] name: Option, #[serde(default)] arguments: Option, - - // Compatibility: DeepSeek sometimes wraps arguments differently #[serde(rename = "parameters", default)] parameters: Option, } -impl ToolCall { - /// Extract function name with fallback logic for various provider formats +impl ToolCallIn { fn function_name(&self) -> Option { - // Standard OpenAI format: tool_calls[].function.name - if let Some(ref func) = self.function - && let Some(ref name) = func.name - { - return Some(name.clone()); - } - // Fallback: direct name field - self.name.clone() + self.function + .as_ref() + .and_then(|f| f.name.clone()) + .or_else(|| self.name.clone()) } - /// Extract arguments with fallback logic and type conversion fn function_arguments(&self) -> Option { - // Standard OpenAI format: tool_calls[].function.arguments (string) - if let Some(ref func) = self.function - && let Some(ref args) = func.arguments - { - return Some(args.clone()); - } - // Fallback: direct arguments field - if let Some(ref args) = self.arguments { - return Some(args.clone()); - } - // Compatibility: Some providers return parameters as object instead of string - if let Some(ref params) = self.parameters { - return serde_json::to_string(params).ok(); - } - None + self.function + .as_ref() + .and_then(|f| f.arguments.clone()) + .or_else(|| self.arguments.clone()) + .or_else(|| { + self.parameters + .as_ref() + .and_then(|p| serde_json::to_string(p).ok()) + }) } } -#[derive(Debug, Deserialize, Serialize)] -struct Function { - #[serde(default)] - name: Option, - #[serde(default)] - arguments: Option, -} - -#[derive(Debug, Serialize)] -struct NativeChatRequest { - model: String, - messages: Vec, - temperature: f64, - #[serde(skip_serializing_if = "Option::is_none")] - stream: Option, - #[serde(skip_serializing_if = "Option::is_none")] - tools: Option>, - #[serde(skip_serializing_if = "Option::is_none")] - tool_choice: Option, -} - -#[derive(Debug, Serialize)] -struct NativeMessage { - role: String, - #[serde(skip_serializing_if = "Option::is_none")] - content: Option, - #[serde(skip_serializing_if = "Option::is_none")] - tool_call_id: Option, - #[serde(skip_serializing_if = "Option::is_none")] - tool_calls: Option>, - /// Raw reasoning content from thinking models; pass-through for providers - /// that require it in assistant tool-call history messages. - #[serde(skip_serializing_if = "Option::is_none")] - reasoning_content: Option, -} - -#[derive(Debug, Serialize)] -struct ResponsesRequest { - model: String, - input: Vec, - #[serde(skip_serializing_if = "Option::is_none")] - instructions: Option, - #[serde(skip_serializing_if = "Option::is_none")] - stream: Option, -} - -#[derive(Debug, Serialize)] -struct ResponsesInput { - role: String, - content: String, -} - -#[derive(Debug, Deserialize)] -struct ResponsesResponse { - #[serde(default)] - output: Vec, - #[serde(default)] - output_text: Option, -} - -#[derive(Debug, Deserialize)] -struct ResponsesOutput { - #[serde(default)] - content: Vec, -} - -#[derive(Debug, Deserialize)] -struct ResponsesContent { - #[serde(rename = "type")] - kind: Option, - text: Option, -} - -// --------------------------------------------------------------- -fn first_nonempty(text: Option<&str>) -> Option { - text.and_then(|value| { - let trimmed = value.trim(); - if trimmed.is_empty() { - None - } else { - Some(trimmed.to_string()) - } - }) -} - -fn normalize_responses_role(role: &str) -> &'static str { - match role { - "assistant" | "tool" => "assistant", - _ => "user", +impl ResponseMessage { + fn effective_content(&self) -> Option { + self.content + .as_ref() + .filter(|c| !c.trim().is_empty()) + .cloned() + .or_else(|| { + self.reasoning_content + .as_ref() + .filter(|c| !c.trim().is_empty()) + .cloned() + }) } } -fn build_responses_prompt(messages: &[ChatMessage]) -> (Option, Vec) { - let mut instructions_parts = Vec::new(); - let mut input = Vec::new(); - - for message in messages { - if message.content.trim().is_empty() { - continue; - } +// ── Provider ───────────────────────────────────────────────────────────────── - if message.role == "system" { - instructions_parts.push(message.content.clone()); - continue; +impl OpenAiCompatibleProvider { + pub fn new(name: &str, base_url: &str, credential: Option<&str>) -> Self { + Self { + name: name.to_string(), + base_url: base_url.trim_end_matches('/').to_string(), + credential: credential + .map(str::trim) + .filter(|k| !k.is_empty()) + .map(ToString::to_string), + client: Client::builder() + .timeout(std::time::Duration::from_secs(120)) + .connect_timeout(std::time::Duration::from_secs(10)) + .build() + .unwrap_or_else(|_| Client::new()), } - - input.push(ResponsesInput { - role: normalize_responses_role(&message.role).to_string(), - content: message.content.clone(), - }); } - let instructions = if instructions_parts.is_empty() { - None - } else { - Some(instructions_parts.join("\n\n")) - }; - - (instructions, input) -} - -fn extract_responses_text(response: ResponsesResponse) -> Option { - if let Some(text) = first_nonempty(response.output_text.as_deref()) { - return Some(text); + fn chat_completions_url(&self) -> String { + if self.base_url.ends_with("/chat/completions") { + self.base_url.clone() + } else { + format!("{}/chat/completions", self.base_url) + } } - for item in &response.output { - for content in &item.content { - if content.kind.as_deref() == Some("output_text") - && let Some(text) = first_nonempty(content.text.as_deref()) - { - return Some(text); - } - } + fn require_credential(&self) -> anyhow::Result<&str> { + self.credential.as_deref().ok_or_else(|| { + anyhow::anyhow!( + "{} API key not set. Set the appropriate env var.", + self.name + ) + }) } - for item in &response.output { - for content in &item.content { - if let Some(text) = first_nonempty(content.text.as_deref()) { - return Some(text); - } - } + fn convert_tools(tools: Option<&[ToolSpec]>) -> Option> { + let items = tools.filter(|t| !t.is_empty())?; + Some( + items + .iter() + .map(|tool| { + serde_json::json!({ + "type": "function", + "function": { + "name": tool.name, + "description": tool.description, + "parameters": tool.parameters, + } + }) + }) + .collect(), + ) } - None -} + fn convert_messages(messages: &[ChatMessage]) -> Vec { + messages + .iter() + .map(|m| { + // Assistant message with embedded tool_calls JSON + if m.role == "assistant" + && let Ok(value) = serde_json::from_str::(&m.content) + && let Some(tool_calls_value) = value.get("tool_calls") + && let Ok(parsed) = + serde_json::from_value::>(tool_calls_value.clone()) + { + let tool_calls = parsed + .into_iter() + .map(|tc| ToolCallOut { + id: Some(tc.id), + kind: Some("function".to_string()), + function: FunctionRef { + name: Some(tc.name), + arguments: Some(tc.arguments), + }, + }) + .collect(); -fn compact_sanitized_body_snippet(body: &str) -> String { - sanitize_api_error(body) - .split_whitespace() - .collect::>() - .join(" ") -} + let content = value + .get("content") + .and_then(serde_json::Value::as_str) + .map(ToString::to_string); -fn parse_chat_response_body(provider_name: &str, body: &str) -> anyhow::Result { - serde_json::from_str::(body).map_err(|error| { - let snippet = compact_sanitized_body_snippet(body); - anyhow::anyhow!( - "{provider_name} API returned an unexpected chat-completions payload: {error}; body={snippet}" - ) - }) -} + let reasoning_content = value + .get("reasoning_content") + .and_then(serde_json::Value::as_str) + .map(ToString::to_string); -fn parse_responses_response_body( - provider_name: &str, - body: &str, -) -> anyhow::Result { - serde_json::from_str::(body).map_err(|error| { - let snippet = compact_sanitized_body_snippet(body); - anyhow::anyhow!( - "{provider_name} Responses API returned an unexpected payload: {error}; body={snippet}" - ) - }) -} - -impl OpenAiCompatibleProvider { - fn apply_auth_header( - &self, - req: reqwest::RequestBuilder, - credential: &str, - ) -> reqwest::RequestBuilder { - match &self.auth_header { - AuthStyle::Bearer => req.header("Authorization", format!("Bearer {credential}")), - AuthStyle::XApiKey => req.header("x-api-key", credential), - AuthStyle::Custom(header) => req.header(header, credential), - } - } - - async fn chat_via_responses( - &self, - credential: &str, - messages: &[ChatMessage], - model: &str, - ) -> anyhow::Result { - let (instructions, input) = build_responses_prompt(messages); - if input.is_empty() { - anyhow::bail!( - "{} Responses API fallback requires at least one non-system message", - self.name - ); - } - - let request = ResponsesRequest { - model: model.to_string(), - input, - instructions, - stream: Some(false), - }; - - let url = self.responses_url(); - - let response = self - .apply_auth_header(self.http_client().post(&url).json(&request), credential) - .send() - .await?; - - if !response.status().is_success() { - let error = response.text().await?; - anyhow::bail!("{} Responses API error: {error}", self.name); - } - - let body = response.text().await?; - let responses = parse_responses_response_body(&self.name, &body)?; - - extract_responses_text(responses) - .ok_or_else(|| anyhow::anyhow!("No response from {} Responses API", self.name)) - } - - fn convert_tool_specs(tools: Option<&[ToolSpec]>) -> Option> { - tools.map(|items| { - items - .iter() - .map(|tool| { - serde_json::json!({ - "type": "function", - "function": { - "name": tool.name, - "description": tool.description, - "parameters": tool.parameters, - } - }) - }) - .collect() - }) - } - - #[allow(dead_code)] - fn to_message_content(_role: &str, content: &str) -> MessageContent { - MessageContent::Text(content.to_string()) - } - - fn convert_messages_for_native(messages: &[ChatMessage]) -> Vec { - messages - .iter() - .map(|message| { - if message.role == "assistant" - && let Ok(value) = serde_json::from_str::(&message.content) - && let Some(tool_calls_value) = value.get("tool_calls") - && let Ok(parsed_calls) = - serde_json::from_value::>(tool_calls_value.clone()) - { - let tool_calls = parsed_calls - .into_iter() - .map(|tc| ToolCall { - id: Some(tc.id), - kind: Some("function".to_string()), - function: Some(Function { - name: Some(tc.name), - arguments: Some(tc.arguments), - }), - name: None, - arguments: None, - parameters: None, - }) - .collect::>(); - - let content = value - .get("content") - .and_then(serde_json::Value::as_str) - .map(|value| MessageContent::Text(value.to_string())); - - let reasoning_content = value - .get("reasoning_content") - .and_then(serde_json::Value::as_str) - .map(ToString::to_string); - - return NativeMessage { + return RequestMessage { role: "assistant".to_string(), content, tool_call_id: None, @@ -773,8 +244,9 @@ impl OpenAiCompatibleProvider { }; } - if message.role == "tool" - && let Ok(value) = serde_json::from_str::(&message.content) + // Tool result message + if m.role == "tool" + && let Ok(value) = serde_json::from_str::(&m.content) { let tool_call_id = value .get("tool_call_id") @@ -783,10 +255,10 @@ impl OpenAiCompatibleProvider { let content = value .get("content") .and_then(serde_json::Value::as_str) - .map(|value| MessageContent::Text(value.to_string())) - .or_else(|| Some(MessageContent::Text(message.content.clone()))); + .map(ToString::to_string) + .or_else(|| Some(m.content.clone())); - return NativeMessage { + return RequestMessage { role: "tool".to_string(), content, tool_call_id, @@ -795,9 +267,10 @@ impl OpenAiCompatibleProvider { }; } - NativeMessage { - role: message.role.clone(), - content: Some(MessageContent::Text(message.content.clone())), + // Regular message + RequestMessage { + role: m.role.clone(), + content: Some(m.content.clone()), tool_call_id: None, tool_calls: None, reasoning_content: None, @@ -806,117 +279,80 @@ impl OpenAiCompatibleProvider { .collect() } - fn with_prompt_guided_tool_instructions( - messages: &[ChatMessage], - tools: Option<&[ToolSpec]>, - ) -> Vec { - let Some(tools) = tools else { - return messages.to_vec(); - }; - - if tools.is_empty() { - return messages.to_vec(); - } - - let instructions = crate::provider::traits::build_tool_instructions_text(tools); - let mut modified_messages = messages.to_vec(); + fn parse_response( + response: ChatCompletionsResponse, + ) -> (ProviderChatResponse, Option) { + let usage = response.usage.map(|u| TokenUsage { + input_tokens: u.prompt_tokens, + output_tokens: u.completion_tokens, + }); - if let Some(system_message) = modified_messages.iter_mut().find(|m| m.role == "system") { - if !system_message.content.is_empty() { - system_message.content.push_str("\n\n"); - } - system_message.content.push_str(&instructions); - } else { - modified_messages.insert(0, ChatMessage::system(instructions)); - } + let Some(choice) = response.choices.into_iter().next() else { + return ( + ProviderChatResponse { + text: None, + tool_calls: vec![], + usage: None, + reasoning_content: None, + }, + usage, + ); + }; - modified_messages - } + let msg = choice.message; + let text = msg.effective_content(); + let reasoning_content = msg.reasoning_content.clone(); - fn parse_native_response(message: ResponseMessage) -> ProviderChatResponse { - let text = message.effective_content_optional(); - let reasoning_content = message.reasoning_content.clone(); - let tool_calls = message + let tool_calls = msg .tool_calls .unwrap_or_default() .into_iter() .filter_map(|tc| { let name = tc.function_name()?; let arguments = tc.function_arguments().unwrap_or_else(|| "{}".to_string()); - let normalized_arguments = - if serde_json::from_str::(&arguments).is_ok() { - arguments - } else { - tracing::warn!( - function = %name, - arguments = %arguments, - "Invalid JSON in native tool-call arguments, using empty object" - ); - "{}".to_string() - }; + let arguments = if serde_json::from_str::(&arguments).is_ok() { + arguments + } else { + "{}".to_string() + }; Some(ProviderToolCall { id: tc.id.unwrap_or_else(|| uuid::Uuid::new_v4().to_string()), name, - arguments: normalized_arguments, + arguments, }) }) - .collect::>(); - - ProviderChatResponse { - text, - tool_calls, - usage: None, - reasoning_content, - } - } - - fn is_native_tool_schema_unsupported(status: reqwest::StatusCode, error: &str) -> bool { - if !matches!( - status, - reqwest::StatusCode::BAD_REQUEST | reqwest::StatusCode::UNPROCESSABLE_ENTITY - ) { - return false; - } + .collect(); - let lower = error.to_lowercase(); - [ - "unknown parameter: tools", - "unsupported parameter: tools", - "unrecognized field `tools`", - "does not support tools", - "function calling is not supported", - "tool_choice", - ] - .iter() - .any(|hint| lower.contains(hint)) + ( + ProviderChatResponse { + text, + tool_calls, + usage: None, + reasoning_content, + }, + usage, + ) } } +// ── Public utilities (used by other provider modules) ──────────────────────── + /// Sanitize API error text by scrubbing secrets and truncating length. pub fn sanitize_api_error(input: &str) -> String { - // Redact common secret patterns (API keys, tokens, etc.) - let patterns = [ - // sk-ant-*, sk-*, similar key patterns - (r"sk-[A-Za-z0-9\-_]{8,}", "[REDACTED]"), - ]; - let mut result = input.to_string(); - for (pattern, replacement) in &patterns { - if let Ok(re) = regex::Regex::new(pattern) { - result = re.replace_all(&result, *replacement).to_string(); - } + if let Ok(re) = regex::Regex::new(r"sk-[A-Za-z0-9\-_]{8,}") { + result = re.replace_all(&result, "[REDACTED]").to_string(); } - const MAX_API_ERROR_CHARS: usize = 200; - if result.chars().count() <= MAX_API_ERROR_CHARS { + const MAX_CHARS: usize = 200; + if result.chars().count() <= MAX_CHARS { return result; } - let mut end = MAX_API_ERROR_CHARS; + let mut end = MAX_CHARS; while end > 0 && !result.is_char_boundary(end) { end -= 1; } - format!("{}...", &result[..end]) } @@ -931,15 +367,30 @@ pub async fn api_error(provider: &str, response: reqwest::Response) -> anyhow::E anyhow::anyhow!("{provider} API error ({status}): {sanitized}") } +// ── Provider trait ─────────────────────────────────────────────────────────── + #[async_trait] impl Provider for OpenAiCompatibleProvider { - fn capabilities(&self) -> crate::provider::traits::ProviderCapabilities { - crate::provider::traits::ProviderCapabilities { - native_tool_calling: self.native_tool_calling, - vision: self.supports_vision, + fn capabilities(&self) -> ProviderCapabilities { + ProviderCapabilities { + native_tool_calling: true, + vision: false, } } + async fn warmup(&self) -> anyhow::Result<()> { + if let Some(credential) = &self.credential { + let url = self.chat_completions_url(); + let _ = self + .client + .get(&url) + .header("Authorization", format!("Bearer {credential}")) + .send() + .await; + } + Ok(()) + } + async fn chat_with_system( &self, system_prompt: Option<&str>, @@ -947,38 +398,27 @@ impl Provider for OpenAiCompatibleProvider { model: &str, temperature: f64, ) -> anyhow::Result { - let credential = self.credential.as_ref().ok_or_else(|| { - anyhow::anyhow!( - "{} API key not set. Set the appropriate env var.", - self.name - ) - })?; + let credential = self.require_credential()?; let mut messages = Vec::new(); - - if self.merge_system_into_user { - let content = match system_prompt { - Some(sys) => format!("{sys}\n\n{message}"), - None => message.to_string(), - }; - messages.push(Message { - role: "user".to_string(), - content: MessageContent::Text(content), - }); - } else { - if let Some(sys) = system_prompt { - messages.push(Message { - role: "system".to_string(), - content: MessageContent::Text(sys.to_string()), - }); - } - messages.push(Message { - role: "user".to_string(), - content: MessageContent::Text(message.to_string()), + if let Some(sys) = system_prompt { + messages.push(RequestMessage { + role: "system".to_string(), + content: Some(sys.to_string()), + tool_call_id: None, + tool_calls: None, + reasoning_content: None, }); } + messages.push(RequestMessage { + role: "user".to_string(), + content: Some(message.to_string()), + tool_call_id: None, + tool_calls: None, + reasoning_content: None, + }); - let request = ApiChatRequest { + let request = ChatCompletionsRequest { model: model.to_string(), messages, temperature, @@ -988,823 +428,159 @@ impl Provider for OpenAiCompatibleProvider { }; let url = self.chat_completions_url(); - - let mut fallback_messages = Vec::new(); - if let Some(system_prompt) = system_prompt { - fallback_messages.push(ChatMessage::system(system_prompt)); - } - fallback_messages.push(ChatMessage::user(message)); - let fallback_messages = if self.merge_system_into_user { - Self::flatten_system_messages(&fallback_messages) - } else { - fallback_messages - }; - - let response = match self - .apply_auth_header(self.http_client().post(&url).json(&request), credential) - .send() - .await - { - Ok(response) => response, - Err(chat_error) => { - if self.supports_responses_fallback { - let sanitized = sanitize_api_error(&chat_error.to_string()); - return self - .chat_via_responses(credential, &fallback_messages, model) - .await - .map_err(|responses_err| { - anyhow::anyhow!( - "{} chat completions transport error: {sanitized} (responses fallback failed: {responses_err})", - self.name - ) - }); - } - - return Err(chat_error.into()); - } - }; - - if !response.status().is_success() { - let status = response.status(); - let error = response.text().await?; - let sanitized = sanitize_api_error(&error); - - if status == reqwest::StatusCode::NOT_FOUND && self.supports_responses_fallback { - return self - .chat_via_responses(credential, &fallback_messages, model) - .await - .map_err(|responses_err| { - anyhow::anyhow!( - "{} API error ({status}): {sanitized} (chat completions unavailable; responses fallback failed: {responses_err})", - self.name - ) - }); - } - - anyhow::bail!("{} API error ({status}): {sanitized}", self.name); - } - - let body = response.text().await?; - let chat_response = parse_chat_response_body(&self.name, &body)?; - - chat_response - .choices - .into_iter() - .next() - .map(|c| { - // If tool_calls are present, serialize the full message as JSON - // so parse_tool_calls can handle the OpenAI-style format - if c.message.tool_calls.is_some() - && c.message.tool_calls.as_ref().is_some_and(|t| !t.is_empty()) - { - serde_json::to_string(&c.message) - .unwrap_or_else(|_| c.message.effective_content()) - } else { - // No tool calls, return content (with reasoning_content fallback) - c.message.effective_content() - } - }) - .ok_or_else(|| anyhow::anyhow!("No response from {}", self.name)) - } - - async fn chat_with_history( - &self, - messages: &[ChatMessage], - model: &str, - temperature: f64, - ) -> anyhow::Result { - let credential = self.credential.as_ref().ok_or_else(|| { - anyhow::anyhow!( - "{} API key not set. Set the appropriate env var.", - self.name - ) - })?; - - let effective_messages = if self.merge_system_into_user { - Self::flatten_system_messages(messages) - } else { - messages.to_vec() - }; - let api_messages: Vec = effective_messages - .iter() - .map(|m| Message { - role: m.role.clone(), - content: MessageContent::Text(m.content.clone()), - }) - .collect(); - - let request = ApiChatRequest { - model: model.to_string(), - messages: api_messages, - temperature, - stream: Some(false), - tools: None, - tool_choice: None, - }; - - let url = self.chat_completions_url(); - let response = match self - .apply_auth_header(self.http_client().post(&url).json(&request), credential) - .send() - .await - { - Ok(response) => response, - Err(chat_error) => { - if self.supports_responses_fallback { - let sanitized = sanitize_api_error(&chat_error.to_string()); - return self - .chat_via_responses(credential, &effective_messages, model) - .await - .map_err(|responses_err| { - anyhow::anyhow!( - "{} chat completions transport error: {sanitized} (responses fallback failed: {responses_err})", - self.name - ) - }); - } - - return Err(chat_error.into()); - } - }; - - if !response.status().is_success() { - let status = response.status(); - - // Mirror chat_with_system: 404 may mean this provider uses the Responses API - if status == reqwest::StatusCode::NOT_FOUND && self.supports_responses_fallback { - return self - .chat_via_responses(credential, &effective_messages, model) - .await - .map_err(|responses_err| { - anyhow::anyhow!( - "{} API error (chat completions unavailable; responses fallback failed: {responses_err})", - self.name - ) - }); - } - - return Err(api_error(&self.name, response).await); - } - - let body = response.text().await?; - let chat_response = parse_chat_response_body(&self.name, &body)?; - - chat_response - .choices - .into_iter() - .next() - .map(|c| { - // If tool_calls are present, serialize the full message as JSON - // so parse_tool_calls can handle the OpenAI-style format - if c.message.tool_calls.is_some() - && c.message.tool_calls.as_ref().is_some_and(|t| !t.is_empty()) - { - serde_json::to_string(&c.message) - .unwrap_or_else(|_| c.message.effective_content()) - } else { - // No tool calls, return content (with reasoning_content fallback) - c.message.effective_content() - } - }) - .ok_or_else(|| anyhow::anyhow!("No response from {}", self.name)) - } - - async fn chat_with_tools( - &self, - messages: &[ChatMessage], - tools: &[serde_json::Value], - model: &str, - temperature: f64, - ) -> anyhow::Result { - let credential = self.credential.as_ref().ok_or_else(|| { - anyhow::anyhow!( - "{} API key not set. Set the appropriate env var.", - self.name - ) - })?; - - let effective_messages = if self.merge_system_into_user { - Self::flatten_system_messages(messages) - } else { - messages.to_vec() - }; - let api_messages: Vec = effective_messages - .iter() - .map(|m| Message { - role: m.role.clone(), - content: MessageContent::Text(m.content.clone()), - }) - .collect(); - - let request = ApiChatRequest { - model: model.to_string(), - messages: api_messages, - temperature, - stream: Some(false), - tools: if tools.is_empty() { - None - } else { - Some(tools.to_vec()) - }, - tool_choice: if tools.is_empty() { - None - } else { - Some("auto".to_string()) - }, - }; - - let url = self.chat_completions_url(); - let response = match self - .apply_auth_header(self.http_client().post(&url).json(&request), credential) - .send() - .await - { - Ok(response) => response, - Err(error) => { - tracing::warn!( - "{} native tool call transport failed: {error}; falling back to history path", - self.name - ); - let text = self.chat_with_history(messages, model, temperature).await?; - return Ok(ProviderChatResponse { - text: Some(text), - tool_calls: vec![], - usage: None, - reasoning_content: None, - }); - } - }; - - if !response.status().is_success() { - return Err(api_error(&self.name, response).await); - } - - let body = response.text().await?; - let chat_response = parse_chat_response_body(&self.name, &body)?; - let usage = chat_response.usage.map(|u| TokenUsage { - input_tokens: u.prompt_tokens, - output_tokens: u.completion_tokens, - }); - let choice = chat_response - .choices - .into_iter() - .next() - .ok_or_else(|| anyhow::anyhow!("No response from {}", self.name))?; - - let text = choice.message.effective_content_optional(); - let reasoning_content = choice.message.reasoning_content; - let tool_calls = choice - .message - .tool_calls - .unwrap_or_default() - .into_iter() - .filter_map(|tc| { - let function = tc.function?; - let name = function.name?; - let arguments = function.arguments.unwrap_or_else(|| "{}".to_string()); - Some(ProviderToolCall { - id: uuid::Uuid::new_v4().to_string(), - name, - arguments, - }) - }) - .collect::>(); - - Ok(ProviderChatResponse { - text, - tool_calls, - usage, - reasoning_content, - }) - } - - async fn chat( - &self, - request: ProviderChatRequest<'_>, - model: &str, - temperature: f64, - ) -> anyhow::Result { - let credential = self.credential.as_ref().ok_or_else(|| { - anyhow::anyhow!( - "{} API key not set. Set the appropriate env var.", - self.name - ) - })?; - - let tools = Self::convert_tool_specs(request.tools); - let effective_messages = if self.merge_system_into_user { - Self::flatten_system_messages(request.messages) - } else { - request.messages.to_vec() - }; - let native_request = NativeChatRequest { - model: model.to_string(), - messages: Self::convert_messages_for_native(&effective_messages), - temperature, - stream: Some(false), - tool_choice: tools.as_ref().map(|_| "auto".to_string()), - tools, - }; - - let url = self.chat_completions_url(); - let response = match self - .apply_auth_header( - self.http_client().post(&url).json(&native_request), - credential, - ) + let response = self + .client + .post(&url) + .header("Authorization", format!("Bearer {credential}")) + .json(&request) .send() - .await - { - Ok(response) => response, - Err(chat_error) => { - if self.supports_responses_fallback { - let sanitized = sanitize_api_error(&chat_error.to_string()); - return self - .chat_via_responses(credential, &effective_messages, model) - .await - .map(|text| ProviderChatResponse { - text: Some(text), - tool_calls: vec![], - usage: None, - reasoning_content: None, - }) - .map_err(|responses_err| { - anyhow::anyhow!( - "{} native chat transport error: {sanitized} (responses fallback failed: {responses_err})", - self.name - ) - }); - } - - return Err(chat_error.into()); - } - }; + .await?; if !response.status().is_success() { - let status = response.status(); - let error = response.text().await?; - let sanitized = sanitize_api_error(&error); - - if Self::is_native_tool_schema_unsupported(status, &sanitized) { - let fallback_messages = - Self::with_prompt_guided_tool_instructions(request.messages, request.tools); - let text = self - .chat_with_history(&fallback_messages, model, temperature) - .await?; - return Ok(ProviderChatResponse { - text: Some(text), - tool_calls: vec![], - usage: None, - reasoning_content: None, - }); - } - - if status == reqwest::StatusCode::NOT_FOUND && self.supports_responses_fallback { - return self - .chat_via_responses(credential, &effective_messages, model) - .await - .map(|text| ProviderChatResponse { - text: Some(text), - tool_calls: vec![], - usage: None, - reasoning_content: None, - }) - .map_err(|responses_err| { - anyhow::anyhow!( - "{} API error ({status}): {sanitized} (chat completions unavailable; responses fallback failed: {responses_err})", - self.name - ) - }); - } - - anyhow::bail!("{} API error ({status}): {sanitized}", self.name); - } - - let native_response: ApiChatResponse = response.json().await?; - let usage = native_response.usage.map(|u| TokenUsage { - input_tokens: u.prompt_tokens, - output_tokens: u.completion_tokens, - }); - let message = native_response - .choices - .into_iter() - .next() - .map(|choice| choice.message) - .ok_or_else(|| anyhow::anyhow!("No response from {}", self.name))?; - - let mut result = Self::parse_native_response(message); - result.usage = usage; - Ok(result) - } - - fn supports_native_tools(&self) -> bool { - self.native_tool_calling - } - - async fn warmup(&self) -> anyhow::Result<()> { - if let Some(credential) = self.credential.as_ref() { - // Hit the chat completions URL with a GET to establish the connection pool. - // The server will likely return 405 Method Not Allowed, which is fine - - // the goal is TLS handshake and HTTP/2 negotiation. - let url = self.chat_completions_url(); - let _ = self - .apply_auth_header(self.http_client().get(&url), credential) - .send() - .await?; - } - Ok(()) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - fn make_provider(name: &str, url: &str, key: Option<&str>) -> OpenAiCompatibleProvider { - OpenAiCompatibleProvider::new(name, url, key, AuthStyle::Bearer) - } - - #[test] - fn creates_with_key() { - let p = make_provider( - "venice", - "https://api.venice.ai", - Some("venice-test-credential"), - ); - assert_eq!(p.name, "venice"); - assert_eq!(p.base_url, "https://api.venice.ai"); - assert_eq!(p.credential.as_deref(), Some("venice-test-credential")); - } - - #[test] - fn creates_without_key() { - let p = make_provider("test", "https://example.com", None); - assert!(p.credential.is_none()); - } - - #[test] - fn strips_trailing_slash() { - let p = make_provider("test", "https://example.com/", None); - assert_eq!(p.base_url, "https://example.com"); - } - - #[tokio::test] - async fn chat_fails_without_key() { - let p = make_provider("Venice", "https://api.venice.ai", None); - let result = p - .chat_with_system(None, "hello", "llama-3.3-70b", 0.7) - .await; - assert!(result.is_err()); - assert!( - result - .unwrap_err() - .to_string() - .contains("Venice API key not set") - ); - } - - #[test] - fn request_serializes_correctly() { - let req = ApiChatRequest { - model: "llama-3.3-70b".to_string(), - messages: vec![ - Message { - role: "system".to_string(), - content: MessageContent::Text("You are a helpful assistant".to_string()), - }, - Message { - role: "user".to_string(), - content: MessageContent::Text("hello".to_string()), - }, - ], - temperature: 0.4, - stream: Some(false), - tools: None, - tool_choice: None, - }; - let json = serde_json::to_string(&req).unwrap(); - assert!(json.contains("llama-3.3-70b")); - assert!(json.contains("system")); - assert!(json.contains("user")); - // tools/tool_choice should be omitted when None - assert!(!json.contains("tools")); - assert!(!json.contains("tool_choice")); - } - - #[test] - fn response_deserializes() { - let json = r#"{"choices":[{"message":{"content":"Hello from Venice!"}}]}"#; - let resp: ApiChatResponse = serde_json::from_str(json).unwrap(); - assert_eq!( - resp.choices[0].message.content, - Some("Hello from Venice!".to_string()) - ); - } - - #[test] - fn response_empty_choices() { - let json = r#"{"choices":[]}"#; - let resp: ApiChatResponse = serde_json::from_str(json).unwrap(); - assert!(resp.choices.is_empty()); - } - - #[test] - fn parse_chat_response_body_reports_sanitized_snippet() { - let body = r#"{"choices":"invalid","api_key":"sk-test-secret-value"}"#; - let err = parse_chat_response_body("custom", body).expect_err("payload should fail"); - let msg = err.to_string(); - - assert!(msg.contains("custom API returned an unexpected chat-completions payload")); - assert!(msg.contains("body=")); - // Secrets should be redacted - assert!(!msg.contains("sk-test-secret-value")); - } - - #[test] - fn parse_responses_response_body_reports_sanitized_snippet() { - let body = r#"{"output_text":123,"api_key":"sk-another-secret"}"#; - let err = parse_responses_response_body("custom", body).expect_err("payload should fail"); - let msg = err.to_string(); - - assert!(msg.contains("custom Responses API returned an unexpected payload")); - assert!(msg.contains("body=")); - assert!(!msg.contains("sk-another-secret")); - } - - #[test] - fn x_api_key_auth_style() { - let p = OpenAiCompatibleProvider::new( - "moonshot", - "https://api.moonshot.cn", - Some("ms-key"), - AuthStyle::XApiKey, - ); - assert!(matches!(p.auth_header, AuthStyle::XApiKey)); - } - - #[test] - fn custom_auth_style() { - let p = OpenAiCompatibleProvider::new( - "custom", - "https://api.example.com", - Some("key"), - AuthStyle::Custom("X-Custom-Key".into()), - ); - assert!(matches!(p.auth_header, AuthStyle::Custom(_))); - } - - #[tokio::test] - async fn all_compatible_providers_fail_without_key() { - let providers = vec![ - make_provider("Venice", "https://api.venice.ai", None), - make_provider("Moonshot", "https://api.moonshot.cn", None), - make_provider("GLM", "https://open.bigmodel.cn", None), - make_provider("MiniMax", "https://api.minimaxi.com/v1", None), - make_provider("Groq", "https://api.groq.com/openai", None), - make_provider("Mistral", "https://api.mistral.ai", None), - make_provider("xAI", "https://api.x.ai", None), - make_provider("Astrai", "https://as-trai.com/v1", None), - ]; - - for p in providers { - let result = p.chat_with_system(None, "test", "model", 0.7).await; - assert!(result.is_err(), "{} should fail without key", p.name); - assert!( - result.unwrap_err().to_string().contains("API key not set"), - "{} error should mention key", - p.name - ); + return Err(api_error(&self.name, response).await); } - } - #[test] - fn responses_extracts_top_level_output_text() { - let json = r#"{"output_text":"Hello from top-level","output":[]}"#; - let response: ResponsesResponse = serde_json::from_str(json).unwrap(); - assert_eq!( - extract_responses_text(response).as_deref(), - Some("Hello from top-level") - ); - } + let body = response.text().await?; + let parsed: ChatCompletionsResponse = serde_json::from_str(&body).map_err(|e| { + anyhow::anyhow!( + "{} unexpected response: {e}; body={}", + self.name, + sanitize_api_error(&body) + ) + })?; - #[test] - fn responses_extracts_nested_output_text() { - let json = - r#"{"output":[{"content":[{"type":"output_text","text":"Hello from nested"}]}]}"#; - let response: ResponsesResponse = serde_json::from_str(json).unwrap(); - assert_eq!( - extract_responses_text(response).as_deref(), - Some("Hello from nested") - ); + let (result, _) = Self::parse_response(parsed); + result + .text + .ok_or_else(|| anyhow::anyhow!("No response from {}", self.name)) } - #[test] - fn responses_extracts_any_text_as_fallback() { - let json = r#"{"output":[{"content":[{"type":"message","text":"Fallback text"}]}]}"#; - let response: ResponsesResponse = serde_json::from_str(json).unwrap(); - assert_eq!( - extract_responses_text(response).as_deref(), - Some("Fallback text") - ); - } + async fn chat_with_history( + &self, + messages: &[ChatMessage], + model: &str, + temperature: f64, + ) -> anyhow::Result { + let credential = self.require_credential()?; - #[test] - fn build_responses_prompt_preserves_multi_turn_history() { - let messages = vec![ - ChatMessage::system("policy"), - ChatMessage::user("step 1"), - ChatMessage::assistant("ack 1"), - ChatMessage::tool("{\"result\":\"ok\"}"), - ChatMessage::user("step 2"), - ]; - - let (instructions, input) = build_responses_prompt(&messages); - - assert_eq!(instructions.as_deref(), Some("policy")); - assert_eq!(input.len(), 4); - assert_eq!(input[0].role, "user"); - assert_eq!(input[0].content, "step 1"); - assert_eq!(input[1].role, "assistant"); - assert_eq!(input[1].content, "ack 1"); - assert_eq!(input[2].role, "assistant"); - assert_eq!(input[2].content, "{\"result\":\"ok\"}"); - assert_eq!(input[3].role, "user"); - assert_eq!(input[3].content, "step 2"); - } + let request = ChatCompletionsRequest { + model: model.to_string(), + messages: Self::convert_messages(messages), + temperature, + stream: Some(false), + tools: None, + tool_choice: None, + }; - #[tokio::test] - async fn chat_via_responses_requires_non_system_message() { - let provider = make_provider("custom", "https://api.example.com", Some("test-key")); - let err = provider - .chat_via_responses("test-key", &[ChatMessage::system("policy")], "gpt-test") - .await - .expect_err("system-only fallback payload should fail"); - - assert!( - err.to_string() - .contains("requires at least one non-system message") - ); - } + let url = self.chat_completions_url(); + let response = self + .client + .post(&url) + .header("Authorization", format!("Bearer {credential}")) + .json(&request) + .send() + .await?; - #[test] - fn tool_call_function_name_falls_back_to_top_level_name() { - let call: ToolCall = serde_json::from_value(serde_json::json!({ - "name": "memory_recall", - "arguments": "{\"query\":\"latest roadmap\"}" - })) - .unwrap(); - - assert_eq!(call.function_name().as_deref(), Some("memory_recall")); - } + if !response.status().is_success() { + return Err(api_error(&self.name, response).await); + } - #[test] - fn tool_call_function_arguments_falls_back_to_parameters_object() { - let call: ToolCall = serde_json::from_value(serde_json::json!({ - "name": "shell", - "parameters": {"command": "pwd"} - })) - .unwrap(); + let body = response.text().await?; + let parsed: ChatCompletionsResponse = serde_json::from_str(&body).map_err(|e| { + anyhow::anyhow!( + "{} unexpected response: {e}; body={}", + self.name, + sanitize_api_error(&body) + ) + })?; - assert_eq!( - call.function_arguments().as_deref(), - Some("{\"command\":\"pwd\"}") - ); + let (result, _) = Self::parse_response(parsed); + result + .text + .ok_or_else(|| anyhow::anyhow!("No response from {}", self.name)) } - #[test] - fn tool_call_function_arguments_prefers_nested_function_field() { - let call: ToolCall = serde_json::from_value(serde_json::json!({ - "name": "ignored_name", - "arguments": "{\"query\":\"ignored\"}", - "function": { - "name": "memory_recall", - "arguments": "{\"query\":\"preferred\"}" - } - })) - .unwrap(); - - assert_eq!(call.function_name().as_deref(), Some("memory_recall")); - assert_eq!( - call.function_arguments().as_deref(), - Some("{\"query\":\"preferred\"}") - ); - } + async fn chat( + &self, + request: ProviderChatRequest<'_>, + model: &str, + temperature: f64, + ) -> anyhow::Result { + let credential = self.require_credential()?; - #[test] - fn chat_completions_url_standard_openai() { - let p = make_provider("openai", "https://api.openai.com/v1", None); - assert_eq!( - p.chat_completions_url(), - "https://api.openai.com/v1/chat/completions" - ); - } + let tools = Self::convert_tools(request.tools); + let api_request = ChatCompletionsRequest { + model: model.to_string(), + messages: Self::convert_messages(request.messages), + temperature, + stream: Some(false), + tool_choice: tools.as_ref().map(|_| "auto".to_string()), + tools, + }; - #[test] - fn chat_completions_url_trailing_slash() { - let p = make_provider("test", "https://api.example.com/v1/", None); - assert_eq!( - p.chat_completions_url(), - "https://api.example.com/v1/chat/completions" - ); - } + let url = self.chat_completions_url(); + let response = self + .client + .post(&url) + .header("Authorization", format!("Bearer {credential}")) + .json(&api_request) + .send() + .await?; - #[test] - fn chat_completions_url_volcengine_ark() { - let p = make_provider( - "volcengine", - "https://ark.cn-beijing.volces.com/api/coding/v3/chat/completions", - None, - ); - assert_eq!( - p.chat_completions_url(), - "https://ark.cn-beijing.volces.com/api/coding/v3/chat/completions" - ); - } + if !response.status().is_success() { + return Err(api_error(&self.name, response).await); + } - #[test] - fn chat_completions_url_custom_full_endpoint() { - let p = make_provider( - "custom", - "https://my-api.example.com/v2/llm/chat/completions", - None, - ); - assert_eq!( - p.chat_completions_url(), - "https://my-api.example.com/v2/llm/chat/completions" - ); - } + let body = response.text().await?; + let parsed: ChatCompletionsResponse = serde_json::from_str(&body).map_err(|e| { + anyhow::anyhow!( + "{} unexpected response: {e}; body={}", + self.name, + sanitize_api_error(&body) + ) + })?; - #[test] - fn chat_completions_url_requires_exact_suffix_match() { - let p = make_provider( - "custom", - "https://my-api.example.com/v2/llm/chat/completions-proxy", - None, - ); - assert_eq!( - p.chat_completions_url(), - "https://my-api.example.com/v2/llm/chat/completions-proxy/chat/completions" - ); + let (mut result, usage) = Self::parse_response(parsed); + result.usage = usage; + Ok(result) } +} - #[test] - fn responses_url_standard() { - let p = make_provider("test", "https://api.example.com", None); - assert_eq!(p.responses_url(), "https://api.example.com/v1/responses"); - } +#[cfg(test)] +mod tests { + use super::*; - #[test] - fn responses_url_custom_full_endpoint() { - let p = make_provider( - "custom", - "https://my-api.example.com/api/v2/responses", - None, - ); - assert_eq!( - p.responses_url(), - "https://my-api.example.com/api/v2/responses" - ); + fn make_provider(name: &str, url: &str, key: Option<&str>) -> OpenAiCompatibleProvider { + OpenAiCompatibleProvider::new(name, url, key) } #[test] - fn responses_url_derives_from_chat_endpoint() { - let p = make_provider( - "custom", - "https://my-api.example.com/api/v2/chat/completions", - None, - ); - assert_eq!( - p.responses_url(), - "https://my-api.example.com/api/v2/responses" - ); + fn creates_with_key() { + let p = make_provider("test", "https://api.example.com/v1", Some("sk-test")); + assert_eq!(p.credential.as_deref(), Some("sk-test")); + assert_eq!(p.name, "test"); } #[test] - fn responses_url_base_with_v1_no_duplicate() { + fn creates_without_key() { let p = make_provider("test", "https://api.example.com/v1", None); - assert_eq!(p.responses_url(), "https://api.example.com/v1/responses"); + assert!(p.credential.is_none()); } #[test] - fn responses_url_non_v1_api_path_uses_raw_suffix() { - let p = make_provider("test", "https://api.example.com/api/coding/v3", None); - assert_eq!( - p.responses_url(), - "https://api.example.com/api/coding/v3/responses" - ); + fn trims_empty_key() { + let p = make_provider("test", "https://api.example.com/v1", Some(" ")); + assert!(p.credential.is_none()); } #[test] - fn chat_completions_url_without_v1() { - let p = make_provider("test", "https://api.example.com", None); - assert_eq!( - p.chat_completions_url(), - "https://api.example.com/chat/completions" - ); + fn strips_trailing_slash() { + let p = make_provider("test", "https://api.example.com/v1/", None); + assert_eq!(p.base_url, "https://api.example.com/v1"); } #[test] - fn chat_completions_url_base_with_v1() { + fn chat_completions_url_appends_path() { let p = make_provider("test", "https://api.example.com/v1", None); assert_eq!( p.chat_completions_url(), @@ -1813,398 +589,208 @@ mod tests { } #[test] - fn chat_completions_url_zai() { - let p = make_provider("zai", "https://api.z.ai/api/paas/v4", None); - assert_eq!( - p.chat_completions_url(), - "https://api.z.ai/api/paas/v4/chat/completions" - ); - } - - #[test] - fn chat_completions_url_minimax() { - let p = make_provider("minimax", "https://api.minimaxi.com/v1", None); + fn chat_completions_url_keeps_existing_path() { + let p = make_provider("test", "https://api.example.com/v1/chat/completions", None); assert_eq!( p.chat_completions_url(), - "https://api.minimaxi.com/v1/chat/completions" + "https://api.example.com/v1/chat/completions" ); } - #[test] - fn chat_completions_url_glm() { - let p = make_provider("glm", "https://open.bigmodel.cn/api/paas/v4", None); - assert_eq!( - p.chat_completions_url(), - "https://open.bigmodel.cn/api/paas/v4/chat/completions" - ); + #[tokio::test] + async fn chat_fails_without_key() { + let p = make_provider("test", "https://api.example.com/v1", None); + let result = p.chat_with_system(None, "hi", "gpt-4o", 0.7).await; + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("API key not set")); } #[test] - fn parse_native_response_preserves_tool_call_id() { - let message = ResponseMessage { - content: None, - tool_calls: Some(vec![ToolCall { - id: Some("call_123".to_string()), - kind: Some("function".to_string()), - function: Some(Function { - name: Some("shell".to_string()), - arguments: Some(r#"{"command":"pwd"}"#.to_string()), - }), - name: None, - arguments: None, - parameters: None, - }]), - reasoning_content: None, + fn request_serializes_correctly() { + let request = ChatCompletionsRequest { + model: "gpt-4o".to_string(), + messages: vec![RequestMessage { + role: "user".to_string(), + content: Some("hello".to_string()), + tool_call_id: None, + tool_calls: None, + reasoning_content: None, + }], + temperature: 0.7, + stream: Some(false), + tools: None, + tool_choice: None, }; - - let parsed = OpenAiCompatibleProvider::parse_native_response(message); - assert_eq!(parsed.tool_calls.len(), 1); - assert_eq!(parsed.tool_calls[0].id, "call_123"); - assert_eq!(parsed.tool_calls[0].name, "shell"); + let json = serde_json::to_string(&request).unwrap(); + assert!(json.contains("gpt-4o")); + assert!(json.contains("hello")); + assert!(!json.contains("tool_call_id")); + assert!(!json.contains("reasoning_content")); } #[test] - fn convert_messages_for_native_maps_tool_result_payload() { - let input = vec![ChatMessage::tool( - r#"{"tool_call_id":"call_abc","content":"done"}"#, - )]; - - let converted = OpenAiCompatibleProvider::convert_messages_for_native(&input); - assert_eq!(converted.len(), 1); - assert_eq!(converted[0].role, "tool"); - assert_eq!(converted[0].tool_call_id.as_deref(), Some("call_abc")); - assert!(matches!( - converted[0].content.as_ref(), - Some(MessageContent::Text(value)) if value == "done" - )); + fn response_deserializes() { + let json = r#"{"choices":[{"message":{"content":"Hello!"}}]}"#; + let resp: ChatCompletionsResponse = serde_json::from_str(json).unwrap(); + assert_eq!(resp.choices[0].message.content.as_deref(), Some("Hello!")); } #[test] - fn flatten_system_messages_merges_into_first_user() { - let input = vec![ - ChatMessage::system("core policy"), - ChatMessage::assistant("ack"), - ChatMessage::system("delivery rules"), - ChatMessage::user("hello"), - ChatMessage::assistant("post-user"), - ]; - - let output = OpenAiCompatibleProvider::flatten_system_messages(&input); - assert_eq!(output.len(), 3); - assert_eq!(output[0].role, "assistant"); - assert_eq!(output[0].content, "ack"); - assert_eq!(output[1].role, "user"); - assert_eq!(output[1].content, "core policy\n\ndelivery rules\n\nhello"); - assert_eq!(output[2].role, "assistant"); - assert_eq!(output[2].content, "post-user"); - assert!(output.iter().all(|m| m.role != "system")); + fn response_with_tool_calls() { + let json = r#"{ + "choices":[{ + "message":{ + "content":null, + "tool_calls":[{ + "id":"call_1", + "type":"function", + "function":{"name":"shell","arguments":"{\"cmd\":\"ls\"}"} + }] + } + }], + "usage":{"prompt_tokens":10,"completion_tokens":5} + }"#; + let resp: ChatCompletionsResponse = serde_json::from_str(json).unwrap(); + let (result, usage) = OpenAiCompatibleProvider::parse_response(resp); + assert!(result.text.is_none()); + assert_eq!(result.tool_calls.len(), 1); + assert_eq!(result.tool_calls[0].name, "shell"); + let usage = usage.unwrap(); + assert_eq!(usage.input_tokens, Some(10)); + assert_eq!(usage.output_tokens, Some(5)); + } + + #[test] + fn tool_call_fallback_to_top_level_name() { + let tc = ToolCallIn { + id: Some("1".into()), + function: None, + name: Some("shell".into()), + arguments: Some("{}".into()), + parameters: None, + }; + assert_eq!(tc.function_name().as_deref(), Some("shell")); + assert_eq!(tc.function_arguments().as_deref(), Some("{}")); } #[test] - fn flatten_system_messages_inserts_user_when_missing() { - let input = vec![ - ChatMessage::system("core policy"), - ChatMessage::assistant("ack"), - ]; - - let output = OpenAiCompatibleProvider::flatten_system_messages(&input); - assert_eq!(output.len(), 2); - assert_eq!(output[0].role, "user"); - assert_eq!(output[0].content, "core policy"); - assert_eq!(output[1].role, "assistant"); - assert_eq!(output[1].content, "ack"); + fn tool_call_fallback_to_parameters() { + let tc = ToolCallIn { + id: Some("1".into()), + function: None, + name: Some("shell".into()), + arguments: None, + parameters: Some(serde_json::json!({"cmd": "ls"})), + }; + let args = tc.function_arguments().unwrap(); + assert!(args.contains("cmd")); } #[test] - fn strip_think_tags_drops_unclosed_block_suffix() { - let input = "visiblehidden"; - assert_eq!(strip_think_tags(input), "visible"); + fn convert_messages_handles_tool_call_history() { + let messages = vec![ChatMessage { + role: "assistant".into(), + content: r#"{"content":"checking","tool_calls":[{"id":"c1","name":"shell","arguments":"{}"}]}"#.into(), + }]; + let converted = OpenAiCompatibleProvider::convert_messages(&messages); + assert_eq!(converted[0].role, "assistant"); + assert_eq!(converted[0].content.as_deref(), Some("checking")); + assert!(converted[0].tool_calls.is_some()); } #[test] - fn native_tool_schema_unsupported_detection_is_precise() { - assert!(OpenAiCompatibleProvider::is_native_tool_schema_unsupported( - reqwest::StatusCode::BAD_REQUEST, - "unknown parameter: tools" - )); - assert!( - !OpenAiCompatibleProvider::is_native_tool_schema_unsupported( - reqwest::StatusCode::UNAUTHORIZED, - "unknown parameter: tools" - ) - ); + fn convert_messages_handles_tool_result() { + let messages = vec![ChatMessage { + role: "tool".into(), + content: r#"{"tool_call_id":"c1","content":"done"}"#.into(), + }]; + let converted = OpenAiCompatibleProvider::convert_messages(&messages); + assert_eq!(converted[0].role, "tool"); + assert_eq!(converted[0].tool_call_id.as_deref(), Some("c1")); + assert_eq!(converted[0].content.as_deref(), Some("done")); } #[test] - fn prompt_guided_tool_fallback_injects_system_instruction() { - let input = vec![ChatMessage::user("check status")]; + fn convert_tools_maps_spec() { let tools = vec![ToolSpec { - name: "shell_exec".to_string(), - description: "Execute shell command".to_string(), - parameters: serde_json::json!({ - "type": "object", - "properties": { - "command": { "type": "string" } - }, - "required": ["command"] - }), + name: "shell".to_string(), + description: "Run command".to_string(), + parameters: serde_json::json!({"type": "object"}), }]; - - let output = - OpenAiCompatibleProvider::with_prompt_guided_tool_instructions(&input, Some(&tools)); - assert!(!output.is_empty()); - assert_eq!(output[0].role, "system"); - assert!( - output[0].content.contains("Available Tools") - || output[0].content.contains("Tool Use Protocol") - ); - assert!(output[0].content.contains("shell_exec")); + let converted = OpenAiCompatibleProvider::convert_tools(Some(&tools)).unwrap(); + assert_eq!(converted.len(), 1); + assert!(converted[0]["function"]["name"] == "shell"); } - #[tokio::test] - async fn warmup_without_key_is_noop() { - let provider = make_provider("test", "https://example.com", None); - let result = provider.warmup().await; - assert!(result.is_ok()); + #[test] + fn convert_tools_returns_none_for_empty() { + assert!(OpenAiCompatibleProvider::convert_tools(Some(&[])).is_none()); + assert!(OpenAiCompatibleProvider::convert_tools(None).is_none()); } #[test] - fn capabilities_reports_native_tool_calling() { - let p = make_provider("test", "https://example.com", None); - let caps = ::capabilities(&p); - assert!(caps.native_tool_calling); - assert!(!caps.vision); + fn sanitize_api_error_redacts_keys() { + let input = "Error with key sk-ant-api-12345678 in request"; + let result = sanitize_api_error(input); + assert!(result.contains("[REDACTED]")); + assert!(!result.contains("sk-ant-api")); } #[test] - fn capabilities_reports_vision_for_qwen_compatible_provider() { - let p = OpenAiCompatibleProvider::new_with_vision( - "Qwen", - "https://dashscope.aliyuncs.com/compatible-mode/v1", - Some("k"), - AuthStyle::Bearer, - true, - ); - let caps = ::capabilities(&p); - assert!(caps.native_tool_calling); - assert!(caps.vision); + fn sanitize_api_error_truncates() { + let long = "x".repeat(300); + let result = sanitize_api_error(&long); + assert!(result.ends_with("...")); + assert!(result.len() < 210); } #[test] - fn minimax_provider_disables_native_tool_calling() { - let p = OpenAiCompatibleProvider::new_merge_system_into_user( - "MiniMax", - "https://api.minimax.chat/v1", - Some("k"), - AuthStyle::Bearer, - ); - let caps = ::capabilities(&p); - assert!( - !caps.native_tool_calling, - "MiniMax should use prompt-guided tool calling, not native" - ); - assert!(!caps.vision); + fn reasoning_content_fallback() { + let msg = ResponseMessage { + content: None, + reasoning_content: Some("thinking...".into()), + tool_calls: None, + }; + assert_eq!(msg.effective_content().as_deref(), Some("thinking...")); } #[test] - fn no_responses_fallback_constructor_keeps_native_tool_calling_enabled() { - let p = OpenAiCompatibleProvider::new_no_responses_fallback( - "FallbackProvider", - "https://example.com", - Some("k"), - AuthStyle::Bearer, - ); - let caps = ::capabilities(&p); + fn capabilities_reports_native_tools() { + let p = make_provider("test", "https://api.example.com/v1", None); + let caps = p.capabilities(); assert!(caps.native_tool_calling); assert!(!caps.vision); - assert!(p.user_agent.is_none()); } - #[test] - fn tool_specs_convert_to_openai_format() { - let specs = vec![ToolSpec { - name: "shell".to_string(), - description: "Run shell command".to_string(), - parameters: serde_json::json!({ - "type": "object", - "properties": {"command": {"type": "string"}}, - "required": ["command"] - }), - }]; - - let tools = OpenAiCompatibleProvider::tool_specs_to_openai_format(&specs); - assert_eq!(tools.len(), 1); - assert_eq!(tools[0]["type"], "function"); - assert_eq!(tools[0]["function"]["name"], "shell"); - assert_eq!(tools[0]["function"]["description"], "Run shell command"); - assert_eq!(tools[0]["function"]["parameters"]["required"][0], "command"); - } - - #[test] - fn request_serializes_with_tools() { - let tools = vec![serde_json::json!({ - "type": "function", - "function": { - "name": "get_weather", - "description": "Get weather for a location", - "parameters": { - "type": "object", - "properties": { - "location": {"type": "string"} - } - } - } - })]; - - let req = ApiChatRequest { - model: "test-model".to_string(), - messages: vec![Message { - role: "user".to_string(), - content: MessageContent::Text("What is the weather?".to_string()), - }], - temperature: 0.7, - stream: Some(false), - tools: Some(tools), - tool_choice: Some("auto".to_string()), - }; - let json = serde_json::to_string(&req).unwrap(); - assert!(json.contains("\"tools\"")); - assert!(json.contains("get_weather")); - assert!(json.contains("\"tool_choice\":\"auto\"")); + #[tokio::test] + async fn warmup_without_key_is_noop() { + let p = make_provider("test", "https://api.example.com/v1", None); + assert!(p.warmup().await.is_ok()); } #[test] - fn response_with_tool_calls_deserializes() { - let json = r#"{ - "choices": [{ - "message": { - "content": null, - "tool_calls": [{ - "type": "function", - "function": { - "name": "get_weather", - "arguments": "{\"location\":\"London\"}" - } - }] - } - }] - }"#; - - let resp: ApiChatResponse = serde_json::from_str(json).unwrap(); - let msg = &resp.choices[0].message; - assert!(msg.content.is_none()); - let tool_calls = msg.tool_calls.as_ref().unwrap(); - assert_eq!(tool_calls.len(), 1); - assert_eq!( - tool_calls[0].function.as_ref().unwrap().name.as_deref(), - Some("get_weather") - ); - } - - #[tokio::test] - async fn chat_with_tools_fails_without_key() { - let p = make_provider("TestProvider", "https://example.com", None); + fn reasoning_content_round_trips_in_convert() { let messages = vec![ChatMessage { - role: "user".to_string(), - content: "hello".to_string(), + role: "assistant".into(), + content: r#"{"content":"ok","tool_calls":[{"id":"c1","name":"shell","arguments":"{}"}],"reasoning_content":"let me think"}"#.into(), }]; - let tools = vec![serde_json::json!({ - "type": "function", - "function": { - "name": "test_tool", - "description": "A test tool", - "parameters": {} - } - })]; - - let result = p.chat_with_tools(&messages, &tools, "model", 0.7).await; - assert!(result.is_err()); - assert!( - result - .unwrap_err() - .to_string() - .contains("TestProvider API key not set") + let converted = OpenAiCompatibleProvider::convert_messages(&messages); + assert_eq!( + converted[0].reasoning_content.as_deref(), + Some("let me think") ); } #[test] - fn strip_think_tags_removes_multiple_blocks_with_surrounding_text() { - let input = "Answer A hidden 1 and B hidden 2 done"; - let output = strip_think_tags(input); - assert_eq!(output, "Answer A and B done"); - } - - #[test] - fn reasoning_content_fallback_when_content_empty() { - let json = r#"{"choices":[{"message":{"content":"","reasoning_content":"Thinking output here"}}]}"#; - let resp: ApiChatResponse = serde_json::from_str(json).unwrap(); - let msg = &resp.choices[0].message; - assert_eq!(msg.effective_content(), "Thinking output here"); - } - - #[test] - fn api_response_parses_usage() { - let json = r#"{ - "choices": [{"message": {"content": "Hello"}}], - "usage": {"prompt_tokens": 150, "completion_tokens": 60} - }"#; - let resp: ApiChatResponse = serde_json::from_str(json).unwrap(); - let usage = resp.usage.unwrap(); - assert_eq!(usage.prompt_tokens, Some(150)); - assert_eq!(usage.completion_tokens, Some(60)); - } - - #[test] - fn parse_native_response_captures_reasoning_content() { - let message = ResponseMessage { - content: Some("answer".to_string()), - reasoning_content: Some("thinking step".to_string()), - tool_calls: Some(vec![ToolCall { - id: Some("call_1".to_string()), - kind: Some("function".to_string()), - function: Some(Function { - name: Some("shell".to_string()), - arguments: Some(r#"{"cmd":"ls"}"#.to_string()), - }), - name: None, - arguments: None, - parameters: None, - }]), - }; - - let parsed = OpenAiCompatibleProvider::parse_native_response(message); - assert_eq!(parsed.reasoning_content.as_deref(), Some("thinking step")); - assert_eq!(parsed.text.as_deref(), Some("answer")); - assert_eq!(parsed.tool_calls.len(), 1); - } - - #[test] - fn convert_messages_for_native_reasoning_content_serialized_only_when_present() { - let msg_without = NativeMessage { + fn reasoning_content_omitted_when_none() { + let msg = RequestMessage { role: "assistant".to_string(), - content: Some(MessageContent::Text("hi".to_string())), + content: Some("hi".into()), tool_call_id: None, tool_calls: None, reasoning_content: None, }; - let json = serde_json::to_string(&msg_without).unwrap(); - assert!( - !json.contains("reasoning_content"), - "reasoning_content should be omitted when None" - ); - - let msg_with = NativeMessage { - role: "assistant".to_string(), - content: Some(MessageContent::Text("hi".to_string())), - tool_call_id: None, - tool_calls: None, - reasoning_content: Some("thinking...".to_string()), - }; - let json = serde_json::to_string(&msg_with).unwrap(); - assert!( - json.contains("reasoning_content"), - "reasoning_content should be present when Some" - ); - assert!(json.contains("thinking...")); + let json = serde_json::to_string(&msg).unwrap(); + assert!(!json.contains("reasoning_content")); } } diff --git a/crewforge-rs/src/provider/copilot.rs b/crewforge-rs/src/provider/copilot.rs deleted file mode 100644 index 08ed40d..0000000 --- a/crewforge-rs/src/provider/copilot.rs +++ /dev/null @@ -1,745 +0,0 @@ -//! GitHub Copilot provider with OAuth device-flow authentication. -//! -//! Authenticates via GitHub's device code flow (same as VS Code Copilot), -//! then exchanges the OAuth token for short-lived Copilot API keys. -//! Tokens are cached to disk and auto-refreshed. -//! -//! **Note:** This uses VS Code's OAuth client ID (`Iv1.b507a08c87ecfe98`) and -//! editor headers. This is the same approach used by LiteLLM, Codex CLI, -//! and other third-party Copilot integrations. The Copilot token endpoint is -//! private; there is no public OAuth scope or app registration for it. -//! GitHub could change or revoke this at any time, which would break all -//! third-party integrations simultaneously. - -use crate::provider::traits::{ - ChatMessage, ChatRequest as ProviderChatRequest, ChatResponse as ProviderChatResponse, - Provider, ProviderCapabilities, TokenUsage, ToolCall as ProviderToolCall, ToolSpec, -}; -use async_trait::async_trait; -use reqwest::Client; -use serde::{Deserialize, Serialize}; -use std::path::{Path, PathBuf}; -use std::sync::Arc; -use std::time::Duration; -use tokio::sync::Mutex; -use tracing::warn; - -/// GitHub OAuth client ID for Copilot (VS Code extension). -const GITHUB_CLIENT_ID: &str = "Iv1.b507a08c87ecfe98"; -const GITHUB_DEVICE_CODE_URL: &str = "https://github.com/login/device/code"; -const GITHUB_ACCESS_TOKEN_URL: &str = "https://github.com/login/oauth/access_token"; -const GITHUB_API_KEY_URL: &str = "https://api.github.com/copilot_internal/v2/token"; -const DEFAULT_API: &str = "https://api.githubcopilot.com"; - -// ── Token types ────────────────────────────────────────────────── - -#[derive(Debug, Deserialize)] -struct DeviceCodeResponse { - device_code: String, - user_code: String, - verification_uri: String, - #[serde(default = "default_interval")] - interval: u64, - #[serde(default = "default_expires_in")] - expires_in: u64, -} - -fn default_interval() -> u64 { - 5 -} - -fn default_expires_in() -> u64 { - 900 -} - -#[derive(Debug, Deserialize)] -struct AccessTokenResponse { - access_token: Option, - error: Option, -} - -#[derive(Debug, Serialize, Deserialize)] -struct ApiKeyInfo { - token: String, - expires_at: i64, - #[serde(default)] - endpoints: Option, -} - -#[derive(Debug, Serialize, Deserialize)] -struct ApiEndpoints { - api: Option, -} - -struct CachedApiKey { - token: String, - api_endpoint: String, - expires_at: i64, -} - -// ── Chat completions types ─────────────────────────────────────── - -#[derive(Debug, Serialize)] -struct ApiChatRequest<'a> { - model: String, - messages: Vec, - temperature: f64, - #[serde(skip_serializing_if = "Option::is_none")] - tools: Option>>, - #[serde(skip_serializing_if = "Option::is_none")] - tool_choice: Option, -} - -#[derive(Debug, Serialize)] -struct ApiMessage { - role: String, - #[serde(skip_serializing_if = "Option::is_none")] - content: Option, - #[serde(skip_serializing_if = "Option::is_none")] - tool_call_id: Option, - #[serde(skip_serializing_if = "Option::is_none")] - tool_calls: Option>, -} - -#[derive(Debug, Serialize)] -struct NativeToolSpec<'a> { - #[serde(rename = "type")] - kind: &'static str, - function: NativeToolFunctionSpec<'a>, -} - -#[derive(Debug, Serialize)] -struct NativeToolFunctionSpec<'a> { - name: &'a str, - description: &'a str, - parameters: &'a serde_json::Value, -} - -#[derive(Debug, Serialize, Deserialize)] -struct NativeToolCall { - #[serde(skip_serializing_if = "Option::is_none")] - id: Option, - #[serde(rename = "type", skip_serializing_if = "Option::is_none")] - kind: Option, - function: NativeFunctionCall, -} - -#[derive(Debug, Serialize, Deserialize)] -struct NativeFunctionCall { - name: String, - arguments: String, -} - -#[derive(Debug, Deserialize)] -struct ApiChatResponse { - choices: Vec, - #[serde(default)] - usage: Option, -} - -#[derive(Debug, Deserialize)] -struct UsageInfo { - #[serde(default)] - prompt_tokens: Option, - #[serde(default)] - completion_tokens: Option, -} - -#[derive(Debug, Deserialize)] -struct Choice { - message: ResponseMessage, -} - -#[derive(Debug, Deserialize)] -struct ResponseMessage { - #[serde(default)] - content: Option, - #[serde(default)] - tool_calls: Option>, -} - -// ── Provider ───────────────────────────────────────────────────── - -/// GitHub Copilot provider with automatic OAuth and token refresh. -/// -/// On first use, prompts the user to visit github.com/login/device. -/// Tokens are cached to `~/.config/crewforge/copilot/` and refreshed -/// automatically. -pub struct CopilotProvider { - github_token: Option, - /// Mutex ensures only one caller refreshes tokens at a time, - /// preventing duplicate device flow prompts or redundant API calls. - refresh_lock: Arc>>, - token_dir: PathBuf, -} - -impl CopilotProvider { - pub fn new(github_token: Option<&str>) -> Self { - let token_dir = directories::ProjectDirs::from("", "", "crewforge") - .map(|dir| dir.config_dir().join("copilot")) - .unwrap_or_else(|| { - // Fall back to a user-specific temp directory to avoid - // shared-directory symlink attacks. - let user = std::env::var("USER") - .or_else(|_| std::env::var("USERNAME")) - .unwrap_or_else(|_| "unknown".to_string()); - std::env::temp_dir().join(format!("crewforge-copilot-{user}")) - }); - - if let Err(err) = std::fs::create_dir_all(&token_dir) { - warn!( - "Failed to create Copilot token directory {:?}: {err}. Token caching is disabled.", - token_dir - ); - } else { - #[cfg(unix)] - { - use std::os::unix::fs::PermissionsExt; - - if let Err(err) = - std::fs::set_permissions(&token_dir, std::fs::Permissions::from_mode(0o700)) - { - warn!( - "Failed to set Copilot token directory permissions on {:?}: {err}", - token_dir - ); - } - } - } - - Self { - github_token: github_token - .filter(|token| !token.is_empty()) - .map(String::from), - refresh_lock: Arc::new(Mutex::new(None)), - token_dir, - } - } - - fn http_client(&self) -> Client { - reqwest::Client::builder() - .timeout(Duration::from_secs(120)) - .connect_timeout(Duration::from_secs(10)) - .build() - .unwrap_or_default() - } - - /// Required headers for Copilot API requests (editor identification). - const COPILOT_HEADERS: [(&'static str, &'static str); 4] = [ - ("Editor-Version", "vscode/1.85.1"), - ("Editor-Plugin-Version", "copilot/1.155.0"), - ("User-Agent", "GithubCopilot/1.155.0"), - ("Accept", "application/json"), - ]; - - fn convert_tools(tools: Option<&[ToolSpec]>) -> Option>> { - tools.map(|items| { - items - .iter() - .map(|tool| NativeToolSpec { - kind: "function", - function: NativeToolFunctionSpec { - name: &tool.name, - description: &tool.description, - parameters: &tool.parameters, - }, - }) - .collect() - }) - } - - fn convert_messages(messages: &[ChatMessage]) -> Vec { - messages - .iter() - .map(|message| { - if message.role == "assistant" - && let Ok(value) = serde_json::from_str::(&message.content) - && let Some(tool_calls_value) = value.get("tool_calls") - && let Ok(parsed_calls) = - serde_json::from_value::>(tool_calls_value.clone()) - { - let tool_calls = parsed_calls - .into_iter() - .map(|tool_call| NativeToolCall { - id: Some(tool_call.id), - kind: Some("function".to_string()), - function: NativeFunctionCall { - name: tool_call.name, - arguments: tool_call.arguments, - }, - }) - .collect::>(); - - let content = value - .get("content") - .and_then(serde_json::Value::as_str) - .map(ToString::to_string); - - return ApiMessage { - role: "assistant".to_string(), - content, - tool_call_id: None, - tool_calls: Some(tool_calls), - }; - } - - if message.role == "tool" - && let Ok(value) = serde_json::from_str::(&message.content) - { - let tool_call_id = value - .get("tool_call_id") - .and_then(serde_json::Value::as_str) - .map(ToString::to_string); - let content = value - .get("content") - .and_then(serde_json::Value::as_str) - .map(ToString::to_string); - - return ApiMessage { - role: "tool".to_string(), - content, - tool_call_id, - tool_calls: None, - }; - } - - ApiMessage { - role: message.role.clone(), - content: Some(message.content.clone()), - tool_call_id: None, - tool_calls: None, - } - }) - .collect() - } - - /// Send a chat completions request with required Copilot headers. - async fn send_chat_request( - &self, - messages: Vec, - tools: Option<&[ToolSpec]>, - model: &str, - temperature: f64, - ) -> anyhow::Result { - let (token, endpoint) = self.get_api_key().await?; - let url = format!("{}/chat/completions", endpoint.trim_end_matches('/')); - - let native_tools = Self::convert_tools(tools); - let request = ApiChatRequest { - model: model.to_string(), - messages, - temperature, - tool_choice: native_tools.as_ref().map(|_| "auto".to_string()), - tools: native_tools, - }; - - let mut req = self - .http_client() - .post(&url) - .header("Authorization", format!("Bearer {token}")) - .json(&request); - - for (header, value) in &Self::COPILOT_HEADERS { - req = req.header(*header, *value); - } - - let response = req.send().await?; - - if !response.status().is_success() { - return Err(super::api_error("GitHub Copilot", response).await); - } - - let api_response: ApiChatResponse = response.json().await?; - let usage = api_response.usage.map(|u| TokenUsage { - input_tokens: u.prompt_tokens, - output_tokens: u.completion_tokens, - }); - let choice = api_response - .choices - .into_iter() - .next() - .ok_or_else(|| anyhow::anyhow!("No response from GitHub Copilot"))?; - - let tool_calls = choice - .message - .tool_calls - .unwrap_or_default() - .into_iter() - .map(|tool_call| ProviderToolCall { - id: tool_call - .id - .unwrap_or_else(|| uuid::Uuid::new_v4().to_string()), - name: tool_call.function.name, - arguments: tool_call.function.arguments, - }) - .collect(); - - Ok(ProviderChatResponse { - text: choice.message.content, - tool_calls, - usage, - reasoning_content: None, - }) - } - - /// Get a valid Copilot API key, refreshing or re-authenticating as needed. - /// Uses a Mutex to ensure only one caller refreshes at a time. - async fn get_api_key(&self) -> anyhow::Result<(String, String)> { - let mut cached = self.refresh_lock.lock().await; - - if let Some(cached_key) = cached.as_ref() - && chrono::Utc::now().timestamp() + 120 < cached_key.expires_at - { - return Ok((cached_key.token.clone(), cached_key.api_endpoint.clone())); - } - - if let Some(info) = self.load_api_key_from_disk().await - && chrono::Utc::now().timestamp() + 120 < info.expires_at - { - let endpoint = info - .endpoints - .as_ref() - .and_then(|e| e.api.clone()) - .unwrap_or_else(|| DEFAULT_API.to_string()); - let token = info.token; - - *cached = Some(CachedApiKey { - token: token.clone(), - api_endpoint: endpoint.clone(), - expires_at: info.expires_at, - }); - return Ok((token, endpoint)); - } - - let access_token = self.get_github_access_token().await?; - let api_key_info = self.exchange_for_api_key(&access_token).await?; - self.save_api_key_to_disk(&api_key_info).await; - - let endpoint = api_key_info - .endpoints - .as_ref() - .and_then(|e| e.api.clone()) - .unwrap_or_else(|| DEFAULT_API.to_string()); - - *cached = Some(CachedApiKey { - token: api_key_info.token.clone(), - api_endpoint: endpoint.clone(), - expires_at: api_key_info.expires_at, - }); - - Ok((api_key_info.token, endpoint)) - } - - /// Get a GitHub access token from config, cache, or device flow. - async fn get_github_access_token(&self) -> anyhow::Result { - if let Some(token) = &self.github_token { - return Ok(token.clone()); - } - - let access_token_path = self.token_dir.join("access-token"); - if let Ok(cached) = tokio::fs::read_to_string(&access_token_path).await { - let token = cached.trim(); - if !token.is_empty() { - return Ok(token.to_string()); - } - } - - let token = self.device_code_login().await?; - write_file_secure(&access_token_path, &token).await; - Ok(token) - } - - /// Run GitHub OAuth device code flow. - async fn device_code_login(&self) -> anyhow::Result { - let response: DeviceCodeResponse = self - .http_client() - .post(GITHUB_DEVICE_CODE_URL) - .header("Accept", "application/json") - .json(&serde_json::json!({ - "client_id": GITHUB_CLIENT_ID, - "scope": "read:user" - })) - .send() - .await? - .error_for_status()? - .json() - .await?; - - let mut poll_interval = Duration::from_secs(response.interval.max(5)); - let expires_in = response.expires_in.max(1); - let expires_at = tokio::time::Instant::now() + Duration::from_secs(expires_in); - - eprintln!( - "\nGitHub Copilot authentication is required.\n\ - Visit: {}\n\ - Code: {}\n\ - Waiting for authorization...\n", - response.verification_uri, response.user_code - ); - - while tokio::time::Instant::now() < expires_at { - tokio::time::sleep(poll_interval).await; - - let token_response: AccessTokenResponse = self - .http_client() - .post(GITHUB_ACCESS_TOKEN_URL) - .header("Accept", "application/json") - .json(&serde_json::json!({ - "client_id": GITHUB_CLIENT_ID, - "device_code": response.device_code, - "grant_type": "urn:ietf:params:oauth:grant-type:device_code" - })) - .send() - .await? - .json() - .await?; - - if let Some(token) = token_response.access_token { - eprintln!("Authentication succeeded.\n"); - return Ok(token); - } - - match token_response.error.as_deref() { - Some("slow_down") => { - poll_interval += Duration::from_secs(5); - } - Some("authorization_pending") | None => {} - Some("expired_token") => { - anyhow::bail!("GitHub device authorization expired") - } - Some(error) => anyhow::bail!("GitHub auth failed: {error}"), - } - } - - anyhow::bail!("Timed out waiting for GitHub authorization") - } - - /// Exchange a GitHub access token for a Copilot API key. - async fn exchange_for_api_key(&self, access_token: &str) -> anyhow::Result { - let mut request = self.http_client().get(GITHUB_API_KEY_URL); - for (header, value) in &Self::COPILOT_HEADERS { - request = request.header(*header, *value); - } - request = request.header("Authorization", format!("token {access_token}")); - - let response = request.send().await?; - - if !response.status().is_success() { - let status = response.status(); - let body = response.text().await.unwrap_or_default(); - let sanitized = super::sanitize_api_error(&body); - - if status.as_u16() == 401 || status.as_u16() == 403 { - let access_token_path = self.token_dir.join("access-token"); - tokio::fs::remove_file(&access_token_path).await.ok(); - } - - anyhow::bail!( - "Failed to get Copilot API key ({status}): {sanitized}. \ - Ensure your GitHub account has an active Copilot subscription." - ); - } - - let info: ApiKeyInfo = response.json().await?; - Ok(info) - } - - async fn load_api_key_from_disk(&self) -> Option { - let path = self.token_dir.join("api-key.json"); - let data = tokio::fs::read_to_string(&path).await.ok()?; - serde_json::from_str(&data).ok() - } - - async fn save_api_key_to_disk(&self, info: &ApiKeyInfo) { - let path = self.token_dir.join("api-key.json"); - if let Ok(json) = serde_json::to_string_pretty(info) { - write_file_secure(&path, &json).await; - } - } -} - -/// Write a file with 0600 permissions (owner read/write only). -/// Uses `spawn_blocking` to avoid blocking the async runtime. -async fn write_file_secure(path: &Path, content: &str) { - let path = path.to_path_buf(); - let content = content.to_string(); - - let result = tokio::task::spawn_blocking(move || { - #[cfg(unix)] - { - use std::io::Write; - use std::os::unix::fs::{OpenOptionsExt, PermissionsExt}; - - let mut file = std::fs::OpenOptions::new() - .write(true) - .create(true) - .truncate(true) - .mode(0o600) - .open(&path)?; - file.write_all(content.as_bytes())?; - - std::fs::set_permissions(&path, std::fs::Permissions::from_mode(0o600))?; - Ok::<(), std::io::Error>(()) - } - #[cfg(not(unix))] - { - std::fs::write(&path, &content)?; - Ok::<(), std::io::Error>(()) - } - }) - .await; - - match result { - Ok(Ok(())) => {} - Ok(Err(err)) => warn!("Failed to write secure file: {err}"), - Err(err) => warn!("Failed to spawn blocking write: {err}"), - } -} - -#[async_trait] -impl Provider for CopilotProvider { - fn capabilities(&self) -> ProviderCapabilities { - ProviderCapabilities { - native_tool_calling: true, - vision: false, - } - } - - async fn chat_with_system( - &self, - system_prompt: Option<&str>, - message: &str, - model: &str, - temperature: f64, - ) -> anyhow::Result { - let mut messages = Vec::new(); - if let Some(system) = system_prompt { - messages.push(ApiMessage { - role: "system".to_string(), - content: Some(system.to_string()), - tool_call_id: None, - tool_calls: None, - }); - } - messages.push(ApiMessage { - role: "user".to_string(), - content: Some(message.to_string()), - tool_call_id: None, - tool_calls: None, - }); - - let response = self - .send_chat_request(messages, None, model, temperature) - .await?; - Ok(response.text.unwrap_or_default()) - } - - async fn chat_with_history( - &self, - messages: &[ChatMessage], - model: &str, - temperature: f64, - ) -> anyhow::Result { - let response = self - .send_chat_request(Self::convert_messages(messages), None, model, temperature) - .await?; - Ok(response.text.unwrap_or_default()) - } - - async fn chat( - &self, - request: ProviderChatRequest<'_>, - model: &str, - temperature: f64, - ) -> anyhow::Result { - self.send_chat_request( - Self::convert_messages(request.messages), - request.tools, - model, - temperature, - ) - .await - } - - async fn warmup(&self) -> anyhow::Result<()> { - let _ = self.get_api_key().await?; - Ok(()) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn new_without_token() { - let provider = CopilotProvider::new(None); - assert!(provider.github_token.is_none()); - } - - #[test] - fn new_with_token() { - let provider = CopilotProvider::new(Some("ghp_test")); - assert_eq!(provider.github_token.as_deref(), Some("ghp_test")); - } - - #[test] - fn empty_token_treated_as_none() { - let provider = CopilotProvider::new(Some("")); - assert!(provider.github_token.is_none()); - } - - #[tokio::test] - async fn cache_starts_empty() { - let provider = CopilotProvider::new(None); - let cached = provider.refresh_lock.lock().await; - assert!(cached.is_none()); - } - - #[test] - fn copilot_headers_include_required_fields() { - let headers = CopilotProvider::COPILOT_HEADERS; - assert!( - headers - .iter() - .any(|(header, _)| *header == "Editor-Version") - ); - assert!( - headers - .iter() - .any(|(header, _)| *header == "Editor-Plugin-Version") - ); - assert!(headers.iter().any(|(header, _)| *header == "User-Agent")); - } - - #[test] - fn default_interval_and_expiry() { - assert_eq!(default_interval(), 5); - assert_eq!(default_expires_in(), 900); - } - - #[test] - fn supports_native_tools() { - let provider = CopilotProvider::new(None); - assert!(provider.supports_native_tools()); - } - - #[test] - fn api_response_parses_usage() { - let json = r#"{ - "choices": [{"message": {"content": "Hello"}}], - "usage": {"prompt_tokens": 200, "completion_tokens": 80} - }"#; - let resp: ApiChatResponse = serde_json::from_str(json).unwrap(); - let usage = resp.usage.unwrap(); - assert_eq!(usage.prompt_tokens, Some(200)); - assert_eq!(usage.completion_tokens, Some(80)); - } - - #[test] - fn api_response_parses_without_usage() { - let json = r#"{"choices": [{"message": {"content": "Hello"}}]}"#; - let resp: ApiChatResponse = serde_json::from_str(json).unwrap(); - assert!(resp.usage.is_none()); - } -} diff --git a/crewforge-rs/src/provider/gemini.rs b/crewforge-rs/src/provider/gemini.rs deleted file mode 100644 index d6196e5..0000000 --- a/crewforge-rs/src/provider/gemini.rs +++ /dev/null @@ -1,496 +0,0 @@ -//! Google Gemini provider with API key authentication. -//! Supports GEMINI_API_KEY and GOOGLE_API_KEY environment variables. - -use crate::provider::traits::{ChatMessage, ChatResponse, Provider, TokenUsage}; -use async_trait::async_trait; -use reqwest::Client; -use serde::{Deserialize, Serialize}; - -/// Gemini provider supporting API key authentication. -pub struct GeminiProvider { - api_key: Option, -} - -// ══════════════════════════════════════════════════════════════════════════════ -// API REQUEST/RESPONSE TYPES -// ══════════════════════════════════════════════════════════════════════════════ - -#[derive(Debug, Serialize, Clone)] -struct GenerateContentRequest { - contents: Vec, - #[serde(rename = "systemInstruction", skip_serializing_if = "Option::is_none")] - system_instruction: Option, - #[serde(rename = "generationConfig")] - generation_config: GenerationConfig, -} - -#[derive(Debug, Serialize, Clone)] -struct Content { - #[serde(skip_serializing_if = "Option::is_none")] - role: Option, - parts: Vec, -} - -#[derive(Debug, Serialize, Clone)] -struct Part { - text: String, -} - -#[derive(Debug, Serialize, Clone)] -struct GenerationConfig { - temperature: f64, - #[serde(rename = "maxOutputTokens")] - max_output_tokens: u32, -} - -#[derive(Debug, Deserialize)] -struct GenerateContentResponse { - candidates: Option>, - error: Option, - #[serde(default, rename = "usageMetadata")] - usage_metadata: Option, -} - -#[derive(Debug, Deserialize)] -struct GeminiUsageMetadata { - #[serde(default, rename = "promptTokenCount")] - prompt_token_count: Option, - #[serde(default, rename = "candidatesTokenCount")] - candidates_token_count: Option, -} - -#[derive(Debug, Deserialize)] -struct Candidate { - #[serde(default)] - content: Option, -} - -#[derive(Debug, Deserialize)] -struct CandidateContent { - parts: Vec, -} - -#[derive(Debug, Deserialize)] -struct ResponsePart { - #[serde(default)] - text: Option, - /// Thinking models (e.g. gemini-3-pro-preview) mark reasoning parts with `thought: true`. - #[serde(default)] - thought: bool, -} - -impl CandidateContent { - /// Extract effective text, skipping thinking/signature parts. - /// - /// Gemini thinking models return parts like: - /// - `{"thought": true, "text": "reasoning..."}` — internal reasoning - /// - `{"text": "actual answer"}` — the real response - /// - `{"thoughtSignature": "..."}` — opaque signature (no text field) - /// - /// Returns the non-thinking text, falling back to thinking text only when - /// no non-thinking content is available. - fn effective_text(self) -> Option { - let mut answer_parts: Vec = Vec::new(); - let mut first_thinking: Option = None; - - for part in self.parts { - if let Some(text) = part.text { - if text.is_empty() { - continue; - } - if !part.thought { - answer_parts.push(text); - } else if first_thinking.is_none() { - first_thinking = Some(text); - } - } - } - - if answer_parts.is_empty() { - first_thinking - } else { - Some(answer_parts.join("")) - } - } -} - -#[derive(Debug, Deserialize)] -struct ApiError { - message: String, -} - -/// Public API endpoint for API key users. -const PUBLIC_API_ENDPOINT: &str = "https://generativelanguage.googleapis.com/v1beta"; - -impl GeminiProvider { - /// Create a new Gemini provider. - /// - /// Authentication priority: - /// 1. Explicit API key passed in - /// 2. `GEMINI_API_KEY` environment variable - /// 3. `GOOGLE_API_KEY` environment variable - pub fn new(api_key: Option<&str>) -> Self { - let resolved_key = api_key - .and_then(Self::normalize_non_empty) - .or_else(|| Self::load_non_empty_env("GEMINI_API_KEY")) - .or_else(|| Self::load_non_empty_env("GOOGLE_API_KEY")); - - Self { - api_key: resolved_key, - } - } - - fn normalize_non_empty(value: &str) -> Option { - let trimmed = value.trim(); - if trimmed.is_empty() { - None - } else { - Some(trimmed.to_string()) - } - } - - fn load_non_empty_env(name: &str) -> Option { - std::env::var(name) - .ok() - .and_then(|value| Self::normalize_non_empty(&value)) - } - - fn format_model_name(model: &str) -> String { - if model.starts_with("models/") { - model.to_string() - } else { - format!("models/{model}") - } - } - - fn build_generate_content_url(&self, model: &str) -> String { - let model_name = Self::format_model_name(model); - let base_url = format!("{PUBLIC_API_ENDPOINT}/{model_name}:generateContent"); - if let Some(key) = &self.api_key { - format!("{base_url}?key={key}") - } else { - base_url - } - } - - fn http_client(&self) -> Client { - reqwest::Client::builder() - .timeout(std::time::Duration::from_secs(120)) - .build() - .unwrap_or_default() - } - - async fn send_generate_content( - &self, - contents: Vec, - system_instruction: Option, - model: &str, - temperature: f64, - ) -> anyhow::Result<(String, Option)> { - if self.api_key.is_none() { - anyhow::bail!( - "Gemini API key not found. Set GEMINI_API_KEY or GOOGLE_API_KEY environment variable, \ - or pass the key directly." - ); - } - - let request = GenerateContentRequest { - contents, - system_instruction, - generation_config: GenerationConfig { - temperature, - max_output_tokens: 8192, - }, - }; - - let url = self.build_generate_content_url(model); - - let response = self.http_client().post(&url).json(&request).send().await?; - - if !response.status().is_success() { - let status = response.status(); - let body = response.text().await.unwrap_or_default(); - anyhow::bail!("Gemini API error (HTTP {status}): {body}"); - } - - let parsed: GenerateContentResponse = response.json().await?; - - // Check for API-level errors embedded in a 200 response - if let Some(err) = parsed.error { - anyhow::bail!("Gemini API error: {}", err.message); - } - - let usage = parsed.usage_metadata.map(|u| TokenUsage { - input_tokens: u.prompt_token_count, - output_tokens: u.candidates_token_count, - }); - - let text = parsed - .candidates - .and_then(|candidates| candidates.into_iter().next()) - .and_then(|c| c.content) - .and_then(|c| c.effective_text()) - .ok_or_else(|| anyhow::anyhow!("No response from Gemini"))?; - - Ok((text, usage)) - } - - fn messages_to_contents(messages: &[ChatMessage]) -> Vec { - messages - .iter() - .filter(|m| m.role != "system") - .map(|m| { - let role = if m.role == "assistant" { - "model" - } else { - "user" - }; - Content { - role: Some(role.to_string()), - parts: vec![Part { - text: m.content.clone(), - }], - } - }) - .collect() - } -} - -#[async_trait] -impl Provider for GeminiProvider { - async fn chat_with_system( - &self, - system_prompt: Option<&str>, - message: &str, - model: &str, - temperature: f64, - ) -> anyhow::Result { - let contents = vec![Content { - role: Some("user".to_string()), - parts: vec![Part { - text: message.to_string(), - }], - }]; - - let system_instruction = system_prompt.map(|sys| Content { - role: None, - parts: vec![Part { - text: sys.to_string(), - }], - }); - - let (text, _usage) = self - .send_generate_content(contents, system_instruction, model, temperature) - .await?; - Ok(text) - } - - async fn chat_with_history( - &self, - messages: &[ChatMessage], - model: &str, - temperature: f64, - ) -> anyhow::Result { - let system_instruction = messages - .iter() - .find(|m| m.role == "system") - .map(|m| Content { - role: None, - parts: vec![Part { - text: m.content.clone(), - }], - }); - - let contents = Self::messages_to_contents(messages); - - let (text, _usage) = self - .send_generate_content(contents, system_instruction, model, temperature) - .await?; - Ok(text) - } - - async fn chat( - &self, - request: crate::provider::traits::ChatRequest<'_>, - model: &str, - temperature: f64, - ) -> anyhow::Result { - let system_instruction = request - .messages - .iter() - .find(|m| m.role == "system") - .map(|m| Content { - role: None, - parts: vec![Part { - text: m.content.clone(), - }], - }); - - let contents = Self::messages_to_contents(request.messages); - - let (text, usage) = self - .send_generate_content(contents, system_instruction, model, temperature) - .await?; - - Ok(ChatResponse { - text: Some(text), - tool_calls: vec![], - usage, - reasoning_content: None, - }) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn creates_with_explicit_key() { - // Test that explicit key takes priority - let p = GeminiProvider::new(Some("explicit-key")); - assert_eq!(p.api_key.as_deref(), Some("explicit-key")); - } - - #[test] - fn creates_without_key() { - // Clear any env vars for clean test - let old_gemini = std::env::var("GEMINI_API_KEY").ok(); - let old_google = std::env::var("GOOGLE_API_KEY").ok(); - unsafe { - std::env::remove_var("GEMINI_API_KEY"); - std::env::remove_var("GOOGLE_API_KEY"); - } - let p = GeminiProvider::new(None); - assert!(p.api_key.is_none()); - // Restore env vars - if let Some(v) = old_gemini { - unsafe { std::env::set_var("GEMINI_API_KEY", v) }; - } - if let Some(v) = old_google { - unsafe { std::env::set_var("GOOGLE_API_KEY", v) }; - } - } - - #[test] - fn trims_whitespace_key() { - let p = GeminiProvider::new(Some(" trimmed-key ")); - assert_eq!(p.api_key.as_deref(), Some("trimmed-key")); - } - - #[test] - fn rejects_empty_key() { - let old_gemini = std::env::var("GEMINI_API_KEY").ok(); - let old_google = std::env::var("GOOGLE_API_KEY").ok(); - unsafe { - std::env::remove_var("GEMINI_API_KEY"); - std::env::remove_var("GOOGLE_API_KEY"); - } - let p = GeminiProvider::new(Some("")); - assert!(p.api_key.is_none()); - if let Some(v) = old_gemini { - unsafe { std::env::set_var("GEMINI_API_KEY", v) }; - } - if let Some(v) = old_google { - unsafe { std::env::set_var("GOOGLE_API_KEY", v) }; - } - } - - #[test] - fn format_model_name_adds_prefix() { - assert_eq!( - GeminiProvider::format_model_name("gemini-2.5-pro"), - "models/gemini-2.5-pro" - ); - } - - #[test] - fn format_model_name_preserves_prefix() { - assert_eq!( - GeminiProvider::format_model_name("models/gemini-2.5-pro"), - "models/gemini-2.5-pro" - ); - } - - #[test] - fn effective_text_skips_thought_parts() { - let content = CandidateContent { - parts: vec![ - ResponsePart { - text: Some("thinking...".to_string()), - thought: true, - }, - ResponsePart { - text: Some("actual answer".to_string()), - thought: false, - }, - ], - }; - assert_eq!(content.effective_text(), Some("actual answer".to_string())); - } - - #[test] - fn effective_text_falls_back_to_thought_if_no_answer() { - let content = CandidateContent { - parts: vec![ResponsePart { - text: Some("only thinking".to_string()), - thought: true, - }], - }; - assert_eq!(content.effective_text(), Some("only thinking".to_string())); - } - - #[test] - fn effective_text_returns_none_for_empty_parts() { - let content = CandidateContent { parts: vec![] }; - assert_eq!(content.effective_text(), None); - } - - #[tokio::test] - async fn chat_fails_without_key() { - let old_gemini = std::env::var("GEMINI_API_KEY").ok(); - let old_google = std::env::var("GOOGLE_API_KEY").ok(); - unsafe { - std::env::remove_var("GEMINI_API_KEY"); - std::env::remove_var("GOOGLE_API_KEY"); - } - let p = GeminiProvider::new(None); - let result = p - .chat_with_system(None, "hello", "gemini-2.5-pro", 0.7) - .await; - assert!(result.is_err()); - assert!( - result - .unwrap_err() - .to_string() - .contains("API key not found") - ); - if let Some(v) = old_gemini { - unsafe { std::env::set_var("GEMINI_API_KEY", v) }; - } - if let Some(v) = old_google { - unsafe { std::env::set_var("GOOGLE_API_KEY", v) }; - } - } - - #[test] - fn generate_request_serializes_correctly() { - let req = GenerateContentRequest { - contents: vec![Content { - role: Some("user".to_string()), - parts: vec![Part { - text: "hello".to_string(), - }], - }], - system_instruction: None, - generation_config: GenerationConfig { - temperature: 0.7, - max_output_tokens: 8192, - }, - }; - let json = serde_json::to_value(&req).unwrap(); - assert!(json.get("contents").is_some()); - assert!(json.get("generationConfig").is_some()); - assert!(json.get("systemInstruction").is_none()); - } -} diff --git a/crewforge-rs/src/provider/glm.rs b/crewforge-rs/src/provider/glm.rs deleted file mode 100644 index 066e49b..0000000 --- a/crewforge-rs/src/provider/glm.rs +++ /dev/null @@ -1,367 +0,0 @@ -//! Zhipu GLM provider with JWT authentication. -//! The GLM API requires JWT tokens generated from the `id.secret` API key format -//! with a custom `sign_type: "SIGN"` header, and uses `/v4/chat/completions`. - -use crate::provider::traits::{ChatMessage, Provider}; -use async_trait::async_trait; -use reqwest::Client; -use ring::hmac; -use serde::{Deserialize, Serialize}; -use std::sync::Mutex; -use std::time::{SystemTime, UNIX_EPOCH}; - -pub struct GlmProvider { - api_key_id: String, - api_key_secret: String, - base_url: String, - /// Cached JWT token + expiry timestamp (ms) - token_cache: Mutex>, -} - -#[derive(Debug, Serialize)] -struct ChatRequest { - model: String, - messages: Vec, - temperature: f64, -} - -#[derive(Debug, Serialize)] -struct Message { - role: String, - content: String, -} - -#[derive(Debug, Deserialize)] -struct ChatResponse { - choices: Vec, -} - -#[derive(Debug, Deserialize)] -struct Choice { - message: ResponseMessage, -} - -#[derive(Debug, Deserialize)] -struct ResponseMessage { - content: String, -} - -/// Base64url encode without padding (per JWT spec). -fn base64url_encode_bytes(data: &[u8]) -> String { - const CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; - let mut result = String::new(); - let mut i = 0; - while i < data.len() { - let b0 = data[i] as u32; - let b1 = if i + 1 < data.len() { - data[i + 1] as u32 - } else { - 0 - }; - let b2 = if i + 2 < data.len() { - data[i + 2] as u32 - } else { - 0 - }; - let triple = (b0 << 16) | (b1 << 8) | b2; - - result.push(CHARS[((triple >> 18) & 0x3F) as usize] as char); - result.push(CHARS[((triple >> 12) & 0x3F) as usize] as char); - - if i + 1 < data.len() { - result.push(CHARS[((triple >> 6) & 0x3F) as usize] as char); - } - if i + 2 < data.len() { - result.push(CHARS[(triple & 0x3F) as usize] as char); - } - - i += 3; - } - - // Convert to base64url: replace + with -, / with _, strip = - result.replace('+', "-").replace('/', "_") -} - -fn base64url_encode_str(s: &str) -> String { - base64url_encode_bytes(s.as_bytes()) -} - -impl GlmProvider { - pub fn new(api_key: Option<&str>) -> Self { - let (id, secret) = api_key - .and_then(|k| k.split_once('.')) - .map(|(id, secret)| (id.to_string(), secret.to_string())) - .unwrap_or_default(); - - Self { - api_key_id: id, - api_key_secret: secret, - base_url: "https://api.z.ai/api/paas/v4".to_string(), - token_cache: Mutex::new(None), - } - } - - fn generate_token(&self) -> anyhow::Result { - if self.api_key_id.is_empty() || self.api_key_secret.is_empty() { - anyhow::bail!( - "GLM API key not set or invalid format. Expected 'id.secret'. \ - Set GLM_API_KEY environment variable." - ); - } - - let now_ms = SystemTime::now().duration_since(UNIX_EPOCH)?.as_millis() as u64; - - // Check cache (valid for 3 minutes, token expires at 3.5 min) - if let Ok(cache) = self.token_cache.lock() - && let Some((ref token, expiry)) = *cache - && now_ms < expiry - { - return Ok(token.clone()); - } - - let exp_ms = now_ms + 210_000; // 3.5 minutes - - // Build JWT manually to include custom sign_type header - // Header: {"alg":"HS256","typ":"JWT","sign_type":"SIGN"} - let header_json = r#"{"alg":"HS256","typ":"JWT","sign_type":"SIGN"}"#; - let header_b64 = base64url_encode_str(header_json); - - // Payload: {"api_key":"...","exp":...,"timestamp":...} - let payload_json = format!( - r#"{{"api_key":"{}","exp":{},"timestamp":{}}}"#, - self.api_key_id, exp_ms, now_ms - ); - let payload_b64 = base64url_encode_str(&payload_json); - - // Sign: HMAC-SHA256(header.payload, secret) - let signing_input = format!("{header_b64}.{payload_b64}"); - let key = hmac::Key::new(hmac::HMAC_SHA256, self.api_key_secret.as_bytes()); - let signature = hmac::sign(&key, signing_input.as_bytes()); - let sig_b64 = base64url_encode_bytes(signature.as_ref()); - - let token = format!("{signing_input}.{sig_b64}"); - - // Cache for 3 minutes - if let Ok(mut cache) = self.token_cache.lock() { - *cache = Some((token.clone(), now_ms + 180_000)); - } - - Ok(token) - } - - fn http_client(&self) -> Client { - reqwest::Client::builder() - .timeout(std::time::Duration::from_secs(120)) - .build() - .unwrap_or_default() - } -} - -#[async_trait] -impl Provider for GlmProvider { - async fn chat_with_system( - &self, - system_prompt: Option<&str>, - message: &str, - model: &str, - temperature: f64, - ) -> anyhow::Result { - let token = self.generate_token()?; - - let mut messages = Vec::new(); - - if let Some(sys) = system_prompt { - messages.push(Message { - role: "system".to_string(), - content: sys.to_string(), - }); - } - - messages.push(Message { - role: "user".to_string(), - content: message.to_string(), - }); - - let request = ChatRequest { - model: model.to_string(), - messages, - temperature, - }; - - let url = format!("{}/chat/completions", self.base_url); - - let response = self - .http_client() - .post(&url) - .header("Authorization", format!("Bearer {token}")) - .json(&request) - .send() - .await?; - - if !response.status().is_success() { - let error = response.text().await?; - anyhow::bail!("GLM API error: {error}"); - } - - let chat_response: ChatResponse = response.json().await?; - - chat_response - .choices - .into_iter() - .next() - .map(|c| c.message.content) - .ok_or_else(|| anyhow::anyhow!("No response from GLM")) - } - - async fn chat_with_history( - &self, - messages: &[ChatMessage], - model: &str, - temperature: f64, - ) -> anyhow::Result { - let token = self.generate_token()?; - - let api_messages: Vec = messages - .iter() - .map(|m| Message { - role: m.role.clone(), - content: m.content.clone(), - }) - .collect(); - - let request = ChatRequest { - model: model.to_string(), - messages: api_messages, - temperature, - }; - - let url = format!("{}/chat/completions", self.base_url); - - let response = self - .http_client() - .post(&url) - .header("Authorization", format!("Bearer {token}")) - .json(&request) - .send() - .await?; - - if !response.status().is_success() { - let error = response.text().await?; - anyhow::bail!("GLM API error: {error}"); - } - - let chat_response: ChatResponse = response.json().await?; - - chat_response - .choices - .into_iter() - .next() - .map(|c| c.message.content) - .ok_or_else(|| anyhow::anyhow!("No response from GLM")) - } - - async fn warmup(&self) -> anyhow::Result<()> { - if self.api_key_id.is_empty() || self.api_key_secret.is_empty() { - return Ok(()); - } - - // Generate and cache a JWT token, establishing TLS to the GLM API. - let token = self.generate_token()?; - let url = format!("{}/chat/completions", self.base_url); - // GET will likely return 405 but establishes the TLS + HTTP/2 connection pool. - let _ = self - .http_client() - .get(&url) - .header("Authorization", format!("Bearer {token}")) - .send() - .await?; - Ok(()) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn parses_api_key() { - let p = GlmProvider::new(Some("abc123.secretXYZ")); - assert_eq!(p.api_key_id, "abc123"); - assert_eq!(p.api_key_secret, "secretXYZ"); - } - - #[test] - fn handles_no_key() { - let p = GlmProvider::new(None); - assert!(p.api_key_id.is_empty()); - assert!(p.api_key_secret.is_empty()); - } - - #[test] - fn handles_invalid_key_format() { - let p = GlmProvider::new(Some("no-dot-here")); - assert!(p.api_key_id.is_empty()); - assert!(p.api_key_secret.is_empty()); - } - - #[test] - fn generates_jwt_token() { - let p = GlmProvider::new(Some("testid.testsecret")); - let token = p.generate_token().unwrap(); - assert!(!token.is_empty()); - // JWT has 3 dot-separated parts - let parts: Vec<&str> = token.split('.').collect(); - assert_eq!(parts.len(), 3, "JWT should have 3 parts: {token}"); - } - - #[test] - fn caches_token() { - let p = GlmProvider::new(Some("testid.testsecret")); - let token1 = p.generate_token().unwrap(); - let token2 = p.generate_token().unwrap(); - assert_eq!(token1, token2, "Cached token should be reused"); - } - - #[test] - fn fails_without_key() { - let p = GlmProvider::new(None); - let result = p.generate_token(); - assert!(result.is_err()); - assert!(result.unwrap_err().to_string().contains("API key not set")); - } - - #[tokio::test] - async fn chat_fails_without_key() { - let p = GlmProvider::new(None); - let result = p.chat_with_system(None, "hello", "glm-4.7", 0.7).await; - assert!(result.is_err()); - } - - #[tokio::test] - async fn chat_with_history_fails_without_key() { - let p = GlmProvider::new(None); - let messages = vec![ - ChatMessage::system("You are helpful."), - ChatMessage::user("Hello"), - ChatMessage::assistant("Hi there!"), - ChatMessage::user("What did I say?"), - ]; - let result = p.chat_with_history(&messages, "glm-4.7", 0.7).await; - assert!(result.is_err()); - } - - #[test] - fn base64url_no_padding() { - let encoded = base64url_encode_bytes(b"hello"); - assert!(!encoded.contains('=')); - assert!(!encoded.contains('+')); - assert!(!encoded.contains('/')); - } - - #[tokio::test] - async fn warmup_without_key_is_noop() { - let provider = GlmProvider::new(None); - let result = provider.warmup().await; - assert!(result.is_ok()); - } -} diff --git a/crewforge-rs/src/provider/mod.rs b/crewforge-rs/src/provider/mod.rs index adb10a8..0911d3b 100644 --- a/crewforge-rs/src/provider/mod.rs +++ b/crewforge-rs/src/provider/mod.rs @@ -1,12 +1,6 @@ -pub mod anthropic; +pub mod anthropic_oauth; pub mod compatible; -pub mod copilot; -pub mod gemini; -pub mod glm; -pub mod ollama; -pub mod openai; -pub mod openai_codex; -pub mod openrouter; +pub mod openai_oauth; pub mod reliable; pub mod router; pub mod traits; @@ -20,9 +14,9 @@ pub use compatible::{api_error, sanitize_api_error}; pub use reliable::ReliableProvider; pub use router::{Route, RouterProvider}; -use compatible::{AuthStyle, OpenAiCompatibleProvider}; +use compatible::OpenAiCompatibleProvider; -/// Runtime options for providers that use OAuth/auth services (copilot, openai-codex). +/// Runtime options for providers that use OAuth/auth services (anthropic, openai-codex). #[derive(Debug, Default)] pub struct ProviderRuntimeOptions { pub crewforge_dir: Option, @@ -41,73 +35,68 @@ pub fn create_provider( ) -> anyhow::Result> { let resolved_key = resolve_api_key(provider_name, api_key); let p: Box = match provider_name.to_lowercase().as_str() { - "anthropic" | "claude" => Box::new(anthropic::AnthropicProvider::with_base_url( - resolved_key.as_deref(), - base_url, - )), - "openai" | "gpt" => Box::new(openai::OpenAiProvider::with_base_url( - base_url, + "anthropic" | "claude" => { + let opts = ProviderRuntimeOptions { + provider_api_url: base_url.map(ToString::to_string), + ..ProviderRuntimeOptions::default() + }; + Box::new( + anthropic_oauth::AnthropicOAuthProvider::new(&opts, resolved_key.as_deref()) + .map_err(|e| anyhow::anyhow!("Failed to create Anthropic provider: {e}"))?, + ) + } + "openai" | "gpt" => Box::new(OpenAiCompatibleProvider::new( + "openai", + base_url.unwrap_or("https://api.openai.com/v1"), resolved_key.as_deref(), )), - "gemini" | "google" => Box::new(gemini::GeminiProvider::new(resolved_key.as_deref())), - "ollama" => Box::new(ollama::OllamaProvider::new( - base_url, + "openrouter" => Box::new(OpenAiCompatibleProvider::new( + "openrouter", + base_url.unwrap_or("https://openrouter.ai/api/v1"), resolved_key.as_deref(), )), - "openrouter" => Box::new(openrouter::OpenRouterProvider::new(resolved_key.as_deref())), - "glm" | "zhipuai" | "zhipu" => Box::new(glm::GlmProvider::new(resolved_key.as_deref())), "moonshot" | "kimi" => Box::new(OpenAiCompatibleProvider::new( "moonshot", base_url.unwrap_or("https://api.moonshot.ai/v1"), resolved_key.as_deref(), - AuthStyle::Bearer, )), "qwen" | "dashscope" => Box::new(OpenAiCompatibleProvider::new( "qwen", base_url.unwrap_or("https://dashscope.aliyuncs.com/compatible-mode/v1"), resolved_key.as_deref(), - AuthStyle::Bearer, )), - "minimax" => Box::new(OpenAiCompatibleProvider::new_merge_system_into_user( + "minimax" => Box::new(OpenAiCompatibleProvider::new( "minimax", base_url.unwrap_or("https://api.minimax.io/v1"), resolved_key.as_deref(), - AuthStyle::Bearer, )), "deepseek" => Box::new(OpenAiCompatibleProvider::new( "deepseek", base_url.unwrap_or("https://api.deepseek.com/v1"), resolved_key.as_deref(), - AuthStyle::Bearer, )), "groq" => Box::new(OpenAiCompatibleProvider::new( "groq", base_url.unwrap_or("https://api.groq.com/openai/v1"), resolved_key.as_deref(), - AuthStyle::Bearer, )), "mistral" => Box::new(OpenAiCompatibleProvider::new( "mistral", base_url.unwrap_or("https://api.mistral.ai/v1"), resolved_key.as_deref(), - AuthStyle::Bearer, )), "xai" | "grok" => Box::new(OpenAiCompatibleProvider::new( "xai", base_url.unwrap_or("https://api.x.ai/v1"), resolved_key.as_deref(), - AuthStyle::Bearer, )), - "copilot" | "github-copilot" => { - Box::new(copilot::CopilotProvider::new(resolved_key.as_deref())) - } "openai-codex" | "codex" => { let opts = ProviderRuntimeOptions { provider_api_url: base_url.map(ToString::to_string), ..ProviderRuntimeOptions::default() }; Box::new( - openai_codex::OpenAiCodexProvider::new(&opts, resolved_key.as_deref()) + openai_oauth::OpenAiCodexProvider::new(&opts, resolved_key.as_deref()) .map_err(|e| anyhow::anyhow!("Failed to create OpenAI Codex provider: {e}"))?, ) } @@ -117,7 +106,6 @@ pub fn create_provider( other, url, resolved_key.as_deref(), - AuthStyle::Bearer, )) } else { anyhow::bail!( @@ -145,10 +133,7 @@ pub fn default_api_key_env(provider_name: &str) -> Option<&'static str> { match provider_name.to_lowercase().as_str() { "anthropic" | "claude" => Some("ANTHROPIC_API_KEY"), "openai" | "gpt" => Some("OPENAI_API_KEY"), - "gemini" | "google" => Some("GEMINI_API_KEY"), - "ollama" => None, "openrouter" => Some("OPENROUTER_API_KEY"), - "glm" | "zhipuai" | "zhipu" => Some("GLM_API_KEY"), "moonshot" | "kimi" => Some("MOONSHOT_API_KEY"), "qwen" | "dashscope" => Some("QWEN_API_KEY"), "minimax" => Some("MINIMAX_API_KEY"), diff --git a/crewforge-rs/src/provider/ollama.rs b/crewforge-rs/src/provider/ollama.rs deleted file mode 100644 index d7786f2..0000000 --- a/crewforge-rs/src/provider/ollama.rs +++ /dev/null @@ -1,989 +0,0 @@ -use crate::provider::traits::{ - ChatMessage, ChatResponse, Provider, ProviderCapabilities, TokenUsage, ToolCall, -}; -use async_trait::async_trait; -use reqwest::Client; -use serde::{Deserialize, Serialize}; -use std::collections::HashMap; - -pub struct OllamaProvider { - base_url: String, - api_key: Option, - reasoning_enabled: Option, -} - -// ─── Request Structures ─────────────────────────────────────────────────────── - -#[derive(Debug, Serialize)] -struct ChatRequest { - model: String, - messages: Vec, - stream: bool, - options: Options, - #[serde(skip_serializing_if = "Option::is_none")] - think: Option, - #[serde(skip_serializing_if = "Option::is_none")] - tools: Option>, -} - -#[derive(Debug, Serialize)] -struct Message { - role: String, - #[serde(skip_serializing_if = "Option::is_none")] - content: Option, - #[serde(skip_serializing_if = "Option::is_none")] - tool_calls: Option>, - #[serde(skip_serializing_if = "Option::is_none")] - tool_name: Option, -} - -#[derive(Debug, Serialize)] -struct OutgoingToolCall { - #[serde(rename = "type")] - kind: String, - function: OutgoingFunction, -} - -#[derive(Debug, Serialize)] -struct OutgoingFunction { - name: String, - arguments: serde_json::Value, -} - -#[derive(Debug, Serialize)] -struct Options { - temperature: f64, -} - -// ─── Response Structures ────────────────────────────────────────────────────── - -#[derive(Debug, Deserialize)] -struct ApiChatResponse { - message: ResponseMessage, - #[serde(default)] - prompt_eval_count: Option, - #[serde(default)] - eval_count: Option, -} - -#[derive(Debug, Deserialize)] -struct ResponseMessage { - #[serde(default)] - content: String, - #[serde(default)] - tool_calls: Vec, - /// Some models return a "thinking" field with internal reasoning - #[serde(default)] - thinking: Option, -} - -#[derive(Debug, Deserialize)] -struct OllamaToolCall { - id: Option, - function: OllamaFunction, -} - -#[derive(Debug, Deserialize)] -struct OllamaFunction { - name: String, - #[serde(default)] - arguments: serde_json::Value, -} - -// ─── Implementation ─────────────────────────────────────────────────────────── - -fn sanitize_api_error(raw: &str) -> String { - let truncated = if raw.len() > 500 { &raw[..500] } else { raw }; - truncated.replace('\n', " ").trim().to_string() -} - -impl OllamaProvider { - fn normalize_base_url(raw_url: &str) -> String { - let trimmed = raw_url.trim().trim_end_matches('/'); - if trimmed.is_empty() { - return String::new(); - } - - trimmed - .strip_suffix("/api") - .unwrap_or(trimmed) - .trim_end_matches('/') - .to_string() - } - - pub fn new(base_url: Option<&str>, api_key: Option<&str>) -> Self { - Self::new_with_reasoning(base_url, api_key, None) - } - - pub fn new_with_reasoning( - base_url: Option<&str>, - api_key: Option<&str>, - reasoning_enabled: Option, - ) -> Self { - let api_key = api_key.and_then(|value| { - let trimmed = value.trim(); - (!trimmed.is_empty()).then(|| trimmed.to_string()) - }); - - Self { - base_url: Self::normalize_base_url(base_url.unwrap_or("http://localhost:11434")), - api_key, - reasoning_enabled, - } - } - - fn is_local_endpoint(&self) -> bool { - reqwest::Url::parse(&self.base_url) - .ok() - .and_then(|url| url.host_str().map(|host| host.to_string())) - .is_some_and(|host| matches!(host.as_str(), "localhost" | "127.0.0.1" | "::1")) - } - - fn http_client(&self) -> Client { - reqwest::Client::builder() - .timeout(std::time::Duration::from_secs(300)) - .build() - .unwrap_or_default() - } - - fn resolve_request_details(&self, model: &str) -> anyhow::Result<(String, bool)> { - let requests_cloud = model.ends_with(":cloud"); - let normalized_model = model.strip_suffix(":cloud").unwrap_or(model).to_string(); - - if requests_cloud && self.is_local_endpoint() { - anyhow::bail!( - "Model '{}' requested cloud routing, but Ollama endpoint is local. Configure api_url with a remote Ollama endpoint.", - model - ); - } - - if requests_cloud && self.api_key.is_none() { - anyhow::bail!( - "Model '{}' requested cloud routing, but no API key is configured. Set OLLAMA_API_KEY or config api_key.", - model - ); - } - - let should_auth = self.api_key.is_some() && !self.is_local_endpoint(); - - Ok((normalized_model, should_auth)) - } - - fn parse_tool_arguments(arguments: &str) -> serde_json::Value { - serde_json::from_str(arguments).unwrap_or_else(|_| serde_json::json!({})) - } - - fn normalize_response_text(content: String) -> Option { - if content.trim().is_empty() { - None - } else { - Some(content) - } - } - - fn fallback_text_for_empty_content(model: &str, thinking: Option<&str>) -> String { - if let Some(thinking) = thinking.map(str::trim).filter(|value| !value.is_empty()) { - let thinking_log_excerpt: String = thinking.chars().take(100).collect(); - let thinking_reply_excerpt: String = thinking.chars().take(200).collect(); - tracing::warn!( - "Ollama returned empty content with only thinking for model '{}': '{}'. Model may have stopped prematurely.", - model, - thinking_log_excerpt - ); - return format!( - "I was thinking about this: {}... but I didn't complete my response. Could you try asking again?", - thinking_reply_excerpt - ); - } - - tracing::warn!( - "Ollama returned empty or whitespace content with no tool calls for model '{}'", - model - ); - "I couldn't get a complete response from Ollama. Please try again or switch to a different model." - .to_string() - } - - fn build_chat_request( - &self, - messages: Vec, - model: &str, - temperature: f64, - tools: Option<&[serde_json::Value]>, - ) -> ChatRequest { - ChatRequest { - model: model.to_string(), - messages, - stream: false, - options: Options { temperature }, - think: self.reasoning_enabled, - tools: tools.map(|t| t.to_vec()), - } - } - - /// Convert internal chat history format to Ollama's native tool-call message schema. - fn convert_messages(&self, messages: &[ChatMessage]) -> Vec { - let mut tool_name_by_id: HashMap = HashMap::new(); - - messages - .iter() - .map(|message| { - if message.role == "assistant" - && let Ok(value) = serde_json::from_str::(&message.content) - && let Some(tool_calls_value) = value.get("tool_calls") - && let Ok(parsed_calls) = - serde_json::from_value::>(tool_calls_value.clone()) - { - let outgoing_calls: Vec = parsed_calls - .into_iter() - .map(|call| { - tool_name_by_id.insert(call.id.clone(), call.name.clone()); - OutgoingToolCall { - kind: "function".to_string(), - function: OutgoingFunction { - name: call.name, - arguments: Self::parse_tool_arguments(&call.arguments), - }, - } - }) - .collect(); - let content = value - .get("content") - .and_then(serde_json::Value::as_str) - .map(ToString::to_string); - return Message { - role: "assistant".to_string(), - content, - tool_calls: Some(outgoing_calls), - tool_name: None, - }; - } - - if message.role == "tool" - && let Ok(value) = serde_json::from_str::(&message.content) - { - let tool_name = value - .get("tool_name") - .and_then(serde_json::Value::as_str) - .map(ToString::to_string) - .or_else(|| { - value - .get("tool_call_id") - .and_then(serde_json::Value::as_str) - .and_then(|id| tool_name_by_id.get(id)) - .cloned() - }); - let content = value - .get("content") - .and_then(serde_json::Value::as_str) - .map(ToString::to_string) - .or_else(|| { - (!message.content.trim().is_empty()).then_some(message.content.clone()) - }); - - return Message { - role: "tool".to_string(), - content, - tool_calls: None, - tool_name, - }; - } - - Message { - role: message.role.clone(), - content: Some(message.content.clone()), - tool_calls: None, - tool_name: None, - } - }) - .collect() - } - - /// Send a request to Ollama and get the parsed response. - async fn send_request( - &self, - messages: Vec, - model: &str, - temperature: f64, - should_auth: bool, - tools: Option<&[serde_json::Value]>, - ) -> anyhow::Result { - let request = self.build_chat_request(messages, model, temperature, tools); - - let url = format!("{}/api/chat", self.base_url); - - tracing::debug!( - "Ollama request: url={} model={} message_count={} temperature={} think={:?} tool_count={}", - url, - model, - request.messages.len(), - temperature, - request.think, - request.tools.as_ref().map_or(0, |t| t.len()), - ); - - let mut request_builder = self.http_client().post(&url).json(&request); - - if should_auth && let Some(key) = self.api_key.as_ref() { - request_builder = request_builder.bearer_auth(key); - } - - let response = request_builder.send().await?; - let status = response.status(); - tracing::debug!("Ollama response status: {}", status); - - let body = response.bytes().await?; - tracing::debug!("Ollama response body length: {} bytes", body.len()); - - if !status.is_success() { - let raw = String::from_utf8_lossy(&body); - let sanitized = sanitize_api_error(&raw); - tracing::error!( - "Ollama error response: status={} body_excerpt={}", - status, - sanitized - ); - anyhow::bail!( - "Ollama API error ({}): {}. Is Ollama running? (brew install ollama && ollama serve)", - status, - sanitized - ); - } - - let chat_response: ApiChatResponse = match serde_json::from_slice(&body) { - Ok(r) => r, - Err(e) => { - let raw = String::from_utf8_lossy(&body); - let sanitized = sanitize_api_error(&raw); - tracing::error!( - "Ollama response deserialization failed: {e}. body_excerpt={}", - sanitized - ); - anyhow::bail!("Failed to parse Ollama response: {e}"); - } - }; - - Ok(chat_response) - } - - /// Convert Ollama tool calls to the JSON format expected by the agent loop. - fn format_tool_calls_for_loop(&self, tool_calls: &[OllamaToolCall]) -> String { - let formatted_calls: Vec = tool_calls - .iter() - .map(|tc| { - let (tool_name, tool_args) = self.extract_tool_name_and_args(tc); - - let args_str = - serde_json::to_string(&tool_args).unwrap_or_else(|_| "{}".to_string()); - - serde_json::json!({ - "id": tc.id, - "type": "function", - "function": { - "name": tool_name, - "arguments": args_str - } - }) - }) - .collect(); - - serde_json::json!({ - "content": "", - "tool_calls": formatted_calls - }) - .to_string() - } - - /// Extract the actual tool name and arguments from potentially nested structures. - fn extract_tool_name_and_args(&self, tc: &OllamaToolCall) -> (String, serde_json::Value) { - let name = &tc.function.name; - let args = &tc.function.arguments; - - // Pattern 1: Nested tool_call wrapper - if (name == "tool_call" - || name == "tool.call" - || name.starts_with("tool_call>") - || name.starts_with("tool_call<")) - && let Some(nested_name) = args.get("name").and_then(|v| v.as_str()) - { - let nested_args = args - .get("arguments") - .cloned() - .unwrap_or(serde_json::json!({})); - tracing::debug!( - "Unwrapped nested tool call: {} -> {} with args {:?}", - name, - nested_name, - nested_args - ); - return (nested_name.to_string(), nested_args); - } - - // Pattern 2: Prefixed tool name (tool.shell, tool.file_read, etc.) - if let Some(stripped) = name.strip_prefix("tool.") { - return (stripped.to_string(), args.clone()); - } - - // Pattern 3: Normal tool call - (name.clone(), args.clone()) - } -} - -#[async_trait] -impl Provider for OllamaProvider { - fn capabilities(&self) -> ProviderCapabilities { - ProviderCapabilities { - native_tool_calling: true, - vision: false, - } - } - - async fn chat_with_system( - &self, - system_prompt: Option<&str>, - message: &str, - model: &str, - temperature: f64, - ) -> anyhow::Result { - let (normalized_model, should_auth) = self.resolve_request_details(model)?; - - let mut messages = Vec::new(); - - if let Some(sys) = system_prompt { - messages.push(Message { - role: "system".to_string(), - content: Some(sys.to_string()), - tool_calls: None, - tool_name: None, - }); - } - - messages.push(Message { - role: "user".to_string(), - content: Some(message.to_string()), - tool_calls: None, - tool_name: None, - }); - - let response = self - .send_request(messages, &normalized_model, temperature, should_auth, None) - .await?; - - // If model returned tool calls, format them for the agent loop parser - if !response.message.tool_calls.is_empty() { - tracing::debug!( - "Ollama returned {} tool call(s), formatting for loop parser", - response.message.tool_calls.len() - ); - return Ok(self.format_tool_calls_for_loop(&response.message.tool_calls)); - } - - // Plain text response - let content = response.message.content; - if let Some(content) = Self::normalize_response_text(content) { - return Ok(content); - } - - Ok(Self::fallback_text_for_empty_content( - &normalized_model, - response.message.thinking.as_deref(), - )) - } - - async fn chat_with_history( - &self, - messages: &[ChatMessage], - model: &str, - temperature: f64, - ) -> anyhow::Result { - let (normalized_model, should_auth) = self.resolve_request_details(model)?; - - let api_messages = self.convert_messages(messages); - - let response = self - .send_request( - api_messages, - &normalized_model, - temperature, - should_auth, - None, - ) - .await?; - - // If model returned tool calls, format them for the agent loop parser - if !response.message.tool_calls.is_empty() { - tracing::debug!( - "Ollama returned {} tool call(s), formatting for loop parser", - response.message.tool_calls.len() - ); - return Ok(self.format_tool_calls_for_loop(&response.message.tool_calls)); - } - - // Plain text response - let content = response.message.content; - if let Some(content) = Self::normalize_response_text(content) { - return Ok(content); - } - - Ok(Self::fallback_text_for_empty_content( - &normalized_model, - response.message.thinking.as_deref(), - )) - } - - async fn chat_with_tools( - &self, - messages: &[ChatMessage], - tools: &[serde_json::Value], - model: &str, - temperature: f64, - ) -> anyhow::Result { - let (normalized_model, should_auth) = self.resolve_request_details(model)?; - - let api_messages = self.convert_messages(messages); - - let tools_opt = if tools.is_empty() { None } else { Some(tools) }; - - let response = self - .send_request( - api_messages, - &normalized_model, - temperature, - should_auth, - tools_opt, - ) - .await?; - - let usage = if response.prompt_eval_count.is_some() || response.eval_count.is_some() { - Some(TokenUsage { - input_tokens: response.prompt_eval_count, - output_tokens: response.eval_count, - }) - } else { - None - }; - - // Native tool calls returned by the model. - if !response.message.tool_calls.is_empty() { - let tool_calls: Vec = response - .message - .tool_calls - .iter() - .map(|tc| { - let (name, args) = self.extract_tool_name_and_args(tc); - ToolCall { - id: tc - .id - .clone() - .unwrap_or_else(|| uuid::Uuid::new_v4().to_string()), - name, - arguments: serde_json::to_string(&args) - .unwrap_or_else(|_| "{}".to_string()), - } - }) - .collect(); - let text = Self::normalize_response_text(response.message.content); - return Ok(ChatResponse { - text, - tool_calls, - usage, - reasoning_content: None, - }); - } - - // Plain text response. - let content = response.message.content; - let text = if let Some(content) = Self::normalize_response_text(content) { - content - } else { - Self::fallback_text_for_empty_content( - &normalized_model, - response.message.thinking.as_deref(), - ) - }; - Ok(ChatResponse { - text: Some(text), - tool_calls: vec![], - usage, - reasoning_content: None, - }) - } - - fn supports_native_tools(&self) -> bool { - // Ollama's /api/chat supports native function-calling for capable models - true - } - - async fn chat( - &self, - request: crate::provider::traits::ChatRequest<'_>, - model: &str, - temperature: f64, - ) -> anyhow::Result { - // Convert ToolSpec to OpenAI-compatible JSON and delegate to chat_with_tools. - if let Some(specs) = request.tools - && !specs.is_empty() - { - let tools: Vec = specs - .iter() - .map(|s| { - serde_json::json!({ - "type": "function", - "function": { - "name": s.name, - "description": s.description, - "parameters": s.parameters - } - }) - }) - .collect(); - return self - .chat_with_tools(request.messages, &tools, model, temperature) - .await; - } - - // No tools — fall back to plain text chat. - let text = self - .chat_with_history(request.messages, model, temperature) - .await?; - Ok(ChatResponse { - text: Some(text), - tool_calls: vec![], - usage: None, - reasoning_content: None, - }) - } -} - -// ─── Tests ──────────────────────────────────────────────────────────────────── - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn default_url() { - let p = OllamaProvider::new(None, None); - assert_eq!(p.base_url, "http://localhost:11434"); - } - - #[test] - fn custom_url_trailing_slash() { - let p = OllamaProvider::new(Some("http://192.168.1.100:11434/"), None); - assert_eq!(p.base_url, "http://192.168.1.100:11434"); - } - - #[test] - fn custom_url_no_trailing_slash() { - let p = OllamaProvider::new(Some("http://myserver:11434"), None); - assert_eq!(p.base_url, "http://myserver:11434"); - } - - #[test] - fn custom_url_strips_api_suffix() { - let p = OllamaProvider::new(Some("https://ollama.com/api/"), None); - assert_eq!(p.base_url, "https://ollama.com"); - } - - #[test] - fn empty_url_uses_empty() { - let p = OllamaProvider::new(Some(""), None); - assert_eq!(p.base_url, ""); - } - - #[test] - fn cloud_suffix_strips_model_name() { - let p = OllamaProvider::new(Some("https://ollama.com"), Some("ollama-key")); - let (model, should_auth) = p.resolve_request_details("qwen3:cloud").unwrap(); - assert_eq!(model, "qwen3"); - assert!(should_auth); - } - - #[test] - fn cloud_suffix_with_local_endpoint_errors() { - let p = OllamaProvider::new(None, Some("ollama-key")); - let error = p - .resolve_request_details("qwen3:cloud") - .expect_err("cloud suffix should fail on local endpoint"); - assert!( - error - .to_string() - .contains("requested cloud routing, but Ollama endpoint is local") - ); - } - - #[test] - fn cloud_suffix_without_api_key_errors() { - let p = OllamaProvider::new(Some("https://ollama.com"), None); - let error = p - .resolve_request_details("qwen3:cloud") - .expect_err("cloud suffix should require API key"); - assert!( - error - .to_string() - .contains("requested cloud routing, but no API key is configured") - ); - } - - #[test] - fn remote_endpoint_auth_enabled_when_key_present() { - let p = OllamaProvider::new(Some("https://ollama.com"), Some("ollama-key")); - let (_model, should_auth) = p.resolve_request_details("qwen3").unwrap(); - assert!(should_auth); - } - - #[test] - fn remote_endpoint_with_api_suffix_still_allows_cloud_models() { - let p = OllamaProvider::new(Some("https://ollama.com/api"), Some("ollama-key")); - let (model, should_auth) = p.resolve_request_details("qwen3:cloud").unwrap(); - assert_eq!(model, "qwen3"); - assert!(should_auth); - } - - #[test] - fn local_endpoint_auth_disabled_even_with_key() { - let p = OllamaProvider::new(None, Some("ollama-key")); - let (_model, should_auth) = p.resolve_request_details("llama3").unwrap(); - assert!(!should_auth); - } - - #[test] - fn request_omits_think_when_reasoning_not_configured() { - let provider = OllamaProvider::new(None, None); - let request = provider.build_chat_request( - vec![Message { - role: "user".to_string(), - content: Some("hello".to_string()), - tool_calls: None, - tool_name: None, - }], - "llama3", - 0.7, - None, - ); - - let json = serde_json::to_value(request).unwrap(); - assert!(json.get("think").is_none()); - } - - #[test] - fn request_includes_think_when_reasoning_configured() { - let provider = OllamaProvider::new_with_reasoning(None, None, Some(false)); - let request = provider.build_chat_request( - vec![Message { - role: "user".to_string(), - content: Some("hello".to_string()), - tool_calls: None, - tool_name: None, - }], - "llama3", - 0.7, - None, - ); - - let json = serde_json::to_value(request).unwrap(); - assert_eq!(json.get("think"), Some(&serde_json::json!(false))); - } - - #[test] - fn response_deserializes() { - let json = r#"{"message":{"role":"assistant","content":"Hello from Ollama!"}}"#; - let resp: ApiChatResponse = serde_json::from_str(json).unwrap(); - assert_eq!(resp.message.content, "Hello from Ollama!"); - } - - #[test] - fn response_with_empty_content() { - let json = r#"{"message":{"role":"assistant","content":""}}"#; - let resp: ApiChatResponse = serde_json::from_str(json).unwrap(); - assert!(resp.message.content.is_empty()); - } - - #[test] - fn normalize_response_text_rejects_whitespace_only_content() { - assert_eq!( - OllamaProvider::normalize_response_text("\n \t".to_string()), - None - ); - assert_eq!( - OllamaProvider::normalize_response_text(" hello ".to_string()), - Some(" hello ".to_string()) - ); - } - - #[test] - fn fallback_text_for_empty_content_without_thinking_is_generic() { - let text = OllamaProvider::fallback_text_for_empty_content("qwen3-coder", None); - assert!(text.contains("couldn't get a complete response from Ollama")); - } - - #[test] - fn response_with_missing_content_defaults_to_empty() { - let json = r#"{"message":{"role":"assistant"}}"#; - let resp: ApiChatResponse = serde_json::from_str(json).unwrap(); - assert!(resp.message.content.is_empty()); - } - - #[test] - fn response_with_thinking_field_extracts_content() { - let json = - r#"{"message":{"role":"assistant","content":"hello","thinking":"internal reasoning"}}"#; - let resp: ApiChatResponse = serde_json::from_str(json).unwrap(); - assert_eq!(resp.message.content, "hello"); - } - - #[test] - fn response_with_tool_calls_parses_correctly() { - let json = r#"{"message":{"role":"assistant","content":"","tool_calls":[{"id":"call_123","function":{"name":"shell","arguments":{"command":"date"}}}]}}"#; - let resp: ApiChatResponse = serde_json::from_str(json).unwrap(); - assert!(resp.message.content.is_empty()); - assert_eq!(resp.message.tool_calls.len(), 1); - assert_eq!(resp.message.tool_calls[0].function.name, "shell"); - } - - #[test] - fn extract_tool_name_handles_nested_tool_call() { - let provider = OllamaProvider::new(None, None); - let tc = OllamaToolCall { - id: Some("call_123".into()), - function: OllamaFunction { - name: "tool_call".into(), - arguments: serde_json::json!({ - "name": "shell", - "arguments": {"command": "date"} - }), - }, - }; - let (name, args) = provider.extract_tool_name_and_args(&tc); - assert_eq!(name, "shell"); - assert_eq!(args.get("command").unwrap(), "date"); - } - - #[test] - fn extract_tool_name_handles_prefixed_name() { - let provider = OllamaProvider::new(None, None); - let tc = OllamaToolCall { - id: Some("call_123".into()), - function: OllamaFunction { - name: "tool.shell".into(), - arguments: serde_json::json!({"command": "ls"}), - }, - }; - let (name, args) = provider.extract_tool_name_and_args(&tc); - assert_eq!(name, "shell"); - assert_eq!(args.get("command").unwrap(), "ls"); - } - - #[test] - fn extract_tool_name_handles_normal_call() { - let provider = OllamaProvider::new(None, None); - let tc = OllamaToolCall { - id: Some("call_123".into()), - function: OllamaFunction { - name: "file_read".into(), - arguments: serde_json::json!({"path": "/tmp/test"}), - }, - }; - let (name, args) = provider.extract_tool_name_and_args(&tc); - assert_eq!(name, "file_read"); - assert_eq!(args.get("path").unwrap(), "/tmp/test"); - } - - #[test] - fn format_tool_calls_produces_valid_json() { - let provider = OllamaProvider::new(None, None); - let tool_calls = vec![OllamaToolCall { - id: Some("call_abc".into()), - function: OllamaFunction { - name: "shell".into(), - arguments: serde_json::json!({"command": "date"}), - }, - }]; - - let formatted = provider.format_tool_calls_for_loop(&tool_calls); - let parsed: serde_json::Value = serde_json::from_str(&formatted).unwrap(); - - assert!(parsed.get("tool_calls").is_some()); - let calls = parsed.get("tool_calls").unwrap().as_array().unwrap(); - assert_eq!(calls.len(), 1); - - let func = calls[0].get("function").unwrap(); - assert_eq!(func.get("name").unwrap(), "shell"); - // arguments should be a string (JSON-encoded) - assert!(func.get("arguments").unwrap().is_string()); - } - - #[test] - fn convert_messages_parses_native_assistant_tool_calls() { - let provider = OllamaProvider::new(None, None); - let messages = vec![ChatMessage { - role: "assistant".into(), - content: r#"{"content":null,"tool_calls":[{"id":"call_1","name":"shell","arguments":"{\"command\":\"ls\"}"}]}"#.into(), - }]; - - let converted = provider.convert_messages(&messages); - - assert_eq!(converted.len(), 1); - assert_eq!(converted[0].role, "assistant"); - assert!(converted[0].content.is_none()); - let calls = converted[0] - .tool_calls - .as_ref() - .expect("tool calls expected"); - assert_eq!(calls.len(), 1); - assert_eq!(calls[0].kind, "function"); - assert_eq!(calls[0].function.name, "shell"); - assert_eq!(calls[0].function.arguments.get("command").unwrap(), "ls"); - } - - #[test] - fn convert_messages_maps_tool_result_call_id_to_tool_name() { - let provider = OllamaProvider::new(None, None); - let messages = vec![ - ChatMessage { - role: "assistant".into(), - content: r#"{"content":null,"tool_calls":[{"id":"call_7","name":"file_read","arguments":"{\"path\":\"README.md\"}"}]}"#.into(), - }, - ChatMessage { - role: "tool".into(), - content: r#"{"tool_call_id":"call_7","content":"ok"}"#.into(), - }, - ]; - - let converted = provider.convert_messages(&messages); - - assert_eq!(converted.len(), 2); - assert_eq!(converted[1].role, "tool"); - assert_eq!(converted[1].tool_name.as_deref(), Some("file_read")); - assert_eq!(converted[1].content.as_deref(), Some("ok")); - assert!(converted[1].tool_calls.is_none()); - } - - #[test] - fn capabilities_include_native_tools_but_not_vision() { - let provider = OllamaProvider::new(None, None); - let caps = ::capabilities(&provider); - assert!(caps.native_tool_calling); - assert!(!caps.vision); - } - - #[test] - fn api_response_parses_eval_counts() { - let json = r#"{ - "message": {"content": "Hello", "tool_calls": []}, - "prompt_eval_count": 50, - "eval_count": 25 - }"#; - let resp: ApiChatResponse = serde_json::from_str(json).unwrap(); - assert_eq!(resp.prompt_eval_count, Some(50)); - assert_eq!(resp.eval_count, Some(25)); - } - - #[test] - fn api_response_parses_without_eval_counts() { - let json = r#"{"message": {"content": "Hello", "tool_calls": []}}"#; - let resp: ApiChatResponse = serde_json::from_str(json).unwrap(); - assert!(resp.prompt_eval_count.is_none()); - assert!(resp.eval_count.is_none()); - } -} diff --git a/crewforge-rs/src/provider/openai.rs b/crewforge-rs/src/provider/openai.rs deleted file mode 100644 index 7e4e68a..0000000 --- a/crewforge-rs/src/provider/openai.rs +++ /dev/null @@ -1,827 +0,0 @@ -use crate::provider::traits::{ - ChatMessage, ChatRequest as ProviderChatRequest, ChatResponse as ProviderChatResponse, - Provider, TokenUsage, ToolCall as ProviderToolCall, ToolSpec, -}; -use async_trait::async_trait; -use reqwest::Client; -use serde::{Deserialize, Serialize}; - -pub struct OpenAiProvider { - base_url: String, - credential: Option, -} - -#[derive(Debug, Serialize)] -struct ChatRequest { - model: String, - messages: Vec, - temperature: f64, -} - -#[derive(Debug, Serialize)] -struct Message { - role: String, - content: String, -} - -#[derive(Debug, Deserialize)] -struct ChatResponse { - choices: Vec, -} - -#[derive(Debug, Deserialize)] -struct Choice { - message: ResponseMessage, -} - -#[derive(Debug, Deserialize)] -struct ResponseMessage { - #[serde(default)] - content: Option, - /// Reasoning/thinking models may return output in `reasoning_content`. - #[serde(default)] - reasoning_content: Option, -} - -impl ResponseMessage { - fn effective_content(&self) -> String { - match &self.content { - Some(c) if !c.is_empty() => c.clone(), - _ => self.reasoning_content.clone().unwrap_or_default(), - } - } -} - -#[derive(Debug, Serialize)] -struct NativeChatRequest { - model: String, - messages: Vec, - temperature: f64, - #[serde(skip_serializing_if = "Option::is_none")] - tools: Option>, - #[serde(skip_serializing_if = "Option::is_none")] - tool_choice: Option, -} - -#[derive(Debug, Serialize)] -struct NativeMessage { - role: String, - #[serde(skip_serializing_if = "Option::is_none")] - content: Option, - #[serde(skip_serializing_if = "Option::is_none")] - tool_call_id: Option, - #[serde(skip_serializing_if = "Option::is_none")] - tool_calls: Option>, - /// Raw reasoning content from thinking models; pass-through for providers - /// that require it in assistant tool-call history messages. - #[serde(skip_serializing_if = "Option::is_none")] - reasoning_content: Option, -} - -#[derive(Debug, Serialize, Deserialize)] -struct NativeToolSpec { - #[serde(rename = "type")] - kind: String, - function: NativeToolFunctionSpec, -} - -#[derive(Debug, Serialize, Deserialize)] -struct NativeToolFunctionSpec { - name: String, - description: String, - parameters: serde_json::Value, -} - -fn parse_native_tool_spec(value: serde_json::Value) -> anyhow::Result { - let spec: NativeToolSpec = serde_json::from_value(value) - .map_err(|e| anyhow::anyhow!("Invalid OpenAI tool specification: {e}"))?; - - if spec.kind != "function" { - anyhow::bail!( - "Invalid OpenAI tool specification: unsupported tool type '{}', expected 'function'", - spec.kind - ); - } - - Ok(spec) -} - -#[derive(Debug, Serialize, Deserialize)] -struct NativeToolCall { - #[serde(skip_serializing_if = "Option::is_none")] - id: Option, - #[serde(rename = "type", skip_serializing_if = "Option::is_none")] - kind: Option, - function: NativeFunctionCall, -} - -#[derive(Debug, Serialize, Deserialize)] -struct NativeFunctionCall { - name: String, - arguments: String, -} - -#[derive(Debug, Deserialize)] -struct NativeChatResponse { - choices: Vec, - #[serde(default)] - usage: Option, -} - -#[derive(Debug, Deserialize)] -struct UsageInfo { - #[serde(default)] - prompt_tokens: Option, - #[serde(default)] - completion_tokens: Option, -} - -#[derive(Debug, Deserialize)] -struct NativeChoice { - message: NativeResponseMessage, -} - -#[derive(Debug, Deserialize)] -struct NativeResponseMessage { - #[serde(default)] - content: Option, - /// Reasoning/thinking models may return output in `reasoning_content`. - #[serde(default)] - reasoning_content: Option, - #[serde(default)] - tool_calls: Option>, -} - -impl NativeResponseMessage { - fn effective_content(&self) -> Option { - match &self.content { - Some(c) if !c.is_empty() => Some(c.clone()), - _ => self.reasoning_content.clone(), - } - } -} - -async fn api_error(provider: &str, response: reqwest::Response) -> anyhow::Error { - let status = response.status(); - let body = response - .text() - .await - .unwrap_or_else(|_| "".to_string()); - anyhow::anyhow!("API error ({provider}, {status}): {body}") -} - -impl OpenAiProvider { - pub fn new(credential: Option<&str>) -> Self { - Self::with_base_url(None, credential) - } - - /// Create a provider with an optional custom base URL. - /// Defaults to `https://api.openai.com/v1` when `base_url` is `None`. - pub fn with_base_url(base_url: Option<&str>, credential: Option<&str>) -> Self { - Self { - base_url: base_url - .map(|u| u.trim_end_matches('/').to_string()) - .unwrap_or_else(|| "https://api.openai.com/v1".to_string()), - credential: credential.map(ToString::to_string), - } - } - - fn convert_tools(tools: Option<&[ToolSpec]>) -> Option> { - tools.map(|items| { - items - .iter() - .map(|tool| NativeToolSpec { - kind: "function".to_string(), - function: NativeToolFunctionSpec { - name: tool.name.clone(), - description: tool.description.clone(), - parameters: tool.parameters.clone(), - }, - }) - .collect() - }) - } - - fn convert_messages(messages: &[ChatMessage]) -> Vec { - messages - .iter() - .map(|m| { - if m.role == "assistant" - && let Ok(value) = serde_json::from_str::(&m.content) - && let Some(tool_calls_value) = value.get("tool_calls") - && let Ok(parsed_calls) = - serde_json::from_value::>(tool_calls_value.clone()) - { - let tool_calls = parsed_calls - .into_iter() - .map(|tc| NativeToolCall { - id: Some(tc.id), - kind: Some("function".to_string()), - function: NativeFunctionCall { - name: tc.name, - arguments: tc.arguments, - }, - }) - .collect::>(); - let content = value - .get("content") - .and_then(serde_json::Value::as_str) - .map(ToString::to_string); - let reasoning_content = value - .get("reasoning_content") - .and_then(serde_json::Value::as_str) - .map(ToString::to_string); - return NativeMessage { - role: "assistant".to_string(), - content, - tool_call_id: None, - tool_calls: Some(tool_calls), - reasoning_content, - }; - } - - if m.role == "tool" - && let Ok(value) = serde_json::from_str::(&m.content) - { - let tool_call_id = value - .get("tool_call_id") - .and_then(serde_json::Value::as_str) - .map(ToString::to_string); - let content = value - .get("content") - .and_then(serde_json::Value::as_str) - .map(ToString::to_string); - return NativeMessage { - role: "tool".to_string(), - content, - tool_call_id, - tool_calls: None, - reasoning_content: None, - }; - } - - NativeMessage { - role: m.role.clone(), - content: Some(m.content.clone()), - tool_call_id: None, - tool_calls: None, - reasoning_content: None, - } - }) - .collect() - } - - fn parse_native_response(message: NativeResponseMessage) -> ProviderChatResponse { - let text = message.effective_content(); - let reasoning_content = message.reasoning_content.clone(); - let tool_calls = message - .tool_calls - .unwrap_or_default() - .into_iter() - .map(|tc| ProviderToolCall { - id: tc.id.unwrap_or_else(|| uuid::Uuid::new_v4().to_string()), - name: tc.function.name, - arguments: tc.function.arguments, - }) - .collect::>(); - - ProviderChatResponse { - text, - tool_calls, - usage: None, - reasoning_content, - } - } - - fn http_client(&self) -> Client { - reqwest::Client::builder() - .timeout(std::time::Duration::from_secs(120)) - .build() - .unwrap_or_default() - } -} - -#[async_trait] -impl Provider for OpenAiProvider { - async fn chat_with_system( - &self, - system_prompt: Option<&str>, - message: &str, - model: &str, - temperature: f64, - ) -> anyhow::Result { - let credential = self.credential.as_ref().ok_or_else(|| { - anyhow::anyhow!("OpenAI API key not set. Set OPENAI_API_KEY or edit config.") - })?; - - let mut messages = Vec::new(); - - if let Some(sys) = system_prompt { - messages.push(Message { - role: "system".to_string(), - content: sys.to_string(), - }); - } - - messages.push(Message { - role: "user".to_string(), - content: message.to_string(), - }); - - let request = ChatRequest { - model: model.to_string(), - messages, - temperature, - }; - - let response = self - .http_client() - .post(format!("{}/chat/completions", self.base_url)) - .header("Authorization", format!("Bearer {credential}")) - .json(&request) - .send() - .await?; - - if !response.status().is_success() { - return Err(api_error("OpenAI", response).await); - } - - let chat_response: ChatResponse = response.json().await?; - - chat_response - .choices - .into_iter() - .next() - .map(|c| c.message.effective_content()) - .ok_or_else(|| anyhow::anyhow!("No response from OpenAI")) - } - - async fn chat( - &self, - request: ProviderChatRequest<'_>, - model: &str, - temperature: f64, - ) -> anyhow::Result { - let credential = self.credential.as_ref().ok_or_else(|| { - anyhow::anyhow!("OpenAI API key not set. Set OPENAI_API_KEY or edit config.") - })?; - - let tools = Self::convert_tools(request.tools); - let native_request = NativeChatRequest { - model: model.to_string(), - messages: Self::convert_messages(request.messages), - temperature, - tool_choice: tools.as_ref().map(|_| "auto".to_string()), - tools, - }; - - let response = self - .http_client() - .post(format!("{}/chat/completions", self.base_url)) - .header("Authorization", format!("Bearer {credential}")) - .json(&native_request) - .send() - .await?; - - if !response.status().is_success() { - return Err(api_error("OpenAI", response).await); - } - - let native_response: NativeChatResponse = response.json().await?; - let usage = native_response.usage.map(|u| TokenUsage { - input_tokens: u.prompt_tokens, - output_tokens: u.completion_tokens, - }); - let message = native_response - .choices - .into_iter() - .next() - .map(|c| c.message) - .ok_or_else(|| anyhow::anyhow!("No response from OpenAI"))?; - let mut result = Self::parse_native_response(message); - result.usage = usage; - Ok(result) - } - - fn supports_native_tools(&self) -> bool { - true - } - - async fn chat_with_tools( - &self, - messages: &[ChatMessage], - tools: &[serde_json::Value], - model: &str, - temperature: f64, - ) -> anyhow::Result { - let credential = self.credential.as_ref().ok_or_else(|| { - anyhow::anyhow!("OpenAI API key not set. Set OPENAI_API_KEY or edit config.") - })?; - - let native_tools: Option> = if tools.is_empty() { - None - } else { - Some( - tools - .iter() - .cloned() - .map(parse_native_tool_spec) - .collect::, _>>()?, - ) - }; - - let native_request = NativeChatRequest { - model: model.to_string(), - messages: Self::convert_messages(messages), - temperature, - tool_choice: native_tools.as_ref().map(|_| "auto".to_string()), - tools: native_tools, - }; - - let response = self - .http_client() - .post(format!("{}/chat/completions", self.base_url)) - .header("Authorization", format!("Bearer {credential}")) - .json(&native_request) - .send() - .await?; - - if !response.status().is_success() { - return Err(api_error("OpenAI", response).await); - } - - let native_response: NativeChatResponse = response.json().await?; - let usage = native_response.usage.map(|u| TokenUsage { - input_tokens: u.prompt_tokens, - output_tokens: u.completion_tokens, - }); - let message = native_response - .choices - .into_iter() - .next() - .map(|c| c.message) - .ok_or_else(|| anyhow::anyhow!("No response from OpenAI"))?; - let mut result = Self::parse_native_response(message); - result.usage = usage; - Ok(result) - } - - async fn warmup(&self) -> anyhow::Result<()> { - if let Some(credential) = self.credential.as_ref() { - self.http_client() - .get(format!("{}/models", self.base_url)) - .header("Authorization", format!("Bearer {credential}")) - .send() - .await? - .error_for_status()?; - } - Ok(()) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn creates_with_key() { - let p = OpenAiProvider::new(Some("openai-test-credential")); - assert_eq!(p.credential.as_deref(), Some("openai-test-credential")); - } - - #[test] - fn creates_without_key() { - let p = OpenAiProvider::new(None); - assert!(p.credential.is_none()); - } - - #[test] - fn creates_with_empty_key() { - let p = OpenAiProvider::new(Some("")); - assert_eq!(p.credential.as_deref(), Some("")); - } - - #[tokio::test] - async fn chat_fails_without_key() { - let p = OpenAiProvider::new(None); - let result = p.chat_with_system(None, "hello", "gpt-4o", 0.7).await; - assert!(result.is_err()); - assert!(result.unwrap_err().to_string().contains("API key not set")); - } - - #[tokio::test] - async fn chat_with_system_fails_without_key() { - let p = OpenAiProvider::new(None); - let result = p - .chat_with_system(Some("You are a helpful assistant"), "test", "gpt-4o", 0.5) - .await; - assert!(result.is_err()); - } - - #[test] - fn request_serializes_with_system_message() { - let req = ChatRequest { - model: "gpt-4o".to_string(), - messages: vec![ - Message { - role: "system".to_string(), - content: "You are a helpful assistant".to_string(), - }, - Message { - role: "user".to_string(), - content: "hello".to_string(), - }, - ], - temperature: 0.7, - }; - let json = serde_json::to_string(&req).unwrap(); - assert!(json.contains("\"role\":\"system\"")); - assert!(json.contains("\"role\":\"user\"")); - assert!(json.contains("gpt-4o")); - } - - #[test] - fn request_serializes_without_system() { - let req = ChatRequest { - model: "gpt-4o".to_string(), - messages: vec![Message { - role: "user".to_string(), - content: "hello".to_string(), - }], - temperature: 0.0, - }; - let json = serde_json::to_string(&req).unwrap(); - assert!(!json.contains("system")); - assert!(json.contains("\"temperature\":0.0")); - } - - #[test] - fn response_deserializes_single_choice() { - let json = r#"{"choices":[{"message":{"content":"Hi!"}}]}"#; - let resp: ChatResponse = serde_json::from_str(json).unwrap(); - assert_eq!(resp.choices.len(), 1); - assert_eq!(resp.choices[0].message.effective_content(), "Hi!"); - } - - #[test] - fn response_deserializes_empty_choices() { - let json = r#"{"choices":[]}"#; - let resp: ChatResponse = serde_json::from_str(json).unwrap(); - assert!(resp.choices.is_empty()); - } - - #[test] - fn response_deserializes_multiple_choices() { - let json = r#"{"choices":[{"message":{"content":"A"}},{"message":{"content":"B"}}]}"#; - let resp: ChatResponse = serde_json::from_str(json).unwrap(); - assert_eq!(resp.choices.len(), 2); - assert_eq!(resp.choices[0].message.effective_content(), "A"); - } - - #[test] - fn response_with_unicode() { - let json = r#"{"choices":[{"message":{"content":"Hello \u03A9"}}]}"#; - let resp: ChatResponse = serde_json::from_str(json).unwrap(); - assert_eq!( - resp.choices[0].message.effective_content(), - "Hello \u{03A9}" - ); - } - - #[test] - fn response_with_long_content() { - let long = "x".repeat(100_000); - let json = format!(r#"{{"choices":[{{"message":{{"content":"{long}"}}}}]}}"#); - let resp: ChatResponse = serde_json::from_str(&json).unwrap(); - assert_eq!( - resp.choices[0].message.content.as_ref().unwrap().len(), - 100_000 - ); - } - - #[tokio::test] - async fn warmup_without_key_is_noop() { - let provider = OpenAiProvider::new(None); - let result = provider.warmup().await; - assert!(result.is_ok()); - } - - #[test] - fn reasoning_content_fallback_empty_content() { - let json = r#"{"choices":[{"message":{"content":"","reasoning_content":"Thinking..."}}]}"#; - let resp: ChatResponse = serde_json::from_str(json).unwrap(); - assert_eq!(resp.choices[0].message.effective_content(), "Thinking..."); - } - - #[test] - fn reasoning_content_fallback_null_content() { - let json = - r#"{"choices":[{"message":{"content":null,"reasoning_content":"Thinking..."}}]}"#; - let resp: ChatResponse = serde_json::from_str(json).unwrap(); - assert_eq!(resp.choices[0].message.effective_content(), "Thinking..."); - } - - #[test] - fn reasoning_content_not_used_when_content_present() { - let json = r#"{"choices":[{"message":{"content":"Hello","reasoning_content":"Ignored"}}]}"#; - let resp: ChatResponse = serde_json::from_str(json).unwrap(); - assert_eq!(resp.choices[0].message.effective_content(), "Hello"); - } - - #[test] - fn native_response_reasoning_content_fallback() { - let json = - r#"{"choices":[{"message":{"content":"","reasoning_content":"Native thinking"}}]}"#; - let resp: NativeChatResponse = serde_json::from_str(json).unwrap(); - let msg = &resp.choices[0].message; - assert_eq!(msg.effective_content(), Some("Native thinking".to_string())); - } - - #[test] - fn native_response_reasoning_content_ignored_when_content_present() { - let json = - r#"{"choices":[{"message":{"content":"Real answer","reasoning_content":"Ignored"}}]}"#; - let resp: NativeChatResponse = serde_json::from_str(json).unwrap(); - let msg = &resp.choices[0].message; - assert_eq!(msg.effective_content(), Some("Real answer".to_string())); - } - - #[tokio::test] - async fn chat_with_tools_fails_without_key() { - let p = OpenAiProvider::new(None); - let messages = vec![ChatMessage::user("hello".to_string())]; - let tools = vec![serde_json::json!({ - "type": "function", - "function": { - "name": "shell", - "description": "Run a shell command", - "parameters": { - "type": "object", - "properties": { - "command": { "type": "string" } - }, - "required": ["command"] - } - } - })]; - let result = p.chat_with_tools(&messages, &tools, "gpt-4o", 0.7).await; - assert!(result.is_err()); - assert!(result.unwrap_err().to_string().contains("API key not set")); - } - - #[tokio::test] - async fn chat_with_tools_rejects_invalid_tool_shape() { - let p = OpenAiProvider::new(Some("openai-test-credential")); - let messages = vec![ChatMessage::user("hello".to_string())]; - let tools = vec![serde_json::json!({ - "type": "function", - "function": { - "name": "shell", - "parameters": { - "type": "object", - "properties": { - "command": { "type": "string" } - }, - "required": ["command"] - } - } - })]; - - let result = p.chat_with_tools(&messages, &tools, "gpt-4o", 0.7).await; - assert!(result.is_err()); - assert!( - result - .unwrap_err() - .to_string() - .contains("Invalid OpenAI tool specification") - ); - } - - #[test] - fn native_tool_spec_deserializes_from_openai_format() { - let json = serde_json::json!({ - "type": "function", - "function": { - "name": "shell", - "description": "Run a shell command", - "parameters": { - "type": "object", - "properties": { - "command": { "type": "string" } - }, - "required": ["command"] - } - } - }); - let spec = parse_native_tool_spec(json).unwrap(); - assert_eq!(spec.kind, "function"); - assert_eq!(spec.function.name, "shell"); - } - - #[test] - fn native_response_parses_usage() { - let json = r#"{ - "choices": [{"message": {"content": "Hello"}}], - "usage": {"prompt_tokens": 100, "completion_tokens": 50} - }"#; - let resp: NativeChatResponse = serde_json::from_str(json).unwrap(); - let usage = resp.usage.unwrap(); - assert_eq!(usage.prompt_tokens, Some(100)); - assert_eq!(usage.completion_tokens, Some(50)); - } - - #[test] - fn native_response_parses_without_usage() { - let json = r#"{"choices": [{"message": {"content": "Hello"}}]}"#; - let resp: NativeChatResponse = serde_json::from_str(json).unwrap(); - assert!(resp.usage.is_none()); - } - - #[test] - fn parse_native_response_captures_reasoning_content() { - let json = r#"{"choices":[{"message":{ - "content":"answer", - "reasoning_content":"thinking step", - "tool_calls":[{"id":"call_1","type":"function","function":{"name":"shell","arguments":"{}"}}] - }}]}"#; - let resp: NativeChatResponse = serde_json::from_str(json).unwrap(); - let message = resp.choices.into_iter().next().unwrap().message; - let parsed = OpenAiProvider::parse_native_response(message); - assert_eq!(parsed.reasoning_content.as_deref(), Some("thinking step")); - assert_eq!(parsed.tool_calls.len(), 1); - } - - #[test] - fn parse_native_response_none_reasoning_content_for_normal_model() { - let json = r#"{"choices":[{"message":{"content":"hello"}}]}"#; - let resp: NativeChatResponse = serde_json::from_str(json).unwrap(); - let message = resp.choices.into_iter().next().unwrap().message; - let parsed = OpenAiProvider::parse_native_response(message); - assert!(parsed.reasoning_content.is_none()); - } - - #[test] - fn convert_messages_round_trips_reasoning_content() { - let history_json = serde_json::json!({ - "content": "I will check", - "tool_calls": [{ - "id": "tc_1", - "name": "shell", - "arguments": "{}" - }], - "reasoning_content": "Let me think..." - }); - - let messages = vec![ChatMessage::assistant(history_json.to_string())]; - let native = OpenAiProvider::convert_messages(&messages); - assert_eq!(native.len(), 1); - assert_eq!( - native[0].reasoning_content.as_deref(), - Some("Let me think...") - ); - } - - #[test] - fn convert_messages_no_reasoning_content_when_absent() { - let history_json = serde_json::json!({ - "content": "I will check", - "tool_calls": [{ - "id": "tc_1", - "name": "shell", - "arguments": "{}" - }] - }); - - let messages = vec![ChatMessage::assistant(history_json.to_string())]; - let native = OpenAiProvider::convert_messages(&messages); - assert_eq!(native.len(), 1); - assert!(native[0].reasoning_content.is_none()); - } - - #[test] - fn native_message_omits_reasoning_content_when_none() { - let msg = NativeMessage { - role: "assistant".to_string(), - content: Some("hi".to_string()), - tool_call_id: None, - tool_calls: None, - reasoning_content: None, - }; - let json = serde_json::to_string(&msg).unwrap(); - assert!(!json.contains("reasoning_content")); - } - - #[test] - fn native_message_includes_reasoning_content_when_some() { - let msg = NativeMessage { - role: "assistant".to_string(), - content: Some("hi".to_string()), - tool_call_id: None, - tool_calls: None, - reasoning_content: Some("thinking...".to_string()), - }; - let json = serde_json::to_string(&msg).unwrap(); - assert!(json.contains("reasoning_content")); - assert!(json.contains("thinking...")); - } -} diff --git a/crewforge-rs/src/provider/openai_codex.rs b/crewforge-rs/src/provider/openai_oauth.rs similarity index 100% rename from crewforge-rs/src/provider/openai_codex.rs rename to crewforge-rs/src/provider/openai_oauth.rs diff --git a/crewforge-rs/src/provider/openrouter.rs b/crewforge-rs/src/provider/openrouter.rs deleted file mode 100644 index 5887d6b..0000000 --- a/crewforge-rs/src/provider/openrouter.rs +++ /dev/null @@ -1,823 +0,0 @@ -use crate::provider::traits::{ - ChatMessage, ChatRequest as ProviderChatRequest, ChatResponse as ProviderChatResponse, - Provider, ProviderCapabilities, TokenUsage, ToolCall as ProviderToolCall, ToolSpec, -}; -use async_trait::async_trait; -use reqwest::Client; -use serde::{Deserialize, Serialize}; - -pub struct OpenRouterProvider { - credential: Option, -} - -#[derive(Debug, Serialize)] -struct ChatRequest { - model: String, - messages: Vec, - temperature: f64, -} - -#[derive(Debug, Serialize)] -struct Message { - role: String, - content: String, -} - -#[derive(Debug, Deserialize)] -struct ApiChatResponse { - choices: Vec, -} - -#[derive(Debug, Deserialize)] -struct Choice { - message: ResponseMessage, -} - -#[derive(Debug, Deserialize)] -struct ResponseMessage { - content: String, -} - -#[derive(Debug, Serialize)] -struct NativeChatRequest { - model: String, - messages: Vec, - temperature: f64, - #[serde(skip_serializing_if = "Option::is_none")] - tools: Option>, - #[serde(skip_serializing_if = "Option::is_none")] - tool_choice: Option, -} - -#[derive(Debug, Serialize)] -struct NativeMessage { - role: String, - #[serde(skip_serializing_if = "Option::is_none")] - content: Option, - #[serde(skip_serializing_if = "Option::is_none")] - tool_call_id: Option, - #[serde(skip_serializing_if = "Option::is_none")] - tool_calls: Option>, - /// Raw reasoning content from thinking models; pass-through for providers - /// that require it in assistant tool-call history messages. - #[serde(skip_serializing_if = "Option::is_none")] - reasoning_content: Option, -} - -#[derive(Debug, Serialize)] -struct NativeToolSpec { - #[serde(rename = "type")] - kind: String, - function: NativeToolFunctionSpec, -} - -#[derive(Debug, Serialize)] -struct NativeToolFunctionSpec { - name: String, - description: String, - parameters: serde_json::Value, -} - -#[derive(Debug, Serialize, Deserialize)] -struct NativeToolCall { - #[serde(skip_serializing_if = "Option::is_none")] - id: Option, - #[serde(rename = "type", skip_serializing_if = "Option::is_none")] - kind: Option, - function: NativeFunctionCall, -} - -#[derive(Debug, Serialize, Deserialize)] -struct NativeFunctionCall { - name: String, - arguments: String, -} - -#[derive(Debug, Deserialize)] -struct NativeChatResponse { - choices: Vec, - #[serde(default)] - usage: Option, -} - -#[derive(Debug, Deserialize)] -struct UsageInfo { - #[serde(default)] - prompt_tokens: Option, - #[serde(default)] - completion_tokens: Option, -} - -#[derive(Debug, Deserialize)] -struct NativeChoice { - message: NativeResponseMessage, -} - -#[derive(Debug, Deserialize)] -struct NativeResponseMessage { - #[serde(default)] - content: Option, - /// Reasoning/thinking models may return output in `reasoning_content`. - #[serde(default)] - reasoning_content: Option, - #[serde(default)] - tool_calls: Option>, -} - -async fn api_error(provider: &str, response: reqwest::Response) -> anyhow::Error { - let status = response.status(); - let body = response - .text() - .await - .unwrap_or_else(|_| "".to_string()); - anyhow::anyhow!("API error ({provider}, {status}): {body}") -} - -impl OpenRouterProvider { - pub fn new(credential: Option<&str>) -> Self { - Self { - credential: credential.map(ToString::to_string), - } - } - - fn convert_tools(tools: Option<&[ToolSpec]>) -> Option> { - let items = tools?; - if items.is_empty() { - return None; - } - Some( - items - .iter() - .map(|tool| NativeToolSpec { - kind: "function".to_string(), - function: NativeToolFunctionSpec { - name: tool.name.clone(), - description: tool.description.clone(), - parameters: tool.parameters.clone(), - }, - }) - .collect(), - ) - } - - fn convert_messages(messages: &[ChatMessage]) -> Vec { - messages - .iter() - .map(|m| { - if m.role == "assistant" - && let Ok(value) = serde_json::from_str::(&m.content) - && let Some(tool_calls_value) = value.get("tool_calls") - && let Ok(parsed_calls) = - serde_json::from_value::>(tool_calls_value.clone()) - { - let tool_calls = parsed_calls - .into_iter() - .map(|tc| NativeToolCall { - id: Some(tc.id), - kind: Some("function".to_string()), - function: NativeFunctionCall { - name: tc.name, - arguments: tc.arguments, - }, - }) - .collect::>(); - let content = value - .get("content") - .and_then(serde_json::Value::as_str) - .map(ToString::to_string); - let reasoning_content = value - .get("reasoning_content") - .and_then(serde_json::Value::as_str) - .map(ToString::to_string); - return NativeMessage { - role: "assistant".to_string(), - content, - tool_call_id: None, - tool_calls: Some(tool_calls), - reasoning_content, - }; - } - - if m.role == "tool" - && let Ok(value) = serde_json::from_str::(&m.content) - { - let tool_call_id = value - .get("tool_call_id") - .and_then(serde_json::Value::as_str) - .map(ToString::to_string); - let content = value - .get("content") - .and_then(serde_json::Value::as_str) - .map(ToString::to_string) - .or_else(|| Some(m.content.clone())); - return NativeMessage { - role: "tool".to_string(), - content, - tool_call_id, - tool_calls: None, - reasoning_content: None, - }; - } - - NativeMessage { - role: m.role.clone(), - content: Some(m.content.clone()), - tool_call_id: None, - tool_calls: None, - reasoning_content: None, - } - }) - .collect() - } - - fn parse_native_response(message: NativeResponseMessage) -> ProviderChatResponse { - let reasoning_content = message.reasoning_content.clone(); - let tool_calls = message - .tool_calls - .unwrap_or_default() - .into_iter() - .map(|tc| ProviderToolCall { - id: tc.id.unwrap_or_else(|| uuid::Uuid::new_v4().to_string()), - name: tc.function.name, - arguments: tc.function.arguments, - }) - .collect::>(); - - ProviderChatResponse { - text: message.content, - tool_calls, - usage: None, - reasoning_content, - } - } - - fn http_client(&self) -> Client { - reqwest::Client::builder() - .timeout(std::time::Duration::from_secs(120)) - .build() - .unwrap_or_default() - } -} - -#[async_trait] -impl Provider for OpenRouterProvider { - fn capabilities(&self) -> ProviderCapabilities { - ProviderCapabilities { - native_tool_calling: true, - vision: false, - } - } - - async fn warmup(&self) -> anyhow::Result<()> { - if let Some(credential) = self.credential.as_ref() { - self.http_client() - .get("https://openrouter.ai/api/v1/auth/key") - .header("Authorization", format!("Bearer {credential}")) - .send() - .await? - .error_for_status()?; - } - Ok(()) - } - - async fn chat_with_system( - &self, - system_prompt: Option<&str>, - message: &str, - model: &str, - temperature: f64, - ) -> anyhow::Result { - let credential = self.credential.as_ref().ok_or_else(|| { - anyhow::anyhow!("OpenRouter API key not set. Set OPENROUTER_API_KEY.") - })?; - - let mut messages = Vec::new(); - - if let Some(sys) = system_prompt { - messages.push(Message { - role: "system".to_string(), - content: sys.to_string(), - }); - } - - messages.push(Message { - role: "user".to_string(), - content: message.to_string(), - }); - - let request = ChatRequest { - model: model.to_string(), - messages, - temperature, - }; - - let response = self - .http_client() - .post("https://openrouter.ai/api/v1/chat/completions") - .header("Authorization", format!("Bearer {credential}")) - .header("HTTP-Referer", "https://github.com/crewforge/crewforge") - .header("X-Title", "CrewForge") - .json(&request) - .send() - .await?; - - if !response.status().is_success() { - return Err(api_error("OpenRouter", response).await); - } - - let chat_response: ApiChatResponse = response.json().await?; - - chat_response - .choices - .into_iter() - .next() - .map(|c| c.message.content) - .ok_or_else(|| anyhow::anyhow!("No response from OpenRouter")) - } - - async fn chat_with_history( - &self, - messages: &[ChatMessage], - model: &str, - temperature: f64, - ) -> anyhow::Result { - let credential = self.credential.as_ref().ok_or_else(|| { - anyhow::anyhow!("OpenRouter API key not set. Set OPENROUTER_API_KEY.") - })?; - - let api_messages: Vec = messages - .iter() - .map(|m| Message { - role: m.role.clone(), - content: m.content.clone(), - }) - .collect(); - - let request = ChatRequest { - model: model.to_string(), - messages: api_messages, - temperature, - }; - - let response = self - .http_client() - .post("https://openrouter.ai/api/v1/chat/completions") - .header("Authorization", format!("Bearer {credential}")) - .header("HTTP-Referer", "https://github.com/crewforge/crewforge") - .header("X-Title", "CrewForge") - .json(&request) - .send() - .await?; - - if !response.status().is_success() { - return Err(api_error("OpenRouter", response).await); - } - - let chat_response: ApiChatResponse = response.json().await?; - - chat_response - .choices - .into_iter() - .next() - .map(|c| c.message.content) - .ok_or_else(|| anyhow::anyhow!("No response from OpenRouter")) - } - - async fn chat( - &self, - request: ProviderChatRequest<'_>, - model: &str, - temperature: f64, - ) -> anyhow::Result { - let credential = self.credential.as_ref().ok_or_else(|| { - anyhow::anyhow!("OpenRouter API key not set. Set OPENROUTER_API_KEY.") - })?; - - let tools = Self::convert_tools(request.tools); - let native_request = NativeChatRequest { - model: model.to_string(), - messages: Self::convert_messages(request.messages), - temperature, - tool_choice: tools.as_ref().map(|_| "auto".to_string()), - tools, - }; - - let response = self - .http_client() - .post("https://openrouter.ai/api/v1/chat/completions") - .header("Authorization", format!("Bearer {credential}")) - .header("HTTP-Referer", "https://github.com/crewforge/crewforge") - .header("X-Title", "CrewForge") - .json(&native_request) - .send() - .await?; - - if !response.status().is_success() { - return Err(api_error("OpenRouter", response).await); - } - - let native_response: NativeChatResponse = response.json().await?; - let usage = native_response.usage.map(|u| TokenUsage { - input_tokens: u.prompt_tokens, - output_tokens: u.completion_tokens, - }); - let message = native_response - .choices - .into_iter() - .next() - .map(|c| c.message) - .ok_or_else(|| anyhow::anyhow!("No response from OpenRouter"))?; - let mut result = Self::parse_native_response(message); - result.usage = usage; - Ok(result) - } - - fn supports_native_tools(&self) -> bool { - true - } - - async fn chat_with_tools( - &self, - messages: &[ChatMessage], - tools: &[serde_json::Value], - model: &str, - temperature: f64, - ) -> anyhow::Result { - let credential = self.credential.as_ref().ok_or_else(|| { - anyhow::anyhow!("OpenRouter API key not set. Set OPENROUTER_API_KEY.") - })?; - - let native_tools: Option> = if tools.is_empty() { - None - } else { - let specs: Vec = tools - .iter() - .filter_map(|t| { - let func = t.get("function")?; - Some(NativeToolSpec { - kind: "function".to_string(), - function: NativeToolFunctionSpec { - name: func.get("name")?.as_str()?.to_string(), - description: func - .get("description") - .and_then(|d| d.as_str()) - .unwrap_or("") - .to_string(), - parameters: func - .get("parameters") - .cloned() - .unwrap_or(serde_json::json!({})), - }, - }) - }) - .collect(); - if specs.is_empty() { None } else { Some(specs) } - }; - - let native_messages = Self::convert_messages(messages); - - let native_request = NativeChatRequest { - model: model.to_string(), - messages: native_messages, - temperature, - tool_choice: native_tools.as_ref().map(|_| "auto".to_string()), - tools: native_tools, - }; - - let response = self - .http_client() - .post("https://openrouter.ai/api/v1/chat/completions") - .header("Authorization", format!("Bearer {credential}")) - .header("HTTP-Referer", "https://github.com/crewforge/crewforge") - .header("X-Title", "CrewForge") - .json(&native_request) - .send() - .await?; - - if !response.status().is_success() { - return Err(api_error("OpenRouter", response).await); - } - - let native_response: NativeChatResponse = response.json().await?; - let usage = native_response.usage.map(|u| TokenUsage { - input_tokens: u.prompt_tokens, - output_tokens: u.completion_tokens, - }); - let message = native_response - .choices - .into_iter() - .next() - .map(|c| c.message) - .ok_or_else(|| anyhow::anyhow!("No response from OpenRouter"))?; - let mut result = Self::parse_native_response(message); - result.usage = usage; - Ok(result) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::provider::traits::{ChatMessage, Provider}; - - #[test] - fn capabilities_report_native_tool_support() { - let provider = OpenRouterProvider::new(Some("openrouter-test-credential")); - let caps = ::capabilities(&provider); - assert!(caps.native_tool_calling); - assert!(!caps.vision); - } - - #[test] - fn creates_with_key() { - let provider = OpenRouterProvider::new(Some("openrouter-test-credential")); - assert_eq!( - provider.credential.as_deref(), - Some("openrouter-test-credential") - ); - } - - #[test] - fn creates_without_key() { - let provider = OpenRouterProvider::new(None); - assert!(provider.credential.is_none()); - } - - #[tokio::test] - async fn warmup_without_key_is_noop() { - let provider = OpenRouterProvider::new(None); - let result = provider.warmup().await; - assert!(result.is_ok()); - } - - #[tokio::test] - async fn chat_with_system_fails_without_key() { - let provider = OpenRouterProvider::new(None); - let result = provider - .chat_with_system(Some("system"), "hello", "openai/gpt-4o", 0.2) - .await; - - assert!(result.is_err()); - assert!(result.unwrap_err().to_string().contains("API key not set")); - } - - #[tokio::test] - async fn chat_with_history_fails_without_key() { - let provider = OpenRouterProvider::new(None); - let messages = vec![ - ChatMessage { - role: "system".into(), - content: "be concise".into(), - }, - ChatMessage { - role: "user".into(), - content: "hello".into(), - }, - ]; - - let result = provider - .chat_with_history(&messages, "anthropic/claude-sonnet-4", 0.7) - .await; - - assert!(result.is_err()); - assert!(result.unwrap_err().to_string().contains("API key not set")); - } - - #[test] - fn chat_request_serializes_with_system_and_user() { - let request = ChatRequest { - model: "anthropic/claude-sonnet-4".into(), - messages: vec![ - Message { - role: "system".into(), - content: "You are helpful".into(), - }, - Message { - role: "user".into(), - content: "Summarize this".into(), - }, - ], - temperature: 0.5, - }; - - let json = serde_json::to_string(&request).unwrap(); - - assert!(json.contains("anthropic/claude-sonnet-4")); - assert!(json.contains("\"role\":\"system\"")); - assert!(json.contains("\"role\":\"user\"")); - assert!(json.contains("\"temperature\":0.5")); - } - - #[test] - fn response_deserializes_single_choice() { - let json = r#"{"choices":[{"message":{"content":"Hi from OpenRouter"}}]}"#; - - let response: ApiChatResponse = serde_json::from_str(json).unwrap(); - - assert_eq!(response.choices.len(), 1); - assert_eq!(response.choices[0].message.content, "Hi from OpenRouter"); - } - - #[test] - fn response_deserializes_empty_choices() { - let json = r#"{"choices":[]}"#; - - let response: ApiChatResponse = serde_json::from_str(json).unwrap(); - - assert!(response.choices.is_empty()); - } - - #[tokio::test] - async fn chat_with_tools_fails_without_key() { - let provider = OpenRouterProvider::new(None); - let messages = vec![ChatMessage { - role: "user".into(), - content: "What is the date?".into(), - }]; - let tools = vec![serde_json::json!({ - "type": "function", - "function": { - "name": "shell", - "description": "Run a shell command", - "parameters": {"type": "object", "properties": {"command": {"type": "string"}}} - } - })]; - - let result = provider - .chat_with_tools(&messages, &tools, "deepseek/deepseek-chat", 0.5) - .await; - - assert!(result.is_err()); - assert!(result.unwrap_err().to_string().contains("API key not set")); - } - - #[test] - fn native_response_deserializes_with_tool_calls() { - let json = r#"{ - "choices":[{ - "message":{ - "content":null, - "tool_calls":[ - {"id":"call_123","type":"function","function":{"name":"get_price","arguments":"{\"symbol\":\"BTC\"}"}} - ] - } - }] - }"#; - - let response: NativeChatResponse = serde_json::from_str(json).unwrap(); - - assert_eq!(response.choices.len(), 1); - let message = &response.choices[0].message; - assert!(message.content.is_none()); - let tool_calls = message.tool_calls.as_ref().unwrap(); - assert_eq!(tool_calls.len(), 1); - assert_eq!(tool_calls[0].id.as_deref(), Some("call_123")); - assert_eq!(tool_calls[0].function.name, "get_price"); - } - - #[test] - fn parse_native_response_converts_to_chat_response() { - let message = NativeResponseMessage { - content: Some("Here you go.".into()), - reasoning_content: None, - tool_calls: Some(vec![NativeToolCall { - id: Some("call_789".into()), - kind: Some("function".into()), - function: NativeFunctionCall { - name: "file_read".into(), - arguments: r#"{"path":"test.txt"}"#.into(), - }, - }]), - }; - - let response = OpenRouterProvider::parse_native_response(message); - - assert_eq!(response.text.as_deref(), Some("Here you go.")); - assert_eq!(response.tool_calls.len(), 1); - assert_eq!(response.tool_calls[0].id, "call_789"); - assert_eq!(response.tool_calls[0].name, "file_read"); - } - - #[test] - fn convert_messages_parses_assistant_tool_call_payload() { - let messages = vec![ChatMessage { - role: "assistant".into(), - content: r#"{"content":"Using tool","tool_calls":[{"id":"call_abc","name":"shell","arguments":"{\"command\":\"pwd\"}"}]}"# - .into(), - }]; - - let converted = OpenRouterProvider::convert_messages(&messages); - assert_eq!(converted.len(), 1); - assert_eq!(converted[0].role, "assistant"); - assert_eq!(converted[0].content.as_deref(), Some("Using tool")); - - let tool_calls = converted[0].tool_calls.as_ref().unwrap(); - assert_eq!(tool_calls.len(), 1); - assert_eq!(tool_calls[0].id.as_deref(), Some("call_abc")); - assert_eq!(tool_calls[0].function.name, "shell"); - } - - #[test] - fn convert_messages_parses_tool_result_payload() { - let messages = vec![ChatMessage { - role: "tool".into(), - content: r#"{"tool_call_id":"call_xyz","content":"done"}"#.into(), - }]; - - let converted = OpenRouterProvider::convert_messages(&messages); - assert_eq!(converted.len(), 1); - assert_eq!(converted[0].role, "tool"); - assert_eq!(converted[0].tool_call_id.as_deref(), Some("call_xyz")); - assert_eq!(converted[0].content.as_deref(), Some("done")); - assert!(converted[0].tool_calls.is_none()); - } - - #[test] - fn native_response_parses_usage() { - let json = r#"{ - "choices": [{"message": {"content": "Hello"}}], - "usage": {"prompt_tokens": 42, "completion_tokens": 15} - }"#; - let resp: NativeChatResponse = serde_json::from_str(json).unwrap(); - let usage = resp.usage.unwrap(); - assert_eq!(usage.prompt_tokens, Some(42)); - assert_eq!(usage.completion_tokens, Some(15)); - } - - #[test] - fn native_response_parses_without_usage() { - let json = r#"{"choices": [{"message": {"content": "Hello"}}]}"#; - let resp: NativeChatResponse = serde_json::from_str(json).unwrap(); - assert!(resp.usage.is_none()); - } - - #[test] - fn parse_native_response_captures_reasoning_content() { - let message = NativeResponseMessage { - content: Some("answer".into()), - reasoning_content: Some("thinking step".into()), - tool_calls: Some(vec![NativeToolCall { - id: Some("call_1".into()), - kind: Some("function".into()), - function: NativeFunctionCall { - name: "shell".into(), - arguments: "{}".into(), - }, - }]), - }; - let parsed = OpenRouterProvider::parse_native_response(message); - assert_eq!(parsed.reasoning_content.as_deref(), Some("thinking step")); - assert_eq!(parsed.tool_calls.len(), 1); - } - - #[test] - fn convert_messages_round_trips_reasoning_content() { - let history_json = serde_json::json!({ - "content": "I will check", - "tool_calls": [{ - "id": "tc_1", - "name": "shell", - "arguments": "{}" - }], - "reasoning_content": "Let me think..." - }); - - let messages = vec![ChatMessage { - role: "assistant".into(), - content: history_json.to_string(), - }]; - let native = OpenRouterProvider::convert_messages(&messages); - assert_eq!(native.len(), 1); - assert_eq!( - native[0].reasoning_content.as_deref(), - Some("Let me think...") - ); - } - - #[test] - fn native_message_omits_reasoning_content_when_none() { - let msg = NativeMessage { - role: "assistant".to_string(), - content: Some("hi".into()), - tool_call_id: None, - tool_calls: None, - reasoning_content: None, - }; - let json = serde_json::to_string(&msg).unwrap(); - assert!(!json.contains("reasoning_content")); - } - - #[test] - fn native_message_includes_reasoning_content_when_some() { - let msg = NativeMessage { - role: "assistant".to_string(), - content: Some("hi".into()), - tool_call_id: None, - tool_calls: None, - reasoning_content: Some("thinking...".to_string()), - }; - let json = serde_json::to_string(&msg).unwrap(); - assert!(json.contains("reasoning_content")); - assert!(json.contains("thinking...")); - } -}