Skip to content

Commit 5dc8cf5

Browse files
Pass parent session ID to child sessions in AgentTool
When AgentTool creates a child session for a sub-agent, it now passes the parent session ID to maintain session continuity across agent boundaries. This ensures that child agents can properly track and reference their parent session context. Changes: - Added session_id parameter to create_session call in agent_tool.py - Added test_agent_tool_passes_parent_session_id to verify the behavior Testing: - All 16 unit tests pass - New test specifically validates session ID propagation
1 parent 322dd18 commit 5dc8cf5

File tree

2 files changed

+121
-0
lines changed

2 files changed

+121
-0
lines changed

src/google/adk/tools/agent_tool.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ async def run_async(
167167
app_name=child_app_name,
168168
user_id=tool_context._invocation_context.user_id,
169169
state=state_dict,
170+
session_id=tool_context._invocation_context.session.id,
170171
)
171172

172173
last_content = None

tests/unittests/tools/test_agent_tool.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,126 @@ async def close(self):
177177
assert captured['session_app_name'] == parent_app_name
178178

179179

180+
@mark.asyncio
181+
async def test_agent_tool_passes_parent_session_id(monkeypatch):
182+
"""Test that the parent session ID is passed to the child session."""
183+
parent_app_name = 'parent_app'
184+
parent_session_id = 'parent-session-123'
185+
captured: dict[str, str] = {}
186+
187+
class RecordingSessionService(InMemorySessionService):
188+
189+
async def create_session(
190+
self,
191+
*,
192+
app_name: str,
193+
user_id: str,
194+
state: Optional[dict[str, Any]] = None,
195+
session_id: Optional[str] = None,
196+
):
197+
captured['child_session_id'] = session_id
198+
return await super().create_session(
199+
app_name=app_name,
200+
user_id=user_id,
201+
state=state,
202+
session_id=session_id,
203+
)
204+
205+
monkeypatch.setattr(
206+
'google.adk.sessions.in_memory_session_service.InMemorySessionService',
207+
RecordingSessionService,
208+
)
209+
210+
async def _empty_async_generator():
211+
if False:
212+
yield None
213+
214+
class StubRunner:
215+
216+
def __init__(
217+
self,
218+
*,
219+
app_name: str,
220+
agent: Agent,
221+
artifact_service,
222+
session_service,
223+
memory_service,
224+
credential_service,
225+
plugins,
226+
):
227+
del artifact_service, memory_service, credential_service
228+
self.agent = agent
229+
self.session_service = session_service
230+
self.plugin_manager = PluginManager(plugins=plugins)
231+
self.app_name = app_name
232+
233+
def run_async(
234+
self,
235+
*,
236+
user_id: str,
237+
session_id: str,
238+
invocation_id: Optional[str] = None,
239+
new_message: Optional[types.Content] = None,
240+
state_delta: Optional[dict[str, Any]] = None,
241+
run_config: Optional[RunConfig] = None,
242+
):
243+
del (
244+
user_id,
245+
session_id,
246+
invocation_id,
247+
new_message,
248+
state_delta,
249+
run_config,
250+
)
251+
return _empty_async_generator()
252+
253+
async def close(self):
254+
"""Mock close method."""
255+
pass
256+
257+
monkeypatch.setattr('google.adk.runners.Runner', StubRunner)
258+
259+
tool_agent = Agent(
260+
name='tool_agent',
261+
model='test-model',
262+
)
263+
agent_tool = AgentTool(agent=tool_agent)
264+
root_agent = Agent(
265+
name='root_agent',
266+
model='test-model',
267+
tools=[agent_tool],
268+
)
269+
270+
artifact_service = InMemoryArtifactService()
271+
parent_session_service = InMemorySessionService()
272+
parent_session = await parent_session_service.create_session(
273+
app_name=parent_app_name,
274+
user_id='user',
275+
session_id=parent_session_id,
276+
)
277+
invocation_context = InvocationContext(
278+
artifact_service=artifact_service,
279+
session_service=parent_session_service,
280+
memory_service=InMemoryMemoryService(),
281+
plugin_manager=PluginManager(),
282+
invocation_id='invocation-id',
283+
agent=root_agent,
284+
session=parent_session,
285+
run_config=RunConfig(),
286+
)
287+
tool_context = ToolContext(invocation_context)
288+
289+
assert tool_context._invocation_context.session.id == parent_session_id
290+
291+
await agent_tool.run_async(
292+
args={'request': 'hello'},
293+
tool_context=tool_context,
294+
)
295+
296+
# Verify that the parent session ID was passed to the child session
297+
assert captured['child_session_id'] == parent_session_id
298+
299+
180300
def test_no_schema():
181301
mock_model = testing_utils.MockModel.create(
182302
responses=[

0 commit comments

Comments
 (0)