Skip to content

Commit 3fd6c93

Browse files
committed
feat(runner): add metadata parameter to Runner.run_async()
Add support for passing per-request metadata through the agent execution pipeline. This enables use cases like: - Passing user_id, trace_id, or session context to callbacks - Enabling memory injection in before_model_callback - Supporting request-specific context without using ContextVar workarounds Changes: - Add `metadata` field to LlmRequest model - Add `metadata` field to InvocationContext model - Add `metadata` parameter to Runner.run_async() and related methods - Propagate metadata from InvocationContext to LlmRequest in base_llm_flow - Add unit tests for metadata functionality Closes #2978
1 parent 0b1cff2 commit 3fd6c93

File tree

5 files changed

+181
-3
lines changed

5 files changed

+181
-3
lines changed

src/google/adk/agents/invocation_context.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,15 @@ class InvocationContext(BaseModel):
206206
canonical_tools_cache: Optional[list[BaseTool]] = None
207207
"""The cache of canonical tools for this invocation."""
208208

209+
metadata: Optional[dict[str, Any]] = None
210+
"""Per-request metadata passed from Runner.run_async().
211+
212+
This field allows passing arbitrary metadata that can be accessed during
213+
the invocation lifecycle, particularly in callbacks like before_model_callback.
214+
Common use cases include passing user_id, trace_id, memory context keys, or
215+
other request-specific context that needs to be available during processing.
216+
"""
217+
209218
_invocation_cost_manager: _InvocationCostManager = PrivateAttr(
210219
default_factory=_InvocationCostManager
211220
)

src/google/adk/flows/llm_flows/base_llm_flow.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ async def run_live(
8989
invocation_context: InvocationContext,
9090
) -> AsyncGenerator[Event, None]:
9191
"""Runs the flow using live api."""
92-
llm_request = LlmRequest()
92+
llm_request = LlmRequest(metadata=invocation_context.metadata)
9393
event_id = Event.new_id()
9494

9595
# Preprocess before calling the LLM.
@@ -380,7 +380,7 @@ async def _run_one_step_async(
380380
invocation_context: InvocationContext,
381381
) -> AsyncGenerator[Event, None]:
382382
"""One step means one LLM call."""
383-
llm_request = LlmRequest()
383+
llm_request = LlmRequest(metadata=invocation_context.metadata)
384384

