Skip to content

Commit 6c92088

Browse files
Merge pull request #6 from asappresearch/async_support
Async support
2 parents 545e858 + 0e9fcf2 commit 6c92088

File tree

2 files changed

+243
-1
lines changed

2 files changed

+243
-1
lines changed

josh_train/async_josh.py

Lines changed: 242 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,242 @@
1+
import copy
2+
import numpy as np
3+
from typing import Optional, Callable, Awaitable, Any, Tuple, List
4+
5+
class AsyncJOSH:
6+
def __init__(self, rewards, agent_step: Callable, user_step: Callable, add_error_message: Callable,
7+
root_agent, user, beam_size=8, max_turn_tries=10, agent_model=None,
8+
agent_tokenizer=None, agent_env=None, debug=False):
9+
self.agent_step = agent_step
10+
self.user_step = user_step
11+
self.add_error_message = add_error_message
12+
self.root = Node(root_agent, None)
13+
14+
self.current_reward = 0.0
15+
self.beam_size = beam_size
16+
self.training_examples = []
17+
self.max_turn_tries = max_turn_tries
18+
self.rewards = rewards
19+
self.num_total_rewards = len(rewards)
20+
self.golden_agent = None
21+
self.agent_model = agent_model
22+
self.agent_tokenizer = agent_tokenizer
23+
self.agent_env = agent_env
24+
self.user = user
25+
self.debug = debug
26+
27+
def set_root_agent(self, agent):
28+
self.root.agent = agent
29+
30+
def set_success_path(self, success_node):
31+
success_node.is_successful = True
32+
if success_node.parent is None:
33+
return
34+
self.set_success_path(success_node.parent)
35+
return
36+
37+
def set_golden_path(self, success_node):
38+
success_node.is_golden_path = True
39+
if success_node.parent is None:
40+
return
41+
self.set_golden_path(success_node.parent)
42+
return
43+
44+
async def step_user(self):
45+
leaves = np.array(self.root.get_leaves())
46+
if len(leaves)==0:
47+
return True
48+
if self.debug:
49+
print(f'Running {len(leaves)} users')
50+
51+
# Use asyncio.gather to run all user steps concurrently
52+
tasks = []
53+
for leaf in leaves:
54+
tasks.append(self._process_user_leaf(leaf))
55+
56+
await asyncio.gather(*tasks)
57+
58+
leaves = np.array(self.root.get_leaves())
59+
return len(leaves)==0
60+
61+
async def _process_user_leaf(self, leaf):
62+
# Call the user_step function asynchronously
63+
leaf.agent, end_conversation = await self.user_step(self.user, leaf.agent)
64+
leaf.conversation_over = end_conversation
65+
66+
async def step_agent(self):
67+
leaves = np.array(self.root.get_leaves())
68+
if len(leaves)==0:
69+
return True
70+
71+
# Step for each leaf
72+
if self.debug:
73+
print(f'Running {len(leaves)} agents')
74+
75+
count = 0
76+
done = np.array([False]*len(leaves))
77+
training_examples = []
78+
collapse_root_to = None
79+
successful_leaves = []
80+
81+
while count < self.max_turn_tries:
82+
unfinished_leaf_indices = np.where(done==False)[0]
83+
if len(unfinished_leaf_indices)==0:
84+
break
85+
86+
unfinished_leaves = leaves[unfinished_leaf_indices]
87+
88+
# Process all unfinished leaves concurrently
89+
tasks = []
90+
for lf in unfinished_leaves:
91+
tasks.append(self._process_agent_leaf(lf))
92+
93+
turn_finished_results = await asyncio.gather(*tasks)
94+
95+
# Update done status based on results
96+
for idx, (turn, turn_finished, got_reward, rw_to_delete) in enumerate(turn_finished_results):
97+
leaf_idx = unfinished_leaf_indices[idx]
98+
99+
if turn_finished:
100+
done[leaf_idx] = True
101+
102+
if got_reward:
103+
successful_leaves.append(turn)
104+
if not collapse_root_to:
105+
if self.debug:
106+
print(f'🌟 Got reward')
107+
collapse_root_to = unfinished_leaves[idx]
108+
rewards_to_delete = copy.deepcopy(rw_to_delete)
109+
if len(self.rewards)==1:
110+
self.golden_agent = turn.agent
111+
112+
if count+1 == self.max_turn_tries and not turn_finished:
113+
turn.agent = self.add_error_message(turn.agent)
114+
115+
count += 1
116+
117+
if collapse_root_to:
118+
if self.debug:
119+
print(f'🪓👷 Collapsing tree')
120+
# set the descendence of all successful leaves as successful
121+
for leaf in successful_leaves:
122+
self.set_success_path(leaf)
123+
124+
self.rewards.delete_reward(rewards_to_delete)
125+
self.set_golden_path(collapse_root_to)
126+
training_examples = self.root.get_tree()
127+
for ex in training_examples:
128+
if ex not in self.training_examples:
129+
self.training_examples.append(ex)
130+
self.root = collapse_root_to
131+
self.root.parent=None
132+
self.root.is_successful = False
133+
leaves = [self.root]
134+
135+
if self.num_total_rewards != 0:
136+
self.current_reward = (self.num_total_rewards-len(self.rewards))/self.num_total_rewards
137+
else:
138+
self.current_reward = 0.0
139+
140+
all_done = len(self.rewards) == 0
141+
return all_done
142+
143+
async def _process_agent_leaf(self, leaf):
144+
# Call the agent_step function asynchronously
145+
try:
146+
leaf.agent, pass_to_customer = await self.agent_step(
147+
agent=leaf.agent,
148+
model=self.agent_model,
149+
tokenizer=self.agent_tokenizer,
150+
env=self.agent_env
151+
)
152+
except Exception as e:
153+
if self.debug:
154+
print(f"Error in agent step: {e}")
155+
pass_to_customer = None
156+
157+
turn_finished = True if pass_to_customer is None else pass_to_customer
158+
159+
# Check for rewards
160+
got_reward, rw_to_delete = await self._check_rewards(leaf.agent)
161+
162+
return leaf, turn_finished, got_reward, rw_to_delete
163+
164+
async def _check_rewards(self, agent):
165+
# Handle both synchronous and asynchronous reward checking
166+
if hasattr(self.rewards, 'is_reward_async'):
167+
return await self.rewards.is_reward_async(agent.recent_actions)
168+
else:
169+
return self.rewards.is_reward(agent.recent_actions)
170+
171+
def expand_tree(self):
172+
leaves = np.array(self.root.get_leaves())
173+
make_more_leaves = len(leaves)*2 <= self.beam_size
174+
175+
# Add messages to each leaf
176+
if make_more_leaves:
177+
if self.debug:
178+
print(f'🌲 Expanding tree to {len([l for l in leaves if not l.conversation_over])*2} leaves')
179+
for leaf in leaves:
180+
# If the user ended the conversation, kill the leaf and keep going
181+
if leaf.conversation_over:
182+
continue
183+
# Extend leaves
184+
leaf.left = Node(copy.deepcopy(leaf.agent), parent=leaf)
185+
leaf.right = Node(copy.deepcopy(leaf.agent), parent=leaf)
186+
elif self.debug:
187+
print(f'🎄 Tree at maximum size')
188+
189+
async def step(self):
190+
self.expand_tree()
191+
192+
all_done = await self.step_agent()
193+
if not all_done:
194+
all_done = await self.step_user()
195+
196+
return self.current_reward, all_done
197+
198+
# Node class from original JOSH implementation
199+
class Node:
200+
def __init__(self, agent, parent: Optional["Node"]=None):
201+
self.agent = agent
202+
self.conversation_over = False
203+
self.parent = parent
204+
self.left = None
205+
self.right = None
206+
self.is_successful = False
207+
self.is_golden_path = False
208+
209+
def get_leaves(self):
210+
if not self.left and not self.right and not self.conversation_over: # If the node is a leaf
211+
return [self]
212+
213+
leaves = []
214+
215+
if self.left:
216+
leaves.extend(self.left.get_leaves())
217+
if self.right:
218+
leaves.extend(self.right.get_leaves())
219+
220+
return leaves
221+
222+
def get_tree(self):
223+
tree = [(trim_user_msg(copy.deepcopy(self.agent.messages_internal)), self.is_successful, self.is_golden_path)]
224+
225+
if self.left:
226+
tree.extend(self.left.get_tree())
227+
if self.right:
228+
tree.extend(self.right.get_tree())
229+
230+
return tree
231+
232+
# Helper function from original JOSH implementation
233+
def trim_user_msg(messages):
234+
if len(messages) == 0:
235+
return []
236+
for idx, dic in enumerate(reversed(messages)):
237+
if dic.get('role')=='user':
238+
continue
239+
break
240+
if idx == 0:
241+
return messages
242+
return messages[:-1*idx]

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def get_requires(requires_filename: str) -> List[str]:
2525
install_requires=get_requires("requirements.txt"),
2626
include_package_data=True,
2727
setup_requires=["setuptools_scm"],
28-
version="0.1.1",
28+
version="0.1.2",
2929
classifiers=[
3030
"Programming Language :: Python :: 3",
3131
"License :: OSI Approved :: MIT License",

0 commit comments

Comments
 (0)