From 720c46a2fcd3ce8f3fd93785caea84176ea2a788 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 17 Feb 2026 04:35:18 +0000 Subject: [PATCH 1/4] add best of n sampling baselines + generate test baselines --- examples/TTSwithVerification/README.md | 41 ++ .../TTSwithVerification/bestofk_baseline.py | 608 ++++++++++++++++++ 2 files changed, 649 insertions(+) create mode 100644 examples/TTSwithVerification/bestofk_baseline.py diff --git a/examples/TTSwithVerification/README.md b/examples/TTSwithVerification/README.md index ed7fbb4..a2c52dd 100644 --- a/examples/TTSwithVerification/README.md +++ b/examples/TTSwithVerification/README.md @@ -156,6 +156,39 @@ The Z3 solver handles diagonal directions (`Northwest`, `Northeast`, `Southwest` --- +# Best-of-K Baseline + +A simple best-of-K baseline that generates K independent reasoning traces per example and selects the best based on: +1. **Ground-truth matching** (default): Greedy selection of first correct answer among K samples +2. **Critic model evaluation** (optional): Use a separate critic LLM to evaluate correctness without access to ground truth + +This baseline demonstrates that with sufficient sampling, even simple CoT can achieve good performance. + +## Usage + +```bash +# Best-of-K with ground-truth evaluation +python ./examples/TTSwithVerification/bestofk_baseline.py --task game24 -n 10 --k 4 + +# Best-of-K with critic model evaluation +python ./examples/TTSwithVerification/bestofk_baseline.py --task game24 -n 10 --k 4 --use_critic --critic_model Qwen/Qwen3-30B-A3B-Thinking-2507 --critic_port 8001 +``` + +### Parameters + +| Argument | Description | Default | +|----------|-------------|---------| +| `--task` | Task: `game24`, `maze`, or `spatialmap` | required | +| `--k` | Number of samples per example | `4` | +| `--use_critic` | Use critic model for evaluation instead of ground truth | `False` | +| `--critic_model` | Model to use for critic evaluation | MAIN_MODEL | +| `--critic_port` | vLLM server port for critic model | `8001` | +| `--num_examples`, `-n` | Number of examples to run | varies | +| `--main_model` | Model for generation | `Qwen/Qwen3-30B-A3B-Thinking-2507` | +| `--port` | vLLM server port for main model | `8000` | + +--- + ## Example Scripts Each script runs a full evaluation: loading a dataset, building structured prompts, running inference with step verification, and computing accuracy/token statistics. @@ -169,6 +202,14 @@ python ./examples/TTSwithVerification/maze_stepverifier.py -n 1 # SpatialMap with step verification python ./examples/TTSwithVerification/spatialmap_stepverifier.py -n 1 + +# Best-of-K baseline (standard CoT, no monitors) +python ./examples/TTSwithVerification/bestofk_baseline.py --task game24 -n 1 --k 4 +python ./examples/TTSwithVerification/bestofk_baseline.py --task maze -n 1 --k 4 +python ./examples/TTSwithVerification/bestofk_baseline.py --task spatialmap -n 1 --k 4 + +# Best-of-K with critic model evaluation +python ./examples/TTSwithVerification/bestofk_baseline.py --task game24 -n 1 --k 4 --use_critic ``` ### Common arguments diff --git a/examples/TTSwithVerification/bestofk_baseline.py b/examples/TTSwithVerification/bestofk_baseline.py new file mode 100644 index 0000000..8864a38 --- /dev/null +++ b/examples/TTSwithVerification/bestofk_baseline.py @@ -0,0 +1,608 @@ +import argparse +import asyncio +import json +import logging +import os +import re +from dataclasses import dataclass +from typing import List, Optional + +import numpy as np +from datasets import load_dataset +from tqdm import tqdm +from transformers import AutoTokenizer + +from interwhen import stream_completion + +# ============== MODEL CONFIGURATION ============== +MAIN_MODEL = "Qwen/Qwen3-30B-A3B-Thinking-2507" +# ================================================= + +logger = logging.getLogger(__name__) + + +@dataclass +class SampleResult: + output: str + correct: bool + extracted: Optional[str] + message: str + tokens: int + critic_correct: Optional[bool] = None + + +def get_model_short_name(model_name: str) -> str: + short_name = model_name.split("/")[-1] + return short_name.replace(" ", "_").replace(":", "-") + + +def get_output_dirs(task: str, main_model: str, base_dir: str = "../../b-pchanda/Outputs_TTS/BestOfKResults"): + model_short_name = get_model_short_name(main_model) + output_base = os.path.join(base_dir, task, model_short_name) + dirs = { + "base": output_base, + "reasoning": os.path.join(output_base, "Reasoning_output"), + } + for dir_path in dirs.values(): + os.makedirs(dir_path, exist_ok=True) + return dirs + + +def init_llm_server(model_name, max_tokens=32768, port=8000, temperature=0.6, seed=42): + url = f"http://localhost:{port}/v1/completions" + payload = { + "model": model_name, + "max_tokens": max_tokens, + "top_k": 20, + "top_p": 0.95, + "min_p": 0.0, + "do_sample": True, + "temperature": temperature, + "stream": True, + "logprobs": 20, + "use_beam_search": False, + "prompt_cache": True, + "seed": seed, + } + headers = {"Content-Type": "application/json"} + return {"url": url, "payload": payload, "headers": headers} + + +def count_tokens(text: str, tokenizer) -> int: + tokens = tokenizer.encode(text, add_special_tokens=False) + return len(tokens) + + +def save_outputs(idx: int, outputs: List[SampleResult], best_idx: int, output_dir: str): + os.makedirs(output_dir, exist_ok=True) + filepath = os.path.join(output_dir, f"output_{idx}.txt") + with open(filepath, "w", encoding="utf-8") as f: + f.write(f"BEST_INDEX={best_idx}\n") + for i, result in enumerate(outputs): + f.write("\n" + "=" * 80 + "\n") + f.write(f"SAMPLE {i}\n") + f.write(f"CORRECT={result.correct}\n") + f.write(f"CRITIC_CORRECT={result.critic_correct}\n") + f.write(f"EXTRACTED={result.extracted}\n") + f.write(f"TOKENS={result.tokens}\n") + f.write(f"MESSAGE={result.message}\n") + f.write("\n") + f.write(result.output) + f.write("\n") + logger.info(f"Saved outputs to {filepath}") + + +# --------------------- Game24 helpers --------------------- + +def build_game24_prompt(nums): + a, b, c, d = nums + boxed = r"\\boxed{}" + base_prompt = f""" +You are solving the Game of 24. + +You are given four numbers: {a}, {b}, {c}, {d} + +Your job is to produce a valid arithmetic expression using: +- ALL four numbers exactly once +- ONLY +, -, *, / +- The expression must evaluate to exactly 24. + +Please reason step by step, and put your final answer containing only the expression within {boxed}. +""".strip() + return base_prompt + + +def extract_solution_game24(text): + boxed_pattern = r"\\boxed\{" + matches = list(re.finditer(boxed_pattern, text)) + if not matches: + return None + last_match = matches[-1] + start = last_match.end() + brace_count = 1 + end = start + while end < len(text) and brace_count > 0: + if text[end] == "{": + brace_count += 1 + elif text[end] == "}": + brace_count -= 1 + end += 1 + expr = text[start:end - 1].strip() + + frac_pattern = r"\\frac\{([^{}]+)\}\{([^{}]+)\}" + while re.search(frac_pattern, expr): + expr = re.sub(frac_pattern, r"(\1/\2)", expr) + + replacements = { + r"\times": "*", + r"\cdot": "*", + r"\div": "/", + } + for latex, op in replacements.items(): + expr = expr.replace(latex, op) + + expr = expr.replace(r"\\,", "").replace(r"\\ ", "") + expr = re.sub(r"\)\s*\(", ")*(", expr) + expr = re.sub(r"\)\s*(\d)", r")*\1", expr) + expr = re.sub(r"(\d)\s*\(", r"\1*(", expr) + + return expr + + +def extract_numbers_from_expr(expr): + numbers = re.findall(r"\d+\.?\d*", expr) + return [int(float(n)) if float(n).is_integer() else float(n) for n in numbers] + + +def validate_numbers_used(expr, expected_nums): + used_nums = extract_numbers_from_expr(expr) + return sorted(used_nums) == sorted(expected_nums) + + +def evaluate_expression(expr, expected_nums=None): + try: + if expected_nums is not None and not validate_numbers_used(expr, expected_nums): + return False + value = eval(expr, {"__builtins__": None}, {}) + return abs(value - 24) < 1e-6 + except Exception: + return False + + +def evaluate_game24_answer(answer, nums): + expr = extract_solution_game24(answer) + if not expr: + return False, None, "No expression found" + if evaluate_expression(expr, expected_nums=nums): + return True, expr, "Correct solution (evaluates to 24 using exactly the given numbers)" + used_nums = extract_numbers_from_expr(expr) + if sorted(used_nums) != sorted(nums): + return False, expr, f"Incorrect: Expression uses {used_nums}, expected {nums}" + return False, expr, "Expression does not evaluate to 24" + + +# --------------------- Maze/SpatialMap helpers --------------------- + +def remove_last_paragraph(s: str) -> str: + return s[:-143] if len(s) > 143 else s + + +def build_maze_prompt(example): + pre_prompt = ( + "You are an expert problem solver. Carefully read the following multiple-choice question " + "and think through the solution step-by-step before providing your final answer. " + "Provide your final answer option by enclosing it within \\boxed{A/B/C/D}.:" + ) + description = remove_last_paragraph(str(example.get("prompt"))) + return pre_prompt, description + + +def build_spatialmap_prompt(example): + pre_prompt = ( + "You are an expert problem solver. Carefully read the following multiple-choice question " + "and think through the solution step-by-step before providing your final answer." + "Provide your final answer option by enclosing it within \\boxed{A/B/C/D}.:" + ) + description = remove_last_paragraph(str(example.get("prompt"))) + return pre_prompt, description + + +def extract_solution_mcq(text): + matches = re.findall(r"\\boxed\{([^}]*)\}", text) + if not matches: + return None + expr = matches[-1].strip() + choice_match = re.search(r"\b([ABCD])\b", expr, flags=re.IGNORECASE) + if not choice_match: + return None + return choice_match.group(1).upper() + + +def extract_options_from_prompt(prompt_text, target_options): + pattern = r"\b([A-D])\.\s*(.*?)(?=\s*[A-D]\.\s*|$)" + raw = re.findall(pattern, prompt_text, flags=re.DOTALL) + options = {k: v.strip().rstrip(".") for k, v in raw} + if target_options: + options = {k: v for k, v in options.items() if k in target_options} + return options + + +def evaluate_mcq_answer(answer, options, ground_truth): + sol = extract_solution_mcq(answer) + gt_sol = str(ground_truth).strip() + if not sol: + return False, None, "No expression found" + sol = sol.strip() + if sol in options: + if options[sol] == gt_sol: + return True, sol, f"Correct: option {sol} -> {options[sol]}" + return False, sol, f"Incorrect: expected '{gt_sol}', got '{options[sol]}' (option {sol})" + if sol.lower() == gt_sol.lower(): + return True, sol, f"Correct: answer text matches ground truth: {sol}" + for opt_letter, opt_value in options.items(): + if sol.lower() == opt_value.lower(): + if opt_value == gt_sol: + return True, sol, f"Correct: answer text {sol} (option {opt_letter})" + return False, sol, f"Incorrect: expected '{gt_sol}', got '{opt_value}' (option {opt_letter})" + return False, sol, f"Solution '{sol}' not found in options or ground truth" + + +def build_full_prompt(task, example, nums=None): + if task == "game24": + prompt = build_game24_prompt(nums) + return f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n" + if task == "maze": + system_prompt, user_prompt = build_maze_prompt(example) + else: + system_prompt, user_prompt = build_spatialmap_prompt(example) + return ( + f"<|im_start|>system\n{system_prompt}<|im_end|>\n" + f"<|im_start|>user\n{user_prompt}<|im_end|>\n" + f"<|im_start|>assistant\n" + ) + + +def load_dataset_for_task(task): + if task == "game24": + return load_dataset("nlile/24-game", split="train") + if task == "maze": + return load_dataset("microsoft/VISION_LANGUAGE", "maze", split="val") + if task == "spatialmap": + return load_dataset("microsoft/VISION_LANGUAGE", "spatial_map_text_only", split="val") + raise ValueError(f"Unsupported task: {task}") + + +def resolve_indices(task, dataset_len, args): + if args.indices: + return [int(x.strip()) for x in args.indices.split(",")] + if args.xrange: + parts = args.xrange.split("-") + if len(parts) == 2: + try: + start = int(parts[0].strip()) + end = int(parts[1].strip()) + return range(start, end) + except ValueError: + raise ValueError(f"Invalid xrange format: {args.xrange}. Use 'start-end'") + if args.num_examples: + return np.linspace(0, dataset_len - 1, args.num_examples, dtype=int) + # Default: use full range + start = args.start if args.start is not None else 0 + end = args.end if args.end is not None else dataset_len + return range(start, end) + + +def run_k_samples(prompt, llm_server, k, seed): + outputs = [] + for i in range(k): + llm_server["payload"]["seed"] = seed + i + outputs.append(asyncio.run(stream_completion( + prompt, + llm_server=llm_server, + monitors=(), + add_delay=False, + termination_requires_validation=False, + async_execution=True, + ))) + return outputs + + +# --------------------- Critic model helpers --------------------- + +def build_game24_critic_prompt(nums, reasoning_output): + """Build critic prompt to evaluate Game of 24 solution.""" + return f"""You are a math verifier. Evaluate the following Game of 24 solution. + +Numbers: {nums} +Target: 24 + +Student's reasoning and answer: +{reasoning_output} + +Verify: +1. Does it use ALL four numbers exactly once? +2. Does each step follow correct arithmetic? +3. Does the final expression evaluate to exactly 24? + +Respond with ONLY: "CORRECT" or "INCORRECT" +""" + + +def build_mcq_critic_prompt(task, task_description, reasoning_output): + """Build critic prompt to evaluate MCQ solution.""" + task_name = "Maze" if task == "maze" else "Spatial Reasoning" + return f"""You are an expert {task_name} verifier. Evaluate the following solution. + +Task: +{task_description} + +Student's reasoning and answer: +{reasoning_output} + +Verify the correctness of the step-by-step reasoning and final answer. + +Respond with ONLY: "CORRECT" or "INCORRECT" +""" + + +async def evaluate_with_critic(output_text, task, example, critic_llm_server, tokenizer, nums=None): + """Use critic model to evaluate correctness of output.""" + try: + if task == "game24": + critic_prompt = build_game24_critic_prompt(nums, output_text) + else: + if task == "maze": + _, task_desc = build_maze_prompt(example) + else: + _, task_desc = build_spatialmap_prompt(example) + critic_prompt = build_mcq_critic_prompt(task, task_desc, output_text) + + critic_system = "You are a strict academic verifier." + full_prompt = f"<|im_start|>system\n{critic_system}<|im_end|>\n<|im_start|>user\n{critic_prompt}<|im_end|>\n<|im_start|>assistant\n" + + critic_output = await stream_completion( + full_prompt, + llm_server=critic_llm_server, + monitors=(), + add_delay=False, + termination_requires_validation=False, + async_execution=True, + ) + + is_correct = "CORRECT" in critic_output.upper() + return is_correct, critic_output + except Exception as e: + logger.warning(f"Critic evaluation failed: {e}") + return False, "" + + +def run_k_samples_with_critic( + prompt, + llm_server, + critic_llm_server, + k, + seed, + task, + example, + tokenizer, + eval_fn, + nums=None, + early_stop=False, +): + """Run up to K samples, evaluate with critic, and score with ground truth.""" + sample_results = [] + for i in range(k): + llm_server["payload"]["seed"] = seed + i + output = asyncio.run(stream_completion( + prompt, + llm_server=llm_server, + monitors=(), + add_delay=False, + termination_requires_validation=False, + async_execution=True, + )) + + critic_correct, critic_response = asyncio.run(evaluate_with_critic( + output, task, example, critic_llm_server, tokenizer, nums=nums + )) + is_correct, extracted, message = eval_fn(output) + token_count = count_tokens(output, tokenizer) + + sample_results.append(SampleResult( + output=output, + correct=is_correct, + extracted=extracted, + message=f"Critic verdict: {'CORRECT' if critic_correct else 'INCORRECT'} | {message}", + tokens=token_count, + critic_correct=critic_correct, + )) + + if early_stop and critic_correct: + break + + return sample_results + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Best-of-K baseline (standard CoT) for TTSwithVerification datasets") + parser.add_argument("--task", type=str, required=True, choices=["game24", "maze", "spatialmap"], + help="Task to run") + parser.add_argument("--k", type=int, default=4, help="Number of samples per example") + parser.add_argument("--num_examples", "-n", type=int, default=None, + help="Number of examples to run (overrides start/end)") + parser.add_argument("--indices", type=str, default=None, + help="Comma-separated indices to run") + parser.add_argument("--xrange", type=str, default=None, + help="Range of indices to run (format: 'start-end')") + parser.add_argument("--start", type=int, default=None, help="Start index") + parser.add_argument("--end", type=int, default=None, help="End index") + parser.add_argument("--main_model", type=str, default=MAIN_MODEL, help="Main model to use for generation") + parser.add_argument("--port", type=int, default=8000, help="vLLM server port") + parser.add_argument("--use_critic", action="store_true", help="Use critic model for evaluation instead of ground truth") + parser.add_argument("--critic_model", type=str, default=MAIN_MODEL, help="Critic model to use for evaluation") + parser.add_argument("--critic_port", type=int, default=8000, help="vLLM server port for critic model (default: same as main model port)") + parser.add_argument("--critic_early_stop", action="store_true", help="Stop sampling after first critic-correct trace") + parser.add_argument("--seed", type=int, default=42, help="Base random seed") + parser.add_argument("--max_tokens", type=int, default=32768, help="Max tokens for generation") + parser.add_argument("--temperature", type=float, default=0.6, help="Sampling temperature") + parser.add_argument("--debug", "-d", action="store_true", help="Enable debug logging") + args = parser.parse_args() + + log_level = logging.DEBUG if args.debug else logging.INFO + logging.basicConfig(level=log_level, format="%(message)s") + + + dataset = load_dataset_for_task(args.task) + indices = resolve_indices(args.task, len(dataset), args) + + llm_server = init_llm_server( + args.main_model, + max_tokens=args.max_tokens, + port=args.port, + temperature=args.temperature, + seed=args.seed, + ) + + critic_llm_server = None + if args.use_critic: + critic_llm_server = init_llm_server( + args.critic_model, + max_tokens=512, + port=args.critic_port, + temperature=0.2, + seed=args.seed, + ) + logger.info(f"Using critic model: {args.critic_model} on port {args.critic_port}") + + logger.info(f"Loading tokenizer for {args.main_model}...") + tokenizer = AutoTokenizer.from_pretrained(args.main_model, trust_remote_code=True) + logger.info("Tokenizer loaded successfully.") + + output_dirs = get_output_dirs(args.task, args.main_model) + + total_examples = 0 + total_correct = 0 + total_correct_samples = 0 + total_samples = 0 + critic_correct_samples = 0 + critic_total_samples = 0 + total_tokens = 0 + total_tokens_all_samples = 0 + results = [] + + for idx in tqdm(indices, desc="Processing examples", unit="example"): + example = dataset[int(idx)] + if args.task == "game24": + nums = example["numbers"] + prompt = build_full_prompt(args.task, example, nums=nums) + eval_fn = lambda output: evaluate_game24_answer(output, nums) + options = None + else: + prompt = build_full_prompt(args.task, example) + gt = str(example.get("ground_truth", "")).strip() + if gt == "Q4": + target_options = ["A", "B"] + else: + target_options = ["A", "B", "C", "D"] + if args.task == "maze": + _, user_prompt = build_maze_prompt(example) + else: + _, user_prompt = build_spatialmap_prompt(example) + options = extract_options_from_prompt(user_prompt, target_options) + eval_fn = lambda output: evaluate_mcq_answer(output, options, gt) + + logger.info(f"---- Example {idx} ----") + + if args.use_critic: + sample_results = run_k_samples_with_critic( + prompt, llm_server, critic_llm_server, args.k, args.seed, + args.task, example, tokenizer, eval_fn, nums=(nums if args.task == "game24" else None), + early_stop=args.critic_early_stop + ) + else: + outputs = run_k_samples(prompt, llm_server, args.k, args.seed) + sample_results = [] + for output in outputs: + is_correct, extracted, message = eval_fn(output) + token_count = count_tokens(output, tokenizer) + sample_results.append(SampleResult( + output=output, + correct=is_correct, + extracted=extracted, + message=message, + tokens=token_count, + critic_correct=None, + )) + + if args.use_critic: + best_idx = next((i for i, r in enumerate(sample_results) if r.critic_correct), 0) + else: + best_idx = next((i for i, r in enumerate(sample_results) if r.correct), 0) + best_result = sample_results[best_idx] + any_correct = any(r.correct for r in sample_results) + correct_samples = sum(1 for r in sample_results if r.correct) + critic_correct_samples_example = sum(1 for r in sample_results if r.critic_correct) + + save_outputs(idx, sample_results, best_idx, output_dirs["reasoning"]) + + total_examples += 1 + if any_correct: + total_correct += 1 + total_correct_samples += correct_samples + total_samples += len(sample_results) + critic_correct_samples += critic_correct_samples_example + critic_total_samples += len(sample_results) + total_tokens += best_result.tokens + total_tokens_all_samples += sum(r.tokens for r in sample_results) + + results.append({ + "idx": int(idx), + "best_idx": best_idx, + "any_correct": any_correct, + "best_correct": best_result.correct, + "best_critic_correct": best_result.critic_correct, + "best_extracted": best_result.extracted, + "best_message": best_result.message, + "best_tokens": best_result.tokens, + "all_tokens": [r.tokens for r in sample_results], + "all_correct": [r.correct for r in sample_results], + "all_critic_correct": [r.critic_correct for r in sample_results], + "options": options, + }) + + logger.info(f"Best sample: {best_idx} | Correct in K: {any_correct}") + logger.info(f"Best message: {best_result.message}") + + accuracy = total_correct / total_examples if total_examples else 0 + avg_best_tokens = total_tokens / total_examples if total_examples else 0 + avg_all_tokens = total_tokens_all_samples / total_examples if total_examples else 0 + + summary = { + "task": args.task, + "model": args.main_model, + "k": args.k, + "use_critic": args.use_critic, + "total_examples": total_examples, + "correct": total_correct, + "correct_samples": total_correct_samples, + "total_samples": total_samples, + "critic_correct_samples": critic_correct_samples, + "critic_total_samples": critic_total_samples, + "critic_accuracy": (critic_correct_samples / critic_total_samples) if critic_total_samples else 0, + "accuracy": accuracy, + "avg_best_tokens": avg_best_tokens, + "avg_all_tokens": avg_all_tokens, + "total_tokens_best": total_tokens, + "total_tokens_all_samples": total_tokens_all_samples, + "results": results, + } + + if args.use_critic: + summary["critic_model"] = args.critic_model + summary["critic_port"] = args.critic_port + summary["critic_early_stop"] = args.critic_early_stop + + summary_path = os.path.join(output_dirs["base"], "summary.json") + with open(summary_path, "w", encoding="utf-8") as f: + json.dump(summary, f, indent=2) + logger.info(f"Saved summary to {summary_path}") From e706fc9067bdddfd3b414fc2bf5ab914ab30030c Mon Sep 17 00:00:00 2001 From: Prateek Chanda Date: Mon, 2 Mar 2026 04:43:35 +0000 Subject: [PATCH 2/4] Add b-pchanda to gitignore --- .gitignore | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.gitignore b/.gitignore index b7faf40..fd51d2d 100644 --- a/.gitignore +++ b/.gitignore @@ -205,3 +205,5 @@ cython_debug/ marimo/_static/ marimo/_lsp/ __marimo__/ + +b-pchanda/ \ No newline at end of file From fe030362b7305b665cf3d8dd784cfaaca8f86c29 Mon Sep 17 00:00:00 2001 From: Prateek Chanda Date: Mon, 2 Mar 2026 04:53:10 +0000 Subject: [PATCH 3/4] add tree of thoughts --- README.md | 15 + examples/TTSwithVerification/tot_baseline.py | 630 +++++++++ interwhen/tree_of_thought.py | 1268 ++++++++++++++++++ interwhen/value_prompts.py | 432 ++++++ 4 files changed, 2345 insertions(+) create mode 100644 examples/TTSwithVerification/tot_baseline.py create mode 100644 interwhen/tree_of_thought.py create mode 100644 interwhen/value_prompts.py diff --git a/README.md b/README.md index 03a763e..dc5ca94 100644 --- a/README.md +++ b/README.md @@ -102,6 +102,21 @@ We provide examples using three datasets: Maze, Game of 24, and SpatialMap. python ./examples/TTSwithVerification/[your_dataset]_stepverifier.py -n 1 # dataset=maze,game24, or spatialmap ``` +For using TreeofThought + +```bash +python ./examples/TTSwithVerification/tot_baseline.py \ + --task maze \ + --num_examples 4 \ + --ports 8000,8001,8002,8003 \ + --concurrency 4 \ + --model Qwen/QwQ-32B \ + --max_tokens 32768 \ +``` + +This script loads the same datasets as the verification examples, spins up `TreeOfThoughtSearch` with configurable branching/depth, and round-robins requests across multiple vLLM instances for faster experimentation. + + ### Monitors for Early stopping ```bash python ./examples/EarlyStopping/[your_dataset]_example.py -n 1 diff --git a/examples/TTSwithVerification/tot_baseline.py b/examples/TTSwithVerification/tot_baseline.py new file mode 100644 index 0000000..e8f7660 --- /dev/null +++ b/examples/TTSwithVerification/tot_baseline.py @@ -0,0 +1,630 @@ +#!/usr/bin/env python3 +"""Command-line Tree-of-Thought baseline runner for interwhen datasets.""" + +import argparse +import asyncio +import json +import logging +import os +import re +from pathlib import Path +from typing import Any, Dict, List, Tuple + +import httpx +import numpy as np +from datasets import load_dataset + +from interwhen.tree_of_thought import ( + SearchMethod, + ToTSearchConfig, + TreeOfThoughtSearch, + build_tot_problem, +) + +LOGGER = logging.getLogger("tot_baseline") + + +# ============== Helper Functions ============== + +def remove_last_paragraph(s: str) -> str: + return s[:-143] if len(s) > 143 else s + + +def build_maze_prompt(example): + pre_prompt = ( + "You are an expert problem solver. Carefully read the following multiple-choice question " + "and think through the solution step-by-step before providing your final answer. " + "Provide your final answer option by enclosing it within \\boxed{A/B/C/D}.:" + ) + description = remove_last_paragraph(str(example.get("prompt"))) + return pre_prompt, description + + +def build_spatialmap_prompt(example): + pre_prompt = ( + "You are an expert problem solver. Carefully read the following multiple-choice question " + "and think through the solution step-by-step before providing your final answer." + "Provide your final answer option by enclosing it within \\boxed{A/B/C/D}.:" + ) + description = remove_last_paragraph(str(example.get("prompt"))) + return pre_prompt, description + + +def extract_solution_game24(text): + boxed_pattern = r"\\boxed\{" + matches = list(re.finditer(boxed_pattern, text)) + if not matches: + return None + last_match = matches[-1] + start = last_match.end() + brace_count = 1 + end = start + while end < len(text) and brace_count > 0: + if text[end] == "{": + brace_count += 1 + elif text[end] == "}": + brace_count -= 1 + end += 1 + expr = text[start:end - 1].strip() + + frac_pattern = r"\\frac\{([^{}]+)\}\{([^{}]+)\}" + while re.search(frac_pattern, expr): + expr = re.sub(frac_pattern, r"(\1/\2)", expr) + + replacements = { + r"\times": "*", + r"\cdot": "*", + r"\div": "/", + } + for latex, op in replacements.items(): + expr = expr.replace(latex, op) + + expr = expr.replace(r"\\,", "").replace(r"\\ ", "") + expr = re.sub(r"\)\s*\(", ")*(", expr) + expr = re.sub(r"\)\s*(\d)", r")*\1", expr) + expr = re.sub(r"(\d)\s*\(", r"\1*(", expr) + + return expr + + +def extract_numbers_from_expr(expr): + numbers = re.findall(r"\d+\.?\d*", expr) + return [int(float(n)) if float(n).is_integer() else float(n) for n in numbers] + + +def validate_numbers_used(expr, expected_nums): + used_nums = extract_numbers_from_expr(expr) + return sorted(used_nums) == sorted(expected_nums) + + +def evaluate_expression(expr, expected_nums=None): + try: + if expected_nums is not None and not validate_numbers_used(expr, expected_nums): + return False + value = eval(expr, {"__builtins__": None}, {}) + return abs(value - 24) < 1e-6 + except Exception: + return False + + +def evaluate_game24_answer(answer, nums): + expr = extract_solution_game24(answer) + if not expr: + return False, None, "No expression found" + if evaluate_expression(expr, expected_nums=nums): + return True, expr, "Correct solution (evaluates to 24 using exactly the given numbers)" + used_nums = extract_numbers_from_expr(expr) + if sorted(used_nums) != sorted(nums): + return False, expr, f"Incorrect: Expression uses {used_nums}, expected {nums}" + return False, expr, "Expression does not evaluate to 24" + + +def extract_solution_mcq(text): + """Extract MCQ solution from model output.""" + patterns = [ + r"\\boxed\{([^}]*)\}", + r"boxed\{([^}]*)\}", + r"\*\*([A-D])\*\*", + r"answer[:\s]*([A-D])", + r"(?:^|\n)([A-D])(?:\s|$|\.)", + ] + + for pattern in patterns: + matches = re.findall(pattern, text, re.IGNORECASE) + if matches: + expr = matches[-1].strip() + choice_match = re.search(r"\b([ABCD])\b", expr, flags=re.IGNORECASE) + if choice_match: + return choice_match.group(1).upper() + + standalone = re.findall(r"\b([ABCD])\b", text) + if standalone: + return standalone[-1].upper() + + return None + + +def extract_options_from_prompt(prompt_text, target_options): + pattern = r"\b([A-D])\.\s*(.*?)(?=\s*[A-D]\.\s*|$)" + raw = re.findall(pattern, prompt_text, flags=re.DOTALL) + options = {k: v.strip().rstrip(".") for k, v in raw} + if target_options: + options = {k: v for k, v in options.items() if k in target_options} + return options + + +def evaluate_mcq_answer(answer, options, ground_truth): + sol = extract_solution_mcq(answer) + gt_sol = str(ground_truth).strip() + if not sol: + return False, None, "No expression found" + sol = sol.strip() + if sol in options: + if options[sol] == gt_sol: + return True, sol, f"Correct: option {sol} -> {options[sol]}" + return False, sol, f"Incorrect: expected '{gt_sol}', got '{options[sol]}' (option {sol})" + if sol.lower() == gt_sol.lower(): + return True, sol, f"Correct: answer text matches ground truth: {sol}" + for opt_letter, opt_value in options.items(): + if sol.lower() == opt_value.lower(): + if opt_value == gt_sol: + return True, sol, f"Correct: answer text {sol} (option {opt_letter})" + return False, sol, f"Incorrect: expected '{gt_sol}', got '{opt_value}' (option {opt_letter})" + return False, sol, f"Solution '{sol}' not found in options or ground truth" + + +def extract_solution_zebralogic(text): + """Extract JSON solution from ZebraLogic model output.""" + if not text: + return None + + def _try_parse(candidate: str): + try: + parsed = json.loads(candidate) + if isinstance(parsed, dict): + # Unwrap if it's a wrapper with "answer" key + if "answer" in parsed and isinstance(parsed["answer"], dict): + inner = parsed["answer"] + if any(re.match(r"^house\s*\d+$", str(k).strip(), flags=re.IGNORECASE) for k in inner.keys()): + return inner + return parsed + except json.JSONDecodeError: + return None + return None + + # Try to extract JSON from code blocks + patterns = [ + r"```json\s*({.*?})\s*```", # Markdown code block + r"```\s*({.*?})\s*```", # Generic code block + r"({\s*['\"]House\s*\d+['\"].*?})", # Direct JSON starting with House + ] + + for pattern in patterns: + matches = re.findall(pattern, text, re.DOTALL | re.IGNORECASE) + if matches: + json_str = matches[-1].strip() + solution = _try_parse(json_str) + if solution is not None: + return solution + + # Try parsing entire last large JSON-like structure + try: + # Find potential JSON starting with { + json_match = re.search(r"({\s*(?:['\"]House|['{\"\[])+[\s\S]*})", text) + if json_match: + json_str = json_match.group(1) + solution = _try_parse(json_str) + if solution is not None: + return solution + except (json.JSONDecodeError, AttributeError): + pass + + # Last-chance extraction: parse top-level JSON object spans and keep the + # last one that parses and looks like a house assignment dictionary. + stack = [] + spans = [] + for idx, ch in enumerate(text): + if ch == "{": + stack.append(idx) + elif ch == "}" and stack: + start = stack.pop() + if not stack: + spans.append((start, idx + 1)) + for start, end in reversed(spans): + candidate = text[start:end] + solution = _try_parse(candidate) + if solution is not None: + # Handle wrapped solution with "answer" key + if isinstance(solution, dict) and "answer" in solution: + answer = solution["answer"] + if isinstance(answer, dict) and any( + re.match(r"^house\s*\d+$", str(key).strip(), flags=re.IGNORECASE) + for key in answer.keys() + ): + return answer + # Direct house keys + if any( + re.match(r"^house\s*\d+$", str(key).strip(), flags=re.IGNORECASE) + for key in solution.keys() + ): + return solution + + return None + + +async def _request_zebralogic_json(prompt: str, llm_server: Dict[str, Any]) -> str: + """Submit a strict-JSON request for ZebraLogic and return raw model content.""" + payload = dict(llm_server["payload"]) + payload["temperature"] = 0.0 + payload["messages"] = [ + { + "role": "system", + "content": ( + "You solve Zebra Logic puzzles and MUST return strictly valid JSON only. " + "No markdown fences. No explanation. No extra text." + ), + }, + { + "role": "user", + "content": prompt, + }, + ] + payload["response_format"] = {"type": "json_object"} + async with httpx.AsyncClient(timeout=120.0) as client: + response = await client.post( + llm_server["url"], + headers=llm_server["headers"], + json=payload, + ) + response.raise_for_status() + body = response.json() + return body["choices"][0]["message"]["content"].strip() + + +async def finalize_zebralogic_json(problem: str, trajectory: str, llm_server: Dict[str, Any]) -> str: + """Ask the model to convert an existing trajectory into strict final JSON only.""" + prompt = ( + "Convert the reasoning into the final Zebra Logic answer JSON.\n" + "Output ONLY valid JSON (no markdown, no explanation).\n" + "Use exact feature/value names from the puzzle.\n\n" + "PUZZLE:\n" + f"{problem}\n\n" + "REASONING:\n" + f"{trajectory}\n" + ) + return await _request_zebralogic_json(prompt, llm_server) + + +async def solve_zebralogic_json_direct(problem: str, llm_server: Dict[str, Any]) -> str: + """Directly solve ZebraLogic and return strict final JSON.""" + prompt = ( + "Solve the Zebra Logic puzzle and provide the final house assignments.\n" + "Output ONLY valid JSON with keys like 'House 1', 'House 2', etc.\n" + "Use exact feature/value names from the puzzle text.\n\n" + "PUZZLE:\n" + f"{problem}\n" + ) + return await _request_zebralogic_json(prompt, llm_server) + + +def evaluate_zebralogic_answer(answer, ground_truth): + """Evaluate ZebraLogic solution against ground truth.""" + extracted = extract_solution_zebralogic(answer) + + if extracted is None: + return False, None, "Could not extract valid JSON solution" + + # Normalize keys (handle 'House X', 'house x', etc.) + def normalize_solution(sol): + normalized = {} + for key, value in sol.items(): + # Normalize house key + house_match = re.search(r"House\s*(\d+)", key, re.IGNORECASE) + if house_match: + house_num = house_match.group(1) + normalized[f"House {house_num}"] = value if isinstance(value, dict) else value + else: + normalized[key] = value + return normalized + + extracted_norm = normalize_solution(extracted) + ground_truth_norm = normalize_solution(ground_truth) if isinstance(ground_truth, dict) else ground_truth + + # Simple exact match on normalized solution + if extracted_norm == ground_truth_norm: + return True, extracted_norm, "Correct: Solution matches ground truth exactly" + + # Check if as string they're close (for JSON format differences) + extracted_str = json.dumps(extracted_norm, sort_keys=True) + gt_str = json.dumps(ground_truth_norm, sort_keys=True) if isinstance(ground_truth_norm, dict) else str(ground_truth_norm) + + if extracted_str == gt_str: + return True, extracted_norm, "Correct: Solution matches ground truth (format normalized)" + + # Partial credit: check if majority of houses match + if isinstance(ground_truth_norm, dict) and isinstance(extracted_norm, dict): + matches = sum(1 for k in extracted_norm if k in ground_truth_norm and extracted_norm[k] == ground_truth_norm[k]) + total = max(len(extracted_norm), len(ground_truth_norm)) + if matches > 0: + accuracy = matches / total + return False, extracted_norm, f"Partial match: {matches}/{total} houses correct ({accuracy:.1%})" + + return False, extracted_norm, "Incorrect: Solution does not match ground truth" + + +def load_dataset_for_task(task): + if task == "game24": + return load_dataset("nlile/24-game", split="train") + if task == "maze": + return load_dataset("microsoft/VISION_LANGUAGE", "maze_text_only", split="val") + if task == "spatialmap": + return load_dataset("microsoft/VISION_LANGUAGE", "spatial_map_text_only", split="val") + if task == "zebralogic": + return load_dataset("WildEval/ZebraLogic", name="grid_mode", split="test") + raise ValueError(f"Unsupported task: {task}") + + +def resolve_indices(task, dataset_len, args): + if args.indices: + return [int(x.strip()) for x in args.indices.split(",")] + if args.xrange: + parts = args.xrange.split("-") + if len(parts) == 2: + try: + start = int(parts[0].strip()) + end = int(parts[1].strip()) + return list(range(start, end)) + except ValueError: + raise ValueError(f"Invalid xrange format: {args.xrange}. Use 'start-end'") + if args.num_examples: + return list(np.linspace(0, dataset_len - 1, args.num_examples, dtype=int)) + start = args.start if args.start is not None else 0 + end = args.end if args.end is not None else dataset_len + return list(range(start, end)) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Run Tree-of-Thought search on a subset of the supported tasks", + ) + parser.add_argument("--task", choices=["game24", "maze", "spatialmap", "zebralogic"], required=True) + parser.add_argument("--k", type=int, default=1, help="Unused placeholder to mirror other baselines") + parser.add_argument("--num_examples", "-n", type=int, default=None) + parser.add_argument("--indices", type=str, default=None) + parser.add_argument("--xrange", type=str, default=None) + parser.add_argument("--start", type=int, default=None) + parser.add_argument("--end", type=int, default=None) + parser.add_argument("--model", default="Qwen/QwQ-32B") + parser.add_argument("--llm_url", default="http://localhost:{port}/v1/chat/completions") + parser.add_argument( + "--ports", + default="8000", + help="Comma-separated list of vLLM ports to round-robin across", + ) + parser.add_argument("--temperature", type=float, default=0.3) + parser.add_argument("--top_p", type=float, default=0.9) + parser.add_argument("--top_k", type=int, default=20) + parser.add_argument("--max_tokens", type=int, default=32768) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--search_method", choices=["bfs", "dfs", "beam"], default="beam") + parser.add_argument("--branching_factor", type=int, default=4) + parser.add_argument("--max_depth", type=int, default=1) + parser.add_argument("--beam_width", type=int, default=2) + parser.add_argument("--sure_threshold", type=float, default=0.7) + parser.add_argument("--likely_threshold", type=float, default=0.5) + parser.add_argument("--impossible_threshold", type=float, default=0.2) + parser.add_argument("--max_candidates_per_level", type=int, default=3) + parser.add_argument("--early_termination", action="store_true") + parser.add_argument("--no_cache", action="store_true") + parser.add_argument( + "--concurrency", + type=int, + default=1, + help="Maximum number of ToT examples to run concurrently", + ) + parser.add_argument( + "--output_dir", + default="outputs/tot_baseline", + help="Directory to store per-example JSON logs and summary", + ) + parser.add_argument("--log_level", default="INFO") + return parser.parse_args() + + +def parse_port_list(port_str: str) -> List[int]: + return [int(p.strip()) for p in port_str.split(",") if p.strip()] + + +def build_llm_server(args: argparse.Namespace, port: int) -> Dict[str, Any]: + payload = { + "model": args.model, + "temperature": args.temperature, + "top_p": args.top_p, + "top_k": args.top_k, + "max_tokens": args.max_tokens, + "stream": False, + "seed": args.seed, + } + return { + "url": args.llm_url.format(port=port), + "headers": {"content-type": "application/json"}, + "payload": payload, + } + + +def build_tot_config(args: argparse.Namespace) -> ToTSearchConfig: + method = SearchMethod[args.search_method.upper()] + return ToTSearchConfig( + branching_factor=args.branching_factor, + max_depth=args.max_depth, + search_method=method, + beam_width=args.beam_width, + sure_threshold=args.sure_threshold, + likely_threshold=args.likely_threshold, + impossible_threshold=args.impossible_threshold, + early_termination=args.early_termination, + cache_evaluations=not args.no_cache, + max_candidates_per_level=args.max_candidates_per_level, + ) + + +def ensure_output_dir(base_dir: str, task: str) -> Path: + path = Path(base_dir).expanduser().resolve() / task + path.mkdir(parents=True, exist_ok=True) + return path + + +def prepare_eval(task: str, example: Dict[str, Any]) -> Tuple: + if task == "game24": + nums = list(example.get("numbers", [])) + return (lambda output: evaluate_game24_answer(output, nums), {"numbers": nums}) + if task == "zebralogic": + # ZebraLogic ground truth is the solution JSON + ground_truth = example.get("solution", {}) + meta = {"ground_truth_sample": str(ground_truth)[:100]} + return (lambda output: evaluate_zebralogic_answer(output, ground_truth), meta) + gt = str(example.get("ground_truth", "")).strip() + target_options = ["A", "B"] if gt == "Q4" else ["A", "B", "C", "D"] + if task == "maze": + _, user_prompt = build_maze_prompt(example) + else: + _, user_prompt = build_spatialmap_prompt(example) + options = extract_options_from_prompt(user_prompt, target_options) + meta = {"options": options, "ground_truth": gt} + return (lambda output: evaluate_mcq_answer(output, options, gt), meta) + + +async def run_single_example( + idx: int, + task: str, + example: Dict[str, Any], + tot_config: ToTSearchConfig, + llm_server: Dict[str, Any], +) -> Dict[str, Any]: + eval_fn, eval_meta = prepare_eval(task, example) + problem = build_tot_problem(task, example, nums=example.get("numbers")) + tot = TreeOfThoughtSearch(tot_config) + search_result = await tot.search(task, problem, llm_server) + best_traj = search_result.get("best_trajectory", "") + best_value = search_result.get("best_value", 0.0) + is_correct, extracted, message = eval_fn(best_traj) + + # ZebraLogic often ends with partial reasoning trajectories; add strict-JSON + # recovery passes before scoring. + finalized_answer = None + direct_answer = None + if task == "zebralogic" and (not is_correct): + try: + finalized_answer = await finalize_zebralogic_json(problem, best_traj, llm_server) + final_is_correct, final_extracted, final_message = eval_fn(finalized_answer) + if final_extracted is not None or final_is_correct: + is_correct = final_is_correct + extracted = final_extracted + message = final_message + best_traj = finalized_answer + except Exception as exc: # pragma: no cover + LOGGER.warning("ZebraLogic finalization failed for index %s: %s", idx, exc) + + if task == "zebralogic" and (not is_correct): + try: + direct_answer = await solve_zebralogic_json_direct(problem, llm_server) + direct_is_correct, direct_extracted, direct_message = eval_fn(direct_answer) + if direct_extracted is not None or direct_is_correct: + is_correct = direct_is_correct + extracted = direct_extracted + message = direct_message + best_traj = direct_answer + except Exception as exc: # pragma: no cover + LOGGER.warning("ZebraLogic direct solve failed for index %s: %s", idx, exc) + + return { + "index": int(idx), + "best_value": best_value, + "best_trajectory": best_traj, + "raw_best_trajectory": search_result.get("best_trajectory", ""), + "finalized_answer": finalized_answer, + "direct_answer": direct_answer, + "search_stats": search_result.get("search_stats", {}), + "decision_tree": search_result.get("decision_tree", []), + "correct": bool(is_correct), + "extracted": extracted, + "message": message, + "evaluation_meta": eval_meta, + } + + +async def run_tot_baseline(args: argparse.Namespace) -> None: + logging.basicConfig(level=getattr(logging, args.log_level.upper(), logging.INFO)) + dataset = load_dataset_for_task(args.task) + indices = resolve_indices(args.task, len(dataset), args) + output_dir = ensure_output_dir(args.output_dir, args.task) + tot_config = build_tot_config(args) + ports = parse_port_list(args.ports) + if not ports: + raise ValueError("At least one port must be specified via --ports") + concurrency = max(1, args.concurrency) + port_lock = asyncio.Lock() + port_index = {"value": 0} + + async def next_port() -> int: + async with port_lock: + port = ports[port_index["value"] % len(ports)] + port_index["value"] += 1 + return port + + semaphore = asyncio.Semaphore(concurrency) + + async def process_index(idx: int) -> Dict[str, Any]: + async with semaphore: + example = dataset[int(idx)] + port = await next_port() + llm_server = build_llm_server(args, port) + LOGGER.info("Running ToT on example %s via port %s", idx, port) + try: + record = await run_single_example(idx, args.task, example, tot_config, llm_server) + except Exception as exc: # pragma: no cover + LOGGER.exception("Failed example %s", idx) + record = { + "index": int(idx), + "error": str(exc), + "best_trajectory": "", + "correct": False, + } + example_path = output_dir / f"example_{idx}.json" + with example_path.open("w", encoding="utf-8") as handle: + json.dump(record, handle, indent=2) + return record + + processed = await asyncio.gather(*[process_index(idx) for idx in indices]) + + total = len(processed) + correct = sum(1 for r in processed if r.get("correct")) + summary = { + "task": args.task, + "model": args.model, + "total_examples": total, + "correct": correct, + "accuracy": (correct / total) if total else 0.0, + "search_method": args.search_method, + "config": { + "branching_factor": args.branching_factor, + "max_depth": args.max_depth, + "beam_width": args.beam_width, + "sure_threshold": args.sure_threshold, + "likely_threshold": args.likely_threshold, + "impossible_threshold": args.impossible_threshold, + "max_candidates_per_level": args.max_candidates_per_level, + "early_termination": args.early_termination, + "cache_evaluations": not args.no_cache, + "ports": ports, + "concurrency": concurrency, + }, + } + summary_path = output_dir / "summary.json" + with summary_path.open("w", encoding="utf-8") as handle: + json.dump(summary, handle, indent=2) + LOGGER.info("Accuracy %.2f (%d/%d)", summary["accuracy"], correct, total) + + +if __name__ == "__main__": + asyncio.run(run_tot_baseline(parse_args())) \ No newline at end of file diff --git a/interwhen/tree_of_thought.py b/interwhen/tree_of_thought.py new file mode 100644 index 0000000..cbca864 --- /dev/null +++ b/interwhen/tree_of_thought.py @@ -0,0 +1,1268 @@ +""" +Tree of Thought implementation for interwhen-style streaming completion. + +Implements proper ToT search using: +1. Propose function to generate candidate next steps +2. Value function to evaluate intermediate states +3. Search algorithm (BFS/DFS/beam) to explore the tree +4. Integrated with interwhen's async streaming architecture +""" + +import asyncio +import json +import logging +import re +from typing import Dict, List, Any, Optional, Tuple +from dataclasses import dataclass, field +from enum import Enum +import time + +from .value_prompts import ( + build_game24_value_prompt, + build_mcq_value_prompt, + build_tot_value_prompt as build_tot_value_prompt_impl, +) + +logger = logging.getLogger(__name__) + + +# --------------------- Dataset prompt helpers --------------------- + +def build_game24_prompt(nums: List[int]) -> str: + """Return the canonical Game24 instruction block used across baselines.""" + if len(nums) != 4: + raise ValueError("Game24 requires exactly four numbers.") + a, b, c, d = nums + boxed = r"\\boxed{}" + return ( + "You are solving the Game of 24.\n\n" + f"You are given four numbers: {a}, {b}, {c}, {d}\n\n" + "Your job is to produce a valid arithmetic expression using:\n" + "- ALL four numbers exactly once\n- ONLY +, -, *, /\n" + "- The expression must evaluate to exactly 24.\n\n" + "Please reason step by step, and put your final answer containing" + f" only the expression within {boxed}." + ) + + +def build_maze_prompt(example: Dict[str, Any]) -> str: + """Construct the maze reasoning instructions used in other pipelines.""" + pre_prompt = ( + "You are an expert problem solver. Carefully read the following " + "multiple-choice question and think through the solution step-by-step " + "before providing your final answer. Provide the final answer option by " + "enclosing it within \\boxed{A/B/C/D}." + ) + description = str(example.get("prompt", "")) + return f"{pre_prompt}\n\n{description.strip()}" + + +def build_spatialmap_prompt(example: Dict[str, Any]) -> str: + """Construct the spatial reasoning instructions for TOT experiments.""" + pre_prompt = ( + "You are an expert problem solver. Carefully read the following " + "multiple-choice question and think through the solution step-by-step " + "before providing your final answer. Provide the final answer option by " + "enclosing it within \\boxed{A/B/C/D}." + ) + description = str(example.get("prompt", "")) + return f"{pre_prompt}\n\n{description.strip()}" + + +def build_zebralogic_prompt(example: Dict[str, Any]) -> str: + """Construct the Zebra Logic puzzle solving instructions for TOT experiments.""" + puzzle_text = str(example.get("puzzle", "")) + prompt = ( + "# Problem Description\n\n" + "You are solving a house grid logic puzzle. You are given:\n" + "1. Features and Domains\n" + " - A fixed number of houses, indexed sequentially (e.g., House 1, House 2, …) from left to right.\n" + " - A set of features (e.g., color, name, pet, book genre).\n" + " - Each feature has a finite domain of possible values.\n" + "2. Constraints:\n" + " - Each house has exactly one value per feature.\n" + " - No two houses share the same value for the same feature.\n" + "3. Clues / Constraints describing:\n" + " - Houses and their positions\n" + " - Feature values\n" + " - Relative ordering (e.g., 'next to', 'to the left of', '2 houses away from')\n\n" + "Solve this puzzle to your best ability by determining the arrangement of features across the houses.\n\n" + "# Puzzle\n\n" + f"{puzzle_text}\n\n" + "# Solution Format\n\n" + "Provide your final answer in this exact JSON format:\n" + "```json\n" + '{\n' + ' "House 1": { "feature1": "value1", "feature2": "value2", ... },\n' + ' "House 2": { "feature1": "value1", "feature2": "value2", ... },\n' + ' ...\n' + '}\n' + "```\n\n" + "Make sure to use the exact feature/value names as given in the puzzle.\n" + "Ensure the JSON is valid and parsable." + ) + return prompt + + +def build_tot_problem(task: str, example: Dict[str, Any], nums: Optional[List[int]] = None) -> str: + """Helper that mirrors the best-of-k prompt builders for ToT runs.""" + task_lower = task.lower() + if task_lower == "game24": + numbers = nums or example.get("numbers") + if numbers is None: + raise ValueError("Game24 prompt requires 'numbers' in the example") + return build_game24_prompt(list(numbers)) + if task_lower == "maze": + return build_maze_prompt(example) + if task_lower == "spatialmap": + return build_spatialmap_prompt(example) + if task_lower == "zebralogic": + return build_zebralogic_prompt(example) + raise ValueError(f"Unsupported task for ToT prompt building: {task}") + + +def build_tot_value_prompt( + task: str, + problem: str, + trajectory: str, + use_fewshot: bool = True +) -> str: + """ + Build value prompt for Tree of Thought evaluation. + + Args: + task: The task type (e.g., "game24", "maze", "spatialmap") + problem: The original problem statement + trajectory: Current partial solution (or 'No progress yet' if empty) + use_fewshot: Whether to use few-shot examples (default True for better evaluation) + + Returns: + Formatted value prompt with or without few-shot examples + """ + if not trajectory.strip(): + trajectory = "No progress yet" + return build_tot_value_prompt_impl(task, problem, trajectory, use_fewshot=use_fewshot) + + +class SearchMethod(Enum): + """Search algorithm types""" + BFS = "bfs" + DFS = "dfs" + BEAM = "beam" + + +@dataclass +class TreeNode: + """Represents a node in the Tree of Thought""" + trajectory: str + depth: int + value: float = 0.5 + parent: Optional['TreeNode'] = None + children: List['TreeNode'] = field(default_factory=list) + is_terminal: bool = False + proposals: List[str] = field(default_factory=list) + evaluation_log: Dict[str, Any] = field(default_factory=dict) + + def __hash__(self): + return hash(self.trajectory) + + def __eq__(self, other): + return isinstance(other, TreeNode) and self.trajectory == other.trajectory + + +@dataclass +class ToTSearchConfig: + """Configuration for Tree of Thought search""" + branching_factor: int = 4 + max_depth: int = 6 + search_method: SearchMethod = SearchMethod.BFS + beam_width: int = 2 + + # Value thresholds + sure_threshold: float = 0.7 + likely_threshold: float = 0.5 + impossible_threshold: float = 0.2 + + # Optimization settings + early_termination: bool = True + cache_evaluations: bool = True + max_candidates_per_level: int = 3 + + +class TreeOfThoughtSearch: + """ + Tree of Thought search controller compatible with interwhen's streaming. + + Provides propose/evaluate/search methods that work with vLLM API calls + via the llm_server interface used in interwhen. + """ + + def __init__(self, config: ToTSearchConfig = None): + self.config = config or ToTSearchConfig() + self.evaluation_cache = {} + self.proposal_cache = {} + self.search_stats = { + "nodes_explored": 0, + "evaluations_performed": 0, + "branches_pruned": 0, + "cache_hits": 0, + "solutions_found": 0, + "total_nodes_in_tree": 0, + } + self.decision_tree = [] + self.root = None + + # ===================== PROPOSE FUNCTION ===================== + + async def propose_next_steps( + self, + task: str, + problem: str, + current_trajectory: str, + llm_server: Dict, + num_proposals: Optional[int] = None, + ) -> List[str]: + """ + Generate candidate next steps using the model's propose capability. + + Args: + task: The task type (e.g., "game24", "maze", "spatialmap") + problem: The original problem statement + current_trajectory: Current partial solution + llm_server: vLLM server config (url, headers, payload template) + num_proposals: Number of proposals to generate (defaults to branching_factor) + + Returns: + List of proposed next steps + """ + if num_proposals is None: + num_proposals = self.config.branching_factor + + # Check cache + cache_key = f"propose_{hash(problem)}_{hash(current_trajectory)}" + if self.config.cache_evaluations and cache_key in self.proposal_cache: + self.search_stats["cache_hits"] += 1 + return self.proposal_cache[cache_key] + + self.search_stats["nodes_explored"] += 1 + + # Build propose prompt + propose_prompt = self._build_propose_prompt( + task, + problem, + current_trajectory, + num_proposals + ) + + # Call model with streaming + proposal_text = await self._call_llm_streaming( + llm_server, + propose_prompt + ) + + # Parse proposals from response + proposals = self._parse_proposals(proposal_text, num_proposals) + + logger.info( + "Generated %d proposals at depth hint=%s", + len(proposals), + "root" if not current_trajectory.strip() else "non-root", + ) + + # Log decision point + decision_log = { + "type": "proposal_generation", + "timestamp": time.time(), + "problem_hash": hash(problem), + "trajectory": current_trajectory, + "prompt": propose_prompt, + "prompt_preview": propose_prompt[:200] + "..." if len(propose_prompt) > 200 else propose_prompt, + "raw_response": proposal_text, + "raw_response_preview": proposal_text[:300] + "..." if len(proposal_text) > 300 else proposal_text, + "parsed_proposals": proposals, + } + self.decision_tree.append(decision_log) + + # Cache + if self.config.cache_evaluations: + self.proposal_cache[cache_key] = proposals + + return proposals + + def _build_propose_prompt( + self, + task: str, + problem: str, + trajectory: str, + num_proposals: int + ) -> str: + """Build a prompt requesting proposals for next steps.""" + if task == "maze": + return self._build_maze_propose_prompt(problem, trajectory, num_proposals) + if task == "spatialmap": + return self._build_spatialmap_propose_prompt(problem, trajectory, num_proposals) + + return f"""Given the following problem and current progress, propose {num_proposals} possible next steps. + +PROBLEM: +{problem} + +CURRENT PROGRESS/TRAJECTORY: +{trajectory if trajectory.strip() else "Starting fresh - no progress yet"} + +Generate {num_proposals} distinct next steps that could advance the solution. Be specific and actionable. + +Format each proposal clearly, one per line: +1. [Proposal 1] +2. [Proposal 2] +... + +Think step by step about what makes each proposal viable. +""" + + def _detect_maze_question_type(self, problem: str) -> str: + """Detect maze subtype for proposal steering (Q0/Q2/Q4).""" + lower = problem.lower() + if "how many right turns" in lower: + logger.debug("Detected Q0: right turn counting") + return "q0" + if "how many turns" in lower and "right turns" not in lower: + logger.debug("Detected Q2: total turn counting") + return "q2" + if "starting from s" in lower and "where is e" in lower: + logger.debug("Detected Q4: spatial relation") + return "q4" + if "relative" in lower and "s" in lower and "e" in lower: + logger.debug("Detected Q4: spatial relation (relative)") + return "q4" + logger.warning(f"Maze question type not recognized, using generic. Problem preview: {lower[:200]}") + return "generic" + + def _extract_last_move_info(self, trajectory: str) -> dict: + """Extract previous direction and counter from last step in trajectory.""" + if not trajectory.strip(): + return {"prev_direction": None, "right_count": 0, "left_count": 0, "total_count": 0} + + lines = trajectory.strip().split('\n') + last_line = lines[-1] if lines else "" + + # Try to extract direction from "Next move: [DIRECTION]" + import re + direction_match = re.search(r'Next move:\s*(UP|DOWN|LEFT|RIGHT)', last_line, re.IGNORECASE) + prev_direction = direction_match.group(1).upper() if direction_match else None + + # Extract counters + right_match = re.search(r'Right-turn count:\s*(\d+)', last_line) + left_match = re.search(r'Left-turn count:\s*(\d+)', last_line) + total_match = re.search(r'Total-turn count:\s*(\d+)', last_line) + + right_count = int(right_match.group(1)) if right_match else 0 + left_count = int(left_match.group(1)) if left_match else 0 + total_count = int(total_match.group(1)) if total_match else 0 + + return { + "prev_direction": prev_direction, + "right_count": right_count, + "left_count": left_count, + "total_count": total_count + } + + def _build_maze_propose_prompt( + self, + problem: str, + trajectory: str, + num_proposals: int, + ) -> str: + """Build maze-specific atomic next-step proposal prompts by question type.""" + question_type = self._detect_maze_question_type(problem) + logger.debug(f"Detected maze question type: {question_type}") + + # Extract previous move info for bookkeeping + prev_info = self._extract_last_move_info(trajectory) + + if not trajectory.strip(): + current = "Starting fresh - no progress yet" + last_step_hint = "" + else: + current = trajectory.strip() + lines = current.split('\n') + last_line = lines[-1] if lines else "" + last_step_hint = f"\nLAST COMPLETED STEP: {last_line}\nNow generate the NEXT move after this (do NOT repeat this move).\n" + + if question_type == "q0": + prev_dir = prev_info["prev_direction"] + prev_count = prev_info["right_count"] + + if prev_dir is None: + # First move - all directions result in STRAIGHT with count 0 + examples = f"""PARENT: first move, count=0 + +Valid answers (pick {num_proposals}): +Next move: UP | Turn: STRAIGHT | Right-turn count: 0 +Next move: DOWN | Turn: STRAIGHT | Right-turn count: 0 +Next move: LEFT | Turn: STRAIGHT | Right-turn count: 0 +Next move: RIGHT | Turn: STRAIGHT | Right-turn count: 0""" + else: + # Define turn mappings + turn_map = { + "UP": {"RIGHT": "RIGHT", "LEFT": "LEFT", "UP": "STRAIGHT", "DOWN": "STRAIGHT"}, + "DOWN": {"LEFT": "RIGHT", "RIGHT": "LEFT", "DOWN": "STRAIGHT", "UP": "STRAIGHT"}, + "LEFT": {"UP": "RIGHT", "DOWN": "LEFT", "LEFT": "STRAIGHT", "RIGHT": "STRAIGHT"}, + "RIGHT": {"DOWN": "RIGHT", "UP": "LEFT", "RIGHT": "STRAIGHT", "LEFT": "STRAIGHT"} + } + + moves = turn_map.get(prev_dir, {}) + examples_list = [] + for next_dir, turn_type in moves.items(): + new_count = prev_count + 1 if turn_type == "RIGHT" else prev_count + examples_list.append(f"Next move: {next_dir} | Turn: {turn_type} | Right-turn count: {new_count}") + + examples = f"""PARENT: direction={prev_dir}, count={prev_count} + +Valid answers (pick {num_proposals}): +{chr(10).join(examples_list)}""" + + return f"""{examples} + +DO NOT explain. DO NOT reason. Just output {num_proposals} lines from above.""" + + if question_type == "q2": + prev_dir = prev_info["prev_direction"] + prev_count = prev_info["total_count"] + + if prev_dir is None: + examples = f"""PARENT: first move, count=0 + +Valid answers (pick {num_proposals}): +Next move: UP | Turn: STRAIGHT | Total-turn count: 0 +Next move: DOWN | Turn: STRAIGHT | Total-turn count: 0 +Next move: LEFT | Turn: STRAIGHT | Total-turn count: 0 +Next move: RIGHT | Turn: STRAIGHT | Total-turn count: 0""" + else: + # Define turn mappings (same as Q0) + turn_map = { + "UP": {"RIGHT": "RIGHT", "LEFT": "LEFT", "UP": "STRAIGHT", "DOWN": "STRAIGHT"}, + "DOWN": {"LEFT": "RIGHT", "RIGHT": "LEFT", "DOWN": "STRAIGHT", "UP": "STRAIGHT"}, + "LEFT": {"UP": "RIGHT", "DOWN": "LEFT", "LEFT": "STRAIGHT", "RIGHT": "STRAIGHT"}, + "RIGHT": {"DOWN": "RIGHT", "UP": "LEFT", "RIGHT": "STRAIGHT", "LEFT": "STRAIGHT"} + } + + moves = turn_map.get(prev_dir, {}) + examples_list = [] + for next_dir, turn_type in moves.items(): + new_count = prev_count + 1 if turn_type in ["RIGHT", "LEFT"] else prev_count + examples_list.append(f"Next move: {next_dir} | Turn: {turn_type} | Total-turn count: {new_count}") + + examples = f"""PARENT: direction={prev_dir}, count={prev_count} + +Valid answers (pick {num_proposals}): +{chr(10).join(examples_list)}""" + + return f"""{examples} + +DO NOT explain. DO NOT reason. Just output {num_proposals} lines from above.""" + + if question_type == "q4": + # Q4 should also be structured - no long reasoning + return f"""Maze spatial question. Generate {num_proposals} brief factual statements. + +{current if current else "Starting."} + +Output {num_proposals} lines. Each line: one short fact. NO long explanations.""" + + # Generic fallback - also keep it structured + return f"""Maze question. Generate {num_proposals} next steps. + +{current if current else "Starting."} + +Output {num_proposals} lines. Each line: one short step. NO explanations.""" + + def _detect_spatialmap_question_type(self, problem: str) -> str: + """Detect spatialmap subtype for proposal steering (direction/object/counting).""" + lower = problem.lower() + if "how many" in lower and ("objects" in lower or "places" in lower or "locations" in lower): + return "counting" + if "which object" in lower or "what object" in lower or "which place" in lower or "which location" in lower: + return "object" + if "in which direction" in lower or "what direction" in lower or "relative to" in lower: + return "direction" + return "generic" + + def _build_spatialmap_propose_prompt( + self, + problem: str, + trajectory: str, + num_proposals: int, + ) -> str: + """Build spatialmap-specific atomic next-step proposal prompts by question type.""" + question_type = self._detect_spatialmap_question_type(problem) + logger.debug(f"Detected spatialmap question type: {question_type}") + + if not trajectory.strip(): + current = "Starting fresh - no progress yet" + last_step_hint = "" + else: + current = trajectory.strip() + lines = [line.strip() for line in current.split("\n") if line.strip()] + last_line = lines[-1] if lines else "" + last_step_hint = ( + f"\nLAST COMPLETED STEP: {last_line}\n" + "Now generate the NEXT atomic step after this (do NOT repeat this step).\n" + ) + + if question_type == "direction": + return f"""You are solving a spatial-map DIRECTION question. + +PROBLEM: +{problem} + +TRAJECTORY SO FAR: +{current} +{last_step_hint} +Your task: Propose {num_proposals} ATOMIC next steps only. +- Each step must advance exactly ONE concrete spatial inference +- Prefer one of: parse one relation, apply reversibility once, apply transitivity once, or map target-vs-reference direction +- Do NOT restate the whole map + +Output format (one line per proposal, no preamble): +1. [Atomic spatial inference] +2. [Atomic spatial inference] +...""" + + if question_type == "object": + return f"""You are solving a spatial-map OBJECT-IDENTIFICATION question. + +PROBLEM: +{problem} + +TRAJECTORY SO FAR: +{current} +{last_step_hint} +Your task: Propose {num_proposals} ATOMIC next steps only. +- Each step should do one action: identify candidate set, eliminate one candidate, or validate one relation against query direction +- Keep steps local and specific to the asked direction/object +- Do NOT rewrite all relationships + +Output format (one line per proposal, no preamble): +1. [Atomic candidate/evidence step] +2. [Atomic candidate/evidence step] +...""" + + if question_type == "counting": + return f"""You are solving a spatial-map COUNTING question. + +PROBLEM: +{problem} + +TRAJECTORY SO FAR: +{current} +{last_step_hint} +Your task: Propose {num_proposals} ATOMIC next steps only. +- Each step should do one action: identify one qualifying object, rule out one object, or update running count by exactly one justified change +- Keep a clear running count state +- Do NOT provide final answer yet unless count is fully justified + +Output format (one line per proposal, no preamble): +1. [Atomic counting step, e.g., "Qualifies: ; running count = n"] +2. [Atomic counting step, e.g., "Ruled out: ; running count = n"] +...""" + + return f"""You are solving a spatial-map reasoning question. + +PROBLEM: +{problem} + +TRAJECTORY SO FAR: +{current} +{last_step_hint} +Propose {num_proposals} atomic, actionable next steps. +Each step must add ONE new spatial fact/inference and be different from prior steps. + +Output format (one line per proposal, no preamble): +1. ... +2. ... +...""" + + def _parse_proposals(self, response: str, num_proposals: int) -> List[str]: + """ + Parse proposals from model response. + Handles various formats (numbered lists, bullets, etc.) + """ + proposals: List[str] = [] + + # First, try to extract exact format lines: "Next move: X | Turn: Y | ...count: Z" + # Match line by line to avoid cross-line pollution + for line in response.split('\n'): + line = line.strip() + # Check if it matches our exact format + if 'Next move:' in line and 'Turn:' in line and 'count:' in line: + proposals.append(line) + if len(proposals) >= num_proposals: + return proposals[:num_proposals] + + # If we got enough exact format proposals, return them + if len(proposals) >= num_proposals: + return proposals[:num_proposals] + + # Fallback: numbered multiline blocks (preserve continuation lines) + numbered_blocks = re.findall( + r"(?:^|\n)\s*\d+[\.)]\s*(.+?)(?=(?:\n\s*\d+[\.)]\s*)|\Z)", + response, + flags=re.DOTALL, + ) + for block in numbered_blocks: + cleaned = " ".join(block.strip().split()) + if cleaned and len(cleaned) > 3: + proposals.append(cleaned) + if len(proposals) >= num_proposals: + return proposals[:num_proposals] + + # Second pass: fallback line parser for bullets/single-line proposals + for line in response.split("\n"): + cleaned = line.strip() + if not cleaned or cleaned in ["Next steps:", "Proposals:", "possible next steps"]: + continue + + cleaned = re.sub(r"^\s*(?:\d+[\.)]|[-•*])\s*", "", cleaned) + cleaned = re.sub(r"^\s*\[\d+\]\s*", "", cleaned) + cleaned = re.sub(r"^\s*proposal\s*[:\-]\s*", "", cleaned, flags=re.IGNORECASE) + cleaned = " ".join(cleaned.split()) + + if cleaned and len(cleaned) > 3: + proposals.append(cleaned) + if len(proposals) >= num_proposals: + return proposals[:num_proposals] + + return proposals[:num_proposals] + + # ===================== EVALUATE FUNCTION ===================== + + async def evaluate_state( + self, + task: str, + problem: str, + trajectory: str, + llm_server: Dict, + ) -> float: + """ + Evaluate the quality/progress of current state. + + Args: + task: The task type (e.g., "game24", "maze", "spatialmap") + problem: Original problem + trajectory: Current solution trajectory + llm_server: vLLM server config + + Returns: + Value score between 0.0 and 1.0 + """ + # Check cache + cache_key = f"evaluate_{hash(problem)}_{hash(trajectory)}" + if self.config.cache_evaluations and cache_key in self.evaluation_cache: + self.search_stats["cache_hits"] += 1 + return self.evaluation_cache[cache_key] + + self.search_stats["evaluations_performed"] += 1 + + # Build evaluation prompt + eval_prompt = self._build_evaluation_prompt(task, problem, trajectory) + + # Call model + eval_response = await self._call_llm_streaming(llm_server, eval_prompt) + + # Parse evaluation into score + score = self._parse_evaluation(eval_response) + confidence_label = self._extract_confidence_label(eval_response, score) + + # Log evaluation + eval_log = { + "type": "state_evaluation", + "timestamp": time.time(), + "trajectory": trajectory, + "prompt_preview": eval_prompt[:200] + "...", + "response_preview": eval_response[:200] + "...", + "score": score, + "confidence": confidence_label, + } + if self.decision_tree: + if "evaluations" not in self.decision_tree[-1]: + self.decision_tree[-1]["evaluations"] = [] + self.decision_tree[-1]["evaluations"].append(eval_log) + + # Cache + if self.config.cache_evaluations: + self.evaluation_cache[cache_key] = score + + return score + + def _build_evaluation_prompt(self, task: str, problem: str, trajectory: str) -> str: + """Build dataset-aware evaluation prompts reused by ToT scoring.""" + return build_tot_value_prompt(task, problem, trajectory) + + def _parse_evaluation(self, response: str) -> float: + """ + Parse evaluation response into a scalar score [0, 1] + """ + response_lower = response.lower() + + confidence_keywords = { + "sure": 0.9, "certain": 0.9, "confident": 0.9, + "likely": 0.7, "probably": 0.7, + "possible": 0.5, "maybe": 0.5, + "unlikely": 0.3, "doubtful": 0.3, + "impossible": 0.1, "blocked": 0.1, + } + + for keyword, score in confidence_keywords.items(): + if keyword in response_lower: + return score + + # Try to extract numeric score if present (1-9 scale) + for i, char in enumerate(response): + if char.isdigit(): + digit = int(char) + if 1 <= digit <= 9: + return digit / 9.0 # Normalize to [0, 1] + + return 0.5 # Default neutral score + + def _score_to_confidence(self, score: float) -> str: + """Map scalar score [0,1] to confidence bucket.""" + if score >= 0.8: + return "sure" + if score >= 0.6: + return "likely" + if score >= 0.4: + return "possible" + if score >= 0.2: + return "unlikely" + return "impossible" + + def _extract_confidence_label(self, response: str, score: float) -> str: + """Extract confidence label from value response; fallback to score mapping.""" + lower = response.lower() + for label in ["sure", "likely", "possible", "unlikely", "impossible"]: + if label in lower: + return label + return self._score_to_confidence(score) + + def _log_proposal_transition( + self, + depth: int, + parent_trajectory: str, + proposal: str, + next_state: str, + value: float, + is_terminal: bool, + pruned: bool, + ) -> None: + """Log proposal -> next-state -> value transition for debugging/analysis.""" + self.decision_tree.append( + { + "type": "proposal_transition", + "timestamp": time.time(), + "depth": depth, + "parent_trajectory": parent_trajectory, + "proposal": proposal, + "next_state": next_state, + "value": value, + "value_confidence": self._score_to_confidence(value), + "is_terminal": is_terminal, + "pruned": pruned, + } + ) + + logger.info( + "ToT transition | depth=%d | proposal=%s | value=%.3f | confidence=%s | pruned=%s | terminal=%s", + depth, + proposal, + value, + self._score_to_confidence(value), + pruned, + is_terminal, + ) + + # ===================== SEARCH IMPLEMENTATION ===================== + + async def search( + self, + task: str, + problem: str, + llm_server: Dict, + ) -> Dict[str, Any]: + """ + Perform Tree of Thought search on the problem. + + Args: + task: The task type (e.g., "game24", "maze", "spatialmap") + problem: Problem statement + llm_server: vLLM server config + + Returns: + Dictionary with best_trajectory, best_value, search_log + """ + logger.info(f"Starting ToT search with method={self.config.search_method.value}") + + # Initialize root node + self.root = TreeNode(trajectory="", depth=0, value=0.5) + + if self.config.search_method == SearchMethod.BFS: + return await self._bfs_search(task, problem, llm_server) + elif self.config.search_method == SearchMethod.BEAM: + return await self._beam_search(task, problem, llm_server) + else: + return await self._dfs_search(task, problem, llm_server) + + async def _bfs_search(self, task: str, problem: str, llm_server: Dict) -> Dict[str, Any]: + """Breadth-First Search implementation""" + queue = [self.root] + best_terminal = None + best_value = 0.0 + best_candidate = None + best_candidate_value = float('-inf') + + for depth in range(self.config.max_depth): + if not queue: + break + + next_queue = [] + + for node in queue: + # Generate proposals + proposals = await self.propose_next_steps( + task, + problem, + node.trajectory, + llm_server, + self.config.branching_factor + ) + node.proposals = proposals + + # Create child nodes + for prop in proposals: + new_trajectory = f"{node.trajectory}\n{prop}" if node.trajectory else prop + child = TreeNode( + trajectory=new_trajectory, + depth=depth + 1, + parent=node, + ) + + # Evaluate + value = await self.evaluate_state(task, problem, new_trajectory, llm_server) + child.value = value + + # Track best candidate regardless of terminal status + if value > best_candidate_value: + best_candidate = child + best_candidate_value = value + + # Check if terminal and meets threshold + if self._is_terminal(new_trajectory): + child.is_terminal = True + self.search_stats["solutions_found"] += 1 + if value > best_value: + best_value = value + best_terminal = child + is_terminal = True + + # Early termination if high confidence + if self.config.early_termination and value >= self.config.sure_threshold: + self._log_proposal_transition( + depth=depth + 1, + parent_trajectory=node.trajectory, + proposal=prop, + next_state=new_trajectory, + value=value, + is_terminal=is_terminal, + pruned=False, + ) + return self._format_search_result(best_terminal, problem) + else: + is_terminal = False + + # Prune low-value nodes + if value < self.config.impossible_threshold: + self.search_stats["branches_pruned"] += 1 + self._log_proposal_transition( + depth=depth + 1, + parent_trajectory=node.trajectory, + proposal=prop, + next_state=new_trajectory, + value=value, + is_terminal=is_terminal, + pruned=True, + ) + continue + + self._log_proposal_transition( + depth=depth + 1, + parent_trajectory=node.trajectory, + proposal=prop, + next_state=new_trajectory, + value=value, + is_terminal=is_terminal, + pruned=False, + ) + + node.children.append(child) + next_queue.append(child) + self.search_stats["total_nodes_in_tree"] += 1 + + queue = next_queue[:self.config.max_candidates_per_level] + + return self._format_search_result(best_terminal or best_candidate, problem) + + async def _beam_search(self, task: str, problem: str, llm_server: Dict) -> Dict[str, Any]: + """Beam Search implementation""" + beam = [self.root] + best_terminal = None + best_value = 0.0 + best_candidate = None + best_candidate_value = float('-inf') + + for depth in range(self.config.max_depth): + candidates = [] + + for node in beam: + # Generate and evaluate proposals + proposals = await self.propose_next_steps( + task, + problem, + node.trajectory, + llm_server, + self.config.branching_factor + ) + node.proposals = proposals + + for prop in proposals: + new_trajectory = f"{node.trajectory}\n{prop}" if node.trajectory else prop + value = await self.evaluate_state(task, problem, new_trajectory, llm_server) + + child = TreeNode( + trajectory=new_trajectory, + depth=depth + 1, + value=value, + parent=node, + ) + + candidates.append((child, value)) + + if value > best_candidate_value: + best_candidate = child + best_candidate_value = value + self.search_stats["total_nodes_in_tree"] += 1 + + if self._is_terminal(new_trajectory): + child.is_terminal = True + self.search_stats["solutions_found"] += 1 + if value > best_value: + best_value = value + best_terminal = child + is_terminal = True + + if self.config.early_termination and value >= self.config.sure_threshold: + self._log_proposal_transition( + depth=depth + 1, + parent_trajectory=node.trajectory, + proposal=prop, + next_state=new_trajectory, + value=value, + is_terminal=is_terminal, + pruned=False, + ) + return self._format_search_result(best_terminal, problem) + else: + is_terminal = False + + pruned = value < self.config.impossible_threshold + if pruned: + self.search_stats["branches_pruned"] += 1 + + self._log_proposal_transition( + depth=depth + 1, + parent_trajectory=node.trajectory, + proposal=prop, + next_state=new_trajectory, + value=value, + is_terminal=is_terminal, + pruned=pruned, + ) + + if pruned: + continue + + # Keep top-k by value + candidates.sort(key=lambda x: x[1], reverse=True) + beam = [child for child, _ in candidates[:self.config.beam_width]] + + if not beam: + break + + return self._format_search_result(best_terminal or best_candidate, problem) + + async def _dfs_search(self, task: str, problem: str, llm_server: Dict) -> Dict[str, Any]: + """Depth-First Search implementation""" + best_terminal = None + best_value = 0.0 + best_candidate = None + best_candidate_value = float('-inf') + + async def dfs(node: TreeNode, depth: int): + nonlocal best_terminal, best_value + + if depth >= self.config.max_depth: + return + + # Generate proposals + proposals = await self.propose_next_steps( + task, + problem, + node.trajectory, + llm_server, + self.config.branching_factor + ) + node.proposals = proposals + + for prop in proposals: + new_trajectory = f"{node.trajectory}\n{prop}" if node.trajectory else prop + value = await self.evaluate_state(task, problem, new_trajectory, llm_server) + + child = TreeNode( + trajectory=new_trajectory, + depth=depth + 1, + value=value, + parent=node, + ) + node.children.append(child) + + if value > best_candidate_value: + best_candidate = child + best_candidate_value = value + self.search_stats["total_nodes_in_tree"] += 1 + + if self._is_terminal(new_trajectory): + child.is_terminal = True + self.search_stats["solutions_found"] += 1 + if value > best_value: + best_value = value + best_terminal = child + is_terminal = True + + if self.config.early_termination and value >= self.config.sure_threshold: + self._log_proposal_transition( + depth=depth + 1, + parent_trajectory=node.trajectory, + proposal=prop, + next_state=new_trajectory, + value=value, + is_terminal=is_terminal, + pruned=False, + ) + return + else: + is_terminal = False + + # Prune + if value >= self.config.impossible_threshold: + self._log_proposal_transition( + depth=depth + 1, + parent_trajectory=node.trajectory, + proposal=prop, + next_state=new_trajectory, + value=value, + is_terminal=is_terminal, + pruned=False, + ) + await dfs(child, depth + 1) + else: + self.search_stats["branches_pruned"] += 1 + self._log_proposal_transition( + depth=depth + 1, + parent_trajectory=node.trajectory, + proposal=prop, + next_state=new_trajectory, + value=value, + is_terminal=is_terminal, + pruned=True, + ) + + await dfs(self.root, 0) + return self._format_search_result(best_terminal or best_candidate, problem) + + # ===================== UTILITIES ===================== + + def _is_terminal(self, trajectory: str) -> bool: + """Check if trajectory represents a complete solution""" + keywords = [ + "final answer", + "reached goal", + "solution:", + "answer:", + "conclusion:", + "result:", + ] + trajectory_lower = trajectory.lower() + return any(kw in trajectory_lower for kw in keywords) + + def _format_search_result( + self, + best_node: Optional[TreeNode], + problem: str + ) -> Dict[str, Any]: + """Format search results for return""" + if best_node: + best_trajectory = best_node.trajectory + best_value = best_node.value + else: + best_trajectory = "" + best_value = 0.0 + + return { + "best_trajectory": best_trajectory, + "best_value": best_value, + "search_stats": self.search_stats, + "decision_tree": self.decision_tree, + "root_node": self.root, + } + + async def _call_llm_streaming( + self, + llm_server: Dict, + prompt: str + ) -> str: + """Call chat-completions endpoint and return the full response text.""" + import httpx + + payload = llm_server["payload"].copy() + payload.pop("prompt", None) + payload.pop("messages", None) + payload["messages"] = [{"role": "user", "content": prompt}] + payload["stream"] = False + + try: + async with httpx.AsyncClient(timeout=None) as client: + response = await client.post( + llm_server["url"], + headers=llm_server["headers"], + json=payload, + ) + response.raise_for_status() + data = response.json() + except Exception as e: + logger.error(f"Error calling LLM: {e}") + raise + + choices = data.get("choices", []) + if not choices: + logger.warning("LLM response missing choices: %s", data.keys()) + return "" + + choice = choices[0] + if isinstance(choice, dict): + msg = choice.get("message") or {} + return msg.get("content") or choice.get("text", "") + return str(choice) + + def get_decision_tree_json(self) -> str: + """Export decision tree as JSON""" + return json.dumps({ + "search_stats": self.search_stats, + "decision_points": self.decision_tree, + "num_decision_points": len(self.decision_tree), + }, indent=2, default=str) + + def _serialize_node(self, node: Optional[TreeNode], max_depth: Optional[int] = None) -> Optional[Dict[str, Any]]: + """Serialize tree nodes recursively for debugging/state inspection.""" + if node is None: + return None + + if max_depth is not None and node.depth >= max_depth: + return { + "depth": node.depth, + "value": node.value, + "is_terminal": node.is_terminal, + "trajectory": node.trajectory, + "num_children": len(node.children), + "children": [], + "truncated": True, + } + + return { + "depth": node.depth, + "value": node.value, + "is_terminal": node.is_terminal, + "trajectory": node.trajectory, + "num_children": len(node.children), + "children": [self._serialize_node(child, max_depth=max_depth) for child in node.children], + } + + def get_state_snapshot( + self, + include_tree: bool = True, + max_tree_depth: Optional[int] = None, + decision_tail: Optional[int] = 50, + include_cache_samples: bool = True, + cache_sample_size: int = 5, + ) -> Dict[str, Any]: + """Return a comprehensive snapshot of the current ToT search state.""" + proposal_cache_keys = list(self.proposal_cache.keys()) + evaluation_cache_keys = list(self.evaluation_cache.keys()) + + snapshot: Dict[str, Any] = { + "config": { + "branching_factor": self.config.branching_factor, + "max_depth": self.config.max_depth, + "search_method": self.config.search_method.value, + "beam_width": self.config.beam_width, + "sure_threshold": self.config.sure_threshold, + "likely_threshold": self.config.likely_threshold, + "impossible_threshold": self.config.impossible_threshold, + "early_termination": self.config.early_termination, + "cache_evaluations": self.config.cache_evaluations, + "max_candidates_per_level": self.config.max_candidates_per_level, + }, + "search_stats": dict(self.search_stats), + "decision_tree_size": len(self.decision_tree), + "cache_state": { + "proposal_cache_size": len(self.proposal_cache), + "evaluation_cache_size": len(self.evaluation_cache), + }, + "root_present": self.root is not None, + "root_depth": self.root.depth if self.root is not None else None, + "root_value": self.root.value if self.root is not None else None, + "root_is_terminal": self.root.is_terminal if self.root is not None else None, + } + + if decision_tail is None: + snapshot["decision_tree"] = self.decision_tree + else: + snapshot["decision_tree_tail"] = self.decision_tree[-decision_tail:] + + if include_cache_samples: + sample_size = max(0, cache_sample_size) + snapshot["cache_state"]["proposal_cache_key_samples"] = proposal_cache_keys[:sample_size] + snapshot["cache_state"]["evaluation_cache_key_samples"] = evaluation_cache_keys[:sample_size] + + if include_tree: + snapshot["tree"] = self._serialize_node(self.root, max_depth=max_tree_depth) + + return snapshot + + def get_state_snapshot_json( + self, + include_tree: bool = True, + max_tree_depth: Optional[int] = None, + decision_tail: Optional[int] = 50, + include_cache_samples: bool = True, + cache_sample_size: int = 5, + indent: int = 2, + ) -> str: + """Return JSON string for the current ToT state snapshot.""" + snapshot = self.get_state_snapshot( + include_tree=include_tree, + max_tree_depth=max_tree_depth, + decision_tail=decision_tail, + include_cache_samples=include_cache_samples, + cache_sample_size=cache_sample_size, + ) + return json.dumps(snapshot, indent=indent, default=str) \ No newline at end of file diff --git a/interwhen/value_prompts.py b/interwhen/value_prompts.py new file mode 100644 index 0000000..1f1fd19 --- /dev/null +++ b/interwhen/value_prompts.py @@ -0,0 +1,432 @@ +""" +Value prompts with few-shot examples for Tree of Thought evaluation across datasets. +""" + +# ============================================================================ +# GAME24 VALUE PROMPTS WITH FEW-SHOT EXAMPLES +# ============================================================================ + +GAME24_VALUE_PROMPT_WITH_FEWSHOT = """Evaluate if given numbers can reach 24 (sure/likely/impossible). + +PROBLEM STATEMENT: +{problem} + +CURRENT TRAJECTORY: +{trajectory} + +Here are examples of how to evaluate Game of 24 trajectories: + +EXAMPLE 1 - SURE: +Numbers: 4 4 6 8 +Trajectory: 4 + 8 = 12 (left: 4 6 12), 6 - 4 = 2 (left: 2 12), 2 * 12 = 24 +Analysis: Reaches exactly 24 using each number exactly once. +Confidence: sure (9) + +EXAMPLE 2 - SURE: +Numbers: 2 9 10 12 +Trajectory: 12 * 2 = 24 (left: 9 10 24), then 24 * (10 - 9) = 24 * 1 = 24 +Analysis: Valid path found with each number used once, equals 24. +Confidence: sure (9) + +EXAMPLE 3 - SURE: +Numbers: 10 14 +Trajectory: 10 + 14 = 24 +Analysis: Direct calculation using both numbers reaches exactly 24. +Confidence: sure (9) + +EXAMPLE 4 - SURE: +Numbers: 4 4 10 +Trajectory: (10 - 4) * 4 = 6 * 4 = 24 +Analysis: Uses all three numbers exactly once with factorization that reaches 24. +Confidence: sure (9) + +EXAMPLE 5 - SURE: +Numbers: 4 9 11 +Trajectory: 9 + 11 + 4 = 24 +Analysis: All numbers used, arithmetic valid, equals 24. +Confidence: sure (9) + +EXAMPLE 6 - LIKELY: +Numbers: 5 7 8 +Trajectory: 5 + 7 + 8 = 20, or (8 - 5) * 7 = 21 +Analysis: Cannot reach 24 immediately, but numbers in reasonable arithmetic range where 24 might be achievable. +Confidence: likely (7) + +EXAMPLE 7 - LIKELY: +Numbers: 5 6 6 +Trajectory: 5 + 6 + 6 = 17, or (6 - 5) * 6 = 6 +Analysis: Current attempts don't reach 24, but numbers are within reasonable range. +Confidence: likely (7) + +EXAMPLE 8 - IMPOSSIBLE: +Numbers: 1 3 3 +Trajectory: 1 * 3 * 3 = 9, or (1 + 3) * 3 = 12 +Analysis: Maximum reachable with any operations is much less than 24. Numbers all too small. +Confidence: impossible (1) + +EXAMPLE 9 - IMPOSSIBLE: +Numbers: 10 10 11 +Trajectory: 10 + 10 + 11 = 31, or (11 - 10) * 10 = 10 +Analysis: Sum exceeds 24, factorizations fall short. Cannot reach exactly 24. +Confidence: impossible (1) + +EXAMPLE 10 - IMPOSSIBLE: +Numbers: 11 12 +Trajectory: 11 + 12 = 23, or 12 - 11 = 1, or 11 * 12 = 132, or 11 / 12 ≈ 0.91 +Analysis: No operation reaches 24. Sum close but not exact. +Confidence: impossible (1) + +Rubric: +- sure (9): Reaches 24 using each number exactly once +- likely (7): Cannot reach 24 yet, but numbers in reasonable range +- possible (5): Uncertain if 24 is reachable +- unlikely (3): Numbers seem misaligned +- impossible (1): Numbers demonstrably cannot reach 24 + +Respond with "Confidence: " followed by brief justification tied to arithmetic evaluations. +""" + +GAME24_VALUE_PROMPT_SIMPLE = """Evaluate if given numbers can reach 24. + +PROBLEM STATEMENT: +{problem} + +CURRENT TRAJECTORY: +{trajectory} + +Rate confidence on this rubric: +- sure (9): Reaches 24 using each number exactly once +- likely (7): Cannot reach 24 yet, but numbers in reasonable range +- possible (5): Uncertain if 24 is reachable +- unlikely (3): Numbers seem misaligned +- impossible (1): Numbers cannot reach 24 + +Respond with "Confidence: " and brief justification. +""" + +# ============================================================================ +# MAZE VALUE PROMPTS WITH FEW-SHOT EXAMPLES +# ============================================================================ + +MAZE_VALUE_PROMPT_WITH_FEWSHOT = """Verify a maze reasoning trace. + +TASK PROMPT: +{problem} + +MODEL TRAJECTORY: +{trajectory} + +EXAMPLE 1 - SURE: +Question: Count right turns in path X from S to E +Trajectory: Carefully trace X-marked path. Starting at S, move UP (initial direction). Then RIGHT (90 degrees clockwise = right turn 1). Then DOWN (90 degrees clockwise = right turn 2). Then RIGHT (90 degrees clockwise = right turn 3). Continuing pattern: 6 right turns total. +Answer: B (6 right turns) +Analysis: Systematic path tracing with correct turn geometry, defensible count. +Confidence: sure (9) + +EXAMPLE 2 - SURE: +Question: What is the sequence of grid direction? +Trajectory: Following marked path from S: [0,0] then [0,1] (UP) then [1,1] (RIGHT) then [1,0] (DOWN) then [2,0] (RIGHT). Each step verified against grid. +Answer: UP, RIGHT, DOWN, RIGHT +Analysis: Clear coordinate tracking, systematic verification. +Confidence: sure (9) + +EXAMPLE 3 - LIKELY: +Question: Count right turns in path from S to E +Trajectory: Observing marked path shows mostly straight movements with a zigzag pattern. Zigzags suggest mostly left turns. Likely 0-2 right turns based on pattern. +Answer: A (0 right turns) +Confidence: likely (7) +Analysis: Shows reasonable spatial intuition but lacks systematic verification. + +EXAMPLE 4 - LIKELY: +Question: Is path continuous S to E? +Trajectory: I trace the marked path and it appears to connect from S all the way to E without breaks. The X marks form a continuous line. +Answer: Yes +Confidence: likely (7) +Analysis: Reasonable assessment but could benefit from detailed step verification. + +EXAMPLE 5 - POSSIBLE: +Question: Navigate maze from S to E +Trajectory: Following path X... I see turns but the specific sequence is unclear to me. Could be 3 or 4 right turns. +Answer: Uncertain between options +Confidence: possible (5) +Analysis: Recognizes task but cannot decisively trace path geometry. + +EXAMPLE 6 - UNLIKELY: +Question: Count right turns in path X +Trajectory: Tracing marked path X from S. Moving DOWN, then RIGHT (left turn?), then DOWN, then RIGHT (right turn?). I'm confused about turn geometry. +Answer: Some right turns but not sure +Confidence: unlikely (3) +Analysis: Confused about path following and angle identification. + +EXAMPLE 7 - IMPOSSIBLE: +Question: Navigate from S to E following marked path +Trajectory: I move RIGHT, then up, then left, I'm not sure where the path goes. I think I hit a wall. +Answer: I'm stuck +Confidence: impossible (1) +Analysis: Abandons task without following clearly marked X path provided. + +Rubric: sure (9), likely (7), possible (5), unlikely (3), impossible (1) +Respond with "Confidence: " + explanation referencing moves/directions. +""" + +MAZE_VALUE_PROMPT_SIMPLE = """Verify a maze reasoning trace. + +TASK PROMPT: +{problem} + +MODEL TRAJECTORY: +{trajectory} + +Judge if reasoning is consistent with maze/spatial relationships and if final answer is defensible. +Rubric: sure (9), likely (7), possible (5), unlikely (3), impossible (1) +Respond with "Confidence: " + explanation referencing moves/directions. +""" + +# ============================================================================ +# SPATIAL REASONING VALUE PROMPTS WITH FEW-SHOT EXAMPLES +# ============================================================================ + +SPATIALMAP_VALUE_PROMPT_WITH_FEWSHOT = """You are verifying a spatial reasoning multiple-choice trace. + +TASK PROMPT: +{problem} + +MODEL TRAJECTORY: +{trajectory} + +Here are examples of how to evaluate spatial reasoning: + +EXAMPLE 1: +Question: Based on the map, which location is northeast of the library? +Trajectory: I look at the map and see the library in the center. To the northeast means both north AND east of that point. Looking at that quadrant, I see the museum is northeast of the library. +Answer: A (Museum) + +Analysis: The student correctly understands the spatial direction (northeast = north AND east), correctly identifies it on the map, and selects the correct option. +Confidence: sure/certain (9) +Justification: The reasoning correctly applies spatial relationships and identifies the appropriate location. + +EXAMPLE 2: +Question: Which building is closest to the park? +Trajectory: The park looks like it's in the middle of the map. Near it I see a building... looks like it could be the school or maybe the library. I think the school is closer. +Answer: C (School) + +Analysis: The student makes reasonable spatial observations but doesn't verify distances or compare alternatives systematically. +Confidence: likely/probably (7) +Justification: The reasoning shows spatial awareness but lacks systematic comparison of distances to verify the answer. + +EXAMPLE 3: +Question: What is north of the train station? +Trajectory: I see the train station. North is... up on the map. I see some buildings, but I'm not sure exactly which one. Could be the post office or the police station. +Answer: I'm not sure + +Analysis: The student recognizes the direction but fails to identify the specific location clearly. +Confidence: possible/maybe (5) +Justification: The student understands the spatial direction but cannot decisively identify which building is in that location. + +EXAMPLE 4: +Question: If you're at the bank facing east, what's behind you? +Trajectory: At the bank facing east means I'm looking east. Behind me would be... west? I need to think about what's west of the bank. I think there's a hotel or a store but I'm not sure. +Answer: Maybe the hotel + +Analysis: The student correctly understands relative directions (east/behind = west) but isn't certain about the specific feature. +Confidence: unlikely/doubtful (3) +Justification: While the directional reasoning is sound, the uncertainty about the specific location makes this answer questionable. + +Use the confidence rubric: sure/certain (9), likely/probably (7), possible/maybe (5), unlikely/doubtful (3), impossible/blocked (1). +Respond with "Confidence: " plus a concise explanation that references spatial relationships and map features. +""" + +SPATIALMAP_VALUE_PROMPT_SIMPLE = """You are verifying a spatial reasoning multiple-choice trace. + +TASK PROMPT: +{problem} + +MODEL TRAJECTORY: +{trajectory} + +Judge if the reasoning correctly applies spatial relationships (north, south, east, west, near, far, etc.) and whether the final \\boxed{{choice}} is defensible. +Use the confidence rubric: sure/certain (9), likely/probably (7), possible/maybe (5), unlikely/doubtful (3), impossible/blocked (1). +Respond with "Confidence: " plus a concise explanation that references the spatial relationships and locations. +""" +# ============================================================================ +# ZEBRA LOGIC VALUE PROMPTS WITH FEW-SHOT EXAMPLES +# ============================================================================ + +ZEBRALOGIC_VALUE_PROMPT_WITH_FEWSHOT = """Evaluate a Zebra Logic puzzle solution trajectory. + +TASK PROMPT: +{problem} + +MODEL TRAJECTORY: +{trajectory} + +Here are examples of how to evaluate Zebra Logic trajectories: + +EXAMPLE 1 - SURE: +Puzzle: Houses with colors, pets, beverages, and nationalities with clues about relationships. +Trajectory: I've systematically worked through the constraints. House 1 has British resident. Red house owner has Panda. Coffee drinker speaks Japanese. Working through elimination, I've determined all houses uniquely and the solution satisfies all clues without contradictions. +Analysis: Systematic constraint satisfaction with clear justification for each assignment. Solution verifiable. +Confidence: sure (9) + +EXAMPLE 2 - LIKELY: +Trajectory: Working through the clues methodically. I've identified several definite assignments (House 2 has Swedish resident with bird). For the remaining houses, the constraints are narrowing down possibilities and should lead to a unique solution. +Analysis: Reasonable progress using logic, but not yet complete verification of all constraints. +Confidence: likely (7) + +EXAMPLE 3 - POSSIBLE: +Trajectory: I understand the puzzle structure. I'm working through clues but some deductions are unclear to me. I think House 1 might have the British resident, but I'm not certain. +Analysis: Shows problem understanding but lacks decisive constraint application. +Confidence: possible (5) + +EXAMPLE 4 - UNLIKELY: +Trajectory: I'm trying to assign attributes to houses. House 1 has red color and Swedish resident. House 2 has green... wait, but green is next to red. I'm getting confused by the adjacency constraints. +Analysis: Fundamental misunderstanding of spatial/logical constraints. +Confidence: unlikely (3) + +EXAMPLE 5 - IMPOSSIBLE: +Trajectory: I'm going to assign all attributes randomly since I don't see how the clues relate to each other. +Analysis: Abandons logical reasoning without attempting systematic constraint satisfaction. +Confidence: impossible (1) + +Rubric for Zebra Logic: +- sure (9): Complete solution derived with clear constraint verification, all assignments justified +- likely (7): Systematic progress with mostly confident deductions, minor uncertainties remain +- possible (5): Some correct deductions but missing clear constraint application +- unlikely (3): Attempting logic but making errors in constraint application or showing confusion +- impossible (1): No meaningful attempt at systematic constraint satisfaction + +Respond with "Confidence: " followed by brief justification referencing the logical deductions and constraint satisfaction. +""" + +ZEBRALOGIC_VALUE_PROMPT_SIMPLE = """Evaluate a Zebra Logic puzzle solution trajectory. + +TASK PROMPT: +{problem} + +MODEL TRAJECTORY: +{trajectory} + +Judge if the reasoning systematically applies logical constraints and whether the solution assignments are well-justified. +Use the confidence rubric: +- sure (9): Complete solution with clear constraint verification +- likely (7): Systematic progress with mostly confident deductions +- possible (5): Some correct deductions with minor gaps +- unlikely (3): Attempting logic but making constraint errors +- impossible (1): No meaningful systematic reasoning + +Respond with "Confidence: " and brief justification referencing constraint satisfaction. +""" +# ============================================================================ +# GENERIC VALUE PROMPT +# ============================================================================ + +GENERIC_VALUE_PROMPT_WITH_FEWSHOT = """Evaluate how close the following trajectory is to solving the problem. + +PROBLEM: +{problem} + +CURRENT TRAJECTORY: +{trajectory} + +Here are examples of trajectory evaluations: + +EXAMPLE 1 (Strong progress): +Trajectory: I've broken down the problem into steps, identified key constraints, and I'm halfway through the solution with correct logic so far. +Confidence: likely/probably (7) +Justification: Clear methodology and correct intermediate progress toward the solution. + +EXAMPLE 2 (Uncertain progress): +Trajectory: I've started the problem and my approach seems reasonable, but I'm not confident about the next steps. +Confidence: possible/maybe (5) +Justification: Direction is sound but execution and completeness require verification. + +EXAMPLE 3 (Low chance of success): +Trajectory: I tried an approach but it seems to have led to a contradiction. +Confidence: unlikely/doubtful (3) +Justification: The approach has fundamental issues that need to be reconsidered. + +Rate the state on the scale: sure/certain (9), likely/probably (7), possible/maybe (5), unlikely/doubtful (3), impossible/blocked (1). +Respond with "Confidence: " and a short rationale. +""" + +GENERIC_VALUE_PROMPT_SIMPLE = """Evaluate how close the following trajectory is to solving the problem. + +PROBLEM: +{problem} + +CURRENT TRAJECTORY: +{trajectory} + +Rate the state on the scale: sure/certain (9), likely/probably (7), possible/maybe (5), unlikely/doubtful (3), impossible/blocked (1). +Respond with "Confidence: " and a short rationale. +""" + + +def build_game24_value_prompt(problem: str, trajectory: str, use_fewshot: bool = True) -> str: + """Build game24 value prompt with or without few-shot examples.""" + if use_fewshot: + return GAME24_VALUE_PROMPT_WITH_FEWSHOT.format(problem=problem, trajectory=trajectory) + else: + return GAME24_VALUE_PROMPT_SIMPLE.format(problem=problem, trajectory=trajectory) + + +def build_mcq_value_prompt( + problem: str, trajectory: str, task_name: str, use_fewshot: bool = True +) -> str: + """Build MCQ (maze/spatial) value prompt with or without few-shot examples.""" + if task_name.lower() == "maze": + if use_fewshot: + return MAZE_VALUE_PROMPT_WITH_FEWSHOT.format(problem=problem, trajectory=trajectory) + else: + return MAZE_VALUE_PROMPT_SIMPLE.format(problem=problem, trajectory=trajectory) + elif task_name.lower() in ("spatial", "spatialmap", "spatial reasoning"): + if use_fewshot: + return SPATIALMAP_VALUE_PROMPT_WITH_FEWSHOT.format(problem=problem, trajectory=trajectory) + else: + return SPATIALMAP_VALUE_PROMPT_SIMPLE.format(problem=problem, trajectory=trajectory) + else: + # Default MCQ template + if use_fewshot: + return MAZE_VALUE_PROMPT_WITH_FEWSHOT.format(problem=problem, trajectory=trajectory) + else: + return MAZE_VALUE_PROMPT_SIMPLE.format(problem=problem, trajectory=trajectory) + + +def build_generic_value_prompt(problem: str, trajectory: str, use_fewshot: bool = True) -> str: + """Build generic value prompt with or without few-shot examples.""" + if use_fewshot: + return GENERIC_VALUE_PROMPT_WITH_FEWSHOT.format(problem=problem, trajectory=trajectory) + else: + return GENERIC_VALUE_PROMPT_SIMPLE.format(problem=problem, trajectory=trajectory) + + +def build_zebralogic_value_prompt(problem: str, trajectory: str, use_fewshot: bool = True) -> str: + """Build zebralogic value prompt with or without few-shot examples.""" + if use_fewshot: + return ZEBRALOGIC_VALUE_PROMPT_WITH_FEWSHOT.format(problem=problem, trajectory=trajectory) + else: + return ZEBRALOGIC_VALUE_PROMPT_SIMPLE.format(problem=problem, trajectory=trajectory) + + +def build_tot_value_prompt(task: str, problem: str, trajectory: str, use_fewshot: bool = True) -> str: + """ + Build value prompt for Tree of Thought evaluation. + + Args: + task: The task type (e.g., "game24", "maze", "spatialmap", "zebralogic") + problem: The original problem statement + trajectory: Current partial solution + use_fewshot: Whether to use few-shot examples (default True) + + Returns: + Formatted value prompt + """ + if task == "game24": + return build_game24_value_prompt(problem, trajectory, use_fewshot) + if task == "maze": + return build_mcq_value_prompt(problem, trajectory, "maze", use_fewshot) + if task == "spatialmap": + return build_mcq_value_prompt(problem, trajectory, "spatial reasoning", use_fewshot) + if task == "zebralogic": + return build_zebralogic_value_prompt(problem, trajectory, use_fewshot) + return build_generic_value_prompt(problem, trajectory, use_fewshot) \ No newline at end of file From 73fc80a1516eda669c78b26c131f128ac6ed3fb2 Mon Sep 17 00:00:00 2001 From: Prateek Chanda Date: Mon, 2 Mar 2026 04:55:27 +0000 Subject: [PATCH 4/4] pull rebase --- .../TTSwithVerification/bestofk_baseline.py | 608 ------------------ 1 file changed, 608 deletions(-) delete mode 100644 examples/TTSwithVerification/bestofk_baseline.py diff --git a/examples/TTSwithVerification/bestofk_baseline.py b/examples/TTSwithVerification/bestofk_baseline.py deleted file mode 100644 index 8864a38..0000000 --- a/examples/TTSwithVerification/bestofk_baseline.py +++ /dev/null @@ -1,608 +0,0 @@ -import argparse -import asyncio -import json -import logging -import os -import re -from dataclasses import dataclass -from typing import List, Optional - -import numpy as np -from datasets import load_dataset -from tqdm import tqdm -from transformers import AutoTokenizer - -from interwhen import stream_completion - -# ============== MODEL CONFIGURATION ============== -MAIN_MODEL = "Qwen/Qwen3-30B-A3B-Thinking-2507" -# ================================================= - -logger = logging.getLogger(__name__) - - -@dataclass -class SampleResult: - output: str - correct: bool - extracted: Optional[str] - message: str - tokens: int - critic_correct: Optional[bool] = None - - -def get_model_short_name(model_name: str) -> str: - short_name = model_name.split("/")[-1] - return short_name.replace(" ", "_").replace(":", "-") - - -def get_output_dirs(task: str, main_model: str, base_dir: str = "../../b-pchanda/Outputs_TTS/BestOfKResults"): - model_short_name = get_model_short_name(main_model) - output_base = os.path.join(base_dir, task, model_short_name) - dirs = { - "base": output_base, - "reasoning": os.path.join(output_base, "Reasoning_output"), - } - for dir_path in dirs.values(): - os.makedirs(dir_path, exist_ok=True) - return dirs - - -def init_llm_server(model_name, max_tokens=32768, port=8000, temperature=0.6, seed=42): - url = f"http://localhost:{port}/v1/completions" - payload = { - "model": model_name, - "max_tokens": max_tokens, - "top_k": 20, - "top_p": 0.95, - "min_p": 0.0, - "do_sample": True, - "temperature": temperature, - "stream": True, - "logprobs": 20, - "use_beam_search": False, - "prompt_cache": True, - "seed": seed, - } - headers = {"Content-Type": "application/json"} - return {"url": url, "payload": payload, "headers": headers} - - -def count_tokens(text: str, tokenizer) -> int: - tokens = tokenizer.encode(text, add_special_tokens=False) - return len(tokens) - - -def save_outputs(idx: int, outputs: List[SampleResult], best_idx: int, output_dir: str): - os.makedirs(output_dir, exist_ok=True) - filepath = os.path.join(output_dir, f"output_{idx}.txt") - with open(filepath, "w", encoding="utf-8") as f: - f.write(f"BEST_INDEX={best_idx}\n") - for i, result in enumerate(outputs): - f.write("\n" + "=" * 80 + "\n") - f.write(f"SAMPLE {i}\n") - f.write(f"CORRECT={result.correct}\n") - f.write(f"CRITIC_CORRECT={result.critic_correct}\n") - f.write(f"EXTRACTED={result.extracted}\n") - f.write(f"TOKENS={result.tokens}\n") - f.write(f"MESSAGE={result.message}\n") - f.write("\n") - f.write(result.output) - f.write("\n") - logger.info(f"Saved outputs to {filepath}") - - -# --------------------- Game24 helpers --------------------- - -def build_game24_prompt(nums): - a, b, c, d = nums - boxed = r"\\boxed{}" - base_prompt = f""" -You are solving the Game of 24. - -You are given four numbers: {a}, {b}, {c}, {d} - -Your job is to produce a valid arithmetic expression using: -- ALL four numbers exactly once -- ONLY +, -, *, / -- The expression must evaluate to exactly 24. - -Please reason step by step, and put your final answer containing only the expression within {boxed}. -""".strip() - return base_prompt - - -def extract_solution_game24(text): - boxed_pattern = r"\\boxed\{" - matches = list(re.finditer(boxed_pattern, text)) - if not matches: - return None - last_match = matches[-1] - start = last_match.end() - brace_count = 1 - end = start - while end < len(text) and brace_count > 0: - if text[end] == "{": - brace_count += 1 - elif text[end] == "}": - brace_count -= 1 - end += 1 - expr = text[start:end - 1].strip() - - frac_pattern = r"\\frac\{([^{}]+)\}\{([^{}]+)\}" - while re.search(frac_pattern, expr): - expr = re.sub(frac_pattern, r"(\1/\2)", expr) - - replacements = { - r"\times": "*", - r"\cdot": "*", - r"\div": "/", - } - for latex, op in replacements.items(): - expr = expr.replace(latex, op) - - expr = expr.replace(r"\\,", "").replace(r"\\ ", "") - expr = re.sub(r"\)\s*\(", ")*(", expr) - expr = re.sub(r"\)\s*(\d)", r")*\1", expr) - expr = re.sub(r"(\d)\s*\(", r"\1*(", expr) - - return expr - - -def extract_numbers_from_expr(expr): - numbers = re.findall(r"\d+\.?\d*", expr) - return [int(float(n)) if float(n).is_integer() else float(n) for n in numbers] - - -def validate_numbers_used(expr, expected_nums): - used_nums = extract_numbers_from_expr(expr) - return sorted(used_nums) == sorted(expected_nums) - - -def evaluate_expression(expr, expected_nums=None): - try: - if expected_nums is not None and not validate_numbers_used(expr, expected_nums): - return False - value = eval(expr, {"__builtins__": None}, {}) - return abs(value - 24) < 1e-6 - except Exception: - return False - - -def evaluate_game24_answer(answer, nums): - expr = extract_solution_game24(answer) - if not expr: - return False, None, "No expression found" - if evaluate_expression(expr, expected_nums=nums): - return True, expr, "Correct solution (evaluates to 24 using exactly the given numbers)" - used_nums = extract_numbers_from_expr(expr) - if sorted(used_nums) != sorted(nums): - return False, expr, f"Incorrect: Expression uses {used_nums}, expected {nums}" - return False, expr, "Expression does not evaluate to 24" - - -# --------------------- Maze/SpatialMap helpers --------------------- - -def remove_last_paragraph(s: str) -> str: - return s[:-143] if len(s) > 143 else s - - -def build_maze_prompt(example): - pre_prompt = ( - "You are an expert problem solver. Carefully read the following multiple-choice question " - "and think through the solution step-by-step before providing your final answer. " - "Provide your final answer option by enclosing it within \\boxed{A/B/C/D}.:" - ) - description = remove_last_paragraph(str(example.get("prompt"))) - return pre_prompt, description - - -def build_spatialmap_prompt(example): - pre_prompt = ( - "You are an expert problem solver. Carefully read the following multiple-choice question " - "and think through the solution step-by-step before providing your final answer." - "Provide your final answer option by enclosing it within \\boxed{A/B/C/D}.:" - ) - description = remove_last_paragraph(str(example.get("prompt"))) - return pre_prompt, description - - -def extract_solution_mcq(text): - matches = re.findall(r"\\boxed\{([^}]*)\}", text) - if not matches: - return None - expr = matches[-1].strip() - choice_match = re.search(r"\b([ABCD])\b", expr, flags=re.IGNORECASE) - if not choice_match: - return None - return choice_match.group(1).upper() - - -def extract_options_from_prompt(prompt_text, target_options): - pattern = r"\b([A-D])\.\s*(.*?)(?=\s*[A-D]\.\s*|$)" - raw = re.findall(pattern, prompt_text, flags=re.DOTALL) - options = {k: v.strip().rstrip(".") for k, v in raw} - if target_options: - options = {k: v for k, v in options.items() if k in target_options} - return options - - -def evaluate_mcq_answer(answer, options, ground_truth): - sol = extract_solution_mcq(answer) - gt_sol = str(ground_truth).strip() - if not sol: - return False, None, "No expression found" - sol = sol.strip() - if sol in options: - if options[sol] == gt_sol: - return True, sol, f"Correct: option {sol} -> {options[sol]}" - return False, sol, f"Incorrect: expected '{gt_sol}', got '{options[sol]}' (option {sol})" - if sol.lower() == gt_sol.lower(): - return True, sol, f"Correct: answer text matches ground truth: {sol}" - for opt_letter, opt_value in options.items(): - if sol.lower() == opt_value.lower(): - if opt_value == gt_sol: - return True, sol, f"Correct: answer text {sol} (option {opt_letter})" - return False, sol, f"Incorrect: expected '{gt_sol}', got '{opt_value}' (option {opt_letter})" - return False, sol, f"Solution '{sol}' not found in options or ground truth" - - -def build_full_prompt(task, example, nums=None): - if task == "game24": - prompt = build_game24_prompt(nums) - return f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n" - if task == "maze": - system_prompt, user_prompt = build_maze_prompt(example) - else: - system_prompt, user_prompt = build_spatialmap_prompt(example) - return ( - f"<|im_start|>system\n{system_prompt}<|im_end|>\n" - f"<|im_start|>user\n{user_prompt}<|im_end|>\n" - f"<|im_start|>assistant\n" - ) - - -def load_dataset_for_task(task): - if task == "game24": - return load_dataset("nlile/24-game", split="train") - if task == "maze": - return load_dataset("microsoft/VISION_LANGUAGE", "maze", split="val") - if task == "spatialmap": - return load_dataset("microsoft/VISION_LANGUAGE", "spatial_map_text_only", split="val") - raise ValueError(f"Unsupported task: {task}") - - -def resolve_indices(task, dataset_len, args): - if args.indices: - return [int(x.strip()) for x in args.indices.split(",")] - if args.xrange: - parts = args.xrange.split("-") - if len(parts) == 2: - try: - start = int(parts[0].strip()) - end = int(parts[1].strip()) - return range(start, end) - except ValueError: - raise ValueError(f"Invalid xrange format: {args.xrange}. Use 'start-end'") - if args.num_examples: - return np.linspace(0, dataset_len - 1, args.num_examples, dtype=int) - # Default: use full range - start = args.start if args.start is not None else 0 - end = args.end if args.end is not None else dataset_len - return range(start, end) - - -def run_k_samples(prompt, llm_server, k, seed): - outputs = [] - for i in range(k): - llm_server["payload"]["seed"] = seed + i - outputs.append(asyncio.run(stream_completion( - prompt, - llm_server=llm_server, - monitors=(), - add_delay=False, - termination_requires_validation=False, - async_execution=True, - ))) - return outputs - - -# --------------------- Critic model helpers --------------------- - -def build_game24_critic_prompt(nums, reasoning_output): - """Build critic prompt to evaluate Game of 24 solution.""" - return f"""You are a math verifier. Evaluate the following Game of 24 solution. - -Numbers: {nums} -Target: 24 - -Student's reasoning and answer: -{reasoning_output} - -Verify: -1. Does it use ALL four numbers exactly once? -2. Does each step follow correct arithmetic? -3. Does the final expression evaluate to exactly 24? - -Respond with ONLY: "CORRECT" or "INCORRECT" -""" - - -def build_mcq_critic_prompt(task, task_description, reasoning_output): - """Build critic prompt to evaluate MCQ solution.""" - task_name = "Maze" if task == "maze" else "Spatial Reasoning" - return f"""You are an expert {task_name} verifier. Evaluate the following solution. - -Task: -{task_description} - -Student's reasoning and answer: -{reasoning_output} - -Verify the correctness of the step-by-step reasoning and final answer. - -Respond with ONLY: "CORRECT" or "INCORRECT" -""" - - -async def evaluate_with_critic(output_text, task, example, critic_llm_server, tokenizer, nums=None): - """Use critic model to evaluate correctness of output.""" - try: - if task == "game24": - critic_prompt = build_game24_critic_prompt(nums, output_text) - else: - if task == "maze": - _, task_desc = build_maze_prompt(example) - else: - _, task_desc = build_spatialmap_prompt(example) - critic_prompt = build_mcq_critic_prompt(task, task_desc, output_text) - - critic_system = "You are a strict academic verifier." - full_prompt = f"<|im_start|>system\n{critic_system}<|im_end|>\n<|im_start|>user\n{critic_prompt}<|im_end|>\n<|im_start|>assistant\n" - - critic_output = await stream_completion( - full_prompt, - llm_server=critic_llm_server, - monitors=(), - add_delay=False, - termination_requires_validation=False, - async_execution=True, - ) - - is_correct = "CORRECT" in critic_output.upper() - return is_correct, critic_output - except Exception as e: - logger.warning(f"Critic evaluation failed: {e}") - return False, "" - - -def run_k_samples_with_critic( - prompt, - llm_server, - critic_llm_server, - k, - seed, - task, - example, - tokenizer, - eval_fn, - nums=None, - early_stop=False, -): - """Run up to K samples, evaluate with critic, and score with ground truth.""" - sample_results = [] - for i in range(k): - llm_server["payload"]["seed"] = seed + i - output = asyncio.run(stream_completion( - prompt, - llm_server=llm_server, - monitors=(), - add_delay=False, - termination_requires_validation=False, - async_execution=True, - )) - - critic_correct, critic_response = asyncio.run(evaluate_with_critic( - output, task, example, critic_llm_server, tokenizer, nums=nums - )) - is_correct, extracted, message = eval_fn(output) - token_count = count_tokens(output, tokenizer) - - sample_results.append(SampleResult( - output=output, - correct=is_correct, - extracted=extracted, - message=f"Critic verdict: {'CORRECT' if critic_correct else 'INCORRECT'} | {message}", - tokens=token_count, - critic_correct=critic_correct, - )) - - if early_stop and critic_correct: - break - - return sample_results - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Best-of-K baseline (standard CoT) for TTSwithVerification datasets") - parser.add_argument("--task", type=str, required=True, choices=["game24", "maze", "spatialmap"], - help="Task to run") - parser.add_argument("--k", type=int, default=4, help="Number of samples per example") - parser.add_argument("--num_examples", "-n", type=int, default=None, - help="Number of examples to run (overrides start/end)") - parser.add_argument("--indices", type=str, default=None, - help="Comma-separated indices to run") - parser.add_argument("--xrange", type=str, default=None, - help="Range of indices to run (format: 'start-end')") - parser.add_argument("--start", type=int, default=None, help="Start index") - parser.add_argument("--end", type=int, default=None, help="End index") - parser.add_argument("--main_model", type=str, default=MAIN_MODEL, help="Main model to use for generation") - parser.add_argument("--port", type=int, default=8000, help="vLLM server port") - parser.add_argument("--use_critic", action="store_true", help="Use critic model for evaluation instead of ground truth") - parser.add_argument("--critic_model", type=str, default=MAIN_MODEL, help="Critic model to use for evaluation") - parser.add_argument("--critic_port", type=int, default=8000, help="vLLM server port for critic model (default: same as main model port)") - parser.add_argument("--critic_early_stop", action="store_true", help="Stop sampling after first critic-correct trace") - parser.add_argument("--seed", type=int, default=42, help="Base random seed") - parser.add_argument("--max_tokens", type=int, default=32768, help="Max tokens for generation") - parser.add_argument("--temperature", type=float, default=0.6, help="Sampling temperature") - parser.add_argument("--debug", "-d", action="store_true", help="Enable debug logging") - args = parser.parse_args() - - log_level = logging.DEBUG if args.debug else logging.INFO - logging.basicConfig(level=log_level, format="%(message)s") - - - dataset = load_dataset_for_task(args.task) - indices = resolve_indices(args.task, len(dataset), args) - - llm_server = init_llm_server( - args.main_model, - max_tokens=args.max_tokens, - port=args.port, - temperature=args.temperature, - seed=args.seed, - ) - - critic_llm_server = None - if args.use_critic: - critic_llm_server = init_llm_server( - args.critic_model, - max_tokens=512, - port=args.critic_port, - temperature=0.2, - seed=args.seed, - ) - logger.info(f"Using critic model: {args.critic_model} on port {args.critic_port}") - - logger.info(f"Loading tokenizer for {args.main_model}...") - tokenizer = AutoTokenizer.from_pretrained(args.main_model, trust_remote_code=True) - logger.info("Tokenizer loaded successfully.") - - output_dirs = get_output_dirs(args.task, args.main_model) - - total_examples = 0 - total_correct = 0 - total_correct_samples = 0 - total_samples = 0 - critic_correct_samples = 0 - critic_total_samples = 0 - total_tokens = 0 - total_tokens_all_samples = 0 - results = [] - - for idx in tqdm(indices, desc="Processing examples", unit="example"): - example = dataset[int(idx)] - if args.task == "game24": - nums = example["numbers"] - prompt = build_full_prompt(args.task, example, nums=nums) - eval_fn = lambda output: evaluate_game24_answer(output, nums) - options = None - else: - prompt = build_full_prompt(args.task, example) - gt = str(example.get("ground_truth", "")).strip() - if gt == "Q4": - target_options = ["A", "B"] - else: - target_options = ["A", "B", "C", "D"] - if args.task == "maze": - _, user_prompt = build_maze_prompt(example) - else: - _, user_prompt = build_spatialmap_prompt(example) - options = extract_options_from_prompt(user_prompt, target_options) - eval_fn = lambda output: evaluate_mcq_answer(output, options, gt) - - logger.info(f"---- Example {idx} ----") - - if args.use_critic: - sample_results = run_k_samples_with_critic( - prompt, llm_server, critic_llm_server, args.k, args.seed, - args.task, example, tokenizer, eval_fn, nums=(nums if args.task == "game24" else None), - early_stop=args.critic_early_stop - ) - else: - outputs = run_k_samples(prompt, llm_server, args.k, args.seed) - sample_results = [] - for output in outputs: - is_correct, extracted, message = eval_fn(output) - token_count = count_tokens(output, tokenizer) - sample_results.append(SampleResult( - output=output, - correct=is_correct, - extracted=extracted, - message=message, - tokens=token_count, - critic_correct=None, - )) - - if args.use_critic: - best_idx = next((i for i, r in enumerate(sample_results) if r.critic_correct), 0) - else: - best_idx = next((i for i, r in enumerate(sample_results) if r.correct), 0) - best_result = sample_results[best_idx] - any_correct = any(r.correct for r in sample_results) - correct_samples = sum(1 for r in sample_results if r.correct) - critic_correct_samples_example = sum(1 for r in sample_results if r.critic_correct) - - save_outputs(idx, sample_results, best_idx, output_dirs["reasoning"]) - - total_examples += 1 - if any_correct: - total_correct += 1 - total_correct_samples += correct_samples - total_samples += len(sample_results) - critic_correct_samples += critic_correct_samples_example - critic_total_samples += len(sample_results) - total_tokens += best_result.tokens - total_tokens_all_samples += sum(r.tokens for r in sample_results) - - results.append({ - "idx": int(idx), - "best_idx": best_idx, - "any_correct": any_correct, - "best_correct": best_result.correct, - "best_critic_correct": best_result.critic_correct, - "best_extracted": best_result.extracted, - "best_message": best_result.message, - "best_tokens": best_result.tokens, - "all_tokens": [r.tokens for r in sample_results], - "all_correct": [r.correct for r in sample_results], - "all_critic_correct": [r.critic_correct for r in sample_results], - "options": options, - }) - - logger.info(f"Best sample: {best_idx} | Correct in K: {any_correct}") - logger.info(f"Best message: {best_result.message}") - - accuracy = total_correct / total_examples if total_examples else 0 - avg_best_tokens = total_tokens / total_examples if total_examples else 0 - avg_all_tokens = total_tokens_all_samples / total_examples if total_examples else 0 - - summary = { - "task": args.task, - "model": args.main_model, - "k": args.k, - "use_critic": args.use_critic, - "total_examples": total_examples, - "correct": total_correct, - "correct_samples": total_correct_samples, - "total_samples": total_samples, - "critic_correct_samples": critic_correct_samples, - "critic_total_samples": critic_total_samples, - "critic_accuracy": (critic_correct_samples / critic_total_samples) if critic_total_samples else 0, - "accuracy": accuracy, - "avg_best_tokens": avg_best_tokens, - "avg_all_tokens": avg_all_tokens, - "total_tokens_best": total_tokens, - "total_tokens_all_samples": total_tokens_all_samples, - "results": results, - } - - if args.use_critic: - summary["critic_model"] = args.critic_model - summary["critic_port"] = args.critic_port - summary["critic_early_stop"] = args.critic_early_stop - - summary_path = os.path.join(output_dirs["base"], "summary.json") - with open(summary_path, "w", encoding="utf-8") as f: - json.dump(summary, f, indent=2) - logger.info(f"Saved summary to {summary_path}")