diff --git a/tests/unit/vertex_adk/test_agent_engine_templates_adk.py b/tests/unit/vertex_adk/test_agent_engine_templates_adk.py index 19fb6c79d4..2b0f06cece 100644 --- a/tests/unit/vertex_adk/test_agent_engine_templates_adk.py +++ b/tests/unit/vertex_adk/test_agent_engine_templates_adk.py @@ -54,6 +54,7 @@ def __init__(self, name: str, model: str): _TEST_LOCATION = "us-central1" _TEST_PROJECT = "test-project" +_TEST_API_KEY = "test-api-key" _TEST_MODEL = "gemini-2.0-flash" _TEST_USER_ID = "test_user_id" _TEST_AGENT_NAME = "test_agent" @@ -868,6 +869,41 @@ def test_dump_event_for_json(): assert base64.b64decode(part["thought_signature"]) == raw_signature +def test_adk_app_initialization_with_api_key(): + importlib.reload(initializer) + importlib.reload(vertexai) + try: + vertexai.init(api_key=_TEST_API_KEY) + app = agent_engines.AdkApp(agent=_TEST_AGENT) + assert app._tmpl_attrs.get("project") is None + assert app._tmpl_attrs.get("location") is None + assert app._tmpl_attrs.get("express_mode_api_key") == _TEST_API_KEY + assert app._tmpl_attrs.get("runner") is None + app.set_up() + assert app._tmpl_attrs.get("runner") is not None + assert os.environ.get("GOOGLE_API_KEY") == _TEST_API_KEY + assert "GOOGLE_CLOUD_LOCATION" not in os.environ + assert "GOOGLE_CLOUD_PROJECT" not in os.environ + finally: + initializer.global_pool.shutdown(wait=True) + + +def test_adk_app_initialization_with_env_api_key(): + try: + os.environ["GOOGLE_API_KEY"] == _TEST_API_KEY + app = agent_engines.AdkApp(agent=_TEST_AGENT) + assert app._tmpl_attrs.get("project") is None + assert app._tmpl_attrs.get("location") is None + assert app._tmpl_attrs.get("express_mode_api_key") == _TEST_API_KEY + assert app._tmpl_attrs.get("runner") is None + app.set_up() + assert app._tmpl_attrs.get("runner") is not None + assert "GOOGLE_CLOUD_LOCATION" not in os.environ + assert "GOOGLE_CLOUD_PROJECT" not in os.environ + finally: + initializer.global_pool.shutdown(wait=True) + + @pytest.mark.usefixtures("mock_adk_version") class TestAdkAppErrors: @pytest.mark.asyncio diff --git a/vertexai/agent_engines/templates/adk.py b/vertexai/agent_engines/templates/adk.py index 4176b58e4a..f3c186e903 100644 --- a/vertexai/agent_engines/templates/adk.py +++ b/vertexai/agent_engines/templates/adk.py @@ -534,6 +534,7 @@ def __init__( If not provided, a default instrumentor builder will be used. This parameter is ignored if `enable_tracing` is False. """ + import os from google.cloud.aiplatform import initializer adk_version = get_adk_version() @@ -571,6 +572,9 @@ def __init__( "artifact_service_builder": artifact_service_builder, "memory_service_builder": memory_service_builder, "instrumentor_builder": instrumentor_builder, + "express_mode_api_key": ( + initializer.global_config.api_key or os.environ.get("GOOGLE_API_KEY") + ), } async def _init_session( @@ -708,9 +712,18 @@ def set_up(self): os.environ["GOOGLE_GENAI_USE_VERTEXAI"] = "1" project = self._tmpl_attrs.get("project") - os.environ["GOOGLE_CLOUD_PROJECT"] = project + if project: + os.environ["GOOGLE_CLOUD_PROJECT"] = project location = self._tmpl_attrs.get("location") - os.environ["GOOGLE_CLOUD_LOCATION"] = location + if location: + os.environ["GOOGLE_CLOUD_LOCATION"] = location + express_mode_api_key = self._tmpl_attrs.get("express_mode_api_key") + if express_mode_api_key and not project: + os.environ["GOOGLE_API_KEY"] = express_mode_api_key + # Clear location and project env vars if express mode api key is provided. + os.environ.pop("GOOGLE_CLOUD_LOCATION", None) + os.environ.pop("GOOGLE_CLOUD_PROJECT", None) + location = None # Disable content capture in custom ADK spans unless user enabled # tracing explicitly with the old flag @@ -783,6 +796,8 @@ def set_up(self): VertexAiSessionService, ) + # If the express mode api key is set, it will be read from the + # environment variable when initializing the session service. self._tmpl_attrs["session_service"] = VertexAiSessionService( project=project, location=location, @@ -793,6 +808,8 @@ def set_up(self): VertexAiSessionService, ) + # If the express mode api key is set, it will be read from the + # environment variable when initializing the session service. self._tmpl_attrs["session_service"] = VertexAiSessionService( project=project, location=location, @@ -813,6 +830,8 @@ def set_up(self): VertexAiMemoryBankService, ) + # If the express mode api key is set, it will be read from the + # environment variable when initializing the memory service. self._tmpl_attrs["memory_service"] = VertexAiMemoryBankService( project=project, location=location,