Skip to content

Commit 48d5b0a

Browse files
Pass parent session ID to child sessions in AgentTool
When AgentTool creates a child session for a sub-agent, it now passes the parent session ID to maintain session continuity across agent boundaries. This ensures that child agents can properly track and reference their parent session context. Changes: - Added session_id parameter to create_session call in agent_tool.py - Added test_agent_tool_passes_parent_session_id to verify the behavior Testing: - All 16 unit tests pass - New test specifically validates session ID propagation
1 parent 71b3289 commit 48d5b0a

File tree

2 files changed

+161
-98
lines changed

2 files changed

+161
-98
lines changed

src/google/adk/tools/agent_tool.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ async def run_async(
167167
app_name=child_app_name,
168168
user_id=tool_context._invocation_context.user_id,
169169
state=state_dict,
170+
session_id=tool_context._invocation_context.session.id,
170171
)
171172

172173
last_content = None

tests/unittests/tools/test_agent_tool.py

Lines changed: 160 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from google.genai import types
3434
from google.genai.types import Part
3535
from pydantic import BaseModel
36+
from pytest import fixture
3637
from pytest import mark
3738

3839
from .. import testing_utils
@@ -59,112 +60,147 @@ def change_state_callback(callback_context: CallbackContext):
5960
print('change_state_callback: ', callback_context.state)
6061

6162

62-
@mark.asyncio
63-
async def test_agent_tool_inherits_parent_app_name(monkeypatch):
64-
parent_app_name = 'parent_app'
65-
captured: dict[str, str] = {}
66-
67-
class RecordingSessionService(InMemorySessionService):
68-
69-
async def create_session(
70-
self,
71-
*,
72-
app_name: str,
73-
user_id: str,
74-
state: Optional[dict[str, Any]] = None,
75-
session_id: Optional[str] = None,
76-
):
77-
captured['session_app_name'] = app_name
78-
return await super().create_session(
79-
app_name=app_name,
80-
user_id=user_id,
81-
state=state,
82-
session_id=session_id,
83-
)
84-
85-
monkeypatch.setattr(
86-
'google.adk.sessions.in_memory_session_service.InMemorySessionService',
87-
RecordingSessionService,
88-
)
89-
63+
@fixture
64+
def agent_tool_setup_factory(monkeypatch):
9065
async def _empty_async_generator():
9166
if False:
9267
yield None
9368

94-
class StubRunner:
69+
async def _create_setup(
70+
*,
71+
parent_app_name: str,
72+
parent_session_id: Optional[str] | None = None,
73+
capture_runner_app_name: bool = False,
74+
capture_session_app_name: bool = False,
75+
capture_child_session_id: bool = False,
76+
):
77+
captured: dict[str, Any] = {}
78+
79+
class RecordingSessionService(InMemorySessionService):
80+
81+
async def create_session(
82+
self,
83+
*,
84+
app_name: str,
85+
user_id: str,
86+
state: Optional[dict[str, Any]] = None,
87+
session_id: Optional[str] = None,
88+
):
89+
if capture_session_app_name:
90+
captured['session_app_name'] = app_name
91+
if capture_child_session_id:
92+
captured['child_session_id'] = session_id
93+
return await super().create_session(
94+
app_name=app_name,
95+
user_id=user_id,
96+
state=state,
97+
session_id=session_id,
98+
)
9599

96-
def __init__(
97-
self,
98-
*,
99-
app_name: str,
100-
agent: Agent,
101-
artifact_service,
102-
session_service,
103-
memory_service,
104-
credential_service,
105-
plugins,
106-
):
107-
del artifact_service, memory_service, credential_service
108-
captured['runner_app_name'] = app_name
109-
self.agent = agent
110-
self.session_service = session_service
111-
self.plugin_manager = PluginManager(plugins=plugins)
112-
self.app_name = app_name
113-
114-
def run_async(
115-
self,
116-
*,
117-
user_id: str,
118-
session_id: str,
119-
invocation_id: Optional[str] = None,
120-
new_message: Optional[types.Content] = None,
121-
state_delta: Optional[dict[str, Any]] = None,
122-
run_config: Optional[RunConfig] = None,
123-
):
124-
del (
125-
user_id,
126-
session_id,
127-
invocation_id,
128-
new_message,
129-
state_delta,
130-
run_config,
131-
)
132-
return _empty_async_generator()
133-
134-
async def close(self):
135-
"""Mock close method."""
136-
pass
137-
138-
monkeypatch.setattr('google.adk.runners.Runner', StubRunner)
100+
monkeypatch.setattr(
101+
'google.adk.sessions.in_memory_session_service.InMemorySessionService',
102+
RecordingSessionService,
103+
)
139104

