Skip to content

Commit 90af80c

Browse files
fix: AgentTask deadlock when on_enter awaits generate_reply that triggers another AgentTask (#5377)
Co-authored-by: devin-ai-integration[bot] <158243242+devin-ai-integration[bot]@users.noreply.github.com>
1 parent fda016b commit 90af80c

3 files changed

Lines changed: 171 additions & 13 deletions

File tree

livekit-agents/livekit/agents/voice/agent.py

Lines changed: 40 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -848,9 +848,14 @@ def _handle_task_done(_: asyncio.Task[Any]) -> None:
848848

849849
# TODO(theomonnom): could the RunResult watcher & the blocked_tasks share the same logic?
850850
self.__inactive_ev.clear()
851+
suspended_handles: list[SpeechHandle | asyncio.Task[Any]] = []
852+
pending_on_enter_task: asyncio.Task[None] | None = None
851853
try:
854+
# use wait_on_enter=False to avoid deadlock: on_enter may spawn nested
855+
# AgentTasks that require user input, but session.run() can't return until
856+
# all watched handles complete — creating a circular wait.
852857
await session._update_activity(
853-
self, previous_activity="pause", blocked_tasks=blocked_tasks
858+
self, previous_activity="pause", blocked_tasks=blocked_tasks, wait_on_enter=False
854859
)
855860

856861
if not self._activity and not self.done():
@@ -860,14 +865,29 @@ def _handle_task_done(_: asyncio.Task[Any]) -> None:
860865
)
861866
)
862867

863-
# NOTE: _update_activity is calling the on_enter method, so the RunResult can capture all speeches
864868
run_state = session._global_run_state
865-
if speech_handle and run_state and not run_state.done():
866-
# make sure to not deadlock on the current speech handle
867-
run_state._unwatch_handle(speech_handle)
868-
# it is OK to call _mark_done_if_needed here, the above _update_activity will call on_enter
869-
# so handles added inside the on_enter will make sure we're not completing the run_state too early.
870-
run_state._mark_done_if_needed(None)
869+
870+
if self._activity and (on_enter_task := self._activity._on_enter_task):
871+
if run_state and not run_state.done():
872+
# watch the on_enter task as a guard so RunResult won't complete
873+
# before on_enter has registered its own speech handles
874+
run_state._watch_handle(on_enter_task)
875+
pending_on_enter_task = on_enter_task
876+
else:
877+
# no active run to guard — just wait for on_enter directly
878+
await asyncio.shield(on_enter_task)
879+
880+
# now unwatch the parent speech handle and blocked tasks that belong to the
881+
# old activity — they can't complete while this AgentTask is running, and
882+
# keeping them watched would block RunResult from completing.
883+
if run_state and not run_state.done():
884+
if speech_handle and run_state._unwatch_handle(speech_handle):
885+
suspended_handles.append(speech_handle)
886+
for task in blocked_tasks:
887+
if run_state._unwatch_handle(task):
888+
suspended_handles.append(task)
889+
if suspended_handles:
890+
run_state._mark_done_if_needed(None)
871891
except Exception:
872892
self.__inactive_ev.set()
873893
raise
@@ -883,24 +903,32 @@ def _handle_task_done(_: asyncio.Task[Any]) -> None:
883903
# run_state could have changed after self.__fut
884904
run_state = session._global_run_state
885905

906+
# re-watch the suspended handles so the resumed parent activity
907+
# is tracked by the current RunResult again
908+
if run_state and not run_state.done():
909+
for handle in suspended_handles:
910+
run_state._watch_handle(handle)
911+
912+
if pending_on_enter_task:
913+
try:
914+
await asyncio.shield(pending_on_enter_task)
915+
except BaseException:
916+
logger.exception("error in on_enter task of agent %s", self.id)
917+
886918
if session.current_agent != self:
887919
logger.warning(
888920
f"{self.__class__.__name__} completed, but the agent has changed in the meantime. "
889921
"Ignoring handoff to the previous agent, likely due to `AgentSession.update_agent` being invoked."
890922
)
891923
await old_activity.aclose()
892924
else:
893-
if speech_handle and run_state and not run_state.done():
894-
run_state._watch_handle(speech_handle)
895-
896925
merged_chat_ctx = old_agent.chat_ctx.merge(
897926
self.chat_ctx,
898927
exclude_function_call=not self._preserve_function_call_history,
899928
exclude_instructions=True,
900929
)
901930
# set the chat_ctx directly, `session._update_activity` will sync it to the rt_session if needed
902931
old_agent._chat_ctx.items[:] = merged_chat_ctx.items
903-
# await old_agent.update_chat_ctx(merged_chat_ctx)
904932

