|
| 1 | +import sys |
1 | 2 | from typing import Annotated |
2 | 3 | from typing import AsyncGenerator |
3 | 4 |
|
|
6 | 7 | from google.adk.agents.invocation_context import InvocationContext |
7 | 8 | from google.adk.agents.parallel_agent import _create_branch_ctx_for_sub_agent |
8 | 9 | from google.adk.agents.parallel_agent import _merge_agent_run |
| 10 | +from google.adk.agents.parallel_agent import _merge_agent_run_pre_3_11 |
9 | 11 | from google.adk.events import Event |
10 | 12 | from google.adk.flows.llm_flows.contents import _should_include_event_in_context |
11 | 13 | from google.genai import types |
12 | 14 | from pydantic import Field |
13 | 15 | from pydantic import RootModel |
14 | 16 | from typing_extensions import override |
15 | 17 |
|
| 18 | +from ..utils.context_utils import Aclosing |
| 19 | + |
16 | 20 |
|
17 | 21 | class MapAgent(BaseAgent): |
18 | 22 | sub_agents: Annotated[list[BaseAgent], Len(1, 1)] = Field( |
@@ -63,10 +67,16 @@ async def _run_async_impl( |
63 | 67 | for prompt, agent in zip(prompts, sub_agents) |
64 | 68 | ] |
65 | 69 |
|
66 | | - async for event in _merge_agent_run( |
67 | | - [ctx.agent.run_async(ctx) for ctx in contexts] |
68 | | - ): |
69 | | - yield event |
| 70 | + agent_runs = [ctx.agent.run_async(ctx) for ctx in contexts] |
| 71 | + |
| 72 | + merge_func = ( |
| 73 | + _merge_agent_run |
| 74 | + if sys.version_info >= (3, 11) |
| 75 | + else _merge_agent_run_pre_3_11 |
| 76 | + ) |
| 77 | + async with Aclosing(merge_func(agent_runs)) as agen: |
| 78 | + async for event in agen: |
| 79 | + yield event |
70 | 80 |
|
71 | 81 | def _extract_input_prompts( |
72 | 82 | self, ctx: InvocationContext |
|
0 commit comments