Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions server_tests/test_openai_responses_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,6 +584,71 @@ def test_create_response_stream_with_tool_calls(self, mock_openai: Mock) -> None
self.assertEqual(final_event["ai_tool_calls"][0]["function_name"], "create_point")
self.assertEqual(final_event["ai_tool_calls"][0]["arguments"], {"x": 5, "y": 10})

@patch("static.openai_api_base.OpenAI")
def test_clear_previous_response_id(self, mock_openai: Mock) -> None:
"""Test clear_previous_response_id resets the stored ID."""
api = OpenAIResponsesAPI()
api._previous_response_id = "resp_abc123"

api.clear_previous_response_id()

self.assertIsNone(api._previous_response_id)

@patch("static.openai_api_base.OpenAI")
def test_clear_previous_response_id_noop_when_none(self, mock_openai: Mock) -> None:
"""Test clear_previous_response_id is a no-op when ID is already None."""
api = OpenAIResponsesAPI()
self.assertIsNone(api._previous_response_id)

api.clear_previous_response_id() # Should not raise

self.assertIsNone(api._previous_response_id)

@patch("static.openai_api_base.OpenAI")
def test_add_partial_message_clears_previous_response_id(self, mock_openai: Mock) -> None:
"""Test add_partial_assistant_message clears stale previous_response_id."""
api = OpenAIResponsesAPI()
api._previous_response_id = "resp_stale_with_pending_tool_calls"

api.add_partial_assistant_message("Partial AI response before stop")

self.assertIsNone(api._previous_response_id)

@patch("static.openai_api_base.OpenAI")
def test_partial_message_still_appends_to_history(self, mock_openai: Mock) -> None:
"""Test add_partial_assistant_message still appends the message via super()."""
api = OpenAIResponsesAPI()
api._previous_response_id = "resp_xyz"
initial_count = len(api.messages)

api.add_partial_assistant_message("Some partial text")

# Message should be appended
self.assertEqual(len(api.messages), initial_count + 1)
self.assertEqual(api.messages[-1]["role"], "assistant")
self.assertEqual(api.messages[-1]["content"], "Some partial text")
# And ID should be cleared
self.assertIsNone(api._previous_response_id)

@patch("static.openai_api_base.OpenAI")
def test_add_empty_partial_message_still_clears_response_id(self, mock_openai: Mock) -> None:
"""Test that empty partial message still clears stale previous_response_id.

When the user interrupts a tool-call-only response, the stream buffer
is empty. The backend must still clear the stale response ID even though
no message is appended to history.
"""
api = OpenAIResponsesAPI()
api._previous_response_id = "resp_pending_tool_calls"
initial_count = len(api.messages)

api.add_partial_assistant_message("")

# No message should be appended (base class skips empty content)
self.assertEqual(len(api.messages), initial_count)
# But the stale response ID must still be cleared
self.assertIsNone(api._previous_response_id)


class TestOpenAIResponsesAPIIntegration(unittest.TestCase):
"""Integration tests that actually call the OpenAI Responses API.
Expand Down
7 changes: 4 additions & 3 deletions static/client/ai_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -1869,9 +1869,10 @@ def stop_ai_processing(self) -> None:
self._stop_requested = True
self._abort_current_stream()
self._cancel_response_timeout()
# Save partial response to conversation history before finalizing
if self._stream_buffer and self._stream_buffer.strip():
self._save_partial_response(self._stream_buffer)
# Always notify the backend so it can clear stale conversation state
# (e.g. previous_response_id pointing to unanswered tool calls).
# The backend handles empty text gracefully.
self._save_partial_response(self._stream_buffer or "")
self._finalize_stream_message()
self._print_system_message_in_chat("Generation stopped.")
self._enable_send_controls()
Expand Down
16 changes: 16 additions & 0 deletions static/openai_responses_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,22 @@ def reset_conversation(self) -> None:
self._previous_response_id = None
self._log("[Responses API] Conversation reset, cleared previous_response_id")

def clear_previous_response_id(self) -> None:
"""Clear the stored response ID (e.g. after user interruption)."""
if self._previous_response_id is not None:
self._log("[Responses API] Cleared previous_response_id")
self._previous_response_id = None

def add_partial_assistant_message(self, content: str) -> None:
"""Add a partial assistant message and clear stale response ID.

When the user interrupts, the previous response may have pending
tool calls that will never be answered. Clearing the ID prevents
the next request from referencing that broken state.
"""
super().add_partial_assistant_message(content)
self.clear_previous_response_id()

def _is_regular_message_turn(self) -> bool:
"""Check if this is a regular user message turn (not a tool call continuation).

Expand Down
8 changes: 5 additions & 3 deletions static/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1013,14 +1013,16 @@ def save_partial_response() -> ResponseReturnValue:
)

partial_message = request_payload.get("partial_message", "")
if not isinstance(partial_message, str) or not partial_message.strip():
if not isinstance(partial_message, str):
return AppManager.make_response(
message="No partial message to save",
message="Invalid partial message",
status="error",
code=400,
)

# Add the partial response to all API conversation histories
# Always notify all APIs so they can clear stale conversation
# state (e.g. previous_response_id after interrupted tool calls).
# The base class skips appending empty text to history.
app.ai_api.add_partial_assistant_message(partial_message)
app.responses_api.add_partial_assistant_message(partial_message)
for provider in app.providers.values():
Expand Down