From 16119f734c85a76494da8d4543e3f24c55de11cf Mon Sep 17 00:00:00 2001 From: pepperoni21 Date: Mon, 13 Jan 2025 11:14:40 +0100 Subject: [PATCH 1/2] Added abort request for generate_stream --- ollama-rs/src/error.rs | 2 ++ ollama-rs/src/generation/completion/mod.rs | 29 +++++++++------- .../src/generation/completion/request.rs | 33 +++++++++++++++++++ 3 files changed, 53 insertions(+), 11 deletions(-) diff --git a/ollama-rs/src/error.rs b/ollama-rs/src/error.rs index 23100b0..5f7993d 100644 --- a/ollama-rs/src/error.rs +++ b/ollama-rs/src/error.rs @@ -15,6 +15,8 @@ pub enum OllamaError { ReqwestError(#[from] reqwest::Error), #[error("Internal Ollama error")] InternalError(InternalOllamaError), + #[error("Ollama aborted the request")] + Abort, #[error("Error in Ollama")] Other(String), } diff --git a/ollama-rs/src/generation/completion/mod.rs b/ollama-rs/src/generation/completion/mod.rs index 4202c85..6b1065f 100644 --- a/ollama-rs/src/generation/completion/mod.rs +++ b/ollama-rs/src/generation/completion/mod.rs @@ -47,18 +47,25 @@ impl Ollama { )); } - let stream = Box::new(res.bytes_stream().map(|res| match res { - Ok(bytes) => { - let res = serde_json::Deserializer::from_slice(&bytes).into_iter(); - let res = res - .filter_map(Result::ok) // Filter out the errors - .collect::>(); - Ok(res) + let stream = Box::new(res.bytes_stream().map(move |res| { + if let Some(abort_signal) = request.abort_signal.as_ref() { + if abort_signal.aborted() { + return Err(OllamaError::Abort); + } + } + match res { + Ok(bytes) => { + let res = serde_json::Deserializer::from_slice(&bytes).into_iter(); + let res = res + .filter_map(Result::ok) // Filter out the errors + .collect::>(); + Ok(res) + } + Err(e) => Err(OllamaError::Other(format!( + "Failed to read response: {}", + e + ))), } - Err(e) => Err(OllamaError::Other(format!( - "Failed to read response: {}", - e - ))), })); Ok(std::pin::Pin::from(stream)) diff --git a/ollama-rs/src/generation/completion/request.rs b/ollama-rs/src/generation/completion/request.rs index d7605b3..ee20df1 100644 --- a/ollama-rs/src/generation/completion/request.rs +++ b/ollama-rs/src/generation/completion/request.rs @@ -1,3 +1,5 @@ +use std::sync::{atomic::AtomicBool, Arc}; + use serde::Serialize; use crate::generation::{ @@ -8,6 +10,28 @@ use crate::generation::{ use super::GenerationContext; +#[derive(Debug, Clone)] +pub struct AbortSignal { + pub(crate) abort_signal: Arc, +} + +impl AbortSignal { + pub fn new() -> Self { + Self { + abort_signal: Arc::new(AtomicBool::new(false)), + } + } + + pub fn abort(&self) { + self.abort_signal + .store(true, std::sync::atomic::Ordering::Relaxed); + } + + pub fn aborted(&self) -> bool { + self.abort_signal.load(std::sync::atomic::Ordering::Relaxed) + } +} + /// A generation request to Ollama. #[derive(Debug, Clone, Serialize)] pub struct GenerationRequest { @@ -24,6 +48,8 @@ pub struct GenerationRequest { pub format: Option, pub keep_alive: Option, pub(crate) stream: bool, + #[serde(skip)] + pub abort_signal: Option, } impl GenerationRequest { @@ -41,6 +67,7 @@ impl GenerationRequest { keep_alive: None, // Stream value will be overwritten by Ollama::generate_stream() and Ollama::generate() methods stream: false, + abort_signal: None, } } @@ -103,4 +130,10 @@ impl GenerationRequest { self.keep_alive = Some(keep_alive); self } + + /// Sets the abort signal for the request + pub fn abort_signal(mut self, abort_signal: AbortSignal) -> Self { + self.abort_signal = Some(abort_signal); + self + } } From 9994cb5cc1c8a7ec312710e0dd4a465f37ffd58b Mon Sep 17 00:00:00 2001 From: pepperoni21 Date: Mon, 13 Jan 2025 12:26:46 +0100 Subject: [PATCH 2/2] Added abort signal for chat requests and timeout support --- ollama-rs/src/generation/chat/mod.rs | 43 +++++++++++++------ ollama-rs/src/generation/chat/request.rs | 19 ++++++++ ollama-rs/src/generation/completion/mod.rs | 12 +++++- .../src/generation/completion/request.rs | 9 ++++ 4 files changed, 67 insertions(+), 16 deletions(-) diff --git a/ollama-rs/src/generation/chat/mod.rs b/ollama-rs/src/generation/chat/mod.rs index 1143185..c69a518 100644 --- a/ollama-rs/src/generation/chat/mod.rs +++ b/ollama-rs/src/generation/chat/mod.rs @@ -36,7 +36,11 @@ impl Ollama { let serialized = serde_json::to_string(&request) .map_err(|e| e.to_string()) .unwrap(); - let builder = self.reqwest_client.post(url); + let mut builder = self.reqwest_client.post(url); + + if let Some(timeout) = request.timeout { + builder = builder.timeout(timeout); + } #[cfg(feature = "headers")] let builder = builder.headers(self.request_headers.clone()); @@ -49,20 +53,27 @@ impl Ollama { )); } - let stream = Box::new(res.bytes_stream().map(|res| match res { - Ok(bytes) => { - let res = serde_json::from_slice::(&bytes); - match res { - Ok(res) => Ok(res), - Err(e) => { - eprintln!("Failed to deserialize response: {}", e); - Err(()) - } + let stream = Box::new(res.bytes_stream().map(move |res| { + if let Some(abort_signal) = request.abort_signal.as_ref() { + if abort_signal.aborted() { + return Err(()); } } - Err(e) => { - eprintln!("Failed to read response: {}", e); - Err(()) + match res { + Ok(bytes) => { + let res = serde_json::from_slice::(&bytes); + match res { + Ok(res) => Ok(res), + Err(e) => { + eprintln!("Failed to deserialize response: {}", e); + Err(()) + } + } + } + Err(e) => { + eprintln!("Failed to read response: {}", e); + Err(()) + } } })); @@ -80,7 +91,11 @@ impl Ollama { let url = format!("{}api/chat", self.url_str()); let serialized = serde_json::to_string(&request)?; - let builder = self.reqwest_client.post(url); + let mut builder = self.reqwest_client.post(url); + + if let Some(timeout) = request.timeout { + builder = builder.timeout(timeout); + } #[cfg(feature = "headers")] let builder = builder.headers(self.request_headers.clone()); diff --git a/ollama-rs/src/generation/chat/request.rs b/ollama-rs/src/generation/chat/request.rs index 68e33a4..686b274 100644 --- a/ollama-rs/src/generation/chat/request.rs +++ b/ollama-rs/src/generation/chat/request.rs @@ -1,6 +1,7 @@ use serde::Serialize; use crate::generation::{ + completion::request::AbortSignal, options::GenerationOptions, parameters::FormatType, tools::{ToolGroup, ToolInfo}, @@ -22,6 +23,10 @@ pub struct ChatMessageRequest { pub format: Option, /// Must be false if tools are provided pub(crate) stream: bool, + #[serde(skip)] + pub abort_signal: Option, + #[serde(skip)] + pub(crate) timeout: Option, } impl ChatMessageRequest { @@ -35,6 +40,8 @@ impl ChatMessageRequest { // Stream value will be overwritten by Ollama::send_chat_messages_stream() and Ollama::send_chat_messages() methods stream: false, tools: vec![], + abort_signal: None, + timeout: None, } } @@ -63,4 +70,16 @@ impl ChatMessageRequest { self } + + /// Sets the abort signal for the request + pub fn abort_signal(mut self, signal: AbortSignal) -> Self { + self.abort_signal = Some(signal); + self + } + + /// Sets the timeout for the request + pub fn timeout(mut self, timeout: std::time::Duration) -> Self { + self.timeout = Some(timeout); + self + } } diff --git a/ollama-rs/src/generation/completion/mod.rs b/ollama-rs/src/generation/completion/mod.rs index 6b1065f..57729f1 100644 --- a/ollama-rs/src/generation/completion/mod.rs +++ b/ollama-rs/src/generation/completion/mod.rs @@ -34,7 +34,11 @@ impl Ollama { let url = format!("{}api/generate", self.url_str()); let serialized = serde_json::to_string(&request)?; - let builder = self.reqwest_client.post(url); + let mut builder = self.reqwest_client.post(url); + + if let Some(timeout) = request.timeout { + builder = builder.timeout(timeout); + } #[cfg(feature = "headers")] let builder = builder.headers(self.request_headers.clone()); @@ -87,7 +91,11 @@ impl Ollama { #[cfg(feature = "headers")] let builder = builder.headers(self.request_headers.clone()); - let res = builder.body(serialized).send().await?; + let mut builder = builder.body(serialized); + if let Some(timeout) = request.timeout { + builder = builder.timeout(timeout); + } + let res = builder.send().await?; if !res.status().is_success() { return Err(OllamaError::Other( diff --git a/ollama-rs/src/generation/completion/request.rs b/ollama-rs/src/generation/completion/request.rs index ee20df1..0d0dda3 100644 --- a/ollama-rs/src/generation/completion/request.rs +++ b/ollama-rs/src/generation/completion/request.rs @@ -50,6 +50,8 @@ pub struct GenerationRequest { pub(crate) stream: bool, #[serde(skip)] pub abort_signal: Option, + #[serde(skip)] + pub(crate) timeout: Option, } impl GenerationRequest { @@ -68,6 +70,7 @@ impl GenerationRequest { // Stream value will be overwritten by Ollama::generate_stream() and Ollama::generate() methods stream: false, abort_signal: None, + timeout: None, } } @@ -136,4 +139,10 @@ impl GenerationRequest { self.abort_signal = Some(abort_signal); self } + + /// Sets the timeout for the request + pub fn timeout(mut self, timeout: std::time::Duration) -> Self { + self.timeout = Some(timeout); + self + } }