Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 178 additions & 1 deletion debug_gym/agents/debug_agent.py
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

relevant part of the code to modify the history

Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from debug_gym.llms.base import LLM
from debug_gym.agents.base_agent import BaseAgent, register_agent

from debug_gym.agents.history_tracker import HistoryTracker
from debug_gym.gym.envs.env import EnvInfo
from debug_gym.gym.entities import Observation
import json

@register_agent
class DebugAgent(BaseAgent):
Expand Down Expand Up @@ -116,3 +120,176 @@ def run(self, task_name=None, debug=False):
status="error",
)
raise

@register_agent
class ReplayAgent(DebugAgent):
name: str = "replay_agent"

def run(self, task_name=None, debug=False):
print("running!")
step = 0
info = None
max_steps = self.config["max_steps"]
try:
self.history.reset()
info = self.env.reset(options={"task_name": task_name})
# initial state does not have prompt and response
self.history.step(info, None)

if info.done is True:
self.logger.report_progress(
problem_id=task_name,
step=1,
total_steps=1,
score=info.score,
max_score=info.max_score,
status="resolved",
)
return True

self.logger.info(
"Available tools (in LLM's tool calling format):\n"
f"{json.dumps(self.llm.define_tools(info.tools), indent=4)}\n"
)

highscore = info.score
for step in range(max_steps):
print(step)
critique_step = self.config.get("replay_from", None)
critique = self.config.get("critique", "")
if critique_step is not None and step == critique_step:
new_observation = Observation(
source="user",
observation=critique
)
info = EnvInfo(
step_observation=new_observation,
all_observations=info.all_observations + [new_observation],
action_tool_call=None,
action_content=None,
action_reasoning=None,
rewrite_counter=info.rewrite_counter,
score=info.score,
max_score=info.max_score,
done=info.done,
instructions=info.instructions,
tools=info.tools,
dir_tree=info.dir_tree,
current_breakpoints=info.current_breakpoints,
eval_observation=info.eval_observation,
)
self.history.step(info, None)

self.logger.info(f"\n{'='*20} STEP {step+1} {'='*20}\n")
highscore = max(highscore, info.score)
self.logger.info(
f"[{task_name[:10]:<10}] | Step: {step:<4} | Score: {info.score:>4}/{info.max_score:<4} ({info.score/info.max_score:.1%}) [Best: {highscore}]"
)

messages = self.build_prompt(info)
llm_response = self.llm(messages, info.tools)

if debug:
breakpoint()

info = self.env.step(
llm_response.tool,
llm_response.response,
llm_response.reasoning_response,
)
self.history.step(info, llm_response)



if (
info.done
or info.rewrite_counter >= self.config["max_rewrite_steps"]
):
reason = "done" if info.done else "max_rewrite_steps reached"
self.logger.info(
f"Step: {step} | Score: {info.score}/{info.max_score} ({info.score/info.max_score:.1%}) | Reason: {reason}"
)
# early stop, set current step and total steps to be the same
self.logger.report_progress(
problem_id=task_name,
step=step + 1,
total_steps=step + 1,
score=info.score,
max_score=info.max_score,
status="resolved" if info.done else "unresolved",
)
break
# keep progress bar running until max_steps is reached
self.logger.report_progress(
problem_id=task_name,
step=step + 1,
total_steps=max_steps + 1,
score=info.score,
max_score=info.max_score,
status="running",
)
# max_steps was reached, task was either resolved or unresolved
self.logger.report_progress(
problem_id=task_name,
step=step + 1,
total_steps=step + 1,
score=info.score,
max_score=info.max_score,
status="resolved" if info.done else "unresolved",
)
return info.done
except Exception:
# report any error that happens during the run
self.logger.report_progress(
problem_id=task_name,
step=step + 1,
total_steps=step + 1,
score=info.score if info else 0,
max_score=info.max_score if info else 1,
status="error",
)
raise

def build_history_prompt(
history: HistoryTracker, llm: LLM, reset_prompt_history_after_rewrite: bool = False
):
_history, _prompt_response_pairs = history.get()
latest_rewrite_step = 0
# Find the latest rewrite step if reset_prompt_history_after_rewrite
if reset_prompt_history_after_rewrite:
for i in range(len(_history)):
if _history[i].rewrite_counter == _history[-1].rewrite_counter:
latest_rewrite_step = i
break
_messages = []
for history_info, response in zip(
_history[latest_rewrite_step:], _prompt_response_pairs[latest_rewrite_step:]
):
_messages.extend(llm.format_tool_call_history(history_info, response))
return _messages

def insert_critique_into_history(
history: HistoryTracker, critique: str, step_id: int = None
):
"""
Insert a critique into the history at the specified step.
If step_id is None, insert at the last step.
If step_id is longer than the history, return the history unchanged.
"""
new_history = history.clone()
if step_id is None:
step_id = len(history) - 1
if step_id < 0 or step_id >= len(history):
return new_history

before_step = new_history.memory[:step_id]
after_step = new_history.memory[step_id + 1:]
new_step = new_history.memory[step_id].clone()
new_step.action_reasoning = None
new_step.action_content = None
new_step.action_tool_call = None
new_step.step_observation = critique


new_history.memory = before_step + [new_step] + after_step
return new_history
5 changes: 5 additions & 0 deletions scripts/config_swesmith.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,8 @@ solution_agent:
grep_agent:
agent_type: "rewrite_agent"
tools: ["grep", "view", "rewrite", "listdir", "eval"]

replay_agent:
tools: ["pdb", "view", "rewrite", "listdir", "eval"]
replay_from: 0
critique: "You are an awesome debugger"
1 change: 0 additions & 1 deletion scripts/replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ class AgentTimeoutException(BaseException):
"""Custom exception to handle timeouts in agent
execution. Inherits from BaseException to ensure
it is not caught by agent exception handling."""

pass


Expand Down