Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions model_gateway/src/routers/mcp_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,18 @@ use tracing::{debug, warn};
/// Default maximum tool loop iterations (safety limit).
pub const DEFAULT_MAX_ITERATIONS: usize = 10;

/// Compute the effective tool call limit from an optional user-specified max.
///
/// Clamps the user value to [`DEFAULT_MAX_ITERATIONS`] and falls back to
/// the default when no user limit is given.
#[inline]
pub fn effective_tool_call_limit(max_tool_calls: Option<usize>) -> usize {
match max_tool_calls {
Some(user_max) => user_max.min(DEFAULT_MAX_ITERATIONS),
None => DEFAULT_MAX_ITERATIONS,
}
}

/// Protocol-agnostic MCP server descriptor for connection setup.
///
/// Contains only the fields needed by [`connect_mcp_servers`]. Each router
Expand Down
10 changes: 2 additions & 8 deletions model_gateway/src/routers/openai/mcp/tool_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,14 +144,8 @@ impl StreamingToolHandler {

/// Process an SSE event and determine what action to take
pub fn process_event(&mut self, event_name: Option<&str>, data: &str) -> StreamAction {
// Always feed to accumulator for storage
self.accumulator.ingest_block(&format!(
"{}data: {}",
event_name
.map(|n| format!("event: {n}\n"))
.unwrap_or_default(),
data
));
// Always feed to accumulator for storage (bypasses format+reparse overhead)
self.accumulator.ingest_event(event_name, data);

let parsed: Value = match serde_json::from_str(data) {
Ok(v) => v,
Expand Down
68 changes: 26 additions & 42 deletions model_gateway/src/routers/openai/mcp/tool_loop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,25 @@ use tracing::{debug, info, warn};
use super::tool_handler::FunctionCallInProgress;
use crate::{
observability::metrics::{metrics_labels, Metrics},
routers::{error, header_utils::apply_request_headers, mcp_utils::DEFAULT_MAX_ITERATIONS},
routers::{
error,
header_utils::apply_request_headers,
mcp_utils::{effective_tool_call_limit, DEFAULT_MAX_ITERATIONS},
},
};

/// Send an SSE event to the client channel.
/// Returns false if client disconnected.
#[inline]
fn send_sse_event(
tx: &mpsc::UnboundedSender<Result<Bytes, io::Error>>,
event_name: &str,
payload: &Value,
) -> bool {
let block = format!("event: {event_name}\ndata: {payload}\n\n");
tx.send(Ok(Bytes::from(block))).is_ok()
}
Comment on lines +37 to +47
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick | 🔵 Trivial

LGTM! Clean SSE helper consolidating inline formatting.

The helper correctly formats the SSE block and returns connection status. The #[inline] hint is appropriate for this small, hot-path function.

Note: An identical send_sse_event helper exists in streaming.rs (lines 205-212). The duplication is acceptable given both are module-private, but could be consolidated into a shared utility if more call sites emerge.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@model_gateway/src/routers/openai/mcp/tool_loop.rs` around lines 37 - 47,
Duplicate send_sse_event helpers exist (this file's send_sse_event and the one
in streaming.rs); extract the shared logic into a single private utility (e.g.,
a new module or helper function referenced by both modules) and replace both
local implementations with calls to that shared function to reduce duplication
while preserving current behavior and signature (keep the parameters tx:
&mpsc::UnboundedSender<Result<Bytes, io::Error>>, event_name: &str, payload:
&Value and return bool).


/// State for tracking multi-turn tool calling loop
pub(crate) struct ToolLoopState {
/// Current iteration number (starts at 0, increments with each tool call)
Expand Down Expand Up @@ -311,13 +327,8 @@ pub(crate) fn send_mcp_list_tools_events(
"item": tools_item_empty
});
*sequence_number += 1;
let event1 = format!(
"event: {}\ndata: {}\n\n",
OutputItemEvent::ADDED,
event1_payload
);
if tx.send(Ok(Bytes::from(event1))).is_err() {
return false; // Client disconnected
if !send_sse_event(tx, OutputItemEvent::ADDED, &event1_payload) {
return false;
}

// Event 2: response.mcp_list_tools.in_progress
Expand All @@ -328,12 +339,7 @@ pub(crate) fn send_mcp_list_tools_events(
"item_id": item_id
});
*sequence_number += 1;
let event2 = format!(
"event: {}\ndata: {}\n\n",
McpEvent::LIST_TOOLS_IN_PROGRESS,
event2_payload
);
if tx.send(Ok(Bytes::from(event2))).is_err() {
if !send_sse_event(tx, McpEvent::LIST_TOOLS_IN_PROGRESS, &event2_payload) {
return false;
}

Expand All @@ -345,12 +351,7 @@ pub(crate) fn send_mcp_list_tools_events(
"item_id": item_id
});
*sequence_number += 1;
let event3 = format!(
"event: {}\ndata: {}\n\n",
McpEvent::LIST_TOOLS_COMPLETED,
event3_payload
);
if tx.send(Ok(Bytes::from(event3))).is_err() {
if !send_sse_event(tx, McpEvent::LIST_TOOLS_COMPLETED, &event3_payload) {
return false;
}

Expand All @@ -362,12 +363,7 @@ pub(crate) fn send_mcp_list_tools_events(
"item": tools_item_full
});
*sequence_number += 1;
let event4 = format!(
"event: {}\ndata: {}\n\n",
OutputItemEvent::DONE,
event4_payload
);
tx.send(Ok(Bytes::from(event4))).is_ok()
send_sse_event(tx, OutputItemEvent::DONE, &event4_payload)
}

/// Send intermediate event during tool execution (searching/interpreting).
Expand Down Expand Up @@ -403,8 +399,7 @@ fn send_tool_call_intermediate_event(
});
*sequence_number += 1;

let event = format!("event: {event_type}\ndata: {event_payload}\n\n");
tx.send(Ok(Bytes::from(event))).is_ok()
send_sse_event(tx, event_type, &event_payload)
}

/// Send tool call completion events after tool execution.
Expand Down Expand Up @@ -444,9 +439,7 @@ fn send_tool_call_completion_events(
"item_id": item_id
});
*sequence_number += 1;

let completed_event = format!("event: {completed_event_type}\ndata: {completed_payload}\n\n");
if tx.send(Ok(Bytes::from(completed_event))).is_err() {
if !send_sse_event(tx, completed_event_type, &completed_payload) {
return false;
}

Expand All @@ -458,13 +451,7 @@ fn send_tool_call_completion_events(
"item": tool_call_item
});
*sequence_number += 1;

let done_event = format!(
"event: {}\ndata: {}\n\n",
OutputItemEvent::DONE,
done_payload
);
tx.send(Ok(Bytes::from(done_event))).is_ok()
send_sse_event(tx, OutputItemEvent::DONE, &done_payload)
}

/// Inject MCP metadata into a streaming response
Expand Down Expand Up @@ -564,10 +551,7 @@ pub(crate) async fn execute_tool_loop(
function_calls.len()
);

let effective_limit = match max_tool_calls {
Some(user_max) => user_max.min(DEFAULT_MAX_ITERATIONS),
None => DEFAULT_MAX_ITERATIONS,
};
let effective_limit = effective_tool_call_limit(max_tool_calls);

for call in function_calls {
state.total_calls += 1;
Expand Down
32 changes: 15 additions & 17 deletions model_gateway/src/routers/openai/responses/accumulator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,17 @@ impl StreamingResponseAccumulator {
self.process_block(block);
}

/// Feed the accumulator with a pre-split event name and data payload.
///
/// Avoids the format+reparse overhead of [`ingest_block`] when the caller
/// already has the event name and data separated (e.g. from SSE parsing).
pub fn ingest_event(&mut self, event_name: Option<&str>, data: &str) {
if data.is_empty() {
return;
}
self.handle_event(event_name, data);
}

/// Consume the accumulator and produce the best-effort final response value.
pub fn into_final_response(mut self) -> Option<Value> {
if self.completed_response.is_some() {
Expand All @@ -62,16 +73,15 @@ impl StreamingResponseAccumulator {
if let Some(resp) = &self.completed_response {
return Some(resp.clone());
}
self.build_fallback_response_snapshot()
self.assemble_fallback(self.output_items.clone())
}

fn build_fallback_response_snapshot(&self) -> Option<Value> {
fn assemble_fallback(&self, mut output_items: Vec<(usize, Value)>) -> Option<Value> {
let mut response = self.initial_response.clone()?;

if let Some(obj) = response.as_object_mut() {
obj.insert("status".to_string(), Value::String("completed".to_string()));

let mut output_items = self.output_items.clone();
output_items.sort_by_key(|(index, _)| *index);
let outputs: Vec<Value> = output_items.into_iter().map(|(_, item)| item).collect();
obj.insert("output".to_string(), Value::Array(outputs));
Expand Down Expand Up @@ -131,19 +141,7 @@ impl StreamingResponseAccumulator {
}

fn build_fallback_response(&mut self) -> Option<Value> {
let mut response = self.initial_response.clone()?;

if let Some(obj) = response.as_object_mut() {
obj.insert("status".to_string(), Value::String("completed".to_string()));

self.output_items.sort_by_key(|(index, _)| *index);
let outputs: Vec<Value> = std::mem::take(&mut self.output_items)
.into_iter()
.map(|(_, item)| item)
.collect();
obj.insert("output".to_string(), Value::Array(outputs));
}

Some(response)
let output_items = std::mem::take(&mut self.output_items);
self.assemble_fallback(output_items)
}
}
6 changes: 5 additions & 1 deletion model_gateway/src/routers/openai/responses/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,14 @@ impl ChunkProcessor {
Ok(s) => Cow::Borrowed(s),
Err(_) => Cow::Owned(String::from_utf8_lossy(chunk).into_owned()),
};
// Fast path: no \r means no CRLF normalization needed
if !chunk_str.contains('\r') {
self.pending.push_str(&chunk_str);
return;
}
let mut chars = chunk_str.chars().peekable();
while let Some(c) = chars.next() {
if c == '\r' && chars.peek() == Some(&'\n') {
// Skip \r when followed by \n
continue;
}
self.pending.push(c);
Expand Down
Loading
Loading