385385
# Preprocess before calling the LLM.
386386
async with Aclosing(

src/google/adk/models/llm_request.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from __future__ import annotations
1616

1717
import logging
18+
from typing import Any
1819
from typing import Optional
1920
from typing import Union
2021

@@ -99,6 +100,15 @@ class LlmRequest(BaseModel):
99100
the full history.
100101
"""
101102

103+
metadata: Optional[dict[str, Any]] = None
104+
"""Per-request metadata for callbacks and custom processing.
105+
106+
This field allows passing arbitrary metadata from the Runner.run_async()
107+
call to callbacks like before_model_callback. This is useful for passing
108+
request-specific context such as user_id, trace_id, or memory context keys
109+
that need to be available during model invocation.
110+
"""
111+
102112
def append_instructions(
103113
self, instructions: Union[list[str], types.Content]
104114
) -> list[types.Content]:

src/google/adk/runners.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,7 @@ async def run_async(
400400
new_message: Optional[types.Content] = None,
401401
state_delta: Optional[dict[str, Any]] = None,
402402
run_config: Optional[RunConfig] = None,
403+
metadata: Optional[dict[str, Any]] = None,
403404
) -> AsyncGenerator[Event, None]:
404405
"""Main entry method to run the agent in this runner.
405406
@@ -417,6 +418,9 @@ async def run_async(
417418
new_message: A new message to append to the session.
418419
state_delta: Optional state changes to apply to the session.
419420
run_config: The run config for the agent.
421+
metadata: Optional per-request metadata that will be passed to callbacks.
422+
This allows passing request-specific context such as user_id, trace_id,
423+
or memory context keys to before_model_callback and other callbacks.
420424
421425
Yields:
422426
The events generated by the agent.
@@ -426,13 +430,16 @@ async def run_async(
426430
new_message are None.
427431
"""
428432
run_config = run_config or RunConfig()
433+
# Create a shallow copy to isolate from caller's modifications
434+
metadata = metadata.copy() if metadata else None
429435

430436
if new_message and not new_message.role:
431437
new_message.role = 'user'
432438

433439
async def _run_with_trace(
434440
new_message: Optional[types.Content] = None,
435441
invocation_id: Optional[str] = None,
442+
metadata: Optional[dict[str, Any]] = None,
436443
) -> AsyncGenerator[Event, None]:
437444
with tracer.start_as_current_span('invocation'):
438445
session = await self.session_service.get_session(
@@ -463,6 +470,7 @@ async def _run_with_trace(
463470
invocation_id=invocation_id,
464471
run_config=run_config,
465472
state_delta=state_delta,
473+
metadata=metadata,
466474
)
467475
if invocation_context.end_of_agents.get(
468476
invocation_context.agent.name
@@ -476,6 +484,7 @@ async def _run_with_trace(
476484
new_message=new_message, # new_message is not None.
477485
run_config=run_config,
478486
state_delta=state_delta,
487+
metadata=metadata,
479488
)
480489

481490
async def execute(ctx: InvocationContext) -> AsyncGenerator[Event]:
@@ -502,7 +511,9 @@ async def execute(ctx: InvocationContext) -> AsyncGenerator[Event]:
502511
self.app, session, self.session_service
503512
)
504513

505-
async with Aclosing(_run_with_trace(new_message, invocation_id)) as agen:
514+
async with Aclosing(
515+
_run_with_trace(new_message, invocation_id, metadata)
516+
) as agen:
506517
async for event in agen:
507518
yield event
508519

@@ -1186,6 +1197,7 @@ async def _setup_context_for_new_invocation(
11861197
new_message: types.Content,
11871198
run_config: RunConfig,
11881199
state_delta: Optional[dict[str, Any]],
1200+
metadata: Optional[dict[str, Any]] = None,
11891201
) -> InvocationContext:
11901202
"""Sets up the context for a new invocation.
11911203
@@ -1194,6 +1206,7 @@ async def _setup_context_for_new_invocation(
11941206
new_message: The new message to process and append to the session.
11951207
run_config: The run config of the agent.
11961208
state_delta: Optional state changes to apply to the session.
1209+
metadata: Optional per-request metadata to pass to callbacks.
11971210
11981211
Returns:
11991212
The invocation context for the new invocation.
@@ -1203,6 +1216,7 @@ async def _setup_context_for_new_invocation(
12031216
session,
12041217
new_message=new_message,
12051218
run_config=run_config,
1219+
metadata=metadata,
12061220
)
12071221
# Step 2: Handle new message, by running callbacks and appending to
12081222
# session.
@@ -1225,6 +1239,7 @@ async def _setup_context_for_resumed_invocation(
12251239
invocation_id: Optional[str],
12261240
run_config: RunConfig,
12271241
state_delta: Optional[dict[str, Any]],
1242+
metadata: Optional[dict[str, Any]] = None,
12281243
) -> InvocationContext:
12291244
"""Sets up the context for a resumed invocation.
12301245
@@ -1234,6 +1249,7 @@ async def _setup_context_for_resumed_invocation(
12341249
invocation_id: The invocation id to resume.
12351250
run_config: The run config of the agent.
12361251
state_delta: Optional state changes to apply to the session.
1252+
metadata: Optional per-request metadata to pass to callbacks.
12371253
12381254
Returns:
12391255
The invocation context for the resumed invocation.
@@ -1259,6 +1275,7 @@ async def _setup_context_for_resumed_invocation(
12591275
new_message=user_message,
12601276
run_config=run_config,
12611277
invocation_id=invocation_id,
1278+
metadata=metadata,
12621279
)
12631280
# Step 3: Maybe handle new message.
12641281
if new_message:
@@ -1303,6 +1320,7 @@ def _new_invocation_context(
13031320
new_message: Optional[types.Content] = None,
13041321
live_request_queue: Optional[LiveRequestQueue] = None,
13051322
run_config: Optional[RunConfig] = None,
1323+
metadata: Optional[dict[str, Any]] = None,
13061324
) -> InvocationContext:
13071325
"""Creates a new invocation context.
13081326
@@ -1312,6 +1330,7 @@ def _new_invocation_context(
13121330
new_message: The new message for the context.
13131331
live_request_queue: The live request queue for the context.
13141332
run_config: The run config for the context.
1333+
metadata: Optional per-request metadata for the context.
13151334
13161335
Returns:
13171336
The new invocation context.
@@ -1343,6 +1362,7 @@ def _new_invocation_context(
13431362
live_request_queue=live_request_queue,
13441363
run_config=run_config,
13451364
resumability_config=self.resumability_config,
1365+
metadata=metadata,
13461366
)
13471367

13481368
def _new_invocation_context_for_live(

tests/unittests/test_runners.py

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService
2929
from google.adk.cli.utils.agent_loader import AgentLoader
3030
from google.adk.events.event import Event
31+
from google.adk.models.llm_request import LlmRequest
32+
from google.adk.models.llm_response import LlmResponse
3133
from google.adk.plugins.base_plugin import BasePlugin
3234
from google.adk.runners import Runner
3335
from google.adk.sessions.in_memory_session_service import InMemorySessionService
@@ -1038,5 +1040,142 @@ def test_infer_agent_origin_detects_mismatch_for_user_agent(
10381040
assert "actual_name" in runner._app_name_alignment_hint
10391041

10401042

1043+
class TestRunnerMetadata:
1044+
"""Tests for Runner metadata parameter functionality."""
1045+
1046+
def setup_method(self):
1047+
"""Set up test fixtures."""
1048+
self.session_service = InMemorySessionService()
1049+
self.artifact_service = InMemoryArtifactService()
1050+
self.root_agent = MockLlmAgent("root_agent")
1051+
self.runner = Runner(
1052+
app_name="test_app",
1053+
agent=self.root_agent,
1054+
session_service=self.session_service,
1055+
artifact_service=self.artifact_service,
1056+
)
1057+
1058+
def test_new_invocation_context_with_metadata(self):
1059+
"""Test that _new_invocation_context correctly passes metadata."""
1060+
mock_session = Session(
1061+
id=TEST_SESSION_ID,
1062+
app_name=TEST_APP_ID,
1063+
user_id=TEST_USER_ID,
1064+
events=[],
1065+
)
1066+
1067+
test_metadata = {"user_id": "test123", "trace_id": "trace456"}
1068+
invocation_context = self.runner._new_invocation_context(
1069+
mock_session, metadata=test_metadata
1070+
)
1071+
1072+
assert invocation_context.metadata == test_metadata
1073+
assert invocation_context.metadata["user_id"] == "test123"
1074+
assert invocation_context.metadata["trace_id"] == "trace456"
1075+
1076+
def test_new_invocation_context_without_metadata(self):
1077+
"""Test that _new_invocation_context works without metadata."""
1078+
mock_session = Session(
1079+
id=TEST_SESSION_ID,
1080+
app_name=TEST_APP_ID,
1081+
user_id=TEST_USER_ID,
1082+
events=[],
1083+
)
1084+
1085+
invocation_context = self.runner._new_invocation_context(mock_session)
1086+
1087+
assert invocation_context.metadata is None
1088+
1089+
@pytest.mark.asyncio
1090+
async def test_run_async_passes_metadata_to_invocation_context(self):
1091+
"""Test that run_async correctly passes metadata to before_model_callback."""
1092+
# Capture metadata received in callback
1093+
captured_metadata = None
1094+
1095+
def before_model_callback(callback_context, llm_request):
1096+
nonlocal captured_metadata
1097+
captured_metadata = llm_request.metadata
1098+
# Return a response to skip actual LLM call
1099+
return LlmResponse(
1100+
content=types.Content(
1101+
role="model", parts=[types.Part(text="Test response")]
1102+
)
1103+
)
1104+
1105+
# Create agent with before_model_callback
1106+
agent_with_callback = LlmAgent(
1107+
name="callback_agent",
1108+
model="gemini-2.0-flash",
1109+
before_model_callback=before_model_callback,
1110+
)
1111+
1112+
runner_with_callback = Runner(
1113+
app_name="test_app",
1114+
agent=agent_with_callback,
1115+
session_service=self.session_service,
1116+
artifact_service=self.artifact_service,
1117+
)
1118+
1119+
session = await self.session_service.create_session(
1120+
app_name=TEST_APP_ID, user_id=TEST_USER_ID, session_id=TEST_SESSION_ID
1121+
)
1122+
1123+
test_metadata = {"experiment_id": "exp-001", "variant": "B"}
1124+
1125+
async for event in runner_with_callback.run_async(
1126+
user_id=TEST_USER_ID,
1127+
session_id=TEST_SESSION_ID,
1128+
new_message=types.Content(
1129+
role="user", parts=[types.Part(text="Hello")]
1130+
),
1131+
metadata=test_metadata,
1132+
):
1133+
pass
1134+
1135+
# Verify metadata was passed to before_model_callback
1136+
assert captured_metadata is not None
1137+
assert captured_metadata == test_metadata
1138+
assert captured_metadata["experiment_id"] == "exp-001"
1139+
assert captured_metadata["variant"] == "B"
1140+
1141+
def test_metadata_field_in_invocation_context(self):
1142+
"""Test that InvocationContext model accepts metadata field."""
1143+
mock_session = Session(
1144+
id=TEST_SESSION_ID,
1145+
app_name=TEST_APP_ID,
1146+
user_id=TEST_USER_ID,
1147+
events=[],
1148+
)
1149+
1150+
test_metadata = {"key1": "value1", "key2": 123}
1151+
1152+
# This should not raise a validation error
1153+
invocation_context = InvocationContext(
1154+
session_service=self.session_service,
1155+
invocation_id="test_inv_id",
1156+
agent=self.root_agent,
1157+
session=mock_session,
1158+
metadata=test_metadata,
1159+
)
1160+
1161+
assert invocation_context.metadata == test_metadata
1162+
1163+
def test_metadata_field_in_llm_request(self):
1164+
"""Test that LlmRequest model accepts metadata field."""
1165+
test_metadata = {"context_key": "ctx123", "user_info": {"name": "test"}}
1166+
1167+
llm_request = LlmRequest(metadata=test_metadata)
1168+
1169+
assert llm_request.metadata == test_metadata
1170+
assert llm_request.metadata["context_key"] == "ctx123"
1171+
assert llm_request.metadata["user_info"]["name"] == "test"
1172+
1173+
def test_llm_request_without_metadata(self):
1174+
"""Test that LlmRequest works without metadata."""
1175+
llm_request = LlmRequest()
1176+
1177+
assert llm_request.metadata is None
1178+
1179+
10411180
if __name__ == "__main__":
10421181
pytest.main([__file__])

0 commit comments

Comments
 (0)