diff --git a/codex-rs/ollama/src/backend.rs b/codex-rs/ollama/src/backend.rs index a422e40c4e2..b60c8447794 100644 --- a/codex-rs/ollama/src/backend.rs +++ b/codex-rs/ollama/src/backend.rs @@ -1,5 +1,5 @@ use codex_core::error::{CodexErr, Result}; -use codex_core::{ContentItem, ResponseEvent, ResponseItem, ToolingBridge, TOOLING_SCHEMA}; +use codex_core::{ContentItem, ResponseEvent, ResponseItem, TOOLING_SCHEMA, ToolingBridge}; use futures::StreamExt; use ollama_rs::Ollama; use ollama_rs::error::OllamaError; @@ -16,6 +16,15 @@ pub struct OllamaBackend { tool_bridge: Option>, } +/// Events produced while streaming chat completions. +#[derive(Debug, Clone)] +pub enum ChatStreamEvent { + /// A standard response event from the model. + Response(ResponseEvent), + /// Error encountered while parsing a streamed chunk. + Error(String), +} + impl OllamaBackend { /// Construct a backend pointing at the given base URL, e.g. /// `http://localhost:11434`. @@ -69,10 +78,10 @@ impl OllamaBackend { } /// Stream a chat completion, invoking `on_event` for each emitted - /// [`ResponseEvent`]. + /// [`ChatStreamEvent`]. pub async fn chat_stream(&self, model: &str, prompt: &str, mut on_event: F) -> Result<()> where - F: FnMut(ResponseEvent), + F: FnMut(ChatStreamEvent), { let mut full_prompt = prompt.to_string(); if self.tool_bridge.is_some() { @@ -87,29 +96,108 @@ impl OllamaBackend { .await .map_err(|e| CodexErr::Io(io::Error::other(e.to_string())))?; - let mut buffer = String::new(); + let mut full_buffer = String::new(); + let mut json_buffer = String::new(); + let mut depth = 0usize; + let mut in_string = false; + let mut escape = false; + while let Some(chunk) = stream.next().await { - let parts = - chunk.map_err(|e| CodexErr::Io(io::Error::other(e.to_string())))?; + let parts = chunk.map_err(|e| CodexErr::Io(io::Error::other(e.to_string())))?; for part in parts { - buffer.push_str(&part.response); + full_buffer.push_str(&part.response); if self.tool_bridge.is_none() { - on_event(ResponseEvent::OutputTextDelta(part.response)); + on_event(ChatStreamEvent::Response(ResponseEvent::OutputTextDelta( + part.response, + ))); + continue; + } + + for ch in part.response.chars() { + json_buffer.push(ch); + if in_string { + if escape { + escape = false; + continue; + } + match ch { + '\\' => escape = true, + '"' => in_string = false, + _ => {} + } + continue; + } + match ch { + '"' => in_string = true, + '{' => depth += 1, + '}' => { + if depth > 0 { + depth -= 1; + } + if depth == 0 { + let text = std::mem::take(&mut json_buffer); + let item = ResponseItem::Message { + id: None, + role: "assistant".into(), + content: vec![ContentItem::OutputText { text: text.clone() }], + }; + if let Some(bridge) = &self.tool_bridge { + match bridge.parse_event(ResponseEvent::OutputItemDone(item)) { + Ok(events) => { + for ev in events { + on_event(ChatStreamEvent::Response(ev)); + } + } + Err(err) => { + on_event(ChatStreamEvent::Error(err.to_string())); + return Err(err); + } + } + } + } + } + _ => {} + } } } } - let item = ResponseItem::Message { - id: None, - role: "assistant".into(), - content: vec![ContentItem::OutputText { text: buffer }], - }; - if let Some(bridge) = &self.tool_bridge { - for ev in bridge.parse_event(ResponseEvent::OutputItemDone(item))? { - on_event(ev); + if self.tool_bridge.is_some() { + if depth != 0 { + let msg = "incomplete JSON object".to_string(); + on_event(ChatStreamEvent::Error(msg.clone())); + return Err(CodexErr::Json(serde_json::Error::custom(msg))); + } + if !json_buffer.trim().is_empty() { + let text = std::mem::take(&mut json_buffer); + let item = ResponseItem::Message { + id: None, + role: "assistant".into(), + content: vec![ContentItem::OutputText { text: text.clone() }], + }; + if let Some(bridge) = &self.tool_bridge { + match bridge.parse_event(ResponseEvent::OutputItemDone(item)) { + Ok(events) => { + for ev in events { + on_event(ChatStreamEvent::Response(ev)); + } + } + Err(err) => { + on_event(ChatStreamEvent::Error(err.to_string())); + return Err(err); + } + } + } } } else { - on_event(ResponseEvent::OutputItemDone(item)); + let item = ResponseItem::Message { + id: None, + role: "assistant".into(), + content: vec![ContentItem::OutputText { text: full_buffer }], + }; + on_event(ChatStreamEvent::Response(ResponseEvent::OutputItemDone( + item, + ))); } Ok(()) } diff --git a/codex-rs/ollama/src/lib.rs b/codex-rs/ollama/src/lib.rs index a3defa2ff6a..a3b4c789ac1 100644 --- a/codex-rs/ollama/src/lib.rs +++ b/codex-rs/ollama/src/lib.rs @@ -5,7 +5,7 @@ mod parser; mod pull; mod url; -pub use backend::OllamaBackend; +pub use backend::{ChatStreamEvent, OllamaBackend}; pub use bridge::OllamaToolBridge; pub use bridge::register_ollama_tool_bridge; pub use client::OllamaClient;