diff --git a/README.md b/README.md index 45523d7..05d03eb 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,7 @@ ## 特性 - **无框架设计**: 从零实现,不依赖LangChain等重型框架 -- **多LLM支持**: 支持DeepSeek、OpenAI等主流大语言模型 +- **多LLM支持**: 支持各种模型服务商提供的LLM,如DeepSeek、OpenAI、MiniMax等 - **智能搜索**: 集成Tavily搜索引擎,提供高质量网络搜索 - **反思机制**: 多轮反思优化,确保研究深度和完整性 - **状态管理**: 完整的研究过程状态跟踪和恢复 @@ -66,14 +66,15 @@ python --version ```bash git clone -cd Demo\ DeepSearch\ Agent +cd DeepSearchAgent-Demo ``` ### 3. 安装依赖 ```bash -# 激活虚拟环境(推荐) -conda activate pytorch_python11 # 或者使用其他虚拟环境 +# 使用conda创建并激活虚拟环境 +conda create -n deepresearch python=3.11 -y +conda activate deepresearch # 安装依赖 pip install -r requirements.txt @@ -81,25 +82,19 @@ pip install -r requirements.txt ### 4. 配置API密钥 -项目根目录下已有`config.py`配置文件,请直接编辑此文件设置您的API密钥: +请复制项目根目录下的`config.py.example`配置文件改名为`config.py`,请直接编辑此文件设置您使用服务商的base url,API Key,模型名称以及Tavily搜索API Key: ```python # Deep Search Agent 配置文件 -# 请在这里填入您的API密钥 -# DeepSeek API Key -DEEPSEEK_API_KEY = "your_deepseek_api_key_here" +# 请在这里填入服务商的Base URL、API Key和模型名称 +BASE_URL = "" +API_KEY = "" +MODEL_NAME = "" -# OpenAI API Key (可选) -OPENAI_API_KEY = "your_openai_api_key_here" +# 请在这里填入Tavily搜索API Key,前往https://app.tavily.com/home获取 +TAVILY_API_KEY = "" -# Tavily搜索API Key -TAVILY_API_KEY = "your_tavily_api_key_here" - -# 配置参数 -DEFAULT_LLM_PROVIDER = "deepseek" -DEEPSEEK_MODEL = "deepseek-chat" -OPENAI_MODEL = "gpt-4o-mini" MAX_REFLECTIONS = 2 SEARCH_RESULTS_PER_QUERY = 3 @@ -167,15 +162,15 @@ from src import DeepSearchAgent, Config # 自定义配置 config = Config( - default_llm_provider="deepseek", - deepseek_model="deepseek-chat", max_reflections=3, # 增加反思次数 max_search_results=5, # 增加搜索结果数 output_dir="my_reports" # 自定义输出目录 ) # 设置API密钥 -config.deepseek_api_key = "your_api_key" +config.base_url = "base_url" +config.api_key = "your_api_key" +config.model = "model_name" config.tavily_api_key = "your_tavily_key" agent = DeepSearchAgent(config) @@ -188,8 +183,7 @@ Demo DeepSearch Agent/ ├── src/ # 核心代码 │ ├── llms/ # LLM调用模块 │ │ ├── base.py # LLM基类 -│ │ ├── deepseek.py # DeepSeek实现 -│ │ └── openai_llm.py # OpenAI实现 +│ │ ├── llm.py # 通用LLM实现 │ ├── nodes/ # 处理节点 │ │ ├── base_node.py # 节点基类 │ │ ├── report_structure_node.py # 结构生成 @@ -341,15 +335,12 @@ class DeepSearchAgent: ```python class Config: - # API密钥 - deepseek_api_key: Optional[str] - openai_api_key: Optional[str] - tavily_api_key: Optional[str] + """配置类""" + base_url: Optional[str] = None + api_key: Optional[str] = None + tavily_api_key: Optional[str] = None - # 模型配置 - default_llm_provider: str = "deepseek" - deepseek_model: str = "deepseek-chat" - openai_model: str = "gpt-4o-mini" + model: str = "" # 搜索配置 max_search_results: int = 3 @@ -412,16 +403,6 @@ print(f"研究进度: {progress['progress_percentage']}%") ## 高级功能 -### 多模型支持 - -```python -# 使用DeepSeek -config = Config(default_llm_provider="deepseek") - -# 使用OpenAI -config = Config(default_llm_provider="openai", openai_model="gpt-4o") -``` - ### 自定义输出 ```python @@ -435,10 +416,7 @@ config = Config( ### Q: 支持哪些LLM? -A: 目前支持: -- **DeepSeek**: 推荐使用,性价比高 -- **OpenAI**: GPT-4o、GPT-4o-mini等 -- 可以通过继承`BaseLLM`类轻松添加其他模型 +A: 支持各种模型服务商提供的LLM,如DeepSeek、OpenAI、MiniMax等。只需在配置文件中设置相应的Base URL、API Key和模型名称即可。 ### Q: 如何获取API密钥? @@ -446,6 +424,7 @@ A: - **DeepSeek**: 访问 [DeepSeek平台](https://platform.deepseek.com/) 注册获取 - **Tavily**: 访问 [Tavily](https://tavily.com/) 注册获取(每月1000次免费) - **OpenAI**: 访问 [OpenAI平台](https://platform.openai.com/) 获取 +- **MiniMax**: 访问 [MiniMax AI](https://platform.minimaxi.com/user-center/basic-information/interface-key) 获取 获取密钥后,直接编辑项目根目录的`config.py`文件填入即可。 @@ -481,9 +460,8 @@ A: 当前主要支持Tavily,但可以通过修改`src/tools/search.py`添加 ## 致谢 -- 感谢 [DeepSeek](https://www.deepseek.com/) 提供优秀的LLM服务 +- 感谢 [DeepSeek](https://www.deepseek.com/) [MiniMax](https://minimaxi.com/) 提供优秀的LLM服务 - 感谢 [Tavily](https://tavily.com/) 提供高质量的搜索API - --- 如果这个项目对您有帮助,请给个Star! diff --git a/config.py.example b/config.py.example new file mode 100644 index 0000000..9181da7 --- /dev/null +++ b/config.py.example @@ -0,0 +1,16 @@ +# Deep Search Agent 配置文件 + +# 请在这里填入服务商的Base URL、API Key和模型名称 +BASE_URL = "" +API_KEY = "" +MODEL_NAME = "" + +# 请在这里填入Tavily搜索API Key,前往https://app.tavily.com/home获取 +TAVILY_API_KEY = "" + + +MAX_REFLECTIONS = 2 +SEARCH_RESULTS_PER_QUERY = 3 +SEARCH_CONTENT_MAX_LENGTH = 20000 +OUTPUT_DIR = "reports" +SAVE_INTERMEDIATE_STATES = True diff --git a/examples/advanced_usage.py b/examples/advanced_usage.py index fa444d4..1b211e6 100644 --- a/examples/advanced_usage.py +++ b/examples/advanced_usage.py @@ -9,7 +9,7 @@ # 添加项目根目录到Python路径 sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) -from src import DeepSearchAgent, Config +from src import DeepSearchAgent, Config, load_config from src.utils.config import print_config @@ -22,22 +22,30 @@ def advanced_example(): try: # 自定义配置 print("正在创建自定义配置...") - config = Config( - # 使用OpenAI而不是DeepSeek - default_llm_provider="openai", - openai_model="gpt-4o-mini", - # 自定义搜索参数 - max_search_results=5, # 更多搜索结果 - max_reflections=3, # 更多反思次数 - max_content_length=15000, - # 自定义输出 - output_dir="custom_reports", - save_intermediate_states=True - ) - + config = load_config() + config.max_search_results = 5 + config.max_reflections = 3 + config.max_content_length = 15000 + config.output_dir = "custom_reports" + config.save_intermediate_states = True + # config = Config( + # # 填入您使用的模型服务商的 base url, api key和模型名称 + # # base_url="...", + # # api_key="", # 您的API密钥 + # # model="...", + # # 自定义搜索参数 + # max_search_results=5, # 更多搜索结果 + # max_reflections=3, # 更多反思次数 + # max_content_length=15000, + # # 自定义输出 + # output_dir="custom_reports", + # save_intermediate_states=True + # ) # 从环境变量设置API密钥 - config.openai_api_key = os.getenv("OPENAI_API_KEY") - config.tavily_api_key = os.getenv("TAVILY_API_KEY") + if not config.api_key: + config.api_key = os.getenv("API_KEY") + if not config.tavily_api_key: + config.tavily_api_key = os.getenv("TAVILY_API_KEY") if not config.validate(): print("配置验证失败,请检查API密钥设置") @@ -51,9 +59,9 @@ def advanced_example(): # 执行多个研究任务 queries = [ - "深度学习在医疗领域的应用", - "区块链技术的最新发展", - "可持续能源技术趋势" + "深度学习在三维重建领域的应用", + # "中国基座大模型的最新发展", + ] for i, query in enumerate(queries, 1): @@ -96,7 +104,7 @@ def state_management_example(): try: # 创建配置 - config = Config.from_env() + config = load_config() if not config.validate(): print("配置验证失败") return @@ -104,14 +112,14 @@ def state_management_example(): # 创建Agent agent = DeepSearchAgent(config) - query = "量子计算的发展现状" + query = "SAM模型的技术细节与应用" print(f"开始研究: {query}") # 执行研究 final_report = agent.research(query) # 保存状态 - state_file = "custom_reports/quantum_computing_state.json" + state_file = "custom_reports/sam_model_state.json" agent.save_state(state_file) print(f"状态已保存到: {state_file}") diff --git a/src/agent.py b/src/agent.py index c01b5f0..92a9310 100644 --- a/src/agent.py +++ b/src/agent.py @@ -8,7 +8,7 @@ from datetime import datetime from typing import Optional, Dict, Any, List -from .llms import DeepSeekLLM, OpenAILLM, BaseLLM +from .llms import LLM, BaseLLM from .nodes import ( ReportStructureNode, FirstSearchNode, @@ -52,18 +52,7 @@ def __init__(self, config: Optional[Config] = None): def _initialize_llm(self) -> BaseLLM: """初始化LLM客户端""" - if self.config.default_llm_provider == "deepseek": - return DeepSeekLLM( - api_key=self.config.deepseek_api_key, - model_name=self.config.deepseek_model - ) - elif self.config.default_llm_provider == "openai": - return OpenAILLM( - api_key=self.config.openai_api_key, - model_name=self.config.openai_model - ) - else: - raise ValueError(f"不支持的LLM提供商: {self.config.default_llm_provider}") + return LLM(self.config.api_key, self.config.base_url, self.config.model) def _initialize_nodes(self): """初始化处理节点""" diff --git a/src/llms/__init__.py b/src/llms/__init__.py index 2e1602a..b299811 100644 --- a/src/llms/__init__.py +++ b/src/llms/__init__.py @@ -4,7 +4,6 @@ """ from .base import BaseLLM -from .deepseek import DeepSeekLLM -from .openai_llm import OpenAILLM +from .llm import LLM -__all__ = ["BaseLLM", "DeepSeekLLM", "OpenAILLM"] +__all__ = ["BaseLLM", "LLM"] diff --git a/src/llms/base.py b/src/llms/base.py index 9c6f99f..2687592 100644 --- a/src/llms/base.py +++ b/src/llms/base.py @@ -10,7 +10,7 @@ class BaseLLM(ABC): """LLM基础抽象类""" - def __init__(self, api_key: str, model_name: Optional[str] = None): + def __init__(self, api_key: str, base_url: Optional[str] = None, model_name: Optional[str] = None): """ 初始化LLM客户端 @@ -19,6 +19,7 @@ def __init__(self, api_key: str, model_name: Optional[str] = None): model_name: 模型名称,如果不指定则使用默认模型 """ self.api_key = api_key + self.base_url = base_url self.model_name = model_name @abstractmethod @@ -36,16 +37,6 @@ def invoke(self, system_prompt: str, user_prompt: str, **kwargs) -> str: """ pass - @abstractmethod - def get_default_model(self) -> str: - """ - 获取默认模型名称 - - Returns: - 默认模型名称 - """ - pass - def validate_response(self, response: str) -> str: """ 验证和清理响应内容 diff --git a/src/llms/deepseek.py b/src/llms/llm.py similarity index 66% rename from src/llms/deepseek.py rename to src/llms/llm.py index 9e6d96d..c2be241 100644 --- a/src/llms/deepseek.py +++ b/src/llms/llm.py @@ -1,6 +1,5 @@ """ -DeepSeek LLM实现 -使用DeepSeek API进行文本生成 +支持OpenAI接口格式的通用LLM实现 """ import os @@ -8,36 +7,36 @@ from openai import OpenAI from .base import BaseLLM - -class DeepSeekLLM(BaseLLM): - """DeepSeek LLM实现类""" +class LLM(BaseLLM): + """通用LLM实现类""" - def __init__(self, api_key: Optional[str] = None, model_name: Optional[str] = None): + def __init__(self, api_key: Optional[str] = None, base_url: Optional[str] = None, model_name: Optional[str] = None): """ - 初始化DeepSeek客户端 + 初始化LLM客户端 Args: - api_key: DeepSeek API密钥,如果不提供则从环境变量读取 + api_key: API密钥,如果不提供则从环境变量读取API_KEY + base_url: API基础URL,默认使用DeepSeek的URL model_name: 模型名称,默认使用deepseek-chat """ if api_key is None: - api_key = os.getenv("DEEPSEEK_API_KEY") - if not api_key: - raise ValueError("DeepSeek API Key未找到!请设置DEEPSEEK_API_KEY环境变量或在初始化时提供") + raise ValueError("API Key未找到!请在config.py或.env文件中设置API_KEY。") - super().__init__(api_key, model_name) + if base_url is None: + raise ValueError("Base URL未找到!请在config.py或.env文件中设置BASE_URL。") + + super().__init__(api_key, base_url, model_name) # 初始化OpenAI客户端,使用DeepSeek的endpoint self.client = OpenAI( api_key=self.api_key, - base_url="https://api.deepseek.com" + base_url=self.base_url ) - - self.default_model = model_name or self.get_default_model() + if self.model_name: + self.default_model = self.model_name + else: + raise ValueError("模型名称未找到!请在config.py或.env文件中设置MODEL_NAME。") - def get_default_model(self) -> str: - """获取默认模型名称""" - return "deepseek-chat" def invoke(self, system_prompt: str, user_prompt: str, **kwargs) -> str: """ @@ -78,7 +77,7 @@ def invoke(self, system_prompt: str, user_prompt: str, **kwargs) -> str: return "" except Exception as e: - print(f"DeepSeek API调用错误: {str(e)}") + print(f"{self.base_url} API调用错误: {str(e)}") raise e def get_model_info(self) -> Dict[str, Any]: @@ -89,7 +88,6 @@ def get_model_info(self) -> Dict[str, Any]: 模型信息字典 """ return { - "provider": "DeepSeek", "model": self.default_model, - "api_base": "https://api.deepseek.com" + "base_url": self.base_url, } diff --git a/src/llms/openai_llm.py b/src/llms/openai_llm.py deleted file mode 100644 index fff0e79..0000000 --- a/src/llms/openai_llm.py +++ /dev/null @@ -1,90 +0,0 @@ -""" -OpenAI LLM实现 -使用OpenAI API进行文本生成 -""" - -import os -from typing import Optional, Dict, Any -from openai import OpenAI -from .base import BaseLLM - - -class OpenAILLM(BaseLLM): - """OpenAI LLM实现类""" - - def __init__(self, api_key: Optional[str] = None, model_name: Optional[str] = None): - """ - 初始化OpenAI客户端 - - Args: - api_key: OpenAI API密钥,如果不提供则从环境变量读取 - model_name: 模型名称,默认使用gpt-4o-mini - """ - if api_key is None: - api_key = os.getenv("OPENAI_API_KEY") - if not api_key: - raise ValueError("OpenAI API Key未找到!请设置OPENAI_API_KEY环境变量或在初始化时提供") - - super().__init__(api_key, model_name) - - # 初始化OpenAI客户端 - self.client = OpenAI(api_key=self.api_key) - self.default_model = model_name or self.get_default_model() - - def get_default_model(self) -> str: - """获取默认模型名称""" - return "gpt-4o-mini" - - def invoke(self, system_prompt: str, user_prompt: str, **kwargs) -> str: - """ - 调用OpenAI API生成回复 - - Args: - system_prompt: 系统提示词 - user_prompt: 用户输入 - **kwargs: 其他参数,如temperature、max_tokens等 - - Returns: - OpenAI生成的回复文本 - """ - try: - # 构建消息 - messages = [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": user_prompt} - ] - - # 设置默认参数 - params = { - "model": self.default_model, - "messages": messages, - "temperature": kwargs.get("temperature", 0.7), - "max_tokens": kwargs.get("max_tokens", 4000) - } - - # 调用API - response = self.client.chat.completions.create(**params) - - # 提取回复内容 - if response.choices and response.choices[0].message: - content = response.choices[0].message.content - return self.validate_response(content) - else: - return "" - - except Exception as e: - print(f"OpenAI API调用错误: {str(e)}") - raise e - - def get_model_info(self) -> Dict[str, Any]: - """ - 获取当前模型信息 - - Returns: - 模型信息字典 - """ - return { - "provider": "OpenAI", - "model": self.default_model, - "api_base": "https://api.openai.com" - } diff --git a/src/utils/config.py b/src/utils/config.py index 3a1ab5d..403774a 100644 --- a/src/utils/config.py +++ b/src/utils/config.py @@ -7,19 +7,17 @@ from dataclasses import dataclass from typing import Optional +from openai import base_url + @dataclass class Config: """配置类""" - # API密钥 - deepseek_api_key: Optional[str] = None - openai_api_key: Optional[str] = None + base_url: Optional[str] = None + api_key: Optional[str] = None tavily_api_key: Optional[str] = None - # 模型配置 - default_llm_provider: str = "deepseek" # deepseek 或 openai - deepseek_model: str = "deepseek-chat" - openai_model: str = "gpt-4o-mini" + model: str = "" # 搜索配置 max_search_results: int = 3 @@ -37,12 +35,8 @@ class Config: def validate(self) -> bool: """验证配置""" # 检查必需的API密钥 - if self.default_llm_provider == "deepseek" and not self.deepseek_api_key: - print("错误: DeepSeek API Key未设置") - return False - - if self.default_llm_provider == "openai" and not self.openai_api_key: - print("错误: OpenAI API Key未设置") + if not self.api_key: + print("错误: API Key未设置") return False if not self.tavily_api_key: @@ -64,12 +58,10 @@ def from_file(cls, config_file: str) -> "Config": spec.loader.exec_module(config_module) return cls( - deepseek_api_key=getattr(config_module, "DEEPSEEK_API_KEY", None), - openai_api_key=getattr(config_module, "OPENAI_API_KEY", None), + base_url=getattr(config_module, "BASE_URL", None), + api_key=getattr(config_module, "API_KEY", None), tavily_api_key=getattr(config_module, "TAVILY_API_KEY", None), - default_llm_provider=getattr(config_module, "DEFAULT_LLM_PROVIDER", "deepseek"), - deepseek_model=getattr(config_module, "DEEPSEEK_MODEL", "deepseek-chat"), - openai_model=getattr(config_module, "OPENAI_MODEL", "gpt-4o-mini"), + model=getattr(config_module, "MODEL_NAME", "deepseek-chat"), max_search_results=getattr(config_module, "SEARCH_RESULTS_PER_QUERY", 3), search_timeout=getattr(config_module, "SEARCH_TIMEOUT", 240), max_content_length=getattr(config_module, "SEARCH_CONTENT_MAX_LENGTH", 20000), @@ -91,12 +83,10 @@ def from_file(cls, config_file: str) -> "Config": config_dict[key.strip()] = value.strip() return cls( - deepseek_api_key=config_dict.get("DEEPSEEK_API_KEY"), - openai_api_key=config_dict.get("OPENAI_API_KEY"), + base_url=config_dict.get("BASE_URL"), + api_key=config_dict.get("API_KEY"), tavily_api_key=config_dict.get("TAVILY_API_KEY"), - default_llm_provider=config_dict.get("DEFAULT_LLM_PROVIDER", "deepseek"), - deepseek_model=config_dict.get("DEEPSEEK_MODEL", "deepseek-chat"), - openai_model=config_dict.get("OPENAI_MODEL", "gpt-4o-mini"), + model=config_dict.get("MODEL_NAME", "deepseek-chat"), max_search_results=int(config_dict.get("SEARCH_RESULTS_PER_QUERY", "3")), search_timeout=int(config_dict.get("SEARCH_TIMEOUT", "240")), max_content_length=int(config_dict.get("SEARCH_CONTENT_MAX_LENGTH", "20000")), @@ -145,9 +135,8 @@ def load_config(config_file: Optional[str] = None) -> Config: def print_config(config: Config): """打印配置信息(隐藏敏感信息)""" print("\n=== 当前配置 ===") - print(f"LLM提供商: {config.default_llm_provider}") - print(f"DeepSeek模型: {config.deepseek_model}") - print(f"OpenAI模型: {config.openai_model}") + print(f"提供商: {config.base_url}") + print(f"模型: {config.model}") print(f"最大搜索结果数: {config.max_search_results}") print(f"搜索超时: {config.search_timeout}秒") print(f"最大内容长度: {config.max_content_length}") @@ -157,7 +146,6 @@ def print_config(config: Config): print(f"保存中间状态: {config.save_intermediate_states}") # 显示API密钥状态(不显示实际密钥) - print(f"DeepSeek API Key: {'已设置' if config.deepseek_api_key else '未设置'}") - print(f"OpenAI API Key: {'已设置' if config.openai_api_key else '未设置'}") + print(f"API Key: {'已设置' if config.api_key else '未设置'}") print(f"Tavily API Key: {'已设置' if config.tavily_api_key else '未设置'}") print("==================\n")