Skip to content

Commit 37200bc

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: Alow VertexAiSession for streaming_agent_run_with_events
PiperOrigin-RevId: 824600367
1 parent 92d8b2a commit 37200bc

File tree

1 file changed

+33
-17
lines changed
  • vertexai/agent_engines/templates

1 file changed

+33
-17
lines changed

vertexai/agent_engines/templates/adk.py

Lines changed: 33 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -559,14 +559,9 @@ async def _init_session(
559559
auth = _Authorization(**auth)
560560
session_state[f"temp:{auth_id}"] = auth.access_token
561561

562-
if request.session_id:
563-
session_id = request.session_id
564-
else:
565-
session_id = f"temp_session_{random.randbytes(8).hex()}"
566562
session = await session_service.create_session(
567563
app_name=self._tmpl_attrs.get("app_name"),
568564
user_id=request.user_id,
569-
session_id=session_id,
570565
state=session_state,
571566
)
572567
if not session:
@@ -1012,43 +1007,64 @@ async def streaming_agent_run_with_events(self, request_json: str):
10121007

10131008
import json
10141009
from google.genai import types
1010+
from google.cloud.aiplatform import base
1011+
from google.api_core.exceptions import FailedPrecondition
1012+
1013+
_LOGGER = base.Logger(__name__)
10151014

10161015
request = _StreamRunRequest(**json.loads(request_json))
1016+
_LOGGER.info("request: %s, type: %s", request, type(request))
10171017
if not self._tmpl_attrs.get("in_memory_runner"):
10181018
self.set_up()
10191019
# Prepare the in-memory session.
10201020
if not self._tmpl_attrs.get("in_memory_artifact_service"):
10211021
self.set_up()
10221022
if not self._tmpl_attrs.get("in_memory_session_service"):
10231023
self.set_up()
1024-
session_service = self._tmpl_attrs.get("in_memory_session_service")
1025-
artifact_service = self._tmpl_attrs.get("in_memory_artifact_service")
1024+
session_service = self._tmpl_attrs.get("session_service")
1025+
artifact_service = self._tmpl_attrs.get("artifact_service")
10261026
app = self._tmpl_attrs.get("app")
1027+
runner = self._tmpl_attrs.get("runner")
10271028
# Try to get the session, if it doesn't exist, create a new one.
1028-
session = None
10291029
if request.session_id:
10301030
try:
10311031
session = await session_service.get_session(
10321032
app_name=app.name if app else self._tmpl_attrs.get("app_name"),
10331033
user_id=request.user_id,
10341034
session_id=request.session_id,
10351035
)
1036-
except RuntimeError:
1037-
pass
1038-
if not session:
1039-
# Fall back to create session if the session is not found.
1040-
session = await self._init_session(
1041-
session_service=session_service,
1042-
artifact_service=artifact_service,
1043-
request=request,
1036+
except FailedPrecondition:
1037+
_LOGGER.info("Session not found, creating a new one.")
1038+
# Fall back to create session if the session is not found.
1039+
session = await self._init_session(
1040+
session_service=session_service,
1041+
artifact_service=artifact_service,
1042+
request=request,
1043+
)
1044+
except Exception as e:
1045+
_LOGGER.error("Failed to get session: %s, type: %s", e, type(e))
1046+
raise e
1047+
1048+
else:
1049+
# Not providing a session ID will create a new in-memory session.
1050+
session_service = self._tmpl_attrs.get("in_memory_session_service")
1051+
artifact_service = self._tmpl_attrs.get("in_memory_artifact_service")
1052+
runner = self._tmpl_attrs.get("in_memory_runner")
1053+
session = await session_service.create_session(
1054+
app_name=self._tmpl_attrs.get("app_name"),
1055+
user_id=request.user_id,
1056+
session_id=request.session_id,
10441057
)
10451058
if not session:
10461059
raise RuntimeError("Session initialization failed.")
10471060

1061+
_LOGGER.info("session: %s, type: %s", session, type(session))
1062+
_LOGGER.info("session.id: %s, type: %s", session.id, type(session.id))
1063+
10481064
# Run the agent
10491065
message_for_agent = types.Content(**request.message)
10501066
try:
1051-
async for event in self._tmpl_attrs.get("in_memory_runner").run_async(
1067+
async for event in runner.run_async(
10521068
user_id=request.user_id,
10531069
session_id=session.id,
10541070
new_message=message_for_agent,

0 commit comments

Comments
 (0)