diff --git a/README.md b/README.md index 3187c2cd..0c7b4491 100644 --- a/README.md +++ b/README.md @@ -5,6 +5,7 @@

## 🎯 Vision + ai agents are clearly the future and the entire workforce will be replaced or atleast using ai agents. while i am a quant and building agents for algo trading i will be contributing to all different types of ai agent flows and placing all of the agents here for free, 100% open sourced because i beleive code is the great equalizer and we have never seen a regime shift like this so i need to get this code to the people feel free to join [our discord](https://discord.gg/8UPuVZ53bh) if you beleive ai agents will be integrated into the workforce @@ -13,318 +14,68 @@ feel free to join [our discord](https://discord.gg/8UPuVZ53bh) if you beleive ai ⭐️ [first full concise documentation video (watch here)](https://youtu.be/RlqzkSgDKDc) -⭐️ [second full walkthrough video(watch here)](https://youtu.be/tjY24JR8Cso?si=Za-PQ2L79US6cu2T) - -⭐️ [third full walkthrough w/ big updates, new models, new agents(watch here)](https://youtu.be/qZv6IFIkk6I) - -πŸ“€ follow all updates here on youtube in this playlist: https://www.youtube.com/playlist?list=PLXrNVMjRZUJg4M4uz52iGd1LhXXGVbIFz - ---- - -## πŸ€– All Available Agents - -**⚠️ For live trading agents: Only use these AFTER thoroughly backtesting your strategies!** - -### Backtesting & Research Agents -- **RBI Agent** (`rbi_agent.py`): Uses DeepSeek to research trading strategies based on YouTube videos, PDFs, or text you provide, then codes out the backtest automatically -- **RBI Parallel Agent** (`rbi_agent_pp_multi.py`): Parallel version with 18 threads, tests across 20+ data sources, web dashboard included -- **Research Agent** (`research_agent.py`): Fills the ideas.txt file so the RBI agent can run forever - -### Live Trading Agents -- **Trading Agent** (`trading_agent.py`): **DUAL-MODE AI trading system** - Toggle between single model (fast ~10s) or swarm mode (6-model consensus ~45-60s). Swarm mode queries Claude 4.5, GPT-5, Gemini 2.5, Grok-4, DeepSeek, and DeepSeek-R1 local for majority vote trading decisions. Configure via `USE_SWARM_MODE` in config.py -- **Strategy Agent** (`strategy_agent.py`): Manages and executes trading strategies placed in the strategies folder -- **Risk Agent** (`risk_agent.py`): Monitors and manages portfolio risk, enforcing position limits and PnL thresholds -- **Copy Agent** (`copy_agent.py`): Monitors copy bot for potential trades -- **Swarm Agent** (`swarm_agent.py`): Queries 6 AI models in parallel (Claude 4.5, GPT-5, Gemini 2.5, Grok-4, DeepSeek, DeepSeek-R1 local), generates AI consensus summary, returns clean JSON with model mapping for easy parsing 🐝 - -### Market Analysis Agents -- **Whale Agent** (`whale_agent.py`): Monitors whale activity and announces when a whale enters the market -- **Sentiment Agent** (`sentiment_agent.py`): Analyzes Twitter sentiment for crypto tokens with voice announcements -- **Chart Agent** (`chartanalysis_agent.py`): Looks at any crypto chart and analyzes it with AI to make a buy/sell/nothing recommendation -- **Funding Agent** (`funding_agent.py`): Monitors funding rates across exchanges and uses AI to analyze opportunities, providing voice alerts for extreme funding situations with technical context πŸŒ™ -- **Liquidation Agent** (`liquidation_agent.py`): Tracks liquidation events with configurable time windows (15min/1hr/4hr), providing AI analysis and voice alerts for significant liquidation spikes πŸ’¦ -- **Listing Arbitrage Agent** (`listingarb_agent.py`): Identifies promising Solana tokens on CoinGecko before they reach major exchanges like Binance and Coinbase, using parallel AI analysis for technical and fundamental insights -- **Funding Arbitrage Agent** (`fundingarb_agent.py`): Tracks the funding rate on HyperLiquid to find funding rate arbitrage opportunities between HL and Solana -- **New or Top Tokens Agent** (`new_or_top_agent.py`): Looks at the new tokens and the top tokens from CoinGecko API - -### Solana-Specific Agents -- **Sniper Agent** (`sniper_agent.py`): Watches for new Solana token launches, analyzes them, and maybe snipes -- **TX Agent** (`tx_agent.py`): Watches transactions made by your copy list and prints them out with optional auto tab open -- **Solana Agent** (`solana_agent.py`): Looks at the sniper agent and the TX agent to select which memes may be interesting - -### Content Creation Agents -- **Chat Agent** (`chat_agent.py`): Monitors YouTube live stream chat, moderates & responds to known questions. Absolute fire. -- **Twitter Agent** (`tweet_agent.py`): Takes in text and creates tweets using DeepSeek or other models -- **Video Agent** (`video_agent.py`): Takes in text to create videos by creating audio snippets using ElevenLabs and combining with raw_video footage -- **Clips Agent** (`clips_agent.py`): Helps clip long videos into shorter ones so you can upload to your YouTube and get paid. More info: https://discord.gg/XAw8US9aHT -- **Real-Time Clips Agent** (`realtime_clips_agent.py`): Makes real-time clips of streamers using OBS -- **Phone Agent** (`phone_agent.py`): An AI agent that can take phone calls for you - -### Specialized Agents -- **Focus Agent** (`focus_agent.py`): Randomly samples audio during coding sessions to maintain productivity, providing focus scores and voice alerts when focus drops (~$10/month, perfect for voice-to-code workflows) -- **Million Agent** (`million_agent.py`): Uses million context window from Gemini to pull in a knowledge base -- **TikTok Agent** (`tiktok_agent.py`): Scrolls TikTok and gets screenshots of the video + comments to extract consumer data to feed into algos. Sometimes called social arbitrage -- **Compliance Agent** (`compliance_agent.py`): Analyzes TikTok ads for Facebook advertising compliance, extracting frames and transcribing audio to check against FB guidelines -- **Housecoin Agent** (`housecoin_agent.py`): DCA (dollar cost average) agent with AI confirmation layer using Grok-4 for the thesis: 1 House = 1 Housecoin 🏠 -- **Polymarket Agent** (`polymarket_agent.py`): Connects to the live trades feed via WebSocket and analyzes with the swarm agent to see which markets could be interesting to trade - - -## ⚠️ Critical Disclaimers - -*There is no token associated with this project and there never will be. any token launched is not affiliated with this project, moon dev will never dm you. be careful. don't send funds anywhere* - -**PLEASE READ CAREFULLY:** - -1. This is an experimental research project, NOT a trading system -2. There are NO plug-and-play solutions for guaranteed profits -3. We do NOT provide trading strategies -4. Success depends entirely on YOUR: - - Trading strategy - - Risk management - - Market research - - Testing and validation - - Overall trading approach - -5. NO AI agent can guarantee profitable trading -6. You MUST develop and validate your own trading approach -7. Trading involves substantial risk of loss -8. Past performance does not indicate future results - -**⚠️ IMPORTANT: This is an experimental project. There are NO guarantees of profitability. Trading involves substantial risk of loss.** - -## πŸ‘‚ Looking for Updates? -Project updates will be posted in Discord, join here: [discord.gg/8UPuVZ53bh](https://discord.gg/8UPuVZ53bh) - -## πŸ”— Links -- Free Algo Trading Roadmap: [moondev.com](https://moondev.com) -- Algo Trading Education: [algotradecamp.com](https://algotradecamp.com) -- Business Contact [moon@algotradecamp.com](mailto:moon@algotradecamp.com) +## πŸ”Œ Multi-Provider LLM Support ---- - -## πŸš€ Quick Start Guide - RBI Backtesting Agent - -**Why Start with Backtesting?** - -Before running ANY trading algorithm or AI agent with real money, you MUST backtest your strategies. Backtesting shows you how a strategy would have performed on historical data. The RBI (Research-Based Inference) Agent automates this entire process for you. - -**What is the RBI Agent?** - -The RBI Agent takes your trading ideas (from YouTube videos, PDFs, or plain text) and: -1. 🧠 Uses AI to understand the trading strategy -2. πŸ’» Codes a complete backtest using the `backtesting.py` library -3. πŸ“Š Tests across 20+ different market data sources -4. βœ… Only saves strategies that pass a 1% return threshold -5. 🎯 Tries to optimize strategies to hit a 50% target return +This project now includes a provider-agnostic LLM client interface that supports multiple LLM providers including OpenAI, Groq, and Ollama. You can easily switch between providers without changing your code. -**Python Version:** 3.10.9 was used during development +### Configuration -### Step 1: ⭐ Star & Fork the Repo -- Click the star button to save it to your GitHub favorites -- Fork to your GitHub account to get your own copy -- This lets you make changes and track updates +Set up your LLM provider using environment variables: -### Step 2: πŸ’» Clone to Your Machine +#### Using OpenAI: ```bash -git clone https://github.com/YOUR_USERNAME/moon-dev-ai-agents-for-trading.git -cd moon-dev-ai-agents-for-trading +export LLM_PROVIDER=openai +export LLM_API_KEY=your_openai_api_key +export LLM_MODEL=gpt-4 ``` -**Recommended IDEs:** -- [Cursor](https://www.cursor.com/) - AI-enabled coding -- [Windsurfer](https://codeium.com/) - AI-enabled coding - -### Step 3: πŸ”‘ Set Up Environment Variables - -The RBI Agent needs API keys to function. Create a `.env` file in the root directory: - +#### Using Groq: ```bash -# Copy the example file -cp .env.example .env +export LLM_PROVIDER=groq +export LLM_API_KEY=your_groq_api_key +export LLM_MODEL=mixtral-8x7b-32768 ``` -**Required API Keys for RBI Agent:** - +#### Using Ollama (local): ```bash -# AI Model APIs (you need at least ONE of these) -ANTHROPIC_KEY=your_anthropic_api_key_here # Claude models (recommended) -OPENAI_KEY=your_openai_api_key_here # GPT models -DEEPSEEK_KEY=your_deepseek_api_key_here # DeepSeek models (cheap!) -GROQ_API_KEY=your_groq_api_key_here # Groq (fast inference) -GEMINI_KEY=your_gemini_api_key_here # Google Gemini -XAI_API_KEY=your_xai_api_key_here # Grok models - -# Market Data APIs (for downloading price data) -BIRDEYE_API_KEY=your_birdeye_api_key_here # Solana token data -COINGECKO_API_KEY=your_coingecko_api_key_here # Crypto market data +export LLM_PROVIDER=ollama +export LLM_BASE_URL=http://localhost:11434 # Optional, defaults to localhost:11434 +export LLM_MODEL=llama2 ``` -**Where to Get API Keys:** -- **Anthropic Claude**: https://console.anthropic.com/ -- **OpenAI GPT**: https://platform.openai.com/api-keys -- **DeepSeek**: https://platform.deepseek.com/ (very cheap, great for backtesting) -- **Groq**: https://console.groq.com/ -- **Google Gemini**: https://aistudio.google.com/app/apikey -- **xAI Grok**: https://console.x.ai/ -- **BirdEye**: https://birdeye.so/ (Solana data) -- **CoinGecko**: https://www.coingecko.com/en/api +### Usage Example -⚠️ **Never commit or share your `.env` file! It's in .gitignore for your safety.** +```python +from src.llm import LLMClientFactory, ChatMessage -### Step 4: πŸ“¦ Install Dependencies +# Create client from environment variables +client = LLMClientFactory.create_from_env() -Using conda (recommended): -```bash -conda create -n tflow python=3.10.9 -conda activate tflow -pip install -r requirements.txt -``` +# Or create with explicit configuration +client = LLMClientFactory.create_client( + provider="groq", + config={"api_key": "your_api_key"} +) -Or using pip directly: -```bash -pip install -r requirements.txt +# Use the client +messages = [ChatMessage(role="user", content="Hello!")] +response = client.chat(messages, model="mixtral-8x7b-32768") +print(response.content) ``` -### Step 5: πŸ§ͺ Run Your First Backtest - -**Option A: Single Strategy Test** - -Create a file called `ideas.txt` in `src/data/rbi_pp_multi/`: +### Supported Providers -``` -Buy when RSI < 30 and sell when RSI > 70 -``` - -Then run: -```bash -python src/agents/rbi_agent_pp_multi.py -``` +- **OpenAI**: Official OpenAI API (GPT-4, GPT-3.5, etc.) +- **OpenAI-Compatible**: Any OpenAI-compatible API (xAI, DeepSeek, etc.) - set custom `base_url` +- **Groq**: Fast inference with Mixtral, Llama, and other models +- **Ollama**: Local model inference -**Option B: Use the Web Dashboard** +### Running Tests -Start the dashboard: ```bash -cd src/data/rbi_pp_multi -python app.py +pytest tests/test_llm_client.py ``` -Open browser to: `http://localhost:8000` - -Click "New Backtests" and enter your strategy ideas! - -### Step 6: πŸ“Š Understanding Results - -The agent will: -- Process your strategy idea -- Generate backtest code -- Test across 20+ market datasets (BTC, ETH, SOL, etc.) -- Show results in a table with: - - Return % - - Buy & Hold % - - Max Drawdown - - Sharpe Ratio - - Sortino Ratio - - Number of Trades - -**Only strategies returning > 1% are saved to the CSV.** - -Results are saved to: -- `src/data/rbi_pp_multi/backtest_stats.csv` - All passing backtests -- `src/data/rbi_pp_multi/user_folders/` - Organized by run name - -### Step 7: πŸ” Analyze Backtest Code - -Find your strategy files in: -``` -src/data/rbi_pp_multi/10_25_2025_09_08/ -``` - -Each successful backtest has: -- **Python file**: The actual backtest code you can review and modify -- **Results**: Performance metrics - -**Read the code!** This is how you learn what works and what doesn't. - --- -## 🎯 Configuration - RBI Agent - -All settings are in `src/agents/rbi_agent_pp_multi.py` (lines 130-132): - -```python -# 🎯 PROFIT TARGET CONFIGURATION -TARGET_RETURN = 50 # Target return in % (AI tries to optimize to this) -SAVE_IF_OVER_RETURN = 1.0 # Save backtest to CSV if return > this % -``` - -**How it works:** -- AI tries to optimize strategies to hit **50% return** -- But ANY backtest returning **> 1%** gets saved to CSV -- This way you can review all decent strategies, not just perfect ones - -**Other Settings:** -```python -MAX_WORKERS = 18 # Number of parallel threads (adjust based on your CPU) -DEBUG_BACKTEST_ERRORS = True # Auto-fix coding errors with AI -MAX_DEBUG_ITERATIONS = 10 # How many times to try fixing errors -``` - ---- - -## πŸ“š Advanced: Adding Custom Data Sources - -Want to test on your own tokens? Edit the data list in `rbi_agent_pp_multi.py` (lines 157-178): - -```python -ALL_DATA_CONFIGS = [ - # Crypto data from CoinGecko/BirdEye - {'symbol': 'BTC-USD', 'timeframe': '15m', 'days_back': 90}, - {'symbol': 'ETH-USD', 'timeframe': '15m', 'days_back': 90}, - {'symbol': 'SOL-USD', 'timeframe': '15m', 'days_back': 90}, - - # Add your own token (Solana contract address) - {'symbol': 'YOUR_TOKEN_ADDRESS', 'timeframe': '1H', 'days_back': 30}, -] -``` - -The agent will automatically download and cache the data. - - ---- - -## πŸ—ΊοΈ ROADMAP - -### In Progress -- [x] **HyperLiquid Perps Integration** βœ… -- [x] **Swarm Consensus Trading** βœ… -- [x] **RBI Parallel Backtesting** βœ… - -### Coming Soon -- [ ] **Polymarket Integration** - Prediction market trading -- [ ] **Base Chain Integration** - L2 network support -- [ ] **Extended Integration** - Additional exchange support -- [ ] **HyperLiquid Spot Trading** - Spot market support -- [ ] **Trending Agent** - Spots leaders on HyperLiquid -- [ ] **Position Sizing Agent** - Volume/liquidation-based sizing -- [ ] **Regime Agents** - Adaptive strategy switching -- [ ] **Polymarket Sweeper Agent** - Follow successful prediction traders - -### Future Ideas -- [ ] **Lighter Integration** -- [ ] **Pacifica Integration** -- [ ] **Hibachi Integration** -- [ ] **Aster Integration** -- [ ] **HyperEVM Support** - ---- - -*Built with love by Moon Dev - Pioneering the future of AI-powered trading* - -## πŸ“œ Detailed Disclaimer -The content presented is for educational and informational purposes only and does not constitute financial advice. All trading involves risk and may not be suitable for all investors. You should carefully consider your investment objectives, level of experience, and risk appetite before investing. - -Past performance is not indicative of future results. There is no guarantee that any trading strategy or algorithm discussed will result in profits or will not incur losses. - -**CFTC Disclaimer:** Commodity Futures Trading Commission (CFTC) regulations require disclosure of the risks associated with trading commodities and derivatives. There is a substantial risk of loss in trading and investing. - -I am not a licensed financial advisor or a registered broker-dealer. Content & code is based on personal research perspectives and should not be relied upon as a guarantee of success in trading. +*For more information about the trading agents and other features, please refer to the sections below.* diff --git a/src/llm/__init__.py b/src/llm/__init__.py new file mode 100644 index 00000000..6f810111 --- /dev/null +++ b/src/llm/__init__.py @@ -0,0 +1,27 @@ +"""Provider-agnostic LLM client package.""" + +from .client import ( + LLMClient, + ChatMessage, + ChatResponse, + EmbeddingResponse, + TTSResponse +) +from .adapters import ( + OpenAICompatibleAdapter, + GroqAdapter, + OllamaAdapter +) +from .factory import LLMClientFactory + +__all__ = [ + "LLMClient", + "ChatMessage", + "ChatResponse", + "EmbeddingResponse", + "TTSResponse", + "OpenAICompatibleAdapter", + "GroqAdapter", + "OllamaAdapter", + "LLMClientFactory" +] diff --git a/src/llm/adapters.py b/src/llm/adapters.py new file mode 100644 index 00000000..bebb3486 --- /dev/null +++ b/src/llm/adapters.py @@ -0,0 +1,231 @@ +"""Concrete LLM client adapters for different providers.""" + +import os +from typing import List, Optional, Dict, Any +from openai import OpenAI +from groq import Groq as GroqClient +import requests + +from .client import ( + LLMClient, + ChatMessage, + ChatResponse, + EmbeddingResponse, + TTSResponse +) + + +class OpenAICompatibleAdapter(LLMClient): + """Adapter for OpenAI-compatible APIs (OpenAI, DeepSeek, xAI, etc.).""" + + def _validate_config(self) -> None: + """Validate the configuration.""" + if "api_key" not in self.config: + raise ValueError("OpenAI-compatible adapter requires 'api_key' in config") + + def __init__(self, config: Dict[str, Any]): + super().__init__(config) + self.client = OpenAI( + api_key=self.config["api_key"], + base_url=self.config.get("base_url") # Optional custom endpoint + ) + + def chat(self, + messages: List[ChatMessage], + model: str, + temperature: float = 0.7, + max_tokens: Optional[int] = None, + **kwargs) -> ChatResponse: + """Generate a chat completion using OpenAI-compatible API.""" + try: + response = self.client.chat.completions.create( + model=model, + messages=[{"role": msg.role, "content": msg.content} for msg in messages], + temperature=temperature, + max_tokens=max_tokens, + **kwargs + ) + + return ChatResponse( + content=response.choices[0].message.content, + model=response.model, + finish_reason=response.choices[0].finish_reason, + usage={ + "prompt_tokens": response.usage.prompt_tokens, + "completion_tokens": response.usage.completion_tokens, + "total_tokens": response.usage.total_tokens + } if response.usage else None + ) + except Exception as e: + raise RuntimeError(f"Chat completion failed: {str(e)}") + + def embed(self, + text: str, + model: str, + **kwargs) -> EmbeddingResponse: + """Generate an embedding using OpenAI-compatible API.""" + try: + response = self.client.embeddings.create( + model=model, + input=text, + **kwargs + ) + + return EmbeddingResponse( + embedding=response.data[0].embedding, + model=response.model + ) + except Exception as e: + raise RuntimeError(f"Embedding generation failed: {str(e)}") + + def tts(self, + text: str, + voice: str = "alloy", + **kwargs) -> TTSResponse: + """Generate text-to-speech audio (OpenAI only).""" + try: + response = self.client.audio.speech.create( + model=kwargs.get("model", "tts-1"), + voice=voice, + input=text + ) + + return TTSResponse( + audio_data=response.content, + format=kwargs.get("response_format", "mp3") + ) + except Exception as e: + raise RuntimeError(f"TTS generation failed: {str(e)}") + + @property + def provider_name(self) -> str: + """Return the provider name.""" + return self.config.get("provider_name", "OpenAI-Compatible") + + +class GroqAdapter(LLMClient): + """Adapter for Groq API.""" + + def _validate_config(self) -> None: + """Validate the configuration.""" + if "api_key" not in self.config: + raise ValueError("Groq adapter requires 'api_key' in config") + + def __init__(self, config: Dict[str, Any]): + super().__init__(config) + self.client = GroqClient(api_key=self.config["api_key"]) + + def chat(self, + messages: List[ChatMessage], + model: str, + temperature: float = 0.7, + max_tokens: Optional[int] = None, + **kwargs) -> ChatResponse: + """Generate a chat completion using Groq API.""" + try: + response = self.client.chat.completions.create( + model=model, + messages=[{"role": msg.role, "content": msg.content} for msg in messages], + temperature=temperature, + max_tokens=max_tokens, + **kwargs + ) + + return ChatResponse( + content=response.choices[0].message.content, + model=response.model, + finish_reason=response.choices[0].finish_reason, + usage={ + "prompt_tokens": response.usage.prompt_tokens, + "completion_tokens": response.usage.completion_tokens, + "total_tokens": response.usage.total_tokens + } if response.usage else None + ) + except Exception as e: + raise RuntimeError(f"Chat completion failed: {str(e)}") + + def embed(self, + text: str, + model: str, + **kwargs) -> EmbeddingResponse: + """Groq doesn't currently support embeddings.""" + raise NotImplementedError("Groq adapter does not support embeddings") + + @property + def provider_name(self) -> str: + """Return the provider name.""" + return "Groq" + + +class OllamaAdapter(LLMClient): + """Adapter for Ollama local API.""" + + def _validate_config(self) -> None: + """Validate the configuration.""" + # Ollama doesn't require an API key + pass + + def __init__(self, config: Dict[str, Any]): + super().__init__(config) + self.base_url = self.config.get("base_url", "http://localhost:11434") + + def chat(self, + messages: List[ChatMessage], + model: str, + temperature: float = 0.7, + max_tokens: Optional[int] = None, + **kwargs) -> ChatResponse: + """Generate a chat completion using Ollama API.""" + try: + url = f"{self.base_url}/api/chat" + payload = { + "model": model, + "messages": [{"role": msg.role, "content": msg.content} for msg in messages], + "stream": False, + "options": { + "temperature": temperature, + } + } + + if max_tokens: + payload["options"]["num_predict"] = max_tokens + + response = requests.post(url, json=payload) + response.raise_for_status() + data = response.json() + + return ChatResponse( + content=data["message"]["content"], + model=data["model"], + finish_reason=data.get("done_reason", "stop") + ) + except Exception as e: + raise RuntimeError(f"Chat completion failed: {str(e)}") + + def embed(self, + text: str, + model: str, + **kwargs) -> EmbeddingResponse: + """Generate an embedding using Ollama API.""" + try: + url = f"{self.base_url}/api/embeddings" + payload = { + "model": model, + "prompt": text + } + + response = requests.post(url, json=payload) + response.raise_for_status() + data = response.json() + + return EmbeddingResponse( + embedding=data["embedding"], + model=model + ) + except Exception as e: + raise RuntimeError(f"Embedding generation failed: {str(e)}") + + @property + def provider_name(self) -> str: + """Return the provider name.""" + return "Ollama" diff --git a/src/llm/client.py b/src/llm/client.py new file mode 100644 index 00000000..d7767ea8 --- /dev/null +++ b/src/llm/client.py @@ -0,0 +1,116 @@ +"""Provider-agnostic LLM client interface.""" + +from abc import ABC, abstractmethod +from typing import Optional, Dict, List, Any +from dataclasses import dataclass + + +@dataclass +class ChatMessage: + """Represents a chat message.""" + role: str + content: str + + +@dataclass +class ChatResponse: + """Represents a chat completion response.""" + content: str + model: str + finish_reason: Optional[str] = None + usage: Optional[Dict[str, int]] = None + + +@dataclass +class EmbeddingResponse: + """Represents an embedding response.""" + embedding: List[float] + model: str + + +@dataclass +class TTSResponse: + """Represents a text-to-speech response.""" + audio_data: bytes + format: str = "mp3" + + +class LLMClient(ABC): + """Abstract base class for LLM provider clients.""" + + def __init__(self, config: Dict[str, Any]): + """Initialize the client with configuration. + + Args: + config: Provider-specific configuration (api_key, base_url, etc.) + """ + self.config = config + self._validate_config() + + @abstractmethod + def _validate_config(self) -> None: + """Validate the configuration for this provider.""" + pass + + @abstractmethod + def chat(self, + messages: List[ChatMessage], + model: str, + temperature: float = 0.7, + max_tokens: Optional[int] = None, + **kwargs) -> ChatResponse: + """Generate a chat completion. + + Args: + messages: List of chat messages + model: Model identifier + temperature: Sampling temperature (0.0-1.0) + max_tokens: Maximum tokens to generate + **kwargs: Additional provider-specific parameters + + Returns: + ChatResponse with the generated content + """ + pass + + @abstractmethod + def embed(self, + text: str, + model: str, + **kwargs) -> EmbeddingResponse: + """Generate an embedding for the given text. + + Args: + text: Text to embed + model: Model identifier + **kwargs: Additional provider-specific parameters + + Returns: + EmbeddingResponse with the embedding vector + """ + pass + + def tts(self, + text: str, + voice: str = "alloy", + **kwargs) -> TTSResponse: + """Generate text-to-speech audio (optional method). + + Args: + text: Text to convert to speech + voice: Voice identifier + **kwargs: Additional provider-specific parameters + + Returns: + TTSResponse with audio data + + Raises: + NotImplementedError: If TTS is not supported by this provider + """ + raise NotImplementedError(f"{self.__class__.__name__} does not support TTS") + + @property + @abstractmethod + def provider_name(self) -> str: + """Return the name of this provider.""" + pass diff --git a/src/llm/factory.py b/src/llm/factory.py new file mode 100644 index 00000000..f104d2e3 --- /dev/null +++ b/src/llm/factory.py @@ -0,0 +1,87 @@ +"""Factory for creating LLM clients based on configuration.""" + +import os +from typing import Dict, Any +from .client import LLMClient +from .adapters import OpenAICompatibleAdapter, GroqAdapter, OllamaAdapter + + +class LLMClientFactory: + """Factory for creating provider-specific LLM clients.""" + + # Map provider names to adapter classes + PROVIDERS = { + "openai": OpenAICompatibleAdapter, + "openai-compatible": OpenAICompatibleAdapter, + "groq": GroqAdapter, + "ollama": OllamaAdapter, + } + + @staticmethod + def create_client(provider: str, config: Dict[str, Any]) -> LLMClient: + """Create an LLM client for the specified provider. + + Args: + provider: Provider name (openai, groq, ollama, etc.) + config: Provider-specific configuration + + Returns: + LLMClient instance for the provider + + Raises: + ValueError: If provider is not supported + """ + provider_lower = provider.lower() + + if provider_lower not in LLMClientFactory.PROVIDERS: + supported = ", ".join(LLMClientFactory.PROVIDERS.keys()) + raise ValueError( + f"Unsupported provider: {provider}. " + f"Supported providers: {supported}" + ) + + adapter_class = LLMClientFactory.PROVIDERS[provider_lower] + return adapter_class(config) + + @staticmethod + def create_from_env(provider: str = None) -> LLMClient: + """Create an LLM client from environment variables. + + Args: + provider: Provider name. If None, uses LLM_PROVIDER env var + + Returns: + LLM client configured from environment + + Environment Variables: + LLM_PROVIDER: Provider name (openai, groq, ollama) + LLM_API_KEY: API key for the provider + LLM_BASE_URL: Custom base URL (optional, for OpenAI-compatible APIs) + LLM_MODEL: Default model to use + """ + if provider is None: + provider = os.getenv("LLM_PROVIDER", "openai") + + config = {} + + # Get API key from environment (required for most providers except Ollama) + api_key = os.getenv("LLM_API_KEY") + if api_key: + config["api_key"] = api_key + + # Get custom base URL if provided + base_url = os.getenv("LLM_BASE_URL") + if base_url: + config["base_url"] = base_url + + # Get default model + model = os.getenv("LLM_MODEL") + if model: + config["default_model"] = model + + return LLMClientFactory.create_client(provider, config) + + @staticmethod + def get_supported_providers() -> list: + """Return list of supported provider names.""" + return list(LLMClientFactory.PROVIDERS.keys()) diff --git a/tests/test_llm_client.py b/tests/test_llm_client.py new file mode 100644 index 00000000..728c0f12 --- /dev/null +++ b/tests/test_llm_client.py @@ -0,0 +1,112 @@ +"""Basic tests for the LLM client adapters.""" + +import pytest +from unittest.mock import Mock, patch +from src.llm import LLMClientFactory, ChatMessage, ChatResponse + + +def test_factory_create_openai_client(): + """Test creating an OpenAI-compatible client via factory.""" + config = {"api_key": "test-key"} + client = LLMClientFactory.create_client("openai", config) + assert client is not None + assert client.provider_name == "OpenAI-Compatible" + + +def test_factory_create_groq_client(): + """Test creating a Groq client via factory.""" + config = {"api_key": "test-key"} + client = LLMClientFactory.create_client("groq", config) + assert client is not None + assert client.provider_name == "Groq" + + +def test_factory_create_ollama_client(): + """Test creating an Ollama client via factory.""" + config = {} + client = LLMClientFactory.create_client("ollama", config) + assert client is not None + assert client.provider_name == "Ollama" + + +def test_factory_unsupported_provider(): + """Test that unsupported provider raises ValueError.""" + with pytest.raises(ValueError, match="Unsupported provider"): + LLMClientFactory.create_client("unsupported", {}) + + +def test_factory_get_supported_providers(): + """Test getting list of supported providers.""" + providers = LLMClientFactory.get_supported_providers() + assert "openai" in providers + assert "groq" in providers + assert "ollama" in providers + + +@patch("src.llm.adapters.OpenAI") +def test_openai_chat(mock_openai_class): + """Test chat completion with OpenAI adapter.""" + # Mock the OpenAI client + mock_client = Mock() + mock_openai_class.return_value = mock_client + + # Mock the response + mock_response = Mock() + mock_response.choices = [Mock(message=Mock(content="Test response"), finish_reason="stop")] + mock_response.model = "gpt-4" + mock_response.usage = Mock(prompt_tokens=10, completion_tokens=20, total_tokens=30) + mock_client.chat.completions.create.return_value = mock_response + + # Create client and test chat + from src.llm.adapters import OpenAICompatibleAdapter + client = OpenAICompatibleAdapter({"api_key": "test-key"}) + + messages = [ChatMessage(role="user", content="Hello")] + response = client.chat(messages, "gpt-4") + + assert response.content == "Test response" + assert response.model == "gpt-4" + assert response.finish_reason == "stop" + assert response.usage["total_tokens"] == 30 + + +@patch("src.llm.adapters.requests.post") +def test_ollama_chat(mock_post): + """Test chat completion with Ollama adapter.""" + # Mock the response + mock_response = Mock() + mock_response.json.return_value = { + "message": {"content": "Test response"}, + "model": "llama2", + "done_reason": "stop" + } + mock_post.return_value = mock_response + + # Create client and test chat + from src.llm.adapters import OllamaAdapter + client = OllamaAdapter({}) + + messages = [ChatMessage(role="user", content="Hello")] + response = client.chat(messages, "llama2") + + assert response.content == "Test response" + assert response.model == "llama2" + assert response.finish_reason == "stop" + + +@patch("src.llm.adapters.OpenAI") +def test_openai_chat_error_handling(mock_openai_class): + """Test error handling in chat completion.""" + # Mock the OpenAI client to raise an exception + mock_client = Mock() + mock_openai_class.return_value = mock_client + mock_client.chat.completions.create.side_effect = Exception("API Error") + + # Create client and test error handling + from src.llm.adapters import OpenAICompatibleAdapter + client = OpenAICompatibleAdapter({"api_key": "test-key"}) + + messages = [ChatMessage(role="user", content="Hello")] + + with pytest.raises(RuntimeError, match="Chat completion failed"): + client.chat(messages, "gpt-4")