3333from google .genai import types
3434from google .genai .types import Part
3535from pydantic import BaseModel
36+ from pytest import fixture
3637from pytest import mark
3738
3839from .. import testing_utils
@@ -59,112 +60,147 @@ def change_state_callback(callback_context: CallbackContext):
5960 print ('change_state_callback: ' , callback_context .state )
6061
6162
62- @mark .asyncio
63- async def test_agent_tool_inherits_parent_app_name (monkeypatch ):
64- parent_app_name = 'parent_app'
65- captured : dict [str , str ] = {}
66-
67- class RecordingSessionService (InMemorySessionService ):
68-
69- async def create_session (
70- self ,
71- * ,
72- app_name : str ,
73- user_id : str ,
74- state : Optional [dict [str , Any ]] = None ,
75- session_id : Optional [str ] = None ,
76- ):
77- captured ['session_app_name' ] = app_name
78- return await super ().create_session (
79- app_name = app_name ,
80- user_id = user_id ,
81- state = state ,
82- session_id = session_id ,
83- )
84-
85- monkeypatch .setattr (
86- 'google.adk.sessions.in_memory_session_service.InMemorySessionService' ,
87- RecordingSessionService ,
88- )
89-
63+ @fixture
64+ def agent_tool_setup_factory (monkeypatch ):
9065 async def _empty_async_generator ():
9166 if False :
9267 yield None
9368
94- class StubRunner :
69+ async def _create_setup (
70+ * ,
71+ parent_app_name : str ,
72+ parent_session_id : Optional [str ] | None = None ,
73+ capture_runner_app_name : bool = False ,
74+ capture_session_app_name : bool = False ,
75+ capture_child_session_id : bool = False ,
76+ ):
77+ captured : dict [str , Any ] = {}
78+
79+ class RecordingSessionService (InMemorySessionService ):
80+
81+ async def create_session (
82+ self ,
83+ * ,
84+ app_name : str ,
85+ user_id : str ,
86+ state : Optional [dict [str , Any ]] = None ,
87+ session_id : Optional [str ] = None ,
88+ ):
89+ if capture_session_app_name :
90+ captured ['session_app_name' ] = app_name
91+ if capture_child_session_id :
92+ captured ['child_session_id' ] = session_id
93+ return await super ().create_session (
94+ app_name = app_name ,
95+ user_id = user_id ,
96+ state = state ,
97+ session_id = session_id ,
98+ )
9599
96- def __init__ (
97- self ,
98- * ,
99- app_name : str ,
100- agent : Agent ,
101- artifact_service ,
102- session_service ,
103- memory_service ,
104- credential_service ,
105- plugins ,
106- ):
107- del artifact_service , memory_service , credential_service
108- captured ['runner_app_name' ] = app_name
109- self .agent = agent
110- self .session_service = session_service
111- self .plugin_manager = PluginManager (plugins = plugins )
112- self .app_name = app_name
113-
114- def run_async (
115- self ,
116- * ,
117- user_id : str ,
118- session_id : str ,
119- invocation_id : Optional [str ] = None ,
120- new_message : Optional [types .Content ] = None ,
121- state_delta : Optional [dict [str , Any ]] = None ,
122- run_config : Optional [RunConfig ] = None ,
123- ):
124- del (
125- user_id ,
126- session_id ,
127- invocation_id ,
128- new_message ,
129- state_delta ,
130- run_config ,
131- )
132- return _empty_async_generator ()
133-
134- async def close (self ):
135- """Mock close method."""
136- pass
137-
138- monkeypatch .setattr ('google.adk.runners.Runner' , StubRunner )
100+ monkeypatch .setattr (
101+ 'google.adk.sessions.in_memory_session_service.InMemorySessionService' ,
102+ RecordingSessionService ,
103+ )
139104
140- tool_agent = Agent (
141- name = 'tool_agent' ,
142- model = 'test-model' ,
143- )
144- agent_tool = AgentTool (agent = tool_agent )
145- root_agent = Agent (
146- name = 'root_agent' ,
147- model = 'test-model' ,
148- tools = [agent_tool ],
149- )
105+ class StubRunner :
106+
107+ def __init__ (
108+ self ,
109+ * ,
110+ app_name : str ,
111+ agent : Agent ,
112+ artifact_service ,
113+ session_service ,
114+ memory_service ,
115+ credential_service ,
116+ plugins ,
117+ ):
118+ del artifact_service , memory_service , credential_service
119+ if capture_runner_app_name :
120+ captured ['runner_app_name' ] = app_name
121+ self .agent = agent
122+ self .session_service = session_service
123+ self .plugin_manager = PluginManager (plugins = plugins )
124+ self .app_name = app_name
125+
126+ def run_async (
127+ self ,
128+ * ,
129+ user_id : str ,
130+ session_id : str ,
131+ invocation_id : Optional [str ] = None ,
132+ new_message : Optional [types .Content ] = None ,
133+ state_delta : Optional [dict [str , Any ]] = None ,
134+ run_config : Optional [RunConfig ] = None ,
135+ ):
136+ del (
137+ user_id ,
138+ session_id ,
139+ invocation_id ,
140+ new_message ,
141+ state_delta ,
142+ run_config ,
143+ )
144+ return _empty_async_generator ()
150145
151- artifact_service = InMemoryArtifactService ()
152- parent_session_service = InMemorySessionService ()
153- parent_session = await parent_session_service .create_session (
154- app_name = parent_app_name ,
155- user_id = 'user' ,
156- )
157- invocation_context = InvocationContext (
158- artifact_service = artifact_service ,
159- session_service = parent_session_service ,
160- memory_service = InMemoryMemoryService (),
161- plugin_manager = PluginManager (),
162- invocation_id = 'invocation-id' ,
163- agent = root_agent ,
164- session = parent_session ,
165- run_config = RunConfig (),
146+ async def close (self ):
147+ """Mock close method."""
148+ pass
149+
150+ monkeypatch .setattr ('google.adk.runners.Runner' , StubRunner )
151+
152+ tool_agent = Agent (
153+ name = 'tool_agent' ,
154+ model = 'test-model' ,
155+ )
156+ agent_tool = AgentTool (agent = tool_agent )
157+ root_agent = Agent (
158+ name = 'root_agent' ,
159+ model = 'test-model' ,
160+ tools = [agent_tool ],
161+ )
162+
163+ artifact_service = InMemoryArtifactService ()
164+ parent_session_service = InMemorySessionService ()
165+ parent_session = await parent_session_service .create_session (
166+ app_name = parent_app_name ,
167+ user_id = 'user' ,
168+ session_id = parent_session_id ,
169+ )
170+ invocation_context = InvocationContext (
171+ artifact_service = artifact_service ,
172+ session_service = parent_session_service ,
173+ memory_service = InMemoryMemoryService (),
174+ plugin_manager = PluginManager (),
175+ invocation_id = 'invocation-id' ,
176+ agent = root_agent ,
177+ session = parent_session ,
178+ run_config = RunConfig (),
179+ )
180+ tool_context = ToolContext (invocation_context )
181+
182+ return {
183+ 'agent_tool' : agent_tool ,
184+ 'tool_context' : tool_context ,
185+ 'captured' : captured ,
186+ }
187+
188+ return _create_setup
189+
190+
191+ @mark .asyncio
192+ async def test_agent_tool_inherits_parent_app_name (agent_tool_setup_factory ):
193+ parent_app_name = 'parent_app'
194+
195+ setup = await agent_tool_setup_factory (
196+ parent_app_name = parent_app_name ,
197+ capture_runner_app_name = True ,
198+ capture_session_app_name = True ,
166199 )
167- tool_context = ToolContext (invocation_context )
200+
201+ agent_tool = setup ['agent_tool' ]
202+ tool_context = setup ['tool_context' ]
203+ captured = setup ['captured' ]
168204
169205 assert tool_context ._invocation_context .app_name == parent_app_name
170206
@@ -177,6 +213,32 @@ async def close(self):
177213 assert captured ['session_app_name' ] == parent_app_name
178214
179215
216+ @mark .asyncio
217+ async def test_agent_tool_passes_parent_session_id (agent_tool_setup_factory ):
218+ """Test that the parent session ID is passed to the child session."""
219+ parent_app_name = 'parent_app'
220+ parent_session_id = 'parent-session-123'
221+ setup = await agent_tool_setup_factory (
222+ parent_app_name = parent_app_name ,
223+ parent_session_id = parent_session_id ,
224+ capture_child_session_id = True ,
225+ )
226+
227+ agent_tool = setup ['agent_tool' ]
228+ tool_context = setup ['tool_context' ]
229+ captured = setup ['captured' ]
230+
231+ assert tool_context ._invocation_context .session .id == parent_session_id
232+
233+ await agent_tool .run_async (
234+ args = {'request' : 'hello' },
235+ tool_context = tool_context ,
236+ )
237+
238+ # Verify that the parent session ID was passed to the child session
239+ assert captured ['child_session_id' ] == parent_session_id
240+
241+
180242def test_no_schema ():
181243 mock_model = testing_utils .MockModel .create (
182244 responses = [
0 commit comments