Skip to content

Commit f49186c

Browse files
committed
feat(runner): add metadata parameter to Runner.run_async()
Add support for passing per-request metadata through the agent execution pipeline. This enables use cases like: - Passing user_id, trace_id, or session context to callbacks - Enabling memory injection in before_model_callback - Supporting request-specific context without using ContextVar workarounds Changes: - Add `metadata` field to LlmRequest model - Add `metadata` field to InvocationContext model - Add `metadata` parameter to Runner.run_async() and related methods - Propagate metadata from InvocationContext to LlmRequest in base_llm_flow - Add unit tests for metadata functionality Closes #2978
1 parent 0b1cff2 commit f49186c

File tree

5 files changed

+179
-3
lines changed

5 files changed

+179
-3
lines changed

src/google/adk/agents/invocation_context.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,15 @@ class InvocationContext(BaseModel):
206206
canonical_tools_cache: Optional[list[BaseTool]] = None
207207
"""The cache of canonical tools for this invocation."""
208208

209+
metadata: Optional[dict[str, Any]] = None
210+
"""Per-request metadata passed from Runner.run_async().
211+
212+
This field allows passing arbitrary metadata that can be accessed during
213+
the invocation lifecycle, particularly in callbacks like before_model_callback.
214+
Common use cases include passing user_id, trace_id, memory context keys, or
215+
other request-specific context that needs to be available during processing.
216+
"""
217+
209218
_invocation_cost_manager: _InvocationCostManager = PrivateAttr(
210219
default_factory=_InvocationCostManager
211220
)

src/google/adk/flows/llm_flows/base_llm_flow.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ async def run_live(
8989
invocation_context: InvocationContext,
9090
) -> AsyncGenerator[Event, None]:
9191
"""Runs the flow using live api."""
92-
llm_request = LlmRequest()
92+
llm_request = LlmRequest(metadata=invocation_context.metadata)
9393
event_id = Event.new_id()
9494

9595
# Preprocess before calling the LLM.
@@ -380,7 +380,7 @@ async def _run_one_step_async(
380380
invocation_context: InvocationContext,
381381
) -> AsyncGenerator[Event, None]:
382382
"""One step means one LLM call."""
383-
llm_request = LlmRequest()
383+
llm_request = LlmRequest(metadata=invocation_context.metadata)
384384

