@@ -559,14 +559,9 @@ async def _init_session(
559559 auth = _Authorization (** auth )
560560 session_state [f"temp:{ auth_id } " ] = auth .access_token
561561
562- if request .session_id :
563- session_id = request .session_id
564- else :
565- session_id = f"temp_session_{ random .randbytes (8 ).hex ()} "
566562 session = await session_service .create_session (
567563 app_name = self ._tmpl_attrs .get ("app_name" ),
568564 user_id = request .user_id ,
569- session_id = session_id ,
570565 state = session_state ,
571566 )
572567 if not session :
@@ -1012,43 +1007,64 @@ async def streaming_agent_run_with_events(self, request_json: str):
10121007
10131008 import json
10141009 from google .genai import types
1010+ from google .cloud .aiplatform import base
1011+ from google .api_core .exceptions import FailedPrecondition
1012+
1013+ _LOGGER = base .Logger (__name__ )
10151014
10161015 request = _StreamRunRequest (** json .loads (request_json ))
1016+ _LOGGER .info ("request: %s, type: %s" , request , type (request ))
10171017 if not self ._tmpl_attrs .get ("in_memory_runner" ):
10181018 self .set_up ()
10191019 # Prepare the in-memory session.
10201020 if not self ._tmpl_attrs .get ("in_memory_artifact_service" ):
10211021 self .set_up ()
10221022 if not self ._tmpl_attrs .get ("in_memory_session_service" ):
10231023 self .set_up ()
1024- session_service = self ._tmpl_attrs .get ("in_memory_session_service " )
1025- artifact_service = self ._tmpl_attrs .get ("in_memory_artifact_service " )
1024+ session_service = self ._tmpl_attrs .get ("session_service " )
1025+ artifact_service = self ._tmpl_attrs .get ("artifact_service " )
10261026 app = self ._tmpl_attrs .get ("app" )
1027+ runner = self ._tmpl_attrs .get ("runner" )
10271028 # Try to get the session, if it doesn't exist, create a new one.
1028- session = None
10291029 if request .session_id :
10301030 try :
10311031 session = await session_service .get_session (
10321032 app_name = app .name if app else self ._tmpl_attrs .get ("app_name" ),
10331033 user_id = request .user_id ,
10341034 session_id = request .session_id ,
10351035 )
1036- except RuntimeError :
1037- pass
1038- if not session :
1039- # Fall back to create session if the session is not found.
1040- session = await self ._init_session (
1041- session_service = session_service ,
1042- artifact_service = artifact_service ,
1043- request = request ,
1036+ except FailedPrecondition :
1037+ _LOGGER .info ("Session not found, creating a new one." )
1038+ # Fall back to create session if the session is not found.
1039+ session = await self ._init_session (
1040+ session_service = session_service ,
1041+ artifact_service = artifact_service ,
1042+ request = request ,
1043+ )
1044+ except Exception as e :
1045+ _LOGGER .error ("Failed to get session: %s, type: %s" , e , type (e ))
1046+ raise e
1047+
1048+ else :
1049+ # Not providing a session ID will create a new in-memory session.
1050+ session_service = self ._tmpl_attrs .get ("in_memory_session_service" )
1051+ artifact_service = self ._tmpl_attrs .get ("in_memory_artifact_service" )
1052+ runner = self ._tmpl_attrs .get ("in_memory_runner" )
1053+ session = await session_service .create_session (
1054+ app_name = self ._tmpl_attrs .get ("app_name" ),
1055+ user_id = request .user_id ,
1056+ session_id = request .session_id ,
10441057 )
10451058 if not session :
10461059 raise RuntimeError ("Session initialization failed." )
10471060
1061+ _LOGGER .info ("session: %s, type: %s" , session , type (session ))
1062+ _LOGGER .info ("session.id: %s, type: %s" , session .id , type (session .id ))
1063+
10481064 # Run the agent
10491065 message_for_agent = types .Content (** request .message )
10501066 try :
1051- async for event in self . _tmpl_attrs . get ( "in_memory_runner" ) .run_async (
1067+ async for event in runner .run_async (
10521068 user_id = request .user_id ,
10531069 session_id = session .id ,
10541070 new_message = message_for_agent ,
0 commit comments