Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions src/services/gemini_api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __init__(

logger.debug(f"GeminiAPIClient initialized: model={model_name}")

def send_multimodal_prompt(self, text: str, screenshot: Screenshot) -> str: # noqa: PLR0912
def send_multimodal_prompt(self, text: str, screenshot: Screenshot) -> str: # noqa: PLR0912, PLR0915
"""
Send text + image prompt to Gemini API.

Expand Down Expand Up @@ -167,11 +167,13 @@ def send_multimodal_prompt(self, text: str, screenshot: Screenshot) -> str: # n
except genai.types.generation_types.BlockedPromptException as e:
raise APIError(f"Prompt blocked by safety filters: {e}") from e
except Exception as e:
if "API_KEY_INVALID" in str(e) or "invalid API key" in str(e).lower():
message = str(e)
message_lower = message.lower()
if "API_KEY_INVALID" in message or "invalid api key" in message_lower:
raise AuthenticationError("Invalid or expired Gemini API key") from e
if "quota" in str(e).lower():
if "quota" in message_lower:
raise APIError(f"API quota exceeded: {e}") from e
if "not found" in str(e).lower():
if "not found" in message_lower:
raise APIError(f"Model not found: {self.model_name}") from e
raise APIError(f"Gemini API request failed: {e}") from e

Expand Down
93 changes: 92 additions & 1 deletion tests/unit/test_gemini_api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import pytest
from PIL import Image

from src.lib.exceptions import APIError, OAuthConfigNotFoundError, PayloadTooLargeError
from src.lib.exceptions import APIError, AuthenticationError, OAuthConfigNotFoundError, PayloadTooLargeError
from src.models.entities import Screenshot
from src.services.gemini_api_client import GeminiAPIClient

Expand Down Expand Up @@ -64,6 +64,36 @@ def test_get_api_key_raises_when_no_sources_available(
client._get_api_key()


def test_get_api_key_reads_from_yaml_config(
monkeypatch: pytest.MonkeyPatch,
tmp_path: Path,
) -> None:
cfg = tmp_path / "config.yaml"
cfg.write_text("gemini:\n api_key: yaml-key\n", encoding="utf-8")
monkeypatch.delenv("GEMINI_API_KEY", raising=False)
monkeypatch.delenv("GOOGLE_API_KEY", raising=False)
monkeypatch.setattr(GeminiAPIClient, "DEFAULT_CONFIG_PATH", cfg)

client = GeminiAPIClient(api_key=None)

assert client._get_api_key() == "yaml-key"


def test_validate_oauth_token_returns_false_on_unexpected_error(monkeypatch: pytest.MonkeyPatch) -> None:
client = GeminiAPIClient(api_key=None)
monkeypatch.setattr(client, "_get_api_key", lambda: (_ for _ in ()).throw(RuntimeError("boom")))

assert client.validate_oauth_token() is False


def test_refresh_oauth_token_raises_when_validation_fails(monkeypatch: pytest.MonkeyPatch) -> None:
client = GeminiAPIClient(api_key=None)
monkeypatch.setattr(client, "validate_oauth_token", lambda: False)

with pytest.raises(AuthenticationError, match="API key is invalid and cannot be refreshed"):
client.refresh_oauth_token()


def test_send_multimodal_prompt_missing_file_raises_api_error() -> None:
client = GeminiAPIClient(api_key="gem-key")
screenshot = Screenshot(
Expand Down Expand Up @@ -102,6 +132,37 @@ def test_send_multimodal_prompt_large_payload_raises(sample_screenshot: Screensh
client.send_multimodal_prompt("hello", oversized)


def test_send_multimodal_prompt_initializes_model_when_missing(
monkeypatch: pytest.MonkeyPatch,
sample_screenshot: Screenshot,
) -> None:
client = GeminiAPIClient(api_key="gem-key")
monkeypatch.setattr("src.services.gemini_api_client.genai.configure", lambda **_kwargs: None)
monkeypatch.setattr(
"src.services.gemini_api_client.genai.GenerativeModel",
lambda **_kwargs: _FakeModel(
SimpleNamespace(
prompt_feedback=None,
candidates=[SimpleNamespace(content=SimpleNamespace(parts=[SimpleNamespace(text="ok")]))],
)
),
)

assert client.send_multimodal_prompt("hello", sample_screenshot) == "ok"


def test_send_multimodal_prompt_raises_when_model_init_returns_none(
monkeypatch: pytest.MonkeyPatch,
sample_screenshot: Screenshot,
) -> None:
client = GeminiAPIClient(api_key="gem-key")
monkeypatch.setattr("src.services.gemini_api_client.genai.configure", lambda **_kwargs: None)
monkeypatch.setattr("src.services.gemini_api_client.genai.GenerativeModel", lambda **_kwargs: None)

with pytest.raises(APIError, match="Gemini model initialization failed"):
client.send_multimodal_prompt("hello", sample_screenshot)


def test_send_multimodal_prompt_returns_concatenated_text(
monkeypatch: pytest.MonkeyPatch,
sample_screenshot: Screenshot,
Expand Down Expand Up @@ -153,3 +214,33 @@ def test_send_multimodal_prompt_raises_on_empty_candidates(

with pytest.raises(APIError, match="No response candidates"):
client.send_multimodal_prompt("hello", sample_screenshot)


@pytest.mark.parametrize(
("message", "exception_type", "expected"),
[
("API_KEY_INVALID", AuthenticationError, "Invalid or expired Gemini API key"),
("invalid API key", AuthenticationError, "Invalid or expired Gemini API key"),
("quota exceeded", APIError, "API quota exceeded"),
("model not found", APIError, "Model not found"),
("other boom", APIError, "Gemini API request failed"),
],
)
def test_send_multimodal_prompt_maps_runtime_errors(
monkeypatch: pytest.MonkeyPatch,
sample_screenshot: Screenshot,
message: str,
exception_type: type[Exception],
expected: str,
) -> None:
client = GeminiAPIClient(api_key="gem-key")
monkeypatch.setattr("src.services.gemini_api_client.genai.configure", lambda **_kwargs: None)

class _RaiseModel:
def generate_content(self, *_args, **_kwargs):
raise RuntimeError(message)

client._model = _RaiseModel()

with pytest.raises(exception_type, match=expected):
client.send_multimodal_prompt("hello", sample_screenshot)
Loading