@@ -520,7 +520,6 @@ async def _init_session(
520520 ):
521521 """Initializes the session, and returns the session id."""
522522 from google .adk .events .event import Event
523- import random
524523
525524 session_state = None
526525 if request .authorizations :
@@ -529,14 +528,9 @@ async def _init_session(
529528 auth = _Authorization (** auth )
530529 session_state [f"temp:{ auth_id } " ] = auth .access_token
531530
532- if request .session_id :
533- session_id = request .session_id
534- else :
535- session_id = f"temp_session_{ random .randbytes (8 ).hex ()} "
536531 session = await session_service .create_session (
537532 app_name = self ._tmpl_attrs .get ("app_name" ),
538533 user_id = request .user_id ,
539- session_id = session_id ,
540534 state = session_state ,
541535 )
542536 if not session :
@@ -554,7 +548,7 @@ async def _init_session(
554548 saved_version = await artifact_service .save_artifact (
555549 app_name = self ._tmpl_attrs .get ("app_name" ),
556550 user_id = request .user_id ,
557- session_id = session_id ,
551+ session_id = session . id ,
558552 filename = artifact .file_name ,
559553 artifact = version_data .data ,
560554 )
@@ -904,44 +898,60 @@ async def async_stream_query(
904898 def streaming_agent_run_with_events (self , request_json : str ):
905899 import json
906900 from google .genai import types
901+ from google .genai .errors import ClientError
907902
908903 event_queue = queue .Queue (maxsize = 1 )
909904
910905 async def _invoke_agent_async ():
911906 request = _StreamRunRequest (** json .loads (request_json ))
912907 if not self ._tmpl_attrs .get ("in_memory_runner" ):
913908 self .set_up ()
909+ if not self ._tmpl_attrs .get ("runner" ):
910+ self .set_up ()
914911 # Prepare the in-memory session.
915912 if not self ._tmpl_attrs .get ("in_memory_artifact_service" ):
916913 self .set_up ()
914+ if not self ._tmpl_attrs .get ("artifact_service" ):
915+ self .set_up ()
917916 if not self ._tmpl_attrs .get ("in_memory_session_service" ):
918917 self .set_up ()
919- session_service = self ._tmpl_attrs .get ("in_memory_session_service" )
920- artifact_service = self ._tmpl_attrs .get ("in_memory_artifact_service" )
921- # Try to get the session, if it doesn't exist, create a new one.
922- session = None
918+ if not self ._tmpl_attrs .get ("session_service" ):
919+ self .set_up ()
923920 if request .session_id :
921+ session_service = self ._tmpl_attrs .get ("session_service" )
922+ artifact_service = self ._tmpl_attrs .get ("artifact_service" )
923+ runner = self ._tmpl_attrs .get ("runner" )
924924 try :
925925 session = await session_service .get_session (
926926 app_name = self ._tmpl_attrs .get ("app_name" ),
927927 user_id = request .user_id ,
928928 session_id = request .session_id ,
929929 )
930- except RuntimeError :
931- pass
932- if not session :
933- # Fall back to create session if the session is not found.
934- session = await self ._init_session (
935- session_service = session_service ,
936- artifact_service = artifact_service ,
937- request = request ,
930+ except ClientError :
931+ # Fall back to create session if the session is not found.
932+ # Specifying session_id on creation is not supported,
933+ # so session id will be regenerated.
934+ session = await self ._init_session (
935+ session_service = session_service ,
936+ artifact_service = artifact_service ,
937+ request = request ,
938+ )
939+ else :
940+ # Not providing a session ID will create a new in-memory session.
941+ session_service = self ._tmpl_attrs .get ("in_memory_session_service" )
942+ artifact_service = self ._tmpl_attrs .get ("in_memory_artifact_service" )
943+ runner = self ._tmpl_attrs .get ("in_memory_runner" )
944+ session = await session_service .create_session (
945+ app_name = self ._tmpl_attrs .get ("app_name" ),
946+ user_id = request .user_id ,
947+ session_id = request .session_id ,
938948 )
939949 if not session :
940950 raise RuntimeError ("Session initialization failed." )
941951 # Run the agent.
942952 message_for_agent = types .Content (** request .message )
943953 try :
944- for event in self . _tmpl_attrs . get ( "in_memory_runner" ) .run (
954+ for event in runner .run (
945955 user_id = request .user_id ,
946956 session_id = session .id ,
947957 new_message = message_for_agent ,
0 commit comments