diff --git a/Cargo.toml b/Cargo.toml index 6dfe403..7b3eaf7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,6 +13,7 @@ tree-sitter = "0.20" serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" thiserror = "1.0" +rig-core = "0.16.0" reqwest = { version = "0.11", features = ["json"] } futures = "0.3" lazy_static = "1.4" diff --git a/src/analyzer/llm_client.rs b/src/analyzer/llm_client.rs index 12f1282..041fd7d 100644 --- a/src/analyzer/llm_client.rs +++ b/src/analyzer/llm_client.rs @@ -1,21 +1,16 @@ use crate::error::EbiError; use crate::models::{AnalysisRequest, AnalysisResult, AnalysisType, OutputLanguage}; -use reqwest; -use serde::{Deserialize, Serialize}; +use rig::completion::{CompletionModel, AssistantContent}; +use rig::providers::{anthropic, gemini, openai}; +use rig::client::CompletionClient; use std::collections::HashSet; use std::future::Future; use std::pin::Pin; use std::time::Duration; -use tokio::time::timeout; - -// Constants for Claude API configuration -const CLAUDE_DEFAULT_MAX_TOKENS: u32 = 1000; -const CLAUDE_DEFAULT_TEMPERATURE: f32 = 0.3; #[derive(Debug, Clone)] pub struct LlmConfig { pub model_name: String, - pub api_endpoint: String, pub api_key: Option, pub timeout_seconds: u64, pub max_retries: u32, @@ -23,89 +18,6 @@ pub struct LlmConfig { pub temperature: Option, } -#[derive(Debug, Serialize)] -struct LlmApiRequest { - model: String, - messages: Vec, - #[serde(skip_serializing_if = "Option::is_none")] - max_tokens: Option, - #[serde(skip_serializing_if = "Option::is_none")] - max_completion_tokens: Option, - #[serde(skip_serializing_if = "Option::is_none")] - temperature: Option, -} - -#[derive(Debug, Serialize)] -struct ChatMessage { - role: String, - content: String, -} - -#[derive(Debug, Serialize)] -struct ResponsesApiRequest { - model: String, - input: String, - #[serde(skip_serializing_if = "Option::is_none")] - reasoning: Option, - #[serde(skip_serializing_if = "Option::is_none")] - text: Option, -} - -#[derive(Debug, Serialize)] -struct ResponsesReasoning { - effort: String, -} - -#[derive(Debug, Serialize)] -struct ResponsesText { - verbosity: String, -} - -#[derive(Debug, Deserialize)] -struct LlmApiResponse { - choices: Vec, - usage: Option, -} - -#[derive(Debug, Deserialize)] -struct ResponsesApiResponse { - #[serde(default)] - output: Vec, - #[serde(default)] - output_text: Option, -} - -#[derive(Debug, Deserialize)] -struct ResponsesOutput { - #[serde(default)] - content: Vec, -} - -#[derive(Debug, Deserialize)] -struct ResponsesContent { - #[serde(rename = "type")] - content_type: String, - text: Option, -} - -#[derive(Debug, Deserialize)] -struct Choice { - message: ResponseMessage, - finish_reason: Option, -} - -#[derive(Debug, Deserialize)] -struct ResponseMessage { - content: String, -} - -#[derive(Debug, Deserialize)] -struct Usage { - total_tokens: u32, - prompt_tokens: u32, - completion_tokens: u32, -} - pub trait LlmProvider: Send + Sync { fn analyze<'a>( &'a self, @@ -116,11 +28,16 @@ pub trait LlmProvider: Send + Sync { } fn extract_summary_from_response(response: &str) -> String { + if response.trim().is_empty() { + return "No response from model.".to_string(); + } + let lines: Vec<&str> = response.lines().collect(); + // First, try to find explicit summary sections for (i, line) in lines.iter().enumerate() { let line_lower = line.to_lowercase(); - if line_lower.contains("summary:") || line_lower.contains("概要:") { + if line_lower.contains("summary:") || line_lower.contains("概要:") || line_lower.contains("script purpose:") { let summary_lines: Vec<&str> = lines .iter() .skip(i + 1) @@ -131,17 +48,21 @@ fn extract_summary_from_response(response: &str) -> String { } let lower = trimmed.to_lowercase(); - !lower.contains("analysis:") && !lower.contains("分析:") + !lower.contains("analysis:") && !lower.contains("分析:") && !lower.contains("specific findings:") }) .copied() .collect(); if !summary_lines.is_empty() { - return summary_lines.join("\n").trim().to_string(); + let summary = summary_lines.join("\n").trim().to_string(); + if !summary.is_empty() && summary.len() > 10 { + return clean_summary_text(summary); + } } } } + // If no explicit summary found, extract meaningful content let meaningful_lines: Vec<&str> = lines .iter() .filter(|line| { @@ -150,24 +71,33 @@ fn extract_summary_from_response(response: &str) -> String { return false; } - if line_clean.starts_with('#') { + // Skip headers, metadata, and structural elements + if line_clean.starts_with('#') || line_clean.starts_with("```") { return false; } let uppercase = line_clean.to_uppercase(); - if uppercase.starts_with("RISK LEVEL:") || uppercase.starts_with("CONFIDENCE:") { + if uppercase.starts_with("RISK LEVEL:") + || uppercase.starts_with("CONFIDENCE:") + || uppercase.starts_with("MODEL:") + || uppercase.starts_with("→ CONCERN:") + || uppercase.starts_with("→ CONTEXT:") + || line_clean.starts_with("Line ") { return false; } - true + // Include meaningful content lines + line_clean.len() > 10 && !line_clean.starts_with("Line ") }) - .take(6) + .take(8) .copied() .collect(); let raw_summary = if meaningful_lines.is_empty() { + // Fallback: take first non-empty lines response .lines() + .filter(|line| !line.trim().is_empty() && line.trim().len() > 5) .take(3) .collect::>() .join("\n") @@ -175,31 +105,36 @@ fn extract_summary_from_response(response: &str) -> String { .take(600) .collect() } else { - meaningful_lines.join("\n").chars().take(1200).collect() + meaningful_lines.join(" ").chars().take(1200).collect() }; - let mut summary = clean_summary_text(raw_summary); + let summary = clean_summary_text(raw_summary); - if summary.trim().len() < 5 { - summary = response + if summary.trim().len() < 10 { + // Last resort: take any substantial content + let fallback = response .lines() .filter(|line| { let trimmed = line.trim(); !trimmed.is_empty() + && trimmed.len() > 15 && !trimmed.to_uppercase().starts_with("RISK LEVEL:") && !trimmed.to_uppercase().starts_with("CONFIDENCE:") + && !trimmed.starts_with("Line ") + && !trimmed.starts_with("→") }) - .take(6) - .map(|line| line.replace('|', " ")) + .take(3) .collect::>() - .join("\n"); + .join(" "); - if summary.trim().is_empty() { - summary = "Summary not provided by model.".to_string(); + if fallback.trim().len() > 10 { + clean_summary_text(fallback) + } else { + format!("Analysis completed for script content ({})", response.chars().take(50).collect::().trim()) } + } else { + summary } - - summary } fn parse_explicit_risk_level(line: &str) -> Option { @@ -230,7 +165,6 @@ fn extract_risk_token( let marker_pos = lower_line.find(&marker_lower)?; let after_marker = &original_line[marker_pos + marker.len()..]; - // Split on colon-like separators first, then take the leading token let token_section = after_marker .split(|c| c == ':' || c == ':') .nth(1) @@ -524,417 +458,96 @@ fn clean_summary_text(summary: String) -> String { } } -fn is_responses_model(model: &str) -> bool { - let candidate = model.to_lowercase(); - candidate.starts_with("gpt-5") +pub struct RigLlmClient { + config: LlmConfig, + provider: RigProvider, } -pub struct OpenAiCompatibleClient { - config: LlmConfig, - client: reqwest::Client, +enum RigProvider { + OpenAI(openai::Client), + OpenAIResponses(openai::Client), + Anthropic(anthropic::Client), + Gemini(gemini::Client), } -impl OpenAiCompatibleClient { +impl RigLlmClient { pub fn new(config: LlmConfig) -> Result { - let client = reqwest::Client::builder() - .timeout(Duration::from_secs(config.timeout_seconds)) - .build() - .map_err(|e| { - EbiError::LlmClientError(format!("Failed to create HTTP client: {}", e)) - })?; - - Ok(Self { config, client }) + let provider = create_provider(&config)?; + Ok(Self { config, provider }) } async fn make_api_request(&self, request: &AnalysisRequest) -> Result { let prompt = self.build_prompt(request); - let use_responses_api = is_responses_model(&self.config.model_name); - let chat_request; - let responses_request; - if use_responses_api { - chat_request = None; - responses_request = Some(self.build_responses_request(prompt, &request.analysis_type)); - } else { - chat_request = Some(self.build_api_request( - prompt, - &request.analysis_type, - &request.output_language, - )); - responses_request = None; - } - - let mut retries = 0; - loop { - let timeout_secs = request.timeout_seconds.min(self.config.timeout_seconds); - let timeout_duration = Duration::from_secs(timeout_secs); + let system_prompt = self.build_system_prompt(&request.analysis_type, &request.output_language); - let mut http_request = self - .client - .post(&self.config.api_endpoint) - .header("Content-Type", "application/json"); - - if let Some(ref req) = chat_request { - http_request = http_request.json(req); - } else if let Some(ref req) = responses_request { - http_request = http_request.json(req); + match &self.provider { + RigProvider::OpenAI(client) | RigProvider::OpenAIResponses(client) => { + let model = client.completion_model(&self.config.model_name); + self.send_completion_request(model, &prompt, system_prompt).await } - - if let Some(ref api_key) = self.config.api_key { - if !api_key.is_empty() { - http_request = - http_request.header("Authorization", format!("Bearer {}", api_key)); - } + RigProvider::Anthropic(client) => { + let model = client.completion_model(&self.config.model_name); + self.send_completion_request(model, &prompt, system_prompt).await } - - let response = timeout(timeout_duration, http_request.send()).await; - - match response { - Ok(Ok(resp)) => { - if resp.status().is_success() { - if use_responses_api { - let api_response: ResponsesApiResponse = - resp.json().await.map_err(|e| { - EbiError::LlmClientError(format!( - "Failed to parse response: {}", - e - )) - })?; - - if let Some(text) = api_response.output_text { - if !text.trim().is_empty() { - return Ok(text); - } - } - - for output in api_response.output { - for piece in output.content { - if piece.content_type == "output_text" { - if let Some(text) = piece.text { - return Ok(text); - } - } - } - } - - return Ok(String::new()); - } else { - let api_response: LlmApiResponse = resp.json().await.map_err(|e| { - EbiError::LlmClientError(format!("Failed to parse response: {}", e)) - })?; - - if let Some(choice) = api_response.choices.first() { - return Ok(choice.message.content.clone()); - } else { - return Err(EbiError::LlmClientError( - "No response choices received".to_string(), - )); - } - } - } else { - let status = resp.status(); - let error_text = resp - .text() - .await - .unwrap_or_else(|_| "Unknown error".to_string()); - - if retries < self.config.max_retries - && (status.is_server_error() || status == 429) - { - retries += 1; - tokio::time::sleep(Duration::from_millis(1000 * retries as u64)).await; - continue; - } - - return Err(EbiError::LlmClientError(format!( - "API request failed with status {}: {}", - status, error_text - ))); - } - } - Ok(Err(e)) => { - if retries < self.config.max_retries { - retries += 1; - tokio::time::sleep(Duration::from_millis(1000 * retries as u64)).await; - continue; - } - return Err(EbiError::LlmClientError(format!("Network error: {}", e))); - } - Err(_) => { - return Err(EbiError::AnalysisTimeout { - timeout: timeout_secs, - }); - } + RigProvider::Gemini(client) => { + let model = client.completion_model(&self.config.model_name); + self.send_completion_request(model, &prompt, system_prompt).await } } } - fn build_prompt(&self, request: &AnalysisRequest) -> String { - use crate::analyzer::prompts::PromptTemplate; - - match request.analysis_type { - AnalysisType::CodeVulnerability => PromptTemplate::build_vulnerability_analysis_prompt( - &request.content, - &request.context.language, - &request.context.source, - &request.output_language, - ), - AnalysisType::InjectionDetection => PromptTemplate::build_injection_analysis_prompt( - &request.content, - &request.context.language, - &request.context.source, - &request.output_language, - ), - AnalysisType::DetailedRiskAnalysis => { - // For now, pass empty initial findings - this could be enhanced later - PromptTemplate::build_detailed_risk_analysis_prompt( - &request.content, - &request.context.language, - &request.context.source, - &request.output_language, - &[], // initial_findings - could be enhanced to pass actual findings - ) - } - AnalysisType::SpecificThreatAnalysis => { - // For now, pass empty focus lines - this could be enhanced later - PromptTemplate::build_specific_threat_analysis_prompt( - &request.content, - &request.context.language, - &request.context.source, - &request.output_language, - &[], // focus_lines - could be enhanced to pass specific line numbers - ) + async fn send_completion_request( + &self, + model: M, + prompt: &str, + system_prompt: String, + ) -> Result { + let mut builder = model + .completion_request(prompt) + .preamble(system_prompt); + + // Skip temperature for models that don't support it (like GPT-5 series and o1 series) + if let Some(temp) = self.config.temperature { + let model_name = &self.config.model_name; + if !model_name.starts_with("gpt-5") && !model_name.starts_with("o1") && !model_name.starts_with("o3") && !model_name.starts_with("o4") { + builder = builder.temperature(temp as f64); } } - } - fn build_api_request( - &self, - prompt: String, - analysis_type: &AnalysisType, - output_language: &OutputLanguage, - ) -> LlmApiRequest { - build_llm_api_request( - &self.config.model_name, - prompt, - analysis_type, - output_language, - ) - } - - fn build_responses_request( - &self, - prompt: String, - _analysis_type: &AnalysisType, - ) -> ResponsesApiRequest { - ResponsesApiRequest { - model: self.config.model_name.clone(), - input: prompt, - reasoning: Some(ResponsesReasoning { - effort: "minimal".to_string(), - }), - text: Some(ResponsesText { - verbosity: "low".to_string(), - }), + // Set max_tokens with provider-specific defaults + if let Some(max_tokens) = self.config.max_tokens { + builder = builder.max_tokens(max_tokens as u64); + } else { + // Use provider-specific defaults + let default_tokens = match &self.provider { + RigProvider::Anthropic(_) => 4096, // Claude models typically need explicit max_tokens + _ => 1000, + }; + builder = builder.max_tokens(default_tokens); } - } -} -impl LlmProvider for OpenAiCompatibleClient { - fn analyze<'a>( - &'a self, - request: &'a AnalysisRequest, - ) -> Pin> + Send + 'a>> { - Box::pin(async move { - let start_time = std::time::Instant::now(); - - let response_content = self.make_api_request(request).await?; - - let duration_ms = start_time.elapsed().as_millis() as u64; - - // Parse the response to extract risk level and summary - let (risk_level, summary, confidence) = - Self::parse_analysis_response(&response_content); - - let result = AnalysisResult::new( - request.analysis_type.clone(), - self.config.model_name.clone(), - duration_ms, - ) - .with_risk_level(risk_level) - .with_summary(summary) - .with_confidence(confidence) - .with_details(response_content); + let response = builder + .send() + .await + .map_err(|e| EbiError::LlmClientError(format!("Request failed for model {}: {}", self.config.model_name, e)))?; - Ok(result) - }) - } - - fn get_model_name(&self) -> &str { - &self.config.model_name - } - - fn get_timeout(&self) -> Duration { - Duration::from_secs(self.config.timeout_seconds) - } -} - -impl OpenAiCompatibleClient { - fn parse_analysis_response(response: &str) -> (crate::models::RiskLevel, String, f32) { - let response_lower = response.to_lowercase(); - - let risk_level = determine_risk_level(response, &response_lower); - let mut summary = extract_summary_from_response(response); - if let Some(legitimacy_line) = extract_legitimacy_line(response) { - let summary_lower = summary.to_lowercase(); - if summary.is_empty() - || (!summary_lower.contains("legitimacy assessment") - && !summary.contains("正当性評価") - && !summary.contains("正当")) - { - summary = if summary.is_empty() { - legitimacy_line - } else { - format!("{}\n{}", legitimacy_line, summary) - }; + // Extract the text content from the response + let mut extracted_text = String::new(); + for content in response.choice.iter() { + if let AssistantContent::Text(text_content) = content { + extracted_text.push_str(&text_content.text); } } - let confidence = calculate_confidence(response, &response_lower); - - (risk_level, summary, confidence) - } -} - -// Anthropic Claude API structures -#[derive(Debug, Serialize)] -struct ClaudeApiRequest { - model: String, - max_tokens: u32, - messages: Vec, - #[serde(skip_serializing_if = "Option::is_none")] - temperature: Option, - #[serde(skip_serializing_if = "Option::is_none")] - system: Option, -} - -#[derive(Debug, Serialize)] -struct ClaudeMessage { - role: String, - content: String, -} - -#[derive(Debug, Deserialize)] -struct ClaudeApiResponse { - content: Vec, - usage: Option, - #[serde(rename = "stop_reason")] - stop_reason: Option, -} - -#[derive(Debug, Deserialize)] -struct ClaudeContent { - text: String, - #[serde(rename = "type")] - content_type: String, -} - -#[derive(Debug, Deserialize)] -struct ClaudeUsage { - #[serde(rename = "input_tokens")] - input_tokens: u32, - #[serde(rename = "output_tokens")] - output_tokens: u32, -} - -pub struct ClaudeClient { - config: LlmConfig, - client: reqwest::Client, -} -impl ClaudeClient { - pub fn new(config: LlmConfig) -> Result { - let client = reqwest::Client::builder() - .timeout(Duration::from_secs(config.timeout_seconds)) - .build() - .map_err(|e| { - EbiError::LlmClientError(format!("Failed to create HTTP client: {}", e)) - })?; - - Ok(Self { config, client }) - } - - async fn make_api_request(&self, request: &AnalysisRequest) -> Result { - let prompt = self.build_prompt(request); - let api_request = - self.build_api_request(prompt, &request.analysis_type, &request.output_language); - - let mut retries = 0; - loop { - let timeout_secs = request.timeout_seconds.min(self.config.timeout_seconds); - let timeout_duration = Duration::from_secs(timeout_secs); - - let http_request = self - .client - .post(&self.config.api_endpoint) - .header("Content-Type", "application/json") - .header( - "x-api-key", - self.config.api_key.as_ref().unwrap_or(&"".to_string()), - ) - .header("anthropic-version", "2023-06-01") - .json(&api_request); - - let response = timeout(timeout_duration, http_request.send()).await; - - match response { - Ok(Ok(resp)) => { - if resp.status().is_success() { - let api_response: ClaudeApiResponse = resp.json().await.map_err(|e| { - EbiError::LlmClientError(format!("Failed to parse response: {}", e)) - })?; - - if let Some(content) = api_response.content.first() { - return Ok(content.text.clone()); - } else { - return Err(EbiError::LlmClientError( - "No response content received".to_string(), - )); - } - } else { - let status = resp.status(); - let error_text = resp - .text() - .await - .unwrap_or_else(|_| "Unknown error".to_string()); - - if retries < self.config.max_retries - && (status.is_server_error() || status == 429) - { - retries += 1; - tokio::time::sleep(Duration::from_millis(1000 * retries as u64)).await; - continue; - } - - return Err(EbiError::LlmClientError(format!( - "API request failed with status {}: {}", - status, error_text - ))); - } - } - Ok(Err(e)) => { - if retries < self.config.max_retries { - retries += 1; - tokio::time::sleep(Duration::from_millis(1000 * retries as u64)).await; - continue; - } - return Err(EbiError::LlmClientError(format!("Network error: {}", e))); - } - Err(_) => { - return Err(EbiError::AnalysisTimeout { - timeout: timeout_secs, - }); - } - } + // Handle empty responses + if extracted_text.trim().is_empty() { + return Err(EbiError::LlmClientError( + format!("Model {} returned empty response. Response had {} choice(s). This may indicate an API issue, authentication problem, or model configuration issue.", + self.config.model_name, response.choice.len()) + )); } + + Ok(extracted_text) } fn build_prompt(&self, request: &AnalysisRequest) -> String { @@ -959,7 +572,7 @@ impl ClaudeClient { &request.context.language, &request.context.source, &request.output_language, - &[], // No initial findings in base LLM call + &[], ) } AnalysisType::SpecificThreatAnalysis => { @@ -968,74 +581,17 @@ impl ClaudeClient { &request.context.language, &request.context.source, &request.output_language, - &[], // No focus lines in base LLM call + &[], ) } } } - fn build_api_request( - &self, - prompt: String, - analysis_type: &AnalysisType, - output_language: &OutputLanguage, - ) -> ClaudeApiRequest { + fn build_system_prompt(&self, analysis_type: &AnalysisType, output_language: &OutputLanguage) -> String { use crate::analyzer::prompts::PromptTemplate; - let system_prompt = PromptTemplate::build_system_prompt(analysis_type, output_language); - - ClaudeApiRequest { - model: self.config.model_name.clone(), - max_tokens: self.config.max_tokens.unwrap_or(CLAUDE_DEFAULT_MAX_TOKENS), - messages: vec![ClaudeMessage { - role: "user".to_string(), - content: prompt, - }], - temperature: self.config.temperature.or(Some(CLAUDE_DEFAULT_TEMPERATURE)), - system: Some(system_prompt), - } - } -} - -impl LlmProvider for ClaudeClient { - fn analyze<'a>( - &'a self, - request: &'a AnalysisRequest, - ) -> Pin> + Send + 'a>> { - Box::pin(async move { - let start_time = std::time::Instant::now(); - - let response_content = self.make_api_request(request).await?; - - let duration_ms = start_time.elapsed().as_millis() as u64; - - // Parse the response to extract risk level and summary - let (risk_level, summary, confidence) = - Self::parse_analysis_response(&response_content); - - let result = AnalysisResult::new( - request.analysis_type.clone(), - self.config.model_name.clone(), - duration_ms, - ) - .with_risk_level(risk_level) - .with_summary(summary) - .with_confidence(confidence) - .with_details(response_content); - - Ok(result) - }) - } - - fn get_model_name(&self) -> &str { - &self.config.model_name - } - - fn get_timeout(&self) -> Duration { - Duration::from_secs(self.config.timeout_seconds) + PromptTemplate::build_system_prompt(analysis_type, output_language) } -} -impl ClaudeClient { fn parse_analysis_response(response: &str) -> (crate::models::RiskLevel, String, f32) { let response_lower = response.to_lowercase(); @@ -1061,204 +617,7 @@ impl ClaudeClient { } } -// Gemini API structures -#[derive(Debug, Serialize)] -struct GeminiApiRequest { - contents: Vec, - generation_config: Option, -} - -#[derive(Debug, Serialize, Deserialize)] -struct GeminiContent { - parts: Vec, -} - -#[derive(Debug, Serialize, Deserialize)] -struct GeminiPart { - text: String, -} - -#[derive(Debug, Serialize)] -struct GeminiGenerationConfig { - #[serde(skip_serializing_if = "Option::is_none")] - max_output_tokens: Option, - #[serde(skip_serializing_if = "Option::is_none")] - temperature: Option, -} - -#[derive(Debug, Deserialize)] -struct GeminiApiResponse { - candidates: Vec, - usage_metadata: Option, -} - -#[derive(Debug, Deserialize)] -struct GeminiCandidate { - content: GeminiContent, - finish_reason: Option, -} - -#[derive(Debug, Deserialize)] -struct GeminiUsageMetadata { - prompt_token_count: u32, - candidates_token_count: u32, - total_token_count: u32, -} - -pub struct GeminiClient { - config: LlmConfig, - client: reqwest::Client, -} - -impl GeminiClient { - pub fn new(config: LlmConfig) -> Result { - let client = reqwest::Client::builder() - .timeout(Duration::from_secs(config.timeout_seconds)) - .build() - .map_err(|e| { - EbiError::LlmClientError(format!("Failed to create HTTP client: {}", e)) - })?; - - Ok(Self { config, client }) - } - - async fn make_api_request(&self, request: &AnalysisRequest) -> Result { - let prompt = self.build_prompt(request); - let api_request = - self.build_api_request(prompt, &request.analysis_type, &request.output_language); - - let mut retries = 0; - loop { - let timeout_secs = request.timeout_seconds.min(self.config.timeout_seconds); - let timeout_duration = Duration::from_secs(timeout_secs); - - let mut http_request = self - .client - .post(&self.config.api_endpoint) - .header("Content-Type", "application/json") - .json(&api_request); - - if let Some(ref api_key) = self.config.api_key { - if !api_key.is_empty() { - http_request = http_request.query(&[("key", api_key)]); - } - } - - let response = timeout(timeout_duration, http_request.send()).await; - - match response { - Ok(Ok(resp)) => { - if resp.status().is_success() { - let api_response: GeminiApiResponse = resp.json().await.map_err(|e| { - EbiError::LlmClientError(format!("Failed to parse response: {}", e)) - })?; - - if let Some(candidate) = api_response.candidates.first() { - if let Some(part) = candidate.content.parts.first() { - return Ok(part.text.clone()); - } - } - return Err(EbiError::LlmClientError( - "No response content received".to_string(), - )); - } else { - let status = resp.status(); - let error_text = resp - .text() - .await - .unwrap_or_else(|_| "Unknown error".to_string()); - - if retries < self.config.max_retries - && (status.is_server_error() || status == 429) - { - retries += 1; - tokio::time::sleep(Duration::from_millis(1000 * retries as u64)).await; - continue; - } - - return Err(EbiError::LlmClientError(format!( - "API request failed with status {}: {}", - status, error_text - ))); - } - } - Ok(Err(e)) => { - if retries < self.config.max_retries { - retries += 1; - tokio::time::sleep(Duration::from_millis(1000 * retries as u64)).await; - continue; - } - return Err(EbiError::LlmClientError(format!("Network error: {}", e))); - } - Err(_) => { - return Err(EbiError::AnalysisTimeout { - timeout: timeout_secs, - }); - } - } - } - } - - fn build_prompt(&self, request: &AnalysisRequest) -> String { - use crate::analyzer::prompts::PromptTemplate; - - match request.analysis_type { - AnalysisType::CodeVulnerability => PromptTemplate::build_vulnerability_analysis_prompt( - &request.content, - &request.context.language, - &request.context.source, - &request.output_language, - ), - AnalysisType::InjectionDetection => PromptTemplate::build_injection_analysis_prompt( - &request.content, - &request.context.language, - &request.context.source, - &request.output_language, - ), - AnalysisType::DetailedRiskAnalysis => { - PromptTemplate::build_detailed_risk_analysis_prompt( - &request.content, - &request.context.language, - &request.context.source, - &request.output_language, - &[], // No initial findings in base LLM call - ) - } - AnalysisType::SpecificThreatAnalysis => { - PromptTemplate::build_specific_threat_analysis_prompt( - &request.content, - &request.context.language, - &request.context.source, - &request.output_language, - &[], // No focus lines in base LLM call - ) - } - } - } - - fn build_api_request( - &self, - prompt: String, - analysis_type: &AnalysisType, - output_language: &OutputLanguage, - ) -> GeminiApiRequest { - use crate::analyzer::prompts::PromptTemplate; - let system_prompt = PromptTemplate::build_system_prompt(analysis_type, output_language); - let full_prompt = format!("{}\n\n{}", system_prompt, prompt); - - GeminiApiRequest { - contents: vec![GeminiContent { - parts: vec![GeminiPart { text: full_prompt }], - }], - generation_config: Some(GeminiGenerationConfig { - max_output_tokens: self.config.max_tokens, - temperature: self.config.temperature, - }), - } - } -} - -impl LlmProvider for GeminiClient { +impl LlmProvider for RigLlmClient { fn analyze<'a>( &'a self, request: &'a AnalysisRequest, @@ -1270,7 +629,6 @@ impl LlmProvider for GeminiClient { let duration_ms = start_time.elapsed().as_millis() as u64; - // Parse the response to extract risk level and summary let (risk_level, summary, confidence) = Self::parse_analysis_response(&response_content); @@ -1297,130 +655,60 @@ impl LlmProvider for GeminiClient { } } -impl GeminiClient { - fn parse_analysis_response(response: &str) -> (crate::models::RiskLevel, String, f32) { - let response_lower = response.to_lowercase(); +fn create_provider(config: &LlmConfig) -> Result { + let model_name = config.model_name.trim(); - let risk_level = determine_risk_level(response, &response_lower); - let mut summary = extract_summary_from_response(response); - if let Some(legitimacy_line) = extract_legitimacy_line(response) { - let summary_lower = summary.to_lowercase(); - if summary.is_empty() - || (!summary_lower.contains("legitimacy assessment") - && !summary.contains("正当性評価") - && !summary.contains("正当")) - { - summary = if summary.is_empty() { - legitimacy_line - } else { - format!("{}\n{}", legitimacy_line, summary) - }; - } - } - let confidence = calculate_confidence(response, &response_lower); + if is_openai_model(model_name) { + let api_key = config.api_key.clone() + .or_else(|| std::env::var("OPENAI_API_KEY").ok()) + .ok_or_else(|| EbiError::LlmClientError("OpenAI API key not found".to_string()))?; - (risk_level, summary, confidence) + let client = openai::Client::new(&api_key); + + // Use ResponsesCompletionModel for newer models + if model_name.starts_with("gpt-5") || model_name.starts_with("o") { + Ok(RigProvider::OpenAIResponses(client)) + } else { + Ok(RigProvider::OpenAI(client)) + } + } else if is_claude_model(model_name) { + let api_key = config.api_key.clone() + .or_else(|| std::env::var("ANTHROPIC_API_KEY").ok()) + .ok_or_else(|| EbiError::LlmClientError("Anthropic API key not found".to_string()))?; + + let client = anthropic::Client::new(&api_key); + Ok(RigProvider::Anthropic(client)) + } else if is_gemini_model(model_name) { + let api_key = config.api_key.clone() + .or_else(|| std::env::var("GEMINI_API_KEY").ok()) + .ok_or_else(|| EbiError::LlmClientError("Gemini API key not found".to_string()))?; + + let client = gemini::Client::new(&api_key); + Ok(RigProvider::Gemini(client)) + } else { + Err(EbiError::LlmClientError(format!( + "Unsupported model '{}'. Use OpenAI (gpt-*), Anthropic (claude-*), or Gemini (gemini-*) models", + model_name + ))) } } -// Factory function to create LLM clients pub fn create_llm_client( model: &str, api_key: Option, timeout_seconds: u64, ) -> Result, EbiError> { - // Determine API endpoint based on model - let endpoint_override = std::env::var("EBI_LLM_API_ENDPOINT").ok(); - let trimmed_model = model.trim(); - - let (api_endpoint, actual_model) = if let Some(endpoint) = endpoint_override { - (endpoint, trimmed_model.to_string()) - } else if is_openai_model(trimmed_model) { - if is_responses_model(trimmed_model) { - ( - "https://api.openai.com/v1/responses".to_string(), - trimmed_model.to_string(), - ) - } else { - ( - "https://api.openai.com/v1/chat/completions".to_string(), - trimmed_model.to_string(), - ) - } - } else if is_claude_model(trimmed_model) { - ( - "https://api.anthropic.com/v1/messages".to_string(), - trimmed_model.to_string(), - ) - } else if is_gemini_model(trimmed_model) { - ( - format!( - "https://generativelanguage.googleapis.com/v1beta/models/{}:generateContent", - trimmed_model - ), - trimmed_model.to_string(), - ) - } else { - return Err(EbiError::LlmClientError(format!( - "Unsupported model '{trimmed_model}'. Specify a supported model or set EBI_LLM_API_ENDPOINT for custom integrations", - ))); - }; - let config = LlmConfig { - model_name: actual_model, - api_endpoint, + model_name: model.to_string(), api_key, timeout_seconds, max_retries: 3, - max_tokens: None, - temperature: None, - }; - - let client: Box = if is_claude_model(&config.model_name) { - Box::new(ClaudeClient::new(config)?) - } else if is_gemini_model(&config.model_name) { - Box::new(GeminiClient::new(config)?) - } else { - Box::new(OpenAiCompatibleClient::new(config)?) - }; - Ok(client) -} - -fn build_llm_api_request( - model_name: &str, - prompt: String, - analysis_type: &AnalysisType, - output_language: &OutputLanguage, -) -> LlmApiRequest { - use crate::analyzer::prompts::PromptTemplate; - - let uses_reasoning = uses_reasoning_parameters(model_name); - let system_prompt = PromptTemplate::build_system_prompt(analysis_type, output_language); - - let mut api_request = LlmApiRequest { - model: model_name.to_string(), - messages: vec![ - ChatMessage { - role: "system".to_string(), - content: system_prompt, - }, - ChatMessage { - role: "user".to_string(), - content: prompt, - }, - ], - max_tokens: None, - max_completion_tokens: None, - temperature: if uses_reasoning { None } else { Some(0.3) }, + max_tokens: Some(1000), + temperature: Some(0.3), }; - if uses_reasoning { - api_request.max_completion_tokens = Some(1000); - } else { - api_request.max_tokens = Some(1000); - } - - api_request + let client = RigLlmClient::new(config)?; + Ok(Box::new(client)) } fn is_openai_model(model: &str) -> bool { @@ -1441,21 +729,9 @@ fn is_claude_model(model: &str) -> bool { fn is_gemini_model(model: &str) -> bool { let candidate = model.strip_prefix("gemini/").unwrap_or(model); - candidate.starts_with("gemini-") } -fn uses_reasoning_parameters(model: &str) -> bool { - let candidate = model.strip_prefix("openai/").unwrap_or(model); - let candidate = candidate.strip_prefix("ft:").unwrap_or(candidate); - - candidate.starts_with("o1") - || candidate.starts_with("o3") - || candidate.starts_with("o4") - || candidate.starts_with("gpt-5") - || candidate.starts_with("gpt-4.1") -} - #[cfg(test)] mod tests { use super::*; @@ -1479,7 +755,7 @@ mod tests { fn test_response_parsing() { let response = "Risk Level: HIGH\nThis script contains potential vulnerabilities including command injection."; let (risk_level, summary, confidence) = - OpenAiCompatibleClient::parse_analysis_response(response); + RigLlmClient::parse_analysis_response(response); assert_eq!(risk_level, crate::models::RiskLevel::High); assert!(summary.contains("vulnerabilities")); @@ -1502,116 +778,39 @@ mod tests { } #[test] - fn test_o_series_models_supported_by_detection() { - assert!(super::is_openai_model("o1-mini")); - assert!(super::is_openai_model("o3-preview")); - assert!(super::is_openai_model("o4-mini")); - assert!(super::is_openai_model("gpt-5-mini")); - - assert!(super::uses_reasoning_parameters("o1-mini")); - assert!(super::uses_reasoning_parameters("o3-preview")); - assert!(super::uses_reasoning_parameters("o4-mini")); - assert!(super::uses_reasoning_parameters("gpt-5-mini")); - assert!(super::uses_reasoning_parameters("gpt-4.1")); - - assert!(!super::uses_reasoning_parameters("gpt-4o")); - assert!(!super::uses_reasoning_parameters("gpt-4o-mini")); - assert!(!super::uses_reasoning_parameters("gpt-3.5-turbo")); - } - - #[test] - fn test_build_api_request_switches_token_parameters() { - use crate::models::{AnalysisType, OutputLanguage}; - let request = super::build_llm_api_request( - "gpt-5-mini", - "prompt".to_string(), - &AnalysisType::CodeVulnerability, - &OutputLanguage::English, - ); - assert!(request.max_tokens.is_none()); - assert_eq!(request.max_completion_tokens, Some(1000)); - assert!(request.temperature.is_none()); - - let classic_request = super::build_llm_api_request( - "gpt-4o", - "prompt".to_string(), - &AnalysisType::CodeVulnerability, - &OutputLanguage::English, - ); - assert_eq!(classic_request.max_tokens, Some(1000)); - assert!(classic_request.max_completion_tokens.is_none()); - assert_eq!(classic_request.temperature, Some(0.3)); - } - - #[test] - fn test_claude_model_detection() { - assert!(super::is_claude_model("claude-3.5-sonnet")); - assert!(super::is_claude_model("claude-3.5-haiku")); - assert!(super::is_claude_model("claude-3-opus")); - assert!(super::is_claude_model("claude-3-sonnet")); - assert!(super::is_claude_model("claude-3-haiku")); - assert!(super::is_claude_model("claude-2")); - assert!(super::is_claude_model("claude-instant")); - assert!(super::is_claude_model("anthropic/claude-3.5-sonnet")); - - assert!(!super::is_claude_model("gpt-4")); - assert!(!super::is_claude_model("gemini-pro")); - assert!(!super::is_claude_model("unknown-model")); - } + fn test_model_detection() { + assert!(is_openai_model("gpt-4")); + assert!(is_openai_model("gpt-4o")); + assert!(is_openai_model("o1-mini")); + assert!(is_openai_model("o3-preview")); - #[cfg_attr(target_os = "macos", ignore)] - #[test] - fn test_claude_client_creation() { - let client = - super::create_llm_client("claude-3.5-sonnet", Some("test-key".to_string()), 60); - assert!(client.is_ok()); + assert!(is_claude_model("claude-3.5-sonnet")); + assert!(is_claude_model("claude-3-opus")); + assert!(is_claude_model("anthropic/claude-3.5-sonnet")); - let client = client.unwrap(); - assert_eq!(client.get_model_name(), "claude-3.5-sonnet"); + assert!(is_gemini_model("gemini-pro")); + assert!(is_gemini_model("gemini-1.5-pro")); + assert!(is_gemini_model("gemini/gemini-2.5-flash")); } #[test] fn test_claude_response_parsing() { let response = "Risk Level: HIGH\nThis script contains potential vulnerabilities including command injection."; let (risk_level, summary, confidence) = - super::ClaudeClient::parse_analysis_response(response); + RigLlmClient::parse_analysis_response(response); assert_eq!(risk_level, crate::models::RiskLevel::High); assert!(summary.contains("vulnerabilities")); assert!(confidence > 0.5); } - #[test] - fn test_gemini_model_detection() { - assert!(super::is_gemini_model("gemini-pro")); - assert!(super::is_gemini_model("gemini-1.5-pro")); - assert!(super::is_gemini_model("gemini-1.5-flash")); - assert!(super::is_gemini_model("gemini-2.0-flash-exp")); - assert!(super::is_gemini_model("gemini-2.5-flash")); - assert!(super::is_gemini_model("gemini/gemini-1.5-pro")); - - assert!(!super::is_gemini_model("gpt-4")); - assert!(!super::is_gemini_model("claude-3.5-sonnet")); - assert!(!super::is_gemini_model("unknown-model")); - } - - #[cfg_attr(target_os = "macos", ignore)] - #[test] - fn test_gemini_client_creation() { - let client = super::create_llm_client("gemini-2.5-flash", Some("test-key".to_string()), 60); - assert!(client.is_ok()); - - let client = client.unwrap(); - assert_eq!(client.get_model_name(), "gemini-2.5-flash"); - } - #[test] fn test_gemini_response_parsing() { let response = "Risk Level: HIGH\nThis script contains potential vulnerabilities including command injection."; - let (risk_level, summary, confidence) = GeminiClient::parse_analysis_response(response); + let (risk_level, summary, confidence) = RigLlmClient::parse_analysis_response(response); assert_eq!(risk_level, crate::models::RiskLevel::High); assert!(summary.contains("vulnerabilities")); assert!(confidence > 0.5); } -} +} \ No newline at end of file diff --git a/src/analyzer/mod.rs b/src/analyzer/mod.rs index e8e3c1c..adff81d 100644 --- a/src/analyzer/mod.rs +++ b/src/analyzer/mod.rs @@ -6,6 +6,6 @@ pub mod orchestrator; pub mod prompts; pub use aggregator::AnalysisAggregator; -pub use llm_client::{create_llm_client, LlmConfig, LlmProvider, OpenAiCompatibleClient}; +pub use llm_client::{create_llm_client, LlmConfig, LlmProvider}; pub use orchestrator::AnalysisOrchestrator; pub use prompts::PromptTemplate; diff --git a/src/analyzer/orchestrator.rs b/src/analyzer/orchestrator.rs index 5b34b87..8acc001 100644 --- a/src/analyzer/orchestrator.rs +++ b/src/analyzer/orchestrator.rs @@ -388,11 +388,11 @@ impl AnalysisOrchestrator { // Helper functions for creating orchestrators with common configurations impl AnalysisOrchestrator { pub fn for_development() -> Result { - Self::new("gpt-5-mini", None, 60, 2) + Self::new("gpt-4o", None, 60, 2) } pub fn for_production(api_key: String) -> Result { - Self::new("gpt-5-mini", Some(api_key), 120, 3) + Self::new("gpt-4o", Some(api_key), 120, 3) } pub fn for_testing() -> Result { diff --git a/src/cli/args.rs b/src/cli/args.rs index 6e4d4ad..7ee8309 100644 --- a/src/cli/args.rs +++ b/src/cli/args.rs @@ -14,7 +14,7 @@ pub struct Cli { pub lang: Option, /// LLM model to use for analysis - #[arg(short = 'm', long, default_value = "gpt-5-mini")] + #[arg(short = 'm', long, default_value = "gpt-4o")] pub model: String, /// Maximum time for LLM analysis in seconds (10-300) @@ -201,7 +201,7 @@ mod tests { let cli = Cli::try_parse_from(args).unwrap(); assert_eq!(cli.command_and_args, vec!["bash"]); - assert_eq!(cli.model, "gpt-5-mini"); + assert_eq!(cli.model, "gpt-4o"); assert_eq!(cli.timeout, 300); assert!(cli.lang.is_none()); assert!(!cli.verbose);