diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ec4b2a4..d9a8fbb 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -118,6 +118,7 @@ jobs: env: GROQ_API_KEY: ${{ secrets.GROQ_API_KEY }} GEMINI_API_KEY: ${{ secrets.GEMINI_API_KEY }} + HUGGING_FACE_API: ${{ secrets.HUGGING_FACE_API }} run: pytest tests/integrations/ -v --cov=shekel --cov-report=xml --cov-append - name: Upload integration coverage diff --git a/CHANGELOG.md b/CHANGELOG.md index eb0a127..dea9b42 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,22 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Added +- **Google Gemini Provider Adapter** (`shekel/providers/gemini.py`) — Native support for the `google-genai` SDK + - Patches `google.genai.models.Models.generate_content` (non-streaming) and `generate_content_stream` (streaming) as two separate methods + - Token extraction from `response.usage_metadata.prompt_token_count` / `candidates_token_count` + - Model name captured from `model` kwarg before the call (not available in Gemini response objects) + - New pricing entries: `gemini-2.0-flash`, `gemini-2.5-flash`, `gemini-2.5-pro` + - Install via `pip install shekel[gemini]` +- **HuggingFace Provider Adapter** (`shekel/providers/huggingface.py`) — Support for `huggingface_hub.InferenceClient` + - Patches `InferenceClient.chat_completion` (the underlying method for `.chat.completions.create`) + - OpenAI-compatible token extraction (`usage.prompt_tokens` / `usage.completion_tokens`) + - Graceful handling when models don't return usage in streaming responses + - Install via `pip install shekel[huggingface]` +- **Integration tests** for both new adapters with real API calls (skip gracefully on quota errors) +- **Examples**: `examples/gemini_demo.py`, `examples/huggingface_demo.py` +- **Documentation**: `docs/integrations/gemini.md`, `docs/integrations/huggingface.md` + ## [0.2.5] - 2026-03-11 ### Added diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..7a53d60 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,11 @@ +# CLAUDE.md — Development Guidelines + +## Test File Naming + +Tests must be organized **by domain**, not by implementation unit or coverage goal. + +- **Good**: `test_openai_wrappers.py`, `test_gemini_wrappers.py`, `test_fallback.py` +- **Bad**: `test_patch_coverage.py`, `test_patching.py`, `test_coverage_for_x.py` + +Name test files after the feature or domain being exercised, not after the module +being covered or the motivation for writing the tests. diff --git a/docs/installation.md b/docs/installation.md index 469f7d4..feb70e3 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -35,6 +35,22 @@ If you're using models from both providers: pip install shekel[all] ``` +### Google Gemini + +For Google Gemini via the `google-genai` SDK: + +```bash +pip install shekel[gemini] +``` + +### HuggingFace Inference API + +For HuggingFace's `InferenceClient`: + +```bash +pip install shekel[huggingface] +``` + ### LiteLLM (100+ Providers) For access to OpenAI, Anthropic, Gemini, Cohere, Ollama, Azure, Bedrock, and 90+ more through a unified interface: @@ -107,6 +123,8 @@ Shekel has zero required dependencies beyond the Python standard library. The Op | `openai>=1.0.0` | Optional | Track OpenAI API costs | | `anthropic>=0.7.0` | Optional | Track Anthropic API costs | | `litellm>=1.0.0` | Optional | Track costs via LiteLLM (100+ providers) | +| `google-genai>=1.0.0` | Optional | Track Google Gemini costs (native SDK) | +| `huggingface-hub>=0.20.0` | Optional | Track HuggingFace Inference API costs | | `tokencost>=0.1.0` | Optional | Support 400+ models | | `click>=8.0.0` | Optional | CLI tools | diff --git a/docs/integrations/gemini.md b/docs/integrations/gemini.md new file mode 100644 index 0000000..be371ed --- /dev/null +++ b/docs/integrations/gemini.md @@ -0,0 +1,159 @@ +# Google Gemini Integration + +Shekel tracks costs and enforces budgets for [Google Gemini](https://ai.google.dev/) via the official `google-genai` Python SDK. + +## Installation + +```bash +pip install shekel[gemini] +``` + +## Why a dedicated adapter? + +Unlike OpenAI and Anthropic, Gemini uses its own SDK (`google-genai`) that makes direct API calls — it does **not** route through the OpenAI SDK. Without a dedicated adapter, `budget()` would be completely blind to Gemini spend. + +Shekel's `GeminiAdapter` patches two methods at runtime: + +- `google.genai.models.Models.generate_content` — non-streaming calls +- `google.genai.models.Models.generate_content_stream` — streaming calls + +All other Shekel features (nested budgets, fallback models, `BudgetExceededError`) work identically. + +## Basic Integration + +```python +import google.genai as genai +from shekel import budget + +client = genai.Client(api_key="your-gemini-key") + +with budget(max_usd=1.00) as b: + response = client.models.generate_content( + model="gemini-2.0-flash-lite", + contents="Explain quantum computing in one sentence.", + ) + print(response.candidates[0].content.parts[0].text) + print(f"Cost: ${b.spent:.6f}") +``` + +## Streaming + +Gemini streaming uses a **separate method** (`generate_content_stream`) rather than a `stream=True` kwarg — Shekel patches both: + +```python +with budget(max_usd=1.00) as b: + for chunk in client.models.generate_content_stream( + model="gemini-2.0-flash-lite", + contents="List three benefits of Python.", + ): + if chunk.candidates: + print(chunk.candidates[0].content.parts[0].text, end="", flush=True) + print() + print(f"Cost: ${b.spent:.6f}") +``` + +## Nested Budgets + +Track costs across multi-step Gemini workflows: + +```python +with budget(max_usd=5.00, name="pipeline") as total: + with budget(max_usd=1.00, name="research") as research: + client.models.generate_content( + model="gemini-2.0-flash-lite", + contents="Summarise recent AI trends.", + ) + + with budget(max_usd=2.00, name="analysis") as analysis: + client.models.generate_content( + model="gemini-2.0-flash", + contents="Analyse the implications of those trends.", + ) + +print(f"Research: ${research.spent:.6f}") +print(f"Analysis: ${analysis.spent:.6f}") +print(f"Total: ${total.spent:.6f}") +print(total.tree()) +``` + +## Fallback Models + +Switch to a cheaper Gemini model when spend reaches a threshold: + +```python +with budget( + max_usd=0.50, + fallback={"at_pct": 0.8, "model": "gemini-2.0-flash-lite"}, +) as b: + # Starts with gemini-2.0-flash; auto-switches at 80% ($0.40) + response = client.models.generate_content( + model="gemini-2.0-flash", + contents="Write a detailed market analysis.", + ) + +if b.model_switched: + print(f"Switched to fallback at ${b.switched_at_usd:.4f}") +``` + +!!! note "Same-provider fallback only" + Fallback must be another Gemini model. Cross-provider fallback (e.g. Gemini → GPT-4o) is not supported. + +## Budget Enforcement + +Stop a runaway Gemini loop automatically: + +```python +from shekel import BudgetExceededError + +try: + with budget(max_usd=2.00) as b: + for _ in range(100): # Shekel stops this when budget runs out + client.models.generate_content( + model="gemini-2.0-flash-lite", + contents="Analyse this document.", + ) +except BudgetExceededError as e: + print(f"Stopped at ${e.spent:.4f} — saved the rest of the budget.") +``` + +## Supported Models and Pricing + +| Model | Input (per 1k tokens) | Output (per 1k tokens) | +|---|---|---| +| `gemini-2.5-pro` | $0.00125 | $0.01000 | +| `gemini-2.5-flash` | $0.000075 | $0.00030 | +| `gemini-2.0-flash` | $0.000075 | $0.00030 | +| `gemini-2.0-flash-lite` | $0.000075 | $0.00030 | +| `gemini-1.5-pro` | $0.00125 | $0.00500 | +| `gemini-1.5-flash` | $0.000075 | $0.00030 | + +Shekel uses prefix matching, so `gemini-2.0-flash-001` and similar versioned names resolve automatically. + +## Custom Pricing + +For models not in the pricing table, pass `price_per_1k_tokens`: + +```python +with budget( + max_usd=1.00, + price_per_1k_tokens={"input": 0.0001, "output": 0.0003}, +) as b: + client.models.generate_content( + model="gemini-3-flash-preview", + contents="Hello.", + ) +``` + +## Tips for Gemini + Shekel + +1. **Use `generate_content_stream` for long responses** — streaming lets you stop mid-generation if the budget is hit +2. **Wrap at the workflow level**, not per-call, for accurate total cost tracking +3. **Set `warn_at=0.8`** to log a warning before the budget cap triggers +4. **Gemini free tier has per-minute limits** — use exponential backoff for production workloads + +## Next Steps + +- [HuggingFace Integration](huggingface.md) +- [Nested Budgets](../usage/nested-budgets.md) +- [Fallback Models](../usage/fallback-models.md) +- [Extending Shekel](../extending.md) diff --git a/docs/integrations/huggingface.md b/docs/integrations/huggingface.md new file mode 100644 index 0000000..9fab7b2 --- /dev/null +++ b/docs/integrations/huggingface.md @@ -0,0 +1,144 @@ +# HuggingFace Integration + +Shekel tracks costs and enforces budgets for [HuggingFace Inference API](https://huggingface.co/docs/inference-providers/en/index) via the `huggingface-hub` Python SDK's `InferenceClient`. + +## Installation + +```bash +pip install shekel[huggingface] +``` + +## Why a dedicated adapter? + +HuggingFace's `InferenceClient` uses its own HTTP layer — it does **not** call the OpenAI SDK under the hood. Without a dedicated adapter, `budget()` would be completely blind to HuggingFace spend. + +Shekel's `HuggingFaceAdapter` patches `InferenceClient.chat_completion` at runtime. Since `client.chat.completions.create()` delegates to `chat_completion` internally, all calls through either interface are tracked automatically. + +## Important: Custom Pricing Required + +!!! warning "No bundled HuggingFace pricing" + HuggingFace hosts thousands of models with varying pricing. Shekel has no standard pricing table for HuggingFace models. + + **Always pass `price_per_1k_tokens` to `budget()`** so Shekel can calculate costs: + + ```python + with budget(max_usd=1.00, price_per_1k_tokens={"input": 0.001, "output": 0.001}): + ... + ``` + + If you omit this, `b.spent` will always be `0.0` even though tokens were consumed. + +## Basic Integration + +```python +from huggingface_hub import InferenceClient +from shekel import budget + +client = InferenceClient(token="your-hf-token") + +with budget( + max_usd=1.00, + price_per_1k_tokens={"input": 0.001, "output": 0.001}, +) as b: + response = client.chat.completions.create( + model="meta-llama/Llama-3.2-1B-Instruct", + messages=[{"role": "user", "content": "Explain transformers in one sentence."}], + max_tokens=50, + ) + print(response.choices[0].message.content) + print(f"Cost: ${b.spent:.6f}") +``` + +## Streaming + +```python +with budget( + max_usd=1.00, + price_per_1k_tokens={"input": 0.001, "output": 0.001}, +) as b: + full_text = "" + for chunk in client.chat.completions.create( + model="meta-llama/Llama-3.2-1B-Instruct", + messages=[{"role": "user", "content": "List three ML frameworks."}], + max_tokens=60, + stream=True, + ): + delta = chunk.choices[0].delta.content + if delta: + full_text += delta + print(delta, end="", flush=True) + print() + print(f"Cost: ${b.spent:.6f}") +``` + +!!! note "Streaming usage availability" + Many HuggingFace-hosted models do not return `usage` data in streaming chunks. In that case, `b.spent` will be `0.0` for streaming calls even if tokens were consumed. Non-streaming calls generally do return usage data. + +## Nested Budgets + +```python +with budget( + max_usd=5.00, + name="pipeline", + price_per_1k_tokens={"input": 0.001, "output": 0.001}, +) as total: + with budget( + max_usd=1.00, + name="step-1", + price_per_1k_tokens={"input": 0.001, "output": 0.001}, + ) as step1: + client.chat.completions.create( + model="meta-llama/Llama-3.2-1B-Instruct", + messages=[{"role": "user", "content": "Summarise this document."}], + max_tokens=100, + ) + +print(f"Step 1: ${step1.spent:.6f}") +print(f"Total: ${total.spent:.6f}") +``` + +## Budget Enforcement + +```python +from shekel import BudgetExceededError + +try: + with budget( + max_usd=0.10, + price_per_1k_tokens={"input": 0.001, "output": 0.001}, + ) as b: + for _ in range(100): # Shekel stops this when budget runs out + client.chat.completions.create( + model="meta-llama/Llama-3.2-1B-Instruct", + messages=[{"role": "user", "content": "Analyse this."}], + max_tokens=50, + ) +except BudgetExceededError as e: + print(f"Stopped at ${e.spent:.4f}") +``` + +## Free vs Paid Models + +HuggingFace offers two tiers for inference: + +| Tier | Description | Pricing | +|---|---|---| +| Free (Serverless) | Limited RPM, shared infrastructure | Free but rate-limited | +| PRO / Inference Endpoints | Dedicated infrastructure | Pay per token / per hour | + +For most chat models, use `InferenceClient` with an `hf_*` token. Free-tier models may return 503 when overloaded — add retry logic for production use. + +## Tips for HuggingFace + Shekel + +1. **Always set `price_per_1k_tokens`** — there is no default pricing for HuggingFace models +2. **Use non-streaming calls for accurate cost tracking** — many models omit usage in streaming +3. **Check model availability** — not all models are available on HuggingFace's serverless API +4. **Handle 503 errors** — free-tier endpoints can be temporarily unavailable under load +5. **Use `max_tokens`** to cap response length and control costs + +## Next Steps + +- [Google Gemini Integration](gemini.md) +- [Nested Budgets](../usage/nested-budgets.md) +- [Budget Enforcement](../usage/budget-enforcement.md) +- [Extending Shekel](../extending.md) diff --git a/docs/models.md b/docs/models.md index 78ef4c1..7048d03 100644 --- a/docs/models.md +++ b/docs/models.md @@ -28,8 +28,14 @@ These models have zero-dependency pricing built into shekel: | Model | Input / 1K | Output / 1K | Use Case | |-------|-----------|-------------|----------| -| **gemini-1.5-flash** | $0.0000750 | $0.000300 | Fastest, cheapest | +| **gemini-2.5-pro** | $0.00125 | $0.01000 | Most capable Gemini | +| **gemini-2.5-flash** | $0.0000750 | $0.000300 | Fast, cost-efficient | +| **gemini-2.0-flash** | $0.0000750 | $0.000300 | Latest flash model | | **gemini-1.5-pro** | $0.00125 | $0.00500 | Balanced quality/cost | +| **gemini-1.5-flash** | $0.0000750 | $0.000300 | Fastest, cheapest | + +!!! note "Native Gemini SDK support" + To track costs when calling Gemini via the `google-genai` SDK directly (not through LiteLLM), install `shekel[gemini]`. See [Google Gemini Integration](integrations/gemini.md). ## Version Resolution diff --git a/examples/gemini_demo.py b/examples/gemini_demo.py new file mode 100644 index 0000000..b33c352 --- /dev/null +++ b/examples/gemini_demo.py @@ -0,0 +1,92 @@ +# Requires: pip install shekel[gemini] +""" +Gemini demo: budget enforcement with shekel + google-genai SDK. + +Shekel patches google.genai.models.Models at runtime so every +client.models.generate_content() and generate_content_stream() call +is automatically tracked inside an active budget(). + +Shows three patterns: +1. Basic generate_content with budget tracking +2. Streaming with token accumulation +3. Fallback to a cheaper Gemini model at a spend threshold +""" + +import os + + +def main() -> None: + try: + import google.genai as genai + except ImportError: + print("Missing dependency: google-genai") + print("Run: pip install shekel[gemini]") + return + + api_key = os.environ.get("GEMINI_API_KEY") + if not api_key: + print("Set GEMINI_API_KEY to run this demo.") + return + + from shekel import BudgetExceededError, budget + + client = genai.Client(api_key=api_key) + + # ------------------------------------------------------------------ + # 1. Basic budget enforcement + # ------------------------------------------------------------------ + print("=== Basic budget enforcement ===") + try: + with budget(max_usd=0.10, name="demo", warn_at=0.8) as b: + response = client.models.generate_content( + model="gemini-2.0-flash-lite", + contents="What is 2+2? Answer in one word.", + ) + text = response.candidates[0].content.parts[0].text + print(f"Answer: {text.strip()}") + print(f"Spent: ${b.spent:.6f} / ${b.limit:.2f}") + except BudgetExceededError as e: + print(f"Budget exceeded: {e}") + + # ------------------------------------------------------------------ + # 2. Streaming with token accumulation + # ------------------------------------------------------------------ + print("\n=== Streaming ===") + try: + with budget(max_usd=0.10, name="streaming") as b: + full_text = "" + for chunk in client.models.generate_content_stream( + model="gemini-2.0-flash-lite", + contents="Count from 1 to 5, one number per line.", + ): + if chunk.candidates: + for part in chunk.candidates[0].content.parts: + full_text += part.text + print(f"Response: {full_text.strip()}") + print(f"Spent: ${b.spent:.6f}") + except BudgetExceededError as e: + print(f"Budget exceeded: {e}") + + # ------------------------------------------------------------------ + # 3. Fallback model when threshold is reached + # ------------------------------------------------------------------ + print("\n=== Fallback model ===") + with budget( + max_usd=0.0001, + name="fallback-demo", + fallback={"at_pct": 0.5, "model": "gemini-2.0-flash-lite"}, + ) as b: + try: + client.models.generate_content( + model="gemini-2.0-flash", + contents="What is the capital of France?", + ) + except BudgetExceededError: + pass + if b.model_switched: + print(f"Switched to fallback at ${b.switched_at_usd:.8f}") + print(f"Total: ${b.spent:.6f}") + + +if __name__ == "__main__": + main() diff --git a/examples/huggingface_demo.py b/examples/huggingface_demo.py new file mode 100644 index 0000000..c4b7317 --- /dev/null +++ b/examples/huggingface_demo.py @@ -0,0 +1,102 @@ +# Requires: pip install shekel[huggingface] +""" +HuggingFace InferenceClient demo: budget enforcement with shekel. + +Shekel patches InferenceClient.chat_completion at runtime so every +client.chat.completions.create() call is automatically tracked inside +an active budget(). + +Note: HuggingFace has no standard pricing table. Always pass +price_per_1k_tokens={'input': X, 'output': Y} to budget() so Shekel +knows the cost per token for the model you're using. + +Shows three patterns: +1. Basic chat completion with budget tracking +2. Streaming response with token tracking +3. BudgetExceededError handling to cap runaway costs +""" + +import os + +_MODEL = "meta-llama/Llama-3.2-1B-Instruct" + + +def main() -> None: + try: + from huggingface_hub import InferenceClient + except ImportError: + print("Missing dependency: huggingface-hub") + print("Run: pip install shekel[huggingface]") + return + + api_key = os.environ.get("HUGGING_FACE_API") + if not api_key: + print("Set HUGGING_FACE_API to run this demo.") + return + + from shekel import BudgetExceededError, budget + + client = InferenceClient(token=api_key) + + # Custom pricing — required for HuggingFace (no bundled price table) + pricing = {"input": 0.001, "output": 0.001} # $0.001 per 1k tokens + + # ------------------------------------------------------------------ + # 1. Basic budget enforcement + # ------------------------------------------------------------------ + print("=== Basic budget enforcement ===") + try: + with budget(max_usd=0.10, name="demo", price_per_1k_tokens=pricing) as b: + response = client.chat.completions.create( + model=_MODEL, + messages=[{"role": "user", "content": "What is 2+2? Answer in one word."}], + max_tokens=10, + ) + text = response.choices[0].message.content + print(f"Answer: {text.strip()}") + print(f"Spent: ${b.spent:.6f} / ${b.limit:.2f}") + except BudgetExceededError as e: + print(f"Budget exceeded: {e}") + + # ------------------------------------------------------------------ + # 2. Streaming with token tracking + # ------------------------------------------------------------------ + print("\n=== Streaming ===") + try: + with budget(max_usd=0.10, name="streaming", price_per_1k_tokens=pricing) as b: + full_text = "" + for chunk in client.chat.completions.create( + model=_MODEL, + messages=[{"role": "user", "content": "Count from 1 to 3."}], + max_tokens=20, + stream=True, + ): + delta = chunk.choices[0].delta.content + if delta: + full_text += delta + print(f"Response: {full_text.strip()}") + print(f"Spent: ${b.spent:.6f}") + except BudgetExceededError as e: + print(f"Budget exceeded: {e}") + + # ------------------------------------------------------------------ + # 3. BudgetExceededError stops runaway calls + # ------------------------------------------------------------------ + print("\n=== Budget cap ===") + try: + with budget( + max_usd=0.000001, + name="cap-demo", + price_per_1k_tokens={"input": 100.0, "output": 100.0}, + ): + client.chat.completions.create( + model=_MODEL, + messages=[{"role": "user", "content": "Hello."}], + max_tokens=5, + ) + except BudgetExceededError as e: + print(f"Caught BudgetExceededError at ${e.spent:.8f} — call stopped cleanly.") + + +if __name__ == "__main__": + main() diff --git a/mkdocs.yml b/mkdocs.yml index 47623db..cf71065 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -119,6 +119,8 @@ nav: - CrewAI: integrations/crewai.md - OpenAI: integrations/openai.md - Anthropic: integrations/anthropic.md + - Google Gemini: integrations/gemini.md + - HuggingFace: integrations/huggingface.md - Langfuse: integrations/langfuse.md - Reference: - CLI Tools: cli.md diff --git a/pyproject.toml b/pyproject.toml index 1823050..3a150f5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -63,7 +63,9 @@ openai = ["openai>=1.0.0"] anthropic = ["anthropic>=0.7.0"] langfuse = ["langfuse>=2.0.0"] litellm = ["litellm>=1.0.0"] -all = ["openai>=1.0.0", "anthropic>=0.7.0", "langfuse>=2.0.0", "litellm>=1.0.0"] +gemini = ["google-genai>=1.0.0"] +huggingface = ["huggingface-hub>=0.20.0"] +all = ["openai>=1.0.0", "anthropic>=0.7.0", "langfuse>=2.0.0", "litellm>=1.0.0", "google-genai>=1.0.0", "huggingface-hub>=0.20.0"] all-models = ["openai>=1.0.0", "anthropic>=0.7.0", "langfuse>=2.0.0", "tokencost>=0.1.0"] cli = ["click>=8.0.0"] dev = [ @@ -78,6 +80,8 @@ dev = [ "pytest-cov>=4.0.0", "click>=8.0.0", "mkdocs-material>=9.0.0", + "google-genai>=1.0.0", + "huggingface-hub>=0.20.0", ] [project.scripts] @@ -139,6 +143,14 @@ ignore_missing_imports = true module = "litellm" ignore_missing_imports = true +[[tool.mypy.overrides]] +module = ["google.genai", "google.genai.*"] +ignore_missing_imports = true + +[[tool.mypy.overrides]] +module = ["huggingface_hub", "huggingface_hub.*"] +ignore_missing_imports = true + [[tool.mypy.overrides]] module = ["_pytest", "_pytest.*"] follow_imports = "skip" diff --git a/shekel/_patch.py b/shekel/_patch.py index e532df8..ddeb711 100644 --- a/shekel/_patch.py +++ b/shekel/_patch.py @@ -58,6 +58,7 @@ def _validate_same_provider(fallback_model: str, current_provider: str) -> None: """Raise ValueError if fallback model is from a different provider.""" is_anthropic = fallback_model.startswith("claude-") is_openai = any(fallback_model.startswith(p) for p in ("gpt-", "o1", "o2", "o3", "o4", "text-")) + is_gemini = fallback_model.startswith("gemini-") if current_provider == "openai" and is_anthropic: raise ValueError( @@ -72,6 +73,18 @@ def _validate_same_provider(fallback_model: str, current_provider: str) -> None: f"Cross-provider fallback is not supported in v0.2. " f"Use an Anthropic model as fallback (e.g. fallback='claude-3-haiku-20240307')." ) + if current_provider == "gemini" and not is_gemini: + raise ValueError( + f"shekel: fallback model '{fallback_model}' does not appear to be a Gemini model " + f"but the current call is to Gemini. Cross-provider fallback is not supported. " + f"Use a Gemini model as fallback (e.g. fallback='gemini-2.0-flash')." + ) + if current_provider == "huggingface" and (is_openai or is_anthropic or is_gemini): + raise ValueError( + f"shekel: fallback model '{fallback_model}' does not appear to be a HuggingFace model " + f"but the current call is to HuggingFace. Cross-provider fallback is not supported. " + f"Use a HuggingFace model as fallback (e.g. fallback='HuggingFaceH4/zephyr-7b-beta')." + ) def _apply_fallback_if_needed(active_budget: Any, kwargs: dict[str, Any], provider: str) -> None: @@ -418,3 +431,120 @@ async def _wrap_litellm_stream_async(stream: Any) -> Any: finally: it, ot, m = seen[-1] if seen else (0, 0, "unknown") _record(it, ot, m) + + +# --------------------------------------------------------------------------- +# Gemini sync wrapper (google-genai SDK) +# --------------------------------------------------------------------------- + + +def _gemini_sync_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: + original = _originals.get("gemini_sync") + if original is None: + raise RuntimeError("shekel: gemini original not stored") + + # Capture model name from kwargs before call (not available in response) + model_name: str = kwargs.get("model", None) or "unknown" + + active_budget = _context.get_active_budget() + if active_budget is not None: + _apply_fallback_if_needed(active_budget, kwargs, "gemini") + # Re-read model in case fallback rewrote it + model_name = kwargs.get("model", None) or model_name + + response = original(self, *args, **kwargs) + input_tokens, output_tokens, _ = _extract_gemini_tokens(response) + _record(input_tokens, output_tokens, model_name) + return response + + +def _gemini_stream_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: + original = _originals.get("gemini_stream") + if original is None: + raise RuntimeError("shekel: gemini stream original not stored") + + model_name: str = kwargs.get("model", None) or "unknown" + + active_budget = _context.get_active_budget() + if active_budget is not None: + _apply_fallback_if_needed(active_budget, kwargs, "gemini") + model_name = kwargs.get("model", None) or model_name + + stream = original(self, *args, **kwargs) + return _wrap_gemini_stream(stream, model_name) + + +def _wrap_gemini_stream(stream: Any, model_name: str) -> Generator[Any, None, None]: + seen: list[tuple[int, int]] = [] + try: + for chunk in stream: + usage = getattr(chunk, "usage_metadata", None) + if usage is not None: + try: + it = usage.prompt_token_count or 0 + ot = usage.candidates_token_count or 0 + seen.append((it, ot)) + except AttributeError: + pass + yield chunk + finally: + if seen: + it, ot = seen[-1] + else: + it, ot = 0, 0 + _record(it, ot, model_name) + + +def _extract_gemini_tokens(response: Any) -> tuple[int, int, str]: + try: + usage = response.usage_metadata + if usage is None: + return 0, 0, "unknown" + input_tokens = usage.prompt_token_count or 0 + output_tokens = usage.candidates_token_count or 0 + return input_tokens, output_tokens, "unknown" + except AttributeError: + return 0, 0, "unknown" + + +# --------------------------------------------------------------------------- +# HuggingFace sync wrapper (huggingface-hub SDK) +# --------------------------------------------------------------------------- + + +def _huggingface_sync_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: + original = _originals.get("huggingface_sync") + if original is None: + raise RuntimeError("shekel: huggingface original not stored") + + active_budget = _context.get_active_budget() + if active_budget is not None: + _apply_fallback_if_needed(active_budget, kwargs, "huggingface") + + if kwargs.get("stream") is True: + stream = original(self, *args, **kwargs) + return _wrap_huggingface_stream(stream) + + response = original(self, *args, **kwargs) + input_tokens, output_tokens, model = _extract_openai_tokens(response) + _record(input_tokens, output_tokens, model) + return response + + +def _wrap_huggingface_stream(stream: Any) -> Generator[Any, None, None]: + seen: list[tuple[int, int, str]] = [] + try: + for chunk in stream: + usage = getattr(chunk, "usage", None) + if usage is not None: + try: + it = usage.prompt_tokens or 0 + ot = usage.completion_tokens or 0 + m = getattr(chunk, "model", None) or "unknown" + seen.append((it, ot, m)) + except AttributeError: + pass + yield chunk + finally: + it, ot, m = seen[-1] if seen else (0, 0, "unknown") + _record(it, ot, m) diff --git a/shekel/prices.json b/shekel/prices.json index 0107dc6..b7bd7a4 100644 --- a/shekel/prices.json +++ b/shekel/prices.json @@ -45,5 +45,17 @@ "gemini-1.5-pro": { "input_per_1k": 0.00125, "output_per_1k": 0.005 + }, + "gemini-2.0-flash": { + "input_per_1k": 0.0001, + "output_per_1k": 0.0004 + }, + "gemini-2.5-flash": { + "input_per_1k": 0.00015, + "output_per_1k": 0.0006 + }, + "gemini-2.5-pro": { + "input_per_1k": 0.00125, + "output_per_1k": 0.01 } } diff --git a/shekel/providers/__init__.py b/shekel/providers/__init__.py index 23aec21..6891cb5 100644 --- a/shekel/providers/__init__.py +++ b/shekel/providers/__init__.py @@ -23,6 +23,20 @@ except ImportError: pass +try: + from shekel.providers.gemini import GeminiAdapter + + ADAPTER_REGISTRY.register(GeminiAdapter()) +except ImportError: + pass + +try: + from shekel.providers.huggingface import HuggingFaceAdapter + + ADAPTER_REGISTRY.register(HuggingFaceAdapter()) +except ImportError: + pass + __all__ = [ "ADAPTER_REGISTRY", "ProviderAdapter", @@ -30,4 +44,6 @@ "OpenAIAdapter", "AnthropicAdapter", "LiteLLMAdapter", + "GeminiAdapter", + "HuggingFaceAdapter", ] diff --git a/shekel/providers/gemini.py b/shekel/providers/gemini.py new file mode 100644 index 0000000..51d27a1 --- /dev/null +++ b/shekel/providers/gemini.py @@ -0,0 +1,112 @@ +"""Google Gemini provider adapter for Shekel LLM cost tracking. + +Patches google.genai.models.Models.generate_content (non-streaming) and +google.genai.models.Models.generate_content_stream (streaming) to intercept +API calls and record token costs inside active budgets. + +Usage metadata comes from response.usage_metadata with fields: + - prompt_token_count + - candidates_token_count + +The model name is NOT included in the response object; the wrapper captures it +from the 'model' kwarg before the call and passes it to _record(). +""" + +from __future__ import annotations + +from collections.abc import Generator +from typing import Any + +from shekel.providers.base import ProviderAdapter + + +class GeminiAdapter(ProviderAdapter): + """Adapter for Google Gemini's generate_content API (google-genai SDK).""" + + def __init__(self) -> None: + self._originals: dict[str, Any] = {} + + @property + def name(self) -> str: + return "gemini" + + def install_patches(self) -> None: + """Monkey-patch google.genai.models.Models generate_content methods.""" + from shekel import _patch + + try: + import google.genai.models as gm + + if "gemini_sync" not in _patch._originals: + _patch._originals["gemini_sync"] = gm.Models.generate_content + _patch._originals["gemini_stream"] = gm.Models.generate_content_stream + gm.Models.generate_content = _patch._gemini_sync_wrapper # type: ignore[method-assign] + gm.Models.generate_content_stream = _patch._gemini_stream_wrapper # type: ignore[method-assign] + except ImportError: + pass + + def remove_patches(self) -> None: + """Restore original google.genai.models.Models methods.""" + from shekel import _patch + + try: + import google.genai.models as gm + + if "gemini_sync" in _patch._originals: + gm.Models.generate_content = _patch._originals.pop("gemini_sync") # type: ignore[method-assign] + if "gemini_stream" in _patch._originals: + gm.Models.generate_content_stream = _patch._originals.pop("gemini_stream") # type: ignore[method-assign] + except ImportError: + pass + + def extract_tokens(self, response: Any) -> tuple[int, int, str]: + """Extract tokens from Gemini non-streaming response. + + Uses response.usage_metadata.prompt_token_count / + response.usage_metadata.candidates_token_count. + + Model name is NOT available in the response — returns 'unknown'. + The wrapper captures the model name from kwargs before the call. + """ + try: + usage = response.usage_metadata + if usage is None: + return 0, 0, "unknown" + input_tokens = usage.prompt_token_count or 0 + output_tokens = usage.candidates_token_count or 0 + return input_tokens, output_tokens, "unknown" + except AttributeError: + return 0, 0, "unknown" + + def detect_streaming(self, kwargs: dict[str, Any], response: Any) -> bool: + """Gemini streaming uses a separate method — never stream=True kwarg.""" + return False + + def wrap_stream(self, stream: Any) -> Generator[Any, None, tuple[int, int, str]]: + """Wrap Gemini streaming response to collect usage_metadata from chunks.""" + seen: list[tuple[int, int]] = [] + for chunk in stream: + usage = getattr(chunk, "usage_metadata", None) + if usage is not None: + try: + it = usage.prompt_token_count or 0 + ot = usage.candidates_token_count or 0 + seen.append((it, ot)) + except AttributeError: + pass + yield chunk + if seen: + it, ot = seen[-1] + else: + it, ot = 0, 0 + return it, ot, "unknown" + + def validate_fallback(self, fallback_model: str) -> None: + """Validate that fallback model is a Gemini model.""" + if not fallback_model.startswith("gemini-"): + raise ValueError( + f"shekel: fallback model '{fallback_model}' does not appear to be a " + f"Google Gemini model but the current call is to Gemini. " + f"Cross-provider fallback is not supported. " + f"Use a Gemini model as fallback (e.g. fallback='gemini-2.0-flash')." + ) diff --git a/shekel/providers/huggingface.py b/shekel/providers/huggingface.py new file mode 100644 index 0000000..07de987 --- /dev/null +++ b/shekel/providers/huggingface.py @@ -0,0 +1,121 @@ +"""HuggingFace provider adapter for Shekel LLM cost tracking. + +Patches huggingface_hub.inference._client.InferenceClient.chat_completion to +intercept API calls and record token costs inside active budgets. + +HuggingFace uses an OpenAI-compatible response format: + - response.usage.prompt_tokens + - response.usage.completion_tokens + - response.model + +Note: Many HuggingFace models do not return usage data in streaming responses. +Shekel handles this gracefully by recording zero tokens when usage is absent. + +Since HuggingFace Inference API pricing varies per model and is not standardised, +you should always pass price_per_1k_tokens to budget() when tracking costs: + + with budget(max_usd=0.10, price_per_1k_tokens={"input": 0.001, "output": 0.001}): + client.chat_completion(...) +""" + +from __future__ import annotations + +from collections.abc import Generator +from typing import Any + +from shekel.providers.base import ProviderAdapter + + +class HuggingFaceAdapter(ProviderAdapter): + """Adapter for HuggingFace Inference API (huggingface-hub SDK).""" + + def __init__(self) -> None: + self._originals: dict[str, Any] = {} + + @property + def name(self) -> str: + return "huggingface" + + def install_patches(self) -> None: + """Monkey-patch InferenceClient.chat_completion.""" + from shekel import _patch + + try: + from huggingface_hub.inference import _client + + if "huggingface_sync" not in _patch._originals: + _patch._originals["huggingface_sync"] = _client.InferenceClient.chat_completion + _client.InferenceClient.chat_completion = _patch._huggingface_sync_wrapper # type: ignore[method-assign] + except ImportError: + pass + + def remove_patches(self) -> None: + """Restore original InferenceClient.chat_completion.""" + from shekel import _patch + + try: + from huggingface_hub.inference import _client + + if "huggingface_sync" in _patch._originals: + _client.InferenceClient.chat_completion = _patch._originals.pop("huggingface_sync") # type: ignore[method-assign] + except ImportError: + pass + + def extract_tokens(self, response: Any) -> tuple[int, int, str]: + """Extract tokens from HuggingFace non-streaming response. + + Uses OpenAI-compatible format: + response.usage.prompt_tokens / response.usage.completion_tokens + """ + try: + usage = response.usage + if usage is None: + model = getattr(response, "model", None) or "unknown" + return 0, 0, model + input_tokens = usage.prompt_tokens or 0 + output_tokens = usage.completion_tokens or 0 + model = getattr(response, "model", None) or "unknown" + return input_tokens, output_tokens, model + except AttributeError: + return 0, 0, "unknown" + + def detect_streaming(self, kwargs: dict[str, Any], response: Any) -> bool: + """Detect streaming via the 'stream' kwarg.""" + return kwargs.get("stream") is True + + def wrap_stream(self, stream: Any) -> Generator[Any, None, tuple[int, int, str]]: + """Wrap HuggingFace streaming response to collect token counts. + + Many HuggingFace models do not return usage in streaming chunks. + Returns (0, 0, 'unknown') gracefully when usage is absent. + """ + seen: list[tuple[int, int, str]] = [] + for chunk in stream: + usage = getattr(chunk, "usage", None) + if usage is not None: + try: + it = usage.prompt_tokens or 0 + ot = usage.completion_tokens or 0 + m = getattr(chunk, "model", None) or "unknown" + seen.append((it, ot, m)) + except AttributeError: + pass + yield chunk + return seen[-1] if seen else (0, 0, "unknown") + + def validate_fallback(self, fallback_model: str) -> None: + """Validate that fallback model is a HuggingFace model (org/model format).""" + is_openai = any( + fallback_model.startswith(p) for p in ("gpt-", "o1", "o2", "o3", "o4", "text-") + ) + is_anthropic = fallback_model.startswith("claude-") + is_gemini = fallback_model.startswith("gemini-") + + if is_openai or is_anthropic or is_gemini: + raise ValueError( + f"shekel: fallback model '{fallback_model}' does not appear to be a " + f"HuggingFace model but the current call is to HuggingFace. " + f"Cross-provider fallback is not supported. " + f"Use a HuggingFace model as fallback " + f"(e.g. fallback='HuggingFaceH4/zephyr-7b-beta')." + ) diff --git a/tests/integrations/test_gemini_sdk_integration.py b/tests/integrations/test_gemini_sdk_integration.py new file mode 100644 index 0000000..1a6434b --- /dev/null +++ b/tests/integrations/test_gemini_sdk_integration.py @@ -0,0 +1,325 @@ +"""Integration tests for the Gemini SDK adapter (google-genai). + +Real-API tests (TestGeminiSDKRealIntegration) require GEMINI_API_KEY env var +and are skipped without it. + +Mock tests (TestGeminiSDKMockIntegration) run without any API keys and verify +the adapter's patch lifecycle and token extraction end-to-end. +""" + +from __future__ import annotations + +import os +from typing import Any +from unittest.mock import MagicMock, patch + +import pytest + +from shekel import budget +from shekel.exceptions import BudgetExceededError + +try: + import google.genai as genai + import google.genai.models as gm + from google.genai.errors import ClientError as GeminiClientError + + GENAI_AVAILABLE = True +except ImportError: + GENAI_AVAILABLE = False + GeminiClientError = Exception # type: ignore[assignment,misc] + +pytestmark = pytest.mark.skipif(not GENAI_AVAILABLE, reason="google-genai not installed") + + +# --------------------------------------------------------------------------- +# Real-API tests +# --------------------------------------------------------------------------- + + +class TestGeminiSDKRealIntegration: + """Tests that call the real Gemini API via the google-genai SDK.""" + + @pytest.fixture + def api_key(self) -> str | None: + return os.getenv("GEMINI_API_KEY") + + @pytest.fixture + def available(self, api_key: str | None) -> bool: + return bool(api_key and GENAI_AVAILABLE) + + @pytest.fixture + def client(self, api_key: str | None, available: bool) -> Any: + if not available or api_key is None: + pytest.skip("Gemini API not available") + return genai.Client(api_key=api_key) + + @staticmethod + def _maybe_skip_quota(exc: Exception) -> None: + """Call pytest.skip() if exc is a Gemini quota/rate-limit error.""" + msg = str(exc) + if "429" in msg or "RESOURCE_EXHAUSTED" in msg or "quota" in msg.lower(): + pytest.skip(f"Gemini free-tier quota exhausted: {msg[:120]}") + + def test_basic_generate_content_tracks_spend(self, client: Any, available: bool) -> None: + """budget() tracks spend from client.models.generate_content().""" + if not available: + pytest.skip("Gemini API not available") + + try: + with budget(max_usd=1.00) as b: + response = client.models.generate_content( + model="gemini-2.0-flash-lite", + contents="Say hello in one word.", + ) + assert response is not None + assert b.spent > 0, "Expected spend > 0 after a real Gemini call" + except Exception as exc: + self._maybe_skip_quota(exc) + raise + + def test_streaming_generate_content_tracks_spend(self, client: Any, available: bool) -> None: + """budget() tracks spend when iterating generate_content_stream().""" + if not available: + pytest.skip("Gemini API not available") + + try: + with budget(max_usd=1.00) as b: + chunks = list( + client.models.generate_content_stream( + model="gemini-2.0-flash-lite", + contents="Count to three.", + ) + ) + assert len(chunks) > 0 + assert b.spent >= 0 + except Exception as exc: + self._maybe_skip_quota(exc) + raise + + def test_budget_exceeded_raises_error(self, client: Any, available: bool) -> None: + """BudgetExceededError is raised when budget is exhausted.""" + if not available: + pytest.skip("Gemini API not available") + + try: + with pytest.raises((BudgetExceededError, GeminiClientError)): + with budget( + max_usd=0.000001, + price_per_1k_tokens={"input": 100.0, "output": 100.0}, + ): + client.models.generate_content( + model="gemini-2.0-flash-lite", + contents="Hello.", + ) + except Exception as exc: + self._maybe_skip_quota(exc) + raise + + def test_nested_budgets_roll_up(self, client: Any, available: bool) -> None: + """Spend in an inner budget is reflected in the outer budget.""" + if not available: + pytest.skip("Gemini API not available") + + try: + with budget(max_usd=5.00, name="outer") as outer: + with budget(max_usd=1.00, name="inner") as inner: + client.models.generate_content( + model="gemini-2.0-flash-lite", + contents="Say yes.", + ) + assert inner.spent > 0 + assert outer.spent >= inner.spent + except Exception as exc: + self._maybe_skip_quota(exc) + raise + + def test_fallback_model_within_gemini(self, client: Any, available: bool) -> None: + """Fallback rewrites model kwarg to cheaper Gemini model at threshold.""" + if not available: + pytest.skip("Gemini API not available") + + try: + with budget( + max_usd=0.001, + fallback={"at_pct": 0.01, "model": "gemini-2.0-flash-lite"}, + ) as b: + try: + client.models.generate_content( + model="gemini-2.0-flash", + contents="Say hi.", + ) + except BudgetExceededError: + pass # acceptable — budget very small + assert b.spent >= 0 + except Exception as exc: + self._maybe_skip_quota(exc) + raise + + +# --------------------------------------------------------------------------- +# Mock tests — always run, no API key required +# --------------------------------------------------------------------------- + + +class TestGeminiSDKMockIntegration: + """Verify adapter lifecycle and token extraction without API calls.""" + + def test_patch_install_and_remove_lifecycle(self) -> None: + """Adapter patches Models.generate_content and generate_content_stream.""" + from shekel.providers.gemini import GeminiAdapter + + original_gc = gm.Models.generate_content + original_gcs = gm.Models.generate_content_stream + + adapter = GeminiAdapter() + try: + adapter.install_patches() + assert gm.Models.generate_content is not original_gc + assert gm.Models.generate_content_stream is not original_gcs + finally: + adapter.remove_patches() + + assert gm.Models.generate_content is original_gc + assert gm.Models.generate_content_stream is original_gcs + + def test_budget_records_spend_from_mock_response(self) -> None: + """budget() records correct cost from mocked generate_content call.""" + + class FakeUsage: + prompt_token_count = 50 + candidates_token_count = 25 + + class FakeResponse: + usage_metadata = FakeUsage() + + def fake_generate_content(self: Any, **kwargs: Any) -> FakeResponse: + return FakeResponse() + + with patch.object(gm.Models, "generate_content", fake_generate_content): + with budget( + max_usd=1.00, + price_per_1k_tokens={"input": 1.0, "output": 1.0}, + ) as b: + client = genai.Client(api_key="fake-key") + client.models.generate_content( + model="gemini-2.0-flash", + contents="hello", + ) + + # 50 input + 25 output at $1/1k each = $0.075 + assert b.spent > 0 + + def test_streaming_mock_records_spend(self) -> None: + """budget() records correct cost from mocked generate_content_stream.""" + + class FakeUsage: + prompt_token_count = 30 + candidates_token_count = 15 + + class FakeChunk: + usage_metadata = None + + class FakeChunkWithUsage: + usage_metadata = FakeUsage() + + def fake_stream(self: Any, **kwargs: Any) -> Any: + yield FakeChunk() + yield FakeChunk() + yield FakeChunkWithUsage() + + with patch.object(gm.Models, "generate_content_stream", fake_stream): + with budget( + max_usd=1.00, + price_per_1k_tokens={"input": 1.0, "output": 1.0}, + ) as b: + client = genai.Client(api_key="fake-key") + list( + client.models.generate_content_stream( + model="gemini-2.0-flash", + contents="count", + ) + ) + + assert b.spent > 0 + + def test_budget_exceeded_from_mock(self) -> None: + """BudgetExceededError raised from mocked call when budget tiny.""" + + class FakeUsage: + prompt_token_count = 1000 + candidates_token_count = 500 + + class FakeResponse: + usage_metadata = FakeUsage() + + def fake_generate_content(self: Any, **kwargs: Any) -> FakeResponse: + return FakeResponse() + + with patch.object(gm.Models, "generate_content", fake_generate_content): + with pytest.raises(BudgetExceededError): + with budget( + max_usd=0.000001, + price_per_1k_tokens={"input": 100.0, "output": 100.0}, + ): + client = genai.Client(api_key="fake-key") + client.models.generate_content( + model="gemini-2.0-flash", + contents="hello", + ) + + def test_no_crash_without_google_genai(self) -> None: + """install_patches() is a no-op when google-genai is not importable.""" + from shekel.providers.gemini import GeminiAdapter + + adapter = GeminiAdapter() + with patch.dict( + "sys.modules", + {"google": None, "google.genai": None, "google.genai.models": None}, + ): + try: + adapter.install_patches() + except Exception: + pass # ImportError is acceptable + + def test_extract_tokens_from_real_response_shape(self) -> None: + """extract_tokens handles the real Gemini response structure.""" + from shekel.providers.gemini import GeminiAdapter + + class FakeUsageMeta: + prompt_token_count = 42 + candidates_token_count = 17 + + class FakeResp: + usage_metadata = FakeUsageMeta() + + adapter = GeminiAdapter() + it, ot, model = adapter.extract_tokens(FakeResp()) + assert it == 42 + assert ot == 17 + assert model == "unknown" # model never in Gemini response + + def test_model_kwarg_captured_for_pricing(self) -> None: + """Wrapper uses model kwarg for cost calculation, not response field.""" + from shekel import _patch + + class FakeUsage: + prompt_token_count = 100 + candidates_token_count = 50 + + class FakeResponse: + usage_metadata = FakeUsage() + + original = MagicMock(return_value=FakeResponse()) + _patch._originals["gemini_sync"] = original + + try: + with budget( + max_usd=1.00, + price_per_1k_tokens={"input": 1.0, "output": 1.0}, + ) as b: + fake_self = MagicMock() + _patch._gemini_sync_wrapper(fake_self, model="gemini-2.0-flash", contents="hi") + finally: + _patch._originals.pop("gemini_sync", None) + + assert b.spent > 0 diff --git a/tests/integrations/test_huggingface_integration.py b/tests/integrations/test_huggingface_integration.py new file mode 100644 index 0000000..9f329aa --- /dev/null +++ b/tests/integrations/test_huggingface_integration.py @@ -0,0 +1,351 @@ +"""Integration tests for the HuggingFace adapter (huggingface-hub InferenceClient). + +Real-API tests (TestHuggingFaceRealIntegration) require HUGGING_FACE_API env var +and are skipped without it. + +Mock tests (TestHuggingFaceMockIntegration) run without any API keys and verify +the adapter's patch lifecycle and token extraction end-to-end. +""" + +from __future__ import annotations + +import os +from typing import Any +from unittest.mock import MagicMock, patch + +import pytest + +from shekel import budget +from shekel.exceptions import BudgetExceededError + +try: + from huggingface_hub import InferenceClient + from huggingface_hub.errors import BadRequestError as HFBadRequestError + from huggingface_hub.inference import _client as hf_client + + HF_AVAILABLE = True +except ImportError: + HF_AVAILABLE = False + HFBadRequestError = Exception # type: ignore[assignment,misc] + +pytestmark = pytest.mark.skipif(not HF_AVAILABLE, reason="huggingface-hub not installed") + +_HF_MODEL = "meta-llama/Llama-3.2-1B-Instruct" + + +# --------------------------------------------------------------------------- +# Real-API tests +# --------------------------------------------------------------------------- + + +class TestHuggingFaceRealIntegration: + """Tests that call the real HuggingFace Inference API.""" + + @pytest.fixture + def api_key(self) -> str | None: + return os.getenv("HUGGING_FACE_API") + + @pytest.fixture + def available(self, api_key: str | None) -> bool: + return bool(api_key and HF_AVAILABLE) + + @pytest.fixture + def client(self, api_key: str | None, available: bool) -> Any: + if not available or api_key is None: + pytest.skip("HuggingFace API not available") + return InferenceClient(token=api_key) + + def _skip_on_api_error(self, exc: Exception) -> None: + """Skip test gracefully on model-not-supported or provider errors.""" + msg = str(exc) + if "not supported" in msg or "model_not_supported" in msg or "503" in msg: + pytest.skip(f"HuggingFace model unavailable: {msg[:120]}") + + def test_basic_chat_completion_tracks_spend(self, client: Any, available: bool) -> None: + """budget() tracks spend from InferenceClient.chat.completions.create().""" + if not available: + pytest.skip("HuggingFace API not available") + + try: + with budget( + max_usd=1.00, + price_per_1k_tokens={"input": 0.001, "output": 0.001}, + ) as b: + response = client.chat.completions.create( + model=_HF_MODEL, + messages=[{"role": "user", "content": "Say hello in one word."}], + max_tokens=10, + ) + assert response is not None + except HFBadRequestError as e: + self._skip_on_api_error(e) + raise + + # Usage data availability varies by model — just assert no crash + assert b.spent >= 0 + + def test_streaming_chat_completion_tracks_spend(self, client: Any, available: bool) -> None: + """budget() handles streaming chat_completion call.""" + if not available: + pytest.skip("HuggingFace API not available") + + try: + with budget( + max_usd=1.00, + price_per_1k_tokens={"input": 0.001, "output": 0.001}, + ) as b: + chunks = list( + client.chat.completions.create( + model=_HF_MODEL, + messages=[{"role": "user", "content": "Say hi."}], + max_tokens=5, + stream=True, + ) + ) + assert len(chunks) > 0 + except HFBadRequestError as e: + self._skip_on_api_error(e) + raise + + assert b.spent >= 0 + + def test_budget_exceeded_raises_error(self, client: Any, available: bool) -> None: + """BudgetExceededError is raised when budget is exhausted.""" + if not available: + pytest.skip("HuggingFace API not available") + + try: + with pytest.raises(BudgetExceededError): + with budget( + max_usd=0.000001, + price_per_1k_tokens={"input": 100.0, "output": 100.0}, + ): + client.chat.completions.create( + model=_HF_MODEL, + messages=[{"role": "user", "content": "Hello."}], + max_tokens=5, + ) + except HFBadRequestError as e: + self._skip_on_api_error(e) + raise + + def test_nested_budgets_roll_up(self, client: Any, available: bool) -> None: + """Inner budget spend is visible in outer budget.""" + if not available: + pytest.skip("HuggingFace API not available") + + try: + with budget( + max_usd=5.00, + name="outer", + price_per_1k_tokens={"input": 0.001, "output": 0.001}, + ) as outer: + with budget( + max_usd=1.00, + name="inner", + price_per_1k_tokens={"input": 0.001, "output": 0.001}, + ) as inner: + client.chat.completions.create( + model=_HF_MODEL, + messages=[{"role": "user", "content": "Yes or no?"}], + max_tokens=5, + ) + assert outer.spent >= inner.spent + except HFBadRequestError as e: + self._skip_on_api_error(e) + raise + + +# --------------------------------------------------------------------------- +# Mock tests — always run, no API key required +# --------------------------------------------------------------------------- + + +class TestHuggingFaceMockIntegration: + """Verify adapter lifecycle and token extraction without API calls.""" + + def test_patch_install_and_remove_lifecycle(self) -> None: + """Adapter patches InferenceClient.chat_completion.""" + from shekel.providers.huggingface import HuggingFaceAdapter + + original = hf_client.InferenceClient.chat_completion + + adapter = HuggingFaceAdapter() + try: + adapter.install_patches() + assert hf_client.InferenceClient.chat_completion is not original + finally: + adapter.remove_patches() + + assert hf_client.InferenceClient.chat_completion is original + + def test_budget_records_spend_from_mock_response(self) -> None: + """budget() records correct cost from a mocked chat_completion call.""" + + class FakeUsage: + prompt_tokens = 80 + completion_tokens = 40 + + class FakeResponse: + model = _HF_MODEL + usage = FakeUsage() + + def fake_chat_completion(self: Any, messages: Any, **kwargs: Any) -> FakeResponse: + return FakeResponse() + + with patch.object(hf_client.InferenceClient, "chat_completion", fake_chat_completion): + with budget( + max_usd=1.00, + price_per_1k_tokens={"input": 1.0, "output": 1.0}, + ) as b: + client = InferenceClient(token="fake-token") + client.chat_completion( + messages=[{"role": "user", "content": "hello"}], + model=_HF_MODEL, + ) + + # 80 input + 40 output at $1/1k each = $0.12 + assert b.spent > 0 + + def test_streaming_mock_records_spend(self) -> None: + """budget() records correct cost from mocked streaming call.""" + + class FakeUsage: + prompt_tokens = 30 + completion_tokens = 15 + + class FakeChunk: + model = _HF_MODEL + usage = None + + class FakeChunkWithUsage: + model = _HF_MODEL + usage = FakeUsage() + + def fake_stream(self: Any, messages: Any, **kwargs: Any) -> Any: + yield FakeChunk() + yield FakeChunk() + yield FakeChunkWithUsage() + + with patch.object(hf_client.InferenceClient, "chat_completion", fake_stream): + with budget( + max_usd=1.00, + price_per_1k_tokens={"input": 1.0, "output": 1.0}, + ) as b: + client = InferenceClient(token="fake-token") + list( + client.chat_completion( + messages=[{"role": "user", "content": "hello"}], + model=_HF_MODEL, + stream=True, + ) + ) + + assert b.spent > 0 + + def test_budget_exceeded_from_mock(self) -> None: + """BudgetExceededError raised when mocked usage pushes over tiny budget.""" + + class FakeUsage: + prompt_tokens = 1000 + completion_tokens = 500 + + class FakeResponse: + model = _HF_MODEL + usage = FakeUsage() + + def fake_chat_completion(self: Any, messages: Any, **kwargs: Any) -> FakeResponse: + return FakeResponse() + + with patch.object(hf_client.InferenceClient, "chat_completion", fake_chat_completion): + with pytest.raises(BudgetExceededError): + with budget( + max_usd=0.000001, + price_per_1k_tokens={"input": 100.0, "output": 100.0}, + ): + client = InferenceClient(token="fake-token") + client.chat_completion( + messages=[{"role": "user", "content": "hello"}], + model=_HF_MODEL, + ) + + def test_no_crash_without_huggingface_hub(self) -> None: + """install_patches() is a no-op when huggingface-hub is not importable.""" + from shekel.providers.huggingface import HuggingFaceAdapter + + adapter = HuggingFaceAdapter() + with patch.dict( + "sys.modules", + { + "huggingface_hub": None, + "huggingface_hub.inference": None, + "huggingface_hub.inference._client": None, + }, + ): + try: + adapter.install_patches() + except Exception: + pass # ImportError is acceptable + + def test_extract_tokens_from_real_response_shape(self) -> None: + """extract_tokens handles the real HuggingFace response structure.""" + from shekel.providers.huggingface import HuggingFaceAdapter + + class FakeUsage: + prompt_tokens = 60 + completion_tokens = 30 + + class FakeResp: + model = _HF_MODEL + usage = FakeUsage() + + adapter = HuggingFaceAdapter() + it, ot, model = adapter.extract_tokens(FakeResp()) + assert it == 60 + assert ot == 30 + assert model == _HF_MODEL + + def test_no_usage_in_response_returns_zeros(self) -> None: + """When response.usage is None, extract_tokens returns (0, 0, model).""" + from shekel.providers.huggingface import HuggingFaceAdapter + + class FakeResp: + model = _HF_MODEL + usage = None + + adapter = HuggingFaceAdapter() + it, ot, model = adapter.extract_tokens(FakeResp()) + assert it == 0 + assert ot == 0 + assert model == _HF_MODEL + + def test_wrapper_directly_records_spend(self) -> None: + """_huggingface_sync_wrapper records cost via _patch._record.""" + from shekel import _patch + + class FakeUsage: + prompt_tokens = 200 + completion_tokens = 100 + + class FakeResponse: + model = _HF_MODEL + usage = FakeUsage() + + original = MagicMock(return_value=FakeResponse()) + _patch._originals["huggingface_sync"] = original + + try: + with budget( + max_usd=1.00, + price_per_1k_tokens={"input": 1.0, "output": 1.0}, + ) as b: + fake_self = MagicMock() + _patch._huggingface_sync_wrapper( + fake_self, + [{"role": "user", "content": "hi"}], + model=_HF_MODEL, + ) + finally: + _patch._originals.pop("huggingface_sync", None) + + assert b.spent > 0 diff --git a/tests/providers/test_gemini_adapter.py b/tests/providers/test_gemini_adapter.py new file mode 100644 index 0000000..81fbcf7 --- /dev/null +++ b/tests/providers/test_gemini_adapter.py @@ -0,0 +1,454 @@ +"""Unit tests for the Gemini provider adapter. + +Tests cover: +- Adapter name and isinstance checks +- Token extraction from usage_metadata +- Stream detection via separate generate_content_stream method +- Stream wrapping and usage_metadata collection +- Fallback model validation (must be gemini-*) +- Patch install/remove lifecycle +""" + +from __future__ import annotations + +from collections.abc import Generator +from typing import Any +from unittest.mock import MagicMock, patch + +import pytest + +from tests.providers.conftest import ProviderTestBase + +# --------------------------------------------------------------------------- +# Lazy import guard — skip all tests if google-genai is not installed +# --------------------------------------------------------------------------- + + +def _gemini_available() -> bool: + try: + import google.genai # noqa: F401 + + return True + except ImportError: + return False + + +pytestmark = pytest.mark.skipif(not _gemini_available(), reason="google-genai not installed") + + +# --------------------------------------------------------------------------- +# Mock fixtures +# --------------------------------------------------------------------------- + + +class MockUsageMetadata: + """Mock Gemini usage_metadata object.""" + + def __init__(self, prompt_token_count: int = 0, candidates_token_count: int = 0) -> None: + self.prompt_token_count = prompt_token_count + self.candidates_token_count = candidates_token_count + + +class MockGeminiResponse: + """Mock Gemini generate_content response.""" + + def __init__( + self, + usage_metadata: MockUsageMetadata | None = None, + model_version: str | None = None, + ) -> None: + self.usage_metadata = usage_metadata + self.model_version = model_version + + +class MockGeminiStreamChunk: + """Mock Gemini streaming chunk.""" + + def __init__(self, usage_metadata: MockUsageMetadata | None = None) -> None: + self.usage_metadata = usage_metadata + + +# --------------------------------------------------------------------------- +# Helper to build a mock Gemini stream +# --------------------------------------------------------------------------- + + +def make_gemini_stream( + input_tokens: int = 0, output_tokens: int = 0 +) -> Generator[MockGeminiStreamChunk, None, None]: + """Yield content chunks then a final chunk with usage_metadata.""" + yield MockGeminiStreamChunk() + yield MockGeminiStreamChunk() + yield MockGeminiStreamChunk( + usage_metadata=MockUsageMetadata( + prompt_token_count=input_tokens, + candidates_token_count=output_tokens, + ) + ) + + +# --------------------------------------------------------------------------- +# TestGeminiAdapterBasic +# --------------------------------------------------------------------------- + + +class TestGeminiAdapterBasic(ProviderTestBase): + """Test adapter name and base class membership.""" + + def test_name(self) -> None: + from shekel.providers.gemini import GeminiAdapter + + adapter = GeminiAdapter() + assert adapter.name == "gemini" + + def test_isinstance_provider_adapter(self) -> None: + from shekel.providers.base import ProviderAdapter + from shekel.providers.gemini import GeminiAdapter + + adapter = GeminiAdapter() + assert isinstance(adapter, ProviderAdapter) + + +# --------------------------------------------------------------------------- +# TestGeminiTokenExtraction +# --------------------------------------------------------------------------- + + +class TestGeminiTokenExtraction(ProviderTestBase): + """Test extract_tokens() for various response shapes.""" + + def test_normal_response(self) -> None: + from shekel.providers.gemini import GeminiAdapter + + adapter = GeminiAdapter() + response = MockGeminiResponse( + usage_metadata=MockUsageMetadata(prompt_token_count=100, candidates_token_count=50) + ) + it, ot, model = adapter.extract_tokens(response) + assert it == 100 + assert ot == 50 + assert model == "unknown" # model not in response; wrapper passes model name + + def test_none_usage_metadata(self) -> None: + from shekel.providers.gemini import GeminiAdapter + + adapter = GeminiAdapter() + response = MockGeminiResponse(usage_metadata=None) + it, ot, model = adapter.extract_tokens(response) + assert it == 0 + assert ot == 0 + assert model == "unknown" + + def test_missing_usage_metadata_attr(self) -> None: + from shekel.providers.gemini import GeminiAdapter + + adapter = GeminiAdapter() + response = object() + it, ot, model = adapter.extract_tokens(response) + assert it == 0 + assert ot == 0 + assert model == "unknown" + + def test_zero_tokens(self) -> None: + from shekel.providers.gemini import GeminiAdapter + + adapter = GeminiAdapter() + response = MockGeminiResponse( + usage_metadata=MockUsageMetadata(prompt_token_count=0, candidates_token_count=0) + ) + it, ot, model = adapter.extract_tokens(response) + assert it == 0 + assert ot == 0 + + def test_none_token_counts(self) -> None: + """usage_metadata exists but token counts are None — should return 0.""" + from shekel.providers.gemini import GeminiAdapter + + adapter = GeminiAdapter() + + class BrokenUsage: + prompt_token_count = None + candidates_token_count = None + + response = MockGeminiResponse(usage_metadata=BrokenUsage()) # type: ignore[arg-type] + it, ot, model = adapter.extract_tokens(response) + assert it == 0 + assert ot == 0 + + +# --------------------------------------------------------------------------- +# TestGeminiStreamDetection +# --------------------------------------------------------------------------- + + +class TestGeminiStreamDetection(ProviderTestBase): + """Test detect_streaming() — Gemini uses a separate method so stream kwarg is False.""" + + def test_detect_streaming_no_stream_kwarg(self) -> None: + from shekel.providers.gemini import GeminiAdapter + + adapter = GeminiAdapter() + # generate_content_stream is a separate method; no stream kwarg is passed + assert adapter.detect_streaming({}, None) is False + + def test_detect_streaming_with_false_kwarg(self) -> None: + from shekel.providers.gemini import GeminiAdapter + + adapter = GeminiAdapter() + assert adapter.detect_streaming({"stream": False}, None) is False + + def test_detect_streaming_empty_kwargs(self) -> None: + from shekel.providers.gemini import GeminiAdapter + + adapter = GeminiAdapter() + assert adapter.detect_streaming({}, MagicMock()) is False + + +# --------------------------------------------------------------------------- +# TestGeminiStreamWrapping +# --------------------------------------------------------------------------- + + +class TestGeminiStreamWrapping(ProviderTestBase): + """Test wrap_stream() collects usage_metadata from chunks.""" + + def test_yields_all_chunks(self) -> None: + from shekel.providers.gemini import GeminiAdapter + + adapter = GeminiAdapter() + stream = make_gemini_stream(input_tokens=10, output_tokens=5) + chunks = list(adapter.wrap_stream(stream)) + assert len(chunks) == 3 + + def test_collects_usage_metadata(self) -> None: + from shekel.providers.gemini import GeminiAdapter + + adapter = GeminiAdapter() + stream = make_gemini_stream(input_tokens=100, output_tokens=50) + gen = adapter.wrap_stream(stream) + seen: list[Any] = [] + try: + while True: + seen.append(next(gen)) + except StopIteration as e: + it, ot, model = e.value + assert it == 100 + assert ot == 50 + assert model == "unknown" + + def test_no_usage_returns_zeros(self) -> None: + """Stream with no usage chunks returns (0, 0, 'unknown').""" + from shekel.providers.gemini import GeminiAdapter + + adapter = GeminiAdapter() + + def no_usage_stream() -> Generator[MockGeminiStreamChunk, None, None]: + yield MockGeminiStreamChunk() + yield MockGeminiStreamChunk() + + gen = adapter.wrap_stream(no_usage_stream()) + try: + while True: + next(gen) + except StopIteration as e: + it, ot, model = e.value + assert it == 0 + assert ot == 0 + assert model == "unknown" + + def test_empty_stream_returns_zeros(self) -> None: + from shekel.providers.gemini import GeminiAdapter + + adapter = GeminiAdapter() + + def empty() -> Generator[Any, None, None]: + return + yield # noqa: F704 + + gen = adapter.wrap_stream(empty()) + try: + while True: + next(gen) + except StopIteration as e: + it, ot, model = e.value + assert (it, ot, model) == (0, 0, "unknown") + + +# --------------------------------------------------------------------------- +# TestGeminiFallbackValidation +# --------------------------------------------------------------------------- + + +class TestGeminiFallbackValidation(ProviderTestBase): + """Test validate_fallback() rejects non-gemini models.""" + + def test_accepts_gemini_model(self) -> None: + from shekel.providers.gemini import GeminiAdapter + + adapter = GeminiAdapter() + adapter.validate_fallback("gemini-2.0-flash") # should not raise + + def test_accepts_gemini_25_pro(self) -> None: + from shekel.providers.gemini import GeminiAdapter + + adapter = GeminiAdapter() + adapter.validate_fallback("gemini-2.5-pro") + + def test_rejects_gpt_model(self) -> None: + from shekel.providers.gemini import GeminiAdapter + + adapter = GeminiAdapter() + with pytest.raises(ValueError, match="gemini"): + adapter.validate_fallback("gpt-4o-mini") + + def test_rejects_claude_model(self) -> None: + from shekel.providers.gemini import GeminiAdapter + + adapter = GeminiAdapter() + with pytest.raises(ValueError, match="gemini"): + adapter.validate_fallback("claude-3-haiku-20240307") + + def test_rejects_arbitrary_model(self) -> None: + from shekel.providers.gemini import GeminiAdapter + + adapter = GeminiAdapter() + with pytest.raises(ValueError): + adapter.validate_fallback("llama-3-8b") + + +# --------------------------------------------------------------------------- +# TestGeminiPatching +# --------------------------------------------------------------------------- + + +class TestGeminiPatching(ProviderTestBase): + """Test install_patches() / remove_patches() lifecycle.""" + + def test_install_patches_replaces_generate_content(self) -> None: + import google.genai.models as gm + + from shekel import _patch + from shekel.providers.gemini import GeminiAdapter + + original_gc = gm.Models.generate_content + original_gcs = gm.Models.generate_content_stream + + adapter = GeminiAdapter() + try: + adapter.install_patches() + assert "gemini_sync" in _patch._originals + assert "gemini_stream" in _patch._originals + assert gm.Models.generate_content is not original_gc + assert gm.Models.generate_content_stream is not original_gcs + finally: + adapter.remove_patches() + # Restore to originals in case test fails mid-way + if gm.Models.generate_content is not original_gc: + gm.Models.generate_content = original_gc # type: ignore[method-assign] + if gm.Models.generate_content_stream is not original_gcs: + gm.Models.generate_content_stream = original_gcs # type: ignore[method-assign] + + def test_remove_patches_restores_originals(self) -> None: + import google.genai.models as gm + + from shekel import _patch + from shekel.providers.gemini import GeminiAdapter + + original_gc = gm.Models.generate_content + original_gcs = gm.Models.generate_content_stream + + adapter = GeminiAdapter() + adapter.install_patches() + adapter.remove_patches() + + assert "gemini_sync" not in _patch._originals + assert "gemini_stream" not in _patch._originals + assert gm.Models.generate_content is original_gc + assert gm.Models.generate_content_stream is original_gcs + + def test_install_patches_idempotent(self) -> None: + """Calling install_patches() twice should not double-wrap.""" + import google.genai.models as gm + + from shekel.providers.gemini import GeminiAdapter + + original_gc = gm.Models.generate_content + + adapter = GeminiAdapter() + try: + adapter.install_patches() + patched_first = gm.Models.generate_content + adapter.install_patches() + # Should still be the same patched function + assert gm.Models.generate_content is patched_first + finally: + adapter.remove_patches() + if gm.Models.generate_content is not original_gc: + gm.Models.generate_content = original_gc # type: ignore[method-assign] + + def test_remove_patches_safe_without_install(self) -> None: + """remove_patches() before install_patches() should not raise.""" + from shekel.providers.gemini import GeminiAdapter + + adapter = GeminiAdapter() + adapter.remove_patches() # Should not raise + + def test_no_import_safety(self) -> None: + """install_patches() is a no-op when google-genai is not importable.""" + from shekel.providers.gemini import GeminiAdapter + + adapter = GeminiAdapter() + with patch.dict( + "sys.modules", {"google": None, "google.genai": None, "google.genai.models": None} + ): + # Even with the module mocked out, should not raise + try: + adapter.install_patches() + except Exception: + pass # ImportError is acceptable + + def test_remove_patches_safe_with_missing_import(self) -> None: + """Lines 59-60: remove_patches() catches ImportError when google-genai unavailable.""" + from shekel.providers.gemini import GeminiAdapter + + adapter = GeminiAdapter() + with patch.dict( + "sys.modules", {"google": None, "google.genai": None, "google.genai.models": None} + ): + adapter.remove_patches() # must not raise + + +# --------------------------------------------------------------------------- +# TestGeminiWrapStreamAttributeError +# --------------------------------------------------------------------------- + + +class TestGeminiWrapStreamAttributeError(ProviderTestBase): + """Test wrap_stream() handles chunks with broken usage_metadata attributes.""" + + def test_wrap_stream_swallows_usage_attribute_error(self) -> None: + """Lines 95-96: chunk whose usage_metadata attrs raise AttributeError is skipped.""" + from unittest.mock import MagicMock + + from shekel.providers.gemini import GeminiAdapter + + adapter = GeminiAdapter() + + class BrokenUsage: + def __getattr__(self, name: str) -> None: + raise AttributeError(name) + + def stream() -> Any: # type: ignore[return] + chunk = MagicMock() + chunk.usage_metadata = BrokenUsage() + yield chunk + + gen = adapter.wrap_stream(stream()) + try: + while True: + next(gen) + except StopIteration as e: + it, ot, model = e.value + assert it == 0 + assert ot == 0 + assert model == "unknown" diff --git a/tests/providers/test_huggingface_adapter.py b/tests/providers/test_huggingface_adapter.py new file mode 100644 index 0000000..d49c08c --- /dev/null +++ b/tests/providers/test_huggingface_adapter.py @@ -0,0 +1,460 @@ +"""Unit tests for the HuggingFace provider adapter. + +Tests cover: +- Adapter name and isinstance checks +- Token extraction from OpenAI-compatible usage format +- Stream detection via stream kwarg +- Stream wrapping and usage collection +- Fallback model validation +- Patch install/remove lifecycle +""" + +from __future__ import annotations + +from collections.abc import Generator +from typing import Any +from unittest.mock import patch + +import pytest + +from tests.providers.conftest import ProviderTestBase + +# --------------------------------------------------------------------------- +# Lazy import guard — skip all tests if huggingface-hub is not installed +# --------------------------------------------------------------------------- + + +def _huggingface_available() -> bool: + try: + from huggingface_hub.inference import _client # noqa: F401 + + return True + except ImportError: + return False + + +pytestmark = pytest.mark.skipif( + not _huggingface_available(), reason="huggingface-hub not installed" +) + + +# --------------------------------------------------------------------------- +# Mock fixtures +# --------------------------------------------------------------------------- + + +class MockHFUsage: + """Mock HuggingFace ChatCompletionOutputUsage (OpenAI-compatible).""" + + def __init__(self, prompt_tokens: int = 0, completion_tokens: int = 0) -> None: + self.prompt_tokens = prompt_tokens + self.completion_tokens = completion_tokens + + +class MockHFResponse: + """Mock HuggingFace chat_completion response.""" + + def __init__( + self, + model: str = "HuggingFaceH4/zephyr-7b-beta", + usage: MockHFUsage | None = None, + ) -> None: + self.model = model + self.usage = usage + + +class MockHFStreamChunk: + """Mock HuggingFace streaming chunk.""" + + def __init__( + self, + model: str | None = None, + usage: MockHFUsage | None = None, + ) -> None: + self.model = model + self.usage = usage + + +# --------------------------------------------------------------------------- +# Helper to build a mock HF stream +# --------------------------------------------------------------------------- + + +def make_hf_stream( + model: str = "HuggingFaceH4/zephyr-7b-beta", + input_tokens: int = 0, + output_tokens: int = 0, +) -> Generator[MockHFStreamChunk, None, None]: + """Yield content chunks then a final chunk with usage.""" + yield MockHFStreamChunk(model=model) + yield MockHFStreamChunk(model=model) + yield MockHFStreamChunk( + model=model, + usage=MockHFUsage(prompt_tokens=input_tokens, completion_tokens=output_tokens), + ) + + +# --------------------------------------------------------------------------- +# TestHuggingFaceAdapterBasic +# --------------------------------------------------------------------------- + + +class TestHuggingFaceAdapterBasic(ProviderTestBase): + """Test adapter name and base class membership.""" + + def test_name(self) -> None: + from shekel.providers.huggingface import HuggingFaceAdapter + + adapter = HuggingFaceAdapter() + assert adapter.name == "huggingface" + + def test_isinstance_provider_adapter(self) -> None: + from shekel.providers.base import ProviderAdapter + from shekel.providers.huggingface import HuggingFaceAdapter + + adapter = HuggingFaceAdapter() + assert isinstance(adapter, ProviderAdapter) + + +# --------------------------------------------------------------------------- +# TestHuggingFaceTokenExtraction +# --------------------------------------------------------------------------- + + +class TestHuggingFaceTokenExtraction(ProviderTestBase): + """Test extract_tokens() for various response shapes.""" + + def test_normal_response(self) -> None: + from shekel.providers.huggingface import HuggingFaceAdapter + + adapter = HuggingFaceAdapter() + response = MockHFResponse( + model="HuggingFaceH4/zephyr-7b-beta", + usage=MockHFUsage(prompt_tokens=80, completion_tokens=40), + ) + it, ot, model = adapter.extract_tokens(response) + assert it == 80 + assert ot == 40 + assert model == "HuggingFaceH4/zephyr-7b-beta" + + def test_none_usage(self) -> None: + from shekel.providers.huggingface import HuggingFaceAdapter + + adapter = HuggingFaceAdapter() + response = MockHFResponse(usage=None) + it, ot, model = adapter.extract_tokens(response) + assert it == 0 + assert ot == 0 + + def test_missing_usage_attr(self) -> None: + from shekel.providers.huggingface import HuggingFaceAdapter + + adapter = HuggingFaceAdapter() + response = object() + it, ot, model = adapter.extract_tokens(response) + assert it == 0 + assert ot == 0 + assert model == "unknown" + + def test_zero_tokens(self) -> None: + from shekel.providers.huggingface import HuggingFaceAdapter + + adapter = HuggingFaceAdapter() + response = MockHFResponse(usage=MockHFUsage(prompt_tokens=0, completion_tokens=0)) + it, ot, model = adapter.extract_tokens(response) + assert it == 0 + assert ot == 0 + + def test_none_token_counts(self) -> None: + from shekel.providers.huggingface import HuggingFaceAdapter + + adapter = HuggingFaceAdapter() + + class BrokenUsage: + prompt_tokens = None + completion_tokens = None + + response = MockHFResponse(usage=BrokenUsage()) # type: ignore[arg-type] + it, ot, model = adapter.extract_tokens(response) + assert it == 0 + assert ot == 0 + + +# --------------------------------------------------------------------------- +# TestHuggingFaceStreamDetection +# --------------------------------------------------------------------------- + + +class TestHuggingFaceStreamDetection(ProviderTestBase): + """Test detect_streaming() — uses stream=True kwarg.""" + + def test_detect_stream_true(self) -> None: + from shekel.providers.huggingface import HuggingFaceAdapter + + adapter = HuggingFaceAdapter() + assert adapter.detect_streaming({"stream": True}, None) is True + + def test_detect_stream_false(self) -> None: + from shekel.providers.huggingface import HuggingFaceAdapter + + adapter = HuggingFaceAdapter() + assert adapter.detect_streaming({"stream": False}, None) is False + + def test_detect_no_stream_kwarg(self) -> None: + from shekel.providers.huggingface import HuggingFaceAdapter + + adapter = HuggingFaceAdapter() + assert adapter.detect_streaming({}, None) is False + + +# --------------------------------------------------------------------------- +# TestHuggingFaceStreamWrapping +# --------------------------------------------------------------------------- + + +class TestHuggingFaceStreamWrapping(ProviderTestBase): + """Test wrap_stream() collects usage from chunks.""" + + def test_yields_all_chunks(self) -> None: + from shekel.providers.huggingface import HuggingFaceAdapter + + adapter = HuggingFaceAdapter() + stream = make_hf_stream(input_tokens=10, output_tokens=5) + chunks = list(adapter.wrap_stream(stream)) + assert len(chunks) == 3 + + def test_collects_usage(self) -> None: + from shekel.providers.huggingface import HuggingFaceAdapter + + adapter = HuggingFaceAdapter() + stream = make_hf_stream( + model="HuggingFaceH4/zephyr-7b-beta", + input_tokens=100, + output_tokens=50, + ) + gen = adapter.wrap_stream(stream) + seen: list[Any] = [] + try: + while True: + seen.append(next(gen)) + except StopIteration as e: + it, ot, model = e.value + assert it == 100 + assert ot == 50 + assert model == "HuggingFaceH4/zephyr-7b-beta" + + def test_no_usage_returns_zeros(self) -> None: + """Streaming chunks with no usage return (0, 0, 'unknown').""" + from shekel.providers.huggingface import HuggingFaceAdapter + + adapter = HuggingFaceAdapter() + + def no_usage_stream() -> Generator[MockHFStreamChunk, None, None]: + yield MockHFStreamChunk() + yield MockHFStreamChunk() + + gen = adapter.wrap_stream(no_usage_stream()) + try: + while True: + next(gen) + except StopIteration as e: + it, ot, model = e.value + assert it == 0 + assert ot == 0 + assert model == "unknown" + + def test_empty_stream_returns_zeros(self) -> None: + from shekel.providers.huggingface import HuggingFaceAdapter + + adapter = HuggingFaceAdapter() + + def empty() -> Generator[Any, None, None]: + return + yield # noqa: unreachable + + gen = adapter.wrap_stream(empty()) + try: + while True: + next(gen) + except StopIteration as e: + it, ot, model = e.value + assert (it, ot, model) == (0, 0, "unknown") + + +# --------------------------------------------------------------------------- +# TestHuggingFaceFallbackValidation +# --------------------------------------------------------------------------- + + +class TestHuggingFaceFallbackValidation(ProviderTestBase): + """Test validate_fallback() rejects non-HF models.""" + + def test_accepts_hf_model(self) -> None: + from shekel.providers.huggingface import HuggingFaceAdapter + + adapter = HuggingFaceAdapter() + adapter.validate_fallback("HuggingFaceH4/zephyr-7b-beta") # should not raise + + def test_accepts_any_slash_model(self) -> None: + """Any model with an org/model format is accepted.""" + from shekel.providers.huggingface import HuggingFaceAdapter + + adapter = HuggingFaceAdapter() + adapter.validate_fallback("mistralai/Mistral-7B-Instruct-v0.3") + + def test_rejects_gpt_model(self) -> None: + from shekel.providers.huggingface import HuggingFaceAdapter + + adapter = HuggingFaceAdapter() + with pytest.raises(ValueError, match="(?i)huggingface"): + adapter.validate_fallback("gpt-4o-mini") + + def test_rejects_claude_model(self) -> None: + from shekel.providers.huggingface import HuggingFaceAdapter + + adapter = HuggingFaceAdapter() + with pytest.raises(ValueError, match="(?i)huggingface"): + adapter.validate_fallback("claude-3-haiku-20240307") + + def test_rejects_gemini_model(self) -> None: + from shekel.providers.huggingface import HuggingFaceAdapter + + adapter = HuggingFaceAdapter() + with pytest.raises(ValueError, match="(?i)huggingface"): + adapter.validate_fallback("gemini-2.0-flash") + + +# --------------------------------------------------------------------------- +# TestHuggingFacePatching +# --------------------------------------------------------------------------- + + +class TestHuggingFacePatching(ProviderTestBase): + """Test install_patches() / remove_patches() lifecycle.""" + + def test_install_patches_replaces_chat_completion(self) -> None: + from huggingface_hub.inference import _client + + from shekel import _patch + from shekel.providers.huggingface import HuggingFaceAdapter + + original = _client.InferenceClient.chat_completion + + adapter = HuggingFaceAdapter() + try: + adapter.install_patches() + assert "huggingface_sync" in _patch._originals + assert _client.InferenceClient.chat_completion is not original + finally: + adapter.remove_patches() + if _client.InferenceClient.chat_completion is not original: + _client.InferenceClient.chat_completion = original # type: ignore[method-assign] + + def test_remove_patches_restores_original(self) -> None: + from huggingface_hub.inference import _client + + from shekel import _patch + from shekel.providers.huggingface import HuggingFaceAdapter + + original = _client.InferenceClient.chat_completion + + adapter = HuggingFaceAdapter() + adapter.install_patches() + adapter.remove_patches() + + assert "huggingface_sync" not in _patch._originals + assert _client.InferenceClient.chat_completion is original + + def test_install_patches_idempotent(self) -> None: + """Calling install_patches() twice should not double-wrap.""" + from huggingface_hub.inference import _client + + from shekel.providers.huggingface import HuggingFaceAdapter + + original = _client.InferenceClient.chat_completion + + adapter = HuggingFaceAdapter() + try: + adapter.install_patches() + patched_first = _client.InferenceClient.chat_completion + adapter.install_patches() + assert _client.InferenceClient.chat_completion is patched_first + finally: + adapter.remove_patches() + if _client.InferenceClient.chat_completion is not original: + _client.InferenceClient.chat_completion = original # type: ignore[method-assign] + + def test_remove_patches_safe_without_install(self) -> None: + from shekel.providers.huggingface import HuggingFaceAdapter + + adapter = HuggingFaceAdapter() + adapter.remove_patches() # Should not raise + + def test_no_import_safety(self) -> None: + """install_patches() is a no-op when huggingface-hub is not importable.""" + from shekel.providers.huggingface import HuggingFaceAdapter + + adapter = HuggingFaceAdapter() + with patch.dict( + "sys.modules", + { + "huggingface_hub": None, + "huggingface_hub.inference": None, + "huggingface_hub.inference._client": None, + }, + ): + try: + adapter.install_patches() + except Exception: + pass # ImportError is acceptable + + def test_remove_patches_safe_with_missing_import(self) -> None: + """Lines 61-62: remove_patches() catches ImportError when huggingface-hub unavailable.""" + from shekel.providers.huggingface import HuggingFaceAdapter + + adapter = HuggingFaceAdapter() + with patch.dict( + "sys.modules", + { + "huggingface_hub": None, + "huggingface_hub.inference": None, + "huggingface_hub.inference._client": None, + }, + ): + adapter.remove_patches() # must not raise + + +# --------------------------------------------------------------------------- +# TestHuggingFaceWrapStreamAttributeError +# --------------------------------------------------------------------------- + + +class TestHuggingFaceWrapStreamAttributeError(ProviderTestBase): + """Test wrap_stream() handles chunks with broken usage attributes.""" + + def test_wrap_stream_swallows_usage_attribute_error(self) -> None: + """Lines 101-102: chunk whose usage attrs raise AttributeError is skipped.""" + from unittest.mock import MagicMock + + from shekel.providers.huggingface import HuggingFaceAdapter + + adapter = HuggingFaceAdapter() + + class BrokenUsage: + def __getattr__(self, name: str) -> None: + raise AttributeError(name) + + def stream() -> Any: # type: ignore[return] + chunk = MagicMock() + chunk.usage = BrokenUsage() + yield chunk + + gen = adapter.wrap_stream(stream()) + try: + while True: + next(gen) + except StopIteration as e: + it, ot, model = e.value + assert it == 0 + assert ot == 0 + assert model == "unknown" diff --git a/tests/providers/test_registry.py b/tests/providers/test_registry.py index a72e735..2693aa8 100644 --- a/tests/providers/test_registry.py +++ b/tests/providers/test_registry.py @@ -297,3 +297,25 @@ def test_litellm_import_error_is_swallowed(self): import shekel.providers as providers_mod importlib.reload(providers_mod) + + def test_gemini_import_error_is_swallowed(self): + """Lines 30-31: ImportError when google-genai is absent is silently ignored.""" + import importlib + import sys + from unittest.mock import patch + + with patch.dict(sys.modules, {"shekel.providers.gemini": None}): + import shekel.providers as providers_mod + + importlib.reload(providers_mod) + + def test_huggingface_import_error_is_swallowed(self): + """Lines 37-38: ImportError when huggingface-hub is absent is silently ignored.""" + import importlib + import sys + from unittest.mock import patch + + with patch.dict(sys.modules, {"shekel.providers.huggingface": None}): + import shekel.providers as providers_mod + + importlib.reload(providers_mod) diff --git a/tests/test_anthropic_wrappers.py b/tests/test_anthropic_wrappers.py new file mode 100644 index 0000000..b6c8791 --- /dev/null +++ b/tests/test_anthropic_wrappers.py @@ -0,0 +1,116 @@ +"""Tests for Anthropic provider wrappers in shekel/_patch.py.""" + +from __future__ import annotations + +from unittest.mock import patch + +import pytest + +from shekel import budget + +ANTHROPIC_CREATE = "anthropic.resources.messages.Messages.create" + + +def test_anthropic_malformed_response_records_zero() -> None: + """Response missing .usage attribute records $0 rather than crashing.""" + + class NoUsage: + model = "claude-3-5-sonnet-20241022" + + with patch(ANTHROPIC_CREATE, return_value=NoUsage()): + with budget(max_usd=1.00) as b: + import anthropic + + client = anthropic.Anthropic(api_key="test") + client.messages.create(model="claude-3-5-sonnet-20241022", messages=[], max_tokens=10) + + assert b.spent == pytest.approx(0.0) + + +def test_anthropic_sync_wrapper_raises_if_no_original() -> None: + """RuntimeError when anthropic_sync not in _originals.""" + from shekel._patch import _anthropic_sync_wrapper + + with patch("shekel._patch._originals", {}): + with pytest.raises(RuntimeError, match="anthropic original not stored"): + _anthropic_sync_wrapper(None) + + +def test_wrap_anthropic_stream_swallows_message_start_attribute_error() -> None: + """Broken message_start event is handled without crashing.""" + from shekel._patch import _wrap_anthropic_stream + + class BrokenMessage: + def __getattr__(self, name: str) -> None: + raise AttributeError(name) + + class MessageStartEvent: + type = "message_start" + message = BrokenMessage() + + list(_wrap_anthropic_stream(iter([MessageStartEvent()]))) + + +def test_wrap_anthropic_stream_swallows_message_delta_attribute_error() -> None: + """Broken message_delta event is handled without crashing.""" + from shekel._patch import _wrap_anthropic_stream + + class BrokenUsage: + def __getattr__(self, name: str) -> None: + raise AttributeError(name) + + class MessageDeltaEvent: + type = "message_delta" + usage = BrokenUsage() + + list(_wrap_anthropic_stream(iter([MessageDeltaEvent()]))) + + +@pytest.mark.asyncio +async def test_anthropic_async_wrapper_raises_if_no_original() -> None: + """RuntimeError when anthropic_async not in _originals.""" + from shekel._patch import _anthropic_async_wrapper + + with patch("shekel._patch._originals", {}): + with pytest.raises(RuntimeError, match="anthropic async original not stored"): + await _anthropic_async_wrapper(None) + + +@pytest.mark.asyncio +async def test_wrap_anthropic_stream_async_swallows_message_start_error() -> None: + """Broken message_start in async stream is handled without crashing.""" + from shekel._patch import _wrap_anthropic_stream_async + + class BrokenMessage: + def __getattr__(self, name: str) -> None: + raise AttributeError(name) + + class MessageStartEvent: + type = "message_start" + message = BrokenMessage() + + async def stream(): # type: ignore[return] + yield MessageStartEvent() + + async for _ in _wrap_anthropic_stream_async(stream()): + pass + + +@pytest.mark.asyncio +async def test_wrap_anthropic_stream_async_swallows_message_delta_error() -> None: + """Broken message_delta in async stream is handled without crashing.""" + from shekel._patch import _wrap_anthropic_stream_async + + class BrokenUsage: + def __getattr__(self, name: str) -> None: + raise AttributeError(name) + + class MessageDeltaEvent: + type = "message_delta" + usage = BrokenUsage() + + async def stream(): # type: ignore[return] + yield MessageDeltaEvent() + + async for _ in _wrap_anthropic_stream_async(stream()): + pass diff --git a/tests/test_budget.py b/tests/test_budget.py index 328f135..27e894d 100644 --- a/tests/test_budget.py +++ b/tests/test_budget.py @@ -252,3 +252,41 @@ async def test_async_track_only_no_exception() -> None: assert b.spent > 0.0 assert b.limit is None + + +# --------------------------------------------------------------------------- +# _record edge cases +# --------------------------------------------------------------------------- + + +def test_record_with_no_active_budget_is_noop() -> None: + """_record outside a budget() context returns silently.""" + from shekel._patch import _record + + _record(100, 50, "gpt-4o") + + +def test_record_swallows_pricing_exception() -> None: + """If calculate_cost raises, cost falls back to 0.0 rather than crashing.""" + from unittest.mock import patch + + from shekel._patch import _record + + with budget(max_usd=1.0) as b: + with patch("shekel._pricing.calculate_cost", side_effect=RuntimeError("bad")): + _record(100, 50, "gpt-4o") + assert b.spent == pytest.approx(0.0) + + +def test_record_swallows_adapter_emit_exception() -> None: + """If AdapterRegistry.emit_event raises, the exception is swallowed.""" + from unittest.mock import patch + + from shekel._patch import _record + + with budget(max_usd=1.0): + with patch( + "shekel.integrations.AdapterRegistry.emit_event", + side_effect=RuntimeError("adapter crash"), + ): + _record(100, 50, "gpt-4o-mini") # must not raise diff --git a/tests/test_fallback.py b/tests/test_fallback.py index 6f9b501..93200f8 100644 --- a/tests/test_fallback.py +++ b/tests/test_fallback.py @@ -556,3 +556,48 @@ def test_fallback_empty_string_raises() -> None: """budget(fallback={...}) with empty model raises ValueError at init.""" with pytest.raises(ValueError, match="non-empty string"): budget(max_usd=1.00, fallback={"at_pct": 0.8, "model": ""}) + + +# --------------------------------------------------------------------------- +# Cross-provider fallback validation (_validate_same_provider) +# --------------------------------------------------------------------------- + + +def test_validate_same_provider_anthropic_rejects_openai_model() -> None: + """Anthropic provider + OpenAI fallback model raises ValueError.""" + from shekel._patch import _validate_same_provider + + with pytest.raises(ValueError, match="OpenAI model"): + _validate_same_provider("gpt-4o", "anthropic") + + +def test_validate_same_provider_gemini_rejects_non_gemini() -> None: + """Gemini provider + non-Gemini fallback model raises ValueError.""" + from shekel._patch import _validate_same_provider + + with pytest.raises(ValueError, match="Gemini"): + _validate_same_provider("gpt-4o", "gemini") + + +def test_validate_same_provider_huggingface_rejects_openai() -> None: + """HuggingFace provider + OpenAI fallback raises ValueError.""" + from shekel._patch import _validate_same_provider + + with pytest.raises(ValueError, match="HuggingFace"): + _validate_same_provider("gpt-4o", "huggingface") + + +def test_validate_same_provider_huggingface_rejects_anthropic() -> None: + """HuggingFace provider + Anthropic fallback raises ValueError.""" + from shekel._patch import _validate_same_provider + + with pytest.raises(ValueError, match="HuggingFace"): + _validate_same_provider("claude-3-haiku-20240307", "huggingface") + + +def test_validate_same_provider_huggingface_rejects_gemini() -> None: + """HuggingFace provider + Gemini fallback raises ValueError.""" + from shekel._patch import _validate_same_provider + + with pytest.raises(ValueError, match="HuggingFace"): + _validate_same_provider("gemini-2.0-flash", "huggingface") diff --git a/tests/test_gemini_wrappers.py b/tests/test_gemini_wrappers.py new file mode 100644 index 0000000..274033a --- /dev/null +++ b/tests/test_gemini_wrappers.py @@ -0,0 +1,159 @@ +"""Tests for Google Gemini provider wrappers in shekel/_patch.py.""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest + +from shekel import budget + + +def test_gemini_sync_wrapper_raises_if_no_original() -> None: + """RuntimeError when gemini_sync not in _originals.""" + from shekel._patch import _gemini_sync_wrapper + + with patch("shekel._patch._originals", {}): + with pytest.raises(RuntimeError, match="gemini original not stored"): + _gemini_sync_wrapper(None) + + +def test_gemini_sync_wrapper_records_tokens() -> None: + """Wrapper records token counts through an active budget.""" + from shekel._patch import _gemini_sync_wrapper + + class MockUsage: + prompt_token_count = 100 + candidates_token_count = 50 + + class MockResponse: + usage_metadata = MockUsage() + + def fake_sync(self: object, *args: object, **kwargs: object) -> MockResponse: + return MockResponse() + + def fake_stream(self: object, *args: object, **kwargs: object) -> MockResponse: + return MockResponse() + + # Include both gemini_sync and gemini_stream so install_patches() skips re-install + with patch.dict( + "shekel._patch._originals", + {"gemini_sync": fake_sync, "gemini_stream": fake_stream}, + ): + with budget(max_usd=1.0) as b: + result = _gemini_sync_wrapper(None, model="gemini-2.0-flash") + + assert isinstance(result, MockResponse) + assert b.spent > 0 + + +def test_gemini_sync_wrapper_no_budget() -> None: + """Wrapper works correctly when no budget context is active.""" + from shekel._patch import _gemini_sync_wrapper + + class MockUsage: + prompt_token_count = 10 + candidates_token_count = 5 + + class MockResponse: + usage_metadata = MockUsage() + + def fake_original(self: object, *args: object, **kwargs: object) -> MockResponse: + return MockResponse() + + with patch.dict("shekel._patch._originals", {"gemini_sync": fake_original}): + result = _gemini_sync_wrapper(None, model="gemini-2.0-flash") + + assert isinstance(result, MockResponse) + + +def test_gemini_stream_wrapper_raises_if_no_original() -> None: + """RuntimeError when gemini_stream not in _originals.""" + from shekel._patch import _gemini_stream_wrapper + + with patch("shekel._patch._originals", {}): + with pytest.raises(RuntimeError, match="gemini stream original not stored"): + list(_gemini_stream_wrapper(None)) + + +def test_gemini_stream_wrapper_yields_chunks() -> None: + """Stream wrapper yields all chunks from the underlying generator.""" + from shekel._patch import _gemini_stream_wrapper + + class MockChunk: + usage_metadata = None + + def fake_sync(self: object, *args: object, **kwargs: object) -> None: + return None + + def fake_stream(self: object, *args: object, **kwargs: object): # type: ignore[return] + yield MockChunk() + yield MockChunk() + + # Include both keys so budget's install_patches() skips re-installing + with patch.dict( + "shekel._patch._originals", + {"gemini_sync": fake_sync, "gemini_stream": fake_stream}, + ): + with budget(max_usd=1.0): + chunks = list(_gemini_stream_wrapper(None, model="gemini-2.0-flash")) + + assert len(chunks) == 2 + + +def test_wrap_gemini_stream_swallows_usage_attribute_error() -> None: + """Chunk whose usage_metadata attrs raise AttributeError is skipped.""" + from shekel._patch import _wrap_gemini_stream + + class BrokenUsage: + def __getattr__(self, name: str) -> None: + raise AttributeError(name) + + def stream(): # type: ignore[return] + chunk = MagicMock() + chunk.usage_metadata = BrokenUsage() + yield chunk + + list(_wrap_gemini_stream(stream(), "gemini-2.0-flash")) # must not raise + + +def test_wrap_gemini_stream_records_usage() -> None: + """usage_metadata tokens are captured and charged to the active budget.""" + from shekel._patch import _wrap_gemini_stream + + class MockUsage: + prompt_token_count = 50 + candidates_token_count = 25 + + class MockChunk: + usage_metadata = MockUsage() + + def stream(): # type: ignore[return] + yield MockChunk() + + # Include both gemini keys so install_patches() is skipped inside budget context + with patch.dict( + "shekel._patch._originals", + {"gemini_sync": object(), "gemini_stream": object()}, + ): + with budget(max_usd=1.0) as b: + list(_wrap_gemini_stream(stream(), "gemini-2.0-flash")) + + assert b.spent > 0 + + +def test_extract_gemini_tokens_attribute_error() -> None: + """Response with no attributes returns (0, 0, 'unknown').""" + from shekel._patch import _extract_gemini_tokens + + response = MagicMock(spec=[]) # no attributes + assert _extract_gemini_tokens(response) == (0, 0, "unknown") + + +def test_extract_gemini_tokens_none_usage() -> None: + """response.usage_metadata is None returns (0, 0, 'unknown').""" + from shekel._patch import _extract_gemini_tokens + + response = MagicMock() + response.usage_metadata = None + assert _extract_gemini_tokens(response) == (0, 0, "unknown") diff --git a/tests/test_huggingface_wrappers.py b/tests/test_huggingface_wrappers.py new file mode 100644 index 0000000..ae62c0d --- /dev/null +++ b/tests/test_huggingface_wrappers.py @@ -0,0 +1,98 @@ +"""Tests for HuggingFace provider wrappers in shekel/_patch.py.""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest + +from shekel import budget + + +def test_huggingface_sync_wrapper_raises_if_no_original() -> None: + """RuntimeError when huggingface_sync not in _originals.""" + from shekel._patch import _huggingface_sync_wrapper + + with patch("shekel._patch._originals", {}): + with pytest.raises(RuntimeError, match="huggingface original not stored"): + _huggingface_sync_wrapper(None) + + +def test_huggingface_sync_wrapper_records_tokens() -> None: + """Non-streaming path extracts tokens and charges the active budget.""" + from shekel._patch import _huggingface_sync_wrapper + + class MockUsage: + prompt_tokens = 80 + completion_tokens = 40 + + class MockResponse: + model = "HuggingFaceH4/zephyr-7b-beta" + usage = MockUsage() + + def fake_original(self: object, *args: object, **kwargs: object) -> MockResponse: + return MockResponse() + + # Include the key so install_patches() skips re-install inside budget context + with patch.dict("shekel._patch._originals", {"huggingface_sync": fake_original}): + with budget(max_usd=1.0, price_per_1k_tokens={"input": 0.001, "output": 0.001}) as b: + result = _huggingface_sync_wrapper(None) + + assert isinstance(result, MockResponse) + assert b.spent > 0 + + +def test_huggingface_sync_wrapper_stream_path() -> None: + """stream=True delegates to the streaming path and returns a generator.""" + from shekel._patch import _huggingface_sync_wrapper + + class MockChunk: + usage = None + + def fake_original(self: object, *args: object, **kwargs: object): # type: ignore[return] + yield MockChunk() + + with patch.dict("shekel._patch._originals", {"huggingface_sync": fake_original}): + gen = _huggingface_sync_wrapper(None, stream=True) + chunks = list(gen) + + assert len(chunks) == 1 + + +def test_wrap_huggingface_stream_swallows_usage_attribute_error() -> None: + """Chunk whose usage attrs raise AttributeError is skipped without crashing.""" + from shekel._patch import _wrap_huggingface_stream + + class BrokenUsage: + def __getattr__(self, name: str) -> None: + raise AttributeError(name) + + def stream(): # type: ignore[return] + chunk = MagicMock() + chunk.usage = BrokenUsage() + yield chunk + + list(_wrap_huggingface_stream(stream())) # must not raise + + +def test_wrap_huggingface_stream_records_usage() -> None: + """Usage tokens from streaming chunks are charged to the active budget.""" + from shekel._patch import _wrap_huggingface_stream + + class MockUsage: + prompt_tokens = 60 + completion_tokens = 30 + + class MockChunk: + model = "HuggingFaceH4/zephyr-7b-beta" + usage = MockUsage() + + def stream(): # type: ignore[return] + yield MockChunk() + + # Prevent install_patches() from overwriting _originals inside budget context + with patch.dict("shekel._patch._originals", {"huggingface_sync": object()}): + with budget(max_usd=1.0, price_per_1k_tokens={"input": 0.001, "output": 0.001}) as b: + list(_wrap_huggingface_stream(stream())) + + assert b.spent > 0 diff --git a/tests/test_litellm_wrappers.py b/tests/test_litellm_wrappers.py new file mode 100644 index 0000000..4b0ff28 --- /dev/null +++ b/tests/test_litellm_wrappers.py @@ -0,0 +1,126 @@ +"""Tests for LiteLLM provider wrappers in shekel/_patch.py.""" + +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from shekel import budget + + +def test_litellm_sync_wrapper_raises_if_no_original() -> None: + """RuntimeError when litellm_sync not in _originals.""" + from shekel._patch import _litellm_sync_wrapper + + with patch("shekel._patch._originals", {}): + with pytest.raises(RuntimeError, match="litellm original not stored"): + _litellm_sync_wrapper() + + +def test_wrap_litellm_stream_swallows_chunk_attribute_error() -> None: + """Broken usage attrs in litellm stream chunk handled without crashing.""" + from shekel._patch import _wrap_litellm_stream + + class BrokenUsage: + def __getattr__(self, name: str) -> None: + raise AttributeError(name) + + def stream(): # type: ignore[return] + chunk = MagicMock() + chunk.usage = BrokenUsage() + yield chunk + + list(_wrap_litellm_stream(stream())) + + +@pytest.mark.asyncio +async def test_litellm_async_wrapper_raises_if_no_original() -> None: + """RuntimeError when litellm_async not in _originals.""" + from shekel._patch import _litellm_async_wrapper + + with patch("shekel._patch._originals", {}): + with pytest.raises(RuntimeError, match="litellm async original not stored"): + await _litellm_async_wrapper() + + +@pytest.mark.asyncio +async def test_litellm_async_wrapper_stream_path() -> None: + """stream=True branch in async wrapper returns an async generator.""" + from shekel._patch import _litellm_async_wrapper + + async def mock_async_stream(): # type: ignore[return] + chunk = MagicMock() + chunk.usage = MagicMock() + chunk.usage.prompt_tokens = 10 + chunk.usage.completion_tokens = 5 + chunk.model = "gpt-4o-mini" + yield chunk + + original = AsyncMock(return_value=mock_async_stream()) + + mock_budget = MagicMock() + mock_budget._using_fallback = False + + with patch("shekel._patch._originals", {"litellm_async": original}): + with patch("shekel._context.get_active_budget", return_value=mock_budget): + stream = await _litellm_async_wrapper(model="gpt-4o-mini", messages=[], stream=True) + assert stream is not None + async for _ in stream: + pass + + +@pytest.mark.asyncio +async def test_wrap_litellm_stream_async_records_cost() -> None: + """Async litellm stream records tokens from the final usage chunk.""" + from shekel._patch import _wrap_litellm_stream_async + + async def stream(): # type: ignore[return] + chunk = MagicMock() + chunk.usage = None + yield chunk + final = MagicMock() + final.usage = MagicMock() + final.usage.prompt_tokens = 100 + final.usage.completion_tokens = 50 + final.model = "gpt-4o-mini" + yield final + + with budget(max_usd=1.0) as b: + async for _ in _wrap_litellm_stream_async(stream()): + pass + assert b.spent > 0 + + +@pytest.mark.asyncio +async def test_wrap_litellm_stream_async_swallows_attribute_error() -> None: + """Broken usage in async litellm chunk is handled without crashing.""" + from shekel._patch import _wrap_litellm_stream_async + + class BrokenUsage: + def __getattr__(self, name: str) -> None: + raise AttributeError(name) + + async def stream(): # type: ignore[return] + chunk = MagicMock() + chunk.usage = BrokenUsage() + yield chunk + + async for _ in _wrap_litellm_stream_async(stream()): + pass + + +@pytest.mark.asyncio +async def test_wrap_litellm_stream_async_no_usage_fallback() -> None: + """No usage chunks — records $0 rather than crashing.""" + from shekel._patch import _wrap_litellm_stream_async + + async def stream(): # type: ignore[return] + chunk = MagicMock() + chunk.usage = None + yield chunk + + with budget(max_usd=1.0) as b: + async for _ in _wrap_litellm_stream_async(stream()): + pass + assert b.spent == pytest.approx(0.0) diff --git a/tests/test_openai_wrappers.py b/tests/test_openai_wrappers.py new file mode 100644 index 0000000..4394081 --- /dev/null +++ b/tests/test_openai_wrappers.py @@ -0,0 +1,105 @@ +"""Tests for OpenAI provider wrappers in shekel/_patch.py.""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest + +from shekel import budget + +OPENAI_CREATE = "openai.resources.chat.completions.Completions.create" + + +def test_record_unknown_model_falls_back_to_zero_cost() -> None: + """Unknown model with no price override records $0 rather than crashing.""" + fake = MagicMock() + fake.model = "gpt-999-not-real" + fake.usage.prompt_tokens = 100 + fake.usage.completion_tokens = 50 + + with patch(OPENAI_CREATE, return_value=fake): + with budget(max_usd=1.00) as b: + import openai + + client = openai.OpenAI(api_key="test") + client.chat.completions.create(model="gpt-999-not-real", messages=[]) + + assert b.spent == pytest.approx(0.0) + + +def test_extract_openai_tokens_attribute_error() -> None: + """Response with no attributes returns (0, 0, 'unknown').""" + from shekel._patch import _extract_openai_tokens + + response = MagicMock(spec=[]) # no attributes at all + assert _extract_openai_tokens(response) == (0, 0, "unknown") + + +def test_openai_sync_wrapper_raises_if_no_original() -> None: + """RuntimeError when openai_sync not in _originals.""" + from shekel._patch import _openai_sync_wrapper + + with patch("shekel._patch._originals", {}): + with pytest.raises(RuntimeError, match="openai original not stored"): + _openai_sync_wrapper(None) + + +def test_wrap_openai_stream_swallows_chunk_attribute_error() -> None: + """Chunk whose usage attrs raise AttributeError is handled without crashing.""" + from shekel._patch import _wrap_openai_stream + + class BrokenUsage: + def __getattr__(self, name: str) -> None: + raise AttributeError(name) + + def stream(): # type: ignore[return] + chunk = MagicMock() + chunk.usage = BrokenUsage() + yield chunk + + list(_wrap_openai_stream(stream())) # must not raise + + +@pytest.mark.asyncio +async def test_openai_async_wrapper_raises_if_no_original() -> None: + """RuntimeError when openai_async not in _originals.""" + from shekel._patch import _openai_async_wrapper + + with patch("shekel._patch._originals", {}): + with pytest.raises(RuntimeError, match="openai async original not stored"): + await _openai_async_wrapper(None) + + +@pytest.mark.asyncio +async def test_wrap_openai_stream_async_swallows_attribute_error() -> None: + """Broken usage in async stream chunk is handled without crashing.""" + from shekel._patch import _wrap_openai_stream_async + + class BrokenUsage: + def __getattr__(self, name: str) -> None: + raise AttributeError(name) + + async def stream(): # type: ignore[return] + chunk = MagicMock() + chunk.usage = BrokenUsage() + yield chunk + + async for _ in _wrap_openai_stream_async(stream()): + pass + + +@pytest.mark.asyncio +async def test_wrap_openai_stream_async_no_usage_chunks() -> None: + """When no chunk has usage, records $0 rather than crashing.""" + from shekel._patch import _wrap_openai_stream_async + + async def stream(): # type: ignore[return] + chunk = MagicMock() + chunk.usage = None + yield chunk + + with budget(max_usd=1.0) as b: + async for _ in _wrap_openai_stream_async(stream()): + pass + assert b.spent == pytest.approx(0.0) diff --git a/tests/test_patch_coverage.py b/tests/test_patch_coverage.py deleted file mode 100644 index 2252081..0000000 --- a/tests/test_patch_coverage.py +++ /dev/null @@ -1,403 +0,0 @@ -"""Tests to reach 100% coverage of shekel/_patch.py. - -Each test targets specific uncovered lines identified by coverage analysis. -""" - -from __future__ import annotations - -from unittest.mock import AsyncMock, MagicMock, patch - -import pytest - -# --------------------------------------------------------------------------- -# _validate_same_provider (line 69) -# --------------------------------------------------------------------------- - - -def test_validate_same_provider_anthropic_rejects_openai_model(): - """Line 69: anthropic provider + openai fallback model raises ValueError.""" - from shekel._patch import _validate_same_provider - - with pytest.raises(ValueError, match="OpenAI model"): - _validate_same_provider("gpt-4o", "anthropic") - - -# --------------------------------------------------------------------------- -# _extract_openai_tokens (lines 101-102) -# --------------------------------------------------------------------------- - - -def test_extract_openai_tokens_attribute_error(): - """Lines 101-102: response with no attributes returns (0, 0, 'unknown').""" - from shekel._patch import _extract_openai_tokens - - response = MagicMock(spec=[]) # no attributes at all - assert _extract_openai_tokens(response) == (0, 0, "unknown") - - -# --------------------------------------------------------------------------- -# _record (lines 121-122 and 141-143) -# --------------------------------------------------------------------------- - - -def test_record_swallows_pricing_exception(): - """Lines 121-122: if calculate_cost raises, cost falls back to 0.0.""" - from shekel import budget - from shekel._patch import _record - - with budget(max_usd=1.0) as b: - with patch("shekel._pricing.calculate_cost", side_effect=RuntimeError("bad")): - _record(100, 50, "gpt-4o") - assert b.spent == pytest.approx(0.0) - - -def test_record_swallows_adapter_emit_exception(): - """Lines 141-143: if AdapterRegistry.emit_event raises, exception is swallowed.""" - from shekel import budget - from shekel._patch import _record - - with budget(max_usd=1.0): - with patch( - "shekel.integrations.AdapterRegistry.emit_event", - side_effect=RuntimeError("adapter crash"), - ): - _record(100, 50, "gpt-4o-mini") # must not raise - - -# --------------------------------------------------------------------------- -# _openai_sync_wrapper (line 154) -# --------------------------------------------------------------------------- - - -def test_openai_sync_wrapper_raises_if_no_original(): - """Line 154: RuntimeError when openai_sync not in _originals.""" - from shekel._patch import _openai_sync_wrapper - - with patch("shekel._patch._originals", {}): - with pytest.raises(RuntimeError, match="openai original not stored"): - _openai_sync_wrapper(None) - - -# --------------------------------------------------------------------------- -# _wrap_openai_stream (lines 182-183) -# --------------------------------------------------------------------------- - - -def test_wrap_openai_stream_swallows_chunk_attribute_error(): - """Lines 182-183: chunk whose usage attrs raise AttributeError is handled.""" - from shekel._patch import _wrap_openai_stream - - class BrokenUsage: - def __getattr__(self, name: str) -> None: - raise AttributeError(name) - - def stream(): - chunk = MagicMock() - chunk.usage = BrokenUsage() - yield chunk - - list(_wrap_openai_stream(stream())) # must not raise - - -# --------------------------------------------------------------------------- -# _openai_async_wrapper (line 201) -# --------------------------------------------------------------------------- - - -@pytest.mark.asyncio -async def test_openai_async_wrapper_raises_if_no_original(): - """Line 201: RuntimeError when openai_async not in _originals.""" - from shekel._patch import _openai_async_wrapper - - with patch("shekel._patch._originals", {}): - with pytest.raises(RuntimeError, match="openai async original not stored"): - await _openai_async_wrapper(None) - - -# --------------------------------------------------------------------------- -# _wrap_openai_stream_async (lines 229-230 and 236) -# --------------------------------------------------------------------------- - - -@pytest.mark.asyncio -async def test_wrap_openai_stream_async_swallows_attribute_error(): - """Lines 229-230: broken usage in async stream chunk is handled.""" - from shekel._patch import _wrap_openai_stream_async - - class BrokenUsage: - def __getattr__(self, name: str) -> None: - raise AttributeError(name) - - async def stream(): - chunk = MagicMock() - chunk.usage = BrokenUsage() - yield chunk - - async for _ in _wrap_openai_stream_async(stream()): - pass - - -@pytest.mark.asyncio -async def test_wrap_openai_stream_async_no_usage_chunks(): - """Line 236: when no chunk has usage, falls back to (0, 0, 'unknown').""" - from shekel import budget - from shekel._patch import _wrap_openai_stream_async - - async def stream(): - chunk = MagicMock() - chunk.usage = None - yield chunk - - with budget(max_usd=1.0) as b: - async for _ in _wrap_openai_stream_async(stream()): - pass - assert b.spent == pytest.approx(0.0) - - -# --------------------------------------------------------------------------- -# _anthropic_sync_wrapper (line 248) -# --------------------------------------------------------------------------- - - -def test_anthropic_sync_wrapper_raises_if_no_original(): - """Line 248: RuntimeError when anthropic_sync not in _originals.""" - from shekel._patch import _anthropic_sync_wrapper - - with patch("shekel._patch._originals", {}): - with pytest.raises(RuntimeError, match="anthropic original not stored"): - _anthropic_sync_wrapper(None) - - -# --------------------------------------------------------------------------- -# _wrap_anthropic_stream (lines 277-278 and 282-283) -# --------------------------------------------------------------------------- - - -def test_wrap_anthropic_stream_swallows_message_start_attribute_error(): - """Lines 277-278: broken message_start event handled gracefully.""" - from shekel._patch import _wrap_anthropic_stream - - class BrokenMessage: - def __getattr__(self, name: str) -> None: - raise AttributeError(name) - - class MessageStartEvent: - type = "message_start" - message = BrokenMessage() - - list(_wrap_anthropic_stream(iter([MessageStartEvent()]))) - - -def test_wrap_anthropic_stream_swallows_message_delta_attribute_error(): - """Lines 282-283: broken message_delta event handled gracefully.""" - from shekel._patch import _wrap_anthropic_stream - - class BrokenUsage: - def __getattr__(self, name: str) -> None: - raise AttributeError(name) - - class MessageDeltaEvent: - type = "message_delta" - usage = BrokenUsage() - - list(_wrap_anthropic_stream(iter([MessageDeltaEvent()]))) - - -# --------------------------------------------------------------------------- -# _anthropic_async_wrapper (line 297) -# --------------------------------------------------------------------------- - - -@pytest.mark.asyncio -async def test_anthropic_async_wrapper_raises_if_no_original(): - """Line 297: RuntimeError when anthropic_async not in _originals.""" - from shekel._patch import _anthropic_async_wrapper - - with patch("shekel._patch._originals", {}): - with pytest.raises(RuntimeError, match="anthropic async original not stored"): - await _anthropic_async_wrapper(None) - - -# --------------------------------------------------------------------------- -# _wrap_anthropic_stream_async (lines 325-326 and 330-331) -# --------------------------------------------------------------------------- - - -@pytest.mark.asyncio -async def test_wrap_anthropic_stream_async_swallows_message_start_error(): - """Lines 325-326: broken message_start in async stream handled.""" - from shekel._patch import _wrap_anthropic_stream_async - - class BrokenMessage: - def __getattr__(self, name: str) -> None: - raise AttributeError(name) - - class MessageStartEvent: - type = "message_start" - message = BrokenMessage() - - async def stream(): - yield MessageStartEvent() - - async for _ in _wrap_anthropic_stream_async(stream()): - pass - - -@pytest.mark.asyncio -async def test_wrap_anthropic_stream_async_swallows_message_delta_error(): - """Lines 330-331: broken message_delta in async stream handled.""" - from shekel._patch import _wrap_anthropic_stream_async - - class BrokenUsage: - def __getattr__(self, name: str) -> None: - raise AttributeError(name) - - class MessageDeltaEvent: - type = "message_delta" - usage = BrokenUsage() - - async def stream(): - yield MessageDeltaEvent() - - async for _ in _wrap_anthropic_stream_async(stream()): - pass - - -# --------------------------------------------------------------------------- -# _litellm_sync_wrapper (line 345) -# --------------------------------------------------------------------------- - - -def test_litellm_sync_wrapper_raises_if_no_original(): - """Line 345: RuntimeError when litellm_sync not in _originals.""" - from shekel._patch import _litellm_sync_wrapper - - with patch("shekel._patch._originals", {}): - with pytest.raises(RuntimeError, match="litellm original not stored"): - _litellm_sync_wrapper() - - -# --------------------------------------------------------------------------- -# _wrap_litellm_stream (lines 372-373) -# --------------------------------------------------------------------------- - - -def test_wrap_litellm_stream_swallows_chunk_attribute_error(): - """Lines 372-373: broken usage attrs in litellm stream chunk handled.""" - from shekel._patch import _wrap_litellm_stream - - class BrokenUsage: - def __getattr__(self, name: str) -> None: - raise AttributeError(name) - - def stream(): - chunk = MagicMock() - chunk.usage = BrokenUsage() - yield chunk - - list(_wrap_litellm_stream(stream())) - - -# --------------------------------------------------------------------------- -# _litellm_async_wrapper (lines 388 and 395-397) -# --------------------------------------------------------------------------- - - -@pytest.mark.asyncio -async def test_litellm_async_wrapper_raises_if_no_original(): - """Line 388: RuntimeError when litellm_async not in _originals.""" - from shekel._patch import _litellm_async_wrapper - - with patch("shekel._patch._originals", {}): - with pytest.raises(RuntimeError, match="litellm async original not stored"): - await _litellm_async_wrapper() - - -@pytest.mark.asyncio -async def test_litellm_async_wrapper_stream_path(): - """Lines 395-397: stream=True branch in async wrapper returns async generator.""" - from shekel._patch import _litellm_async_wrapper - - async def mock_async_stream(): - chunk = MagicMock() - chunk.usage = MagicMock() - chunk.usage.prompt_tokens = 10 - chunk.usage.completion_tokens = 5 - chunk.model = "gpt-4o-mini" - yield chunk - - original = AsyncMock(return_value=mock_async_stream()) - - mock_budget = MagicMock() - mock_budget._using_fallback = False - - with patch("shekel._patch._originals", {"litellm_async": original}): - with patch("shekel._context.get_active_budget", return_value=mock_budget): - stream = await _litellm_async_wrapper(model="gpt-4o-mini", messages=[], stream=True) - assert stream is not None - # drain the generator to hit the finally block - async for _ in stream: - pass - - -# --------------------------------------------------------------------------- -# _wrap_litellm_stream_async (lines 406-420) -# --------------------------------------------------------------------------- - - -@pytest.mark.asyncio -async def test_wrap_litellm_stream_async_records_cost(): - """Lines 406-414: async litellm stream records tokens from usage chunk.""" - from shekel import budget - from shekel._patch import _wrap_litellm_stream_async - - async def stream(): - chunk = MagicMock() - chunk.usage = None - yield chunk - final = MagicMock() - final.usage = MagicMock() - final.usage.prompt_tokens = 100 - final.usage.completion_tokens = 50 - final.model = "gpt-4o-mini" - yield final - - with budget(max_usd=1.0) as b: - async for _ in _wrap_litellm_stream_async(stream()): - pass - assert b.spent > 0 - - -@pytest.mark.asyncio -async def test_wrap_litellm_stream_async_swallows_attribute_error(): - """Lines 415-416: broken usage in async litellm chunk handled.""" - from shekel._patch import _wrap_litellm_stream_async - - class BrokenUsage: - def __getattr__(self, name: str) -> None: - raise AttributeError(name) - - async def stream(): - chunk = MagicMock() - chunk.usage = BrokenUsage() - yield chunk - - async for _ in _wrap_litellm_stream_async(stream()): - pass - - -@pytest.mark.asyncio -async def test_wrap_litellm_stream_async_no_usage_fallback(): - """Line 420: no usage chunks → falls back to (0, 0, 'unknown').""" - from shekel import budget - from shekel._patch import _wrap_litellm_stream_async - - async def stream(): - chunk = MagicMock() - chunk.usage = None - yield chunk - - with budget(max_usd=1.0) as b: - async for _ in _wrap_litellm_stream_async(stream()): - pass - assert b.spent == pytest.approx(0.0) diff --git a/tests/test_patching.py b/tests/test_patching.py deleted file mode 100644 index d8d842a..0000000 --- a/tests/test_patching.py +++ /dev/null @@ -1,51 +0,0 @@ -from __future__ import annotations - -from unittest.mock import patch - -import pytest - -from shekel import budget -from shekel._patch import _record - -OPENAI_CREATE = "openai.resources.chat.completions.Completions.create" -ANTHROPIC_CREATE = "anthropic.resources.messages.Messages.create" - - -def test_record_with_no_active_budget_is_noop() -> None: - """_record outside a budget() context should return silently.""" - _record(100, 50, "gpt-4o") - - -def test_record_unknown_model_falls_back_to_zero_cost() -> None: - """Unknown model with no price override records $0 rather than crashing.""" - from unittest.mock import MagicMock - - fake = MagicMock() - fake.model = "gpt-999-not-real" - fake.usage.prompt_tokens = 100 - fake.usage.completion_tokens = 50 - - with patch(OPENAI_CREATE, return_value=fake): - with budget(max_usd=1.00) as b: - import openai - - client = openai.OpenAI(api_key="test") - client.chat.completions.create(model="gpt-999-not-real", messages=[]) - - assert b.spent == pytest.approx(0.0) - - -def test_anthropic_malformed_response_records_zero() -> None: - """Response missing .usage attribute records $0 rather than crashing.""" - - class NoUsage: - model = "claude-3-5-sonnet-20241022" - - with patch(ANTHROPIC_CREATE, return_value=NoUsage()): - with budget(max_usd=1.00) as b: - import anthropic - - client = anthropic.Anthropic(api_key="test") - client.messages.create(model="claude-3-5-sonnet-20241022", messages=[], max_tokens=10) - - assert b.spent == pytest.approx(0.0) diff --git a/tests/test_pricing.py b/tests/test_pricing.py index 30b208a..1e53916 100644 --- a/tests/test_pricing.py +++ b/tests/test_pricing.py @@ -274,7 +274,7 @@ def test_prices_json_schema() -> None: def test_list_models_returns_ten() -> None: - assert len(list_models()) == 10 + assert len(list_models()) == 13 def test_anthropic_model_in_list() -> None: