Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions examples/apo/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# APO Example
Using CFPO optimizer to mutate prompt
85 changes: 85 additions & 0 deletions examples/apo/Vllm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from .base import LLM_Model
from vllm import LLM, SamplingParams
from typing import List, Union, Optional
import os
import logging

class VllmModel(LLM_Model):
def __init__(
self,
model_path: Optional[str] = None,
max_tokens: int = 256,
stop: str = '',
repetition_penalty: float = 1.0,
logger: Optional[logging.Logger] = None,
):
"""
Initialize the VLLM model.

Args:
model_path (Optional[str]): Path to the model. Defaults to None.
max_tokens (int): Maximum number of tokens to generate. Defaults to 256.
stop (str): Stop sequence for generation. Defaults to ''.
repetition_penalty (float): Penalty for repetition. Defaults to 1.2.
logger (Optional[logging.Logger]): Logger object for logging messages.
"""
self.logger = logger

# Initialize the VLLM model
self.llm = LLM(model=model_path)
self.max_tokens = max_tokens
self.stop = stop
self.repetition_penalty = repetition_penalty

def inference(
self,
prompt: Union[str, List[str]],
use_batch_acceleration: bool = True,
desc: str = '',
) -> Union[str, List[str]]:
"""
Perform inference using the VLLM model.

Args:
prompt (Union[str, List[str]]): Input prompt(s) for the model.
use_batch_acceleration (bool): Whether to use batch acceleration. Defaults to True.
desc (str): Description of the inference task for logging.

Returns:
Union[str, List[str]]: Generated output(s) from the model.
"""
# Log the inference call
if self.logger:
self.logger.info(f"VLLM | {desc}")

# Configure sampling parameters
sampling_params = SamplingParams(
temperature=0,
repetition_penalty=self.repetition_penalty,
top_p=0.1,
max_tokens=self.max_tokens,
stop=self.stop,
)

if use_batch_acceleration and isinstance(prompt, list):
batch_size = 512
gen_output_list = []

for start_idx in range(0, len(prompt), batch_size):
end_idx = start_idx + batch_size
sub_gen_input_list = prompt[start_idx:end_idx]
sub_gen_output_list = self.llm.generate(sub_gen_input_list, sampling_params, use_tqdm=False)
gen_output_list.extend(sub_gen_output_list)

return [item.outputs[0].text for item in gen_output_list]

elif not use_batch_acceleration and isinstance(prompt, str):
output = self.llm.generate(prompt, sampling_params, use_tqdm=False)
return output[0].outputs[0].text

if __name__ == "__main__":
llm = VllmModel("/home/aiscuser/Phi-3-mini-4k-instruct", max_tokens=512, stop='\n\n', repetition_penalty=1.0)
print(llm.inference("Hello, how are you?", use_batch_acceleration=False))
26 changes: 26 additions & 0 deletions examples/apo/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
def call_api(prompt):
import os
import anthropic

client = anthropic.Anthropic(
# defaults to os.environ.get("ANTHROPIC_API_KEY")
api_key=os.environ.get("ANTHROPIC_API_KEY"),
)

message = client.messages.create(
model="claude-sonnet-4-20250514",
max_tokens=20000,
temperature=1,
messages=[
{
"role": "user",
"content": [
{
"type": "text",
"text": prompt
}
]
}
]
)
return message.content[0].text
220 changes: 197 additions & 23 deletions examples/apo/apo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,44 +3,218 @@
from agentlightning.server import AgentLightningServer
from agentlightning.types import NamedResources, PromptTemplate

import asyncio
import random
from agentlightning.server import AgentLightningServer
from agentlightning.types import NamedResources, PromptTemplate
# from cfpo import POAgents
from typing import List, Dict, Tuple, Optional
from po_agent.agent import POAgent
import os

def call_api(user_prompt: str = "") -> str:
import anthropic
print("Calling API with user prompt:", user_prompt)

client = anthropic.Anthropic(
# defaults to os.environ.get("ANTHROPIC_API_KEY")
api_key=os.environ.get("ANTHROPIC_API_KEY"),
)

