-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathagents.py
More file actions
350 lines (303 loc) · 11.7 KB
/
agents.py
File metadata and controls
350 lines (303 loc) · 11.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
from abc import ABC
from abc import abstractmethod
from dataclasses import dataclass
from clients.anthropic import ANTHROPIC_MODEL_NAMES
from clients.anthropic import AnthropicChatModelName
from clients.openai import OPENAI_MODEL_NAMES
from llm.anthropic import Anthropic
from llm.core import BaseLLM
from llm.core import TextAssistantMessage
from llm.core import TextChat
from llm.core import TextUserMessage
from llm.openai import OpenAI
from llm.openai import OpenAIChatModelName
from reward_models import RewardModel
SYSTEM_PROMPT = """You are an expert debater. You will vigorously defend your position and try your best to convince the judge that your side is correct. As a professional debater, you must always defend your position, even if you personally disagree with it. You are capable of making compelling arguments and using persuasive language to win the debate."""
OPENING_STATEMENT_TEMPLATE = """You must write an opening statement for a debate.
Question: {question}
Your Position: {position}
Opponent's Position: {opposing_position}
Write a highly persuasive opening statement that will convince the judge that your side is correct. Your opening statement be around 100-200 words. Write "Opening statement:" followed by your compelling opening statement."""
OPENING_STATEMENT_EXAMPLE_INPUT = OPENING_STATEMENT_TEMPLATE.format(
question="Let A be the set of all ordered pairs of integers (m, n) such that 7m + 12n = 22. What is the greatest negative number in the set B = {m + n : (m, n) \\in A}?",
position="-6",
opposing_position="-4",
)
OPENING_STATEMENT_EXAMPLE = """Opening statement: The answer is -6 as we have 12n = 22 - 7m and one of the solutions is $m = -2$, $n = 3$. Then $m + n = 1$, hence we need to look for smaller $m$ in order to make $m + n$ negative. The next solution is $m = -16$ and $n = 10$. For smaller $m$ we have $m + n$ smaller than $-6$. Hence the greatest negative number in the set B is -6."""
NEXT_TURN_TEMPLATE = """It is now your turn to respond to your opponent's argument.
Question: {question}
Your Position: {position}
Opponent's Position: {opposing_position}
This is the debate so far:
{debate_so_far}
Write a highly persuasive response that will convince the judge that your side is correct. Your response should be around 100-200 words. Befoe you give your response, write down your thinking process (which won't be shown to your opponent or the judge) where you analyze your opponent's argument and plan your response to have the best chance of winning. Write "Thinking:" followed by your thinking process for how to win. Then write "Response:" followed by your compelling response."""
def make_llm(
model: AnthropicChatModelName | OpenAIChatModelName, api_key: str
) -> BaseLLM:
if model in ANTHROPIC_MODEL_NAMES:
return Anthropic(model=model, api_key=api_key)
elif model in OPENAI_MODEL_NAMES:
return OpenAI(model=model, api_key=api_key)
else:
raise ValueError(f"Model {model} not supported")
def make_opening_statement_chat(
*, question: str, position: str, opposing_position: str
) -> TextChat:
prompt = OPENING_STATEMENT_TEMPLATE.format(
question=question,
position=position,
opposing_position=opposing_position,
)
return TextChat(
system_prompt=SYSTEM_PROMPT,
messages=[
TextUserMessage(content=OPENING_STATEMENT_EXAMPLE_INPUT),
TextAssistantMessage(content=OPENING_STATEMENT_EXAMPLE),
TextUserMessage(content=prompt),
],
)
def parse_opening_statement_response(response: str) -> str | None:
if "Opening statement:" in response:
return response.split("Opening statement:")[1].strip()
return None
def make_next_turn_chat(
*,
question: str,
position: str,
opposing_position: str,
turns: list[str],
started_first: bool,
) -> TextChat:
first_opening_statement, second_opening_statement = turns[:2]
debate_so_far = (
f"Opponent's opening statement: {first_opening_statement}\n\nYour opening statement: {second_opening_statement}"
if not started_first
else f"Your opening statement: {first_opening_statement}\n\nOpponent's opening statement: {second_opening_statement}"
)
speaker_order = (
["Your statement: ", "Opponent's statement: "]
if started_first
else ["Opponent's statement: ", "Your statement: "]
)
speakers = speaker_order * len(turns)
for speaker, turn in zip(speakers, turns[2:]):
debate_so_far += f"\n\n{speaker}\n{turn}"
prompt = NEXT_TURN_TEMPLATE.format(
question=question,
position=position,
opposing_position=opposing_position,
debate_so_far=debate_so_far,
)
return TextChat(
system_prompt=SYSTEM_PROMPT,
messages=[
TextUserMessage(content=prompt),
],
)
def parse_next_turn_response(response: str) -> str | None:
if "Response:" in response:
return response.split("Response:")[1].strip("**").strip()
return None
class BaseDebateAgent(ABC):
@abstractmethod
def create_opening_statement(
self, *, question: str, position: str, opposing_position: str
) -> str:
raise NotImplementedError
@abstractmethod
def create_next_turn(
self,
*,
question: str,
position: str,
opposing_position: str,
turns: list[str],
started_first: bool,
) -> str:
raise NotImplementedError
@property
@abstractmethod
def name(self) -> str:
raise NotImplementedError
@dataclass
class DebateAgent(BaseDebateAgent):
llm: BaseLLM
model_name: AnthropicChatModelName | OpenAIChatModelName
temperature: float
@classmethod
def from_model(
cls,
*,
model: AnthropicChatModelName | OpenAIChatModelName,
api_key: str,
temperature: float,
) -> "DebateAgent":
llm = make_llm(model=model, api_key=api_key)
return cls(llm=llm, temperature=temperature, model_name=model)
def create_opening_statement(
self, *, question: str, position: str, opposing_position: str
) -> str:
chat = make_opening_statement_chat(
question=question, position=position, opposing_position=opposing_position
)
response = self.llm.predict(chat, temperature=self.temperature)
opening_statement = parse_opening_statement_response(response)
if opening_statement is not None:
return opening_statement
print("Warning: Could not find opening statement in response")
return response
def create_next_turn(
self,
*,
question: str,
position: str,
opposing_position: str,
turns: list[str],
started_first: bool,
) -> str:
chat = make_next_turn_chat(
question=question,
position=position,
opposing_position=opposing_position,
turns=turns,
started_first=started_first,
)
response = self.llm.predict(chat)
next_turn = parse_next_turn_response(response)
if next_turn is not None:
return next_turn
print("Warning: Could not find response in response")
return response
@property
def name(self) -> str:
return f"DebateAgent-{self.model_name}"
@dataclass
class BoNDebateAgent(BaseDebateAgent):
llm: BaseLLM
model_name: AnthropicChatModelName | OpenAIChatModelName
reward_model: RewardModel
temperature: float
best_of: int
@classmethod
def from_model(
cls,
*,
model: AnthropicChatModelName | OpenAIChatModelName,
api_key: str,
temperature: float,
best_of: int,
) -> "BoNDebateAgent":
llm = make_llm(model=model, api_key=api_key)
reward_model = RewardModel.from_model(model=model, api_key=api_key)
return cls(
llm=llm,
reward_model=reward_model,
best_of=best_of,
temperature=temperature,
model_name=model,
)
def create_opening_statement(
self, *, question: str, position: str, opposing_position: str
) -> str:
chat = make_opening_statement_chat(
question=question, position=position, opposing_position=opposing_position
)
responses = self.llm.sample(
chat, temperature=self.temperature, num_samples=self.best_of
)
opening_statements = [
parse_opening_statement_response(response) for response in responses
]
# filter out None values
opening_statements = [
opening_statement
for opening_statement in opening_statements
if opening_statement is not None
]
if len(opening_statements) == 0:
print("Warning: Could not find opening statement in responses")
return responses[0]
return self.reward_model.pick_best_opening_statement(
question=question,
position=position,
opposing_position=opposing_position,
possible_opening_statements=opening_statements,
)
def create_next_turn(
self,
*,
question: str,
position: str,
opposing_position: str,
turns: list[str],
started_first: bool,
) -> str:
chat = make_next_turn_chat(
question=question,
position=position,
opposing_position=opposing_position,
turns=turns,
started_first=started_first,
)
responses = self.llm.sample(
chat, temperature=self.temperature, num_samples=self.best_of
)
next_turns = [parse_next_turn_response(response) for response in responses]
# filter out None values
next_turns = [next_turn for next_turn in next_turns if next_turn is not None]
if len(next_turns) == 0:
print("Warning: Could not find response in responses")
return responses[0]
return self.reward_model.pick_best_response(
question=question,
position=position,
opposing_position=opposing_position,
turns=turns,
started_first=started_first,
possible_next_turns=next_turns,
)
@property
def name(self) -> str:
return f"BoNDebateAgent-{self.model_name}-best_of-{self.best_of}"
def run_debate(
*,
question: str,
position: str,
opposing_position: str,
agent: BaseDebateAgent,
opponent_agent: BaseDebateAgent,
number_of_turns: int = 2, # Number of turns (including opening statements so minimum is 2 and must be even)
):
assert (
number_of_turns >= 2 and number_of_turns % 2 == 0
), "Number of turns must be at least 2 and even"
opening_statement = agent.create_opening_statement(
question=question,
position=position,
opposing_position=opposing_position,
)
opponent_opening_statement = opponent_agent.create_opening_statement(
question=question,
position=opposing_position,
opposing_position=position,
)
turns = [opening_statement, opponent_opening_statement]
for _ in range((number_of_turns - 2) // 2):
next_turn = agent.create_next_turn(
question=question,
position=position,
opposing_position=opposing_position,
turns=turns,
started_first=True,
)
turns.append(next_turn)
opponent_next_turn = opponent_agent.create_next_turn(
question=question,
position=opposing_position,
opposing_position=position,
turns=turns,
started_first=False,
)
turns.append(opponent_next_turn)
# Position always starts first
return turns