Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions ollama-rs/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ pub enum OllamaError {
ReqwestError(#[from] reqwest::Error),
#[error("Internal Ollama error")]
InternalError(InternalOllamaError),
#[error("Ollama aborted the request")]
Abort,
#[error("Error in Ollama")]
Other(String),
}
Expand Down
43 changes: 29 additions & 14 deletions ollama-rs/src/generation/chat/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,11 @@ impl Ollama {
let serialized = serde_json::to_string(&request)
.map_err(|e| e.to_string())
.unwrap();
let builder = self.reqwest_client.post(url);
let mut builder = self.reqwest_client.post(url);

if let Some(timeout) = request.timeout {
builder = builder.timeout(timeout);
}

#[cfg(feature = "headers")]
let builder = builder.headers(self.request_headers.clone());
Expand All @@ -49,20 +53,27 @@ impl Ollama {
));
}

let stream = Box::new(res.bytes_stream().map(|res| match res {
Ok(bytes) => {
let res = serde_json::from_slice::<ChatMessageResponse>(&bytes);
match res {
Ok(res) => Ok(res),
Err(e) => {
eprintln!("Failed to deserialize response: {}", e);
Err(())
}
let stream = Box::new(res.bytes_stream().map(move |res| {
if let Some(abort_signal) = request.abort_signal.as_ref() {
if abort_signal.aborted() {
return Err(());
}
}
Err(e) => {
eprintln!("Failed to read response: {}", e);
Err(())
match res {
Ok(bytes) => {
let res = serde_json::from_slice::<ChatMessageResponse>(&bytes);
match res {
Ok(res) => Ok(res),
Err(e) => {
eprintln!("Failed to deserialize response: {}", e);
Err(())
}
}
}
Err(e) => {
eprintln!("Failed to read response: {}", e);
Err(())
}
}
}));

Expand All @@ -80,7 +91,11 @@ impl Ollama {

let url = format!("{}api/chat", self.url_str());
let serialized = serde_json::to_string(&request)?;
let builder = self.reqwest_client.post(url);
let mut builder = self.reqwest_client.post(url);

if let Some(timeout) = request.timeout {
builder = builder.timeout(timeout);
}

#[cfg(feature = "headers")]
let builder = builder.headers(self.request_headers.clone());
Expand Down
19 changes: 19 additions & 0 deletions ollama-rs/src/generation/chat/request.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use serde::Serialize;

use crate::generation::{
completion::request::AbortSignal,
options::GenerationOptions,
parameters::FormatType,
tools::{ToolGroup, ToolInfo},
Expand All @@ -22,6 +23,10 @@ pub struct ChatMessageRequest {
pub format: Option<FormatType>,
/// Must be false if tools are provided
pub(crate) stream: bool,
#[serde(skip)]
pub abort_signal: Option<AbortSignal>,
#[serde(skip)]
pub(crate) timeout: Option<std::time::Duration>,
}

impl ChatMessageRequest {
Expand All @@ -35,6 +40,8 @@ impl ChatMessageRequest {
// Stream value will be overwritten by Ollama::send_chat_messages_stream() and Ollama::send_chat_messages() methods
stream: false,
tools: vec![],
abort_signal: None,
timeout: None,
}
}

Expand Down Expand Up @@ -63,4 +70,16 @@ impl ChatMessageRequest {

self
}

/// Sets the abort signal for the request
pub fn abort_signal(mut self, signal: AbortSignal) -> Self {
self.abort_signal = Some(signal);
self
}

/// Sets the timeout for the request
pub fn timeout(mut self, timeout: std::time::Duration) -> Self {
self.timeout = Some(timeout);
self
}
}
41 changes: 28 additions & 13 deletions ollama-rs/src/generation/completion/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,11 @@ impl Ollama {

let url = format!("{}api/generate", self.url_str());
let serialized = serde_json::to_string(&request)?;
let builder = self.reqwest_client.post(url);
let mut builder = self.reqwest_client.post(url);

if let Some(timeout) = request.timeout {
builder = builder.timeout(timeout);
}

#[cfg(feature = "headers")]
let builder = builder.headers(self.request_headers.clone());
Expand All @@ -47,18 +51,25 @@ impl Ollama {
));
}

let stream = Box::new(res.bytes_stream().map(|res| match res {
Ok(bytes) => {
let res = serde_json::Deserializer::from_slice(&bytes).into_iter();
let res = res
.filter_map(Result::ok) // Filter out the errors
.collect::<Vec<GenerationResponse>>();
Ok(res)
let stream = Box::new(res.bytes_stream().map(move |res| {
if let Some(abort_signal) = request.abort_signal.as_ref() {
if abort_signal.aborted() {
return Err(OllamaError::Abort);
}
}
match res {
Ok(bytes) => {
let res = serde_json::Deserializer::from_slice(&bytes).into_iter();
let res = res
.filter_map(Result::ok) // Filter out the errors
.collect::<Vec<GenerationResponse>>();
Ok(res)
}
Err(e) => Err(OllamaError::Other(format!(
"Failed to read response: {}",
e
))),
}
Err(e) => Err(OllamaError::Other(format!(
"Failed to read response: {}",
e
))),
}));

Ok(std::pin::Pin::from(stream))
Expand All @@ -80,7 +91,11 @@ impl Ollama {
#[cfg(feature = "headers")]
let builder = builder.headers(self.request_headers.clone());

let res = builder.body(serialized).send().await?;
let mut builder = builder.body(serialized);
if let Some(timeout) = request.timeout {
builder = builder.timeout(timeout);
}
let res = builder.send().await?;

if !res.status().is_success() {
return Err(OllamaError::Other(
Expand Down
42 changes: 42 additions & 0 deletions ollama-rs/src/generation/completion/request.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::sync::{atomic::AtomicBool, Arc};

use serde::Serialize;

use crate::generation::{
Expand All @@ -8,6 +10,28 @@

use super::GenerationContext;

#[derive(Debug, Clone)]
pub struct AbortSignal {
pub(crate) abort_signal: Arc<AtomicBool>,
}

impl AbortSignal {
pub fn new() -> Self {

Check failure on line 19 in ollama-rs/src/generation/completion/request.rs

View workflow job for this annotation

GitHub Actions / Formatting

you should consider adding a `Default` implementation for `AbortSignal`
Self {
abort_signal: Arc::new(AtomicBool::new(false)),
}
}

pub fn abort(&self) {
self.abort_signal
.store(true, std::sync::atomic::Ordering::Relaxed);
}

pub fn aborted(&self) -> bool {
self.abort_signal.load(std::sync::atomic::Ordering::Relaxed)
}
}

/// A generation request to Ollama.
#[derive(Debug, Clone, Serialize)]
pub struct GenerationRequest {
Expand All @@ -24,6 +48,10 @@
pub format: Option<FormatType>,
pub keep_alive: Option<KeepAlive>,
pub(crate) stream: bool,
#[serde(skip)]
pub abort_signal: Option<AbortSignal>,
#[serde(skip)]
pub(crate) timeout: Option<std::time::Duration>,
}

impl GenerationRequest {
Expand All @@ -41,6 +69,8 @@
keep_alive: None,
// Stream value will be overwritten by Ollama::generate_stream() and Ollama::generate() methods
stream: false,
abort_signal: None,
timeout: None,
}
}

Expand Down Expand Up @@ -103,4 +133,16 @@
self.keep_alive = Some(keep_alive);
self
}

/// Sets the abort signal for the request
pub fn abort_signal(mut self, abort_signal: AbortSignal) -> Self {
self.abort_signal = Some(abort_signal);
self
}

/// Sets the timeout for the request
pub fn timeout(mut self, timeout: std::time::Duration) -> Self {
self.timeout = Some(timeout);
self
}
}
Loading