Skip to content

Commit 071f89f

Browse files
authored
direct tool call - interrupt not allowed (#1097)
1 parent 2147920 commit 071f89f

File tree

2 files changed

+29
-3
lines changed

2 files changed

+29
-3
lines changed

src/strands/agent/agent.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@
5555
from ..tools.registry import ToolRegistry
5656
from ..tools.structured_output._structured_output_context import StructuredOutputContext
5757
from ..tools.watcher import ToolWatcher
58-
from ..types._events import AgentResultEvent, InitEventLoopEvent, ModelStreamChunkEvent, TypedEvent
58+
from ..types._events import AgentResultEvent, InitEventLoopEvent, ModelStreamChunkEvent, ToolInterruptEvent, TypedEvent
5959
from ..types.agent import AgentInput
6060
from ..types.content import ContentBlock, Message, Messages
6161
from ..types.exceptions import ContextWindowOverflowException
@@ -166,7 +166,9 @@ def caller(
166166

167167
async def acall() -> ToolResult:
168168
async for event in ToolExecutor._stream(self._agent, tool_use, tool_results, invocation_state):
169-
_ = event
169+
if isinstance(event, ToolInterruptEvent):
170+
self._agent._interrupt_state.deactivate()
171+
raise RuntimeError("cannot raise interrupt in direct tool call")
170172

171173
return tool_results[0]
172174

tests/strands/agent/test_agent.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2054,7 +2054,31 @@ def test_agent_structured_output_interrupt(user):
20542054
agent.structured_output(type(user), "invalid")
20552055

20562056

2057-
def test_agent_tool_caller_interrupt(user):
2057+
def test_agent_tool_caller_interrupt():
2058+
@strands.tool(context=True)
2059+
def test_tool(tool_context):
2060+
tool_context.interrupt("test-interrupt")
2061+
2062+
agent = Agent(tools=[test_tool])
2063+
2064+
exp_message = r"cannot raise interrupt in direct tool call"
2065+
with pytest.raises(RuntimeError, match=exp_message):
2066+
agent.tool.test_tool(agent=agent)
2067+
2068+
tru_state = agent._interrupt_state.to_dict()
2069+
exp_state = {
2070+
"activated": False,
2071+
"context": {},
2072+
"interrupts": {},
2073+
}
2074+
assert tru_state == exp_state
2075+
2076+
tru_messages = agent.messages
2077+
exp_messages = []
2078+
assert tru_messages == exp_messages
2079+
2080+
2081+
def test_agent_tool_caller_interrupt_activated():
20582082
agent = Agent()
20592083
agent._interrupt_state.activated = True
20602084

0 commit comments

Comments
 (0)