message = client.messages.create(
model="claude-sonnet-4-20250514",
max_tokens=20000,
temperature=1,
messages=[
{
"role": "user",
"content": [
{
"type": "text",
"text": user_prompt
}
]
}
]
)
return message.content[0].text


import datasets
import random
import re
from tqdm import tqdm
from typing import List, Dict, Tuple, Optional
import logging


class GSM8KDataLoader:
def __init__(
self,
train_size: int = 100,
minibatch_size: int = 5,
valid_size: int = 200,
test_size: int = -1,
answer_marker: str = "The answer is",
):
"""
Initialize the GSM8K task.
"""
data_dir = "openai/gsm8k"
self.train_size = train_size
self.minibatch_size = minibatch_size
self.valid_size = valid_size
self.test_size = test_size
self.answer_marker = answer_marker
self.dataset = self.load_task_dataset(data_dir)
self.train_set, self.valid_set, self.test_set = self.dataset


def load_task_dataset(self, data_dir) -> Tuple[List[Dict], List[Dict], List[Dict]]:
"""
Load and preprocess the GSM8K dataset.
"""
dataset = datasets.load_dataset(path=data_dir, name='main')
train_examples = self._pre_process(dataset["train"])
test_examples = self._pre_process(dataset['test'])

# Split dataset into train, validation, and test sets
test_set = test_examples if self.test_size == -1 else test_examples[:self.test_size]

train_size = len(train_examples)
if self.valid_size > train_size:
raise ValueError("valid_size is greater than the number of train examples.")

valid_indices = random.sample(range(train_size), self.valid_size)
valid_set = [train_examples[i] for i in valid_indices]

remaining_train_indices = [i for i in range(train_size) if i not in valid_indices]
if self.train_size == -1:
train_set = [train_examples[i] for i in remaining_train_indices]
else:
if self.train_size > len(remaining_train_indices):
raise ValueError("train_size is greater than the remaining number of train examples after validation set selection.")
train_set = [train_examples[i] for i in random.sample(remaining_train_indices, self.train_size)]

return train_set, valid_set, test_set

def _pre_process(self, dataset) -> List[Dict]:
"""
Preprocess the dataset.
"""
out_doc = []
for doc in dataset:
label = doc['answer'].split('####')[-1].strip()
text = doc['answer'].split('####')[0].strip()

lines = text.split('\n')
processed_lines = [f"{line.strip()}." if not line.strip().endswith('.') else line.strip() for line in lines]
processed_text = ' '.join(processed_lines).strip()

answer = f"{processed_text} {self.answer_marker} {label}."
question = re.sub(r'\s+', ' ', doc['question'])
answer = re.sub(r'\s+', ' ', answer)

out_doc.append({"question": question, "answer": answer})
return out_doc

def sample_minibatch(self) -> List[Dict]:
"""
Sample a minibatch from the training set.
"""
minibatch = random.sample(self.train_set, k=min(self.minibatch_size, len(self.train_set)))
return minibatch


async def example_apo():
"""
An example of how a prompt optimization works.
"""
gsm8k_dataloader = GSM8KDataLoader()
prompt_optimizer = POAgent(task_intention="solve a reasoning task and answer the following mathematical problem",
optimizer_api=call_api)

server = AgentLightningServer(host="127.0.0.1", port=9997)
await server.start()

prompt_candidates = [
"You are a helpful assistant.",
"You are a knowledgeable AI.",
"You are a friendly chatbot.",
]
prompt_candidates_and_reward = [("Please solve the following question: ", None)]

# for prompt, _ in prompt_candidates_and_reward:
# task_id_list = []
# # 1. The optimization algorithm updates the prompt template
# print(f"\n[Algo] Updating prompt template to: '{prompt}'")
# prompt = prompt + "\n{question}" if "\n{question}" not in prompt else prompt
# resources: NamedResources = {"prompt": PromptTemplate(template=prompt, engine="f-string")}
# # How the resource is used fully depends on the client implementation.
# await server.update_resources(resources)

# minibatch = gsm8k_dataloader.sample_minibatch()
# print(f"[Algo] Sampled {len(minibatch)} tasks from the GSM8K dataset for this round.")

