@@ -211,18 +211,15 @@ def _construct_gaia_sys_msgs(self):
211211
212212 return user_sys_msg , assistant_sys_msg
213213
214+ import threading
214215 def step (
215- self , assistant_msg : BaseMessage
216+ self , assistant_msg : BaseMessage , stop_event : Optional [ threading . Event ] = None
216217 ) -> Tuple [ChatAgentResponse , ChatAgentResponse ]:
217- user_response = self .user_agent .step (assistant_msg )
218+ user_response = self .user_agent .step (assistant_msg , stop_event )
218219 if user_response .terminated or user_response .msgs is None :
219220 return (
220- ChatAgentResponse (msgs = [], terminated = False , info = {}),
221- ChatAgentResponse (
222- msgs = [],
223- terminated = user_response .terminated ,
224- info = user_response .info ,
225- ),
221+ ChatAgentResponse (msgs = [assistant_msg ], terminated = False , info = {}),
222+ user_response
226223 )
227224 user_msg = self ._reduce_message_options (user_response .msgs )
228225
@@ -244,16 +241,12 @@ def step(
244241 """
245242
246243 # process assistant's response
247- assistant_response = self .assistant_agent .step (modified_user_msg )
244+ assistant_response = self .assistant_agent .step (modified_user_msg , stop_event )
248245 if assistant_response .terminated or assistant_response .msgs is None :
249246 return (
247+ assistant_response ,
250248 ChatAgentResponse (
251- msgs = [],
252- terminated = assistant_response .terminated ,
253- info = assistant_response .info ,
254- ),
255- ChatAgentResponse (
256- msgs = [user_msg ], terminated = False , info = user_response .info
249+ msgs = [modified_user_msg ], terminated = False , info = user_response .info
257250 ),
258251 )
259252 assistant_msg = self ._reduce_message_options (assistant_response .msgs )
0 commit comments