diff --git a/libs/community/langchain_community/chat_models/moonshot.py b/libs/community/langchain_community/chat_models/moonshot.py index 6d31426f..48ffd494 100644 --- a/libs/community/langchain_community/chat_models/moonshot.py +++ b/libs/community/langchain_community/chat_models/moonshot.py @@ -1,12 +1,18 @@ """Wrapper around Moonshot chat models.""" -from typing import Dict +from typing import Any, Callable, Dict, Sequence, Type, Union +from langchain_core.language_models import LanguageModelInput +from langchain_core.messages import AIMessage +from langchain_core.runnables import Runnable +from langchain_core.tools import BaseTool from langchain_core.utils import ( convert_to_secret_str, get_from_dict_or_env, pre_init, ) +from langchain_core.utils.function_calling import convert_to_openai_tool +from pydantic import BaseModel from langchain_community.chat_models import ChatOpenAI from langchain_community.llms.moonshot import MOONSHOT_SERVICE_URL_BASE, MoonshotCommon @@ -172,9 +178,11 @@ def validate_environment(cls, values: Dict) -> Dict: client_params = { "api_key": values["moonshot_api_key"].get_secret_value(), - "base_url": values["base_url"] - if "base_url" in values - else MOONSHOT_SERVICE_URL_BASE, + "base_url": ( + values["base_url"] + if "base_url" in values + else MOONSHOT_SERVICE_URL_BASE + ), } if not values.get("client"): @@ -185,3 +193,22 @@ def validate_environment(cls, values: Dict) -> Dict: ).chat.completions return values + + def bind_tools( + self, + tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]], + **kwargs: Any, + ) -> Runnable[LanguageModelInput, AIMessage]: + """Bind tool-like objects to this chat model. + + Args: + tools: A list of tool definitions to bind to this chat model. + Can be a dictionary, pydantic model, callable, or BaseTool. Pydantic + models, callables, and BaseTools will be automatically converted to + their schema dictionary representation. + **kwargs: Any additional parameters to pass to the + :class:`~langchain.runnable.Runnable` constructor. + """ + + formatted_tools = [convert_to_openai_tool(tool) for tool in tools] + return super().bind(tools=formatted_tools, **kwargs) diff --git a/libs/community/langchain_community/llms/moonshot.py b/libs/community/langchain_community/llms/moonshot.py index 7f204fa6..c9ffc978 100644 --- a/libs/community/langchain_community/llms/moonshot.py +++ b/libs/community/langchain_community/llms/moonshot.py @@ -39,7 +39,7 @@ def completion(self, request: Any) -> Any: class MoonshotCommon(BaseModel): """Common parameters for Moonshot LLMs.""" - client: Any + client: Any = Field(default=None) base_url: str = MOONSHOT_SERVICE_URL_BASE moonshot_api_key: Optional[SecretStr] = Field(default=None, alias="api_key") """Moonshot API key. Get it here: https://platform.moonshot.cn/console/api-keys""" diff --git a/libs/community/tests/integration_tests/chat_models/test_moonshot.py b/libs/community/tests/integration_tests/chat_models/test_moonshot.py index de4725cf..29a31ca9 100644 --- a/libs/community/tests/integration_tests/chat_models/test_moonshot.py +++ b/libs/community/tests/integration_tests/chat_models/test_moonshot.py @@ -27,5 +27,5 @@ def test_usage_metadata(self, model: BaseChatModel) -> None: def test_chat_moonshot_instantiate_with_alias() -> None: """Test MoonshotChat instantiate when using alias.""" api_key = "your-api-key" - chat = MoonshotChat(api_key=api_key) # type: ignore[call-arg] + chat = MoonshotChat(api_key=api_key) assert cast(SecretStr, chat.moonshot_api_key).get_secret_value() == api_key diff --git a/libs/community/tests/unit_tests/chat_models/test_moonshot.py b/libs/community/tests/unit_tests/chat_models/test_moonshot.py new file mode 100644 index 00000000..6938fc57 --- /dev/null +++ b/libs/community/tests/unit_tests/chat_models/test_moonshot.py @@ -0,0 +1,14 @@ +from typing import Any + +import pytest + +from langchain_community.chat_models.moonshot import MoonshotChat + +mock_tool_list = [lambda: f"tool-id-{i}" for i in range(3)] + + +@pytest.mark.requires("openai") +def test_moonshot_bind_tools() -> None: + llm = MoonshotChat(name="moonshot") + ret: Any = llm.bind_tools(mock_tool_list) + assert len(ret.kwargs["tools"]) == 3 diff --git a/libs/community/tests/unit_tests/llms/test_moonshot.py b/libs/community/tests/unit_tests/llms/test_moonshot.py index fda1a529..116abe99 100644 --- a/libs/community/tests/unit_tests/llms/test_moonshot.py +++ b/libs/community/tests/unit_tests/llms/test_moonshot.py @@ -9,7 +9,7 @@ @pytest.mark.requires("openai") def test_moonshot_model_param() -> None: - llm = Moonshot(model="foo") # type: ignore[call-arg] + llm = Moonshot(model="foo") assert llm.model_name == "foo" llm = Moonshot(model_name="bar") # type: ignore[call-arg] assert llm.model_name == "bar"