Skip to content

Commit 1dc03ac

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: Add support for Vertex Express Mode API key in AdkApp
PiperOrigin-RevId: 825638989
1 parent f51b813 commit 1dc03ac

File tree

2 files changed

+102
-17
lines changed

2 files changed

+102
-17
lines changed

tests/unit/vertex_adk/test_agent_engine_templates_adk.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def __init__(self, name: str, model: str):
5454

5555
_TEST_LOCATION = "us-central1"
5656
_TEST_PROJECT = "test-project"
57+
_TEST_API_KEY = "test-api-key"
5758
_TEST_MODEL = "gemini-2.0-flash"
5859
_TEST_USER_ID = "test_user_id"
5960
_TEST_AGENT_NAME = "test_agent"
@@ -761,6 +762,25 @@ def test_dump_event_for_json():
761762
assert base64.b64decode(part["thought_signature"]) == raw_signature
762763

763764

765+
def test_adk_app_initialization_with_api_key():
766+
importlib.reload(initializer)
767+
importlib.reload(vertexai)
768+
try:
769+
vertexai.init(api_key=_TEST_API_KEY)
770+
app = agent_engines.AdkApp(agent=_TEST_AGENT)
771+
assert app._tmpl_attrs.get("project") is None
772+
assert app._tmpl_attrs.get("location") is None
773+
assert app._tmpl_attrs.get("express_mode_api_key") == _TEST_API_KEY
774+
assert app._tmpl_attrs.get("runner") is None
775+
app.set_up()
776+
assert app._tmpl_attrs.get("runner") is not None
777+
assert os.environ.get("GOOGLE_API_KEY") == _TEST_API_KEY
778+
assert "GOOGLE_CLOUD_LOCATION" not in os.environ
779+
assert "GOOGLE_CLOUD_PROJECT" not in os.environ
780+
finally:
781+
initializer.global_pool.shutdown(wait=True)
782+
783+
764784
@pytest.mark.usefixtures("mock_adk_version")
765785
class TestAdkAppErrors:
766786
@pytest.mark.asyncio

vertexai/agent_engines/templates/adk.py

