diff --git a/contributing/samples/gepa/experiment.py b/contributing/samples/gepa/experiment.py index 2f5d03a772..f68b349d9c 100644 --- a/contributing/samples/gepa/experiment.py +++ b/contributing/samples/gepa/experiment.py @@ -43,7 +43,6 @@ from tau_bench.types import EnvRunResult from tau_bench.types import RunConfig import tau_bench_agent as tau_bench_agent_lib - import utils diff --git a/contributing/samples/gepa/run_experiment.py b/contributing/samples/gepa/run_experiment.py index cfd850b3a3..1bc4ee58c8 100644 --- a/contributing/samples/gepa/run_experiment.py +++ b/contributing/samples/gepa/run_experiment.py @@ -25,7 +25,6 @@ from absl import flags import experiment from google.genai import types - import utils _OUTPUT_DIR = flags.DEFINE_string( diff --git a/contributing/samples/pubsub/README.md b/contributing/samples/pubsub/README.md new file mode 100644 index 0000000000..507902abca --- /dev/null +++ b/contributing/samples/pubsub/README.md @@ -0,0 +1,88 @@ +# Pub/Sub Tools Sample + +## Introduction + +This sample agent demonstrates the Pub/Sub first-party tools in ADK, +distributed via the `google.adk.tools.pubsub` module. These tools include: + +1. `publish_message` + + Publishes a message to a Pub/Sub topic. + +2. `pull_messages` + + Pulls messages from a Pub/Sub subscription. + +3. `acknowledge_messages` + + Acknowledges messages on a Pub/Sub subscription. + +## How to use + +Set up environment variables in your `.env` file for using +[Google AI Studio](https://google.github.io/adk-docs/get-started/quickstart/#gemini---google-ai-studio) +or +[Google Cloud Vertex AI](https://google.github.io/adk-docs/get-started/quickstart/#gemini---google-cloud-vertex-ai) +for the LLM service for your agent. For example, for using Google AI Studio you +would set: + +* GOOGLE_GENAI_USE_VERTEXAI=FALSE +* GOOGLE_API_KEY={your api key} + +### With Application Default Credentials + +This mode is useful for quick development when the agent builder is the only +user interacting with the agent. The tools are run with these credentials. + +1. Create application default credentials on the machine where the agent would +be running by following https://cloud.google.com/docs/authentication/provide-credentials-adc. + +1. Set `CREDENTIALS_TYPE=None` in `agent.py` + +1. Run the agent + +### With Service Account Keys + +This mode is useful for quick development when the agent builder wants to run +the agent with service account credentials. The tools are run with these +credentials. + +1. Create service account key by following https://cloud.google.com/iam/docs/service-account-creds#user-managed-keys. + +1. Set `CREDENTIALS_TYPE=AuthCredentialTypes.SERVICE_ACCOUNT` in `agent.py` + +1. Download the key file and replace `"service_account_key.json"` with the path + +1. Run the agent + +### With Interactive OAuth + +1. Follow +https://developers.google.com/identity/protocols/oauth2#1.-obtain-oauth-2.0-credentials-from-the-dynamic_data.setvar.console_name. +to get your client id and client secret. Be sure to choose "web" as your client +type. + +1. Follow https://developers.google.com/workspace/guides/configure-oauth-consent to add scope "https://www.googleapis.com/auth/pubsub". + +1. Follow https://developers.google.com/identity/protocols/oauth2/web-server#creatingcred to add http://localhost/dev-ui/ to "Authorized redirect URIs". + + Note: localhost here is just a hostname that you use to access the dev ui, + replace it with the actual hostname you use to access the dev ui. + +1. For 1st run, allow popup for localhost in Chrome. + +1. Configure your `.env` file to add two more variables before running the agent: + + * OAUTH_CLIENT_ID={your client id} + * OAUTH_CLIENT_SECRET={your client secret} + + Note: don't create a separate .env, instead put it to the same .env file that + stores your Vertex AI or Dev ML credentials + +1. Set `CREDENTIALS_TYPE=AuthCredentialTypes.OAUTH2` in `agent.py` and run the agent + +## Sample prompts + +* publish 'Hello World' to 'my-topic' +* pull messages from 'my-subscription' +* acknowledge message 'ack-id' from 'my-subscription' diff --git a/contributing/samples/pubsub/__init__.py b/contributing/samples/pubsub/__init__.py new file mode 100644 index 0000000000..c48963cdc7 --- /dev/null +++ b/contributing/samples/pubsub/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import agent diff --git a/contributing/samples/pubsub/agent.py b/contributing/samples/pubsub/agent.py new file mode 100644 index 0000000000..400471b09e --- /dev/null +++ b/contributing/samples/pubsub/agent.py @@ -0,0 +1,79 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from google.adk.agents.llm_agent import LlmAgent +from google.adk.auth.auth_credential import AuthCredentialTypes +from google.adk.tools.pubsub.config import PubSubToolConfig +from google.adk.tools.pubsub.pubsub_credentials import PubSubCredentialsConfig +from google.adk.tools.pubsub.pubsub_toolset import PubSubToolset +import google.auth + +# Define the desired credential type. +# By default use Application Default Credentials (ADC) from the local +# environment, which can be set up by following +# https://cloud.google.com/docs/authentication/provide-credentials-adc. +CREDENTIALS_TYPE = None + +# Define an appropriate application name +PUBSUB_AGENT_NAME = "adk_sample_pubsub_agent" + + +# Define Pub/Sub tool config. +# You can optionally set the project_id here, or let the agent infer it from context/user input. +tool_config = PubSubToolConfig(project_id=os.getenv("GOOGLE_CLOUD_PROJECT")) + +if CREDENTIALS_TYPE == AuthCredentialTypes.OAUTH2: + # Initialize the tools to do interactive OAuth + # The environment variables OAUTH_CLIENT_ID and OAUTH_CLIENT_SECRET + # must be set + credentials_config = PubSubCredentialsConfig( + client_id=os.getenv("OAUTH_CLIENT_ID"), + client_secret=os.getenv("OAUTH_CLIENT_SECRET"), + ) +elif CREDENTIALS_TYPE == AuthCredentialTypes.SERVICE_ACCOUNT: + # Initialize the tools to use the credentials in the service account key. + # If this flow is enabled, make sure to replace the file path with your own + # service account key file + # https://cloud.google.com/iam/docs/service-account-creds#user-managed-keys + creds, _ = google.auth.load_credentials_from_file("service_account_key.json") + credentials_config = PubSubCredentialsConfig(credentials=creds) +else: + # Initialize the tools to use the application default credentials. + # https://cloud.google.com/docs/authentication/provide-credentials-adc + application_default_credentials, _ = google.auth.default() + credentials_config = PubSubCredentialsConfig( + credentials=application_default_credentials + ) + +pubsub_toolset = PubSubToolset( + credentials_config=credentials_config, pubsub_tool_config=tool_config +) + +# The variable name `root_agent` determines what your root agent is for the +# debug CLI +root_agent = LlmAgent( + model="gemini-2.0-flash", + name=PUBSUB_AGENT_NAME, + description=( + "Agent to publish, pull, and acknowledge messages from Google Cloud" + " Pub/Sub." + ), + instruction="""\ + You are a cloud engineer agent with access to Google Cloud Pub/Sub tools. + You can publish messages to topics, pull messages from subscriptions, and acknowledge messages. + """, + tools=[pubsub_toolset], +) diff --git a/pyproject.toml b/pyproject.toml index 72444fe5d7..7a6031c5f9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,6 +37,7 @@ dependencies = [ "google-cloud-bigquery>=2.2.0", "google-cloud-bigtable>=2.32.0", # For Bigtable database "google-cloud-discoveryengine>=0.13.12, <0.14.0", # For Discovery Engine Search Tool + "google-cloud-pubsub>=2.0.0, <3.0.0", # For Pub/Sub Tool "google-cloud-secret-manager>=2.22.0, <3.0.0", # Fetching secrets in RestAPI Tool "google-cloud-spanner>=3.56.0, <4.0.0", # For Spanner database "google-cloud-speech>=2.30.0, <3.0.0", # For Audio Transcription diff --git a/src/google/adk/features/_feature_registry.py b/src/google/adk/features/_feature_registry.py index 46b56eb6d9..cb3fa882d1 100644 --- a/src/google/adk/features/_feature_registry.py +++ b/src/google/adk/features/_feature_registry.py @@ -32,6 +32,7 @@ class FeatureName(str, Enum): GOOGLE_TOOL = "GOOGLE_TOOL" JSON_SCHEMA_FOR_FUNC_DECL = "JSON_SCHEMA_FOR_FUNC_DECL" PROGRESSIVE_SSE_STREAMING = "PROGRESSIVE_SSE_STREAMING" + PUBSUB_TOOLSET = "PUBSUB_TOOLSET" SPANNER_TOOLSET = "SPANNER_TOOLSET" SPANNER_TOOL_SETTINGS = "SPANNER_TOOL_SETTINGS" @@ -90,6 +91,9 @@ class FeatureConfig: FeatureName.PROGRESSIVE_SSE_STREAMING: FeatureConfig( FeatureStage.WIP, default_on=False ), + FeatureName.PUBSUB_TOOLSET: FeatureConfig( + FeatureStage.EXPERIMENTAL, default_on=True + ), FeatureName.SPANNER_TOOLSET: FeatureConfig( FeatureStage.EXPERIMENTAL, default_on=True ), diff --git a/src/google/adk/tools/pubsub/__init__.py b/src/google/adk/tools/pubsub/__init__.py new file mode 100644 index 0000000000..9625155f06 --- /dev/null +++ b/src/google/adk/tools/pubsub/__init__.py @@ -0,0 +1,30 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Pub/Sub Tools (Experimental). + +Pub/Sub Tools under this module are hand crafted and customized while the tools +under google.adk.tools.google_api_tool are auto generated based on API +definition. The rationales to have customized tool are: + +1. Better handling of base64 encoding for published messages. +2. A richer subscribe-side API that reflects how users may want to pull/ack + messages. +""" + +from .config import PubSubToolConfig +from .pubsub_credentials import PubSubCredentialsConfig +from .pubsub_toolset import PubSubToolset + +__all__ = ["PubSubCredentialsConfig", "PubSubToolConfig", "PubSubToolset"] diff --git a/src/google/adk/tools/pubsub/client.py b/src/google/adk/tools/pubsub/client.py new file mode 100644 index 0000000000..b04c9ae7f5 --- /dev/null +++ b/src/google/adk/tools/pubsub/client.py @@ -0,0 +1,165 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import threading +import time + +from google.api_core.gapic_v1.client_info import ClientInfo +from google.auth.credentials import Credentials +from google.cloud import pubsub_v1 +from google.cloud.pubsub_v1.types import BatchSettings + +from ... import version + +USER_AGENT = f"adk-pubsub-tool google-adk/{version.__version__}" + +_CACHE_TTL = 1800 # 30 minutes + +_publisher_client_cache = {} +_publisher_client_lock = threading.Lock() + + +def get_publisher_client( + *, + credentials: Credentials, + user_agent: str | list[str] | None = None, + publisher_options: pubsub_v1.types.PublisherOptions | None = None, +) -> pubsub_v1.PublisherClient: + """Get a Pub/Sub Publisher client. + + Args: + credentials: The credentials to use for the request. + user_agent: The user agent to use for the request. + publisher_options: The publisher options to use for the request. + + Returns: + A Pub/Sub Publisher client. + """ + global _publisher_client_cache + current_time = time.time() + + user_agents_key = None + if user_agent: + if isinstance(user_agent, str): + user_agents_key = (user_agent,) + else: + user_agents_key = tuple(user_agent) + + # Use object identity for credentials and publisher_options as they are not hashable + key = (id(credentials), user_agents_key, id(publisher_options)) + + with _publisher_client_lock: + if key in _publisher_client_cache: + client, expiration = _publisher_client_cache[key] + if expiration > current_time: + return client + + user_agents = [USER_AGENT] + if user_agent: + if isinstance(user_agent, str): + user_agents.append(user_agent) + else: + user_agents.extend(ua for ua in user_agent if ua) + + client_info = ClientInfo(user_agent=" ".join(user_agents)) + + # Since we synchronously publish messages, we want to disable batching to + # remove any delay. + custom_batch_settings = BatchSettings(max_messages=1) + publisher_client = pubsub_v1.PublisherClient( + credentials=credentials, + client_info=client_info, + publisher_options=publisher_options, + batch_settings=custom_batch_settings, + ) + + _publisher_client_cache[key] = (publisher_client, current_time + _CACHE_TTL) + + return publisher_client + + +_subscriber_client_cache = {} +_subscriber_client_lock = threading.Lock() + + +def get_subscriber_client( + *, + credentials: Credentials, + user_agent: str | list[str] | None = None, +) -> pubsub_v1.SubscriberClient: + """Get a Pub/Sub Subscriber client. + + Args: + credentials: The credentials to use for the request. + user_agent: The user agent to use for the request. + + Returns: + A Pub/Sub Subscriber client. + """ + global _subscriber_client_cache + current_time = time.time() + + user_agents_key = None + if user_agent: + if isinstance(user_agent, str): + user_agents_key = (user_agent,) + else: + user_agents_key = tuple(user_agent) + + # Use object identity for credentials as they are not hashable + key = (id(credentials), user_agents_key) + + with _subscriber_client_lock: + if key in _subscriber_client_cache: + client, expiration = _subscriber_client_cache[key] + if expiration > current_time: + return client + + user_agents = [USER_AGENT] + if user_agent: + if isinstance(user_agent, str): + user_agents.append(user_agent) + else: + user_agents.extend(ua for ua in user_agent if ua) + + client_info = ClientInfo(user_agent=" ".join(user_agents)) + + subscriber_client = pubsub_v1.SubscriberClient( + credentials=credentials, + client_info=client_info, + ) + + _subscriber_client_cache[key] = ( + subscriber_client, + current_time + _CACHE_TTL, + ) + + return subscriber_client + + +def cleanup_clients(): + """Clean up all cached Pub/Sub clients.""" + global _publisher_client_cache, _subscriber_client_cache + + with _publisher_client_lock: + for client, _ in _publisher_client_cache.values(): + client.transport.close() + _publisher_client_cache.clear() + + with _subscriber_client_lock: + for client, _ in _subscriber_client_cache.values(): + client.close() + _subscriber_client_cache.clear() diff --git a/src/google/adk/tools/pubsub/config.py b/src/google/adk/tools/pubsub/config.py new file mode 100644 index 0000000000..eb48a1f7f4 --- /dev/null +++ b/src/google/adk/tools/pubsub/config.py @@ -0,0 +1,35 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from pydantic import BaseModel +from pydantic import ConfigDict + +from ...utils.feature_decorator import experimental + + +@experimental('Config defaults may have breaking change in the future.') +class PubSubToolConfig(BaseModel): + """Configuration for Pub/Sub tools.""" + + # Forbid any fields not defined in the model + model_config = ConfigDict(extra='forbid') + + project_id: str | None = None + """GCP project ID to use for the Pub/Sub operations. + + If not set, the project ID will be inferred from the environment or + credentials. + """ diff --git a/src/google/adk/tools/pubsub/message_tool.py b/src/google/adk/tools/pubsub/message_tool.py new file mode 100644 index 0000000000..182b48c0bd --- /dev/null +++ b/src/google/adk/tools/pubsub/message_tool.py @@ -0,0 +1,191 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import base64 + +from google.auth.credentials import Credentials +from google.cloud import pubsub_v1 + +from . import client +from .config import PubSubToolConfig + + +def publish_message( + topic_name: str, + message: str, + credentials: Credentials, + settings: PubSubToolConfig, + attributes: dict[str, str] | None = None, + ordering_key: str | None = None, +) -> dict: + """Publish a message to a Pub/Sub topic. + + Args: + topic_name (str): The Pub/Sub topic name (e.g. + projects/my-project/topics/my-topic). + message (str): The message content to publish. + credentials (Credentials): The credentials to use for the request. + settings (PubSubToolConfig): The Pub/Sub tool settings. + attributes (dict[str, str] | None): Attributes to attach to the message. + ordering_key (str | None): Ordering key for the message. + + Returns: + dict: Dictionary with the message_id of the published message. + """ + if attributes is None: + attributes = {} + + try: + if ordering_key: + publisher_options = pubsub_v1.types.PublisherOptions( + enable_message_ordering=True + ) + else: + publisher_options = pubsub_v1.types.PublisherOptions() + publisher_client = client.get_publisher_client( + credentials=credentials, + user_agent=[settings.project_id, "publish_message"], + publisher_options=publisher_options, + ) + + message_bytes = message.encode("utf-8") + future = publisher_client.publish( + topic_name, + data=message_bytes, + ordering_key=ordering_key or "", + **(attributes or {}), + ) + + return {"message_id": future.result()} + except Exception as ex: + return { + "status": "ERROR", + "error_details": ( + f"Failed to publish message to topic '{topic_name}': {repr(ex)}" + ), + } + + +def _decode_message_data(data: bytes) -> str: + """Decodes message data, trying UTF-8 and falling back to base64.""" + try: + return data.decode("utf-8") + except UnicodeDecodeError: + # If UTF-8 decoding fails, encode as base64 string + return base64.b64encode(data).decode("ascii") + + +def pull_messages( + subscription_name: str, + credentials: Credentials, + settings: PubSubToolConfig, + *, + max_messages: int = 1, + auto_ack: bool = False, +) -> dict: + """Pull messages from a Pub/Sub subscription. + + Args: + subscription_name (str): The Pub/Sub subscription name (e.g. + projects/my-project/subscriptions/my-sub). + credentials (Credentials): The credentials to use for the request. + settings (PubSubToolConfig): The Pub/Sub tool settings. + max_messages (int): The maximum number of messages to pull. Defaults to 1. + auto_ack (bool): Whether to automatically acknowledge the messages. + Defaults to False. + + Returns: + dict: Dictionary with the list of pulled messages. + """ + try: + subscriber_client = client.get_subscriber_client( + credentials=credentials, + user_agent=[settings.project_id, "pull_messages"], + ) + + response = subscriber_client.pull( + subscription=subscription_name, + max_messages=max_messages, + ) + + messages = [] + ack_ids = [] + for received_message in response.received_messages: + message_data = _decode_message_data(received_message.message.data) + messages.append({ + "message_id": received_message.message.message_id, + "data": message_data, + "attributes": dict(received_message.message.attributes), + "publish_time": received_message.message.publish_time.rfc3339(), + "ack_id": received_message.ack_id, + }) + ack_ids.append(received_message.ack_id) + + if auto_ack and ack_ids: + subscriber_client.acknowledge( + subscription=subscription_name, + ack_ids=ack_ids, + ) + + return {"messages": messages} + except Exception as ex: + return { + "status": "ERROR", + "error_details": ( + f"Failed to pull messages from subscription '{subscription_name}':" + f" {repr(ex)}" + ), + } + + +def acknowledge_messages( + subscription_name: str, + ack_ids: list[str], + credentials: Credentials, + settings: PubSubToolConfig, +) -> dict: + """Acknowledge messages on a Pub/Sub subscription. + + Args: + subscription_name (str): The Pub/Sub subscription name (e.g. + projects/my-project/subscriptions/my-sub). + ack_ids (list[str]): List of acknowledgment IDs to acknowledge. + credentials (Credentials): The credentials to use for the request. + settings (PubSubToolConfig): The Pub/Sub tool settings. + + Returns: + dict: Status of the operation. + """ + try: + subscriber_client = client.get_subscriber_client( + credentials=credentials, + user_agent=[settings.project_id, "acknowledge_messages"], + ) + + subscriber_client.acknowledge( + subscription=subscription_name, + ack_ids=ack_ids, + ) + + return {"status": "SUCCESS"} + except Exception as ex: + return { + "status": "ERROR", + "error_details": ( + "Failed to acknowledge messages on subscription" + f" '{subscription_name}': {repr(ex)}" + ), + } diff --git a/src/google/adk/tools/pubsub/pubsub_credentials.py b/src/google/adk/tools/pubsub/pubsub_credentials.py new file mode 100644 index 0000000000..ed04b9e0d7 --- /dev/null +++ b/src/google/adk/tools/pubsub/pubsub_credentials.py @@ -0,0 +1,45 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from pydantic import model_validator + +from ...features import experimental +from ...features import FeatureName +from .._google_credentials import BaseGoogleCredentialsConfig + +PUBSUB_TOKEN_CACHE_KEY = "pubsub_token_cache" +PUBSUB_DEFAULT_SCOPE = ("https://www.googleapis.com/auth/pubsub",) + + +@experimental(FeatureName.GOOGLE_CREDENTIALS_CONFIG) +class PubSubCredentialsConfig(BaseGoogleCredentialsConfig): + """Pub/Sub Credentials Configuration for Google API tools (Experimental). + + Please do not use this in production, as it may be deprecated later. + """ + + @model_validator(mode="after") + def __post_init__(self) -> PubSubCredentialsConfig: + """Populate default scope if scopes is None.""" + super().__post_init__() + + if not self.scopes: + self.scopes = PUBSUB_DEFAULT_SCOPE + + # Set the token cache key + self._token_cache_key = PUBSUB_TOKEN_CACHE_KEY + + return self diff --git a/src/google/adk/tools/pubsub/pubsub_toolset.py b/src/google/adk/tools/pubsub/pubsub_toolset.py new file mode 100644 index 0000000000..9f7fb0ed4f --- /dev/null +++ b/src/google/adk/tools/pubsub/pubsub_toolset.py @@ -0,0 +1,99 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from google.adk.agents.readonly_context import ReadonlyContext +from typing_extensions import override + +from . import client +from . import message_tool +from ...features import experimental +from ...features import FeatureName +from ...tools.base_tool import BaseTool +from ...tools.base_toolset import BaseToolset +from ...tools.base_toolset import ToolPredicate +from ...tools.google_tool import GoogleTool +from .config import PubSubToolConfig +from .pubsub_credentials import PubSubCredentialsConfig + + +@experimental(FeatureName.PUBSUB_TOOLSET) +class PubSubToolset(BaseToolset): + """Pub/Sub Toolset contains tools for interacting with Pub/Sub topics and subscriptions.""" + + def __init__( + self, + *, + tool_filter: ToolPredicate | list[str] | None = None, + credentials_config: PubSubCredentialsConfig | None = None, + pubsub_tool_config: PubSubToolConfig | None = None, + ): + """Initializes the PubSubToolset. + + Args: + tool_filter: A predicate or list of tool names to filter the tools in + the toolset. If None, all tools are included. + credentials_config: The credentials configuration to use for + authenticating with Google Cloud. + pubsub_tool_config: The configuration for the Pub/Sub tools. + """ + super().__init__(tool_filter=tool_filter) + self._credentials_config = credentials_config + self._tool_settings = ( + pubsub_tool_config if pubsub_tool_config else PubSubToolConfig() + ) + + def _is_tool_selected( + self, tool: BaseTool, readonly_context: ReadonlyContext + ) -> bool: + if self.tool_filter is None: + return True + + if isinstance(self.tool_filter, ToolPredicate): + return self.tool_filter(tool, readonly_context) + + if isinstance(self.tool_filter, list): + return tool.name in self.tool_filter + + return False + + @override + async def get_tools( + self, readonly_context: ReadonlyContext | None = None + ) -> list[BaseTool]: + """Get tools from the toolset.""" + all_tools = [ + GoogleTool( + func=func, + credentials_config=self._credentials_config, + tool_settings=self._tool_settings, + ) + for func in [ + message_tool.publish_message, + message_tool.pull_messages, + message_tool.acknowledge_messages, + ] + ] + + return [ + tool + for tool in all_tools + if self._is_tool_selected(tool, readonly_context) + ] + + @override + async def close(self): + """Clean up resources used by the toolset.""" + client.cleanup_clients() diff --git a/tests/unittests/tools/pubsub/test_pubsub_client.py b/tests/unittests/tools/pubsub/test_pubsub_client.py new file mode 100644 index 0000000000..fec9b3798d --- /dev/null +++ b/tests/unittests/tools/pubsub/test_pubsub_client.py @@ -0,0 +1,136 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest import mock + +from google.adk.tools.pubsub import client +from google.cloud import pubsub_v1 +from google.oauth2.credentials import Credentials +import pytest + + +# Save original Pub/Sub classes before patching. +# This is necessary because create_autospec cannot be used on a mock object, +# and mock.patch.object(..., autospec=True) replaces the class with a mock. +# We need the original class to create spec'd mocks in side_effect. +ORIG_PUBLISHER = pubsub_v1.PublisherClient +ORIG_SUBSCRIBER = pubsub_v1.SubscriberClient + + +@pytest.fixture(autouse=True) +def cleanup_pubsub_clients(): + """Automatically clean up Pub/Sub client caches after each test. + + This fixture runs automatically for all tests in this file, + ensuring that client caches are cleared between tests to prevent + state leakage and ensure test isolation. + """ + yield + client.cleanup_clients() + + +@mock.patch.object(pubsub_v1, "PublisherClient", autospec=True) +def test_get_publisher_client(mock_publisher_client): + """Test get_publisher_client factory.""" + mock_creds = mock.create_autospec(Credentials, instance=True, spec_set=True) + client.get_publisher_client(credentials=mock_creds) + + mock_publisher_client.assert_called_once() + _, kwargs = mock_publisher_client.call_args + assert kwargs["credentials"] == mock_creds + assert "client_info" in kwargs + assert isinstance(kwargs["batch_settings"], pubsub_v1.types.BatchSettings) + assert kwargs["batch_settings"].max_messages == 1 + + +@mock.patch.object(pubsub_v1, "PublisherClient", autospec=True) +def test_get_publisher_client_with_options(mock_publisher_client): + """Test get_publisher_client factory with options.""" + mock_creds = mock.create_autospec(Credentials, instance=True, spec_set=True) + mock_options = mock.create_autospec( + pubsub_v1.types.PublisherOptions, instance=True, spec_set=True + ) + client.get_publisher_client( + credentials=mock_creds, publisher_options=mock_options + ) + + mock_publisher_client.assert_called_once() + _, kwargs = mock_publisher_client.call_args + assert kwargs["credentials"] == mock_creds + assert kwargs["publisher_options"] == mock_options + assert "client_info" in kwargs + assert isinstance(kwargs["batch_settings"], pubsub_v1.types.BatchSettings) + assert kwargs["batch_settings"].max_messages == 1 + + +@mock.patch.object(pubsub_v1, "PublisherClient", autospec=True) +def test_get_publisher_client_caching(mock_publisher_client): + """Test get_publisher_client caching behavior.""" + mock_creds = mock.create_autospec(Credentials, instance=True, spec_set=True) + mock_publisher_client.side_effect = [ + mock.create_autospec(ORIG_PUBLISHER, instance=True, spec_set=True), + mock.create_autospec(ORIG_PUBLISHER, instance=True, spec_set=True), + ] + + # First call - should create client + client1 = client.get_publisher_client(credentials=mock_creds) + mock_publisher_client.assert_called_once() + + # Second call with same args - should return cached client + client2 = client.get_publisher_client(credentials=mock_creds) + assert client1 is client2 + mock_publisher_client.assert_called_once() # Still called only once + + # Call with different args - should create new client + mock_creds2 = mock.create_autospec(Credentials, instance=True, spec_set=True) + client3 = client.get_publisher_client(credentials=mock_creds2) + assert client3 is not client1 + assert mock_publisher_client.call_count == 2 + + +@mock.patch.object(pubsub_v1, "SubscriberClient", autospec=True) +def test_get_subscriber_client(mock_subscriber_client): + """Test get_subscriber_client factory.""" + mock_creds = mock.create_autospec(Credentials, instance=True, spec_set=True) + client.get_subscriber_client(credentials=mock_creds) + + mock_subscriber_client.assert_called_once() + _, kwargs = mock_subscriber_client.call_args + assert kwargs["credentials"] == mock_creds + assert "client_info" in kwargs + + +@mock.patch.object(pubsub_v1, "SubscriberClient", autospec=True) +def test_get_subscriber_client_caching(mock_subscriber_client): + """Test get_subscriber_client caching behavior.""" + mock_creds = mock.create_autospec(Credentials, instance=True, spec_set=True) + mock_subscriber_client.side_effect = [ + mock.create_autospec(ORIG_SUBSCRIBER, instance=True, spec_set=True), + mock.create_autospec(ORIG_SUBSCRIBER, instance=True, spec_set=True), + ] + + # First call - should create client + client1 = client.get_subscriber_client(credentials=mock_creds) + mock_subscriber_client.assert_called_once() + + # Second call with same args - should return cached client + client2 = client.get_subscriber_client(credentials=mock_creds) + assert client1 is client2 + mock_subscriber_client.assert_called_once() # Still called only once + + # Call with different args - should create new client + mock_creds2 = mock.create_autospec(Credentials, instance=True, spec_set=True) + client3 = client.get_subscriber_client(credentials=mock_creds2) + assert client3 is not client1 + assert mock_subscriber_client.call_count == 2 diff --git a/tests/unittests/tools/pubsub/test_pubsub_config.py b/tests/unittests/tools/pubsub/test_pubsub_config.py new file mode 100644 index 0000000000..2e2628df4c --- /dev/null +++ b/tests/unittests/tools/pubsub/test_pubsub_config.py @@ -0,0 +1,27 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from google.adk.tools.pubsub.config import PubSubToolConfig + + +def test_pubsub_tool_config_init(): + """Test PubSubToolConfig initialization.""" + config = PubSubToolConfig(project_id="my-project") + assert config.project_id == "my-project" + + +def test_pubsub_tool_config_default(): + """Test PubSubToolConfig default initialization.""" + config = PubSubToolConfig() + assert config.project_id is None diff --git a/tests/unittests/tools/pubsub/test_pubsub_credentials.py b/tests/unittests/tools/pubsub/test_pubsub_credentials.py new file mode 100644 index 0000000000..11a5d5dea7 --- /dev/null +++ b/tests/unittests/tools/pubsub/test_pubsub_credentials.py @@ -0,0 +1,133 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest import mock + +from google.adk.tools.pubsub.pubsub_credentials import PUBSUB_DEFAULT_SCOPE +from google.adk.tools.pubsub.pubsub_credentials import PubSubCredentialsConfig +from google.auth.credentials import Credentials +import google.oauth2.credentials +import pytest + + +"""Test suite for PubSub credentials configuration validation. + +This class tests the credential configuration logic that ensures +either existing credentials or client ID/secret pairs are provided. +""" + + +def test_pubsub_credentials_config_client_id_secret(): + """Test PubSubCredentialsConfig with client_id and client_secret. + + Ensures that when client_id and client_secret are provided, the config + object is created with the correct attributes. + """ + config = PubSubCredentialsConfig(client_id="abc", client_secret="def") + assert config.client_id == "abc" + assert config.client_secret == "def" + assert config.scopes == PUBSUB_DEFAULT_SCOPE + assert config.credentials is None + + +def test_pubsub_credentials_config_existing_creds(): + """Test PubSubCredentialsConfig with existing generic credentials. + + Ensures that when a generic Credentials object is provided, it is + stored correctly. + """ + mock_creds = mock.create_autospec(Credentials, instance=True) + config = PubSubCredentialsConfig(credentials=mock_creds) + assert config.credentials == mock_creds + assert config.client_id is None + assert config.client_secret is None + + +def test_pubsub_credentials_config_oauth2_creds(): + """Test PubSubCredentialsConfig with existing OAuth2 credentials. + + Ensures that when a google.oauth2.credentials.Credentials object is + provided, the client_id, client_secret, and scopes are extracted + from the credentials object. + """ + mock_creds = mock.create_autospec( + google.oauth2.credentials.Credentials, instance=True + ) + mock_creds.client_id = "oauth_client_id" + mock_creds.client_secret = "oauth_client_secret" + mock_creds.scopes = ["fake_scope"] + config = PubSubCredentialsConfig(credentials=mock_creds) + assert config.client_id == "oauth_client_id" + assert config.client_secret == "oauth_client_secret" + assert config.scopes == ["fake_scope"] + + +@pytest.mark.parametrize( + "credentials, client_id, client_secret", + [ + # No arguments provided + (None, None, None), + # Only client_id is provided + (None, "abc", None), + ], +) +def test_pubsub_credentials_config_validation_errors( + credentials, client_id, client_secret +): + """Test PubSubCredentialsConfig validation errors. + + Ensures that ValueError is raised when invalid combinations of credentials + and client ID/secret are provided. + + Args: + credentials: The credentials object to pass. + client_id: The client ID to pass. + client_secret: The client secret to pass. + """ + with pytest.raises( + ValueError, + match=( + "Must provide either credentials or client_id and client_secret pair." + ), + ): + PubSubCredentialsConfig( + credentials=credentials, + client_id=client_id, + client_secret=client_secret, + ) + + +def test_pubsub_credentials_config_both_credentials_and_client_provided(): + """Test PubSubCredentialsConfig validation errors. + + Ensures that ValueError is raised when invalid combinations of credentials + and client ID/secret are provided. + + Args: + credentials: The credentials object to pass. + client_id: The client ID to pass. + client_secret: The client secret to pass. + """ + with pytest.raises( + ValueError, + match=( + "Cannot provide both existing credentials and" + " client_id/client_secret/scopes." + ), + ): + PubSubCredentialsConfig( + credentials=mock.create_autospec(Credentials, instance=True), + client_id="abc", + client_secret="def", + ) diff --git a/tests/unittests/tools/pubsub/test_pubsub_message_tool.py b/tests/unittests/tools/pubsub/test_pubsub_message_tool.py new file mode 100644 index 0000000000..2a935f41e2 --- /dev/null +++ b/tests/unittests/tools/pubsub/test_pubsub_message_tool.py @@ -0,0 +1,328 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import os +from unittest import mock + +from google.adk.tools.pubsub import client as pubsub_client_lib +from google.adk.tools.pubsub import message_tool +from google.adk.tools.pubsub.config import PubSubToolConfig +from google.api_core import future +from google.cloud import pubsub_v1 +from google.cloud.pubsub_v1 import types +from google.oauth2.credentials import Credentials +from google.protobuf import timestamp_pb2 + + +@mock.patch.dict(os.environ, {}, clear=True) +@mock.patch.object(pubsub_v1.PublisherClient, "publish", autospec=True) +@mock.patch.object(pubsub_client_lib, "get_publisher_client", autospec=True) +def test_publish_message(mock_get_publisher_client, mock_publish): + """Test publish_message tool invocation.""" + topic_name = "projects/my_project_id/topics/my_topic" + message = "Hello World" + mock_credentials = mock.create_autospec(Credentials, instance=True) + tool_settings = PubSubToolConfig(project_id="my_project_id") + + mock_publisher_client = mock.create_autospec( + pubsub_v1.PublisherClient, instance=True + ) + mock_get_publisher_client.return_value = mock_publisher_client + + mock_future = mock.create_autospec(future.Future, instance=True) + mock_future.result.return_value = "message_id" + mock_publisher_client.publish.return_value = mock_future + + result = message_tool.publish_message( + topic_name, message, mock_credentials, tool_settings + ) + + assert result["message_id"] == "message_id" + mock_get_publisher_client.assert_called_once() + mock_publisher_client.publish.assert_called_once() + + +@mock.patch.dict(os.environ, {}, clear=True) +@mock.patch.object(pubsub_v1.PublisherClient, "publish", autospec=True) +@mock.patch.object(pubsub_client_lib, "get_publisher_client", autospec=True) +def test_publish_message_with_ordering_key( + mock_get_publisher_client, mock_publish +): + """Test publish_message tool invocation with ordering_key.""" + topic_name = "projects/my_project_id/topics/my_topic" + message = "Hello World" + ordering_key = "key1" + mock_credentials = mock.create_autospec(Credentials, instance=True) + tool_settings = PubSubToolConfig(project_id="my_project_id") + + mock_publisher_client = mock.create_autospec( + pubsub_v1.PublisherClient, instance=True + ) + mock_get_publisher_client.return_value = mock_publisher_client + + mock_future = mock.create_autospec(future.Future, instance=True) + mock_future.result.return_value = "message_id" + mock_publisher_client.publish.return_value = mock_future + + result = message_tool.publish_message( + topic_name, + message, + mock_credentials, + tool_settings, + ordering_key=ordering_key, + ) + + assert result["message_id"] == "message_id" + mock_get_publisher_client.assert_called_once() + _, kwargs = mock_get_publisher_client.call_args + assert kwargs["publisher_options"].enable_message_ordering is True + + mock_publisher_client.publish.assert_called_once() + + # Verify ordering_key was passed + _, kwargs = mock_publisher_client.publish.call_args + assert kwargs["ordering_key"] == ordering_key + + +@mock.patch.dict(os.environ, {}, clear=True) +@mock.patch.object(pubsub_v1.PublisherClient, "publish", autospec=True) +@mock.patch.object(pubsub_client_lib, "get_publisher_client", autospec=True) +def test_publish_message_with_attributes( + mock_get_publisher_client, mock_publish +): + """Test publish_message tool invocation with attributes.""" + topic_name = "projects/my_project_id/topics/my_topic" + message = "Hello World" + attributes = {"key1": "value1", "key2": "value2"} + mock_credentials = mock.create_autospec(Credentials, instance=True) + tool_settings = PubSubToolConfig(project_id="my_project_id") + + mock_publisher_client = mock.create_autospec( + pubsub_v1.PublisherClient, instance=True + ) + mock_get_publisher_client.return_value = mock_publisher_client + + mock_future = mock.create_autospec(future.Future, instance=True) + mock_future.result.return_value = "message_id" + mock_publisher_client.publish.return_value = mock_future + + result = message_tool.publish_message( + topic_name, + message, + mock_credentials, + tool_settings, + attributes=attributes, + ) + + assert result["message_id"] == "message_id" + mock_get_publisher_client.assert_called_once() + mock_publisher_client.publish.assert_called_once() + + # Verify attributes were passed + _, kwargs = mock_publisher_client.publish.call_args + assert kwargs["key1"] == "value1" + assert kwargs["key2"] == "value2" + + +@mock.patch.dict(os.environ, {}, clear=True) +@mock.patch.object(pubsub_v1.PublisherClient, "publish", autospec=True) +@mock.patch.object(pubsub_client_lib, "get_publisher_client", autospec=True) +def test_publish_message_exception(mock_get_publisher_client, mock_publish): + """Test publish_message tool invocation when exception occurs.""" + topic_name = "projects/my_project_id/topics/my_topic" + message = "Hello World" + mock_credentials = mock.create_autospec(Credentials, instance=True) + tool_settings = PubSubToolConfig(project_id="my_project_id") + + mock_publisher_client = mock.create_autospec( + pubsub_v1.PublisherClient, instance=True + ) + mock_get_publisher_client.return_value = mock_publisher_client + + # Simulate an exception during publish + mock_publisher_client.publish.side_effect = Exception("Publish failed") + + result = message_tool.publish_message( + topic_name, + message, + mock_credentials, + tool_settings, + ) + + assert result["status"] == "ERROR" + assert "Publish failed" in result["error_details"] + mock_get_publisher_client.assert_called_once() + mock_publisher_client.publish.assert_called_once() + + +@mock.patch.dict(os.environ, {}, clear=True) +@mock.patch.object(pubsub_client_lib, "get_subscriber_client", autospec=True) +def test_pull_messages(mock_get_subscriber_client): + """Test pull_messages tool invocation.""" + subscription_name = "projects/my_project_id/subscriptions/my_sub" + mock_credentials = mock.create_autospec(Credentials, instance=True) + tool_settings = PubSubToolConfig(project_id="my_project_id") + + mock_subscriber_client = mock.create_autospec( + pubsub_v1.SubscriberClient, instance=True + ) + mock_get_subscriber_client.return_value = mock_subscriber_client + + mock_response = mock.create_autospec(types.PullResponse, instance=True) + mock_message = mock.MagicMock() + mock_message.message.message_id = "123" + mock_message.message.data = b"Hello" + mock_message.message.attributes = {"key": "value"} + mock_publish_time = mock.MagicMock() + mock_publish_time.rfc3339.return_value = "2023-01-01T00:00:00Z" + mock_message.message.publish_time = mock_publish_time + mock_message.ack_id = "ack_123" + mock_response.received_messages = [mock_message] + mock_subscriber_client.pull.return_value = mock_response + + result = message_tool.pull_messages( + subscription_name, mock_credentials, tool_settings + ) + + expected_message = { + "message_id": "123", + "data": "Hello", + "attributes": {"key": "value"}, + "publish_time": "2023-01-01T00:00:00Z", + "ack_id": "ack_123", + } + assert result["messages"] == [expected_message] + + mock_get_subscriber_client.assert_called_once() + mock_subscriber_client.pull.assert_called_once_with( + subscription=subscription_name, max_messages=1 + ) + mock_subscriber_client.acknowledge.assert_not_called() + + +@mock.patch.dict(os.environ, {}, clear=True) +@mock.patch.object(pubsub_client_lib, "get_subscriber_client", autospec=True) +def test_pull_messages_auto_ack(mock_get_subscriber_client): + """Test pull_messages tool invocation with auto_ack.""" + subscription_name = "projects/my_project_id/subscriptions/my_sub" + mock_credentials = mock.create_autospec(Credentials, instance=True) + tool_settings = PubSubToolConfig(project_id="my_project_id") + + mock_subscriber_client = mock.create_autospec( + pubsub_v1.SubscriberClient, instance=True + ) + mock_get_subscriber_client.return_value = mock_subscriber_client + + mock_response = mock.create_autospec(types.PullResponse, instance=True) + mock_message = mock.MagicMock() + mock_message.message.message_id = "123" + mock_message.message.data = b"Hello" + mock_message.message.attributes = {} + mock_publish_time = mock.MagicMock() + mock_publish_time.rfc3339.return_value = "2023-01-01T00:00:00Z" + mock_message.message.publish_time = mock_publish_time + mock_message.ack_id = "ack_123" + mock_response.received_messages = [mock_message] + mock_subscriber_client.pull.return_value = mock_response + + result = message_tool.pull_messages( + subscription_name, + mock_credentials, + tool_settings, + max_messages=5, + auto_ack=True, + ) + + assert len(result["messages"]) == 1 + mock_get_subscriber_client.assert_called_once() + mock_subscriber_client.pull.assert_called_once_with( + subscription=subscription_name, max_messages=5 + ) + mock_subscriber_client.acknowledge.assert_called_once_with( + subscription=subscription_name, ack_ids=["ack_123"] + ) + + +@mock.patch.dict(os.environ, {}, clear=True) +@mock.patch.object(pubsub_client_lib, "get_subscriber_client", autospec=True) +def test_pull_messages_exception(mock_get_subscriber_client): + """Test pull_messages tool invocation when exception occurs.""" + subscription_name = "projects/my_project_id/subscriptions/my_sub" + mock_credentials = mock.create_autospec(Credentials, instance=True) + tool_settings = PubSubToolConfig(project_id="my_project_id") + + mock_subscriber_client = mock.create_autospec( + pubsub_v1.SubscriberClient, instance=True + ) + mock_get_subscriber_client.return_value = mock_subscriber_client + + mock_subscriber_client.pull.side_effect = Exception("Pull failed") + + result = message_tool.pull_messages( + subscription_name, mock_credentials, tool_settings + ) + + assert result["status"] == "ERROR" + assert "Pull failed" in result["error_details"] + + +@mock.patch.dict(os.environ, {}, clear=True) +@mock.patch.object(pubsub_client_lib, "get_subscriber_client", autospec=True) +def test_acknowledge_messages(mock_get_subscriber_client): + """Test acknowledge_messages tool invocation.""" + subscription_name = "projects/my_project_id/subscriptions/my_sub" + ack_ids = ["ack1", "ack2"] + mock_credentials = mock.create_autospec(Credentials, instance=True) + tool_settings = PubSubToolConfig(project_id="my_project_id") + + mock_subscriber_client = mock.create_autospec( + pubsub_v1.SubscriberClient, instance=True + ) + mock_get_subscriber_client.return_value = mock_subscriber_client + + result = message_tool.acknowledge_messages( + subscription_name, ack_ids, mock_credentials, tool_settings + ) + + assert result["status"] == "SUCCESS" + mock_get_subscriber_client.assert_called_once() + mock_subscriber_client.acknowledge.assert_called_once_with( + subscription=subscription_name, ack_ids=ack_ids + ) + + +@mock.patch.dict(os.environ, {}, clear=True) +@mock.patch.object(pubsub_client_lib, "get_subscriber_client", autospec=True) +def test_acknowledge_messages_exception(mock_get_subscriber_client): + """Test acknowledge_messages tool invocation when exception occurs.""" + subscription_name = "projects/my_project_id/subscriptions/my_sub" + ack_ids = ["ack1"] + mock_credentials = mock.create_autospec(Credentials, instance=True) + tool_settings = PubSubToolConfig(project_id="my_project_id") + + mock_subscriber_client = mock.create_autospec( + pubsub_v1.SubscriberClient, instance=True + ) + mock_get_subscriber_client.return_value = mock_subscriber_client + + mock_subscriber_client.acknowledge.side_effect = Exception("Ack failed") + + result = message_tool.acknowledge_messages( + subscription_name, ack_ids, mock_credentials, tool_settings + ) + + assert result["status"] == "ERROR" + assert "Ack failed" in result["error_details"] diff --git a/tests/unittests/tools/pubsub/test_pubsub_toolset.py b/tests/unittests/tools/pubsub/test_pubsub_toolset.py new file mode 100644 index 0000000000..4750db1204 --- /dev/null +++ b/tests/unittests/tools/pubsub/test_pubsub_toolset.py @@ -0,0 +1,131 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from google.adk.tools.google_tool import GoogleTool +from google.adk.tools.pubsub import PubSubCredentialsConfig +from google.adk.tools.pubsub import PubSubToolset +from google.adk.tools.pubsub.config import PubSubToolConfig +import pytest + + +@pytest.mark.asyncio +async def test_pubsub_toolset_tools_default(): + """Test default PubSub toolset. + + This test verifies the behavior of the PubSub toolset when no filter is + specified. + """ + credentials_config = PubSubCredentialsConfig( + client_id="abc", client_secret="def" + ) + toolset = PubSubToolset( + credentials_config=credentials_config, pubsub_tool_config=None + ) + # Verify that the tool config is initialized to default values. + assert isinstance(toolset._tool_settings, PubSubToolConfig) # pylint: disable=protected-access + assert toolset._tool_settings.__dict__ == PubSubToolConfig().__dict__ # pylint: disable=protected-access + + tools = await toolset.get_tools() + assert tools is not None + + assert len(tools) == 3 + assert all(isinstance(tool, GoogleTool) for tool in tools) + + expected_tool_names = set([ + "publish_message", + "pull_messages", + "acknowledge_messages", + ]) + actual_tool_names = {tool.name for tool in tools} + assert actual_tool_names == expected_tool_names + + +@pytest.mark.parametrize( + "selected_tools", + [ + pytest.param([], id="None"), + pytest.param(["publish_message"], id="publish"), + pytest.param(["pull_messages"], id="pull"), + pytest.param(["acknowledge_messages"], id="ack"), + ], +) +@pytest.mark.asyncio +async def test_pubsub_toolset_tools_selective(selected_tools): + """Test PubSub toolset with filter. + + This test verifies the behavior of the PubSub toolset when filter is + specified. A use case for this would be when the agent builder wants to + use only a subset of the tools provided by the toolset. + + Args: + selected_tools: The list of tools to select from the toolset. + """ + credentials_config = PubSubCredentialsConfig( + client_id="abc", client_secret="def" + ) + toolset = PubSubToolset( + credentials_config=credentials_config, tool_filter=selected_tools + ) + tools = await toolset.get_tools() + assert tools is not None + + assert len(tools) == len(selected_tools) + assert all(isinstance(tool, GoogleTool) for tool in tools) + + expected_tool_names = set(selected_tools) + actual_tool_names = {tool.name for tool in tools} + assert actual_tool_names == expected_tool_names + + +@pytest.mark.parametrize( + ("selected_tools", "returned_tools"), + [ + pytest.param(["unknown"], [], id="all-unknown"), + pytest.param( + ["unknown", "publish_message"], + ["publish_message"], + id="mixed-known-unknown", + ), + ], +) +@pytest.mark.asyncio +async def test_pubsub_toolset_unknown_tool(selected_tools, returned_tools): + """Test PubSub toolset with filter. + + This test verifies the behavior of the PubSub toolset when filter is + specified with an unknown tool. + + Args: + selected_tools: The list of tools to select from the toolset. + returned_tools: The list of tools that are expected to be returned. + """ + credentials_config = PubSubCredentialsConfig( + client_id="abc", client_secret="def" + ) + + toolset = PubSubToolset( + credentials_config=credentials_config, tool_filter=selected_tools + ) + + tools = await toolset.get_tools() + assert tools is not None + + assert len(tools) == len(returned_tools) + assert all(isinstance(tool, GoogleTool) for tool in tools) + + expected_tool_names = set(returned_tools) + actual_tool_names = {tool.name for tool in tools} + assert actual_tool_names == expected_tool_names