diff --git a/src/kit/models/__init__.py b/src/kit/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/kit/models/base.py b/src/kit/models/base.py new file mode 100644 index 00000000..2b4ea9de --- /dev/null +++ b/src/kit/models/base.py @@ -0,0 +1,20 @@ +""" +Base classes and protocols for LLM models. +""" + +from typing import Protocol, runtime_checkable + + +# Define a Protocol for LLM clients to help with type checking +@runtime_checkable +class LLMClientProtocol(Protocol): + """Protocol defining the interface for LLM clients.""" + + # This is a structural protocol - any object with compatible methods will be accepted + pass + + +class LLMError(Exception): + """Custom exception for LLM related errors.""" + + pass diff --git a/src/kit/models/config.py b/src/kit/models/config.py new file mode 100644 index 00000000..ce1e7beb --- /dev/null +++ b/src/kit/models/config.py @@ -0,0 +1,57 @@ +""" +LLM provider configuration classes. +""" + +import os +from dataclasses import dataclass, field +from typing import Any, Dict, Optional + + +@dataclass +class OpenAIConfig: + """Configuration for OpenAI API access.""" + + api_key: Optional[str] = field(default_factory=lambda: os.environ.get("OPENAI_API_KEY")) + model: str = "gpt-4o" + temperature: float = 0.7 + max_tokens: int = 1000 # Default max tokens for summary + base_url: Optional[str] = None + + def __post_init__(self): + if not self.api_key: + raise ValueError( + "OpenAI API key not found. Set OPENAI_API_KEY environment variable or pass api_key directly." + ) + + +@dataclass +class AnthropicConfig: + """Configuration for Anthropic API access.""" + + api_key: Optional[str] = field(default_factory=lambda: os.environ.get("ANTHROPIC_API_KEY")) + model: str = "claude-3-opus-20240229" + temperature: float = 0.7 + max_tokens: int = 1000 # Corresponds to Anthropic's max_tokens_to_sample + + def __post_init__(self): + if not self.api_key: + raise ValueError( + "Anthropic API key not found. Set ANTHROPIC_API_KEY environment variable or pass api_key directly." + ) + + +@dataclass +class GoogleConfig: + """Configuration for Google Generative AI API access.""" + + api_key: Optional[str] = field(default_factory=lambda: os.environ.get("GOOGLE_API_KEY")) + model: str = "gemini-1.5-pro-latest" + temperature: Optional[float] = 0.7 + max_output_tokens: Optional[int] = 1000 # Corresponds to Gemini's max_output_tokens + model_kwargs: Optional[Dict[str, Any]] = field(default_factory=dict) + + def __post_init__(self): + if not self.api_key: + raise ValueError( + "Google API key not found. Set GOOGLE_API_KEY environment variable or pass api_key directly." + ) diff --git a/src/kit/models/llm_client.py b/src/kit/models/llm_client.py new file mode 100644 index 00000000..d7ab1e94 --- /dev/null +++ b/src/kit/models/llm_client.py @@ -0,0 +1,273 @@ +"""LLM client interfaces and implementations.""" + +import logging +from abc import ABC, abstractmethod +from typing import Any, Dict, Optional, Union + +from kit.models.base import LLMError +from kit.models.config import AnthropicConfig, GoogleConfig, OpenAIConfig +from kit.models.llm_utils import count_openai_chat_tokens + +# Conditionally import google.genai +try: + import google.genai as genai + from google.genai import types as genai_types +except ImportError: + genai = None # type: ignore + genai_types = None # type: ignore + +logger = logging.getLogger(__name__) + +# Constants +OPENAI_MAX_PROMPT_TOKENS = 15000 # Max tokens for the prompt to OpenAI + + +class LLMClient(ABC): + """Base class for LLM clients.""" + + @abstractmethod + def generate_completion(self, system_prompt: str, user_prompt: str, model_name: Optional[str] = None) -> str: + """Generate a completion from the LLM. + + Args: + system_prompt: The system prompt to use. + user_prompt: The user prompt to use. + model_name: Optional model name to override the default. + + Returns: + The generated completion text. + + Raises: + LLMError: If there was an error generating the completion. + """ + pass + + @staticmethod + def create_client(config: Union[OpenAIConfig, AnthropicConfig, GoogleConfig]) -> "LLMClient": + """Factory method to create an appropriate LLM client. + + Args: + config: The LLM configuration to use. + + Returns: + An LLMClient instance. + + Raises: + TypeError: If config is None or an unsupported configuration type. + LLMError: If there was an error initializing the client. + """ + # Require a valid config + if config is None: + raise TypeError("LLM configuration must be provided") + + if isinstance(config, OpenAIConfig): + return OpenAIClient(config) + elif isinstance(config, AnthropicConfig): + return AnthropicClient(config) + elif isinstance(config, GoogleConfig): + return GoogleClient(config) + else: + raise TypeError(f"Unsupported LLM configuration type: {type(config)}") + + +class OpenAIClient(LLMClient): + """Client for OpenAI's API.""" + + def __init__(self, config: OpenAIConfig): + """Initialize with OpenAI configuration. + + Args: + config: The OpenAI configuration. + + Raises: + LLMError: If the OpenAI SDK is not available. + """ + self.config = config + try: + from openai import OpenAI + + if self.config.base_url: + self.client = OpenAI(api_key=self.config.api_key, base_url=self.config.base_url) + else: + self.client = OpenAI(api_key=self.config.api_key) + except ImportError: + raise LLMError("OpenAI SDK (openai) not available. Please install it.") + + def generate_completion(self, system_prompt: str, user_prompt: str, model_name: Optional[str] = None) -> str: + """Generate a completion using OpenAI's API. + + Args: + system_prompt: The system prompt to use. + user_prompt: The user prompt to use. + model_name: Optional model name to override the config's model. + + Returns: + The generated completion text. + + Raises: + LLMError: If there was an error generating the completion. + """ + # Use provided model_name or fall back to config + actual_model = model_name if model_name is not None else self.config.model + + messages_for_api = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ] + + # Check token count + prompt_token_count = count_openai_chat_tokens(messages_for_api, actual_model) + if prompt_token_count is not None and prompt_token_count > OPENAI_MAX_PROMPT_TOKENS: + return f"Completion generation failed: OpenAI prompt too large ({prompt_token_count} tokens). Limit is {OPENAI_MAX_PROMPT_TOKENS} tokens." + + try: + response = self.client.chat.completions.create( + model=actual_model, + messages=messages_for_api, + temperature=self.config.temperature, + max_tokens=self.config.max_tokens, + ) + + if response.usage: + logger.debug(f"OpenAI API usage: {response.usage}") + + return response.choices[0].message.content + except Exception as e: + logger.error(f"Error communicating with OpenAI API: {e}") + raise LLMError(f"Error communicating with OpenAI API: {e}") from e + + +class AnthropicClient(LLMClient): + """Client for Anthropic's API.""" + + def __init__(self, config: AnthropicConfig): + """Initialize with Anthropic configuration. + + Args: + config: The Anthropic configuration. + + Raises: + LLMError: If the Anthropic SDK is not available. + """ + self.config = config + try: + from anthropic import Anthropic + + self.client = Anthropic(api_key=self.config.api_key) + except ImportError: + raise LLMError("Anthropic SDK (anthropic) not available. Please install it.") + + def generate_completion(self, system_prompt: str, user_prompt: str, model_name: Optional[str] = None) -> str: + """Generate a completion using Anthropic's API. + + Args: + system_prompt: The system prompt to use. + user_prompt: The user prompt to use. + model_name: Optional model name to override the config's model. + + Returns: + The generated completion text. + + Raises: + LLMError: If there was an error generating the completion. + """ + # Use provided model_name or fall back to config + actual_model = model_name if model_name is not None else self.config.model + + try: + response = self.client.messages.create( + model=actual_model, + system=system_prompt, + messages=[{"role": "user", "content": user_prompt}], + max_tokens=self.config.max_tokens, + temperature=self.config.temperature, + ) + + return response.content[0].text + except Exception as e: + logger.error(f"Error communicating with Anthropic API: {e}") + raise LLMError(f"Error communicating with Anthropic API: {e}") from e + + +class GoogleClient(LLMClient): + """Client for Google's Generative AI API.""" + + def __init__(self, config: GoogleConfig): + """Initialize with Google configuration. + + Args: + config: The Google configuration. + + Raises: + LLMError: If the Google Gen AI SDK is not available. + """ + self.config = config + if genai is None: + raise LLMError("Google Gen AI SDK (google-genai) not available. Please install it.") + + try: + self.client = genai.Client(api_key=self.config.api_key) + except Exception as e: + raise LLMError(f"Error initializing Google Gen AI client: {e}") from e + + def generate_completion(self, system_prompt: str, user_prompt: str, model_name: Optional[str] = None) -> str: + """Generate a completion using Google's Generative AI API. + + Args: + system_prompt: The system prompt to use (Note: currently not used by Google's API directly). + user_prompt: The user prompt to use. + model_name: Optional model name to override the config's model. + + Returns: + The generated completion text. + + Raises: + LLMError: If there was an error generating the completion. + """ + # Use provided model_name or fall back to config + actual_model = model_name if model_name is not None else self.config.model + + if genai_types is None: + raise LLMError( + "Google Gen AI SDK (google-genai) types not available. SDK might not be installed correctly." + ) + + # Prepare generation config from model_kwargs + generation_config_params: Dict[str, Any] = ( + self.config.model_kwargs.copy() if self.config.model_kwargs is not None else {} + ) + + if self.config.temperature is not None: + generation_config_params["temperature"] = self.config.temperature + if self.config.max_output_tokens is not None: + generation_config_params["max_output_tokens"] = self.config.max_output_tokens + + final_sdk_params = generation_config_params if generation_config_params else None + + # TODO: Incorporate system_prompt into user_prompt for Google models + # Since Google models don't have a direct system prompt parameter, + # we might need to combine them or use a different approach + + try: + response = self.client.models.generate_content( + model=actual_model, contents=user_prompt, generation_config=final_sdk_params + ) + + # Check for blocked prompt + if ( + hasattr(response, "prompt_feedback") + and response.prompt_feedback + and response.prompt_feedback.block_reason + ): + logger.warning(f"Google LLM prompt blocked. Reason: {response.prompt_feedback.block_reason}") + return f"Completion generation failed: Prompt blocked by API (Reason: {response.prompt_feedback.block_reason})" + + # Check for empty response + if not response.text: + logger.warning(f"Google LLM returned no text. Response: {response}") + return "Completion generation failed: No text returned by API." + + return response.text + except Exception as e: + logger.error(f"Error communicating with Google Gen AI API: {e}") + raise LLMError(f"Error communicating with Google Gen AI API: {e}") from e diff --git a/src/kit/models/llm_utils.py b/src/kit/models/llm_utils.py new file mode 100644 index 00000000..9b821a96 --- /dev/null +++ b/src/kit/models/llm_utils.py @@ -0,0 +1,154 @@ +"""Utilities for working with LLMs.""" + +import logging +from typing import Any, Dict, List, Optional + +import tiktoken + +logger = logging.getLogger(__name__) + +# Cache for tiktoken encoders +_tokenizer_cache: Dict[str, Any] = {} + + +def get_tokenizer(model_name: str): + """Get a tokenizer for a specific model. + + Args: + model_name: The name of the model to get a tokenizer for. + + Returns: + A tokenizer for the specified model, or None if no tokenizer is available. + """ + if model_name in _tokenizer_cache: + return _tokenizer_cache[model_name] + try: + encoding = tiktoken.encoding_for_model(model_name) + _tokenizer_cache[model_name] = encoding + return encoding + except KeyError: + try: + # Fallback for models not directly in tiktoken.model.MODEL_TO_ENCODING + encoding = tiktoken.get_encoding("cl100k_base") + _tokenizer_cache[model_name] = encoding + return encoding + except Exception as e: + logger.warning( + f"Could not load tiktoken encoder for {model_name} due to {e}, token count will be approximate (char count)." + ) + return None + + +def count_tokens(text: str, model_name: Optional[str] = None) -> int: + """Count the number of tokens in a text string for a given model. + + Args: + text: The text to count tokens for. + model_name: The name of the model to count tokens for. If None, defaults to "gpt-4o". + + Returns: + The number of tokens in the text. + """ + if not text: + return 0 + + # Use a default model if none specified + if model_name is None: + model_name = "gpt-4o" # Default fallback + + try: + # Try to use tiktoken for accurate token counting + if tiktoken: + try: + if model_name in _tokenizer_cache: + encoder = _tokenizer_cache[model_name] + else: + try: + encoder = tiktoken.encoding_for_model(model_name) + except KeyError: + # Model not found, use cl100k_base as fallback + encoder = tiktoken.get_encoding("cl100k_base") + _tokenizer_cache[model_name] = encoder + + return len(encoder.encode(text)) + except Exception as e: + logger.warning(f"Error using tiktoken for model {model_name}: {e}") + # Fall through to character-based approximation + else: + logger.warning( + f"No tiktoken encoder found for model {model_name}, token count will be approximate (char count)." + ) + except NameError: + # tiktoken not available + logger.warning("tiktoken not available, token count will be approximate (char count).") + + # Fallback: approximate token count based on characters (4 chars ~= 1 token) + return len(text) // 4 + + +def count_openai_chat_tokens(messages: List[Dict[str, str]], model_name: str) -> Optional[int]: + """Return the number of tokens used by a list of messages for OpenAI chat models. + + Args: + messages: A list of messages to count tokens for. + model_name: The name of the model to count tokens for. + + Returns: + The number of tokens in the messages, or None if the tokens could not be counted. + """ + encoding = get_tokenizer(model_name) + if not encoding: + logger.warning(f"Cannot count OpenAI chat tokens for {model_name}, no tiktoken encoder available.") + return None + + # Logic adapted from OpenAI cookbook for counting tokens for chat completions + # See: https://github.com/openai/openai-cookbook/blob/main/examples/how_to_count_tokens_with_tiktoken.ipynb + if model_name in { + "gpt-3.5-turbo-0613", + "gpt-3.5-turbo-16k-0613", + "gpt-4-0314", + "gpt-4-32k-0314", + "gpt-4-0613", + "gpt-4-32k-0613", + }: + tokens_per_message = 3 + tokens_per_name = 1 + elif model_name == "gpt-3.5-turbo-0301": + tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n + tokens_per_name = -1 # if there's a name, the role is omitted + elif "gpt-3.5-turbo" in model_name: # Covers general gpt-3.5-turbo and variants not explicitly listed + # Defaulting to newer model token counts as a general heuristic + logger.debug(f"Using token counting parameters for gpt-3.5-turbo-0613 for model {model_name}.") + tokens_per_message = 3 + tokens_per_name = 1 + elif "gpt-4" in model_name: # Covers general gpt-4 and variants not explicitly listed + logger.debug(f"Using token counting parameters for gpt-4-0613 for model {model_name}.") + tokens_per_message = 3 + tokens_per_name = 1 + else: + # Fallback for unknown models; this might not be perfectly accurate. + logger.warning( + f"count_openai_chat_tokens() may not be accurate for model {model_name}. " + f"It's not explicitly handled. Using default token counting parameters (3 tokens/message, 1 token/name). " + f"See OpenAI's documentation for details on your specific model." + ) + tokens_per_message = 3 + tokens_per_name = 1 + + num_tokens = 0 + for message in messages: + num_tokens += tokens_per_message + for key, value in message.items(): + if value is None: # Ensure value is not None before attempting to encode + logger.debug(f"Encountered None value for key '{key}' in message, skipping for token counting.") + continue + try: + num_tokens += len(encoding.encode(str(value))) # Ensure value is string + except Exception as e: + # This catch is a safeguard; tiktoken should handle most string inputs. + logger.error(f"Could not encode value for token counting: '{str(value)[:50]}...', error: {e}") + return None # Inability to encode part of message means count is unreliable + if key == "name": + num_tokens += tokens_per_name + num_tokens += 3 # every reply is primed with <|start|>assistant<|message|> (approximates assistant's first tokens) + return num_tokens diff --git a/src/kit/summaries.py b/src/kit/summaries.py index 7700d317..ccc8864a 100644 --- a/src/kit/summaries.py +++ b/src/kit/summaries.py @@ -1,29 +1,12 @@ """Handles code summarization using LLMs.""" import logging -import os -from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Protocol, Union, runtime_checkable +from typing import TYPE_CHECKING, Any, Optional, Union -import tiktoken - - -# Define a Protocol for LLM clients to help with type checking -@runtime_checkable -class LLMClientProtocol(Protocol): - """Protocol defining the interface for LLM clients.""" - - # This is a structural protocol - any object with compatible methods will be accepted - pass - - -# Conditionally import google.genai -try: - import google.genai as genai - from google.genai import types as genai_types -except ImportError: - genai = None # type: ignore - genai_types = None # type: ignore +from kit.models.base import LLMError +from kit.models.config import AnthropicConfig, GoogleConfig, OpenAIConfig +from kit.models.llm_client import LLMClient +from kit.models.llm_utils import count_tokens logger = logging.getLogger(__name__) @@ -32,211 +15,28 @@ class LLMClientProtocol(Protocol): from kit.repository import Repository -class LLMError(Exception): - """Custom exception for LLM related errors.""" - - pass - - class SymbolNotFoundError(Exception): """Custom exception for when a symbol (function, class) is not found.""" pass -@dataclass -class OpenAIConfig: - """Configuration for OpenAI API access.""" - - api_key: Optional[str] = field(default_factory=lambda: os.environ.get("OPENAI_API_KEY")) - model: str = "gpt-4o" - temperature: float = 0.7 - max_tokens: int = 1000 # Default max tokens for summary - base_url: Optional[str] = None - - def __post_init__(self): - if not self.api_key: - raise ValueError( - "OpenAI API key not found. Set OPENAI_API_KEY environment variable or pass api_key directly." - ) - - -@dataclass -class AnthropicConfig: - """Configuration for Anthropic API access.""" - - api_key: Optional[str] = field(default_factory=lambda: os.environ.get("ANTHROPIC_API_KEY")) - model: str = "claude-3-opus-20240229" - temperature: float = 0.7 - max_tokens: int = 1000 # Corresponds to Anthropic's max_tokens_to_sample - - def __post_init__(self): - if not self.api_key: - raise ValueError( - "Anthropic API key not found. Set ANTHROPIC_API_KEY environment variable or pass api_key directly." - ) - - -@dataclass -class GoogleConfig: - """Configuration for Google Generative AI API access.""" - - api_key: Optional[str] = field(default_factory=lambda: os.environ.get("GOOGLE_API_KEY")) - model: str = "gemini-1.5-pro-latest" - temperature: Optional[float] = 0.7 - max_output_tokens: Optional[int] = 1000 # Corresponds to Gemini's max_output_tokens - model_kwargs: Optional[Dict[str, Any]] = field(default_factory=dict) - - def __post_init__(self): - if not self.api_key: - raise ValueError( - "Google API key not found. Set GOOGLE_API_KEY environment variable or pass api_key directly." - ) - - # todo: make configurable MAX_CODE_LENGTH_CHARS = 50000 # Max characters for a single function/class summary MAX_FILE_SUMMARIZE_CHARS = 25000 # Max characters for file content in summarize_file -OPENAI_MAX_PROMPT_TOKENS = 15000 # Max tokens for the prompt to OpenAI class Summarizer: """Provides methods to summarize code using a configured LLM.""" - _tokenizer_cache: Dict[str, Any] = {} # Cache for tiktoken encoders - config: Optional[Union[OpenAIConfig, AnthropicConfig, GoogleConfig]] repo: "Repository" - _llm_client: Optional[Any] # type: ignore - - def _get_tokenizer(self, model_name: str): - if model_name in self._tokenizer_cache: - return self._tokenizer_cache[model_name] - try: - encoding = tiktoken.encoding_for_model(model_name) - self._tokenizer_cache[model_name] = encoding - return encoding - except KeyError: - try: - # Fallback for models not directly in tiktoken.model.MODEL_TO_ENCODING - encoding = tiktoken.get_encoding("cl100k_base") - self._tokenizer_cache[model_name] = encoding - return encoding - except Exception as e: - logger.warning( - f"Could not load tiktoken encoder for {model_name} due to {e}, token count will be approximate (char count)." - ) - return None - - def _count_tokens(self, text: str, model_name: Optional[str] = None) -> int: - """Count the number of tokens in a text string for a given model.""" - if not text: - return 0 - - # Use model from config if available, otherwise use a default - if model_name is None: - if self.config is not None and hasattr(self.config, "model"): - model_name = self.config.model - else: - # Default to a common model if no config or model specified - model_name = "gpt-4o" # Default fallback - - try: - # Try to use tiktoken for accurate token counting - if tiktoken: - try: - if model_name in self._tokenizer_cache: - encoder = self._tokenizer_cache[model_name] - else: - try: - encoder = tiktoken.encoding_for_model(model_name) - except KeyError: - # Model not found, use cl100k_base as fallback - encoder = tiktoken.get_encoding("cl100k_base") - self._tokenizer_cache[model_name] = encoder - - return len(encoder.encode(text)) - except Exception as e: - logger.warning(f"Error using tiktoken for model {model_name}: {e}") - # Fall through to character-based approximation - else: - logger.warning( - f"No tiktoken encoder found for model {model_name}, token count will be approximate (char count)." - ) - except NameError: - # tiktoken not available - logger.warning("tiktoken not available, token count will be approximate (char count).") - - # Fallback: approximate token count based on characters (4 chars ~= 1 token) - return len(text) // 4 - - def _count_openai_chat_tokens(self, messages: List[Dict[str, str]], model_name: str) -> Optional[int]: - """Return the number of tokens used by a list of messages for OpenAI chat models.""" - encoding = self._get_tokenizer(model_name) - if not encoding: - logger.warning(f"Cannot count OpenAI chat tokens for {model_name}, no tiktoken encoder available.") - return None - - # Logic adapted from OpenAI cookbook for counting tokens for chat completions - # See: https://github.com/openai/openai-cookbook/blob/main/examples/how_to_count_tokens_with_tiktoken.ipynb - if model_name in { - "gpt-3.5-turbo-0613", - "gpt-3.5-turbo-16k-0613", - "gpt-4-0314", - "gpt-4-32k-0314", - "gpt-4-0613", - "gpt-4-32k-0613", - }: - tokens_per_message = 3 - tokens_per_name = 1 - elif model_name == "gpt-3.5-turbo-0301": - tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n - tokens_per_name = -1 # if there's a name, the role is omitted - elif "gpt-3.5-turbo" in model_name: # Covers general gpt-3.5-turbo and variants not explicitly listed - # Defaulting to newer model token counts as a general heuristic - logger.debug(f"Using token counting parameters for gpt-3.5-turbo-0613 for model {model_name}.") - tokens_per_message = 3 - tokens_per_name = 1 - elif "gpt-4" in model_name: # Covers general gpt-4 and variants not explicitly listed - logger.debug(f"Using token counting parameters for gpt-4-0613 for model {model_name}.") - tokens_per_message = 3 - tokens_per_name = 1 - else: - # Fallback for unknown models; this might not be perfectly accurate. - # Raise an error or use a default if this model is not supported by tiktoken's encoding_for_model - # For now, using a common default and logging a warning. - logger.warning( - f"_count_openai_chat_tokens() may not be accurate for model {model_name}. " - f"It's not explicitly handled. Using default token counting parameters (3 tokens/message, 1 token/name). " - f"See OpenAI's documentation for details on your specific model." - ) - tokens_per_message = 3 - tokens_per_name = 1 - - num_tokens = 0 - for message in messages: - num_tokens += tokens_per_message - for key, value in message.items(): - if value is None: # Ensure value is not None before attempting to encode - logger.debug(f"Encountered None value for key '{key}' in message, skipping for token counting.") - continue - try: - num_tokens += len(encoding.encode(str(value))) # Ensure value is string - except Exception as e: - # This catch is a safeguard; tiktoken should handle most string inputs. - logger.error(f"Could not encode value for token counting: '{str(value)[:50]}...', error: {e}") - return None # Inability to encode part of message means count is unreliable - if key == "name": - num_tokens += tokens_per_name - num_tokens += ( - 3 # every reply is primed with <|start|>assistant<|message|> (approximates assistant's first tokens) - ) - return num_tokens + _llm_client: LLMClient + config: Optional[Union[OpenAIConfig, AnthropicConfig, GoogleConfig]] def __init__( self, repo: "Repository", - config: Optional[Union[OpenAIConfig, AnthropicConfig, GoogleConfig]] = None, - llm_client: Optional[Any] = None, + config: Union[OpenAIConfig, AnthropicConfig, GoogleConfig], ): """ Initializes the Summarizer. @@ -244,102 +44,23 @@ def __init__( Args: repo: The kit.Repository instance containing the code. config: LLM configuration (OpenAIConfig, AnthropicConfig, or GoogleConfig). - If None, defaults to OpenAIConfig. - llm_client: Optional pre-initialized LLM client. If None, client will be - lazy-loaded on first use based on the config. + This is required to specify which LLM provider to use. + + Raises: + TypeError: If config is not provided or has an unsupported type. """ self.repo = repo - self._llm_client = llm_client # Store provided llm_client directly - self.config = config # Store provided config - - if self._llm_client is None: - # Only create/setup LLM if a client wasn't directly provided - if self.config is None: - # If no config is provided either, default to OpenAIConfig - # This will raise ValueError if OPENAI_API_KEY is not set. - self.config = OpenAIConfig() - - if isinstance(self.config, OpenAIConfig): - try: - import openai - - if self.config.base_url: - self._llm_client = openai.OpenAI(api_key=self.config.api_key, base_url=self.config.base_url) - else: - self._llm_client = openai.OpenAI(api_key=self.config.api_key) - except ImportError: - raise LLMError("OpenAI SDK (openai) not available. Please install it.") - elif isinstance(self.config, AnthropicConfig): - try: - import anthropic - - self._llm_client = anthropic.Anthropic(api_key=self.config.api_key) - except ImportError: - raise LLMError("Anthropic SDK (anthropic) not available. Please install it.") - elif isinstance(self.config, GoogleConfig): - try: - import google.genai as genai - - self._llm_client = genai.Client(api_key=self.config.api_key) # Use the new client - except ImportError: - raise LLMError("Google Gen AI SDK (google-genai) not available. Please install it.") - else: - # This case implies self.config was set to something unexpected if self._llm_client was None - # and self.config was also None initially. Or self.config was passed with an invalid type. - if self.config is not None: # Only raise if config is of an unsupported type - raise TypeError(f"Unsupported LLM configuration type: {type(self.config)}") - # If self.config is None here, it means OpenAIConfig() failed, but that should raise its own error. - # Or, it implies a logic flaw if this path is reached with self.config being None. - # If _llm_client was provided, we assume it's configured and ready. - # self.config might be None if only llm_client was passed. + self.config = config + + # Create LLM client using factory method + self._llm_client = LLMClient.create_client(config) def _get_llm_client(self) -> Any: - """Lazy loads the appropriate LLM client based on self.config.""" - if self._llm_client is not None: - return self._llm_client + """Returns the LLM client. - try: - if isinstance(self.config, OpenAIConfig): - from openai import OpenAI # Local import for OpenAI client - - if self.config.base_url: - client = OpenAI(api_key=self.config.api_key, base_url=self.config.base_url) - else: - client = OpenAI(api_key=self.config.api_key) - elif isinstance(self.config, AnthropicConfig): - from anthropic import Anthropic # Local import for Anthropic client - - client = Anthropic(api_key=self.config.api_key) # type: ignore # Different client type - elif isinstance(self.config, GoogleConfig): - if genai is None or genai_types is None: - raise LLMError( - "Google Gen AI SDK (google-genai) is not installed. Please install it to use Google models." - ) - # API key is picked up from GOOGLE_API_KEY env var by default if not passed to Client() - # However, we have it in self.config.api_key, so we pass it explicitly. - client = genai.Client(api_key=self.config.api_key) # type: ignore # Different client type - else: - # This case should ideally be prevented by the __init__ type check, - # but as a safeguard: - raise LLMError(f"Unsupported LLM configuration type: {type(self.config)}") - - self._llm_client = client - return self._llm_client - except ImportError as e: - sdk_name = "" - if "openai" in str(e).lower(): - sdk_name = "openai" - elif "anthropic" in str(e).lower(): - sdk_name = "anthropic" - # google-genai import is handled by genai being None - if sdk_name: - raise LLMError( - f"{sdk_name.capitalize()} SDK not installed. Please install it to use {sdk_name.capitalize()} models." - ) from e - raise # Re-raise if it's a different import error - except Exception as e: - logger.error(f"Error initializing LLM client: {e}") - raise LLMError(f"Error initializing LLM client: {e}") from e + This method is maintained for backward compatibility with tests. + """ + return self._llm_client def summarize_file(self, file_path: str) -> str: """ @@ -388,99 +109,18 @@ def summarize_file(self, file_path: str) -> str: system_prompt_text = "You are an expert assistant skilled in creating concise and informative code summaries." user_prompt_text = f"Summarize the following code from the file '{file_path}'. Provide a high-level overview of its purpose, key components, and functionality. Focus on what the code does, not just how it's written. The code is:\n\n```\n{file_content}\n```" - client = self._get_llm_client() - summary = "" - - logger.debug(f"System Prompt for {file_path}: {system_prompt_text}") - logger.debug(f"User Prompt for {file_path} (first 200 chars): {user_prompt_text[:200]}...") # Get model name from config if available, otherwise pass None for default model_name = self.config.model if self.config is not None and hasattr(self.config, "model") else None - token_count = self._count_tokens(user_prompt_text, model_name) + token_count = count_tokens(user_prompt_text, model_name) + if token_count is not None: logger.debug(f"Estimated tokens for user prompt ({file_path}): {token_count}") else: logger.debug(f"Approximate characters for user prompt ({file_path}): {len(user_prompt_text)}") try: - # If a custom llm_client was provided without a config, use it directly - if self.config is None: - # For custom llm_client without config, assume it knows how to handle the prompt - # This is used in tests with FakeOpenAI - try: - # Try OpenAI-style interface first - response = client.chat.completions.create( - messages=[ - {"role": "system", "content": system_prompt_text}, - {"role": "user", "content": user_prompt_text}, - ] - ) - summary = response.choices[0].message.content - except (AttributeError, TypeError) as e: - # If that fails, the client might have a different interface - logger.warning(f"Custom LLM client doesn't support OpenAI-style interface: {e}") - raise LLMError(f"Custom LLM client without config doesn't support expected interface: {e}") - elif isinstance(self.config, OpenAIConfig): - messages_for_api = [ - {"role": "system", "content": system_prompt_text}, - {"role": "user", "content": user_prompt_text}, - ] - prompt_token_count = self._count_openai_chat_tokens(messages_for_api, self.config.model) - if prompt_token_count is not None and prompt_token_count > OPENAI_MAX_PROMPT_TOKENS: - summary = f"Summary generation failed: OpenAI prompt too large ({prompt_token_count} tokens). Limit is {OPENAI_MAX_PROMPT_TOKENS} tokens." - else: - response = client.chat.completions.create( - model=self.config.model, - messages=messages_for_api, - temperature=self.config.temperature, - max_tokens=self.config.max_tokens, - ) - summary = response.choices[0].message.content - if response.usage: - logger.debug(f"OpenAI API usage for {file_path}: {response.usage}") - elif isinstance(self.config, AnthropicConfig): - response = client.messages.create( - model=self.config.model, - system=system_prompt_text, - messages=[{"role": "user", "content": user_prompt_text}], - max_tokens=self.config.max_tokens, - temperature=self.config.temperature, - ) - summary = response.content[0].text - elif isinstance(self.config, GoogleConfig): - if not genai_types: - raise LLMError( - "Google Gen AI SDK (google-genai) types not available. SDK might not be installed correctly." - ) - - generation_config_params: Dict[str, Any] = ( - self.config.model_kwargs.copy() if self.config.model_kwargs is not None else {} - ) - - if self.config.temperature is not None: - generation_config_params["temperature"] = self.config.temperature - if self.config.max_output_tokens is not None: - generation_config_params["max_output_tokens"] = self.config.max_output_tokens - - final_sdk_params = generation_config_params if generation_config_params else None - - response = client.models.generate_content( - model=self.config.model, contents=user_prompt_text, generation_config=final_sdk_params - ) - # Check for blocked prompt first - if ( - hasattr(response, "prompt_feedback") - and response.prompt_feedback - and response.prompt_feedback.block_reason - ): - logger.warning( - f"Google LLM prompt for file {file_path} blocked. Reason: {response.prompt_feedback.block_reason}" - ) - summary = f"Summary generation failed: Prompt blocked by API (Reason: {response.prompt_feedback.block_reason})" - elif not response.text: - logger.warning(f"Google LLM returned no text for file {file_path}. Response: {response}") - summary = "Summary generation failed: No text returned by API." - else: - summary = response.text + # Generate the summary using the LLM client + summary = self._llm_client.generate_completion(system_prompt_text, user_prompt_text, model_name) if not summary or not summary.strip(): logger.warning(f"LLM returned an empty or whitespace-only summary for file {file_path}.") @@ -536,103 +176,14 @@ def summarize_function(self, file_path: str, function_name: str) -> str: system_prompt_text = "You are an expert assistant skilled in creating concise code summaries for functions." user_prompt_text = f"Summarize the following function named '{function_name}' from the file '{file_path}'. Describe its purpose, parameters, and return value. The function definition is:\n\n```\n{function_code}\n```" - client = self._get_llm_client() - summary = "" - - logger.debug(f"System Prompt for {function_name} in {file_path}: {system_prompt_text}") - logger.debug(f"User Prompt for {function_name} in {file_path} (first 200 chars): {user_prompt_text[:200]}...") # Get model name from config if available, otherwise pass None for default model_name = self.config.model if self.config is not None and hasattr(self.config, "model") else None - token_count = self._count_tokens(user_prompt_text, model_name) + token_count = count_tokens(user_prompt_text, model_name) logger.debug(f"Token count for {function_name} in {file_path}: {token_count}") try: - # If a custom llm_client was provided without a config, use it directly - if self.config is None: - # For custom llm_client without config, assume it knows how to handle the prompt - # This is used in tests with FakeOpenAI - try: - # Try OpenAI-style interface first - response = client.chat.completions.create( - messages=[ - {"role": "system", "content": system_prompt_text}, - {"role": "user", "content": user_prompt_text}, - ] - ) - summary = response.choices[0].message.content - except (AttributeError, TypeError) as e: - # If that fails, the client might have a different interface - # In a real implementation, you'd need more robust handling here - logger.warning(f"Custom LLM client doesn't support OpenAI-style interface: {e}") - raise LLMError(f"Custom LLM client without config doesn't support expected interface: {e}") - elif isinstance(self.config, OpenAIConfig): - messages_for_api = [ - {"role": "system", "content": system_prompt_text}, - {"role": "user", "content": user_prompt_text}, - ] - prompt_token_count = self._count_openai_chat_tokens(messages_for_api, self.config.model) - if prompt_token_count is not None and prompt_token_count > OPENAI_MAX_PROMPT_TOKENS: - summary = f"Summary generation failed: OpenAI prompt too large ({prompt_token_count} tokens). Limit is {OPENAI_MAX_PROMPT_TOKENS} tokens." - else: - response = client.chat.completions.create( - model=self.config.model, - messages=messages_for_api, - temperature=self.config.temperature, - max_tokens=self.config.max_tokens, - ) - summary = response.choices[0].message.content - if response.usage: - logger.debug(f"OpenAI API usage for {function_name} in {file_path}: {response.usage}") - elif isinstance(self.config, AnthropicConfig): - response = client.messages.create( - model=self.config.model, - system=system_prompt_text, - messages=[{"role": "user", "content": user_prompt_text}], - max_tokens=self.config.max_tokens, - temperature=self.config.temperature, - ) - summary = response.content[0].text - # Anthropic usage might be in response.usage (confirm API docs) - # Example: logger.debug(f"Anthropic API usage for {function_name} in {file_path}: {response.usage}") - elif isinstance(self.config, GoogleConfig): - if not genai_types: - raise LLMError( - "Google Gen AI SDK (google-genai) types not available. SDK might not be installed correctly." - ) - - generation_config_params: Dict[str, Any] = ( - self.config.model_kwargs.copy() if self.config.model_kwargs is not None else {} - ) - - if self.config.temperature is not None: - generation_config_params["temperature"] = self.config.temperature - if self.config.max_output_tokens is not None: - generation_config_params["max_output_tokens"] = self.config.max_output_tokens - - final_sdk_params = generation_config_params if generation_config_params else None - - response = client.models.generate_content( - model=self.config.model, contents=user_prompt_text, generation_config=final_sdk_params - ) - if ( - hasattr(response, "prompt_feedback") - and response.prompt_feedback - and response.prompt_feedback.block_reason - ): - logger.warning( - f"Google LLM prompt for {function_name} in {file_path} blocked. Reason: {response.prompt_feedback.block_reason}" - ) - summary = f"Summary generation failed: Prompt blocked by API (Reason: {response.prompt_feedback.block_reason})" - elif not response.text: - logger.warning( - f"Google LLM returned no text for {function_name} in {file_path}. Response: {response}" - ) - summary = "Summary generation failed: No text returned by API." - else: - summary = response.text - else: - # This should never happen with our current logic, but as a safeguard - raise LLMError(f"Unsupported LLM configuration type: {type(self.config) if self.config else None}") + # Generate the summary using the LLM client + summary = self._llm_client.generate_completion(system_prompt_text, user_prompt_text, model_name) if not summary or not summary.strip(): logger.warning( @@ -690,99 +241,14 @@ def summarize_class(self, file_path: str, class_name: str) -> str: system_prompt_text = "You are an expert assistant skilled in creating concise code summaries for classes." user_prompt_text = f"Summarize the following class named '{class_name}' from the file '{file_path}'. Describe its purpose, key attributes, and main methods. The class definition is:\n\n```\n{class_code}\n```" - client = self._get_llm_client() - summary = "" - - logger.debug(f"System Prompt for {class_name} in {file_path}: {system_prompt_text}") - logger.debug(f"User Prompt for {class_name} (first 200 chars): {user_prompt_text[:200]}...") # Get model name from config if available, otherwise pass None for default model_name = self.config.model if self.config is not None and hasattr(self.config, "model") else None - token_count = self._count_tokens(user_prompt_text, model_name) + token_count = count_tokens(user_prompt_text, model_name) logger.debug(f"Token count for {class_name} in {file_path}: {token_count}") try: - # If a custom llm_client was provided without a config, use it directly - if self.config is None: - # For custom llm_client without config, assume it knows how to handle the prompt - # This is used in tests with FakeOpenAI - try: - # Try OpenAI-style interface first - response = client.chat.completions.create( - messages=[ - {"role": "system", "content": system_prompt_text}, - {"role": "user", "content": user_prompt_text}, - ] - ) - summary = response.choices[0].message.content - except (AttributeError, TypeError) as e: - logger.warning(f"Custom LLM client doesn't support OpenAI-style interface: {e}") - raise LLMError(f"Custom LLM client without config doesn't support expected interface: {e}") - elif isinstance(self.config, OpenAIConfig): - messages_for_api = [ - {"role": "system", "content": system_prompt_text}, - {"role": "user", "content": user_prompt_text}, - ] - prompt_token_count = self._count_openai_chat_tokens(messages_for_api, self.config.model) - if prompt_token_count is not None and prompt_token_count > OPENAI_MAX_PROMPT_TOKENS: - summary = f"Summary generation failed: OpenAI prompt too large ({prompt_token_count} tokens). Limit is {OPENAI_MAX_PROMPT_TOKENS} tokens." - else: - response = client.chat.completions.create( - model=self.config.model, - messages=messages_for_api, - temperature=self.config.temperature, - max_tokens=self.config.max_tokens, - ) - summary = response.choices[0].message.content - if response.usage: - logger.debug(f"OpenAI API usage for {class_name} in {file_path}: {response.usage}") - elif isinstance(self.config, AnthropicConfig): - response = client.messages.create( - model=self.config.model, - system=system_prompt_text, - messages=[{"role": "user", "content": user_prompt_text}], - max_tokens=self.config.max_tokens, - temperature=self.config.temperature, - ) - summary = response.content[0].text - # Anthropic usage might be in response.usage (confirm API docs) - # Example: logger.debug(f"Anthropic API usage for {class_name} in {file_path}: {response.usage}") - elif isinstance(self.config, GoogleConfig): - if not genai_types: - raise LLMError( - "Google Gen AI SDK (google-genai) types not available. SDK might not be installed correctly." - ) - - generation_config_params: Dict[str, Any] = ( - self.config.model_kwargs.copy() if self.config.model_kwargs is not None else {} - ) - - if self.config.temperature is not None: - generation_config_params["temperature"] = self.config.temperature - if self.config.max_output_tokens is not None: - generation_config_params["max_output_tokens"] = self.config.max_output_tokens - - final_sdk_params = generation_config_params if generation_config_params else None - - response = client.models.generate_content( - model=self.config.model, contents=user_prompt_text, generation_config=final_sdk_params - ) - if ( - hasattr(response, "prompt_feedback") - and response.prompt_feedback - and response.prompt_feedback.block_reason - ): - logger.warning( - f"Google LLM prompt for {class_name} in {file_path} blocked. Reason: {response.prompt_feedback.block_reason}" - ) - summary = f"Summary generation failed: Prompt blocked by API (Reason: {response.prompt_feedback.block_reason})" - elif not response.text: - logger.warning(f"Google LLM returned no text for {class_name} in {file_path}. Response: {response}") - summary = "Summary generation failed: No text returned by API." - else: - summary = response.text - else: - # This should never happen with our current logic, but as a safeguard - raise LLMError(f"Unsupported LLM configuration type: {type(self.config) if self.config else None}") + # Generate the summary using the LLM client + summary = self._llm_client.generate_completion(system_prompt_text, user_prompt_text, model_name) if not summary or not summary.strip(): logger.warning(