prompt_and_rewards = []
# # 2. Get the results of prompt in this minibatch
# querys, ground_truths, query_outputs, scores = [], [], [], []
# for data in minibatch:
# print(data)
# print("[Algo] Queuing task for clients...")
# task_id = await server.queue_task(sample=data, mode='train')
# print(f"[Algo] Task '{task_id}' is now available for clients.")
# task_id_list.append(task_id)

# for task_id in task_id_list:
# rollout = await server.poll_completed_rollout(task_id, timeout=30)
# assert rollout, "Expected a completed rollout from the client."
# print(f"[Algo] Received Result: {rollout}")
# querys.append(rollout.metadata["query"])
# ground_truths.append(rollout.metadata["ground_truth"])
# query_outputs.append(rollout.metadata["query_output"])
# scores.append(rollout.final_reward)

# prompt_candidates = prompt_optimizer.diagnosing(prompt, querys, query_outputs, ground_truths, scores)[0]
# print(f"[Algo] Found {len(prompt_candidates)} prompt candidates for the next round.")

for prompt in prompt_candidates:

# # avg_reward = sum(scores) / len(scores) if scores else 0
# # print(f"[Algo] Average reward for prompt '{prompt}': {avg_reward}")
# for candidate in prompt_candidates:
# print(f"[Algo] Candidate prompt: '{candidate[0]}'")
# prompt_candidates_and_reward.append((candidate[0], None))

for i, (prompt, eval_score) in enumerate(prompt_candidates_and_reward):
if eval_score is not None:
continue
task_id_list = []
# 1. The optimization algorithm updates the prompt template
print(f"\n[Algo] Updating prompt template to: '{prompt}'")
resources: NamedResources = {"system_prompt": PromptTemplate(template=prompt, engine="f-string")}
prompt = prompt + "\n{question}" if "\n{question}" not in prompt else prompt
resources: NamedResources = {"prompt": PromptTemplate(template=prompt, engine="f-string")}
# How the resource is used fully depends on the client implementation.
await server.update_resources(resources)

# 2. The algorithm queues up a task from a dataset
print("[Algo] Queuing task for clients...")
task_id = await server.queue_task(sample={"prompt": "What is the capital of France?"}, mode="train")
print(f"[Algo] Task '{task_id}' is now available for clients.")

# 3. The algorithm waits for clients to process the task
rollout = await server.poll_completed_rollout(task_id, timeout=30)
assert rollout, "Expected a completed rollout from the client."
print(f"[Algo] Received Result: {rollout}")
reward = rollout.final_reward
prompt_and_rewards.append((prompt, reward))

print(f"\n[Algo] All prompts and their rewards: {prompt_and_rewards}")
best_prompt = max(prompt_and_rewards, key=lambda x: x[1])
print(f"[Algo] Best prompt found: '{best_prompt[0]}' with reward {best_prompt[1]}")
valid_set = gsm8k_dataloader.valid_set
print(f"[Algo] Sampled {len(valid_set)} tasks from the GSM8K dataset for this round.")

# 2. Get the results of prompt in this minibatch
querys, ground_truths, query_outputs, scores = [], [], [], []
for data in valid_set:
print(data)
print("[Algo] Queuing task for clients...")
task_id = await server.queue_task(sample=data, mode='train')
print(f"[Algo] Task '{task_id}' is now available for clients.")
task_id_list.append(task_id)

for task_id in task_id_list:
rollout = await server.poll_completed_rollout(task_id, timeout=30)
assert rollout, "Expected a completed rollout from the client."
print(f"[Algo] Received Result: {rollout}")
querys.append(rollout.metadata["query"])
ground_truths.append(rollout.metadata["ground_truth"])
query_outputs.append(rollout.metadata["query_output"])
scores.append(rollout.final_reward)

prompt_candidates_and_reward[i] = (prompt, sum(scores) / len(scores) if scores else 0)
print(f"[Algo] Average reward for prompt '{prompt}': {prompt_candidates_and_reward[i][1]}")



await server.stop()

Expand Down
Loading
Loading