Lines changed: 82 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -521,6 +521,7 @@ def __init__(
521521
If not provided, a default instrumentor builder will be used.
522522
This parameter is ignored if `enable_tracing` is False.
523523
"""
524+
import os
524525
from google.cloud.aiplatform import initializer
525526

526527
adk_version = get_adk_version()
@@ -558,6 +559,7 @@ def __init__(
558559
"artifact_service_builder": artifact_service_builder,
559560
"memory_service_builder": memory_service_builder,
560561
"instrumentor_builder": instrumentor_builder,
562+
"express_mode_api_key": initializer.global_config.api_key,
561563
}
562564

563565
async def _init_session(
@@ -701,9 +703,18 @@ def set_up(self):
701703

702704
os.environ["GOOGLE_GENAI_USE_VERTEXAI"] = "1"
703705
project = self._tmpl_attrs.get("project")
704-
os.environ["GOOGLE_CLOUD_PROJECT"] = project
706+
if project:
707+
os.environ["GOOGLE_CLOUD_PROJECT"] = project
705708
location = self._tmpl_attrs.get("location")
706-
os.environ["GOOGLE_CLOUD_LOCATION"] = location
709+
if location:
710+
os.environ["GOOGLE_CLOUD_LOCATION"] = location
711+
express_mode_api_key = self._tmpl_attrs.get("express_mode_api_key")
712+
if express_mode_api_key and not project:
713+
os.environ["GOOGLE_API_KEY"] = express_mode_api_key
714+
# Clear location and project env vars if express mode api key is provided.
715+
os.environ.pop("GOOGLE_CLOUD_LOCATION")
716+
os.environ.pop("GOOGLE_CLOUD_PROJECT")
717+
location = None
707718

708719
# Disable content capture in custom ADK spans unless user enabled
709720
# tracing explicitly with the old flag
@@ -750,21 +761,57 @@ def set_up(self):
750761
VertexAiSessionService,
751762
)
752763

753-
self._tmpl_attrs["session_service"] = VertexAiSessionService(
754-
project=project,
755-
location=location,
756-
agent_engine_id=os.environ.get("GOOGLE_CLOUD_AGENT_ENGINE_ID"),
757-
)
764+
if is_version_sufficient("1.18.0"):
765+
if express_mode_api_key:
766+
self._tmpl_attrs["session_service"] = VertexAiSessionService(
767+
express_mode_api_key=express_mode_api_key,
768+
agent_engine_id=os.environ.get("GOOGLE_CLOUD_AGENT_ENGINE_ID"),
769+
)
770+
else:
771+
self._tmpl_attrs["session_service"] = VertexAiSessionService(
772+
project=project,
773+
location=location,
774+
agent_engine_id=os.environ.get("GOOGLE_CLOUD_AGENT_ENGINE_ID"),
775+
)
776+
else:
777+
if express_mode_api_key:
778+
self._tmpl_attrs["session_service"] = VertexAiSessionService(
779+
agent_engine_id=os.environ.get("GOOGLE_CLOUD_AGENT_ENGINE_ID"),
780+
)
781+
else:
782+
self._tmpl_attrs["session_service"] = VertexAiSessionService(
783+
project=project,
784+
location=location,
785+
agent_engine_id=os.environ.get("GOOGLE_CLOUD_AGENT_ENGINE_ID"),
786+
)
758787
except (ImportError, AttributeError):
759788
from google.adk.sessions.vertex_ai_session_service_g3 import (
760789
VertexAiSessionService,
761790
)
762791

763-
self._tmpl_attrs["session_service"] = VertexAiSessionService(
764-
project=project,
765-
location=location,
766-
agent_engine_id=os.environ.get("GOOGLE_CLOUD_AGENT_ENGINE_ID"),
767-
)
792+
if is_version_sufficient("1.18.0"):
793+
if express_mode_api_key:
794+
self._tmpl_attrs["session_service"] = VertexAiSessionService(
795+
express_mode_api_key=express_mode_api_key,
796+
agent_engine_id=os.environ.get("GOOGLE_CLOUD_AGENT_ENGINE_ID"),
797+
)
798+
else:
799+
self._tmpl_attrs["session_service"] = VertexAiSessionService(
800+
project=project,
801+
location=location,
802+
agent_engine_id=os.environ.get("GOOGLE_CLOUD_AGENT_ENGINE_ID"),
803+
)
804+
else:
805+
if express_mode_api_key:
806+
self._tmpl_attrs["session_service"] = VertexAiSessionService(
807+
agent_engine_id=os.environ.get("GOOGLE_CLOUD_AGENT_ENGINE_ID"),
808+
)
809+
else:
810+
self._tmpl_attrs["session_service"] = VertexAiSessionService(
811+
project=project,
812+
location=location,
813+
agent_engine_id=os.environ.get("GOOGLE_CLOUD_AGENT_ENGINE_ID"),
814+
)
768815

769816
else:
770817
self._tmpl_attrs["session_service"] = InMemorySessionService()
@@ -780,11 +827,29 @@ def set_up(self):
780827
VertexAiMemoryBankService,
781828
)
782829

783-
self._tmpl_attrs["memory_service"] = VertexAiMemoryBankService(
784-
project=project,
785-
location=location,
786-
agent_engine_id=os.environ.get("GOOGLE_CLOUD_AGENT_ENGINE_ID"),
787-
)
830+
if is_version_sufficient("1.18.0"):
831+
if express_mode_api_key:
832+
self._tmpl_attrs["memory_service"] = VertexAiMemoryBankService(
833+
express_mode_api_key=express_mode_api_key,
834+
agent_engine_id=os.environ.get("GOOGLE_CLOUD_AGENT_ENGINE_ID"),
835+
)
836+
else:
837+
self._tmpl_attrs["memory_service"] = VertexAiMemoryBankService(
838+
project=project,
839+
location=location,
840+
agent_engine_id=os.environ.get("GOOGLE_CLOUD_AGENT_ENGINE_ID"),
841+
)
842+
else:
843+
if express_mode_api_key:
844+
self._tmpl_attrs["memory_service"] = VertexAiMemoryBankService(
845+
agent_engine_id=os.environ.get("GOOGLE_CLOUD_AGENT_ENGINE_ID"),
846+
)
847+
else:
848+
self._tmpl_attrs["memory_service"] = VertexAiMemoryBankService(
849+
project=project,
850+
location=location,
851+
agent_engine_id=os.environ.get("GOOGLE_CLOUD_AGENT_ENGINE_ID"),
852+
)
788853
except (ImportError, AttributeError):
789854
# TODO(ysian): Handle this via _g3 import for google3.
790855
pass

0 commit comments

Comments
 (0)