Skip to content

Commit dd36c04

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: Alow VertexAiSession for streaming_agent_run_with_events
PiperOrigin-RevId: 824600367
1 parent a52da0b commit dd36c04

File tree

1 file changed

+26
-16
lines changed
  • vertexai/agent_engines/templates

1 file changed

+26
-16
lines changed

vertexai/agent_engines/templates/adk.py

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -530,14 +530,9 @@ async def _init_session(
530530
auth = _Authorization(**auth)
531531
session_state[f"temp:{auth_id}"] = auth.access_token
532532

533-
if request.session_id:
534-
session_id = request.session_id
535-
else:
536-
session_id = f"temp_session_{random.randbytes(8).hex()}"
537533
session = await session_service.create_session(
538534
app_name=self._tmpl_attrs.get("app_name"),
539535
user_id=request.user_id,
540-
session_id=session_id,
541536
state=session_state,
542537
)
543538
if not session:
@@ -873,19 +868,23 @@ async def streaming_agent_run_with_events(self, request_json: str):
873868

874869
import json
875870
from google.genai import types
871+
from google.cloud.aiplatform import base
872+
873+
_LOGGER = base.Logger(__name__)
876874

877875
request = _StreamRunRequest(**json.loads(request_json))
876+
_LOGGER.info("request: %s, type: %s", request, type(request))
878877
if not self._tmpl_attrs.get("in_memory_runner"):
879878
self.set_up()
880879
# Prepare the in-memory session.
881880
if not self._tmpl_attrs.get("in_memory_artifact_service"):
882881
self.set_up()
883882
if not self._tmpl_attrs.get("in_memory_session_service"):
884883
self.set_up()
885-
session_service = self._tmpl_attrs.get("in_memory_session_service")
886-
artifact_service = self._tmpl_attrs.get("in_memory_artifact_service")
884+
session_service = self._tmpl_attrs.get("session_service")
885+
artifact_service = self._tmpl_attrs.get("artifact_service")
886+
runner = self._tmpl_attrs.get("runner")
887887
# Try to get the session, if it doesn't exist, create a new one.
888-
session = None
889888
if request.session_id:
890889
try:
891890
session = await session_service.get_session(
@@ -894,21 +893,32 @@ async def streaming_agent_run_with_events(self, request_json: str):
894893
session_id=request.session_id,
895894
)
896895
except RuntimeError:
897-
pass
898-
if not session:
899-
# Fall back to create session if the session is not found.
900-
session = await self._init_session(
901-
session_service=session_service,
902-
artifact_service=artifact_service,
903-
request=request,
896+
# Fall back to create session if the session is not found.
897+
session = await self._init_session(
898+
session_service=session_service,
899+
artifact_service=artifact_service,
900+
request=request,
901+
)
902+
else:
903+
# Not providing a session ID will create a new in-memory session.
904+
session_service = self._tmpl_attrs.get("in_memory_session_service")
905+
artifact_service = self._tmpl_attrs.get("in_memory_artifact_service")
906+
runner = self._tmpl_attrs.get("in_memory_runner")
907+
session = await session_service.create_session(
908+
app_name=self._tmpl_attrs.get("app_name"),
909+
user_id=request.user_id,
910+
session_id=request.session_id,
904911
)
905912
if not session:
906913
raise RuntimeError("Session initialization failed.")
907914

915+
_LOGGER.info("session: %s, type: %s", session, type(session))
916+
_LOGGER.info("session.id: %s, type: %s", session.id, type(session.id))
917+
908918
# Run the agent
909919
message_for_agent = types.Content(**request.message)
910920
try:
911-
async for event in self._tmpl_attrs.get("in_memory_runner").run_async(
921+
async for event in runner.run_async(
912922
user_id=request.user_id,
913923
session_id=session.id,
914924
new_message=message_for_agent,

0 commit comments

Comments
 (0)