From c252a25cc15e6267e80e1c4a752fdce74eb7d36e Mon Sep 17 00:00:00 2001 From: Kamal Aboul-Hosn Date: Sun, 7 Dec 2025 17:12:54 -0500 Subject: [PATCH 01/14] Add built-in tool support for Cloud Pub/Sub --- contributing/samples/pubsub/README.md | 113 ++++++ contributing/samples/pubsub/__init__.py | 15 + contributing/samples/pubsub/agent.py | 80 +++++ pyproject.toml | 1 + src/google/adk/tools/pubsub/__init__.py | 18 + src/google/adk/tools/pubsub/client.py | 123 +++++++ src/google/adk/tools/pubsub/config.py | 36 ++ src/google/adk/tools/pubsub/message_tool.py | 63 ++++ src/google/adk/tools/pubsub/metadata_tool.py | 334 ++++++++++++++++++ .../adk/tools/pubsub/pubsub_credentials.py | 44 +++ src/google/adk/tools/pubsub/pubsub_toolset.py | 98 +++++ .../tools/pubsub/test_pubsub_client.py | 55 +++ .../tools/pubsub/test_pubsub_config.py | 27 ++ .../tools/pubsub/test_pubsub_credentials.py | 91 +++++ .../tools/pubsub/test_pubsub_message_tool.py | 162 +++++++++ .../tools/pubsub/test_pubsub_metadata_tool.py | 210 +++++++++++ .../tools/pubsub/test_pubsub_toolset.py | 133 +++++++ 17 files changed, 1603 insertions(+) create mode 100644 contributing/samples/pubsub/README.md create mode 100644 contributing/samples/pubsub/__init__.py create mode 100644 contributing/samples/pubsub/agent.py create mode 100644 src/google/adk/tools/pubsub/__init__.py create mode 100644 src/google/adk/tools/pubsub/client.py create mode 100644 src/google/adk/tools/pubsub/config.py create mode 100644 src/google/adk/tools/pubsub/message_tool.py create mode 100644 src/google/adk/tools/pubsub/metadata_tool.py create mode 100644 src/google/adk/tools/pubsub/pubsub_credentials.py create mode 100644 src/google/adk/tools/pubsub/pubsub_toolset.py create mode 100644 tests/unittests/tools/pubsub/test_pubsub_client.py create mode 100644 tests/unittests/tools/pubsub/test_pubsub_config.py create mode 100644 tests/unittests/tools/pubsub/test_pubsub_credentials.py create mode 100644 tests/unittests/tools/pubsub/test_pubsub_message_tool.py create mode 100644 tests/unittests/tools/pubsub/test_pubsub_metadata_tool.py create mode 100644 tests/unittests/tools/pubsub/test_pubsub_toolset.py diff --git a/contributing/samples/pubsub/README.md b/contributing/samples/pubsub/README.md new file mode 100644 index 0000000000..53ec73f221 --- /dev/null +++ b/contributing/samples/pubsub/README.md @@ -0,0 +1,113 @@ +# Pub/Sub Tools Sample + +## Introduction + +This sample agent demonstrates the Pub/Sub first-party tools in ADK, +distributed via the `google.adk.tools.pubsub` module. These tools include: + +1. `list_topics` + + Fetches Pub/Sub topics present in a GCP project. + +2. `get_topic` + + Fetches metadata about a Pub/Sub topic. + +3. `list_subscriptions` + + Fetches subscriptions present in a GCP project. + +4. `get_subscription` + + Fetches metadata about a Pub/Sub subscription. + +5. `list_schemas` + + Fetches schemas present in a GCP project. + +6. `get_schema` + + Fetches metadata about a Pub/Sub schema. + +7. `list_schema_revisions` + + Fetches revisions of a Pub/Sub schema. + +8. `get_schema_revision` + + Fetches metadata about a specific Pub/Sub schema revision. + +9. `publish_message` + + Publishes a message to a Pub/Sub topic. + +## How to use + +Set up environment variables in your `.env` file for using +[Google AI Studio](https://google.github.io/adk-docs/get-started/quickstart/#gemini---google-ai-studio) +or +[Google Cloud Vertex AI](https://google.github.io/adk-docs/get-started/quickstart/#gemini---google-cloud-vertex-ai) +for the LLM service for your agent. For example, for using Google AI Studio you +would set: + +* GOOGLE_GENAI_USE_VERTEXAI=FALSE +* GOOGLE_API_KEY={your api key} + +### With Application Default Credentials + +This mode is useful for quick development when the agent builder is the only +user interacting with the agent. The tools are run with these credentials. + +1. Create application default credentials on the machine where the agent would +be running by following https://cloud.google.com/docs/authentication/provide-credentials-adc. + +1. Set `CREDENTIALS_TYPE=None` in `agent.py` + +1. Run the agent + +### With Service Account Keys + +This mode is useful for quick development when the agent builder wants to run +the agent with service account credentials. The tools are run with these +credentials. + +1. Create service account key by following https://cloud.google.com/iam/docs/service-account-creds#user-managed-keys. + +1. Set `CREDENTIALS_TYPE=AuthCredentialTypes.SERVICE_ACCOUNT` in `agent.py` + +1. Download the key file and replace `"service_account_key.json"` with the path + +1. Run the agent + +### With Interactive OAuth + +1. Follow +https://developers.google.com/identity/protocols/oauth2#1.-obtain-oauth-2.0-credentials-from-the-dynamic_data.setvar.console_name. +to get your client id and client secret. Be sure to choose "web" as your client +type. + +1. Follow https://developers.google.com/workspace/guides/configure-oauth-consent to add scope "https://www.googleapis.com/auth/pubsub". + +1. Follow https://developers.google.com/identity/protocols/oauth2/web-server#creatingcred to add http://localhost/dev-ui/ to "Authorized redirect URIs". + + Note: localhost here is just a hostname that you use to access the dev ui, + replace it with the actual hostname you use to access the dev ui. + +1. For 1st run, allow popup for localhost in Chrome. + +1. Configure your `.env` file to add two more variables before running the agent: + + * OAUTH_CLIENT_ID={your client id} + * OAUTH_CLIENT_SECRET={your client secret} + + Note: don't create a separate .env, instead put it to the same .env file that + stores your Vertex AI or Dev ML credentials + +1. Set `CREDENTIALS_TYPE=AuthCredentialTypes.OAUTH2` in `agent.py` and run the agent + +## Sample prompts + +* list topics in my project +* show details for topic 'my-topic' +* list subscriptions +* publish 'Hello World' to 'my-topic' diff --git a/contributing/samples/pubsub/__init__.py b/contributing/samples/pubsub/__init__.py new file mode 100644 index 0000000000..c48963cdc7 --- /dev/null +++ b/contributing/samples/pubsub/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import agent diff --git a/contributing/samples/pubsub/agent.py b/contributing/samples/pubsub/agent.py new file mode 100644 index 0000000000..3658c42604 --- /dev/null +++ b/contributing/samples/pubsub/agent.py @@ -0,0 +1,80 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from google.adk.agents.llm_agent import LlmAgent +from google.adk.auth.auth_credential import AuthCredentialTypes +from google.adk.tools.pubsub.config import PubSubToolConfig +from google.adk.tools.pubsub.pubsub_credentials import PubSubCredentialsConfig +from google.adk.tools.pubsub.pubsub_toolset import PubSubToolset +import google.auth + +# Define the desired credential type. +# By default use Application Default Credentials (ADC) from the local +# environment, which can be set up by following +# https://cloud.google.com/docs/authentication/provide-credentials-adc. +CREDENTIALS_TYPE = None + +# Define an appropriate application name +PUBSUB_AGENT_NAME = "adk_sample_pubsub_agent" + + +# Define Pub/Sub tool config. +# You can optionally set the project_id here, or let the agent infer it from context/user input. +tool_config = PubSubToolConfig(project_id=os.getenv("GOOGLE_CLOUD_PROJECT")) + +if CREDENTIALS_TYPE == AuthCredentialTypes.OAUTH2: + # Initialize the tools to do interactive OAuth + # The environment variables OAUTH_CLIENT_ID and OAUTH_CLIENT_SECRET + # must be set + credentials_config = PubSubCredentialsConfig( + client_id=os.getenv("OAUTH_CLIENT_ID"), + client_secret=os.getenv("OAUTH_CLIENT_SECRET"), + ) +elif CREDENTIALS_TYPE == AuthCredentialTypes.SERVICE_ACCOUNT: + # Initialize the tools to use the credentials in the service account key. + # If this flow is enabled, make sure to replace the file path with your own + # service account key file + # https://cloud.google.com/iam/docs/service-account-creds#user-managed-keys + creds, _ = google.auth.load_credentials_from_file("service_account_key.json") + credentials_config = PubSubCredentialsConfig(credentials=creds) +else: + # Initialize the tools to use the application default credentials. + # https://cloud.google.com/docs/authentication/provide-credentials-adc + application_default_credentials, _ = google.auth.default() + credentials_config = PubSubCredentialsConfig( + credentials=application_default_credentials + ) + +pubsub_toolset = PubSubToolset( + credentials_config=credentials_config, pubsub_tool_config=tool_config +) + +# The variable name `root_agent` determines what your root agent is for the +# debug CLI +root_agent = LlmAgent( + model="gemini-2.0-flash", + name=PUBSUB_AGENT_NAME, + description=( + "Agent to answer questions about Pub/Sub topics, subscriptions, and" + " schemas, and publish messages." + ), + instruction="""\ + You are a cloud engineer agent with access to Google Cloud Pub/Sub tools. + Make use of those tools to answer the user's questions about topics, subscriptions, and schemas. + You can also publish messages to topics if requested. + """, + tools=[pubsub_toolset], +) diff --git a/pyproject.toml b/pyproject.toml index 06ddb04ef2..960dc8a845 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,6 +37,7 @@ dependencies = [ "google-cloud-bigquery>=2.2.0", "google-cloud-bigtable>=2.32.0", # For Bigtable database "google-cloud-discoveryengine>=0.13.12, <0.14.0", # For Discovery Engine Search Tool + "google-cloud-pubsub>=2.0.0", # For Pub/Sub Tool "google-cloud-secret-manager>=2.22.0, <3.0.0", # Fetching secrets in RestAPI Tool "google-cloud-spanner>=3.56.0, <4.0.0", # For Spanner database "google-cloud-speech>=2.30.0, <3.0.0", # For Audio Transcription diff --git a/src/google/adk/tools/pubsub/__init__.py b/src/google/adk/tools/pubsub/__init__.py new file mode 100644 index 0000000000..72faf0bcc8 --- /dev/null +++ b/src/google/adk/tools/pubsub/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .pubsub_credentials import PubSubCredentialsConfig +from .pubsub_toolset import PubSubToolset + +__all__ = ["PubSubCredentialsConfig", "PubSubToolset"] diff --git a/src/google/adk/tools/pubsub/client.py b/src/google/adk/tools/pubsub/client.py new file mode 100644 index 0000000000..bdf85ecc5b --- /dev/null +++ b/src/google/adk/tools/pubsub/client.py @@ -0,0 +1,123 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import List +from typing import Optional +from typing import Union + +from google.api_core.gapic_v1.client_info import ClientInfo +from google.auth.credentials import Credentials +from google.cloud import pubsub_v1 + +from ... import version + +USER_AGENT = f"adk-pubsub-tool google-adk/{version.__version__}" + + +def get_publisher_client( + *, + credentials: Credentials, + user_agent: Optional[Union[str, List[str]]] = None, +) -> pubsub_v1.PublisherClient: + """Get a Pub/Sub Publisher client. + + Args: + credentials: The credentials to use for the request. + user_agent: The user agent to use for the request. + + Returns: + A Pub/Sub Publisher client. + """ + + user_agents = [USER_AGENT] + if user_agent: + if isinstance(user_agent, str): + user_agents.append(user_agent) + else: + user_agents.extend([ua for ua in user_agent if ua]) + + client_info = ClientInfo(user_agent=" ".join(user_agents)) + + publisher_client = pubsub_v1.PublisherClient( + credentials=credentials, + client_info=client_info, + ) + + return publisher_client + + +def get_subscriber_client( + *, + credentials: Credentials, + user_agent: Optional[Union[str, List[str]]] = None, +) -> pubsub_v1.SubscriberClient: + """Get a Pub/Sub Subscriber client. + + Args: + credentials: The credentials to use for the request. + user_agent: The user agent to use for the request. + + Returns: + A Pub/Sub Subscriber client. + """ + + user_agents = [USER_AGENT] + if user_agent: + if isinstance(user_agent, str): + user_agents.append(user_agent) + else: + user_agents.extend([ua for ua in user_agent if ua]) + + client_info = ClientInfo(user_agent=" ".join(user_agents)) + + subscriber_client = pubsub_v1.SubscriberClient( + credentials=credentials, + client_info=client_info, + ) + + return subscriber_client + + +def get_schema_client( + *, + credentials: Credentials, + user_agent: Optional[Union[str, List[str]]] = None, +) -> pubsub_v1.SchemaServiceClient: + """Get a Pub/Sub Schema Service client. + + Args: + credentials: The credentials to use for the request. + user_agent: The user agent to use for the request. + + Returns: + A Pub/Sub Schema Service client. + """ + + user_agents = [USER_AGENT] + if user_agent: + if isinstance(user_agent, str): + user_agents.append(user_agent) + else: + user_agents.extend([ua for ua in user_agent if ua]) + + client_info = ClientInfo(user_agent=" ".join(user_agents)) + + schema_client = pubsub_v1.SchemaServiceClient( + credentials=credentials, + client_info=client_info, + ) + + return schema_client diff --git a/src/google/adk/tools/pubsub/config.py b/src/google/adk/tools/pubsub/config.py new file mode 100644 index 0000000000..a91c62931e --- /dev/null +++ b/src/google/adk/tools/pubsub/config.py @@ -0,0 +1,36 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Optional + +from pydantic import BaseModel +from pydantic import ConfigDict + +from ...utils.feature_decorator import experimental + + +@experimental('Config defaults may have breaking change in the future.') +class PubSubToolConfig(BaseModel): + """Configuration for Pub/Sub tools.""" + + # Forbid any fields not defined in the model + model_config = ConfigDict(extra='forbid') + + project_id: Optional[str] = None + """GCP project ID to use for the Pub/Sub operations. + + If not set, the project ID will be inferred from the environment or credentials. + """ diff --git a/src/google/adk/tools/pubsub/message_tool.py b/src/google/adk/tools/pubsub/message_tool.py new file mode 100644 index 0000000000..cf83296d84 --- /dev/null +++ b/src/google/adk/tools/pubsub/message_tool.py @@ -0,0 +1,63 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Optional + +from google.auth.credentials import Credentials + +from . import client +from .config import PubSubToolConfig + + +def publish_message( + topic_name: str, + message: str, + credentials: Credentials, + settings: PubSubToolConfig, + attributes: Optional[dict[str, str]] = None, + ordering_key: Optional[str] = None, +) -> dict: + """Publish a message to a Pub/Sub topic. + + Args: + topic_name (str): The Pub/Sub topic name (e.g. projects/my-project/topics/my-topic). + message (str): The message content to publish. + credentials (Credentials): The credentials to use for the request. + settings (PubSubToolConfig): The Pub/Sub tool settings. + attributes (Optional[dict[str, str]]): Optional attributes to attach to the message. + ordering_key (Optional[str]): Optional ordering key for the message. + + Returns: + dict: Dictionary with the message_id of the published message. + """ + try: + publisher_client = client.get_publisher_client( + credentials=credentials, + user_agent=[settings.project_id, "publish_message"], + ) + + data = message.encode("utf-8") + future = publisher_client.publish( + topic_name, data, ordering_key=ordering_key, **(attributes or {}) + ) + message_id = future.result() + + return {"message_id": message_id} + except Exception as ex: + return { + "status": "ERROR", + "error_details": str(ex), + } diff --git a/src/google/adk/tools/pubsub/metadata_tool.py b/src/google/adk/tools/pubsub/metadata_tool.py new file mode 100644 index 0000000000..22f3860c0d --- /dev/null +++ b/src/google/adk/tools/pubsub/metadata_tool.py @@ -0,0 +1,334 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from google.auth.credentials import Credentials + +from . import client +from .config import PubSubToolConfig + + +def list_topics( + project_id: str, credentials: Credentials, settings: PubSubToolConfig +) -> list[str]: + """List Pub/Sub topics in a Google Cloud project. + + Args: + project_id (str): The Google Cloud project id. + credentials (Credentials): The credentials to use for the request. + settings (PubSubToolConfig): The Pub/Sub tool settings. + + Returns: + list[str]: List of the Pub/Sub topic names present in the project. + """ + try: + publisher_client = client.get_publisher_client( + credentials=credentials, + user_agent=[settings.project_id, "list_topics"], + ) + + project_path = f"projects/{project_id}" + topics = [] + for topic in publisher_client.list_topics( + request={"project": project_path} + ): + topics.append(topic.name) + return topics + except Exception as ex: + return { + "status": "ERROR", + "error_details": str(ex), + } + + +def get_topic( + topic_name: str, + credentials: Credentials, + settings: PubSubToolConfig, +) -> dict: + """Get metadata information about a Pub/Sub topic. + + Args: + topic_name (str): The Pub/Sub topic name (e.g. projects/my-project/topics/my-topic). + credentials (Credentials): The credentials to use for the request. + settings (PubSubToolConfig): The Pub/Sub tool settings. + + Returns: + dict: Dictionary representing the properties of the topic. + """ + try: + publisher_client = client.get_publisher_client( + credentials=credentials, + user_agent=[settings.project_id, "get_topic"], + ) + topic = publisher_client.get_topic(request={"topic": topic_name}) + + return { + "name": topic.name, + "labels": dict(topic.labels), + "kms_key_name": topic.kms_key_name, + "schema_settings": ( + str(topic.schema_settings) if topic.schema_settings else None + ), + "message_storage_policy": ( + str(topic.message_storage_policy) + if topic.message_storage_policy + else None + ), + } + except Exception as ex: + return { + "status": "ERROR", + "error_details": str(ex), + } + + +def list_subscriptions( + project_id: str, credentials: Credentials, settings: PubSubToolConfig +) -> list[str]: + """List Pub/Sub subscriptions in a Google Cloud project. + + Args: + project_id (str): The Google Cloud project id. + credentials (Credentials): The credentials to use for the request. + settings (PubSubToolConfig): The Pub/Sub tool settings. + + Returns: + list[str]: List of the Pub/Sub subscription names present in the project. + """ + try: + subscriber_client = client.get_subscriber_client( + credentials=credentials, + user_agent=[settings.project_id, "list_subscriptions"], + ) + + project_path = f"projects/{project_id}" + subscriptions = [] + for subscription in subscriber_client.list_subscriptions( + request={"project": project_path} + ): + subscriptions.append(subscription.name) + return subscriptions + except Exception as ex: + return { + "status": "ERROR", + "error_details": str(ex), + } + + +def get_subscription( + subscription_name: str, + credentials: Credentials, + settings: PubSubToolConfig, +) -> dict: + """Get metadata information about a Pub/Sub subscription. + + Args: + subscription_name (str): The Pub/Sub subscription name (e.g. projects/my-project/subscriptions/my-sub). + credentials (Credentials): The credentials to use for the request. + settings (PubSubToolConfig): The Pub/Sub tool settings. + + Returns: + dict: Dictionary representing the properties of the subscription. + """ + try: + subscriber_client = client.get_subscriber_client( + credentials=credentials, + user_agent=[settings.project_id, "get_subscription"], + ) + subscription = subscriber_client.get_subscription( + request={"subscription": subscription_name} + ) + + return { + "name": subscription.name, + "topic": subscription.topic, + "push_config": ( + str(subscription.push_config) if subscription.push_config else None + ), + "ack_deadline_seconds": subscription.ack_deadline_seconds, + "retain_acked_messages": subscription.retain_acked_messages, + "message_retention_duration": ( + str(subscription.message_retention_duration) + if subscription.message_retention_duration + else None + ), + "labels": dict(subscription.labels), + "enable_message_ordering": subscription.enable_message_ordering, + "expiration_policy": ( + str(subscription.expiration_policy) + if subscription.expiration_policy + else None + ), + "filter": subscription.filter, + "dead_letter_policy": ( + str(subscription.dead_letter_policy) + if subscription.dead_letter_policy + else None + ), + "retry_policy": ( + str(subscription.retry_policy) + if subscription.retry_policy + else None + ), + "detached": subscription.detached, + } + except Exception as ex: + return { + "status": "ERROR", + "error_details": str(ex), + } + + +def list_schemas( + project_id: str, credentials: Credentials, settings: PubSubToolConfig +) -> list[str]: + """List Pub/Sub schemas in a Google Cloud project. + + Args: + project_id (str): The Google Cloud project id. + credentials (Credentials): The credentials to use for the request. + settings (PubSubToolConfig): The Pub/Sub tool settings. + + Returns: + list[str]: List of the Pub/Sub schema names present in the project. + """ + try: + schema_client = client.get_schema_client( + credentials=credentials, + user_agent=[settings.project_id, "list_schemas"], + ) + + project_path = f"projects/{project_id}" + schemas = [] + for schema in schema_client.list_schemas(request={"parent": project_path}): + schemas.append(schema.name) + return schemas + except Exception as ex: + return { + "status": "ERROR", + "error_details": str(ex), + } + + +def get_schema( + schema_name: str, + credentials: Credentials, + settings: PubSubToolConfig, +) -> dict: + """Get metadata information about a Pub/Sub schema. + + Args: + schema_name (str): The Pub/Sub schema name (e.g. projects/my-project/schemas/my-schema). + credentials (Credentials): The credentials to use for the request. + settings (PubSubToolConfig): The Pub/Sub tool settings. + + Returns: + dict: Dictionary representing the properties of the schema. + """ + try: + schema_client = client.get_schema_client( + credentials=credentials, + user_agent=[settings.project_id, "get_schema"], + ) + schema = schema_client.get_schema(request={"name": schema_name}) + + return { + "name": schema.name, + "type": str(schema.type_), + "definition": schema.definition, + "revision_id": schema.revision_id, + "revision_create_time": str(schema.revision_create_time), + } + except Exception as ex: + return { + "status": "ERROR", + "error_details": str(ex), + } + + +def list_schema_revisions( + schema_name: str, + credentials: Credentials, + settings: PubSubToolConfig, +) -> list[str]: + """List revisions of a Pub/Sub schema. + + Args: + schema_name (str): The Pub/Sub schema name (e.g. projects/my-project/schemas/my-schema). + credentials (Credentials): The credentials to use for the request. + settings (PubSubToolConfig): The Pub/Sub tool settings. + + Returns: + list[str]: List of the Pub/Sub schema revision IDs. + """ + try: + schema_client = client.get_schema_client( + credentials=credentials, + user_agent=[settings.project_id, "list_schema_revisions"], + ) + + revisions = [] + for schema in schema_client.list_schema_revisions( + request={"name": schema_name} + ): + revisions.append(schema.revision_id) + return revisions + except Exception as ex: + return { + "status": "ERROR", + "error_details": str(ex), + } + + +def get_schema_revision( + schema_name: str, + revision_id: str, + credentials: Credentials, + settings: PubSubToolConfig, +) -> dict: + """Get metadata information about a specific Pub/Sub schema revision. + + Args: + schema_name (str): The Pub/Sub schema name (e.g. projects/my-project/schemas/my-schema). + revision_id (str): The revision ID of the schema. + credentials (Credentials): The credentials to use for the request. + settings (PubSubToolConfig): The Pub/Sub tool settings. + + Returns: + dict: Dictionary representing the properties of the schema revision. + """ + try: + schema_client = client.get_schema_client( + credentials=credentials, + user_agent=[settings.project_id, "get_schema_revision"], + ) + # The get_schema method can take a revision ID appended to the name + # Format: projects/{project}/schemas/{schema}@{revision} + name_with_revision = f"{schema_name}@{revision_id}" + schema = schema_client.get_schema(request={"name": name_with_revision}) + + return { + "name": schema.name, + "type": str(schema.type_), + "definition": schema.definition, + "revision_id": schema.revision_id, + "revision_create_time": str(schema.revision_create_time), + } + except Exception as ex: + return { + "status": "ERROR", + "error_details": str(ex), + } diff --git a/src/google/adk/tools/pubsub/pubsub_credentials.py b/src/google/adk/tools/pubsub/pubsub_credentials.py new file mode 100644 index 0000000000..7729f1ea18 --- /dev/null +++ b/src/google/adk/tools/pubsub/pubsub_credentials.py @@ -0,0 +1,44 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from pydantic import model_validator + +from ...utils.feature_decorator import experimental +from .._google_credentials import BaseGoogleCredentialsConfig + +PUBSUB_TOKEN_CACHE_KEY = "pubsub_token_cache" +PUBSUB_DEFAULT_SCOPE = ["https://www.googleapis.com/auth/pubsub"] + + +@experimental +class PubSubCredentialsConfig(BaseGoogleCredentialsConfig): + """Pub/Sub Credentials Configuration for Google API tools (Experimental). + + Please do not use this in production, as it may be deprecated later. + """ + + @model_validator(mode="after") + def __post_init__(self) -> PubSubCredentialsConfig: + """Populate default scope if scopes is None.""" + super().__post_init__() + + if not self.scopes: + self.scopes = PUBSUB_DEFAULT_SCOPE + + # Set the token cache key + self._token_cache_key = PUBSUB_TOKEN_CACHE_KEY + + return self diff --git a/src/google/adk/tools/pubsub/pubsub_toolset.py b/src/google/adk/tools/pubsub/pubsub_toolset.py new file mode 100644 index 0000000000..402d2abcbd --- /dev/null +++ b/src/google/adk/tools/pubsub/pubsub_toolset.py @@ -0,0 +1,98 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import List +from typing import Optional +from typing import Union + +from google.adk.agents.readonly_context import ReadonlyContext +from typing_extensions import override + +from . import message_tool +from . import metadata_tool +from ...tools.base_tool import BaseTool +from ...tools.base_toolset import BaseToolset +from ...tools.base_toolset import ToolPredicate +from ...tools.google_tool import GoogleTool +from ...utils.feature_decorator import experimental +from .config import PubSubToolConfig +from .pubsub_credentials import PubSubCredentialsConfig + + +@experimental +class PubSubToolset(BaseToolset): + """Pub/Sub Toolset contains tools for interacting with Pub/Sub topics and subscriptions.""" + + def __init__( + self, + *, + tool_filter: Optional[Union[ToolPredicate, List[str]]] = None, + credentials_config: Optional[PubSubCredentialsConfig] = None, + pubsub_tool_config: Optional[PubSubToolConfig] = None, + ): + super().__init__(tool_filter=tool_filter) + self._credentials_config = credentials_config + self._tool_settings = ( + pubsub_tool_config if pubsub_tool_config else PubSubToolConfig() + ) + + def _is_tool_selected( + self, tool: BaseTool, readonly_context: ReadonlyContext + ) -> bool: + if self.tool_filter is None: + return True + + if isinstance(self.tool_filter, ToolPredicate): + return self.tool_filter(tool, readonly_context) + + if isinstance(self.tool_filter, list): + return tool.name in self.tool_filter + + return False + + @override + async def get_tools( + self, readonly_context: Optional[ReadonlyContext] = None + ) -> List[BaseTool]: + """Get tools from the toolset.""" + all_tools = [ + GoogleTool( + func=func, + credentials_config=self._credentials_config, + tool_settings=self._tool_settings, + ) + for func in [ + metadata_tool.list_topics, + metadata_tool.get_topic, + metadata_tool.list_subscriptions, + metadata_tool.get_subscription, + metadata_tool.list_schemas, + metadata_tool.get_schema, + metadata_tool.list_schema_revisions, + metadata_tool.get_schema_revision, + message_tool.publish_message, + ] + ] + + return [ + tool + for tool in all_tools + if self._is_tool_selected(tool, readonly_context) + ] + + @override + async def close(self): + pass diff --git a/tests/unittests/tools/pubsub/test_pubsub_client.py b/tests/unittests/tools/pubsub/test_pubsub_client.py new file mode 100644 index 0000000000..f778e8dacf --- /dev/null +++ b/tests/unittests/tools/pubsub/test_pubsub_client.py @@ -0,0 +1,55 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest import mock + +from google.adk.tools.pubsub import client +from google.cloud import pubsub_v1 +from google.oauth2.credentials import Credentials + + +@mock.patch("google.cloud.pubsub_v1.PublisherClient") +def test_get_publisher_client(mock_publisher_client): + """Test get_publisher_client factory.""" + mock_creds = mock.Mock(spec=Credentials) + client.get_publisher_client(credentials=mock_creds) + + mock_publisher_client.assert_called_once() + _, kwargs = mock_publisher_client.call_args + assert kwargs["credentials"] == mock_creds + assert "client_info" in kwargs + + +@mock.patch("google.cloud.pubsub_v1.SubscriberClient") +def test_get_subscriber_client(mock_subscriber_client): + """Test get_subscriber_client factory.""" + mock_creds = mock.Mock(spec=Credentials) + client.get_subscriber_client(credentials=mock_creds) + + mock_subscriber_client.assert_called_once() + _, kwargs = mock_subscriber_client.call_args + assert kwargs["credentials"] == mock_creds + assert "client_info" in kwargs + + +@mock.patch("google.cloud.pubsub_v1.SchemaServiceClient") +def test_get_schema_client(mock_schema_client): + """Test get_schema_client factory.""" + mock_creds = mock.Mock(spec=Credentials) + client.get_schema_client(credentials=mock_creds) + + mock_schema_client.assert_called_once() + _, kwargs = mock_schema_client.call_args + assert kwargs["credentials"] == mock_creds + assert "client_info" in kwargs diff --git a/tests/unittests/tools/pubsub/test_pubsub_config.py b/tests/unittests/tools/pubsub/test_pubsub_config.py new file mode 100644 index 0000000000..2e2628df4c --- /dev/null +++ b/tests/unittests/tools/pubsub/test_pubsub_config.py @@ -0,0 +1,27 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from google.adk.tools.pubsub.config import PubSubToolConfig + + +def test_pubsub_tool_config_init(): + """Test PubSubToolConfig initialization.""" + config = PubSubToolConfig(project_id="my-project") + assert config.project_id == "my-project" + + +def test_pubsub_tool_config_default(): + """Test PubSubToolConfig default initialization.""" + config = PubSubToolConfig() + assert config.project_id is None diff --git a/tests/unittests/tools/pubsub/test_pubsub_credentials.py b/tests/unittests/tools/pubsub/test_pubsub_credentials.py new file mode 100644 index 0000000000..7c19586a0b --- /dev/null +++ b/tests/unittests/tools/pubsub/test_pubsub_credentials.py @@ -0,0 +1,91 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest import mock + +from google.adk.tools.pubsub.pubsub_credentials import PUBSUB_DEFAULT_SCOPE +from google.adk.tools.pubsub.pubsub_credentials import PubSubCredentialsConfig +from google.auth.credentials import Credentials +import google.oauth2.credentials +import pytest + + +class TestPubSubCredentials: + """Test suite for PubSub credentials configuration validation. + + This class tests the credential configuration logic that ensures + either existing credentials or client ID/secret pairs are provided. + """ + + def test_pubsub_credentials_config_client_id_secret(self): + """Test PubSubCredentialsConfig with client_id and client_secret. + + Ensures that when client_id and client_secret are provided, the config + object is created with the correct attributes. + """ + config = PubSubCredentialsConfig(client_id="abc", client_secret="def") + assert config.client_id == "abc" + assert config.client_secret == "def" + assert config.scopes == PUBSUB_DEFAULT_SCOPE + assert config.credentials is None + + def test_pubsub_credentials_config_existing_creds(self): + """Test PubSubCredentialsConfig with existing generic credentials. + + Ensures that when a generic Credentials object is provided, it is + stored correctly. + """ + mock_creds = mock.create_autospec(Credentials, instance=True) + config = PubSubCredentialsConfig(credentials=mock_creds) + assert config.credentials == mock_creds + assert config.client_id is None + assert config.client_secret is None + + def test_pubsub_credentials_config_oauth2_creds(self): + """Test PubSubCredentialsConfig with existing OAuth2 credentials. + + Ensures that when a google.oauth2.credentials.Credentials object is + provided, the client_id, client_secret, and scopes are extracted + from the credentials object. + """ + mock_creds = mock.create_autospec( + google.oauth2.credentials.Credentials, instance=True + ) + mock_creds.client_id = "oauth_client_id" + mock_creds.client_secret = "oauth_client_secret" + mock_creds.scopes = ["fake_scope"] + config = PubSubCredentialsConfig(credentials=mock_creds) + assert config.client_id == "oauth_client_id" + assert config.client_secret == "oauth_client_secret" + assert config.scopes == ["fake_scope"] + + def test_pubsub_credentials_config_validation_errors(self): + """Test PubSubCredentialsConfig validation errors. + + Ensures that ValueError is raised under the following conditions: + - No arguments are provided. + - Only client_id is provided. + - Both credentials and client_id/client_secret are provided. + """ + with pytest.raises(ValueError): + PubSubCredentialsConfig() + + with pytest.raises(ValueError): + PubSubCredentialsConfig(client_id="abc") + + mock_creds = mock.create_autospec(Credentials, instance=True) + with pytest.raises(ValueError): + PubSubCredentialsConfig( + credentials=mock_creds, client_id="abc", client_secret="def" + ) diff --git a/tests/unittests/tools/pubsub/test_pubsub_message_tool.py b/tests/unittests/tools/pubsub/test_pubsub_message_tool.py new file mode 100644 index 0000000000..ddc074d9d1 --- /dev/null +++ b/tests/unittests/tools/pubsub/test_pubsub_message_tool.py @@ -0,0 +1,162 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import os +from unittest import mock + +from google.adk.tools.pubsub import client as pubsub_client_lib +from google.adk.tools.pubsub import message_tool +from google.adk.tools.pubsub.config import PubSubToolConfig +from google.cloud import pubsub_v1 +from google.oauth2.credentials import Credentials + + +@mock.patch.dict(os.environ, {}, clear=True) +@mock.patch.object(pubsub_v1.PublisherClient, "publish", autospec=True) +@mock.patch.object(pubsub_client_lib, "get_publisher_client", autospec=True) +def test_publish_message(mock_get_publisher_client, mock_publish): + """Test publish_message tool invocation.""" + topic_name = "projects/my_project_id/topics/my_topic" + message = "Hello World" + mock_credentials = mock.create_autospec(Credentials, instance=True) + tool_settings = PubSubToolConfig(project_id="my_project_id") + + mock_publisher_client = mock.create_autospec( + pubsub_v1.PublisherClient, instance=True + ) + mock_get_publisher_client.return_value = mock_publisher_client + + mock_future = mock.Mock() + mock_future.result.return_value = "message_id" + mock_publisher_client.publish.return_value = mock_future + + result = message_tool.publish_message( + topic_name, message, mock_credentials, tool_settings + ) + + assert result["message_id"] == "message_id" + mock_get_publisher_client.assert_called_once() + mock_publisher_client.publish.assert_called_once() + + +@mock.patch.dict(os.environ, {}, clear=True) +@mock.patch.object(pubsub_v1.PublisherClient, "publish", autospec=True) +@mock.patch.object(pubsub_client_lib, "get_publisher_client", autospec=True) +def test_publish_message_with_ordering_key( + mock_get_publisher_client, mock_publish +): + """Test publish_message tool invocation with ordering_key.""" + topic_name = "projects/my_project_id/topics/my_topic" + message = "Hello World" + ordering_key = "key1" + mock_credentials = mock.create_autospec(Credentials, instance=True) + tool_settings = PubSubToolConfig(project_id="my_project_id") + + mock_publisher_client = mock.create_autospec( + pubsub_v1.PublisherClient, instance=True + ) + mock_get_publisher_client.return_value = mock_publisher_client + + mock_future = mock.Mock() + mock_future.result.return_value = "message_id" + mock_publisher_client.publish.return_value = mock_future + + result = message_tool.publish_message( + topic_name, + message, + mock_credentials, + tool_settings, + ordering_key=ordering_key, + ) + + assert result["message_id"] == "message_id" + mock_get_publisher_client.assert_called_once() + mock_publisher_client.publish.assert_called_once() + + # Verify ordering_key was passed + _, kwargs = mock_publisher_client.publish.call_args + assert kwargs["ordering_key"] == ordering_key + + +@mock.patch.dict(os.environ, {}, clear=True) +@mock.patch.object(pubsub_v1.PublisherClient, "publish", autospec=True) +@mock.patch.object(pubsub_client_lib, "get_publisher_client", autospec=True) +def test_publish_message_with_attributes( + mock_get_publisher_client, mock_publish +): + """Test publish_message tool invocation with attributes.""" + topic_name = "projects/my_project_id/topics/my_topic" + message = "Hello World" + attributes = {"key1": "value1", "key2": "value2"} + mock_credentials = mock.create_autospec(Credentials, instance=True) + tool_settings = PubSubToolConfig(project_id="my_project_id") + + mock_publisher_client = mock.create_autospec( + pubsub_v1.PublisherClient, instance=True + ) + mock_get_publisher_client.return_value = mock_publisher_client + + mock_future = mock.Mock() + mock_future.result.return_value = "message_id" + mock_publisher_client.publish.return_value = mock_future + + result = message_tool.publish_message( + topic_name, + message, + mock_credentials, + tool_settings, + attributes=attributes, + ) + + assert result["message_id"] == "message_id" + mock_get_publisher_client.assert_called_once() + mock_publisher_client.publish.assert_called_once() + + # Verify attributes were passed + _, kwargs = mock_publisher_client.publish.call_args + assert kwargs["key1"] == "value1" + assert kwargs["key2"] == "value2" + + +@mock.patch.dict(os.environ, {}, clear=True) +@mock.patch.object(pubsub_v1.PublisherClient, "publish", autospec=True) +@mock.patch.object(pubsub_client_lib, "get_publisher_client", autospec=True) +def test_publish_message_exception(mock_get_publisher_client, mock_publish): + """Test publish_message tool invocation when exception occurs.""" + topic_name = "projects/my_project_id/topics/my_topic" + message = "Hello World" + mock_credentials = mock.create_autospec(Credentials, instance=True) + tool_settings = PubSubToolConfig(project_id="my_project_id") + + mock_publisher_client = mock.create_autospec( + pubsub_v1.PublisherClient, instance=True + ) + mock_get_publisher_client.return_value = mock_publisher_client + + # Simulate an exception during publish + mock_publisher_client.publish.side_effect = Exception("Publish failed") + + result = message_tool.publish_message( + topic_name, + message, + mock_credentials, + tool_settings, + ) + + assert result["status"] == "ERROR" + assert "Publish failed" in result["error_details"] + mock_get_publisher_client.assert_called_once() + mock_publisher_client.publish.assert_called_once() diff --git a/tests/unittests/tools/pubsub/test_pubsub_metadata_tool.py b/tests/unittests/tools/pubsub/test_pubsub_metadata_tool.py new file mode 100644 index 0000000000..3d8058f8d4 --- /dev/null +++ b/tests/unittests/tools/pubsub/test_pubsub_metadata_tool.py @@ -0,0 +1,210 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import os +from unittest import mock + +from google.adk.tools.pubsub import client as pubsub_client_lib +from google.adk.tools.pubsub import metadata_tool +from google.adk.tools.pubsub.config import PubSubToolConfig +from google.cloud import pubsub_v1 +from google.oauth2.credentials import Credentials + + +@mock.patch.dict(os.environ, {}, clear=True) +@mock.patch.object(pubsub_v1.PublisherClient, "list_topics", autospec=True) +@mock.patch.object(pubsub_client_lib, "get_publisher_client", autospec=True) +def test_list_topics(mock_get_publisher_client, mock_list_topics): + """Test list_topics tool invocation.""" + project = "my_project_id" + mock_credentials = mock.create_autospec(Credentials, instance=True) + tool_settings = PubSubToolConfig(project_id=project) + + mock_publisher_client = mock.create_autospec( + pubsub_v1.PublisherClient, instance=True + ) + mock_get_publisher_client.return_value = mock_publisher_client + mock_publisher_client.list_topics.return_value = [ + mock.Mock(name="projects/my_project_id/topics/topic1"), + mock.Mock(name="projects/my_project_id/topics/topic2"), + ] + # Fix the mock names to return the string name when accessed + mock_publisher_client.list_topics.return_value[0].name = "topic1" + mock_publisher_client.list_topics.return_value[1].name = "topic2" + + result = metadata_tool.list_topics(project, mock_credentials, tool_settings) + assert result == ["topic1", "topic2"] + mock_get_publisher_client.assert_called_once() + + +@mock.patch.dict(os.environ, {}, clear=True) +@mock.patch.object(pubsub_v1.PublisherClient, "get_topic", autospec=True) +@mock.patch.object(pubsub_client_lib, "get_publisher_client", autospec=True) +def test_get_topic(mock_get_publisher_client, mock_get_topic): + """Test get_topic tool invocation.""" + topic_name = "projects/my_project_id/topics/my_topic" + mock_credentials = mock.create_autospec(Credentials, instance=True) + tool_settings = PubSubToolConfig(project_id="my_project_id") + + mock_publisher_client = mock.create_autospec( + pubsub_v1.PublisherClient, instance=True + ) + mock_get_publisher_client.return_value = mock_publisher_client + + mock_topic = mock.Mock() + mock_topic.name = topic_name + mock_topic.labels = {"key": "value"} + mock_topic.kms_key_name = "key_name" + mock_topic.schema_settings = "schema_settings" + mock_topic.message_storage_policy = "storage_policy" + + mock_publisher_client.get_topic.return_value = mock_topic + + result = metadata_tool.get_topic(topic_name, mock_credentials, tool_settings) + + assert result["name"] == topic_name + assert result["labels"] == {"key": "value"} + mock_get_publisher_client.assert_called_once() + + +@mock.patch.dict(os.environ, {}, clear=True) +@mock.patch.object( + pubsub_v1.SubscriberClient, "list_subscriptions", autospec=True +) +@mock.patch.object(pubsub_client_lib, "get_subscriber_client", autospec=True) +def test_list_subscriptions( + mock_get_subscriber_client, mock_list_subscriptions +): + """Test list_subscriptions tool invocation.""" + project = "my_project_id" + mock_credentials = mock.create_autospec(Credentials, instance=True) + tool_settings = PubSubToolConfig(project_id=project) + + mock_subscriber_client = mock.create_autospec( + pubsub_v1.SubscriberClient, instance=True + ) + mock_get_subscriber_client.return_value = mock_subscriber_client + mock_subscriber_client.list_subscriptions.return_value = [ + mock.Mock(name="projects/my_project_id/subscriptions/sub1"), + mock.Mock(name="projects/my_project_id/subscriptions/sub2"), + ] + mock_subscriber_client.list_subscriptions.return_value[0].name = "sub1" + mock_subscriber_client.list_subscriptions.return_value[1].name = "sub2" + + result = metadata_tool.list_subscriptions( + project, mock_credentials, tool_settings + ) + assert result == ["sub1", "sub2"] + mock_get_subscriber_client.assert_called_once() + + +@mock.patch.dict(os.environ, {}, clear=True) +@mock.patch.object( + pubsub_v1.SubscriberClient, "get_subscription", autospec=True +) +@mock.patch.object(pubsub_client_lib, "get_subscriber_client", autospec=True) +def test_get_subscription(mock_get_subscriber_client, mock_get_subscription): + """Test get_subscription tool invocation.""" + subscription_name = "projects/my_project_id/subscriptions/my_sub" + mock_credentials = mock.create_autospec(Credentials, instance=True) + tool_settings = PubSubToolConfig(project_id="my_project_id") + + mock_subscriber_client = mock.create_autospec( + pubsub_v1.SubscriberClient, instance=True + ) + mock_get_subscriber_client.return_value = mock_subscriber_client + + mock_subscription = mock.Mock() + mock_subscription.name = subscription_name + mock_subscription.topic = "projects/my_project_id/topics/my_topic" + mock_subscription.push_config = "push_config" + mock_subscription.ack_deadline_seconds = 10 + mock_subscription.retain_acked_messages = True + mock_subscription.message_retention_duration = "duration" + mock_subscription.labels = {"key": "value"} + mock_subscription.enable_message_ordering = True + mock_subscription.expiration_policy = "expiration" + mock_subscription.filter = "filter" + mock_subscription.dead_letter_policy = "dead_letter" + mock_subscription.retry_policy = "retry" + mock_subscription.detached = False + + mock_subscriber_client.get_subscription.return_value = mock_subscription + + result = metadata_tool.get_subscription( + subscription_name, mock_credentials, tool_settings + ) + + assert result["name"] == subscription_name + assert result["topic"] == "projects/my_project_id/topics/my_topic" + mock_get_subscriber_client.assert_called_once() + + +@mock.patch.dict(os.environ, {}, clear=True) +@mock.patch.object(pubsub_v1.SchemaServiceClient, "list_schemas", autospec=True) +@mock.patch.object(pubsub_client_lib, "get_schema_client", autospec=True) +def test_list_schemas(mock_get_schema_client, mock_list_schemas): + """Test list_schemas tool invocation.""" + project = "my_project_id" + mock_credentials = mock.create_autospec(Credentials, instance=True) + tool_settings = PubSubToolConfig(project_id=project) + + mock_schema_client = mock.create_autospec( + pubsub_v1.SchemaServiceClient, instance=True + ) + mock_get_schema_client.return_value = mock_schema_client + mock_schema_client.list_schemas.return_value = [ + mock.Mock(name="projects/my_project_id/schemas/schema1"), + mock.Mock(name="projects/my_project_id/schemas/schema2"), + ] + mock_schema_client.list_schemas.return_value[0].name = "schema1" + mock_schema_client.list_schemas.return_value[1].name = "schema2" + + result = metadata_tool.list_schemas(project, mock_credentials, tool_settings) + assert result == ["schema1", "schema2"] + mock_get_schema_client.assert_called_once() + + +@mock.patch.dict(os.environ, {}, clear=True) +@mock.patch.object(pubsub_v1.SchemaServiceClient, "get_schema", autospec=True) +@mock.patch.object(pubsub_client_lib, "get_schema_client", autospec=True) +def test_get_schema(mock_get_schema_client, mock_get_schema): + """Test get_schema tool invocation.""" + schema_name = "projects/my_project_id/schemas/my_schema" + mock_credentials = mock.create_autospec(Credentials, instance=True) + tool_settings = PubSubToolConfig(project_id="my_project_id") + + mock_schema_client = mock.create_autospec( + pubsub_v1.SchemaServiceClient, instance=True + ) + mock_get_schema_client.return_value = mock_schema_client + + mock_schema = mock.Mock() + mock_schema.name = schema_name + mock_schema.type_ = "AVRO" + mock_schema.definition = "definition" + mock_schema.revision_id = "revision_id" + mock_schema.revision_create_time = "time" + + mock_schema_client.get_schema.return_value = mock_schema + + result = metadata_tool.get_schema( + schema_name, mock_credentials, tool_settings + ) + + assert result["name"] == schema_name + assert result["type"] == "AVRO" + mock_get_schema_client.assert_called_once() diff --git a/tests/unittests/tools/pubsub/test_pubsub_toolset.py b/tests/unittests/tools/pubsub/test_pubsub_toolset.py new file mode 100644 index 0000000000..567d434a23 --- /dev/null +++ b/tests/unittests/tools/pubsub/test_pubsub_toolset.py @@ -0,0 +1,133 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from google.adk.tools.google_tool import GoogleTool +from google.adk.tools.pubsub import PubSubCredentialsConfig +from google.adk.tools.pubsub import PubSubToolset +from google.adk.tools.pubsub.config import PubSubToolConfig +import pytest + + +@pytest.mark.asyncio +async def test_pubsub_toolset_tools_default(): + """Test default PubSub toolset. + + This test verifies the behavior of the PubSub toolset when no filter is + specified. + """ + credentials_config = PubSubCredentialsConfig( + client_id="abc", client_secret="def" + ) + toolset = PubSubToolset( + credentials_config=credentials_config, pubsub_tool_config=None + ) + # Verify that the tool config is initialized to default values. + assert isinstance(toolset._tool_settings, PubSubToolConfig) # pylint: disable=protected-access + assert toolset._tool_settings.__dict__ == PubSubToolConfig().__dict__ # pylint: disable=protected-access + + tools = await toolset.get_tools() + assert tools is not None + + assert len(tools) == 9 + assert all([isinstance(tool, GoogleTool) for tool in tools]) + + expected_tool_names = set([ + "list_topics", + "get_topic", + "list_subscriptions", + "get_subscription", + "list_schemas", + "get_schema", + "list_schema_revisions", + "get_schema_revision", + "publish_message", + ]) + actual_tool_names = set([tool.name for tool in tools]) + assert actual_tool_names == expected_tool_names + + +@pytest.mark.parametrize( + "selected_tools", + [ + pytest.param([], id="None"), + pytest.param(["list_topics", "get_topic"], id="topic-metadata"), + pytest.param( + ["list_subscriptions", "get_subscription"], + id="subscription-metadata", + ), + pytest.param(["publish_message"], id="publish"), + ], +) +@pytest.mark.asyncio +async def test_pubsub_toolset_tools_selective(selected_tools): + """Test PubSub toolset with filter. + + This test verifies the behavior of the PubSub toolset when filter is + specified. A use case for this would be when the agent builder wants to + use only a subset of the tools provided by the toolset. + """ + credentials_config = PubSubCredentialsConfig( + client_id="abc", client_secret="def" + ) + toolset = PubSubToolset( + credentials_config=credentials_config, tool_filter=selected_tools + ) + tools = await toolset.get_tools() + assert tools is not None + + assert len(tools) == len(selected_tools) + assert all([isinstance(tool, GoogleTool) for tool in tools]) + + expected_tool_names = set(selected_tools) + actual_tool_names = set([tool.name for tool in tools]) + assert actual_tool_names == expected_tool_names + + +@pytest.mark.parametrize( + ("selected_tools", "returned_tools"), + [ + pytest.param(["unknown"], [], id="all-unknown"), + pytest.param( + ["unknown", "publish_message"], + ["publish_message"], + id="mixed-known-unknown", + ), + ], +) +@pytest.mark.asyncio +async def test_pubsub_toolset_unknown_tool(selected_tools, returned_tools): + """Test PubSub toolset with filter. + + This test verifies the behavior of the PubSub toolset when filter is + specified with an unknown tool. + """ + credentials_config = PubSubCredentialsConfig( + client_id="abc", client_secret="def" + ) + + toolset = PubSubToolset( + credentials_config=credentials_config, tool_filter=selected_tools + ) + + tools = await toolset.get_tools() + assert tools is not None + + assert len(tools) == len(returned_tools) + assert all([isinstance(tool, GoogleTool) for tool in tools]) + + expected_tool_names = set(returned_tools) + actual_tool_names = set([tool.name for tool in tools]) + assert actual_tool_names == expected_tool_names From 37a662dd04601caa3632690d659faecf747fce36 Mon Sep 17 00:00:00 2001 From: Kamal Aboul-Hosn Date: Sun, 7 Dec 2025 18:37:41 -0500 Subject: [PATCH 02/14] Remove admin operations and add some subscribe-side operations --- contributing/samples/gepa/experiment.py | 1 - contributing/samples/gepa/run_experiment.py | 1 - contributing/samples/pubsub/README.md | 41 +-- contributing/samples/pubsub/agent.py | 7 +- src/google/adk/tools/pubsub/__init__.py | 11 + src/google/adk/tools/pubsub/client.py | 32 -- src/google/adk/tools/pubsub/message_tool.py | 93 +++++ src/google/adk/tools/pubsub/metadata_tool.py | 334 ------------------ src/google/adk/tools/pubsub/pubsub_toolset.py | 11 +- .../tools/pubsub/test_pubsub_client.py | 14 +- .../tools/pubsub/test_pubsub_message_tool.py | 153 ++++++++ .../tools/pubsub/test_pubsub_metadata_tool.py | 210 ----------- .../tools/pubsub/test_pubsub_toolset.py | 19 +- 13 files changed, 277 insertions(+), 650 deletions(-) delete mode 100644 src/google/adk/tools/pubsub/metadata_tool.py delete mode 100644 tests/unittests/tools/pubsub/test_pubsub_metadata_tool.py diff --git a/contributing/samples/gepa/experiment.py b/contributing/samples/gepa/experiment.py index 2f5d03a772..f68b349d9c 100644 --- a/contributing/samples/gepa/experiment.py +++ b/contributing/samples/gepa/experiment.py @@ -43,7 +43,6 @@ from tau_bench.types import EnvRunResult from tau_bench.types import RunConfig import tau_bench_agent as tau_bench_agent_lib - import utils diff --git a/contributing/samples/gepa/run_experiment.py b/contributing/samples/gepa/run_experiment.py index cfd850b3a3..1bc4ee58c8 100644 --- a/contributing/samples/gepa/run_experiment.py +++ b/contributing/samples/gepa/run_experiment.py @@ -25,7 +25,6 @@ from absl import flags import experiment from google.genai import types - import utils _OUTPUT_DIR = flags.DEFINE_string( diff --git a/contributing/samples/pubsub/README.md b/contributing/samples/pubsub/README.md index 53ec73f221..507902abca 100644 --- a/contributing/samples/pubsub/README.md +++ b/contributing/samples/pubsub/README.md @@ -5,41 +5,17 @@ This sample agent demonstrates the Pub/Sub first-party tools in ADK, distributed via the `google.adk.tools.pubsub` module. These tools include: -1. `list_topics` +1. `publish_message` - Fetches Pub/Sub topics present in a GCP project. - -2. `get_topic` - - Fetches metadata about a Pub/Sub topic. - -3. `list_subscriptions` - - Fetches subscriptions present in a GCP project. - -4. `get_subscription` - - Fetches metadata about a Pub/Sub subscription. - -5. `list_schemas` - - Fetches schemas present in a GCP project. - -6. `get_schema` - - Fetches metadata about a Pub/Sub schema. - -7. `list_schema_revisions` - - Fetches revisions of a Pub/Sub schema. + Publishes a message to a Pub/Sub topic. -8. `get_schema_revision` +2. `pull_messages` - Fetches metadata about a specific Pub/Sub schema revision. + Pulls messages from a Pub/Sub subscription. -9. `publish_message` +3. `acknowledge_messages` - Publishes a message to a Pub/Sub topic. + Acknowledges messages on a Pub/Sub subscription. ## How to use @@ -107,7 +83,6 @@ type. ## Sample prompts -* list topics in my project -* show details for topic 'my-topic' -* list subscriptions * publish 'Hello World' to 'my-topic' +* pull messages from 'my-subscription' +* acknowledge message 'ack-id' from 'my-subscription' diff --git a/contributing/samples/pubsub/agent.py b/contributing/samples/pubsub/agent.py index 3658c42604..400471b09e 100644 --- a/contributing/samples/pubsub/agent.py +++ b/contributing/samples/pubsub/agent.py @@ -68,13 +68,12 @@ model="gemini-2.0-flash", name=PUBSUB_AGENT_NAME, description=( - "Agent to answer questions about Pub/Sub topics, subscriptions, and" - " schemas, and publish messages." + "Agent to publish, pull, and acknowledge messages from Google Cloud" + " Pub/Sub." ), instruction="""\ You are a cloud engineer agent with access to Google Cloud Pub/Sub tools. - Make use of those tools to answer the user's questions about topics, subscriptions, and schemas. - You can also publish messages to topics if requested. + You can publish messages to topics, pull messages from subscriptions, and acknowledge messages. """, tools=[pubsub_toolset], ) diff --git a/src/google/adk/tools/pubsub/__init__.py b/src/google/adk/tools/pubsub/__init__.py index 72faf0bcc8..0e57a1cc59 100644 --- a/src/google/adk/tools/pubsub/__init__.py +++ b/src/google/adk/tools/pubsub/__init__.py @@ -12,6 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""Pub/Sub Tools (Experimental). + +Pub/Sub Tools under this module are hand crafted and customized while the tools +under google.adk.tools.google_api_tool are auto generated based on API +definition. The rationales to have customized tool are: + +1. Better handling of base64 encoding for published messages. +2. A richer subscribe-side API that reflects how users may want to pull/ack + messages. +""" + from .pubsub_credentials import PubSubCredentialsConfig from .pubsub_toolset import PubSubToolset diff --git a/src/google/adk/tools/pubsub/client.py b/src/google/adk/tools/pubsub/client.py index bdf85ecc5b..7cd61f40dc 100644 --- a/src/google/adk/tools/pubsub/client.py +++ b/src/google/adk/tools/pubsub/client.py @@ -89,35 +89,3 @@ def get_subscriber_client( ) return subscriber_client - - -def get_schema_client( - *, - credentials: Credentials, - user_agent: Optional[Union[str, List[str]]] = None, -) -> pubsub_v1.SchemaServiceClient: - """Get a Pub/Sub Schema Service client. - - Args: - credentials: The credentials to use for the request. - user_agent: The user agent to use for the request. - - Returns: - A Pub/Sub Schema Service client. - """ - - user_agents = [USER_AGENT] - if user_agent: - if isinstance(user_agent, str): - user_agents.append(user_agent) - else: - user_agents.extend([ua for ua in user_agent if ua]) - - client_info = ClientInfo(user_agent=" ".join(user_agents)) - - schema_client = pubsub_v1.SchemaServiceClient( - credentials=credentials, - client_info=client_info, - ) - - return schema_client diff --git a/src/google/adk/tools/pubsub/message_tool.py b/src/google/adk/tools/pubsub/message_tool.py index cf83296d84..539c78d09b 100644 --- a/src/google/adk/tools/pubsub/message_tool.py +++ b/src/google/adk/tools/pubsub/message_tool.py @@ -14,6 +14,7 @@ from __future__ import annotations +from typing import List from typing import Optional from google.auth.credentials import Credentials @@ -61,3 +62,95 @@ def publish_message( "status": "ERROR", "error_details": str(ex), } + + +def pull_messages( + subscription_name: str, + credentials: Credentials, + settings: PubSubToolConfig, + max_messages: int = 1, + auto_ack: bool = False, +) -> dict: + """Pull messages from a Pub/Sub subscription. + + Args: + subscription_name (str): The Pub/Sub subscription name (e.g. projects/my-project/subscriptions/my-sub). + credentials (Credentials): The credentials to use for the request. + settings (PubSubToolConfig): The Pub/Sub tool settings. + max_messages (int): The maximum number of messages to pull. Defaults to 1. + auto_ack (bool): Whether to automatically acknowledge the messages. Defaults to False. + + Returns: + dict: Dictionary with the list of pulled messages. + """ + try: + subscriber_client = client.get_subscriber_client( + credentials=credentials, + user_agent=[settings.project_id, "pull_messages"], + ) + + response = subscriber_client.pull( + subscription=subscription_name, + max_messages=max_messages, + ) + + messages = [] + ack_ids = [] + for received_message in response.received_messages: + messages.append({ + "message_id": received_message.message.message_id, + "data": received_message.message.data.decode("utf-8"), + "attributes": dict(received_message.message.attributes), + "publish_time": str(received_message.message.publish_time), + "ack_id": received_message.ack_id, + }) + ack_ids.append(received_message.ack_id) + + if auto_ack and ack_ids: + subscriber_client.acknowledge( + subscription=subscription_name, + ack_ids=ack_ids, + ) + + return {"messages": messages} + except Exception as ex: + return { + "status": "ERROR", + "error_details": str(ex), + } + + +def acknowledge_messages( + subscription_name: str, + ack_ids: List[str], + credentials: Credentials, + settings: PubSubToolConfig, +) -> dict: + """Acknowledge messages on a Pub/Sub subscription. + + Args: + subscription_name (str): The Pub/Sub subscription name (e.g. projects/my-project/subscriptions/my-sub). + ack_ids (List[str]): List of acknowledgment IDs to acknowledge. + credentials (Credentials): The credentials to use for the request. + settings (PubSubToolConfig): The Pub/Sub tool settings. + + Returns: + dict: Status of the operation. + """ + try: + subscriber_client = client.get_subscriber_client( + credentials=credentials, + user_agent=[settings.project_id, "acknowledge_messages"], + ) + + subscriber_client.acknowledge( + subscription=subscription_name, + ack_ids=ack_ids, + ) + + return {"status": "SUCCESS"} + except Exception as ex: + return { + "status": "ERROR", + "error_details": str(ex), + } diff --git a/src/google/adk/tools/pubsub/metadata_tool.py b/src/google/adk/tools/pubsub/metadata_tool.py deleted file mode 100644 index 22f3860c0d..0000000000 --- a/src/google/adk/tools/pubsub/metadata_tool.py +++ /dev/null @@ -1,334 +0,0 @@ -# Copyright 2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -from google.auth.credentials import Credentials - -from . import client -from .config import PubSubToolConfig - - -def list_topics( - project_id: str, credentials: Credentials, settings: PubSubToolConfig -) -> list[str]: - """List Pub/Sub topics in a Google Cloud project. - - Args: - project_id (str): The Google Cloud project id. - credentials (Credentials): The credentials to use for the request. - settings (PubSubToolConfig): The Pub/Sub tool settings. - - Returns: - list[str]: List of the Pub/Sub topic names present in the project. - """ - try: - publisher_client = client.get_publisher_client( - credentials=credentials, - user_agent=[settings.project_id, "list_topics"], - ) - - project_path = f"projects/{project_id}" - topics = [] - for topic in publisher_client.list_topics( - request={"project": project_path} - ): - topics.append(topic.name) - return topics - except Exception as ex: - return { - "status": "ERROR", - "error_details": str(ex), - } - - -def get_topic( - topic_name: str, - credentials: Credentials, - settings: PubSubToolConfig, -) -> dict: - """Get metadata information about a Pub/Sub topic. - - Args: - topic_name (str): The Pub/Sub topic name (e.g. projects/my-project/topics/my-topic). - credentials (Credentials): The credentials to use for the request. - settings (PubSubToolConfig): The Pub/Sub tool settings. - - Returns: - dict: Dictionary representing the properties of the topic. - """ - try: - publisher_client = client.get_publisher_client( - credentials=credentials, - user_agent=[settings.project_id, "get_topic"], - ) - topic = publisher_client.get_topic(request={"topic": topic_name}) - - return { - "name": topic.name, - "labels": dict(topic.labels), - "kms_key_name": topic.kms_key_name, - "schema_settings": ( - str(topic.schema_settings) if topic.schema_settings else None - ), - "message_storage_policy": ( - str(topic.message_storage_policy) - if topic.message_storage_policy - else None - ), - } - except Exception as ex: - return { - "status": "ERROR", - "error_details": str(ex), - } - - -def list_subscriptions( - project_id: str, credentials: Credentials, settings: PubSubToolConfig -) -> list[str]: - """List Pub/Sub subscriptions in a Google Cloud project. - - Args: - project_id (str): The Google Cloud project id. - credentials (Credentials): The credentials to use for the request. - settings (PubSubToolConfig): The Pub/Sub tool settings. - - Returns: - list[str]: List of the Pub/Sub subscription names present in the project. - """ - try: - subscriber_client = client.get_subscriber_client( - credentials=credentials, - user_agent=[settings.project_id, "list_subscriptions"], - ) - - project_path = f"projects/{project_id}" - subscriptions = [] - for subscription in subscriber_client.list_subscriptions( - request={"project": project_path} - ): - subscriptions.append(subscription.name) - return subscriptions - except Exception as ex: - return { - "status": "ERROR", - "error_details": str(ex), - } - - -def get_subscription( - subscription_name: str, - credentials: Credentials, - settings: PubSubToolConfig, -) -> dict: - """Get metadata information about a Pub/Sub subscription. - - Args: - subscription_name (str): The Pub/Sub subscription name (e.g. projects/my-project/subscriptions/my-sub). - credentials (Credentials): The credentials to use for the request. - settings (PubSubToolConfig): The Pub/Sub tool settings. - - Returns: - dict: Dictionary representing the properties of the subscription. - """ - try: - subscriber_client = client.get_subscriber_client( - credentials=credentials, - user_agent=[settings.project_id, "get_subscription"], - ) - subscription = subscriber_client.get_subscription( - request={"subscription": subscription_name} - ) - - return { - "name": subscription.name, - "topic": subscription.topic, - "push_config": ( - str(subscription.push_config) if subscription.push_config else None - ), - "ack_deadline_seconds": subscription.ack_deadline_seconds, - "retain_acked_messages": subscription.retain_acked_messages, - "message_retention_duration": ( - str(subscription.message_retention_duration) - if subscription.message_retention_duration - else None - ), - "labels": dict(subscription.labels), - "enable_message_ordering": subscription.enable_message_ordering, - "expiration_policy": ( - str(subscription.expiration_policy) - if subscription.expiration_policy - else None - ), - "filter": subscription.filter, - "dead_letter_policy": ( - str(subscription.dead_letter_policy) - if subscription.dead_letter_policy - else None - ), - "retry_policy": ( - str(subscription.retry_policy) - if subscription.retry_policy - else None - ), - "detached": subscription.detached, - } - except Exception as ex: - return { - "status": "ERROR", - "error_details": str(ex), - } - - -def list_schemas( - project_id: str, credentials: Credentials, settings: PubSubToolConfig -) -> list[str]: - """List Pub/Sub schemas in a Google Cloud project. - - Args: - project_id (str): The Google Cloud project id. - credentials (Credentials): The credentials to use for the request. - settings (PubSubToolConfig): The Pub/Sub tool settings. - - Returns: - list[str]: List of the Pub/Sub schema names present in the project. - """ - try: - schema_client = client.get_schema_client( - credentials=credentials, - user_agent=[settings.project_id, "list_schemas"], - ) - - project_path = f"projects/{project_id}" - schemas = [] - for schema in schema_client.list_schemas(request={"parent": project_path}): - schemas.append(schema.name) - return schemas - except Exception as ex: - return { - "status": "ERROR", - "error_details": str(ex), - } - - -def get_schema( - schema_name: str, - credentials: Credentials, - settings: PubSubToolConfig, -) -> dict: - """Get metadata information about a Pub/Sub schema. - - Args: - schema_name (str): The Pub/Sub schema name (e.g. projects/my-project/schemas/my-schema). - credentials (Credentials): The credentials to use for the request. - settings (PubSubToolConfig): The Pub/Sub tool settings. - - Returns: - dict: Dictionary representing the properties of the schema. - """ - try: - schema_client = client.get_schema_client( - credentials=credentials, - user_agent=[settings.project_id, "get_schema"], - ) - schema = schema_client.get_schema(request={"name": schema_name}) - - return { - "name": schema.name, - "type": str(schema.type_), - "definition": schema.definition, - "revision_id": schema.revision_id, - "revision_create_time": str(schema.revision_create_time), - } - except Exception as ex: - return { - "status": "ERROR", - "error_details": str(ex), - } - - -def list_schema_revisions( - schema_name: str, - credentials: Credentials, - settings: PubSubToolConfig, -) -> list[str]: - """List revisions of a Pub/Sub schema. - - Args: - schema_name (str): The Pub/Sub schema name (e.g. projects/my-project/schemas/my-schema). - credentials (Credentials): The credentials to use for the request. - settings (PubSubToolConfig): The Pub/Sub tool settings. - - Returns: - list[str]: List of the Pub/Sub schema revision IDs. - """ - try: - schema_client = client.get_schema_client( - credentials=credentials, - user_agent=[settings.project_id, "list_schema_revisions"], - ) - - revisions = [] - for schema in schema_client.list_schema_revisions( - request={"name": schema_name} - ): - revisions.append(schema.revision_id) - return revisions - except Exception as ex: - return { - "status": "ERROR", - "error_details": str(ex), - } - - -def get_schema_revision( - schema_name: str, - revision_id: str, - credentials: Credentials, - settings: PubSubToolConfig, -) -> dict: - """Get metadata information about a specific Pub/Sub schema revision. - - Args: - schema_name (str): The Pub/Sub schema name (e.g. projects/my-project/schemas/my-schema). - revision_id (str): The revision ID of the schema. - credentials (Credentials): The credentials to use for the request. - settings (PubSubToolConfig): The Pub/Sub tool settings. - - Returns: - dict: Dictionary representing the properties of the schema revision. - """ - try: - schema_client = client.get_schema_client( - credentials=credentials, - user_agent=[settings.project_id, "get_schema_revision"], - ) - # The get_schema method can take a revision ID appended to the name - # Format: projects/{project}/schemas/{schema}@{revision} - name_with_revision = f"{schema_name}@{revision_id}" - schema = schema_client.get_schema(request={"name": name_with_revision}) - - return { - "name": schema.name, - "type": str(schema.type_), - "definition": schema.definition, - "revision_id": schema.revision_id, - "revision_create_time": str(schema.revision_create_time), - } - except Exception as ex: - return { - "status": "ERROR", - "error_details": str(ex), - } diff --git a/src/google/adk/tools/pubsub/pubsub_toolset.py b/src/google/adk/tools/pubsub/pubsub_toolset.py index 402d2abcbd..394cd89073 100644 --- a/src/google/adk/tools/pubsub/pubsub_toolset.py +++ b/src/google/adk/tools/pubsub/pubsub_toolset.py @@ -22,7 +22,6 @@ from typing_extensions import override from . import message_tool -from . import metadata_tool from ...tools.base_tool import BaseTool from ...tools.base_toolset import BaseToolset from ...tools.base_toolset import ToolPredicate @@ -75,15 +74,9 @@ async def get_tools( tool_settings=self._tool_settings, ) for func in [ - metadata_tool.list_topics, - metadata_tool.get_topic, - metadata_tool.list_subscriptions, - metadata_tool.get_subscription, - metadata_tool.list_schemas, - metadata_tool.get_schema, - metadata_tool.list_schema_revisions, - metadata_tool.get_schema_revision, message_tool.publish_message, + message_tool.pull_messages, + message_tool.acknowledge_messages, ] ] diff --git a/tests/unittests/tools/pubsub/test_pubsub_client.py b/tests/unittests/tools/pubsub/test_pubsub_client.py index f778e8dacf..b37f8b9b94 100644 --- a/tests/unittests/tools/pubsub/test_pubsub_client.py +++ b/tests/unittests/tools/pubsub/test_pubsub_client.py @@ -30,6 +30,8 @@ def test_get_publisher_client(mock_publisher_client): assert kwargs["credentials"] == mock_creds assert "client_info" in kwargs + assert "client_info" in kwargs + @mock.patch("google.cloud.pubsub_v1.SubscriberClient") def test_get_subscriber_client(mock_subscriber_client): @@ -41,15 +43,3 @@ def test_get_subscriber_client(mock_subscriber_client): _, kwargs = mock_subscriber_client.call_args assert kwargs["credentials"] == mock_creds assert "client_info" in kwargs - - -@mock.patch("google.cloud.pubsub_v1.SchemaServiceClient") -def test_get_schema_client(mock_schema_client): - """Test get_schema_client factory.""" - mock_creds = mock.Mock(spec=Credentials) - client.get_schema_client(credentials=mock_creds) - - mock_schema_client.assert_called_once() - _, kwargs = mock_schema_client.call_args - assert kwargs["credentials"] == mock_creds - assert "client_info" in kwargs diff --git a/tests/unittests/tools/pubsub/test_pubsub_message_tool.py b/tests/unittests/tools/pubsub/test_pubsub_message_tool.py index ddc074d9d1..e4d001276a 100644 --- a/tests/unittests/tools/pubsub/test_pubsub_message_tool.py +++ b/tests/unittests/tools/pubsub/test_pubsub_message_tool.py @@ -160,3 +160,156 @@ def test_publish_message_exception(mock_get_publisher_client, mock_publish): assert "Publish failed" in result["error_details"] mock_get_publisher_client.assert_called_once() mock_publisher_client.publish.assert_called_once() + + +@mock.patch.dict(os.environ, {}, clear=True) +@mock.patch.object(pubsub_client_lib, "get_subscriber_client", autospec=True) +def test_pull_messages(mock_get_subscriber_client): + """Test pull_messages tool invocation.""" + subscription_name = "projects/my_project_id/subscriptions/my_sub" + mock_credentials = mock.create_autospec(Credentials, instance=True) + tool_settings = PubSubToolConfig(project_id="my_project_id") + + mock_subscriber_client = mock.create_autospec( + pubsub_v1.SubscriberClient, instance=True + ) + mock_get_subscriber_client.return_value = mock_subscriber_client + + mock_response = mock.Mock() + mock_message = mock.Mock() + mock_message.message.message_id = "123" + mock_message.message.data = b"Hello" + mock_message.message.attributes = {"key": "value"} + mock_message.message.publish_time = "2023-01-01T00:00:00Z" + mock_message.ack_id = "ack_123" + mock_response.received_messages = [mock_message] + mock_subscriber_client.pull.return_value = mock_response + + result = message_tool.pull_messages( + subscription_name, mock_credentials, tool_settings + ) + + assert len(result["messages"]) == 1 + assert result["messages"][0]["message_id"] == "123" + assert result["messages"][0]["data"] == "Hello" + assert result["messages"][0]["attributes"] == {"key": "value"} + assert result["messages"][0]["ack_id"] == "ack_123" + + mock_get_subscriber_client.assert_called_once() + mock_subscriber_client.pull.assert_called_once_with( + subscription=subscription_name, max_messages=1 + ) + mock_subscriber_client.acknowledge.assert_not_called() + + +@mock.patch.dict(os.environ, {}, clear=True) +@mock.patch.object(pubsub_client_lib, "get_subscriber_client", autospec=True) +def test_pull_messages_auto_ack(mock_get_subscriber_client): + """Test pull_messages tool invocation with auto_ack.""" + subscription_name = "projects/my_project_id/subscriptions/my_sub" + mock_credentials = mock.create_autospec(Credentials, instance=True) + tool_settings = PubSubToolConfig(project_id="my_project_id") + + mock_subscriber_client = mock.create_autospec( + pubsub_v1.SubscriberClient, instance=True + ) + mock_get_subscriber_client.return_value = mock_subscriber_client + + mock_response = mock.Mock() + mock_message = mock.Mock() + mock_message.message.message_id = "123" + mock_message.message.data = b"Hello" + mock_message.message.attributes = {} + mock_message.message.publish_time = "2023-01-01T00:00:00Z" + mock_message.ack_id = "ack_123" + mock_response.received_messages = [mock_message] + mock_subscriber_client.pull.return_value = mock_response + + result = message_tool.pull_messages( + subscription_name, + mock_credentials, + tool_settings, + max_messages=5, + auto_ack=True, + ) + + assert len(result["messages"]) == 1 + mock_get_subscriber_client.assert_called_once() + mock_subscriber_client.pull.assert_called_once_with( + subscription=subscription_name, max_messages=5 + ) + mock_subscriber_client.acknowledge.assert_called_once_with( + subscription=subscription_name, ack_ids=["ack_123"] + ) + + +@mock.patch.dict(os.environ, {}, clear=True) +@mock.patch.object(pubsub_client_lib, "get_subscriber_client", autospec=True) +def test_pull_messages_exception(mock_get_subscriber_client): + """Test pull_messages tool invocation when exception occurs.""" + subscription_name = "projects/my_project_id/subscriptions/my_sub" + mock_credentials = mock.create_autospec(Credentials, instance=True) + tool_settings = PubSubToolConfig(project_id="my_project_id") + + mock_subscriber_client = mock.create_autospec( + pubsub_v1.SubscriberClient, instance=True + ) + mock_get_subscriber_client.return_value = mock_subscriber_client + + mock_subscriber_client.pull.side_effect = Exception("Pull failed") + + result = message_tool.pull_messages( + subscription_name, mock_credentials, tool_settings + ) + + assert result["status"] == "ERROR" + assert "Pull failed" in result["error_details"] + + +@mock.patch.dict(os.environ, {}, clear=True) +@mock.patch.object(pubsub_client_lib, "get_subscriber_client", autospec=True) +def test_acknowledge_messages(mock_get_subscriber_client): + """Test acknowledge_messages tool invocation.""" + subscription_name = "projects/my_project_id/subscriptions/my_sub" + ack_ids = ["ack1", "ack2"] + mock_credentials = mock.create_autospec(Credentials, instance=True) + tool_settings = PubSubToolConfig(project_id="my_project_id") + + mock_subscriber_client = mock.create_autospec( + pubsub_v1.SubscriberClient, instance=True + ) + mock_get_subscriber_client.return_value = mock_subscriber_client + + result = message_tool.acknowledge_messages( + subscription_name, ack_ids, mock_credentials, tool_settings + ) + + assert result["status"] == "SUCCESS" + mock_get_subscriber_client.assert_called_once() + mock_subscriber_client.acknowledge.assert_called_once_with( + subscription=subscription_name, ack_ids=ack_ids + ) + + +@mock.patch.dict(os.environ, {}, clear=True) +@mock.patch.object(pubsub_client_lib, "get_subscriber_client", autospec=True) +def test_acknowledge_messages_exception(mock_get_subscriber_client): + """Test acknowledge_messages tool invocation when exception occurs.""" + subscription_name = "projects/my_project_id/subscriptions/my_sub" + ack_ids = ["ack1"] + mock_credentials = mock.create_autospec(Credentials, instance=True) + tool_settings = PubSubToolConfig(project_id="my_project_id") + + mock_subscriber_client = mock.create_autospec( + pubsub_v1.SubscriberClient, instance=True + ) + mock_get_subscriber_client.return_value = mock_subscriber_client + + mock_subscriber_client.acknowledge.side_effect = Exception("Ack failed") + + result = message_tool.acknowledge_messages( + subscription_name, ack_ids, mock_credentials, tool_settings + ) + + assert result["status"] == "ERROR" + assert "Ack failed" in result["error_details"] diff --git a/tests/unittests/tools/pubsub/test_pubsub_metadata_tool.py b/tests/unittests/tools/pubsub/test_pubsub_metadata_tool.py deleted file mode 100644 index 3d8058f8d4..0000000000 --- a/tests/unittests/tools/pubsub/test_pubsub_metadata_tool.py +++ /dev/null @@ -1,210 +0,0 @@ -# Copyright 2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import os -from unittest import mock - -from google.adk.tools.pubsub import client as pubsub_client_lib -from google.adk.tools.pubsub import metadata_tool -from google.adk.tools.pubsub.config import PubSubToolConfig -from google.cloud import pubsub_v1 -from google.oauth2.credentials import Credentials - - -@mock.patch.dict(os.environ, {}, clear=True) -@mock.patch.object(pubsub_v1.PublisherClient, "list_topics", autospec=True) -@mock.patch.object(pubsub_client_lib, "get_publisher_client", autospec=True) -def test_list_topics(mock_get_publisher_client, mock_list_topics): - """Test list_topics tool invocation.""" - project = "my_project_id" - mock_credentials = mock.create_autospec(Credentials, instance=True) - tool_settings = PubSubToolConfig(project_id=project) - - mock_publisher_client = mock.create_autospec( - pubsub_v1.PublisherClient, instance=True - ) - mock_get_publisher_client.return_value = mock_publisher_client - mock_publisher_client.list_topics.return_value = [ - mock.Mock(name="projects/my_project_id/topics/topic1"), - mock.Mock(name="projects/my_project_id/topics/topic2"), - ] - # Fix the mock names to return the string name when accessed - mock_publisher_client.list_topics.return_value[0].name = "topic1" - mock_publisher_client.list_topics.return_value[1].name = "topic2" - - result = metadata_tool.list_topics(project, mock_credentials, tool_settings) - assert result == ["topic1", "topic2"] - mock_get_publisher_client.assert_called_once() - - -@mock.patch.dict(os.environ, {}, clear=True) -@mock.patch.object(pubsub_v1.PublisherClient, "get_topic", autospec=True) -@mock.patch.object(pubsub_client_lib, "get_publisher_client", autospec=True) -def test_get_topic(mock_get_publisher_client, mock_get_topic): - """Test get_topic tool invocation.""" - topic_name = "projects/my_project_id/topics/my_topic" - mock_credentials = mock.create_autospec(Credentials, instance=True) - tool_settings = PubSubToolConfig(project_id="my_project_id") - - mock_publisher_client = mock.create_autospec( - pubsub_v1.PublisherClient, instance=True - ) - mock_get_publisher_client.return_value = mock_publisher_client - - mock_topic = mock.Mock() - mock_topic.name = topic_name - mock_topic.labels = {"key": "value"} - mock_topic.kms_key_name = "key_name" - mock_topic.schema_settings = "schema_settings" - mock_topic.message_storage_policy = "storage_policy" - - mock_publisher_client.get_topic.return_value = mock_topic - - result = metadata_tool.get_topic(topic_name, mock_credentials, tool_settings) - - assert result["name"] == topic_name - assert result["labels"] == {"key": "value"} - mock_get_publisher_client.assert_called_once() - - -@mock.patch.dict(os.environ, {}, clear=True) -@mock.patch.object( - pubsub_v1.SubscriberClient, "list_subscriptions", autospec=True -) -@mock.patch.object(pubsub_client_lib, "get_subscriber_client", autospec=True) -def test_list_subscriptions( - mock_get_subscriber_client, mock_list_subscriptions -): - """Test list_subscriptions tool invocation.""" - project = "my_project_id" - mock_credentials = mock.create_autospec(Credentials, instance=True) - tool_settings = PubSubToolConfig(project_id=project) - - mock_subscriber_client = mock.create_autospec( - pubsub_v1.SubscriberClient, instance=True - ) - mock_get_subscriber_client.return_value = mock_subscriber_client - mock_subscriber_client.list_subscriptions.return_value = [ - mock.Mock(name="projects/my_project_id/subscriptions/sub1"), - mock.Mock(name="projects/my_project_id/subscriptions/sub2"), - ] - mock_subscriber_client.list_subscriptions.return_value[0].name = "sub1" - mock_subscriber_client.list_subscriptions.return_value[1].name = "sub2" - - result = metadata_tool.list_subscriptions( - project, mock_credentials, tool_settings - ) - assert result == ["sub1", "sub2"] - mock_get_subscriber_client.assert_called_once() - - -@mock.patch.dict(os.environ, {}, clear=True) -@mock.patch.object( - pubsub_v1.SubscriberClient, "get_subscription", autospec=True -) -@mock.patch.object(pubsub_client_lib, "get_subscriber_client", autospec=True) -def test_get_subscription(mock_get_subscriber_client, mock_get_subscription): - """Test get_subscription tool invocation.""" - subscription_name = "projects/my_project_id/subscriptions/my_sub" - mock_credentials = mock.create_autospec(Credentials, instance=True) - tool_settings = PubSubToolConfig(project_id="my_project_id") - - mock_subscriber_client = mock.create_autospec( - pubsub_v1.SubscriberClient, instance=True - ) - mock_get_subscriber_client.return_value = mock_subscriber_client - - mock_subscription = mock.Mock() - mock_subscription.name = subscription_name - mock_subscription.topic = "projects/my_project_id/topics/my_topic" - mock_subscription.push_config = "push_config" - mock_subscription.ack_deadline_seconds = 10 - mock_subscription.retain_acked_messages = True - mock_subscription.message_retention_duration = "duration" - mock_subscription.labels = {"key": "value"} - mock_subscription.enable_message_ordering = True - mock_subscription.expiration_policy = "expiration" - mock_subscription.filter = "filter" - mock_subscription.dead_letter_policy = "dead_letter" - mock_subscription.retry_policy = "retry" - mock_subscription.detached = False - - mock_subscriber_client.get_subscription.return_value = mock_subscription - - result = metadata_tool.get_subscription( - subscription_name, mock_credentials, tool_settings - ) - - assert result["name"] == subscription_name - assert result["topic"] == "projects/my_project_id/topics/my_topic" - mock_get_subscriber_client.assert_called_once() - - -@mock.patch.dict(os.environ, {}, clear=True) -@mock.patch.object(pubsub_v1.SchemaServiceClient, "list_schemas", autospec=True) -@mock.patch.object(pubsub_client_lib, "get_schema_client", autospec=True) -def test_list_schemas(mock_get_schema_client, mock_list_schemas): - """Test list_schemas tool invocation.""" - project = "my_project_id" - mock_credentials = mock.create_autospec(Credentials, instance=True) - tool_settings = PubSubToolConfig(project_id=project) - - mock_schema_client = mock.create_autospec( - pubsub_v1.SchemaServiceClient, instance=True - ) - mock_get_schema_client.return_value = mock_schema_client - mock_schema_client.list_schemas.return_value = [ - mock.Mock(name="projects/my_project_id/schemas/schema1"), - mock.Mock(name="projects/my_project_id/schemas/schema2"), - ] - mock_schema_client.list_schemas.return_value[0].name = "schema1" - mock_schema_client.list_schemas.return_value[1].name = "schema2" - - result = metadata_tool.list_schemas(project, mock_credentials, tool_settings) - assert result == ["schema1", "schema2"] - mock_get_schema_client.assert_called_once() - - -@mock.patch.dict(os.environ, {}, clear=True) -@mock.patch.object(pubsub_v1.SchemaServiceClient, "get_schema", autospec=True) -@mock.patch.object(pubsub_client_lib, "get_schema_client", autospec=True) -def test_get_schema(mock_get_schema_client, mock_get_schema): - """Test get_schema tool invocation.""" - schema_name = "projects/my_project_id/schemas/my_schema" - mock_credentials = mock.create_autospec(Credentials, instance=True) - tool_settings = PubSubToolConfig(project_id="my_project_id") - - mock_schema_client = mock.create_autospec( - pubsub_v1.SchemaServiceClient, instance=True - ) - mock_get_schema_client.return_value = mock_schema_client - - mock_schema = mock.Mock() - mock_schema.name = schema_name - mock_schema.type_ = "AVRO" - mock_schema.definition = "definition" - mock_schema.revision_id = "revision_id" - mock_schema.revision_create_time = "time" - - mock_schema_client.get_schema.return_value = mock_schema - - result = metadata_tool.get_schema( - schema_name, mock_credentials, tool_settings - ) - - assert result["name"] == schema_name - assert result["type"] == "AVRO" - mock_get_schema_client.assert_called_once() diff --git a/tests/unittests/tools/pubsub/test_pubsub_toolset.py b/tests/unittests/tools/pubsub/test_pubsub_toolset.py index 567d434a23..0534eb9a9d 100644 --- a/tests/unittests/tools/pubsub/test_pubsub_toolset.py +++ b/tests/unittests/tools/pubsub/test_pubsub_toolset.py @@ -41,19 +41,13 @@ async def test_pubsub_toolset_tools_default(): tools = await toolset.get_tools() assert tools is not None - assert len(tools) == 9 + assert len(tools) == 3 assert all([isinstance(tool, GoogleTool) for tool in tools]) expected_tool_names = set([ - "list_topics", - "get_topic", - "list_subscriptions", - "get_subscription", - "list_schemas", - "get_schema", - "list_schema_revisions", - "get_schema_revision", "publish_message", + "pull_messages", + "acknowledge_messages", ]) actual_tool_names = set([tool.name for tool in tools]) assert actual_tool_names == expected_tool_names @@ -63,12 +57,9 @@ async def test_pubsub_toolset_tools_default(): "selected_tools", [ pytest.param([], id="None"), - pytest.param(["list_topics", "get_topic"], id="topic-metadata"), - pytest.param( - ["list_subscriptions", "get_subscription"], - id="subscription-metadata", - ), pytest.param(["publish_message"], id="publish"), + pytest.param(["pull_messages"], id="pull"), + pytest.param(["acknowledge_messages"], id="ack"), ], ) @pytest.mark.asyncio From ab777729759af7493b06f1146251e8062bba31e2 Mon Sep 17 00:00:00 2001 From: Kamal Aboul-Hosn Date: Sun, 7 Dec 2025 19:07:46 -0500 Subject: [PATCH 03/14] Cache Pub/Sub clients --- src/google/adk/tools/pubsub/client.py | 60 ++++++++++++++++++ src/google/adk/tools/pubsub/message_tool.py | 14 ++++- .../tools/pubsub/test_pubsub_client.py | 62 +++++++++++++++++++ .../tools/pubsub/test_pubsub_message_tool.py | 3 + 4 files changed, 136 insertions(+), 3 deletions(-) diff --git a/src/google/adk/tools/pubsub/client.py b/src/google/adk/tools/pubsub/client.py index 7cd61f40dc..0bafd2b8c2 100644 --- a/src/google/adk/tools/pubsub/client.py +++ b/src/google/adk/tools/pubsub/client.py @@ -27,20 +27,50 @@ USER_AGENT = f"adk-pubsub-tool google-adk/{version.__version__}" +import time + +_publisher_client_cache = {} +_CACHE_TTL = 1800 # 30 minutes + + def get_publisher_client( *, credentials: Credentials, user_agent: Optional[Union[str, List[str]]] = None, + publisher_options: Optional[pubsub_v1.types.PublisherOptions] = None, ) -> pubsub_v1.PublisherClient: """Get a Pub/Sub Publisher client. Args: credentials: The credentials to use for the request. user_agent: The user agent to use for the request. + publisher_options: The publisher options to use for the request. Returns: A Pub/Sub Publisher client. """ + global _publisher_client_cache + current_time = time.time() + + # Clean up expired entries + _publisher_client_cache = { + k: v for k, v in _publisher_client_cache.items() if v[1] > current_time + } + + user_agents_key = None + if user_agent: + if isinstance(user_agent, str): + user_agents_key = (user_agent,) + else: + user_agents_key = tuple(user_agent) + + # Use object identity for credentials and publisher_options as they might not be hashable by value + key = (credentials, user_agents_key, publisher_options) + + if key in _publisher_client_cache: + client, expiration = _publisher_client_cache[key] + if expiration > current_time: + return client user_agents = [USER_AGENT] if user_agent: @@ -54,11 +84,17 @@ def get_publisher_client( publisher_client = pubsub_v1.PublisherClient( credentials=credentials, client_info=client_info, + publisher_options=publisher_options, ) + _publisher_client_cache[key] = (publisher_client, current_time + _CACHE_TTL) + return publisher_client +_subscriber_client_cache = {} + + def get_subscriber_client( *, credentials: Credentials, @@ -73,6 +109,28 @@ def get_subscriber_client( Returns: A Pub/Sub Subscriber client. """ + global _subscriber_client_cache + current_time = time.time() + + # Clean up expired entries + _subscriber_client_cache = { + k: v for k, v in _subscriber_client_cache.items() if v[1] > current_time + } + + user_agents_key = None + if user_agent: + if isinstance(user_agent, str): + user_agents_key = (user_agent,) + else: + user_agents_key = tuple(user_agent) + + # Use object identity for credentials as they might not be hashable by value + key = (credentials, user_agents_key) + + if key in _subscriber_client_cache: + client, expiration = _subscriber_client_cache[key] + if expiration > current_time: + return client user_agents = [USER_AGENT] if user_agent: @@ -88,4 +146,6 @@ def get_subscriber_client( client_info=client_info, ) + _subscriber_client_cache[key] = (subscriber_client, current_time + _CACHE_TTL) + return subscriber_client diff --git a/src/google/adk/tools/pubsub/message_tool.py b/src/google/adk/tools/pubsub/message_tool.py index 539c78d09b..3e8aef4260 100644 --- a/src/google/adk/tools/pubsub/message_tool.py +++ b/src/google/adk/tools/pubsub/message_tool.py @@ -18,6 +18,7 @@ from typing import Optional from google.auth.credentials import Credentials +from google.cloud import pubsub_v1 from . import client from .config import PubSubToolConfig @@ -45,15 +46,22 @@ def publish_message( dict: Dictionary with the message_id of the published message. """ try: + publisher_options = None + publish_kwargs = attributes or {} + if ordering_key: + publish_kwargs["ordering_key"] = ordering_key + publisher_options = pubsub_v1.types.PublisherOptions( + enable_message_ordering=True + ) + publisher_client = client.get_publisher_client( credentials=credentials, user_agent=[settings.project_id, "publish_message"], + publisher_options=publisher_options, ) data = message.encode("utf-8") - future = publisher_client.publish( - topic_name, data, ordering_key=ordering_key, **(attributes or {}) - ) + future = publisher_client.publish(topic_name, data, **publish_kwargs) message_id = future.result() return {"message_id": message_id} diff --git a/tests/unittests/tools/pubsub/test_pubsub_client.py b/tests/unittests/tools/pubsub/test_pubsub_client.py index b37f8b9b94..53ddd13447 100644 --- a/tests/unittests/tools/pubsub/test_pubsub_client.py +++ b/tests/unittests/tools/pubsub/test_pubsub_client.py @@ -30,9 +30,47 @@ def test_get_publisher_client(mock_publisher_client): assert kwargs["credentials"] == mock_creds assert "client_info" in kwargs + +@mock.patch("google.cloud.pubsub_v1.PublisherClient") +def test_get_publisher_client_with_options(mock_publisher_client): + """Test get_publisher_client factory with options.""" + mock_creds = mock.Mock(spec=Credentials) + mock_options = mock.Mock(spec=pubsub_v1.types.PublisherOptions) + client.get_publisher_client( + credentials=mock_creds, publisher_options=mock_options + ) + + mock_publisher_client.assert_called_once() + _, kwargs = mock_publisher_client.call_args + assert kwargs["credentials"] == mock_creds + assert kwargs["publisher_options"] == mock_options assert "client_info" in kwargs +@mock.patch("google.cloud.pubsub_v1.PublisherClient") +def test_get_publisher_client_caching(mock_publisher_client): + """Test get_publisher_client caching behavior.""" + # Configure mock to return different instances + mock_publisher_client.side_effect = [mock.Mock(), mock.Mock()] + + mock_creds = mock.Mock(spec=Credentials) + + # First call - should create client + client1 = client.get_publisher_client(credentials=mock_creds) + mock_publisher_client.assert_called_once() + + # Second call with same args - should return cached client + client2 = client.get_publisher_client(credentials=mock_creds) + assert client1 is client2 + mock_publisher_client.assert_called_once() # Still called only once + + # Call with different args - should create new client + mock_creds2 = mock.Mock(spec=Credentials) + client3 = client.get_publisher_client(credentials=mock_creds2) + assert client3 is not client1 + assert mock_publisher_client.call_count == 2 + + @mock.patch("google.cloud.pubsub_v1.SubscriberClient") def test_get_subscriber_client(mock_subscriber_client): """Test get_subscriber_client factory.""" @@ -43,3 +81,27 @@ def test_get_subscriber_client(mock_subscriber_client): _, kwargs = mock_subscriber_client.call_args assert kwargs["credentials"] == mock_creds assert "client_info" in kwargs + + +@mock.patch("google.cloud.pubsub_v1.SubscriberClient") +def test_get_subscriber_client_caching(mock_subscriber_client): + """Test get_subscriber_client caching behavior.""" + # Configure mock to return different instances + mock_subscriber_client.side_effect = [mock.Mock(), mock.Mock()] + + mock_creds = mock.Mock(spec=Credentials) + + # First call - should create client + client1 = client.get_subscriber_client(credentials=mock_creds) + mock_subscriber_client.assert_called_once() + + # Second call with same args - should return cached client + client2 = client.get_subscriber_client(credentials=mock_creds) + assert client1 is client2 + mock_subscriber_client.assert_called_once() # Still called only once + + # Call with different args - should create new client + mock_creds2 = mock.Mock(spec=Credentials) + client3 = client.get_subscriber_client(credentials=mock_creds2) + assert client3 is not client1 + assert mock_subscriber_client.call_count == 2 diff --git a/tests/unittests/tools/pubsub/test_pubsub_message_tool.py b/tests/unittests/tools/pubsub/test_pubsub_message_tool.py index e4d001276a..043c8862d1 100644 --- a/tests/unittests/tools/pubsub/test_pubsub_message_tool.py +++ b/tests/unittests/tools/pubsub/test_pubsub_message_tool.py @@ -84,6 +84,9 @@ def test_publish_message_with_ordering_key( assert result["message_id"] == "message_id" mock_get_publisher_client.assert_called_once() + _, kwargs = mock_get_publisher_client.call_args + assert kwargs["publisher_options"].enable_message_ordering is True + mock_publisher_client.publish.assert_called_once() # Verify ordering_key was passed From eb7dc7ba60977812f96aa843bf199809e7e1a0b3 Mon Sep 17 00:00:00 2001 From: Kamal Aboul-Hosn Date: Mon, 8 Dec 2025 07:32:44 -0500 Subject: [PATCH 04/14] Fix creation/management of clients --- src/google/adk/tools/pubsub/client.py | 15 +++++---------- src/google/adk/tools/pubsub/message_tool.py | 14 +++++++------- 2 files changed, 12 insertions(+), 17 deletions(-) diff --git a/src/google/adk/tools/pubsub/client.py b/src/google/adk/tools/pubsub/client.py index 0bafd2b8c2..a6cde8ca43 100644 --- a/src/google/adk/tools/pubsub/client.py +++ b/src/google/adk/tools/pubsub/client.py @@ -21,6 +21,7 @@ from google.api_core.gapic_v1.client_info import ClientInfo from google.auth.credentials import Credentials from google.cloud import pubsub_v1 +from google.cloud.pubsub_v1.types import BatchSettings from ... import version @@ -52,11 +53,6 @@ def get_publisher_client( global _publisher_client_cache current_time = time.time() - # Clean up expired entries - _publisher_client_cache = { - k: v for k, v in _publisher_client_cache.items() if v[1] > current_time - } - user_agents_key = None if user_agent: if isinstance(user_agent, str): @@ -81,10 +77,14 @@ def get_publisher_client( client_info = ClientInfo(user_agent=" ".join(user_agents)) + # Since we syncrhonously publish messages, we want to disable batching to + # remove any delay. + custom_batch_settings = BatchSettings(max_messages=1) publisher_client = pubsub_v1.PublisherClient( credentials=credentials, client_info=client_info, publisher_options=publisher_options, + batch_settings=custom_batch_settings, ) _publisher_client_cache[key] = (publisher_client, current_time + _CACHE_TTL) @@ -112,11 +112,6 @@ def get_subscriber_client( global _subscriber_client_cache current_time = time.time() - # Clean up expired entries - _subscriber_client_cache = { - k: v for k, v in _subscriber_client_cache.items() if v[1] > current_time - } - user_agents_key = None if user_agent: if isinstance(user_agent, str): diff --git a/src/google/adk/tools/pubsub/message_tool.py b/src/google/adk/tools/pubsub/message_tool.py index 3e8aef4260..48425536e8 100644 --- a/src/google/adk/tools/pubsub/message_tool.py +++ b/src/google/adk/tools/pubsub/message_tool.py @@ -29,8 +29,8 @@ def publish_message( message: str, credentials: Credentials, settings: PubSubToolConfig, - attributes: Optional[dict[str, str]] = None, - ordering_key: Optional[str] = None, + attributes: Optional[dict[str, str]] = {}, + ordering_key: Optional[str] = "", ) -> dict: """Publish a message to a Pub/Sub topic. @@ -46,14 +46,12 @@ def publish_message( dict: Dictionary with the message_id of the published message. """ try: - publisher_options = None - publish_kwargs = attributes or {} if ordering_key: - publish_kwargs["ordering_key"] = ordering_key publisher_options = pubsub_v1.types.PublisherOptions( enable_message_ordering=True ) - + else: + publisher_options = pubsub_v1.types.PublisherOptions() publisher_client = client.get_publisher_client( credentials=credentials, user_agent=[settings.project_id, "publish_message"], @@ -61,7 +59,9 @@ def publish_message( ) data = message.encode("utf-8") - future = publisher_client.publish(topic_name, data, **publish_kwargs) + future = publisher_client.publish( + topic_name, data=data, ordering_key=ordering_key, **attributes + ) message_id = future.result() return {"message_id": message_id} From 2d3535a67fa6ccf93c860f432bec2e74fcf1228b Mon Sep 17 00:00:00 2001 From: Kamal Aboul-Hosn Date: Mon, 8 Dec 2025 07:53:35 -0500 Subject: [PATCH 05/14] Update src/google/adk/tools/pubsub/message_tool.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- src/google/adk/tools/pubsub/message_tool.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/google/adk/tools/pubsub/message_tool.py b/src/google/adk/tools/pubsub/message_tool.py index 48425536e8..08768f92ac 100644 --- a/src/google/adk/tools/pubsub/message_tool.py +++ b/src/google/adk/tools/pubsub/message_tool.py @@ -109,7 +109,7 @@ def pull_messages( "message_id": received_message.message.message_id, "data": received_message.message.data.decode("utf-8"), "attributes": dict(received_message.message.attributes), - "publish_time": str(received_message.message.publish_time), + "publish_time": received_message.message.publish_time.ToDatetime().isoformat(), "ack_id": received_message.ack_id, }) ack_ids.append(received_message.ack_id) From ea4ca5f3ee9560cd2e7374efcebad22364d0261f Mon Sep 17 00:00:00 2001 From: Kamal Aboul-Hosn Date: Mon, 8 Dec 2025 07:56:50 -0500 Subject: [PATCH 06/14] Add locking to client --- src/google/adk/tools/pubsub/client.py | 86 +++++++++++++++------------ 1 file changed, 47 insertions(+), 39 deletions(-) diff --git a/src/google/adk/tools/pubsub/client.py b/src/google/adk/tools/pubsub/client.py index a6cde8ca43..3274046d63 100644 --- a/src/google/adk/tools/pubsub/client.py +++ b/src/google/adk/tools/pubsub/client.py @@ -28,9 +28,11 @@ USER_AGENT = f"adk-pubsub-tool google-adk/{version.__version__}" +import threading import time _publisher_client_cache = {} +_publisher_client_lock = threading.Lock() _CACHE_TTL = 1800 # 30 minutes @@ -63,36 +65,38 @@ def get_publisher_client( # Use object identity for credentials and publisher_options as they might not be hashable by value key = (credentials, user_agents_key, publisher_options) - if key in _publisher_client_cache: - client, expiration = _publisher_client_cache[key] - if expiration > current_time: - return client + with _publisher_client_lock: + if key in _publisher_client_cache: + client, expiration = _publisher_client_cache[key] + if expiration > current_time: + return client - user_agents = [USER_AGENT] - if user_agent: - if isinstance(user_agent, str): - user_agents.append(user_agent) - else: - user_agents.extend([ua for ua in user_agent if ua]) + user_agents = [USER_AGENT] + if user_agent: + if isinstance(user_agent, str): + user_agents.append(user_agent) + else: + user_agents.extend([ua for ua in user_agent if ua]) - client_info = ClientInfo(user_agent=" ".join(user_agents)) + client_info = ClientInfo(user_agent=" ".join(user_agents)) - # Since we syncrhonously publish messages, we want to disable batching to - # remove any delay. - custom_batch_settings = BatchSettings(max_messages=1) - publisher_client = pubsub_v1.PublisherClient( - credentials=credentials, - client_info=client_info, - publisher_options=publisher_options, - batch_settings=custom_batch_settings, - ) + # Since we syncrhonously publish messages, we want to disable batching to + # remove any delay. + custom_batch_settings = BatchSettings(max_messages=1) + publisher_client = pubsub_v1.PublisherClient( + credentials=credentials, + client_info=client_info, + publisher_options=publisher_options, + batch_settings=custom_batch_settings, + ) - _publisher_client_cache[key] = (publisher_client, current_time + _CACHE_TTL) + _publisher_client_cache[key] = (publisher_client, current_time + _CACHE_TTL) - return publisher_client + return publisher_client _subscriber_client_cache = {} +_subscriber_client_lock = threading.Lock() def get_subscriber_client( @@ -122,25 +126,29 @@ def get_subscriber_client( # Use object identity for credentials as they might not be hashable by value key = (credentials, user_agents_key) - if key in _subscriber_client_cache: - client, expiration = _subscriber_client_cache[key] - if expiration > current_time: - return client + with _subscriber_client_lock: + if key in _subscriber_client_cache: + client, expiration = _subscriber_client_cache[key] + if expiration > current_time: + return client - user_agents = [USER_AGENT] - if user_agent: - if isinstance(user_agent, str): - user_agents.append(user_agent) - else: - user_agents.extend([ua for ua in user_agent if ua]) + user_agents = [USER_AGENT] + if user_agent: + if isinstance(user_agent, str): + user_agents.append(user_agent) + else: + user_agents.extend([ua for ua in user_agent if ua]) - client_info = ClientInfo(user_agent=" ".join(user_agents)) + client_info = ClientInfo(user_agent=" ".join(user_agents)) - subscriber_client = pubsub_v1.SubscriberClient( - credentials=credentials, - client_info=client_info, - ) + subscriber_client = pubsub_v1.SubscriberClient( + credentials=credentials, + client_info=client_info, + ) - _subscriber_client_cache[key] = (subscriber_client, current_time + _CACHE_TTL) + _subscriber_client_cache[key] = ( + subscriber_client, + current_time + _CACHE_TTL, + ) - return subscriber_client + return subscriber_client From fed112f7c3fad22323cb190f79de2141b82e1512 Mon Sep 17 00:00:00 2001 From: Kamal Aboul-Hosn Date: Mon, 8 Dec 2025 07:58:13 -0500 Subject: [PATCH 07/14] Formatting fixes --- src/google/adk/tools/pubsub/message_tool.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/google/adk/tools/pubsub/message_tool.py b/src/google/adk/tools/pubsub/message_tool.py index 08768f92ac..54043ac46b 100644 --- a/src/google/adk/tools/pubsub/message_tool.py +++ b/src/google/adk/tools/pubsub/message_tool.py @@ -109,7 +109,9 @@ def pull_messages( "message_id": received_message.message.message_id, "data": received_message.message.data.decode("utf-8"), "attributes": dict(received_message.message.attributes), - "publish_time": received_message.message.publish_time.ToDatetime().isoformat(), + "publish_time": ( + received_message.message.publish_time.ToDatetime().isoformat() + ), "ack_id": received_message.ack_id, }) ack_ids.append(received_message.ack_id) From 37a38a4dcadb04a0af0ec584e6f611204a63cd2a Mon Sep 17 00:00:00 2001 From: Kamal Aboul-Hosn Date: Mon, 8 Dec 2025 08:23:21 -0500 Subject: [PATCH 08/14] Fix timestamp handling --- src/google/adk/tools/pubsub/message_tool.py | 4 +--- tests/unittests/tools/pubsub/test_pubsub_message_tool.py | 8 ++++++-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/google/adk/tools/pubsub/message_tool.py b/src/google/adk/tools/pubsub/message_tool.py index 54043ac46b..a9a5d00668 100644 --- a/src/google/adk/tools/pubsub/message_tool.py +++ b/src/google/adk/tools/pubsub/message_tool.py @@ -109,9 +109,7 @@ def pull_messages( "message_id": received_message.message.message_id, "data": received_message.message.data.decode("utf-8"), "attributes": dict(received_message.message.attributes), - "publish_time": ( - received_message.message.publish_time.ToDatetime().isoformat() - ), + "publish_time": received_message.message.publish_time.rfc3339(), "ack_id": received_message.ack_id, }) ack_ids.append(received_message.ack_id) diff --git a/tests/unittests/tools/pubsub/test_pubsub_message_tool.py b/tests/unittests/tools/pubsub/test_pubsub_message_tool.py index 043c8862d1..fc2421944d 100644 --- a/tests/unittests/tools/pubsub/test_pubsub_message_tool.py +++ b/tests/unittests/tools/pubsub/test_pubsub_message_tool.py @@ -183,7 +183,9 @@ def test_pull_messages(mock_get_subscriber_client): mock_message.message.message_id = "123" mock_message.message.data = b"Hello" mock_message.message.attributes = {"key": "value"} - mock_message.message.publish_time = "2023-01-01T00:00:00Z" + mock_publish_time = mock.Mock() + mock_publish_time.rfc3339.return_value = "2023-01-01T00:00:00Z" + mock_message.message.publish_time = mock_publish_time mock_message.ack_id = "ack_123" mock_response.received_messages = [mock_message] mock_subscriber_client.pull.return_value = mock_response @@ -223,7 +225,9 @@ def test_pull_messages_auto_ack(mock_get_subscriber_client): mock_message.message.message_id = "123" mock_message.message.data = b"Hello" mock_message.message.attributes = {} - mock_message.message.publish_time = "2023-01-01T00:00:00Z" + mock_publish_time = mock.Mock() + mock_publish_time.rfc3339.return_value = "2023-01-01T00:00:00Z" + mock_message.message.publish_time = mock_publish_time mock_message.ack_id = "ack_123" mock_response.received_messages = [mock_message] mock_subscriber_client.pull.return_value = mock_response From 74bb828bea43da7b5439ab487d58398840ffe1ff Mon Sep 17 00:00:00 2001 From: Kamal Aboul-Hosn Date: Thu, 11 Dec 2025 09:59:22 -0500 Subject: [PATCH 09/14] Better error messages; clean up clients --- src/google/adk/tools/pubsub/client.py | 11 +++++++++++ src/google/adk/tools/pubsub/message_tool.py | 14 +++++++++++--- src/google/adk/tools/pubsub/pubsub_toolset.py | 4 +++- 3 files changed, 25 insertions(+), 4 deletions(-) diff --git a/src/google/adk/tools/pubsub/client.py b/src/google/adk/tools/pubsub/client.py index 3274046d63..5637deb149 100644 --- a/src/google/adk/tools/pubsub/client.py +++ b/src/google/adk/tools/pubsub/client.py @@ -152,3 +152,14 @@ def get_subscriber_client( ) return subscriber_client + + +def cleanup_clients(): + """Clean up all cached Pub/Sub clients.""" + global _publisher_client_cache, _subscriber_client_cache + + with _publisher_client_lock: + _publisher_client_cache.clear() + + with _subscriber_client_lock: + _subscriber_client_cache.clear() diff --git a/src/google/adk/tools/pubsub/message_tool.py b/src/google/adk/tools/pubsub/message_tool.py index a9a5d00668..cf46167441 100644 --- a/src/google/adk/tools/pubsub/message_tool.py +++ b/src/google/adk/tools/pubsub/message_tool.py @@ -68,7 +68,9 @@ def publish_message( except Exception as ex: return { "status": "ERROR", - "error_details": str(ex), + "error_details": ( + f"Failed to publish message to topic '{topic_name}': {repr(ex)}" + ), } @@ -124,7 +126,10 @@ def pull_messages( except Exception as ex: return { "status": "ERROR", - "error_details": str(ex), + "error_details": ( + f"Failed to pull messages from subscription '{subscription_name}':" + f" {repr(ex)}" + ), } @@ -160,5 +165,8 @@ def acknowledge_messages( except Exception as ex: return { "status": "ERROR", - "error_details": str(ex), + "error_details": ( + "Failed to acknowledge messages on subscription" + f" '{subscription_name}': {repr(ex)}" + ), } diff --git a/src/google/adk/tools/pubsub/pubsub_toolset.py b/src/google/adk/tools/pubsub/pubsub_toolset.py index 394cd89073..65e9da304c 100644 --- a/src/google/adk/tools/pubsub/pubsub_toolset.py +++ b/src/google/adk/tools/pubsub/pubsub_toolset.py @@ -21,6 +21,7 @@ from google.adk.agents.readonly_context import ReadonlyContext from typing_extensions import override +from . import client from . import message_tool from ...tools.base_tool import BaseTool from ...tools.base_toolset import BaseToolset @@ -88,4 +89,5 @@ async def get_tools( @override async def close(self): - pass + """Clean up resources used by the toolset.""" + client.cleanup_clients() From f81bd86478b9809bf5fa7d2d9725fc3c74d8eb1c Mon Sep 17 00:00:00 2001 From: Kamal Aboul-Hosn Date: Thu, 11 Dec 2025 10:53:28 -0500 Subject: [PATCH 10/14] Fix parameters and tests --- src/google/adk/tools/pubsub/client.py | 8 +++--- src/google/adk/tools/pubsub/message_tool.py | 28 ++++++++++++++----- .../tools/pubsub/test_pubsub_client.py | 17 +++++++++++ 3 files changed, 42 insertions(+), 11 deletions(-) diff --git a/src/google/adk/tools/pubsub/client.py b/src/google/adk/tools/pubsub/client.py index 5637deb149..aab384687d 100644 --- a/src/google/adk/tools/pubsub/client.py +++ b/src/google/adk/tools/pubsub/client.py @@ -62,8 +62,8 @@ def get_publisher_client( else: user_agents_key = tuple(user_agent) - # Use object identity for credentials and publisher_options as they might not be hashable by value - key = (credentials, user_agents_key, publisher_options) + # Use object identity for credentials and publisher_options as they are not hashable + key = (id(credentials), user_agents_key, id(publisher_options)) with _publisher_client_lock: if key in _publisher_client_cache: @@ -123,8 +123,8 @@ def get_subscriber_client( else: user_agents_key = tuple(user_agent) - # Use object identity for credentials as they might not be hashable by value - key = (credentials, user_agents_key) + # Use object identity for credentials as they are not hashable + key = (id(credentials), user_agents_key) with _subscriber_client_lock: if key in _subscriber_client_cache: diff --git a/src/google/adk/tools/pubsub/message_tool.py b/src/google/adk/tools/pubsub/message_tool.py index cf46167441..688bfd6330 100644 --- a/src/google/adk/tools/pubsub/message_tool.py +++ b/src/google/adk/tools/pubsub/message_tool.py @@ -14,6 +14,7 @@ from __future__ import annotations +import base64 from typing import List from typing import Optional @@ -29,8 +30,8 @@ def publish_message( message: str, credentials: Credentials, settings: PubSubToolConfig, - attributes: Optional[dict[str, str]] = {}, - ordering_key: Optional[str] = "", + attributes: Optional[dict[str, str]] = None, + ordering_key: Optional[str] = None, ) -> dict: """Publish a message to a Pub/Sub topic. @@ -39,8 +40,8 @@ def publish_message( message (str): The message content to publish. credentials (Credentials): The credentials to use for the request. settings (PubSubToolConfig): The Pub/Sub tool settings. - attributes (Optional[dict[str, str]]): Optional attributes to attach to the message. - ordering_key (Optional[str]): Optional ordering key for the message. + attributes (Optional[dict[str, str]]): Attributes to attach to the message. + ordering_key (str): Ordering key for the message. Returns: dict: Dictionary with the message_id of the published message. @@ -58,9 +59,13 @@ def publish_message( publisher_options=publisher_options, ) - data = message.encode("utf-8") + message_bytes = message.encode("utf-8") + future = publisher_client.publish( - topic_name, data=data, ordering_key=ordering_key, **attributes + topic_name, + data=message_bytes, + ordering_key=ordering_key or "", + **(attributes or {}), ) message_id = future.result() @@ -107,9 +112,18 @@ def pull_messages( messages = [] ack_ids = [] for received_message in response.received_messages: + # Try to decode as UTF-8, fall back to base64 for binary data + try: + message_data = received_message.message.data.decode("utf-8") + except UnicodeDecodeError: + # If UTF-8 decoding fails, encode as base64 string + message_data = base64.b64encode(received_message.message.data).decode( + "ascii" + ) + messages.append({ "message_id": received_message.message.message_id, - "data": received_message.message.data.decode("utf-8"), + "data": message_data, "attributes": dict(received_message.message.attributes), "publish_time": received_message.message.publish_time.rfc3339(), "ack_id": received_message.ack_id, diff --git a/tests/unittests/tools/pubsub/test_pubsub_client.py b/tests/unittests/tools/pubsub/test_pubsub_client.py index 53ddd13447..c4dccdec39 100644 --- a/tests/unittests/tools/pubsub/test_pubsub_client.py +++ b/tests/unittests/tools/pubsub/test_pubsub_client.py @@ -17,6 +17,19 @@ from google.adk.tools.pubsub import client from google.cloud import pubsub_v1 from google.oauth2.credentials import Credentials +import pytest + + +@pytest.fixture(autouse=True) +def cleanup_pubsub_clients(): + """Automatically clean up Pub/Sub client caches after each test. + + This fixture runs automatically for all tests in this file, + ensuring that client caches are cleared between tests to prevent + state leakage and ensure test isolation. + """ + yield + client.cleanup_clients() @mock.patch("google.cloud.pubsub_v1.PublisherClient") @@ -29,6 +42,8 @@ def test_get_publisher_client(mock_publisher_client): _, kwargs = mock_publisher_client.call_args assert kwargs["credentials"] == mock_creds assert "client_info" in kwargs + assert isinstance(kwargs["batch_settings"], pubsub_v1.types.BatchSettings) + assert kwargs["batch_settings"].max_messages == 1 @mock.patch("google.cloud.pubsub_v1.PublisherClient") @@ -45,6 +60,8 @@ def test_get_publisher_client_with_options(mock_publisher_client): assert kwargs["credentials"] == mock_creds assert kwargs["publisher_options"] == mock_options assert "client_info" in kwargs + assert isinstance(kwargs["batch_settings"], pubsub_v1.types.BatchSettings) + assert kwargs["batch_settings"].max_messages == 1 @mock.patch("google.cloud.pubsub_v1.PublisherClient") From 7c66289d3a546e626a81c9a6d10dcd7f48dc2f9c Mon Sep 17 00:00:00 2001 From: Kamal Aboul-Hosn Date: Thu, 11 Dec 2025 11:04:30 -0500 Subject: [PATCH 11/14] Fix documentation --- src/google/adk/tools/pubsub/message_tool.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/google/adk/tools/pubsub/message_tool.py b/src/google/adk/tools/pubsub/message_tool.py index 688bfd6330..0720765e9c 100644 --- a/src/google/adk/tools/pubsub/message_tool.py +++ b/src/google/adk/tools/pubsub/message_tool.py @@ -40,8 +40,8 @@ def publish_message( message (str): The message content to publish. credentials (Credentials): The credentials to use for the request. settings (PubSubToolConfig): The Pub/Sub tool settings. - attributes (Optional[dict[str, str]]): Attributes to attach to the message. - ordering_key (str): Ordering key for the message. + attributes (Optional[dict[str, str]]): Optional attributes to attach to the message. + ordering_key (Optional[str]): Optional ordering key for the message. Returns: dict: Dictionary with the message_id of the published message. From 7c72b6f18f93de73b044fe0d2795832f4f96c7c8 Mon Sep 17 00:00:00 2001 From: Kamal Aboul-Hosn Date: Fri, 12 Dec 2025 13:54:16 +0000 Subject: [PATCH 12/14] Formatting and various fixes --- src/google/adk/features/_feature_registry.py | 4 +++ src/google/adk/tools/pubsub/__init__.py | 3 +- src/google/adk/tools/pubsub/client.py | 24 ++++++------- src/google/adk/tools/pubsub/config.py | 7 ++-- src/google/adk/tools/pubsub/message_tool.py | 35 ++++++++++--------- .../adk/tools/pubsub/pubsub_credentials.py | 5 +-- src/google/adk/tools/pubsub/pubsub_toolset.py | 23 ++++++------ .../tools/pubsub/test_pubsub_client.py | 28 ++++++++------- .../tools/pubsub/test_pubsub_message_tool.py | 34 ++++++++++-------- .../tools/pubsub/test_pubsub_toolset.py | 19 ++++++---- 10 files changed, 99 insertions(+), 83 deletions(-) diff --git a/src/google/adk/features/_feature_registry.py b/src/google/adk/features/_feature_registry.py index 46b56eb6d9..cb3fa882d1 100644 --- a/src/google/adk/features/_feature_registry.py +++ b/src/google/adk/features/_feature_registry.py @@ -32,6 +32,7 @@ class FeatureName(str, Enum): GOOGLE_TOOL = "GOOGLE_TOOL" JSON_SCHEMA_FOR_FUNC_DECL = "JSON_SCHEMA_FOR_FUNC_DECL" PROGRESSIVE_SSE_STREAMING = "PROGRESSIVE_SSE_STREAMING" + PUBSUB_TOOLSET = "PUBSUB_TOOLSET" SPANNER_TOOLSET = "SPANNER_TOOLSET" SPANNER_TOOL_SETTINGS = "SPANNER_TOOL_SETTINGS" @@ -90,6 +91,9 @@ class FeatureConfig: FeatureName.PROGRESSIVE_SSE_STREAMING: FeatureConfig( FeatureStage.WIP, default_on=False ), + FeatureName.PUBSUB_TOOLSET: FeatureConfig( + FeatureStage.EXPERIMENTAL, default_on=True + ), FeatureName.SPANNER_TOOLSET: FeatureConfig( FeatureStage.EXPERIMENTAL, default_on=True ), diff --git a/src/google/adk/tools/pubsub/__init__.py b/src/google/adk/tools/pubsub/__init__.py index 0e57a1cc59..9625155f06 100644 --- a/src/google/adk/tools/pubsub/__init__.py +++ b/src/google/adk/tools/pubsub/__init__.py @@ -23,7 +23,8 @@ messages. """ +from .config import PubSubToolConfig from .pubsub_credentials import PubSubCredentialsConfig from .pubsub_toolset import PubSubToolset -__all__ = ["PubSubCredentialsConfig", "PubSubToolset"] +__all__ = ["PubSubCredentialsConfig", "PubSubToolConfig", "PubSubToolset"] diff --git a/src/google/adk/tools/pubsub/client.py b/src/google/adk/tools/pubsub/client.py index aab384687d..d5f04c3e87 100644 --- a/src/google/adk/tools/pubsub/client.py +++ b/src/google/adk/tools/pubsub/client.py @@ -14,9 +14,8 @@ from __future__ import annotations -from typing import List -from typing import Optional -from typing import Union +import threading +import time from google.api_core.gapic_v1.client_info import ClientInfo from google.auth.credentials import Credentials @@ -27,20 +26,17 @@ USER_AGENT = f"adk-pubsub-tool google-adk/{version.__version__}" - -import threading -import time +_CACHE_TTL = 1800 # 30 minutes _publisher_client_cache = {} _publisher_client_lock = threading.Lock() -_CACHE_TTL = 1800 # 30 minutes def get_publisher_client( *, credentials: Credentials, - user_agent: Optional[Union[str, List[str]]] = None, - publisher_options: Optional[pubsub_v1.types.PublisherOptions] = None, + user_agent: str | list[str] | None = None, + publisher_options: pubsub_v1.types.PublisherOptions | None = None, ) -> pubsub_v1.PublisherClient: """Get a Pub/Sub Publisher client. @@ -76,11 +72,11 @@ def get_publisher_client( if isinstance(user_agent, str): user_agents.append(user_agent) else: - user_agents.extend([ua for ua in user_agent if ua]) + user_agents.extend(ua for ua in user_agent if ua) client_info = ClientInfo(user_agent=" ".join(user_agents)) - # Since we syncrhonously publish messages, we want to disable batching to + # Since we synchronously publish messages, we want to disable batching to # remove any delay. custom_batch_settings = BatchSettings(max_messages=1) publisher_client = pubsub_v1.PublisherClient( @@ -102,7 +98,7 @@ def get_publisher_client( def get_subscriber_client( *, credentials: Credentials, - user_agent: Optional[Union[str, List[str]]] = None, + user_agent: str | list[str] | None = None, ) -> pubsub_v1.SubscriberClient: """Get a Pub/Sub Subscriber client. @@ -137,7 +133,7 @@ def get_subscriber_client( if isinstance(user_agent, str): user_agents.append(user_agent) else: - user_agents.extend([ua for ua in user_agent if ua]) + user_agents.extend(ua for ua in user_agent if ua) client_info = ClientInfo(user_agent=" ".join(user_agents)) @@ -162,4 +158,6 @@ def cleanup_clients(): _publisher_client_cache.clear() with _subscriber_client_lock: + for client, _ in _subscriber_client_cache.values(): + client.close() _subscriber_client_cache.clear() diff --git a/src/google/adk/tools/pubsub/config.py b/src/google/adk/tools/pubsub/config.py index a91c62931e..eb48a1f7f4 100644 --- a/src/google/adk/tools/pubsub/config.py +++ b/src/google/adk/tools/pubsub/config.py @@ -14,8 +14,6 @@ from __future__ import annotations -from typing import Optional - from pydantic import BaseModel from pydantic import ConfigDict @@ -29,8 +27,9 @@ class PubSubToolConfig(BaseModel): # Forbid any fields not defined in the model model_config = ConfigDict(extra='forbid') - project_id: Optional[str] = None + project_id: str | None = None """GCP project ID to use for the Pub/Sub operations. - If not set, the project ID will be inferred from the environment or credentials. + If not set, the project ID will be inferred from the environment or + credentials. """ diff --git a/src/google/adk/tools/pubsub/message_tool.py b/src/google/adk/tools/pubsub/message_tool.py index 0720765e9c..d500de131d 100644 --- a/src/google/adk/tools/pubsub/message_tool.py +++ b/src/google/adk/tools/pubsub/message_tool.py @@ -14,10 +14,6 @@ from __future__ import annotations -import base64 -from typing import List -from typing import Optional - from google.auth.credentials import Credentials from google.cloud import pubsub_v1 @@ -30,22 +26,26 @@ def publish_message( message: str, credentials: Credentials, settings: PubSubToolConfig, - attributes: Optional[dict[str, str]] = None, - ordering_key: Optional[str] = None, + attributes: dict[str, str] | None = None, + ordering_key: str | None = None, ) -> dict: """Publish a message to a Pub/Sub topic. Args: - topic_name (str): The Pub/Sub topic name (e.g. projects/my-project/topics/my-topic). + topic_name (str): The Pub/Sub topic name (e.g. + projects/my-project/topics/my-topic). message (str): The message content to publish. credentials (Credentials): The credentials to use for the request. settings (PubSubToolConfig): The Pub/Sub tool settings. - attributes (Optional[dict[str, str]]): Optional attributes to attach to the message. - ordering_key (Optional[str]): Optional ordering key for the message. + attributes (dict[str, str] | None): Attributes to attach to the message. + ordering_key (str | None): Ordering key for the message. Returns: dict: Dictionary with the message_id of the published message. """ + if attributes is None: + attributes = {} + try: if ordering_key: publisher_options = pubsub_v1.types.PublisherOptions( @@ -60,16 +60,14 @@ def publish_message( ) message_bytes = message.encode("utf-8") - future = publisher_client.publish( topic_name, data=message_bytes, ordering_key=ordering_key or "", **(attributes or {}), ) - message_id = future.result() - return {"message_id": message_id} + return {"message_id": future.result()} except Exception as ex: return { "status": "ERROR", @@ -89,11 +87,13 @@ def pull_messages( """Pull messages from a Pub/Sub subscription. Args: - subscription_name (str): The Pub/Sub subscription name (e.g. projects/my-project/subscriptions/my-sub). + subscription_name (str): The Pub/Sub subscription name (e.g. + projects/my-project/subscriptions/my-sub). credentials (Credentials): The credentials to use for the request. settings (PubSubToolConfig): The Pub/Sub tool settings. max_messages (int): The maximum number of messages to pull. Defaults to 1. - auto_ack (bool): Whether to automatically acknowledge the messages. Defaults to False. + auto_ack (bool): Whether to automatically acknowledge the messages. + Defaults to False. Returns: dict: Dictionary with the list of pulled messages. @@ -149,15 +149,16 @@ def pull_messages( def acknowledge_messages( subscription_name: str, - ack_ids: List[str], + ack_ids: list[str], credentials: Credentials, settings: PubSubToolConfig, ) -> dict: """Acknowledge messages on a Pub/Sub subscription. Args: - subscription_name (str): The Pub/Sub subscription name (e.g. projects/my-project/subscriptions/my-sub). - ack_ids (List[str]): List of acknowledgment IDs to acknowledge. + subscription_name (str): The Pub/Sub subscription name (e.g. + projects/my-project/subscriptions/my-sub). + ack_ids (list[str]): List of acknowledgment IDs to acknowledge. credentials (Credentials): The credentials to use for the request. settings (PubSubToolConfig): The Pub/Sub tool settings. diff --git a/src/google/adk/tools/pubsub/pubsub_credentials.py b/src/google/adk/tools/pubsub/pubsub_credentials.py index 7729f1ea18..8db75e7abf 100644 --- a/src/google/adk/tools/pubsub/pubsub_credentials.py +++ b/src/google/adk/tools/pubsub/pubsub_credentials.py @@ -16,14 +16,15 @@ from pydantic import model_validator -from ...utils.feature_decorator import experimental +from ...features import experimental +from ...features import FeatureName from .._google_credentials import BaseGoogleCredentialsConfig PUBSUB_TOKEN_CACHE_KEY = "pubsub_token_cache" PUBSUB_DEFAULT_SCOPE = ["https://www.googleapis.com/auth/pubsub"] -@experimental +@experimental(FeatureName.GOOGLE_CREDENTIALS_CONFIG) class PubSubCredentialsConfig(BaseGoogleCredentialsConfig): """Pub/Sub Credentials Configuration for Google API tools (Experimental). diff --git a/src/google/adk/tools/pubsub/pubsub_toolset.py b/src/google/adk/tools/pubsub/pubsub_toolset.py index 65e9da304c..6e530b4877 100644 --- a/src/google/adk/tools/pubsub/pubsub_toolset.py +++ b/src/google/adk/tools/pubsub/pubsub_toolset.py @@ -14,34 +14,31 @@ from __future__ import annotations -from typing import List -from typing import Optional -from typing import Union - from google.adk.agents.readonly_context import ReadonlyContext from typing_extensions import override -from . import client -from . import message_tool +from ...features import experimental +from ...features import FeatureName from ...tools.base_tool import BaseTool from ...tools.base_toolset import BaseToolset from ...tools.base_toolset import ToolPredicate from ...tools.google_tool import GoogleTool -from ...utils.feature_decorator import experimental +from . import client +from . import message_tool from .config import PubSubToolConfig from .pubsub_credentials import PubSubCredentialsConfig -@experimental +@experimental(FeatureName.PUBSUB_TOOLSET) class PubSubToolset(BaseToolset): """Pub/Sub Toolset contains tools for interacting with Pub/Sub topics and subscriptions.""" def __init__( self, *, - tool_filter: Optional[Union[ToolPredicate, List[str]]] = None, - credentials_config: Optional[PubSubCredentialsConfig] = None, - pubsub_tool_config: Optional[PubSubToolConfig] = None, + tool_filter: ToolPredicate | list[str] | None = None, + credentials_config: PubSubCredentialsConfig | None = None, + pubsub_tool_config: PubSubToolConfig | None = None, ): super().__init__(tool_filter=tool_filter) self._credentials_config = credentials_config @@ -65,8 +62,8 @@ def _is_tool_selected( @override async def get_tools( - self, readonly_context: Optional[ReadonlyContext] = None - ) -> List[BaseTool]: + self, readonly_context: ReadonlyContext | None = None + ) -> list[BaseTool]: """Get tools from the toolset.""" all_tools = [ GoogleTool( diff --git a/tests/unittests/tools/pubsub/test_pubsub_client.py b/tests/unittests/tools/pubsub/test_pubsub_client.py index c4dccdec39..95d38287c7 100644 --- a/tests/unittests/tools/pubsub/test_pubsub_client.py +++ b/tests/unittests/tools/pubsub/test_pubsub_client.py @@ -32,10 +32,10 @@ def cleanup_pubsub_clients(): client.cleanup_clients() -@mock.patch("google.cloud.pubsub_v1.PublisherClient") +@mock.patch.object(pubsub_v1, "PublisherClient", autospec=True) def test_get_publisher_client(mock_publisher_client): """Test get_publisher_client factory.""" - mock_creds = mock.Mock(spec=Credentials) + mock_creds = mock.create_autospec(Credentials, instance=True, spec_set=True) client.get_publisher_client(credentials=mock_creds) mock_publisher_client.assert_called_once() @@ -46,11 +46,13 @@ def test_get_publisher_client(mock_publisher_client): assert kwargs["batch_settings"].max_messages == 1 -@mock.patch("google.cloud.pubsub_v1.PublisherClient") +@mock.patch.object(pubsub_v1, "PublisherClient", autospec=True) def test_get_publisher_client_with_options(mock_publisher_client): """Test get_publisher_client factory with options.""" - mock_creds = mock.Mock(spec=Credentials) - mock_options = mock.Mock(spec=pubsub_v1.types.PublisherOptions) + mock_creds = mock.create_autospec(Credentials, instance=True, spec_set=True) + mock_options = mock.create_autospec( + pubsub_v1.types.PublisherOptions, instance=True, spec_set=True + ) client.get_publisher_client( credentials=mock_creds, publisher_options=mock_options ) @@ -64,13 +66,13 @@ def test_get_publisher_client_with_options(mock_publisher_client): assert kwargs["batch_settings"].max_messages == 1 -@mock.patch("google.cloud.pubsub_v1.PublisherClient") +@mock.patch.object(pubsub_v1, "PublisherClient", autospec=True) def test_get_publisher_client_caching(mock_publisher_client): """Test get_publisher_client caching behavior.""" # Configure mock to return different instances mock_publisher_client.side_effect = [mock.Mock(), mock.Mock()] - mock_creds = mock.Mock(spec=Credentials) + mock_creds = mock.create_autospec(Credentials, instance=True, spec_set=True) # First call - should create client client1 = client.get_publisher_client(credentials=mock_creds) @@ -82,16 +84,16 @@ def test_get_publisher_client_caching(mock_publisher_client): mock_publisher_client.assert_called_once() # Still called only once # Call with different args - should create new client - mock_creds2 = mock.Mock(spec=Credentials) + mock_creds2 = mock.create_autospec(Credentials, instance=True, spec_set=True) client3 = client.get_publisher_client(credentials=mock_creds2) assert client3 is not client1 assert mock_publisher_client.call_count == 2 -@mock.patch("google.cloud.pubsub_v1.SubscriberClient") +@mock.patch.object(pubsub_v1, "SubscriberClient", autospec=True) def test_get_subscriber_client(mock_subscriber_client): """Test get_subscriber_client factory.""" - mock_creds = mock.Mock(spec=Credentials) + mock_creds = mock.create_autospec(Credentials, instance=True, spec_set=True) client.get_subscriber_client(credentials=mock_creds) mock_subscriber_client.assert_called_once() @@ -100,13 +102,13 @@ def test_get_subscriber_client(mock_subscriber_client): assert "client_info" in kwargs -@mock.patch("google.cloud.pubsub_v1.SubscriberClient") +@mock.patch.object(pubsub_v1, "SubscriberClient", autospec=True) def test_get_subscriber_client_caching(mock_subscriber_client): """Test get_subscriber_client caching behavior.""" # Configure mock to return different instances mock_subscriber_client.side_effect = [mock.Mock(), mock.Mock()] - mock_creds = mock.Mock(spec=Credentials) + mock_creds = mock.create_autospec(Credentials, instance=True, spec_set=True) # First call - should create client client1 = client.get_subscriber_client(credentials=mock_creds) @@ -118,7 +120,7 @@ def test_get_subscriber_client_caching(mock_subscriber_client): mock_subscriber_client.assert_called_once() # Still called only once # Call with different args - should create new client - mock_creds2 = mock.Mock(spec=Credentials) + mock_creds2 = mock.create_autospec(Credentials, instance=True, spec_set=True) client3 = client.get_subscriber_client(credentials=mock_creds2) assert client3 is not client1 assert mock_subscriber_client.call_count == 2 diff --git a/tests/unittests/tools/pubsub/test_pubsub_message_tool.py b/tests/unittests/tools/pubsub/test_pubsub_message_tool.py index fc2421944d..2a935f41e2 100644 --- a/tests/unittests/tools/pubsub/test_pubsub_message_tool.py +++ b/tests/unittests/tools/pubsub/test_pubsub_message_tool.py @@ -20,8 +20,11 @@ from google.adk.tools.pubsub import client as pubsub_client_lib from google.adk.tools.pubsub import message_tool from google.adk.tools.pubsub.config import PubSubToolConfig +from google.api_core import future from google.cloud import pubsub_v1 +from google.cloud.pubsub_v1 import types from google.oauth2.credentials import Credentials +from google.protobuf import timestamp_pb2 @mock.patch.dict(os.environ, {}, clear=True) @@ -39,7 +42,7 @@ def test_publish_message(mock_get_publisher_client, mock_publish): ) mock_get_publisher_client.return_value = mock_publisher_client - mock_future = mock.Mock() + mock_future = mock.create_autospec(future.Future, instance=True) mock_future.result.return_value = "message_id" mock_publisher_client.publish.return_value = mock_future @@ -70,7 +73,7 @@ def test_publish_message_with_ordering_key( ) mock_get_publisher_client.return_value = mock_publisher_client - mock_future = mock.Mock() + mock_future = mock.create_autospec(future.Future, instance=True) mock_future.result.return_value = "message_id" mock_publisher_client.publish.return_value = mock_future @@ -112,7 +115,7 @@ def test_publish_message_with_attributes( ) mock_get_publisher_client.return_value = mock_publisher_client - mock_future = mock.Mock() + mock_future = mock.create_autospec(future.Future, instance=True) mock_future.result.return_value = "message_id" mock_publisher_client.publish.return_value = mock_future @@ -178,12 +181,12 @@ def test_pull_messages(mock_get_subscriber_client): ) mock_get_subscriber_client.return_value = mock_subscriber_client - mock_response = mock.Mock() - mock_message = mock.Mock() + mock_response = mock.create_autospec(types.PullResponse, instance=True) + mock_message = mock.MagicMock() mock_message.message.message_id = "123" mock_message.message.data = b"Hello" mock_message.message.attributes = {"key": "value"} - mock_publish_time = mock.Mock() + mock_publish_time = mock.MagicMock() mock_publish_time.rfc3339.return_value = "2023-01-01T00:00:00Z" mock_message.message.publish_time = mock_publish_time mock_message.ack_id = "ack_123" @@ -194,11 +197,14 @@ def test_pull_messages(mock_get_subscriber_client): subscription_name, mock_credentials, tool_settings ) - assert len(result["messages"]) == 1 - assert result["messages"][0]["message_id"] == "123" - assert result["messages"][0]["data"] == "Hello" - assert result["messages"][0]["attributes"] == {"key": "value"} - assert result["messages"][0]["ack_id"] == "ack_123" + expected_message = { + "message_id": "123", + "data": "Hello", + "attributes": {"key": "value"}, + "publish_time": "2023-01-01T00:00:00Z", + "ack_id": "ack_123", + } + assert result["messages"] == [expected_message] mock_get_subscriber_client.assert_called_once() mock_subscriber_client.pull.assert_called_once_with( @@ -220,12 +226,12 @@ def test_pull_messages_auto_ack(mock_get_subscriber_client): ) mock_get_subscriber_client.return_value = mock_subscriber_client - mock_response = mock.Mock() - mock_message = mock.Mock() + mock_response = mock.create_autospec(types.PullResponse, instance=True) + mock_message = mock.MagicMock() mock_message.message.message_id = "123" mock_message.message.data = b"Hello" mock_message.message.attributes = {} - mock_publish_time = mock.Mock() + mock_publish_time = mock.MagicMock() mock_publish_time.rfc3339.return_value = "2023-01-01T00:00:00Z" mock_message.message.publish_time = mock_publish_time mock_message.ack_id = "ack_123" diff --git a/tests/unittests/tools/pubsub/test_pubsub_toolset.py b/tests/unittests/tools/pubsub/test_pubsub_toolset.py index 0534eb9a9d..4750db1204 100644 --- a/tests/unittests/tools/pubsub/test_pubsub_toolset.py +++ b/tests/unittests/tools/pubsub/test_pubsub_toolset.py @@ -42,14 +42,14 @@ async def test_pubsub_toolset_tools_default(): assert tools is not None assert len(tools) == 3 - assert all([isinstance(tool, GoogleTool) for tool in tools]) + assert all(isinstance(tool, GoogleTool) for tool in tools) expected_tool_names = set([ "publish_message", "pull_messages", "acknowledge_messages", ]) - actual_tool_names = set([tool.name for tool in tools]) + actual_tool_names = {tool.name for tool in tools} assert actual_tool_names == expected_tool_names @@ -69,6 +69,9 @@ async def test_pubsub_toolset_tools_selective(selected_tools): This test verifies the behavior of the PubSub toolset when filter is specified. A use case for this would be when the agent builder wants to use only a subset of the tools provided by the toolset. + + Args: + selected_tools: The list of tools to select from the toolset. """ credentials_config = PubSubCredentialsConfig( client_id="abc", client_secret="def" @@ -80,10 +83,10 @@ async def test_pubsub_toolset_tools_selective(selected_tools): assert tools is not None assert len(tools) == len(selected_tools) - assert all([isinstance(tool, GoogleTool) for tool in tools]) + assert all(isinstance(tool, GoogleTool) for tool in tools) expected_tool_names = set(selected_tools) - actual_tool_names = set([tool.name for tool in tools]) + actual_tool_names = {tool.name for tool in tools} assert actual_tool_names == expected_tool_names @@ -104,6 +107,10 @@ async def test_pubsub_toolset_unknown_tool(selected_tools, returned_tools): This test verifies the behavior of the PubSub toolset when filter is specified with an unknown tool. + + Args: + selected_tools: The list of tools to select from the toolset. + returned_tools: The list of tools that are expected to be returned. """ credentials_config = PubSubCredentialsConfig( client_id="abc", client_secret="def" @@ -117,8 +124,8 @@ async def test_pubsub_toolset_unknown_tool(selected_tools, returned_tools): assert tools is not None assert len(tools) == len(returned_tools) - assert all([isinstance(tool, GoogleTool) for tool in tools]) + assert all(isinstance(tool, GoogleTool) for tool in tools) expected_tool_names = set(returned_tools) - actual_tool_names = set([tool.name for tool in tools]) + actual_tool_names = {tool.name for tool in tools} assert actual_tool_names == expected_tool_names From 295943c8b6c493846b434b36ac2caa8ccbed96ff Mon Sep 17 00:00:00 2001 From: Kamal Aboul-Hosn Date: Fri, 12 Dec 2025 13:59:29 +0000 Subject: [PATCH 13/14] Fix import order --- src/google/adk/tools/pubsub/pubsub_toolset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/google/adk/tools/pubsub/pubsub_toolset.py b/src/google/adk/tools/pubsub/pubsub_toolset.py index 6e530b4877..2cc084ae87 100644 --- a/src/google/adk/tools/pubsub/pubsub_toolset.py +++ b/src/google/adk/tools/pubsub/pubsub_toolset.py @@ -17,14 +17,14 @@ from google.adk.agents.readonly_context import ReadonlyContext from typing_extensions import override +from . import client +from . import message_tool from ...features import experimental from ...features import FeatureName from ...tools.base_tool import BaseTool from ...tools.base_toolset import BaseToolset from ...tools.base_toolset import ToolPredicate from ...tools.google_tool import GoogleTool -from . import client -from . import message_tool from .config import PubSubToolConfig from .pubsub_credentials import PubSubCredentialsConfig From 84cab8150fb13791be6b2ec97f286829393a7175 Mon Sep 17 00:00:00 2001 From: Kamal Aboul-Hosn Date: Tue, 16 Dec 2025 01:40:56 +0000 Subject: [PATCH 14/14] Format/various fixes --- pyproject.toml | 2 +- src/google/adk/tools/pubsub/client.py | 2 + src/google/adk/tools/pubsub/message_tool.py | 22 ++- .../adk/tools/pubsub/pubsub_credentials.py | 2 +- src/google/adk/tools/pubsub/pubsub_toolset.py | 9 + .../tools/pubsub/test_pubsub_client.py | 22 ++- .../tools/pubsub/test_pubsub_credentials.py | 170 +++++++++++------- 7 files changed, 148 insertions(+), 81 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 30f41e47ff..7a6031c5f9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,7 @@ dependencies = [ "google-cloud-bigquery>=2.2.0", "google-cloud-bigtable>=2.32.0", # For Bigtable database "google-cloud-discoveryengine>=0.13.12, <0.14.0", # For Discovery Engine Search Tool - "google-cloud-pubsub>=2.0.0", # For Pub/Sub Tool + "google-cloud-pubsub>=2.0.0, <3.0.0", # For Pub/Sub Tool "google-cloud-secret-manager>=2.22.0, <3.0.0", # Fetching secrets in RestAPI Tool "google-cloud-spanner>=3.56.0, <4.0.0", # For Spanner database "google-cloud-speech>=2.30.0, <3.0.0", # For Audio Transcription diff --git a/src/google/adk/tools/pubsub/client.py b/src/google/adk/tools/pubsub/client.py index d5f04c3e87..b04c9ae7f5 100644 --- a/src/google/adk/tools/pubsub/client.py +++ b/src/google/adk/tools/pubsub/client.py @@ -155,6 +155,8 @@ def cleanup_clients(): global _publisher_client_cache, _subscriber_client_cache with _publisher_client_lock: + for client, _ in _publisher_client_cache.values(): + client.transport.close() _publisher_client_cache.clear() with _subscriber_client_lock: diff --git a/src/google/adk/tools/pubsub/message_tool.py b/src/google/adk/tools/pubsub/message_tool.py index d500de131d..182b48c0bd 100644 --- a/src/google/adk/tools/pubsub/message_tool.py +++ b/src/google/adk/tools/pubsub/message_tool.py @@ -14,6 +14,8 @@ from __future__ import annotations +import base64 + from google.auth.credentials import Credentials from google.cloud import pubsub_v1 @@ -77,10 +79,20 @@ def publish_message( } +def _decode_message_data(data: bytes) -> str: + """Decodes message data, trying UTF-8 and falling back to base64.""" + try: + return data.decode("utf-8") + except UnicodeDecodeError: + # If UTF-8 decoding fails, encode as base64 string + return base64.b64encode(data).decode("ascii") + + def pull_messages( subscription_name: str, credentials: Credentials, settings: PubSubToolConfig, + *, max_messages: int = 1, auto_ack: bool = False, ) -> dict: @@ -112,15 +124,7 @@ def pull_messages( messages = [] ack_ids = [] for received_message in response.received_messages: - # Try to decode as UTF-8, fall back to base64 for binary data - try: - message_data = received_message.message.data.decode("utf-8") - except UnicodeDecodeError: - # If UTF-8 decoding fails, encode as base64 string - message_data = base64.b64encode(received_message.message.data).decode( - "ascii" - ) - + message_data = _decode_message_data(received_message.message.data) messages.append({ "message_id": received_message.message.message_id, "data": message_data, diff --git a/src/google/adk/tools/pubsub/pubsub_credentials.py b/src/google/adk/tools/pubsub/pubsub_credentials.py index 8db75e7abf..ed04b9e0d7 100644 --- a/src/google/adk/tools/pubsub/pubsub_credentials.py +++ b/src/google/adk/tools/pubsub/pubsub_credentials.py @@ -21,7 +21,7 @@ from .._google_credentials import BaseGoogleCredentialsConfig PUBSUB_TOKEN_CACHE_KEY = "pubsub_token_cache" -PUBSUB_DEFAULT_SCOPE = ["https://www.googleapis.com/auth/pubsub"] +PUBSUB_DEFAULT_SCOPE = ("https://www.googleapis.com/auth/pubsub",) @experimental(FeatureName.GOOGLE_CREDENTIALS_CONFIG) diff --git a/src/google/adk/tools/pubsub/pubsub_toolset.py b/src/google/adk/tools/pubsub/pubsub_toolset.py index 2cc084ae87..9f7fb0ed4f 100644 --- a/src/google/adk/tools/pubsub/pubsub_toolset.py +++ b/src/google/adk/tools/pubsub/pubsub_toolset.py @@ -40,6 +40,15 @@ def __init__( credentials_config: PubSubCredentialsConfig | None = None, pubsub_tool_config: PubSubToolConfig | None = None, ): + """Initializes the PubSubToolset. + + Args: + tool_filter: A predicate or list of tool names to filter the tools in + the toolset. If None, all tools are included. + credentials_config: The credentials configuration to use for + authenticating with Google Cloud. + pubsub_tool_config: The configuration for the Pub/Sub tools. + """ super().__init__(tool_filter=tool_filter) self._credentials_config = credentials_config self._tool_settings = ( diff --git a/tests/unittests/tools/pubsub/test_pubsub_client.py b/tests/unittests/tools/pubsub/test_pubsub_client.py index 95d38287c7..fec9b3798d 100644 --- a/tests/unittests/tools/pubsub/test_pubsub_client.py +++ b/tests/unittests/tools/pubsub/test_pubsub_client.py @@ -20,6 +20,14 @@ import pytest +# Save original Pub/Sub classes before patching. +# This is necessary because create_autospec cannot be used on a mock object, +# and mock.patch.object(..., autospec=True) replaces the class with a mock. +# We need the original class to create spec'd mocks in side_effect. +ORIG_PUBLISHER = pubsub_v1.PublisherClient +ORIG_SUBSCRIBER = pubsub_v1.SubscriberClient + + @pytest.fixture(autouse=True) def cleanup_pubsub_clients(): """Automatically clean up Pub/Sub client caches after each test. @@ -69,10 +77,11 @@ def test_get_publisher_client_with_options(mock_publisher_client): @mock.patch.object(pubsub_v1, "PublisherClient", autospec=True) def test_get_publisher_client_caching(mock_publisher_client): """Test get_publisher_client caching behavior.""" - # Configure mock to return different instances - mock_publisher_client.side_effect = [mock.Mock(), mock.Mock()] - mock_creds = mock.create_autospec(Credentials, instance=True, spec_set=True) + mock_publisher_client.side_effect = [ + mock.create_autospec(ORIG_PUBLISHER, instance=True, spec_set=True), + mock.create_autospec(ORIG_PUBLISHER, instance=True, spec_set=True), + ] # First call - should create client client1 = client.get_publisher_client(credentials=mock_creds) @@ -105,10 +114,11 @@ def test_get_subscriber_client(mock_subscriber_client): @mock.patch.object(pubsub_v1, "SubscriberClient", autospec=True) def test_get_subscriber_client_caching(mock_subscriber_client): """Test get_subscriber_client caching behavior.""" - # Configure mock to return different instances - mock_subscriber_client.side_effect = [mock.Mock(), mock.Mock()] - mock_creds = mock.create_autospec(Credentials, instance=True, spec_set=True) + mock_subscriber_client.side_effect = [ + mock.create_autospec(ORIG_SUBSCRIBER, instance=True, spec_set=True), + mock.create_autospec(ORIG_SUBSCRIBER, instance=True, spec_set=True), + ] # First call - should create client client1 = client.get_subscriber_client(credentials=mock_creds) diff --git a/tests/unittests/tools/pubsub/test_pubsub_credentials.py b/tests/unittests/tools/pubsub/test_pubsub_credentials.py index 7c19586a0b..11a5d5dea7 100644 --- a/tests/unittests/tools/pubsub/test_pubsub_credentials.py +++ b/tests/unittests/tools/pubsub/test_pubsub_credentials.py @@ -21,71 +21,113 @@ import pytest -class TestPubSubCredentials: - """Test suite for PubSub credentials configuration validation. +"""Test suite for PubSub credentials configuration validation. - This class tests the credential configuration logic that ensures - either existing credentials or client ID/secret pairs are provided. +This class tests the credential configuration logic that ensures +either existing credentials or client ID/secret pairs are provided. +""" + + +def test_pubsub_credentials_config_client_id_secret(): + """Test PubSubCredentialsConfig with client_id and client_secret. + + Ensures that when client_id and client_secret are provided, the config + object is created with the correct attributes. + """ + config = PubSubCredentialsConfig(client_id="abc", client_secret="def") + assert config.client_id == "abc" + assert config.client_secret == "def" + assert config.scopes == PUBSUB_DEFAULT_SCOPE + assert config.credentials is None + + +def test_pubsub_credentials_config_existing_creds(): + """Test PubSubCredentialsConfig with existing generic credentials. + + Ensures that when a generic Credentials object is provided, it is + stored correctly. + """ + mock_creds = mock.create_autospec(Credentials, instance=True) + config = PubSubCredentialsConfig(credentials=mock_creds) + assert config.credentials == mock_creds + assert config.client_id is None + assert config.client_secret is None + + +def test_pubsub_credentials_config_oauth2_creds(): + """Test PubSubCredentialsConfig with existing OAuth2 credentials. + + Ensures that when a google.oauth2.credentials.Credentials object is + provided, the client_id, client_secret, and scopes are extracted + from the credentials object. """ + mock_creds = mock.create_autospec( + google.oauth2.credentials.Credentials, instance=True + ) + mock_creds.client_id = "oauth_client_id" + mock_creds.client_secret = "oauth_client_secret" + mock_creds.scopes = ["fake_scope"] + config = PubSubCredentialsConfig(credentials=mock_creds) + assert config.client_id == "oauth_client_id" + assert config.client_secret == "oauth_client_secret" + assert config.scopes == ["fake_scope"] - def test_pubsub_credentials_config_client_id_secret(self): - """Test PubSubCredentialsConfig with client_id and client_secret. - - Ensures that when client_id and client_secret are provided, the config - object is created with the correct attributes. - """ - config = PubSubCredentialsConfig(client_id="abc", client_secret="def") - assert config.client_id == "abc" - assert config.client_secret == "def" - assert config.scopes == PUBSUB_DEFAULT_SCOPE - assert config.credentials is None - - def test_pubsub_credentials_config_existing_creds(self): - """Test PubSubCredentialsConfig with existing generic credentials. - - Ensures that when a generic Credentials object is provided, it is - stored correctly. - """ - mock_creds = mock.create_autospec(Credentials, instance=True) - config = PubSubCredentialsConfig(credentials=mock_creds) - assert config.credentials == mock_creds - assert config.client_id is None - assert config.client_secret is None - - def test_pubsub_credentials_config_oauth2_creds(self): - """Test PubSubCredentialsConfig with existing OAuth2 credentials. - - Ensures that when a google.oauth2.credentials.Credentials object is - provided, the client_id, client_secret, and scopes are extracted - from the credentials object. - """ - mock_creds = mock.create_autospec( - google.oauth2.credentials.Credentials, instance=True + +@pytest.mark.parametrize( + "credentials, client_id, client_secret", + [ + # No arguments provided + (None, None, None), + # Only client_id is provided + (None, "abc", None), + ], +) +def test_pubsub_credentials_config_validation_errors( + credentials, client_id, client_secret +): + """Test PubSubCredentialsConfig validation errors. + + Ensures that ValueError is raised when invalid combinations of credentials + and client ID/secret are provided. + + Args: + credentials: The credentials object to pass. + client_id: The client ID to pass. + client_secret: The client secret to pass. + """ + with pytest.raises( + ValueError, + match=( + "Must provide either credentials or client_id and client_secret pair." + ), + ): + PubSubCredentialsConfig( + credentials=credentials, + client_id=client_id, + client_secret=client_secret, + ) + + +def test_pubsub_credentials_config_both_credentials_and_client_provided(): + """Test PubSubCredentialsConfig validation errors. + + Ensures that ValueError is raised when invalid combinations of credentials + and client ID/secret are provided. + + Args: + credentials: The credentials object to pass. + client_id: The client ID to pass. + client_secret: The client secret to pass. + """ + with pytest.raises( + ValueError, + match=( + "Cannot provide both existing credentials and" + " client_id/client_secret/scopes." + ), + ): + PubSubCredentialsConfig( + credentials=mock.create_autospec(Credentials, instance=True), + client_id="abc", + client_secret="def", ) - mock_creds.client_id = "oauth_client_id" - mock_creds.client_secret = "oauth_client_secret" - mock_creds.scopes = ["fake_scope"] - config = PubSubCredentialsConfig(credentials=mock_creds) - assert config.client_id == "oauth_client_id" - assert config.client_secret == "oauth_client_secret" - assert config.scopes == ["fake_scope"] - - def test_pubsub_credentials_config_validation_errors(self): - """Test PubSubCredentialsConfig validation errors. - - Ensures that ValueError is raised under the following conditions: - - No arguments are provided. - - Only client_id is provided. - - Both credentials and client_id/client_secret are provided. - """ - with pytest.raises(ValueError): - PubSubCredentialsConfig() - - with pytest.raises(ValueError): - PubSubCredentialsConfig(client_id="abc") - - mock_creds = mock.create_autospec(Credentials, instance=True) - with pytest.raises(ValueError): - PubSubCredentialsConfig( - credentials=mock_creds, client_id="abc", client_secret="def" - )