Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions llm_clients/claude_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@ def __init__(
name: str,
system_prompt: Optional[str] = None,
model_name: Optional[str] = None,
max_retries: int = 3,
**kwargs,
):
super().__init__(name, system_prompt)
super().__init__(name, system_prompt, max_retries=max_retries)

if not Config.ANTHROPIC_API_KEY:
raise ValueError("ANTHROPIC_API_KEY not found in environment variables")
Expand Down Expand Up @@ -98,7 +99,20 @@ async def generate_response(

try:
start_time = time.time()
response = await self.llm.ainvoke(messages)

# Use retry logic for API call
async def _invoke():
return await self.llm.ainvoke(messages)

def _validate_response(response_obj):
"""Validate that response has non-empty content."""
return bool(response_obj.text and response_obj.text.strip())

response = await self._retry_with_backoff(
_invoke,
operation_name="generate_response",
response_validator=_validate_response,
)
end_time = time.time()

# Extract metadata from response
Expand Down Expand Up @@ -148,7 +162,7 @@ async def generate_response(
"error": str(e),
"usage": {},
}
return f"Error generating response: {str(e)}"
raise RuntimeError(f"Error generating response: {str(e)}") from e

async def generate_structured_response(
self, message: Optional[str], response_model: Type[T]
Expand Down
20 changes: 17 additions & 3 deletions llm_clients/gemini_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@ def __init__(
name: str,
system_prompt: Optional[str] = None,
model_name: Optional[str] = None,
max_retries: int = 3,
**kwargs,
):
super().__init__(name, system_prompt)
super().__init__(name, system_prompt, max_retries=max_retries)

if not Config.GOOGLE_API_KEY:
raise ValueError("GOOGLE_API_KEY not found in environment variables")
Expand Down Expand Up @@ -96,7 +97,20 @@ async def generate_response(

try:
start_time = time.time()
response = await self.llm.ainvoke(messages)

# Use retry logic for API call
async def _invoke():
return await self.llm.ainvoke(messages)

def _validate_response(response_obj):
"""Validate that response has non-empty content."""
return bool(response_obj.text and response_obj.text.strip())

response = await self._retry_with_backoff(
_invoke,
operation_name="generate_response",
response_validator=_validate_response,
)
end_time = time.time()

# Extract metadata from response
Expand Down Expand Up @@ -157,7 +171,7 @@ async def generate_response(
"error": str(e),
"usage": {},
}
return f"Error generating response: {str(e)}"
raise RuntimeError(f"Error generating response: {str(e)}") from e

def get_last_response_metadata(self) -> Dict[str, Any]:
"""Get metadata from the last response."""
Expand Down
30 changes: 26 additions & 4 deletions llm_clients/llama_llm.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
from typing import Any, Dict, List, Optional

from langchain_community.llms import Ollama
Expand All @@ -21,9 +22,10 @@ def __init__(
name: str,
system_prompt: Optional[str] = None,
model_name: Optional[str] = None,
max_retries: int = 3,
**kwargs,
):
super().__init__(name, system_prompt)
super().__init__(name, system_prompt, max_retries=max_retries)

# Use provided model name or fall back to config default
self.model_name = model_name or Config.get_llama_config()["model"]
Expand Down Expand Up @@ -59,11 +61,31 @@ async def generate_response(
)

# Ollama doesn't have native async support in langchain-community
# So we'll use the synchronous version
response = self.llm.invoke(full_message)
# So we'll use the synchronous version, wrapped in async for retry logic
async def _invoke():
# Run sync invoke in thread pool to avoid blocking
return await asyncio.to_thread(self.llm.invoke, full_message)

def _validate_response(response_obj):
"""Validate that response has non-empty content."""
# Ollama may return string directly or a message object
if isinstance(response_obj, str):
return bool(response_obj and response_obj.strip())
elif hasattr(response_obj, "text"):
return bool(response_obj.text and response_obj.text.strip())
elif hasattr(response_obj, "content"):
return bool(response_obj.content and response_obj.content.strip())
# If we can't determine, assume valid
return True

response = await self._retry_with_backoff(
_invoke,
operation_name="generate_response",
response_validator=_validate_response,
)
return response
except Exception as e:
return f"Error generating response: {str(e)}"
raise RuntimeError(f"Error generating response: {str(e)}") from e

def set_system_prompt(self, system_prompt: str) -> None:
"""Set or update the system prompt."""
Expand Down
195 changes: 193 additions & 2 deletions llm_clients/llm_interface.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import asyncio
import re
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Type, TypeVar
from typing import Any, Callable, Dict, List, Optional, Type, TypeVar

from pydantic import BaseModel

Expand All @@ -13,9 +15,12 @@ class LLMInterface(ABC):
must support basic text generation and system prompt management.
"""

def __init__(self, name: str, system_prompt: Optional[str] = None):
def __init__(
self, name: str, system_prompt: Optional[str] = None, max_retries: int = 3
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

until now I have tried to avoid (not always successfully) putting defaults outside of the main entry point, as it might introduce subtle bugs.

while it makes the script code a little longer, i wonder if putting an optional arg there and passing it down makes it more coherent with the existing codebase?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

of course, this opens a new question: all the functions are getting bloated, and maybe an config file/arg is now needed

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this!
yeah, I don't like the bloat! I noticed langchain also supports max_retries, but that's only for langchain-supported endpoints.
I'll revisit this after tending to the other bloat which are the other open PRs 😅

):
self.name = name
self.system_prompt = system_prompt or ""
self.max_retries = max_retries

@abstractmethod
async def generate_response(
Expand Down Expand Up @@ -47,6 +52,192 @@ def get_name(self) -> str:
"""Get the name of this LLM instance."""
return self.name

def _extract_http_status_code(self, exception: Exception) -> Optional[int]:
"""Extract HTTP status code from exception if available.

LangChain and various HTTP libraries wrap HTTP errors differently.
This method attempts to extract the status code from common
exception types.
"""
# Check for status_code attribute (common in HTTPException)
if hasattr(exception, "status_code"):
status_code = getattr(exception, "status_code")
if status_code is not None:
return int(status_code)

# Check for response attribute with status_code
if hasattr(exception, "response"):
response = getattr(exception, "response")
if hasattr(response, "status_code"):
status_code = getattr(response, "status_code")
if status_code is not None:
return int(status_code)
if hasattr(response, "status"):
status = getattr(response, "status")
if status is not None:
return int(status)

# Check for status attribute directly
if hasattr(exception, "status"):
status = getattr(exception, "status")
if status is not None:
return int(status)

# Check exception message for status codes (fallback)
error_str = str(exception).lower()
for code in [429, 500, 502, 503, 504, 529]:
if f"status {code}" in error_str or f"status_code {code}" in error_str:
return code

# Check for "Error Code" pattern in exception message
if "error code" in error_str:
# Try to extract numeric code after "error code"
match = re.search(r"error code[:\s]+(\d+)", error_str, re.IGNORECASE)
if match:
try:
return int(match.group(1))
except (ValueError, TypeError):
pass

return None

def _extract_retry_after(self, exception: Exception) -> Optional[int]:
"""Extract Retry-After header value from exception if available."""
if hasattr(exception, "response"):
response = getattr(exception, "response")
if hasattr(response, "headers"):
headers = getattr(response, "headers")
retry_after = headers.get("Retry-After") or headers.get("retry-after")
if retry_after:
try:
return int(retry_after)
except (ValueError, TypeError):
pass
return None

async def _retry_with_backoff(
self,
func: Callable[[], Any],
operation_name: str = "operation",
response_validator: Optional[Callable[[Any], bool]] = None,
) -> Any:
"""Execute a function with retry logic for transient HTTP errors.

Handles the following HTTP status codes:
- 429 (Too Many Requests): Respects Retry-After header,
otherwise exponential backoff
- 500 (Internal Server Error): Retry 1-3 times with
exponential backoff
- 502 (Bad Gateway): Retry 1-3 times with exponential backoff
- 503 (Service Unavailable): Exponential backoff
- 504 (Gateway Timeout): Exponential backoff
- 529 (Overloaded - Anthropic): Treated like 503 with
exponential backoff

Also retries if response_validator is provided and returns False
(e.g., for empty response content).

Args:
func: Async function to execute
operation_name: Name of operation for error messages
response_validator: Optional function to validate response.
If provided and returns False, will retry the operation.
Should accept the result of func() and return True if valid.

Returns:
Result of func()

Raises:
RuntimeError: If max retries exceeded or non-retryable
error occurs, or if response validation fails after
max retries
"""
retryable_status_codes = {429, 500, 502, 503, 504, 529}
max_retries_for_500_502 = 3 # Limit retries for 500/502

last_exception = None

for attempt in range(self.max_retries):
try:
result = await func()

# Validate response if validator is provided
if response_validator is not None:
if not response_validator(result):
# Response validation failed, treat as retryable error
raise ValueError(
f"Response validation failed in {operation_name}: "
"response content is empty or invalid"
)

return result
except Exception as e:
last_exception = e
status_code = self._extract_http_status_code(e)

# If we can't determine status code, check if it's
# retryable by message
if status_code is None:
error_str = str(e).lower()
# Check for common retryable error messages
retryable_keywords = [
"rate limit",
"too many requests",
"service unavailable",
"internal server error",
"bad gateway",
"gateway timeout",
"overloaded",
"timeout",
"response validation failed",
"response content is empty",
]
if any(keyword in error_str for keyword in retryable_keywords):
# Treat as retryable, use exponential backoff
status_code = 503 # Default for unknown retryable
else:
# Non-retryable error, raise immediately
raise RuntimeError(
f"Error in {operation_name}: {str(e)}"
) from e

# Check if this is a retryable status code
if status_code not in retryable_status_codes:
# Non-retryable error, raise immediately
raise RuntimeError(f"Error in {operation_name}: {str(e)}") from e

# For 500 and 502, limit retries to max_retries_for_500_502
if status_code in {500, 502} and attempt >= max_retries_for_500_502 - 1:
raise RuntimeError(
f"Error in {operation_name} after "
f"{max_retries_for_500_502} retries: {str(e)}"
) from e

# Calculate wait time
if status_code == 429:
# Check for Retry-After header
retry_after = self._extract_retry_after(e)
if retry_after is not None:
wait_time = retry_after
else:
# Exponential backoff: 2^attempt seconds, max 60s
wait_time = min(2**attempt, 60)
elif status_code in {503, 529}:
# Exponential backoff for capacity issues
wait_time = min(2**attempt, 60)
else: # 500, 502, 504
# Exponential backoff for transient errors
wait_time = min(2**attempt, 60)

# Wait before retrying
await asyncio.sleep(wait_time)

# Max retries exceeded
raise RuntimeError(
f"Error in {operation_name} after {self.max_retries} retries: "
f"{str(last_exception)}"
) from last_exception

def __getattr__(self, name):
"""Delegate attribute access to the underlying llm object.

Expand Down
20 changes: 17 additions & 3 deletions llm_clients/openai_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@ def __init__(
name: str,
system_prompt: Optional[str] = None,
model_name: Optional[str] = None,
max_retries: int = 3,
**kwargs,
):
super().__init__(name, system_prompt)
super().__init__(name, system_prompt, max_retries=max_retries)

if not Config.OPENAI_API_KEY:
raise ValueError("OPENAI_API_KEY not found in environment variables")
Expand Down Expand Up @@ -95,7 +96,20 @@ async def generate_response(

try:
start_time = time.time()
response = await self.llm.ainvoke(messages)

# Use retry logic for API call
async def _invoke():
return await self.llm.ainvoke(messages)

def _validate_response(response_obj):
"""Validate that response has non-empty content."""
return bool(response_obj.text and response_obj.text.strip())

response = await self._retry_with_backoff(
_invoke,
operation_name="generate_response",
response_validator=_validate_response,
)
end_time = time.time()

# Extract metadata from response - capturing all available fields
Expand Down Expand Up @@ -177,7 +191,7 @@ async def generate_response(
"system_fingerprint": None,
"logprobs": None,
}
return f"Error generating response: {str(e)}"
raise RuntimeError(f"Error generating response: {str(e)}") from e

async def generate_structured_response(
self, message: Optional[str], response_model: Type[T]
Expand Down
Loading