diff --git a/Cargo.lock b/Cargo.lock index b252580..e88faa8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1602,6 +1602,7 @@ dependencies = [ "dirs", "dotenvy", "glob", + "once_cell", "regex", "rustyline", "serde", diff --git a/Cargo.toml b/Cargo.toml index 330c7a4..4ee7bd9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,6 +15,7 @@ clap = { version = "4", features = ["derive", "env"] } dirs = "5" dotenvy = "0.15" glob = "0.3" +once_cell = "1" regex = "1" rustyline = { version = "17", features = ["with-file-history"] } serde = { version = "1", features = ["derive"] } diff --git a/IDEAS.md b/IDEAS.md new file mode 100644 index 0000000..fc9bbf9 --- /dev/null +++ b/IDEAS.md @@ -0,0 +1,183 @@ +# Future Optimization Ideas for `-O` Mode + +This document captures ideas for future enhancements to the `-O` optimization flag, building on the research that shorter, denser prompts improve LLM performance. + +## Research Foundation + +- **Context Rot**: Accuracy degrades as prompts grow longer (Chroma study on 18 models) +- **LLMLingua**: 20x compression with only 1.5% performance loss +- **Positive Framing**: "Do this" outperforms "don't do this" in prompts +- **Signal Density**: "Find the smallest set of high-signal tokens that maximize the likelihood of your desired outcome" + +Sources: +- https://github.com/microsoft/LLMLingua +- https://www.anthropic.com/engineering/effective-context-engineering-for-ai-agents +- https://gritdaily.com/impact-prompt-length-llm-performance/ + +--- + +## Implemented Layers + +### Layer 1: Terse System Prompt +- Reduced from ~60 tokens to ~15 tokens +- Positive framing: "AI-to-AI mode. Maximum information density. Structure over prose. No narration." + +### Layer 2: Compressed Tool Schemas +- Tool descriptions shortened (e.g., "Read file content. Paths relative to project root." → "Read file") +- Parameter descriptions stripped in optimize mode +- Uses `SchemaOptions` struct for extensibility + +--- + +## Future Layers + +### Layer 3: Tool Result Compression + +**Concept**: Strip metadata from tool results in `-O` mode. + +Current Read result: +```json +{ + "path": "foo.rs", + "offset": 0, + "truncated": false, + "content": "...", + "sha256": "abc123", + "lines": 42 +} +``` + +Optimized result: +```json +{"content": "..."} +``` + +**Implementation**: +- Add `optimize` flag to `tools::execute()` +- Conditionally strip fields: `path`, `offset`, `truncated`, `sha256`, `lines` +- Keep only essential data needed for task completion + +**Estimated token savings**: 30-50% per tool result + +--- + +### Layer 4: History Summarization + +**Concept**: Compress older conversation turns to maintain context while reducing tokens. + +**Approaches**: +1. **Sliding Window**: Keep only last N turns in full, summarize older ones +2. **Semantic Compression**: Use small model to compress verbose assistant responses +3. **Result Deduplication**: Merge repeated tool results (e.g., multiple Read calls on same file) + +**Implementation ideas**: +- Add `conversation_compressor` module +- Trigger compression when context exceeds threshold +- Preserve tool call/result structure for agent continuity + +**Research reference**: LLMLingua-2 achieves 3-6x faster compression with task-agnostic distillation + +--- + +### Layer 5: Output Style Enforcement + +**Concept**: Enforce structured output format in `-O` mode. + +**Current**: LLM outputs natural language explanations mixed with actions +**Optimized**: Pure structured output, no prose + +**Implementation ideas**: +1. **Structured Output Schema**: Add JSON schema for responses +2. **Response Format Instruction**: "Respond only with tool calls or structured JSON" +3. **Post-processing**: Strip explanation text, keep only actions + +**Example transformation**: +``` +Before: "I'll read the config file to understand the settings. Let me use the Read tool..." +After: [tool_call: Read, path: "config.toml"] +``` + +**Trade-off**: May reduce transparency for human review, but ideal for AI-to-AI pipelines + +--- + +### Layer 6: Dynamic Tool Injection + +**Concept**: Only include tool schemas likely needed for the current task. + +**Current**: All 8 tools included in every request +**Optimized**: Analyze prompt, inject relevant subset + +**Heuristics**: +- "read", "view", "show" → Read, Grep, Glob +- "edit", "modify", "change" → Read, Edit, Write +- "run", "execute", "test", "build" → Bash +- "find", "search" → Grep, Glob +- "delegate", "subagent" → Task + +**Implementation**: +- Add `infer_tools_from_prompt(prompt: &str) -> Vec` +- Apply before schema generation +- Fall back to full toolset if uncertain + +--- + +### Layer 7: CodeAgents-Style Pseudocode + +**Concept**: Use structured pseudocode instead of natural language for reasoning. + +**Research**: CodeAgents framework reduces tokens by 55-87%. + +**Current**: +``` +I need to first read the file to understand its structure, then I'll make the edit... +``` + +**Optimized**: +``` +PLAN: Read("src/main.rs") -> Edit(find="old", replace="new") +``` + +**Implementation**: +- Add `--reasoning-format=pseudocode` option +- Train/prompt model to use structured planning notation +- Parse pseudocode for execution + +--- + +## Measurement & Validation + +To validate optimization effectiveness: + +1. **Token Counting**: Compare input/output tokens with and without `-O` +2. **Task Success Rate**: Ensure optimizations don't reduce accuracy +3. **Latency**: Measure time-to-first-token improvement +4. **Cost**: Calculate API cost savings + +**Suggested benchmarks**: +- Simple file read/edit tasks +- Multi-step refactoring tasks +- Codebase exploration tasks + +--- + +## Configuration Ideas + +Future `SchemaOptions` extensions: +```rust +pub struct SchemaOptions { + pub optimize: bool, + // Future fields: + pub compress_results: bool, + pub dynamic_tools: bool, + pub pseudocode_reasoning: bool, + pub max_history_turns: Option, +} +``` + +Command-line exposure: +``` +yo -O # Enable all optimizations +yo -O --no-compress # Optimize schemas but not results +yo --optimize-level=2 # Granular control +``` diff --git a/fixtures/mcp_calc_server/src/main.rs b/fixtures/mcp_calc_server/src/main.rs index a02aeb7..fc19ee1 100644 --- a/fixtures/mcp_calc_server/src/main.rs +++ b/fixtures/mcp_calc_server/src/main.rs @@ -6,8 +6,6 @@ use std::io::{self, BufRead, Write}; #[derive(Deserialize)] struct JsonRpcRequest { - #[allow(dead_code)] - jsonrpc: String, id: Option, method: String, params: Option, diff --git a/src/agent.rs b/src/agent.rs index 65f38ad..be4657b 100644 --- a/src/agent.rs +++ b/src/agent.rs @@ -2,7 +2,7 @@ use crate::{ cli::Context, - llm, + llm::{self, LlmClient}, plan::{self, PlanPhase}, policy::Decision, tools, @@ -34,6 +34,15 @@ impl CommandStats { } } +/// Result of a turn, including stats and continuation info +#[derive(Debug, Default, Clone)] +pub struct TurnResult { + pub stats: CommandStats, + /// If true, a Stop hook requested continuation with the given prompt + pub force_continue: bool, + pub continue_prompt: Option, +} + const SYSTEM_PROMPT: &str = r#"You are an agentic coding assistant running locally. You can only access files via tools. All paths are relative to the project root. Use Glob/Grep to find files before Read. Before Edit/Write, explain what you will change. @@ -53,12 +62,8 @@ fn verbose(ctx: &Context, message: &str) { } } -pub fn run_turn( - ctx: &Context, - user_input: &str, - messages: &mut Vec, -) -> Result { - let mut stats = CommandStats::default(); +pub fn run_turn(ctx: &Context, user_input: &str, messages: &mut Vec) -> Result { + let mut result = TurnResult::default(); let _ = ctx.transcript.borrow_mut().user_message(user_input); messages.push(json!({ @@ -111,9 +116,10 @@ pub fn run_turn( } // Get built-in tool schemas (including Task for main agent) and add MCP tools + let schema_opts = tools::SchemaOptions::new(ctx.args.optimize); let mut tool_schemas = if in_planning_mode { // In planning mode, only provide read-only tools - tools::schemas() + tools::schemas(&schema_opts) .into_iter() .filter(|schema| { if let Some(name) = schema @@ -128,7 +134,7 @@ pub fn run_turn( }) .collect() } else { - tools::schemas_with_task() + tools::schemas_with_task(&schema_opts) }; // Only add MCP tools if not in planning mode @@ -187,6 +193,11 @@ pub fn run_turn( SYSTEM_PROMPT.to_string() }; + // Add optimization mode instructions if -O flag is set + if ctx.args.optimize { + system_prompt.push_str("\n\nAI-to-AI mode. Maximum information density. Structure over prose. No narration."); + } + // Add skill pack index let skill_index = ctx.skill_index.borrow(); let skill_prompt = skill_index.format_for_prompt(50); @@ -222,8 +233,25 @@ pub fn run_turn( // Track token usage from this LLM call if let Some(usage) = &response.usage { - stats.input_tokens += usage.prompt_tokens; - stats.output_tokens += usage.completion_tokens; + result.stats.input_tokens += usage.prompt_tokens; + result.stats.output_tokens += usage.completion_tokens; + + // Record cost for this operation + let turn_number = *ctx.turn_counter.borrow(); + let op = ctx.session_costs.borrow_mut().record_operation( + turn_number, + &target.model, + usage.prompt_tokens, + usage.completion_tokens, + ); + + // Log token usage to transcript + let _ = ctx.transcript.borrow_mut().token_usage( + &target.model, + usage.prompt_tokens, + usage.completion_tokens, + op.cost_usd, + ); } if response.choices.is_empty() { @@ -234,6 +262,13 @@ pub fn run_turn( let choice = &response.choices[0]; let msg = &choice.message; + // Warn if response was truncated due to length limit + if choice.finish_reason.as_deref() == Some("length") { + eprintln!( + "⚠️ Response truncated (max tokens reached). Consider increasing max_tokens or using /compact." + ); + } + if let Some(content) = &msg.content { if !content.is_empty() { println!("{}", content); @@ -311,7 +346,7 @@ pub fn run_turn( let args: Value = serde_json::from_str(&tc.function.arguments).unwrap_or(json!({})); // Count this tool use - stats.tool_uses += 1; + result.stats.tool_uses += 1; trace( ctx, @@ -411,9 +446,9 @@ pub fn run_turn( } } else if name == "Task" { // Execute Task tool (subagent delegation) - let (result, sub_stats) = tools::task::execute(args.clone(), ctx)?; - stats.merge(&sub_stats); - result + let (task_result, sub_stats) = tools::task::execute(args.clone(), ctx)?; + result.stats.merge(&sub_stats); + task_result } else if name.starts_with("mcp.") { // Execute MCP tool let start = std::time::Instant::now(); @@ -508,8 +543,28 @@ pub fn run_turn( } } - // Run Stop hooks (note: force_continue not implemented yet) - let _ = ctx.hooks.borrow().on_stop("end_turn", None); + // Run Stop hooks - may request continuation + let last_assistant_message = messages.iter().rev().find_map(|m| { + if m["role"].as_str() == Some("assistant") { + m["content"].as_str().map(|s| s.to_string()) + } else { + None + } + }); + + let (force_continue, continue_prompt) = ctx + .hooks + .borrow() + .on_stop("end_turn", last_assistant_message.as_deref()); + + // If force_continue is requested, signal to caller to run another turn + if force_continue { + if let Some(prompt) = continue_prompt { + result.force_continue = true; + result.continue_prompt = Some(prompt); + verbose(ctx, "Stop hook requested continuation"); + } + } - Ok(stats) + Ok(result) } diff --git a/src/cli.rs b/src/cli.rs index bf8992e..30ff1b2 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -1,9 +1,12 @@ use crate::{ agent::{self, CommandStats}, backend::BackendRegistry, + commands::CommandIndex, + compact, config::Config, config::PermissionMode, config::Target, + cost::{format_cost, SessionCosts}, hooks::HookManager, mcp::manager::McpManager, model_routing::ModelRouter, @@ -44,22 +47,37 @@ pub struct Context { pub model_router: RefCell, pub plan_mode: RefCell, pub hooks: RefCell, + // Cost tracking + pub session_costs: RefCell, + pub turn_counter: RefCell, + // Slash commands + pub command_index: RefCell, } /// Print command stats to stderr -fn print_stats(duration: Duration, stats: &CommandStats) { +fn print_stats(duration: Duration, stats: &CommandStats, cost: Option) { let tokens = stats.total_tokens(); let token_display = if tokens >= 1000 { format!("{:.1}k", tokens as f64 / 1000.0) } else { tokens.to_string() }; - eprintln!( - "[Duration: {:.1}s | Tokens: {} | Tools: {}]", - duration.as_secs_f64(), - token_display, - stats.tool_uses - ); + if let Some(cost_usd) = cost { + eprintln!( + "[Duration: {:.1}s | Tokens: {} | Cost: {} | Tools: {}]", + duration.as_secs_f64(), + token_display, + format_cost(cost_usd), + stats.tool_uses + ); + } else { + eprintln!( + "[Duration: {:.1}s | Tokens: {} | Tools: {}]", + duration.as_secs_f64(), + token_display, + stats.tool_uses + ); + } } pub fn run_once(ctx: &Context, prompt: &str) -> Result<()> { @@ -71,10 +89,40 @@ pub fn run_once(ctx: &Context, prompt: &str) -> Result<()> { } let prompt = updated_prompt.as_deref().unwrap_or(prompt); + // Increment turn counter + let turn_number = { + let mut counter = ctx.turn_counter.borrow_mut(); + *counter += 1; + *counter + }; + let start = Instant::now(); let mut messages = Vec::new(); - let stats = agent::run_turn(ctx, prompt, &mut messages)?; - print_stats(start.elapsed(), &stats); + let result = agent::run_turn(ctx, prompt, &mut messages)?; + + // Handle force_continue - run another turn if requested + let mut total_stats = result.stats.clone(); + if result.force_continue { + if let Some(continue_prompt) = result.continue_prompt { + println!("[Continuing due to Stop hook...]"); + let continuation = agent::run_turn(ctx, &continue_prompt, &mut messages)?; + total_stats.merge(&continuation.stats); + } + } + + // Get cost for this turn if cost tracking is enabled + let cost = if ctx.config.borrow().cost_tracking.display_in_stats { + let costs = ctx.session_costs.borrow(); + costs + .turns() + .iter() + .find(|t| t.turn_number == turn_number) + .map(|t| t.total_cost()) + } else { + None + }; + + print_stats(start.elapsed(), &total_stats, cost); Ok(()) } @@ -112,10 +160,41 @@ pub fn run_repl(ctx: Context) -> Result<()> { } let line = updated_prompt.unwrap_or_else(|| line.to_string()); + // Increment turn counter + let turn_number = { + let mut counter = ctx.turn_counter.borrow_mut(); + *counter += 1; + *counter + }; + let start = Instant::now(); match agent::run_turn(&ctx, &line, &mut messages) { - Ok(stats) => { - print_stats(start.elapsed(), &stats); + Ok(result) => { + // Handle force_continue - run another turn if requested + let mut total_stats = result.stats.clone(); + if result.force_continue { + if let Some(continue_prompt) = result.continue_prompt { + println!("[Continuing due to Stop hook...]"); + if let Ok(continuation) = + agent::run_turn(&ctx, &continue_prompt, &mut messages) + { + total_stats.merge(&continuation.stats); + } + } + } + + // Get cost for this turn if cost tracking is enabled + let cost = if ctx.config.borrow().cost_tracking.display_in_stats { + let costs = ctx.session_costs.borrow(); + costs + .turns() + .iter() + .find(|t| t.turn_number == turn_number) + .map(|t| t.total_cost()) + } else { + None + }; + print_stats(start.elapsed(), &total_stats, cost); } Err(e) => { eprintln!("Error: {}", e); @@ -159,6 +238,8 @@ fn handle_command(ctx: &Context, cmd: &str, messages: &mut Vec"); println!("Context:"); println!(" /context - show context usage stats"); + println!(" /compact - compact conversation history"); + println!(" /cost - show session cost breakdown"); println!("Subagents:"); println!(" /agents - list available subagents"); println!(" /task - run a subagent with the given prompt"); @@ -173,6 +254,9 @@ fn handle_command(ctx: &Context, cmd: &str, messages: &mut Vec - activate skill"); println!(" /skillpack drop - deactivate skill"); println!(" /skillpack active - list active skills"); + println!("Slash Commands:"); + println!(" /commands - list user-defined commands"); + println!(" / [args] - run a user-defined command"); println!("Plan Mode:"); println!(" /plan - enter plan mode with a task"); println!(" /plan - show current plan or help"); @@ -261,6 +345,18 @@ fn handle_command(ctx: &Context, cmd: &str, messages: &mut Vec { + handle_compact_command(ctx, messages); + } + "/cost" => { + handle_cost_command(ctx); + } + "/commands" => { + handle_commands_list(ctx); } "/mcp" => { handle_mcp_command(ctx, if parts.len() > 1 { parts[1] } else { "" }); @@ -280,7 +376,14 @@ fn handle_command(ctx: &Context, cmd: &str, messages: &mut Vec { handle_plan_command(ctx, if parts.len() > 1 { parts[1] } else { "" }, messages); } - _ => println!("Unknown command: {}", parts[0]), + _ => { + // Check for user-defined slash commands + let cmd_name = &parts[0][1..]; // Remove leading / + let args = if parts.len() > 1 { parts[1] } else { "" }; + if !try_run_slash_command(ctx, cmd_name, args, messages) { + println!("Unknown command: {}", parts[0]); + } + } } false } @@ -469,6 +572,66 @@ fn handle_mcp_command(ctx: &Context, args: &str) { } } +fn handle_cost_command(ctx: &Context) { + use crate::cost::format_tokens; + + let costs = ctx.session_costs.borrow(); + let total_cost = costs.total_cost(); + let total_tokens = costs.total_tokens(); + + println!("Session Cost Summary"); + println!("────────────────────"); + println!( + "Total: {} ({} tokens)", + format_cost(total_cost), + format_tokens(total_tokens) + ); + + // Breakdown by model + let by_model = costs.cost_by_model(); + if !by_model.is_empty() { + println!("\nBy Model:"); + let mut models: Vec<_> = by_model.iter().collect(); + models.sort_by(|a, b| { + b.1 .1 + .partial_cmp(&a.1 .1) + .unwrap_or(std::cmp::Ordering::Equal) + }); + for (model, (tokens, cost)) in models { + println!( + " {}: {} ({} tokens)", + model, + format_cost(*cost), + format_tokens(*tokens) + ); + } + } + + // Breakdown by turn + let turns = costs.turns(); + if !turns.is_empty() { + println!("\nBy Turn:"); + for turn in turns { + println!( + " Turn {}: {} ({} tokens)", + turn.turn_number, + format_cost(turn.total_cost()), + format_tokens(turn.total_tokens()) + ); + } + } + + // Check for warning threshold + if let Some(threshold) = ctx.config.borrow().cost_tracking.warn_threshold_usd { + if total_cost > threshold { + println!( + "\n⚠️ Session cost exceeds threshold of {}", + format_cost(threshold) + ); + } + } +} + fn handle_agents_command(ctx: &Context) { let config = ctx.config.borrow(); if config.agents.is_empty() { @@ -538,7 +701,8 @@ fn handle_task_command(ctx: &Context, args: &str) { } else if let Some(error) = &result.error { println!("Subagent error: {} - {}", error.code, error.message); } - print_stats(start.elapsed(), &stats); + // TODO: Add cost tracking for explicit /task commands + print_stats(start.elapsed(), &stats, None); } Err(e) => { eprintln!("Failed to run subagent: {}", e); @@ -741,11 +905,28 @@ fn handle_plan_start(ctx: &Context, goal: String, messages: &mut Vec { - print_stats(start.elapsed(), &stats); + Ok(result) => { + let cost = if ctx.config.borrow().cost_tracking.display_in_stats { + let costs = ctx.session_costs.borrow(); + costs + .turns() + .iter() + .find(|t| t.turn_number == turn_number) + .map(|t| t.total_cost()) + } else { + None + }; + print_stats(start.elapsed(), &result.stats, cost); // Check if we got a plan let state = ctx.plan_mode.borrow(); @@ -860,12 +1041,19 @@ fn handle_plan_execute(ctx: &Context, messages: &mut Vec) { } ); + // Increment turn counter for plan step + let turn_number = { + let mut counter = ctx.turn_counter.borrow_mut(); + *counter += 1; + *counter + }; + // Execute the step let start = Instant::now(); - let result = agent::run_turn(ctx, &prompt, messages); + let turn_result = agent::run_turn(ctx, &prompt, messages); // Update step status based on result - let step_status = if result.is_ok() { + let step_status = if turn_result.is_ok() { plan::PlanStepStatus::Completed } else { plan::PlanStepStatus::Failed @@ -884,16 +1072,26 @@ fn handle_plan_execute(ctx: &Context, messages: &mut Vec) { let _ = ctx.transcript.borrow_mut().plan_step_end( &plan_name, step.number, - if result.is_ok() { + if turn_result.is_ok() { "completed" } else { "failed" }, ); - match result { - Ok(stats) => { - print_stats(start.elapsed(), &stats); + match turn_result { + Ok(result) => { + let cost = if ctx.config.borrow().cost_tracking.display_in_stats { + let costs = ctx.session_costs.borrow(); + costs + .turns() + .iter() + .find(|t| t.turn_number == turn_number) + .map(|t| t.total_cost()) + } else { + None + }; + print_stats(start.elapsed(), &result.stats, cost); println!("\nStep {} complete.", step.number); } Err(e) => { @@ -1040,3 +1238,164 @@ fn handle_plan_delete(ctx: &Context, name: &str) { } } } + +fn handle_compact_command(ctx: &Context, messages: &mut Vec) { + if messages.is_empty() { + println!("No messages to compact."); + return; + } + + // Get target and client + let target = { + let current = ctx.current_target.borrow(); + if let Some(t) = current.as_ref() { + t.clone() + } else { + match ctx.config.borrow().get_default_target() { + Some(t) => t, + None => { + println!("No target configured. Use /target to set one."); + return; + } + } + } + }; + + println!("Compacting conversation..."); + + // Get context config before borrowing backends + let context_config = ctx.config.borrow().context.clone(); + + // Get client and perform compaction - capture result and release borrow + let compact_result = { + let mut backends = ctx.backends.borrow_mut(); + let client = match backends.get_client(&target.backend) { + Ok(c) => c, + Err(e) => { + eprintln!("Failed to get client: {}", e); + return; + } + }; + + compact::compact_messages(messages, &context_config, client, &target.model) + }; + + match compact_result { + Ok((compacted, result)) => { + *messages = compacted; + println!("{}", compact::format_result(&result)); + if !result.summary.is_empty() { + println!("\nSummary:\n{}", result.summary); + } + } + Err(e) => { + eprintln!("Compaction failed: {}", e); + } + } +} + +fn handle_commands_list(ctx: &Context) { + use crate::commands::CommandSource; + + let index = ctx.command_index.borrow(); + let commands = index.list(); + + if commands.is_empty() { + println!("No slash commands defined."); + println!("Add commands to .yo/commands/.md"); + } else { + println!("Slash Commands ({}):", commands.len()); + for cmd in commands { + let source = match cmd.source { + CommandSource::Project => "[project]", + CommandSource::User => "[user]", + }; + let desc = cmd + .meta + .description + .as_deref() + .unwrap_or("(no description)"); + println!(" /{} {} - {}", cmd.name, source, desc); + } + } + + // Show errors if any + for (path, error) in index.errors() { + eprintln!(" [error] {}: {}", path.display(), error); + } +} + +/// Try to run a user-defined slash command +/// Returns true if a command was found and executed +fn try_run_slash_command( + ctx: &Context, + cmd_name: &str, + args: &str, + messages: &mut Vec, +) -> bool { + let command = { + let index = ctx.command_index.borrow(); + index.get(cmd_name).cloned() + }; + + let Some(command) = command else { + return false; + }; + + // Expand the command with arguments + let prompt = command.expand(args); + + println!("Running command: /{}", cmd_name); + if ctx.args.verbose { + println!("Expanded prompt: {}", prompt); + } + + // Run UserPromptSubmit hooks + let (proceed, updated_prompt) = ctx.hooks.borrow().user_prompt_submit(&prompt); + if !proceed { + eprintln!("Command blocked by hook"); + return true; + } + let prompt = updated_prompt.unwrap_or(prompt); + + // Increment turn counter + let turn_number = { + let mut counter = ctx.turn_counter.borrow_mut(); + *counter += 1; + *counter + }; + + // Run the command as a regular prompt + let start = Instant::now(); + match agent::run_turn(ctx, &prompt, messages) { + Ok(result) => { + // Handle force_continue + let mut total_stats = result.stats.clone(); + if result.force_continue { + if let Some(continue_prompt) = result.continue_prompt { + println!("[Continuing due to Stop hook...]"); + if let Ok(continuation) = agent::run_turn(ctx, &continue_prompt, messages) { + total_stats.merge(&continuation.stats); + } + } + } + + let cost = if ctx.config.borrow().cost_tracking.display_in_stats { + let costs = ctx.session_costs.borrow(); + costs + .turns() + .iter() + .find(|t| t.turn_number == turn_number) + .map(|t| t.total_cost()) + } else { + None + }; + print_stats(start.elapsed(), &total_stats, cost); + } + Err(e) => { + eprintln!("Command error: {}", e); + } + } + + true +} diff --git a/src/commands.rs b/src/commands.rs new file mode 100644 index 0000000..1c5b9d9 --- /dev/null +++ b/src/commands.rs @@ -0,0 +1,209 @@ +//! Slash commands system for user-defined markdown commands. +//! +//! Commands are defined as markdown files in: +//! - .yo/commands/.md (project-level) +//! - ~/.yo/commands/.md (user-level) +//! +//! The command name is derived from the filename (without .md extension). +//! The file content becomes the prompt when the command is invoked. +//! Use $ARGUMENTS as a placeholder for user-provided arguments. + +use anyhow::Result; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::path::{Path, PathBuf}; + +/// Metadata parsed from optional YAML frontmatter +#[derive(Debug, Clone, Default, Deserialize, Serialize)] +pub struct CommandMeta { + #[serde(default)] + pub description: Option, + #[serde(default)] + pub allowed_tools: Option>, +} + +/// A loaded slash command +#[derive(Debug, Clone)] +pub struct Command { + pub name: String, + pub source: CommandSource, + pub meta: CommandMeta, + pub content: String, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum CommandSource { + Project, + User, +} + +impl Command { + /// Expand the command with the given arguments + pub fn expand(&self, arguments: &str) -> String { + self.content.replace("$ARGUMENTS", arguments) + } +} + +/// Index of available slash commands +#[derive(Debug, Default)] +pub struct CommandIndex { + commands: HashMap, + errors: Vec<(PathBuf, String)>, +} + +impl CommandIndex { + /// Build the command index by scanning .yo/commands/ and ~/.yo/commands/ + pub fn build(root: &Path) -> Self { + let mut index = Self::default(); + + // Load user-level commands first (lower priority) + if let Some(home) = dirs::home_dir() { + let user_commands_dir = home.join(".yo").join("commands"); + index.load_from_dir(&user_commands_dir, CommandSource::User); + } + + // Load project-level commands (higher priority, overrides user) + let project_commands_dir = root.join(".yo").join("commands"); + index.load_from_dir(&project_commands_dir, CommandSource::Project); + + index + } + + fn load_from_dir(&mut self, dir: &Path, source: CommandSource) { + if !dir.exists() { + return; + } + + let entries = match std::fs::read_dir(dir) { + Ok(entries) => entries, + Err(_) => return, + }; + + for entry in entries.flatten() { + let path = entry.path(); + if path.extension().is_some_and(|ext| ext == "md") { + if let Some(stem) = path.file_stem() { + let name = stem.to_string_lossy().to_string(); + match self.load_command(&path, &name, source) { + Ok(cmd) => { + self.commands.insert(name, cmd); + } + Err(e) => { + self.errors.push((path.clone(), e.to_string())); + } + } + } + } + } + } + + fn load_command(&mut self, path: &Path, name: &str, source: CommandSource) -> Result { + let content = std::fs::read_to_string(path)?; + + // Parse optional YAML frontmatter + let (meta, content, warning) = parse_frontmatter(&content); + + // Record warning but still load the command + if let Some(warn) = warning { + self.errors.push((path.to_path_buf(), warn)); + } + + Ok(Command { + name: name.to_string(), + source, + meta, + content, + }) + } + + /// Get a command by name + pub fn get(&self, name: &str) -> Option<&Command> { + self.commands.get(name) + } + + /// List all available commands + pub fn list(&self) -> Vec<&Command> { + let mut commands: Vec<_> = self.commands.values().collect(); + commands.sort_by(|a, b| a.name.cmp(&b.name)); + commands + } + + /// Get parse errors + pub fn errors(&self) -> &[(PathBuf, String)] { + &self.errors + } +} + +/// Parse optional YAML frontmatter from markdown content +/// Returns (metadata, body, optional_warning) +fn parse_frontmatter(content: &str) -> (CommandMeta, String, Option) { + let trimmed = content.trim_start(); + + if !trimmed.starts_with("---") { + return (CommandMeta::default(), content.to_string(), None); + } + + // Find the closing --- + if let Some(end_pos) = trimmed[3..].find("\n---") { + let yaml_content = &trimmed[3..3 + end_pos].trim(); + let rest = &trimmed[3 + end_pos + 4..].trim_start(); + + match serde_yaml::from_str(yaml_content) { + Ok(meta) => (meta, rest.to_string(), None), + Err(e) => ( + CommandMeta::default(), + content.to_string(), + Some(format!("invalid YAML frontmatter: {}", e)), + ), + } + } else { + (CommandMeta::default(), content.to_string(), None) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_frontmatter_no_frontmatter() { + let content = "Just some content"; + let (meta, body, warning) = parse_frontmatter(content); + assert!(meta.description.is_none()); + assert_eq!(body, "Just some content"); + assert!(warning.is_none()); + } + + #[test] + fn test_parse_frontmatter_with_metadata() { + let content = r#"--- +description: A test command +allowed_tools: + - Read + - Grep +--- + +The actual command content"#; + let (meta, body, warning) = parse_frontmatter(content); + assert_eq!(meta.description, Some("A test command".to_string())); + assert_eq!( + meta.allowed_tools, + Some(vec!["Read".to_string(), "Grep".to_string()]) + ); + assert_eq!(body, "The actual command content"); + assert!(warning.is_none()); + } + + #[test] + fn test_command_expand() { + let cmd = Command { + name: "test".to_string(), + source: CommandSource::Project, + meta: CommandMeta::default(), + content: "Fix issue #$ARGUMENTS in the codebase".to_string(), + }; + + let expanded = cmd.expand("123"); + assert_eq!(expanded, "Fix issue #123 in the codebase"); + } +} diff --git a/src/compact.rs b/src/compact.rs new file mode 100644 index 0000000..3afaf69 --- /dev/null +++ b/src/compact.rs @@ -0,0 +1,230 @@ +//! Context compaction for managing conversation history. +//! +//! When the context window fills up, this module summarizes older messages +//! to reclaim space while preserving essential information. + +use crate::config::ContextConfig; +use crate::llm::{ChatRequest, Client, LlmClient}; +use anyhow::Result; +use serde_json::{json, Value}; + +/// Estimate character count for a message +fn estimate_chars(msg: &Value) -> usize { + serde_json::to_string(msg).map(|s| s.len()).unwrap_or(0) +} + +/// Calculate total context size in characters +pub fn context_size(messages: &[Value]) -> usize { + messages.iter().map(estimate_chars).sum() +} + +/// Check if compaction is needed based on config thresholds +pub fn needs_compaction(messages: &[Value], config: &ContextConfig) -> bool { + if !config.auto_compact_enabled { + return false; + } + let current_size = context_size(messages); + let threshold = (config.max_chars as f64 * config.auto_compact_threshold) as usize; + current_size > threshold +} + +/// Result of compaction +#[derive(Debug)] +pub struct CompactionResult { + pub original_count: usize, + pub compacted_count: usize, + pub original_chars: usize, + pub compacted_chars: usize, + pub summary: String, +} + +/// Compact conversation history by summarizing older messages +/// +/// Strategy: +/// 1. Keep the most recent `keep_last_turns` messages +/// 2. Summarize all earlier messages into a single system message +/// 3. Return the compacted message list +pub fn compact_messages( + messages: &[Value], + config: &ContextConfig, + llm_client: &Client, + model: &str, +) -> Result<(Vec, CompactionResult)> { + let original_count = messages.len(); + let original_chars = context_size(messages); + + // If we have fewer messages than keep_last_turns, nothing to compact + if messages.len() <= config.keep_last_turns * 2 { + return Ok(( + messages.to_vec(), + CompactionResult { + original_count, + compacted_count: messages.len(), + original_chars, + compacted_chars: original_chars, + summary: String::new(), + }, + )); + } + + // Split messages: older ones to summarize, recent ones to keep + let split_point = messages.len().saturating_sub(config.keep_last_turns * 2); + let (to_summarize, to_keep) = messages.split_at(split_point); + + // Generate summary of older messages + let summary = generate_summary(to_summarize, llm_client, model)?; + + // Build compacted message list + let mut compacted = Vec::new(); + + // Add summary as a system message + compacted.push(json!({ + "role": "system", + "content": format!( + "CONVERSATION SUMMARY (compacted from {} earlier messages):\n\n{}", + to_summarize.len(), + summary + ) + })); + + // Add the recent messages + compacted.extend(to_keep.iter().cloned()); + + let compacted_chars = context_size(&compacted); + let compacted_count = compacted.len(); + + Ok(( + compacted, + CompactionResult { + original_count, + compacted_count, + original_chars, + compacted_chars, + summary, + }, + )) +} + +/// Generate a summary of messages using the LLM +fn generate_summary(messages: &[Value], client: &Client, model: &str) -> Result { + // Format messages for summarization + let mut conversation_text = String::new(); + for msg in messages { + let role = msg["role"].as_str().unwrap_or("unknown"); + let content = msg["content"].as_str().unwrap_or(""); + + // Skip tool call messages but note their presence + if msg.get("tool_calls").is_some() { + conversation_text.push_str(&format!("[{}: used tools]\n", role)); + continue; + } + + // Handle tool responses + if role == "tool" { + let tool_id = msg["tool_call_id"].as_str().unwrap_or("unknown"); + // Truncate long tool results + let content_preview = if content.len() > 200 { + format!("{}...", &content[..200]) + } else { + content.to_string() + }; + conversation_text + .push_str(&format!("[tool result {}]: {}\n", tool_id, content_preview)); + continue; + } + + if !content.is_empty() { + conversation_text.push_str(&format!("{}: {}\n\n", role, content)); + } + } + + // Create summarization request + let request = ChatRequest { + model: model.to_string(), + messages: vec![ + json!({ + "role": "system", + "content": "You are a conversation summarizer. Create a concise summary that captures: +1. What the user asked for +2. What was accomplished (files created/modified, commands run) +3. Any important decisions or context +4. Current state and any pending work + +Be brief but complete. Focus on facts and outcomes." + }), + json!({ + "role": "user", + "content": format!("Summarize this conversation:\n\n{}", conversation_text) + }), + ], + tools: None, + tool_choice: None, + }; + + let response = client.chat(&request)?; + + if let Some(choice) = response.choices.first() { + if let Some(content) = &choice.message.content { + return Ok(content.clone()); + } + } + + Ok("Unable to generate summary.".to_string()) +} + +/// Format compaction result for display +pub fn format_result(result: &CompactionResult) -> String { + let reduction = if result.original_chars > 0 { + 100.0 - (result.compacted_chars as f64 / result.original_chars as f64 * 100.0) + } else { + 0.0 + }; + + format!( + "Compacted: {} → {} messages, {} → {} chars ({:.0}% reduction)", + result.original_count, + result.compacted_count, + result.original_chars, + result.compacted_chars, + reduction + ) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_context_size() { + let messages = vec![ + json!({"role": "user", "content": "hello"}), + json!({"role": "assistant", "content": "hi there"}), + ]; + let size = context_size(&messages); + assert!(size > 0); + } + + #[test] + fn test_needs_compaction() { + let config = ContextConfig { + max_chars: 100, + auto_compact_threshold: 0.8, + auto_compact_enabled: true, + keep_last_turns: 2, + }; + + // Small context - no compaction needed + let small_messages = vec![json!({"role": "user", "content": "hi"})]; + assert!(!needs_compaction(&small_messages, &config)); + + // Disabled - no compaction + let disabled_config = ContextConfig { + auto_compact_enabled: false, + ..config + }; + let large_messages: Vec = (0..100) + .map(|i| json!({"role": "user", "content": format!("message {}", i)})) + .collect(); + assert!(!needs_compaction(&large_messages, &disabled_config)); + } +} diff --git a/src/config.rs b/src/config.rs index feaf879..2662100 100644 --- a/src/config.rs +++ b/src/config.rs @@ -3,6 +3,19 @@ use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::path::Path; +/// A validation error in the configuration +#[derive(Debug, Clone)] +pub struct ValidationError { + pub field: String, + pub message: String, +} + +impl std::fmt::Display for ValidationError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "[{}]: {}", self.field, self.message) + } +} + /// Permission mode for tool calls #[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Deserialize, Serialize)] #[serde(rename_all = "camelCase")] @@ -82,10 +95,23 @@ pub struct BashConfig { pub max_output_bytes: Option, } +/// MCP transport type +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Deserialize, Serialize)] +#[serde(rename_all = "lowercase")] +pub enum McpTransport { + #[default] + Stdio, + Http, + Sse, +} + /// Configuration for an MCP server #[derive(Debug, Clone, Deserialize, Serialize)] #[serde(rename_all = "camelCase")] pub struct McpServerConfig { + /// For stdio: the command to spawn + /// For http/sse: not used (use url instead) + #[serde(default)] pub command: String, #[serde(default)] pub args: Vec, @@ -99,6 +125,12 @@ pub struct McpServerConfig { pub auto_start: bool, #[serde(default = "default_timeout_ms")] pub timeout_ms: u64, + /// Transport type: stdio, http, or sse + #[serde(default)] + pub transport: McpTransport, + /// URL for http/sse transports + #[serde(default)] + pub url: Option, } fn default_cwd() -> String { @@ -288,6 +320,7 @@ impl BackendConfig { } } +use crate::cost::{CostConfig, ModelPricing}; use crate::model_routing::ModelRoutingConfig; /// Main configuration structure @@ -309,6 +342,10 @@ pub struct Config { pub model_routing: ModelRoutingConfig, #[serde(default)] pub hooks: Vec, + #[serde(default)] + pub cost_tracking: CostConfig, + #[serde(default)] + pub model_pricing: HashMap, #[serde(skip)] pub agents: HashMap, } @@ -364,6 +401,8 @@ impl Config { mcp: McpConfig::default(), model_routing: ModelRoutingConfig::default(), hooks: Vec::new(), + cost_tracking: CostConfig::default(), + model_pricing: HashMap::new(), agents: HashMap::new(), } } @@ -462,6 +501,14 @@ impl Config { // Merge hooks (concatenate) self.hooks.extend(other.hooks); + + // Merge cost tracking (take other's values) + self.cost_tracking = other.cost_tracking; + + // Merge model pricing (other takes priority) + for (model, pricing) in other.model_pricing { + self.model_pricing.insert(model, pricing); + } } /// Get the default target @@ -509,6 +556,99 @@ impl Config { !self.backends.is_empty() } + /// Validate configuration and return any errors found + pub fn validate(&self) -> Result<(), Vec> { + let mut errors = Vec::new(); + + // Validate default_target format if set + if let Some(target) = &self.default_target { + if Target::parse(target).is_none() { + errors.push(ValidationError { + field: "default_target".to_string(), + message: format!( + "Invalid target format '{}', expected 'model@backend'", + target + ), + }); + } + } + + // Validate context.auto_compact_threshold range + if !(0.0..=1.0).contains(&self.context.auto_compact_threshold) { + errors.push(ValidationError { + field: "context.auto_compact_threshold".to_string(), + message: format!( + "Must be between 0.0 and 1.0, got {}", + self.context.auto_compact_threshold + ), + }); + } + + // Validate agent specs + for (name, spec) in &self.agents { + if spec.max_turns == 0 { + errors.push(ValidationError { + field: format!("agents.{}.max_turns", name), + message: "Must be greater than 0".to_string(), + }); + } + // Validate permission_mode is recognized + if PermissionMode::from_str(&spec.permission_mode).is_none() { + errors.push(ValidationError { + field: format!("agents.{}.permission_mode", name), + message: format!("Invalid permission mode '{}'", spec.permission_mode), + }); + } + } + + // Validate hook matchers are valid regex + for (i, hook) in self.hooks.iter().enumerate() { + if let Some(matcher) = &hook.matcher { + if regex::Regex::new(matcher).is_err() { + errors.push(ValidationError { + field: format!("hooks[{}].matcher", i), + message: format!("Invalid regex pattern '{}'", matcher), + }); + } + } + // Validate hook command is not empty + if hook.command.is_empty() { + errors.push(ValidationError { + field: format!("hooks[{}].command", i), + message: "Command must not be empty".to_string(), + }); + } + } + + // Validate MCP server configs based on transport type + for (name, server) in &self.mcp.servers { + match server.transport { + McpTransport::Stdio => { + if server.command.is_empty() { + errors.push(ValidationError { + field: format!("mcp.servers.{}.command", name), + message: "Command required for stdio transport".to_string(), + }); + } + } + McpTransport::Http | McpTransport::Sse => { + if server.url.is_none() { + errors.push(ValidationError { + field: format!("mcp.servers.{}.url", name), + message: "URL required for http/sse transport".to_string(), + }); + } + } + } + } + + if errors.is_empty() { + Ok(()) + } else { + Err(errors) + } + } + /// Save permissions to local config file (.yo/config.local.toml) /// Creates the .yo directory if it doesn't exist pub fn save_local_permissions(&self) -> Result<()> { @@ -565,4 +705,58 @@ mod tests { }; assert_eq!(format!("{}", target), "gpt-4@chatgpt"); } + + #[test] + fn test_validate_valid_config() { + let config = Config::with_builtin_backends(); + assert!(config.validate().is_ok()); + } + + #[test] + fn test_validate_invalid_target() { + let mut config = Config::with_builtin_backends(); + config.default_target = Some("no-backend".to_string()); + let errors = config.validate().unwrap_err(); + assert_eq!(errors.len(), 1); + assert!(errors[0].field.contains("default_target")); + } + + #[test] + fn test_validate_invalid_threshold() { + let mut config = Config::with_builtin_backends(); + config.context.auto_compact_threshold = 1.5; + let errors = config.validate().unwrap_err(); + assert_eq!(errors.len(), 1); + assert!(errors[0].field.contains("auto_compact_threshold")); + assert!(errors[0].message.contains("between 0.0 and 1.0")); + } + + #[test] + fn test_validate_invalid_hook_regex() { + let mut config = Config::with_builtin_backends(); + config.hooks.push(HookConfig { + event: HookEvent::PreToolUse, + command: vec!["echo".to_string(), "test".to_string()], + matcher: Some("[invalid regex".to_string()), + timeout_ms: 1000, + }); + let errors = config.validate().unwrap_err(); + assert_eq!(errors.len(), 1); + assert!(errors[0].field.contains("hooks")); + assert!(errors[0].message.contains("Invalid regex")); + } + + #[test] + fn test_validate_empty_hook_command() { + let mut config = Config::with_builtin_backends(); + config.hooks.push(HookConfig { + event: HookEvent::Stop, + command: vec![], + matcher: None, + timeout_ms: 1000, + }); + let errors = config.validate().unwrap_err(); + assert_eq!(errors.len(), 1); + assert!(errors[0].message.contains("empty")); + } } diff --git a/src/cost.rs b/src/cost.rs new file mode 100644 index 0000000..4e021c4 --- /dev/null +++ b/src/cost.rs @@ -0,0 +1,391 @@ +//! Cost tracking module for token economics. +//! +//! Tracks token usage and costs across operations, turns, and sessions. +//! Supports per-model pricing configuration with sensible defaults. + +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + +/// Pricing for a single model (per 1M tokens in USD) +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ModelPricing { + /// Cost per 1M input tokens + pub input: f64, + /// Cost per 1M output tokens + pub output: f64, +} + +impl ModelPricing { + pub fn new(input: f64, output: f64) -> Self { + Self { input, output } + } + + /// Calculate cost for given token counts + pub fn calculate(&self, input_tokens: u64, output_tokens: u64) -> f64 { + let input_cost = (input_tokens as f64 / 1_000_000.0) * self.input; + let output_cost = (output_tokens as f64 / 1_000_000.0) * self.output; + input_cost + output_cost + } +} + +/// Configuration for cost tracking +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct CostConfig { + /// Enable cost tracking and display + #[serde(default = "default_true")] + pub enabled: bool, + /// Warn when session cost exceeds this threshold (USD) + #[serde(default)] + pub warn_threshold_usd: Option, + /// Show cost in the stats line after each turn + #[serde(default = "default_true")] + pub display_in_stats: bool, +} + +fn default_true() -> bool { + true +} + +impl Default for CostConfig { + fn default() -> Self { + Self { + enabled: true, + warn_threshold_usd: None, + display_in_stats: true, + } + } +} + +/// Cost for a single LLM operation +#[derive(Debug, Clone, Serialize)] +pub struct OperationCost { + pub model: String, + pub input_tokens: u64, + pub output_tokens: u64, + pub cost_usd: f64, +} + +impl OperationCost { + pub fn new(model: String, input_tokens: u64, output_tokens: u64, cost_usd: f64) -> Self { + Self { + model, + input_tokens, + output_tokens, + cost_usd, + } + } + + pub fn total_tokens(&self) -> u64 { + self.input_tokens + self.output_tokens + } +} + +/// Aggregated costs for a single turn (user message -> assistant response) +#[derive(Debug, Clone, Default)] +pub struct TurnCost { + pub turn_number: u32, + pub operations: Vec, +} + +impl TurnCost { + pub fn new(turn_number: u32) -> Self { + Self { + turn_number, + operations: Vec::new(), + } + } + + pub fn add_operation(&mut self, op: OperationCost) { + self.operations.push(op); + } + + pub fn total_tokens(&self) -> u64 { + self.operations.iter().map(|op| op.total_tokens()).sum() + } + + pub fn total_cost(&self) -> f64 { + self.operations.iter().map(|op| op.cost_usd).sum() + } + + #[allow(dead_code)] // For future detailed reporting + pub fn input_tokens(&self) -> u64 { + self.operations.iter().map(|op| op.input_tokens).sum() + } + + #[allow(dead_code)] // For future detailed reporting + pub fn output_tokens(&self) -> u64 { + self.operations.iter().map(|op| op.output_tokens).sum() + } +} + +/// Session-level cost tracker +#[derive(Debug, Clone)] +pub struct SessionCosts { + #[allow(dead_code)] // For future session persistence + session_id: String, + turns: Vec, + pricing: PricingTable, +} + +impl SessionCosts { + pub fn new(session_id: String, pricing: PricingTable) -> Self { + Self { + session_id, + turns: Vec::new(), + pricing, + } + } + + /// Record an LLM operation and return the cost + pub fn record_operation( + &mut self, + turn_number: u32, + model: &str, + input_tokens: u64, + output_tokens: u64, + ) -> OperationCost { + let cost_usd = self.pricing.calculate(model, input_tokens, output_tokens); + let op = OperationCost::new(model.to_string(), input_tokens, output_tokens, cost_usd); + + // Find or create the turn + if let Some(turn) = self.turns.iter_mut().find(|t| t.turn_number == turn_number) { + turn.add_operation(op.clone()); + } else { + let mut turn = TurnCost::new(turn_number); + turn.add_operation(op.clone()); + self.turns.push(turn); + } + + op + } + + /// Merge costs from a subagent into the current turn + #[allow(dead_code)] // For future parallel subagent support + pub fn merge_operations(&mut self, turn_number: u32, ops: Vec) { + if let Some(turn) = self.turns.iter_mut().find(|t| t.turn_number == turn_number) { + for op in ops { + turn.add_operation(op); + } + } else { + let mut turn = TurnCost::new(turn_number); + for op in ops { + turn.add_operation(op); + } + self.turns.push(turn); + } + } + + #[allow(dead_code)] // For future session persistence + pub fn session_id(&self) -> &str { + &self.session_id + } + + pub fn turns(&self) -> &[TurnCost] { + &self.turns + } + + pub fn total_tokens(&self) -> u64 { + self.turns.iter().map(|t| t.total_tokens()).sum() + } + + pub fn total_cost(&self) -> f64 { + self.turns.iter().map(|t| t.total_cost()).sum() + } + + #[allow(dead_code)] // For future detailed reporting + pub fn input_tokens(&self) -> u64 { + self.turns.iter().map(|t| t.input_tokens()).sum() + } + + #[allow(dead_code)] // For future detailed reporting + pub fn output_tokens(&self) -> u64 { + self.turns.iter().map(|t| t.output_tokens()).sum() + } + + /// Get cost breakdown by model + pub fn cost_by_model(&self) -> HashMap { + let mut result: HashMap = HashMap::new(); + for turn in &self.turns { + for op in &turn.operations { + let entry = result.entry(op.model.clone()).or_insert((0, 0.0)); + entry.0 += op.total_tokens(); + entry.1 += op.cost_usd; + } + } + result + } +} + +/// Pricing table with model-specific costs +#[derive(Debug, Clone)] +pub struct PricingTable { + models: HashMap, + default_pricing: ModelPricing, +} + +impl Default for PricingTable { + fn default() -> Self { + Self::with_defaults() + } +} + +impl PricingTable { + /// Create a pricing table with default model prices + pub fn with_defaults() -> Self { + let mut models = HashMap::new(); + + // OpenAI models + models.insert("gpt-4o".to_string(), ModelPricing::new(2.50, 10.00)); + models.insert("gpt-4o-mini".to_string(), ModelPricing::new(0.15, 0.60)); + models.insert("gpt-4-turbo".to_string(), ModelPricing::new(10.00, 30.00)); + models.insert("gpt-3.5-turbo".to_string(), ModelPricing::new(0.50, 1.50)); + models.insert("o1".to_string(), ModelPricing::new(15.00, 60.00)); + models.insert("o1-mini".to_string(), ModelPricing::new(3.00, 12.00)); + models.insert("o1-preview".to_string(), ModelPricing::new(15.00, 60.00)); + + // Anthropic models + models.insert( + "claude-3-5-sonnet-latest".to_string(), + ModelPricing::new(3.00, 15.00), + ); + models.insert( + "claude-3-5-sonnet-20241022".to_string(), + ModelPricing::new(3.00, 15.00), + ); + models.insert( + "claude-3-5-haiku-latest".to_string(), + ModelPricing::new(0.80, 4.00), + ); + models.insert( + "claude-3-opus-latest".to_string(), + ModelPricing::new(15.00, 75.00), + ); + + // Venice.ai models - free tier default, override via [pricing] config if needed + models.insert( + "qwen3-235b-a22b-instruct-2507".to_string(), + ModelPricing::new(0.00, 0.00), + ); + models.insert("llama-3.3-70b".to_string(), ModelPricing::new(0.00, 0.00)); + + // Ollama / local models (no cost) + models.insert("llama3".to_string(), ModelPricing::new(0.00, 0.00)); + models.insert("llama3:8b".to_string(), ModelPricing::new(0.00, 0.00)); + models.insert("codellama".to_string(), ModelPricing::new(0.00, 0.00)); + + Self { + models, + default_pricing: ModelPricing::new(1.00, 3.00), // Conservative default + } + } + + /// Create from user config, merging with defaults + pub fn from_config(user_pricing: &HashMap) -> Self { + let mut table = Self::with_defaults(); + for (model, pricing) in user_pricing { + table.models.insert(model.clone(), pricing.clone()); + } + table + } + + /// Get pricing for a model (falls back to default if unknown) + pub fn get(&self, model: &str) -> &ModelPricing { + // Try exact match first + if let Some(pricing) = self.models.get(model) { + return pricing; + } + + // Try prefix matching for versioned models (e.g., "gpt-4o-2024-08-06" -> "gpt-4o") + for (name, pricing) in &self.models { + if model.starts_with(name) { + return pricing; + } + } + + &self.default_pricing + } + + /// Calculate cost for given model and token counts + pub fn calculate(&self, model: &str, input_tokens: u64, output_tokens: u64) -> f64 { + self.get(model).calculate(input_tokens, output_tokens) + } + + /// Add or update pricing for a model + #[allow(dead_code)] // For future runtime pricing updates + pub fn set(&mut self, model: &str, pricing: ModelPricing) { + self.models.insert(model.to_string(), pricing); + } +} + +/// Format a cost value for display +pub fn format_cost(cost: f64) -> String { + if cost < 0.01 { + format!("${:.4}", cost) + } else if cost < 1.0 { + format!("${:.3}", cost) + } else { + format!("${:.2}", cost) + } +} + +/// Format token count for display +pub fn format_tokens(tokens: u64) -> String { + if tokens >= 1_000_000 { + format!("{:.1}M", tokens as f64 / 1_000_000.0) + } else if tokens >= 1_000 { + format!("{:.1}k", tokens as f64 / 1_000.0) + } else { + tokens.to_string() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_model_pricing_calculation() { + let pricing = ModelPricing::new(2.50, 10.00); + // 1000 input tokens, 500 output tokens + let cost = pricing.calculate(1000, 500); + // (1000/1M) * 2.50 + (500/1M) * 10.00 = 0.0025 + 0.005 = 0.0075 + assert!((cost - 0.0075).abs() < 0.0001); + } + + #[test] + fn test_pricing_table_prefix_matching() { + let table = PricingTable::with_defaults(); + // Versioned model should match base model + let pricing = table.get("gpt-4o-2024-08-06"); + assert_eq!(pricing.input, 2.50); + } + + #[test] + fn test_session_costs() { + let pricing = PricingTable::with_defaults(); + let mut session = SessionCosts::new("test-session".to_string(), pricing); + + session.record_operation(1, "gpt-4o-mini", 1000, 500); + session.record_operation(1, "gpt-4o-mini", 500, 200); + + assert_eq!(session.total_tokens(), 2200); + assert!(session.total_cost() > 0.0); + + let by_model = session.cost_by_model(); + assert!(by_model.contains_key("gpt-4o-mini")); + } + + #[test] + fn test_format_cost() { + assert_eq!(format_cost(0.001), "$0.0010"); + assert_eq!(format_cost(0.05), "$0.050"); + assert_eq!(format_cost(1.23), "$1.23"); + } + + #[test] + fn test_format_tokens() { + assert_eq!(format_tokens(500), "500"); + assert_eq!(format_tokens(1500), "1.5k"); + assert_eq!(format_tokens(1_500_000), "1.5M"); + } +} diff --git a/src/hooks.rs b/src/hooks.rs index 70c4d6b..52a65fc 100644 --- a/src/hooks.rs +++ b/src/hooks.rs @@ -500,7 +500,6 @@ impl HookManager { } } } - } #[cfg(test)] diff --git a/src/llm.rs b/src/llm.rs index 67f3a40..fdc6c0f 100644 --- a/src/llm.rs +++ b/src/llm.rs @@ -19,10 +19,6 @@ pub struct Usage { pub prompt_tokens: u64, #[serde(default)] pub completion_tokens: u64, - /// Total tokens from API (may be redundant with prompt_tokens + completion_tokens) - #[serde(default)] - #[allow(dead_code)] - pub total_tokens: u64, } #[derive(Debug, Deserialize)] @@ -35,7 +31,6 @@ pub struct ChatResponse { #[derive(Debug, Deserialize)] pub struct Choice { pub message: Message, - #[allow(dead_code)] pub finish_reason: Option, } @@ -62,6 +57,11 @@ pub struct FunctionCall { pub arguments: String, } +/// Trait for LLM clients to allow mocking and abstraction +pub trait LlmClient { + fn chat(&self, request: &ChatRequest) -> Result; +} + pub struct Client { base_url: String, api_key: String, @@ -76,8 +76,10 @@ impl Client { agent: ureq::Agent::new(), } } +} - pub fn chat(&self, request: &ChatRequest) -> Result { +impl LlmClient for Client { + fn chat(&self, request: &ChatRequest) -> Result { let url = format!("{}/chat/completions", self.base_url); let resp = self diff --git a/src/main.rs b/src/main.rs index 0294748..da9f2b6 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,7 +1,10 @@ mod agent; mod backend; mod cli; +mod commands; +mod compact; mod config; +mod cost; mod hooks; mod llm; mod mcp; @@ -10,6 +13,7 @@ mod plan; mod policy; mod skillpacks; mod subagent; +mod tool_filter; mod tools; mod transcript; @@ -87,6 +91,13 @@ pub struct Args { #[arg(long, help = "Debug output (print HTTP details and settings)")] pub debug: bool, + + #[arg( + short = 'O', + long = "optimize", + help = "Optimize output for token efficiency" + )] + pub optimize: bool, } fn main() -> Result<()> { @@ -156,6 +167,17 @@ fn main() -> Result<()> { )); } + // Validate configuration + if let Err(errors) = cfg.validate() { + for err in &errors { + eprintln!("Config error {}", err); + } + return Err(anyhow::anyhow!( + "Configuration has {} validation error(s)", + errors.len() + )); + } + // Apply CLI permission overrides if let Some(mode_str) = &args.mode { if let Some(mode) = config::PermissionMode::from_str(mode_str) { @@ -220,6 +242,13 @@ fn main() -> Result<()> { // Create hook manager let hook_manager = hooks::HookManager::new(cfg.hooks.clone(), session_id.clone(), root.clone()); + // Create cost tracker with pricing from config + let pricing_table = cost::PricingTable::from_config(&cfg.model_pricing); + let session_costs = cost::SessionCosts::new(session_id.clone(), pricing_table); + + // Build command index + let command_index = commands::CommandIndex::build(&root); + let ctx = cli::Context { args, root, @@ -236,6 +265,9 @@ fn main() -> Result<()> { model_router: RefCell::new(model_router), plan_mode: RefCell::new(plan::PlanModeState::new()), hooks: RefCell::new(hook_manager), + session_costs: RefCell::new(session_costs), + turn_counter: RefCell::new(0), + command_index: RefCell::new(command_index), }; // Fire SessionStart hook diff --git a/src/mcp/client.rs b/src/mcp/client.rs index 719edde..3f4a753 100644 --- a/src/mcp/client.rs +++ b/src/mcp/client.rs @@ -1,12 +1,11 @@ //! MCP JSON-RPC client for protocol communication. -use super::transport::StdioTransport; +use super::transport::McpTransportImpl; use super::McpToolDef; use anyhow::Result; use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; use std::sync::atomic::{AtomicU64, Ordering}; -use std::time::Duration; /// JSON-RPC request structure #[derive(Serialize)] @@ -21,9 +20,6 @@ struct JsonRpcRequest { /// JSON-RPC response structure #[derive(Deserialize)] struct JsonRpcResponse { - #[allow(dead_code)] - jsonrpc: String, - id: Option, result: Option, error: Option, } @@ -37,18 +33,16 @@ struct JsonRpcError { /// MCP client for communicating with an MCP server pub struct McpClient { - transport: StdioTransport, + transport: McpTransportImpl, request_id: AtomicU64, - timeout: Duration, } impl McpClient { - /// Create a new MCP client wrapping a transport - pub fn new(transport: StdioTransport, timeout_ms: u64) -> Self { + /// Create a new MCP client with any transport type + pub fn with_transport(transport: McpTransportImpl) -> Self { Self { transport, request_id: AtomicU64::new(1), - timeout: Duration::from_millis(timeout_ms), } } @@ -67,25 +61,17 @@ impl McpClient { params, }; - self.transport.send(&serde_json::to_value(&request)?)?; - - // Wait for response with matching ID - loop { - let response_value = self.transport.recv_timeout(self.timeout)?; - let response: JsonRpcResponse = serde_json::from_value(response_value)?; + let response_value = self.transport.send(&serde_json::to_value(&request)?)?; + let response: JsonRpcResponse = serde_json::from_value(response_value)?; - if response.id == Some(id) { - if let Some(error) = response.error { - return Err(anyhow::anyhow!( - "MCP error {}: {}", - error.code, - error.message - )); - } - return Ok(response.result.unwrap_or(Value::Null)); - } - // Ignore notifications or responses for other requests + if let Some(error) = response.error { + return Err(anyhow::anyhow!( + "MCP error {}: {}", + error.code, + error.message + )); } + Ok(response.result.unwrap_or(Value::Null)) } /// Perform MCP initialize handshake @@ -102,11 +88,14 @@ impl McpClient { let result = self.call("initialize", Some(params))?; // Send initialized notification (no response expected) + // For HTTP/SSE this is a fire-and-forget, but log errors for debugging let notification = json!({ "jsonrpc": "2.0", "method": "notifications/initialized" }); - self.transport.send(¬ification)?; + if let Err(e) = self.transport.send(¬ification) { + eprintln!("MCP: Failed to send initialized notification: {}", e); + } Ok(result) } diff --git a/src/mcp/manager.rs b/src/mcp/manager.rs index 3b66fc8..47c65e4 100644 --- a/src/mcp/manager.rs +++ b/src/mcp/manager.rs @@ -1,9 +1,9 @@ //! MCP server lifecycle manager. use super::client::McpClient; -use super::transport::StdioTransport; +use super::transport::{HttpTransport, McpTransportImpl, SseTransport, StdioTransport}; use super::McpToolDef; -use crate::config::McpServerConfig; +use crate::config::{McpServerConfig, McpTransport}; use anyhow::Result; use serde_json::Value; use std::collections::HashMap; @@ -53,6 +53,7 @@ impl McpManager { } /// Connect to an MCP server by name + /// Returns (pid_or_0, tool_count) - pid is 0 for HTTP/SSE transports pub fn connect(&mut self, name: &str, root: &Path) -> Result<(u32, usize)> { // Check if already connected if self.clients.contains_key(name) { @@ -69,14 +70,34 @@ impl McpManager { return Err(anyhow::anyhow!("Server {} is disabled", name)); } - // Resolve cwd relative to project root - let cwd = root.join(&config.cwd); - - // Spawn transport - let transport = StdioTransport::spawn(&config.command, &config.args, &config.env, &cwd)?; + // Create appropriate transport based on config + let (transport, pid) = match config.transport { + McpTransport::Stdio => { + // Resolve cwd relative to project root + let cwd = root.join(&config.cwd); + let t = StdioTransport::spawn(&config.command, &config.args, &config.env, &cwd)?; + let pid = t.pid(); + (McpTransportImpl::Stdio(t), pid) + } + McpTransport::Http => { + let url = config + .url + .as_ref() + .ok_or_else(|| anyhow::anyhow!("HTTP transport requires 'url' in config"))?; + let t = HttpTransport::new(url, config.timeout_ms); + (McpTransportImpl::Http(t), 0) + } + McpTransport::Sse => { + let url = config + .url + .as_ref() + .ok_or_else(|| anyhow::anyhow!("SSE transport requires 'url' in config"))?; + let t = SseTransport::new(url, config.timeout_ms); + (McpTransportImpl::Sse(t), 0) + } + }; - let pid = transport.pid(); - let mut client = McpClient::new(transport, config.timeout_ms); + let mut client = McpClient::with_transport(transport); // Initialize connection client.initialize()?; diff --git a/src/mcp/transport.rs b/src/mcp/transport.rs index 9835932..7e67133 100644 --- a/src/mcp/transport.rs +++ b/src/mcp/transport.rs @@ -1,6 +1,9 @@ -//! Stdio transport layer for MCP server communication. +//! Transport layer for MCP server communication. //! -//! Spawns MCP servers as subprocesses and communicates via newline-delimited JSON. +//! Supports three transport types: +//! - Stdio: Spawns MCP servers as subprocesses (newline-delimited JSON) +//! - HTTP: Communicates via HTTP POST requests +//! - SSE: Server-Sent Events for streaming responses use anyhow::{Context, Result}; use serde_json::Value; @@ -142,3 +145,219 @@ impl Drop for StdioTransport { } } } + +/// HTTP transport for communicating with an MCP server over HTTP +pub struct HttpTransport { + url: String, + agent: ureq::Agent, + timeout: Duration, +} + +impl HttpTransport { + /// Create a new HTTP transport + pub fn new(url: &str, timeout_ms: u64) -> Self { + Self { + url: url.to_string(), + agent: ureq::Agent::new(), + timeout: Duration::from_millis(timeout_ms), + } + } + + /// Send a JSON-RPC message and receive response + pub fn send(&self, message: &Value) -> Result { + let resp = self + .agent + .post(&self.url) + .timeout(self.timeout) + .set("Content-Type", "application/json") + .send_json(message.clone()); + + match resp { + Ok(r) => { + let body: Value = r.into_json()?; + Ok(body) + } + Err(ureq::Error::Status(code, resp)) => { + let body = resp.into_string().unwrap_or_default(); + Err(anyhow::anyhow!("HTTP error {}: {}", code, body)) + } + Err(e) => Err(anyhow::anyhow!("HTTP request failed: {}", e)), + } + } + + /// HTTP transport is always "alive" since it's stateless + pub fn is_alive(&self) -> bool { + true + } +} + +/// SSE (Server-Sent Events) transport for MCP servers +/// Uses HTTP POST for requests and SSE for streaming responses +pub struct SseTransport { + url: String, + agent: ureq::Agent, + timeout: Duration, +} + +impl SseTransport { + /// Create a new SSE transport + pub fn new(url: &str, timeout_ms: u64) -> Self { + Self { + url: url.to_string(), + agent: ureq::Agent::new(), + timeout: Duration::from_millis(timeout_ms), + } + } + + /// Send a JSON-RPC message and wait for response via SSE + pub fn send(&self, message: &Value) -> Result { + // For SSE, we send the request and then listen for events + // The request includes a unique ID that we match in the response + let request_id = message.get("id").and_then(|v| v.as_u64()); + + // Send the request via POST + let resp = self + .agent + .post(&self.url) + .timeout(self.timeout) + .set("Content-Type", "application/json") + .send_json(message.clone()); + + match resp { + Ok(r) => { + // Check if response is SSE stream or direct JSON + let content_type = r.header("content-type").unwrap_or("").to_lowercase(); + if content_type.contains("text/event-stream") { + self.parse_sse_response(request_id, r) + } else { + // For simple implementations, the response comes back as JSON directly + let body: Value = r.into_json()?; + Ok(body) + } + } + Err(ureq::Error::Status(code, resp)) => { + // Try to get SSE response from the event endpoint + // Some servers send the response on a separate event stream + self.try_sse_fallback(request_id, code, resp) + } + Err(e) => Err(anyhow::anyhow!("SSE request failed: {}", e)), + } + } + + /// Parse SSE event stream from a response + fn parse_sse_response(&self, request_id: Option, resp: ureq::Response) -> Result { + let mut reader = BufReader::new(resp.into_reader()); + let mut line = String::new(); + let mut data = String::new(); + let mut events_read = 0; + const MAX_EVENTS: usize = 1000; // Prevent infinite loops + + loop { + line.clear(); + match reader.read_line(&mut line) { + Ok(0) => break, // EOF + Ok(_) => { + let line = line.trim(); + if let Some(stripped) = line.strip_prefix("data:") { + data = stripped.trim().to_string(); + } else if line.is_empty() && !data.is_empty() { + // End of event, parse the data + events_read += 1; + if events_read > MAX_EVENTS { + return Err(anyhow::anyhow!( + "SSE stream exceeded {} events without matching response", + MAX_EVENTS + )); + } + if let Ok(value) = serde_json::from_str::(&data) { + // Check if this is the response we're waiting for + if let Some(id) = request_id { + if value.get("id").and_then(|v| v.as_u64()) == Some(id) { + return Ok(value); + } + } else { + return Ok(value); + } + } + data.clear(); + } + } + Err(e) => return Err(anyhow::anyhow!("SSE read error: {}", e)), + } + } + + Err(anyhow::anyhow!( + "SSE stream ended without matching response" + )) + } + + fn try_sse_fallback( + &self, + request_id: Option, + http_code: u16, + resp: ureq::Response, + ) -> Result { + // If the POST returned an error status, check if it's an SSE stream + let content_type = resp.header("content-type").unwrap_or("").to_lowercase(); + + if content_type.contains("text/event-stream") { + return self.parse_sse_response(request_id, resp); + } + + Err(anyhow::anyhow!( + "HTTP error {}: SSE fallback failed", + http_code + )) + } + + /// SSE transport is always "alive" since it's stateless + pub fn is_alive(&self) -> bool { + true + } +} + +/// Unified transport enum for MCP communication +pub enum McpTransportImpl { + Stdio(StdioTransport), + Http(HttpTransport), + Sse(SseTransport), +} + +impl McpTransportImpl { + /// Send a message and receive response + pub fn send(&mut self, message: &Value) -> Result { + match self { + McpTransportImpl::Stdio(t) => { + t.send(message)?; + t.recv_timeout(Duration::from_secs(30)) + } + McpTransportImpl::Http(t) => t.send(message), + McpTransportImpl::Sse(t) => t.send(message), + } + } + + /// Check if the transport is alive + pub fn is_alive(&mut self) -> bool { + match self { + McpTransportImpl::Stdio(t) => t.is_alive(), + McpTransportImpl::Http(t) => t.is_alive(), + McpTransportImpl::Sse(t) => t.is_alive(), + } + } + + /// Get exit status (only for stdio) + pub fn exit_status(&mut self) -> Option { + match self { + McpTransportImpl::Stdio(t) => t.exit_status(), + _ => None, + } + } + + /// Kill the transport (only affects stdio) + pub fn kill(&mut self) -> Result<()> { + match self { + McpTransportImpl::Stdio(t) => t.kill(), + _ => Ok(()), // HTTP/SSE are stateless + } + } +} diff --git a/src/plan.rs b/src/plan.rs index ff8433f..4c90cca 100644 --- a/src/plan.rs +++ b/src/plan.rs @@ -5,9 +5,14 @@ use anyhow::{anyhow, Result}; use chrono::{DateTime, Utc}; +use once_cell::sync::Lazy; use serde::{Deserialize, Serialize}; use std::path::{Path, PathBuf}; +/// Cached regex for parsing STEP lines in plan output +static STEP_REGEX: Lazy = + Lazy::new(|| regex::Regex::new(r"(?i)STEP\s+(\d+):\s*(.+)").expect("Invalid step regex")); + // ============================================================================ // Core Data Structures // ============================================================================ @@ -390,15 +395,14 @@ pub fn parse_plan_output(output: &str, goal: &str) -> Result { plan.summary = summary; } - // Parse STEPs using regex - let step_re = regex::Regex::new(r"(?i)STEP\s+(\d+):\s*(.+)")?; + // Parse STEPs using cached regex let mut current_step: Option = None; let mut in_description = false; for line in plan_block.lines() { let trimmed = line.trim(); - if let Some(caps) = step_re.captures(trimmed) { + if let Some(caps) = STEP_REGEX.captures(trimmed) { // Save previous step if exists if let Some(step) = current_step.take() { plan.steps.push(step); diff --git a/src/policy.rs b/src/policy.rs index a6906f2..82bd7a6 100644 --- a/src/policy.rs +++ b/src/policy.rs @@ -103,49 +103,7 @@ impl PolicyEngine { /// Check if a rule pattern matches a tool call /// Pattern format: "ToolName" or "ToolName(prefix:*)" or "mcp.*" or "mcp.server.*" fn rule_matches(pattern: &str, tool: &str, arg: Option<&str>) -> bool { - // Simple tool name match: "Write" matches all Write calls - if pattern == tool { - return true; - } - - // MCP wildcard matching: "mcp.*" or "mcp.server.*" - // Pattern "mcp.*" matches any MCP tool (e.g., "mcp.echo.add") - // Pattern "mcp.echo.*" matches any tool from echo server (e.g., "mcp.echo.add", "mcp.echo.echo") - if pattern.ends_with(".*") && tool.starts_with("mcp.") { - let prefix = &pattern[..pattern.len() - 2]; // Remove ".*" - if let Some(remaining) = tool.strip_prefix(prefix) { - // Check that the match is at a dot boundary - if remaining.is_empty() || remaining.starts_with('.') { - return true; - } - } - } - - // Pattern with argument: "Bash(git diff:*)" or "Edit(src/lib.rs)" - if let Some(open_paren) = pattern.find('(') { - let rule_tool = &pattern[..open_paren]; - if rule_tool != tool { - return false; - } - - // Extract the argument pattern - let close_paren = pattern.rfind(')').unwrap_or(pattern.len()); - let arg_pattern = &pattern[open_paren + 1..close_paren]; - - let Some(arg) = arg else { - return false; - }; - - // Check for prefix match: "git diff:*" - if let Some(prefix) = arg_pattern.strip_suffix(":*") { - return arg.starts_with(prefix); - } - - // Exact match - return arg_pattern == arg; - } - - false + crate::tool_filter::tool_matches(tool, pattern, arg) } /// Determine the permission decision for a tool call diff --git a/src/skillpacks/activation.rs b/src/skillpacks/activation.rs index 3759314..6a6334e 100644 --- a/src/skillpacks/activation.rs +++ b/src/skillpacks/activation.rs @@ -10,8 +10,6 @@ pub struct SkillActivation { pub name: String, pub description: String, pub allowed_tools: Option>, - #[allow(dead_code)] - pub instructions: String, } /// Manages the set of active skills @@ -44,7 +42,6 @@ impl ActiveSkills { name: pack.name.clone(), description: pack.description.clone(), allowed_tools: pack.allowed_tools.clone(), - instructions: pack.instructions.clone(), }; self.active.insert(pack.name.clone(), pack); diff --git a/src/subagent.rs b/src/subagent.rs index 299a02c..28749ad 100644 --- a/src/subagent.rs +++ b/src/subagent.rs @@ -2,6 +2,7 @@ use crate::agent::CommandStats; use crate::config::{AgentSpec, PermissionMode}; +use crate::llm::LlmClient; use crate::policy::{Decision, PolicyEngine}; use crate::{cli::Context, llm, tools}; use anyhow::Result; @@ -76,38 +77,21 @@ pub fn clamp_mode(requested: PermissionMode, parent: PermissionMode) -> Permissi } } -/// Check if a tool name matches an allowed pattern -/// Pattern formats: "ToolName" (exact) or "mcp.*" or "mcp.server.*" (wildcard) -fn tool_matches_pattern(tool_name: &str, pattern: &str) -> bool { - if pattern == tool_name { - return true; - } - - // Handle MCP wildcard patterns like "mcp.*" or "mcp.server.*" - // Pattern "mcp.*" matches "mcp.echo.add" but not "mcpfoo" - // Pattern "mcp.echo.*" matches "mcp.echo.add" but not "mcp.echoserver.add" - if let Some(prefix) = pattern.strip_suffix(".*") { - if let Some(remaining) = tool_name.strip_prefix(prefix) { - // Match at dot boundary: remaining must be empty or start with '.' - return remaining.is_empty() || remaining.starts_with('.'); - } - } - - false -} +use crate::tool_filter; /// Filter tool schemas to only include allowed tools -pub fn filter_tool_schemas(allowed_tools: &[String]) -> Vec { - let all_schemas = tools::schemas(); +pub fn filter_tool_schemas( + allowed_tools: &[String], + schema_opts: &tools::SchemaOptions, +) -> Vec { + let all_schemas = tools::schemas(schema_opts); all_schemas .into_iter() .filter(|schema| { if let Some(func) = schema.get("function") { if let Some(name) = func.get("name").and_then(|n| n.as_str()) { - return allowed_tools - .iter() - .any(|pattern| tool_matches_pattern(name, pattern)); + return tool_filter::tool_matches_any_simple(name, allowed_tools); } } false @@ -122,9 +106,7 @@ fn is_tool_allowed(tool_name: &str, allowed_tools: &[String]) -> bool { return false; } - allowed_tools - .iter() - .any(|pattern| tool_matches_pattern(tool_name, pattern)) + tool_filter::tool_matches_any_simple(tool_name, allowed_tools) } fn trace(ctx: &Context, agent: &str, label: &str, content: &str) { @@ -209,9 +191,20 @@ pub fn run_subagent( trace(ctx, agent_name, "TARGET", &format!("{}", target)); // Build system prompt for subagent - let system_prompt = spec.system_prompt.as_deref().unwrap_or( - "You are a specialized subagent. Complete the assigned task using only your available tools." - ); + let mut system_prompt = spec + .system_prompt + .as_deref() + .unwrap_or( + "You are a specialized subagent. Complete the assigned task using only your available tools.", + ) + .to_string(); + + // Add optimization mode instructions if -O flag is set + if ctx.args.optimize { + system_prompt.push_str( + "\n\nAI-to-AI mode. Maximum information density. Structure over prose. No narration.", + ); + } // Build initial messages let mut messages: Vec = Vec::new(); @@ -236,7 +229,8 @@ pub fn run_subagent( })); // Get filtered tool schemas - let tool_schemas = filter_tool_schemas(&spec.allowed_tools); + let schema_opts = tools::SchemaOptions::new(ctx.args.optimize); + let tool_schemas = filter_tool_schemas(&spec.allowed_tools, &schema_opts); // Also add allowed MCP tools if any let mut all_tool_schemas = tool_schemas; @@ -304,6 +298,23 @@ pub fn run_subagent( if let Some(usage) = &response.usage { stats.input_tokens += usage.prompt_tokens; stats.output_tokens += usage.completion_tokens; + + // Record cost for this operation (uses parent turn number) + let turn_number = *ctx.turn_counter.borrow(); + let op = ctx.session_costs.borrow_mut().record_operation( + turn_number, + &target.model, + usage.prompt_tokens, + usage.completion_tokens, + ); + + // Log token usage to transcript + let _ = ctx.transcript.borrow_mut().token_usage( + &target.model, + usage.prompt_tokens, + usage.completion_tokens, + op.cost_usd, + ); } if response.choices.is_empty() { @@ -313,6 +324,16 @@ pub fn run_subagent( let choice = &response.choices[0]; let msg = &choice.message; + // Warn if response was truncated due to length limit + if choice.finish_reason.as_deref() == Some("length") { + trace( + ctx, + agent_name, + "WARN", + "Response truncated (max tokens reached)", + ); + } + // Collect assistant text if let Some(content) = &msg.content { if !content.is_empty() { diff --git a/src/tool_filter.rs b/src/tool_filter.rs new file mode 100644 index 0000000..d8f8a17 --- /dev/null +++ b/src/tool_filter.rs @@ -0,0 +1,147 @@ +//! Unified tool name pattern matching. +//! +//! Consolidates tool filtering logic used by policy.rs and subagent.rs. + +/// Check if a tool name matches a pattern. +/// +/// # Pattern formats: +/// - `"Read"` - exact match +/// - `"mcp.*"` - matches all MCP tools (e.g., "mcp.echo.add") +/// - `"mcp.server.*"` - matches tools from specific MCP server +/// - `"Bash(git:*)"` - matches Bash with args starting with "git" +/// - `"Edit(src/lib.rs)"` - matches Edit with exact file path +/// +/// # Arguments +/// * `tool` - The tool name being checked +/// * `pattern` - The pattern to match against +/// * `arg` - Optional argument for tools that support arg matching (Bash, Edit, etc.) +pub fn tool_matches(tool: &str, pattern: &str, arg: Option<&str>) -> bool { + // Exact match + if pattern == tool { + return true; + } + + // MCP wildcard matching: "mcp.*" or "mcp.server.*" + // Pattern "mcp.*" matches any MCP tool (e.g., "mcp.echo.add") + // Pattern "mcp.echo.*" matches any tool from echo server + if let Some(prefix) = pattern.strip_suffix(".*") { + if let Some(remaining) = tool.strip_prefix(prefix) { + // Match at dot boundary: remaining must be empty or start with '.' + if remaining.is_empty() || remaining.starts_with('.') { + return true; + } + } + } + + // Pattern with argument: "Bash(git diff:*)" or "Edit(src/lib.rs)" + if let Some(open_paren) = pattern.find('(') { + let pattern_tool = &pattern[..open_paren]; + if pattern_tool != tool { + return false; + } + + // Require matching closing paren for well-formed patterns + let Some(close_paren) = pattern.rfind(')') else { + return false; // Malformed pattern + }; + let arg_pattern = &pattern[open_paren + 1..close_paren]; + + let Some(actual_arg) = arg else { + return false; + }; + + // Check for prefix match: "git diff:*" + if let Some(prefix) = arg_pattern.strip_suffix(":*") { + return actual_arg.starts_with(prefix); + } + + // Exact argument match + return arg_pattern == actual_arg; + } + + false +} + +/// Check if a tool matches any pattern in a list +#[allow(dead_code)] // For future use with argument matching +pub fn tool_matches_any(tool: &str, patterns: &[String], arg: Option<&str>) -> bool { + patterns.iter().any(|p| tool_matches(tool, p, arg)) +} + +/// Check if a tool matches any pattern (no argument version) +/// Convenience for subagent tool filtering where arguments aren't used +pub fn tool_matches_any_simple(tool: &str, patterns: &[String]) -> bool { + patterns.iter().any(|p| tool_matches(tool, p, None)) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_exact_match() { + assert!(tool_matches("Read", "Read", None)); + assert!(tool_matches("Write", "Write", None)); + assert!(!tool_matches("Read", "Write", None)); + } + + #[test] + fn test_mcp_wildcard() { + // mcp.* matches all mcp tools + assert!(tool_matches("mcp.echo.add", "mcp.*", None)); + assert!(tool_matches("mcp.github.list_prs", "mcp.*", None)); + + // mcp.server.* matches tools from that server only + assert!(tool_matches("mcp.echo.add", "mcp.echo.*", None)); + assert!(tool_matches("mcp.echo.multiply", "mcp.echo.*", None)); + assert!(!tool_matches("mcp.github.list_prs", "mcp.echo.*", None)); + + // Should not match non-dot boundary + assert!(!tool_matches("mcpfake.tool", "mcp.*", None)); + } + + #[test] + fn test_arg_prefix_match() { + assert!(tool_matches("Bash", "Bash(git:*)", Some("git status"))); + assert!(tool_matches( + "Bash", + "Bash(git diff:*)", + Some("git diff HEAD") + )); + assert!(!tool_matches("Bash", "Bash(git:*)", Some("npm install"))); + assert!(!tool_matches("Bash", "Bash(git:*)", None)); + } + + #[test] + fn test_arg_exact_match() { + assert!(tool_matches("Edit", "Edit(src/lib.rs)", Some("src/lib.rs"))); + assert!(!tool_matches( + "Edit", + "Edit(src/lib.rs)", + Some("src/main.rs") + )); + } + + #[test] + fn test_matches_any() { + let patterns = vec![ + "Read".to_string(), + "Grep".to_string(), + "mcp.echo.*".to_string(), + ]; + + assert!(tool_matches_any("Read", &patterns, None)); + assert!(tool_matches_any("Grep", &patterns, None)); + assert!(tool_matches_any("mcp.echo.add", &patterns, None)); + assert!(!tool_matches_any("Write", &patterns, None)); + } + + #[test] + fn test_matches_any_simple() { + let patterns = vec!["Read".to_string(), "Glob".to_string()]; + + assert!(tool_matches_any_simple("Read", &patterns)); + assert!(tool_matches_any_simple("Glob", &patterns)); + assert!(!tool_matches_any_simple("Write", &patterns)); + } +} diff --git a/src/tools/activate_skill.rs b/src/tools/activate_skill.rs index 4a7612a..e6cc44a 100644 --- a/src/tools/activate_skill.rs +++ b/src/tools/activate_skill.rs @@ -1,28 +1,47 @@ //! ActivateSkill tool for model-invoked skill activation. +use super::SchemaOptions; use serde_json::{json, Value}; /// Get the ActivateSkill tool schema -pub fn schema() -> Value { - json!({ - "type": "function", - "function": { - "name": "ActivateSkill", - "description": "Activate a skill pack to gain specialized instructions and optionally restrict available tools. Use when the task matches a skill's description. View available skills in the 'Available skill packs:' section of the system prompt.", - "parameters": { - "type": "object", - "properties": { - "name": { - "type": "string", - "description": "Name of the skill pack to activate" +pub fn schema(opts: &SchemaOptions) -> Value { + if opts.optimize { + json!({ + "type": "function", + "function": { + "name": "ActivateSkill", + "description": "Activate skill pack", + "parameters": { + "type": "object", + "properties": { + "name": { "type": "string" }, + "reason": { "type": "string" } }, - "reason": { - "type": "string", - "description": "Brief reason for activating this skill (optional)" - } - }, - "required": ["name"] + "required": ["name"] + } } - } - }) + }) + } else { + json!({ + "type": "function", + "function": { + "name": "ActivateSkill", + "description": "Activate a skill pack to gain specialized instructions and optionally restrict available tools. Use when the task matches a skill's description. View available skills in the 'Available skill packs:' section of the system prompt.", + "parameters": { + "type": "object", + "properties": { + "name": { + "type": "string", + "description": "Name of the skill pack to activate" + }, + "reason": { + "type": "string", + "description": "Brief reason for activating this skill (optional)" + } + }, + "required": ["name"] + } + } + }) + } } diff --git a/src/tools/bash.rs b/src/tools/bash.rs index 2b8dc8d..515606e 100644 --- a/src/tools/bash.rs +++ b/src/tools/bash.rs @@ -23,33 +23,54 @@ struct BashArgs { cwd: Option, } +use super::SchemaOptions; + /// Returns the JSON schema for the Bash tool -pub fn schema() -> Value { - json!({ - "type": "function", - "function": { - "name": "Bash", - "description": "Execute a shell command in the project directory. Commands are parsed as shell words (not passed to sh -c). Returns stdout, stderr, and exit code. Use for builds, tests, git operations, etc.", - "parameters": { - "type": "object", - "properties": { - "command": { - "type": "string", - "description": "The command to execute (parsed as shell words)" +pub fn schema(opts: &SchemaOptions) -> Value { + if opts.optimize { + json!({ + "type": "function", + "function": { + "name": "Bash", + "description": "Run shell command", + "parameters": { + "type": "object", + "properties": { + "command": { "type": "string" }, + "timeout_ms": { "type": "integer" }, + "cwd": { "type": "string" } }, - "timeout_ms": { - "type": "integer", - "description": "Timeout in milliseconds (default 120000, max 600000)" + "required": ["command"] + } + } + }) + } else { + json!({ + "type": "function", + "function": { + "name": "Bash", + "description": "Execute a shell command in the project directory. Commands are parsed as shell words (not passed to sh -c). Returns stdout, stderr, and exit code. Use for builds, tests, git operations, etc.", + "parameters": { + "type": "object", + "properties": { + "command": { + "type": "string", + "description": "The command to execute (parsed as shell words)" + }, + "timeout_ms": { + "type": "integer", + "description": "Timeout in milliseconds (default 120000, max 600000)" + }, + "cwd": { + "type": "string", + "description": "Working directory relative to project root (default: project root)" + } }, - "cwd": { - "type": "string", - "description": "Working directory relative to project root (default: project root)" - } - }, - "required": ["command"] + "required": ["command"] + } } - } - }) + }) + } } /// Execute the Bash tool @@ -216,8 +237,19 @@ mod tests { #[test] fn test_schema() { - let schema = schema(); + let opts = SchemaOptions::default(); + let schema = schema(&opts); + assert_eq!(schema["function"]["name"], "Bash"); + } + + #[test] + fn test_schema_optimized() { + let opts = SchemaOptions::new(true); + let schema = schema(&opts); assert_eq!(schema["function"]["name"], "Bash"); + // Optimized schema should have shorter description + let desc = schema["function"]["description"].as_str().unwrap(); + assert_eq!(desc, "Run shell command"); } #[test] diff --git a/src/tools/edit.rs b/src/tools/edit.rs index 52959b9..c658e5a 100644 --- a/src/tools/edit.rs +++ b/src/tools/edit.rs @@ -1,34 +1,63 @@ -use super::{sha256, validate_path}; +use super::{sha256, validate_path, SchemaOptions}; use serde_json::{json, Value}; use std::path::Path; -pub fn schema() -> Value { - json!({ - "type": "function", - "function": { - "name": "Edit", - "description": "Edit file with find/replace. Requires permission.", - "parameters": { - "type": "object", - "properties": { - "path": { "type": "string", "description": "File path relative to root" }, - "edits": { - "type": "array", - "items": { - "type": "object", - "properties": { - "find": { "type": "string" }, - "replace": { "type": "string" }, - "count": { "type": "integer", "description": "Times to replace (0=all, default 1)" } - }, - "required": ["find", "replace"] +pub fn schema(opts: &SchemaOptions) -> Value { + if opts.optimize { + json!({ + "type": "function", + "function": { + "name": "Edit", + "description": "Edit file: find→replace", + "parameters": { + "type": "object", + "properties": { + "path": { "type": "string" }, + "edits": { + "type": "array", + "items": { + "type": "object", + "properties": { + "find": { "type": "string" }, + "replace": { "type": "string" }, + "count": { "type": "integer", "description": "0=all, default 1" } + }, + "required": ["find", "replace"] + } } - } - }, - "required": ["path", "edits"] + }, + "required": ["path", "edits"] + } } - } - }) + }) + } else { + json!({ + "type": "function", + "function": { + "name": "Edit", + "description": "Edit file with find/replace. Requires permission.", + "parameters": { + "type": "object", + "properties": { + "path": { "type": "string", "description": "File path relative to root" }, + "edits": { + "type": "array", + "items": { + "type": "object", + "properties": { + "find": { "type": "string" }, + "replace": { "type": "string" }, + "count": { "type": "integer", "description": "Times to replace (0=all, default 1)" } + }, + "required": ["find", "replace"] + } + } + }, + "required": ["path", "edits"] + } + } + }) + } } pub fn execute(args: Value, root: &Path) -> anyhow::Result { diff --git a/src/tools/glob.rs b/src/tools/glob.rs index 3b3da02..d07ee9a 100644 --- a/src/tools/glob.rs +++ b/src/tools/glob.rs @@ -1,22 +1,41 @@ +use super::SchemaOptions; use serde_json::{json, Value}; use std::path::Path; -pub fn schema() -> Value { - json!({ - "type": "function", - "function": { - "name": "Glob", - "description": "Find files matching glob pattern. Skips .git, target, .yo dirs.", - "parameters": { - "type": "object", - "properties": { - "pattern": { "type": "string", "description": "Glob pattern (e.g. **/*.rs)" }, - "max_results": { "type": "integer", "description": "Max files (default 2000)" } - }, - "required": ["pattern"] +pub fn schema(opts: &SchemaOptions) -> Value { + if opts.optimize { + json!({ + "type": "function", + "function": { + "name": "Glob", + "description": "Find files by pattern", + "parameters": { + "type": "object", + "properties": { + "pattern": { "type": "string" }, + "max_results": { "type": "integer" } + }, + "required": ["pattern"] + } } - } - }) + }) + } else { + json!({ + "type": "function", + "function": { + "name": "Glob", + "description": "Find files matching glob pattern. Skips .git, target, .yo dirs.", + "parameters": { + "type": "object", + "properties": { + "pattern": { "type": "string", "description": "Glob pattern (e.g. **/*.rs)" }, + "max_results": { "type": "integer", "description": "Max files (default 2000)" } + }, + "required": ["pattern"] + } + } + }) + } } pub fn execute(args: Value, root: &Path) -> anyhow::Result { diff --git a/src/tools/grep.rs b/src/tools/grep.rs index de39082..72ee653 100644 --- a/src/tools/grep.rs +++ b/src/tools/grep.rs @@ -1,25 +1,45 @@ +use super::SchemaOptions; use regex::Regex; use serde_json::{json, Value}; use std::path::Path; use walkdir::WalkDir; -pub fn schema() -> Value { - json!({ - "type": "function", - "function": { - "name": "Grep", - "description": "Search file contents for pattern. Skips .git, target, .yo dirs.", - "parameters": { - "type": "object", - "properties": { - "pattern": { "type": "string", "description": "Regex pattern to search" }, - "paths": { "type": "array", "items": { "type": "string" }, "description": "Paths to search (default: all)" }, - "max_results": { "type": "integer", "description": "Max matches (default 100)" } - }, - "required": ["pattern"] +pub fn schema(opts: &SchemaOptions) -> Value { + if opts.optimize { + json!({ + "type": "function", + "function": { + "name": "Grep", + "description": "Search content by regex", + "parameters": { + "type": "object", + "properties": { + "pattern": { "type": "string" }, + "paths": { "type": "array", "items": { "type": "string" } }, + "max_results": { "type": "integer" } + }, + "required": ["pattern"] + } } - } - }) + }) + } else { + json!({ + "type": "function", + "function": { + "name": "Grep", + "description": "Search file contents for pattern. Skips .git, target, .yo dirs.", + "parameters": { + "type": "object", + "properties": { + "pattern": { "type": "string", "description": "Regex pattern to search" }, + "paths": { "type": "array", "items": { "type": "string" }, "description": "Paths to search (default: all)" }, + "max_results": { "type": "integer", "description": "Max matches (default 100)" } + }, + "required": ["pattern"] + } + } + }) + } } pub fn execute(args: Value, root: &Path) -> anyhow::Result { diff --git a/src/tools/mod.rs b/src/tools/mod.rs index 637c145..fa91c84 100644 --- a/src/tools/mod.rs +++ b/src/tools/mod.rs @@ -13,29 +13,42 @@ use anyhow::Result; use serde_json::{json, Value}; use std::path::Path; +/// Configuration for schema generation +#[derive(Debug, Clone, Copy, Default)] +pub struct SchemaOptions { + /// Generate terse schemas optimized for token efficiency + pub optimize: bool, +} + +impl SchemaOptions { + pub fn new(optimize: bool) -> Self { + Self { optimize } + } +} + /// Get all built-in tool schemas (excluding Task - used by subagents) -pub fn schemas() -> Vec { +pub fn schemas(opts: &SchemaOptions) -> Vec { vec![ - read::schema(), - write::schema(), - edit::schema(), - grep::schema(), - glob::schema(), - bash::schema(), + read::schema(opts), + write::schema(opts), + edit::schema(opts), + grep::schema(opts), + glob::schema(opts), + bash::schema(opts), ] } /// Get all tool schemas including Task and ActivateSkill (used by main agent) -pub fn schemas_with_task() -> Vec { +pub fn schemas_with_task(opts: &SchemaOptions) -> Vec { vec![ - read::schema(), - write::schema(), - edit::schema(), - grep::schema(), - glob::schema(), - bash::schema(), - task::schema(), - activate_skill::schema(), + read::schema(opts), + write::schema(opts), + edit::schema(opts), + grep::schema(opts), + glob::schema(opts), + bash::schema(opts), + task::schema(opts), + activate_skill::schema(opts), ] } diff --git a/src/tools/read.rs b/src/tools/read.rs index 1334d2e..a837fee 100644 --- a/src/tools/read.rs +++ b/src/tools/read.rs @@ -1,24 +1,43 @@ -use super::{sha256, validate_path}; +use super::{sha256, validate_path, SchemaOptions}; use serde_json::{json, Value}; use std::path::Path; -pub fn schema() -> Value { - json!({ - "type": "function", - "function": { - "name": "Read", - "description": "Read file content. Paths relative to project root.", - "parameters": { - "type": "object", - "properties": { - "path": { "type": "string", "description": "File path relative to root" }, - "max_bytes": { "type": "integer", "description": "Max bytes to read (default 65536)" }, - "offset": { "type": "integer", "description": "Byte offset to start from (default 0)" } - }, - "required": ["path"] +pub fn schema(opts: &SchemaOptions) -> Value { + if opts.optimize { + json!({ + "type": "function", + "function": { + "name": "Read", + "description": "Read file", + "parameters": { + "type": "object", + "properties": { + "path": { "type": "string" }, + "max_bytes": { "type": "integer" }, + "offset": { "type": "integer" } + }, + "required": ["path"] + } } - } - }) + }) + } else { + json!({ + "type": "function", + "function": { + "name": "Read", + "description": "Read file content. Paths relative to project root.", + "parameters": { + "type": "object", + "properties": { + "path": { "type": "string", "description": "File path relative to root" }, + "max_bytes": { "type": "integer", "description": "Max bytes to read (default 65536)" }, + "offset": { "type": "integer", "description": "Byte offset to start from (default 0)" } + }, + "required": ["path"] + } + } + }) + } } pub fn execute(args: Value, root: &Path) -> anyhow::Result { diff --git a/src/tools/task.rs b/src/tools/task.rs index ad34620..c6cb5cd 100644 --- a/src/tools/task.rs +++ b/src/tools/task.rs @@ -1,53 +1,86 @@ //! Task tool for delegating work to subagents. +use super::SchemaOptions; use crate::agent::CommandStats; use crate::cli::Context; use crate::subagent::{self, InputContext, SubagentResult}; use serde_json::{json, Value}; -pub fn schema() -> Value { - json!({ - "type": "function", - "function": { - "name": "Task", - "description": "Delegate a task to a specialized subagent. Use /agents to see available agents.", - "parameters": { - "type": "object", - "properties": { - "agent": { - "type": "string", - "description": "Name of the subagent to delegate to (e.g., 'scout', 'patch', 'test', 'docs')" - }, - "prompt": { - "type": "string", - "description": "Task description for the subagent" +pub fn schema(opts: &SchemaOptions) -> Value { + if opts.optimize { + json!({ + "type": "function", + "function": { + "name": "Task", + "description": "Delegate to subagent", + "parameters": { + "type": "object", + "properties": { + "agent": { "type": "string" }, + "prompt": { "type": "string" }, + "input_context": { + "type": "object", + "properties": { + "files": { + "type": "array", + "items": { + "type": "object", + "properties": { "path": { "type": "string" } }, + "required": ["path"] + } + }, + "notes": { "type": "string" } + } + } }, - "input_context": { - "type": "object", - "description": "Optional context to provide to the subagent", - "properties": { - "files": { - "type": "array", - "items": { - "type": "object", - "properties": { - "path": { "type": "string", "description": "File path hint for the subagent" } + "required": ["agent", "prompt"] + } + } + }) + } else { + json!({ + "type": "function", + "function": { + "name": "Task", + "description": "Delegate a task to a specialized subagent. Use /agents to see available agents.", + "parameters": { + "type": "object", + "properties": { + "agent": { + "type": "string", + "description": "Name of the subagent to delegate to (e.g., 'scout', 'patch', 'test', 'docs')" + }, + "prompt": { + "type": "string", + "description": "Task description for the subagent" + }, + "input_context": { + "type": "object", + "description": "Optional context to provide to the subagent", + "properties": { + "files": { + "type": "array", + "items": { + "type": "object", + "properties": { + "path": { "type": "string", "description": "File path hint for the subagent" } + }, + "required": ["path"] }, - "required": ["path"] + "description": "File paths to hint to the subagent (it can use Read tool to access them)" }, - "description": "File paths to hint to the subagent (it can use Read tool to access them)" - }, - "notes": { - "type": "string", - "description": "Additional notes for the subagent" + "notes": { + "type": "string", + "description": "Additional notes for the subagent" + } } } - } - }, - "required": ["agent", "prompt"] + }, + "required": ["agent", "prompt"] + } } - } - }) + }) + } } /// Execute the Task tool - delegates to a subagent diff --git a/src/tools/write.rs b/src/tools/write.rs index 8e0a86a..3cc8ade 100644 --- a/src/tools/write.rs +++ b/src/tools/write.rs @@ -1,24 +1,43 @@ -use super::{sha256, validate_path}; +use super::{sha256, validate_path, SchemaOptions}; use serde_json::{json, Value}; use std::path::Path; -pub fn schema() -> Value { - json!({ - "type": "function", - "function": { - "name": "Write", - "description": "Create or overwrite a file. Requires permission.", - "parameters": { - "type": "object", - "properties": { - "path": { "type": "string", "description": "File path relative to root" }, - "content": { "type": "string", "description": "Content to write" }, - "overwrite": { "type": "boolean", "description": "Allow overwrite (default true)" } - }, - "required": ["path", "content"] +pub fn schema(opts: &SchemaOptions) -> Value { + if opts.optimize { + json!({ + "type": "function", + "function": { + "name": "Write", + "description": "Write file", + "parameters": { + "type": "object", + "properties": { + "path": { "type": "string" }, + "content": { "type": "string" }, + "overwrite": { "type": "boolean" } + }, + "required": ["path", "content"] + } } - } - }) + }) + } else { + json!({ + "type": "function", + "function": { + "name": "Write", + "description": "Create or overwrite a file. Requires permission.", + "parameters": { + "type": "object", + "properties": { + "path": { "type": "string", "description": "File path relative to root" }, + "content": { "type": "string", "description": "Content to write" }, + "overwrite": { "type": "boolean", "description": "Allow overwrite (default true)" } + }, + "required": ["path", "content"] + } + } + }) + } } pub fn execute(args: Value, root: &Path) -> anyhow::Result { diff --git a/src/transcript.rs b/src/transcript.rs index b5915b0..cdbe3f9 100644 --- a/src/transcript.rs +++ b/src/transcript.rs @@ -345,4 +345,24 @@ impl Transcript { }), ) } + + /// Log token usage for an LLM call + pub fn token_usage( + &mut self, + model: &str, + input_tokens: u64, + output_tokens: u64, + cost_usd: f64, + ) -> Result<()> { + self.log( + "token_usage", + serde_json::json!({ + "model": model, + "input_tokens": input_tokens, + "output_tokens": output_tokens, + "total_tokens": input_tokens + output_tokens, + "cost_usd": cost_usd, + }), + ) + } }