Skip to content

Commit 8f9cd4e

Browse files
committed
feat: add termination event to step() function
1 parent 0372137 commit 8f9cd4e

File tree

1 file changed

+8
-15
lines changed

1 file changed

+8
-15
lines changed

owl/utils/enhanced_role_playing.py

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -211,18 +211,15 @@ def _construct_gaia_sys_msgs(self):
211211

212212
return user_sys_msg, assistant_sys_msg
213213

214+
import threading
214215
def step(
215-
self, assistant_msg: BaseMessage
216+
self, assistant_msg: BaseMessage, stop_event: Optional[threading.Event] = None
216217
) -> Tuple[ChatAgentResponse, ChatAgentResponse]:
217-
user_response = self.user_agent.step(assistant_msg)
218+
user_response = self.user_agent.step(assistant_msg, stop_event)
218219
if user_response.terminated or user_response.msgs is None:
219220
return (
220-
ChatAgentResponse(msgs=[], terminated=False, info={}),
221-
ChatAgentResponse(
222-
msgs=[],
223-
terminated=user_response.terminated,
224-
info=user_response.info,
225-
),
221+
ChatAgentResponse(msgs=[assistant_msg], terminated=False, info={}),
222+
user_response
226223
)
227224
user_msg = self._reduce_message_options(user_response.msgs)
228225

@@ -244,16 +241,12 @@ def step(
244241
"""
245242

246243
# process assistant's response
247-
assistant_response = self.assistant_agent.step(modified_user_msg)
244+
assistant_response = self.assistant_agent.step(modified_user_msg, stop_event)
248245
if assistant_response.terminated or assistant_response.msgs is None:
249246
return (
247+
assistant_response,
250248
ChatAgentResponse(
251-
msgs=[],
252-
terminated=assistant_response.terminated,
253-
info=assistant_response.info,
254-
),
255-
ChatAgentResponse(
256-
msgs=[user_msg], terminated=False, info=user_response.info
249+
msgs=[modified_user_msg], terminated=False, info=user_response.info
257250
),
258251
)
259252
assistant_msg = self._reduce_message_options(assistant_response.msgs)

0 commit comments

Comments
 (0)