From 6da22fb1d889d48a658839c5a691bee662e3753f Mon Sep 17 00:00:00 2001 From: matdev83 <211248003+matdev83@users.noreply.github.com> Date: Mon, 13 Oct 2025 00:28:59 +0200 Subject: [PATCH 1/2] Fix OpenAI per-call API key handling --- src/connectors/openai.py | 158 ++++++++++-------- .../test_precision_payload_mapping.py | 48 ++++++ 2 files changed, 133 insertions(+), 73 deletions(-) diff --git a/src/connectors/openai.py b/src/connectors/openai.py index e9414fe81..cc91c574b 100644 --- a/src/connectors/openai.py +++ b/src/connectors/openai.py @@ -225,88 +225,100 @@ async def chat_completions( processed_messages: list[Any], effective_model: str, identity: IAppIdentityConfig | None = None, + api_key: str | None = None, **kwargs: Any, ) -> ResponseEnvelope | StreamingResponseEnvelope: - # Perform health check if enabled (for subclasses that support it) - await self._ensure_healthy() - - # request_data is expected to be a domain ChatRequest (or subclass like CanonicalChatRequest) - # (the frontend controller converts from frontend-specific format to domain format) - # Backends should ONLY convert FROM domain TO backend-specific format - # Type assertion: we know from architectural design that request_data is ChatRequest-like - from typing import cast - - from src.core.domain.chat import CanonicalChatRequest, ChatRequest - - if not isinstance(request_data, ChatRequest): - raise TypeError( - f"Expected ChatRequest or CanonicalChatRequest, got {type(request_data).__name__}. " - "Backend connectors should only receive domain-format requests." - ) - # Cast to CanonicalChatRequest for mypy compatibility with _prepare_payload signature - domain_request: CanonicalChatRequest = cast(CanonicalChatRequest, request_data) + # Allow callers to supply a one-off API key (e.g., multi-tenant flows). + # Temporarily replace the connector-level key for the duration of this + # call so that header construction and health checks use it. + original_api_key = self.api_key + if api_key is not None: + self.api_key = api_key - # Ensure identity headers are scoped to the current request only. - self.identity = identity - - # Prepare the payload using a helper so subclasses and tests can - # override or patch payload construction logic easily. - payload = await self._prepare_payload( - domain_request, processed_messages, effective_model - ) - headers_override = kwargs.pop("headers_override", None) - headers: dict[str, str] | None = None + # Perform health check if enabled (for subclasses that support it) + try: + await self._ensure_healthy() - if headers_override is not None: - # Avoid mutating the caller-provided mapping while preserving any - # Authorization header we compute from the configured API key. - headers = dict(headers_override) + # request_data is expected to be a domain ChatRequest (or subclass like CanonicalChatRequest) + # (the frontend controller converts from frontend-specific format to domain format) + # Backends should ONLY convert FROM domain TO backend-specific format + # Type assertion: we know from architectural design that request_data is ChatRequest-like + from typing import cast - try: - base_headers = self.get_headers() - except Exception: - base_headers = None + from src.core.domain.chat import CanonicalChatRequest, ChatRequest - if base_headers: - merged_headers = dict(base_headers) - merged_headers.update(headers) - headers = merged_headers - else: - try: - # Always update the cached identity so that per-request - # identity headers do not leak between calls. Downstream - # callers rely on identity-specific headers being scoped to - # a single request. - self.identity = identity - headers = self.get_headers() - except Exception: - headers = None + if not isinstance(request_data, ChatRequest): + raise TypeError( + f"Expected ChatRequest or CanonicalChatRequest, got {type(request_data).__name__}. " + "Backend connectors should only receive domain-format requests." + ) + # Cast to CanonicalChatRequest for mypy compatibility with _prepare_payload signature + domain_request: CanonicalChatRequest = cast(CanonicalChatRequest, request_data) - api_base = kwargs.get("openai_url") or self.api_base_url - url = f"{api_base.rstrip('/')}/chat/completions" + # Ensure identity headers are scoped to the current request only. + self.identity = identity - if domain_request.stream: - # Return a domain-level streaming envelope (raw bytes iterator) - try: - content_iterator = await self._handle_streaming_response( - url, - payload, - headers, - domain_request.session_id or "", - "openai", - ) - except AuthenticationError as e: - raise HTTPException(status_code=401, detail=str(e)) - return StreamingResponseEnvelope( - content=content_iterator, - media_type="text/event-stream", - headers={}, - ) - else: - # Return a domain ResponseEnvelope for non-streaming - return await self._handle_non_streaming_response( - url, payload, headers, domain_request.session_id or "" + # Prepare the payload using a helper so subclasses and tests can + # override or patch payload construction logic easily. + payload = await self._prepare_payload( + domain_request, processed_messages, effective_model ) + headers_override = kwargs.pop("headers_override", None) + headers: dict[str, str] | None = None + + if headers_override is not None: + # Avoid mutating the caller-provided mapping while preserving any + # Authorization header we compute from the configured API key. + headers = dict(headers_override) + + try: + base_headers = self.get_headers() + except Exception: + base_headers = None + + if base_headers: + merged_headers = dict(base_headers) + merged_headers.update(headers) + headers = merged_headers + else: + try: + # Always update the cached identity so that per-request + # identity headers do not leak between calls. Downstream + # callers rely on identity-specific headers being scoped to + # a single request. + self.identity = identity + headers = self.get_headers() + except Exception: + headers = None + + api_base = kwargs.get("openai_url") or self.api_base_url + url = f"{api_base.rstrip('/')}/chat/completions" + + if domain_request.stream: + # Return a domain-level streaming envelope (raw bytes iterator) + try: + content_iterator = await self._handle_streaming_response( + url, + payload, + headers, + domain_request.session_id or "", + "openai", + ) + except AuthenticationError as e: + raise HTTPException(status_code=401, detail=str(e)) + return StreamingResponseEnvelope( + content=content_iterator, + media_type="text/event-stream", + headers={}, + ) + else: + # Return a domain ResponseEnvelope for non-streaming + return await self._handle_non_streaming_response( + url, payload, headers, domain_request.session_id or "" + ) + finally: + if api_key is not None: + self.api_key = original_api_key async def _prepare_payload( self, diff --git a/tests/unit/connectors/test_precision_payload_mapping.py b/tests/unit/connectors/test_precision_payload_mapping.py index 99da8b099..b968fbab4 100644 --- a/tests/unit/connectors/test_precision_payload_mapping.py +++ b/tests/unit/connectors/test_precision_payload_mapping.py @@ -12,6 +12,7 @@ from src.connectors.openrouter import OpenRouterBackend from src.core.config.app_config import AppConfig from src.core.domain.chat import ChatMessage, ChatRequest +from src.core.domain.responses import ResponseEnvelope from src.core.services.translation_service import TranslationService @@ -45,6 +46,53 @@ async def fake_post(url: str, json: dict, headers: dict) -> httpx.Response: assert captured_payload.get("top_p") == 0.34 +@pytest.mark.asyncio +async def test_openai_connector_uses_per_call_api_key( + monkeypatch: pytest.MonkeyPatch, +) -> None: + client = httpx.AsyncClient() + connector = OpenAIConnector( + client, AppConfig(), translation_service=TranslationService() + ) + connector.disable_health_check() + connector.api_key = None + + observed_headers: list[dict[str, str] | None] = [] + + async def fake_handle( + self: OpenAIConnector, + url: str, + payload: dict[str, Any], + headers: dict[str, str] | None, + session_id: str, + ) -> ResponseEnvelope: + observed_headers.append(headers) + return ResponseEnvelope(content={}, status_code=200, headers={}) + + monkeypatch.setattr( + OpenAIConnector, + "_handle_non_streaming_response", + fake_handle, + ) + + request = ChatRequest(model="gpt-4o", messages=_messages(), stream=False) + + try: + await connector.chat_completions( + request, + request.messages, + request.model, + api_key="per-call-token", + ) + finally: + await client.aclose() + + assert observed_headers and observed_headers[0] is not None + assert ( + observed_headers[0].get("Authorization") == "Bearer per-call-token" + ) + + @pytest.mark.asyncio async def test_openai_payload_uses_processed_messages_with_list_content( monkeypatch: pytest.MonkeyPatch, From 1b816702a5fa15cc4c287d834e059dd96cd9dc90 Mon Sep 17 00:00:00 2001 From: matdev83 <211248003+matdev83@users.noreply.github.com> Date: Wed, 15 Oct 2025 23:08:27 +0200 Subject: [PATCH 2/2] Avoid mutating API key for per-call overrides --- src/connectors/openai.py | 35 +++++++++++++++++------------------ 1 file changed, 17 insertions(+), 18 deletions(-) diff --git a/src/connectors/openai.py b/src/connectors/openai.py index cc91c574b..8ffc2de0d 100644 --- a/src/connectors/openai.py +++ b/src/connectors/openai.py @@ -95,13 +95,14 @@ def _resolve_translation_service() -> TranslationService: ) return service - def get_headers(self) -> dict[str, str]: + def get_headers(self, api_key: str | None = None) -> dict[str, str]: """Return request headers including API key and per-request identity.""" headers: dict[str, str] = {} - if self.api_key: - headers["Authorization"] = f"Bearer {self.api_key}" + effective_api_key = api_key if api_key is not None else self.api_key + if effective_api_key: + headers["Authorization"] = f"Bearer {effective_api_key}" if self.identity: try: @@ -143,7 +144,7 @@ async def initialize(self, **kwargs: Any) -> None: logger.warning("Failed to fetch models: %s", e, exc_info=True) # Log the error but don't fail initialization - async def _perform_health_check(self) -> bool: + async def _perform_health_check(self, api_key: str | None = None) -> bool: """Perform a health check by testing API connectivity. This method tests actual API connectivity by making a simple request to verify @@ -154,11 +155,12 @@ async def _perform_health_check(self) -> bool: """ try: # Test API connectivity with a simple models endpoint request - if not self.api_key: + effective_api_key = api_key if api_key is not None else self.api_key + if not effective_api_key: logger.warning("Health check failed - no API key available") return False - headers = self.get_headers() + headers = self.get_headers(effective_api_key) if not headers.get("Authorization"): logger.warning("Health check failed - no authorization header") return False @@ -183,7 +185,7 @@ async def _perform_health_check(self) -> bool: ) return False - async def _ensure_healthy(self) -> None: + async def _ensure_healthy(self, api_key: str | None = None) -> None: """Ensure the backend is healthy before use. This method performs health checks on first use, similar to how @@ -198,7 +200,7 @@ async def _ensure_healthy(self) -> None: f"Performing first-use health check for {self.backend_type} backend" ) - healthy = await self._perform_health_check() + healthy = await self._perform_health_check(api_key) if not healthy: logger.warning( "Health check did not pass; continuing with lazy verification on first request" @@ -228,16 +230,14 @@ async def chat_completions( api_key: str | None = None, **kwargs: Any, ) -> ResponseEnvelope | StreamingResponseEnvelope: + original_identity = self.identity + # Allow callers to supply a one-off API key (e.g., multi-tenant flows). - # Temporarily replace the connector-level key for the duration of this - # call so that header construction and health checks use it. - original_api_key = self.api_key - if api_key is not None: - self.api_key = api_key + effective_api_key = api_key if api_key is not None else self.api_key # Perform health check if enabled (for subclasses that support it) try: - await self._ensure_healthy() + await self._ensure_healthy(effective_api_key) # request_data is expected to be a domain ChatRequest (or subclass like CanonicalChatRequest) # (the frontend controller converts from frontend-specific format to domain format) @@ -272,7 +272,7 @@ async def chat_completions( headers = dict(headers_override) try: - base_headers = self.get_headers() + base_headers = self.get_headers(effective_api_key) except Exception: base_headers = None @@ -287,7 +287,7 @@ async def chat_completions( # callers rely on identity-specific headers being scoped to # a single request. self.identity = identity - headers = self.get_headers() + headers = self.get_headers(effective_api_key) except Exception: headers = None @@ -317,8 +317,7 @@ async def chat_completions( url, payload, headers, domain_request.session_id or "" ) finally: - if api_key is not None: - self.api_key = original_api_key + self.identity = original_identity async def _prepare_payload( self,