905933
await session._update_activity(
906934
old_agent, new_activity="resume", wait_on_enter=False

livekit-agents/livekit/agents/voice/run_result.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,12 +169,16 @@ def _watch_handle(self, handle: SpeechHandle | asyncio.Task) -> None:
169169

170170
handle.add_done_callback(self._mark_done_if_needed)
171171

172-
def _unwatch_handle(self, handle: SpeechHandle | asyncio.Task) -> None:
172+
def _unwatch_handle(self, handle: SpeechHandle | asyncio.Task) -> bool:
173+
if handle not in self._handles:
174+
return False
175+
173176
self._handles.discard(handle)
174177
handle.remove_done_callback(self._mark_done_if_needed)
175178

176179
if isinstance(handle, SpeechHandle):
177180
handle._remove_item_added_callback(self._item_added)
181+
return True
178182

179183
def _mark_done_if_needed(self, handle: SpeechHandle | asyncio.Task | None) -> None:
180184
if isinstance(handle, SpeechHandle):

tests/test_nested_agent_task.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
from __future__ import annotations
2+
3+
import asyncio
4+
5+
import pytest
6+
7+
from livekit.agents import Agent, AgentSession, AgentTask, RunContext, function_tool
8+
from livekit.agents.llm import FunctionToolCall
9+
10+
from .fake_llm import FakeLLM, FakeLLMResponse
11+
12+
13+
class InnerTask(AgentTask):
14+
"""A task that needs a second user turn to complete (user must trigger 'finish')."""
15+
16+
def __init__(self) -> None:
17+
super().__init__(instructions="inner task")
18+
19+
async def on_enter(self) -> None:
20+
self.session.generate_reply(instructions="inner_greeting")
21+
22+
@function_tool
23+
async def finish(self, ctx: RunContext) -> str:
24+
"""Called to complete the inner task."""
25+
self.complete(None)
26+
return "done"
27+
28+
29+
class OuterTask(AgentTask):
30+
"""A task whose on_enter triggers a tool call that awaits InnerTask."""
31+
32+
def __init__(self) -> None:
33+
super().__init__(instructions="outer task")
34+
35+
async def on_enter(self) -> None:
36+
await self.session.generate_reply(instructions="outer_greeting")
37+
38+
@function_tool
39+
async def start_inner(self, ctx: RunContext) -> str:
40+
"""Transitions into InnerTask."""
41+
await InnerTask()
42+
self.complete(None)
43+
return "inner completed"
44+
45+
46+
class RootAgent(Agent):
47+
def __init__(self) -> None:
48+
super().__init__(instructions="root agent")
49+
50+
@function_tool
51+
async def start_outer(self, ctx: RunContext) -> str:
52+
"""Transitions into OuterTask."""
53+
await OuterTask()
54+
return "outer completed"
55+
56+
57+
@pytest.mark.asyncio
58+
async def test_nested_agent_task_no_deadlock():
59+
"""session.run() must return when an AgentTask hands off to a nested task
60+
that collects additional user input before completing."""
61+
llm = _build_fake_llm()
62+
async with AgentSession(llm=llm) as sess:
63+
await sess.start(RootAgent())
64+
65+
# This must not deadlock — it should return once the on_enter chain
66+
# has started, even though InnerTask is still waiting for user input.
67+
first_result = await asyncio.wait_for(sess.run(user_input="go"), timeout=5.0)
68+
assert first_result is not None
69+
70+
# Now complete InnerTask by triggering the finish tool
71+
second_result = await asyncio.wait_for(sess.run(user_input="done"), timeout=5.0)
72+
assert second_result is not None
73+
74+
75+
def _build_fake_llm() -> FakeLLM:
76+
return FakeLLM(
77+
fake_responses=[
78+
# user says "go" -> LLM calls start_outer
79+
FakeLLMResponse(
80+
input="go",
81+
content="",
82+
ttft=0,
83+
duration=0,
84+
tool_calls=[FunctionToolCall(name="start_outer", arguments="{}", call_id="call_1")],
85+
),
86+
# OuterTask on_enter generate_reply(instructions="outer_greeting")
87+
# -> LLM calls start_inner
88+
FakeLLMResponse(
89+
input="outer_greeting",
90+
content="",
91+
ttft=0,
92+
duration=0,
93+
tool_calls=[FunctionToolCall(name="start_inner", arguments="{}", call_id="call_2")],
94+
),
95+
# InnerTask on_enter generate_reply(instructions="inner_greeting")
96+
# -> LLM just says hello (no tool call yet — needs user input to finish)
97+
FakeLLMResponse(
98+
input="inner_greeting",
99+
content="hello from inner",
100+
ttft=0,
101+
duration=0,
102+
),
103+
# user says "done" -> LLM calls finish
104+
FakeLLMResponse(
105+
input="done",
106+
content="",
107+
ttft=0,
108+
duration=0,
109+
tool_calls=[FunctionToolCall(name="finish", arguments="{}", call_id="call_3")],
110+
),
111+
# after finish tool output, LLM responds to start_inner tool output
112+
FakeLLMResponse(
113+
input="inner completed",
114+
content="",
115+
ttft=0,
116+
duration=0,
117+
),
118+
# after start_outer tool output, LLM responds
119+
FakeLLMResponse(
120+
input="outer completed",
121+
content="all done",
122+
ttft=0,
123+
duration=0,
124+
),
125+
]
126+
)

0 commit comments

Comments
 (0)