From ed6f4004d324f1e5618fe2481c1eb07cdece9c0c Mon Sep 17 00:00:00 2001 From: "Xingdi (Eric) Yuan" Date: Sat, 4 Oct 2025 23:58:53 -0400 Subject: [PATCH 01/11] hf llm --- README.md | 4 +- debug_gym/llms/__init__.py | 1 + debug_gym/llms/base.py | 51 +++++-- debug_gym/llms/huggingface.py | 238 +++++++++++++++++++++++++++++++++ tests/llms/test_base.py | 21 ++- tests/llms/test_huggingface.py | 101 ++++++++++++++ 6 files changed, 404 insertions(+), 12 deletions(-) create mode 100644 debug_gym/llms/huggingface.py create mode 100644 tests/llms/test_huggingface.py diff --git a/README.md b/README.md index 2c1d15ba..1543798e 100644 --- a/README.md +++ b/README.md @@ -42,7 +42,7 @@ Then, edit this file with your endpoint and credentials. You can choose one of t - For `az login` or Managed Identity authentication on Azure, remove `api_key` and include `scope` instead. > [!WARNING] -> When using open-sourced LLMs, e.g., via vLLM, you need to correctly setup `HF_TOKEN` required by the tokenizer. +> When using open-sourced LLMs, e.g., via vLLM, you need to correctly setup `HF_TOKEN` required by the tokenizer. You can also provide `tokenizer_kwargs` in your `llm.yaml` entry (for example `trust_remote_code: true`) to control how the Hugging Face tokenizer is instantiated. By default, `debug-gym` looks for the LLM config file at `$HOME/.config/debug_gym/llm.yaml`. You can change this behavior by exporting the environment variable `LLM_CONFIG_FILE_PATH` or by setting `llm_config_file_path` in your script config file (see [Running Baselines](#3-running-baselines)). @@ -65,7 +65,7 @@ debug_gym `debug_gym.agents` are LLM-based debugging agents that use `debug_gym.gym` to interact with code repositories to seek necessary information and thus fix potential bugs. At an interaction step, the agent takes a text observation that describes the environment states and tool states as input, it is expected to generate a command, subsequently, the environment will provide a new text observation in response, describing the state change caused by that command. -`debug_gym.llms` are the different LLM backends that can be used to instantiate agents. Currently, we support OpenAI, Azure OpenAI, and Anthropic. +`debug_gym.llms` are the different LLM backends that can be used to instantiate agents. Currently, we support OpenAI, Azure OpenAI, Hugging Face/vLLM deployments (via an OpenAI-compatible endpoint), and Anthropic. For Hugging Face models served through vLLM, the tokenizer's chat template is applied automatically to ensure token counting and truncation match the hosted model. > [!WARNING] > `debug-gym` has limited support on non-Linux platforms. Interactive terminal sessions using PTY (pseudo-terminal) in Docker are not fully supported on macOS or Windows. As a result, the `pdb` tool (see [2.1. Environment and Tools](#21-environment-and-tools)) only works on Linux. diff --git a/debug_gym/llms/__init__.py b/debug_gym/llms/__init__.py index b7cfa674..91e09d9f 100644 --- a/debug_gym/llms/__init__.py +++ b/debug_gym/llms/__init__.py @@ -1,5 +1,6 @@ from debug_gym.llms.anthropic import AnthropicLLM from debug_gym.llms.azure_openai import AzureOpenAILLM from debug_gym.llms.base import LLM +from debug_gym.llms.huggingface import HuggingFaceLLM from debug_gym.llms.human import Human from debug_gym.llms.openai import OpenAILLM diff --git a/debug_gym/llms/base.py b/debug_gym/llms/base.py index b0f14139..369d454d 100644 --- a/debug_gym/llms/base.py +++ b/debug_gym/llms/base.py @@ -66,6 +66,8 @@ class LLMConfig: scope: Optional[str] = None # Custom parameters to pass to generate generate_kwargs: dict = None + # Additional kwargs for tokenizer construction (e.g., trust_remote_code) + tokenizer_kwargs: dict | None = None def __post_init__(self): # Set tokenizer to model if not specified @@ -78,6 +80,8 @@ def __post_init__(self): self.tags = [] if self.generate_kwargs is None: self.generate_kwargs = {} + if self.tokenizer_kwargs is None: + self.tokenizer_kwargs = {} @dataclass @@ -250,6 +254,10 @@ def instantiate( from debug_gym.llms import AzureOpenAILLM klass = AzureOpenAILLM + elif "vllm" in tags: + from debug_gym.llms import HuggingFaceLLM + + klass = HuggingFaceLLM elif "anthropic" in tags: from debug_gym.llms import AnthropicLLM @@ -277,6 +285,34 @@ def count_tokens(self, text: str) -> int: """Count the number of tokens in a text.""" return len(self.tokenize(text)) + def _get_message_token_counts(self, messages: list[dict]) -> list[int]: + """Return per-message token counts used for context management. + + Subclasses can override this to plug in custom counting strategies + (for example, chat-template aware tokenizers). + """ + + from debug_gym.agents.utils import get_message_tokens + + return [get_message_tokens(msg, self.count_tokens) for msg in messages] + + def _trim_messages_to_context( + self, messages: list[dict], message_token_counts: list[int] | None = None + ) -> list[dict]: + """Trim messages so they fit within the model context budget. + + Args: + messages: Original message list. + message_token_counts: Optional precomputed counts aligned with messages. + + Returns: + A trimmed list of messages. + """ + + from debug_gym.agents.utils import trim_prompt_messages + + return trim_prompt_messages(messages, self.context_length, self.count_tokens) + @abstractmethod def define_tools(self, tool_call_list: list[EnvironmentTool]) -> list[dict]: """Translates the list of tools into a format that is specifically defined by each LLM. @@ -311,8 +347,6 @@ def __call__(self, messages, tools, *args, **kwargs) -> LLMResponse: should be implemented by subclasses. Returns an LLMResponse object with the prompt, response and token usage. """ - from debug_gym.agents.utils import get_message_tokens, trim_prompt_messages - # Add custom generation parameters from config for key, value in self.config.generate_kwargs.items(): # Only set if not already specified in the call @@ -347,12 +381,11 @@ def generate_with_drop_message_and_retry(messages, tools, **kwargs): for retry_count in range(max_retries + 1): try: # pre-truncate messages if they are too long, to avoid unnecessary retries - message_tokens = sum( - get_message_tokens(msg, self.count_tokens) for msg in messages - ) + message_token_counts = self._get_message_token_counts(messages) + message_tokens = sum(message_token_counts) if message_tokens > self.context_length * 1.2: - trimmed_messages = trim_prompt_messages( - messages, self.context_length, self.count_tokens + trimmed_messages = self._trim_messages_to_context( + messages, message_token_counts ) messages = trimmed_messages @@ -390,8 +423,8 @@ def generate_with_drop_message_and_retry(messages, tools, **kwargs): ) # Trim messages and try again - trimmed_messages = trim_prompt_messages( - messages, self.context_length, self.count_tokens + trimmed_messages = self._trim_messages_to_context( + messages, self._get_message_token_counts(messages) ) if not trimmed_messages: diff --git a/debug_gym/llms/huggingface.py b/debug_gym/llms/huggingface.py new file mode 100644 index 00000000..dcbf972f --- /dev/null +++ b/debug_gym/llms/huggingface.py @@ -0,0 +1,238 @@ +import json +from typing import Iterable + +from transformers import AutoTokenizer + +from debug_gym.llms.openai import OpenAILLM + + +class HuggingFaceLLM(OpenAILLM): + """LLM implementation backed by a Hugging Face tokenizer.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._hf_tokenizer = None + + def _load_tokenizer(self): + if self._hf_tokenizer is None: + tokenizer_kwargs = getattr(self.config, "tokenizer_kwargs", None) or {} + try: + self._hf_tokenizer = AutoTokenizer.from_pretrained( + self.tokenizer_name, **tokenizer_kwargs + ) + except OSError as exc: + raise ValueError( + "Failed to load Hugging Face tokenizer " + f"`{self.tokenizer_name}` for model {self.model_name}." + ) from exc + + # Ensure we have a pad token to avoid downstream warnings when invoking + # the tokenizer in encode mode. + if ( + getattr(self._hf_tokenizer, "pad_token", None) is None + and getattr(self._hf_tokenizer, "eos_token", None) is not None + ): + self._hf_tokenizer.pad_token = self._hf_tokenizer.eos_token + return self._hf_tokenizer + + def tokenize(self, text: str) -> list[str]: + tokenizer = self._load_tokenizer() + token_ids = tokenizer.encode(str(text), add_special_tokens=False) + + if hasattr(tokenizer, "convert_ids_to_tokens"): + try: + return tokenizer.convert_ids_to_tokens(token_ids) + except Exception: # pragma: no cover + pass + + return [str(t) for t in token_ids] + + def count_tokens(self, text) -> int: + tokenizer = self._load_tokenizer() + token_ids = tokenizer.encode(str(text), add_special_tokens=False) + return len(token_ids) + + # --- chat template helpers ------------------------------------------------- + + def _get_message_token_counts(self, messages: list[dict]) -> list[int]: + if not self._supports_chat_template(): + return super()._get_message_token_counts(messages) + + counts = self._compute_message_token_counts(messages) + if counts is None: + return super()._get_message_token_counts(messages) + return counts + + def _supports_chat_template(self) -> bool: + tokenizer = self._load_tokenizer() + return hasattr(tokenizer, "apply_chat_template") + + def _normalize_messages_for_template(self, messages: Iterable[dict]) -> list[dict]: + normalized = [] + for message in messages: + role = message.get("role", "user") + if role == "tool": + role = "assistant" + elif role not in {"system", "user", "assistant"}: + role = "user" + + content = message.get("content") + if isinstance(content, list): + parts = [] + for item in content: + if isinstance(item, dict) and "text" in item: + parts.append(item["text"]) + else: + parts.append(str(item)) + content = "\n".join(parts) + elif content is None and message.get("tool_calls"): + content = json.dumps(message.get("tool_calls")) + else: + content = "" if content is None else str(content) + + normalized.append({"role": role, "content": content}) + return normalized + + def _chat_template_token_ids(self, messages: list[dict]): + tokenizer = self._load_tokenizer() + if not hasattr(tokenizer, "apply_chat_template"): + return None + + normalized = self._normalize_messages_for_template(messages) + try: + tokenized = tokenizer.apply_chat_template( + normalized, + tokenize=True, + add_generation_prompt=False, + ) + except TypeError: + tokenized = tokenizer.apply_chat_template(normalized, tokenize=True) + except ValueError: + return None + + if isinstance(tokenized, dict): + token_ids = tokenized.get("input_ids", []) + else: + token_ids = tokenized + + if token_ids and isinstance(token_ids[0], list): + token_ids = token_ids[0] + + return token_ids + + def _compute_message_token_counts(self, messages: list[dict]) -> list[int] | None: + normalized = self._normalize_messages_for_template(messages) + counts = [] + prev_len = 0 + + for idx in range(len(normalized)): + partial_ids = self._chat_template_token_ids(normalized[: idx + 1]) + if partial_ids is None: + return None + current_len = len(partial_ids) + counts.append(max(current_len - prev_len, 0)) + prev_len = current_len + + # Fallback in case template produced zero tokens + if not any(counts): + tokenizer = self._load_tokenizer() + return [ + len( + tokenizer.encode( + str(m.get("content", "")), add_special_tokens=False + ) + ) + for m in messages + ] + + return counts + + def _trim_messages_to_context( + self, + messages: list[dict], + message_token_counts: list[int] | None = None, + ) -> list[dict]: + if not self._supports_chat_template(): + return super()._trim_messages_to_context(messages, message_token_counts) + + if message_token_counts is None: + message_token_counts = self._compute_message_token_counts(messages) + if message_token_counts is None: + return super()._trim_messages_to_context(messages, None) + + if len(messages) != len(message_token_counts): + return super()._trim_messages_to_context(messages, None) + + context_limit = self.context_length + total_tokens = sum(message_token_counts) + if total_tokens <= context_limit: + return messages + + assert messages, "messages should not be empty" + + result = [] + remaining_tokens = context_limit + + # Handle system message if present + system_idx = 0 if messages[0].get("role") == "system" else None + if system_idx is not None: + system_tokens = message_token_counts[0] + assert ( + system_tokens <= context_limit + ), f"System message tokens exceed context length: {system_tokens} > {context_limit}!" + result.append(messages[0]) + remaining_tokens -= system_tokens + + # Locate the first user message + user_msg_idx = None + for idx, msg in enumerate(messages): + if msg.get("role") == "user": + user_msg_idx = idx + break + + # Collect assistant/tool pairs starting from the end + assistant_tool_pairs: list[tuple[int, int]] = [] + i = len(messages) - 1 + while i >= 0: + if ( + messages[i].get("role") == "tool" + and i > 0 + and messages[i - 1].get("role") == "assistant" + ): + assistant_tool_pairs.append((i - 1, i)) + i -= 2 + else: + i -= 1 + + included_pairs: list[tuple[int, int]] = [] + for assistant_idx, tool_idx in assistant_tool_pairs: + pair_tokens = ( + message_token_counts[assistant_idx] + message_token_counts[tool_idx] + ) + if pair_tokens <= remaining_tokens: + included_pairs.append((assistant_idx, tool_idx)) + remaining_tokens -= pair_tokens + else: + break + + include_user = False + if ( + user_msg_idx is not None + and len(included_pairs) == len(assistant_tool_pairs) + and message_token_counts[user_msg_idx] <= remaining_tokens + ): + include_user = True + + if include_user: + result.append(messages[user_msg_idx]) + + included_pairs.sort(key=lambda pair: pair[0]) + for assistant_idx, tool_idx in included_pairs: + result.append(messages[assistant_idx]) + result.append(messages[tool_idx]) + + assert ( + result + ), f"After trimming, no messages fit within context length: {context_limit}!" + + return result diff --git a/tests/llms/test_base.py b/tests/llms/test_base.py index b48d63e4..11d8eb1d 100644 --- a/tests/llms/test_base.py +++ b/tests/llms/test_base.py @@ -5,7 +5,13 @@ from debug_gym.gym.entities import Observation from debug_gym.gym.tools.tool import EnvironmentTool -from debug_gym.llms import AnthropicLLM, AzureOpenAILLM, Human, OpenAILLM +from debug_gym.llms import ( + AnthropicLLM, + AzureOpenAILLM, + HuggingFaceLLM, + Human, + OpenAILLM, +) from debug_gym.llms.base import ( LLM, ContextLengthExceededError, @@ -47,6 +53,15 @@ "api_key": "test-api-key", "tags": ["anthropic", "claude", "claude-3.7"], }, + "qwen-3": { + "model": "qwen-3", + "tokenizer": "Qwen/Qwen3", + "context_limit": 4, + "api_key": "test-api-key", + "endpoint": "https://test-endpoint", + "tags": ["vllm"], + "tokenizer_kwargs": {"trust_remote_code": True}, + }, } ), ) @@ -64,6 +79,9 @@ def test_instantiate_llm(mock_open, logger_mock): llm = LLM.instantiate("claude-3.7", logger=logger_mock) assert isinstance(llm, AnthropicLLM) + llm = LLM.instantiate("qwen-3", logger=logger_mock) + assert isinstance(llm, HuggingFaceLLM) + llm = LLM.instantiate("human", logger=logger_mock) assert isinstance(llm, Human) @@ -225,6 +243,7 @@ def test_llm_config_initialization(): assert config.tokenizer == "llm-mock" # Default to model when tokenizer is None assert config.ignore_kwargs == [] # Default empty list assert config.tags == [] # Default empty list + assert config.tokenizer_kwargs == {} def test_llm_config_optional_fields(basic_config): diff --git a/tests/llms/test_huggingface.py b/tests/llms/test_huggingface.py new file mode 100644 index 00000000..48d288db --- /dev/null +++ b/tests/llms/test_huggingface.py @@ -0,0 +1,101 @@ +from unittest.mock import MagicMock, patch + +from debug_gym.gym.tools.tool import ToolCall +from debug_gym.llms import HuggingFaceLLM +from debug_gym.llms.base import LLMConfigRegistry, LLMResponse + + +@patch.object( + LLMConfigRegistry, + "from_file", + return_value=LLMConfigRegistry.register_all( + { + "qwen-3": { + "model": "qwen-3", + "tokenizer": "Qwen/Qwen3", + "context_limit": 4, + "api_key": "test-api-key", + "endpoint": "https://test-endpoint", + "tags": ["vllm"], + "tokenizer_kwargs": {"trust_remote_code": True}, + } + } + ), +) +@patch("debug_gym.llms.huggingface.AutoTokenizer.from_pretrained") +def test_huggingface_tokenizer_usage(mock_auto_tokenizer, mock_llm_config, logger_mock): + tokenizer_mock = MagicMock() + tokenizer_mock.encode.return_value = [10, 20, 30] + tokenizer_mock.convert_ids_to_tokens.return_value = ["", "", ""] + tokenizer_mock.pad_token = None + tokenizer_mock.eos_token = "" + mock_auto_tokenizer.return_value = tokenizer_mock + + llm = HuggingFaceLLM(model_name="qwen-3", logger=logger_mock) + + tokens = llm.tokenize("hello world") + assert tokens == ["", "", ""] + assert llm.count_tokens("hello world") == 3 + + mock_auto_tokenizer.assert_called_once_with("Qwen/Qwen3", trust_remote_code=True) + tokenizer_mock.encode.assert_called_with("hello world", add_special_tokens=False) + + # Ensure pad token fallback was applied + assert tokenizer_mock.pad_token == "" + + +@patch.object(HuggingFaceLLM, "generate") +@patch.object( + LLMConfigRegistry, + "from_file", + return_value=LLMConfigRegistry.register_all( + { + "qwen-3": { + "model": "qwen-3", + "tokenizer": "Qwen/Qwen3", + "context_limit": 4, + "api_key": "test-api-key", + "endpoint": "https://test-endpoint", + "tags": ["vllm"], + "tokenizer_kwargs": {"trust_remote_code": True}, + } + } + ), +) +@patch("debug_gym.llms.huggingface.AutoTokenizer.from_pretrained") +def test_huggingface_chat_template_usage( + mock_auto_tokenizer, mock_llm_config, mock_generate, logger_mock +): + tokenizer_mock = MagicMock() + tokenizer_mock.pad_token = None + tokenizer_mock.eos_token = "" + tokenizer_mock.convert_ids_to_tokens.side_effect = lambda ids: [ + f"<{i}>" for i in ids + ] + tokenizer_mock.apply_chat_template.side_effect = [ + {"input_ids": [[1, 2]]}, + {"input_ids": [[1, 2, 3, 4]]}, + ] + tokenizer_mock.encode.return_value = [99] + mock_auto_tokenizer.return_value = tokenizer_mock + + mock_generate.return_value = LLMResponse( + prompt=[], + response="ok", + tool=ToolCall(id="t1", name="noop", arguments={}), + prompt_token_count=5, + response_token_count=2, + ) + + llm = HuggingFaceLLM(model_name="qwen-3", logger=logger_mock) + + messages = [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "Hello"}, + ] + + response = llm(messages, tools=[]) + + assert response.response == "ok" + assert tokenizer_mock.apply_chat_template.call_count == 2 + mock_generate.assert_called_once() From 3c4dc4866e8844c45c1f2c8cf98e6ef554915d94 Mon Sep 17 00:00:00 2001 From: "Xingdi (Eric) Yuan" Date: Sun, 5 Oct 2025 00:28:32 -0400 Subject: [PATCH 02/11] minor --- debug_gym/llms/huggingface.py | 184 +++++++-------------------------- tests/llms/test_huggingface.py | 163 +++++++++++++++++++++++++++++ 2 files changed, 199 insertions(+), 148 deletions(-) diff --git a/debug_gym/llms/huggingface.py b/debug_gym/llms/huggingface.py index dcbf972f..0a08c755 100644 --- a/debug_gym/llms/huggingface.py +++ b/debug_gym/llms/huggingface.py @@ -58,9 +58,41 @@ def _get_message_token_counts(self, messages: list[dict]) -> list[int]: if not self._supports_chat_template(): return super()._get_message_token_counts(messages) - counts = self._compute_message_token_counts(messages) - if counts is None: - return super()._get_message_token_counts(messages) + tokenizer = self._load_tokenizer() + normalized = self._normalize_messages_for_template(messages) + counts: list[int] = [] + prev_len = 0 + + for idx in range(1, len(normalized) + 1): + try: + tokenized = tokenizer.apply_chat_template( + normalized[:idx], + tokenize=True, + add_generation_prompt=False, + ) + except TypeError: + tokenized = tokenizer.apply_chat_template( + normalized[:idx], tokenize=True + ) + except ValueError: + return super()._get_message_token_counts(messages) + + token_ids = ( + tokenized.get("input_ids") if isinstance(tokenized, dict) else tokenized + ) + if token_ids and isinstance(token_ids[0], list): + token_ids = token_ids[0] + + if token_ids is None: + return super()._get_message_token_counts(messages) + + current_len = len(token_ids) + if current_len == 0 and idx == len(normalized): + return super()._get_message_token_counts(messages) + + counts.append(max(current_len - prev_len, 0)) + prev_len = current_len + return counts def _supports_chat_template(self) -> bool: @@ -72,7 +104,7 @@ def _normalize_messages_for_template(self, messages: Iterable[dict]) -> list[dic for message in messages: role = message.get("role", "user") if role == "tool": - role = "assistant" + role = "user" elif role not in {"system", "user", "assistant"}: role = "user" @@ -92,147 +124,3 @@ def _normalize_messages_for_template(self, messages: Iterable[dict]) -> list[dic normalized.append({"role": role, "content": content}) return normalized - - def _chat_template_token_ids(self, messages: list[dict]): - tokenizer = self._load_tokenizer() - if not hasattr(tokenizer, "apply_chat_template"): - return None - - normalized = self._normalize_messages_for_template(messages) - try: - tokenized = tokenizer.apply_chat_template( - normalized, - tokenize=True, - add_generation_prompt=False, - ) - except TypeError: - tokenized = tokenizer.apply_chat_template(normalized, tokenize=True) - except ValueError: - return None - - if isinstance(tokenized, dict): - token_ids = tokenized.get("input_ids", []) - else: - token_ids = tokenized - - if token_ids and isinstance(token_ids[0], list): - token_ids = token_ids[0] - - return token_ids - - def _compute_message_token_counts(self, messages: list[dict]) -> list[int] | None: - normalized = self._normalize_messages_for_template(messages) - counts = [] - prev_len = 0 - - for idx in range(len(normalized)): - partial_ids = self._chat_template_token_ids(normalized[: idx + 1]) - if partial_ids is None: - return None - current_len = len(partial_ids) - counts.append(max(current_len - prev_len, 0)) - prev_len = current_len - - # Fallback in case template produced zero tokens - if not any(counts): - tokenizer = self._load_tokenizer() - return [ - len( - tokenizer.encode( - str(m.get("content", "")), add_special_tokens=False - ) - ) - for m in messages - ] - - return counts - - def _trim_messages_to_context( - self, - messages: list[dict], - message_token_counts: list[int] | None = None, - ) -> list[dict]: - if not self._supports_chat_template(): - return super()._trim_messages_to_context(messages, message_token_counts) - - if message_token_counts is None: - message_token_counts = self._compute_message_token_counts(messages) - if message_token_counts is None: - return super()._trim_messages_to_context(messages, None) - - if len(messages) != len(message_token_counts): - return super()._trim_messages_to_context(messages, None) - - context_limit = self.context_length - total_tokens = sum(message_token_counts) - if total_tokens <= context_limit: - return messages - - assert messages, "messages should not be empty" - - result = [] - remaining_tokens = context_limit - - # Handle system message if present - system_idx = 0 if messages[0].get("role") == "system" else None - if system_idx is not None: - system_tokens = message_token_counts[0] - assert ( - system_tokens <= context_limit - ), f"System message tokens exceed context length: {system_tokens} > {context_limit}!" - result.append(messages[0]) - remaining_tokens -= system_tokens - - # Locate the first user message - user_msg_idx = None - for idx, msg in enumerate(messages): - if msg.get("role") == "user": - user_msg_idx = idx - break - - # Collect assistant/tool pairs starting from the end - assistant_tool_pairs: list[tuple[int, int]] = [] - i = len(messages) - 1 - while i >= 0: - if ( - messages[i].get("role") == "tool" - and i > 0 - and messages[i - 1].get("role") == "assistant" - ): - assistant_tool_pairs.append((i - 1, i)) - i -= 2 - else: - i -= 1 - - included_pairs: list[tuple[int, int]] = [] - for assistant_idx, tool_idx in assistant_tool_pairs: - pair_tokens = ( - message_token_counts[assistant_idx] + message_token_counts[tool_idx] - ) - if pair_tokens <= remaining_tokens: - included_pairs.append((assistant_idx, tool_idx)) - remaining_tokens -= pair_tokens - else: - break - - include_user = False - if ( - user_msg_idx is not None - and len(included_pairs) == len(assistant_tool_pairs) - and message_token_counts[user_msg_idx] <= remaining_tokens - ): - include_user = True - - if include_user: - result.append(messages[user_msg_idx]) - - included_pairs.sort(key=lambda pair: pair[0]) - for assistant_idx, tool_idx in included_pairs: - result.append(messages[assistant_idx]) - result.append(messages[tool_idx]) - - assert ( - result - ), f"After trimming, no messages fit within context length: {context_limit}!" - - return result diff --git a/tests/llms/test_huggingface.py b/tests/llms/test_huggingface.py index 48d288db..647dabc3 100644 --- a/tests/llms/test_huggingface.py +++ b/tests/llms/test_huggingface.py @@ -1,3 +1,4 @@ +import json from unittest.mock import MagicMock, patch from debug_gym.gym.tools.tool import ToolCall @@ -44,6 +45,168 @@ def test_huggingface_tokenizer_usage(mock_auto_tokenizer, mock_llm_config, logge assert tokenizer_mock.pad_token == "" +@patch.object( + LLMConfigRegistry, + "from_file", + return_value=LLMConfigRegistry.register_all( + { + "qwen-3": { + "model": "qwen-3", + "tokenizer": "Qwen/Qwen3", + "context_limit": 4, + "api_key": "test-api-key", + "endpoint": "https://test-endpoint", + "tags": ["vllm"], + } + } + ), +) +@patch("debug_gym.llms.huggingface.AutoTokenizer.from_pretrained") +def test_huggingface_normalizes_messages_for_template( + mock_auto_tokenizer, mock_llm_config, logger_mock +): + tokenizer_mock = MagicMock() + tokenizer_mock.pad_token = None + tokenizer_mock.eos_token = "" + mock_auto_tokenizer.return_value = tokenizer_mock + + llm = HuggingFaceLLM(model_name="qwen-3", logger=logger_mock) + + raw_messages = [ + {"role": "tool", "content": "partial output"}, + { + "role": "developer", + "content": [{"text": "line1"}, {"text": "line2"}], + }, + { + "role": "assistant", + "content": None, + "tool_calls": [{"type": "function", "name": "noop", "arguments": {}}], + }, + {"role": "user", "content": None}, + ] + + normalized = llm._normalize_messages_for_template(raw_messages) + + assert normalized == [ + {"role": "user", "content": "partial output"}, + {"role": "user", "content": "line1\nline2"}, + { + "role": "assistant", + "content": json.dumps( + [{"type": "function", "name": "noop", "arguments": {}}] + ), + }, + {"role": "user", "content": ""}, + ] + + +@patch.object( + LLMConfigRegistry, + "from_file", + return_value=LLMConfigRegistry.register_all( + { + "qwen-3": { + "model": "qwen-3", + "tokenizer": "Qwen/Qwen3", + "context_limit": 4, + "api_key": "test-api-key", + "endpoint": "https://test-endpoint", + "tags": ["vllm"], + } + } + ), +) +@patch("debug_gym.llms.huggingface.AutoTokenizer.from_pretrained") +def test_huggingface_chat_template_token_counts( + mock_auto_tokenizer, mock_llm_config, logger_mock +): + tokenizer_mock = MagicMock() + tokenizer_mock.pad_token = None + tokenizer_mock.eos_token = "" + + def fake_apply_chat_template(messages, tokenize=True, add_generation_prompt=False): + length = len(messages) + return list(range(length * 2)) + + tokenizer_mock.apply_chat_template.side_effect = fake_apply_chat_template + tokenizer_mock.encode.return_value = [] + mock_auto_tokenizer.return_value = tokenizer_mock + + llm = HuggingFaceLLM(model_name="qwen-3", logger=logger_mock) + + messages = [ + {"role": "system", "content": "Instructions"}, + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there"}, + {"role": "tool", "content": "Result"}, + ] + + counts = llm._get_message_token_counts(messages) + + assert counts == [2, 2, 2, 2] + assert tokenizer_mock.apply_chat_template.call_count == len(messages) + normalized_final = tokenizer_mock.apply_chat_template.call_args_list[-1][0][0] + assert normalized_final[-1]["role"] == "user" + assert normalized_final[-1]["content"] == "Result" + + +@patch.object( + LLMConfigRegistry, + "from_file", + return_value=LLMConfigRegistry.register_all( + { + "qwen-3": { + "model": "qwen-3", + "tokenizer": "Qwen/Qwen3", + "context_limit": 4, + "api_key": "test-api-key", + "endpoint": "https://test-endpoint", + "tags": ["vllm"], + } + } + ), +) +@patch("debug_gym.llms.huggingface.AutoTokenizer.from_pretrained") +def test_huggingface_chat_template_zero_token_fallback( + mock_auto_tokenizer, mock_llm_config, logger_mock +): + tokenizer_mock = MagicMock() + tokenizer_mock.pad_token = None + tokenizer_mock.eos_token = "" + tokenizer_mock.apply_chat_template.return_value = [] + tokenizer_mock.encode.side_effect = [[1, 2], [3, 4, 5]] + mock_auto_tokenizer.return_value = tokenizer_mock + + llm = HuggingFaceLLM(model_name="qwen-3", logger=logger_mock) + + messages = [ + {"role": "system", "content": "Instructions"}, + {"role": "user", "content": "Hello"}, + ] + + counts = llm._get_message_token_counts(messages) + + assert counts == [2, 3] + assert tokenizer_mock.encode.call_count == len(messages) + + +@patch.object( + LLMConfigRegistry, + "from_file", + return_value=LLMConfigRegistry.register_all( + { + "qwen-3": { + "model": "qwen-3", + "tokenizer": "Qwen/Qwen3", + "context_limit": 4, + "api_key": "test-api-key", + "endpoint": "https://test-endpoint", + "tags": ["vllm"], + } + } + ), +) @patch.object(HuggingFaceLLM, "generate") @patch.object( LLMConfigRegistry, From d956a886632656a08fc2398964e7b6e575e20811 Mon Sep 17 00:00:00 2001 From: "Xingdi (Eric) Yuan" Date: Sun, 5 Oct 2025 00:35:53 -0400 Subject: [PATCH 03/11] Update test_huggingface.py --- tests/llms/test_huggingface.py | 185 +++++++-------------------------- 1 file changed, 38 insertions(+), 147 deletions(-) diff --git a/tests/llms/test_huggingface.py b/tests/llms/test_huggingface.py index 647dabc3..c70f4567 100644 --- a/tests/llms/test_huggingface.py +++ b/tests/llms/test_huggingface.py @@ -1,30 +1,32 @@ import json from unittest.mock import MagicMock, patch -from debug_gym.gym.tools.tool import ToolCall from debug_gym.llms import HuggingFaceLLM -from debug_gym.llms.base import LLMConfigRegistry, LLMResponse +from debug_gym.llms.base import LLMConfigRegistry +from debug_gym.llms.openai import OpenAILLM + +MODEL_REGISTRY = { + "qwen-3": { + "model": "qwen-3", + "tokenizer": "Qwen/Qwen3", + "context_limit": 4, + "api_key": "test-api-key", + "endpoint": "https://test-endpoint", + "tags": ["vllm"], + "tokenizer_kwargs": {"trust_remote_code": True}, + } +} @patch.object( LLMConfigRegistry, "from_file", - return_value=LLMConfigRegistry.register_all( - { - "qwen-3": { - "model": "qwen-3", - "tokenizer": "Qwen/Qwen3", - "context_limit": 4, - "api_key": "test-api-key", - "endpoint": "https://test-endpoint", - "tags": ["vllm"], - "tokenizer_kwargs": {"trust_remote_code": True}, - } - } - ), + return_value=LLMConfigRegistry.register_all(MODEL_REGISTRY), ) @patch("debug_gym.llms.huggingface.AutoTokenizer.from_pretrained") -def test_huggingface_tokenizer_usage(mock_auto_tokenizer, mock_llm_config, logger_mock): +def test_tokenize_uses_hf_tokenizer_with_pad_fallback( + mock_auto_tokenizer, mock_llm_config, logger_mock +): tokenizer_mock = MagicMock() tokenizer_mock.encode.return_value = [10, 20, 30] tokenizer_mock.convert_ids_to_tokens.return_value = ["", "", ""] @@ -40,29 +42,16 @@ def test_huggingface_tokenizer_usage(mock_auto_tokenizer, mock_llm_config, logge mock_auto_tokenizer.assert_called_once_with("Qwen/Qwen3", trust_remote_code=True) tokenizer_mock.encode.assert_called_with("hello world", add_special_tokens=False) - - # Ensure pad token fallback was applied assert tokenizer_mock.pad_token == "" @patch.object( LLMConfigRegistry, "from_file", - return_value=LLMConfigRegistry.register_all( - { - "qwen-3": { - "model": "qwen-3", - "tokenizer": "Qwen/Qwen3", - "context_limit": 4, - "api_key": "test-api-key", - "endpoint": "https://test-endpoint", - "tags": ["vllm"], - } - } - ), + return_value=LLMConfigRegistry.register_all(MODEL_REGISTRY), ) @patch("debug_gym.llms.huggingface.AutoTokenizer.from_pretrained") -def test_huggingface_normalizes_messages_for_template( +def test_normalize_messages_for_chat_template( mock_auto_tokenizer, mock_llm_config, logger_mock ): tokenizer_mock = MagicMock() @@ -104,33 +93,20 @@ def test_huggingface_normalizes_messages_for_template( @patch.object( LLMConfigRegistry, "from_file", - return_value=LLMConfigRegistry.register_all( - { - "qwen-3": { - "model": "qwen-3", - "tokenizer": "Qwen/Qwen3", - "context_limit": 4, - "api_key": "test-api-key", - "endpoint": "https://test-endpoint", - "tags": ["vllm"], - } - } - ), + return_value=LLMConfigRegistry.register_all(MODEL_REGISTRY), ) @patch("debug_gym.llms.huggingface.AutoTokenizer.from_pretrained") -def test_huggingface_chat_template_token_counts( +def test_message_token_counts_uses_chat_template( mock_auto_tokenizer, mock_llm_config, logger_mock ): tokenizer_mock = MagicMock() tokenizer_mock.pad_token = None tokenizer_mock.eos_token = "" - - def fake_apply_chat_template(messages, tokenize=True, add_generation_prompt=False): - length = len(messages) - return list(range(length * 2)) - - tokenizer_mock.apply_chat_template.side_effect = fake_apply_chat_template - tokenizer_mock.encode.return_value = [] + tokenizer_mock.apply_chat_template.side_effect = [ + {"input_ids": [[1, 2]]}, + {"input_ids": [[1, 2, 3]]}, + {"input_ids": [[1, 2, 3, 4]]}, + ] mock_auto_tokenizer.return_value = tokenizer_mock llm = HuggingFaceLLM(model_name="qwen-3", logger=logger_mock) @@ -138,44 +114,32 @@ def fake_apply_chat_template(messages, tokenize=True, add_generation_prompt=Fals messages = [ {"role": "system", "content": "Instructions"}, {"role": "user", "content": "Hello"}, - {"role": "assistant", "content": "Hi there"}, {"role": "tool", "content": "Result"}, ] counts = llm._get_message_token_counts(messages) - assert counts == [2, 2, 2, 2] + assert counts == [2, 1, 1] assert tokenizer_mock.apply_chat_template.call_count == len(messages) - normalized_final = tokenizer_mock.apply_chat_template.call_args_list[-1][0][0] - assert normalized_final[-1]["role"] == "user" - assert normalized_final[-1]["content"] == "Result" + final_normalized = tokenizer_mock.apply_chat_template.call_args_list[-1][0][0] + assert final_normalized[-1]["role"] == "user" + assert final_normalized[-1]["content"] == "Result" @patch.object( LLMConfigRegistry, "from_file", - return_value=LLMConfigRegistry.register_all( - { - "qwen-3": { - "model": "qwen-3", - "tokenizer": "Qwen/Qwen3", - "context_limit": 4, - "api_key": "test-api-key", - "endpoint": "https://test-endpoint", - "tags": ["vllm"], - } - } - ), + return_value=LLMConfigRegistry.register_all(MODEL_REGISTRY), ) @patch("debug_gym.llms.huggingface.AutoTokenizer.from_pretrained") -def test_huggingface_chat_template_zero_token_fallback( - mock_auto_tokenizer, mock_llm_config, logger_mock +@patch.object(OpenAILLM, "_get_message_token_counts", return_value=[5, 6]) +def test_message_token_counts_fallbacks_to_openai_when_template_fails( + mock_super_counts, mock_auto_tokenizer, mock_llm_config, logger_mock ): tokenizer_mock = MagicMock() tokenizer_mock.pad_token = None tokenizer_mock.eos_token = "" - tokenizer_mock.apply_chat_template.return_value = [] - tokenizer_mock.encode.side_effect = [[1, 2], [3, 4, 5]] + tokenizer_mock.apply_chat_template.side_effect = ValueError("no template") mock_auto_tokenizer.return_value = tokenizer_mock llm = HuggingFaceLLM(model_name="qwen-3", logger=logger_mock) @@ -187,78 +151,5 @@ def test_huggingface_chat_template_zero_token_fallback( counts = llm._get_message_token_counts(messages) - assert counts == [2, 3] - assert tokenizer_mock.encode.call_count == len(messages) - - -@patch.object( - LLMConfigRegistry, - "from_file", - return_value=LLMConfigRegistry.register_all( - { - "qwen-3": { - "model": "qwen-3", - "tokenizer": "Qwen/Qwen3", - "context_limit": 4, - "api_key": "test-api-key", - "endpoint": "https://test-endpoint", - "tags": ["vllm"], - } - } - ), -) -@patch.object(HuggingFaceLLM, "generate") -@patch.object( - LLMConfigRegistry, - "from_file", - return_value=LLMConfigRegistry.register_all( - { - "qwen-3": { - "model": "qwen-3", - "tokenizer": "Qwen/Qwen3", - "context_limit": 4, - "api_key": "test-api-key", - "endpoint": "https://test-endpoint", - "tags": ["vllm"], - "tokenizer_kwargs": {"trust_remote_code": True}, - } - } - ), -) -@patch("debug_gym.llms.huggingface.AutoTokenizer.from_pretrained") -def test_huggingface_chat_template_usage( - mock_auto_tokenizer, mock_llm_config, mock_generate, logger_mock -): - tokenizer_mock = MagicMock() - tokenizer_mock.pad_token = None - tokenizer_mock.eos_token = "" - tokenizer_mock.convert_ids_to_tokens.side_effect = lambda ids: [ - f"<{i}>" for i in ids - ] - tokenizer_mock.apply_chat_template.side_effect = [ - {"input_ids": [[1, 2]]}, - {"input_ids": [[1, 2, 3, 4]]}, - ] - tokenizer_mock.encode.return_value = [99] - mock_auto_tokenizer.return_value = tokenizer_mock - - mock_generate.return_value = LLMResponse( - prompt=[], - response="ok", - tool=ToolCall(id="t1", name="noop", arguments={}), - prompt_token_count=5, - response_token_count=2, - ) - - llm = HuggingFaceLLM(model_name="qwen-3", logger=logger_mock) - - messages = [ - {"role": "system", "content": "You are helpful."}, - {"role": "user", "content": "Hello"}, - ] - - response = llm(messages, tools=[]) - - assert response.response == "ok" - assert tokenizer_mock.apply_chat_template.call_count == 2 - mock_generate.assert_called_once() + assert counts == [5, 6] + mock_super_counts.assert_called_once_with(messages) From 86ea9bb18aec555844b1ec4f308dc0905072cc93 Mon Sep 17 00:00:00 2001 From: "Xingdi (Eric) Yuan" Date: Sun, 5 Oct 2025 00:41:10 -0400 Subject: [PATCH 04/11] Update huggingface.py --- debug_gym/llms/huggingface.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/debug_gym/llms/huggingface.py b/debug_gym/llms/huggingface.py index 0a08c755..ed1819f4 100644 --- a/debug_gym/llms/huggingface.py +++ b/debug_gym/llms/huggingface.py @@ -103,9 +103,7 @@ def _normalize_messages_for_template(self, messages: Iterable[dict]) -> list[dic normalized = [] for message in messages: role = message.get("role", "user") - if role == "tool": - role = "user" - elif role not in {"system", "user", "assistant"}: + if role not in {"system", "user", "assistant"}: role = "user" content = message.get("content") From ab4999c49cca35964a6de8c2070410eb52bd9a46 Mon Sep 17 00:00:00 2001 From: "Xingdi (Eric) Yuan" Date: Sun, 5 Oct 2025 00:53:21 -0400 Subject: [PATCH 05/11] test w real tokenizer --- pytest.ini | 4 +- tests/llms/test_huggingface.py | 72 +++++++++++++++++++++++++++++++++- 2 files changed, 74 insertions(+), 2 deletions(-) diff --git a/pytest.ini b/pytest.ini index 7654985f..dbde339b 100644 --- a/pytest.ini +++ b/pytest.ini @@ -2,4 +2,6 @@ norecursedirs = data/* asyncio_default_fixture_loop_scope = function env = - BASH_SILENCE_DEPRECATION_WARNING=1 # Suppress deprecation warnings in MacOS (zsh default) \ No newline at end of file + BASH_SILENCE_DEPRECATION_WARNING=1 # Suppress deprecation warnings in MacOS (zsh default) +markers = + hf_tokenizer: tests that require downloading a Hugging Face tokenizer and may be slow \ No newline at end of file diff --git a/tests/llms/test_huggingface.py b/tests/llms/test_huggingface.py index c70f4567..acf2cb49 100644 --- a/tests/llms/test_huggingface.py +++ b/tests/llms/test_huggingface.py @@ -1,10 +1,17 @@ import json from unittest.mock import MagicMock, patch +import pytest +from transformers import AutoTokenizer + from debug_gym.llms import HuggingFaceLLM -from debug_gym.llms.base import LLMConfigRegistry +from debug_gym.llms.base import LLMConfig, LLMConfigRegistry from debug_gym.llms.openai import OpenAILLM +# Run these tests with `pytest tests/llms/test_huggingface.py -m hf_tokenizer` +# to include the integration case that downloads the real Qwen tokenizer. + + MODEL_REGISTRY = { "qwen-3": { "model": "qwen-3", @@ -17,6 +24,20 @@ } } +REAL_TOKENIZER_ID = "Qwen/Qwen3-0.6B" + + +@pytest.fixture(scope="session") +def real_qwen3_tokenizer(): + try: + return AutoTokenizer.from_pretrained(REAL_TOKENIZER_ID, trust_remote_code=True) + except ( + OSError, + ValueError, + ImportError, + ) as exc: # pragma: no cover - network-dependent + pytest.skip(f"Unable to load tokenizer {REAL_TOKENIZER_ID}: {exc}") + @patch.object( LLMConfigRegistry, @@ -153,3 +174,52 @@ def test_message_token_counts_fallbacks_to_openai_when_template_fails( assert counts == [5, 6] mock_super_counts.assert_called_once_with(messages) + + +@pytest.mark.hf_tokenizer +def test_chat_template_counts_with_real_tokenizer(real_qwen3_tokenizer, logger_mock): + config = LLMConfig( + model="qwen-3", + tokenizer=REAL_TOKENIZER_ID, + context_limit=4, + api_key="placeholder", + endpoint="http://localhost", + tags=["vllm"], + tokenizer_kwargs={"trust_remote_code": True}, + ) + + llm = HuggingFaceLLM(model_name="qwen-3", logger=logger_mock, llm_config=config) + llm._hf_tokenizer = real_qwen3_tokenizer + + messages = [ + {"role": "system", "content": "Instructions"}, + {"role": "user", "content": "Hello"}, + {"role": "tool", "content": "Result"}, + ] + + counts = llm._get_message_token_counts(messages) + + normalized = llm._normalize_messages_for_template(messages) + expected_counts = [] + prev_len = 0 + for idx in range(1, len(normalized) + 1): + try: + tokenized = real_qwen3_tokenizer.apply_chat_template( + normalized[:idx], tokenize=True, add_generation_prompt=False + ) + except TypeError: # pragma: no cover - version-specific + tokenized = real_qwen3_tokenizer.apply_chat_template( + normalized[:idx], tokenize=True + ) + token_ids = ( + tokenized.get("input_ids") if isinstance(tokenized, dict) else tokenized + ) + if token_ids and isinstance(token_ids[0], list): + token_ids = token_ids[0] + if token_ids is None: + pytest.skip("Tokenizer did not return token ids") + expected_counts.append(len(token_ids) - prev_len) + prev_len = len(token_ids) + + assert counts == expected_counts + assert counts[-1] > 0 From c2c8001415997f45c12bed8b70ad0871c7904898 Mon Sep 17 00:00:00 2001 From: "Xingdi (Eric) Yuan" Date: Sun, 5 Oct 2025 00:58:57 -0400 Subject: [PATCH 06/11] Update test_huggingface.py --- tests/llms/test_huggingface.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/tests/llms/test_huggingface.py b/tests/llms/test_huggingface.py index acf2cb49..bc2d274c 100644 --- a/tests/llms/test_huggingface.py +++ b/tests/llms/test_huggingface.py @@ -223,3 +223,29 @@ def test_chat_template_counts_with_real_tokenizer(real_qwen3_tokenizer, logger_m assert counts == expected_counts assert counts[-1] > 0 + + +@pytest.mark.hf_tokenizer +def test_tokenize_and_count_tokens_with_real_tokenizer( + real_qwen3_tokenizer, logger_mock +): + config = LLMConfig( + model="qwen-3", + tokenizer=REAL_TOKENIZER_ID, + context_limit=4, + api_key="placeholder", + endpoint="http://localhost", + tags=["vllm"], + tokenizer_kwargs={"trust_remote_code": True}, + ) + + llm = HuggingFaceLLM(model_name="qwen-3", logger=logger_mock, llm_config=config) + llm._hf_tokenizer = real_qwen3_tokenizer + + text = "Hello world!" + hf_ids = real_qwen3_tokenizer.encode(text, add_special_tokens=False) + hf_tokens = real_qwen3_tokenizer.convert_ids_to_tokens(hf_ids) + + tokens = llm.tokenize(text) + assert tokens == hf_tokens + assert llm.count_tokens(text) == len(hf_ids) From 9a790c598b4b2bbbf4b62c4f038124128236582b Mon Sep 17 00:00:00 2001 From: Matheus Pereira Date: Wed, 15 Oct 2025 14:01:50 -0700 Subject: [PATCH 07/11] Add apply_chat_template and enable_thinking options to LLMConfig and OpenAILLM --- debug_gym/llms/base.py | 4 ++ debug_gym/llms/openai.py | 19 +++++++-- tests/llms/test_openai.py | 89 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 109 insertions(+), 3 deletions(-) diff --git a/debug_gym/llms/base.py b/debug_gym/llms/base.py index b0f14139..81a0faf6 100644 --- a/debug_gym/llms/base.py +++ b/debug_gym/llms/base.py @@ -57,6 +57,8 @@ class LLMConfig: api_key: Optional[str] = None endpoint: Optional[str] = None tokenizer: Optional[str] = None + apply_chat_template: Optional[bool] = False + enable_thinking: Optional[bool] = False reasoning_end_token: Optional[str] = None system_prompt_support: bool = True ignore_kwargs: List[str] = None @@ -201,6 +203,8 @@ def __init__( ) self.tokenizer_name = self.config.tokenizer self.context_length = self.config.context_limit * 1000 + self.apply_chat_template = self.config.apply_chat_template + self.enable_thinking = self.config.enable_thinking self.reasoning_end_token = self.config.reasoning_end_token self.logger.debug( diff --git a/debug_gym/llms/openai.py b/debug_gym/llms/openai.py index 7fa00051..397b29c3 100644 --- a/debug_gym/llms/openai.py +++ b/debug_gym/llms/openai.py @@ -68,9 +68,22 @@ def tokenize(self, text: str) -> list[str]: self._tk_func = tiktoken.encoding_for_model(self.tokenizer_name).encode except KeyError: try: # Try to load from transformers. - self._tk_func = AutoTokenizer.from_pretrained( - self.tokenizer_name - ).tokenize + tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name) + if self.apply_chat_template: + + def _tokenize(txt): + return tokenizer.tokenize( + tokenizer.apply_chat_template( + txt, + tokenize=False, + add_generation_prompt=True, + enable_thinking=self.enable_thinking, + ) + ) + + self._tk_func = _tokenize + else: + self._tk_func = tokenizer.tokenize except OSError: raise ValueError( f"Tokenizer `{self.tokenizer_name}` not found for model " diff --git a/tests/llms/test_openai.py b/tests/llms/test_openai.py index bda76b94..7787a733 100644 --- a/tests/llms/test_openai.py +++ b/tests/llms/test_openai.py @@ -532,3 +532,92 @@ def test_llm_without_reasoning_content_attribute( # The response should be just the regular content assert llm_response.response == "Regular response only" + + +@patch.object( + LLMConfigRegistry, + "from_file", + return_value=LLMConfigRegistry.register_all( + { + "qwen": { + "model": "Qwen/Qwen3-0.6B", + "tokenizer": "Qwen/Qwen3-0.6B", + "apply_chat_template": False, + "context_limit": 4096, + "api_key": "fake", + "endpoint": "fake", + "api_version": "1", + "tags": ["vllm"], + } + } + ), +) +def test_hf_tokenize_no_chat_template(mock_llm_config, logger_mock): + llm = OpenAILLM(model_name="qwen", logger=logger_mock) + tokens = llm.tokenize("hello world") + assert tokens == ["hello", "Ġworld"] + + +@patch.object( + LLMConfigRegistry, + "from_file", + return_value=LLMConfigRegistry.register_all( + { + "qwen": { + "model": "Qwen/Qwen3-0.6B", + "tokenizer": "Qwen/Qwen3-0.6B", + "apply_chat_template": True, + "context_limit": 4096, + "api_key": "fake", + "endpoint": "fake", + "api_version": "1", + "tags": ["vllm"], + } + } + ), +) +def test_hf_tokenize_apply_chat_template(mock_llm_config, logger_mock): + llm = OpenAILLM(model_name="qwen", logger=logger_mock) + + tokens = llm.tokenize("hello world") + + assert tokens == [ + "<|im_start|>", + "assistant", + "Ċ", + "", + "ĊĊ", + "", + "ĊĊ", + ] + + +@patch.object( + LLMConfigRegistry, + "from_file", + return_value=LLMConfigRegistry.register_all( + { + "qwen": { + "model": "Qwen/Qwen3-0.6B", + "tokenizer": "Qwen/Qwen3-0.6B", + "apply_chat_template": True, + "enable_thinking": True, + "context_limit": 4096, + "api_key": "fake", + "endpoint": "fake", + "api_version": "1", + "tags": ["vllm"], + } + } + ), +) +def test_hf_tokenize_apply_chat_template_thinking(mock_llm_config, logger_mock): + llm = OpenAILLM(model_name="qwen", logger=logger_mock) + + tokens = llm.tokenize("hello world") + + assert tokens == [ + "<|im_start|>", + "assistant", + "Ċ", + ] From 1cac5fc8b71364891fe2bb817d308cd4de14af23 Mon Sep 17 00:00:00 2001 From: Matheus Pereira Date: Wed, 15 Oct 2025 18:32:24 -0700 Subject: [PATCH 08/11] Update llms tests --- debug_gym/llms/base.py | 6 ++ debug_gym/llms/openai.py | 3 +- tests/llms/test_base.py | 4 +- tests/llms/test_huggingface.py | 142 ++++++++++++++++++++++++++------- tests/llms/test_openai.py | 89 --------------------- 5 files changed, 122 insertions(+), 122 deletions(-) diff --git a/debug_gym/llms/base.py b/debug_gym/llms/base.py index b952f749..96e8c7ca 100644 --- a/debug_gym/llms/base.py +++ b/debug_gym/llms/base.py @@ -245,6 +245,7 @@ def instantiate( llm_config = LLMConfigRegistry.from_file(llm_config_file_path)[llm_name] tags = llm_config.tags + if "copilot openai" in tags: from debug_gym.llms.copilot import CopilotOpenAILLM @@ -254,22 +255,27 @@ def instantiate( from debug_gym.llms.copilot import CopilotClaudeLLM klass = CopilotClaudeLLM + elif "azure openai" in tags: from debug_gym.llms import AzureOpenAILLM klass = AzureOpenAILLM + elif "vllm" in tags: from debug_gym.llms import HuggingFaceLLM klass = HuggingFaceLLM + elif "anthropic" in tags: from debug_gym.llms import AnthropicLLM klass = AnthropicLLM + else: from debug_gym.llms import OpenAILLM klass = OpenAILLM + llm = klass(llm_name, logger=logger, llm_config=llm_config) return llm diff --git a/debug_gym/llms/openai.py b/debug_gym/llms/openai.py index 397b29c3..e0606a2e 100644 --- a/debug_gym/llms/openai.py +++ b/debug_gym/llms/openai.py @@ -67,7 +67,8 @@ def tokenize(self, text: str) -> list[str]: try: self._tk_func = tiktoken.encoding_for_model(self.tokenizer_name).encode except KeyError: - try: # Try to load from transformers. + # Try to load from transformers, mostly deprecated. Use HuggingFaceLLM for transformers models. + try: tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name) if self.apply_chat_template: diff --git a/tests/llms/test_base.py b/tests/llms/test_base.py index 11d8eb1d..98f43811 100644 --- a/tests/llms/test_base.py +++ b/tests/llms/test_base.py @@ -54,8 +54,8 @@ "tags": ["anthropic", "claude", "claude-3.7"], }, "qwen-3": { - "model": "qwen-3", - "tokenizer": "Qwen/Qwen3", + "model": "Qwen/Qwen3-0.6B", + "tokenizer": "Qwen/Qwen3-0.6B", "context_limit": 4, "api_key": "test-api-key", "endpoint": "https://test-endpoint", diff --git a/tests/llms/test_huggingface.py b/tests/llms/test_huggingface.py index bc2d274c..bf2b5345 100644 --- a/tests/llms/test_huggingface.py +++ b/tests/llms/test_huggingface.py @@ -12,10 +12,12 @@ # to include the integration case that downloads the real Qwen tokenizer. +HF_MODEL_ID = "Qwen/Qwen3-0.6B" + MODEL_REGISTRY = { "qwen-3": { - "model": "qwen-3", - "tokenizer": "Qwen/Qwen3", + "model": HF_MODEL_ID, + "tokenizer": HF_MODEL_ID, "context_limit": 4, "api_key": "test-api-key", "endpoint": "https://test-endpoint", @@ -24,19 +26,17 @@ } } -REAL_TOKENIZER_ID = "Qwen/Qwen3-0.6B" - @pytest.fixture(scope="session") def real_qwen3_tokenizer(): try: - return AutoTokenizer.from_pretrained(REAL_TOKENIZER_ID, trust_remote_code=True) + return AutoTokenizer.from_pretrained(HF_MODEL_ID) except ( OSError, ValueError, ImportError, ) as exc: # pragma: no cover - network-dependent - pytest.skip(f"Unable to load tokenizer {REAL_TOKENIZER_ID}: {exc}") + pytest.skip(f"Unable to load tokenizer {HF_MODEL_ID}: {exc}") @patch.object( @@ -44,26 +44,19 @@ def real_qwen3_tokenizer(): "from_file", return_value=LLMConfigRegistry.register_all(MODEL_REGISTRY), ) -@patch("debug_gym.llms.huggingface.AutoTokenizer.from_pretrained") -def test_tokenize_uses_hf_tokenizer_with_pad_fallback( - mock_auto_tokenizer, mock_llm_config, logger_mock -): - tokenizer_mock = MagicMock() - tokenizer_mock.encode.return_value = [10, 20, 30] - tokenizer_mock.convert_ids_to_tokens.return_value = ["", "", ""] - tokenizer_mock.pad_token = None - tokenizer_mock.eos_token = "" - mock_auto_tokenizer.return_value = tokenizer_mock - - llm = HuggingFaceLLM(model_name="qwen-3", logger=logger_mock) - - tokens = llm.tokenize("hello world") - assert tokens == ["", "", ""] - assert llm.count_tokens("hello world") == 3 - - mock_auto_tokenizer.assert_called_once_with("Qwen/Qwen3", trust_remote_code=True) - tokenizer_mock.encode.assert_called_with("hello world", add_special_tokens=False) - assert tokenizer_mock.pad_token == "" +def test_tokenize_uses_hf_tokenizer_with_pad_fallback(mock_llm_config, logger_mock): + tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_ID) + tokenizer.pad_token = None + tokenizer.eos_token = "" + with patch( + "debug_gym.llms.huggingface.AutoTokenizer.from_pretrained" + ) as mock_auto_tokenizer: + mock_auto_tokenizer.return_value = tokenizer + llm = HuggingFaceLLM(model_name="qwen-3", logger=logger_mock) + assert llm.tokenize("hello world") == ["hello", "Ġworld"] + assert llm.count_tokens("hello world") == 2 + assert tokenizer.eos_token == "" + assert tokenizer.pad_token == "" @patch.object( @@ -179,8 +172,8 @@ def test_message_token_counts_fallbacks_to_openai_when_template_fails( @pytest.mark.hf_tokenizer def test_chat_template_counts_with_real_tokenizer(real_qwen3_tokenizer, logger_mock): config = LLMConfig( - model="qwen-3", - tokenizer=REAL_TOKENIZER_ID, + model=HF_MODEL_ID, + tokenizer=HF_MODEL_ID, context_limit=4, api_key="placeholder", endpoint="http://localhost", @@ -230,8 +223,8 @@ def test_tokenize_and_count_tokens_with_real_tokenizer( real_qwen3_tokenizer, logger_mock ): config = LLMConfig( - model="qwen-3", - tokenizer=REAL_TOKENIZER_ID, + model=HF_MODEL_ID, + tokenizer=HF_MODEL_ID, context_limit=4, api_key="placeholder", endpoint="http://localhost", @@ -249,3 +242,92 @@ def test_tokenize_and_count_tokens_with_real_tokenizer( tokens = llm.tokenize(text) assert tokens == hf_tokens assert llm.count_tokens(text) == len(hf_ids) + + +@patch.object( + LLMConfigRegistry, + "from_file", + return_value=LLMConfigRegistry.register_all( + { + "qwen": { + "model": HF_MODEL_ID, + "tokenizer": HF_MODEL_ID, + "apply_chat_template": False, + "context_limit": 4096, + "api_key": "fake", + "endpoint": "fake", + "api_version": "1", + "tags": ["vllm"], + } + } + ), +) +def test_hf_tokenize_no_chat_template(mock_llm_config, logger_mock): + llm = HuggingFaceLLM(model_name="qwen", logger=logger_mock) + tokens = llm.tokenize("hello world") + assert tokens == ["hello", "Ġworld"] + + +@patch.object( + LLMConfigRegistry, + "from_file", + return_value=LLMConfigRegistry.register_all( + { + "qwen": { + "model": HF_MODEL_ID, + "tokenizer": HF_MODEL_ID, + "apply_chat_template": True, + "context_limit": 4096, + "api_key": "fake", + "endpoint": "fake", + "api_version": "1", + "tags": ["vllm"], + } + } + ), +) +def test_hf_tokenize_apply_chat_template(mock_llm_config, logger_mock): + llm = HuggingFaceLLM(model_name="qwen", logger=logger_mock) + + tokens = llm.tokenize("hello world") + + assert tokens == [ + "<|im_start|>", + "assistant", + "Ċ", + "", + "ĊĊ", + "", + "ĊĊ", + ] + + +@patch.object( + LLMConfigRegistry, + "from_file", + return_value=LLMConfigRegistry.register_all( + { + "qwen": { + "model": HF_MODEL_ID, + "tokenizer": HF_MODEL_ID, + "apply_chat_template": True, + "enable_thinking": True, + "context_limit": 4096, + "api_key": "fake", + "endpoint": "fake", + "api_version": "1", + "tags": ["vllm"], + } + } + ), +) +def test_hf_tokenize_apply_chat_template_thinking(mock_llm_config, logger_mock): + llm = HuggingFaceLLM(model_name="qwen", logger=logger_mock) + + tokens = llm.tokenize("hello world") + + assert tokens == [ + "<|im_start|>", + "assistant", + "Ċ", + ] diff --git a/tests/llms/test_openai.py b/tests/llms/test_openai.py index 7787a733..bda76b94 100644 --- a/tests/llms/test_openai.py +++ b/tests/llms/test_openai.py @@ -532,92 +532,3 @@ def test_llm_without_reasoning_content_attribute( # The response should be just the regular content assert llm_response.response == "Regular response only" - - -@patch.object( - LLMConfigRegistry, - "from_file", - return_value=LLMConfigRegistry.register_all( - { - "qwen": { - "model": "Qwen/Qwen3-0.6B", - "tokenizer": "Qwen/Qwen3-0.6B", - "apply_chat_template": False, - "context_limit": 4096, - "api_key": "fake", - "endpoint": "fake", - "api_version": "1", - "tags": ["vllm"], - } - } - ), -) -def test_hf_tokenize_no_chat_template(mock_llm_config, logger_mock): - llm = OpenAILLM(model_name="qwen", logger=logger_mock) - tokens = llm.tokenize("hello world") - assert tokens == ["hello", "Ġworld"] - - -@patch.object( - LLMConfigRegistry, - "from_file", - return_value=LLMConfigRegistry.register_all( - { - "qwen": { - "model": "Qwen/Qwen3-0.6B", - "tokenizer": "Qwen/Qwen3-0.6B", - "apply_chat_template": True, - "context_limit": 4096, - "api_key": "fake", - "endpoint": "fake", - "api_version": "1", - "tags": ["vllm"], - } - } - ), -) -def test_hf_tokenize_apply_chat_template(mock_llm_config, logger_mock): - llm = OpenAILLM(model_name="qwen", logger=logger_mock) - - tokens = llm.tokenize("hello world") - - assert tokens == [ - "<|im_start|>", - "assistant", - "Ċ", - "", - "ĊĊ", - "", - "ĊĊ", - ] - - -@patch.object( - LLMConfigRegistry, - "from_file", - return_value=LLMConfigRegistry.register_all( - { - "qwen": { - "model": "Qwen/Qwen3-0.6B", - "tokenizer": "Qwen/Qwen3-0.6B", - "apply_chat_template": True, - "enable_thinking": True, - "context_limit": 4096, - "api_key": "fake", - "endpoint": "fake", - "api_version": "1", - "tags": ["vllm"], - } - } - ), -) -def test_hf_tokenize_apply_chat_template_thinking(mock_llm_config, logger_mock): - llm = OpenAILLM(model_name="qwen", logger=logger_mock) - - tokens = llm.tokenize("hello world") - - assert tokens == [ - "<|im_start|>", - "assistant", - "Ċ", - ] From f7320628a94664846084664b9f30bf2519e69328 Mon Sep 17 00:00:00 2001 From: Matheus Pereira Date: Wed, 15 Oct 2025 18:33:22 -0700 Subject: [PATCH 09/11] Add apply_chat_template and cache tokenizer for HuggingFaceLLM --- debug_gym/llms/huggingface.py | 37 ++++++++++++++++++++++------------- 1 file changed, 23 insertions(+), 14 deletions(-) diff --git a/debug_gym/llms/huggingface.py b/debug_gym/llms/huggingface.py index ed1819f4..e5cdfb65 100644 --- a/debug_gym/llms/huggingface.py +++ b/debug_gym/llms/huggingface.py @@ -20,11 +20,12 @@ def _load_tokenizer(self): self._hf_tokenizer = AutoTokenizer.from_pretrained( self.tokenizer_name, **tokenizer_kwargs ) - except OSError as exc: + except OSError: raise ValueError( - "Failed to load Hugging Face tokenizer " - f"`{self.tokenizer_name}` for model {self.model_name}." - ) from exc + f"Tokenizer `{self.tokenizer_name}` not found for model " + f"{self.model_name}, make sure you have access to " + "the model (e.g., HuggingFace API key is correctly set)." + ) # Ensure we have a pad token to avoid downstream warnings when invoking # the tokenizer in encode mode. @@ -36,16 +37,24 @@ def _load_tokenizer(self): return self._hf_tokenizer def tokenize(self, text: str) -> list[str]: - tokenizer = self._load_tokenizer() - token_ids = tokenizer.encode(str(text), add_special_tokens=False) - - if hasattr(tokenizer, "convert_ids_to_tokens"): - try: - return tokenizer.convert_ids_to_tokens(token_ids) - except Exception: # pragma: no cover - pass - - return [str(t) for t in token_ids] + if getattr(self, "_tk_func", None) is None: + tokenizer = self._load_tokenizer() + if self.apply_chat_template: + + def _tokenize(txt): + return tokenizer.tokenize( + tokenizer.apply_chat_template( + txt, + tokenize=False, + add_generation_prompt=True, + enable_thinking=self.enable_thinking, + ) + ) + + self._tk_func = _tokenize + else: + self._tk_func = tokenizer.tokenize + return self._tk_func(text) def count_tokens(self, text) -> int: tokenizer = self._load_tokenizer() From 00ced7656c46560ab8192f7524e3a609f6dd8b1a Mon Sep 17 00:00:00 2001 From: Matheus Pereira Date: Thu, 16 Oct 2025 11:35:00 -0700 Subject: [PATCH 10/11] Refactor imports in agents and llms modules --- debug_gym/agents/__init__.py | 3 --- debug_gym/llms/base.py | 7 +------ tests/llms/test_base.py | 2 +- tests/llms/test_huggingface.py | 3 +-- 4 files changed, 3 insertions(+), 12 deletions(-) diff --git a/debug_gym/agents/__init__.py b/debug_gym/agents/__init__.py index 83161b49..e69de29b 100644 --- a/debug_gym/agents/__init__.py +++ b/debug_gym/agents/__init__.py @@ -1,3 +0,0 @@ -from debug_gym.agents.debug_agent import Debug_5_Agent, DebugAgent -from debug_gym.agents.rewrite_agent import RewriteAgent -from debug_gym.agents.solution_agent import AgentSolution diff --git a/debug_gym/llms/base.py b/debug_gym/llms/base.py index 96e8c7ca..a087346f 100644 --- a/debug_gym/llms/base.py +++ b/debug_gym/llms/base.py @@ -13,6 +13,7 @@ wait_random_exponential, ) +from debug_gym.agents.utils import get_message_tokens, trim_prompt_messages from debug_gym.gym.envs.env import EnvInfo from debug_gym.gym.tools.tool import EnvironmentTool, ToolCall from debug_gym.llms.constants import DEFAULT_LLM_CONFIG @@ -301,9 +302,6 @@ def _get_message_token_counts(self, messages: list[dict]) -> list[int]: Subclasses can override this to plug in custom counting strategies (for example, chat-template aware tokenizers). """ - - from debug_gym.agents.utils import get_message_tokens - return [get_message_tokens(msg, self.count_tokens) for msg in messages] def _trim_messages_to_context( @@ -318,9 +316,6 @@ def _trim_messages_to_context( Returns: A trimmed list of messages. """ - - from debug_gym.agents.utils import trim_prompt_messages - return trim_prompt_messages(messages, self.context_length, self.count_tokens) @abstractmethod diff --git a/tests/llms/test_base.py b/tests/llms/test_base.py index 98f43811..75e8349f 100644 --- a/tests/llms/test_base.py +++ b/tests/llms/test_base.py @@ -582,7 +582,7 @@ def generate(self, messages, tools, **kwargs): messages = [{"role": "user", "content": "Long message"}] # Mock trim_prompt_messages to return shorter messages - with patch("debug_gym.agents.utils.trim_prompt_messages") as mock_trim: + with patch("debug_gym.llms.base.trim_prompt_messages") as mock_trim: shorter_messages = [{"role": "user", "content": "Short"}] mock_trim.return_value = shorter_messages diff --git a/tests/llms/test_huggingface.py b/tests/llms/test_huggingface.py index bf2b5345..a937c77a 100644 --- a/tests/llms/test_huggingface.py +++ b/tests/llms/test_huggingface.py @@ -4,9 +4,8 @@ import pytest from transformers import AutoTokenizer -from debug_gym.llms import HuggingFaceLLM +from debug_gym.llms import HuggingFaceLLM, OpenAILLM from debug_gym.llms.base import LLMConfig, LLMConfigRegistry -from debug_gym.llms.openai import OpenAILLM # Run these tests with `pytest tests/llms/test_huggingface.py -m hf_tokenizer` # to include the integration case that downloads the real Qwen tokenizer. From 5664b52e95a595d8803a63b9b46d016dc32f82fe Mon Sep 17 00:00:00 2001 From: Matheus Pereira Date: Thu, 23 Oct 2025 15:51:58 -0400 Subject: [PATCH 11/11] Refactor tokenization methods to accept messages as input (#258) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Refactor tokenization methods to accept messages as input * Update README.md Co-authored-by: Marc-Alexandre Côté * Update debug_gym/llms/openai.py Co-authored-by: Marc-Alexandre Côté * Specify base images for mini-nightmare and aider (#255) * specify base image for mini nightmare * switch to debug-gym:aider image * remove default image * passing base image * test workspace * formatting issues * Simplify configs. Fix issue with overwriting entrypoints. --------- Co-authored-by: Marc-Alexandre Côté * Support env variable in pod_spec_kwargs (#248) * Support env variable in pod_spec_kwargs * Ensure kubectl binaries are discoverable Add logic to set environment variables for kubectl discovery * update test_kubernetes * Refactor ShellSession to split shell_command, then add ["-c", command] keeping the command string intact * Pass Kubernetes service environment variables to support in-cluster kubectl access * Add isort and black to development dependencies * black and isort * Terminal env vars (#256) * Refactor terminal classes to allow include_os_env_vars param only for local terminal * Removed include_os_env_vars from kubernetes terminal * Set include_os_env_vars to True by default in LocalTerminal * Initialize env_vars in LocalTerminal constructor to ensure it defaults to an empty dictionary. * Removed redundant environment variable handling for in-cluster Kubernetes access * Simply logic * Set default base image --------- Co-authored-by: Alessandro Sordoni Co-authored-by: Xingdi (Eric) Yuan Co-authored-by: Matheus Pereira * Move trim utils to llms * Remove message normalization from HuggingFaceLLM * Adjust message truncation logic and update test assertion for token counts --------- Co-authored-by: Marc-Alexandre Côté Co-authored-by: Chinmay Singh Co-authored-by: Alessandro Sordoni Co-authored-by: Xingdi (Eric) Yuan --- README.md | 2 +- debug_gym/agents/__init__.py | 3 + debug_gym/agents/base_agent.py | 10 +- debug_gym/agents/utils.py | 178 --------------- debug_gym/gym/envs/aider.py | 5 +- debug_gym/gym/envs/env.py | 25 +- debug_gym/gym/envs/mini_nightmare.py | 4 +- debug_gym/gym/terminals/__init__.py | 3 - debug_gym/gym/terminals/docker.py | 12 +- debug_gym/gym/terminals/kubernetes.py | 47 ++-- debug_gym/gym/terminals/local.py | 23 ++ debug_gym/gym/terminals/shell_session.py | 11 +- debug_gym/gym/terminals/terminal.py | 13 +- debug_gym/llms/anthropic.py | 30 +-- debug_gym/llms/base.py | 51 ++--- debug_gym/llms/copilot.py | 16 +- debug_gym/llms/huggingface.py | 112 ++------- debug_gym/llms/human.py | 25 +- debug_gym/llms/openai.py | 43 ++-- debug_gym/llms/utils.py | 185 +++++++++++++++ pyproject.toml | 2 + scripts/config_aider.yaml | 8 +- scripts/config_mini_nightmare.yaml | 3 +- tests/agents/test_utils.py | 277 +---------------------- tests/conftest.py | 9 +- tests/gym/envs/test_aider.py | 21 ++ tests/gym/terminals/test_docker.py | 24 +- tests/gym/terminals/test_kubernetes.py | 104 ++++++++- tests/gym/test_workspace.py | 2 +- tests/llms/test_base.py | 2 +- tests/llms/test_copilot.py | 12 +- tests/llms/test_huggingface.py | 203 +++++------------ tests/llms/test_human.py | 8 +- tests/llms/test_openai.py | 32 +++ tests/llms/test_utils.py | 277 ++++++++++++++++++++++- 35 files changed, 909 insertions(+), 873 deletions(-) diff --git a/README.md b/README.md index 1543798e..6ac79648 100644 --- a/README.md +++ b/README.md @@ -202,7 +202,7 @@ In addition to all [built-in Jinja filters](https://jinja.palletsprojects.com/en {{ info.tools | to_pretty_json }} ``` -- **`trim_message`**: Trims a string to fit within a token or character limit, also filtering out non-UTF8 characters. This is helpful for ensuring that large outputs (such as directory trees or evaluation results) do not exceed the LLM's context window. The `trim_message` filter accepts the following arguments to control how messages are trimmed: + - **`trim_message`**: Trims a string to approximately fit within a token or character limit while filtering non-UTF8 characters. This helps keep large outputs (such as directory trees or evaluation results) within the LLM's context window. The `trim_message` filter accepts the following arguments to control how messages are trimmed: - **`max_length`**: The maximum number of tokens to keep in the message. If the message exceeds this length, it will be trimmed. - **`max_length_percentage`**: Instead of specifying an absolute number, you can provide a percentage (e.g., `0.1` for 10%) of the LLM's context window. The message will be trimmed to fit within this percentage of the model's maximum context length. - **`where`**: Specifies where to trim the message if it exceeds the limit. The default is `"middle"`, which trims from the middle of the message. Other options are `start` or `end`. diff --git a/debug_gym/agents/__init__.py b/debug_gym/agents/__init__.py index e69de29b..83161b49 100644 --- a/debug_gym/agents/__init__.py +++ b/debug_gym/agents/__init__.py @@ -0,0 +1,3 @@ +from debug_gym.agents.debug_agent import Debug_5_Agent, DebugAgent +from debug_gym.agents.rewrite_agent import RewriteAgent +from debug_gym.agents.solution_agent import AgentSolution diff --git a/debug_gym/agents/base_agent.py b/debug_gym/agents/base_agent.py index 2907f076..c9fe3d6d 100644 --- a/debug_gym/agents/base_agent.py +++ b/debug_gym/agents/base_agent.py @@ -8,10 +8,10 @@ from jinja2 import Environment, Template from debug_gym.agents.history_tracker import HistoryTracker, build_history_prompt -from debug_gym.agents.utils import trim from debug_gym.gym.envs.env import RepoEnv from debug_gym.gym.utils import filter_non_utf8 from debug_gym.llms.base import LLM +from debug_gym.llms.utils import trim from debug_gym.logger import DebugGymLogger AGENT_REGISTRY = {} @@ -119,7 +119,7 @@ def to_pretty_json(value): def trim_message( self, - message, + message: str, count_tokens=None, max_length=None, max_length_percentage=0, @@ -140,10 +140,8 @@ def trim_message( if count_tokens is None or max_length is None or max_length <= 0: return message - tokens = count_tokens(message) - if tokens > max_length: - return trim(message, max_length, count_tokens=count_tokens, where=where) - return message + + return trim(message, max_length, count_tokens=count_tokens, where=where) def _load_system_prompt_template(self) -> Template | None: """Load system prompt template from config if specified and register custom filters. diff --git a/debug_gym/agents/utils.py b/debug_gym/agents/utils.py index d86e2554..3d694237 100644 --- a/debug_gym/agents/utils.py +++ b/debug_gym/agents/utils.py @@ -5,184 +5,6 @@ import yaml -def trim(text: str, max_tokens: int, count_tokens: callable, where: str = "middle"): - """Trim text to fit within max_tokens by working directly at the token level.""" - if max_tokens <= 0: - return "" - - nb_tokens = count_tokens(text) - if nb_tokens <= max_tokens: - return text - - ellipsis = "…" # assume ellipsis is a single token - available_tokens = max_tokens - 1 # account for ellipsis - - def find_char_position_for_tokens( - target_tokens: int, from_start: bool = True - ) -> int: - """Binary search to find character position that gives approximately target_tokens.""" - left, right = 0, len(text) - best_pos = left if from_start else right - - while left <= right: - mid = (left + right) // 2 - test_text = text[:mid] if from_start else text[mid:] - test_tokens = count_tokens(test_text) - if test_tokens <= target_tokens: - best_pos = mid - if from_start: - left = mid + 1 - else: - right = mid - 1 - else: - if from_start: - right = mid - 1 - else: - left = mid + 1 - return best_pos - - if where == "end": - # Keep the beginning, trim the end - trim_point = find_char_position_for_tokens(available_tokens, from_start=True) - return text[:trim_point] + ellipsis - elif where == "start": - # Keep the end, trim the beginning - trim_point = find_char_position_for_tokens(available_tokens, from_start=False) - return ellipsis + text[trim_point:] - elif where == "middle": - # Keep both ends, trim the middle - half_tokens = available_tokens // 2 - - # Find how much we can keep from the start - start_chars = find_char_position_for_tokens(half_tokens, from_start=True) - - # Find how much we can keep from the end with remaining tokens - remaining_tokens = available_tokens - count_tokens(text[:start_chars]) - end_chars = find_char_position_for_tokens(remaining_tokens, from_start=False) - - return text[:start_chars] + ellipsis + text[end_chars:] - else: - raise ValueError(f"Invalid value for `where`: {where!r}.") - - -def get_message_tokens(message, count_tokens): - """Count tokens in a message.""" - message_content = str(message.get("content", message.get("tool_calls", message))) - return count_tokens(message_content) - - -def trim_prompt_messages( - messages: list[dict], context_length: int, count_tokens: callable -): - """ - Trim message on the assistant-tool pair level to fit context length. - - Strategy: - 1. Keep the system message (assert if system itself is too long) - 2. Keep as many most recent (assistant, tool) pairs as possible - 3. Only when we can keep all (assistant, tool) pairs, keep the user message - - Args: - messages: List of message dicts with 'role' and 'content'/'tool_calls' keys - context_length: Maximum number of tokens allowed - count_tokens: Function to count tokens in a string - - Returns: - Trimmed list of messages that fit within context_length - """ - assert len(messages) > 0, "messages should not be empty" - assert messages[-1]["role"] in [ - "user", - "tool", - ], "the last message should be from the user or the tool" - assert context_length >= 0, "context_length should be non-negative" - - # Calculate token count for all messages - message_tokens = [get_message_tokens(msg, count_tokens) for msg in messages] - total_tokens = sum(message_tokens) - - # If we're already within limit, return as-is - if total_tokens <= context_length: - return messages - - # Find system message - system_msg_idx = 0 if messages[0]["role"] == "system" else None - system_tokens = message_tokens[0] if system_msg_idx is not None else 0 - - # Assert system message fits within context - assert ( - system_tokens <= context_length - ), f"System message tokens exceed context length: {system_tokens} > {context_length}!" - - # Find user message - user_msg_idx = None - for i, msg in enumerate(messages): - if msg["role"] == "user": - user_msg_idx = i - break - - # Find all (assistant, tool) pairs by going backwards - assistant_tool_pairs = [] - i = len(messages) - 1 - while i >= 0: - if ( - messages[i]["role"] == "tool" - and i > 0 - and messages[i - 1]["role"] == "assistant" - ): - assistant_tool_pairs.append((i - 1, i)) # (assistant_idx, tool_idx) - i -= 2 - else: - i -= 1 - - # Start building result with system message - result = [] - remaining_tokens = context_length - - if system_msg_idx is not None: - result.append(messages[system_msg_idx]) - remaining_tokens -= system_tokens - - # Add as many recent (assistant, tool) pairs as possible - included_pairs = [] - for assistant_idx, tool_idx in assistant_tool_pairs: - pair_tokens = message_tokens[assistant_idx] + message_tokens[tool_idx] - if pair_tokens <= remaining_tokens: - included_pairs.append((assistant_idx, tool_idx)) - remaining_tokens -= pair_tokens - else: - break - - # Only include user message if we can fit all (assistant, tool) pairs - include_user = False - if len(included_pairs) == len(assistant_tool_pairs) and user_msg_idx is not None: - user_tokens = message_tokens[user_msg_idx] - if user_tokens <= remaining_tokens: - include_user = True - - # Build final result - if include_user: - result.append(messages[user_msg_idx]) - - # Sort by assistant index to maintain chronological order - included_pairs.sort(key=lambda pair: pair[0]) - for assistant_idx, tool_idx in included_pairs: - result.append(messages[assistant_idx]) - result.append(messages[tool_idx]) - - assert ( - len(result) > 0 - ), f"After trimming, no messages fit within context length: {context_length}!" - - # Verify final token count - final_tokens = sum(get_message_tokens(msg, count_tokens) for msg in result) - assert ( - final_tokens <= context_length - ), f"After trimming, the message length still exceeds: {final_tokens} > {context_length}!" - - return result - - def load_config(): parser = argparse.ArgumentParser() parser.add_argument("config_file", help="path to config file") diff --git a/debug_gym/gym/envs/aider.py b/debug_gym/gym/envs/aider.py index f5bc7423..26776448 100644 --- a/debug_gym/gym/envs/aider.py +++ b/debug_gym/gym/envs/aider.py @@ -66,11 +66,12 @@ def __init__( terminal: Terminal | None = None, **kwargs, ): - terminal = terminal or DockerTerminal( base_image=DOCKER_AIDER_IMAGE_NAME, logger=kwargs.get("logger"), ) + if hasattr(terminal, "base_image") and terminal.base_image is None: + terminal.base_image = DOCKER_AIDER_IMAGE_NAME super().__init__(entrypoint=entrypoint, terminal=terminal, **kwargs) @@ -104,8 +105,6 @@ def setup_workspace(self): ) self.workspace.setup_file_filters() # Use codebase's .debugignore and .debugreadonly. - self.set_entrypoints("python -m pytest --tb=no -s .") - def setup_terminal(self): self.logger.info(f"Configuring {self.terminal}...") diff --git a/debug_gym/gym/envs/env.py b/debug_gym/gym/envs/env.py index 2d11e791..a297082f 100644 --- a/debug_gym/gym/envs/env.py +++ b/debug_gym/gym/envs/env.py @@ -231,8 +231,8 @@ def __init__( self.run_timeout = run_timeout self.dir_tree_depth = dir_tree_depth self.terminal = terminal or LocalTerminal() # TODO: default to DockerTerminal - self.entrypoint = entrypoint - self.debug_entrypoint = debug_entrypoint or entrypoint + self._entrypoint = entrypoint + self._debug_entrypoint = debug_entrypoint self.persistent_breakpoints = persistent_breakpoints self.auto_list = auto_list self.logger = logger or DebugGymLogger("debug-gym") @@ -242,7 +242,7 @@ def __init__( self.workspace = Workspace(self.terminal, logger=self.logger) self.dataset = self.load_dataset(problems) - self.set_entrypoints(self.entrypoint, self.debug_entrypoint) + self.set_entrypoints(self._entrypoint, self._debug_entrypoint) def _reset_env_state(self): """Reset the environment state to the initial state.""" @@ -257,18 +257,13 @@ def _reset_env_state(self): self.empty_event_queue() def set_entrypoints(self, entrypoint: str, debug_entrypoint: str | None = None): - if entrypoint: - self.entrypoint = self._prepare_entrypoint(entrypoint) - debug_entrypoint = debug_entrypoint or entrypoint.replace( - "python ", "python -m pdb " - ) - self.debug_entrypoint = self._prepare_entrypoint(debug_entrypoint) - if self.debug_entrypoint is not None and "-m pdb" not in self.debug_entrypoint: - self.debug_entrypoint = self.debug_entrypoint.replace( - "python ", "python -m pdb " - ) - self.entrypoint = "PYTHONPATH=$PYTHONPATH:$PWD " + self.entrypoint - self.debug_entrypoint = "PYTHONPATH=$PYTHONPATH:$PWD " + self.debug_entrypoint + debug_entrypoint = debug_entrypoint or entrypoint.replace( + "python ", "python -m pdb " + ) + self.entrypoint = self._prepare_entrypoint(entrypoint) + self.debug_entrypoint = self._prepare_entrypoint(debug_entrypoint) + # self.entrypoint = "PYTHONPATH=$PYTHONPATH:$PWD " + self.entrypoint + # self.debug_entrypoint = "PYTHONPATH=$PYTHONPATH:$PWD " + self.debug_entrypoint @staticmethod def _prepare_entrypoint(entrypoint): diff --git a/debug_gym/gym/envs/mini_nightmare.py b/debug_gym/gym/envs/mini_nightmare.py index e13a21b8..b850d087 100644 --- a/debug_gym/gym/envs/mini_nightmare.py +++ b/debug_gym/gym/envs/mini_nightmare.py @@ -82,6 +82,8 @@ def __init__( base_image=DOCKER_MINI_NIGHTMARE_IMAGE_NAME, logger=kwargs.get("logger"), ) + if hasattr(terminal, "base_image") and terminal.base_image is None: + terminal.base_image = DOCKER_MINI_NIGHTMARE_IMAGE_NAME super().__init__(entrypoint=entrypoint, terminal=terminal, **kwargs) @@ -119,8 +121,6 @@ def setup_workspace(self): ) self.workspace.setup_file_filters() # Use codebase's .debugignore and .debugreadonly. - self.set_entrypoints("python -m pytest --tb=no -s test.py") - def setup_terminal(self): self.logger.info(f"Configuring {self.terminal}...") diff --git a/debug_gym/gym/terminals/__init__.py b/debug_gym/gym/terminals/__init__.py index 417e0990..068a8b6a 100644 --- a/debug_gym/gym/terminals/__init__.py +++ b/debug_gym/gym/terminals/__init__.py @@ -15,7 +15,6 @@ def select_terminal( logger = logger or DebugGymLogger("debug-gym") terminal_type = terminal_config["type"] - docker_only = ["base_image", "setup_commands"] match terminal_type: case "docker": terminal_class = DockerTerminal @@ -23,8 +22,6 @@ def select_terminal( terminal_class = KubernetesTerminal case "local": terminal_class = LocalTerminal - if any(cfg in terminal_config for cfg in docker_only): - logger.warning("Ignoring Docker-only parameters for local terminal.") case _: raise ValueError(f"Unknown terminal {terminal_type}") diff --git a/debug_gym/gym/terminals/docker.py b/debug_gym/gym/terminals/docker.py index dd8f1460..d6f107f0 100644 --- a/debug_gym/gym/terminals/docker.py +++ b/debug_gym/gym/terminals/docker.py @@ -19,10 +19,10 @@ def __init__( working_dir: str | None = None, session_commands: list[str] | None = None, env_vars: dict[str, str] | None = None, - include_os_env_vars: bool = False, logger: DebugGymLogger | None = None, # Docker-specific parameters - base_image: str = "ubuntu:latest", + base_image: str | None = None, + registry: str = "", setup_commands: list[str] | None = None, **kwargs, ): @@ -39,11 +39,11 @@ def __init__( working_dir=working_dir, session_commands=session_commands, env_vars=env_vars, - include_os_env_vars=include_os_env_vars, logger=logger, **kwargs, ) self.base_image = base_image + self.registry = registry.rstrip("/") + "/" if registry else "" self.setup_commands = setup_commands or [] self.docker_client = docker.from_env(timeout=600) self._container = None @@ -133,13 +133,15 @@ def run( def setup_container(self) -> docker.models.containers.Container: # Create and start a container mounting volumes and setting environment variables - self.logger.debug(f"Setting up container with base image: {self.base_image}") + self.logger.debug( + f"Setting up container with image: {self.registry}{self.base_image}" + ) # Generate a unique container name container_name = f"debug_gym_{uuid.uuid4()}" container = self.docker_client.containers.run( name=container_name, - image=self.base_image, + image=f"{self.registry}{self.base_image}", command="sleep infinity", # Keep the container running working_dir=self.working_dir, environment=self.env_vars, diff --git a/debug_gym/gym/terminals/kubernetes.py b/debug_gym/gym/terminals/kubernetes.py index 89633c40..33f36d0b 100644 --- a/debug_gym/gym/terminals/kubernetes.py +++ b/debug_gym/gym/terminals/kubernetes.py @@ -7,6 +7,7 @@ import uuid from pathlib import Path +from jinja2 import Template from kubernetes import client, config, stream, watch from kubernetes.client.rest import ApiException from kubernetes.stream.ws_client import ERROR_CHANNEL @@ -16,6 +17,7 @@ stop_after_attempt, wait_random_exponential, ) +from yaml import dump, safe_load from debug_gym.gym.terminals.shell_session import ShellSession from debug_gym.gym.terminals.terminal import DISABLE_ECHO_COMMAND, Terminal @@ -217,15 +219,15 @@ def __init__( working_dir: str | None = None, session_commands: list[str] | None = None, env_vars: dict[str, str] | None = None, - include_os_env_vars: bool = False, logger: DebugGymLogger | None = None, - setup_commands: list[str] | None = None, # Kubernetes-specific parameters + setup_commands: list[str] | None = None, pod_name: str | None = None, - base_image: str = "ubuntu:latest", - registry: str = "docker.io/", + base_image: str | None = None, + registry: str = "", namespace: str = "default", kube_config: str | None = None, + kube_context: str | None = None, extra_labels: dict | None = None, pod_spec_kwargs: dict = None, **kwargs, @@ -234,7 +236,6 @@ def __init__( working_dir=working_dir, session_commands=session_commands, env_vars=env_vars, - include_os_env_vars=include_os_env_vars, logger=logger, **kwargs, ) @@ -243,26 +244,36 @@ def __init__( self.setup_commands = setup_commands or [] self.namespace = namespace self.kubernetes_kwargs = kwargs # e.g., nodeSelector, tolerations - self.registry = registry.rstrip("/") + self.registry = registry.rstrip("/") + "/" if registry else "" self._pod_name = pod_name self.pod_spec_kwargs = pod_spec_kwargs or {} user = _clean_for_kubernetes(os.environ.get("USER", "unknown")) - self.labels = {"app": "debug-gym", "component": "terminal", "user": user} | ( - extra_labels or {} - ) + self.labels = {"app": "dbg-gym", "user": user} | (extra_labels or {}) self._pod = None # Initialize Kubernetes client self.kube_config = kube_config + self.kube_context = kube_context if self.kube_config == "incluster": self.kube_config = None config.load_incluster_config() + # For in-cluster kubectl access, pass Kubernetes service environment variables + # This enables kubectl to auto-discover the service account credentials + for key in ("KUBERNETES_SERVICE_HOST", "KUBERNETES_SERVICE_PORT"): + if key in os.environ: + self.env_vars.setdefault(key, os.environ[key]) else: self.kube_config = self.kube_config or os.environ.get( "KUBECONFIG", "~/.kube/config" ) self.kube_config = os.path.expanduser(self.kube_config) - config.load_kube_config(self.kube_config) + config.load_kube_config(self.kube_config, self.kube_context) + self.env_vars.setdefault("KUBECONFIG", self.kube_config) + + # Ensure helper binaries such as kubectl can be discovered even when + # host environment variables are not inherited. + if "PATH" in os.environ: + self.env_vars.setdefault("PATH", os.environ["PATH"]) self.k8s_client = client.CoreV1Api() atexit.register(self.close) @@ -315,9 +326,9 @@ def pod(self): @property def default_shell_command(self) -> list[str]: """Expects the pod to have bash installed.""" - kubeconfig = f"--kubeconfig {self.kube_config}" if self.kube_config else "" + kubeconfig = f"--kubeconfig {self.kube_config} " if self.kube_config else "" bash_cmd = "/bin/bash --noprofile --norc --noediting" - return f"kubectl {kubeconfig} exec -it {self.pod.name} -n {self.pod.namespace} -- {bash_cmd}" + return f"kubectl {kubeconfig}exec -it {self.pod.name} -c main -n {self.pod.namespace} -- {bash_cmd}" def new_shell_session(self): if not self.pod.is_running(): @@ -427,9 +438,15 @@ def setup_pod(self) -> None: self._pod_name or f"dbg-gym.{self.task_name}.{str(uuid.uuid4())[:8]}" ) self.logger.debug( - f"Setting up pod {pod_name} with base image: {self.base_image}" + f"Setting up pod {pod_name} with image: {self.registry}{self.base_image}" ) + # Render pod_spec_kwargs as a Jinja2 template, replace variables, then load as dict. + pod_spec_yaml = dump(self.pod_spec_kwargs) + pod_spec_template = Template(pod_spec_yaml) + rendered_yaml = pod_spec_template.render(os.environ) + pod_spec_kwargs = safe_load(rendered_yaml) + # Create pod specification for Kubernetes. pod_body = { "apiVersion": "v1", @@ -445,7 +462,7 @@ def setup_pod(self) -> None: "containers": [ { "name": "main", - "image": f"{self.registry}/{self.base_image}", + "image": f"{self.registry}{self.base_image}", "imagePullPolicy": "IfNotPresent", "command": ["/bin/bash"], "args": ["-c", "sleep infinity"], @@ -462,7 +479,7 @@ def setup_pod(self) -> None: }, } ], - **self.pod_spec_kwargs, # e.g., nodeSelector, tolerations + **pod_spec_kwargs, # e.g., nodeSelector, tolerations }, } diff --git a/debug_gym/gym/terminals/local.py b/debug_gym/gym/terminals/local.py index 3c168e81..2aabc91d 100644 --- a/debug_gym/gym/terminals/local.py +++ b/debug_gym/gym/terminals/local.py @@ -5,10 +5,33 @@ from debug_gym.gym.terminals.shell_session import ShellSession from debug_gym.gym.terminals.terminal import Terminal +from debug_gym.logger import DebugGymLogger class LocalTerminal(Terminal): + def __init__( + self, + working_dir: str | None = None, + session_commands: list[str] | None = None, + env_vars: dict[str, str] | None = None, + logger: DebugGymLogger | None = None, + # Local-specific parameters + include_os_env_vars: bool = True, + **kwargs, + ): + env_vars = env_vars or {} + if include_os_env_vars: + env_vars = env_vars | dict(os.environ) + + super().__init__( + working_dir=working_dir, + session_commands=session_commands, + env_vars=env_vars, + logger=logger, + **kwargs, + ) + @property def working_dir(self): """Lazy initialization of the working directory.""" diff --git a/debug_gym/gym/terminals/shell_session.py b/debug_gym/gym/terminals/shell_session.py index 5205b28a..7526e856 100644 --- a/debug_gym/gym/terminals/shell_session.py +++ b/debug_gym/gym/terminals/shell_session.py @@ -70,10 +70,15 @@ def start(self, command=None, read_until=None): # Prepare entrypoint, combining session commands and command if provided # For example: `bin/bash -c "session_command1 && session_command2 && pdb"` - entrypoint = self.shell_command if command: command = " && ".join(self.session_commands + [command]) - entrypoint = f'{self.shell_command} -c "{command}"' + # Build command list: split shell_command, then add ["-c", command] + # Keep the command string intact so constructs like $(which ...) reach the target shell + cmd_list = shlex.split(self.shell_command) + ["-c", command] + entrypoint = f"{self.shell_command} -c {command!r}" + else: + cmd_list = shlex.split(self.shell_command) + entrypoint = self.shell_command self.logger.debug(f"Starting {self} with entrypoint: {entrypoint}") @@ -91,7 +96,7 @@ def start(self, command=None, read_until=None): termios.tcsetattr(_client, termios.TCSANOW, attrs) self.process = subprocess.Popen( - shlex.split(entrypoint), + cmd_list, env=self.env_vars, cwd=self.working_dir, stdin=_client, diff --git a/debug_gym/gym/terminals/terminal.py b/debug_gym/gym/terminals/terminal.py index 280083fe..1bef3361 100644 --- a/debug_gym/gym/terminals/terminal.py +++ b/debug_gym/gym/terminals/terminal.py @@ -1,5 +1,4 @@ import atexit -import os import tempfile from abc import ABC, abstractmethod from pathlib import Path @@ -14,18 +13,15 @@ class Terminal(ABC): def __init__( self, - working_dir: str = None, - session_commands: list[str] = None, - env_vars: dict[str, str] = None, - include_os_env_vars: bool = True, + working_dir: str | None = None, + session_commands: list[str] | None = None, + env_vars: dict[str, str] | None = None, logger: DebugGymLogger | None = None, **kwargs, ): self.logger = logger or DebugGymLogger("debug-gym") self.session_commands = session_commands or [] self.env_vars = env_vars or {} - if include_os_env_vars: - self.env_vars = self.env_vars | dict(os.environ) # Clean up output by disabling terminal prompt and colors self.env_vars["NO_COLOR"] = "1" # disable colors self.env_vars["PYTHONSTARTUP"] = "" # prevent Python from loading startup files @@ -35,6 +31,9 @@ def __init__( self._working_dir = working_dir self.sessions = [] + if kwargs: + self.logger.warning(f"Ignoring unknown parameters: {kwargs}") + @property def working_dir(self): """Lazy initialization of the working directory.""" diff --git a/debug_gym/llms/anthropic.py b/debug_gym/llms/anthropic.py index 947da94d..d19e6124 100644 --- a/debug_gym/llms/anthropic.py +++ b/debug_gym/llms/anthropic.py @@ -41,24 +41,18 @@ def client(self): self._client = Anthropic(api_key=self.config.api_key) return self._client - def tokenize(self, text: str) -> list[str]: - raise NotImplementedError("Tokenization is not supported by Anthropic.") + def tokenize(self, messages: list[dict]) -> list[list[str]]: + """Tokenization is not directly supported by Anthropic. + This method returns empty token lists as a placeholder.""" + raise NotImplementedError("Direct tokenization is not supported by Anthropic.") + + def count_tokens(self, messages: list[dict] | str) -> int: + """Count the number of tokens in a text using the Anthropic API.""" + if isinstance(messages, str): + messages = [ + {"role": "user", "content": [{"type": "text", "text": messages}]} + ] - def count_tokens(self, text: str) -> int: - """Count the number of tokens in a text using the Anthropic API. - Dump content to JSON for cases such as: - { - "role": "user", - "content": [ - { - "type": "tool_result", - "tool_use_id": "id123", - "content": "results", - } - ], - } - """ - messages = [{"role": "user", "content": [{"type": "text", "text": text}]}] try: response = self.client.messages.count_tokens( model=self.tokenizer_name, messages=messages @@ -67,7 +61,7 @@ def count_tokens(self, text: str) -> int: except Exception as e: self.logger.warning( f"Error calling Claude token count API: {e!r}. " - f"The message was: {messages}. Will return 0 tokens." + f"The messages were: {messages}. Will return 0 tokens." ) return 0 diff --git a/debug_gym/llms/base.py b/debug_gym/llms/base.py index a087346f..2dfe9b32 100644 --- a/debug_gym/llms/base.py +++ b/debug_gym/llms/base.py @@ -13,11 +13,10 @@ wait_random_exponential, ) -from debug_gym.agents.utils import get_message_tokens, trim_prompt_messages from debug_gym.gym.envs.env import EnvInfo from debug_gym.gym.tools.tool import EnvironmentTool, ToolCall from debug_gym.llms.constants import DEFAULT_LLM_CONFIG -from debug_gym.llms.utils import print_messages +from debug_gym.llms.utils import print_messages, trim_prompt_messages from debug_gym.logger import DebugGymLogger # Set logging level down to WARNING for endpoint queries. @@ -288,35 +287,30 @@ def generate(self, messages, tools, **kwargs) -> LLMResponse: pass @abstractmethod - def tokenize(self, text: str) -> list[str]: - """Abstract method to tokenize a text.""" - pass - - def count_tokens(self, text: str) -> int: - """Count the number of tokens in a text.""" - return len(self.tokenize(text)) + def tokenize(self, messages: list[dict]) -> list[list[str]]: + """Abstract method to tokenize messages. - def _get_message_token_counts(self, messages: list[dict]) -> list[int]: - """Return per-message token counts used for context management. + Args: + messages: List of message dicts - Subclasses can override this to plug in custom counting strategies - (for example, chat-template aware tokenizers). + Returns: + List of token lists, one per message """ - return [get_message_tokens(msg, self.count_tokens) for msg in messages] + pass - def _trim_messages_to_context( - self, messages: list[dict], message_token_counts: list[int] | None = None - ) -> list[dict]: - """Trim messages so they fit within the model context budget. + def count_tokens(self, messages: list[dict] | str) -> int: + """Count the total number of tokens across all messages. Args: - messages: Original message list. - message_token_counts: Optional precomputed counts aligned with messages. + messages: List of message dicts Returns: - A trimmed list of messages. + Total token count across all messages """ - return trim_prompt_messages(messages, self.context_length, self.count_tokens) + if isinstance(messages, str): + messages = [{"role": "user", "content": messages}] + tokenized = self.tokenize(messages) + return sum(len(tokens) for tokens in tokenized) @abstractmethod def define_tools(self, tool_call_list: list[EnvironmentTool]) -> list[dict]: @@ -386,11 +380,10 @@ def generate_with_drop_message_and_retry(messages, tools, **kwargs): for retry_count in range(max_retries + 1): try: # pre-truncate messages if they are too long, to avoid unnecessary retries - message_token_counts = self._get_message_token_counts(messages) - message_tokens = sum(message_token_counts) - if message_tokens > self.context_length * 1.2: - trimmed_messages = self._trim_messages_to_context( - messages, message_token_counts + message_tokens = self.count_tokens(messages) + if message_tokens > self.context_length: + trimmed_messages = trim_prompt_messages( + messages, self.context_length, self.count_tokens ) messages = trimmed_messages @@ -428,8 +421,8 @@ def generate_with_drop_message_and_retry(messages, tools, **kwargs): ) # Trim messages and try again - trimmed_messages = self._trim_messages_to_context( - messages, self._get_message_token_counts(messages) + trimmed_messages = trim_prompt_messages( + messages, self.context_length, self.count_tokens ) if not trimmed_messages: diff --git a/debug_gym/llms/copilot.py b/debug_gym/llms/copilot.py index eb6eff9c..ee8075c6 100644 --- a/debug_gym/llms/copilot.py +++ b/debug_gym/llms/copilot.py @@ -177,15 +177,23 @@ def _create_copilot_client(self) -> OpenAI: timeout=None, ) - def tokenize(self, text: str) -> list[str]: + def tokenize(self, messages: list[dict]) -> list[list[str]]: if getattr(self, "_tk_func", None) is None: try: - self._tk_func = tiktoken.encoding_for_model("gpt-4o").encode + encoder = tiktoken.encoding_for_model("gpt-4o") + # For tiktoken, encode returns list of ints, convert to strings + self._tk_func = lambda text: [str(t) for t in encoder.encode(text)] except KeyError: # Simple word-based tokenization as fallback - # For Claude, you might want to use tiktoken or another tokenizer self._tk_func = lambda x: x.split() - return self._tk_func(text) + + # Tokenize each message individually + result = [] + for msg in messages: + content = str(msg.get("content", msg.get("tool_calls", msg))) + tokens = self._tk_func(content) + result.append(tokens) + return result def need_to_be_retried(self, exception) -> bool: # re-use the need_to_be_retried function from the parent class diff --git a/debug_gym/llms/huggingface.py b/debug_gym/llms/huggingface.py index e5cdfb65..27fa1c6c 100644 --- a/debug_gym/llms/huggingface.py +++ b/debug_gym/llms/huggingface.py @@ -36,98 +36,26 @@ def _load_tokenizer(self): self._hf_tokenizer.pad_token = self._hf_tokenizer.eos_token return self._hf_tokenizer - def tokenize(self, text: str) -> list[str]: - if getattr(self, "_tk_func", None) is None: - tokenizer = self._load_tokenizer() - if self.apply_chat_template: - - def _tokenize(txt): - return tokenizer.tokenize( - tokenizer.apply_chat_template( - txt, - tokenize=False, - add_generation_prompt=True, - enable_thinking=self.enable_thinking, - ) - ) - - self._tk_func = _tokenize - else: - self._tk_func = tokenizer.tokenize - return self._tk_func(text) - - def count_tokens(self, text) -> int: - tokenizer = self._load_tokenizer() - token_ids = tokenizer.encode(str(text), add_special_tokens=False) - return len(token_ids) - - # --- chat template helpers ------------------------------------------------- - - def _get_message_token_counts(self, messages: list[dict]) -> list[int]: - if not self._supports_chat_template(): - return super()._get_message_token_counts(messages) - + def tokenize(self, messages: list[dict]) -> list[list[str]]: tokenizer = self._load_tokenizer() - normalized = self._normalize_messages_for_template(messages) - counts: list[int] = [] - prev_len = 0 - - for idx in range(1, len(normalized) + 1): - try: - tokenized = tokenizer.apply_chat_template( - normalized[:idx], - tokenize=True, - add_generation_prompt=False, - ) - except TypeError: - tokenized = tokenizer.apply_chat_template( - normalized[:idx], tokenize=True - ) - except ValueError: - return super()._get_message_token_counts(messages) - token_ids = ( - tokenized.get("input_ids") if isinstance(tokenized, dict) else tokenized + if self.apply_chat_template: + # When applying chat template, tokenize all messages together + # then return as a single list + text = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=self.enable_thinking, ) - if token_ids and isinstance(token_ids[0], list): - token_ids = token_ids[0] - - if token_ids is None: - return super()._get_message_token_counts(messages) - - current_len = len(token_ids) - if current_len == 0 and idx == len(normalized): - return super()._get_message_token_counts(messages) - - counts.append(max(current_len - prev_len, 0)) - prev_len = current_len - - return counts - - def _supports_chat_template(self) -> bool: - tokenizer = self._load_tokenizer() - return hasattr(tokenizer, "apply_chat_template") - - def _normalize_messages_for_template(self, messages: Iterable[dict]) -> list[dict]: - normalized = [] - for message in messages: - role = message.get("role", "user") - if role not in {"system", "user", "assistant"}: - role = "user" - - content = message.get("content") - if isinstance(content, list): - parts = [] - for item in content: - if isinstance(item, dict) and "text" in item: - parts.append(item["text"]) - else: - parts.append(str(item)) - content = "\n".join(parts) - elif content is None and message.get("tool_calls"): - content = json.dumps(message.get("tool_calls")) - else: - content = "" if content is None else str(content) - - normalized.append({"role": role, "content": content}) - return normalized + tokens = tokenizer.tokenize(text) + # Return as list with single element (all tokens together) + return [tokens] + else: + # Tokenize each message individually + result = [] + for msg in messages: + content = str(msg["content"]) + tokens = tokenizer.tokenize(content) + result.append(tokens) + return result diff --git a/debug_gym/llms/human.py b/debug_gym/llms/human.py index 7557c8fa..28b34176 100644 --- a/debug_gym/llms/human.py +++ b/debug_gym/llms/human.py @@ -471,12 +471,21 @@ def __init__( if prompt_toolkit_available: self._history = InMemoryHistory() - def tokenize(self, text: str) -> list[str]: - """Tokenizes a text by splitting it by spaces.""" - return text.split() - - def count_tokens(self, text: str) -> int: - return len(self.tokenize(text)) + def tokenize(self, messages: list[dict]) -> list[list[str]]: + """Tokenizes messages by splitting content by spaces.""" + result = [] + for msg in messages: + content = str(msg.get("content", msg.get("tool_calls", msg))) + tokens = content.split() + result.append(tokens) + return result + + def count_tokens(self, messages: list[dict] | str) -> int: + """Count tokens across all messages.""" + if isinstance(messages, str): + messages = [{"role": "user", "content": messages}] + tokenized = self.tokenize(messages) + return sum(len(tokens) for tokens in tokenized) def define_tools(self, tool_call_list: list[EnvironmentTool]) -> list[dict]: available_commands = [] @@ -615,6 +624,6 @@ def __call__(self, messages, tools, *args, **kwargs) -> LLMResponse: prompt=messages, response=action, tool=tool_call, - prompt_token_count=self.count_tokens(json.dumps(messages)), - response_token_count=self.count_tokens(action), + prompt_token_count=self.count_tokens(messages), + response_token_count=self.count_tokens([{"tool_calls": action}]), ) diff --git a/debug_gym/llms/openai.py b/debug_gym/llms/openai.py index e0606a2e..ed0b9006 100644 --- a/debug_gym/llms/openai.py +++ b/debug_gym/llms/openai.py @@ -62,36 +62,25 @@ def client(self): ) return self._client - def tokenize(self, text: str) -> list[str]: + def tokenize(self, messages: list[dict]) -> list[list[str]]: if getattr(self, "_tk_func", None) is None: try: - self._tk_func = tiktoken.encoding_for_model(self.tokenizer_name).encode + encoder = tiktoken.encoding_for_model(self.tokenizer_name) + # For tiktoken, encode returns list of ints, we need to convert to list of "tokens" + self._tk_func = lambda text: [str(t) for t in encoder.encode(text)] except KeyError: - # Try to load from transformers, mostly deprecated. Use HuggingFaceLLM for transformers models. - try: - tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name) - if self.apply_chat_template: - - def _tokenize(txt): - return tokenizer.tokenize( - tokenizer.apply_chat_template( - txt, - tokenize=False, - add_generation_prompt=True, - enable_thinking=self.enable_thinking, - ) - ) - - self._tk_func = _tokenize - else: - self._tk_func = tokenizer.tokenize - except OSError: - raise ValueError( - f"Tokenizer `{self.tokenizer_name}` not found for model " - f"{self.model_name}, make sure you have access to " - "the model (e.g., HuggingFace API key is correctly set)." - ) - return self._tk_func(text) + raise ValueError( + f"Tokenizer `{self.tokenizer_name}` not found for model " + f"{self.model_name}. If using Hugging Face models, please " + f"set tag `vllm` to load the HuggingFaceLLM class instead." + ) + # Tokenize each message individually + result = [] + for msg in messages: + content = str(msg.get("content", msg.get("tool_calls", msg))) + tokens = self._tk_func(content) + result.append(tokens) + return result def need_to_be_retried(self, exception) -> bool: # List of fully qualified names of RateLimitError exceptions from various libraries diff --git a/debug_gym/llms/utils.py b/debug_gym/llms/utils.py index 5bbcd0cf..8a5c4cd2 100644 --- a/debug_gym/llms/utils.py +++ b/debug_gym/llms/utils.py @@ -1,6 +1,191 @@ from debug_gym.logger import DebugGymLogger, log_with_color +def get_message_tokens(message, count_tokens): + """Count tokens in a single message. + + Args: + message: A single message dict + count_tokens: Function that takes a list of messages and returns total token count + + Returns: + Token count for this message + """ + return count_tokens([message]) + + +def trim(text: str, max_tokens: int, count_tokens: callable, where: str = "middle"): + """Trim text to fit within max_tokens by working directly at the token level.""" + if max_tokens <= 0: + return "" + + nb_tokens = count_tokens(text) + if nb_tokens <= max_tokens: + return text + + ellipsis = "…" # assume ellipsis is a single token + available_tokens = max_tokens - 1 # account for ellipsis + + def find_char_position_for_tokens( + target_tokens: int, from_start: bool = True + ) -> int: + """Binary search to find character position that gives approximately target_tokens.""" + left, right = 0, len(text) + best_pos = left if from_start else right + + while left <= right: + mid = (left + right) // 2 + test_text = text[:mid] if from_start else text[mid:] + test_tokens = count_tokens(test_text) + if test_tokens <= target_tokens: + best_pos = mid + if from_start: + left = mid + 1 + else: + right = mid - 1 + else: + if from_start: + right = mid - 1 + else: + left = mid + 1 + return best_pos + + if where == "end": + # Keep the beginning, trim the end + trim_point = find_char_position_for_tokens(available_tokens, from_start=True) + return text[:trim_point] + ellipsis + elif where == "start": + # Keep the end, trim the beginning + trim_point = find_char_position_for_tokens(available_tokens, from_start=False) + return ellipsis + text[trim_point:] + elif where == "middle": + # Keep both ends, trim the middle + half_tokens = available_tokens // 2 + + # Find how much we can keep from the start + start_chars = find_char_position_for_tokens(half_tokens, from_start=True) + + # Find how much we can keep from the end with remaining tokens + remaining_tokens = available_tokens - count_tokens(text[:start_chars]) + end_chars = find_char_position_for_tokens(remaining_tokens, from_start=False) + + return text[:start_chars] + ellipsis + text[end_chars:] + else: + raise ValueError(f"Invalid value for `where`: {where!r}.") + + +def trim_prompt_messages( + messages: list[dict], context_length: int, count_tokens: callable +): + """ + Trim message on the assistant-tool pair level to fit context length. + + Strategy: + 1. Keep the system message (assert if system itself is too long) + 2. Keep as many most recent (assistant, tool) pairs as possible + 3. Only when we can keep all (assistant, tool) pairs, keep the user message + + Args: + messages: List of message dicts with 'role' and 'content'/'tool_calls' keys + context_length: Maximum number of tokens allowed + count_tokens: Function to count tokens in a string + + Returns: + Trimmed list of messages that fit within context_length + """ + assert len(messages) > 0, "messages should not be empty" + assert messages[-1]["role"] in [ + "user", + "tool", + ], "the last message should be from the user or the tool" + assert context_length >= 0, "context_length should be non-negative" + + # Calculate token count for all messages + message_tokens = [get_message_tokens(msg, count_tokens) for msg in messages] + total_tokens = sum(message_tokens) + + # If we're already within limit, return as-is + if total_tokens <= context_length: + return messages + + # Find system message + system_msg_idx = 0 if messages[0]["role"] == "system" else None + system_tokens = message_tokens[0] if system_msg_idx is not None else 0 + + # Assert system message fits within context + assert ( + system_tokens <= context_length + ), f"System message tokens exceed context length: {system_tokens} > {context_length}!" + + # Find user message + user_msg_idx = None + for i, msg in enumerate(messages): + if msg["role"] == "user": + user_msg_idx = i + break + + # Find all (assistant, tool) pairs by going backwards + assistant_tool_pairs = [] + i = len(messages) - 1 + while i >= 0: + if ( + messages[i]["role"] == "tool" + and i > 0 + and messages[i - 1]["role"] == "assistant" + ): + assistant_tool_pairs.append((i - 1, i)) # (assistant_idx, tool_idx) + i -= 2 + else: + i -= 1 + + # Start building result with system message + result = [] + remaining_tokens = context_length + + if system_msg_idx is not None: + result.append(messages[system_msg_idx]) + remaining_tokens -= system_tokens + + # Add as many recent (assistant, tool) pairs as possible + included_pairs = [] + for assistant_idx, tool_idx in assistant_tool_pairs: + pair_tokens = message_tokens[assistant_idx] + message_tokens[tool_idx] + if pair_tokens <= remaining_tokens: + included_pairs.append((assistant_idx, tool_idx)) + remaining_tokens -= pair_tokens + else: + break + + # Only include user message if we can fit all (assistant, tool) pairs + include_user = False + if len(included_pairs) == len(assistant_tool_pairs) and user_msg_idx is not None: + user_tokens = message_tokens[user_msg_idx] + if user_tokens <= remaining_tokens: + include_user = True + + # Build final result + if include_user: + result.append(messages[user_msg_idx]) + + # Sort by assistant index to maintain chronological order + included_pairs.sort(key=lambda pair: pair[0]) + for assistant_idx, tool_idx in included_pairs: + result.append(messages[assistant_idx]) + result.append(messages[tool_idx]) + + assert ( + len(result) > 0 + ), f"After trimming, no messages fit within context length: {context_length}!" + + # Verify final token count + final_tokens = sum(get_message_tokens(msg, count_tokens) for msg in result) + assert ( + final_tokens <= context_length + ), f"After trimming, the message length still exceeds: {final_tokens} > {context_length}!" + + return result + + def print_messages(messages: list[dict], logger: DebugGymLogger): """Print messages coloring each role differently. Colors: diff --git a/pyproject.toml b/pyproject.toml index e8099b4b..1499685e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,4 +35,6 @@ dev = [ "pytest-xdist", "pytest-timeout", "pytest-env", + "isort", + "black", ] \ No newline at end of file diff --git a/scripts/config_aider.yaml b/scripts/config_aider.yaml index 573d8510..4aa6ffdd 100644 --- a/scripts/config_aider.yaml +++ b/scripts/config_aider.yaml @@ -4,7 +4,6 @@ base: benchmark: "aider" problems: "all" # list of problems, e.g., ["wordy"], or "all" env_kwargs: { - "entrypoint": "python -m pytest -s .", "dir_tree_depth": 1, "run_timeout": 20, # shortcut features @@ -15,12 +14,7 @@ base: "auto_list": True, # If True, the environment will automatically call `list .` via the PDB tool after every pdb tool call, which will show the code around the current frame. } terminal: { - type: "docker", # "docker" or "local" - base_image: "python:3.12-slim", - # session_commands define commands that are always executed before starting a shell session or running a single command in the terminal. - # session_commands:["conda activate aider"], - # setup_commands define commands that are executed only once when the terminal is created. This is only supported for Docker terminal. - setup_commands: ["apt update", "apt install -y git", "pip install pytest"], + type: "docker", # "docker", "kubernetes", or "local" } # LLM configs diff --git a/scripts/config_mini_nightmare.yaml b/scripts/config_mini_nightmare.yaml index c6eba799..fa97a1a1 100644 --- a/scripts/config_mini_nightmare.yaml +++ b/scripts/config_mini_nightmare.yaml @@ -4,7 +4,6 @@ base: benchmark: "mini_nightmare" problems: "all" # list of problems, e.g., ["config"], or "all" env_kwargs: { - "entrypoint": "python -m pytest --tb=no -s test.py", "dir_tree_depth": 1, "run_timeout": 30, # shortcut features @@ -16,7 +15,7 @@ base: } terminal: { - type: "docker", # "local", "docker", or "kubernetes" + type: "docker", # "docker", "kubernetes", or "local" } # LLM configs diff --git a/tests/agents/test_utils.py b/tests/agents/test_utils.py index a68b7550..bb479a21 100644 --- a/tests/agents/test_utils.py +++ b/tests/agents/test_utils.py @@ -1,183 +1,7 @@ import logging from unittest.mock import patch -import pytest - -from debug_gym.agents.utils import load_config, trim, trim_prompt_messages - - -def test_trim_prompt_messages(): - def count_tokens(text): - return len(text) - - # Test basic validation - with pytest.raises(AssertionError, match="messages should not be empty"): - trim_prompt_messages([], 5, count_tokens) - - with pytest.raises( - AssertionError, match="the last message should be from the user or the tool" - ): - messages = [ - {"role": "system", "content": "System message"}, - {"role": "assistant", "content": "Assistant message"}, - ] - trim_prompt_messages(messages, 20, count_tokens) - - with pytest.raises(AssertionError, match="context_length should be non-negative"): - messages = [{"role": "user", "content": "User message"}] - trim_prompt_messages(messages, -1, count_tokens) - - # Test system message too long - with pytest.raises( - AssertionError, match="System message tokens exceed context length" - ): - messages = [ - {"role": "system", "content": "Very long system message"}, - {"role": "user", "content": "Hi"}, - ] - trim_prompt_messages(messages, 10, count_tokens) - - # Test simple case: just user message - messages = [{"role": "user", "content": "Hello"}] - result = trim_prompt_messages(messages, 10, count_tokens) - assert result == messages - - # Test case: system + user, fits completely - messages = [ - {"role": "system", "content": "Sys"}, # 3 tokens - {"role": "user", "content": "Hi"}, # 2 tokens - ] - result = trim_prompt_messages(messages, 10, count_tokens) - assert result == messages - - # Test case: system + user + assistant + tool, fits completely - messages = [ - {"role": "system", "content": "Sys"}, # 3 tokens - {"role": "user", "content": "Hi"}, # 2 tokens - {"role": "assistant", "content": "Hello"}, # 5 tokens - {"role": "tool", "content": "Result"}, # 6 tokens - ] - result = trim_prompt_messages(messages, 20, count_tokens) - assert result == messages - - # Test case: Keep system + most recent assistant-tool pair, drop user - messages = [ - {"role": "system", "content": "Sys"}, # 3 tokens - {"role": "user", "content": "Hi"}, # 2 tokens - {"role": "assistant", "content": "Hello"}, # 5 tokens - {"role": "tool", "content": "Result"}, # 6 tokens - ] - expected = [ - {"role": "system", "content": "Sys"}, - {"role": "assistant", "content": "Hello"}, - {"role": "tool", "content": "Result"}, - ] - result = trim_prompt_messages( - messages, 14, count_tokens - ) # Just enough for sys + assistant + tool - assert result == expected - - # Test case: Multiple assistant-tool pairs, keep most recent ones - messages = [ - {"role": "system", "content": "Sys"}, # 3 tokens - {"role": "user", "content": "Hi"}, # 2 tokens - {"role": "assistant", "content": "Hello1"}, # 6 tokens - {"role": "tool", "content": "Result1"}, # 7 tokens - {"role": "assistant", "content": "Hello2"}, # 6 tokens - {"role": "tool", "content": "Result2"}, # 7 tokens - ] - # Keep system + most recent assistant-tool pair only - expected = [ - {"role": "system", "content": "Sys"}, - {"role": "assistant", "content": "Hello2"}, - {"role": "tool", "content": "Result2"}, - ] - result = trim_prompt_messages( - messages, 16, count_tokens - ) # sys(3) + hello2(6) + result2(7) = 16 - assert result == expected - - # Test case: Can fit all assistant-tool pairs + user message - messages = [ - {"role": "system", "content": "Sys"}, # 3 tokens - {"role": "user", "content": "Hi"}, # 2 tokens - {"role": "assistant", "content": "Hello1"}, # 6 tokens - {"role": "tool", "content": "Result1"}, # 7 tokens - {"role": "assistant", "content": "Hello2"}, # 6 tokens - {"role": "tool", "content": "Result2"}, # 7 tokens - ] - # All pairs fit, so include user message too - result = trim_prompt_messages(messages, 50, count_tokens) - assert result == messages - - # Test case: Can fit all assistant-tool pairs but not user message - messages = [ - {"role": "system", "content": "Sys"}, # 3 tokens - {"role": "user", "content": "Hi"}, # 2 tokens - {"role": "assistant", "content": "Hello1"}, # 6 tokens - {"role": "tool", "content": "Result1"}, # 7 tokens - {"role": "assistant", "content": "Hello2"}, # 6 tokens - {"role": "tool", "content": "Result2"}, # 7 tokens - ] - expected = [ - {"role": "system", "content": "Sys"}, - {"role": "assistant", "content": "Hello1"}, - {"role": "tool", "content": "Result1"}, - {"role": "assistant", "content": "Hello2"}, - {"role": "tool", "content": "Result2"}, - ] - result = trim_prompt_messages( - messages, 29, count_tokens - ) # sys(3) + all pairs(26) = 29, no room for user(2) - assert result == expected - - # Test case: No system message - messages = [ - {"role": "user", "content": "Hi"}, # 2 tokens - {"role": "assistant", "content": "Hello"}, # 5 tokens - {"role": "tool", "content": "Result"}, # 6 tokens - ] - result = trim_prompt_messages(messages, 20, count_tokens) - assert result == messages - - # Test case: No assistant-tool pairs, just system and user - messages = [ - {"role": "system", "content": "Sys"}, # 3 tokens - {"role": "user", "content": "Hi"}, # 2 tokens - ] - result = trim_prompt_messages(messages, 10, count_tokens) - assert result == messages - - # Test case: No system, no assistant-tool pairs, just user - messages = [{"role": "user", "content": "Hi"}] # 2 tokens - result = trim_prompt_messages(messages, 10, count_tokens) - assert result == messages - - # Test case: Tool message without preceding assistant (edge case) - messages = [ - {"role": "system", "content": "Sys"}, # 3 tokens - {"role": "user", "content": "Hi"}, # 2 tokens - {"role": "tool", "content": "Result"}, # 6 tokens - ] - expected = [ - {"role": "system", "content": "Sys"}, - {"role": "user", "content": "Hi"}, - ] - result = trim_prompt_messages(messages, 10, count_tokens) - assert result == expected - - # Test case: Message with tool_calls instead of content - messages = [ - {"role": "system", "content": "Sys"}, # 3 tokens - {"role": "user", "content": "Hi"}, # 2 tokens - { - "role": "assistant", - "tool_calls": [{"function": {"name": "test"}}], - }, # ~30 tokens (str representation) - {"role": "tool", "content": "Result"}, # 6 tokens - ] - result = trim_prompt_messages(messages, 100, count_tokens) - assert len(result) == 4 # Should keep all messages if context is large enough +from debug_gym.agents.utils import load_config def test_load_config(): @@ -263,102 +87,3 @@ def test_load_config(): assert _config == expected_config assert _args.debug is True assert _args.logging_level == logging.INFO - - -def test_trim(): - def count_tokens(text): - return len(text) - - # Test basic cases - no trimming needed - assert trim("Hello world", 11, count_tokens) == "Hello world" - assert trim("Hello world", 20, count_tokens) == "Hello world" - assert trim("Hi", 2, count_tokens) == "Hi" - assert trim("Hi", 10, count_tokens) == "Hi" - assert trim("A", 1, count_tokens) == "A" # Exactly fits, no trimming needed - - # Test edge cases - assert trim("Hello world", 0, count_tokens) == "" - assert trim("", 5, count_tokens) == "" - assert trim("", 0, count_tokens) == "" - - # Test cases requiring trimming to single token (ellipsis only) - assert trim("Hello world", 1, count_tokens) == "…" - assert trim("Hi", 1, count_tokens) == "…" - assert trim("ABC", 1, count_tokens) == "…" - - # Test trimming from the middle (default behavior) - assert trim("Hello world", 5, count_tokens) == "He…ld" - assert trim("Hello world", 6, count_tokens) == "He…rld" - assert trim("Hello world", 7, count_tokens) == "Hel…rld" - assert trim("123456789", 5, count_tokens) == "12…89" - assert trim("123456789", 7, count_tokens) == "123…789" - - # Test trimming from the end - assert trim("Hello world", 5, count_tokens, where="end") == "Hell…" - assert trim("Hello world", 6, count_tokens, where="end") == "Hello…" - assert trim("Hello world", 7, count_tokens, where="end") == "Hello …" - assert trim("123456789", 5, count_tokens, where="end") == "1234…" - - # Test trimming from the start - assert trim("Hello world", 5, count_tokens, where="start") == "…orld" - assert trim("Hello world", 6, count_tokens, where="start") == "…world" - assert trim("Hello world", 7, count_tokens, where="start") == "… world" - assert trim("123456789", 5, count_tokens, where="start") == "…6789" - - # Test invalid `where` value - with pytest.raises(ValueError, match="Invalid value for `where`"): - trim("Hello world", 5, count_tokens, where="invalid") - - # Test with different token counter - def another_count_tokens(text): - return len(text) // 2 - - # For "1234567890" (10 chars), another_count_tokens returns 5 tokens - # Original text has 5 tokens, so no trimming needed when max_tokens >= 5 - assert trim("1234567890", 5, another_count_tokens) == "1234567890" - assert trim("1234567890", 6, another_count_tokens) == "1234567890" - - # When max_tokens < 5, trimming is needed - # With max_tokens=4, we need 3 tokens for content + 1 for ellipsis - assert ( - trim("1234567890", 4, another_count_tokens) == "123…67890" - ) # Result has 4 tokens - assert ( - trim("1234567890", 3, another_count_tokens) == "123…890" - ) # Result has 3 tokens - assert trim("1234567890", 2, another_count_tokens) == "1…890" # Result has 2 tokens - assert trim("1234567890", 1, another_count_tokens) == "1…0" # Result has 1 token - - # Test with different trimming positions using the alternative counter - assert ( - trim("1234567890", 3, another_count_tokens, where="end") == "12345…" - ) # Result has 3 tokens - assert ( - trim("1234567890", 3, another_count_tokens, where="start") == "…67890" - ) # Result has 3 tokens - - # Test edge case with very short text and alternative counter - assert trim("AB", 1, another_count_tokens) == "AB" # "AB" has 1 token, fits exactly - assert ( - trim("ABCD", 1, another_count_tokens) == "A…D" - ) # "ABCD" has 2 tokens, needs trimming to 1 - - # Test boundary conditions with precise scenarios - def word_count_tokens(text): - # Count words as tokens - return len(text.split()) - - text = "Hello world test example" # 4 words = 4 tokens - assert trim(text, 4, word_count_tokens) == text # No trimming needed - assert ( - trim(text, 3, word_count_tokens, where="middle") == "Hello … example" - ) # Should fit in 3 tokens - assert ( - trim(text, 2, word_count_tokens, where="end") == "Hello …" - ) # Should fit in 2 tokens - assert ( - trim(text, 2, word_count_tokens, where="start") == "… example" - ) # Should fit in 2 tokens - - # Test very short max_tokens with word counter - assert trim("Hello world", 1, word_count_tokens) == "…" # Only ellipsis fits diff --git a/tests/conftest.py b/tests/conftest.py index 8248360b..17ee3fdd 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -45,8 +45,13 @@ def generate(self, messages, tools, **kwargs): response_token_count=20, ) - def tokenize(self, text): - return [c for c in text] + def tokenize(self, messages): + # Return list of token lists, one per message + result = [] + for msg in messages: + content = str(msg.get("content", msg.get("tool_calls", msg))) + result.append([c for c in content]) + return result def define_tools(self, tool_call_list): return tool_call_list diff --git a/tests/gym/envs/test_aider.py b/tests/gym/envs/test_aider.py index e9a464c5..6325dc05 100644 --- a/tests/gym/envs/test_aider.py +++ b/tests/gym/envs/test_aider.py @@ -1,6 +1,9 @@ +from unittest.mock import patch + import pytest from debug_gym.gym.envs import AiderBenchmarkEnv +from debug_gym.gym.terminals.docker import DockerTerminal from debug_gym.gym.terminals.local import LocalTerminal from debug_gym.gym.tools.tool import ToolCall from debug_gym.gym.tools.toolbox import Toolbox @@ -95,3 +98,21 @@ def test_steps(env): def test_instructions(env): assert env.instructions == "What time is it?" + + +@patch("debug_gym.gym.envs.aider.build_docker_image") +def test_build_docker_image(mock_build_docker_image): + AiderBenchmarkEnv() + mock_build_docker_image.assert_called_once() + + +@pytest.if_docker_running +def test_reset_with_docker_terminal(): + env = AiderBenchmarkEnv() + assert isinstance(env.terminal, DockerTerminal) + + infos = env.reset(options={"task_name": "clock"}) + assert "1 failed" in infos.step_observation.observation + assert infos.max_score == 1 + assert infos.score == 0 + assert not infos.done diff --git a/tests/gym/terminals/test_docker.py b/tests/gym/terminals/test_docker.py index 1799f7bc..8540fcc0 100644 --- a/tests/gym/terminals/test_docker.py +++ b/tests/gym/terminals/test_docker.py @@ -12,7 +12,7 @@ @pytest.if_docker_running def test_docker_terminal_init(): - terminal = DockerTerminal() + terminal = DockerTerminal(base_image="ubuntu:latest") assert terminal.session_commands == [] assert terminal.env_vars == { "NO_COLOR": "1", @@ -60,7 +60,9 @@ def test_docker_terminal_init_with_params(tmp_path): ) def test_docker_terminal_run(tmp_path, command): working_dir = str(tmp_path) - docker_terminal = DockerTerminal(working_dir=working_dir) + docker_terminal = DockerTerminal( + working_dir=working_dir, base_image="ubuntu:latest" + ) success, output = docker_terminal.run(command, timeout=1) assert output == "test" assert success is True @@ -77,7 +79,7 @@ def test_docker_terminal_run(tmp_path, command): def test_terminal_multiple_session_commands(tmp_path): working_dir = str(tmp_path) session_commands = ["echo 'Hello'", "echo 'World'"] - terminal = DockerTerminal(working_dir, session_commands) + terminal = DockerTerminal(working_dir, session_commands, base_image="ubuntu:latest") status, output = terminal.run("pwd", timeout=1) assert status assert output == f"Hello\nWorld\n{working_dir}" @@ -89,7 +91,7 @@ def test_docker_terminal_session(tmp_path): # same as test_terminal_session but with DockerTerminal working_dir = str(tmp_path) command = "echo Hello World" - terminal = DockerTerminal(working_dir=working_dir) + terminal = DockerTerminal(working_dir=working_dir, base_image="ubuntu:latest") assert not terminal.sessions session = terminal.new_shell_session() @@ -113,7 +115,7 @@ def test_docker_terminal_session(tmp_path): @pytest.if_docker_running def test_terminal_sudo_command(tmp_path): working_dir = str(tmp_path) - terminal = DockerTerminal(working_dir=working_dir) + terminal = DockerTerminal(working_dir=working_dir, base_image="ubuntu:latest") success, output = terminal.run("vim --version", timeout=1) assert "vim: command not found" in output assert success is False @@ -129,7 +131,7 @@ def test_terminal_sudo_command(tmp_path): @pytest.if_docker_running def test_terminal_cleanup(tmp_path): working_dir = str(tmp_path) - terminal = DockerTerminal(working_dir=working_dir) + terminal = DockerTerminal(working_dir=working_dir, base_image="ubuntu:latest") container_name = terminal.container.name terminal.clean_up() assert terminal._container is None @@ -151,7 +153,9 @@ def test_select_terminal_docker(): def test_run_setup_commands_success(tmp_path): working_dir = str(tmp_path) setup_commands = ["touch test1.txt", "echo test > test2.txt"] - terminal = DockerTerminal(working_dir, setup_commands=setup_commands) + terminal = DockerTerminal( + working_dir, setup_commands=setup_commands, base_image="ubuntu:latest" + ) assert terminal.container is not None assert terminal.container.status == "running" _, output = terminal.run("ls", timeout=1) @@ -163,7 +167,9 @@ def test_run_setup_commands_failure(tmp_path): working_dir = str(tmp_path) setup_commands = ["echo install", "ls ./non_existent_dir"] with pytest.raises(ValueError, match="Failed to run setup command:*"): - terminal = DockerTerminal(working_dir, setup_commands=setup_commands) + terminal = DockerTerminal( + working_dir, setup_commands=setup_commands, base_image="ubuntu:latest" + ) terminal.container # start the container @@ -176,7 +182,7 @@ def test_copy_content(tmp_path): with open(source_file, "w") as src_file: src_file.write("Hello World") - terminal = DockerTerminal() + terminal = DockerTerminal(base_image="ubuntu:latest") # Source must be a folder. with pytest.raises(ValueError, match="Source .* must be a directory."): terminal.copy_content(source_file) diff --git a/tests/gym/terminals/test_kubernetes.py b/tests/gym/terminals/test_kubernetes.py index 7f2c13cc..acf9eb29 100644 --- a/tests/gym/terminals/test_kubernetes.py +++ b/tests/gym/terminals/test_kubernetes.py @@ -33,13 +33,26 @@ def is_kubernetes_available(): @if_kubernetes_available def test_kubernetes_terminal_init(): - terminal = KubernetesTerminal() + terminal = KubernetesTerminal(base_image="ubuntu:latest") assert terminal.session_commands == [] - assert terminal.env_vars == { + expected_base_env = { "NO_COLOR": "1", "PS1": DEFAULT_PS1, "PYTHONSTARTUP": "", } + for key, value in expected_base_env.items(): + assert terminal.env_vars[key] == value + + assert terminal.env_vars["PATH"] == os.environ.get("PATH") + if terminal.kube_config: + assert terminal.env_vars["KUBECONFIG"] == terminal.kube_config + else: + assert "KUBECONFIG" not in terminal.env_vars + + extra_env_keys = set(terminal.env_vars) - ( + set(expected_base_env) | {"PATH", "KUBECONFIG"} + ) + assert not extra_env_keys assert os.path.basename(terminal.working_dir).startswith("Terminal-") assert terminal.base_image == "ubuntu:latest" assert terminal.namespace == "default" @@ -84,7 +97,15 @@ def test_kubernetes_terminal_init_with_params(tmp_path): ) assert terminal.working_dir == working_dir assert terminal.session_commands == session_commands - assert terminal.env_vars == env_vars | {"NO_COLOR": "1", "PS1": DEFAULT_PS1} + assert terminal.env_vars["ENV_VAR"] == "value" + assert terminal.env_vars["NO_COLOR"] == "1" + assert terminal.env_vars["PS1"] == DEFAULT_PS1 + assert terminal.env_vars["PYTHONSTARTUP"] == "" + assert terminal.env_vars["PATH"] == os.environ.get("PATH") + if terminal.kube_config: + assert terminal.env_vars["KUBECONFIG"] == terminal.kube_config + else: + assert "KUBECONFIG" not in terminal.env_vars assert terminal.base_image == base_image # Create pod. @@ -98,6 +119,69 @@ def test_kubernetes_terminal_init_with_params(tmp_path): assert terminal._pod is None +@if_kubernetes_available +def test_kubernetes_terminal_init_with_pod_specs(tmp_path): + working_dir = str(tmp_path) + # set an environment variable to use in the pod spec + os.environ["HOSTNAME"] = "minikube" + pod_spec_kwargs = { + "affinity": { + "nodeAffinity": { + "requiredDuringSchedulingIgnoredDuringExecution": { + "nodeSelectorTerms": [ + { + "matchExpressions": [ + { + "key": "kubernetes.io/hostname", + "operator": "In", + "values": ["{{HOSTNAME}}"], + } + ] + } + ] + } + } + }, + "tolerations": [ + { + "key": "kubernetes.azure.com/scalesetpriority", + "operator": "Equal", + "value": "spot", + "effect": "NoSchedule", + }, + { + "key": "CriticalAddonsOnly", + "operator": "Equal", + "value": "true", + "effect": "NoSchedule", + }, + ], + } + + terminal = KubernetesTerminal( + working_dir=working_dir, + pod_spec_kwargs=pod_spec_kwargs, + kube_context="minikube", + base_image="ubuntu:latest", + ) + + terminal.pod # Create pod. + assert ( + terminal.pod.pod_body["spec"]["tolerations"] == pod_spec_kwargs["tolerations"] + ) + # Make sure environment variable was replaced in the pod spec. + spec = terminal.pod.pod_body["spec"] + node_affinity = spec["affinity"]["nodeAffinity"] + required = node_affinity["requiredDuringSchedulingIgnoredDuringExecution"] + term = required["nodeSelectorTerms"][0] + match_expression = term["matchExpressions"][0] + assert match_expression["values"] == [os.environ["HOSTNAME"]] + + # Close pod. + terminal.close() + assert terminal._pod is None + + @if_kubernetes_available @pytest.mark.parametrize( "command", @@ -109,7 +193,7 @@ def test_kubernetes_terminal_init_with_params(tmp_path): def test_kubernetes_terminal_run(tmp_path, command): """Test running commands in the Kubernetes terminal.""" working_dir = str(tmp_path) - terminal = KubernetesTerminal(working_dir=working_dir) + terminal = KubernetesTerminal(working_dir=working_dir, base_image="ubuntu:latest") success, output = terminal.run(command, timeout=1) assert output == "test" assert success is True @@ -128,7 +212,9 @@ def test_kubernetes_terminal_run(tmp_path, command): def test_kubernetes_terminal_with_session_commands(tmp_path): working_dir = str(tmp_path) session_commands = ["echo 'Hello'", "echo 'World'"] - terminal = KubernetesTerminal(working_dir, session_commands=session_commands) + terminal = KubernetesTerminal( + working_dir, session_commands=session_commands, base_image="ubuntu:latest" + ) status, output = terminal.run("pwd", timeout=1) assert status assert output == f"Hello\nWorld\n{working_dir}" @@ -141,7 +227,7 @@ def test_kubernetes_terminal_session(tmp_path): # same as test_terminal_session but with DockerTerminal working_dir = str(tmp_path) command = "echo Hello World" - terminal = KubernetesTerminal(working_dir=working_dir) + terminal = KubernetesTerminal(working_dir=working_dir, base_image="ubuntu:latest") assert not terminal.sessions session = terminal.new_shell_session() @@ -171,7 +257,7 @@ def test_copy_content(tmp_path): with open(source_file, "w") as src_file: src_file.write("Hello World") - terminal = KubernetesTerminal() + terminal = KubernetesTerminal(base_image="ubuntu:latest") # Source must be a folder. with pytest.raises(ValueError, match="Source .* must be a directory."): terminal.copy_content(source_file) @@ -192,7 +278,7 @@ def test_copy_content(tmp_path): def test_kubernetes_terminal_cleanup(tmp_path): """Test cleanup functionality.""" working_dir = str(tmp_path) - terminal = KubernetesTerminal(working_dir=working_dir) + terminal = KubernetesTerminal(working_dir=working_dir, base_image="ubuntu:latest") # Test cleanup without creating pod terminal.close() @@ -218,7 +304,7 @@ def test_select_terminal_kubernetes(): def test_kubernetes_terminal_readonly_properties_after_pod_creation(): """Test that working directory cannot be changed after pod creation.""" - terminal = KubernetesTerminal() + terminal = KubernetesTerminal(base_image="ubuntu:latest") terminal.pod # Create pod. with pytest.raises( diff --git a/tests/gym/test_workspace.py b/tests/gym/test_workspace.py index 385728c2..c60345db 100644 --- a/tests/gym/test_workspace.py +++ b/tests/gym/test_workspace.py @@ -47,7 +47,7 @@ def test_reset_and_cleanup_workspace(): assert not os.path.isdir(working_dir) # Setup workspace with a remote terminal. - terminal = DockerTerminal() + terminal = DockerTerminal(base_image="ubuntu:latest") workspace = Workspace(terminal) assert workspace._tempdir is None diff --git a/tests/llms/test_base.py b/tests/llms/test_base.py index 75e8349f..a0801a97 100644 --- a/tests/llms/test_base.py +++ b/tests/llms/test_base.py @@ -527,7 +527,7 @@ def generate(self, messages, tools, **kwargs): messages = [{"role": "user", "content": "Long message"}] # Mock trim_prompt_messages to return the same messages (no reduction) - with patch("debug_gym.agents.utils.trim_prompt_messages") as mock_trim: + with patch("debug_gym.llms.utils.trim_prompt_messages") as mock_trim: mock_trim.return_value = messages # Should raise ContextLengthExceededError, not RecursionError diff --git a/tests/llms/test_copilot.py b/tests/llms/test_copilot.py index 7a8997d2..deb78196 100644 --- a/tests/llms/test_copilot.py +++ b/tests/llms/test_copilot.py @@ -111,9 +111,11 @@ def test_tokenize_with_tiktoken(self, mock_config, logger_mock): mock_encode = MagicMock(return_value=[1, 2, 3, 4]) with patch("tiktoken.encoding_for_model") as mock_tiktoken: mock_tiktoken.return_value.encode = mock_encode - tokens = llm.tokenize("test text") + messages = [{"role": "user", "content": "test text"}] + tokens = llm.tokenize(messages) - assert tokens == [1, 2, 3, 4] + # Should return list of token lists (one per message) + assert tokens == [["1", "2", "3", "4"]] mock_tiktoken.assert_called_once_with("gpt-4o") mock_encode.assert_called_once_with("test text") @@ -125,9 +127,11 @@ def test_tokenize_fallback(self, mock_config, logger_mock): with patch( "tiktoken.encoding_for_model", side_effect=KeyError("model not found") ): - tokens = llm.tokenize("hello world test") + messages = [{"role": "user", "content": "hello world test"}] + tokens = llm.tokenize(messages) - assert tokens == ["hello", "world", "test"] + # Should return list of token lists (one per message) + assert tokens == [["hello", "world", "test"]] def test_need_to_be_retried_hmac_timestamp_error(self, mock_config, logger_mock): """Test retry logic for HMAC timestamp errors""" diff --git a/tests/llms/test_huggingface.py b/tests/llms/test_huggingface.py index a937c77a..e8a126d2 100644 --- a/tests/llms/test_huggingface.py +++ b/tests/llms/test_huggingface.py @@ -25,6 +25,13 @@ } } +MODEL_REGISTRY_WITH_CHAT_TEMPLATE = { + "qwen-3": { + **MODEL_REGISTRY["qwen-3"], + "apply_chat_template": True, + }, +} + @pytest.fixture(scope="session") def real_qwen3_tokenizer(): @@ -52,8 +59,9 @@ def test_tokenize_uses_hf_tokenizer_with_pad_fallback(mock_llm_config, logger_mo ) as mock_auto_tokenizer: mock_auto_tokenizer.return_value = tokenizer llm = HuggingFaceLLM(model_name="qwen-3", logger=logger_mock) - assert llm.tokenize("hello world") == ["hello", "Ġworld"] - assert llm.count_tokens("hello world") == 2 + messages = [{"role": "user", "content": "hello world"}] + assert llm.tokenize(messages) == [["hello", "Ġworld"]] + assert llm.count_tokens(messages) == 2 assert tokenizer.eos_token == "" assert tokenizer.pad_token == "" @@ -61,111 +69,22 @@ def test_tokenize_uses_hf_tokenizer_with_pad_fallback(mock_llm_config, logger_mo @patch.object( LLMConfigRegistry, "from_file", - return_value=LLMConfigRegistry.register_all(MODEL_REGISTRY), + return_value=LLMConfigRegistry.register_all(MODEL_REGISTRY_WITH_CHAT_TEMPLATE), ) -@patch("debug_gym.llms.huggingface.AutoTokenizer.from_pretrained") -def test_normalize_messages_for_chat_template( - mock_auto_tokenizer, mock_llm_config, logger_mock -): - tokenizer_mock = MagicMock() - tokenizer_mock.pad_token = None - tokenizer_mock.eos_token = "" - mock_auto_tokenizer.return_value = tokenizer_mock - - llm = HuggingFaceLLM(model_name="qwen-3", logger=logger_mock) - - raw_messages = [ - {"role": "tool", "content": "partial output"}, - { - "role": "developer", - "content": [{"text": "line1"}, {"text": "line2"}], - }, - { - "role": "assistant", - "content": None, - "tool_calls": [{"type": "function", "name": "noop", "arguments": {}}], - }, - {"role": "user", "content": None}, - ] - - normalized = llm._normalize_messages_for_template(raw_messages) - - assert normalized == [ - {"role": "user", "content": "partial output"}, - {"role": "user", "content": "line1\nline2"}, - { - "role": "assistant", - "content": json.dumps( - [{"type": "function", "name": "noop", "arguments": {}}] - ), - }, - {"role": "user", "content": ""}, - ] - - -@patch.object( - LLMConfigRegistry, - "from_file", - return_value=LLMConfigRegistry.register_all(MODEL_REGISTRY), -) -@patch("debug_gym.llms.huggingface.AutoTokenizer.from_pretrained") -def test_message_token_counts_uses_chat_template( - mock_auto_tokenizer, mock_llm_config, logger_mock -): - tokenizer_mock = MagicMock() - tokenizer_mock.pad_token = None - tokenizer_mock.eos_token = "" - tokenizer_mock.apply_chat_template.side_effect = [ - {"input_ids": [[1, 2]]}, - {"input_ids": [[1, 2, 3]]}, - {"input_ids": [[1, 2, 3, 4]]}, - ] - mock_auto_tokenizer.return_value = tokenizer_mock - +def test_message_token_counts_uses_chat_template(mock_llm_config, logger_mock): llm = HuggingFaceLLM(model_name="qwen-3", logger=logger_mock) messages = [ {"role": "system", "content": "Instructions"}, - {"role": "user", "content": "Hello"}, + {"role": "user", "content": "Hello world!"}, {"role": "tool", "content": "Result"}, ] - counts = llm._get_message_token_counts(messages) - - assert counts == [2, 1, 1] - assert tokenizer_mock.apply_chat_template.call_count == len(messages) - final_normalized = tokenizer_mock.apply_chat_template.call_args_list[-1][0][0] - assert final_normalized[-1]["role"] == "user" - assert final_normalized[-1]["content"] == "Result" - - -@patch.object( - LLMConfigRegistry, - "from_file", - return_value=LLMConfigRegistry.register_all(MODEL_REGISTRY), -) -@patch("debug_gym.llms.huggingface.AutoTokenizer.from_pretrained") -@patch.object(OpenAILLM, "_get_message_token_counts", return_value=[5, 6]) -def test_message_token_counts_fallbacks_to_openai_when_template_fails( - mock_super_counts, mock_auto_tokenizer, mock_llm_config, logger_mock -): - tokenizer_mock = MagicMock() - tokenizer_mock.pad_token = None - tokenizer_mock.eos_token = "" - tokenizer_mock.apply_chat_template.side_effect = ValueError("no template") - mock_auto_tokenizer.return_value = tokenizer_mock - - llm = HuggingFaceLLM(model_name="qwen-3", logger=logger_mock) - - messages = [ - {"role": "system", "content": "Instructions"}, - {"role": "user", "content": "Hello"}, - ] - - counts = llm._get_message_token_counts(messages) + counts = llm.count_tokens(messages) - assert counts == [5, 6] - mock_super_counts.assert_called_once_with(messages) + # When using chat template, each message gets template tokens added + # The exact counts depend on the template format + assert counts == 31 @pytest.mark.hf_tokenizer @@ -185,36 +104,12 @@ def test_chat_template_counts_with_real_tokenizer(real_qwen3_tokenizer, logger_m messages = [ {"role": "system", "content": "Instructions"}, - {"role": "user", "content": "Hello"}, + {"role": "user", "content": "Hello world!"}, {"role": "tool", "content": "Result"}, ] - counts = llm._get_message_token_counts(messages) - - normalized = llm._normalize_messages_for_template(messages) - expected_counts = [] - prev_len = 0 - for idx in range(1, len(normalized) + 1): - try: - tokenized = real_qwen3_tokenizer.apply_chat_template( - normalized[:idx], tokenize=True, add_generation_prompt=False - ) - except TypeError: # pragma: no cover - version-specific - tokenized = real_qwen3_tokenizer.apply_chat_template( - normalized[:idx], tokenize=True - ) - token_ids = ( - tokenized.get("input_ids") if isinstance(tokenized, dict) else tokenized - ) - if token_ids and isinstance(token_ids[0], list): - token_ids = token_ids[0] - if token_ids is None: - pytest.skip("Tokenizer did not return token ids") - expected_counts.append(len(token_ids) - prev_len) - prev_len = len(token_ids) - - assert counts == expected_counts - assert counts[-1] > 0 + counts = llm.count_tokens(messages) + assert counts == 5 @pytest.mark.hf_tokenizer @@ -235,12 +130,13 @@ def test_tokenize_and_count_tokens_with_real_tokenizer( llm._hf_tokenizer = real_qwen3_tokenizer text = "Hello world!" + messages = [{"role": "user", "content": text}] hf_ids = real_qwen3_tokenizer.encode(text, add_special_tokens=False) hf_tokens = real_qwen3_tokenizer.convert_ids_to_tokens(hf_ids) - tokens = llm.tokenize(text) - assert tokens == hf_tokens - assert llm.count_tokens(text) == len(hf_ids) + tokens = llm.tokenize(messages) + assert tokens == [hf_tokens] + assert llm.count_tokens(messages) == len(hf_ids) @patch.object( @@ -263,8 +159,9 @@ def test_tokenize_and_count_tokens_with_real_tokenizer( ) def test_hf_tokenize_no_chat_template(mock_llm_config, logger_mock): llm = HuggingFaceLLM(model_name="qwen", logger=logger_mock) - tokens = llm.tokenize("hello world") - assert tokens == ["hello", "Ġworld"] + messages = [{"role": "user", "content": "hello world"}] + tokens = llm.tokenize(messages) + assert tokens == [["hello", "Ġworld"]] @patch.object( @@ -288,16 +185,27 @@ def test_hf_tokenize_no_chat_template(mock_llm_config, logger_mock): def test_hf_tokenize_apply_chat_template(mock_llm_config, logger_mock): llm = HuggingFaceLLM(model_name="qwen", logger=logger_mock) - tokens = llm.tokenize("hello world") + messages = [{"role": "user", "content": "hello world"}] + tokens = llm.tokenize(messages) + # When using chat template, all messages are tokenized together, so returns single list assert tokens == [ - "<|im_start|>", - "assistant", - "Ċ", - "", - "ĊĊ", - "", - "ĊĊ", + [ + "<|im_start|>", + "user", + "Ċ", + "hello", + "Ġworld", + "<|im_end|>", + "Ċ", + "<|im_start|>", + "assistant", + "Ċ", + "", + "ĊĊ", + "", + "ĊĊ", + ] ] @@ -323,10 +231,21 @@ def test_hf_tokenize_apply_chat_template(mock_llm_config, logger_mock): def test_hf_tokenize_apply_chat_template_thinking(mock_llm_config, logger_mock): llm = HuggingFaceLLM(model_name="qwen", logger=logger_mock) - tokens = llm.tokenize("hello world") + messages = [{"role": "user", "content": "hello world"}] + tokens = llm.tokenize(messages) + # When using chat template, all messages are tokenized together, so returns single list assert tokens == [ - "<|im_start|>", - "assistant", - "Ċ", + [ + "<|im_start|>", + "user", + "Ċ", + "hello", + "Ġworld", + "<|im_end|>", + "Ċ", + "<|im_start|>", + "assistant", + "Ċ", + ] ] diff --git a/tests/llms/test_human.py b/tests/llms/test_human.py index f0adf300..5c0152ee 100644 --- a/tests/llms/test_human.py +++ b/tests/llms/test_human.py @@ -589,14 +589,16 @@ def test_human_max_retries_exceeded(build_env_info): def test_human_tokenize(): """Test Human tokenization""" human = Human() - tokens = human.tokenize("hello world test") - assert tokens == ["hello", "world", "test"] + messages = [{"role": "user", "content": "hello world test"}] + tokens = human.tokenize(messages) + assert tokens == [["hello", "world", "test"]] def test_human_count_tokens(): """Test Human token counting""" human = Human() - count = human.count_tokens("hello world test") + messages = [{"role": "user", "content": "hello world test"}] + count = human.count_tokens(messages) assert count == 3 diff --git a/tests/llms/test_openai.py b/tests/llms/test_openai.py index bda76b94..abd90751 100644 --- a/tests/llms/test_openai.py +++ b/tests/llms/test_openai.py @@ -532,3 +532,35 @@ def test_llm_without_reasoning_content_attribute( # The response should be just the regular content assert llm_response.response == "Regular response only" + + +@patch.object( + LLMConfigRegistry, + "from_file", + return_value=LLMConfigRegistry.register_all( + { + "qwen": { + "model": "Qwen/Qwen3-0.6B", + "tokenizer": "Qwen/Qwen3-0.6B", + "context_limit": 4, + "api_key": "test-api-key", + "endpoint": "https://test-endpoint", + "api_version": "v1", + "tags": ["openai"], # Using openai tag to force OpenAILLM class + } + } + ), +) +def test_openai_llm_raises_error_for_non_gpt_tokenizer(mock_llm_config, logger_mock): + """Test that OpenAILLM raises ValueError when tokenizer is not a GPT model""" + import pytest + + llm = OpenAILLM(model_name="qwen", logger=logger_mock) + messages = [{"role": "user", "content": "test"}] + + # Should raise ValueError when trying to tokenize with a non-GPT tokenizer + with pytest.raises(ValueError) as exc_info: + llm.tokenize(messages) + + assert "Tokenizer `Qwen/Qwen3-0.6B` not found" in str(exc_info.value) + assert "set tag `vllm`" in str(exc_info.value) diff --git a/tests/llms/test_utils.py b/tests/llms/test_utils.py index b8e1e794..39815b73 100644 --- a/tests/llms/test_utils.py +++ b/tests/llms/test_utils.py @@ -1,4 +1,6 @@ -from debug_gym.llms.utils import print_messages +import pytest + +from debug_gym.llms.utils import print_messages, trim, trim_prompt_messages def test_print_messages(logger_mock): @@ -40,3 +42,276 @@ def test_print_messages_unknown_role(logger_mock): assert "Unknown role" in str(e) else: assert False, "ValueError not raised for unknown role" + + +def test_trim_prompt_messages(): + def count_tokens(messages): + return sum(len(msg.get("content", msg.get("tool_calls"))) for msg in messages) + + # Test basic validation + with pytest.raises(AssertionError, match="messages should not be empty"): + trim_prompt_messages([], 5, count_tokens) + + with pytest.raises( + AssertionError, match="the last message should be from the user or the tool" + ): + messages = [ + {"role": "system", "content": "System message"}, + {"role": "assistant", "content": "Assistant message"}, + ] + trim_prompt_messages(messages, 20, count_tokens) + + with pytest.raises(AssertionError, match="context_length should be non-negative"): + messages = [{"role": "user", "content": "User message"}] + trim_prompt_messages(messages, -1, count_tokens) + + # Test system message too long + with pytest.raises( + AssertionError, match="System message tokens exceed context length" + ): + messages = [ + {"role": "system", "content": "Very long system message"}, + {"role": "user", "content": "Hi"}, + ] + trim_prompt_messages(messages, 10, count_tokens) + + # Test simple case: just user message + messages = [{"role": "user", "content": "Hello"}] + result = trim_prompt_messages(messages, 10, count_tokens) + assert result == messages + + # Test case: system + user, fits completely + messages = [ + {"role": "system", "content": "Sys"}, # 3 tokens + {"role": "user", "content": "Hi"}, # 2 tokens + ] + result = trim_prompt_messages(messages, 10, count_tokens) + assert result == messages + + # Test case: system + user + assistant + tool, fits completely + messages = [ + {"role": "system", "content": "Sys"}, # 3 tokens + {"role": "user", "content": "Hi"}, # 2 tokens + {"role": "assistant", "content": "Hello"}, # 5 tokens + {"role": "tool", "content": "Result"}, # 6 tokens + ] + result = trim_prompt_messages(messages, 20, count_tokens) + assert result == messages + + # Test case: Keep system + most recent assistant-tool pair, drop user + messages = [ + {"role": "system", "content": "Sys"}, # 3 tokens + {"role": "user", "content": "Hi"}, # 2 tokens + {"role": "assistant", "content": "Hello"}, # 5 tokens + {"role": "tool", "content": "Result"}, # 6 tokens + ] + expected = [ + {"role": "system", "content": "Sys"}, + {"role": "assistant", "content": "Hello"}, + {"role": "tool", "content": "Result"}, + ] + result = trim_prompt_messages( + messages, 14, count_tokens + ) # Just enough for sys + assistant + tool + assert result == expected + + # Test case: Multiple assistant-tool pairs, keep most recent ones + messages = [ + {"role": "system", "content": "Sys"}, # 3 tokens + {"role": "user", "content": "Hi"}, # 2 tokens + {"role": "assistant", "content": "Hello1"}, # 6 tokens + {"role": "tool", "content": "Result1"}, # 7 tokens + {"role": "assistant", "content": "Hello2"}, # 6 tokens + {"role": "tool", "content": "Result2"}, # 7 tokens + ] + # Keep system + most recent assistant-tool pair only + expected = [ + {"role": "system", "content": "Sys"}, + {"role": "assistant", "content": "Hello2"}, + {"role": "tool", "content": "Result2"}, + ] + result = trim_prompt_messages( + messages, 16, count_tokens + ) # sys(3) + hello2(6) + result2(7) = 16 + assert result == expected + + # Test case: Can fit all assistant-tool pairs + user message + messages = [ + {"role": "system", "content": "Sys"}, # 3 tokens + {"role": "user", "content": "Hi"}, # 2 tokens + {"role": "assistant", "content": "Hello1"}, # 6 tokens + {"role": "tool", "content": "Result1"}, # 7 tokens + {"role": "assistant", "content": "Hello2"}, # 6 tokens + {"role": "tool", "content": "Result2"}, # 7 tokens + ] + # All pairs fit, so include user message too + result = trim_prompt_messages(messages, 50, count_tokens) + assert result == messages + + # Test case: Can fit all assistant-tool pairs but not user message + messages = [ + {"role": "system", "content": "Sys"}, # 3 tokens + {"role": "user", "content": "Hi"}, # 2 tokens + {"role": "assistant", "content": "Hello1"}, # 6 tokens + {"role": "tool", "content": "Result1"}, # 7 tokens + {"role": "assistant", "content": "Hello2"}, # 6 tokens + {"role": "tool", "content": "Result2"}, # 7 tokens + ] + expected = [ + {"role": "system", "content": "Sys"}, + {"role": "assistant", "content": "Hello1"}, + {"role": "tool", "content": "Result1"}, + {"role": "assistant", "content": "Hello2"}, + {"role": "tool", "content": "Result2"}, + ] + result = trim_prompt_messages( + messages, 29, count_tokens + ) # sys(3) + all pairs(26) = 29, no room for user(2) + assert result == expected + + # Test case: No system message + messages = [ + {"role": "user", "content": "Hi"}, # 2 tokens + {"role": "assistant", "content": "Hello"}, # 5 tokens + {"role": "tool", "content": "Result"}, # 6 tokens + ] + result = trim_prompt_messages(messages, 20, count_tokens) + assert result == messages + + # Test case: No assistant-tool pairs, just system and user + messages = [ + {"role": "system", "content": "Sys"}, # 3 tokens + {"role": "user", "content": "Hi"}, # 2 tokens + ] + result = trim_prompt_messages(messages, 10, count_tokens) + assert result == messages + + # Test case: No system, no assistant-tool pairs, just user + messages = [{"role": "user", "content": "Hi"}] # 2 tokens + result = trim_prompt_messages(messages, 10, count_tokens) + assert result == messages + + # Test case: Tool message without preceding assistant (edge case) + messages = [ + {"role": "system", "content": "Sys"}, # 3 tokens + {"role": "user", "content": "Hi"}, # 2 tokens + {"role": "tool", "content": "Result"}, # 6 tokens + ] + expected = [ + {"role": "system", "content": "Sys"}, + {"role": "user", "content": "Hi"}, + ] + result = trim_prompt_messages(messages, 10, count_tokens) + assert result == expected + + # Test case: Message with tool_calls instead of content + messages = [ + {"role": "system", "content": "Sys"}, # 3 tokens + {"role": "user", "content": "Hi"}, # 2 tokens + { + "role": "assistant", + "tool_calls": [{"function": {"name": "test"}}], + }, # ~30 tokens (str representation) + {"role": "tool", "content": "Result"}, # 6 tokens + ] + result = trim_prompt_messages(messages, 100, count_tokens) + assert len(result) == 4 # Should keep all messages if context is large enough + + +def test_trim(): + def count_tokens(text): + return len(text) + + # Test basic cases - no trimming needed + assert trim("Hello world", 11, count_tokens) == "Hello world" + assert trim("Hello world", 20, count_tokens) == "Hello world" + assert trim("Hi", 2, count_tokens) == "Hi" + assert trim("Hi", 10, count_tokens) == "Hi" + assert trim("A", 1, count_tokens) == "A" # Exactly fits, no trimming needed + + # Test edge cases + assert trim("Hello world", 0, count_tokens) == "" + assert trim("", 5, count_tokens) == "" + assert trim("", 0, count_tokens) == "" + + # Test cases requiring trimming to single token (ellipsis only) + assert trim("Hello world", 1, count_tokens) == "…" + assert trim("Hi", 1, count_tokens) == "…" + assert trim("ABC", 1, count_tokens) == "…" + + # Test trimming from the middle (default behavior) + assert trim("Hello world", 5, count_tokens) == "He…ld" + assert trim("Hello world", 6, count_tokens) == "He…rld" + assert trim("Hello world", 7, count_tokens) == "Hel…rld" + assert trim("123456789", 5, count_tokens) == "12…89" + assert trim("123456789", 7, count_tokens) == "123…789" + + # Test trimming from the end + assert trim("Hello world", 5, count_tokens, where="end") == "Hell…" + assert trim("Hello world", 6, count_tokens, where="end") == "Hello…" + assert trim("Hello world", 7, count_tokens, where="end") == "Hello …" + assert trim("123456789", 5, count_tokens, where="end") == "1234…" + + # Test trimming from the start + assert trim("Hello world", 5, count_tokens, where="start") == "…orld" + assert trim("Hello world", 6, count_tokens, where="start") == "…world" + assert trim("Hello world", 7, count_tokens, where="start") == "… world" + assert trim("123456789", 5, count_tokens, where="start") == "…6789" + + # Test invalid `where` value + with pytest.raises(ValueError, match="Invalid value for `where`"): + trim("Hello world", 5, count_tokens, where="invalid") + + # Test with different token counter + def another_count_tokens(text): + return len(text) // 2 + + # For "1234567890" (10 chars), another_count_tokens returns 5 tokens + # Original text has 5 tokens, so no trimming needed when max_tokens >= 5 + assert trim("1234567890", 5, another_count_tokens) == "1234567890" + assert trim("1234567890", 6, another_count_tokens) == "1234567890" + + # When max_tokens < 5, trimming is needed + # With max_tokens=4, we need 3 tokens for content + 1 for ellipsis + assert ( + trim("1234567890", 4, another_count_tokens) == "123…67890" + ) # Result has 4 tokens + assert ( + trim("1234567890", 3, another_count_tokens) == "123…890" + ) # Result has 3 tokens + assert trim("1234567890", 2, another_count_tokens) == "1…890" # Result has 2 tokens + assert trim("1234567890", 1, another_count_tokens) == "1…0" # Result has 1 token + + # Test with different trimming positions using the alternative counter + assert ( + trim("1234567890", 3, another_count_tokens, where="end") == "12345…" + ) # Result has 3 tokens + assert ( + trim("1234567890", 3, another_count_tokens, where="start") == "…67890" + ) # Result has 3 tokens + + # Test edge case with very short text and alternative counter + assert trim("AB", 1, another_count_tokens) == "AB" # "AB" has 1 token, fits exactly + assert ( + trim("ABCD", 1, another_count_tokens) == "A…D" + ) # "ABCD" has 2 tokens, needs trimming to 1 + + # Test boundary conditions with precise scenarios + def word_count_tokens(text): + # Count words as tokens + return len(text.split()) + + text = "Hello world test example" # 4 words = 4 tokens + assert trim(text, 4, word_count_tokens) == text # No trimming needed + assert ( + trim(text, 3, word_count_tokens, where="middle") == "Hello … example" + ) # Should fit in 3 tokens + assert ( + trim(text, 2, word_count_tokens, where="end") == "Hello …" + ) # Should fit in 2 tokens + assert ( + trim(text, 2, word_count_tokens, where="start") == "… example" + ) # Should fit in 2 tokens + + # Test very short max_tokens with word counter + assert trim("Hello world", 1, word_count_tokens) == "…" # Only ellipsis fits