Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ Then, edit this file with your endpoint and credentials. You can choose one of t
- For `az login` or Managed Identity authentication on Azure, remove `api_key` and include `scope` instead.

> [!WARNING]
> When using open-sourced LLMs, e.g., via vLLM, you need to correctly setup `HF_TOKEN` required by the tokenizer.
> When using open-sourced LLMs, e.g., via vLLM, you need to correctly setup `HF_TOKEN` required by the tokenizer. You can also provide `tokenizer_kwargs` in your `llm.yaml` entry (for example `trust_remote_code: true`) to control how the Hugging Face tokenizer is instantiated.

By default, `debug-gym` looks for the LLM config file at `$HOME/.config/debug_gym/llm.yaml`. You can change this behavior by exporting the environment variable `LLM_CONFIG_FILE_PATH` or by setting `llm_config_file_path` in your script config file (see [Running Baselines](#3-running-baselines)).

Expand All @@ -65,7 +65,7 @@ debug_gym

`debug_gym.agents` are LLM-based debugging agents that use `debug_gym.gym` to interact with code repositories to seek necessary information and thus fix potential bugs. At an interaction step, the agent takes a text observation that describes the environment states and tool states as input, it is expected to generate a command, subsequently, the environment will provide a new text observation in response, describing the state change caused by that command.

`debug_gym.llms` are the different LLM backends that can be used to instantiate agents. Currently, we support OpenAI, Azure OpenAI, and Anthropic.
`debug_gym.llms` are the different LLM backends that can be used to instantiate agents. Currently, we support OpenAI, Azure OpenAI, Hugging Face/vLLM deployments (via an OpenAI-compatible endpoint), and Anthropic. For Hugging Face models served through vLLM, the tokenizer's chat template is applied automatically to ensure token counting and truncation match the hosted model.

> [!WARNING]
> `debug-gym` has limited support on non-Linux platforms. Interactive terminal sessions using PTY (pseudo-terminal) in Docker are not fully supported on macOS or Windows. As a result, the `pdb` tool (see [2.1. Environment and Tools](#21-environment-and-tools)) only works on Linux.
Expand Down Expand Up @@ -202,7 +202,7 @@ In addition to all [built-in Jinja filters](https://jinja.palletsprojects.com/en
{{ info.tools | to_pretty_json }}
```

- **`trim_message`**: Trims a string to fit within a token or character limit, also filtering out non-UTF8 characters. This is helpful for ensuring that large outputs (such as directory trees or evaluation results) do not exceed the LLM's context window. The `trim_message` filter accepts the following arguments to control how messages are trimmed:
- **`trim_message`**: Trims a string to approximately fit within a token or character limit while filtering non-UTF8 characters. This helps keep large outputs (such as directory trees or evaluation results) within the LLM's context window. The `trim_message` filter accepts the following arguments to control how messages are trimmed:
- **`max_length`**: The maximum number of tokens to keep in the message. If the message exceeds this length, it will be trimmed.
- **`max_length_percentage`**: Instead of specifying an absolute number, you can provide a percentage (e.g., `0.1` for 10%) of the LLM's context window. The message will be trimmed to fit within this percentage of the model's maximum context length.
- **`where`**: Specifies where to trim the message if it exceeds the limit. The default is `"middle"`, which trims from the middle of the message. Other options are `start` or `end`.
Expand Down
10 changes: 4 additions & 6 deletions debug_gym/agents/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
from jinja2 import Environment, Template

from debug_gym.agents.history_tracker import HistoryTracker, build_history_prompt
from debug_gym.agents.utils import trim
from debug_gym.gym.envs.env import RepoEnv
from debug_gym.gym.utils import filter_non_utf8
from debug_gym.llms.base import LLM
from debug_gym.llms.utils import trim
from debug_gym.logger import DebugGymLogger

AGENT_REGISTRY = {}
Expand Down Expand Up @@ -119,7 +119,7 @@ def to_pretty_json(value):

def trim_message(
self,
message,
message: str,
count_tokens=None,
max_length=None,
max_length_percentage=0,
Expand All @@ -140,10 +140,8 @@ def trim_message(

if count_tokens is None or max_length is None or max_length <= 0:
return message
tokens = count_tokens(message)
if tokens > max_length:
return trim(message, max_length, count_tokens=count_tokens, where=where)
return message

return trim(message, max_length, count_tokens=count_tokens, where=where)

def _load_system_prompt_template(self) -> Template | None:
"""Load system prompt template from config if specified and register custom filters.
Expand Down
178 changes: 0 additions & 178 deletions debug_gym/agents/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,184 +5,6 @@
import yaml


def trim(text: str, max_tokens: int, count_tokens: callable, where: str = "middle"):
"""Trim text to fit within max_tokens by working directly at the token level."""
if max_tokens <= 0:
return ""

nb_tokens = count_tokens(text)
if nb_tokens <= max_tokens:
return text

ellipsis = "…" # assume ellipsis is a single token
available_tokens = max_tokens - 1 # account for ellipsis

def find_char_position_for_tokens(
target_tokens: int, from_start: bool = True
) -> int:
"""Binary search to find character position that gives approximately target_tokens."""
left, right = 0, len(text)
best_pos = left if from_start else right

while left <= right:
mid = (left + right) // 2
test_text = text[:mid] if from_start else text[mid:]
test_tokens = count_tokens(test_text)
if test_tokens <= target_tokens:
best_pos = mid
if from_start:
left = mid + 1
else:
right = mid - 1
else:
if from_start:
right = mid - 1
else:
left = mid + 1
return best_pos

if where == "end":
# Keep the beginning, trim the end
trim_point = find_char_position_for_tokens(available_tokens, from_start=True)
return text[:trim_point] + ellipsis
elif where == "start":
# Keep the end, trim the beginning
trim_point = find_char_position_for_tokens(available_tokens, from_start=False)
return ellipsis + text[trim_point:]
elif where == "middle":
# Keep both ends, trim the middle
half_tokens = available_tokens // 2

# Find how much we can keep from the start
start_chars = find_char_position_for_tokens(half_tokens, from_start=True)

# Find how much we can keep from the end with remaining tokens
remaining_tokens = available_tokens - count_tokens(text[:start_chars])
end_chars = find_char_position_for_tokens(remaining_tokens, from_start=False)

return text[:start_chars] + ellipsis + text[end_chars:]
else:
raise ValueError(f"Invalid value for `where`: {where!r}.")


def get_message_tokens(message, count_tokens):
"""Count tokens in a message."""
message_content = str(message.get("content", message.get("tool_calls", message)))
return count_tokens(message_content)


def trim_prompt_messages(
messages: list[dict], context_length: int, count_tokens: callable
):
"""
Trim message on the assistant-tool pair level to fit context length.

Strategy:
1. Keep the system message (assert if system itself is too long)
2. Keep as many most recent (assistant, tool) pairs as possible
3. Only when we can keep all (assistant, tool) pairs, keep the user message

Args:
messages: List of message dicts with 'role' and 'content'/'tool_calls' keys
context_length: Maximum number of tokens allowed
count_tokens: Function to count tokens in a string

Returns:
Trimmed list of messages that fit within context_length
"""
assert len(messages) > 0, "messages should not be empty"
assert messages[-1]["role"] in [
"user",
"tool",
], "the last message should be from the user or the tool"
assert context_length >= 0, "context_length should be non-negative"

# Calculate token count for all messages
message_tokens = [get_message_tokens(msg, count_tokens) for msg in messages]
total_tokens = sum(message_tokens)

# If we're already within limit, return as-is
if total_tokens <= context_length:
return messages

# Find system message
system_msg_idx = 0 if messages[0]["role"] == "system" else None
system_tokens = message_tokens[0] if system_msg_idx is not None else 0

# Assert system message fits within context
assert (
system_tokens <= context_length
), f"System message tokens exceed context length: {system_tokens} > {context_length}!"

# Find user message
user_msg_idx = None
for i, msg in enumerate(messages):
if msg["role"] == "user":
user_msg_idx = i
break

# Find all (assistant, tool) pairs by going backwards
assistant_tool_pairs = []
i = len(messages) - 1
while i >= 0:
if (
messages[i]["role"] == "tool"
and i > 0
and messages[i - 1]["role"] == "assistant"
):
assistant_tool_pairs.append((i - 1, i)) # (assistant_idx, tool_idx)
i -= 2
else:
i -= 1

# Start building result with system message
result = []
remaining_tokens = context_length

if system_msg_idx is not None:
result.append(messages[system_msg_idx])
remaining_tokens -= system_tokens

# Add as many recent (assistant, tool) pairs as possible
included_pairs = []
for assistant_idx, tool_idx in assistant_tool_pairs:
pair_tokens = message_tokens[assistant_idx] + message_tokens[tool_idx]
if pair_tokens <= remaining_tokens:
included_pairs.append((assistant_idx, tool_idx))
remaining_tokens -= pair_tokens
else:
break

# Only include user message if we can fit all (assistant, tool) pairs
include_user = False
if len(included_pairs) == len(assistant_tool_pairs) and user_msg_idx is not None:
user_tokens = message_tokens[user_msg_idx]
if user_tokens <= remaining_tokens:
include_user = True

# Build final result
if include_user:
result.append(messages[user_msg_idx])

# Sort by assistant index to maintain chronological order
included_pairs.sort(key=lambda pair: pair[0])
for assistant_idx, tool_idx in included_pairs:
result.append(messages[assistant_idx])
result.append(messages[tool_idx])

assert (
len(result) > 0
), f"After trimming, no messages fit within context length: {context_length}!"

# Verify final token count
final_tokens = sum(get_message_tokens(msg, count_tokens) for msg in result)
assert (
final_tokens <= context_length
), f"After trimming, the message length still exceeds: {final_tokens} > {context_length}!"

return result


def load_config():
parser = argparse.ArgumentParser()
parser.add_argument("config_file", help="path to config file")
Expand Down
1 change: 1 addition & 0 deletions debug_gym/llms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from debug_gym.llms.anthropic import AnthropicLLM
from debug_gym.llms.azure_openai import AzureOpenAILLM
from debug_gym.llms.base import LLM
from debug_gym.llms.huggingface import HuggingFaceLLM
from debug_gym.llms.human import Human
from debug_gym.llms.openai import OpenAILLM
30 changes: 12 additions & 18 deletions debug_gym/llms/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,24 +41,18 @@ def client(self):
self._client = Anthropic(api_key=self.config.api_key)
return self._client

def tokenize(self, text: str) -> list[str]:
raise NotImplementedError("Tokenization is not supported by Anthropic.")
def tokenize(self, messages: list[dict]) -> list[list[str]]:
"""Tokenization is not directly supported by Anthropic.
This method returns empty token lists as a placeholder."""
raise NotImplementedError("Direct tokenization is not supported by Anthropic.")

def count_tokens(self, messages: list[dict] | str) -> int:
"""Count the number of tokens in a text using the Anthropic API."""
if isinstance(messages, str):
messages = [
{"role": "user", "content": [{"type": "text", "text": messages}]}
]

def count_tokens(self, text: str) -> int:
"""Count the number of tokens in a text using the Anthropic API.
Dump content to JSON for cases such as:
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "id123",
"content": "results",
}
],
}
"""
messages = [{"role": "user", "content": [{"type": "text", "text": text}]}]
try:
response = self.client.messages.count_tokens(
model=self.tokenizer_name, messages=messages
Expand All @@ -67,7 +61,7 @@ def count_tokens(self, text: str) -> int:
except Exception as e:
self.logger.warning(
f"Error calling Claude token count API: {e!r}. "
f"The message was: {messages}. Will return 0 tokens."
f"The messages were: {messages}. Will return 0 tokens."
)
return 0

Expand Down
Loading