Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 105 additions & 17 deletions codex-rs/ollama/src/backend.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use codex_core::error::{CodexErr, Result};
use codex_core::{ContentItem, ResponseEvent, ResponseItem, ToolingBridge, TOOLING_SCHEMA};
use codex_core::{ContentItem, ResponseEvent, ResponseItem, TOOLING_SCHEMA, ToolingBridge};
use futures::StreamExt;
use ollama_rs::Ollama;
use ollama_rs::error::OllamaError;
Expand All @@ -16,6 +16,15 @@ pub struct OllamaBackend {
tool_bridge: Option<Arc<dyn ToolingBridge>>,
}

/// Events produced while streaming chat completions.
#[derive(Debug, Clone)]
pub enum ChatStreamEvent {
/// A standard response event from the model.
Response(ResponseEvent),
/// Error encountered while parsing a streamed chunk.
Error(String),
}

impl OllamaBackend {
/// Construct a backend pointing at the given base URL, e.g.
/// `http://localhost:11434`.
Expand Down Expand Up @@ -69,10 +78,10 @@ impl OllamaBackend {
}

/// Stream a chat completion, invoking `on_event` for each emitted
/// [`ResponseEvent`].
/// [`ChatStreamEvent`].
pub async fn chat_stream<F>(&self, model: &str, prompt: &str, mut on_event: F) -> Result<()>
where
F: FnMut(ResponseEvent),
F: FnMut(ChatStreamEvent),
{
let mut full_prompt = prompt.to_string();
if self.tool_bridge.is_some() {
Expand All @@ -87,29 +96,108 @@ impl OllamaBackend {
.await
.map_err(|e| CodexErr::Io(io::Error::other(e.to_string())))?;

let mut buffer = String::new();
let mut full_buffer = String::new();
let mut json_buffer = String::new();
let mut depth = 0usize;
let mut in_string = false;
let mut escape = false;

while let Some(chunk) = stream.next().await {
let parts =
chunk.map_err(|e| CodexErr::Io(io::Error::other(e.to_string())))?;
let parts = chunk.map_err(|e| CodexErr::Io(io::Error::other(e.to_string())))?;
for part in parts {
buffer.push_str(&part.response);
full_buffer.push_str(&part.response);
if self.tool_bridge.is_none() {
on_event(ResponseEvent::OutputTextDelta(part.response));
on_event(ChatStreamEvent::Response(ResponseEvent::OutputTextDelta(
part.response,
)));
continue;
}

for ch in part.response.chars() {
json_buffer.push(ch);
if in_string {
if escape {
escape = false;
continue;
}
match ch {
'\\' => escape = true,
'"' => in_string = false,
_ => {}
}
continue;
}
match ch {
'"' => in_string = true,
'{' => depth += 1,
'}' => {
if depth > 0 {
depth -= 1;
}
if depth == 0 {
let text = std::mem::take(&mut json_buffer);
let item = ResponseItem::Message {
id: None,
role: "assistant".into(),
content: vec![ContentItem::OutputText { text: text.clone() }],
};
if let Some(bridge) = &self.tool_bridge {
match bridge.parse_event(ResponseEvent::OutputItemDone(item)) {
Ok(events) => {
for ev in events {
on_event(ChatStreamEvent::Response(ev));
}
}
Err(err) => {
on_event(ChatStreamEvent::Error(err.to_string()));
return Err(err);
}
}
}
}
}
_ => {}
}
}
}
}

let item = ResponseItem::Message {
id: None,
role: "assistant".into(),
content: vec![ContentItem::OutputText { text: buffer }],
};
if let Some(bridge) = &self.tool_bridge {
for ev in bridge.parse_event(ResponseEvent::OutputItemDone(item))? {
on_event(ev);
if self.tool_bridge.is_some() {
if depth != 0 {
let msg = "incomplete JSON object".to_string();
on_event(ChatStreamEvent::Error(msg.clone()));
return Err(CodexErr::Json(serde_json::Error::custom(msg)));
}
if !json_buffer.trim().is_empty() {
let text = std::mem::take(&mut json_buffer);
let item = ResponseItem::Message {
id: None,
role: "assistant".into(),
content: vec![ContentItem::OutputText { text: text.clone() }],
};
if let Some(bridge) = &self.tool_bridge {
match bridge.parse_event(ResponseEvent::OutputItemDone(item)) {
Ok(events) => {
for ev in events {
on_event(ChatStreamEvent::Response(ev));
}
}
Err(err) => {
on_event(ChatStreamEvent::Error(err.to_string()));
return Err(err);
}
}
}
}
} else {
on_event(ResponseEvent::OutputItemDone(item));
let item = ResponseItem::Message {
id: None,
role: "assistant".into(),
content: vec![ContentItem::OutputText { text: full_buffer }],
};
on_event(ChatStreamEvent::Response(ResponseEvent::OutputItemDone(
item,
)));
}
Ok(())
}
Expand Down
2 changes: 1 addition & 1 deletion codex-rs/ollama/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ mod parser;
mod pull;
mod url;

pub use backend::OllamaBackend;
pub use backend::{ChatStreamEvent, OllamaBackend};
pub use bridge::OllamaToolBridge;
pub use bridge::register_ollama_tool_bridge;
pub use client::OllamaClient;
Expand Down
Loading