From 48716a250c29d1eba420966a7a7bd0b810d15263 Mon Sep 17 00:00:00 2001 From: Patrik Date: Fri, 13 Feb 2026 16:18:28 +0100 Subject: [PATCH] test: expand gemini api client coverage and harden error mapping --- src/services/gemini_api_client.py | 10 +-- tests/unit/test_gemini_api_client.py | 93 +++++++++++++++++++++++++++- 2 files changed, 98 insertions(+), 5 deletions(-) diff --git a/src/services/gemini_api_client.py b/src/services/gemini_api_client.py index a9910e4..2dfdf4e 100644 --- a/src/services/gemini_api_client.py +++ b/src/services/gemini_api_client.py @@ -53,7 +53,7 @@ def __init__( logger.debug(f"GeminiAPIClient initialized: model={model_name}") - def send_multimodal_prompt(self, text: str, screenshot: Screenshot) -> str: # noqa: PLR0912 + def send_multimodal_prompt(self, text: str, screenshot: Screenshot) -> str: # noqa: PLR0912, PLR0915 """ Send text + image prompt to Gemini API. @@ -167,11 +167,13 @@ def send_multimodal_prompt(self, text: str, screenshot: Screenshot) -> str: # n except genai.types.generation_types.BlockedPromptException as e: raise APIError(f"Prompt blocked by safety filters: {e}") from e except Exception as e: - if "API_KEY_INVALID" in str(e) or "invalid API key" in str(e).lower(): + message = str(e) + message_lower = message.lower() + if "API_KEY_INVALID" in message or "invalid api key" in message_lower: raise AuthenticationError("Invalid or expired Gemini API key") from e - if "quota" in str(e).lower(): + if "quota" in message_lower: raise APIError(f"API quota exceeded: {e}") from e - if "not found" in str(e).lower(): + if "not found" in message_lower: raise APIError(f"Model not found: {self.model_name}") from e raise APIError(f"Gemini API request failed: {e}") from e diff --git a/tests/unit/test_gemini_api_client.py b/tests/unit/test_gemini_api_client.py index ab4410e..2e7224a 100644 --- a/tests/unit/test_gemini_api_client.py +++ b/tests/unit/test_gemini_api_client.py @@ -10,7 +10,7 @@ import pytest from PIL import Image -from src.lib.exceptions import APIError, OAuthConfigNotFoundError, PayloadTooLargeError +from src.lib.exceptions import APIError, AuthenticationError, OAuthConfigNotFoundError, PayloadTooLargeError from src.models.entities import Screenshot from src.services.gemini_api_client import GeminiAPIClient @@ -64,6 +64,36 @@ def test_get_api_key_raises_when_no_sources_available( client._get_api_key() +def test_get_api_key_reads_from_yaml_config( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +) -> None: + cfg = tmp_path / "config.yaml" + cfg.write_text("gemini:\n api_key: yaml-key\n", encoding="utf-8") + monkeypatch.delenv("GEMINI_API_KEY", raising=False) + monkeypatch.delenv("GOOGLE_API_KEY", raising=False) + monkeypatch.setattr(GeminiAPIClient, "DEFAULT_CONFIG_PATH", cfg) + + client = GeminiAPIClient(api_key=None) + + assert client._get_api_key() == "yaml-key" + + +def test_validate_oauth_token_returns_false_on_unexpected_error(monkeypatch: pytest.MonkeyPatch) -> None: + client = GeminiAPIClient(api_key=None) + monkeypatch.setattr(client, "_get_api_key", lambda: (_ for _ in ()).throw(RuntimeError("boom"))) + + assert client.validate_oauth_token() is False + + +def test_refresh_oauth_token_raises_when_validation_fails(monkeypatch: pytest.MonkeyPatch) -> None: + client = GeminiAPIClient(api_key=None) + monkeypatch.setattr(client, "validate_oauth_token", lambda: False) + + with pytest.raises(AuthenticationError, match="API key is invalid and cannot be refreshed"): + client.refresh_oauth_token() + + def test_send_multimodal_prompt_missing_file_raises_api_error() -> None: client = GeminiAPIClient(api_key="gem-key") screenshot = Screenshot( @@ -102,6 +132,37 @@ def test_send_multimodal_prompt_large_payload_raises(sample_screenshot: Screensh client.send_multimodal_prompt("hello", oversized) +def test_send_multimodal_prompt_initializes_model_when_missing( + monkeypatch: pytest.MonkeyPatch, + sample_screenshot: Screenshot, +) -> None: + client = GeminiAPIClient(api_key="gem-key") + monkeypatch.setattr("src.services.gemini_api_client.genai.configure", lambda **_kwargs: None) + monkeypatch.setattr( + "src.services.gemini_api_client.genai.GenerativeModel", + lambda **_kwargs: _FakeModel( + SimpleNamespace( + prompt_feedback=None, + candidates=[SimpleNamespace(content=SimpleNamespace(parts=[SimpleNamespace(text="ok")]))], + ) + ), + ) + + assert client.send_multimodal_prompt("hello", sample_screenshot) == "ok" + + +def test_send_multimodal_prompt_raises_when_model_init_returns_none( + monkeypatch: pytest.MonkeyPatch, + sample_screenshot: Screenshot, +) -> None: + client = GeminiAPIClient(api_key="gem-key") + monkeypatch.setattr("src.services.gemini_api_client.genai.configure", lambda **_kwargs: None) + monkeypatch.setattr("src.services.gemini_api_client.genai.GenerativeModel", lambda **_kwargs: None) + + with pytest.raises(APIError, match="Gemini model initialization failed"): + client.send_multimodal_prompt("hello", sample_screenshot) + + def test_send_multimodal_prompt_returns_concatenated_text( monkeypatch: pytest.MonkeyPatch, sample_screenshot: Screenshot, @@ -153,3 +214,33 @@ def test_send_multimodal_prompt_raises_on_empty_candidates( with pytest.raises(APIError, match="No response candidates"): client.send_multimodal_prompt("hello", sample_screenshot) + + +@pytest.mark.parametrize( + ("message", "exception_type", "expected"), + [ + ("API_KEY_INVALID", AuthenticationError, "Invalid or expired Gemini API key"), + ("invalid API key", AuthenticationError, "Invalid or expired Gemini API key"), + ("quota exceeded", APIError, "API quota exceeded"), + ("model not found", APIError, "Model not found"), + ("other boom", APIError, "Gemini API request failed"), + ], +) +def test_send_multimodal_prompt_maps_runtime_errors( + monkeypatch: pytest.MonkeyPatch, + sample_screenshot: Screenshot, + message: str, + exception_type: type[Exception], + expected: str, +) -> None: + client = GeminiAPIClient(api_key="gem-key") + monkeypatch.setattr("src.services.gemini_api_client.genai.configure", lambda **_kwargs: None) + + class _RaiseModel: + def generate_content(self, *_args, **_kwargs): + raise RuntimeError(message) + + client._model = _RaiseModel() + + with pytest.raises(exception_type, match=expected): + client.send_multimodal_prompt("hello", sample_screenshot)