140-
tool_agent = Agent(
141-
name='tool_agent',
142-
model='test-model',
143-
)
144-
agent_tool = AgentTool(agent=tool_agent)
145-
root_agent = Agent(
146-
name='root_agent',
147-
model='test-model',
148-
tools=[agent_tool],
149-
)
105+
class StubRunner:
106+
107+
def __init__(
108+
self,
109+
*,
110+
app_name: str,
111+
agent: Agent,
112+
artifact_service,
113+
session_service,
114+
memory_service,
115+
credential_service,
116+
plugins,
117+
):
118+
del artifact_service, memory_service, credential_service
119+
if capture_runner_app_name:
120+
captured['runner_app_name'] = app_name
121+
self.agent = agent
122+
self.session_service = session_service
123+
self.plugin_manager = PluginManager(plugins=plugins)
124+
self.app_name = app_name
125+
126+
def run_async(
127+
self,
128+
*,
129+
user_id: str,
130+
session_id: str,
131+
invocation_id: Optional[str] = None,
132+
new_message: Optional[types.Content] = None,
133+
state_delta: Optional[dict[str, Any]] = None,
134+
run_config: Optional[RunConfig] = None,
135+
):
136+
del (
137+
user_id,
138+
session_id,
139+
invocation_id,
140+
new_message,
141+
state_delta,
142+
run_config,
143+
)
144+
return _empty_async_generator()
150145

151-
artifact_service = InMemoryArtifactService()
152-
parent_session_service = InMemorySessionService()
153-
parent_session = await parent_session_service.create_session(
154-
app_name=parent_app_name,
155-
user_id='user',
156-
)
157-
invocation_context = InvocationContext(
158-
artifact_service=artifact_service,
159-
session_service=parent_session_service,
160-
memory_service=InMemoryMemoryService(),
161-
plugin_manager=PluginManager(),
162-
invocation_id='invocation-id',
163-
agent=root_agent,
164-
session=parent_session,
165-
run_config=RunConfig(),
146+
async def close(self):
147+
"""Mock close method."""
148+
pass
149+
150+
monkeypatch.setattr('google.adk.runners.Runner', StubRunner)
151+
152+
tool_agent = Agent(
153+
name='tool_agent',
154+
model='test-model',
155+
)
156+
agent_tool = AgentTool(agent=tool_agent)
157+
root_agent = Agent(
158+
name='root_agent',
159+
model='test-model',
160+
tools=[agent_tool],
161+
)
162+
163+
artifact_service = InMemoryArtifactService()
164+
parent_session_service = InMemorySessionService()
165+
parent_session = await parent_session_service.create_session(
166+
app_name=parent_app_name,
167+
user_id='user',
168+
session_id=parent_session_id,
169+
)
170+
invocation_context = InvocationContext(
171+
artifact_service=artifact_service,
172+
session_service=parent_session_service,
173+
memory_service=InMemoryMemoryService(),
174+
plugin_manager=PluginManager(),
175+
invocation_id='invocation-id',
176+
agent=root_agent,
177+
session=parent_session,
178+
run_config=RunConfig(),
179+
)
180+
tool_context = ToolContext(invocation_context)
181+
182+
return {
183+
'agent_tool': agent_tool,
184+
'tool_context': tool_context,
185+
'captured': captured,
186+
}
187+
188+
return _create_setup
189+
190+
191+
@mark.asyncio
192+
async def test_agent_tool_inherits_parent_app_name(agent_tool_setup_factory):
193+
parent_app_name = 'parent_app'
194+
195+
setup = await agent_tool_setup_factory(
196+
parent_app_name=parent_app_name,
197+
capture_runner_app_name=True,
198+
capture_session_app_name=True,
166199
)
167-
tool_context = ToolContext(invocation_context)
200+
201+
agent_tool = setup['agent_tool']
202+
tool_context = setup['tool_context']
203+
captured = setup['captured']
168204

169205
assert tool_context._invocation_context.app_name == parent_app_name
170206

@@ -177,6 +213,32 @@ async def close(self):
177213
assert captured['session_app_name'] == parent_app_name
178214

179215

216+
@mark.asyncio
217+
async def test_agent_tool_passes_parent_session_id(agent_tool_setup_factory):
218+
"""Test that the parent session ID is passed to the child session."""
219+
parent_app_name = 'parent_app'
220+
parent_session_id = 'parent-session-123'
221+
setup = await agent_tool_setup_factory(
222+
parent_app_name=parent_app_name,
223+
parent_session_id=parent_session_id,
224+
capture_child_session_id=True,
225+
)
226+
227+
agent_tool = setup['agent_tool']
228+
tool_context = setup['tool_context']
229+
captured = setup['captured']
230+
231+
assert tool_context._invocation_context.session.id == parent_session_id
232+
233+
await agent_tool.run_async(
234+
args={'request': 'hello'},
235+
tool_context=tool_context,
236+
)
237+
238+
# Verify that the parent session ID was passed to the child session
239+
assert captured['child_session_id'] == parent_session_id
240+
241+
180242
def test_no_schema():
181243
mock_model = testing_utils.MockModel.create(
182244
responses=[

0 commit comments

Comments
 (0)