Skip to content

Commit 13faa27

Browse files
Tongzhou-Jiangcopybara-github
authored andcommitted
feat: Alow VertexAiSession for streaming_agent_run_with_events
PiperOrigin-RevId: 826664694
1 parent e600277 commit 13faa27

File tree

2 files changed

+61
-39
lines changed
  • vertexai
    • agent_engines/templates
    • preview/reasoning_engines/templates

2 files changed

+61
-39
lines changed

vertexai/agent_engines/templates/adk.py

Lines changed: 31 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -568,7 +568,6 @@ async def _init_session(
568568
):
569569
"""Initializes the session, and returns the session id."""
570570
from google.adk.events.event import Event
571-
import random
572571

573572
session_state = None
574573
if request.authorizations:
@@ -577,14 +576,9 @@ async def _init_session(
577576
auth = _Authorization(**auth)
578577
session_state[f"temp:{auth_id}"] = auth.access_token
579578

580-
if request.session_id:
581-
session_id = request.session_id
582-
else:
583-
session_id = f"temp_session_{random.randbytes(8).hex()}"
584579
session = await session_service.create_session(
585580
app_name=self._tmpl_attrs.get("app_name"),
586581
user_id=request.user_id,
587-
session_id=session_id,
588582
state=session_state,
589583
)
590584
if not session:
@@ -602,7 +596,7 @@ async def _init_session(
602596
saved_version = await artifact_service.save_artifact(
603597
app_name=self._tmpl_attrs.get("app_name"),
604598
user_id=request.user_id,
605-
session_id=session_id,
599+
session_id=session.id,
606600
filename=artifact.file_name,
607601
artifact=version_data.data,
608602
)
@@ -998,43 +992,61 @@ async def streaming_agent_run_with_events(self, request_json: str):
998992

999993
import json
1000994
from google.genai import types
995+
from google.genai.errors import ClientError
1001996

1002997
request = _StreamRunRequest(**json.loads(request_json))
1003998
if not self._tmpl_attrs.get("in_memory_runner"):
1004999
self.set_up()
1000+
if not self._tmpl_attrs.get("runner"):
1001+
self.set_up()
10051002
# Prepare the in-memory session.
10061003
if not self._tmpl_attrs.get("in_memory_artifact_service"):
10071004
self.set_up()
1005+
if not self._tmpl_attrs.get("artifact_service"):
1006+
self.set_up()
10081007
if not self._tmpl_attrs.get("in_memory_session_service"):
10091008
self.set_up()
1010-
session_service = self._tmpl_attrs.get("in_memory_session_service")
1011-
artifact_service = self._tmpl_attrs.get("in_memory_artifact_service")
1009+
if not self._tmpl_attrs.get("session_service"):
1010+
self.set_up()
10121011
app = self._tmpl_attrs.get("app")
1012+
10131013
# Try to get the session, if it doesn't exist, create a new one.
1014-
session = None
10151014
if request.session_id:
1015+
session_service = self._tmpl_attrs.get("session_service")
1016+
artifact_service = self._tmpl_attrs.get("artifact_service")
1017+
runner = self._tmpl_attrs.get("runner")
10161018
try:
10171019
session = await session_service.get_session(
10181020
app_name=app.name if app else self._tmpl_attrs.get("app_name"),
10191021
user_id=request.user_id,
10201022
session_id=request.session_id,
10211023
)
1022-
except RuntimeError:
1023-
pass
1024-
if not session:
1025-
# Fall back to create session if the session is not found.
1026-
session = await self._init_session(
1027-
session_service=session_service,
1028-
artifact_service=artifact_service,
1029-
request=request,
1024+
except ClientError:
1025+
# Fall back to create session if the session is not found.
1026+
# Specifying session_id on creation is not supported,
1027+
# so session id will be regenerated.
1028+
session = await self._init_session(
1029+
session_service=session_service,
1030+
artifact_service=artifact_service,
1031+
request=request,
1032+
)
1033+
else:
1034+
# Not providing a session ID will create a new in-memory session.
1035+
session_service = self._tmpl_attrs.get("in_memory_session_service")
1036+
artifact_service = self._tmpl_attrs.get("in_memory_artifact_service")
1037+
runner = self._tmpl_attrs.get("in_memory_runner")
1038+
session = await session_service.create_session(
1039+
app_name=self._tmpl_attrs.get("app_name"),
1040+
user_id=request.user_id,
1041+
session_id=request.session_id,
10301042
)
10311043
if not session:
10321044
raise RuntimeError("Session initialization failed.")
10331045

10341046
# Run the agent
10351047
message_for_agent = types.Content(**request.message)
10361048
try:
1037-
async for event in self._tmpl_attrs.get("in_memory_runner").run_async(
1049+
async for event in runner.run_async(
10381050
user_id=request.user_id,
10391051
session_id=session.id,
10401052
new_message=message_for_agent,

vertexai/preview/reasoning_engines/templates/adk.py

Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -520,7 +520,6 @@ async def _init_session(
520520
):
521521
"""Initializes the session, and returns the session id."""
522522
from google.adk.events.event import Event
523-
import random
524523

525524
session_state = None
526525
if request.authorizations:
@@ -529,14 +528,9 @@ async def _init_session(
529528
auth = _Authorization(**auth)
530529
session_state[f"temp:{auth_id}"] = auth.access_token
531530

532-
if request.session_id:
533-
session_id = request.session_id
534-
else:
535-
session_id = f"temp_session_{random.randbytes(8).hex()}"
536531
session = await session_service.create_session(
537532
app_name=self._tmpl_attrs.get("app_name"),
538533
user_id=request.user_id,
539-
session_id=session_id,
540534
state=session_state,
541535
)
542536
if not session:
@@ -554,7 +548,7 @@ async def _init_session(
554548
saved_version = await artifact_service.save_artifact(
555549
app_name=self._tmpl_attrs.get("app_name"),
556550
user_id=request.user_id,
557-
session_id=session_id,
551+
session_id=session.id,
558552
filename=artifact.file_name,
559553
artifact=version_data.data,
560554
)
@@ -904,44 +898,60 @@ async def async_stream_query(
904898
def streaming_agent_run_with_events(self, request_json: str):
905899
import json
906900
from google.genai import types
901+
from google.genai.errors import ClientError
907902

908903
event_queue = queue.Queue(maxsize=1)
909904

910905
async def _invoke_agent_async():
911906
request = _StreamRunRequest(**json.loads(request_json))
912907
if not self._tmpl_attrs.get("in_memory_runner"):
913908
self.set_up()
909+
if not self._tmpl_attrs.get("runner"):
910+
self.set_up()
914911
# Prepare the in-memory session.
915912
if not self._tmpl_attrs.get("in_memory_artifact_service"):
916913
self.set_up()
914+
if not self._tmpl_attrs.get("artifact_service"):
915+
self.set_up()
917916
if not self._tmpl_attrs.get("in_memory_session_service"):
918917
self.set_up()
919-
session_service = self._tmpl_attrs.get("in_memory_session_service")
920-
artifact_service = self._tmpl_attrs.get("in_memory_artifact_service")
921-
# Try to get the session, if it doesn't exist, create a new one.
922-
session = None
918+
if not self._tmpl_attrs.get("session_service"):
919+
self.set_up()
923920
if request.session_id:
921+
session_service = self._tmpl_attrs.get("session_service")
922+
artifact_service = self._tmpl_attrs.get("artifact_service")
923+
runner = self._tmpl_attrs.get("runner")
924924
try:
925925
session = await session_service.get_session(
926926
app_name=self._tmpl_attrs.get("app_name"),
927927
user_id=request.user_id,
928928
session_id=request.session_id,
929929
)
930-
except RuntimeError:
931-
pass
932-
if not session:
933-
# Fall back to create session if the session is not found.
934-
session = await self._init_session(
935-
session_service=session_service,
936-
artifact_service=artifact_service,
937-
request=request,
930+
except ClientError:
931+
# Fall back to create session if the session is not found.
932+
# Specifying session_id on creation is not supported,
933+
# so session id will be regenerated.
934+
session = await self._init_session(
935+
session_service=session_service,
936+
artifact_service=artifact_service,
937+
request=request,
938+
)
939+
else:
940+
# Not providing a session ID will create a new in-memory session.
941+
session_service = self._tmpl_attrs.get("in_memory_session_service")
942+
artifact_service = self._tmpl_attrs.get("in_memory_artifact_service")
943+
runner = self._tmpl_attrs.get("in_memory_runner")
944+
session = await session_service.create_session(
945+
app_name=self._tmpl_attrs.get("app_name"),
946+
user_id=request.user_id,
947+
session_id=request.session_id,
938948
)
939949
if not session:
940950
raise RuntimeError("Session initialization failed.")
941951
# Run the agent.
942952
message_for_agent = types.Content(**request.message)
943953
try:
944-
for event in self._tmpl_attrs.get("in_memory_runner").run(
954+
for event in runner.run(
945955
user_id=request.user_id,
946956
session_id=session.id,
947957
new_message=message_for_agent,

0 commit comments

Comments
 (0)