diff --git a/.gitignore b/.gitignore index b7faf40..fd51d2d 100644 --- a/.gitignore +++ b/.gitignore @@ -205,3 +205,5 @@ cython_debug/ marimo/_static/ marimo/_lsp/ __marimo__/ + +b-pchanda/ \ No newline at end of file diff --git a/README.md b/README.md index 03a763e..dc5ca94 100644 --- a/README.md +++ b/README.md @@ -102,6 +102,21 @@ We provide examples using three datasets: Maze, Game of 24, and SpatialMap. python ./examples/TTSwithVerification/[your_dataset]_stepverifier.py -n 1 # dataset=maze,game24, or spatialmap ``` +For using TreeofThought + +```bash +python ./examples/TTSwithVerification/tot_baseline.py \ + --task maze \ + --num_examples 4 \ + --ports 8000,8001,8002,8003 \ + --concurrency 4 \ + --model Qwen/QwQ-32B \ + --max_tokens 32768 \ +``` + +This script loads the same datasets as the verification examples, spins up `TreeOfThoughtSearch` with configurable branching/depth, and round-robins requests across multiple vLLM instances for faster experimentation. + + ### Monitors for Early stopping ```bash python ./examples/EarlyStopping/[your_dataset]_example.py -n 1 diff --git a/examples/TTSwithVerification/README.md b/examples/TTSwithVerification/README.md index ed7fbb4..a2c52dd 100644 --- a/examples/TTSwithVerification/README.md +++ b/examples/TTSwithVerification/README.md @@ -156,6 +156,39 @@ The Z3 solver handles diagonal directions (`Northwest`, `Northeast`, `Southwest` --- +# Best-of-K Baseline + +A simple best-of-K baseline that generates K independent reasoning traces per example and selects the best based on: +1. **Ground-truth matching** (default): Greedy selection of first correct answer among K samples +2. **Critic model evaluation** (optional): Use a separate critic LLM to evaluate correctness without access to ground truth + +This baseline demonstrates that with sufficient sampling, even simple CoT can achieve good performance. + +## Usage + +```bash +# Best-of-K with ground-truth evaluation +python ./examples/TTSwithVerification/bestofk_baseline.py --task game24 -n 10 --k 4 + +# Best-of-K with critic model evaluation +python ./examples/TTSwithVerification/bestofk_baseline.py --task game24 -n 10 --k 4 --use_critic --critic_model Qwen/Qwen3-30B-A3B-Thinking-2507 --critic_port 8001 +``` + +### Parameters + +| Argument | Description | Default | +|----------|-------------|---------| +| `--task` | Task: `game24`, `maze`, or `spatialmap` | required | +| `--k` | Number of samples per example | `4` | +| `--use_critic` | Use critic model for evaluation instead of ground truth | `False` | +| `--critic_model` | Model to use for critic evaluation | MAIN_MODEL | +| `--critic_port` | vLLM server port for critic model | `8001` | +| `--num_examples`, `-n` | Number of examples to run | varies | +| `--main_model` | Model for generation | `Qwen/Qwen3-30B-A3B-Thinking-2507` | +| `--port` | vLLM server port for main model | `8000` | + +--- + ## Example Scripts Each script runs a full evaluation: loading a dataset, building structured prompts, running inference with step verification, and computing accuracy/token statistics. @@ -169,6 +202,14 @@ python ./examples/TTSwithVerification/maze_stepverifier.py -n 1 # SpatialMap with step verification python ./examples/TTSwithVerification/spatialmap_stepverifier.py -n 1 + +# Best-of-K baseline (standard CoT, no monitors) +python ./examples/TTSwithVerification/bestofk_baseline.py --task game24 -n 1 --k 4 +python ./examples/TTSwithVerification/bestofk_baseline.py --task maze -n 1 --k 4 +python ./examples/TTSwithVerification/bestofk_baseline.py --task spatialmap -n 1 --k 4 + +# Best-of-K with critic model evaluation +python ./examples/TTSwithVerification/bestofk_baseline.py --task game24 -n 1 --k 4 --use_critic ``` ### Common arguments diff --git a/examples/TTSwithVerification/tot_baseline.py b/examples/TTSwithVerification/tot_baseline.py new file mode 100644 index 0000000..e8f7660 --- /dev/null +++ b/examples/TTSwithVerification/tot_baseline.py @@ -0,0 +1,630 @@ +#!/usr/bin/env python3 +"""Command-line Tree-of-Thought baseline runner for interwhen datasets.""" + +import argparse +import asyncio +import json +import logging +import os +import re +from pathlib import Path +from typing import Any, Dict, List, Tuple + +import httpx +import numpy as np +from datasets import load_dataset + +from interwhen.tree_of_thought import ( + SearchMethod, + ToTSearchConfig, + TreeOfThoughtSearch, + build_tot_problem, +) + +LOGGER = logging.getLogger("tot_baseline") + + +# ============== Helper Functions ============== + +def remove_last_paragraph(s: str) -> str: + return s[:-143] if len(s) > 143 else s + + +def build_maze_prompt(example): + pre_prompt = ( + "You are an expert problem solver. Carefully read the following multiple-choice question " + "and think through the solution step-by-step before providing your final answer. " + "Provide your final answer option by enclosing it within \\boxed{A/B/C/D}.:" + ) + description = remove_last_paragraph(str(example.get("prompt"))) + return pre_prompt, description + + +def build_spatialmap_prompt(example): + pre_prompt = ( + "You are an expert problem solver. Carefully read the following multiple-choice question " + "and think through the solution step-by-step before providing your final answer." + "Provide your final answer option by enclosing it within \\boxed{A/B/C/D}.:" + ) + description = remove_last_paragraph(str(example.get("prompt"))) + return pre_prompt, description + + +def extract_solution_game24(text): + boxed_pattern = r"\\boxed\{" + matches = list(re.finditer(boxed_pattern, text)) + if not matches: + return None + last_match = matches[-1] + start = last_match.end() + brace_count = 1 + end = start + while end < len(text) and brace_count > 0: + if text[end] == "{": + brace_count += 1 + elif text[end] == "}": + brace_count -= 1 + end += 1 + expr = text[start:end - 1].strip() + + frac_pattern = r"\\frac\{([^{}]+)\}\{([^{}]+)\}" + while re.search(frac_pattern, expr): + expr = re.sub(frac_pattern, r"(\1/\2)", expr) + + replacements = { + r"\times": "*", + r"\cdot": "*", + r"\div": "/", + } + for latex, op in replacements.items(): + expr = expr.replace(latex, op) + + expr = expr.replace(r"\\,", "").replace(r"\\ ", "") + expr = re.sub(r"\)\s*\(", ")*(", expr) + expr = re.sub(r"\)\s*(\d)", r")*\1", expr) + expr = re.sub(r"(\d)\s*\(", r"\1*(", expr) + + return expr + + +def extract_numbers_from_expr(expr): + numbers = re.findall(r"\d+\.?\d*", expr) + return [int(float(n)) if float(n).is_integer() else float(n) for n in numbers] + + +def validate_numbers_used(expr, expected_nums): + used_nums = extract_numbers_from_expr(expr) + return sorted(used_nums) == sorted(expected_nums) + + +def evaluate_expression(expr, expected_nums=None): + try: + if expected_nums is not None and not validate_numbers_used(expr, expected_nums): + return False + value = eval(expr, {"__builtins__": None}, {}) + return abs(value - 24) < 1e-6 + except Exception: + return False + + +def evaluate_game24_answer(answer, nums): + expr = extract_solution_game24(answer) + if not expr: + return False, None, "No expression found" + if evaluate_expression(expr, expected_nums=nums): + return True, expr, "Correct solution (evaluates to 24 using exactly the given numbers)" + used_nums = extract_numbers_from_expr(expr) + if sorted(used_nums) != sorted(nums): + return False, expr, f"Incorrect: Expression uses {used_nums}, expected {nums}" + return False, expr, "Expression does not evaluate to 24" + + +def extract_solution_mcq(text): + """Extract MCQ solution from model output.""" + patterns = [ + r"\\boxed\{([^}]*)\}", + r"boxed\{([^}]*)\}", + r"\*\*([A-D])\*\*", + r"answer[:\s]*([A-D])", + r"(?:^|\n)([A-D])(?:\s|$|\.)", + ] + + for pattern in patterns: + matches = re.findall(pattern, text, re.IGNORECASE) + if matches: + expr = matches[-1].strip() + choice_match = re.search(r"\b([ABCD])\b", expr, flags=re.IGNORECASE) + if choice_match: + return choice_match.group(1).upper() + + standalone = re.findall(r"\b([ABCD])\b", text) + if standalone: + return standalone[-1].upper() + + return None + + +def extract_options_from_prompt(prompt_text, target_options): + pattern = r"\b([A-D])\.\s*(.*?)(?=\s*[A-D]\.\s*|$)" + raw = re.findall(pattern, prompt_text, flags=re.DOTALL) + options = {k: v.strip().rstrip(".") for k, v in raw} + if target_options: + options = {k: v for k, v in options.items() if k in target_options} + return options + + +def evaluate_mcq_answer(answer, options, ground_truth): + sol = extract_solution_mcq(answer) + gt_sol = str(ground_truth).strip() + if not sol: + return False, None, "No expression found" + sol = sol.strip() + if sol in options: + if options[sol] == gt_sol: + return True, sol, f"Correct: option {sol} -> {options[sol]}" + return False, sol, f"Incorrect: expected '{gt_sol}', got '{options[sol]}' (option {sol})" + if sol.lower() == gt_sol.lower(): + return True, sol, f"Correct: answer text matches ground truth: {sol}" + for opt_letter, opt_value in options.items(): + if sol.lower() == opt_value.lower(): + if opt_value == gt_sol: + return True, sol, f"Correct: answer text {sol} (option {opt_letter})" + return False, sol, f"Incorrect: expected '{gt_sol}', got '{opt_value}' (option {opt_letter})" + return False, sol, f"Solution '{sol}' not found in options or ground truth" + + +def extract_solution_zebralogic(text): + """Extract JSON solution from ZebraLogic model output.""" + if not text: + return None + + def _try_parse(candidate: str): + try: + parsed = json.loads(candidate) + if isinstance(parsed, dict): + # Unwrap if it's a wrapper with "answer" key + if "answer" in parsed and isinstance(parsed["answer"], dict): + inner = parsed["answer"] + if any(re.match(r"^house\s*\d+$", str(k).strip(), flags=re.IGNORECASE) for k in inner.keys()): + return inner + return parsed + except json.JSONDecodeError: + return None + return None + + # Try to extract JSON from code blocks + patterns = [ + r"```json\s*({.*?})\s*```", # Markdown code block + r"```\s*({.*?})\s*```", # Generic code block + r"({\s*['\"]House\s*\d+['\"].*?})", # Direct JSON starting with House + ] + + for pattern in patterns: + matches = re.findall(pattern, text, re.DOTALL | re.IGNORECASE) + if matches: + json_str = matches[-1].strip() + solution = _try_parse(json_str) + if solution is not None: + return solution + + # Try parsing entire last large JSON-like structure + try: + # Find potential JSON starting with { + json_match = re.search(r"({\s*(?:['\"]House|['{\"\[])+[\s\S]*})", text) + if json_match: + json_str = json_match.group(1) + solution = _try_parse(json_str) + if solution is not None: + return solution + except (json.JSONDecodeError, AttributeError): + pass + + # Last-chance extraction: parse top-level JSON object spans and keep the + # last one that parses and looks like a house assignment dictionary. + stack = [] + spans = [] + for idx, ch in enumerate(text): + if ch == "{": + stack.append(idx) + elif ch == "}" and stack: + start = stack.pop() + if not stack: + spans.append((start, idx + 1)) + for start, end in reversed(spans): + candidate = text[start:end] + solution = _try_parse(candidate) + if solution is not None: + # Handle wrapped solution with "answer" key + if isinstance(solution, dict) and "answer" in solution: + answer = solution["answer"] + if isinstance(answer, dict) and any( + re.match(r"^house\s*\d+$", str(key).strip(), flags=re.IGNORECASE) + for key in answer.keys() + ): + return answer + # Direct house keys + if any( + re.match(r"^house\s*\d+$", str(key).strip(), flags=re.IGNORECASE) + for key in solution.keys() + ): + return solution + + return None + + +async def _request_zebralogic_json(prompt: str, llm_server: Dict[str, Any]) -> str: + """Submit a strict-JSON request for ZebraLogic and return raw model content.""" + payload = dict(llm_server["payload"]) + payload["temperature"] = 0.0 + payload["messages"] = [ + { + "role": "system", + "content": ( + "You solve Zebra Logic puzzles and MUST return strictly valid JSON only. " + "No markdown fences. No explanation. No extra text." + ), + }, + { + "role": "user", + "content": prompt, + }, + ] + payload["response_format"] = {"type": "json_object"} + async with httpx.AsyncClient(timeout=120.0) as client: + response = await client.post( + llm_server["url"], + headers=llm_server["headers"], + json=payload, + ) + response.raise_for_status() + body = response.json() + return body["choices"][0]["message"]["content"].strip() + + +async def finalize_zebralogic_json(problem: str, trajectory: str, llm_server: Dict[str, Any]) -> str: + """Ask the model to convert an existing trajectory into strict final JSON only.""" + prompt = ( + "Convert the reasoning into the final Zebra Logic answer JSON.\n" + "Output ONLY valid JSON (no markdown, no explanation).\n" + "Use exact feature/value names from the puzzle.\n\n" + "PUZZLE:\n" + f"{problem}\n\n" + "REASONING:\n" + f"{trajectory}\n" + ) + return await _request_zebralogic_json(prompt, llm_server) + + +async def solve_zebralogic_json_direct(problem: str, llm_server: Dict[str, Any]) -> str: + """Directly solve ZebraLogic and return strict final JSON.""" + prompt = ( + "Solve the Zebra Logic puzzle and provide the final house assignments.\n" + "Output ONLY valid JSON with keys like 'House 1', 'House 2', etc.\n" + "Use exact feature/value names from the puzzle text.\n\n" + "PUZZLE:\n" + f"{problem}\n" + ) + return await _request_zebralogic_json(prompt, llm_server) + + +def evaluate_zebralogic_answer(answer, ground_truth): + """Evaluate ZebraLogic solution against ground truth.""" + extracted = extract_solution_zebralogic(answer) + + if extracted is None: + return False, None, "Could not extract valid JSON solution" + + # Normalize keys (handle 'House X', 'house x', etc.) + def normalize_solution(sol): + normalized = {} + for key, value in sol.items(): + # Normalize house key + house_match = re.search(r"House\s*(\d+)", key, re.IGNORECASE) + if house_match: + house_num = house_match.group(1) + normalized[f"House {house_num}"] = value if isinstance(value, dict) else value + else: + normalized[key] = value + return normalized + + extracted_norm = normalize_solution(extracted) + ground_truth_norm = normalize_solution(ground_truth) if isinstance(ground_truth, dict) else ground_truth + + # Simple exact match on normalized solution + if extracted_norm == ground_truth_norm: + return True, extracted_norm, "Correct: Solution matches ground truth exactly" + + # Check if as string they're close (for JSON format differences) + extracted_str = json.dumps(extracted_norm, sort_keys=True) + gt_str = json.dumps(ground_truth_norm, sort_keys=True) if isinstance(ground_truth_norm, dict) else str(ground_truth_norm) + + if extracted_str == gt_str: + return True, extracted_norm, "Correct: Solution matches ground truth (format normalized)" + + # Partial credit: check if majority of houses match + if isinstance(ground_truth_norm, dict) and isinstance(extracted_norm, dict): + matches = sum(1 for k in extracted_norm if k in ground_truth_norm and extracted_norm[k] == ground_truth_norm[k]) + total = max(len(extracted_norm), len(ground_truth_norm)) + if matches > 0: + accuracy = matches / total + return False, extracted_norm, f"Partial match: {matches}/{total} houses correct ({accuracy:.1%})" + + return False, extracted_norm, "Incorrect: Solution does not match ground truth" + + +def load_dataset_for_task(task): + if task == "game24": + return load_dataset("nlile/24-game", split="train") + if task == "maze": + return load_dataset("microsoft/VISION_LANGUAGE", "maze_text_only", split="val") + if task == "spatialmap": + return load_dataset("microsoft/VISION_LANGUAGE", "spatial_map_text_only", split="val") + if task == "zebralogic": + return load_dataset("WildEval/ZebraLogic", name="grid_mode", split="test") + raise ValueError(f"Unsupported task: {task}") + + +def resolve_indices(task, dataset_len, args): + if args.indices: + return [int(x.strip()) for x in args.indices.split(",")] + if args.xrange: + parts = args.xrange.split("-") + if len(parts) == 2: + try: + start = int(parts[0].strip()) + end = int(parts[1].strip()) + return list(range(start, end)) + except ValueError: + raise ValueError(f"Invalid xrange format: {args.xrange}. Use 'start-end'") + if args.num_examples: + return list(np.linspace(0, dataset_len - 1, args.num_examples, dtype=int)) + start = args.start if args.start is not None else 0 + end = args.end if args.end is not None else dataset_len + return list(range(start, end)) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Run Tree-of-Thought search on a subset of the supported tasks", + ) + parser.add_argument("--task", choices=["game24", "maze", "spatialmap", "zebralogic"], required=True) + parser.add_argument("--k", type=int, default=1, help="Unused placeholder to mirror other baselines") + parser.add_argument("--num_examples", "-n", type=int, default=None) + parser.add_argument("--indices", type=str, default=None) + parser.add_argument("--xrange", type=str, default=None) + parser.add_argument("--start", type=int, default=None) + parser.add_argument("--end", type=int, default=None) + parser.add_argument("--model", default="Qwen/QwQ-32B") + parser.add_argument("--llm_url", default="http://localhost:{port}/v1/chat/completions") + parser.add_argument( + "--ports", + default="8000", + help="Comma-separated list of vLLM ports to round-robin across", + ) + parser.add_argument("--temperature", type=float, default=0.3) + parser.add_argument("--top_p", type=float, default=0.9) + parser.add_argument("--top_k", type=int, default=20) + parser.add_argument("--max_tokens", type=int, default=32768) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--search_method", choices=["bfs", "dfs", "beam"], default="beam") + parser.add_argument("--branching_factor", type=int, default=4) + parser.add_argument("--max_depth", type=int, default=1) + parser.add_argument("--beam_width", type=int, default=2) + parser.add_argument("--sure_threshold", type=float, default=0.7) + parser.add_argument("--likely_threshold", type=float, default=0.5) + parser.add_argument("--impossible_threshold", type=float, default=0.2) + parser.add_argument("--max_candidates_per_level", type=int, default=3) + parser.add_argument("--early_termination", action="store_true") + parser.add_argument("--no_cache", action="store_true") + parser.add_argument( + "--concurrency", + type=int, + default=1, + help="Maximum number of ToT examples to run concurrently", + ) + parser.add_argument( + "--output_dir", + default="outputs/tot_baseline", + help="Directory to store per-example JSON logs and summary", + ) + parser.add_argument("--log_level", default="INFO") + return parser.parse_args() + + +def parse_port_list(port_str: str) -> List[int]: + return [int(p.strip()) for p in port_str.split(",") if p.strip()] + + +def build_llm_server(args: argparse.Namespace, port: int) -> Dict[str, Any]: + payload = { + "model": args.model, + "temperature": args.temperature, + "top_p": args.top_p, + "top_k": args.top_k, + "max_tokens": args.max_tokens, + "stream": False, + "seed": args.seed, + } + return { + "url": args.llm_url.format(port=port), + "headers": {"content-type": "application/json"}, + "payload": payload, + } + + +def build_tot_config(args: argparse.Namespace) -> ToTSearchConfig: + method = SearchMethod[args.search_method.upper()] + return ToTSearchConfig( + branching_factor=args.branching_factor, + max_depth=args.max_depth, + search_method=method, + beam_width=args.beam_width, + sure_threshold=args.sure_threshold, + likely_threshold=args.likely_threshold, + impossible_threshold=args.impossible_threshold, + early_termination=args.early_termination, + cache_evaluations=not args.no_cache, + max_candidates_per_level=args.max_candidates_per_level, + ) + + +def ensure_output_dir(base_dir: str, task: str) -> Path: + path = Path(base_dir).expanduser().resolve() / task + path.mkdir(parents=True, exist_ok=True) + return path + + +def prepare_eval(task: str, example: Dict[str, Any]) -> Tuple: + if task == "game24": + nums = list(example.get("numbers", [])) + return (lambda output: evaluate_game24_answer(output, nums), {"numbers": nums}) + if task == "zebralogic": + # ZebraLogic ground truth is the solution JSON + ground_truth = example.get("solution", {}) + meta = {"ground_truth_sample": str(ground_truth)[:100]} + return (lambda output: evaluate_zebralogic_answer(output, ground_truth), meta) + gt = str(example.get("ground_truth", "")).strip() + target_options = ["A", "B"] if gt == "Q4" else ["A", "B", "C", "D"] + if task == "maze": + _, user_prompt = build_maze_prompt(example) + else: + _, user_prompt = build_spatialmap_prompt(example) + options = extract_options_from_prompt(user_prompt, target_options) + meta = {"options": options, "ground_truth": gt} + return (lambda output: evaluate_mcq_answer(output, options, gt), meta) + + +async def run_single_example( + idx: int, + task: str, + example: Dict[str, Any], + tot_config: ToTSearchConfig, + llm_server: Dict[str, Any], +) -> Dict[str, Any]: + eval_fn, eval_meta = prepare_eval(task, example) + problem = build_tot_problem(task, example, nums=example.get("numbers")) + tot = TreeOfThoughtSearch(tot_config) + search_result = await tot.search(task, problem, llm_server) + best_traj = search_result.get("best_trajectory", "") + best_value = search_result.get("best_value", 0.0) + is_correct, extracted, message = eval_fn(best_traj) + + # ZebraLogic often ends with partial reasoning trajectories; add strict-JSON + # recovery passes before scoring. + finalized_answer = None + direct_answer = None + if task == "zebralogic" and (not is_correct): + try: + finalized_answer = await finalize_zebralogic_json(problem, best_traj, llm_server) + final_is_correct, final_extracted, final_message = eval_fn(finalized_answer) + if final_extracted is not None or final_is_correct: + is_correct = final_is_correct + extracted = final_extracted + message = final_message + best_traj = finalized_answer + except Exception as exc: # pragma: no cover + LOGGER.warning("ZebraLogic finalization failed for index %s: %s", idx, exc) + + if task == "zebralogic" and (not is_correct): + try: + direct_answer = await solve_zebralogic_json_direct(problem, llm_server) + direct_is_correct, direct_extracted, direct_message = eval_fn(direct_answer) + if direct_extracted is not None or direct_is_correct: + is_correct = direct_is_correct + extracted = direct_extracted + message = direct_message + best_traj = direct_answer + except Exception as exc: # pragma: no cover + LOGGER.warning("ZebraLogic direct solve failed for index %s: %s", idx, exc) + + return { + "index": int(idx), + "best_value": best_value, + "best_trajectory": best_traj, + "raw_best_trajectory": search_result.get("best_trajectory", ""), + "finalized_answer": finalized_answer, + "direct_answer": direct_answer, + "search_stats": search_result.get("search_stats", {}), + "decision_tree": search_result.get("decision_tree", []), + "correct": bool(is_correct), + "extracted": extracted, + "message": message, + "evaluation_meta": eval_meta, + } + + +async def run_tot_baseline(args: argparse.Namespace) -> None: + logging.basicConfig(level=getattr(logging, args.log_level.upper(), logging.INFO)) + dataset = load_dataset_for_task(args.task) + indices = resolve_indices(args.task, len(dataset), args) + output_dir = ensure_output_dir(args.output_dir, args.task) + tot_config = build_tot_config(args) + ports = parse_port_list(args.ports) + if not ports: + raise ValueError("At least one port must be specified via --ports") + concurrency = max(1, args.concurrency) + port_lock = asyncio.Lock() + port_index = {"value": 0} + + async def next_port() -> int: + async with port_lock: + port = ports[port_index["value"] % len(ports)] + port_index["value"] += 1 + return port + + semaphore = asyncio.Semaphore(concurrency) + + async def process_index(idx: int) -> Dict[str, Any]: + async with semaphore: + example = dataset[int(idx)] + port = await next_port() + llm_server = build_llm_server(args, port) + LOGGER.info("Running ToT on example %s via port %s", idx, port) + try: + record = await run_single_example(idx, args.task, example, tot_config, llm_server) + except Exception as exc: # pragma: no cover + LOGGER.exception("Failed example %s", idx) + record = { + "index": int(idx), + "error": str(exc), + "best_trajectory": "", + "correct": False, + } + example_path = output_dir / f"example_{idx}.json" + with example_path.open("w", encoding="utf-8") as handle: + json.dump(record, handle, indent=2) + return record + + processed = await asyncio.gather(*[process_index(idx) for idx in indices]) + + total = len(processed) + correct = sum(1 for r in processed if r.get("correct")) + summary = { + "task": args.task, + "model": args.model, + "total_examples": total, + "correct": correct, + "accuracy": (correct / total) if total else 0.0, + "search_method": args.search_method, + "config": { + "branching_factor": args.branching_factor, + "max_depth": args.max_depth, + "beam_width": args.beam_width, + "sure_threshold": args.sure_threshold, + "likely_threshold": args.likely_threshold, + "impossible_threshold": args.impossible_threshold, + "max_candidates_per_level": args.max_candidates_per_level, + "early_termination": args.early_termination, + "cache_evaluations": not args.no_cache, + "ports": ports, + "concurrency": concurrency, + }, + } + summary_path = output_dir / "summary.json" + with summary_path.open("w", encoding="utf-8") as handle: + json.dump(summary, handle, indent=2) + LOGGER.info("Accuracy %.2f (%d/%d)", summary["accuracy"], correct, total) + + +if __name__ == "__main__": + asyncio.run(run_tot_baseline(parse_args())) \ No newline at end of file diff --git a/interwhen/tree_of_thought.py b/interwhen/tree_of_thought.py new file mode 100644 index 0000000..cbca864 --- /dev/null +++ b/interwhen/tree_of_thought.py @@ -0,0 +1,1268 @@ +""" +Tree of Thought implementation for interwhen-style streaming completion. + +Implements proper ToT search using: +1. Propose function to generate candidate next steps +2. Value function to evaluate intermediate states +3. Search algorithm (BFS/DFS/beam) to explore the tree +4. Integrated with interwhen's async streaming architecture +""" + +import asyncio +import json +import logging +import re +from typing import Dict, List, Any, Optional, Tuple +from dataclasses import dataclass, field +from enum import Enum +import time + +from .value_prompts import ( + build_game24_value_prompt, + build_mcq_value_prompt, + build_tot_value_prompt as build_tot_value_prompt_impl, +) + +logger = logging.getLogger(__name__) + + +# --------------------- Dataset prompt helpers --------------------- + +def build_game24_prompt(nums: List[int]) -> str: + """Return the canonical Game24 instruction block used across baselines.""" + if len(nums) != 4: + raise ValueError("Game24 requires exactly four numbers.") + a, b, c, d = nums + boxed = r"\\boxed{}" + return ( + "You are solving the Game of 24.\n\n" + f"You are given four numbers: {a}, {b}, {c}, {d}\n\n" + "Your job is to produce a valid arithmetic expression using:\n" + "- ALL four numbers exactly once\n- ONLY +, -, *, /\n" + "- The expression must evaluate to exactly 24.\n\n" + "Please reason step by step, and put your final answer containing" + f" only the expression within {boxed}." + ) + + +def build_maze_prompt(example: Dict[str, Any]) -> str: + """Construct the maze reasoning instructions used in other pipelines.""" + pre_prompt = ( + "You are an expert problem solver. Carefully read the following " + "multiple-choice question and think through the solution step-by-step " + "before providing your final answer. Provide the final answer option by " + "enclosing it within \\boxed{A/B/C/D}." + ) + description = str(example.get("prompt", "")) + return f"{pre_prompt}\n\n{description.strip()}" + + +def build_spatialmap_prompt(example: Dict[str, Any]) -> str: + """Construct the spatial reasoning instructions for TOT experiments.""" + pre_prompt = ( + "You are an expert problem solver. Carefully read the following " + "multiple-choice question and think through the solution step-by-step " + "before providing your final answer. Provide the final answer option by " + "enclosing it within \\boxed{A/B/C/D}." + ) + description = str(example.get("prompt", "")) + return f"{pre_prompt}\n\n{description.strip()}" + + +def build_zebralogic_prompt(example: Dict[str, Any]) -> str: + """Construct the Zebra Logic puzzle solving instructions for TOT experiments.""" + puzzle_text = str(example.get("puzzle", "")) + prompt = ( + "# Problem Description\n\n" + "You are solving a house grid logic puzzle. You are given:\n" + "1. Features and Domains\n" + " - A fixed number of houses, indexed sequentially (e.g., House 1, House 2, …) from left to right.\n" + " - A set of features (e.g., color, name, pet, book genre).\n" + " - Each feature has a finite domain of possible values.\n" + "2. Constraints:\n" + " - Each house has exactly one value per feature.\n" + " - No two houses share the same value for the same feature.\n" + "3. Clues / Constraints describing:\n" + " - Houses and their positions\n" + " - Feature values\n" + " - Relative ordering (e.g., 'next to', 'to the left of', '2 houses away from')\n\n" + "Solve this puzzle to your best ability by determining the arrangement of features across the houses.\n\n" + "# Puzzle\n\n" + f"{puzzle_text}\n\n" + "# Solution Format\n\n" + "Provide your final answer in this exact JSON format:\n" + "```json\n" + '{\n' + ' "House 1": { "feature1": "value1", "feature2": "value2", ... },\n' + ' "House 2": { "feature1": "value1", "feature2": "value2", ... },\n' + ' ...\n' + '}\n' + "```\n\n" + "Make sure to use the exact feature/value names as given in the puzzle.\n" + "Ensure the JSON is valid and parsable." + ) + return prompt + + +def build_tot_problem(task: str, example: Dict[str, Any], nums: Optional[List[int]] = None) -> str: + """Helper that mirrors the best-of-k prompt builders for ToT runs.""" + task_lower = task.lower() + if task_lower == "game24": + numbers = nums or example.get("numbers") + if numbers is None: + raise ValueError("Game24 prompt requires 'numbers' in the example") + return build_game24_prompt(list(numbers)) + if task_lower == "maze": + return build_maze_prompt(example) + if task_lower == "spatialmap": + return build_spatialmap_prompt(example) + if task_lower == "zebralogic": + return build_zebralogic_prompt(example) + raise ValueError(f"Unsupported task for ToT prompt building: {task}") + + +def build_tot_value_prompt( + task: str, + problem: str, + trajectory: str, + use_fewshot: bool = True +) -> str: + """ + Build value prompt for Tree of Thought evaluation. + + Args: + task: The task type (e.g., "game24", "maze", "spatialmap") + problem: The original problem statement + trajectory: Current partial solution (or 'No progress yet' if empty) + use_fewshot: Whether to use few-shot examples (default True for better evaluation) + + Returns: + Formatted value prompt with or without few-shot examples + """ + if not trajectory.strip(): + trajectory = "No progress yet" + return build_tot_value_prompt_impl(task, problem, trajectory, use_fewshot=use_fewshot) + + +class SearchMethod(Enum): + """Search algorithm types""" + BFS = "bfs" + DFS = "dfs" + BEAM = "beam" + + +@dataclass +class TreeNode: + """Represents a node in the Tree of Thought""" + trajectory: str + depth: int + value: float = 0.5 + parent: Optional['TreeNode'] = None + children: List['TreeNode'] = field(default_factory=list) + is_terminal: bool = False + proposals: List[str] = field(default_factory=list) + evaluation_log: Dict[str, Any] = field(default_factory=dict) + + def __hash__(self): + return hash(self.trajectory) + + def __eq__(self, other): + return isinstance(other, TreeNode) and self.trajectory == other.trajectory + + +@dataclass +class ToTSearchConfig: + """Configuration for Tree of Thought search""" + branching_factor: int = 4 + max_depth: int = 6 + search_method: SearchMethod = SearchMethod.BFS + beam_width: int = 2 + + # Value thresholds + sure_threshold: float = 0.7 + likely_threshold: float = 0.5 + impossible_threshold: float = 0.2 + + # Optimization settings + early_termination: bool = True + cache_evaluations: bool = True + max_candidates_per_level: int = 3 + + +class TreeOfThoughtSearch: + """ + Tree of Thought search controller compatible with interwhen's streaming. + + Provides propose/evaluate/search methods that work with vLLM API calls + via the llm_server interface used in interwhen. + """ + + def __init__(self, config: ToTSearchConfig = None): + self.config = config or ToTSearchConfig() + self.evaluation_cache = {} + self.proposal_cache = {} + self.search_stats = { + "nodes_explored": 0, + "evaluations_performed": 0, + "branches_pruned": 0, + "cache_hits": 0, + "solutions_found": 0, + "total_nodes_in_tree": 0, + } + self.decision_tree = [] + self.root = None + + # ===================== PROPOSE FUNCTION ===================== + + async def propose_next_steps( + self, + task: str, + problem: str, + current_trajectory: str, + llm_server: Dict, + num_proposals: Optional[int] = None, + ) -> List[str]: + """ + Generate candidate next steps using the model's propose capability. + + Args: + task: The task type (e.g., "game24", "maze", "spatialmap") + problem: The original problem statement + current_trajectory: Current partial solution + llm_server: vLLM server config (url, headers, payload template) + num_proposals: Number of proposals to generate (defaults to branching_factor) + + Returns: + List of proposed next steps + """ + if num_proposals is None: + num_proposals = self.config.branching_factor + + # Check cache + cache_key = f"propose_{hash(problem)}_{hash(current_trajectory)}" + if self.config.cache_evaluations and cache_key in self.proposal_cache: + self.search_stats["cache_hits"] += 1 + return self.proposal_cache[cache_key] + + self.search_stats["nodes_explored"] += 1 + + # Build propose prompt + propose_prompt = self._build_propose_prompt( + task, + problem, + current_trajectory, + num_proposals + ) + + # Call model with streaming + proposal_text = await self._call_llm_streaming( + llm_server, + propose_prompt + ) + + # Parse proposals from response + proposals = self._parse_proposals(proposal_text, num_proposals) + + logger.info( + "Generated %d proposals at depth hint=%s", + len(proposals), + "root" if not current_trajectory.strip() else "non-root", + ) + + # Log decision point + decision_log = { + "type": "proposal_generation", + "timestamp": time.time(), + "problem_hash": hash(problem), + "trajectory": current_trajectory, + "prompt": propose_prompt, + "prompt_preview": propose_prompt[:200] + "..." if len(propose_prompt) > 200 else propose_prompt, + "raw_response": proposal_text, + "raw_response_preview": proposal_text[:300] + "..." if len(proposal_text) > 300 else proposal_text, + "parsed_proposals": proposals, + } + self.decision_tree.append(decision_log) + + # Cache + if self.config.cache_evaluations: + self.proposal_cache[cache_key] = proposals + + return proposals + + def _build_propose_prompt( + self, + task: str, + problem: str, + trajectory: str, + num_proposals: int + ) -> str: + """Build a prompt requesting proposals for next steps.""" + if task == "maze": + return self._build_maze_propose_prompt(problem, trajectory, num_proposals) + if task == "spatialmap": + return self._build_spatialmap_propose_prompt(problem, trajectory, num_proposals) + + return f"""Given the following problem and current progress, propose {num_proposals} possible next steps. + +PROBLEM: +{problem} + +CURRENT PROGRESS/TRAJECTORY: +{trajectory if trajectory.strip() else "Starting fresh - no progress yet"} + +Generate {num_proposals} distinct next steps that could advance the solution. Be specific and actionable. + +Format each proposal clearly, one per line: +1. [Proposal 1] +2. [Proposal 2] +... + +Think step by step about what makes each proposal viable. +""" + + def _detect_maze_question_type(self, problem: str) -> str: + """Detect maze subtype for proposal steering (Q0/Q2/Q4).""" + lower = problem.lower() + if "how many right turns" in lower: + logger.debug("Detected Q0: right turn counting") + return "q0" + if "how many turns" in lower and "right turns" not in lower: + logger.debug("Detected Q2: total turn counting") + return "q2" + if "starting from s" in lower and "where is e" in lower: + logger.debug("Detected Q4: spatial relation") + return "q4" + if "relative" in lower and "s" in lower and "e" in lower: + logger.debug("Detected Q4: spatial relation (relative)") + return "q4" + logger.warning(f"Maze question type not recognized, using generic. Problem preview: {lower[:200]}") + return "generic" + + def _extract_last_move_info(self, trajectory: str) -> dict: + """Extract previous direction and counter from last step in trajectory.""" + if not trajectory.strip(): + return {"prev_direction": None, "right_count": 0, "left_count": 0, "total_count": 0} + + lines = trajectory.strip().split('\n') + last_line = lines[-1] if lines else "" + + # Try to extract direction from "Next move: [DIRECTION]" + import re + direction_match = re.search(r'Next move:\s*(UP|DOWN|LEFT|RIGHT)', last_line, re.IGNORECASE) + prev_direction = direction_match.group(1).upper() if direction_match else None + + # Extract counters + right_match = re.search(r'Right-turn count:\s*(\d+)', last_line) + left_match = re.search(r'Left-turn count:\s*(\d+)', last_line) + total_match = re.search(r'Total-turn count:\s*(\d+)', last_line) + + right_count = int(right_match.group(1)) if right_match else 0 + left_count = int(left_match.group(1)) if left_match else 0 + total_count = int(total_match.group(1)) if total_match else 0 + + return { + "prev_direction": prev_direction, + "right_count": right_count, + "left_count": left_count, + "total_count": total_count + } + + def _build_maze_propose_prompt( + self, + problem: str, + trajectory: str, + num_proposals: int, + ) -> str: + """Build maze-specific atomic next-step proposal prompts by question type.""" + question_type = self._detect_maze_question_type(problem) + logger.debug(f"Detected maze question type: {question_type}") + + # Extract previous move info for bookkeeping + prev_info = self._extract_last_move_info(trajectory) + + if not trajectory.strip(): + current = "Starting fresh - no progress yet" + last_step_hint = "" + else: + current = trajectory.strip() + lines = current.split('\n') + last_line = lines[-1] if lines else "" + last_step_hint = f"\nLAST COMPLETED STEP: {last_line}\nNow generate the NEXT move after this (do NOT repeat this move).\n" + + if question_type == "q0": + prev_dir = prev_info["prev_direction"] + prev_count = prev_info["right_count"] + + if prev_dir is None: + # First move - all directions result in STRAIGHT with count 0 + examples = f"""PARENT: first move, count=0 + +Valid answers (pick {num_proposals}): +Next move: UP | Turn: STRAIGHT | Right-turn count: 0 +Next move: DOWN | Turn: STRAIGHT | Right-turn count: 0 +Next move: LEFT | Turn: STRAIGHT | Right-turn count: 0 +Next move: RIGHT | Turn: STRAIGHT | Right-turn count: 0""" + else: + # Define turn mappings + turn_map = { + "UP": {"RIGHT": "RIGHT", "LEFT": "LEFT", "UP": "STRAIGHT", "DOWN": "STRAIGHT"}, + "DOWN": {"LEFT": "RIGHT", "RIGHT": "LEFT", "DOWN": "STRAIGHT", "UP": "STRAIGHT"}, + "LEFT": {"UP": "RIGHT", "DOWN": "LEFT", "LEFT": "STRAIGHT", "RIGHT": "STRAIGHT"}, + "RIGHT": {"DOWN": "RIGHT", "UP": "LEFT", "RIGHT": "STRAIGHT", "LEFT": "STRAIGHT"} + } + + moves = turn_map.get(prev_dir, {}) + examples_list = [] + for next_dir, turn_type in moves.items(): + new_count = prev_count + 1 if turn_type == "RIGHT" else prev_count + examples_list.append(f"Next move: {next_dir} | Turn: {turn_type} | Right-turn count: {new_count}") + + examples = f"""PARENT: direction={prev_dir}, count={prev_count} + +Valid answers (pick {num_proposals}): +{chr(10).join(examples_list)}""" + + return f"""{examples} + +DO NOT explain. DO NOT reason. Just output {num_proposals} lines from above.""" + + if question_type == "q2": + prev_dir = prev_info["prev_direction"] + prev_count = prev_info["total_count"] + + if prev_dir is None: + examples = f"""PARENT: first move, count=0 + +Valid answers (pick {num_proposals}): +Next move: UP | Turn: STRAIGHT | Total-turn count: 0 +Next move: DOWN | Turn: STRAIGHT | Total-turn count: 0 +Next move: LEFT | Turn: STRAIGHT | Total-turn count: 0 +Next move: RIGHT | Turn: STRAIGHT | Total-turn count: 0""" + else: + # Define turn mappings (same as Q0) + turn_map = { + "UP": {"RIGHT": "RIGHT", "LEFT": "LEFT", "UP": "STRAIGHT", "DOWN": "STRAIGHT"}, + "DOWN": {"LEFT": "RIGHT", "RIGHT": "LEFT", "DOWN": "STRAIGHT", "UP": "STRAIGHT"}, + "LEFT": {"UP": "RIGHT", "DOWN": "LEFT", "LEFT": "STRAIGHT", "RIGHT": "STRAIGHT"}, + "RIGHT": {"DOWN": "RIGHT", "UP": "LEFT", "RIGHT": "STRAIGHT", "LEFT": "STRAIGHT"} + } + + moves = turn_map.get(prev_dir, {}) + examples_list = [] + for next_dir, turn_type in moves.items(): + new_count = prev_count + 1 if turn_type in ["RIGHT", "LEFT"] else prev_count + examples_list.append(f"Next move: {next_dir} | Turn: {turn_type} | Total-turn count: {new_count}") + + examples = f"""PARENT: direction={prev_dir}, count={prev_count} + +Valid answers (pick {num_proposals}): +{chr(10).join(examples_list)}""" + + return f"""{examples} + +DO NOT explain. DO NOT reason. Just output {num_proposals} lines from above.""" + + if question_type == "q4": + # Q4 should also be structured - no long reasoning + return f"""Maze spatial question. Generate {num_proposals} brief factual statements. + +{current if current else "Starting."} + +Output {num_proposals} lines. Each line: one short fact. NO long explanations.""" + + # Generic fallback - also keep it structured + return f"""Maze question. Generate {num_proposals} next steps. + +{current if current else "Starting."} + +Output {num_proposals} lines. Each line: one short step. NO explanations.""" + + def _detect_spatialmap_question_type(self, problem: str) -> str: + """Detect spatialmap subtype for proposal steering (direction/object/counting).""" + lower = problem.lower() + if "how many" in lower and ("objects" in lower or "places" in lower or "locations" in lower): + return "counting" + if "which object" in lower or "what object" in lower or "which place" in lower or "which location" in lower: + return "object" + if "in which direction" in lower or "what direction" in lower or "relative to" in lower: + return "direction" + return "generic" + + def _build_spatialmap_propose_prompt( + self, + problem: str, + trajectory: str, + num_proposals: int, + ) -> str: + """Build spatialmap-specific atomic next-step proposal prompts by question type.""" + question_type = self._detect_spatialmap_question_type(problem) + logger.debug(f"Detected spatialmap question type: {question_type}") + + if not trajectory.strip(): + current = "Starting fresh - no progress yet" + last_step_hint = "" + else: + current = trajectory.strip() + lines = [line.strip() for line in current.split("\n") if line.strip()] + last_line = lines[-1] if lines else "" + last_step_hint = ( + f"\nLAST COMPLETED STEP: {last_line}\n" + "Now generate the NEXT atomic step after this (do NOT repeat this step).\n" + ) + + if question_type == "direction": + return f"""You are solving a spatial-map DIRECTION question. + +PROBLEM: +{problem} + +TRAJECTORY SO FAR: +{current} +{last_step_hint} +Your task: Propose {num_proposals} ATOMIC next steps only. +- Each step must advance exactly ONE concrete spatial inference +- Prefer one of: parse one relation, apply reversibility once, apply transitivity once, or map target-vs-reference direction +- Do NOT restate the whole map + +Output format (one line per proposal, no preamble): +1. [Atomic spatial inference] +2. [Atomic spatial inference] +...""" + + if question_type == "object": + return f"""You are solving a spatial-map OBJECT-IDENTIFICATION question. + +PROBLEM: +{problem} + +TRAJECTORY SO FAR: +{current} +{last_step_hint} +Your task: Propose {num_proposals} ATOMIC next steps only. +- Each step should do one action: identify candidate set, eliminate one candidate, or validate one relation against query direction +- Keep steps local and specific to the asked direction/object +- Do NOT rewrite all relationships + +Output format (one line per proposal, no preamble): +1. [Atomic candidate/evidence step] +2. [Atomic candidate/evidence step] +...""" + + if question_type == "counting": + return f"""You are solving a spatial-map COUNTING question. + +PROBLEM: +{problem} + +TRAJECTORY SO FAR: +{current} +{last_step_hint} +Your task: Propose {num_proposals} ATOMIC next steps only. +- Each step should do one action: identify one qualifying object, rule out one object, or update running count by exactly one justified change +- Keep a clear running count state +- Do NOT provide final answer yet unless count is fully justified + +Output format (one line per proposal, no preamble): +1. [Atomic counting step, e.g., "Qualifies: ; running count = n"] +2. [Atomic counting step, e.g., "Ruled out: ; running count = n"] +...""" + + return f"""You are solving a spatial-map reasoning question. + +PROBLEM: +{problem} + +TRAJECTORY SO FAR: +{current} +{last_step_hint} +Propose {num_proposals} atomic, actionable next steps. +Each step must add ONE new spatial fact/inference and be different from prior steps. + +Output format (one line per proposal, no preamble): +1. ... +2. ... +...""" + + def _parse_proposals(self, response: str, num_proposals: int) -> List[str]: + """ + Parse proposals from model response. + Handles various formats (numbered lists, bullets, etc.) + """ + proposals: List[str] = [] + + # First, try to extract exact format lines: "Next move: X | Turn: Y | ...count: Z" + # Match line by line to avoid cross-line pollution + for line in response.split('\n'): + line = line.strip() + # Check if it matches our exact format + if 'Next move:' in line and 'Turn:' in line and 'count:' in line: + proposals.append(line) + if len(proposals) >= num_proposals: + return proposals[:num_proposals] + + # If we got enough exact format proposals, return them + if len(proposals) >= num_proposals: + return proposals[:num_proposals] + + # Fallback: numbered multiline blocks (preserve continuation lines) + numbered_blocks = re.findall( + r"(?:^|\n)\s*\d+[\.)]\s*(.+?)(?=(?:\n\s*\d+[\.)]\s*)|\Z)", + response, + flags=re.DOTALL, + ) + for block in numbered_blocks: + cleaned = " ".join(block.strip().split()) + if cleaned and len(cleaned) > 3: + proposals.append(cleaned) + if len(proposals) >= num_proposals: + return proposals[:num_proposals] + + # Second pass: fallback line parser for bullets/single-line proposals + for line in response.split("\n"): + cleaned = line.strip() + if not cleaned or cleaned in ["Next steps:", "Proposals:", "possible next steps"]: + continue + + cleaned = re.sub(r"^\s*(?:\d+[\.)]|[-•*])\s*", "", cleaned) + cleaned = re.sub(r"^\s*\[\d+\]\s*", "", cleaned) + cleaned = re.sub(r"^\s*proposal\s*[:\-]\s*", "", cleaned, flags=re.IGNORECASE) + cleaned = " ".join(cleaned.split()) + + if cleaned and len(cleaned) > 3: + proposals.append(cleaned) + if len(proposals) >= num_proposals: + return proposals[:num_proposals] + + return proposals[:num_proposals] + + # ===================== EVALUATE FUNCTION ===================== + + async def evaluate_state( + self, + task: str, + problem: str, + trajectory: str, + llm_server: Dict, + ) -> float: + """ + Evaluate the quality/progress of current state. + + Args: + task: The task type (e.g., "game24", "maze", "spatialmap") + problem: Original problem + trajectory: Current solution trajectory + llm_server: vLLM server config + + Returns: + Value score between 0.0 and 1.0 + """ + # Check cache + cache_key = f"evaluate_{hash(problem)}_{hash(trajectory)}" + if self.config.cache_evaluations and cache_key in self.evaluation_cache: + self.search_stats["cache_hits"] += 1 + return self.evaluation_cache[cache_key] + + self.search_stats["evaluations_performed"] += 1 + + # Build evaluation prompt + eval_prompt = self._build_evaluation_prompt(task, problem, trajectory) + + # Call model + eval_response = await self._call_llm_streaming(llm_server, eval_prompt) + + # Parse evaluation into score + score = self._parse_evaluation(eval_response) + confidence_label = self._extract_confidence_label(eval_response, score) + + # Log evaluation + eval_log = { + "type": "state_evaluation", + "timestamp": time.time(), + "trajectory": trajectory, + "prompt_preview": eval_prompt[:200] + "...", + "response_preview": eval_response[:200] + "...", + "score": score, + "confidence": confidence_label, + } + if self.decision_tree: + if "evaluations" not in self.decision_tree[-1]: + self.decision_tree[-1]["evaluations"] = [] + self.decision_tree[-1]["evaluations"].append(eval_log) + + # Cache + if self.config.cache_evaluations: + self.evaluation_cache[cache_key] = score + + return score + + def _build_evaluation_prompt(self, task: str, problem: str, trajectory: str) -> str: + """Build dataset-aware evaluation prompts reused by ToT scoring.""" + return build_tot_value_prompt(task, problem, trajectory) + + def _parse_evaluation(self, response: str) -> float: + """ + Parse evaluation response into a scalar score [0, 1] + """ + response_lower = response.lower() + + confidence_keywords = { + "sure": 0.9, "certain": 0.9, "confident": 0.9, + "likely": 0.7, "probably": 0.7, + "possible": 0.5, "maybe": 0.5, + "unlikely": 0.3, "doubtful": 0.3, + "impossible": 0.1, "blocked": 0.1, + } + + for keyword, score in confidence_keywords.items(): + if keyword in response_lower: + return score + + # Try to extract numeric score if present (1-9 scale) + for i, char in enumerate(response): + if char.isdigit(): + digit = int(char) + if 1 <= digit <= 9: + return digit / 9.0 # Normalize to [0, 1] + + return 0.5 # Default neutral score + + def _score_to_confidence(self, score: float) -> str: + """Map scalar score [0,1] to confidence bucket.""" + if score >= 0.8: + return "sure" + if score >= 0.6: + return "likely" + if score >= 0.4: + return "possible" + if score >= 0.2: + return "unlikely" + return "impossible" + + def _extract_confidence_label(self, response: str, score: float) -> str: + """Extract confidence label from value response; fallback to score mapping.""" + lower = response.lower() + for label in ["sure", "likely", "possible", "unlikely", "impossible"]: + if label in lower: + return label + return self._score_to_confidence(score) + + def _log_proposal_transition( + self, + depth: int, + parent_trajectory: str, + proposal: str, + next_state: str, + value: float, + is_terminal: bool, + pruned: bool, + ) -> None: + """Log proposal -> next-state -> value transition for debugging/analysis.""" + self.decision_tree.append( + { + "type": "proposal_transition", + "timestamp": time.time(), + "depth": depth, + "parent_trajectory": parent_trajectory, + "proposal": proposal, + "next_state": next_state, + "value": value, + "value_confidence": self._score_to_confidence(value), + "is_terminal": is_terminal, + "pruned": pruned, + } + ) + + logger.info( + "ToT transition | depth=%d | proposal=%s | value=%.3f | confidence=%s | pruned=%s | terminal=%s", + depth, + proposal, + value, + self._score_to_confidence(value), + pruned, + is_terminal, + ) + + # ===================== SEARCH IMPLEMENTATION ===================== + + async def search( + self, + task: str, + problem: str, + llm_server: Dict, + ) -> Dict[str, Any]: + """ + Perform Tree of Thought search on the problem. + + Args: + task: The task type (e.g., "game24", "maze", "spatialmap") + problem: Problem statement + llm_server: vLLM server config + + Returns: + Dictionary with best_trajectory, best_value, search_log + """ + logger.info(f"Starting ToT search with method={self.config.search_method.value}") + + # Initialize root node + self.root = TreeNode(trajectory="", depth=0, value=0.5) + + if self.config.search_method == SearchMethod.BFS: + return await self._bfs_search(task, problem, llm_server) + elif self.config.search_method == SearchMethod.BEAM: + return await self._beam_search(task, problem, llm_server) + else: + return await self._dfs_search(task, problem, llm_server) + + async def _bfs_search(self, task: str, problem: str, llm_server: Dict) -> Dict[str, Any]: + """Breadth-First Search implementation""" + queue = [self.root] + best_terminal = None + best_value = 0.0 + best_candidate = None + best_candidate_value = float('-inf') + + for depth in range(self.config.max_depth): + if not queue: + break + + next_queue = [] + + for node in queue: + # Generate proposals + proposals = await self.propose_next_steps( + task, + problem, + node.trajectory, + llm_server, + self.config.branching_factor + ) + node.proposals = proposals + + # Create child nodes + for prop in proposals: + new_trajectory = f"{node.trajectory}\n{prop}" if node.trajectory else prop + child = TreeNode( + trajectory=new_trajectory, + depth=depth + 1, + parent=node, + ) + + # Evaluate + value = await self.evaluate_state(task, problem, new_trajectory, llm_server) + child.value = value + + # Track best candidate regardless of terminal status + if value > best_candidate_value: + best_candidate = child + best_candidate_value = value + + # Check if terminal and meets threshold + if self._is_terminal(new_trajectory): + child.is_terminal = True + self.search_stats["solutions_found"] += 1 + if value > best_value: + best_value = value + best_terminal = child + is_terminal = True + + # Early termination if high confidence + if self.config.early_termination and value >= self.config.sure_threshold: + self._log_proposal_transition( + depth=depth + 1, + parent_trajectory=node.trajectory, + proposal=prop, + next_state=new_trajectory, + value=value, + is_terminal=is_terminal, + pruned=False, + ) + return self._format_search_result(best_terminal, problem) + else: + is_terminal = False + + # Prune low-value nodes + if value < self.config.impossible_threshold: + self.search_stats["branches_pruned"] += 1 + self._log_proposal_transition( + depth=depth + 1, + parent_trajectory=node.trajectory, + proposal=prop, + next_state=new_trajectory, + value=value, + is_terminal=is_terminal, + pruned=True, + ) + continue + + self._log_proposal_transition( + depth=depth + 1, + parent_trajectory=node.trajectory, + proposal=prop, + next_state=new_trajectory, + value=value, + is_terminal=is_terminal, + pruned=False, + ) + + node.children.append(child) + next_queue.append(child) + self.search_stats["total_nodes_in_tree"] += 1 + + queue = next_queue[:self.config.max_candidates_per_level] + + return self._format_search_result(best_terminal or best_candidate, problem) + + async def _beam_search(self, task: str, problem: str, llm_server: Dict) -> Dict[str, Any]: + """Beam Search implementation""" + beam = [self.root] + best_terminal = None + best_value = 0.0 + best_candidate = None + best_candidate_value = float('-inf') + + for depth in range(self.config.max_depth): + candidates = [] + + for node in beam: + # Generate and evaluate proposals + proposals = await self.propose_next_steps( + task, + problem, + node.trajectory, + llm_server, + self.config.branching_factor + ) + node.proposals = proposals + + for prop in proposals: + new_trajectory = f"{node.trajectory}\n{prop}" if node.trajectory else prop + value = await self.evaluate_state(task, problem, new_trajectory, llm_server) + + child = TreeNode( + trajectory=new_trajectory, + depth=depth + 1, + value=value, + parent=node, + ) + + candidates.append((child, value)) + + if value > best_candidate_value: + best_candidate = child + best_candidate_value = value + self.search_stats["total_nodes_in_tree"] += 1 + + if self._is_terminal(new_trajectory): + child.is_terminal = True + self.search_stats["solutions_found"] += 1 + if value > best_value: + best_value = value + best_terminal = child + is_terminal = True + + if self.config.early_termination and value >= self.config.sure_threshold: + self._log_proposal_transition( + depth=depth + 1, + parent_trajectory=node.trajectory, + proposal=prop, + next_state=new_trajectory, + value=value, + is_terminal=is_terminal, + pruned=False, + ) + return self._format_search_result(best_terminal, problem) + else: + is_terminal = False + + pruned = value < self.config.impossible_threshold + if pruned: + self.search_stats["branches_pruned"] += 1 + + self._log_proposal_transition( + depth=depth + 1, + parent_trajectory=node.trajectory, + proposal=prop, + next_state=new_trajectory, + value=value, + is_terminal=is_terminal, + pruned=pruned, + ) + + if pruned: + continue + + # Keep top-k by value + candidates.sort(key=lambda x: x[1], reverse=True) + beam = [child for child, _ in candidates[:self.config.beam_width]] + + if not beam: + break + + return self._format_search_result(best_terminal or best_candidate, problem) + + async def _dfs_search(self, task: str, problem: str, llm_server: Dict) -> Dict[str, Any]: + """Depth-First Search implementation""" + best_terminal = None + best_value = 0.0 + best_candidate = None + best_candidate_value = float('-inf') + + async def dfs(node: TreeNode, depth: int): + nonlocal best_terminal, best_value + + if depth >= self.config.max_depth: + return + + # Generate proposals + proposals = await self.propose_next_steps( + task, + problem, + node.trajectory, + llm_server, + self.config.branching_factor + ) + node.proposals = proposals + + for prop in proposals: + new_trajectory = f"{node.trajectory}\n{prop}" if node.trajectory else prop + value = await self.evaluate_state(task, problem, new_trajectory, llm_server) + + child = TreeNode( + trajectory=new_trajectory, + depth=depth + 1, + value=value, + parent=node, + ) + node.children.append(child) + + if value > best_candidate_value: + best_candidate = child + best_candidate_value = value + self.search_stats["total_nodes_in_tree"] += 1 + + if self._is_terminal(new_trajectory): + child.is_terminal = True + self.search_stats["solutions_found"] += 1 + if value > best_value: + best_value = value + best_terminal = child + is_terminal = True + + if self.config.early_termination and value >= self.config.sure_threshold: + self._log_proposal_transition( + depth=depth + 1, + parent_trajectory=node.trajectory, + proposal=prop, + next_state=new_trajectory, + value=value, + is_terminal=is_terminal, + pruned=False, + ) + return + else: + is_terminal = False + + # Prune + if value >= self.config.impossible_threshold: + self._log_proposal_transition( + depth=depth + 1, + parent_trajectory=node.trajectory, + proposal=prop, + next_state=new_trajectory, + value=value, + is_terminal=is_terminal, + pruned=False, + ) + await dfs(child, depth + 1) + else: + self.search_stats["branches_pruned"] += 1 + self._log_proposal_transition( + depth=depth + 1, + parent_trajectory=node.trajectory, + proposal=prop, + next_state=new_trajectory, + value=value, + is_terminal=is_terminal, + pruned=True, + ) + + await dfs(self.root, 0) + return self._format_search_result(best_terminal or best_candidate, problem) + + # ===================== UTILITIES ===================== + + def _is_terminal(self, trajectory: str) -> bool: + """Check if trajectory represents a complete solution""" + keywords = [ + "final answer", + "reached goal", + "solution:", + "answer:", + "conclusion:", + "result:", + ] + trajectory_lower = trajectory.lower() + return any(kw in trajectory_lower for kw in keywords) + + def _format_search_result( + self, + best_node: Optional[TreeNode], + problem: str + ) -> Dict[str, Any]: + """Format search results for return""" + if best_node: + best_trajectory = best_node.trajectory + best_value = best_node.value + else: + best_trajectory = "" + best_value = 0.0 + + return { + "best_trajectory": best_trajectory, + "best_value": best_value, + "search_stats": self.search_stats, + "decision_tree": self.decision_tree, + "root_node": self.root, + } + + async def _call_llm_streaming( + self, + llm_server: Dict, + prompt: str + ) -> str: + """Call chat-completions endpoint and return the full response text.""" + import httpx + + payload = llm_server["payload"].copy() + payload.pop("prompt", None) + payload.pop("messages", None) + payload["messages"] = [{"role": "user", "content": prompt}] + payload["stream"] = False + + try: + async with httpx.AsyncClient(timeout=None) as client: + response = await client.post( + llm_server["url"], + headers=llm_server["headers"], + json=payload, + ) + response.raise_for_status() + data = response.json() + except Exception as e: + logger.error(f"Error calling LLM: {e}") + raise + + choices = data.get("choices", []) + if not choices: + logger.warning("LLM response missing choices: %s", data.keys()) + return "" + + choice = choices[0] + if isinstance(choice, dict): + msg = choice.get("message") or {} + return msg.get("content") or choice.get("text", "") + return str(choice) + + def get_decision_tree_json(self) -> str: + """Export decision tree as JSON""" + return json.dumps({ + "search_stats": self.search_stats, + "decision_points": self.decision_tree, + "num_decision_points": len(self.decision_tree), + }, indent=2, default=str) + + def _serialize_node(self, node: Optional[TreeNode], max_depth: Optional[int] = None) -> Optional[Dict[str, Any]]: + """Serialize tree nodes recursively for debugging/state inspection.""" + if node is None: + return None + + if max_depth is not None and node.depth >= max_depth: + return { + "depth": node.depth, + "value": node.value, + "is_terminal": node.is_terminal, + "trajectory": node.trajectory, + "num_children": len(node.children), + "children": [], + "truncated": True, + } + + return { + "depth": node.depth, + "value": node.value, + "is_terminal": node.is_terminal, + "trajectory": node.trajectory, + "num_children": len(node.children), + "children": [self._serialize_node(child, max_depth=max_depth) for child in node.children], + } + + def get_state_snapshot( + self, + include_tree: bool = True, + max_tree_depth: Optional[int] = None, + decision_tail: Optional[int] = 50, + include_cache_samples: bool = True, + cache_sample_size: int = 5, + ) -> Dict[str, Any]: + """Return a comprehensive snapshot of the current ToT search state.""" + proposal_cache_keys = list(self.proposal_cache.keys()) + evaluation_cache_keys = list(self.evaluation_cache.keys()) + + snapshot: Dict[str, Any] = { + "config": { + "branching_factor": self.config.branching_factor, + "max_depth": self.config.max_depth, + "search_method": self.config.search_method.value, + "beam_width": self.config.beam_width, + "sure_threshold": self.config.sure_threshold, + "likely_threshold": self.config.likely_threshold, + "impossible_threshold": self.config.impossible_threshold, + "early_termination": self.config.early_termination, + "cache_evaluations": self.config.cache_evaluations, + "max_candidates_per_level": self.config.max_candidates_per_level, + }, + "search_stats": dict(self.search_stats), + "decision_tree_size": len(self.decision_tree), + "cache_state": { + "proposal_cache_size": len(self.proposal_cache), + "evaluation_cache_size": len(self.evaluation_cache), + }, + "root_present": self.root is not None, + "root_depth": self.root.depth if self.root is not None else None, + "root_value": self.root.value if self.root is not None else None, + "root_is_terminal": self.root.is_terminal if self.root is not None else None, + } + + if decision_tail is None: + snapshot["decision_tree"] = self.decision_tree + else: + snapshot["decision_tree_tail"] = self.decision_tree[-decision_tail:] + + if include_cache_samples: + sample_size = max(0, cache_sample_size) + snapshot["cache_state"]["proposal_cache_key_samples"] = proposal_cache_keys[:sample_size] + snapshot["cache_state"]["evaluation_cache_key_samples"] = evaluation_cache_keys[:sample_size] + + if include_tree: + snapshot["tree"] = self._serialize_node(self.root, max_depth=max_tree_depth) + + return snapshot + + def get_state_snapshot_json( + self, + include_tree: bool = True, + max_tree_depth: Optional[int] = None, + decision_tail: Optional[int] = 50, + include_cache_samples: bool = True, + cache_sample_size: int = 5, + indent: int = 2, + ) -> str: + """Return JSON string for the current ToT state snapshot.""" + snapshot = self.get_state_snapshot( + include_tree=include_tree, + max_tree_depth=max_tree_depth, + decision_tail=decision_tail, + include_cache_samples=include_cache_samples, + cache_sample_size=cache_sample_size, + ) + return json.dumps(snapshot, indent=indent, default=str) \ No newline at end of file diff --git a/interwhen/value_prompts.py b/interwhen/value_prompts.py new file mode 100644 index 0000000..1f1fd19 --- /dev/null +++ b/interwhen/value_prompts.py @@ -0,0 +1,432 @@ +""" +Value prompts with few-shot examples for Tree of Thought evaluation across datasets. +""" + +# ============================================================================ +# GAME24 VALUE PROMPTS WITH FEW-SHOT EXAMPLES +# ============================================================================ + +GAME24_VALUE_PROMPT_WITH_FEWSHOT = """Evaluate if given numbers can reach 24 (sure/likely/impossible). + +PROBLEM STATEMENT: +{problem} + +CURRENT TRAJECTORY: +{trajectory} + +Here are examples of how to evaluate Game of 24 trajectories: + +EXAMPLE 1 - SURE: +Numbers: 4 4 6 8 +Trajectory: 4 + 8 = 12 (left: 4 6 12), 6 - 4 = 2 (left: 2 12), 2 * 12 = 24 +Analysis: Reaches exactly 24 using each number exactly once. +Confidence: sure (9) + +EXAMPLE 2 - SURE: +Numbers: 2 9 10 12 +Trajectory: 12 * 2 = 24 (left: 9 10 24), then 24 * (10 - 9) = 24 * 1 = 24 +Analysis: Valid path found with each number used once, equals 24. +Confidence: sure (9) + +EXAMPLE 3 - SURE: +Numbers: 10 14 +Trajectory: 10 + 14 = 24 +Analysis: Direct calculation using both numbers reaches exactly 24. +Confidence: sure (9) + +EXAMPLE 4 - SURE: +Numbers: 4 4 10 +Trajectory: (10 - 4) * 4 = 6 * 4 = 24 +Analysis: Uses all three numbers exactly once with factorization that reaches 24. +Confidence: sure (9) + +EXAMPLE 5 - SURE: +Numbers: 4 9 11 +Trajectory: 9 + 11 + 4 = 24 +Analysis: All numbers used, arithmetic valid, equals 24. +Confidence: sure (9) + +EXAMPLE 6 - LIKELY: +Numbers: 5 7 8 +Trajectory: 5 + 7 + 8 = 20, or (8 - 5) * 7 = 21 +Analysis: Cannot reach 24 immediately, but numbers in reasonable arithmetic range where 24 might be achievable. +Confidence: likely (7) + +EXAMPLE 7 - LIKELY: +Numbers: 5 6 6 +Trajectory: 5 + 6 + 6 = 17, or (6 - 5) * 6 = 6 +Analysis: Current attempts don't reach 24, but numbers are within reasonable range. +Confidence: likely (7) + +EXAMPLE 8 - IMPOSSIBLE: +Numbers: 1 3 3 +Trajectory: 1 * 3 * 3 = 9, or (1 + 3) * 3 = 12 +Analysis: Maximum reachable with any operations is much less than 24. Numbers all too small. +Confidence: impossible (1) + +EXAMPLE 9 - IMPOSSIBLE: +Numbers: 10 10 11 +Trajectory: 10 + 10 + 11 = 31, or (11 - 10) * 10 = 10 +Analysis: Sum exceeds 24, factorizations fall short. Cannot reach exactly 24. +Confidence: impossible (1) + +EXAMPLE 10 - IMPOSSIBLE: +Numbers: 11 12 +Trajectory: 11 + 12 = 23, or 12 - 11 = 1, or 11 * 12 = 132, or 11 / 12 ≈ 0.91 +Analysis: No operation reaches 24. Sum close but not exact. +Confidence: impossible (1) + +Rubric: +- sure (9): Reaches 24 using each number exactly once +- likely (7): Cannot reach 24 yet, but numbers in reasonable range +- possible (5): Uncertain if 24 is reachable +- unlikely (3): Numbers seem misaligned +- impossible (1): Numbers demonstrably cannot reach 24 + +Respond with "Confidence: " followed by brief justification tied to arithmetic evaluations. +""" + +GAME24_VALUE_PROMPT_SIMPLE = """Evaluate if given numbers can reach 24. + +PROBLEM STATEMENT: +{problem} + +CURRENT TRAJECTORY: +{trajectory} + +Rate confidence on this rubric: +- sure (9): Reaches 24 using each number exactly once +- likely (7): Cannot reach 24 yet, but numbers in reasonable range +- possible (5): Uncertain if 24 is reachable +- unlikely (3): Numbers seem misaligned +- impossible (1): Numbers cannot reach 24 + +Respond with "Confidence: " and brief justification. +""" + +# ============================================================================ +# MAZE VALUE PROMPTS WITH FEW-SHOT EXAMPLES +# ============================================================================ + +MAZE_VALUE_PROMPT_WITH_FEWSHOT = """Verify a maze reasoning trace. + +TASK PROMPT: +{problem} + +MODEL TRAJECTORY: +{trajectory} + +EXAMPLE 1 - SURE: +Question: Count right turns in path X from S to E +Trajectory: Carefully trace X-marked path. Starting at S, move UP (initial direction). Then RIGHT (90 degrees clockwise = right turn 1). Then DOWN (90 degrees clockwise = right turn 2). Then RIGHT (90 degrees clockwise = right turn 3). Continuing pattern: 6 right turns total. +Answer: B (6 right turns) +Analysis: Systematic path tracing with correct turn geometry, defensible count. +Confidence: sure (9) + +EXAMPLE 2 - SURE: +Question: What is the sequence of grid direction? +Trajectory: Following marked path from S: [0,0] then [0,1] (UP) then [1,1] (RIGHT) then [1,0] (DOWN) then [2,0] (RIGHT). Each step verified against grid. +Answer: UP, RIGHT, DOWN, RIGHT +Analysis: Clear coordinate tracking, systematic verification. +Confidence: sure (9) + +EXAMPLE 3 - LIKELY: +Question: Count right turns in path from S to E +Trajectory: Observing marked path shows mostly straight movements with a zigzag pattern. Zigzags suggest mostly left turns. Likely 0-2 right turns based on pattern. +Answer: A (0 right turns) +Confidence: likely (7) +Analysis: Shows reasonable spatial intuition but lacks systematic verification. + +EXAMPLE 4 - LIKELY: +Question: Is path continuous S to E? +Trajectory: I trace the marked path and it appears to connect from S all the way to E without breaks. The X marks form a continuous line. +Answer: Yes +Confidence: likely (7) +Analysis: Reasonable assessment but could benefit from detailed step verification. + +EXAMPLE 5 - POSSIBLE: +Question: Navigate maze from S to E +Trajectory: Following path X... I see turns but the specific sequence is unclear to me. Could be 3 or 4 right turns. +Answer: Uncertain between options +Confidence: possible (5) +Analysis: Recognizes task but cannot decisively trace path geometry. + +EXAMPLE 6 - UNLIKELY: +Question: Count right turns in path X +Trajectory: Tracing marked path X from S. Moving DOWN, then RIGHT (left turn?), then DOWN, then RIGHT (right turn?). I'm confused about turn geometry. +Answer: Some right turns but not sure +Confidence: unlikely (3) +Analysis: Confused about path following and angle identification. + +EXAMPLE 7 - IMPOSSIBLE: +Question: Navigate from S to E following marked path +Trajectory: I move RIGHT, then up, then left, I'm not sure where the path goes. I think I hit a wall. +Answer: I'm stuck +Confidence: impossible (1) +Analysis: Abandons task without following clearly marked X path provided. + +Rubric: sure (9), likely (7), possible (5), unlikely (3), impossible (1) +Respond with "Confidence: " + explanation referencing moves/directions. +""" + +MAZE_VALUE_PROMPT_SIMPLE = """Verify a maze reasoning trace. + +TASK PROMPT: +{problem} + +MODEL TRAJECTORY: +{trajectory} + +Judge if reasoning is consistent with maze/spatial relationships and if final answer is defensible. +Rubric: sure (9), likely (7), possible (5), unlikely (3), impossible (1) +Respond with "Confidence: " + explanation referencing moves/directions. +""" + +# ============================================================================ +# SPATIAL REASONING VALUE PROMPTS WITH FEW-SHOT EXAMPLES +# ============================================================================ + +SPATIALMAP_VALUE_PROMPT_WITH_FEWSHOT = """You are verifying a spatial reasoning multiple-choice trace. + +TASK PROMPT: +{problem} + +MODEL TRAJECTORY: +{trajectory} + +Here are examples of how to evaluate spatial reasoning: + +EXAMPLE 1: +Question: Based on the map, which location is northeast of the library? +Trajectory: I look at the map and see the library in the center. To the northeast means both north AND east of that point. Looking at that quadrant, I see the museum is northeast of the library. +Answer: A (Museum) + +Analysis: The student correctly understands the spatial direction (northeast = north AND east), correctly identifies it on the map, and selects the correct option. +Confidence: sure/certain (9) +Justification: The reasoning correctly applies spatial relationships and identifies the appropriate location. + +EXAMPLE 2: +Question: Which building is closest to the park? +Trajectory: The park looks like it's in the middle of the map. Near it I see a building... looks like it could be the school or maybe the library. I think the school is closer. +Answer: C (School) + +Analysis: The student makes reasonable spatial observations but doesn't verify distances or compare alternatives systematically. +Confidence: likely/probably (7) +Justification: The reasoning shows spatial awareness but lacks systematic comparison of distances to verify the answer. + +EXAMPLE 3: +Question: What is north of the train station? +Trajectory: I see the train station. North is... up on the map. I see some buildings, but I'm not sure exactly which one. Could be the post office or the police station. +Answer: I'm not sure + +Analysis: The student recognizes the direction but fails to identify the specific location clearly. +Confidence: possible/maybe (5) +Justification: The student understands the spatial direction but cannot decisively identify which building is in that location. + +EXAMPLE 4: +Question: If you're at the bank facing east, what's behind you? +Trajectory: At the bank facing east means I'm looking east. Behind me would be... west? I need to think about what's west of the bank. I think there's a hotel or a store but I'm not sure. +Answer: Maybe the hotel + +Analysis: The student correctly understands relative directions (east/behind = west) but isn't certain about the specific feature. +Confidence: unlikely/doubtful (3) +Justification: While the directional reasoning is sound, the uncertainty about the specific location makes this answer questionable. + +Use the confidence rubric: sure/certain (9), likely/probably (7), possible/maybe (5), unlikely/doubtful (3), impossible/blocked (1). +Respond with "Confidence: " plus a concise explanation that references spatial relationships and map features. +""" + +SPATIALMAP_VALUE_PROMPT_SIMPLE = """You are verifying a spatial reasoning multiple-choice trace. + +TASK PROMPT: +{problem} + +MODEL TRAJECTORY: +{trajectory} + +Judge if the reasoning correctly applies spatial relationships (north, south, east, west, near, far, etc.) and whether the final \\boxed{{choice}} is defensible. +Use the confidence rubric: sure/certain (9), likely/probably (7), possible/maybe (5), unlikely/doubtful (3), impossible/blocked (1). +Respond with "Confidence: " plus a concise explanation that references the spatial relationships and locations. +""" +# ============================================================================ +# ZEBRA LOGIC VALUE PROMPTS WITH FEW-SHOT EXAMPLES +# ============================================================================ + +ZEBRALOGIC_VALUE_PROMPT_WITH_FEWSHOT = """Evaluate a Zebra Logic puzzle solution trajectory. + +TASK PROMPT: +{problem} + +MODEL TRAJECTORY: +{trajectory} + +Here are examples of how to evaluate Zebra Logic trajectories: + +EXAMPLE 1 - SURE: +Puzzle: Houses with colors, pets, beverages, and nationalities with clues about relationships. +Trajectory: I've systematically worked through the constraints. House 1 has British resident. Red house owner has Panda. Coffee drinker speaks Japanese. Working through elimination, I've determined all houses uniquely and the solution satisfies all clues without contradictions. +Analysis: Systematic constraint satisfaction with clear justification for each assignment. Solution verifiable. +Confidence: sure (9) + +EXAMPLE 2 - LIKELY: +Trajectory: Working through the clues methodically. I've identified several definite assignments (House 2 has Swedish resident with bird). For the remaining houses, the constraints are narrowing down possibilities and should lead to a unique solution. +Analysis: Reasonable progress using logic, but not yet complete verification of all constraints. +Confidence: likely (7) + +EXAMPLE 3 - POSSIBLE: +Trajectory: I understand the puzzle structure. I'm working through clues but some deductions are unclear to me. I think House 1 might have the British resident, but I'm not certain. +Analysis: Shows problem understanding but lacks decisive constraint application. +Confidence: possible (5) + +EXAMPLE 4 - UNLIKELY: +Trajectory: I'm trying to assign attributes to houses. House 1 has red color and Swedish resident. House 2 has green... wait, but green is next to red. I'm getting confused by the adjacency constraints. +Analysis: Fundamental misunderstanding of spatial/logical constraints. +Confidence: unlikely (3) + +EXAMPLE 5 - IMPOSSIBLE: +Trajectory: I'm going to assign all attributes randomly since I don't see how the clues relate to each other. +Analysis: Abandons logical reasoning without attempting systematic constraint satisfaction. +Confidence: impossible (1) + +Rubric for Zebra Logic: +- sure (9): Complete solution derived with clear constraint verification, all assignments justified +- likely (7): Systematic progress with mostly confident deductions, minor uncertainties remain +- possible (5): Some correct deductions but missing clear constraint application +- unlikely (3): Attempting logic but making errors in constraint application or showing confusion +- impossible (1): No meaningful attempt at systematic constraint satisfaction + +Respond with "Confidence: " followed by brief justification referencing the logical deductions and constraint satisfaction. +""" + +ZEBRALOGIC_VALUE_PROMPT_SIMPLE = """Evaluate a Zebra Logic puzzle solution trajectory. + +TASK PROMPT: +{problem} + +MODEL TRAJECTORY: +{trajectory} + +Judge if the reasoning systematically applies logical constraints and whether the solution assignments are well-justified. +Use the confidence rubric: +- sure (9): Complete solution with clear constraint verification +- likely (7): Systematic progress with mostly confident deductions +- possible (5): Some correct deductions with minor gaps +- unlikely (3): Attempting logic but making constraint errors +- impossible (1): No meaningful systematic reasoning + +Respond with "Confidence: " and brief justification referencing constraint satisfaction. +""" +# ============================================================================ +# GENERIC VALUE PROMPT +# ============================================================================ + +GENERIC_VALUE_PROMPT_WITH_FEWSHOT = """Evaluate how close the following trajectory is to solving the problem. + +PROBLEM: +{problem} + +CURRENT TRAJECTORY: +{trajectory} + +Here are examples of trajectory evaluations: + +EXAMPLE 1 (Strong progress): +Trajectory: I've broken down the problem into steps, identified key constraints, and I'm halfway through the solution with correct logic so far. +Confidence: likely/probably (7) +Justification: Clear methodology and correct intermediate progress toward the solution. + +EXAMPLE 2 (Uncertain progress): +Trajectory: I've started the problem and my approach seems reasonable, but I'm not confident about the next steps. +Confidence: possible/maybe (5) +Justification: Direction is sound but execution and completeness require verification. + +EXAMPLE 3 (Low chance of success): +Trajectory: I tried an approach but it seems to have led to a contradiction. +Confidence: unlikely/doubtful (3) +Justification: The approach has fundamental issues that need to be reconsidered. + +Rate the state on the scale: sure/certain (9), likely/probably (7), possible/maybe (5), unlikely/doubtful (3), impossible/blocked (1). +Respond with "Confidence: " and a short rationale. +""" + +GENERIC_VALUE_PROMPT_SIMPLE = """Evaluate how close the following trajectory is to solving the problem. + +PROBLEM: +{problem} + +CURRENT TRAJECTORY: +{trajectory} + +Rate the state on the scale: sure/certain (9), likely/probably (7), possible/maybe (5), unlikely/doubtful (3), impossible/blocked (1). +Respond with "Confidence: " and a short rationale. +""" + + +def build_game24_value_prompt(problem: str, trajectory: str, use_fewshot: bool = True) -> str: + """Build game24 value prompt with or without few-shot examples.""" + if use_fewshot: + return GAME24_VALUE_PROMPT_WITH_FEWSHOT.format(problem=problem, trajectory=trajectory) + else: + return GAME24_VALUE_PROMPT_SIMPLE.format(problem=problem, trajectory=trajectory) + + +def build_mcq_value_prompt( + problem: str, trajectory: str, task_name: str, use_fewshot: bool = True +) -> str: + """Build MCQ (maze/spatial) value prompt with or without few-shot examples.""" + if task_name.lower() == "maze": + if use_fewshot: + return MAZE_VALUE_PROMPT_WITH_FEWSHOT.format(problem=problem, trajectory=trajectory) + else: + return MAZE_VALUE_PROMPT_SIMPLE.format(problem=problem, trajectory=trajectory) + elif task_name.lower() in ("spatial", "spatialmap", "spatial reasoning"): + if use_fewshot: + return SPATIALMAP_VALUE_PROMPT_WITH_FEWSHOT.format(problem=problem, trajectory=trajectory) + else: + return SPATIALMAP_VALUE_PROMPT_SIMPLE.format(problem=problem, trajectory=trajectory) + else: + # Default MCQ template + if use_fewshot: + return MAZE_VALUE_PROMPT_WITH_FEWSHOT.format(problem=problem, trajectory=trajectory) + else: + return MAZE_VALUE_PROMPT_SIMPLE.format(problem=problem, trajectory=trajectory) + + +def build_generic_value_prompt(problem: str, trajectory: str, use_fewshot: bool = True) -> str: + """Build generic value prompt with or without few-shot examples.""" + if use_fewshot: + return GENERIC_VALUE_PROMPT_WITH_FEWSHOT.format(problem=problem, trajectory=trajectory) + else: + return GENERIC_VALUE_PROMPT_SIMPLE.format(problem=problem, trajectory=trajectory) + + +def build_zebralogic_value_prompt(problem: str, trajectory: str, use_fewshot: bool = True) -> str: + """Build zebralogic value prompt with or without few-shot examples.""" + if use_fewshot: + return ZEBRALOGIC_VALUE_PROMPT_WITH_FEWSHOT.format(problem=problem, trajectory=trajectory) + else: + return ZEBRALOGIC_VALUE_PROMPT_SIMPLE.format(problem=problem, trajectory=trajectory) + + +def build_tot_value_prompt(task: str, problem: str, trajectory: str, use_fewshot: bool = True) -> str: + """ + Build value prompt for Tree of Thought evaluation. + + Args: + task: The task type (e.g., "game24", "maze", "spatialmap", "zebralogic") + problem: The original problem statement + trajectory: Current partial solution + use_fewshot: Whether to use few-shot examples (default True) + + Returns: + Formatted value prompt + """ + if task == "game24": + return build_game24_value_prompt(problem, trajectory, use_fewshot) + if task == "maze": + return build_mcq_value_prompt(problem, trajectory, "maze", use_fewshot) + if task == "spatialmap": + return build_mcq_value_prompt(problem, trajectory, "spatial reasoning", use_fewshot) + if task == "zebralogic": + return build_zebralogic_value_prompt(problem, trajectory, use_fewshot) + return build_generic_value_prompt(problem, trajectory, use_fewshot) \ No newline at end of file