diff --git a/.gitignore b/.gitignore index 18911ad4..974a4934 100644 --- a/.gitignore +++ b/.gitignore @@ -30,3 +30,4 @@ logs/ # g3 artifacts requirements.md todo.g3.md +config.toml diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 00000000..e6508f03 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,27 @@ +# Changelog + +All notable changes to this project will be documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), +and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +## [Unreleased] - 2025-12-16 + +### Added +- **OpenRouter Support**: Implemented full support for OpenRouter as an LLM provider. + - Added `OpenRouterProvider` implementation in `crates/g3-providers`. + - Added configuration structures `OpenRouterConfig` and `ProviderPreferencesConfig` in `crates/g3-config`. + - Added integration tests in `crates/tests/openrouter_integration_tests.rs`. + - Updated `g3-cli` to accept `openrouter` as a valid provider type in command line arguments. + - Updated `g3-core` to register and handle OpenRouter providers. + - Added example configuration in `config.example.toml`. +- **Configuration**: Added `config.toml` to `.gitignore`. + +### Changed +- **g3-core**: Updated `provider_max_tokens` and `resolve_max_tokens` to correctly handle OpenRouter configuration and context window sizes (defaulting to 128k if not specified, but respecting config). +- **g3-cli**: Updated provider validation logic to support `openrouter` prefix (e.g., `openrouter.grok`). + +### Fixed +- **g3-providers**: Fixed unused variable warnings in `anthropic.rs` by renaming `cache_config` to `_cache_config`. +- **g3-planner**: Fixed unused function warning in `llm.rs` by renaming `print_status_line` to `_print_status_line`. +- **g3-providers**: Fixed typo in `openrouter.rs` (`pub mode` -> `pub mod`). diff --git a/config.example.toml b/config.example.toml index 8c47b5cf..afb172a4 100644 --- a/config.example.toml +++ b/config.example.toml @@ -46,6 +46,27 @@ use_oauth = true # max_tokens = 4096 # temperature = 0.1 +Named OpenRouter configurations +# OpenRouter provides access to 200+ AI models through a unified API +[providers.openrouter.default] +api_key = "${OPENROUTER_API_KEY}" +model = "anthropic/claude-3.5-sonnet" +max_tokens = 4096 +temperature = 0.7 +# http_referer = "https://yourapp.com" # Optional: Your app URL for analytics +# x_title = "Your App Name" # Optional: Your app name for analytics +# provider_order = ["Anthropic"] # Optional: Preferred provider routing +# allow_fallbacks = true # Optional: Allow fallback to other providers + +# Multiple OpenAI-compatible providers can be configured +# [providers.openai_compatible.groq] +# api_key = "your-groq-api-key" +# model = "llama-3.3-70b-versatile" +# base_url = "https://api.groq.com/openai/v1" +# max_tokens = 4096 +# temperature = 0.1 + + # Multiple OpenAI-compatible providers can be configured # [providers.openai_compatible.openrouter] # api_key = "your-openrouter-api-key" diff --git a/crates/g3-cli/src/lib.rs b/crates/g3-cli/src/lib.rs index 6bcff99f..5f503862 100644 --- a/crates/g3-cli/src/lib.rs +++ b/crates/g3-cli/src/lib.rs @@ -185,34 +185,45 @@ fn extract_coach_feedback_from_logs( if let Some(prev_content) = prev_msg.get("content") { if let Some(prev_content_str) = prev_content.as_str() { // Check if the previous assistant message contains a final_output tool call - if prev_content_str.contains("\"tool\": \"final_output\"") { - // This is a final_output tool result - let feedback = if content_str.starts_with("Tool result: ") { + // If the previous assistant message explicitly indicates a final_output tool + // then treat it as verified. Otherwise, accept the Tool result as a + // fallback (with a warning) to avoid losing coach feedback when logs + // don't include an exact final_output marker. + let feedback = if prev_content_str.contains("\"tool\": \"final_output\"") { + if content_str.starts_with("Tool result: ") { content_str.strip_prefix("Tool result: ") .unwrap_or(content_str) .to_string() } else { content_str.to_string() - }; - - output.print(&format!( - "Coach feedback extracted: {} characters (from {} total)", - feedback.len(), - content_str.len() - )); - output.print(&format!("Coach feedback:\n{}", feedback)); - - output.print(&format!( - "✅ Extracted coach feedback from session: {} (verified final_output tool)", - session_id - )); - return Ok(feedback); + } } else { + // Unverified fallback: accept the tool result but warn output.print(&format!( - "⚠️ Skipping tool result at index {} - not a final_output tool call", + "⚠️ Tool result at index {} not verified as final_output; accepting as fallback", i )); - } + if content_str.starts_with("Tool result: ") { + content_str.strip_prefix("Tool result: ") + .unwrap_or(content_str) + .to_string() + } else { + content_str.to_string() + } + }; + + output.print(&format!( + "Coach feedback extracted: {} characters (from {} total)", + feedback.len(), + content_str.len() + )); + output.print(&format!("Coach feedback:\n{}", feedback)); + + output.print(&format!( + "✅ Extracted coach feedback from session: {}", + session_id + )); + return Ok(feedback); } } } @@ -233,8 +244,8 @@ fn extract_coach_feedback_from_logs( } } - // If we couldn't extract from logs, panic with detailed error - panic!( + // If we couldn't extract from logs, return an error (avoid panicking) + return Err(anyhow::anyhow!( "CRITICAL: Could not extract coach feedback from session: {}\n\ Log file path: {:?}\n\ Log file exists: {}\n\ @@ -244,7 +255,7 @@ fn extract_coach_feedback_from_logs( log_file_path, log_file_path.exists(), coach_result.response.len() - ); + )); } use clap::Parser; @@ -323,7 +334,7 @@ pub struct Cli { #[arg(long)] pub machine: bool, - /// Override the configured provider (anthropic, databricks, embedded, openai) + /// Override the configured provider (anthropic, databricks, embedded, openai, openrouter) #[arg(long, value_name = "PROVIDER")] pub provider: Option, @@ -533,8 +544,9 @@ pub async fn run() -> Result<()> { // Validate provider if specified if let Some(ref provider) = cli.provider { - let valid_providers = ["anthropic", "databricks", "embedded", "openai"]; - if !valid_providers.contains(&provider.as_str()) { + let valid_providers = ["anthropic", "databricks", "embedded", "openai", "openrouter"]; + let provider_type = provider.split('.').next().unwrap_or(provider); + if !valid_providers.contains(&provider_type) { return Err(anyhow::anyhow!( "Invalid provider '{}'. Valid options: {:?}", provider, diff --git a/crates/g3-config/src/lib.rs b/crates/g3-config/src/lib.rs index a3f5910e..63d94353 100644 --- a/crates/g3-config/src/lib.rs +++ b/crates/g3-config/src/lib.rs @@ -47,6 +47,10 @@ pub struct ProvidersConfig { /// Multiple named OpenAI-compatible providers (e.g., openrouter, groq, etc.) #[serde(default)] pub openai_compatible: HashMap, + + /// Named OpenRouter provider configs + #[serde(default)] + pub openrouter: HashMap, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -58,6 +62,25 @@ pub struct OpenAIConfig { pub temperature: Option, } +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct OpenRouterConfig { + pub api_key: String, + pub model: String, + pub base_url: Option, + pub max_tokens: Option, + pub temperature: Option, + pub provider_preferences: Option, + pub http_referer: Option, + pub x_title: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ProviderPreferencesConfig { + pub order: Option>, + pub allow_fallbacks: Option, + pub require_parameters: Option, +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct AnthropicConfig { pub api_key: String, @@ -198,6 +221,7 @@ impl Default for Config { databricks: databricks_configs, embedded: HashMap::new(), openai_compatible: HashMap::new(), + openrouter: HashMap::new(), }, agent: AgentConfig { max_context_length: None, @@ -414,11 +438,20 @@ impl Config { ); } } + "openrouter" => { + if !self.providers.openrouter.contains_key(config_name) { + anyhow::bail!( + "Provider config 'openrouter.{}' not found. Available: {:?}", + config_name, + self.providers.openrouter.keys().collect::>() + ); + } + } _ => { // Check openai_compatible providers if !self.providers.openai_compatible.contains_key(provider_type) { anyhow::bail!( - "Unknown provider type '{}'. Valid types: anthropic, openai, databricks, embedded, or openai_compatible names", + "Unknown provider type '{}'. Valid types: anthropic, openai, databricks, embedded, openrouter, or openai_compatible names", provider_type ); } diff --git a/crates/g3-core/src/lib.rs b/crates/g3-core/src/lib.rs index 69a702aa..9432cd79 100644 --- a/crates/g3-core/src/lib.rs +++ b/crates/g3-core/src/lib.rs @@ -1316,6 +1316,39 @@ impl Agent { } } + // Register OpenRouter providers from HashMap + for (name, openrouter_config) in &config.providers.openrouter { + if should_register("openrouter", name) { + let mut openrouter_provider = g3_providers::OpenRouterProvider::new_with_name( + format!("openrouter.{}", name), + openrouter_config.api_key.clone(), + Some(openrouter_config.model.clone()), + openrouter_config.max_tokens, + openrouter_config.temperature, + )?; + + if let Some(prefs) = &openrouter_config.provider_preferences { + // Convert config prefs to provider prefs + let provider_prefs = g3_providers::ProviderPreferences { + order: prefs.order.clone(), + allow_fallbacks: prefs.allow_fallbacks, + require_parameters: prefs.require_parameters, + }; + openrouter_provider = openrouter_provider.with_provider_preferences(provider_prefs); + } + + if let Some(referer) = &openrouter_config.http_referer { + openrouter_provider = openrouter_provider.with_http_referer(referer.clone()); + } + + if let Some(title) = &openrouter_config.x_title { + openrouter_provider = openrouter_provider.with_x_title(title.clone()); + } + + providers.register(openrouter_provider); + } + } + // Register Anthropic providers from HashMap for (name, anthropic_config) in &config.providers.anthropic { if should_register("anthropic", name) { @@ -1543,7 +1576,11 @@ impl Agent { "openai" => config.providers.openai.get(config_name)?.max_tokens, "databricks" => config.providers.databricks.get(config_name)?.max_tokens, "embedded" => config.providers.embedded.get(config_name)?.max_tokens, - _ => None, + "openrouter" => config.providers.openrouter.get(config_name)?.max_tokens, + _ => { + // Check openai_compatible + config.providers.openai_compatible.get(provider_type)?.max_tokens + } } } @@ -1563,7 +1600,11 @@ impl Agent { "openai" => config.providers.openai.get(config_name)?.temperature, "databricks" => config.providers.databricks.get(config_name)?.temperature, "embedded" => config.providers.embedded.get(config_name)?.temperature, - _ => None, + "openrouter" => config.providers.openrouter.get(config_name)?.temperature, + _ => { + // Check openai_compatible + config.providers.openai_compatible.get(provider_type)?.temperature + } } } @@ -1946,6 +1987,17 @@ impl Agent { 16384 // Conservative default for other Databricks models } } + "openrouter" => { + if let Some(max_tokens) = Self::provider_max_tokens(config, provider_name) { + warnings.push(format!( + "Context length falling back to max_tokens ({}) for provider={}", + max_tokens, provider_name + )); + max_tokens + } else { + 128000 // Default for OpenRouter + } + } _ => config.agent.fallback_default_max_tokens as u32, }; diff --git a/crates/g3-planner/src/llm.rs b/crates/g3-planner/src/llm.rs index 9eb610b0..4623574e 100644 --- a/crates/g3-planner/src/llm.rs +++ b/crates/g3-planner/src/llm.rs @@ -206,7 +206,7 @@ impl PlannerUiWriter { } /// Clear the current line and print a status message - fn print_status_line(&self, message: &str) { + fn _print_status_line(&self, message: &str) { // Print status message without overwriting previous content // Use println to ensure each status is on its own line println!("{:.80}", message); diff --git a/crates/g3-providers/src/anthropic.rs b/crates/g3-providers/src/anthropic.rs index d0258641..a2fd29a5 100644 --- a/crates/g3-providers/src/anthropic.rs +++ b/crates/g3-providers/src/anthropic.rs @@ -125,7 +125,7 @@ pub struct AnthropicProvider { model: String, max_tokens: u32, temperature: f32, - cache_config: Option, + _cache_config: Option, enable_1m_context: bool, thinking_budget_tokens: Option, } @@ -136,7 +136,7 @@ impl AnthropicProvider { model: Option, max_tokens: Option, temperature: Option, - cache_config: Option, + _cache_config: Option, enable_1m_context: Option, thinking_budget_tokens: Option, ) -> Result { @@ -156,7 +156,7 @@ impl AnthropicProvider { model, max_tokens: max_tokens.unwrap_or(4096), temperature: temperature.unwrap_or(0.1), - cache_config, + _cache_config, enable_1m_context: enable_1m_context.unwrap_or(false), thinking_budget_tokens, }) @@ -169,7 +169,7 @@ impl AnthropicProvider { model: Option, max_tokens: Option, temperature: Option, - cache_config: Option, + _cache_config: Option, enable_1m_context: Option, thinking_budget_tokens: Option, ) -> Result { @@ -189,7 +189,7 @@ impl AnthropicProvider { model, max_tokens: max_tokens.unwrap_or(4096), temperature: temperature.unwrap_or(0.1), - cache_config, + _cache_config, enable_1m_context: enable_1m_context.unwrap_or(false), thinking_budget_tokens, }) diff --git a/crates/g3-providers/src/lib.rs b/crates/g3-providers/src/lib.rs index 0e088ce2..970dda68 100644 --- a/crates/g3-providers/src/lib.rs +++ b/crates/g3-providers/src/lib.rs @@ -144,11 +144,13 @@ pub mod databricks; pub mod embedded; pub mod oauth; pub mod openai; +pub mod openrouter; pub use anthropic::AnthropicProvider; pub use databricks::DatabricksProvider; pub use embedded::EmbeddedProvider; pub use openai::OpenAIProvider; +pub use openrouter::{OpenRouterProvider, ProviderPreferences}; impl Message { /// Generate a unique message ID in format HHMMSS-XXX diff --git a/crates/g3-providers/src/openrouter.rs b/crates/g3-providers/src/openrouter.rs new file mode 100644 index 00000000..0a436db7 --- /dev/null +++ b/crates/g3-providers/src/openrouter.rs @@ -0,0 +1,621 @@ +//! OpenRouter provider implementation for the g3-providers crate. +//! +//! This module provides an implementation of the `LLMProvider` trait for OpenRouter's unified API, +//! which provides access to 200+ AI models through a single OpenAI-compatible endpoint. +//! +//! # Features +//! +//! - Support for 200+ models from multiple providers (Anthropic, OpenAI, Google, Meta, etc.) +//! - OpenAI-compatible API with provider routing extensions +//! - Both completion and streaming response modes +//! - Provider preference configuration for routing control +//! - Optional HTTP-Referer and X-Title headers for better analytics +//! +//! # Usage +//! +//! ```rust,no_run +//! use g3_providers::{OpenRouterProvider, LLMProvider, CompletionRequest, Message, MessageRole}; +//! +//! #[tokio::main] +//! async fn main() -> anyhow::Result<()> { +//! // Create the provider with your API key +//! let provider = OpenRouterProvider::new( +//! "your-api-key".to_string(), +//! Some("anthropic/claude-3.5-sonnet".to_string()), +//! None, // max_tokens +//! None, // temperature +//! )?; +//! +//! // Create a completion request +//! let request = CompletionRequest { +//! messages: vec![ +//! Message::new(MessageRole::User, "Hello! How are you?".to_string()), +//! ], +//! max_tokens: Some(1000), +//! temperature: Some(0.7), +//! stream: false, +//! tools: None, +//! disable_thinking: false, +//! }; +//! +//! // Get a completion +//! let response = provider.complete(request).await?; +//! println!("Response: {}", response.content); +//! +//! Ok(()) +//! } +//! ``` +use anyhow::Result; +use async_trait::async_trait; +use bytes::Bytes; +use futures_util::stream::StreamExt; +use reqwest::Client; +use serde::{Deserialize, Serialize}; +use serde_json::json; +use tokio::sync::mpsc; +use tokio_stream::wrappers::ReceiverStream; +use tracing::{debug, error}; + +use crate::{ + CompletionChunk, CompletionRequest, CompletionResponse, CompletionStream, LLMProvider, Message, + MessageRole, Tool, ToolCall, Usage, +}; + +const OPENROUTER_BASE_URL: &str = "https://openrouter.ai/api/v1"; + +#[derive(Debug, Clone, Serialize)] +pub struct ProviderPreferences { + #[serde(skip_serializing_if = "Option::is_none")] + pub order: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub allow_fallbacks: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub require_parameters: Option, +} + +#[derive(Clone)] +pub struct OpenRouterProvider { + client: Client, + api_key: String, + model: String, + base_url: String, + max_tokens: Option, + _temperature: Option, + name: String, + provider_preferences: Option, + http_referer: Option, + x_title: Option, +} + +impl OpenRouterProvider { + pub fn new( + api_key: String, + model: Option, + max_tokens: Option, + temperature: Option, + ) -> Result { + Self::new_with_name( + "openrouter".to_string(), + api_key, + model, + max_tokens, + temperature, + ) + } + + pub fn new_with_name( + name: String, + api_key: String, + model: Option, + max_tokens: Option, + temperature: Option, + ) -> Result { + Ok(Self { + client: Client::new(), + api_key, + model: model.unwrap_or_else(|| "anthropic/claude-3.5-sonnet".to_string()), + base_url: OPENROUTER_BASE_URL.to_string(), + max_tokens, + _temperature: temperature, + name, + provider_preferences: None, + http_referer: None, + x_title: None, + }) + } + + pub fn with_provider_preferences(mut self, preferences: ProviderPreferences) -> Self { + self.provider_preferences = Some(preferences); + self + } + + pub fn with_http_referer(mut self, referer: String) -> Self { + self.http_referer = Some(referer); + self + } + + pub fn with_x_title(mut self, title: String) -> Self { + self.x_title = Some(title); + self + } + + fn create_request_body( + &self, + messages: &[Message], + tools: Option<&[Tool]>, + stream: bool, + max_tokens: Option, + _temperature: Option, + ) -> serde_json::Value { + let mut body = json!({ + "model": self.model, + "messages": convert_messages(messages), + "stream": stream, + }); + + if let Some(max_tokens) = max_tokens.or(self.max_tokens) { + body["max_tokens"] = json!(max_tokens); + } + + if let Some(tools) = tools { + if !tools.is_empty() { + body["tools"] = json!(convert_tools(tools)); + } + } + + if let Some(ref preferences) = self.provider_preferences { + body["provider"] = serde_json::to_value(preferences).unwrap_or(json!({})); + } + + if stream { + body["stream_options"] = json!({ + "include_usage": true, + }); + } + + body + } + + async fn parse_streaming_response( + &self, + mut stream: impl futures_util::Stream> + Unpin, + tx: mpsc::Sender>, + ) -> Option { + let mut buffer = String::new(); + let mut accumulated_usage: Option = None; + let mut current_tool_calls: Vec = Vec::new(); + + while let Some(chunk_result) = stream.next().await { + match chunk_result { + Ok(chunk) => { + let chunk_str = match std::str::from_utf8(&chunk) { + Ok(s) => s, + Err(e) => { + error!("Failed to parse chunk as UTF-8: {}", e); + continue; + } + }; + + buffer.push_str(chunk_str); + + // Process complete lines + while let Some(line_end) = buffer.find('\n') { + let line = buffer[..line_end].trim().to_string(); + buffer.drain(..line_end + 1); + + if line.is_empty() { + continue; + } + + // Parse Server-Sent Events format + if let Some(data) = line.strip_prefix("data: ") { + if data == "[DONE]" { + debug!("Received stream completion marker"); + + let tool_calls = if current_tool_calls.is_empty() { + None + } else { + Some( + current_tool_calls + .iter() + .filter_map(|tc| tc.to_tool_call()) + .collect(), + ) + }; + + let final_chunk = CompletionChunk { + content: String::new(), + finished: true, + tool_calls, + usage: accumulated_usage.clone(), + }; + let _ = tx.send(Ok(final_chunk)).await; + return accumulated_usage; + } + + // Parse the JSON data + match serde_json::from_str::(data) { + Ok(chunk_data) => { + // Handle content + for choice in &chunk_data.choices { + if let Some(content) = &choice.delta.content { + let chunk = CompletionChunk { + content: content.clone(), + finished: false, + tool_calls: None, + usage: None, + }; + if tx.send(Ok(chunk)).await.is_err() { + debug!("Receiver dropped, stopping stream"); + return accumulated_usage; + } + } + + // Handle tool calls + if let Some(delta_tool_calls) = &choice.delta.tool_calls { + for delta_tool_call in delta_tool_calls { + if let Some(index) = delta_tool_call.index { + // Ensure we have enough tool calls in our vector + while current_tool_calls.len() <= index { + current_tool_calls.push( + OpenRouterStreamingToolCall::default(), + ); + } + + let tool_call = &mut current_tool_calls[index]; + + if let Some(id) = &delta_tool_call.id { + tool_call.id = Some(id.clone()); + } + + if let Some(function) = + &delta_tool_call.function + { + if let Some(name) = &function.name { + tool_call.name = Some(name.clone()); + } + if let Some(arguments) = &function.arguments + { + tool_call.arguments.push_str(arguments); + } + } + } + } + } + } + + // Handle usage + if let Some(usage) = chunk_data.usage { + accumulated_usage = Some(Usage { + prompt_tokens: usage.prompt_tokens, + completion_tokens: usage.completion_tokens, + total_tokens: usage.total_tokens, + }); + } + } + Err(e) => { + debug!("Failed to parse stream chunk: {} - Data: {}", e, data); + } + } + } + } + } + Err(e) => { + error!("Stream error: {}", e); + let _ = tx.send(Err(anyhow::anyhow!("Stream error: {}", e))).await; + return accumulated_usage; + } + } + } + + // Send final chunk if we haven't already + let tool_calls = if current_tool_calls.is_empty() { + None + } else { + Some( + current_tool_calls + .iter() + .filter_map(|tc| tc.to_tool_call()) + .collect(), + ) + }; + + let final_chunk = CompletionChunk { + content: String::new(), + finished: true, + tool_calls, + usage: accumulated_usage.clone(), + }; + let _ = tx.send(Ok(final_chunk)).await; + + accumulated_usage + } +} + +#[async_trait] +impl LLMProvider for OpenRouterProvider { + async fn complete(&self, request: CompletionRequest) -> Result { + debug!( + "Processing OpenRouter completion request with {} messages", + request.messages.len() + ); + + let body = self.create_request_body( + &request.messages, + request.tools.as_deref(), + false, + request.max_tokens, + request.temperature, + ); + + debug!("Sending request to OpenRouter API: model={}", self.model); + + let mut req = self + .client + .post(format!("{}/chat/completions", self.base_url)) + .header("Authorization", format!("Bearer {}", self.api_key)); + + if let Some(ref referer) = self.http_referer { + req = req.header("HTTP-Referer", referer); + } + + if let Some(ref title) = self.x_title { + req = req.header("X-Title", title); + } + + let response = req.json(&body).send().await?; + + let status = response.status(); + if !status.is_success() { + let error_text = response + .text() + .await + .unwrap_or_else(|_| "Unknown error".to_string()); + return Err(anyhow::anyhow!( + "OpenRouter API error {}: {}", + status, + error_text + )); + } + + let openrouter_response: OpenRouterResponse = response.json().await?; + + let content = openrouter_response + .choices + .first() + .and_then(|choice| choice.message.content.clone()) + .unwrap_or_default(); + + let usage = Usage { + prompt_tokens: openrouter_response.usage.prompt_tokens, + completion_tokens: openrouter_response.usage.completion_tokens, + total_tokens: openrouter_response.usage.total_tokens, + }; + + debug!( + "OpenRouter completion successful: {} tokens generated", + usage.completion_tokens + ); + + Ok(CompletionResponse { + content, + usage, + model: self.model.clone(), + }) + } + + async fn stream(&self, request: CompletionRequest) -> Result { + debug!( + "Processing OpenRouter streaming request with {} messages", + request.messages.len() + ); + + let body = self.create_request_body( + &request.messages, + request.tools.as_deref(), + true, + request.max_tokens, + request.temperature, + ); + + debug!( + "Sending streaming request to OpenRouter API: model={}", + self.model + ); + + let mut req = self + .client + .post(format!("{}/chat/completions", self.base_url)) + .header("Authorization", format!("Bearer {}", self.api_key)); + + if let Some(ref referer) = self.http_referer { + req = req.header("HTTP-Referer", referer); + } + + if let Some(ref title) = self.x_title { + req = req.header("X-Title", title); + } + + let response = req.json(&body).send().await?; + + let status = response.status(); + if !status.is_success() { + let error_text = response + .text() + .await + .unwrap_or_else(|_| "Unknown error".to_string()); + return Err(anyhow::anyhow!( + "OpenRouter API error {}: {}", + status, + error_text + )); + } + + let stream = response.bytes_stream(); + let (tx, rx) = mpsc::channel(100); + + // Spawn task to process the stream + let provider = self.clone(); + tokio::spawn(async move { + let usage = provider.parse_streaming_response(stream, tx).await; + // Log the final usage if available + if let Some(usage) = usage { + debug!( + "Stream completed with usage - prompt: {}, completion: {}, total: {}", + usage.prompt_tokens, usage.completion_tokens, usage.total_tokens + ); + } + }); + + Ok(ReceiverStream::new(rx)) + } + + fn name(&self) -> &str { + &self.name + } + + fn model(&self) -> &str { + &self.model + } + + fn has_native_tool_calling(&self) -> bool { + // OpenRouter supports tool calling via OpenAI-compatible format + true + } + + fn max_tokens(&self) -> u32 { + self.max_tokens.unwrap_or(4096) + } + + fn temperature(&self) -> f32 { + self._temperature.unwrap_or(0.7) + } +} + +fn convert_messages(messages: &[Message]) -> Vec { + messages + .iter() + .map(|msg| { + json!({ + "role": match msg.role { + MessageRole::System => "system", + MessageRole::User => "user", + MessageRole::Assistant => "assistant", + }, + "content": msg.content, + }) + }) + .collect() +} + +fn convert_tools(tools: &[Tool]) -> Vec { + tools + .iter() + .map(|tool| { + json!({ + "type": "function", + "function": { + "name": tool.name, + "description": tool.description, + "parameters": tool.input_schema, + } + }) + }) + .collect() +} + +// OpenRouter API response structures (OpenAI-compatible) +#[derive(Debug, Deserialize)] +struct OpenRouterResponse { + choices: Vec, + usage: OpenRouterUsage, +} + +#[derive(Debug, Deserialize)] +struct OpenRouterChoice { + message: OpenRouterMessage, +} + +#[allow(dead_code)] +#[derive(Debug, Deserialize)] +struct OpenRouterMessage { + content: Option, + #[serde(default)] + tool_calls: Option>, +} + +#[allow(dead_code)] +#[derive(Debug, Deserialize)] +struct OpenRouterToolCall { + id: String, + function: OpenRouterFunction, +} + +#[allow(dead_code)] +#[derive(Debug, Deserialize)] +struct OpenRouterFunction { + name: String, + arguments: String, +} + +// Streaming tool call accumulator +#[derive(Debug, Default)] +struct OpenRouterStreamingToolCall { + id: Option, + name: Option, + arguments: String, +} + +impl OpenRouterStreamingToolCall { + fn to_tool_call(&self) -> Option { + let id = self.id.as_ref()?; + let name = self.name.as_ref()?; + + let args = serde_json::from_str(&self.arguments).unwrap_or(serde_json::Value::Null); + + Some(ToolCall { + id: id.clone(), + tool: name.clone(), + args, + }) + } +} + +#[derive(Debug, Deserialize)] +struct OpenRouterUsage { + prompt_tokens: u32, + completion_tokens: u32, + total_tokens: u32, +} + +// Streaming response structures +#[derive(Debug, Deserialize)] +struct OpenRouterStreamChunk { + choices: Vec, + usage: Option, +} + +#[derive(Debug, Deserialize)] +struct OpenRouterStreamChoice { + delta: OpenRouterDelta, +} + +#[derive(Debug, Deserialize)] +struct OpenRouterDelta { + content: Option, + #[serde(default)] + tool_calls: Option>, +} + +#[derive(Debug, Deserialize)] +struct OpenRouterDeltaToolCall { + index: Option, + id: Option, + function: Option, +} + +#[derive(Debug, Deserialize)] +struct OpenRouterDeltaFunction { + name: Option, + arguments: Option, + } \ No newline at end of file diff --git a/crates/tests/openrouter_integration_tests.rs b/crates/tests/openrouter_integration_tests.rs new file mode 100644 index 00000000..ebbb9bd0 --- /dev/null +++ b/crates/tests/openrouter_integration_tests.rs @@ -0,0 +1,337 @@ +//! Integration tests for OpenRouter provider +//! +//! These tests verify OpenRouter provider functionality including basic API integration, +//! streaming support, and provider routing features. +use g3_providers::{ + CompletionRequest, LLMProvider, Message, MessageRole, OpenRouterProvider, ProviderPreferences, + Tool, +}; +use serde_json::json; +use std::env; +use tokio_stream::StreamExt; + +/// Helper function to get API key from environment or skip test +fn get_api_key_or_skip() -> Option { + env::var("OPENROUTER_API_KEY").ok() +} + +#[tokio::test] +async fn test_openrouter_basic_completion() { + let Some(api_key) = get_api_key_or_skip() else { + println!("Skipping test: OPENROUTER_API_KEY not set"); + return; + }; + + let provider = OpenRouterProvider::new( + api_key, + Some("anthropic/claude-3.5-sonnet".to_string()), + Some(100), + Some(0.7), + ) + .expect("Failed to create OpenRouter provider"); + + let request = CompletionRequest { + messages: vec![Message::new( + MessageRole::User, + "Say 'test successful' and nothing else.".to_string(), + )], + max_tokens: Some(50), + temperature: Some(0.7), + stream: false, + tools: None, + disable_thinking: false, + }; + + let response = provider + .complete(request) + .await + .expect("Completion request failed"); + + println!("Response: {}", response.content); + assert!(!response.content.is_empty(), "Response should not be empty"); + assert!( + response.usage.total_tokens > 0, + "Token usage should be tracked" + ); +} + +#[tokio::test] +async fn test_openrouter_streaming() { + let Some(api_key) = get_api_key_or_skip() else { + println!("Skipping test: OPENROUTER_API_KEY not set"); + return; + }; + + let provider = OpenRouterProvider::new( + api_key, + Some("anthropic/claude-3.5-sonnet".to_string()), + Some(100), + Some(0.7), + ) + .expect("Failed to create OpenRouter provider"); + + let request = CompletionRequest { + messages: vec![Message::new( + MessageRole::User, + "Count from 1 to 5.".to_string(), + )], + max_tokens: Some(50), + temperature: Some(0.7), + stream: true, + tools: None, + disable_thinking: false, + }; + + let mut stream = provider + .stream(request) + .await + .expect("Streaming request failed"); + + let mut accumulated_content = String::new(); + let mut chunk_count = 0; + let mut final_usage = None; + + while let Some(chunk_result) = stream.next().await { + match chunk_result { + Ok(chunk) => { + chunk_count += 1; + accumulated_content.push_str(&chunk.content); + + if chunk.finished { + final_usage = chunk.usage; + println!("Stream finished after {} chunks", chunk_count); + break; + } + } + Err(e) => { + panic!("Stream error: {}", e); + } + } + } + + println!("Accumulated content: {}", accumulated_content); + assert!(!accumulated_content.is_empty(), "Should receive content"); + assert!(chunk_count > 0, "Should receive at least one chunk"); + assert!(final_usage.is_some(), "Should track token usage"); +} + +#[tokio::test] +async fn test_openrouter_with_provider_preferences() { + let Some(api_key) = get_api_key_or_skip() else { + println!("Skipping test: OPENROUTER_API_KEY not set"); + return; + }; + + let preferences = ProviderPreferences { + order: Some(vec!["Anthropic".to_string()]), + allow_fallbacks: Some(true), + require_parameters: Some(false), + }; + + let provider = OpenRouterProvider::new( + api_key, + Some("anthropic/claude-3.5-sonnet".to_string()), + Some(100), + Some(0.7), + ) + .expect("Failed to create OpenRouter provider") + .with_provider_preferences(preferences); + + let request = CompletionRequest { + messages: vec![Message::new( + MessageRole::User, + "Reply with 'ok'.".to_string(), + )], + max_tokens: Some(50), + temperature: Some(0.7), + stream: false, + tools: None, + disable_thinking: false, + }; + + let response = provider + .complete(request) + .await + .expect("Completion request with provider preferences failed"); + + assert!(!response.content.is_empty()); +} + +#[tokio::test] +async fn test_openrouter_with_http_headers() { + let Some(api_key) = get_api_key_or_skip() else { + println!("Skipping test: OPENROUTER_API_KEY not set"); + return; + }; + + let provider = OpenRouterProvider::new( + api_key, + Some("anthropic/claude-3.5-sonnet".to_string()), + Some(100), + Some(0.7), + ) + .expect("Failed to create OpenRouter provider") + .with_http_referer("https://example.com".to_string()) + .with_x_title("G3 Test Suite".to_string()); + + let request = CompletionRequest { + messages: vec![Message::new( + MessageRole::User, + "Reply with 'ok'.".to_string(), + )], + max_tokens: Some(50), + temperature: Some(0.7), + stream: false, + tools: None, + disable_thinking: false, + }; + + let response = provider + .complete(request) + .await + .expect("Completion request with HTTP headers failed"); + + assert!(!response.content.is_empty()); +} + +#[tokio::test] +async fn test_openrouter_tool_calling() { + let Some(api_key) = get_api_key_or_skip() else { + println!("Skipping test: OPENROUTER_API_KEY not set"); + return; + }; + + let provider = OpenRouterProvider::new( + api_key, + Some("anthropic/claude-3.5-sonnet".to_string()), + Some(500), + Some(0.7), + ) + .expect("Failed to create OpenRouter provider"); + + let weather_tool = Tool { + name: "get_weather".to_string(), + description: "Get the current weather for a location".to_string(), + input_schema: json!({ + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA" + } + }, + "required": ["location"] + }), + }; + + let request = CompletionRequest { + messages: vec![Message::new( + MessageRole::User, + "What's the weather like in Tokyo?".to_string(), + )], + max_tokens: Some(500), + temperature: Some(0.7), + stream: false, + tools: Some(vec![weather_tool]), + disable_thinking: false, + }; + + let response = provider + .complete(request) + .await + .expect("Tool calling request failed"); + + println!("Response content: {}", response.content); + println!("Token usage: {:?}", response.usage); + + // Note: Tool calling may or may not be invoked depending on model behavior + assert!( + response.usage.total_tokens > 0, + "Token usage should be tracked" + ); +} + +#[test] +fn test_provider_preferences_serialization() { + let preferences = ProviderPreferences { + order: Some(vec!["Anthropic".to_string(), "OpenAI".to_string()]), + allow_fallbacks: Some(true), + require_parameters: Some(false), + }; + + let json = serde_json::to_value(&preferences).unwrap(); + println!("Provider preferences JSON: {}", json); + + assert!(json.get("order").is_some()); + assert_eq!(json.get("allow_fallbacks").unwrap(), &json!(true)); + assert_eq!(json.get("require_parameters").unwrap(), &json!(false)); +} + +#[test] +fn test_provider_preferences_partial_serialization() { + // Test that None fields are omitted from JSON + let preferences = ProviderPreferences { + order: None, + allow_fallbacks: Some(true), + require_parameters: None, + }; + + let json = serde_json::to_value(&preferences).unwrap(); + println!("Partial provider preferences JSON: {}", json); + + assert!( + !json.as_object().unwrap().contains_key("order"), + "None fields should be omitted" + ); + assert!(json.get("allow_fallbacks").is_some()); + assert!( + !json.as_object().unwrap().contains_key("require_parameters"), + "None fields should be omitted" + ); +} + +#[test] +fn test_openrouter_provider_trait_implementation() { + let provider = OpenRouterProvider::new( + "test_key".to_string(), + Some("anthropic/claude-3.5-sonnet".to_string()), + Some(4096), + Some(0.7), + ) + .expect("Failed to create provider"); + + // Test LLMProvider trait methods + assert_eq!(provider.name(), "openrouter"); + assert_eq!(provider.model(), "anthropic/claude-3.5-sonnet"); + assert!(provider.has_native_tool_calling()); + assert_eq!(provider.max_tokens(), 4096); + assert_eq!(provider.temperature(), 0.7); +} + +#[test] +fn test_openrouter_provider_with_custom_name() { + let provider = OpenRouterProvider::new_with_name( + "openrouter.custom".to_string(), + "test_key".to_string(), + Some("openai/gpt-4o".to_string()), + None, + None, + ) + .expect("Failed to create provider"); + + assert_eq!(provider.name(), "openrouter.custom"); + assert_eq!(provider.model(), "openai/gpt-4o"); +} + +#[test] +fn test_openrouter_default_model() { + let provider = OpenRouterProvider::new( + "test_key".to_string(), + None, // No model specified + None, + None, + ) + .expect("Failed to create provider"); + + assert_eq!(provider.model(), "anthropic/claude-3.5-sonnet"); +} \ No newline at end of file