385385
# Preprocess before calling the LLM.
386386
async with Aclosing(

src/google/adk/models/llm_request.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from __future__ import annotations
1616

1717
import logging
18+
from typing import Any
1819
from typing import Optional
1920
from typing import Union
2021

@@ -99,6 +100,15 @@ class LlmRequest(BaseModel):
99100
the full history.
100101
"""
101102

103+
metadata: Optional[dict[str, Any]] = None
104+
"""Per-request metadata for callbacks and custom processing.
105+
106+
This field allows passing arbitrary metadata from the Runner.run_async()
107+
call to callbacks like before_model_callback. This is useful for passing
108+
request-specific context such as user_id, trace_id, or memory context keys
109+
that need to be available during model invocation.
110+
"""
111+
102112
def append_instructions(
103113
self, instructions: Union[list[str], types.Content]
104114
) -> list[types.Content]:

src/google/adk/runners.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,7 @@ async def run_async(
400400
new_message: Optional[types.Content] = None,
401401
state_delta: Optional[dict[str, Any]] = None,
402402
run_config: Optional[RunConfig] = None,
403+
metadata: Optional[dict[str, Any]] = None,
403404
) -> AsyncGenerator[Event, None]:
404405
"""Main entry method to run the agent in this runner.
405406
@@ -417,6 +418,9 @@ async def run_async(
417418
new_message: A new message to append to the session.
418419
state_delta: Optional state changes to apply to the session.
419420
run_config: The run config for the agent.
421+
metadata: Optional per-request metadata that will be passed to callbacks.
422+
This allows passing request-specific context such as user_id, trace_id,
423+
or memory context keys to before_model_callback and other callbacks.
420424
421425
Yields:
422426
The events generated by the agent.
@@ -433,6 +437,7 @@ async def run_async(
433437
async def _run_with_trace(
434438
new_message: Optional[types.Content] = None,
435439
invocation_id: Optional[str] = None,
440+
metadata: Optional[dict[str, Any]] = None,
436441
) -> AsyncGenerator[Event, None]:
437442
with tracer.start_as_current_span('invocation'):
438443
session = await self.session_service.get_session(
@@ -463,6 +468,7 @@ async def _run_with_trace(
463468
invocation_id=invocation_id,
464469
run_config=run_config,
465470
state_delta=state_delta,
471+
metadata=metadata,
466472
)
467473
if invocation_context.end_of_agents.get(
468474
invocation_context.agent.name
@@ -476,6 +482,7 @@ async def _run_with_trace(
476482
new_message=new_message, # new_message is not None.
477483
run_config=run_config,
478484
state_delta=state_delta,
485+
metadata=metadata,
479486
)
480487

481488
async def execute(ctx: InvocationContext) -> AsyncGenerator[Event]:
@@ -502,7 +509,9 @@ async def execute(ctx: InvocationContext) -> AsyncGenerator[Event]:
502509
self.app, session, self.session_service
503510
)
504511

505-
async with Aclosing(_run_with_trace(new_message, invocation_id)) as agen:
512+
async with Aclosing(
513+
_run_with_trace(new_message, invocation_id, metadata)
514+
) as agen:
506515
async for event in agen:
507516
yield event
508517

@@ -1186,6 +1195,7 @@ async def _setup_context_for_new_invocation(
11861195
new_message: types.Content,
11871196
run_config: RunConfig,
11881197
state_delta: Optional[dict[str, Any]],
1198+
metadata: Optional[dict[str, Any]] = None,
11891199
) -> InvocationContext:
11901200
"""Sets up the context for a new invocation.
11911201
@@ -1194,6 +1204,7 @@ async def _setup_context_for_new_invocation(
11941204
new_message: The new message to process and append to the session.
11951205
run_config: The run config of the agent.
11961206
state_delta: Optional state changes to apply to the session.
1207+
metadata: Optional per-request metadata to pass to callbacks.
11971208
11981209
Returns:
11991210
The invocation context for the new invocation.
@@ -1203,6 +1214,7 @@ async def _setup_context_for_new_invocation(
12031214
session,
12041215
new_message=new_message,
12051216
run_config=run_config,
1217+
metadata=metadata,
12061218
)
12071219
# Step 2: Handle new message, by running callbacks and appending to
12081220
# session.
@@ -1225,6 +1237,7 @@ async def _setup_context_for_resumed_invocation(
12251237
invocation_id: Optional[str],
12261238
run_config: RunConfig,
12271239
state_delta: Optional[dict[str, Any]],
1240+
metadata: Optional[dict[str, Any]] = None,
12281241
) -> InvocationContext:
12291242
"""Sets up the context for a resumed invocation.
12301243
@@ -1234,6 +1247,7 @@ async def _setup_context_for_resumed_invocation(
12341247
invocation_id: The invocation id to resume.
12351248
run_config: The run config of the agent.
12361249
state_delta: Optional state changes to apply to the session.
1250+
metadata: Optional per-request metadata to pass to callbacks.
12371251
12381252
Returns:
12391253
The invocation context for the resumed invocation.
@@ -1259,6 +1273,7 @@ async def _setup_context_for_resumed_invocation(
12591273
new_message=user_message,
12601274
run_config=run_config,
12611275
invocation_id=invocation_id,
1276+
metadata=metadata,
12621277
)
12631278
# Step 3: Maybe handle new message.
12641279
if new_message:
@@ -1303,6 +1318,7 @@ def _new_invocation_context(
13031318
new_message: Optional[types.Content] = None,
13041319
live_request_queue: Optional[LiveRequestQueue] = None,
13051320
run_config: Optional[RunConfig] = None,
1321+
metadata: Optional[dict[str, Any]] = None,
13061322
) -> InvocationContext:
13071323
"""Creates a new invocation context.
13081324
@@ -1312,6 +1328,7 @@ def _new_invocation_context(
13121328
new_message: The new message for the context.
13131329
live_request_queue: The live request queue for the context.
13141330
run_config: The run config for the context.
1331+
metadata: Optional per-request metadata for the context.
13151332
13161333
Returns:
13171334
The new invocation context.
@@ -1343,6 +1360,7 @@ def _new_invocation_context(
13431360
live_request_queue=live_request_queue,
13441361
run_config=run_config,
13451362
resumability_config=self.resumability_config,
1363+
metadata=metadata,
13461364
)
13471365

13481366
def _new_invocation_context_for_live(

tests/unittests/test_runners.py

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService
2929
from google.adk.cli.utils.agent_loader import AgentLoader
3030
from google.adk.events.event import Event
31+
from google.adk.models.llm_request import LlmRequest
32+
from google.adk.models.llm_response import LlmResponse
3133
from google.adk.plugins.base_plugin import BasePlugin
3234
from google.adk.runners import Runner
3335
from google.adk.sessions.in_memory_session_service import InMemorySessionService
@@ -1038,5 +1040,142 @@ def test_infer_agent_origin_detects_mismatch_for_user_agent(
10381040
assert "actual_name" in runner._app_name_alignment_hint
10391041

10401042

1043+
class TestRunnerMetadata:
1044+
"""Tests for Runner metadata parameter functionality."""
1045+
1046+
def setup_method(self):
1047+
"""Set up test fixtures."""
1048+
self.session_service = InMemorySessionService()
1049+
self.artifact_service = InMemoryArtifactService()
1050+
self.root_agent = MockLlmAgent("root_agent")
1051+
self.runner = Runner(
1052+
app_name="test_app",
1053+
agent=self.root_agent,
1054+
session_service=self.session_service,
1055+
artifact_service=self.artifact_service,
1056+
)
1057+
1058+
def test_new_invocation_context_with_metadata(self):
1059+
"""Test that _new_invocation_context correctly passes metadata."""
1060+
mock_session = Session(
1061+
id=TEST_SESSION_ID,
1062+
app_name=TEST_APP_ID,
1063+
user_id=TEST_USER_ID,
1064+
events=[],
1065+
)
1066+
1067+
test_metadata = {"user_id": "test123", "trace_id": "trace456"}
1068+
invocation_context = self.runner._new_invocation_context(
1069+
mock_session, metadata=test_metadata
1070+
)
1071+
1072+
assert invocation_context.metadata == test_metadata
1073+
assert invocation_context.metadata["user_id"] == "test123"
1074+
assert invocation_context.metadata["trace_id"] == "trace456"
1075+
1076+
def test_new_invocation_context_without_metadata(self):
1077+
"""Test that _new_invocation_context works without metadata."""
1078+
mock_session = Session(
1079+
id=TEST_SESSION_ID,
1080+
app_name=TEST_APP_ID,
1081+
user_id=TEST_USER_ID,
1082+
events=[],
1083+
)
1084+
1085+
invocation_context = self.runner._new_invocation_context(mock_session)
1086+
1087+
assert invocation_context.metadata is None
1088+
1089+
@pytest.mark.asyncio
1090+
async def test_run_async_passes_metadata_to_invocation_context(self):
1091+
"""Test that run_async correctly passes metadata to before_model_callback."""
1092+
# Capture metadata received in callback
1093+
captured_metadata = None
1094+
1095+
def before_model_callback(callback_context, llm_request):
1096+
nonlocal captured_metadata
1097+
captured_metadata = llm_request.metadata
1098+
# Return a response to skip actual LLM call
1099+
return LlmResponse(
1100+
content=types.Content(
1101+
role="model", parts=[types.Part(text="Test response")]
1102+
)
1103+
)
1104+
1105+
# Create agent with before_model_callback
1106+
agent_with_callback = LlmAgent(
1107+
name="callback_agent",
1108+
model="gemini-2.0-flash",
1109+
before_model_callback=before_model_callback,
1110+
)
1111+
1112+
runner_with_callback = Runner(
1113+
app_name="test_app",
1114+
agent=agent_with_callback,
1115+
session_service=self.session_service,
1116+
artifact_service=self.artifact_service,
1117+
)
1118+
1119+
session = await self.session_service.create_session(
1120+
app_name=TEST_APP_ID, user_id=TEST_USER_ID, session_id=TEST_SESSION_ID
1121+
)
1122+
1123+
test_metadata = {"experiment_id": "exp-001", "variant": "B"}
1124+
1125+
async for event in runner_with_callback.run_async(
1126+
user_id=TEST_USER_ID,
1127+
session_id=TEST_SESSION_ID,
1128+
new_message=types.Content(
1129+
role="user", parts=[types.Part(text="Hello")]
1130+
),
1131+
metadata=test_metadata,
1132+
):
1133+
pass
1134+
1135+
# Verify metadata was passed to before_model_callback
1136+
assert captured_metadata is not None
1137+
assert captured_metadata == test_metadata
1138+
assert captured_metadata["experiment_id"] == "exp-001"
1139+
assert captured_metadata["variant"] == "B"
1140+
1141+
def test_metadata_field_in_invocation_context(self):
1142+
"""Test that InvocationContext model accepts metadata field."""
1143+
mock_session = Session(
1144+
id=TEST_SESSION_ID,
1145+
app_name=TEST_APP_ID,
1146+
user_id=TEST_USER_ID,
1147+
events=[],
1148+
)
1149+
1150+
test_metadata = {"key1": "value1", "key2": 123}
1151+
1152+
# This should not raise a validation error
1153+
invocation_context = InvocationContext(
1154+
session_service=self.session_service,
1155+
invocation_id="test_inv_id",
1156+
agent=self.root_agent,
1157+
session=mock_session,
1158+
metadata=test_metadata,
1159+
)
1160+
1161+
assert invocation_context.metadata == test_metadata
1162+
1163+
def test_metadata_field_in_llm_request(self):
1164+
"""Test that LlmRequest model accepts metadata field."""
1165+
test_metadata = {"context_key": "ctx123", "user_info": {"name": "test"}}
1166+
1167+
llm_request = LlmRequest(metadata=test_metadata)
1168+
1169+
assert llm_request.metadata == test_metadata
1170+
assert llm_request.metadata["context_key"] == "ctx123"
1171+
assert llm_request.metadata["user_info"]["name"] == "test"
1172+
1173+
def test_llm_request_without_metadata(self):
1174+
"""Test that LlmRequest works without metadata."""
1175+
llm_request = LlmRequest()
1176+
1177+
assert llm_request.metadata is None
1178+
1179+
10411180
if __name__ == "__main__":
10421181
pytest.main([__file__])

0 commit comments

Comments
 (0)