1+ import copy
2+ import numpy as np
3+ from typing import Optional , Callable , Awaitable , Any , Tuple , List
4+
5+ class AsyncJOSH :
6+ def __init__ (self , rewards , agent_step : Callable , user_step : Callable , add_error_message : Callable ,
7+ root_agent , user , beam_size = 8 , max_turn_tries = 10 , agent_model = None ,
8+ agent_tokenizer = None , agent_env = None , debug = False ):
9+ self .agent_step = agent_step
10+ self .user_step = user_step
11+ self .add_error_message = add_error_message
12+ self .root = Node (root_agent , None )
13+
14+ self .current_reward = 0.0
15+ self .beam_size = beam_size
16+ self .training_examples = []
17+ self .max_turn_tries = max_turn_tries
18+ self .rewards = rewards
19+ self .num_total_rewards = len (rewards )
20+ self .golden_agent = None
21+ self .agent_model = agent_model
22+ self .agent_tokenizer = agent_tokenizer
23+ self .agent_env = agent_env
24+ self .user = user
25+ self .debug = debug
26+
27+ def set_root_agent (self , agent ):
28+ self .root .agent = agent
29+
30+ def set_success_path (self , success_node ):
31+ success_node .is_successful = True
32+ if success_node .parent is None :
33+ return
34+ self .set_success_path (success_node .parent )
35+ return
36+
37+ def set_golden_path (self , success_node ):
38+ success_node .is_golden_path = True
39+ if success_node .parent is None :
40+ return
41+ self .set_golden_path (success_node .parent )
42+ return
43+
44+ async def step_user (self ):
45+ leaves = np .array (self .root .get_leaves ())
46+ if len (leaves )== 0 :
47+ return True
48+ if self .debug :
49+ print (f'Running { len (leaves )} users' )
50+
51+ # Use asyncio.gather to run all user steps concurrently
52+ tasks = []
53+ for leaf in leaves :
54+ tasks .append (self ._process_user_leaf (leaf ))
55+
56+ await asyncio .gather (* tasks )
57+
58+ leaves = np .array (self .root .get_leaves ())
59+ return len (leaves )== 0
60+
61+ async def _process_user_leaf (self , leaf ):
62+ # Call the user_step function asynchronously
63+ leaf .agent , end_conversation = await self .user_step (self .user , leaf .agent )
64+ leaf .conversation_over = end_conversation
65+
66+ async def step_agent (self ):
67+ leaves = np .array (self .root .get_leaves ())
68+ if len (leaves )== 0 :
69+ return True
70+
71+ # Step for each leaf
72+ if self .debug :
73+ print (f'Running { len (leaves )} agents' )
74+
75+ count = 0
76+ done = np .array ([False ]* len (leaves ))
77+ training_examples = []
78+ collapse_root_to = None
79+ successful_leaves = []
80+
81+ while count < self .max_turn_tries :
82+ unfinished_leaf_indices = np .where (done == False )[0 ]
83+ if len (unfinished_leaf_indices )== 0 :
84+ break
85+
86+ unfinished_leaves = leaves [unfinished_leaf_indices ]
87+
88+ # Process all unfinished leaves concurrently
89+ tasks = []
90+ for lf in unfinished_leaves :
91+ tasks .append (self ._process_agent_leaf (lf ))
92+
93+ turn_finished_results = await asyncio .gather (* tasks )
94+
95+ # Update done status based on results
96+ for idx , (turn , turn_finished , got_reward , rw_to_delete ) in enumerate (turn_finished_results ):
97+ leaf_idx = unfinished_leaf_indices [idx ]
98+
99+ if turn_finished :
100+ done [leaf_idx ] = True
101+
102+ if got_reward :
103+ successful_leaves .append (turn )
104+ if not collapse_root_to :
105+ if self .debug :
106+ print (f'🌟 Got reward' )
107+ collapse_root_to = unfinished_leaves [idx ]
108+ rewards_to_delete = copy .deepcopy (rw_to_delete )
109+ if len (self .rewards )== 1 :
110+ self .golden_agent = turn .agent
111+
112+ if count + 1 == self .max_turn_tries and not turn_finished :
113+ turn .agent = self .add_error_message (turn .agent )
114+
115+ count += 1
116+
117+ if collapse_root_to :
118+ if self .debug :
119+ print (f'🪓👷 Collapsing tree' )
120+ # set the descendence of all successful leaves as successful
121+ for leaf in successful_leaves :
122+ self .set_success_path (leaf )
123+
124+ self .rewards .delete_reward (rewards_to_delete )
125+ self .set_golden_path (collapse_root_to )
126+ training_examples = self .root .get_tree ()
127+ for ex in training_examples :
128+ if ex not in self .training_examples :
129+ self .training_examples .append (ex )
130+ self .root = collapse_root_to
131+ self .root .parent = None
132+ self .root .is_successful = False
133+ leaves = [self .root ]
134+
135+ if self .num_total_rewards != 0 :
136+ self .current_reward = (self .num_total_rewards - len (self .rewards ))/ self .num_total_rewards
137+ else :
138+ self .current_reward = 0.0
139+
140+ all_done = len (self .rewards ) == 0
141+ return all_done
142+
143+ async def _process_agent_leaf (self , leaf ):
144+ # Call the agent_step function asynchronously
145+ try :
146+ leaf .agent , pass_to_customer = await self .agent_step (
147+ agent = leaf .agent ,
148+ model = self .agent_model ,
149+ tokenizer = self .agent_tokenizer ,
150+ env = self .agent_env
151+ )
152+ except Exception as e :
153+ if self .debug :
154+ print (f"Error in agent step: { e } " )
155+ pass_to_customer = None
156+
157+ turn_finished = True if pass_to_customer is None else pass_to_customer
158+
159+ # Check for rewards
160+ got_reward , rw_to_delete = await self ._check_rewards (leaf .agent )
161+
162+ return leaf , turn_finished , got_reward , rw_to_delete
163+
164+ async def _check_rewards (self , agent ):
165+ # Handle both synchronous and asynchronous reward checking
166+ if hasattr (self .rewards , 'is_reward_async' ):
167+ return await self .rewards .is_reward_async (agent .recent_actions )
168+ else :
169+ return self .rewards .is_reward (agent .recent_actions )
170+
171+ def expand_tree (self ):
172+ leaves = np .array (self .root .get_leaves ())
173+ make_more_leaves = len (leaves )* 2 <= self .beam_size
174+
175+ # Add messages to each leaf
176+ if make_more_leaves :
177+ if self .debug :
178+ print (f'🌲 Expanding tree to { len ([l for l in leaves if not l .conversation_over ])* 2 } leaves' )
179+ for leaf in leaves :
180+ # If the user ended the conversation, kill the leaf and keep going
181+ if leaf .conversation_over :
182+ continue
183+ # Extend leaves
184+ leaf .left = Node (copy .deepcopy (leaf .agent ), parent = leaf )
185+ leaf .right = Node (copy .deepcopy (leaf .agent ), parent = leaf )
186+ elif self .debug :
187+ print (f'🎄 Tree at maximum size' )
188+
189+ async def step (self ):
190+ self .expand_tree ()
191+
192+ all_done = await self .step_agent ()
193+ if not all_done :
194+ all_done = await self .step_user ()
195+
196+ return self .current_reward , all_done
197+
198+ # Node class from original JOSH implementation
199+ class Node :
200+ def __init__ (self , agent , parent : Optional ["Node" ]= None ):
201+ self .agent = agent
202+ self .conversation_over = False
203+ self .parent = parent
204+ self .left = None
205+ self .right = None
206+ self .is_successful = False
207+ self .is_golden_path = False
208+
209+ def get_leaves (self ):
210+ if not self .left and not self .right and not self .conversation_over : # If the node is a leaf
211+ return [self ]
212+
213+ leaves = []
214+
215+ if self .left :
216+ leaves .extend (self .left .get_leaves ())
217+ if self .right :
218+ leaves .extend (self .right .get_leaves ())
219+
220+ return leaves
221+
222+ def get_tree (self ):
223+ tree = [(trim_user_msg (copy .deepcopy (self .agent .messages_internal )), self .is_successful , self .is_golden_path )]
224+
225+ if self .left :
226+ tree .extend (self .left .get_tree ())
227+ if self .right :
228+ tree .extend (self .right .get_tree ())
229+
230+ return tree
231+
232+ # Helper function from original JOSH implementation
233+ def trim_user_msg (messages ):
234+ if len (messages ) == 0 :
235+ return []
236+ for idx , dic in enumerate (reversed (messages )):
237+ if dic .get ('role' )== 'user' :
238+ continue
239+ break
240+ if idx == 0 :
241+ return messages
242+ return messages [:- 1 * idx ]
0 commit comments