diff --git a/server_tests/test_openai_responses_api.py b/server_tests/test_openai_responses_api.py index 8eaabc7..2ea6724 100644 --- a/server_tests/test_openai_responses_api.py +++ b/server_tests/test_openai_responses_api.py @@ -584,6 +584,71 @@ def test_create_response_stream_with_tool_calls(self, mock_openai: Mock) -> None self.assertEqual(final_event["ai_tool_calls"][0]["function_name"], "create_point") self.assertEqual(final_event["ai_tool_calls"][0]["arguments"], {"x": 5, "y": 10}) + @patch("static.openai_api_base.OpenAI") + def test_clear_previous_response_id(self, mock_openai: Mock) -> None: + """Test clear_previous_response_id resets the stored ID.""" + api = OpenAIResponsesAPI() + api._previous_response_id = "resp_abc123" + + api.clear_previous_response_id() + + self.assertIsNone(api._previous_response_id) + + @patch("static.openai_api_base.OpenAI") + def test_clear_previous_response_id_noop_when_none(self, mock_openai: Mock) -> None: + """Test clear_previous_response_id is a no-op when ID is already None.""" + api = OpenAIResponsesAPI() + self.assertIsNone(api._previous_response_id) + + api.clear_previous_response_id() # Should not raise + + self.assertIsNone(api._previous_response_id) + + @patch("static.openai_api_base.OpenAI") + def test_add_partial_message_clears_previous_response_id(self, mock_openai: Mock) -> None: + """Test add_partial_assistant_message clears stale previous_response_id.""" + api = OpenAIResponsesAPI() + api._previous_response_id = "resp_stale_with_pending_tool_calls" + + api.add_partial_assistant_message("Partial AI response before stop") + + self.assertIsNone(api._previous_response_id) + + @patch("static.openai_api_base.OpenAI") + def test_partial_message_still_appends_to_history(self, mock_openai: Mock) -> None: + """Test add_partial_assistant_message still appends the message via super().""" + api = OpenAIResponsesAPI() + api._previous_response_id = "resp_xyz" + initial_count = len(api.messages) + + api.add_partial_assistant_message("Some partial text") + + # Message should be appended + self.assertEqual(len(api.messages), initial_count + 1) + self.assertEqual(api.messages[-1]["role"], "assistant") + self.assertEqual(api.messages[-1]["content"], "Some partial text") + # And ID should be cleared + self.assertIsNone(api._previous_response_id) + + @patch("static.openai_api_base.OpenAI") + def test_add_empty_partial_message_still_clears_response_id(self, mock_openai: Mock) -> None: + """Test that empty partial message still clears stale previous_response_id. + + When the user interrupts a tool-call-only response, the stream buffer + is empty. The backend must still clear the stale response ID even though + no message is appended to history. + """ + api = OpenAIResponsesAPI() + api._previous_response_id = "resp_pending_tool_calls" + initial_count = len(api.messages) + + api.add_partial_assistant_message("") + + # No message should be appended (base class skips empty content) + self.assertEqual(len(api.messages), initial_count) + # But the stale response ID must still be cleared + self.assertIsNone(api._previous_response_id) + class TestOpenAIResponsesAPIIntegration(unittest.TestCase): """Integration tests that actually call the OpenAI Responses API. diff --git a/static/client/ai_interface.py b/static/client/ai_interface.py index 0254761..8b09f86 100644 --- a/static/client/ai_interface.py +++ b/static/client/ai_interface.py @@ -1869,9 +1869,10 @@ def stop_ai_processing(self) -> None: self._stop_requested = True self._abort_current_stream() self._cancel_response_timeout() - # Save partial response to conversation history before finalizing - if self._stream_buffer and self._stream_buffer.strip(): - self._save_partial_response(self._stream_buffer) + # Always notify the backend so it can clear stale conversation state + # (e.g. previous_response_id pointing to unanswered tool calls). + # The backend handles empty text gracefully. + self._save_partial_response(self._stream_buffer or "") self._finalize_stream_message() self._print_system_message_in_chat("Generation stopped.") self._enable_send_controls() diff --git a/static/openai_responses_api.py b/static/openai_responses_api.py index aaa92c4..3ba8718 100644 --- a/static/openai_responses_api.py +++ b/static/openai_responses_api.py @@ -61,6 +61,22 @@ def reset_conversation(self) -> None: self._previous_response_id = None self._log("[Responses API] Conversation reset, cleared previous_response_id") + def clear_previous_response_id(self) -> None: + """Clear the stored response ID (e.g. after user interruption).""" + if self._previous_response_id is not None: + self._log("[Responses API] Cleared previous_response_id") + self._previous_response_id = None + + def add_partial_assistant_message(self, content: str) -> None: + """Add a partial assistant message and clear stale response ID. + + When the user interrupts, the previous response may have pending + tool calls that will never be answered. Clearing the ID prevents + the next request from referencing that broken state. + """ + super().add_partial_assistant_message(content) + self.clear_previous_response_id() + def _is_regular_message_turn(self) -> bool: """Check if this is a regular user message turn (not a tool call continuation). diff --git a/static/routes.py b/static/routes.py index 63a52b5..7ae9200 100644 --- a/static/routes.py +++ b/static/routes.py @@ -1013,14 +1013,16 @@ def save_partial_response() -> ResponseReturnValue: ) partial_message = request_payload.get("partial_message", "") - if not isinstance(partial_message, str) or not partial_message.strip(): + if not isinstance(partial_message, str): return AppManager.make_response( - message="No partial message to save", + message="Invalid partial message", status="error", code=400, ) - # Add the partial response to all API conversation histories + # Always notify all APIs so they can clear stale conversation + # state (e.g. previous_response_id after interrupted tool calls). + # The base class skips appending empty text to history. app.ai_api.add_partial_assistant_message(partial_message) app.responses_api.add_partial_assistant_message(partial_message) for provider in app.providers.values():