diff --git a/README.md b/README.md index 2c1d15ba..6ac79648 100644 --- a/README.md +++ b/README.md @@ -42,7 +42,7 @@ Then, edit this file with your endpoint and credentials. You can choose one of t - For `az login` or Managed Identity authentication on Azure, remove `api_key` and include `scope` instead. > [!WARNING] -> When using open-sourced LLMs, e.g., via vLLM, you need to correctly setup `HF_TOKEN` required by the tokenizer. +> When using open-sourced LLMs, e.g., via vLLM, you need to correctly setup `HF_TOKEN` required by the tokenizer. You can also provide `tokenizer_kwargs` in your `llm.yaml` entry (for example `trust_remote_code: true`) to control how the Hugging Face tokenizer is instantiated. By default, `debug-gym` looks for the LLM config file at `$HOME/.config/debug_gym/llm.yaml`. You can change this behavior by exporting the environment variable `LLM_CONFIG_FILE_PATH` or by setting `llm_config_file_path` in your script config file (see [Running Baselines](#3-running-baselines)). @@ -65,7 +65,7 @@ debug_gym `debug_gym.agents` are LLM-based debugging agents that use `debug_gym.gym` to interact with code repositories to seek necessary information and thus fix potential bugs. At an interaction step, the agent takes a text observation that describes the environment states and tool states as input, it is expected to generate a command, subsequently, the environment will provide a new text observation in response, describing the state change caused by that command. -`debug_gym.llms` are the different LLM backends that can be used to instantiate agents. Currently, we support OpenAI, Azure OpenAI, and Anthropic. +`debug_gym.llms` are the different LLM backends that can be used to instantiate agents. Currently, we support OpenAI, Azure OpenAI, Hugging Face/vLLM deployments (via an OpenAI-compatible endpoint), and Anthropic. For Hugging Face models served through vLLM, the tokenizer's chat template is applied automatically to ensure token counting and truncation match the hosted model. > [!WARNING] > `debug-gym` has limited support on non-Linux platforms. Interactive terminal sessions using PTY (pseudo-terminal) in Docker are not fully supported on macOS or Windows. As a result, the `pdb` tool (see [2.1. Environment and Tools](#21-environment-and-tools)) only works on Linux. @@ -202,7 +202,7 @@ In addition to all [built-in Jinja filters](https://jinja.palletsprojects.com/en {{ info.tools | to_pretty_json }} ``` -- **`trim_message`**: Trims a string to fit within a token or character limit, also filtering out non-UTF8 characters. This is helpful for ensuring that large outputs (such as directory trees or evaluation results) do not exceed the LLM's context window. The `trim_message` filter accepts the following arguments to control how messages are trimmed: + - **`trim_message`**: Trims a string to approximately fit within a token or character limit while filtering non-UTF8 characters. This helps keep large outputs (such as directory trees or evaluation results) within the LLM's context window. The `trim_message` filter accepts the following arguments to control how messages are trimmed: - **`max_length`**: The maximum number of tokens to keep in the message. If the message exceeds this length, it will be trimmed. - **`max_length_percentage`**: Instead of specifying an absolute number, you can provide a percentage (e.g., `0.1` for 10%) of the LLM's context window. The message will be trimmed to fit within this percentage of the model's maximum context length. - **`where`**: Specifies where to trim the message if it exceeds the limit. The default is `"middle"`, which trims from the middle of the message. Other options are `start` or `end`. diff --git a/debug_gym/agents/base_agent.py b/debug_gym/agents/base_agent.py index 2907f076..c9fe3d6d 100644 --- a/debug_gym/agents/base_agent.py +++ b/debug_gym/agents/base_agent.py @@ -8,10 +8,10 @@ from jinja2 import Environment, Template from debug_gym.agents.history_tracker import HistoryTracker, build_history_prompt -from debug_gym.agents.utils import trim from debug_gym.gym.envs.env import RepoEnv from debug_gym.gym.utils import filter_non_utf8 from debug_gym.llms.base import LLM +from debug_gym.llms.utils import trim from debug_gym.logger import DebugGymLogger AGENT_REGISTRY = {} @@ -119,7 +119,7 @@ def to_pretty_json(value): def trim_message( self, - message, + message: str, count_tokens=None, max_length=None, max_length_percentage=0, @@ -140,10 +140,8 @@ def trim_message( if count_tokens is None or max_length is None or max_length <= 0: return message - tokens = count_tokens(message) - if tokens > max_length: - return trim(message, max_length, count_tokens=count_tokens, where=where) - return message + + return trim(message, max_length, count_tokens=count_tokens, where=where) def _load_system_prompt_template(self) -> Template | None: """Load system prompt template from config if specified and register custom filters. diff --git a/debug_gym/agents/utils.py b/debug_gym/agents/utils.py index d86e2554..3d694237 100644 --- a/debug_gym/agents/utils.py +++ b/debug_gym/agents/utils.py @@ -5,184 +5,6 @@ import yaml -def trim(text: str, max_tokens: int, count_tokens: callable, where: str = "middle"): - """Trim text to fit within max_tokens by working directly at the token level.""" - if max_tokens <= 0: - return "" - - nb_tokens = count_tokens(text) - if nb_tokens <= max_tokens: - return text - - ellipsis = "…" # assume ellipsis is a single token - available_tokens = max_tokens - 1 # account for ellipsis - - def find_char_position_for_tokens( - target_tokens: int, from_start: bool = True - ) -> int: - """Binary search to find character position that gives approximately target_tokens.""" - left, right = 0, len(text) - best_pos = left if from_start else right - - while left <= right: - mid = (left + right) // 2 - test_text = text[:mid] if from_start else text[mid:] - test_tokens = count_tokens(test_text) - if test_tokens <= target_tokens: - best_pos = mid - if from_start: - left = mid + 1 - else: - right = mid - 1 - else: - if from_start: - right = mid - 1 - else: - left = mid + 1 - return best_pos - - if where == "end": - # Keep the beginning, trim the end - trim_point = find_char_position_for_tokens(available_tokens, from_start=True) - return text[:trim_point] + ellipsis - elif where == "start": - # Keep the end, trim the beginning - trim_point = find_char_position_for_tokens(available_tokens, from_start=False) - return ellipsis + text[trim_point:] - elif where == "middle": - # Keep both ends, trim the middle - half_tokens = available_tokens // 2 - - # Find how much we can keep from the start - start_chars = find_char_position_for_tokens(half_tokens, from_start=True) - - # Find how much we can keep from the end with remaining tokens - remaining_tokens = available_tokens - count_tokens(text[:start_chars]) - end_chars = find_char_position_for_tokens(remaining_tokens, from_start=False) - - return text[:start_chars] + ellipsis + text[end_chars:] - else: - raise ValueError(f"Invalid value for `where`: {where!r}.") - - -def get_message_tokens(message, count_tokens): - """Count tokens in a message.""" - message_content = str(message.get("content", message.get("tool_calls", message))) - return count_tokens(message_content) - - -def trim_prompt_messages( - messages: list[dict], context_length: int, count_tokens: callable -): - """ - Trim message on the assistant-tool pair level to fit context length. - - Strategy: - 1. Keep the system message (assert if system itself is too long) - 2. Keep as many most recent (assistant, tool) pairs as possible - 3. Only when we can keep all (assistant, tool) pairs, keep the user message - - Args: - messages: List of message dicts with 'role' and 'content'/'tool_calls' keys - context_length: Maximum number of tokens allowed - count_tokens: Function to count tokens in a string - - Returns: - Trimmed list of messages that fit within context_length - """ - assert len(messages) > 0, "messages should not be empty" - assert messages[-1]["role"] in [ - "user", - "tool", - ], "the last message should be from the user or the tool" - assert context_length >= 0, "context_length should be non-negative" - - # Calculate token count for all messages - message_tokens = [get_message_tokens(msg, count_tokens) for msg in messages] - total_tokens = sum(message_tokens) - - # If we're already within limit, return as-is - if total_tokens <= context_length: - return messages - - # Find system message - system_msg_idx = 0 if messages[0]["role"] == "system" else None - system_tokens = message_tokens[0] if system_msg_idx is not None else 0 - - # Assert system message fits within context - assert ( - system_tokens <= context_length - ), f"System message tokens exceed context length: {system_tokens} > {context_length}!" - - # Find user message - user_msg_idx = None - for i, msg in enumerate(messages): - if msg["role"] == "user": - user_msg_idx = i - break - - # Find all (assistant, tool) pairs by going backwards - assistant_tool_pairs = [] - i = len(messages) - 1 - while i >= 0: - if ( - messages[i]["role"] == "tool" - and i > 0 - and messages[i - 1]["role"] == "assistant" - ): - assistant_tool_pairs.append((i - 1, i)) # (assistant_idx, tool_idx) - i -= 2 - else: - i -= 1 - - # Start building result with system message - result = [] - remaining_tokens = context_length - - if system_msg_idx is not None: - result.append(messages[system_msg_idx]) - remaining_tokens -= system_tokens - - # Add as many recent (assistant, tool) pairs as possible - included_pairs = [] - for assistant_idx, tool_idx in assistant_tool_pairs: - pair_tokens = message_tokens[assistant_idx] + message_tokens[tool_idx] - if pair_tokens <= remaining_tokens: - included_pairs.append((assistant_idx, tool_idx)) - remaining_tokens -= pair_tokens - else: - break - - # Only include user message if we can fit all (assistant, tool) pairs - include_user = False - if len(included_pairs) == len(assistant_tool_pairs) and user_msg_idx is not None: - user_tokens = message_tokens[user_msg_idx] - if user_tokens <= remaining_tokens: - include_user = True - - # Build final result - if include_user: - result.append(messages[user_msg_idx]) - - # Sort by assistant index to maintain chronological order - included_pairs.sort(key=lambda pair: pair[0]) - for assistant_idx, tool_idx in included_pairs: - result.append(messages[assistant_idx]) - result.append(messages[tool_idx]) - - assert ( - len(result) > 0 - ), f"After trimming, no messages fit within context length: {context_length}!" - - # Verify final token count - final_tokens = sum(get_message_tokens(msg, count_tokens) for msg in result) - assert ( - final_tokens <= context_length - ), f"After trimming, the message length still exceeds: {final_tokens} > {context_length}!" - - return result - - def load_config(): parser = argparse.ArgumentParser() parser.add_argument("config_file", help="path to config file") diff --git a/debug_gym/llms/__init__.py b/debug_gym/llms/__init__.py index b7cfa674..91e09d9f 100644 --- a/debug_gym/llms/__init__.py +++ b/debug_gym/llms/__init__.py @@ -1,5 +1,6 @@ from debug_gym.llms.anthropic import AnthropicLLM from debug_gym.llms.azure_openai import AzureOpenAILLM from debug_gym.llms.base import LLM +from debug_gym.llms.huggingface import HuggingFaceLLM from debug_gym.llms.human import Human from debug_gym.llms.openai import OpenAILLM diff --git a/debug_gym/llms/anthropic.py b/debug_gym/llms/anthropic.py index 947da94d..d19e6124 100644 --- a/debug_gym/llms/anthropic.py +++ b/debug_gym/llms/anthropic.py @@ -41,24 +41,18 @@ def client(self): self._client = Anthropic(api_key=self.config.api_key) return self._client - def tokenize(self, text: str) -> list[str]: - raise NotImplementedError("Tokenization is not supported by Anthropic.") + def tokenize(self, messages: list[dict]) -> list[list[str]]: + """Tokenization is not directly supported by Anthropic. + This method returns empty token lists as a placeholder.""" + raise NotImplementedError("Direct tokenization is not supported by Anthropic.") + + def count_tokens(self, messages: list[dict] | str) -> int: + """Count the number of tokens in a text using the Anthropic API.""" + if isinstance(messages, str): + messages = [ + {"role": "user", "content": [{"type": "text", "text": messages}]} + ] - def count_tokens(self, text: str) -> int: - """Count the number of tokens in a text using the Anthropic API. - Dump content to JSON for cases such as: - { - "role": "user", - "content": [ - { - "type": "tool_result", - "tool_use_id": "id123", - "content": "results", - } - ], - } - """ - messages = [{"role": "user", "content": [{"type": "text", "text": text}]}] try: response = self.client.messages.count_tokens( model=self.tokenizer_name, messages=messages @@ -67,7 +61,7 @@ def count_tokens(self, text: str) -> int: except Exception as e: self.logger.warning( f"Error calling Claude token count API: {e!r}. " - f"The message was: {messages}. Will return 0 tokens." + f"The messages were: {messages}. Will return 0 tokens." ) return 0 diff --git a/debug_gym/llms/base.py b/debug_gym/llms/base.py index b0f14139..2dfe9b32 100644 --- a/debug_gym/llms/base.py +++ b/debug_gym/llms/base.py @@ -16,7 +16,7 @@ from debug_gym.gym.envs.env import EnvInfo from debug_gym.gym.tools.tool import EnvironmentTool, ToolCall from debug_gym.llms.constants import DEFAULT_LLM_CONFIG -from debug_gym.llms.utils import print_messages +from debug_gym.llms.utils import print_messages, trim_prompt_messages from debug_gym.logger import DebugGymLogger # Set logging level down to WARNING for endpoint queries. @@ -57,6 +57,8 @@ class LLMConfig: api_key: Optional[str] = None endpoint: Optional[str] = None tokenizer: Optional[str] = None + apply_chat_template: Optional[bool] = False + enable_thinking: Optional[bool] = False reasoning_end_token: Optional[str] = None system_prompt_support: bool = True ignore_kwargs: List[str] = None @@ -66,6 +68,8 @@ class LLMConfig: scope: Optional[str] = None # Custom parameters to pass to generate generate_kwargs: dict = None + # Additional kwargs for tokenizer construction (e.g., trust_remote_code) + tokenizer_kwargs: dict | None = None def __post_init__(self): # Set tokenizer to model if not specified @@ -78,6 +82,8 @@ def __post_init__(self): self.tags = [] if self.generate_kwargs is None: self.generate_kwargs = {} + if self.tokenizer_kwargs is None: + self.tokenizer_kwargs = {} @dataclass @@ -201,6 +207,8 @@ def __init__( ) self.tokenizer_name = self.config.tokenizer self.context_length = self.config.context_limit * 1000 + self.apply_chat_template = self.config.apply_chat_template + self.enable_thinking = self.config.enable_thinking self.reasoning_end_token = self.config.reasoning_end_token self.logger.debug( @@ -237,6 +245,7 @@ def instantiate( llm_config = LLMConfigRegistry.from_file(llm_config_file_path)[llm_name] tags = llm_config.tags + if "copilot openai" in tags: from debug_gym.llms.copilot import CopilotOpenAILLM @@ -246,18 +255,27 @@ def instantiate( from debug_gym.llms.copilot import CopilotClaudeLLM klass = CopilotClaudeLLM + elif "azure openai" in tags: from debug_gym.llms import AzureOpenAILLM klass = AzureOpenAILLM + + elif "vllm" in tags: + from debug_gym.llms import HuggingFaceLLM + + klass = HuggingFaceLLM + elif "anthropic" in tags: from debug_gym.llms import AnthropicLLM klass = AnthropicLLM + else: from debug_gym.llms import OpenAILLM klass = OpenAILLM + llm = klass(llm_name, logger=logger, llm_config=llm_config) return llm @@ -269,13 +287,30 @@ def generate(self, messages, tools, **kwargs) -> LLMResponse: pass @abstractmethod - def tokenize(self, text: str) -> list[str]: - """Abstract method to tokenize a text.""" + def tokenize(self, messages: list[dict]) -> list[list[str]]: + """Abstract method to tokenize messages. + + Args: + messages: List of message dicts + + Returns: + List of token lists, one per message + """ pass - def count_tokens(self, text: str) -> int: - """Count the number of tokens in a text.""" - return len(self.tokenize(text)) + def count_tokens(self, messages: list[dict] | str) -> int: + """Count the total number of tokens across all messages. + + Args: + messages: List of message dicts + + Returns: + Total token count across all messages + """ + if isinstance(messages, str): + messages = [{"role": "user", "content": messages}] + tokenized = self.tokenize(messages) + return sum(len(tokens) for tokens in tokenized) @abstractmethod def define_tools(self, tool_call_list: list[EnvironmentTool]) -> list[dict]: @@ -311,8 +346,6 @@ def __call__(self, messages, tools, *args, **kwargs) -> LLMResponse: should be implemented by subclasses. Returns an LLMResponse object with the prompt, response and token usage. """ - from debug_gym.agents.utils import get_message_tokens, trim_prompt_messages - # Add custom generation parameters from config for key, value in self.config.generate_kwargs.items(): # Only set if not already specified in the call @@ -347,10 +380,8 @@ def generate_with_drop_message_and_retry(messages, tools, **kwargs): for retry_count in range(max_retries + 1): try: # pre-truncate messages if they are too long, to avoid unnecessary retries - message_tokens = sum( - get_message_tokens(msg, self.count_tokens) for msg in messages - ) - if message_tokens > self.context_length * 1.2: + message_tokens = self.count_tokens(messages) + if message_tokens > self.context_length: trimmed_messages = trim_prompt_messages( messages, self.context_length, self.count_tokens ) diff --git a/debug_gym/llms/copilot.py b/debug_gym/llms/copilot.py index eb6eff9c..ee8075c6 100644 --- a/debug_gym/llms/copilot.py +++ b/debug_gym/llms/copilot.py @@ -177,15 +177,23 @@ def _create_copilot_client(self) -> OpenAI: timeout=None, ) - def tokenize(self, text: str) -> list[str]: + def tokenize(self, messages: list[dict]) -> list[list[str]]: if getattr(self, "_tk_func", None) is None: try: - self._tk_func = tiktoken.encoding_for_model("gpt-4o").encode + encoder = tiktoken.encoding_for_model("gpt-4o") + # For tiktoken, encode returns list of ints, convert to strings + self._tk_func = lambda text: [str(t) for t in encoder.encode(text)] except KeyError: # Simple word-based tokenization as fallback - # For Claude, you might want to use tiktoken or another tokenizer self._tk_func = lambda x: x.split() - return self._tk_func(text) + + # Tokenize each message individually + result = [] + for msg in messages: + content = str(msg.get("content", msg.get("tool_calls", msg))) + tokens = self._tk_func(content) + result.append(tokens) + return result def need_to_be_retried(self, exception) -> bool: # re-use the need_to_be_retried function from the parent class diff --git a/debug_gym/llms/huggingface.py b/debug_gym/llms/huggingface.py new file mode 100644 index 00000000..27fa1c6c --- /dev/null +++ b/debug_gym/llms/huggingface.py @@ -0,0 +1,61 @@ +import json +from typing import Iterable + +from transformers import AutoTokenizer + +from debug_gym.llms.openai import OpenAILLM + + +class HuggingFaceLLM(OpenAILLM): + """LLM implementation backed by a Hugging Face tokenizer.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._hf_tokenizer = None + + def _load_tokenizer(self): + if self._hf_tokenizer is None: + tokenizer_kwargs = getattr(self.config, "tokenizer_kwargs", None) or {} + try: + self._hf_tokenizer = AutoTokenizer.from_pretrained( + self.tokenizer_name, **tokenizer_kwargs + ) + except OSError: + raise ValueError( + f"Tokenizer `{self.tokenizer_name}` not found for model " + f"{self.model_name}, make sure you have access to " + "the model (e.g., HuggingFace API key is correctly set)." + ) + + # Ensure we have a pad token to avoid downstream warnings when invoking + # the tokenizer in encode mode. + if ( + getattr(self._hf_tokenizer, "pad_token", None) is None + and getattr(self._hf_tokenizer, "eos_token", None) is not None + ): + self._hf_tokenizer.pad_token = self._hf_tokenizer.eos_token + return self._hf_tokenizer + + def tokenize(self, messages: list[dict]) -> list[list[str]]: + tokenizer = self._load_tokenizer() + + if self.apply_chat_template: + # When applying chat template, tokenize all messages together + # then return as a single list + text = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=self.enable_thinking, + ) + tokens = tokenizer.tokenize(text) + # Return as list with single element (all tokens together) + return [tokens] + else: + # Tokenize each message individually + result = [] + for msg in messages: + content = str(msg["content"]) + tokens = tokenizer.tokenize(content) + result.append(tokens) + return result diff --git a/debug_gym/llms/human.py b/debug_gym/llms/human.py index 7557c8fa..28b34176 100644 --- a/debug_gym/llms/human.py +++ b/debug_gym/llms/human.py @@ -471,12 +471,21 @@ def __init__( if prompt_toolkit_available: self._history = InMemoryHistory() - def tokenize(self, text: str) -> list[str]: - """Tokenizes a text by splitting it by spaces.""" - return text.split() - - def count_tokens(self, text: str) -> int: - return len(self.tokenize(text)) + def tokenize(self, messages: list[dict]) -> list[list[str]]: + """Tokenizes messages by splitting content by spaces.""" + result = [] + for msg in messages: + content = str(msg.get("content", msg.get("tool_calls", msg))) + tokens = content.split() + result.append(tokens) + return result + + def count_tokens(self, messages: list[dict] | str) -> int: + """Count tokens across all messages.""" + if isinstance(messages, str): + messages = [{"role": "user", "content": messages}] + tokenized = self.tokenize(messages) + return sum(len(tokens) for tokens in tokenized) def define_tools(self, tool_call_list: list[EnvironmentTool]) -> list[dict]: available_commands = [] @@ -615,6 +624,6 @@ def __call__(self, messages, tools, *args, **kwargs) -> LLMResponse: prompt=messages, response=action, tool=tool_call, - prompt_token_count=self.count_tokens(json.dumps(messages)), - response_token_count=self.count_tokens(action), + prompt_token_count=self.count_tokens(messages), + response_token_count=self.count_tokens([{"tool_calls": action}]), ) diff --git a/debug_gym/llms/openai.py b/debug_gym/llms/openai.py index 7fa00051..ed0b9006 100644 --- a/debug_gym/llms/openai.py +++ b/debug_gym/llms/openai.py @@ -62,22 +62,25 @@ def client(self): ) return self._client - def tokenize(self, text: str) -> list[str]: + def tokenize(self, messages: list[dict]) -> list[list[str]]: if getattr(self, "_tk_func", None) is None: try: - self._tk_func = tiktoken.encoding_for_model(self.tokenizer_name).encode + encoder = tiktoken.encoding_for_model(self.tokenizer_name) + # For tiktoken, encode returns list of ints, we need to convert to list of "tokens" + self._tk_func = lambda text: [str(t) for t in encoder.encode(text)] except KeyError: - try: # Try to load from transformers. - self._tk_func = AutoTokenizer.from_pretrained( - self.tokenizer_name - ).tokenize - except OSError: - raise ValueError( - f"Tokenizer `{self.tokenizer_name}` not found for model " - f"{self.model_name}, make sure you have access to " - "the model (e.g., HuggingFace API key is correctly set)." - ) - return self._tk_func(text) + raise ValueError( + f"Tokenizer `{self.tokenizer_name}` not found for model " + f"{self.model_name}. If using Hugging Face models, please " + f"set tag `vllm` to load the HuggingFaceLLM class instead." + ) + # Tokenize each message individually + result = [] + for msg in messages: + content = str(msg.get("content", msg.get("tool_calls", msg))) + tokens = self._tk_func(content) + result.append(tokens) + return result def need_to_be_retried(self, exception) -> bool: # List of fully qualified names of RateLimitError exceptions from various libraries diff --git a/debug_gym/llms/utils.py b/debug_gym/llms/utils.py index 5bbcd0cf..8a5c4cd2 100644 --- a/debug_gym/llms/utils.py +++ b/debug_gym/llms/utils.py @@ -1,6 +1,191 @@ from debug_gym.logger import DebugGymLogger, log_with_color +def get_message_tokens(message, count_tokens): + """Count tokens in a single message. + + Args: + message: A single message dict + count_tokens: Function that takes a list of messages and returns total token count + + Returns: + Token count for this message + """ + return count_tokens([message]) + + +def trim(text: str, max_tokens: int, count_tokens: callable, where: str = "middle"): + """Trim text to fit within max_tokens by working directly at the token level.""" + if max_tokens <= 0: + return "" + + nb_tokens = count_tokens(text) + if nb_tokens <= max_tokens: + return text + + ellipsis = "…" # assume ellipsis is a single token + available_tokens = max_tokens - 1 # account for ellipsis + + def find_char_position_for_tokens( + target_tokens: int, from_start: bool = True + ) -> int: + """Binary search to find character position that gives approximately target_tokens.""" + left, right = 0, len(text) + best_pos = left if from_start else right + + while left <= right: + mid = (left + right) // 2 + test_text = text[:mid] if from_start else text[mid:] + test_tokens = count_tokens(test_text) + if test_tokens <= target_tokens: + best_pos = mid + if from_start: + left = mid + 1 + else: + right = mid - 1 + else: + if from_start: + right = mid - 1 + else: + left = mid + 1 + return best_pos + + if where == "end": + # Keep the beginning, trim the end + trim_point = find_char_position_for_tokens(available_tokens, from_start=True) + return text[:trim_point] + ellipsis + elif where == "start": + # Keep the end, trim the beginning + trim_point = find_char_position_for_tokens(available_tokens, from_start=False) + return ellipsis + text[trim_point:] + elif where == "middle": + # Keep both ends, trim the middle + half_tokens = available_tokens // 2 + + # Find how much we can keep from the start + start_chars = find_char_position_for_tokens(half_tokens, from_start=True) + + # Find how much we can keep from the end with remaining tokens + remaining_tokens = available_tokens - count_tokens(text[:start_chars]) + end_chars = find_char_position_for_tokens(remaining_tokens, from_start=False) + + return text[:start_chars] + ellipsis + text[end_chars:] + else: + raise ValueError(f"Invalid value for `where`: {where!r}.") + + +def trim_prompt_messages( + messages: list[dict], context_length: int, count_tokens: callable +): + """ + Trim message on the assistant-tool pair level to fit context length. + + Strategy: + 1. Keep the system message (assert if system itself is too long) + 2. Keep as many most recent (assistant, tool) pairs as possible + 3. Only when we can keep all (assistant, tool) pairs, keep the user message + + Args: + messages: List of message dicts with 'role' and 'content'/'tool_calls' keys + context_length: Maximum number of tokens allowed + count_tokens: Function to count tokens in a string + + Returns: + Trimmed list of messages that fit within context_length + """ + assert len(messages) > 0, "messages should not be empty" + assert messages[-1]["role"] in [ + "user", + "tool", + ], "the last message should be from the user or the tool" + assert context_length >= 0, "context_length should be non-negative" + + # Calculate token count for all messages + message_tokens = [get_message_tokens(msg, count_tokens) for msg in messages] + total_tokens = sum(message_tokens) + + # If we're already within limit, return as-is + if total_tokens <= context_length: + return messages + + # Find system message + system_msg_idx = 0 if messages[0]["role"] == "system" else None + system_tokens = message_tokens[0] if system_msg_idx is not None else 0 + + # Assert system message fits within context + assert ( + system_tokens <= context_length + ), f"System message tokens exceed context length: {system_tokens} > {context_length}!" + + # Find user message + user_msg_idx = None + for i, msg in enumerate(messages): + if msg["role"] == "user": + user_msg_idx = i + break + + # Find all (assistant, tool) pairs by going backwards + assistant_tool_pairs = [] + i = len(messages) - 1 + while i >= 0: + if ( + messages[i]["role"] == "tool" + and i > 0 + and messages[i - 1]["role"] == "assistant" + ): + assistant_tool_pairs.append((i - 1, i)) # (assistant_idx, tool_idx) + i -= 2 + else: + i -= 1 + + # Start building result with system message + result = [] + remaining_tokens = context_length + + if system_msg_idx is not None: + result.append(messages[system_msg_idx]) + remaining_tokens -= system_tokens + + # Add as many recent (assistant, tool) pairs as possible + included_pairs = [] + for assistant_idx, tool_idx in assistant_tool_pairs: + pair_tokens = message_tokens[assistant_idx] + message_tokens[tool_idx] + if pair_tokens <= remaining_tokens: + included_pairs.append((assistant_idx, tool_idx)) + remaining_tokens -= pair_tokens + else: + break + + # Only include user message if we can fit all (assistant, tool) pairs + include_user = False + if len(included_pairs) == len(assistant_tool_pairs) and user_msg_idx is not None: + user_tokens = message_tokens[user_msg_idx] + if user_tokens <= remaining_tokens: + include_user = True + + # Build final result + if include_user: + result.append(messages[user_msg_idx]) + + # Sort by assistant index to maintain chronological order + included_pairs.sort(key=lambda pair: pair[0]) + for assistant_idx, tool_idx in included_pairs: + result.append(messages[assistant_idx]) + result.append(messages[tool_idx]) + + assert ( + len(result) > 0 + ), f"After trimming, no messages fit within context length: {context_length}!" + + # Verify final token count + final_tokens = sum(get_message_tokens(msg, count_tokens) for msg in result) + assert ( + final_tokens <= context_length + ), f"After trimming, the message length still exceeds: {final_tokens} > {context_length}!" + + return result + + def print_messages(messages: list[dict], logger: DebugGymLogger): """Print messages coloring each role differently. Colors: diff --git a/pytest.ini b/pytest.ini index 7654985f..dbde339b 100644 --- a/pytest.ini +++ b/pytest.ini @@ -2,4 +2,6 @@ norecursedirs = data/* asyncio_default_fixture_loop_scope = function env = - BASH_SILENCE_DEPRECATION_WARNING=1 # Suppress deprecation warnings in MacOS (zsh default) \ No newline at end of file + BASH_SILENCE_DEPRECATION_WARNING=1 # Suppress deprecation warnings in MacOS (zsh default) +markers = + hf_tokenizer: tests that require downloading a Hugging Face tokenizer and may be slow \ No newline at end of file diff --git a/tests/agents/test_utils.py b/tests/agents/test_utils.py index a68b7550..bb479a21 100644 --- a/tests/agents/test_utils.py +++ b/tests/agents/test_utils.py @@ -1,183 +1,7 @@ import logging from unittest.mock import patch -import pytest - -from debug_gym.agents.utils import load_config, trim, trim_prompt_messages - - -def test_trim_prompt_messages(): - def count_tokens(text): - return len(text) - - # Test basic validation - with pytest.raises(AssertionError, match="messages should not be empty"): - trim_prompt_messages([], 5, count_tokens) - - with pytest.raises( - AssertionError, match="the last message should be from the user or the tool" - ): - messages = [ - {"role": "system", "content": "System message"}, - {"role": "assistant", "content": "Assistant message"}, - ] - trim_prompt_messages(messages, 20, count_tokens) - - with pytest.raises(AssertionError, match="context_length should be non-negative"): - messages = [{"role": "user", "content": "User message"}] - trim_prompt_messages(messages, -1, count_tokens) - - # Test system message too long - with pytest.raises( - AssertionError, match="System message tokens exceed context length" - ): - messages = [ - {"role": "system", "content": "Very long system message"}, - {"role": "user", "content": "Hi"}, - ] - trim_prompt_messages(messages, 10, count_tokens) - - # Test simple case: just user message - messages = [{"role": "user", "content": "Hello"}] - result = trim_prompt_messages(messages, 10, count_tokens) - assert result == messages - - # Test case: system + user, fits completely - messages = [ - {"role": "system", "content": "Sys"}, # 3 tokens - {"role": "user", "content": "Hi"}, # 2 tokens - ] - result = trim_prompt_messages(messages, 10, count_tokens) - assert result == messages - - # Test case: system + user + assistant + tool, fits completely - messages = [ - {"role": "system", "content": "Sys"}, # 3 tokens - {"role": "user", "content": "Hi"}, # 2 tokens - {"role": "assistant", "content": "Hello"}, # 5 tokens - {"role": "tool", "content": "Result"}, # 6 tokens - ] - result = trim_prompt_messages(messages, 20, count_tokens) - assert result == messages - - # Test case: Keep system + most recent assistant-tool pair, drop user - messages = [ - {"role": "system", "content": "Sys"}, # 3 tokens - {"role": "user", "content": "Hi"}, # 2 tokens - {"role": "assistant", "content": "Hello"}, # 5 tokens - {"role": "tool", "content": "Result"}, # 6 tokens - ] - expected = [ - {"role": "system", "content": "Sys"}, - {"role": "assistant", "content": "Hello"}, - {"role": "tool", "content": "Result"}, - ] - result = trim_prompt_messages( - messages, 14, count_tokens - ) # Just enough for sys + assistant + tool - assert result == expected - - # Test case: Multiple assistant-tool pairs, keep most recent ones - messages = [ - {"role": "system", "content": "Sys"}, # 3 tokens - {"role": "user", "content": "Hi"}, # 2 tokens - {"role": "assistant", "content": "Hello1"}, # 6 tokens - {"role": "tool", "content": "Result1"}, # 7 tokens - {"role": "assistant", "content": "Hello2"}, # 6 tokens - {"role": "tool", "content": "Result2"}, # 7 tokens - ] - # Keep system + most recent assistant-tool pair only - expected = [ - {"role": "system", "content": "Sys"}, - {"role": "assistant", "content": "Hello2"}, - {"role": "tool", "content": "Result2"}, - ] - result = trim_prompt_messages( - messages, 16, count_tokens - ) # sys(3) + hello2(6) + result2(7) = 16 - assert result == expected - - # Test case: Can fit all assistant-tool pairs + user message - messages = [ - {"role": "system", "content": "Sys"}, # 3 tokens - {"role": "user", "content": "Hi"}, # 2 tokens - {"role": "assistant", "content": "Hello1"}, # 6 tokens - {"role": "tool", "content": "Result1"}, # 7 tokens - {"role": "assistant", "content": "Hello2"}, # 6 tokens - {"role": "tool", "content": "Result2"}, # 7 tokens - ] - # All pairs fit, so include user message too - result = trim_prompt_messages(messages, 50, count_tokens) - assert result == messages - - # Test case: Can fit all assistant-tool pairs but not user message - messages = [ - {"role": "system", "content": "Sys"}, # 3 tokens - {"role": "user", "content": "Hi"}, # 2 tokens - {"role": "assistant", "content": "Hello1"}, # 6 tokens - {"role": "tool", "content": "Result1"}, # 7 tokens - {"role": "assistant", "content": "Hello2"}, # 6 tokens - {"role": "tool", "content": "Result2"}, # 7 tokens - ] - expected = [ - {"role": "system", "content": "Sys"}, - {"role": "assistant", "content": "Hello1"}, - {"role": "tool", "content": "Result1"}, - {"role": "assistant", "content": "Hello2"}, - {"role": "tool", "content": "Result2"}, - ] - result = trim_prompt_messages( - messages, 29, count_tokens - ) # sys(3) + all pairs(26) = 29, no room for user(2) - assert result == expected - - # Test case: No system message - messages = [ - {"role": "user", "content": "Hi"}, # 2 tokens - {"role": "assistant", "content": "Hello"}, # 5 tokens - {"role": "tool", "content": "Result"}, # 6 tokens - ] - result = trim_prompt_messages(messages, 20, count_tokens) - assert result == messages - - # Test case: No assistant-tool pairs, just system and user - messages = [ - {"role": "system", "content": "Sys"}, # 3 tokens - {"role": "user", "content": "Hi"}, # 2 tokens - ] - result = trim_prompt_messages(messages, 10, count_tokens) - assert result == messages - - # Test case: No system, no assistant-tool pairs, just user - messages = [{"role": "user", "content": "Hi"}] # 2 tokens - result = trim_prompt_messages(messages, 10, count_tokens) - assert result == messages - - # Test case: Tool message without preceding assistant (edge case) - messages = [ - {"role": "system", "content": "Sys"}, # 3 tokens - {"role": "user", "content": "Hi"}, # 2 tokens - {"role": "tool", "content": "Result"}, # 6 tokens - ] - expected = [ - {"role": "system", "content": "Sys"}, - {"role": "user", "content": "Hi"}, - ] - result = trim_prompt_messages(messages, 10, count_tokens) - assert result == expected - - # Test case: Message with tool_calls instead of content - messages = [ - {"role": "system", "content": "Sys"}, # 3 tokens - {"role": "user", "content": "Hi"}, # 2 tokens - { - "role": "assistant", - "tool_calls": [{"function": {"name": "test"}}], - }, # ~30 tokens (str representation) - {"role": "tool", "content": "Result"}, # 6 tokens - ] - result = trim_prompt_messages(messages, 100, count_tokens) - assert len(result) == 4 # Should keep all messages if context is large enough +from debug_gym.agents.utils import load_config def test_load_config(): @@ -263,102 +87,3 @@ def test_load_config(): assert _config == expected_config assert _args.debug is True assert _args.logging_level == logging.INFO - - -def test_trim(): - def count_tokens(text): - return len(text) - - # Test basic cases - no trimming needed - assert trim("Hello world", 11, count_tokens) == "Hello world" - assert trim("Hello world", 20, count_tokens) == "Hello world" - assert trim("Hi", 2, count_tokens) == "Hi" - assert trim("Hi", 10, count_tokens) == "Hi" - assert trim("A", 1, count_tokens) == "A" # Exactly fits, no trimming needed - - # Test edge cases - assert trim("Hello world", 0, count_tokens) == "" - assert trim("", 5, count_tokens) == "" - assert trim("", 0, count_tokens) == "" - - # Test cases requiring trimming to single token (ellipsis only) - assert trim("Hello world", 1, count_tokens) == "…" - assert trim("Hi", 1, count_tokens) == "…" - assert trim("ABC", 1, count_tokens) == "…" - - # Test trimming from the middle (default behavior) - assert trim("Hello world", 5, count_tokens) == "He…ld" - assert trim("Hello world", 6, count_tokens) == "He…rld" - assert trim("Hello world", 7, count_tokens) == "Hel…rld" - assert trim("123456789", 5, count_tokens) == "12…89" - assert trim("123456789", 7, count_tokens) == "123…789" - - # Test trimming from the end - assert trim("Hello world", 5, count_tokens, where="end") == "Hell…" - assert trim("Hello world", 6, count_tokens, where="end") == "Hello…" - assert trim("Hello world", 7, count_tokens, where="end") == "Hello …" - assert trim("123456789", 5, count_tokens, where="end") == "1234…" - - # Test trimming from the start - assert trim("Hello world", 5, count_tokens, where="start") == "…orld" - assert trim("Hello world", 6, count_tokens, where="start") == "…world" - assert trim("Hello world", 7, count_tokens, where="start") == "… world" - assert trim("123456789", 5, count_tokens, where="start") == "…6789" - - # Test invalid `where` value - with pytest.raises(ValueError, match="Invalid value for `where`"): - trim("Hello world", 5, count_tokens, where="invalid") - - # Test with different token counter - def another_count_tokens(text): - return len(text) // 2 - - # For "1234567890" (10 chars), another_count_tokens returns 5 tokens - # Original text has 5 tokens, so no trimming needed when max_tokens >= 5 - assert trim("1234567890", 5, another_count_tokens) == "1234567890" - assert trim("1234567890", 6, another_count_tokens) == "1234567890" - - # When max_tokens < 5, trimming is needed - # With max_tokens=4, we need 3 tokens for content + 1 for ellipsis - assert ( - trim("1234567890", 4, another_count_tokens) == "123…67890" - ) # Result has 4 tokens - assert ( - trim("1234567890", 3, another_count_tokens) == "123…890" - ) # Result has 3 tokens - assert trim("1234567890", 2, another_count_tokens) == "1…890" # Result has 2 tokens - assert trim("1234567890", 1, another_count_tokens) == "1…0" # Result has 1 token - - # Test with different trimming positions using the alternative counter - assert ( - trim("1234567890", 3, another_count_tokens, where="end") == "12345…" - ) # Result has 3 tokens - assert ( - trim("1234567890", 3, another_count_tokens, where="start") == "…67890" - ) # Result has 3 tokens - - # Test edge case with very short text and alternative counter - assert trim("AB", 1, another_count_tokens) == "AB" # "AB" has 1 token, fits exactly - assert ( - trim("ABCD", 1, another_count_tokens) == "A…D" - ) # "ABCD" has 2 tokens, needs trimming to 1 - - # Test boundary conditions with precise scenarios - def word_count_tokens(text): - # Count words as tokens - return len(text.split()) - - text = "Hello world test example" # 4 words = 4 tokens - assert trim(text, 4, word_count_tokens) == text # No trimming needed - assert ( - trim(text, 3, word_count_tokens, where="middle") == "Hello … example" - ) # Should fit in 3 tokens - assert ( - trim(text, 2, word_count_tokens, where="end") == "Hello …" - ) # Should fit in 2 tokens - assert ( - trim(text, 2, word_count_tokens, where="start") == "… example" - ) # Should fit in 2 tokens - - # Test very short max_tokens with word counter - assert trim("Hello world", 1, word_count_tokens) == "…" # Only ellipsis fits diff --git a/tests/conftest.py b/tests/conftest.py index 8248360b..17ee3fdd 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -45,8 +45,13 @@ def generate(self, messages, tools, **kwargs): response_token_count=20, ) - def tokenize(self, text): - return [c for c in text] + def tokenize(self, messages): + # Return list of token lists, one per message + result = [] + for msg in messages: + content = str(msg.get("content", msg.get("tool_calls", msg))) + result.append([c for c in content]) + return result def define_tools(self, tool_call_list): return tool_call_list diff --git a/tests/llms/test_base.py b/tests/llms/test_base.py index b48d63e4..a0801a97 100644 --- a/tests/llms/test_base.py +++ b/tests/llms/test_base.py @@ -5,7 +5,13 @@ from debug_gym.gym.entities import Observation from debug_gym.gym.tools.tool import EnvironmentTool -from debug_gym.llms import AnthropicLLM, AzureOpenAILLM, Human, OpenAILLM +from debug_gym.llms import ( + AnthropicLLM, + AzureOpenAILLM, + HuggingFaceLLM, + Human, + OpenAILLM, +) from debug_gym.llms.base import ( LLM, ContextLengthExceededError, @@ -47,6 +53,15 @@ "api_key": "test-api-key", "tags": ["anthropic", "claude", "claude-3.7"], }, + "qwen-3": { + "model": "Qwen/Qwen3-0.6B", + "tokenizer": "Qwen/Qwen3-0.6B", + "context_limit": 4, + "api_key": "test-api-key", + "endpoint": "https://test-endpoint", + "tags": ["vllm"], + "tokenizer_kwargs": {"trust_remote_code": True}, + }, } ), ) @@ -64,6 +79,9 @@ def test_instantiate_llm(mock_open, logger_mock): llm = LLM.instantiate("claude-3.7", logger=logger_mock) assert isinstance(llm, AnthropicLLM) + llm = LLM.instantiate("qwen-3", logger=logger_mock) + assert isinstance(llm, HuggingFaceLLM) + llm = LLM.instantiate("human", logger=logger_mock) assert isinstance(llm, Human) @@ -225,6 +243,7 @@ def test_llm_config_initialization(): assert config.tokenizer == "llm-mock" # Default to model when tokenizer is None assert config.ignore_kwargs == [] # Default empty list assert config.tags == [] # Default empty list + assert config.tokenizer_kwargs == {} def test_llm_config_optional_fields(basic_config): @@ -508,7 +527,7 @@ def generate(self, messages, tools, **kwargs): messages = [{"role": "user", "content": "Long message"}] # Mock trim_prompt_messages to return the same messages (no reduction) - with patch("debug_gym.agents.utils.trim_prompt_messages") as mock_trim: + with patch("debug_gym.llms.utils.trim_prompt_messages") as mock_trim: mock_trim.return_value = messages # Should raise ContextLengthExceededError, not RecursionError @@ -563,7 +582,7 @@ def generate(self, messages, tools, **kwargs): messages = [{"role": "user", "content": "Long message"}] # Mock trim_prompt_messages to return shorter messages - with patch("debug_gym.agents.utils.trim_prompt_messages") as mock_trim: + with patch("debug_gym.llms.base.trim_prompt_messages") as mock_trim: shorter_messages = [{"role": "user", "content": "Short"}] mock_trim.return_value = shorter_messages diff --git a/tests/llms/test_copilot.py b/tests/llms/test_copilot.py index 7a8997d2..deb78196 100644 --- a/tests/llms/test_copilot.py +++ b/tests/llms/test_copilot.py @@ -111,9 +111,11 @@ def test_tokenize_with_tiktoken(self, mock_config, logger_mock): mock_encode = MagicMock(return_value=[1, 2, 3, 4]) with patch("tiktoken.encoding_for_model") as mock_tiktoken: mock_tiktoken.return_value.encode = mock_encode - tokens = llm.tokenize("test text") + messages = [{"role": "user", "content": "test text"}] + tokens = llm.tokenize(messages) - assert tokens == [1, 2, 3, 4] + # Should return list of token lists (one per message) + assert tokens == [["1", "2", "3", "4"]] mock_tiktoken.assert_called_once_with("gpt-4o") mock_encode.assert_called_once_with("test text") @@ -125,9 +127,11 @@ def test_tokenize_fallback(self, mock_config, logger_mock): with patch( "tiktoken.encoding_for_model", side_effect=KeyError("model not found") ): - tokens = llm.tokenize("hello world test") + messages = [{"role": "user", "content": "hello world test"}] + tokens = llm.tokenize(messages) - assert tokens == ["hello", "world", "test"] + # Should return list of token lists (one per message) + assert tokens == [["hello", "world", "test"]] def test_need_to_be_retried_hmac_timestamp_error(self, mock_config, logger_mock): """Test retry logic for HMAC timestamp errors""" diff --git a/tests/llms/test_huggingface.py b/tests/llms/test_huggingface.py new file mode 100644 index 00000000..e8a126d2 --- /dev/null +++ b/tests/llms/test_huggingface.py @@ -0,0 +1,251 @@ +import json +from unittest.mock import MagicMock, patch + +import pytest +from transformers import AutoTokenizer + +from debug_gym.llms import HuggingFaceLLM, OpenAILLM +from debug_gym.llms.base import LLMConfig, LLMConfigRegistry + +# Run these tests with `pytest tests/llms/test_huggingface.py -m hf_tokenizer` +# to include the integration case that downloads the real Qwen tokenizer. + + +HF_MODEL_ID = "Qwen/Qwen3-0.6B" + +MODEL_REGISTRY = { + "qwen-3": { + "model": HF_MODEL_ID, + "tokenizer": HF_MODEL_ID, + "context_limit": 4, + "api_key": "test-api-key", + "endpoint": "https://test-endpoint", + "tags": ["vllm"], + "tokenizer_kwargs": {"trust_remote_code": True}, + } +} + +MODEL_REGISTRY_WITH_CHAT_TEMPLATE = { + "qwen-3": { + **MODEL_REGISTRY["qwen-3"], + "apply_chat_template": True, + }, +} + + +@pytest.fixture(scope="session") +def real_qwen3_tokenizer(): + try: + return AutoTokenizer.from_pretrained(HF_MODEL_ID) + except ( + OSError, + ValueError, + ImportError, + ) as exc: # pragma: no cover - network-dependent + pytest.skip(f"Unable to load tokenizer {HF_MODEL_ID}: {exc}") + + +@patch.object( + LLMConfigRegistry, + "from_file", + return_value=LLMConfigRegistry.register_all(MODEL_REGISTRY), +) +def test_tokenize_uses_hf_tokenizer_with_pad_fallback(mock_llm_config, logger_mock): + tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_ID) + tokenizer.pad_token = None + tokenizer.eos_token = "" + with patch( + "debug_gym.llms.huggingface.AutoTokenizer.from_pretrained" + ) as mock_auto_tokenizer: + mock_auto_tokenizer.return_value = tokenizer + llm = HuggingFaceLLM(model_name="qwen-3", logger=logger_mock) + messages = [{"role": "user", "content": "hello world"}] + assert llm.tokenize(messages) == [["hello", "Ġworld"]] + assert llm.count_tokens(messages) == 2 + assert tokenizer.eos_token == "" + assert tokenizer.pad_token == "" + + +@patch.object( + LLMConfigRegistry, + "from_file", + return_value=LLMConfigRegistry.register_all(MODEL_REGISTRY_WITH_CHAT_TEMPLATE), +) +def test_message_token_counts_uses_chat_template(mock_llm_config, logger_mock): + llm = HuggingFaceLLM(model_name="qwen-3", logger=logger_mock) + + messages = [ + {"role": "system", "content": "Instructions"}, + {"role": "user", "content": "Hello world!"}, + {"role": "tool", "content": "Result"}, + ] + + counts = llm.count_tokens(messages) + + # When using chat template, each message gets template tokens added + # The exact counts depend on the template format + assert counts == 31 + + +@pytest.mark.hf_tokenizer +def test_chat_template_counts_with_real_tokenizer(real_qwen3_tokenizer, logger_mock): + config = LLMConfig( + model=HF_MODEL_ID, + tokenizer=HF_MODEL_ID, + context_limit=4, + api_key="placeholder", + endpoint="http://localhost", + tags=["vllm"], + tokenizer_kwargs={"trust_remote_code": True}, + ) + + llm = HuggingFaceLLM(model_name="qwen-3", logger=logger_mock, llm_config=config) + llm._hf_tokenizer = real_qwen3_tokenizer + + messages = [ + {"role": "system", "content": "Instructions"}, + {"role": "user", "content": "Hello world!"}, + {"role": "tool", "content": "Result"}, + ] + + counts = llm.count_tokens(messages) + assert counts == 5 + + +@pytest.mark.hf_tokenizer +def test_tokenize_and_count_tokens_with_real_tokenizer( + real_qwen3_tokenizer, logger_mock +): + config = LLMConfig( + model=HF_MODEL_ID, + tokenizer=HF_MODEL_ID, + context_limit=4, + api_key="placeholder", + endpoint="http://localhost", + tags=["vllm"], + tokenizer_kwargs={"trust_remote_code": True}, + ) + + llm = HuggingFaceLLM(model_name="qwen-3", logger=logger_mock, llm_config=config) + llm._hf_tokenizer = real_qwen3_tokenizer + + text = "Hello world!" + messages = [{"role": "user", "content": text}] + hf_ids = real_qwen3_tokenizer.encode(text, add_special_tokens=False) + hf_tokens = real_qwen3_tokenizer.convert_ids_to_tokens(hf_ids) + + tokens = llm.tokenize(messages) + assert tokens == [hf_tokens] + assert llm.count_tokens(messages) == len(hf_ids) + + +@patch.object( + LLMConfigRegistry, + "from_file", + return_value=LLMConfigRegistry.register_all( + { + "qwen": { + "model": HF_MODEL_ID, + "tokenizer": HF_MODEL_ID, + "apply_chat_template": False, + "context_limit": 4096, + "api_key": "fake", + "endpoint": "fake", + "api_version": "1", + "tags": ["vllm"], + } + } + ), +) +def test_hf_tokenize_no_chat_template(mock_llm_config, logger_mock): + llm = HuggingFaceLLM(model_name="qwen", logger=logger_mock) + messages = [{"role": "user", "content": "hello world"}] + tokens = llm.tokenize(messages) + assert tokens == [["hello", "Ġworld"]] + + +@patch.object( + LLMConfigRegistry, + "from_file", + return_value=LLMConfigRegistry.register_all( + { + "qwen": { + "model": HF_MODEL_ID, + "tokenizer": HF_MODEL_ID, + "apply_chat_template": True, + "context_limit": 4096, + "api_key": "fake", + "endpoint": "fake", + "api_version": "1", + "tags": ["vllm"], + } + } + ), +) +def test_hf_tokenize_apply_chat_template(mock_llm_config, logger_mock): + llm = HuggingFaceLLM(model_name="qwen", logger=logger_mock) + + messages = [{"role": "user", "content": "hello world"}] + tokens = llm.tokenize(messages) + + # When using chat template, all messages are tokenized together, so returns single list + assert tokens == [ + [ + "<|im_start|>", + "user", + "Ċ", + "hello", + "Ġworld", + "<|im_end|>", + "Ċ", + "<|im_start|>", + "assistant", + "Ċ", + "", + "ĊĊ", + "", + "ĊĊ", + ] + ] + + +@patch.object( + LLMConfigRegistry, + "from_file", + return_value=LLMConfigRegistry.register_all( + { + "qwen": { + "model": HF_MODEL_ID, + "tokenizer": HF_MODEL_ID, + "apply_chat_template": True, + "enable_thinking": True, + "context_limit": 4096, + "api_key": "fake", + "endpoint": "fake", + "api_version": "1", + "tags": ["vllm"], + } + } + ), +) +def test_hf_tokenize_apply_chat_template_thinking(mock_llm_config, logger_mock): + llm = HuggingFaceLLM(model_name="qwen", logger=logger_mock) + + messages = [{"role": "user", "content": "hello world"}] + tokens = llm.tokenize(messages) + + # When using chat template, all messages are tokenized together, so returns single list + assert tokens == [ + [ + "<|im_start|>", + "user", + "Ċ", + "hello", + "Ġworld", + "<|im_end|>", + "Ċ", + "<|im_start|>", + "assistant", + "Ċ", + ] + ] diff --git a/tests/llms/test_human.py b/tests/llms/test_human.py index f0adf300..5c0152ee 100644 --- a/tests/llms/test_human.py +++ b/tests/llms/test_human.py @@ -589,14 +589,16 @@ def test_human_max_retries_exceeded(build_env_info): def test_human_tokenize(): """Test Human tokenization""" human = Human() - tokens = human.tokenize("hello world test") - assert tokens == ["hello", "world", "test"] + messages = [{"role": "user", "content": "hello world test"}] + tokens = human.tokenize(messages) + assert tokens == [["hello", "world", "test"]] def test_human_count_tokens(): """Test Human token counting""" human = Human() - count = human.count_tokens("hello world test") + messages = [{"role": "user", "content": "hello world test"}] + count = human.count_tokens(messages) assert count == 3 diff --git a/tests/llms/test_openai.py b/tests/llms/test_openai.py index bda76b94..abd90751 100644 --- a/tests/llms/test_openai.py +++ b/tests/llms/test_openai.py @@ -532,3 +532,35 @@ def test_llm_without_reasoning_content_attribute( # The response should be just the regular content assert llm_response.response == "Regular response only" + + +@patch.object( + LLMConfigRegistry, + "from_file", + return_value=LLMConfigRegistry.register_all( + { + "qwen": { + "model": "Qwen/Qwen3-0.6B", + "tokenizer": "Qwen/Qwen3-0.6B", + "context_limit": 4, + "api_key": "test-api-key", + "endpoint": "https://test-endpoint", + "api_version": "v1", + "tags": ["openai"], # Using openai tag to force OpenAILLM class + } + } + ), +) +def test_openai_llm_raises_error_for_non_gpt_tokenizer(mock_llm_config, logger_mock): + """Test that OpenAILLM raises ValueError when tokenizer is not a GPT model""" + import pytest + + llm = OpenAILLM(model_name="qwen", logger=logger_mock) + messages = [{"role": "user", "content": "test"}] + + # Should raise ValueError when trying to tokenize with a non-GPT tokenizer + with pytest.raises(ValueError) as exc_info: + llm.tokenize(messages) + + assert "Tokenizer `Qwen/Qwen3-0.6B` not found" in str(exc_info.value) + assert "set tag `vllm`" in str(exc_info.value) diff --git a/tests/llms/test_utils.py b/tests/llms/test_utils.py index b8e1e794..39815b73 100644 --- a/tests/llms/test_utils.py +++ b/tests/llms/test_utils.py @@ -1,4 +1,6 @@ -from debug_gym.llms.utils import print_messages +import pytest + +from debug_gym.llms.utils import print_messages, trim, trim_prompt_messages def test_print_messages(logger_mock): @@ -40,3 +42,276 @@ def test_print_messages_unknown_role(logger_mock): assert "Unknown role" in str(e) else: assert False, "ValueError not raised for unknown role" + + +def test_trim_prompt_messages(): + def count_tokens(messages): + return sum(len(msg.get("content", msg.get("tool_calls"))) for msg in messages) + + # Test basic validation + with pytest.raises(AssertionError, match="messages should not be empty"): + trim_prompt_messages([], 5, count_tokens) + + with pytest.raises( + AssertionError, match="the last message should be from the user or the tool" + ): + messages = [ + {"role": "system", "content": "System message"}, + {"role": "assistant", "content": "Assistant message"}, + ] + trim_prompt_messages(messages, 20, count_tokens) + + with pytest.raises(AssertionError, match="context_length should be non-negative"): + messages = [{"role": "user", "content": "User message"}] + trim_prompt_messages(messages, -1, count_tokens) + + # Test system message too long + with pytest.raises( + AssertionError, match="System message tokens exceed context length" + ): + messages = [ + {"role": "system", "content": "Very long system message"}, + {"role": "user", "content": "Hi"}, + ] + trim_prompt_messages(messages, 10, count_tokens) + + # Test simple case: just user message + messages = [{"role": "user", "content": "Hello"}] + result = trim_prompt_messages(messages, 10, count_tokens) + assert result == messages + + # Test case: system + user, fits completely + messages = [ + {"role": "system", "content": "Sys"}, # 3 tokens + {"role": "user", "content": "Hi"}, # 2 tokens + ] + result = trim_prompt_messages(messages, 10, count_tokens) + assert result == messages + + # Test case: system + user + assistant + tool, fits completely + messages = [ + {"role": "system", "content": "Sys"}, # 3 tokens + {"role": "user", "content": "Hi"}, # 2 tokens + {"role": "assistant", "content": "Hello"}, # 5 tokens + {"role": "tool", "content": "Result"}, # 6 tokens + ] + result = trim_prompt_messages(messages, 20, count_tokens) + assert result == messages + + # Test case: Keep system + most recent assistant-tool pair, drop user + messages = [ + {"role": "system", "content": "Sys"}, # 3 tokens + {"role": "user", "content": "Hi"}, # 2 tokens + {"role": "assistant", "content": "Hello"}, # 5 tokens + {"role": "tool", "content": "Result"}, # 6 tokens + ] + expected = [ + {"role": "system", "content": "Sys"}, + {"role": "assistant", "content": "Hello"}, + {"role": "tool", "content": "Result"}, + ] + result = trim_prompt_messages( + messages, 14, count_tokens + ) # Just enough for sys + assistant + tool + assert result == expected + + # Test case: Multiple assistant-tool pairs, keep most recent ones + messages = [ + {"role": "system", "content": "Sys"}, # 3 tokens + {"role": "user", "content": "Hi"}, # 2 tokens + {"role": "assistant", "content": "Hello1"}, # 6 tokens + {"role": "tool", "content": "Result1"}, # 7 tokens + {"role": "assistant", "content": "Hello2"}, # 6 tokens + {"role": "tool", "content": "Result2"}, # 7 tokens + ] + # Keep system + most recent assistant-tool pair only + expected = [ + {"role": "system", "content": "Sys"}, + {"role": "assistant", "content": "Hello2"}, + {"role": "tool", "content": "Result2"}, + ] + result = trim_prompt_messages( + messages, 16, count_tokens + ) # sys(3) + hello2(6) + result2(7) = 16 + assert result == expected + + # Test case: Can fit all assistant-tool pairs + user message + messages = [ + {"role": "system", "content": "Sys"}, # 3 tokens + {"role": "user", "content": "Hi"}, # 2 tokens + {"role": "assistant", "content": "Hello1"}, # 6 tokens + {"role": "tool", "content": "Result1"}, # 7 tokens + {"role": "assistant", "content": "Hello2"}, # 6 tokens + {"role": "tool", "content": "Result2"}, # 7 tokens + ] + # All pairs fit, so include user message too + result = trim_prompt_messages(messages, 50, count_tokens) + assert result == messages + + # Test case: Can fit all assistant-tool pairs but not user message + messages = [ + {"role": "system", "content": "Sys"}, # 3 tokens + {"role": "user", "content": "Hi"}, # 2 tokens + {"role": "assistant", "content": "Hello1"}, # 6 tokens + {"role": "tool", "content": "Result1"}, # 7 tokens + {"role": "assistant", "content": "Hello2"}, # 6 tokens + {"role": "tool", "content": "Result2"}, # 7 tokens + ] + expected = [ + {"role": "system", "content": "Sys"}, + {"role": "assistant", "content": "Hello1"}, + {"role": "tool", "content": "Result1"}, + {"role": "assistant", "content": "Hello2"}, + {"role": "tool", "content": "Result2"}, + ] + result = trim_prompt_messages( + messages, 29, count_tokens + ) # sys(3) + all pairs(26) = 29, no room for user(2) + assert result == expected + + # Test case: No system message + messages = [ + {"role": "user", "content": "Hi"}, # 2 tokens + {"role": "assistant", "content": "Hello"}, # 5 tokens + {"role": "tool", "content": "Result"}, # 6 tokens + ] + result = trim_prompt_messages(messages, 20, count_tokens) + assert result == messages + + # Test case: No assistant-tool pairs, just system and user + messages = [ + {"role": "system", "content": "Sys"}, # 3 tokens + {"role": "user", "content": "Hi"}, # 2 tokens + ] + result = trim_prompt_messages(messages, 10, count_tokens) + assert result == messages + + # Test case: No system, no assistant-tool pairs, just user + messages = [{"role": "user", "content": "Hi"}] # 2 tokens + result = trim_prompt_messages(messages, 10, count_tokens) + assert result == messages + + # Test case: Tool message without preceding assistant (edge case) + messages = [ + {"role": "system", "content": "Sys"}, # 3 tokens + {"role": "user", "content": "Hi"}, # 2 tokens + {"role": "tool", "content": "Result"}, # 6 tokens + ] + expected = [ + {"role": "system", "content": "Sys"}, + {"role": "user", "content": "Hi"}, + ] + result = trim_prompt_messages(messages, 10, count_tokens) + assert result == expected + + # Test case: Message with tool_calls instead of content + messages = [ + {"role": "system", "content": "Sys"}, # 3 tokens + {"role": "user", "content": "Hi"}, # 2 tokens + { + "role": "assistant", + "tool_calls": [{"function": {"name": "test"}}], + }, # ~30 tokens (str representation) + {"role": "tool", "content": "Result"}, # 6 tokens + ] + result = trim_prompt_messages(messages, 100, count_tokens) + assert len(result) == 4 # Should keep all messages if context is large enough + + +def test_trim(): + def count_tokens(text): + return len(text) + + # Test basic cases - no trimming needed + assert trim("Hello world", 11, count_tokens) == "Hello world" + assert trim("Hello world", 20, count_tokens) == "Hello world" + assert trim("Hi", 2, count_tokens) == "Hi" + assert trim("Hi", 10, count_tokens) == "Hi" + assert trim("A", 1, count_tokens) == "A" # Exactly fits, no trimming needed + + # Test edge cases + assert trim("Hello world", 0, count_tokens) == "" + assert trim("", 5, count_tokens) == "" + assert trim("", 0, count_tokens) == "" + + # Test cases requiring trimming to single token (ellipsis only) + assert trim("Hello world", 1, count_tokens) == "…" + assert trim("Hi", 1, count_tokens) == "…" + assert trim("ABC", 1, count_tokens) == "…" + + # Test trimming from the middle (default behavior) + assert trim("Hello world", 5, count_tokens) == "He…ld" + assert trim("Hello world", 6, count_tokens) == "He…rld" + assert trim("Hello world", 7, count_tokens) == "Hel…rld" + assert trim("123456789", 5, count_tokens) == "12…89" + assert trim("123456789", 7, count_tokens) == "123…789" + + # Test trimming from the end + assert trim("Hello world", 5, count_tokens, where="end") == "Hell…" + assert trim("Hello world", 6, count_tokens, where="end") == "Hello…" + assert trim("Hello world", 7, count_tokens, where="end") == "Hello …" + assert trim("123456789", 5, count_tokens, where="end") == "1234…" + + # Test trimming from the start + assert trim("Hello world", 5, count_tokens, where="start") == "…orld" + assert trim("Hello world", 6, count_tokens, where="start") == "…world" + assert trim("Hello world", 7, count_tokens, where="start") == "… world" + assert trim("123456789", 5, count_tokens, where="start") == "…6789" + + # Test invalid `where` value + with pytest.raises(ValueError, match="Invalid value for `where`"): + trim("Hello world", 5, count_tokens, where="invalid") + + # Test with different token counter + def another_count_tokens(text): + return len(text) // 2 + + # For "1234567890" (10 chars), another_count_tokens returns 5 tokens + # Original text has 5 tokens, so no trimming needed when max_tokens >= 5 + assert trim("1234567890", 5, another_count_tokens) == "1234567890" + assert trim("1234567890", 6, another_count_tokens) == "1234567890" + + # When max_tokens < 5, trimming is needed + # With max_tokens=4, we need 3 tokens for content + 1 for ellipsis + assert ( + trim("1234567890", 4, another_count_tokens) == "123…67890" + ) # Result has 4 tokens + assert ( + trim("1234567890", 3, another_count_tokens) == "123…890" + ) # Result has 3 tokens + assert trim("1234567890", 2, another_count_tokens) == "1…890" # Result has 2 tokens + assert trim("1234567890", 1, another_count_tokens) == "1…0" # Result has 1 token + + # Test with different trimming positions using the alternative counter + assert ( + trim("1234567890", 3, another_count_tokens, where="end") == "12345…" + ) # Result has 3 tokens + assert ( + trim("1234567890", 3, another_count_tokens, where="start") == "…67890" + ) # Result has 3 tokens + + # Test edge case with very short text and alternative counter + assert trim("AB", 1, another_count_tokens) == "AB" # "AB" has 1 token, fits exactly + assert ( + trim("ABCD", 1, another_count_tokens) == "A…D" + ) # "ABCD" has 2 tokens, needs trimming to 1 + + # Test boundary conditions with precise scenarios + def word_count_tokens(text): + # Count words as tokens + return len(text.split()) + + text = "Hello world test example" # 4 words = 4 tokens + assert trim(text, 4, word_count_tokens) == text # No trimming needed + assert ( + trim(text, 3, word_count_tokens, where="middle") == "Hello … example" + ) # Should fit in 3 tokens + assert ( + trim(text, 2, word_count_tokens, where="end") == "Hello …" + ) # Should fit in 2 tokens + assert ( + trim(text, 2, word_count_tokens, where="start") == "… example" + ) # Should fit in 2 tokens + + # Test very short max_tokens with word counter + assert trim("Hello world", 1, word_count_tokens) == "…" # Only ellipsis fits