diff --git a/llm_client/src/clients/anthropic.rs b/llm_client/src/clients/anthropic.rs index 00732e48d..6370cc6e0 100644 --- a/llm_client/src/clients/anthropic.rs +++ b/llm_client/src/clients/anthropic.rs @@ -242,6 +242,25 @@ struct AnthropicRequest { } impl AnthropicRequest { + pub fn new_chat( + messages: Vec, + temperature: f32, + top_p: Option, + max_tokens: Option, + model_str: String, + ) -> Self { + AnthropicRequest { + system: vec![], + messages, + tools: vec![], + temperature, + stream: true, + max_tokens, + model: model_str, + thinking: None, + } + } + fn from_client_completion_request( completion_request: LLMClientCompletionRequest, model_str: String, @@ -336,6 +355,7 @@ impl AnthropicRequest { stream: true, max_tokens, model: model_str, + thinking, } } @@ -349,6 +369,11 @@ impl AnthropicRequest { "user".to_owned(), completion_request.prompt().to_owned(), )]; + let thinking = completion_request.thinking_budget().map(|budget| AnthropicThinking { + r#type: "enabled".to_string(), + budget_tokens: budget, + }); + AnthropicRequest { system: vec![], messages, @@ -357,6 +382,7 @@ impl AnthropicRequest { stream: true, max_tokens, model: model_str, + thinking, } } } diff --git a/llm_client/src/clients/types.rs b/llm_client/src/clients/types.rs index de003ebdf..51b944a41 100644 --- a/llm_client/src/clients/types.rs +++ b/llm_client/src/clients/types.rs @@ -712,9 +712,14 @@ impl LLMClientCompletionStringRequest { frequency_penalty, stop_words: None, max_tokens: None, + thinking_budget: None, } } + pub fn thinking_budget(&self) -> Option { + self.thinking_budget + } + pub fn set_stop_words(mut self, stop_words: Vec) -> Self { self.stop_words = Some(stop_words); self @@ -748,6 +753,11 @@ impl LLMClientCompletionStringRequest { pub fn get_max_tokens(&self) -> Option { self.max_tokens } + + pub fn set_thinking_budget(mut self, thinking_budget: usize) -> Self { + self.thinking_budget = Some(thinking_budget); + self + } } impl LLMClientCompletionRequest { @@ -764,6 +774,7 @@ impl LLMClientCompletionRequest { frequency_penalty, stop_words: None, max_tokens: None, + thinking_budget: None, } } @@ -1040,7 +1051,32 @@ pub trait LLMClient { #[cfg(test)] mod tests { - use super::LLMType; + use super::*; + + #[test] + fn test_thinking_budget() { + let request = LLMClientCompletionRequest::new( + LLMType::ClaudeSonnet, + vec![LLMClientMessage::user("test".to_string())], + 0.7, + None, + ); + assert_eq!(request.thinking_budget(), None); + + let request = request.set_thinking_budget(16000); + assert_eq!(request.thinking_budget(), Some(16000)); + + let string_request = LLMClientCompletionStringRequest::new( + LLMType::ClaudeSonnet, + "test".to_string(), + 0.7, + None, + ); + assert_eq!(string_request.thinking_budget(), None); + + let string_request = string_request.set_thinking_budget(16000); + assert_eq!(string_request.thinking_budget(), Some(16000)); + } #[test] fn test_llm_type_from_string() { @@ -1048,4 +1084,4 @@ mod tests { let str_llm_type = serde_json::to_string(&llm_type).expect("to work"); assert_eq!(str_llm_type, ""); } -} +} \ No newline at end of file