From 13faa27376814f7b0a223ff9455c289a9af75288 Mon Sep 17 00:00:00 2001 From: Tongzhou Jiang Date: Fri, 31 Oct 2025 16:14:36 -0700 Subject: [PATCH] feat: Alow VertexAiSession for streaming_agent_run_with_events PiperOrigin-RevId: 826664694 --- vertexai/agent_engines/templates/adk.py | 50 ++++++++++++------- .../reasoning_engines/templates/adk.py | 50 +++++++++++-------- 2 files changed, 61 insertions(+), 39 deletions(-) diff --git a/vertexai/agent_engines/templates/adk.py b/vertexai/agent_engines/templates/adk.py index c16aed3a1a..ca7c715075 100644 --- a/vertexai/agent_engines/templates/adk.py +++ b/vertexai/agent_engines/templates/adk.py @@ -568,7 +568,6 @@ async def _init_session( ): """Initializes the session, and returns the session id.""" from google.adk.events.event import Event - import random session_state = None if request.authorizations: @@ -577,14 +576,9 @@ async def _init_session( auth = _Authorization(**auth) session_state[f"temp:{auth_id}"] = auth.access_token - if request.session_id: - session_id = request.session_id - else: - session_id = f"temp_session_{random.randbytes(8).hex()}" session = await session_service.create_session( app_name=self._tmpl_attrs.get("app_name"), user_id=request.user_id, - session_id=session_id, state=session_state, ) if not session: @@ -602,7 +596,7 @@ async def _init_session( saved_version = await artifact_service.save_artifact( app_name=self._tmpl_attrs.get("app_name"), user_id=request.user_id, - session_id=session_id, + session_id=session.id, filename=artifact.file_name, artifact=version_data.data, ) @@ -998,35 +992,53 @@ async def streaming_agent_run_with_events(self, request_json: str): import json from google.genai import types + from google.genai.errors import ClientError request = _StreamRunRequest(**json.loads(request_json)) if not self._tmpl_attrs.get("in_memory_runner"): self.set_up() + if not self._tmpl_attrs.get("runner"): + self.set_up() # Prepare the in-memory session. if not self._tmpl_attrs.get("in_memory_artifact_service"): self.set_up() + if not self._tmpl_attrs.get("artifact_service"): + self.set_up() if not self._tmpl_attrs.get("in_memory_session_service"): self.set_up() - session_service = self._tmpl_attrs.get("in_memory_session_service") - artifact_service = self._tmpl_attrs.get("in_memory_artifact_service") + if not self._tmpl_attrs.get("session_service"): + self.set_up() app = self._tmpl_attrs.get("app") + # Try to get the session, if it doesn't exist, create a new one. - session = None if request.session_id: + session_service = self._tmpl_attrs.get("session_service") + artifact_service = self._tmpl_attrs.get("artifact_service") + runner = self._tmpl_attrs.get("runner") try: session = await session_service.get_session( app_name=app.name if app else self._tmpl_attrs.get("app_name"), user_id=request.user_id, session_id=request.session_id, ) - except RuntimeError: - pass - if not session: - # Fall back to create session if the session is not found. - session = await self._init_session( - session_service=session_service, - artifact_service=artifact_service, - request=request, + except ClientError: + # Fall back to create session if the session is not found. + # Specifying session_id on creation is not supported, + # so session id will be regenerated. + session = await self._init_session( + session_service=session_service, + artifact_service=artifact_service, + request=request, + ) + else: + # Not providing a session ID will create a new in-memory session. + session_service = self._tmpl_attrs.get("in_memory_session_service") + artifact_service = self._tmpl_attrs.get("in_memory_artifact_service") + runner = self._tmpl_attrs.get("in_memory_runner") + session = await session_service.create_session( + app_name=self._tmpl_attrs.get("app_name"), + user_id=request.user_id, + session_id=request.session_id, ) if not session: raise RuntimeError("Session initialization failed.") @@ -1034,7 +1046,7 @@ async def streaming_agent_run_with_events(self, request_json: str): # Run the agent message_for_agent = types.Content(**request.message) try: - async for event in self._tmpl_attrs.get("in_memory_runner").run_async( + async for event in runner.run_async( user_id=request.user_id, session_id=session.id, new_message=message_for_agent, diff --git a/vertexai/preview/reasoning_engines/templates/adk.py b/vertexai/preview/reasoning_engines/templates/adk.py index 345ff981f1..b07b1ddc98 100644 --- a/vertexai/preview/reasoning_engines/templates/adk.py +++ b/vertexai/preview/reasoning_engines/templates/adk.py @@ -520,7 +520,6 @@ async def _init_session( ): """Initializes the session, and returns the session id.""" from google.adk.events.event import Event - import random session_state = None if request.authorizations: @@ -529,14 +528,9 @@ async def _init_session( auth = _Authorization(**auth) session_state[f"temp:{auth_id}"] = auth.access_token - if request.session_id: - session_id = request.session_id - else: - session_id = f"temp_session_{random.randbytes(8).hex()}" session = await session_service.create_session( app_name=self._tmpl_attrs.get("app_name"), user_id=request.user_id, - session_id=session_id, state=session_state, ) if not session: @@ -554,7 +548,7 @@ async def _init_session( saved_version = await artifact_service.save_artifact( app_name=self._tmpl_attrs.get("app_name"), user_id=request.user_id, - session_id=session_id, + session_id=session.id, filename=artifact.file_name, artifact=version_data.data, ) @@ -904,6 +898,7 @@ async def async_stream_query( def streaming_agent_run_with_events(self, request_json: str): import json from google.genai import types + from google.genai.errors import ClientError event_queue = queue.Queue(maxsize=1) @@ -911,37 +906,52 @@ async def _invoke_agent_async(): request = _StreamRunRequest(**json.loads(request_json)) if not self._tmpl_attrs.get("in_memory_runner"): self.set_up() + if not self._tmpl_attrs.get("runner"): + self.set_up() # Prepare the in-memory session. if not self._tmpl_attrs.get("in_memory_artifact_service"): self.set_up() + if not self._tmpl_attrs.get("artifact_service"): + self.set_up() if not self._tmpl_attrs.get("in_memory_session_service"): self.set_up() - session_service = self._tmpl_attrs.get("in_memory_session_service") - artifact_service = self._tmpl_attrs.get("in_memory_artifact_service") - # Try to get the session, if it doesn't exist, create a new one. - session = None + if not self._tmpl_attrs.get("session_service"): + self.set_up() if request.session_id: + session_service = self._tmpl_attrs.get("session_service") + artifact_service = self._tmpl_attrs.get("artifact_service") + runner = self._tmpl_attrs.get("runner") try: session = await session_service.get_session( app_name=self._tmpl_attrs.get("app_name"), user_id=request.user_id, session_id=request.session_id, ) - except RuntimeError: - pass - if not session: - # Fall back to create session if the session is not found. - session = await self._init_session( - session_service=session_service, - artifact_service=artifact_service, - request=request, + except ClientError: + # Fall back to create session if the session is not found. + # Specifying session_id on creation is not supported, + # so session id will be regenerated. + session = await self._init_session( + session_service=session_service, + artifact_service=artifact_service, + request=request, + ) + else: + # Not providing a session ID will create a new in-memory session. + session_service = self._tmpl_attrs.get("in_memory_session_service") + artifact_service = self._tmpl_attrs.get("in_memory_artifact_service") + runner = self._tmpl_attrs.get("in_memory_runner") + session = await session_service.create_session( + app_name=self._tmpl_attrs.get("app_name"), + user_id=request.user_id, + session_id=request.session_id, ) if not session: raise RuntimeError("Session initialization failed.") # Run the agent. message_for_agent = types.Content(**request.message) try: - for event in self._tmpl_attrs.get("in_memory_runner").run( + for event in runner.run( user_id=request.user_id, session_id=session.id, new_message=message_for_agent,