From 1075898bafe85a28dd68455b01ff74f20703698e Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Tue, 25 Nov 2025 17:43:39 -0500 Subject: [PATCH 01/15] added compilation util --- src/gfn/utils/compile.py | 52 ++ .../examples/train_hypergrid_optimized.py | 633 ++++++++++++++++++ 2 files changed, 685 insertions(+) create mode 100644 src/gfn/utils/compile.py create mode 100644 tutorials/examples/train_hypergrid_optimized.py diff --git a/src/gfn/utils/compile.py b/src/gfn/utils/compile.py new file mode 100644 index 00000000..7456c5c3 --- /dev/null +++ b/src/gfn/utils/compile.py @@ -0,0 +1,52 @@ +from __future__ import annotations + +from typing import Iterable + +import torch + + +def try_compile_gflownet( + gfn, + *, + mode: str = "default", + components: Iterable[str] = ("pf", "pb", "logZ", "logF"), +) -> dict[str, bool]: + """Best-effort compilation of estimator modules attached to a GFlowNet. + + Args: + gfn: The GFlowNet instance to compile. + mode: Compilation mode forwarded to ``torch.compile``. + components: Attribute names to attempt compilation on (e.g., ``pf``). + + Returns: + Mapping from component name to compilation success status. + """ + + if not hasattr(torch, "compile"): + return {name: False for name in components} + + results: dict[str, bool] = {} + for name in components: + if not hasattr(gfn, name): + msg = ( + f"GFlowNet of type {type(gfn).__name__} has no '{name}' attribute; " + "expected a valid estimator when attempting compilation." + ) + raise AttributeError(msg) + + estimator = getattr(gfn, name) + module = getattr(estimator, "module", None) + + # If the estimator does not have a module, we cannot compile it. + if module is None: + results[name] = False + continue + + # If the estimator does not have a module, we cannot compile it. + try: + assert isinstance(estimator.module, torch.nn.Module) + estimator.module = torch.compile(module, mode=mode) + results[name] = True + except Exception: + results[name] = False + return results diff --git a/tutorials/examples/train_hypergrid_optimized.py b/tutorials/examples/train_hypergrid_optimized.py new file mode 100644 index 00000000..9bcfcd5d --- /dev/null +++ b/tutorials/examples/train_hypergrid_optimized.py @@ -0,0 +1,633 @@ +#!/usr/bin/env python +r""" +Optimized HyperGrid training script with optional torch.compile, vmap, and benchmarking. +""" + +from __future__ import annotations + +import argparse +import statistics +import time +from pathlib import Path +from typing import Any, Dict, Iterable, cast + +import torch +from torch.func import vmap +from tqdm import tqdm + +from gfn.estimators import DiscretePolicyEstimator, ScalarEstimator +from gfn.gflownet.detailed_balance import DBGFlowNet +from gfn.gflownet.flow_matching import FMGFlowNet +from gfn.gflownet.trajectory_balance import TBGFlowNet +from gfn.gym import HyperGrid +from gfn.preprocessors import KHotPreprocessor +from gfn.samplers import Sampler +from gfn.states import DiscreteStates +from gfn.utils.common import set_seed +from gfn.utils.compile import try_compile_gflownet +from gfn.utils.modules import MLP, DiscreteUniform +from gfn.utils.training import validate + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser() + parser.add_argument("--loss", choices=["FM", "TB", "DB"], default="TB") + parser.add_argument("--ndim", type=int, default=2) + parser.add_argument("--height", type=int, default=32) + parser.add_argument("--R0", type=float, default=0.1) + parser.add_argument("--R1", type=float, default=0.5) + parser.add_argument("--R2", type=float, default=2.0) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--lr", type=float, default=1e-3) + parser.add_argument("--lr_logz", type=float, default=1e-1) + parser.add_argument("--uniform_pb", action="store_true") + parser.add_argument("--n_iterations", type=int, default=100) + parser.add_argument("--validation_interval", type=int, default=100) + parser.add_argument("--validation_samples", type=int, default=200_000) + parser.add_argument("--batch_size", type=int, default=16) + parser.add_argument("--epsilon", type=float, default=0.0) + parser.add_argument( + "--device", + choices=["auto", "cpu", "mps", "cuda"], + default="auto", + help="Device to run on; auto prefers CUDA>MPS>CPU.", + ) + parser.add_argument("--compile", action="store_true", help="Enable torch.compile.") + parser.add_argument( + "--compile-mode", + choices=["default", "reduce-overhead", "max-autotune"], + default="reduce-overhead", + help="Mode passed to torch.compile.", + ) + parser.add_argument("--use-vmap", action="store_true", help="Use vmap TB loss.") + parser.add_argument("--benchmark", action="store_true", help="Run benchmark mode.") + parser.add_argument( + "--benchmark-output", + type=str, + default="hypergrid_benchmark.png", + help="Output path for benchmark plot.", + ) + parser.add_argument( + "--warmup-iters", + type=int, + default=50, + help="Warmup iterations before timing (benchmark mode).", + ) + return parser.parse_args() + + +def init_metrics() -> Dict[str, Any]: + return { + "validation_info": {"l1_dist": float("inf")}, + "discovered_modes": set(), + "total_steps": 0, + "measured_steps": 0, + } + + +def main() -> None: + args = parse_args() + device = resolve_device(args.device) + + if args.benchmark: + scenarios = [ + ("Baseline", False, False), + (f"Compile ({args.compile_mode})", True, False), + ("Vmap", False, True), + (f"Compile+Vmap ({args.compile_mode})", True, True), + ] + results: list[dict[str, Any]] = [] + for label, enable_compile, use_vmap in scenarios: + result = train_with_options( + args, + device, + enable_compile=enable_compile, + use_vmap=use_vmap, + warmup_iters=args.warmup_iters, + quiet=True, + timing=True, + record_history=True, + ) + result["label"] = label + results.append(result) + + baseline_elapsed = results[0]["elapsed"] + print("Benchmark summary (speedups vs baseline):") + for result in results: + speedup = ( + baseline_elapsed / result["elapsed"] + if result["elapsed"] + else float("inf") + ) + print( + f"- {result['label']}: {result['elapsed']:.2f}s " + f"({speedup:.2f}x) | compile_mode={result['compile_mode']} " + f"| vmap={'on' if result['effective_vmap'] else 'off'}" + ) + + plot_benchmark(results, args.benchmark_output) + return + + train_with_options( + args, + device, + enable_compile=args.compile, + use_vmap=args.use_vmap, + warmup_iters=0, + quiet=False, + timing=False, + record_history=False, + ) + + +def train_with_options( + args: argparse.Namespace, + device: torch.device, + *, + enable_compile: bool, + use_vmap: bool, + warmup_iters: int, + quiet: bool, + timing: bool, + record_history: bool, +) -> dict[str, Any]: + set_seed(args.seed) + ( + env, + gflownet, + sampler, + optimizer, + visited_states, + ) = build_training_components(args, device) + metrics = init_metrics() + + compile_mode = args.compile_mode if enable_compile else "none" + if enable_compile: + compile_results = try_compile_gflownet( + gflownet, + mode=args.compile_mode, + ) + if not quiet: + formatted = ", ".join( + f"{name}:{'✓' if success else 'x'}" + for name, success in compile_results.items() + ) + print(f"[compile] {formatted}") + + requested_vmap = use_vmap + if use_vmap and not isinstance(gflownet, TBGFlowNet): + if not quiet: + print("vmap is currently only supported for TBGFlowNet; ignoring flag.") + use_vmap = False + effective_vmap = use_vmap + + if warmup_iters > 0: + run_iterations( + env, + gflownet, + sampler, + optimizer, + visited_states, + metrics, + args, + n_iters=warmup_iters, + use_vmap=use_vmap, + quiet=True, + collect_metrics=False, + track_time=False, + record_history=False, + ) + + elapsed, history = run_iterations( + env, + gflownet, + sampler, + optimizer, + visited_states, + metrics, + args, + n_iters=args.n_iterations, + use_vmap=use_vmap, + quiet=quiet, + collect_metrics=True, + track_time=timing, + record_history=record_history, + ) + + if not quiet: + validation_info = metrics["validation_info"] + l1 = validation_info.get("l1_dist", float("nan")) + print( + f"Finished training | iterations={metrics['measured_steps']} | " + f"modes={len(metrics['discovered_modes'])} / {env.n_modes} | " + f"L1 distance={l1:.6f}" + ) + + return { + "elapsed": elapsed or 0.0, + "losses": history["losses"] if history else None, + "iter_times": history["iter_times"] if history else None, + "compile_mode": compile_mode, + "use_compile": enable_compile, + "requested_vmap": requested_vmap, + "effective_vmap": effective_vmap, + } + + +def run_iterations( + env: HyperGrid, + gflownet: TBGFlowNet | DBGFlowNet | FMGFlowNet, + sampler: Sampler, + optimizer: torch.optim.Optimizer, + visited_states: DiscreteStates, + metrics: Dict[str, Any], + args: argparse.Namespace, + *, + n_iters: int, + use_vmap: bool, + quiet: bool, + collect_metrics: bool, + track_time: bool, + record_history: bool, +) -> tuple[float | None, Dict[str, list[float]] | None]: + if n_iters <= 0: + empty_history = {"losses": [], "iter_times": []} if record_history else None + return (0.0 if track_time else None), empty_history + + iterator: Iterable[int] + if quiet: + iterator = range(n_iters) + else: + iterator = tqdm(range(n_iters), dynamic_ncols=True) + + start_time = time.perf_counter() if track_time else None + last_loss = 0.0 + losses_history: list[float] | None = [] if record_history else None + iter_time_history: list[float] | None = [] if record_history else None + + for _ in iterator: + iter_start = time.perf_counter() if (track_time or record_history) else None + trajectories = sampler.sample_trajectories( + env, + n=args.batch_size, + save_logprobs=False, + save_estimator_outputs=False, + epsilon=args.epsilon, + ) + + terminating_states = cast(DiscreteStates, trajectories.terminating_states) + visited_states.extend(terminating_states) + + optimizer.zero_grad() + loss = compute_loss(gflownet, env, trajectories, use_vmap=use_vmap) + loss.backward() + gflownet.assert_finite_gradients() + torch.nn.utils.clip_grad_norm_(gflownet.parameters(), 1.0) + optimizer.step() + gflownet.assert_finite_parameters() + + metrics["total_steps"] += 1 + if collect_metrics: + metrics["measured_steps"] += 1 + + last_loss = loss.item() + if ( + record_history + and losses_history is not None + and iter_time_history is not None + ): + losses_history.append(last_loss) + iter_duration = ( + (time.perf_counter() - iter_start) if iter_start is not None else 0.0 + ) + iter_time_history.append(iter_duration) + + if collect_metrics: + run_validation_if_needed( + env, + gflownet, + visited_states, + metrics, + args, + quiet=quiet, + ) + + if not quiet and isinstance(iterator, tqdm): + iterator.set_postfix( + { + "loss": last_loss, + "trajectories_sampled": ( + metrics["measured_steps"] * args.batch_size + ), + } + ) + + if track_time: + synchronize_if_needed(env.device) + assert start_time is not None + elapsed_time = time.perf_counter() - start_time + else: + elapsed_time = None + + history = None + if record_history and losses_history is not None and iter_time_history is not None: + history = { + "losses": losses_history, + "iter_times": iter_time_history, + } + + return elapsed_time, history + + +def compute_loss( + gflownet: TBGFlowNet | DBGFlowNet | FMGFlowNet, + env: HyperGrid, + trajectories, + *, + use_vmap: bool, +) -> torch.Tensor: + if use_vmap and isinstance(gflownet, TBGFlowNet): + return trajectory_balance_loss_vmap(gflownet, trajectories) + + return gflownet.loss_from_trajectories( + env, trajectories, recalculate_all_logprobs=False + ) + + +def trajectory_balance_loss_vmap( + gflownet: TBGFlowNet, + trajectories, +) -> torch.Tensor: + log_pf, log_pb = gflownet.get_pfs_and_pbs( + trajectories, recalculate_all_logprobs=False + ) + log_rewards = trajectories.log_rewards + if log_rewards is None: + raise ValueError("Log rewards required for TB loss.") + + def tb_residual( + log_pf_seq: torch.Tensor, log_pb_seq: torch.Tensor, log_reward: torch.Tensor + ) -> torch.Tensor: + return log_pf_seq.sum() - log_pb_seq.sum() - log_reward + + residuals = vmap(tb_residual)( + log_pf.transpose(0, 1), + log_pb.transpose(0, 1), + log_rewards, + ) + + log_z = gflownet.logZ + if isinstance(log_z, ScalarEstimator): + if trajectories.conditions is None: + raise ValueError("Conditional logZ requires conditions tensor.") + log_z_value = log_z(trajectories.conditions) + else: + log_z_value = log_z + + if isinstance(log_z_value, torch.Tensor): + log_z_tensor = log_z_value + else: + log_z_tensor = torch.as_tensor(log_z_value, device=residuals.device) + log_z_tensor = log_z_tensor.squeeze() + scores = (residuals + log_z_tensor).pow(2) + + return scores.mean() + + +def run_validation_if_needed( + env: HyperGrid, + gflownet: TBGFlowNet | DBGFlowNet | FMGFlowNet, + visited_states: DiscreteStates, + metrics: Dict[str, Any], + args: argparse.Namespace, + *, + quiet: bool, +) -> None: + if args.validation_interval <= 0: + return + measured_steps = metrics["measured_steps"] + if measured_steps == 0: + return + if measured_steps % args.validation_interval != 0: + return + + validation_info, _ = validate( + env, + gflownet, + args.validation_samples, + visited_states, + ) + metrics["validation_info"] = validation_info + modes_found = env.modes_found(visited_states) + metrics["discovered_modes"].update(modes_found) + + if not quiet: + str_info = ( + f"Iter {measured_steps}: " + f"L1 distance={validation_info.get('l1_dist', float('nan')):.8f} " + f"modes discovered={len(metrics['discovered_modes'])} / {env.n_modes} " + f"n terminating states {len(visited_states)}" + ) + print(str_info) + + +def build_training_components(args: argparse.Namespace, device: torch.device) -> tuple[ + HyperGrid, + TBGFlowNet | DBGFlowNet | FMGFlowNet, + Sampler, + torch.optim.Optimizer, + DiscreteStates, +]: + env = HyperGrid( + ndim=args.ndim, + height=args.height, + reward_fn_str="original", + reward_fn_kwargs={ + "R0": args.R0, + "R1": args.R1, + "R2": args.R2, + }, + device=device, + calculate_partition=True, + store_all_states=True, + check_action_validity=__debug__, + ) + + preprocessor = KHotPreprocessor(height=env.height, ndim=env.ndim) + module_PF = MLP( + input_dim=preprocessor.output_dim, + output_dim=env.n_actions, + ) + if not args.uniform_pb: + module_PB = MLP( + input_dim=preprocessor.output_dim, + output_dim=env.n_actions - 1, + trunk=module_PF.trunk, + ) + else: + module_PB = DiscreteUniform(output_dim=env.n_actions - 1) + + if args.loss == "FM": + logF_estimator = DiscretePolicyEstimator( + module=module_PF, + n_actions=env.n_actions, + preprocessor=preprocessor, + ) + gflownet: TBGFlowNet | DBGFlowNet | FMGFlowNet = FMGFlowNet(logF_estimator).to( + device + ) + optimizer = torch.optim.Adam(gflownet.logF.parameters(), lr=args.lr) + sampler = Sampler(estimator=logF_estimator) + else: + pf_estimator = DiscretePolicyEstimator( + module_PF, env.n_actions, preprocessor=preprocessor, is_backward=False + ) + pb_estimator = DiscretePolicyEstimator( + module_PB, env.n_actions, preprocessor=preprocessor, is_backward=True + ) + + if args.loss == "DB": + logF_module = MLP( + input_dim=preprocessor.output_dim, + output_dim=1, + ) + logF_estimator = ScalarEstimator( + module=logF_module, + preprocessor=preprocessor, + ) + gflownet = DBGFlowNet(pf=pf_estimator, pb=pb_estimator, logF=logF_estimator) + else: + gflownet = TBGFlowNet(pf=pf_estimator, pb=pb_estimator, init_logZ=0.0) + + gflownet = gflownet.to(device) + optimizer = torch.optim.Adam(gflownet.pf_pb_parameters(), lr=args.lr) + if isinstance(gflownet, DBGFlowNet): + optimizer.add_param_group( + {"params": gflownet.logF.parameters(), "lr": args.lr} + ) + else: + optimizer.add_param_group( + {"params": gflownet.logz_parameters(), "lr": args.lr_logz} + ) + sampler = Sampler(estimator=pf_estimator) + + visited_states = env.states_from_batch_shape((0,)) + return env, gflownet, sampler, optimizer, visited_states + + +def _mps_backend_available() -> bool: + backend = getattr(torch.backends, "mps", None) + return bool(backend and backend.is_available()) + + +def resolve_device(requested: str) -> torch.device: + if requested == "auto": + if torch.cuda.is_available(): + return torch.device("cuda") + if _mps_backend_available(): + return torch.device("mps") + return torch.device("cpu") + + device = torch.device(requested) + if device.type == "cuda" and not torch.cuda.is_available(): + raise RuntimeError("CUDA requested but not available.") + if device.type == "mps" and not _mps_backend_available(): + raise RuntimeError("MPS requested but not available.") + return device + + +def synchronize_if_needed(device: torch.device) -> None: + if device.type == "cuda" and torch.cuda.is_available(): + torch.cuda.synchronize() + elif device.type == "mps" and _mps_backend_available() and hasattr(torch, "mps"): + torch.mps.synchronize() + + +def plot_benchmark(results: list[Dict[str, Any]], output_path: str) -> None: + try: + import matplotlib.pyplot as plt + except ImportError as exc: + raise RuntimeError( + "matplotlib is required for plotting; install it or omit --benchmark." + ) from exc + + def summarize_iteration_times(times: list[float]) -> tuple[float, float]: + if not times: + return 0.0, 0.0 + mean_time = statistics.fmean(times) + std_time = statistics.pstdev(times) if len(times) > 1 else 0.0 + return mean_time, std_time + + labels = [res.get("label", f"Run {idx+1}") for idx, res in enumerate(results)] + times = [res["elapsed"] for res in results] + losses_list = [res.get("losses") or [] for res in results] + iter_times_list = [res.get("iter_times") or [] for res in results] + + fig, axes = plt.subplots(1, 3, figsize=(20, 5)) + + # Subplot 1: total time comparison + colors = ["#6c757d", "#1f77b4", "#2ca02c", "#d62728", "#9467bd", "#8c564b"] + bar_colors = [colors[i % len(colors)] for i in range(len(results))] + bars = axes[0].bar(labels, times, color=bar_colors) + axes[0].set_ylabel("Wall-clock time (s)") + axes[0].set_title("Total Training Time") + baseline_time = times[0] if times else 1.0 + for i, (bar, value) in enumerate(zip(bars, times)): + speedup = baseline_time / value if value else float("inf") + axes[0].text( + bar.get_x() + bar.get_width() / 2, + value, + f"{value:.2f}s\n{speedup:.2f}x", + ha="center", + va="bottom", + ) + + # Subplot 2: training curves + line_styles = ["-", "--", "-.", ":", (0, (3, 1, 1, 1)), (0, (5, 5))] + + for idx, losses in enumerate(losses_list): + if not losses: + continue + axes[1].plot( + range(1, len(losses) + 1), + losses, + label=labels[idx], + color=bar_colors[idx], + linestyle=line_styles[idx % len(line_styles)], + linewidth=2.0, + alpha=0.5, + ) + axes[1].set_title("Training Loss") + axes[1].set_xlabel("Iteration") + axes[1].set_ylabel("Loss") + axes[1].legend() + + # Subplot 3: per-iteration timing with error bars + summary_stats = [summarize_iteration_times(times) for times in iter_times_list] + means_ms = [mean * 1000.0 for mean, _ in summary_stats] + stds_ms = [std * 1000.0 for _, std in summary_stats] + axes[2].bar( + labels, + means_ms, + yerr=stds_ms, + capsize=6, + color=bar_colors, + ) + axes[2].set_ylabel("Per-iteration time (ms)") + axes[2].set_title("Iteration Timing (mean ± std)") + + for ax in axes: + for label in ax.get_xticklabels(): + label.set_rotation(30) + label.set_ha("right") + + output = Path(output_path) + output.parent.mkdir(parents=True, exist_ok=True) + fig.tight_layout() + fig.savefig(output, dpi=150) + plt.close(fig) + print(f"Saved benchmark plot to {output}") + + +if __name__ == "__main__": + main() From f1d0674686fd8402eddebf6bcf9a3844ed7a91c0 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Tue, 25 Nov 2025 18:21:06 -0500 Subject: [PATCH 02/15] updated path --- tutorials/examples/train_hypergrid_optimized.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tutorials/examples/train_hypergrid_optimized.py b/tutorials/examples/train_hypergrid_optimized.py index 9bcfcd5d..6608939f 100644 --- a/tutorials/examples/train_hypergrid_optimized.py +++ b/tutorials/examples/train_hypergrid_optimized.py @@ -64,7 +64,7 @@ def parse_args() -> argparse.Namespace: parser.add_argument( "--benchmark-output", type=str, - default="hypergrid_benchmark.png", + default=str(Path.home() / "hypergrid_benchmark.png"), help="Output path for benchmark plot.", ) parser.add_argument( From df563f01e808d98d1364c1523d1705fcb6f9fd74 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Wed, 26 Nov 2025 07:20:40 -0500 Subject: [PATCH 03/15] compile tests --- src/gfn/utils/compile.py | 10 +- src/gfn/utils/prob_calculations.py | 20 ++ .../examples/train_hypergrid_optimized.py | 282 +++++++++++++++++- 3 files changed, 292 insertions(+), 20 deletions(-) diff --git a/src/gfn/utils/compile.py b/src/gfn/utils/compile.py index 7456c5c3..db6ded2d 100644 --- a/src/gfn/utils/compile.py +++ b/src/gfn/utils/compile.py @@ -27,12 +27,11 @@ def try_compile_gflownet( results: dict[str, bool] = {} for name in components: + + # If the estimator does not exist, we cannot compile it. if not hasattr(gfn, name): - msg = ( - f"GFlowNet of type {type(gfn).__name__} has no '{name}' attribute; " - "expected a valid estimator when attempting compilation." - ) - raise AttributeError(msg) + results[name] = False + continue estimator = getattr(gfn, name) module = getattr(estimator, "module", None) @@ -49,4 +48,5 @@ def try_compile_gflownet( results[name] = True except Exception: results[name] = False + return results diff --git a/src/gfn/utils/prob_calculations.py b/src/gfn/utils/prob_calculations.py index 5f1a75e8..b007f6e0 100644 --- a/src/gfn/utils/prob_calculations.py +++ b/src/gfn/utils/prob_calculations.py @@ -96,6 +96,26 @@ def get_trajectory_pfs( valid_actions = trajectories.actions[action_mask] if valid_states.batch_shape != valid_actions.batch_shape: + print( + "[DEBUG get_trajectory_pfs] state_mask shape:", + state_mask.shape, + "action_mask shape:", + action_mask.shape, + ) + print( + "[DEBUG get_trajectory_pfs] valid_states.batch_shape:", + valid_states.batch_shape, + "valid_actions.batch_shape:", + valid_actions.batch_shape, + ) + print( + "[DEBUG get_trajectory_pfs] trajectories.states.is_sink_state:", + trajectories.states.is_sink_state.shape, + ) + print( + "[DEBUG get_trajectory_pfs] trajectories.actions.is_dummy:", + trajectories.actions.is_dummy.shape, + ) raise AssertionError("Something wrong happening with log_pf evaluations") if trajectories.has_log_probs and not recalculate_all_logprobs: diff --git a/tutorials/examples/train_hypergrid_optimized.py b/tutorials/examples/train_hypergrid_optimized.py index 6608939f..8daad1c4 100644 --- a/tutorials/examples/train_hypergrid_optimized.py +++ b/tutorials/examples/train_hypergrid_optimized.py @@ -9,12 +9,13 @@ import statistics import time from pathlib import Path -from typing import Any, Dict, Iterable, cast +from typing import Any, Dict, Iterable, List, cast import torch from torch.func import vmap from tqdm import tqdm +from gfn.containers import Trajectories from gfn.estimators import DiscretePolicyEstimator, ScalarEstimator from gfn.gflownet.detailed_balance import DBGFlowNet from gfn.gflownet.flow_matching import FMGFlowNet @@ -29,6 +30,215 @@ from gfn.utils.training import validate +# Local subclasses for benchmarking-only optimizations (no core library changes) +class HyperGridWithTensorStep(HyperGrid): + def step_tensor( + self, states: torch.Tensor, actions: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + assert states.dtype == torch.long + device = states.device + batch = states.shape[0] + ndim = self.ndim + exit_idx = self.n_actions - 1 + + if actions.ndim == 1: + actions_idx = actions.view(-1, 1) + else: + assert actions.shape[-1] == 1 + actions_idx = actions + + is_exit = actions_idx.squeeze(-1) == exit_idx + + next_states = states.clone() + non_exit_mask = ~is_exit + if torch.any(non_exit_mask): + sel_states = next_states[non_exit_mask] + sel_actions = actions_idx[non_exit_mask] + sel_states = sel_states.scatter(-1, sel_actions, 1, reduce="add") + next_states[non_exit_mask] = sel_states + if torch.any(is_exit): + # Ensure exit actions land exactly on the sink state so downstream + # `is_sink_state` masks match the action padding semantics assumed + # by `Trajectories` and probability calculations. + next_states[is_exit] = self.sf.to(device=device) + + next_forward_masks = torch.ones( + (batch, self.n_actions), dtype=torch.bool, device=device + ) + next_forward_masks[:, :ndim] = next_states != (self.height - 1) + next_forward_masks[:, ndim] = True + return next_states, next_forward_masks, is_exit + + def forward_action_masks(self, states_tensor: torch.Tensor) -> torch.Tensor: + """Returns forward-action masks for a batch of state tensors.""" + base = states_tensor != (self.height - 1) + return torch.cat( + [ + base, + torch.ones( + (states_tensor.shape[0], 1), + dtype=torch.bool, + device=states_tensor.device, + ), + ], + dim=-1, + ) + + +class ChunkedHyperGridSampler(Sampler): + def __init__(self, estimator, chunk_size: int): + super().__init__(estimator) + self.chunk_size = int(chunk_size) + + def sample_trajectories( # noqa: C901 + self, + env: HyperGridWithTensorStep, + n: int | None = None, + states: DiscreteStates | None = None, + conditions: torch.Tensor | None = None, + save_estimator_outputs: bool = False, # unused in chunked fast path + save_logprobs: bool = False, # unused in chunked fast path + **policy_kwargs: Any, + ): + assert self.chunk_size > 0 + assert hasattr(env, "step_tensor") + epsilon = float(policy_kwargs.get("epsilon", 0.0)) + + if states is None: + assert n is not None + states_obj = env.reset(batch_shape=(n,)) + else: + states_obj = states + + estimator = self.estimator + module = getattr(estimator, "module", None) + assert module is not None + height = int(env.height) + exit_idx = env.n_actions - 1 + + curr_states = states_obj.tensor + batch = curr_states.shape[0] + device = curr_states.device + + forward_masks = env.forward_action_masks(curr_states) + done = torch.zeros(batch, dtype=torch.bool, device=device) + actions_seq: List[torch.Tensor] = [] + dones_seq: List[torch.Tensor] = [] + + def sample_actions_from_logits( + logits: torch.Tensor, masks: torch.Tensor, eps: float + ) -> torch.Tensor: + masked_logits = logits.masked_fill(~masks, float("-inf")) + probs = torch.softmax(masked_logits, dim=-1) + + if eps > 0.0: + valid_counts = masks.sum(dim=-1, keepdim=True).clamp_min(1) + uniform = masks.to(probs.dtype) / valid_counts.to(probs.dtype) + probs = (1.0 - eps) * probs + eps * uniform + + # Ensure exit actions have probability 1.0 so that they land exactly on + # the sink state and downstream `is_sink_state` masks match the action + # padding semantics assumed by `Trajectories` and probability calculations. + nan_rows = torch.isnan(probs).any(dim=-1) + if nan_rows.any(): + probs[nan_rows] = 0.0 + probs[nan_rows, exit_idx] = 1.0 + + return torch.multinomial(probs, 1) + + def _chunk_loop( + current_states: torch.Tensor, + current_masks: torch.Tensor, + done_mask: torch.Tensor, + ): + actions_list: List[torch.Tensor] = [] + dones_list: List[torch.Tensor] = [] + for _ in range(self.chunk_size): + if done_mask.any(): + current_masks = current_masks.clone() + current_masks[done_mask] = False + current_masks[done_mask, exit_idx] = True + khot = torch.nn.functional.one_hot( + current_states, num_classes=height + ).to(dtype=torch.get_default_dtype()) + khot = khot.view(current_states.shape[0], -1) + logits = module(khot) + actions = sample_actions_from_logits(logits, current_masks, epsilon) + next_states, next_masks, is_exit = env.step_tensor( + current_states, actions + ) + record_actions = actions.clone() + + # Replace actions for already-finished trajectories with the dummy + # action so that their timeline matches the padded semantics expected + # by Trajectories (actions.is_dummy aligns with states.is_sink_state[:-1]). + if done_mask.any(): + dummy_val = env.dummy_action.to(device=device) + record_actions[done_mask] = dummy_val + actions_list.append(record_actions) + dones_list.append(is_exit) + + current_states = next_states + current_masks = next_masks + done_mask = done_mask | is_exit + + if bool(done_mask.all().item()): + break + + return current_states, current_masks, done_mask, actions_list, dones_list + + chunk_fn = _chunk_loop + if hasattr(torch, "compile"): + try: + chunk_fn = torch.compile(_chunk_loop, mode="reduce-overhead") # type: ignore + except Exception: + pass + + while not bool(done.all().item()): + curr_states, forward_masks, done, actions_chunk, dones_chunk = chunk_fn( + curr_states, forward_masks, done + ) + if actions_chunk: + actions_seq.extend(actions_chunk) + dones_seq.extend(dones_chunk) + + if actions_seq: + actions_tsr = torch.stack([a for a in actions_seq], dim=0) + T = actions_tsr.shape[0] + s = states_obj.tensor + states_stack = [s] + for t in range(T): + s, fm, is_exit = env.step_tensor(s, actions_tsr[t]) + states_stack.append(s) + states_tsr = torch.stack(states_stack, dim=0) + is_exit_seq = torch.stack(dones_seq, dim=0) + first_exit = torch.argmax(is_exit_seq.to(torch.long), dim=0) + never_exited = ~is_exit_seq.any(dim=0) + first_exit = torch.where( + never_exited, torch.tensor(T - 1, device=device), first_exit + ) + terminating_idx = first_exit + 1 + else: + states_tsr = states_obj.tensor.unsqueeze(0) + actions_tsr = env.actions_from_batch_shape((0, states_tsr.shape[1])).tensor + terminating_idx = torch.zeros( + states_tsr.shape[1], dtype=torch.long, device=device + ) + + trajectories = Trajectories( + env=env, + states=env.states_from_tensor(states_tsr), + conditions=None, + actions=env.actions_from_tensor(actions_tsr), + terminating_idx=terminating_idx, + is_backward=False, + log_rewards=None, + log_probs=None, + estimator_outputs=None, + ) + return trajectories + + def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--loss", choices=["FM", "TB", "DB"], default="TB") @@ -61,6 +271,12 @@ def parse_args() -> argparse.Namespace: ) parser.add_argument("--use-vmap", action="store_true", help="Use vmap TB loss.") parser.add_argument("--benchmark", action="store_true", help="Run benchmark mode.") + parser.add_argument( + "--chunk-size", + type=int, + default=0, + help="Enable chunked sampler fast path when > 0.", + ) parser.add_argument( "--benchmark-output", type=str, @@ -90,14 +306,32 @@ def main() -> None: device = resolve_device(args.device) if args.benchmark: - scenarios = [ - ("Baseline", False, False), - (f"Compile ({args.compile_mode})", True, False), - ("Vmap", False, True), - (f"Compile+Vmap ({args.compile_mode})", True, True), + base_scenarios: list[tuple[str, bool, bool, bool]] = [ + ("Baseline", False, False, False), + (f"Compile ({args.compile_mode})", True, False, False), + ("Vmap", False, True, False), + (f"Compile+Vmap ({args.compile_mode})", True, True, False), ] + if args.chunk_size > 0: + base_scenarios += [ + (f"Chunk ({args.chunk_size})", False, False, True), + ( + f"Compile+Chunk ({args.compile_mode},{args.chunk_size})", + True, + False, + True, + ), + (f"Chunk+Vmap ({args.chunk_size})", False, True, True), + ( + f"Compile+Chunk+Vmap ({args.compile_mode},{args.chunk_size})", + True, + True, + True, + ), + ] + scenarios = base_scenarios results: list[dict[str, Any]] = [] - for label, enable_compile, use_vmap in scenarios: + for label, enable_compile, use_vmap, use_chunk in scenarios: result = train_with_options( args, device, @@ -107,6 +341,7 @@ def main() -> None: quiet=True, timing=True, record_history=True, + use_chunk=use_chunk, ) result["label"] = label results.append(result) @@ -122,7 +357,8 @@ def main() -> None: print( f"- {result['label']}: {result['elapsed']:.2f}s " f"({speedup:.2f}x) | compile_mode={result['compile_mode']} " - f"| vmap={'on' if result['effective_vmap'] else 'off'}" + f"| vmap={'on' if result['effective_vmap'] else 'off'} " + f"| chunk={'on' if result.get('chunk_size_effective', 0) > 0 else 'off'}" ) plot_benchmark(results, args.benchmark_output) @@ -137,6 +373,7 @@ def main() -> None: quiet=False, timing=False, record_history=False, + use_chunk=args.chunk_size > 0, ) @@ -150,6 +387,7 @@ def train_with_options( quiet: bool, timing: bool, record_history: bool, + use_chunk: bool = False, ) -> dict[str, Any]: set_seed(args.seed) ( @@ -158,7 +396,7 @@ def train_with_options( sampler, optimizer, visited_states, - ) = build_training_components(args, device) + ) = build_training_components(args, device, use_chunk=use_chunk) metrics = init_metrics() compile_mode = args.compile_mode if enable_compile else "none" @@ -231,6 +469,7 @@ def train_with_options( "use_compile": enable_compile, "requested_vmap": requested_vmap, "effective_vmap": effective_vmap, + "chunk_size_effective": (args.chunk_size if use_chunk else 0), } @@ -293,8 +532,8 @@ def run_iterations( last_loss = loss.item() if ( record_history - and losses_history is not None - and iter_time_history is not None + and (losses_history is not None) + and (iter_time_history is not None) ): losses_history.append(last_loss) iter_duration = ( @@ -431,14 +670,19 @@ def run_validation_if_needed( print(str_info) -def build_training_components(args: argparse.Namespace, device: torch.device) -> tuple[ +def build_training_components( + args: argparse.Namespace, device: torch.device, *, use_chunk: bool = False +) -> tuple[ HyperGrid, TBGFlowNet | DBGFlowNet | FMGFlowNet, Sampler, torch.optim.Optimizer, DiscreteStates, ]: - env = HyperGrid( + EnvClass = ( + HyperGridWithTensorStep if (use_chunk and args.chunk_size > 0) else HyperGrid + ) + env = EnvClass( ndim=args.ndim, height=args.height, reward_fn_str="original", @@ -477,7 +721,11 @@ def build_training_components(args: argparse.Namespace, device: torch.device) -> device ) optimizer = torch.optim.Adam(gflownet.logF.parameters(), lr=args.lr) - sampler = Sampler(estimator=logF_estimator) + sampler = ( + ChunkedHyperGridSampler(estimator=logF_estimator, chunk_size=args.chunk_size) + if use_chunk and args.chunk_size > 0 + else Sampler(estimator=logF_estimator) + ) else: pf_estimator = DiscretePolicyEstimator( module_PF, env.n_actions, preprocessor=preprocessor, is_backward=False @@ -509,7 +757,11 @@ def build_training_components(args: argparse.Namespace, device: torch.device) -> optimizer.add_param_group( {"params": gflownet.logz_parameters(), "lr": args.lr_logz} ) - sampler = Sampler(estimator=pf_estimator) + sampler = ( + ChunkedHyperGridSampler(estimator=pf_estimator, chunk_size=args.chunk_size) + if use_chunk and args.chunk_size > 0 + else Sampler(estimator=pf_estimator) + ) visited_states = env.states_from_batch_shape((0,)) return env, gflownet, sampler, optimizer, visited_states From 70daba84f7491e33642d057671431bb0da428b05 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Thu, 27 Nov 2025 14:20:45 -0500 Subject: [PATCH 04/15] changes to get_geometric_within_contributions for numerical stability --- src/gfn/gflownet/sub_trajectory_balance.py | 45 ++++++++++++++-------- 1 file changed, 30 insertions(+), 15 deletions(-) diff --git a/src/gfn/gflownet/sub_trajectory_balance.py b/src/gfn/gflownet/sub_trajectory_balance.py index 37296c61..e338c86d 100644 --- a/src/gfn/gflownet/sub_trajectory_balance.py +++ b/src/gfn/gflownet/sub_trajectory_balance.py @@ -492,18 +492,32 @@ def get_geometric_within_contributions( Returns: The contributions tensor of shape (max_len * (max_len+1) / 2, n_trajectories). """ - L = self.lamda + max_len = trajectories.max_length - t_idx = trajectories.terminating_idx + if max_len == 0 or len(trajectories) == 0: + return torch.zeros( + (0, len(trajectories)), + device=trajectories.device, + dtype=torch.get_default_dtype(), + ) - # The following tensor represents the weights given to each possible - # sub-trajectory length. - contributions = (L ** torch.arange(max_len, device=t_idx.device).double()).to( - torch.get_default_dtype() - ) - contributions = contributions.unsqueeze(-1).repeat(1, len(trajectories)) + dtype = torch.get_default_dtype() + device = trajectories.device + t_idx = trajectories.terminating_idx.to(dtype) + + # Clamp lambda away from 0/1 to avoid divisions by zero or log(0) while keeping + # the computation compatible with torch.compile. + lamda = torch.as_tensor(self.lamda, device=device, dtype=dtype) + finfo = torch.finfo(dtype) + lamda = torch.clamp(lamda, finfo.tiny, 1 - finfo.eps) + + # Geometric weights for each possible sub-trajectory length, computed in log + # space to reduce error when lamda is close to 1. + lengths = torch.arange(max_len, device=device, dtype=dtype) + log_weights = lengths * torch.log(lamda) + contributions = torch.exp(log_weights).unsqueeze(-1).repeat(1, len(trajectories)) contributions = contributions.repeat_interleave( - torch.arange(max_len, 0, -1, device=t_idx.device), + torch.arange(max_len, 0, -1, device=device), dim=0, output_size=int(max_len * (max_len + 1) / 2), ) @@ -512,13 +526,14 @@ def get_geometric_within_contributions( # where n is the length of the trajectory corresponding to that column # We can do it the ugly way, or using the cool identity: # https://www.wolframalpha.com/input?i=sum%28%28n-i%29+*+lambda+%5Ei%2C+i%3D0..n%29 - per_trajectory_denom = ( - 1.0 - / (1 - L) ** 2 - * (L * (L ** t_idx.double() - 1) + (1 - L) * t_idx.double()) - ).to(torch.get_default_dtype()) - contributions = contributions / per_trajectory_denom / len(trajectories) + # Closed-form normalization: + # sum_{i=0}^{n-1} (n - i) * lamda^i + lamda_pow_n = torch.pow(lamda, t_idx) + numerator = lamda * (lamda_pow_n - 1) + (1 - lamda) * t_idx + denominator = (1 - lamda) ** 2 + per_trajectory_denom = numerator / denominator + contributions = contributions / per_trajectory_denom / len(trajectories) return contributions def loss( From 91d745a36b12da9c18db8f99fc9d13384a52a29e Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Thu, 27 Nov 2025 20:05:46 -0500 Subject: [PATCH 05/15] a lot of vibe coded fast paths --- multi.plan.md | 34 + src/gfn/env.py | 104 +- src/gfn/estimators.py | 103 +- src/gfn/gym/bitSequence.py | 149 +- src/gfn/gym/box.py | 24 +- src/gfn/gym/diffusion_sampling.py | 43 +- src/gfn/gym/discrete_ebm.py | 57 +- src/gfn/gym/helpers/box_utils.py | 58 +- src/gfn/gym/hypergrid.py | 52 +- src/gfn/gym/line.py | 30 +- src/gfn/gym/perfect_tree.py | 76 +- src/gfn/gym/set_addition.py | 78 +- src/gfn/samplers.py | 278 +- testing/test_environments.py | 332 +- .../examples/train_hypergrid_optimized.py | 1415 +++-- tutorials/examples/train_line.py | 30 +- .../torch_compile_discrete_states.ipynb | 4889 +++++++++++++++++ 17 files changed, 7406 insertions(+), 346 deletions(-) create mode 100644 multi.plan.md create mode 100644 tutorials/notebooks/torch_compile_discrete_states.ipynb diff --git a/multi.plan.md b/multi.plan.md new file mode 100644 index 00000000..0df97141 --- /dev/null +++ b/multi.plan.md @@ -0,0 +1,34 @@ + +# Plan: Extend Benchmark to Diffusion Sampling + +## 1. Refactor scenario management + +- Extract existing HyperGrid logic into a reusable `EnvironmentBenchmark` structure (e.g., dataclass with name, color, scenario list, builder function). +- Keep HyperGrid’s current scenarios (baseline / library fast path / script fast path) but register them under the new structure. +- Update the main loop to iterate over environments sequentially, collecting per-env results (including histories) and tagging each record with both env and scenario identifiers. + +## 2. Add diffusion sampling environment support + +- Review `tutorials/examples/train_diffusion_sampler.py` to reuse its estimator construction (`DiffusionSampling`, `DiffusionPISGradNetForward`, `DiffusionFixedBackwardModule`, `PinnedBrownianMotionForward/Backward`). +- Implement a new `DiffusionEnvConfig` builder under `build_training_components` (or a dedicated helper) that creates the env, forward/backward estimators, optimizer groups, and default hyperparameters mirroring the standalone script. +- Define diffusion-specific scenarios: +- Baseline: standard sampler, no compilation. +- Library Fast Path: use `CompiledChunkSampler` (env already inherits `EnvFastPathMixin`). +- Script Fast Path: implement a local chunked sampler analogous to `ChunkedHyperGridSampler`, but operating on diffusion states/tensors (handle continuous actions, exit padding, dummy actions). Expose it only for diffusion. + +## 3. Integrate new sampler/env wiring + +- Update `build_training_components` to dispatch based on the environment key (hypergrid vs diffusion) so each path can select the correct preprocessor, estimator modules, sampler type, and optimizer parameter groups. +- Ensure the diffusion path still returns metrics compatible with the existing training loop (needs `validate`?—if not available for diffusion, skip validation or provide a stub message). + +## 4. Expand plotting to multi-row layout + +- Adjust `plot_benchmark` to group results by environment and create one row per environment (HyperGrid row retains three scenarios; Diffusion row shows its two/three variants). +- Reuse the existing color mapping for GFlowNet variants; introduce per-environment scenario linestyles (or reuse existing names when overlapping). +- Update subplot titles/labels to mention the environment name so viewers can distinguish rows easily. + +## 5. Final polish + +- Update CLI help text to mention multi-environment benchmarking and any diffusion-specific knobs (e.g., target selection, num steps) if exposed; otherwise, explain defaults in docstring/comments. +- Verify histories are recorded for both environments so the new loss/timing plots aren’t empty. +- Refresh documentation/comments at the top of the script to describe the new diffusion benchmark capability. \ No newline at end of file diff --git a/src/gfn/env.py b/src/gfn/env.py index 646094f0..3146ceb0 100644 --- a/src/gfn/env.py +++ b/src/gfn/env.py @@ -1,7 +1,14 @@ import warnings from abc import ABC, abstractmethod from collections import Counter -from typing import TYPE_CHECKING, Dict, Optional, Tuple, cast +from dataclasses import dataclass +from typing import ( + TYPE_CHECKING, + Dict, + Optional, + Tuple, + cast, +) if TYPE_CHECKING: from gfn.gflownet import GFlowNet @@ -17,6 +24,25 @@ NonValidActionsError = type("NonValidActionsError", (ValueError,), {}) +class EnvFastPathMixin: + """Marker mixin for environments exposing tensor-only fast-path helpers. + + Environments inheriting this mixin are expected to override: + + - ``step_tensor``: vectorized transition operating purely on tensors. + - ``forward_action_masks_tensor``: tensor-based forward action masks. + - ``states_from_tensor_fast``: lightweight wrapper that avoids redundant + allocations when reconstructing ``States`` objects from raw tensors. + + The mixin itself does not provide implementations; it purely signals that + the environment intends to support the fast path and enables nominal checks + such as ``isinstance(env, EnvFastPathMixin)`` without relying on structural + typing. + """ + + fast_path_enabled: bool = True + + class Env(ABC): """Base class for all environments. @@ -37,6 +63,22 @@ class Env(ABC): is_discrete: bool = False + @dataclass + class TensorStepResult: + """Container returned by tensor-level step helpers. + + Attributes: + next_states: Tensor containing the next states produced by the step. + is_sink_state: Optional boolean tensor indicating which rows are sink. + forward_masks: Optional boolean tensor with forward action masks. + backward_masks: Optional boolean tensor with backward action masks. + """ + + next_states: torch.Tensor + is_sink_state: torch.Tensor | None = None + forward_masks: torch.Tensor | None = None + backward_masks: torch.Tensor | None = None + def __init__( self, s0: torch.Tensor | GeometricData, @@ -145,6 +187,51 @@ def actions_from_batch_shape(self, batch_shape: Tuple) -> Actions: """ return self.Actions.make_dummy_actions(batch_shape, device=self.device) + @property + def has_tensor_fast_path(self) -> bool: + """Whether this environment opts into the tensor-only fast API.""" + + return isinstance(self, EnvFastPathMixin) + + def states_from_tensor_fast(self, tensor: torch.Tensor) -> States: + """Fallback helper recreating ``States`` objects from tensors. + + Fast-path environments can override this to avoid redundant mask + recomputation or to attach cached metadata. The default simply calls + ``states_from_tensor``. + """ + + return self.states_from_tensor(tensor) + + def step_tensor( + self, states_tensor: torch.Tensor, actions_tensor: torch.Tensor + ) -> "Env.TensorStepResult": + """Tensor equivalent of `_step` with default object-based fallback. + + Environments can override this method to provide compiler-friendly + implementations that avoid constructing `States`/`Actions`. The default + fallback simply wraps tensors into the standard containers and delegates + to `_step`, ensuring parity with the legacy path. + """ + + states = self.states_from_tensor(states_tensor.clone()) + actions = self.actions_from_tensor(actions_tensor.clone()) + new_states = self._step(states, actions) + return self.TensorStepResult(next_states=new_states.tensor.clone()) + + def forward_action_masks_tensor(self, states_tensor: torch.Tensor) -> torch.Tensor: + """Tensor helper returning forward masks for the supplied states. + + Base environments do not provide a generic implementation because mask + semantics are environment-specific. Subclasses (e.g., ``DiscreteEnv``) + are expected to override this to expose a fallback compatible with the + fast sampler path. + """ + + raise NotImplementedError( + f"{self.__class__.__name__} does not expose tensor forward masks." + ) + @abstractmethod def step(self, states: States, actions: Actions) -> States: """Forward transition function of the environment. @@ -559,6 +646,21 @@ def states_from_batch_shape( assert isinstance(out, DiscreteStates) return out + def states_from_tensor_fast(self, tensor: torch.Tensor) -> DiscreteStates: + """Return `DiscreteStates` without extra bookkeeping for fast paths.""" + + states = self.states_from_tensor(tensor) + assert isinstance(states, DiscreteStates) + return states + + def forward_action_masks_tensor(self, states_tensor: torch.Tensor) -> torch.Tensor: + """Recompute forward masks for the supplied state tensor.""" + + states = self.states_from_tensor(states_tensor.clone()) + self.update_masks(states) + assert states.forward_masks is not None + return states.forward_masks.clone() + def reset( self, batch_shape: int | Tuple[int, ...], diff --git a/src/gfn/estimators.py b/src/gfn/estimators.py index 9cb6839d..ac727a5c 100644 --- a/src/gfn/estimators.py +++ b/src/gfn/estimators.py @@ -241,6 +241,45 @@ def get_current_estimator_output(self, ctx: Any) -> Optional[torch.Tensor]: return getattr(ctx, "current_estimator_output", None) +class FastPolicyMixin(PolicyMixin): + """Optional mixin for policies that ingest tensors directly on fast paths. + + Estimators inheriting this mixin should implement the tensor-oriented hooks + below so samplers can bypass `States`/`Actions` allocation when environments + expose compatible helpers. + """ + + fast_path_enabled: bool = True + + def fast_features( + self, + states_tensor: torch.Tensor, + *, + forward_masks: Optional[torch.Tensor] = None, + backward_masks: Optional[torch.Tensor] = None, + conditions: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Preprocess raw tensors into module-ready features.""" + + raise NotImplementedError( + f"{self.__class__.__name__} does not implement fast_features." + ) + + def fast_distribution( + self, + features: torch.Tensor, + *, + forward_masks: Optional[torch.Tensor] = None, + backward_masks: Optional[torch.Tensor] = None, + **policy_kwargs: Any, + ) -> Distribution: + """Build the action distribution from tensor features.""" + + raise NotImplementedError( + f"{self.__class__.__name__} does not implement fast_distribution." + ) + + class RecurrentPolicyMixin(PolicyMixin): """Mixin for recurrent policies that maintain and update a rollout carry.""" @@ -1227,7 +1266,7 @@ def init_carry( return init_carry_fn(batch_size, device) -class DiffusionPolicyEstimator(PolicyMixin, Estimator): +class DiffusionPolicyEstimator(FastPolicyMixin, Estimator): """Base class for diffusion policy estimators.""" def __init__(self, s_dim: int, module: nn.Module, is_backward: bool = False): @@ -1282,6 +1321,16 @@ def to_probability_distribution( """ raise NotImplementedError + def fast_features( + self, + states_tensor: torch.Tensor, + *, + forward_masks: torch.Tensor | None = None, + backward_masks: torch.Tensor | None = None, + conditions: torch.Tensor | None = None, + ) -> torch.Tensor: + return states_tensor + class PinnedBrownianMotionForward(DiffusionPolicyEstimator): # TODO: support OU process def __init__( @@ -1345,8 +1394,13 @@ def to_probability_distribution( A IsotropicGaussian distribution (distribution of the next states) """ assert len(states.batch_shape) == 1, "States must have a batch_shape of length 1" - s_curr = states.tensor[:, :-1] - t_curr = states.tensor[:, [-1]] + return self._distribution_from_tensor(states.tensor, module_output) + + def _distribution_from_tensor( + self, states_tensor: torch.Tensor, module_output: torch.Tensor + ) -> IsotropicGaussian: + s_curr = states_tensor[:, :-1] + t_curr = states_tensor[:, [-1]] module_output = torch.where( (1.0 - t_curr) < self.dt * 1e-2, # sf case; when t_curr is 1.0 @@ -1359,6 +1413,22 @@ def to_probability_distribution( fwd_std = fwd_std.repeat(fwd_mean.shape[0], 1) return IsotropicGaussian(fwd_mean, fwd_std) + def fast_distribution( + self, + features: torch.Tensor, + *, + states_tensor: torch.Tensor | None = None, + forward_masks: torch.Tensor | None = None, + backward_masks: torch.Tensor | None = None, + **policy_kwargs: Any, + ) -> IsotropicGaussian: + if states_tensor is None: + raise ValueError( + "states_tensor is required for PinnedBrownianMotionForward fast path." + ) + module_output = self.module(features) + return self._distribution_from_tensor(states_tensor, module_output) + class PinnedBrownianMotionBackward(DiffusionPolicyEstimator): # TODO: support OU process def __init__( @@ -1422,10 +1492,15 @@ def to_probability_distribution( A IsotropicGaussian distribution (distribution of the previous states) """ assert len(states.batch_shape) == 1, "States must have a batch_shape of length 1" - s_curr = states.tensor[:, :-1] - t_curr = states.tensor[:, [-1]] # shape: (*batch_shape,) + return self._distribution_from_tensor(states.tensor, module_output) + + def _distribution_from_tensor( + self, states_tensor: torch.Tensor, module_output: torch.Tensor + ) -> IsotropicGaussian: + s_curr = states_tensor[:, :-1] + t_curr = states_tensor[:, [-1]] - is_s0 = (t_curr - self.dt) < self.dt * 1e-2 # s0 case; when t_curr - dt is 0.0 + is_s0 = (t_curr - self.dt) < self.dt * 1e-2 bwd_mean = torch.where( is_s0, s_curr, @@ -1437,3 +1512,19 @@ def to_probability_distribution( self.sigma * (self.dt * (t_curr - self.dt) / t_curr).sqrt(), ) return IsotropicGaussian(bwd_mean, bwd_std) + + def fast_distribution( + self, + features: torch.Tensor, + *, + states_tensor: torch.Tensor | None = None, + forward_masks: torch.Tensor | None = None, + backward_masks: torch.Tensor | None = None, + **policy_kwargs: Any, + ) -> IsotropicGaussian: + if states_tensor is None: + raise ValueError( + "states_tensor is required for PinnedBrownianMotionBackward fast path." + ) + module_output = self.module(features) + return self._distribution_from_tensor(states_tensor, module_output) diff --git a/src/gfn/gym/bitSequence.py b/src/gfn/gym/bitSequence.py index 2aca72bf..cc13eb9a 100644 --- a/src/gfn/gym/bitSequence.py +++ b/src/gfn/gym/bitSequence.py @@ -6,12 +6,14 @@ from gfn.actions import Actions from gfn.containers import Trajectories -from gfn.env import DiscreteEnv +from gfn.env import DiscreteEnv, EnvFastPathMixin from gfn.states import DiscreteStates from gfn.utils.common import is_int_dtype -# This environment is the torchgfn implmentation of the bit sequences task presented in :Malkin, Nikolay & Jain, Moksh & Bengio, Emmanuel & Sun, Chen & Bengio, Yoshua. (2022). -# Trajectory Balance: Improved Credit Assignment in GFlowNets. https://arxiv.org/pdf/2201.13259 +# This environment is the torchgfn implmentation of the bit sequences task presented in +# :Malkin, Nikolay & Jain, Moksh & Bengio, Emmanuel & Sun, Chen & Bengio, Yoshua. +# (2022). Trajectory Balance: Improved Credit Assignment in GFlowNets. +# https://arxiv.org/pdf/2201.13259 class BitSequenceStates(DiscreteStates): @@ -186,7 +188,7 @@ def row_to_binary_string(row, row_mask): return [row_to_binary_string(tensor[i], mask[i]) for i in range(tensor.shape[0])] -class BitSequence(DiscreteEnv): +class BitSequence(EnvFastPathMixin, DiscreteEnv): """Append-only BitSequence environment. This environment represents a sequence of binary words and provides methods to @@ -347,6 +349,18 @@ def update_masks(self, states: BitSequenceStates) -> None: ) states.backward_masks[~is_sink, last_actions] = True + def _lengths_from_tensor(self, states_tensor: torch.Tensor) -> torch.Tensor: + return torch.count_nonzero(states_tensor != -1, dim=-1) + + def forward_action_masks_tensor(self, states_tensor: torch.Tensor) -> torch.Tensor: + batch = states_tensor.shape[0] + device = states_tensor.device + masks = torch.ones((batch, self.n_actions), dtype=torch.bool, device=device) + lengths = self._lengths_from_tensor(states_tensor) + masks[lengths == self.words_per_seq, :-1] = False + masks[lengths < self.words_per_seq, -1] = False + return masks + def step(self, states: BitSequenceStates, actions: Actions) -> BitSequenceStates: """Performs a step in the environment. @@ -368,6 +382,56 @@ def step(self, states: BitSequenceStates, actions: Actions) -> BitSequenceStates ) return self.States(old_tensor) + def step_tensor( + self, states_tensor: torch.Tensor, actions_tensor: torch.Tensor + ) -> DiscreteEnv.TensorStepResult: + if actions_tensor.ndim == 2 and actions_tensor.shape[-1] == 1: + actions_vals = actions_tensor.squeeze(-1) + else: + actions_vals = actions_tensor + + exit_val = self.n_actions - 1 + is_exit = actions_vals == exit_val + next_states = states_tensor.clone() + + lengths = self._lengths_from_tensor(states_tensor) + + non_exit_idx = (~is_exit).nonzero(as_tuple=True)[0] + if len(non_exit_idx) > 0: + insert_pos = lengths[non_exit_idx] + next_states[non_exit_idx, insert_pos] = actions_vals[non_exit_idx] + + if is_exit.any(): + sink_row = torch.full( + (self.words_per_seq,), + exit_val, + dtype=torch.long, + device=states_tensor.device, + ) + next_states[is_exit] = sink_row + + forward_masks = self.forward_action_masks_tensor(next_states) + backward_masks = torch.zeros( + (next_states.shape[0], self.n_actions - 1), + dtype=torch.bool, + device=states_tensor.device, + ) + is_sink_state = torch.all(next_states == exit_val, dim=-1) + non_sink = ~is_sink_state + if non_sink.any(): + new_lengths = self._lengths_from_tensor(next_states[non_sink]) + last_idx = torch.clamp(new_lengths - 1, min=0) + rows = non_sink.nonzero(as_tuple=True)[0] + last_actions = next_states[rows, last_idx] + backward_masks[rows, last_actions] = True + + return self.TensorStepResult( + next_states=next_states, + is_sink_state=is_sink_state, + forward_masks=forward_masks, + backward_masks=backward_masks, + ) + def backward_step( self, states: BitSequenceStates, actions: Actions ) -> BitSequenceStates: @@ -794,6 +858,69 @@ def update_masks(self, states: BitSequenceStates) -> None: states.backward_masks[~is_sink, last_actions] = True states.backward_masks[~is_sink, first_actions + (self.n_actions - 1) // 2] = True + def forward_action_masks_tensor(self, states_tensor: torch.Tensor) -> torch.Tensor: + return super().forward_action_masks_tensor(states_tensor) + + def step_tensor( + self, states_tensor: torch.Tensor, actions_tensor: torch.Tensor + ) -> DiscreteEnv.TensorStepResult: + if actions_tensor.ndim == 2 and actions_tensor.shape[-1] == 1: + actions_vals = actions_tensor.squeeze(-1) + else: + actions_vals = actions_tensor + + exit_val = self.n_actions - 1 + append_threshold = (self.n_actions - 1) // 2 + is_exit = actions_vals == exit_val + append_mask = actions_vals < append_threshold + prepend_mask = (~is_exit) & (~append_mask) + + next_states = states_tensor.clone() + lengths = self._lengths_from_tensor(states_tensor) + + if append_mask.any(): + idx = append_mask.nonzero(as_tuple=True)[0] + insert_pos = lengths[idx] + next_states[idx, insert_pos] = actions_vals[idx] + + if prepend_mask.any(): + idx = prepend_mask.nonzero(as_tuple=True)[0] + next_states[idx, 1:] = next_states[idx, :-1] + next_states[idx, 0] = actions_vals[idx] - append_threshold + + if is_exit.any(): + sink_row = torch.full( + (self.words_per_seq,), + exit_val, + dtype=torch.long, + device=states_tensor.device, + ) + next_states[is_exit] = sink_row + + forward_masks = self.forward_action_masks_tensor(next_states) + backward_masks = torch.zeros( + (next_states.shape[0], self.n_actions - 1), + dtype=torch.bool, + device=states_tensor.device, + ) + is_sink_state = torch.all(next_states == exit_val, dim=-1) + non_sink = ~is_sink_state + if non_sink.any(): + new_lengths = self._lengths_from_tensor(next_states[non_sink]) + last_idx = torch.clamp(new_lengths - 1, min=0) + rows = non_sink.nonzero(as_tuple=True)[0] + last_actions = next_states[rows, last_idx] + first_actions = next_states[rows, 0] + backward_masks[rows, last_actions] = True + backward_masks[rows, first_actions + append_threshold] = True + + return self.TensorStepResult( + next_states=next_states, + is_sink_state=is_sink_state, + forward_masks=forward_masks, + backward_masks=backward_masks, + ) + def step(self, states: BitSequenceStates, actions: Actions) -> BitSequenceStates: """Performs a step in the environment. @@ -808,16 +935,18 @@ def step(self, states: BitSequenceStates, actions: Actions) -> BitSequenceStates old_tensor = states.tensor.clone() append_mask = (actions.tensor < (self.n_actions - 1) // 2).squeeze() prepend_mask = ~append_mask - assert states.length - old_tensor[append_mask & ~is_exit, states.length[append_mask & ~is_exit]] = ( - actions.tensor[append_mask & ~is_exit].squeeze() - ) + assert states.length is not None + append_rows = append_mask & ~is_exit + old_tensor[append_rows, states.length[append_rows]] = actions.tensor[ + append_rows + ].squeeze() old_tensor[prepend_mask & ~is_exit, 1:] = old_tensor[ prepend_mask & ~is_exit, :-1 ] - old_tensor[prepend_mask & ~is_exit, 0] = ( - actions.tensor[prepend_mask & ~is_exit].squeeze() - (self.n_actions - 1) // 2 + prepend_rows = prepend_mask & ~is_exit + old_tensor[prepend_rows, 0] = ( + actions.tensor[prepend_rows].squeeze() - (self.n_actions - 1) // 2 ) old_tensor[is_exit] = torch.full_like( diff --git a/src/gfn/gym/box.py b/src/gfn/gym/box.py index 583d0b90..90ef4df4 100644 --- a/src/gfn/gym/box.py +++ b/src/gfn/gym/box.py @@ -4,11 +4,11 @@ import torch from gfn.actions import Actions -from gfn.env import Env +from gfn.env import Env, EnvFastPathMixin from gfn.states import States -class Box(Env): +class Box(EnvFastPathMixin, Env): """Box environment, corresponding to the one in Section 4.1 of https://arxiv.org/abs/2301.12594 Attributes: @@ -101,6 +101,26 @@ def backward_step(self, states: States, actions: Actions) -> States: """ return self.States(states.tensor - actions.tensor) + def step_tensor( + self, states_tensor: torch.Tensor, actions_tensor: torch.Tensor + ) -> Env.TensorStepResult: + next_states = states_tensor.clone() + exit_action = self.exit_action.to(states_tensor.device).to(states_tensor.dtype) + exit_mask = torch.all(actions_tensor == exit_action, dim=-1) + non_exit = ~exit_mask + + next_states[non_exit] = next_states[non_exit] + actions_tensor[non_exit] + + if exit_mask.any(): + assert isinstance(self.sf, torch.Tensor) + sf_tensor = self.sf.to(states_tensor.device).to(states_tensor.dtype) + next_states[exit_mask] = sf_tensor + + return self.TensorStepResult( + next_states=next_states, + is_sink_state=exit_mask.clone(), + ) + @staticmethod def norm(x: torch.Tensor) -> torch.Tensor: """Computes the L2 norm of the input tensor along the last dimension. diff --git a/src/gfn/gym/diffusion_sampling.py b/src/gfn/gym/diffusion_sampling.py index 72e5975e..4ed88f55 100644 --- a/src/gfn/gym/diffusion_sampling.py +++ b/src/gfn/gym/diffusion_sampling.py @@ -11,7 +11,7 @@ from scipy.stats import wishart from gfn.actions import Actions -from gfn.env import Env +from gfn.env import Env, EnvFastPathMixin from gfn.gym.helpers.diffusion_utils import viz_2d_slice from gfn.states import States from gfn.utils.common import filter_kwargs_for_callable, temporarily_set_seed @@ -672,7 +672,7 @@ def visualize( ###################################### -class DiffusionSampling(Env): +class DiffusionSampling(EnvFastPathMixin, Env): """Diffusion sampling environment. Attributes: @@ -802,6 +802,45 @@ def step(self, states: States, actions: Actions) -> States: next_states_tensor[..., -1] = next_states_tensor[..., -1] + self.dt return self.States(next_states_tensor) + def step_tensor( + self, states_tensor: torch.Tensor, actions_tensor: torch.Tensor + ) -> Env.TensorStepResult: + """Tensor fast-path equivalent of `_step`. + + Mirrors the legacy wrapper by skipping already-sink rows, applying the action + update to the remaining states, and forcing exit actions onto the sink state. + """ + + assert states_tensor.shape[-1] == self.dim + 1 + assert actions_tensor.shape[-1] == self.dim + + device = states_tensor.device + dtype = states_tensor.dtype + sf_tensor = cast(torch.Tensor, self.sf).to(device=device, dtype=dtype) + exit_action = self.exit_action.to(device=device, dtype=dtype) + + # Detect rows that are already padded sink states, and exit rows that should + # transition to the sink regardless of their current state. + sink_mask = torch.all(states_tensor == sf_tensor, dim=-1) + exit_mask = torch.all(actions_tensor == exit_action, dim=-1) + update_mask = ~(sink_mask | exit_mask) + + next_states = states_tensor.clone() + if update_mask.any(): + next_states[update_mask, :-1] = ( + next_states[update_mask, :-1] + actions_tensor[update_mask] + ) + dt = torch.as_tensor(self.dt, device=device, dtype=dtype) + next_states[update_mask, -1] = next_states[update_mask, -1] + dt + + if exit_mask.any(): + next_states[exit_mask] = sf_tensor + + next_sink_mask = sink_mask | exit_mask + return self.TensorStepResult( + next_states=next_states, is_sink_state=next_sink_mask + ) + def backward_step(self, states: States, actions: Actions) -> States: """Backward step function for the SimpleGaussianMixtureModel environment. diff --git a/src/gfn/gym/discrete_ebm.py b/src/gfn/gym/discrete_ebm.py index 3ed1c5eb..9162badf 100644 --- a/src/gfn/gym/discrete_ebm.py +++ b/src/gfn/gym/discrete_ebm.py @@ -5,7 +5,7 @@ import torch.nn as nn from gfn.actions import Actions -from gfn.env import DiscreteEnv +from gfn.env import DiscreteEnv, EnvFastPathMixin from gfn.states import DiscreteStates, States @@ -59,7 +59,7 @@ def forward(self, states: torch.Tensor) -> torch.Tensor: return -(states * tmp).sum(-1) -class DiscreteEBM(DiscreteEnv): +class DiscreteEBM(EnvFastPathMixin, DiscreteEnv): """Environment for discrete energy-based models. This environment is based on the paper https://arxiv.org/pdf/2202.01361.pdf. @@ -132,6 +132,16 @@ def update_masks(self, states: DiscreteStates) -> None: states.backward_masks[..., : self.ndim] = states.tensor == 0 states.backward_masks[..., self.ndim : 2 * self.ndim] = states.tensor == 1 + def forward_action_masks_tensor(self, states_tensor: torch.Tensor) -> torch.Tensor: + batch = states_tensor.shape[0] + device = states_tensor.device + masks = torch.zeros((batch, self.n_actions), dtype=torch.bool, device=device) + available = states_tensor == -1 + masks[:, : self.ndim] = available + masks[:, self.ndim : 2 * self.ndim] = available + masks[:, -1] = torch.all(states_tensor != -1, dim=-1) + return masks + def make_random_states( self, batch_shape: Tuple, device: torch.device | None = None ) -> DiscreteStates: @@ -186,6 +196,49 @@ def step(self, states: States, actions: Actions) -> States: ) return self.States(states.tensor) + def step_tensor( + self, states_tensor: torch.Tensor, actions_tensor: torch.Tensor + ) -> DiscreteEnv.TensorStepResult: + if actions_tensor.ndim == 1: + actions_idx = actions_tensor + else: + actions_idx = actions_tensor.squeeze(-1) + + exit_idx = self.n_actions - 1 + next_states = states_tensor.clone() + device = states_tensor.device + + is_exit = actions_idx == exit_idx + mask0 = (actions_idx < self.ndim) & ~is_exit + mask1 = (actions_idx >= self.ndim) & (actions_idx < 2 * self.ndim) & ~is_exit + + if mask0.any(): + rows = mask0.nonzero(as_tuple=True)[0] + cols = actions_idx[rows] + next_states[rows, cols] = 0 + + if mask1.any(): + rows = mask1.nonzero(as_tuple=True)[0] + cols = actions_idx[rows] - self.ndim + next_states[rows, cols] = 1 + + if is_exit.any(): + next_states[is_exit] = self.sf.to(device=device) + + forward_masks = self.forward_action_masks_tensor(next_states) + backward_masks = torch.zeros_like(forward_masks) + backward_masks[:, : self.ndim] = next_states == 0 + backward_masks[:, self.ndim : 2 * self.ndim] = next_states == 1 + + is_sink_state = torch.all(next_states == self.sf.to(device=device), dim=-1) + + return self.TensorStepResult( + next_states=next_states, + is_sink_state=is_sink_state, + forward_masks=forward_masks, + backward_masks=backward_masks, + ) + def backward_step(self, states: States, actions: Actions) -> States: """Performs a backward step. diff --git a/src/gfn/gym/helpers/box_utils.py b/src/gfn/gym/helpers/box_utils.py index 57933bbf..379d283d 100644 --- a/src/gfn/gym/helpers/box_utils.py +++ b/src/gfn/gym/helpers/box_utils.py @@ -8,7 +8,7 @@ from torch import Size, Tensor from torch.distributions import Beta, Categorical, Distribution, MixtureSameFamily -from gfn.estimators import Estimator, PolicyMixin +from gfn.estimators import Estimator, FastPolicyMixin from gfn.gym import Box from gfn.states import States from gfn.utils.modules import MLP @@ -936,7 +936,7 @@ def split_PF_module_output( return (exit_probability, mixture_logits, alpha_theta, beta_theta, alpha_r, beta_r) -class BoxPFEstimator(Estimator, PolicyMixin): +class BoxPFEstimator(FastPolicyMixin, Estimator): r"""Estimator for `P_F` for the Box environment. This estimator uses the `DistributionWrapper` distribution. @@ -978,6 +978,7 @@ def __init__( max_concentration: The maximum concentration for the Beta distributions. """ super().__init__(module) + self.env = env self._n_comp_max = max(n_components_s0, n_components) self.n_components_s0 = n_components_s0 self.n_components = n_components @@ -1059,8 +1060,33 @@ def _normalize(x: Tensor) -> Tensor: self.n_components_s0, ) + def fast_features( + self, + states_tensor: torch.Tensor, + *, + forward_masks: torch.Tensor | None = None, + backward_masks: torch.Tensor | None = None, + conditions: torch.Tensor | None = None, + ) -> torch.Tensor: + return states_tensor + + def fast_distribution( + self, + features: torch.Tensor, + *, + states_tensor: torch.Tensor | None = None, + forward_masks: torch.Tensor | None = None, + backward_masks: torch.Tensor | None = None, + **policy_kwargs: Any, + ) -> Distribution: + if states_tensor is None: + raise ValueError("states_tensor is required for BoxPFEstimator fast path.") + module_output = self.module(features) + states = self.env.states_from_tensor_fast(states_tensor) + return self.to_probability_distribution(states, module_output) + -class BoxPBEstimator(Estimator, PolicyMixin): +class BoxPBEstimator(FastPolicyMixin, Estimator): r"""Estimator for `P_B` for the Box environment. This estimator uses the `QuarterCircle(northeastern=False)` distribution. @@ -1096,6 +1122,7 @@ def __init__( """ super().__init__(module, is_backward=True) self.module = module + self.env = env self.n_components = n_components self.min_concentration = min_concentration @@ -1145,3 +1172,28 @@ def _normalize(x: Tensor) -> Tensor: alpha=alpha, beta=beta, ) + + def fast_features( + self, + states_tensor: torch.Tensor, + *, + forward_masks: torch.Tensor | None = None, + backward_masks: torch.Tensor | None = None, + conditions: torch.Tensor | None = None, + ) -> torch.Tensor: + return states_tensor + + def fast_distribution( + self, + features: torch.Tensor, + *, + states_tensor: torch.Tensor | None = None, + forward_masks: torch.Tensor | None = None, + backward_masks: torch.Tensor | None = None, + **policy_kwargs: Any, + ) -> Distribution: + if states_tensor is None: + raise ValueError("states_tensor is required for BoxPBEstimator fast path.") + module_output = self.module(features) + states = self.env.states_from_tensor_fast(states_tensor) + return self.to_probability_distribution(states, module_output) diff --git a/src/gfn/gym/hypergrid.py b/src/gfn/gym/hypergrid.py index 1215b5f5..e2d65553 100644 --- a/src/gfn/gym/hypergrid.py +++ b/src/gfn/gym/hypergrid.py @@ -14,7 +14,7 @@ import torch from gfn.actions import Actions -from gfn.env import DiscreteEnv +from gfn.env import DiscreteEnv, EnvFastPathMixin from gfn.states import DiscreteStates from gfn.utils.common import ensure_same_device @@ -48,7 +48,7 @@ def smallest_multiplier_to_integers(float_vector, precision=3): return smallest_multiplier -class HyperGrid(DiscreteEnv): +class HyperGrid(EnvFastPathMixin, DiscreteEnv): """HyperGrid environment from the GFlowNets paper. The states are represented as 1-d tensors of length `ndim` with values in @@ -159,6 +159,15 @@ def update_masks(self, states: DiscreteStates) -> None: ) states.backward_masks = states.tensor != 0 + def forward_action_masks_tensor(self, states_tensor: torch.Tensor) -> torch.Tensor: + """Tensor-only equivalent of `update_masks` for forward masks.""" + + base = states_tensor != (self.height - 1) + exit_column = torch.ones( + (states_tensor.shape[0], 1), dtype=torch.bool, device=states_tensor.device + ) + return torch.cat([base, exit_column], dim=-1) + def make_random_states( self, batch_shape: Tuple[int, ...], device: torch.device | None = None ) -> DiscreteStates: @@ -191,6 +200,45 @@ def step(self, states: DiscreteStates, actions: Actions) -> DiscreteStates: assert new_states_tensor.shape == states.tensor.shape return self.States(new_states_tensor) + def step_tensor( + self, states_tensor: torch.Tensor, actions_tensor: torch.Tensor + ) -> DiscreteEnv.TensorStepResult: + """Tensor-only transition combined with mask outputs for fast paths.""" + + assert states_tensor.dtype == torch.long + if actions_tensor.ndim == 1: + actions_idx = actions_tensor.view(-1, 1) + else: + assert actions_tensor.shape[-1] == 1 + actions_idx = actions_tensor + + exit_idx = self.n_actions - 1 + is_exit_action = actions_idx.squeeze(-1) == exit_idx + next_states = states_tensor.clone() + + non_exit_mask = ~is_exit_action + if torch.any(non_exit_mask): + sel_states = next_states[non_exit_mask] + sel_actions = actions_idx[non_exit_mask] + sel_states = sel_states.scatter(-1, sel_actions, 1, reduce="add") + next_states[non_exit_mask] = sel_states + + if torch.any(is_exit_action): + next_states[is_exit_action] = self.sf.to(device=states_tensor.device) + + forward_masks = self.forward_action_masks_tensor(next_states) + backward_masks = next_states != 0 + is_sink_state = (next_states == self.sf.to(device=states_tensor.device)).all( + dim=-1 + ) + + return self.TensorStepResult( + next_states=next_states, + is_sink_state=is_sink_state, + forward_masks=forward_masks, + backward_masks=backward_masks, + ) + def backward_step(self, states: DiscreteStates, actions: Actions) -> DiscreteStates: """Performs a backward step in the environment. diff --git a/src/gfn/gym/line.py b/src/gfn/gym/line.py index e4b34b8c..b2af9df5 100644 --- a/src/gfn/gym/line.py +++ b/src/gfn/gym/line.py @@ -4,11 +4,11 @@ from torch.distributions import Normal # TODO: extend to Beta from gfn.actions import Actions -from gfn.env import Env +from gfn.env import Env, EnvFastPathMixin from gfn.states import States -class Line(Env): +class Line(EnvFastPathMixin, Env): """Mixture of Gaussians Line environment. Attributes: @@ -84,6 +84,32 @@ def step(self, states: States, actions: Actions) -> States: assert states.tensor.shape == states.batch_shape + (2,) return self.States(states.tensor) + def step_tensor( + self, states_tensor: torch.Tensor, actions_tensor: torch.Tensor + ) -> Env.TensorStepResult: + next_states = states_tensor.clone() + if actions_tensor.ndim == 2 and actions_tensor.shape[-1] == 1: + action_vals = actions_tensor.squeeze(-1) + else: + action_vals = actions_tensor + + exit_val = float(self.exit_action.item()) + exit_mask = action_vals == exit_val + non_exit = ~exit_mask + + next_states[non_exit, 0] = next_states[non_exit, 0] + action_vals[non_exit] + next_states[non_exit, 1] = next_states[non_exit, 1] + 1 + + if exit_mask.any(): + assert isinstance(self.sf, torch.Tensor) + sf_tensor = self.sf.to(states_tensor.device) + sf_tensor = sf_tensor.to(states_tensor.dtype) + next_states[exit_mask] = sf_tensor + + return self.TensorStepResult( + next_states=next_states, is_sink_state=exit_mask.clone() + ) + def backward_step(self, states: States, actions: Actions) -> States: """Performs a backward step in the environment. diff --git a/src/gfn/gym/perfect_tree.py b/src/gfn/gym/perfect_tree.py index 7c3512da..220ef5a0 100644 --- a/src/gfn/gym/perfect_tree.py +++ b/src/gfn/gym/perfect_tree.py @@ -2,11 +2,11 @@ import torch -from gfn.env import Actions, DiscreteEnv, DiscreteStates +from gfn.env import Actions, DiscreteEnv, DiscreteStates, EnvFastPathMixin from gfn.states import States -class PerfectBinaryTree(DiscreteEnv): +class PerfectBinaryTree(EnvFastPathMixin, DiscreteEnv): r"""Perfect Tree Environment. This environment is a perfect binary tree, where there is a bijection between @@ -75,6 +75,8 @@ def __init__( self.inverse_transition_table, self.term_states, ) = self._build_tree() + self._leaf_lower = 2**self.depth - 1 + self._leaf_upper = 2 ** (self.depth + 1) - 1 def _build_tree(self) -> tuple[dict, dict, DiscreteStates]: """Builds the tree and the transition tables. @@ -192,6 +194,76 @@ def update_masks(self, states: DiscreteStates) -> None: # Initial state has no available backward action states.backward_masks[initial_state_mask] = False + def _is_leaf_tensor(self, states_tensor: torch.Tensor) -> torch.Tensor: + values = states_tensor.view(-1) + return (values >= self._leaf_lower) & (values < self._leaf_upper) + + def forward_action_masks_tensor(self, states_tensor: torch.Tensor) -> torch.Tensor: + batch = states_tensor.shape[0] + device = states_tensor.device + masks = torch.zeros((batch, self.n_actions), dtype=torch.bool, device=device) + leaf_mask = self._is_leaf_tensor(states_tensor) + sink_mask = (states_tensor == self.sf.to(device)).all(dim=-1) + non_leaf = ~(leaf_mask | sink_mask) + masks[non_leaf, : self.branching_factor] = True + masks[leaf_mask | sink_mask, -1] = True + return masks + + def step_tensor( + self, states_tensor: torch.Tensor, actions_tensor: torch.Tensor + ) -> DiscreteEnv.TensorStepResult: + if actions_tensor.ndim == 1: + actions_idx = actions_tensor.view(-1, 1) + else: + assert actions_tensor.shape[-1] == 1 + actions_idx = actions_tensor + + exit_idx = self.n_actions - 1 + device = states_tensor.device + next_states = states_tensor.clone() + actions_flat = actions_idx.squeeze(-1) + state_vals = next_states.squeeze(-1) + + is_exit = actions_flat == exit_idx + non_exit = ~is_exit + if non_exit.any(): + parents = state_vals[non_exit] + child_idx = parents.clone() + left_mask = actions_flat[non_exit] == 0 + right_mask = actions_flat[non_exit] == 1 + if left_mask.any(): + child_idx[left_mask] = 2 * parents[left_mask] + 1 + if right_mask.any(): + child_idx[right_mask] = 2 * parents[right_mask] + 2 + next_states[non_exit, 0] = child_idx + + if is_exit.any(): + next_states[is_exit] = self.sf.to(device=device) + + forward_masks = self.forward_action_masks_tensor(next_states) + backward_masks = torch.zeros( + (next_states.shape[0], self.branching_factor), + dtype=torch.bool, + device=device, + ) + next_vals = next_states.squeeze(-1) + sink_mask = (next_states == self.sf.to(device)).all(dim=-1) + initial_mask = next_vals == self.s0.item() + even_mask = (next_vals % 2 == 0) & ~sink_mask + odd_mask = (next_vals % 2 == 1) & ~sink_mask + backward_masks[even_mask, 1] = True + backward_masks[odd_mask, 0] = True + backward_masks[initial_mask] = False + + is_sink_state = sink_mask + + return self.TensorStepResult( + next_states=next_states, + is_sink_state=is_sink_state, + forward_masks=forward_masks, + backward_masks=backward_masks, + ) + def get_states_indices(self, states: States): """Returns the indices of the states. diff --git a/src/gfn/gym/set_addition.py b/src/gfn/gym/set_addition.py index 65f0707a..71ea58df 100644 --- a/src/gfn/gym/set_addition.py +++ b/src/gfn/gym/set_addition.py @@ -2,10 +2,10 @@ import torch -from gfn.env import Actions, DiscreteEnv, DiscreteStates +from gfn.env import Actions, DiscreteEnv, DiscreteStates, EnvFastPathMixin -class SetAddition(DiscreteEnv): +class SetAddition(EnvFastPathMixin, DiscreteEnv): """Append only MDP, similarly to what is described in Remark 8 of Shen et al. 2023 [Towards Understanding and Improving GFlowNet Training](https://proceedings.mlr.press/v202/shen23a.html) @@ -118,6 +118,41 @@ def update_masks(self, states: DiscreteStates) -> None: states.backward_masks[..., : self.n_items] = states.tensor != 0 + def forward_action_masks_tensor(self, states_tensor: torch.Tensor) -> torch.Tensor: + """Tensor equivalent of `update_masks` for forward masks.""" + + batch = states_tensor.shape[0] + device = states_tensor.device + masks = torch.zeros((batch, self.n_actions), dtype=torch.bool, device=device) + + n_items_per_state = states_tensor.sum(dim=-1) + states_that_must_end = n_items_per_state >= self.max_traj_len + states_that_may_continue = ~states_that_must_end + + if states_that_may_continue.any(): + cont_states = states_tensor[states_that_may_continue] == 0 + cont_masks = torch.zeros( + (cont_states.shape[0], self.n_actions), + dtype=torch.bool, + device=device, + ) + cont_masks[:, : self.n_items] = cont_states + masks[states_that_may_continue] = cont_masks + + if states_that_must_end.any(): + end_masks = torch.zeros( + (int(states_that_must_end.sum().item()), self.n_actions), + dtype=torch.bool, + device=device, + ) + end_masks[:, -1] = True + masks[states_that_must_end] = end_masks + + if not self.fixed_length: + masks[..., -1] = True + + return masks + def step(self, states: DiscreteStates, actions: Actions) -> DiscreteStates: """Performs a step in the environment. @@ -131,6 +166,45 @@ def step(self, states: DiscreteStates, actions: Actions) -> DiscreteStates: new_states_tensor = states.tensor.scatter(-1, actions.tensor, 1, reduce="add") return self.States(new_states_tensor) + def step_tensor( + self, states_tensor: torch.Tensor, actions_tensor: torch.Tensor + ) -> DiscreteEnv.TensorStepResult: + """Tensor-only transition mirroring the legacy `_step` path.""" + + if actions_tensor.ndim == 1: + actions_idx = actions_tensor.view(-1, 1) + else: + assert actions_tensor.shape[-1] == 1 + actions_idx = actions_tensor + + exit_idx = self.n_actions - 1 + is_exit = actions_idx.squeeze(-1) == exit_idx + next_states = states_tensor.clone() + + non_exit_mask = ~is_exit + if torch.any(non_exit_mask): + sel_states = next_states[non_exit_mask] + sel_actions = actions_idx[non_exit_mask] + sel_states = sel_states.scatter(-1, sel_actions, 1, reduce="add") + next_states[non_exit_mask] = sel_states + + if torch.any(is_exit): + next_states[is_exit] = self.sf.to(device=states_tensor.device) + + forward_masks = self.forward_action_masks_tensor(next_states) + backward_masks = torch.zeros_like(forward_masks) + backward_masks[..., : self.n_items] = next_states != 0 + is_sink_state = (next_states == self.sf.to(device=states_tensor.device)).all( + dim=-1 + ) + + return self.TensorStepResult( + next_states=next_states, + is_sink_state=is_sink_state, + forward_masks=forward_masks, + backward_masks=backward_masks, + ) + def backward_step(self, states: DiscreteStates, actions: Actions) -> DiscreteStates: """Performs a backward step in the environment. diff --git a/src/gfn/samplers.py b/src/gfn/samplers.py index a4d74409..053ef290 100644 --- a/src/gfn/samplers.py +++ b/src/gfn/samplers.py @@ -1,11 +1,12 @@ +import warnings from typing import Any, List, Optional, Tuple, cast import torch from gfn.actions import Actions from gfn.containers import Trajectories -from gfn.env import Env -from gfn.estimators import Estimator, PolicyEstimatorProtocol +from gfn.env import Env, EnvFastPathMixin +from gfn.estimators import Estimator, FastPolicyMixin, PolicyEstimatorProtocol from gfn.states import GraphStates, States from gfn.utils.common import ensure_same_device from gfn.utils.graphs import graph_states_share_storage @@ -890,3 +891,276 @@ def _combine_prev_and_recon_trajectories( # noqa: C901 ) return new_trajectories, new_trajectories_log_pf, new_trajectories_log_pb + + +class CompiledChunkSampler(Sampler): + """Chunked tensor sampler that stays on the fast path for torch.compile.""" + + def __init__( + self, + estimator: Estimator, + *, + chunk_size: int = 32, + compile_mode: str = "reduce-overhead", + ) -> None: + super().__init__(estimator) + self.chunk_size = int(chunk_size) + self.compile_mode = compile_mode + + def sample_trajectories( + self, + env: Env, + n: Optional[int] = None, + states: Optional[States] = None, + conditions: Optional[torch.Tensor] = None, + save_estimator_outputs: bool = False, + save_logprobs: bool = False, + **policy_kwargs: Any, + ) -> Trajectories: + + # Log-probs: we’d need to store each chunk’s dist (or sampled actions) plus a + # boolean mask of which rows were active, then call policy.fast_distribution + # (...).log_prob(...) during or after the chunk loop. Because done rows get + # forced to the exit action, we’d have to mask those out when accumulating + # log-probs so the padded semantics match Trajectories.log_probs. That means + # keeping per-step tensors shaped (chunk_len, batch, action_dim) and writing + # them into the context at the end. + + # Estimator outputs: same idea—capture the raw tensor returned by policy. + # fast_features/fast_distribution (whatever we consider the “estimator output”) + # for active rows, pad them back to batch size, and append to a list per chunk + # so we can stack them like the legacy sampler. + if save_estimator_outputs or save_logprobs: + raise NotImplementedError( + "CompiledChunkSampler does not yet record log-probs or estimator outputs." + ) + + if not isinstance(env, EnvFastPathMixin): + raise TypeError( + "CompiledChunkSampler requires environments implementing EnvFastPathMixin." + ) + + if not isinstance(self.estimator, FastPolicyMixin): + raise TypeError( + "CompiledChunkSampler requires estimators implementing FastPolicyMixin." + ) + + assert self.chunk_size > 0, "chunk_size must be positive" + + policy = cast(FastPolicyMixin, self.estimator) + + if states is None: + assert n is not None, "Either `n` or `states` must be provided." + states_obj = env.reset(batch_shape=(n,)) + else: + states_obj = states + assert len(states_obj.batch_shape) == 1, "States batch must be 1-D." + + batch = states_obj.batch_shape[0] + device = states_obj.device + + if conditions is not None: + assert ( + conditions.shape[0] == batch + ), "Conditions batch dimension must match states batch size." + ensure_same_device(device, conditions.device) + + curr_states = states_obj.tensor + done = states_obj.is_sink_state.clone() + exit_action_value = env.exit_action.to(device=curr_states.device) + dummy_action_value = env.dummy_action.to(device=curr_states.device) + + # `step_actions_seq` keeps the raw sampled actions (with exits injected for + # finished rows) so we can exactly replay the tensor-forward environment when + # rebuilding the state stack after the chunk loop. `recorded_actions_seq` + # mirrors those actions but rewrites already-finished rows with the env dummy + # action so that downstream Trajectories consumers (DBG/SubTB losses) never see + # transitions originating from sink states. + recorded_actions_seq: List[torch.Tensor] = [] + step_actions_seq: List[torch.Tensor] = [] + sink_seq: List[torch.Tensor] = [] + + chunk_size = int(policy_kwargs.pop("chunk_size", self.chunk_size)) + + def _chunk_loop(current_states: torch.Tensor, done_mask: torch.Tensor) -> tuple[ + torch.Tensor, + torch.Tensor, + List[torch.Tensor], + List[torch.Tensor], + List[torch.Tensor], + ]: + """ + This function is the core of the chunked sampler. It is responsible for + sampling actions for a chunk of states. It is called in a loop until all + states are done. It returns the current states, a boolean mask indicating + which states are done, the actions sampled for the chunk, and a boolean + mask indicating which states are sinks. + + The purpose of this function is to serve as a torch.compile-ed function to + speed up the sampling process. It is called in a loop until all states are + done. + + Args: + current_states: The current states to sample actions for. + done_mask: A boolean mask indicating which states are done. + + Returns: + """ + local_step_actions: List[torch.Tensor] = [] + local_recorded_actions: List[torch.Tensor] = [] + local_sinks: List[torch.Tensor] = [] + + for _ in range(chunk_size): + if bool(done_mask.all().item()): + break + + state_view = current_states + features = policy.fast_features( + state_view, + forward_masks=None, + backward_masks=None, + conditions=conditions, + ) + dist = policy.fast_distribution( + features, + forward_masks=None, + backward_masks=None, + states_tensor=state_view, + **policy_kwargs, + ) + + actions_tensor = dist.sample() + + if done_mask.any(): + # Broadcast the boolean mask and the per-env exit/dummy templates so + # they match the estimator's sampled action tensor shape (covers + # both scalar Discrete actions and potential multi-dim action + # heads). + mask = done_mask + while mask.ndim < actions_tensor.ndim: + mask = mask.unsqueeze(-1) + + exit_fill = exit_action_value.to( + device=actions_tensor.device, dtype=actions_tensor.dtype + ) + while exit_fill.ndim < actions_tensor.ndim: + exit_fill = exit_fill.unsqueeze(0) + + dummy_fill = dummy_action_value.to( + device=actions_tensor.device, dtype=actions_tensor.dtype + ) + while dummy_fill.ndim < actions_tensor.ndim: + dummy_fill = dummy_fill.unsqueeze(0) + + step_actions = torch.where(mask, exit_fill, actions_tensor) + record_actions = torch.where(mask, dummy_fill, actions_tensor) + else: + step_actions = actions_tensor + record_actions = actions_tensor + + # Only the step actions (exit-padded) are used to advance the tensor + # env. The recorded actions (dummy-padded) are used to reconstruct the + # state stack after the chunk loop. + step_res = env.step_tensor(current_states, step_actions) + current_states = step_res.next_states + sinks = step_res.is_sink_state + if sinks is None: + sinks = env.states_from_tensor(current_states).is_sink_state + + done_mask = done_mask | sinks + local_step_actions.append(step_actions) + local_recorded_actions.append(record_actions) + local_sinks.append(sinks) + + return ( + current_states, + done_mask, + local_step_actions, + local_recorded_actions, + local_sinks, + ) + + # Fallback to the non-compiled version if compilation fails. + chunk_fn = _chunk_loop + if hasattr(torch, "compile"): + try: + chunk_fn = torch.compile(_chunk_loop, mode=self.compile_mode) # type: ignore[arg-type] + except Exception: + # If compilation fails, use the non-compiled version. + warnings.warn( + "Compilation of chunk_loop failed, using non-compiled version.", + stacklevel=2, + ) + chunk_fn = _chunk_loop + + # Main loop: call the compiled function until all states are done. + while not bool(done.all().item()): + ( + curr_states, + done, + step_actions_chunk, + recorded_actions_chunk, + sinks_chunk, + ) = chunk_fn(curr_states, done) + if step_actions_chunk: + step_actions_seq.extend(step_actions_chunk) + recorded_actions_seq.extend(recorded_actions_chunk) + sink_seq.extend(sinks_chunk) + + if recorded_actions_seq: + actions_tsr = torch.stack(recorded_actions_seq, dim=0) + T = actions_tsr.shape[0] + + s = states_obj.tensor + states_stack = [s] + for t in range(T): + # Re-simulate using the true step actions so reconstructed states match + # the chunk rollout exactly even though padded (recorded) actions may + # differ. + step = env.step_tensor(s, step_actions_seq[t]) + s = step.next_states + states_stack.append(s) + states_tsr = torch.stack(states_stack, dim=0) + + sinks_tsr = torch.stack(sink_seq, dim=0) + first_sink = torch.argmax(sinks_tsr.to(torch.long), dim=0) + never_sink = ~sinks_tsr.any(dim=0) + first_sink = torch.where( + never_sink, + torch.tensor(T - 1, device=device), + first_sink, + ) + terminating_idx = first_sink + 1 + else: + states_tsr = states_obj.tensor.unsqueeze(0) + actions_tsr = env.actions_from_batch_shape((0, batch)).tensor + terminating_idx = torch.zeros(batch, dtype=torch.long, device=device) + + # Ensure the stacked (dummy-padded) actions respect the environment's action + # shape before wrapping them into an Actions container (e.g., discrete envs + # expect (..., 1)). Without this guard DB/SubTB estimators would fail when the + # chunk sampler returns rank-1 tensors. + action_shape = getattr(env, "action_shape", None) + if action_shape: + tail_shape = tuple(actions_tsr.shape[-len(action_shape) :]) + if tail_shape != tuple(action_shape): + if tuple(action_shape) == (1,): + actions_tsr = actions_tsr.unsqueeze(-1) + else: + raise ValueError( + "CompiledChunkSampler produced actions with shape " + f"{actions_tsr.shape}, expected trailing dims {action_shape}." + ) + + trajectories = Trajectories( + env=env, + states=env.states_from_tensor(states_tsr), + conditions=conditions, + actions=env.actions_from_tensor(actions_tsr), + terminating_idx=terminating_idx, + is_backward=policy.is_backward, + log_rewards=None, + log_probs=None, + estimator_outputs=None, + ) + return trajectories diff --git a/testing/test_environments.py b/testing/test_environments.py index 42abc54f..f889447e 100644 --- a/testing/test_environments.py +++ b/testing/test_environments.py @@ -1,4 +1,4 @@ -from typing import Literal +from typing import Literal, cast import numpy as np import pytest @@ -7,7 +7,8 @@ from gfn.actions import GraphActions, GraphActionType from gfn.env import Env, NonValidActionsError -from gfn.gym import Box, DiscreteEBM, HyperGrid +from gfn.gym import BitSequence, BitSequencePlus, Box, DiscreteEBM, HyperGrid, Line +from gfn.gym.diffusion_sampling import DiffusionSampling from gfn.gym.graph_building import GraphBuilding from gfn.gym.perfect_tree import PerfectBinaryTree from gfn.gym.set_addition import SetAddition @@ -133,6 +134,45 @@ def test_HyperGrid_bwd_step(): states = env._backward_step(states, failing_actions) +def test_HyperGrid_fast_path_matches_legacy(): + NDIM = 3 + ENV_HEIGHT = 5 + BATCH_SIZE = 64 + + env = HyperGrid(ndim=NDIM, height=ENV_HEIGHT) + states = env.reset(batch_shape=BATCH_SIZE, random=True, seed=123) + + assert states.forward_masks is not None + tensor_masks = env.forward_action_masks_tensor(states.tensor) + assert states.forward_masks is not None + legacy_forward_masks = cast(torch.Tensor, states.forward_masks) + assert torch.equal(tensor_masks, legacy_forward_masks) + + action_dist = torch.distributions.Categorical( + probs=states.forward_masks.to(dtype=torch.float32) + ) + actions_tensor = action_dist.sample().unsqueeze(-1) + + legacy_next = env._step(states, env.actions_from_tensor(actions_tensor.clone())) + assert legacy_next.forward_masks is not None + assert legacy_next.backward_masks is not None + legacy_step_forward = cast(torch.Tensor, legacy_next.forward_masks) + legacy_step_backward = cast(torch.Tensor, legacy_next.backward_masks) + + fast = env.step_tensor(states.tensor, actions_tensor) + + assert torch.equal(fast.next_states, legacy_next.tensor) + assert fast.is_sink_state is not None + fast_is_sink = cast(torch.Tensor, fast.is_sink_state) + assert torch.equal(fast_is_sink, legacy_next.is_sink_state) + assert fast.forward_masks is not None + fast_forward_masks = cast(torch.Tensor, fast.forward_masks) + assert fast.backward_masks is not None + fast_backward_masks = cast(torch.Tensor, fast.backward_masks) + assert torch.equal(fast_forward_masks, legacy_step_forward) + assert torch.equal(fast_backward_masks, legacy_step_backward) + + def test_DiscreteEBM_fwd_step(): NDIM = 2 BATCH_SIZE = 4 @@ -193,6 +233,45 @@ def test_DiscreteEBM_bwd_step(): states = env._backward_step(states, failing_actions) +def test_DiscreteEBM_fast_path_matches_legacy(): + NDIM = 5 + BATCH_SIZE = 48 + env = DiscreteEBM(ndim=NDIM) + states_tensor = torch.randint( + -1, 2, (BATCH_SIZE, NDIM), dtype=torch.long, device=env.device + ) + states = env.states_from_tensor(states_tensor.clone()) + assert states.forward_masks is not None + forward_masks = cast(torch.Tensor, states.forward_masks) + actions_tensor = torch.distributions.Categorical( + probs=forward_masks.to(dtype=torch.float32) + ).sample() + actions_tensor = actions_tensor.unsqueeze(-1) + + legacy_next = env._step(states, env.actions_from_tensor(actions_tensor.clone())) + assert legacy_next.forward_masks is not None + assert legacy_next.backward_masks is not None + legacy_forward = cast(torch.Tensor, legacy_next.forward_masks) + legacy_backward = cast(torch.Tensor, legacy_next.backward_masks) + + fast = env.step_tensor(states.tensor, actions_tensor) + assert fast.forward_masks is not None + assert fast.backward_masks is not None + assert fast.is_sink_state is not None + + assert torch.equal(fast.next_states, legacy_next.tensor) + assert torch.equal(fast.is_sink_state, legacy_next.is_sink_state) + + non_sink = ~legacy_next.is_sink_state + if non_sink.any(): + assert torch.equal( + cast(torch.Tensor, fast.forward_masks)[non_sink], legacy_forward[non_sink] + ) + fast_backward = cast(torch.Tensor, fast.backward_masks)[non_sink, : 2 * NDIM] + legacy_backward_trim = legacy_backward[non_sink, : 2 * NDIM] + assert torch.equal(fast_backward, legacy_backward_trim) + + @pytest.mark.parametrize("delta", [0.1, 0.5, 1.0]) def test_box_fwd_step(delta: float): env = Box(delta=delta) @@ -592,6 +671,28 @@ def test_graph_env(): assert states.tensor.x.shape == (0, 1) +def test_Line_fast_path_matches_legacy(): + BATCH_SIZE = 32 + env = Line( + mus=[0.0, 2.0], + sigmas=[0.5, 0.75], + init_value=0.1, + n_steps_per_trajectory=5, + ) + states = env.reset(batch_shape=BATCH_SIZE) + actions_tensor = torch.randn(BATCH_SIZE, 1, device=states.device) + exit_mask = torch.rand(BATCH_SIZE, device=states.device) < 0.25 + actions_tensor[exit_mask] = env.exit_action.item() + + legacy_next = env._step(states, env.actions_from_tensor(actions_tensor.clone())) + fast = env.step_tensor(states.tensor, actions_tensor) + + assert torch.allclose(fast.next_states, legacy_next.tensor) + assert fast.is_sink_state is not None + fast_sink = cast(torch.Tensor, fast.is_sink_state) + assert torch.equal(fast_sink, legacy_next.is_sink_state) + + def test_set_addition_fwd_step(): N_ITEMS = 4 MAX_ITEMS = 3 @@ -648,6 +749,147 @@ def test_set_addition_fwd_step(): assert torch.allclose(rewards, expected_rewards) +def test_box_fast_path_matches_legacy(): + BATCH_SIZE = 48 + DELTA = 0.2 + env = Box(delta=DELTA) + states = env.reset(batch_shape=BATCH_SIZE) + actions_tensor = torch.zeros( + BATCH_SIZE, 2, dtype=torch.get_default_dtype(), device=states.device + ) + exit_mask = torch.rand(BATCH_SIZE, device=states.device) < 0.2 + actions_tensor[exit_mask] = env.exit_action.to(actions_tensor.device).to( + actions_tensor.dtype + ) + non_exit_idx = (~exit_mask).nonzero(as_tuple=True)[0] + if len(non_exit_idx) > 0: + radii = torch.rand(len(non_exit_idx), device=states.device) * DELTA + angles = torch.rand(len(non_exit_idx), device=states.device) * 2 * torch.pi + actions_tensor[non_exit_idx, 0] = radii * torch.cos(angles) + actions_tensor[non_exit_idx, 1] = radii * torch.sin(angles) + + legacy_next = env._step(states, env.actions_from_tensor(actions_tensor.clone())) + fast = env.step_tensor(states.tensor, actions_tensor) + + assert torch.allclose(fast.next_states, legacy_next.tensor) + assert fast.is_sink_state is not None + assert torch.equal(cast(torch.Tensor, fast.is_sink_state), legacy_next.is_sink_state) + + +def test_diffusion_sampling_fast_path_matches_legacy(): + BATCH_SIZE = 16 + env = DiffusionSampling( + target_str="gmm2", target_kwargs={}, num_discretization_steps=8.0 + ) + states = env.reset(batch_shape=BATCH_SIZE) + actions_tensor = torch.randn( + BATCH_SIZE, env.dim, device=states.device, dtype=states.tensor.dtype + ) + exit_mask = torch.rand(BATCH_SIZE, device=states.device) < 0.2 + if exit_mask.any(): + exit_action = env.exit_action.to(device=states.device, dtype=states.tensor.dtype) + actions_tensor[exit_mask] = exit_action + + legacy_next = env._step(states, env.actions_from_tensor(actions_tensor.clone())) + fast = env.step_tensor(states.tensor, actions_tensor) + + assert torch.allclose(fast.next_states, legacy_next.tensor) + assert fast.is_sink_state is not None + assert torch.equal(cast(torch.Tensor, fast.is_sink_state), legacy_next.is_sink_state) + + +def test_bitsequence_fast_path_matches_legacy(): + BATCH_SIZE = 32 + env = BitSequence(word_size=2, seq_size=8, n_modes=4, temperature=1.0) + states = env.reset(batch_shape=BATCH_SIZE) + for _ in range(3): + assert states.forward_masks is not None + masks = cast(torch.Tensor, states.forward_masks) + actions_tensor = ( + torch.distributions.Categorical(probs=masks.to(dtype=torch.float32)) + .sample() + .unsqueeze(-1) + ) + states = env._step(states, env.actions_from_tensor(actions_tensor)) + + assert states.forward_masks is not None + forward_masks = cast(torch.Tensor, states.forward_masks) + actions_tensor = ( + torch.distributions.Categorical(probs=forward_masks.to(dtype=torch.float32)) + .sample() + .unsqueeze(-1) + ) + + legacy_next = env._step(states, env.actions_from_tensor(actions_tensor.clone())) + assert legacy_next.forward_masks is not None + assert legacy_next.backward_masks is not None + + fast = env.step_tensor(states.tensor, actions_tensor) + assert fast.forward_masks is not None + assert fast.backward_masks is not None + assert fast.is_sink_state is not None + + assert torch.equal(fast.next_states, legacy_next.tensor) + assert torch.equal(cast(torch.Tensor, fast.is_sink_state), legacy_next.is_sink_state) + + non_sink = ~legacy_next.is_sink_state + if non_sink.any(): + assert torch.equal( + cast(torch.Tensor, fast.forward_masks)[non_sink], + cast(torch.Tensor, legacy_next.forward_masks)[non_sink], + ) + assert torch.equal( + cast(torch.Tensor, fast.backward_masks)[non_sink], + cast(torch.Tensor, legacy_next.backward_masks)[non_sink], + ) + + +def test_bitsequence_plus_fast_path_matches_legacy(): + BATCH_SIZE = 24 + env = BitSequencePlus(word_size=2, seq_size=16, n_modes=5, temperature=1.0) + states = env.reset(batch_shape=BATCH_SIZE) + for _ in range(4): + assert states.forward_masks is not None + masks = cast(torch.Tensor, states.forward_masks) + actions_tensor = ( + torch.distributions.Categorical(probs=masks.to(dtype=torch.float32)) + .sample() + .unsqueeze(-1) + ) + states = env._step(states, env.actions_from_tensor(actions_tensor)) + + assert states.forward_masks is not None + forward_masks = cast(torch.Tensor, states.forward_masks) + actions_tensor = ( + torch.distributions.Categorical(probs=forward_masks.to(dtype=torch.float32)) + .sample() + .unsqueeze(-1) + ) + + legacy_next = env._step(states, env.actions_from_tensor(actions_tensor.clone())) + assert legacy_next.forward_masks is not None + assert legacy_next.backward_masks is not None + + fast = env.step_tensor(states.tensor, actions_tensor) + assert fast.forward_masks is not None + assert fast.backward_masks is not None + assert fast.is_sink_state is not None + + assert torch.equal(fast.next_states, legacy_next.tensor) + assert torch.equal(cast(torch.Tensor, fast.is_sink_state), legacy_next.is_sink_state) + + non_sink = ~legacy_next.is_sink_state + if non_sink.any(): + assert torch.equal( + cast(torch.Tensor, fast.forward_masks)[non_sink], + cast(torch.Tensor, legacy_next.forward_masks)[non_sink], + ) + assert torch.equal( + cast(torch.Tensor, fast.backward_masks)[non_sink], + cast(torch.Tensor, legacy_next.backward_masks)[non_sink], + ) + + def test_set_addition_bwd_step(): N_ITEMS = 5 MAX_ITEMS = 4 @@ -692,6 +934,49 @@ def test_set_addition_bwd_step(): assert torch.all(states.is_initial_state) +def test_set_addition_fast_path_matches_legacy(): + N_ITEMS = 6 + MAX_ITEMS = 4 + BATCH_SIZE = 48 + + env = SetAddition( + n_items=N_ITEMS, max_items=MAX_ITEMS, reward_fn=lambda s: s.sum(-1) + ) + states_tensor = torch.randint( + 0, 2, (BATCH_SIZE, N_ITEMS), dtype=torch.get_default_dtype() + ) + states = env.states_from_tensor(states_tensor.clone()) + + assert states.forward_masks is not None + forward_masks = cast(torch.Tensor, states.forward_masks) + actions_tensor = torch.distributions.Categorical( + probs=forward_masks.to(dtype=torch.float32) + ).sample() + actions_tensor = actions_tensor.unsqueeze(-1) + + legacy_next = env._step(states, env.actions_from_tensor(actions_tensor.clone())) + assert legacy_next.forward_masks is not None + assert legacy_next.backward_masks is not None + legacy_forward = cast(torch.Tensor, legacy_next.forward_masks) + legacy_backward = cast(torch.Tensor, legacy_next.backward_masks) + + fast = env.step_tensor(states.tensor, actions_tensor) + assert fast.forward_masks is not None + assert fast.backward_masks is not None + assert fast.is_sink_state is not None + + assert torch.equal(fast.next_states, legacy_next.tensor) + assert torch.equal(fast.is_sink_state, legacy_next.is_sink_state) + + fast_forward = cast(torch.Tensor, fast.forward_masks) + fast_backward = cast(torch.Tensor, fast.backward_masks)[..., : env.n_items] + + non_sink = ~legacy_next.is_sink_state + if non_sink.any(): + assert torch.equal(fast_forward[non_sink], legacy_forward[non_sink]) + assert torch.equal(fast_backward[non_sink], legacy_backward[non_sink]) + + def test_perfect_binary_tree_fwd_step(): DEPTH = 3 BATCH_SIZE = 2 @@ -781,6 +1066,49 @@ def test_perfect_binary_tree_bwd_step(): assert torch.all(states.is_initial_state) +def test_perfect_binary_tree_fast_path_matches_legacy(): + DEPTH = 4 + BATCH_SIZE = 64 + + env = PerfectBinaryTree( + depth=DEPTH, + reward_fn=lambda s: s.to(torch.get_default_dtype()) + 1, + ) + states_tensor = torch.randint( + 0, env.n_nodes, (BATCH_SIZE, 1), dtype=torch.long, device=env.device + ) + states = env.states_from_tensor(states_tensor.clone()) + assert states.forward_masks is not None + forward_masks = cast(torch.Tensor, states.forward_masks) + actions_tensor = torch.distributions.Categorical( + probs=forward_masks.to(dtype=torch.float32) + ).sample() + actions_tensor = actions_tensor.unsqueeze(-1) + + legacy_next = env._step(states, env.actions_from_tensor(actions_tensor.clone())) + assert legacy_next.forward_masks is not None + assert legacy_next.backward_masks is not None + legacy_forward = cast(torch.Tensor, legacy_next.forward_masks) + legacy_backward = cast(torch.Tensor, legacy_next.backward_masks) + + fast = env.step_tensor(states.tensor, actions_tensor) + assert fast.forward_masks is not None + assert fast.backward_masks is not None + assert fast.is_sink_state is not None + + assert torch.equal(fast.next_states, legacy_next.tensor) + assert torch.equal(fast.is_sink_state, legacy_next.is_sink_state) + + non_sink = ~legacy_next.is_sink_state + if non_sink.any(): + assert torch.equal( + cast(torch.Tensor, fast.forward_masks)[non_sink], legacy_forward[non_sink] + ) + assert torch.equal( + cast(torch.Tensor, fast.backward_masks)[non_sink], legacy_backward[non_sink] + ) + + # ----------------------------------------------------------------------------- # Tests for default sf fill value based on dtype # ----------------------------------------------------------------------------- diff --git a/tutorials/examples/train_hypergrid_optimized.py b/tutorials/examples/train_hypergrid_optimized.py index 8daad1c4..211e2edf 100644 --- a/tutorials/examples/train_hypergrid_optimized.py +++ b/tutorials/examples/train_hypergrid_optimized.py @@ -1,6 +1,7 @@ #!/usr/bin/env python r""" -Optimized HyperGrid training script with optional torch.compile, vmap, and benchmarking. +Optimized multi-environment (HyperGrid + Diffusion) training/benchmark script with +optional torch.compile, vmap, and chunked sampling across several GFlowNet variants. """ from __future__ import annotations @@ -8,33 +9,249 @@ import argparse import statistics import time +from dataclasses import dataclass from pathlib import Path -from typing import Any, Dict, Iterable, List, cast +from typing import Any, Dict, Iterable, List, Literal, cast import torch from torch.func import vmap from tqdm import tqdm from gfn.containers import Trajectories -from gfn.estimators import DiscretePolicyEstimator, ScalarEstimator +from gfn.env import Env, EnvFastPathMixin +from gfn.estimators import ( + DiscretePolicyEstimator, + FastPolicyMixin, + PinnedBrownianMotionBackward, + PinnedBrownianMotionForward, + ScalarEstimator, +) +from gfn.gflownet import PFBasedGFlowNet, SubTBGFlowNet from gfn.gflownet.detailed_balance import DBGFlowNet -from gfn.gflownet.flow_matching import FMGFlowNet from gfn.gflownet.trajectory_balance import TBGFlowNet from gfn.gym import HyperGrid -from gfn.preprocessors import KHotPreprocessor -from gfn.samplers import Sampler -from gfn.states import DiscreteStates +from gfn.gym.diffusion_sampling import DiffusionSampling +from gfn.preprocessors import IdentityPreprocessor, KHotPreprocessor +from gfn.samplers import CompiledChunkSampler, Sampler +from gfn.states import DiscreteStates, States from gfn.utils.common import set_seed from gfn.utils.compile import try_compile_gflownet -from gfn.utils.modules import MLP, DiscreteUniform +from gfn.utils.modules import ( + MLP, + DiffusionFixedBackwardModule, + DiffusionPISGradNetForward, +) from gfn.utils.training import validate +# Default HyperGrid configuration (easy to extend to multiple envs later on). +HYPERGRID_KWARGS: Dict[str, Any] = { + "ndim": 2, + "height": 32, + "reward_fn_str": "original", + "reward_fn_kwargs": {"R0": 0.1, "R1": 0.5, "R2": 2.0}, + "calculate_partition": False, + "store_all_states": False, + "check_action_validity": __debug__, +} + +DEFAULT_CHUNK_SIZE = 32 +DEFAULT_COMPILE_MODE = "reduce-overhead" + + +@dataclass +class ScenarioConfig: + name: str + description: str + sampler: Literal["standard", "compiled_chunk", "script_chunk"] + use_script_env: bool + use_compile: bool + use_vmap: bool + + +@dataclass(frozen=True) +class FlowVariant: + key: Literal["tb", "dbg", "subtb"] + label: str + description: str + requires_logf: bool + supports_vmap: bool + + +HYPERGRID_SCENARIOS: list[ScenarioConfig] = [ + ScenarioConfig( + name="Baseline (core)", + description="Stock library path: standard env + sampler, no compilation.", + sampler="standard", + use_script_env=False, + use_compile=False, + use_vmap=False, + ), + ScenarioConfig( + name="Library Fast Path", + description="Core EnvFastPath + CompiledChunkSampler + compile + vmap TB.", + sampler="compiled_chunk", + use_script_env=False, + use_compile=True, + use_vmap=True, + ), + ScenarioConfig( + name="Script Env + Compiled Chunk", + description="Script-local tensor env + library CompiledChunkSampler + compile/vmap.", + sampler="compiled_chunk", + use_script_env=True, + use_compile=True, + use_vmap=True, + ), + ScenarioConfig( + name="Script Fast Path", + description="Script-local tensor env/sampler, compile, and vmap TB.", + sampler="script_chunk", + use_script_env=True, + use_compile=True, + use_vmap=True, + ), +] + +DIFFUSION_SCENARIOS: list[ScenarioConfig] = [ + ScenarioConfig( + name="Diffusion Baseline", + description="Pinned Brownian sampler without compilation or chunking.", + sampler="standard", + use_script_env=False, + use_compile=False, + use_vmap=False, + ), + ScenarioConfig( + name="Diffusion Library Fast Path", + description="EnvFastPath + CompiledChunkSampler (library implementation).", + sampler="compiled_chunk", + use_script_env=False, + use_compile=True, + use_vmap=False, + ), + ScenarioConfig( + name="Diffusion Script Fast Path", + description="Script-local tensor sampler tailored to diffusion states.", + sampler="script_chunk", + use_script_env=False, + use_compile=True, + use_vmap=False, + ), +] + + +FLOW_VARIANTS: dict[str, FlowVariant] = { + "tb": FlowVariant( + key="tb", + label="TBGFlowNet", + description="Trajectory Balance baseline with optional torch.compile/vmap.", + requires_logf=False, + supports_vmap=True, + ), + "dbg": FlowVariant( + key="dbg", + label="DBGFlowNet", + description="Detailed Balance loss with learned log-state flows.", + requires_logf=True, + supports_vmap=False, + ), + "subtb": FlowVariant( + key="subtb", + label="SubTBGFlowNet", + description="Sub-trajectory balance variant with configurable weighting.", + requires_logf=True, + supports_vmap=False, + ), +} + +DEFAULT_FLOW_ORDER = ["tb", "dbg", "subtb"] + +# Plot styling: consistent colors for GFlowNet variants, linestyles for scenarios. +VARIANT_COLORS: dict[str, str] = { + "tb": "#000000", # Trajectory Balance -> black + "subtb": "#d62728", # SubTB -> red + "dbg": "#1f77b4", # Detailed Balance -> blue +} +SCENARIO_LINESTYLES: dict[str, Any] = { + "Baseline (core)": "-", + "Library Fast Path": "--", # fast-path compiled + "Script Env + Compiled Chunk": "dashdot", + "Script Fast Path": ":", + "Diffusion Baseline": "-", + "Diffusion Library Fast Path": "--", + "Diffusion Script Fast Path": ":", +} +LOSS_LINE_ALPHA = 0.5 + + +@dataclass +class EnvironmentBenchmark: + key: Literal["hypergrid", "diffusion"] + label: str + description: str + color: str + scenarios: list[ScenarioConfig] + supported_flows: list[str] + supports_validation: bool + + +ENVIRONMENT_BENCHMARKS: dict[str, EnvironmentBenchmark] = { + "hypergrid": EnvironmentBenchmark( + key="hypergrid", + label="HyperGrid", + description="High-dimensional discrete lattice with known reward landscape.", + color="#4a90e2", + scenarios=HYPERGRID_SCENARIOS, + supported_flows=list(DEFAULT_FLOW_ORDER), + supports_validation=True, + ), + "diffusion": EnvironmentBenchmark( + key="diffusion", + label="Diffusion Sampling", + description="Continuous-state diffusion sampling benchmark (Pinned Brownian).", + color="#a17be7", + scenarios=DIFFUSION_SCENARIOS, + supported_flows=list(DEFAULT_FLOW_ORDER), + supports_validation=False, + ), +} +DEFAULT_ENV_ORDER = ["hypergrid", "diffusion"] + + +def _normalize_flow_keys(requested: list[str]) -> list[str]: + normalized: list[str] = [] + for key in requested: + alias = key.lower() + if alias not in FLOW_VARIANTS: + supported = ", ".join(sorted(FLOW_VARIANTS)) + raise ValueError( + f"Unsupported GFlowNet variant '{key}'. Choose from {supported}." + ) + if alias not in normalized: + normalized.append(alias) + return normalized + + +def _normalize_env_keys(requested: list[str]) -> list[str]: + normalized: list[str] = [] + available = ENVIRONMENT_BENCHMARKS + for key in requested: + alias = key.lower() + if alias not in available: + supported = ", ".join(sorted(available)) + raise ValueError( + f"Unsupported environment '{key}'. Choose from {supported}." + ) + if alias not in normalized: + normalized.append(alias) + return normalized or list(DEFAULT_ENV_ORDER) + # Local subclasses for benchmarking-only optimizations (no core library changes) class HyperGridWithTensorStep(HyperGrid): def step_tensor( self, states: torch.Tensor, actions: torch.Tensor - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> Env.TensorStepResult: assert states.dtype == torch.long device = states.device batch = states.shape[0] @@ -67,7 +284,14 @@ def step_tensor( ) next_forward_masks[:, :ndim] = next_states != (self.height - 1) next_forward_masks[:, ndim] = True - return next_states, next_forward_masks, is_exit + backward_masks = next_states != 0 + + return self.TensorStepResult( + next_states=next_states, + is_sink_state=is_exit, + forward_masks=next_forward_masks, + backward_masks=backward_masks, + ) def forward_action_masks(self, states_tensor: torch.Tensor) -> torch.Tensor: """Returns forward-action masks for a batch of state tensors.""" @@ -120,7 +344,33 @@ def sample_trajectories( # noqa: C901 batch = curr_states.shape[0] device = curr_states.device - forward_masks = env.forward_action_masks(curr_states) + def compute_forward_masks(states_tensor: torch.Tensor) -> torch.Tensor: + if hasattr(env, "forward_action_masks"): + return env.forward_action_masks(states_tensor) + if hasattr(env, "forward_action_masks_tensor"): + return env.forward_action_masks_tensor(states_tensor) + raise TypeError( + "HyperGrid environment must expose forward action masks for fast path." + ) + + def step_tensor( + states_tensor: torch.Tensor, actions_tensor: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + step_result = env.step_tensor(states_tensor, actions_tensor) + if isinstance(step_result, Env.TensorStepResult): + next_states = step_result.next_states + next_masks = step_result.forward_masks + if next_masks is None: + next_masks = compute_forward_masks(next_states) + is_exit_states = step_result.is_sink_state + if is_exit_states is None: + is_exit_states = env.states_from_tensor(next_states).is_sink_state + return next_states, next_masks, is_exit_states + assert isinstance(step_result, tuple) and len(step_result) == 3 + next_states, next_masks, is_exit_states = step_result + return next_states, next_masks, is_exit_states + + forward_masks = compute_forward_masks(curr_states) done = torch.zeros(batch, dtype=torch.bool, device=device) actions_seq: List[torch.Tensor] = [] dones_seq: List[torch.Tensor] = [] @@ -158,15 +408,14 @@ def _chunk_loop( current_masks = current_masks.clone() current_masks[done_mask] = False current_masks[done_mask, exit_idx] = True + states_for_encoding = torch.clamp(current_states, min=0) khot = torch.nn.functional.one_hot( - current_states, num_classes=height + states_for_encoding, num_classes=height ).to(dtype=torch.get_default_dtype()) khot = khot.view(current_states.shape[0], -1) logits = module(khot) actions = sample_actions_from_logits(logits, current_masks, epsilon) - next_states, next_masks, is_exit = env.step_tensor( - current_states, actions - ) + next_states, next_masks, is_exit = step_tensor(current_states, actions) record_actions = actions.clone() # Replace actions for already-finished trajectories with the dummy @@ -190,7 +439,7 @@ def _chunk_loop( chunk_fn = _chunk_loop if hasattr(torch, "compile"): try: - chunk_fn = torch.compile(_chunk_loop, mode="reduce-overhead") # type: ignore + chunk_fn = torch.compile(_chunk_loop, mode="reduce-overhead") # type: ignore[arg-type] except Exception: pass @@ -204,11 +453,22 @@ def _chunk_loop( if actions_seq: actions_tsr = torch.stack([a for a in actions_seq], dim=0) + replay_actions = actions_tsr.clone() + dummy_val = env.dummy_action.to(device=device, dtype=replay_actions.dtype) + exit_val = env.exit_action.to(device=device, dtype=replay_actions.dtype) + + mask = replay_actions == dummy_val + if mask.any(): + exit_fill = exit_val + while exit_fill.ndim < replay_actions.ndim: + exit_fill = exit_fill.unsqueeze(0) + replay_actions = torch.where(mask, exit_fill, replay_actions) + T = actions_tsr.shape[0] s = states_obj.tensor states_stack = [s] for t in range(T): - s, fm, is_exit = env.step_tensor(s, actions_tsr[t]) + s, _, _ = step_tensor(s, replay_actions[t]) states_stack.append(s) states_tsr = torch.stack(states_stack, dim=0) is_exit_seq = torch.stack(dones_seq, dim=0) @@ -239,62 +499,419 @@ def _chunk_loop( return trajectories +class ChunkedDiffusionSampler(Sampler): + """Chunked fast-path sampler specialized for DiffusionSampling states.""" + + def __init__(self, estimator: PinnedBrownianMotionForward, chunk_size: int): + super().__init__(estimator) + self.chunk_size = int(chunk_size) + + def sample_trajectories( # noqa: C901 + self, + env: DiffusionSampling, + n: int | None = None, + states: States | None = None, + conditions: torch.Tensor | None = None, + save_estimator_outputs: bool = False, + save_logprobs: bool = False, + **policy_kwargs: Any, + ) -> Trajectories: + if save_estimator_outputs or save_logprobs: + raise NotImplementedError( + "ChunkedDiffusionSampler does not record estimator outputs/log-probs yet." + ) + if not isinstance(env, EnvFastPathMixin): + raise TypeError( + "ChunkedDiffusionSampler requires environments with tensor fast paths." + ) + if not isinstance(self.estimator, FastPolicyMixin): + raise TypeError( + "ChunkedDiffusionSampler requires a FastPolicy-compatible estimator." + ) + + policy = cast(FastPolicyMixin, self.estimator) + chunk_size = max(1, self.chunk_size) + + if states is None: + assert n is not None + states_obj = env.reset(batch_shape=(n,)) + else: + states_obj = states + + curr_states = states_obj.tensor + done = states_obj.is_sink_state.clone() + exit_action_value = env.exit_action.to(device=curr_states.device) + dummy_action_value = env.dummy_action.to(device=curr_states.device) + + step_actions_seq: List[torch.Tensor] = [] + recorded_actions_seq: List[torch.Tensor] = [] + sink_seq: List[torch.Tensor] = [] + + def _chunk_loop(current_states: torch.Tensor, done_mask: torch.Tensor) -> tuple[ + torch.Tensor, + torch.Tensor, + List[torch.Tensor], + List[torch.Tensor], + List[torch.Tensor], + ]: + local_step_actions: List[torch.Tensor] = [] + local_recorded_actions: List[torch.Tensor] = [] + local_sinks: List[torch.Tensor] = [] + + for _ in range(chunk_size): + if bool(done_mask.all().item()): + break + + features = policy.fast_features( + current_states, + forward_masks=None, + backward_masks=None, + conditions=conditions, + ) + dist = policy.fast_distribution( + features, + forward_masks=None, + backward_masks=None, + states_tensor=current_states, + **policy_kwargs, + ) + sampled_actions = dist.sample() + + if done_mask.any(): + mask = done_mask + while mask.ndim < sampled_actions.ndim: + mask = mask.unsqueeze(-1) + + exit_fill = exit_action_value.to( + device=sampled_actions.device, dtype=sampled_actions.dtype + ) + while exit_fill.ndim < sampled_actions.ndim: + exit_fill = exit_fill.unsqueeze(0) + + dummy_fill = dummy_action_value.to( + device=sampled_actions.device, dtype=sampled_actions.dtype + ) + while dummy_fill.ndim < sampled_actions.ndim: + dummy_fill = dummy_fill.unsqueeze(0) + + step_actions = torch.where(mask, exit_fill, sampled_actions) + record_actions = torch.where(mask, dummy_fill, sampled_actions) + else: + step_actions = sampled_actions + record_actions = sampled_actions + + step_res = env.step_tensor(current_states, step_actions) + current_states = step_res.next_states + sinks = step_res.is_sink_state + if sinks is None: + sinks = env.states_from_tensor(current_states).is_sink_state + + done_mask = done_mask | sinks + local_step_actions.append(step_actions) + local_recorded_actions.append(record_actions) + local_sinks.append(sinks) + + return ( + current_states, + done_mask, + local_step_actions, + local_recorded_actions, + local_sinks, + ) + + chunk_fn = _chunk_loop + device_type = curr_states.device.type + if hasattr(torch, "compile") and device_type in ("cuda", "cpu"): + try: + chunk_fn = torch.compile(_chunk_loop, mode="reduce-overhead") # type: ignore[arg-type] + except Exception: + raise RuntimeError( + "Compilation of _chunk_loop for Diffusion Sampling fails on MPS" + ) + + while not bool(done.all().item()): + ( + curr_states, + done, + step_actions_chunk, + recorded_actions_chunk, + sinks_chunk, + ) = chunk_fn(curr_states, done) + if step_actions_chunk: + step_actions_seq.extend(step_actions_chunk) + recorded_actions_seq.extend(recorded_actions_chunk) + sink_seq.extend(sinks_chunk) + + if recorded_actions_seq: + actions_tsr = torch.stack(recorded_actions_seq, dim=0) + T = actions_tsr.shape[0] + + s = states_obj.tensor + states_stack = [s] + for t in range(T): + step = env.step_tensor(s, step_actions_seq[t]) + s = step.next_states + states_stack.append(s) + states_tsr = torch.stack(states_stack, dim=0) + + sinks_tsr = torch.stack(sink_seq, dim=0) + first_sink = torch.argmax(sinks_tsr.to(torch.long), dim=0) + never_sink = ~sinks_tsr.any(dim=0) + first_sink = torch.where( + never_sink, + torch.tensor(T - 1, device=curr_states.device), + first_sink, + ) + terminating_idx = first_sink + 1 + else: + states_tsr = states_obj.tensor.unsqueeze(0) + actions_tsr = env.actions_from_batch_shape((0, states_tsr.shape[1])).tensor + terminating_idx = torch.zeros( + states_tsr.shape[1], dtype=torch.long, device=curr_states.device + ) + return Trajectories( + env=env, + states=env.states_from_tensor(states_tsr), + conditions=conditions, + actions=env.actions_from_tensor(actions_tsr), + terminating_idx=terminating_idx, + is_backward=False, + log_rewards=None, + log_probs=None, + estimator_outputs=None, + ) + + trajectories = Trajectories( + env=env, + states=env.states_from_tensor(states_tsr), + conditions=conditions, + actions=env.actions_from_tensor(actions_tsr), + terminating_idx=terminating_idx, + is_backward=False, + log_rewards=None, + log_probs=None, + estimator_outputs=None, + ) + return trajectories + + +class FastKHotDiscretePolicyEstimator(FastPolicyMixin, DiscretePolicyEstimator): + """Discrete forward policy with tensor-only helpers for HyperGrid.""" + + def __init__( + self, + env: HyperGrid, + module: torch.nn.Module, + preprocessor: KHotPreprocessor, + ) -> None: + super().__init__( + module=module, + n_actions=env.n_actions, + preprocessor=preprocessor, + is_backward=False, + ) + self.height = int(env.height) + self.ndim = int(env.ndim) + self.exit_idx = env.n_actions - 1 + + def fast_features( + self, + states_tensor: torch.Tensor, + *, + forward_masks: torch.Tensor | None = None, + backward_masks: torch.Tensor | None = None, + conditions: torch.Tensor | None = None, + ) -> torch.Tensor: + assert states_tensor.dtype == torch.long + khot = torch.nn.functional.one_hot(states_tensor, num_classes=self.height).to( + dtype=torch.get_default_dtype() + ) + return khot.view(states_tensor.shape[0], -1) + + def fast_distribution( + self, + features: torch.Tensor, + *, + states_tensor: torch.Tensor | None = None, + forward_masks: torch.Tensor | None = None, + backward_masks: torch.Tensor | None = None, + epsilon: float = 0.0, + **policy_kwargs: Any, + ) -> torch.distributions.Categorical: + if states_tensor is None: + raise ValueError( + "states_tensor is required for FastKHotDiscretePolicyEstimator." + ) + + logits = self.module(features) + batch = states_tensor.shape[0] + masks = torch.zeros( + batch, + self.ndim + 1, + dtype=torch.bool, + device=states_tensor.device, + ) + masks[:, : self.ndim] = states_tensor < (self.height - 1) + masks[:, self.exit_idx] = True + + masked_logits = logits.masked_fill(~masks, float("-inf")) + probs = torch.softmax(masked_logits, dim=-1) + + if epsilon > 0.0: + valid_counts = masks.sum(dim=-1, keepdim=True).clamp_min(1) + uniform = masks.to(probs.dtype) / valid_counts.to(probs.dtype) + probs = (1.0 - epsilon) * probs + epsilon * uniform + + return torch.distributions.Categorical(probs=probs) + + def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser() - parser.add_argument("--loss", choices=["FM", "TB", "DB"], default="TB") - parser.add_argument("--ndim", type=int, default=2) - parser.add_argument("--height", type=int, default=32) - parser.add_argument("--R0", type=float, default=0.1) - parser.add_argument("--R1", type=float, default=0.5) - parser.add_argument("--R2", type=float, default=2.0) - parser.add_argument("--seed", type=int, default=0) + parser = argparse.ArgumentParser( + description="Compare baseline vs. fast-path HyperGrid training pipelines." + ) + parser.add_argument("--n-iterations", type=int, default=50, dest="n_iterations") + parser.add_argument("--batch-size", type=int, default=16, dest="batch_size") + parser.add_argument("--warmup-iters", type=int, default=25, dest="warmup_iters") parser.add_argument("--lr", type=float, default=1e-3) - parser.add_argument("--lr_logz", type=float, default=1e-1) - parser.add_argument("--uniform_pb", action="store_true") - parser.add_argument("--n_iterations", type=int, default=100) - parser.add_argument("--validation_interval", type=int, default=100) - parser.add_argument("--validation_samples", type=int, default=200_000) - parser.add_argument("--batch_size", type=int, default=16) + parser.add_argument("--lr-logz", type=float, default=1e-1, dest="lr_logz") + parser.add_argument("--lr-logf", type=float, default=1e-3, dest="lr_logf") parser.add_argument("--epsilon", type=float, default=0.0) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument( + "--environments", + nargs="+", + choices=sorted(ENVIRONMENT_BENCHMARKS), + default=list(DEFAULT_ENV_ORDER), + help="Benchmark environments to include (e.g., hypergrid diffusion).", + ) + parser.add_argument( + "--validation-interval", type=int, default=100, dest="validation_interval" + ) + parser.add_argument( + "--validation-samples", type=int, default=200_000, dest="validation_samples" + ) parser.add_argument( "--device", choices=["auto", "cpu", "mps", "cuda"], default="auto", help="Device to run on; auto prefers CUDA>MPS>CPU.", ) - parser.add_argument("--compile", action="store_true", help="Enable torch.compile.") parser.add_argument( - "--compile-mode", - choices=["default", "reduce-overhead", "max-autotune"], - default="reduce-overhead", - help="Mode passed to torch.compile.", + "--benchmark-output", + type=str, + default=str(Path.home() / "hypergrid_benchmark.png"), + help="Output path for optional benchmark plot.", ) - parser.add_argument("--use-vmap", action="store_true", help="Use vmap TB loss.") - parser.add_argument("--benchmark", action="store_true", help="Run benchmark mode.") parser.add_argument( - "--chunk-size", - type=int, - default=0, - help="Enable chunked sampler fast path when > 0.", + "--skip-plot", + action="store_true", + help="Skip writing the benchmark plot (still prints the summary).", ) parser.add_argument( - "--benchmark-output", + "--gflownets", + nargs="+", + default=DEFAULT_FLOW_ORDER, + help="GFlowNet variants to benchmark (any of: tb, dbg, subtb).", + ) + parser.add_argument( + "--subtb-weighting", + choices=[ + "DB", + "ModifiedDB", + "TB", + "geometric", + "equal", + "equal_within", + ], + default="ModifiedDB", + dest="subtb_weighting", + help="Weighting strategy for SubTBGFlowNet runs.", + ) + parser.add_argument( + "--subtb-lamda", + type=float, + default=0.9, + dest="subtb_lamda", + help="Lambda discount factor for SubTBGFlowNet geometric weighting.", + ) + # Diffusion-specific knobs (ignored unless `diffusion` is selected). + parser.add_argument( + "--diffusion-target", type=str, - default=str(Path.home() / "hypergrid_benchmark.png"), - help="Output path for benchmark plot.", + default="gmm2", + help="Diffusion target alias (see DiffusionSampling.DIFFUSION_TARGETS).", + ) + parser.add_argument( + "--diffusion-dim", + type=int, + default=None, + help="Override target dimensionality when supported.", + ) + parser.add_argument( + "--diffusion-num-components", + type=int, + default=None, + help="Override mixture component count for Gaussian targets.", + ) + parser.add_argument( + "--diffusion-target-seed", + type=int, + default=2, + help="Seed controlling random targets (centers, covariances, etc.).", + ) + parser.add_argument( + "--diffusion-num-steps", + type=int, + default=32, + help="Number of discretization steps for the diffusion process.", + ) + parser.add_argument( + "--diffusion-sigma", + type=float, + default=5.0, + help="Pinned Brownian motion diffusion coefficient.", + ) + parser.add_argument( + "--diffusion-harmonics-dim", + type=int, + default=64, + help="Harmonics embedding dimension for DiffusionPISGradNetForward.", + ) + parser.add_argument( + "--diffusion-t-emb-dim", + type=int, + default=64, + help="Temporal embedding dimension for diffusion forward model.", + ) + parser.add_argument( + "--diffusion-s-emb-dim", + type=int, + default=64, + help="State embedding dimension for diffusion forward model.", ) parser.add_argument( - "--warmup-iters", + "--diffusion-hidden-dim", type=int, - default=50, - help="Warmup iterations before timing (benchmark mode).", + default=64, + help="Hidden dimension for diffusion forward model.", + ) + parser.add_argument( + "--diffusion-joint-layers", + type=int, + default=2, + help="Joint layers count for diffusion forward model.", + ) + parser.add_argument( + "--diffusion-zero-init", + action="store_true", + help="Initialize diffusion forward model heads to zero.", ) return parser.parse_args() def init_metrics() -> Dict[str, Any]: return { - "validation_info": {"l1_dist": float("inf")}, + "validation_info": {"l1_dist": float("nan")}, "discovered_modes": set(), "total_steps": 0, "measured_steps": 0, @@ -305,89 +922,95 @@ def main() -> None: args = parse_args() device = resolve_device(args.device) - if args.benchmark: - base_scenarios: list[tuple[str, bool, bool, bool]] = [ - ("Baseline", False, False, False), - (f"Compile ({args.compile_mode})", True, False, False), - ("Vmap", False, True, False), - (f"Compile+Vmap ({args.compile_mode})", True, True, False), + flow_keys = _normalize_flow_keys(args.gflownets) + env_keys = _normalize_env_keys(args.environments) + if not flow_keys: + raise ValueError("At least one GFlowNet variant must be specified.") + + results: list[dict[str, Any]] = [] + grouped_results: dict[str, dict[str, list[dict[str, Any]]]] = {} + + for env_key in env_keys: + env_cfg = ENVIRONMENT_BENCHMARKS[env_key] + env_flow_keys = [ + flow_key for flow_key in flow_keys if flow_key in env_cfg.supported_flows ] - if args.chunk_size > 0: - base_scenarios += [ - (f"Chunk ({args.chunk_size})", False, False, True), - ( - f"Compile+Chunk ({args.compile_mode},{args.chunk_size})", - True, - False, - True, - ), - (f"Chunk+Vmap ({args.chunk_size})", False, True, True), - ( - f"Compile+Chunk+Vmap ({args.compile_mode},{args.chunk_size})", - True, - True, - True, - ), - ] - scenarios = base_scenarios - results: list[dict[str, Any]] = [] - for label, enable_compile, use_vmap, use_chunk in scenarios: - result = train_with_options( - args, - device, - enable_compile=enable_compile, - use_vmap=use_vmap, - warmup_iters=args.warmup_iters, - quiet=True, - timing=True, - record_history=True, - use_chunk=use_chunk, + if not env_flow_keys: + print( + f"\nSkipping environment '{env_cfg.label}' " + f"(no compatible flows among {', '.join(flow_keys)})." + ) + continue + + grouped_results.setdefault(env_key, {}) + print(f"\n### Environment: {env_cfg.label} ###\n" f"{env_cfg.description}\n") + + for flow_key in env_flow_keys: + flow_variant = FLOW_VARIANTS[flow_key] + grouped_results[env_key].setdefault(flow_key, []) + print( + f"\n=== GFlowNet Variant: {flow_variant.label} " + f"@ {env_cfg.label} ===\n{flow_variant.description}\n" ) - result["label"] = label - results.append(result) - - baseline_elapsed = results[0]["elapsed"] - print("Benchmark summary (speedups vs baseline):") - for result in results: - speedup = ( - baseline_elapsed / result["elapsed"] - if result["elapsed"] - else float("inf") + for scenario in env_cfg.scenarios: + print( + f"\n--- Scenario: {scenario.name} | " + f"{flow_variant.label} ({env_cfg.label}) ---\n" + f"{scenario.description}\n" + ) + result = run_scenario(args, device, scenario, flow_variant, env_cfg) + result["label"] = scenario.name + result["description"] = scenario.description + result["env_key"] = env_cfg.key + result["env_label"] = env_cfg.label + results.append(result) + grouped_results[env_key][flow_key].append(result) + + print("\nBenchmark summary (speedups vs. per-environment baselines):") + for env_key in env_keys: + env_cfg = ENVIRONMENT_BENCHMARKS.get(env_key) + if env_cfg is None: + continue + env_flow_results = grouped_results.get(env_key, {}) + if not env_flow_results: + continue + + baseline_name = env_cfg.scenarios[0].name if env_cfg.scenarios else "baseline" + print(f"\n[{env_cfg.label}] scenario baseline = {baseline_name}") + + for flow_key, flow_results in env_flow_results.items(): + if not flow_results: + continue + flow_variant = FLOW_VARIANTS[flow_key] + baseline_candidate = next( + (res for res in flow_results if res.get("label") == baseline_name), + flow_results[0], ) + baseline_time = baseline_candidate.get("elapsed", 0.0) or 1.0 print( - f"- {result['label']}: {result['elapsed']:.2f}s " - f"({speedup:.2f}x) | compile_mode={result['compile_mode']} " - f"| vmap={'on' if result['effective_vmap'] else 'off'} " - f"| chunk={'on' if result.get('chunk_size_effective', 0) > 0 else 'off'}" + f"\n - {flow_variant.label}: " + f"{baseline_time:.2f}s baseline ({baseline_candidate['label']})" ) + for result in flow_results: + elapsed = result["elapsed"] + speedup = baseline_time / elapsed if elapsed else float("inf") + print( + f" • {result['label']}: {elapsed:.2f}s ({speedup:.2f}x) | " + f"compile={'yes' if result['use_compile'] else 'no'} | " + f"vmap={'yes' if result['use_vmap'] else 'no'} | " + f"sampler={result['sampler']}" + ) + if not args.skip_plot: plot_benchmark(results, args.benchmark_output) - return - - train_with_options( - args, - device, - enable_compile=args.compile, - use_vmap=args.use_vmap, - warmup_iters=0, - quiet=False, - timing=False, - record_history=False, - use_chunk=args.chunk_size > 0, - ) -def train_with_options( +def run_scenario( args: argparse.Namespace, device: torch.device, - *, - enable_compile: bool, - use_vmap: bool, - warmup_iters: int, - quiet: bool, - timing: bool, - record_history: bool, - use_chunk: bool = False, + scenario: ScenarioConfig, + flow_variant: FlowVariant, + env_cfg: EnvironmentBenchmark, ) -> dict[str, Any]: set_seed(args.seed) ( @@ -396,30 +1019,22 @@ def train_with_options( sampler, optimizer, visited_states, - ) = build_training_components(args, device, use_chunk=use_chunk) + ) = build_training_components(args, device, scenario, flow_variant, env_cfg) metrics = init_metrics() + use_vmap = scenario.use_vmap and flow_variant.supports_vmap - compile_mode = args.compile_mode if enable_compile else "none" - if enable_compile: + if scenario.use_compile: compile_results = try_compile_gflownet( gflownet, - mode=args.compile_mode, + mode=DEFAULT_COMPILE_MODE, ) - if not quiet: - formatted = ", ".join( - f"{name}:{'✓' if success else 'x'}" - for name, success in compile_results.items() - ) - print(f"[compile] {formatted}") - - requested_vmap = use_vmap - if use_vmap and not isinstance(gflownet, TBGFlowNet): - if not quiet: - print("vmap is currently only supported for TBGFlowNet; ignoring flag.") - use_vmap = False - effective_vmap = use_vmap + formatted = ", ".join( + f"{name}:{'✓' if success else 'x'}" + for name, success in compile_results.items() + ) + print(f"[compile] {formatted}") - if warmup_iters > 0: + if args.warmup_iters > 0: run_iterations( env, gflownet, @@ -428,12 +1043,13 @@ def train_with_options( visited_states, metrics, args, - n_iters=warmup_iters, + n_iters=args.warmup_iters, use_vmap=use_vmap, quiet=True, collect_metrics=False, track_time=False, record_history=False, + supports_validation=env_cfg.supports_validation, ) elapsed, history = run_iterations( @@ -446,39 +1062,48 @@ def train_with_options( args, n_iters=args.n_iterations, use_vmap=use_vmap, - quiet=quiet, + quiet=False, collect_metrics=True, - track_time=timing, - record_history=record_history, + track_time=True, + record_history=True, + supports_validation=env_cfg.supports_validation, ) - if not quiet: - validation_info = metrics["validation_info"] - l1 = validation_info.get("l1_dist", float("nan")) - print( - f"Finished training | iterations={metrics['measured_steps']} | " - f"modes={len(metrics['discovered_modes'])} / {env.n_modes} | " - f"L1 distance={l1:.6f}" - ) + validation_info = metrics["validation_info"] + l1 = validation_info.get("l1_dist", float("nan")) + modes_total = getattr(env, "n_modes", None) + if modes_total is None: + modes_str = "modes=n/a" + else: + modes_str = f"modes={len(metrics['discovered_modes'])} / {modes_total}" + if env_cfg.supports_validation: + validation_str = f"L1 distance={l1:.6f}" + else: + validation_str = "validation=skipped" + print( + f"Finished training ({env_cfg.label}) | " + f"iterations={metrics['measured_steps']} | " + f"{modes_str} | {validation_str}" + ) return { "elapsed": elapsed or 0.0, "losses": history["losses"] if history else None, "iter_times": history["iter_times"] if history else None, - "compile_mode": compile_mode, - "use_compile": enable_compile, - "requested_vmap": requested_vmap, - "effective_vmap": effective_vmap, - "chunk_size_effective": (args.chunk_size if use_chunk else 0), + "use_compile": scenario.use_compile, + "use_vmap": use_vmap, + "sampler": scenario.sampler, + "gflownet_key": flow_variant.key, + "gflownet_label": flow_variant.label, } def run_iterations( - env: HyperGrid, - gflownet: TBGFlowNet | DBGFlowNet | FMGFlowNet, + env: Env, + gflownet: PFBasedGFlowNet, sampler: Sampler, optimizer: torch.optim.Optimizer, - visited_states: DiscreteStates, + visited_states, metrics: Dict[str, Any], args: argparse.Namespace, *, @@ -488,6 +1113,7 @@ def run_iterations( collect_metrics: bool, track_time: bool, record_history: bool, + supports_validation: bool, ) -> tuple[float | None, Dict[str, list[float]] | None]: if n_iters <= 0: empty_history = {"losses": [], "iter_times": []} if record_history else None @@ -514,7 +1140,7 @@ def run_iterations( epsilon=args.epsilon, ) - terminating_states = cast(DiscreteStates, trajectories.terminating_states) + terminating_states = cast(States, trajectories.terminating_states) visited_states.extend(terminating_states) optimizer.zero_grad() @@ -532,8 +1158,8 @@ def run_iterations( last_loss = loss.item() if ( record_history - and (losses_history is not None) - and (iter_time_history is not None) + and losses_history is not None + and iter_time_history is not None ): losses_history.append(last_loss) iter_duration = ( @@ -541,9 +1167,9 @@ def run_iterations( ) iter_time_history.append(iter_duration) - if collect_metrics: + if collect_metrics and supports_validation: run_validation_if_needed( - env, + cast(HyperGrid, env), gflownet, visited_states, metrics, @@ -562,7 +1188,8 @@ def run_iterations( ) if track_time: - synchronize_if_needed(env.device) + env_device = getattr(env, "device", torch.device("cpu")) + synchronize_if_needed(env_device) assert start_time is not None elapsed_time = time.perf_counter() - start_time else: @@ -579,14 +1206,16 @@ def run_iterations( def compute_loss( - gflownet: TBGFlowNet | DBGFlowNet | FMGFlowNet, - env: HyperGrid, + gflownet: PFBasedGFlowNet, + env: Env, trajectories, *, use_vmap: bool, ) -> torch.Tensor: - if use_vmap and isinstance(gflownet, TBGFlowNet): - return trajectory_balance_loss_vmap(gflownet, trajectories) + if use_vmap: + if not isinstance(gflownet, TBGFlowNet): + raise ValueError("vmap trajectory balance loss requires a TBGFlowNet.") + return trajectory_balance_loss_vmap(cast(TBGFlowNet, gflownet), trajectories) return gflownet.loss_from_trajectories( env, trajectories, recalculate_all_logprobs=False @@ -615,18 +1244,11 @@ def tb_residual( log_rewards, ) - log_z = gflownet.logZ - if isinstance(log_z, ScalarEstimator): - if trajectories.conditions is None: - raise ValueError("Conditional logZ requires conditions tensor.") - log_z_value = log_z(trajectories.conditions) + log_z_value = gflownet.logZ + if not isinstance(log_z_value, torch.Tensor): + log_z_tensor = torch.as_tensor(log_z_value, device=residuals.device) else: - log_z_value = log_z - - if isinstance(log_z_value, torch.Tensor): log_z_tensor = log_z_value - else: - log_z_tensor = torch.as_tensor(log_z_value, device=residuals.device) log_z_tensor = log_z_tensor.squeeze() scores = (residuals + log_z_tensor).pow(2) @@ -635,7 +1257,7 @@ def tb_residual( def run_validation_if_needed( env: HyperGrid, - gflownet: TBGFlowNet | DBGFlowNet | FMGFlowNet, + gflownet: PFBasedGFlowNet, visited_states: DiscreteStates, metrics: Dict[str, Any], args: argparse.Namespace, @@ -671,97 +1293,212 @@ def run_validation_if_needed( def build_training_components( - args: argparse.Namespace, device: torch.device, *, use_chunk: bool = False -) -> tuple[ - HyperGrid, - TBGFlowNet | DBGFlowNet | FMGFlowNet, - Sampler, - torch.optim.Optimizer, - DiscreteStates, -]: - EnvClass = ( - HyperGridWithTensorStep if (use_chunk and args.chunk_size > 0) else HyperGrid - ) - env = EnvClass( - ndim=args.ndim, - height=args.height, - reward_fn_str="original", - reward_fn_kwargs={ - "R0": args.R0, - "R1": args.R1, - "R2": args.R2, - }, - device=device, - calculate_partition=True, - store_all_states=True, - check_action_validity=__debug__, - ) + args: argparse.Namespace, + device: torch.device, + scenario: ScenarioConfig, + flow_variant: FlowVariant, + env_cfg: EnvironmentBenchmark, +) -> tuple[Env, PFBasedGFlowNet, Sampler, torch.optim.Optimizer, States]: + if env_cfg.key == "hypergrid": + return _build_hypergrid_components(args, device, scenario, flow_variant) + if env_cfg.key == "diffusion": + return _build_diffusion_components(args, device, scenario, flow_variant) + raise ValueError(f"Unsupported environment key: {env_cfg.key}") + + +def _build_hypergrid_components( + args: argparse.Namespace, + device: torch.device, + scenario: ScenarioConfig, + flow_variant: FlowVariant, +) -> tuple[HyperGrid, PFBasedGFlowNet, Sampler, torch.optim.Optimizer, DiscreteStates]: + env_kwargs = dict(HYPERGRID_KWARGS) + env_kwargs["device"] = device + EnvClass = HyperGridWithTensorStep if scenario.use_script_env else HyperGrid + env = EnvClass(**env_kwargs) preprocessor = KHotPreprocessor(height=env.height, ndim=env.ndim) - module_PF = MLP( + module_pf = MLP( input_dim=preprocessor.output_dim, output_dim=env.n_actions, ) - if not args.uniform_pb: - module_PB = MLP( + module_pb = MLP( + input_dim=preprocessor.output_dim, + output_dim=env.n_actions - 1, + trunk=module_pf.trunk, + ) + + if scenario.sampler == "compiled_chunk": + pf_estimator = FastKHotDiscretePolicyEstimator(env, module_pf, preprocessor) + else: + pf_estimator = DiscretePolicyEstimator( + module_pf, env.n_actions, preprocessor=preprocessor, is_backward=False + ) + pb_estimator = DiscretePolicyEstimator( + module_pb, env.n_actions, preprocessor=preprocessor, is_backward=True + ) + + logF_estimator: ScalarEstimator | None = None + if flow_variant.requires_logf: + logF_module = MLP( input_dim=preprocessor.output_dim, - output_dim=env.n_actions - 1, - trunk=module_PF.trunk, + output_dim=1, + ) + logF_estimator = ScalarEstimator(module=logF_module, preprocessor=preprocessor) + + if flow_variant.key == "tb": + gflownet = TBGFlowNet(pf=pf_estimator, pb=pb_estimator, init_logZ=0.0) + elif flow_variant.key == "dbg": + assert logF_estimator is not None + gflownet = DBGFlowNet(pf=pf_estimator, pb=pb_estimator, logF=logF_estimator) + elif flow_variant.key == "subtb": + assert logF_estimator is not None + gflownet = SubTBGFlowNet( + pf=pf_estimator, + pb=pb_estimator, + logF=logF_estimator, + weighting=args.subtb_weighting, + lamda=args.subtb_lamda, ) else: - module_PB = DiscreteUniform(output_dim=env.n_actions - 1) + raise ValueError(f"Unsupported GFlowNet variant: {flow_variant.key}") + + gflownet = gflownet.to(device) + optimizer = torch.optim.Adam(gflownet.pf_pb_parameters(), lr=args.lr) + + logz_params = getattr(gflownet, "logz_parameters", None) + if callable(logz_params): + params = logz_params() + if params: + optimizer.add_param_group({"params": params, "lr": args.lr_logz}) + + logf_params = getattr(gflownet, "logF_parameters", None) + if callable(logf_params): + params = logf_params() + if params: + optimizer.add_param_group({"params": params, "lr": args.lr_logf}) + + if scenario.sampler == "compiled_chunk": + sampler: Sampler = CompiledChunkSampler( + estimator=pf_estimator, chunk_size=DEFAULT_CHUNK_SIZE + ) + elif scenario.sampler == "script_chunk": + sampler = ChunkedHyperGridSampler( + estimator=pf_estimator, chunk_size=DEFAULT_CHUNK_SIZE + ) + else: + sampler = Sampler(estimator=pf_estimator) - if args.loss == "FM": - logF_estimator = DiscretePolicyEstimator( - module=module_PF, - n_actions=env.n_actions, - preprocessor=preprocessor, + visited_states = env.states_from_batch_shape((0,)) + return env, gflownet, sampler, optimizer, visited_states + + +def _build_diffusion_components( + args: argparse.Namespace, + device: torch.device, + scenario: ScenarioConfig, + flow_variant: FlowVariant, +) -> tuple[DiffusionSampling, PFBasedGFlowNet, Sampler, torch.optim.Optimizer, States]: + target_kwargs: dict[str, Any] = {"seed": args.diffusion_target_seed} + if args.diffusion_dim is not None: + target_kwargs["dim"] = args.diffusion_dim + if args.diffusion_num_components is not None: + target_kwargs["num_components"] = args.diffusion_num_components + + env = DiffusionSampling( + target_str=args.diffusion_target, + target_kwargs=target_kwargs, + num_discretization_steps=args.diffusion_num_steps, + device=device, + check_action_validity=False, + ) + + s_dim = env.dim + pf_module = DiffusionPISGradNetForward( + s_dim=s_dim, + harmonics_dim=args.diffusion_harmonics_dim, + t_emb_dim=args.diffusion_t_emb_dim, + s_emb_dim=args.diffusion_s_emb_dim, + hidden_dim=args.diffusion_hidden_dim, + joint_layers=args.diffusion_joint_layers, + zero_init=args.diffusion_zero_init, + ) + pb_module = DiffusionFixedBackwardModule(s_dim=s_dim) + + pf_estimator = PinnedBrownianMotionForward( + s_dim=s_dim, + pf_module=pf_module, + sigma=args.diffusion_sigma, + num_discretization_steps=args.diffusion_num_steps, + ) + pb_estimator = PinnedBrownianMotionBackward( + s_dim=s_dim, + pb_module=pb_module, + sigma=args.diffusion_sigma, + num_discretization_steps=args.diffusion_num_steps, + ) + + logF_estimator: ScalarEstimator | None = None + if flow_variant.requires_logf: + logF_module = MLP( + input_dim=env.state_shape[-1], + output_dim=1, ) - gflownet: TBGFlowNet | DBGFlowNet | FMGFlowNet = FMGFlowNet(logF_estimator).to( - device + logF_preprocessor = IdentityPreprocessor(output_dim=env.state_shape[-1]) + logF_estimator = ScalarEstimator( + module=logF_module, preprocessor=logF_preprocessor ) - optimizer = torch.optim.Adam(gflownet.logF.parameters(), lr=args.lr) - sampler = ( - ChunkedHyperGridSampler(estimator=logF_estimator, chunk_size=args.chunk_size) - if use_chunk and args.chunk_size > 0 - else Sampler(estimator=logF_estimator) + + if flow_variant.key == "tb": + gflownet: PFBasedGFlowNet = TBGFlowNet( + pf=pf_estimator, pb=pb_estimator, init_logZ=0.0 ) - else: - pf_estimator = DiscretePolicyEstimator( - module_PF, env.n_actions, preprocessor=preprocessor, is_backward=False + elif flow_variant.key == "dbg": + assert logF_estimator is not None + gflownet = DBGFlowNet( + pf=pf_estimator, + pb=pb_estimator, + logF=logF_estimator, ) - pb_estimator = DiscretePolicyEstimator( - module_PB, env.n_actions, preprocessor=preprocessor, is_backward=True + elif flow_variant.key == "subtb": + assert logF_estimator is not None + gflownet = SubTBGFlowNet( + pf=pf_estimator, + pb=pb_estimator, + logF=logF_estimator, + weighting=args.subtb_weighting, + lamda=args.subtb_lamda, + ) + else: + raise ValueError( + f"Unsupported GFlowNet variant for diffusion: {flow_variant.key}" ) - if args.loss == "DB": - logF_module = MLP( - input_dim=preprocessor.output_dim, - output_dim=1, - ) - logF_estimator = ScalarEstimator( - module=logF_module, - preprocessor=preprocessor, - ) - gflownet = DBGFlowNet(pf=pf_estimator, pb=pb_estimator, logF=logF_estimator) - else: - gflownet = TBGFlowNet(pf=pf_estimator, pb=pb_estimator, init_logZ=0.0) + gflownet = gflownet.to(device) + optimizer = torch.optim.Adam(gflownet.pf_pb_parameters(), lr=args.lr) - gflownet = gflownet.to(device) - optimizer = torch.optim.Adam(gflownet.pf_pb_parameters(), lr=args.lr) - if isinstance(gflownet, DBGFlowNet): - optimizer.add_param_group( - {"params": gflownet.logF.parameters(), "lr": args.lr} - ) - else: - optimizer.add_param_group( - {"params": gflownet.logz_parameters(), "lr": args.lr_logz} - ) - sampler = ( - ChunkedHyperGridSampler(estimator=pf_estimator, chunk_size=args.chunk_size) - if use_chunk and args.chunk_size > 0 - else Sampler(estimator=pf_estimator) + logz_params = getattr(gflownet, "logz_parameters", None) + if callable(logz_params): + params = logz_params() + if params: + optimizer.add_param_group({"params": params, "lr": args.lr_logz}) + + logf_params = getattr(gflownet, "logF_parameters", None) + if callable(logf_params): + params = logf_params() + if params: + optimizer.add_param_group({"params": params, "lr": args.lr_logf}) + + if scenario.sampler == "compiled_chunk": + sampler: Sampler = CompiledChunkSampler( + estimator=pf_estimator, chunk_size=DEFAULT_CHUNK_SIZE ) + elif scenario.sampler == "script_chunk": + sampler = ChunkedDiffusionSampler( + estimator=pf_estimator, chunk_size=DEFAULT_CHUNK_SIZE + ) + else: + sampler = Sampler(estimator=pf_estimator) visited_states = env.states_from_batch_shape((0,)) return env, gflownet, sampler, optimizer, visited_states @@ -773,11 +1510,12 @@ def _mps_backend_available() -> bool: def resolve_device(requested: str) -> torch.device: + """MPS backend is not supported for the Diffusion Sampling environment.""" if requested == "auto": if torch.cuda.is_available(): return torch.device("cuda") - if _mps_backend_available(): - return torch.device("mps") + # if _mps_backend_available(): + # return torch.device("mps") return torch.device("cpu") device = torch.device(requested) @@ -795,84 +1533,149 @@ def synchronize_if_needed(device: torch.device) -> None: torch.mps.synchronize() -def plot_benchmark(results: list[Dict[str, Any]], output_path: str) -> None: - try: - import matplotlib.pyplot as plt - except ImportError as exc: - raise RuntimeError( - "matplotlib is required for plotting; install it or omit --benchmark." - ) from exc +def _summarize_iteration_times(times: list[float]) -> tuple[float, float]: + if not times: + return 0.0, 0.0 + mean_time = statistics.fmean(times) + std_time = statistics.pstdev(times) if len(times) > 1 else 0.0 + return mean_time, std_time - def summarize_iteration_times(times: list[float]) -> tuple[float, float]: - if not times: - return 0.0, 0.0 - mean_time = statistics.fmean(times) - std_time = statistics.pstdev(times) if len(times) > 1 else 0.0 - return mean_time, std_time - - labels = [res.get("label", f"Run {idx+1}") for idx, res in enumerate(results)] - times = [res["elapsed"] for res in results] - losses_list = [res.get("losses") or [] for res in results] - iter_times_list = [res.get("iter_times") or [] for res in results] - - fig, axes = plt.subplots(1, 3, figsize=(20, 5)) - - # Subplot 1: total time comparison - colors = ["#6c757d", "#1f77b4", "#2ca02c", "#d62728", "#9467bd", "#8c564b"] - bar_colors = [colors[i % len(colors)] for i in range(len(results))] - bars = axes[0].bar(labels, times, color=bar_colors) - axes[0].set_ylabel("Wall-clock time (s)") - axes[0].set_title("Total Training Time") - baseline_time = times[0] if times else 1.0 - for i, (bar, value) in enumerate(zip(bars, times)): - speedup = baseline_time / value if value else float("inf") - axes[0].text( + +def _render_env_row( + row_axes, + env_results: list[Dict[str, Any]], + env_cfg: EnvironmentBenchmark | None, + palette: list[str], +) -> None: + env_label = env_cfg.label if env_cfg else env_results[0].get("env_label", "Env") + labels = [ + f"{res.get('label', f'Run {idx+1}')} [{res.get('gflownet_label', '?')}]" + for idx, res in enumerate(env_results) + ] + times = [res.get("elapsed", 0.0) for res in env_results] + bar_colors = [palette[i % len(palette)] for i in range(len(env_results))] + + baseline_name = env_cfg.scenarios[0].name if env_cfg and env_cfg.scenarios else None + + # Determine per-flow baselines (default to the baseline scenario if present, else first run). + flow_baselines: dict[str, float] = {} + for res in env_results: + flow_key = res.get("gflownet_key") + if flow_key is None or flow_key in flow_baselines: + continue + if baseline_name is not None and res.get("label") == baseline_name: + flow_baselines[flow_key] = res.get("elapsed", 0.0) or 0.0 + for res in env_results: + flow_key = res.get("gflownet_key") + if flow_key is None: + continue + flow_baselines.setdefault(flow_key, res.get("elapsed", 0.0) or 0.0) + + bars = row_axes[0].bar(labels, times, color=bar_colors) + row_axes[0].set_ylabel("Wall-clock time (s)") + row_axes[0].set_title(f"{env_label} | Total Training Time") + + for bar, value, res in zip(bars, times, env_results): + if value == 0.0: + continue + flow_key = res.get("gflownet_key", "") + flow_baseline = flow_baselines.get(flow_key, value) or value + pct_speedup = ( + (flow_baseline / value - 1.0) * 100.0 if value > 0.0 else float("inf") + ) + row_axes[0].text( bar.get_x() + bar.get_width() / 2, value, - f"{value:.2f}s\n{speedup:.2f}x", + f"{value:.2f}s\n{pct_speedup:+.1f}%", ha="center", va="bottom", + color="black", + fontsize=9, ) # Subplot 2: training curves - line_styles = ["-", "--", "-.", ":", (0, (3, 1, 1, 1)), (0, (5, 5))] - - for idx, losses in enumerate(losses_list): + loss_ax = row_axes[1] + for idx, res in enumerate(env_results): + losses = res.get("losses") or [] if not losses: continue - axes[1].plot( + variant_key = res.get("gflownet_key", "") + scenario_label = res.get("label", "") + color = VARIANT_COLORS.get(variant_key, palette[idx % len(palette)]) + linestyle = SCENARIO_LINESTYLES.get(scenario_label, "-") + loss_ax.plot( range(1, len(losses) + 1), losses, label=labels[idx], - color=bar_colors[idx], - linestyle=line_styles[idx % len(line_styles)], + color=color, + linestyle=linestyle, linewidth=2.0, - alpha=0.5, + alpha=LOSS_LINE_ALPHA, ) - axes[1].set_title("Training Loss") - axes[1].set_xlabel("Iteration") - axes[1].set_ylabel("Loss") - axes[1].legend() + loss_ax.set_title(f"{env_label} | Training Loss") + loss_ax.set_xlabel("Iteration") + loss_ax.set_ylabel("Loss") + if loss_ax.lines: + loss_ax.legend(fontsize="small") # Subplot 3: per-iteration timing with error bars - summary_stats = [summarize_iteration_times(times) for times in iter_times_list] - means_ms = [mean * 1000.0 for mean, _ in summary_stats] - stds_ms = [std * 1000.0 for _, std in summary_stats] - axes[2].bar( + + iter_ax = row_axes[2] + iter_stats = [ + _summarize_iteration_times(res.get("iter_times") or []) for res in env_results + ] + means_ms = [mean * 1000.0 for mean, _ in iter_stats] + stds_ms = [std * 1000.0 for _, std in iter_stats] + iter_ax.bar( labels, means_ms, yerr=stds_ms, capsize=6, color=bar_colors, ) - axes[2].set_ylabel("Per-iteration time (ms)") - axes[2].set_title("Iteration Timing (mean ± std)") + iter_ax.set_ylabel("Per-iteration time (ms)") + iter_ax.set_title(f"{env_label} | Iteration Timing") - for ax in axes: + for ax in row_axes: for label in ax.get_xticklabels(): label.set_rotation(30) label.set_ha("right") + +def plot_benchmark(results: list[Dict[str, Any]], output_path: str) -> None: + try: + import matplotlib.pyplot as plt + except ImportError as exc: + raise RuntimeError( + "matplotlib is required for plotting; install it or omit --benchmark." + ) from exc + + if not results: + print("No benchmark results to plot.") + return + + env_order: list[str] = [] + for res in results: + env_key = res.get("env_key", "unknown") + if env_key not in env_order: + env_order.append(env_key) + + n_rows = max(1, len(env_order)) + fig, axes = plt.subplots(n_rows, 3, figsize=(20, 5 * n_rows)) + if n_rows == 1: + axes = [axes] # type: ignore[list-item] + + palette = ["#6c757d", "#1f77b4", "#2ca02c", "#d62728", "#9467bd", "#8c564b"] + + for row_idx, env_key in enumerate(env_order): + env_results = [res for res in results if res.get("env_key") == env_key] + if not env_results: + continue + + env_cfg = ENVIRONMENT_BENCHMARKS.get(env_key) + row_axes = axes[row_idx] + _render_env_row(row_axes, env_results, env_cfg, palette) + output = Path(output_path) output.parent.mkdir(parents=True, exist_ok=True) fig.tight_layout() diff --git a/tutorials/examples/train_line.py b/tutorials/examples/train_line.py index 4a5d492d..5fbfc8f8 100644 --- a/tutorials/examples/train_line.py +++ b/tutorials/examples/train_line.py @@ -7,7 +7,7 @@ from torch.distributions.independent import Independent from tqdm import trange -from gfn.estimators import Estimator, PolicyMixin +from gfn.estimators import Estimator, FastPolicyMixin from gfn.gflownet import TBGFlowNet # TODO: Extend to SubTBGFlowNet from gfn.gym.line import Line from gfn.states import States @@ -168,13 +168,14 @@ def forward(self, preprocessed_states: torch.Tensor) -> torch.Tensor: return out -class StepEstimator(Estimator, PolicyMixin): +class StepEstimator(FastPolicyMixin, Estimator): """Estimator for PF and PB of the Line environment.""" def __init__(self, env: Line, module: torch.nn.Module, backward: bool): super().__init__(module, is_backward=backward) self.backward = backward self.n_steps_per_trajectory = env.n_steps_per_trajectory + self.env = env @property def expected_output_dim(self) -> int: @@ -207,6 +208,31 @@ def to_probability_distribution( n_steps=self.n_steps_per_trajectory, ) + def fast_features( + self, + states_tensor: torch.Tensor, + *, + forward_masks: torch.Tensor | None = None, + backward_masks: torch.Tensor | None = None, + conditions: torch.Tensor | None = None, + ) -> torch.Tensor: + return states_tensor + + def fast_distribution( + self, + features: torch.Tensor, + *, + states_tensor: torch.Tensor | None = None, + forward_masks: torch.Tensor | None = None, + backward_masks: torch.Tensor | None = None, + **policy_kwargs, + ) -> Distribution: + if states_tensor is None: + raise ValueError("states_tensor is required for StepEstimator fast path.") + module_output = self.module(features) + states = self.env.states_from_tensor_fast(states_tensor) + return self.to_probability_distribution(states, module_output, **policy_kwargs) + def train( gflownet, diff --git a/tutorials/notebooks/torch_compile_discrete_states.ipynb b/tutorials/notebooks/torch_compile_discrete_states.ipynb new file mode 100644 index 00000000..84b993ee --- /dev/null +++ b/tutorials/notebooks/torch_compile_discrete_states.ipynb @@ -0,0 +1,4889 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# torch.compile with `DiscreteStates`\n", + "\n", + "This short experiment shows that a `DiscreteStates` wrapper can safely flow through `torch.compile`. We instantiate a simple environment, grab its states/actions, and compare the eager and compiled results of a single `_step` call.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Using device: cpu\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/jdv/code/torchgfn/src/gfn/gym/hypergrid.py:90: UserWarning: + Warning: height <= 4 can lead to unsolvable environments.\n", + " warnings.warn(\"+ Warning: height <= 4 can lead to unsolvable environments.\")\n", + "/Users/jdv/code/torchgfn/src/gfn/env.py:495: UserWarning: You're using advanced parameters: (sf). These are only needed for custom action handling. For basic environments, you can omit these.\n", + " warnings.warn(\n", + "[W1126 23:25:08.072794000 unwind.cpp:12] Warning: record_context_cpp is not support on non-linux non-x86_64 platforms (function operator())\n", + "W1126 23:25:12.035000 80441 site-packages/torch/_dynamo/variables/tensor.py:1047] [9/0] Graph break from `Tensor.item()`, consider setting:\n", + "W1126 23:25:12.035000 80441 site-packages/torch/_dynamo/variables/tensor.py:1047] [9/0] torch._dynamo.config.capture_scalar_outputs = True\n", + "W1126 23:25:12.035000 80441 site-packages/torch/_dynamo/variables/tensor.py:1047] [9/0] or:\n", + "W1126 23:25:12.035000 80441 site-packages/torch/_dynamo/variables/tensor.py:1047] [9/0] env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1\n", + "W1126 23:25:12.035000 80441 site-packages/torch/_dynamo/variables/tensor.py:1047] [9/0] to include these operations in the captured graph.\n", + "W1126 23:25:12.035000 80441 site-packages/torch/_dynamo/variables/tensor.py:1047] [9/0] \n", + "W1126 23:25:12.035000 80441 site-packages/torch/_dynamo/variables/tensor.py:1047] [9/0] Graph break: from user code at:\n", + "W1126 23:25:12.035000 80441 site-packages/torch/_dynamo/variables/tensor.py:1047] [9/0] File \"/Users/jdv/code/torchgfn/src/gfn/env.py\", line 309, in torch_dynamo_resume_in__step_at_307\n", + "W1126 23:25:12.035000 80441 site-packages/torch/_dynamo/variables/tensor.py:1047] [9/0] if not self.is_action_valid(valid_states, valid_actions):\n", + "W1126 23:25:12.035000 80441 site-packages/torch/_dynamo/variables/tensor.py:1047] [9/0] File \"/Users/jdv/code/torchgfn/src/gfn/env.py\", line 643, in is_action_valid\n", + "W1126 23:25:12.035000 80441 site-packages/torch/_dynamo/variables/tensor.py:1047] [9/0] return bool(torch.gather(masks_tensor, 1, actions.tensor).all().item())\n", + "W1126 23:25:12.035000 80441 site-packages/torch/_dynamo/variables/tensor.py:1047] [9/0] \n", + "W1126 23:25:12.035000 80441 site-packages/torch/_dynamo/variables/tensor.py:1047] [9/0] \n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Outputs match: True\n", + "Output device: cpu\n", + "Example compiled output:\n", + " tensor([[0, 1],\n", + " [0, 1],\n", + " [0, 1],\n", + " [0, 1]])\n" + ] + } + ], + "source": [ + "import torch\n", + "from gfn.gym.hypergrid import HyperGrid\n", + "\n", + "# Resolve device (CUDA if available, else CPU)\n", + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "print(f\"Using device: {device}\")\n", + "\n", + "# Instantiate a small environment and grab states/actions.\n", + "env = HyperGrid(ndim=2, height=4, device=device)\n", + "states = env.reset(batch_shape=4)\n", + "actions = env.actions_from_batch_shape((4,))\n", + "actions.tensor = torch.ones((4, 1), dtype=torch.long, device=device)\n", + "\n", + "# Define a helper that takes raw tensors, rebuilds the wrappers, and returns the step result.\n", + "def step_once(states_tensor: torch.Tensor, actions_tensor: torch.Tensor) -> torch.Tensor:\n", + " s = env.States(states_tensor)\n", + " a = env.Actions(actions_tensor)\n", + " return env._step(s, a).tensor\n", + "\n", + "compiled_step = torch.compile(step_once, dynamic=True)\n", + "\n", + "eager_out = step_once(states.tensor, actions.tensor)\n", + "compiled_out = compiled_step(states.tensor, actions.tensor)\n", + "\n", + "print(\"Outputs match:\", torch.equal(eager_out, compiled_out))\n", + "print(\"Output device:\", compiled_out.device)\n", + "print(\"Example compiled output:\\n\", compiled_out)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Microbenchmark harness\n", + "\n", + "The cells below build a small timing helper so we can compare `step_once` in eager mode vs the `torch.compile(..., dynamic=True)` variant under identical inputs. We run everything on CPU for consistency.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import math\n", + "import statistics\n", + "import warnings\n", + "from typing import Callable, Dict\n", + "\n", + "import torch.utils.benchmark as benchmark\n", + "\n", + "\n", + "def _sync_if_needed() -> None:\n", + " if torch.cuda.is_available():\n", + " torch.cuda.synchronize()\n", + "\n", + "\n", + "def benchmark_step_fn(\n", + " step_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],\n", + " label: str,\n", + " states_tensor: torch.Tensor,\n", + " actions_tensor: torch.Tensor,\n", + " *,\n", + " iters: int = 200,\n", + ") -> Dict[str, float]:\n", + " \"\"\"Time repeated calls to `step_fn` under identical inputs.\"\"\"\n", + "\n", + " torch.manual_seed(0)\n", + " if torch.cuda.is_available():\n", + " torch.cuda.manual_seed_all(0)\n", + "\n", + " warmup_iters = max(5, iters // 10)\n", + " for _ in range(warmup_iters):\n", + " step_fn(states_tensor, actions_tensor)\n", + " _sync_if_needed()\n", + "\n", + " timer = benchmark.Timer(\n", + " stmt=\"fn(states_tensor, actions_tensor)\",\n", + " globals={\n", + " \"fn\": step_fn,\n", + " \"states_tensor\": states_tensor,\n", + " \"actions_tensor\": actions_tensor,\n", + " },\n", + " label=label,\n", + " sub_label=f\"device={states_tensor.device}\",\n", + " description=\"step_once microbenchmark\",\n", + " )\n", + " result = timer.timeit(iters)\n", + " std_ms = statistics.pstdev(result.raw_times) * 1000 if result.raw_times else float(\"nan\")\n", + " run_count = len(result.raw_times) if result.raw_times else iters\n", + " return {\n", + " \"label\": label,\n", + " \"mean_ms\": result.mean * 1000,\n", + " \"std_ms\": std_ms,\n", + " \"iters\": run_count,\n", + " }\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[{'label': 'Eager step_once',\n", + " 'mean_ms': 0.10937216250458733,\n", + " 'std_ms': 0.0,\n", + " 'iters': 1},\n", + " {'label': 'torch.compile(step_once)',\n", + " 'mean_ms': 0.3456614166498184,\n", + " 'std_ms': 0.0,\n", + " 'iters': 1}]" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "benchmark_iters = 20000\n", + "results = []\n", + "\n", + "results.append(\n", + " benchmark_step_fn(\n", + " step_once,\n", + " label=\"Eager step_once\",\n", + " states_tensor=states.tensor,\n", + " actions_tensor=actions.tensor,\n", + " iters=benchmark_iters,\n", + " )\n", + ")\n", + "\n", + "with warnings.catch_warnings(record=True) as caught:\n", + " warnings.simplefilter(\"always\")\n", + " results.append(\n", + " benchmark_step_fn(\n", + " compiled_step,\n", + " label=\"torch.compile(step_once)\",\n", + " states_tensor=states.tensor,\n", + " actions_tensor=actions.tensor,\n", + " iters=benchmark_iters,\n", + " )\n", + " )\n", + " compile_warning_messages = sorted({str(w.message) for w in caught})\n", + "\n", + "results\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Mode mean (ms) std (ms) iters\n", + "-----------------------------------------------------------------\n", + "Eager step_once 0.1094 0.0000 1\n", + "torch.compile(step_once) 0.3457 0.0000 1\n", + "\n", + "Speedup (eager / compiled): 0.316x\n", + "\n", + "Dynamo summary -> Graphs: 12, Graph breaks: 11, Break reasons: ['Dynamic shape operator', 'Unsupported Tensor.item() call with capture_scalar_outputs=False']\n", + "\n", + "Warnings during compiled execution: none captured\n" + ] + } + ], + "source": [ + "import torch._dynamo as dynamo\n", + "\n", + "\n", + "def _format_results(rows):\n", + " header = f\"{'Mode':<30} {'mean (ms)':>12} {'std (ms)':>12} {'iters':>8}\"\n", + " lines = [header, \"-\" * len(header)]\n", + " for row in rows:\n", + " lines.append(\n", + " f\"{row['label']:<30} {row['mean_ms']:>12.4f} {row['std_ms']:>12.4f} {row['iters']:>8d}\"\n", + " )\n", + " return \"\\n\".join(lines)\n", + "\n", + "\n", + "def _extract_count(report: str, prefix: str) -> int:\n", + " for line in report.splitlines():\n", + " if line.startswith(prefix):\n", + " return int(line.split(\":\", 1)[1].strip())\n", + " return -1\n", + "\n", + "\n", + "print(_format_results(results))\n", + "\n", + "eager_mean = next(r for r in results if r[\"label\"] == \"Eager step_once\")[\"mean_ms\"]\n", + "compiled_mean = next(r for r in results if \"torch.compile\" in r[\"label\"])[\"mean_ms\"]\n", + "speedup = eager_mean / compiled_mean if compiled_mean else float(\"nan\")\n", + "print(f\"\\nSpeedup (eager / compiled): {speedup:.3f}x\")\n", + "\n", + "compiled_report = dynamo.explain(step_once)(states.tensor, actions.tensor)\n", + "compiled_report_text = str(compiled_report)\n", + "\n", + "graph_count = _extract_count(compiled_report_text, \"Graph Count\")\n", + "graph_breaks = _extract_count(compiled_report_text, \"Graph Break Count\")\n", + "break_reasons = sorted(\n", + " {\n", + " line.strip().split(\":\", 1)[1].strip()\n", + " for line in compiled_report_text.splitlines()\n", + " if line.strip().startswith(\"Reason:\")\n", + " }\n", + ")\n", + "\n", + "print(\n", + " f\"\\nDynamo summary -> Graphs: {graph_count}, Graph breaks: {graph_breaks}, \"\n", + " f\"Break reasons: {break_reasons or ['None']}\"\n", + ")\n", + "\n", + "if compile_warning_messages:\n", + " print(\"\\nWarnings during compiled execution:\")\n", + " for msg in compile_warning_messages:\n", + " print(f\" - {msg}\")\n", + "else:\n", + " print(\"\\nWarnings during compiled execution: none captured\")\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Full GFlowNet benchmark\n", + "\n", + "The cell below reuses `train_hypergrid_optimized.py`'s benchmarking entry-point so we can time a larger training loop (Baseline vs compiled) directly from this notebook.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "calculated tensor of all states in 0.0009723345438639323 minutes\n", + "+ Environment has 1024 states\n", + "+ Environment log partition is 5.711750507354736\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/jdv/code/torchgfn/src/gfn/env.py:495: UserWarning: You're using advanced parameters: (sf). These are only needed for custom action handling. For basic environments, you can omit these.\n", + " warnings.warn(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "calculated tensor of all states in 0.0007189313570658366 minutes\n", + "+ Environment has 1024 states\n", + "+ Environment log partition is 5.711750507354736\n" + ] + }, + { + "data": { + "text/plain": [ + "[{'elapsed': 6.105575958034024,\n", + " 'losses': [7.572955131530762,\n", + " 2.491760015487671,\n", + " 3.2979516983032227,\n", + " 3.0194754600524902,\n", + " 0.9618801474571228,\n", + " 0.8673862218856812,\n", + " 1.6195870637893677,\n", + " 0.5269477367401123,\n", + " 1.0297844409942627,\n", + " 1.332466959953308,\n", + " 0.6973802447319031,\n", + " 1.6610175371170044,\n", + " 0.47196799516677856,\n", + " 1.980200171470642,\n", + " 2.484879970550537,\n", + " 1.153307557106018,\n", + " 0.5622124671936035,\n", + " 0.7249877452850342,\n", + " 1.2468236684799194,\n", + " 1.9157636165618896,\n", + " 1.5802578926086426,\n", + " 0.9950146675109863,\n", + " 0.9827088713645935,\n", + " 0.9594094753265381,\n", + " 2.0273141860961914,\n", + " 1.0678741931915283,\n", + " 1.7654989957809448,\n", + " 1.8363938331604004,\n", + " 0.5704580545425415,\n", + " 2.0948450565338135,\n", + " 0.8548241853713989,\n", + " 4.518639087677002,\n", + " 1.0827535390853882,\n", + " 1.2317500114440918,\n", + " 0.6395683288574219,\n", + " 1.3933279514312744,\n", + " 1.7131190299987793,\n", + " 1.1856663227081299,\n", + " 1.428055763244629,\n", + " 0.8084158897399902,\n", + " 0.37907153367996216,\n", + " 1.583935260772705,\n", + " 2.161365270614624,\n", + " 1.4849199056625366,\n", + " 1.6980212926864624,\n", + " 0.4082474708557129,\n", + " 1.0781633853912354,\n", + " 0.6617383360862732,\n", + " 0.8540241718292236,\n", + " 0.7804931998252869,\n", + " 1.5323201417922974,\n", + " 1.175217628479004,\n", + " 0.4573594331741333,\n", + " 1.7341632843017578,\n", + " 1.3420581817626953,\n", + " 0.6467706561088562,\n", + " 0.7274094223976135,\n", + " 0.4892826974391937,\n", + " 0.4271280765533447,\n", + " 1.8384681940078735,\n", + " 0.7235350608825684,\n", + " 1.1163365840911865,\n", + " 1.350170612335205,\n", + " 0.42495375871658325,\n", + " 1.4814311265945435,\n", + " 0.9633447527885437,\n", + " 0.7441744804382324,\n", + " 0.49172350764274597,\n", + " 0.8439239859580994,\n", + " 1.7822604179382324,\n", + " 2.700016975402832,\n", + " 1.2268513441085815,\n", + " 1.689005732536316,\n", + " 0.8238610029220581,\n", + " 1.1699678897857666,\n", + " 0.5300710201263428,\n", + " 0.3418184518814087,\n", + " 0.6137208342552185,\n", + " 0.7867594957351685,\n", + " 2.185892343521118,\n", + " 0.521285355091095,\n", + " 1.2726870775222778,\n", + " 0.450527161359787,\n", + " 1.4032469987869263,\n", + " 0.4590965509414673,\n", + " 0.4651802182197571,\n", + " 0.8082336187362671,\n", + " 0.7147279977798462,\n", + " 0.5436726212501526,\n", + " 1.4215888977050781,\n", + " 1.2204718589782715,\n", + " 0.25053414702415466,\n", + " 1.3869221210479736,\n", + " 1.0469372272491455,\n", + " 1.4329015016555786,\n", + " 0.6535708904266357,\n", + " 1.5620619058609009,\n", + " 2.5244903564453125,\n", + " 0.7888399362564087,\n", + " 5.477772235870361,\n", + " 2.438631057739258,\n", + " 1.3707287311553955,\n", + " 0.9398314952850342,\n", + " 0.46227526664733887,\n", + " 0.5834711790084839,\n", + " 1.6012723445892334,\n", + " 1.7451685667037964,\n", + " 0.9019611477851868,\n", + " 0.7608010172843933,\n", + " 1.2909115552902222,\n", + " 0.4646669626235962,\n", + " 0.4738759398460388,\n", + " 0.7258164882659912,\n", + " 1.7222124338150024,\n", + " 2.4299848079681396,\n", + " 2.2367324829101562,\n", + " 2.0737838745117188,\n", + " 0.5498713254928589,\n", + " 1.791583776473999,\n", + " 1.1461840867996216,\n", + " 0.2634006142616272,\n", + " 0.8089989423751831,\n", + " 0.283713698387146,\n", + " 0.14579910039901733,\n", + " 0.8403635621070862,\n", + " 0.5281050205230713,\n", + " 0.3584972620010376,\n", + " 0.6050671935081482,\n", + " 0.4479628801345825,\n", + " 0.5756915211677551,\n", + " 1.256803035736084,\n", + " 0.7478235960006714,\n", + " 0.375349760055542,\n", + " 0.35438239574432373,\n", + " 1.180328369140625,\n", + " 1.3136284351348877,\n", + " 3.27740478515625,\n", + " 1.0790156126022339,\n", + " 1.540788173675537,\n", + " 1.0326931476593018,\n", + " 0.9449985027313232,\n", + " 3.155139684677124,\n", + " 0.9995787143707275,\n", + " 0.7784771919250488,\n", + " 1.443956732749939,\n", + " 0.8618024587631226,\n", + " 0.3689582049846649,\n", + " 0.4708964228630066,\n", + " 1.133431315422058,\n", + " 1.1145482063293457,\n", + " 0.4921965003013611,\n", + " 0.415180504322052,\n", + " 1.5828590393066406,\n", + " 2.7756614685058594,\n", + " 0.6000064611434937,\n", + " 0.7194350957870483,\n", + " 0.8013563752174377,\n", + " 1.2213236093521118,\n", + " 1.1368153095245361,\n", + " 1.3761565685272217,\n", + " 0.346245139837265,\n", + " 0.3820178508758545,\n", + " 0.4246070384979248,\n", + " 0.5360602140426636,\n", + " 0.6117574572563171,\n", + " 1.0365396738052368,\n", + " 0.191411092877388,\n", + " 0.6832408905029297,\n", + " 1.0592424869537354,\n", + " 0.6425431966781616,\n", + " 0.498931348323822,\n", + " 0.774171769618988,\n", + " 0.32929566502571106,\n", + " 0.42261868715286255,\n", + " 0.3470837473869324,\n", + " 0.4950379729270935,\n", + " 0.5027860403060913,\n", + " 0.35800668597221375,\n", + " 1.288243293762207,\n", + " 0.6900274753570557,\n", + " 1.4558058977127075,\n", + " 1.1142147779464722,\n", + " 0.2911098003387451,\n", + " 0.7661944031715393,\n", + " 1.0826146602630615,\n", + " 1.19940984249115,\n", + " 0.884093701839447,\n", + " 0.5238901972770691,\n", + " 0.6807741522789001,\n", + " 0.5270069241523743,\n", + " 0.43598586320877075,\n", + " 0.31679433584213257,\n", + " 0.7662327885627747,\n", + " 0.4052656292915344,\n", + " 0.4683819115161896,\n", + " 0.4934506416320801,\n", + " 0.17495952546596527,\n", + " 0.5440036654472351,\n", + " 0.5274096131324768,\n", + " 0.6581551432609558],\n", + " 'iter_times': [0.029843291034922004,\n", + " 0.02665925002656877,\n", + " 0.026591749861836433,\n", + " 0.02578529203310609,\n", + " 0.0249330410733819,\n", + " 0.02374075003899634,\n", + " 0.02827424998395145,\n", + " 0.016017750138416886,\n", + " 0.025069749914109707,\n", + " 0.02290912508033216,\n", + " 0.01798625010997057,\n", + " 0.022280167089775205,\n", + " 0.021828250028192997,\n", + " 0.025819166796281934,\n", + " 0.023717750096693635,\n", + " 0.027398916892707348,\n", + " 0.0268092080950737,\n", + " 0.02269370900467038,\n", + " 0.028668625047430396,\n", + " 0.025364749832078815,\n", + " 0.026112499879673123,\n", + " 0.03050654218532145,\n", + " 0.026381792034953833,\n", + " 0.02492345799691975,\n", + " 0.024277166929095984,\n", + " 0.025802375050261617,\n", + " 0.023120166966691613,\n", + " 0.022838874952867627,\n", + " 0.02152595785446465,\n", + " 0.02243420807644725,\n", + " 0.026837416924536228,\n", + " 0.02476345794275403,\n", + " 0.027901625027880073,\n", + " 0.02595195802859962,\n", + " 0.029563874937593937,\n", + " 0.02456283406354487,\n", + " 0.023599833017215133,\n", + " 0.027492624940350652,\n", + " 0.022965540876612067,\n", + " 0.02508420799858868,\n", + " 0.028139542089775205,\n", + " 0.025828625075519085,\n", + " 0.028102665906772017,\n", + " 0.026154458988457918,\n", + " 0.027773250127211213,\n", + " 0.025903834030032158,\n", + " 0.028939666924998164,\n", + " 0.024504374945536256,\n", + " 0.02638370799832046,\n", + " 0.025218249997124076,\n", + " 0.023709542118012905,\n", + " 0.03099012514576316,\n", + " 0.021817500004544854,\n", + " 0.025243084179237485,\n", + " 0.031247250037267804,\n", + " 0.026461832923814654,\n", + " 0.02370858401991427,\n", + " 0.02921529207378626,\n", + " 0.02194329211488366,\n", + " 0.02969479188323021,\n", + " 0.02624008315615356,\n", + " 0.025315249804407358,\n", + " 0.030012167058885098,\n", + " 0.032295542070642114,\n", + " 0.029138999991118908,\n", + " 0.02831891691312194,\n", + " 0.02835141704417765,\n", + " 0.02719200006686151,\n", + " 0.027825874974951148,\n", + " 0.024933042004704475,\n", + " 0.03004787489771843,\n", + " 0.0257573330309242,\n", + " 0.02414166694507003,\n", + " 0.028785540955141187,\n", + " 0.029787667095661163,\n", + " 0.03012562496587634,\n", + " 0.0186317500192672,\n", + " 0.026522708125412464,\n", + " 0.025528999976813793,\n", + " 0.02337670815177262,\n", + " 0.028197707841172814,\n", + " 0.028563749976456165,\n", + " 0.023193708853796124,\n", + " 0.023533583153039217,\n", + " 0.023977917153388262,\n", + " 0.025529957842081785,\n", + " 0.025657958118245006,\n", + " 0.028129874961450696,\n", + " 0.0250042078550905,\n", + " 0.024240750120952725,\n", + " 0.02651733416132629,\n", + " 0.03236395795829594,\n", + " 0.03119733394123614,\n", + " 0.023247458040714264,\n", + " 0.02916412497870624,\n", + " 0.03493220801465213,\n", + " 0.026215665973722935,\n", + " 0.0358433339279145,\n", + " 0.03177366708405316,\n", + " 0.03824399993754923,\n", + " 0.03338041715323925,\n", + " 0.034712209133431315,\n", + " 0.03465991700068116,\n", + " 0.02921774983406067,\n", + " 0.02734670788049698,\n", + " 0.032816499937325716,\n", + " 0.02972904103808105,\n", + " 0.032379542011767626,\n", + " 0.03371258289553225,\n", + " 0.03127762512303889,\n", + " 0.02719833399169147,\n", + " 0.025000832974910736,\n", + " 0.03524758294224739,\n", + " 0.03120354190468788,\n", + " 0.03501374996267259,\n", + " 0.03635912504978478,\n", + " 0.0354501660913229,\n", + " 0.03287862497381866,\n", + " 0.029975875047966838,\n", + " 0.03692404204048216,\n", + " 0.02574570896103978,\n", + " 0.03072458296082914,\n", + " 0.03142145904712379,\n", + " 0.034465207951143384,\n", + " 0.032932084053754807,\n", + " 0.03766591614112258,\n", + " 0.031669416930526495,\n", + " 0.03097916697151959,\n", + " 0.024389415979385376,\n", + " 0.026578041957691312,\n", + " 0.028716041008010507,\n", + " 0.032391542103141546,\n", + " 0.03359537501819432,\n", + " 0.029333041980862617,\n", + " 0.03852045815438032,\n", + " 0.03522874996997416,\n", + " 0.039978290908038616,\n", + " 0.03800995904020965,\n", + " 0.03813070897012949,\n", + " 0.03231212496757507,\n", + " 0.039890124928206205,\n", + " 0.03974391706287861,\n", + " 0.040384375024586916,\n", + " 0.03584462497383356,\n", + " 0.03564045880921185,\n", + " 0.03651641705073416,\n", + " 0.037402458023279905,\n", + " 0.03648191690444946,\n", + " 0.03885470796376467,\n", + " 0.03411237499676645,\n", + " 0.03694487502798438,\n", + " 0.02758374996483326,\n", + " 0.03918912494555116,\n", + " 0.03915116610005498,\n", + " 0.03766429191455245,\n", + " 0.034370541106909513,\n", + " 0.03439179202541709,\n", + " 0.03841937496326864,\n", + " 0.039793833857402205,\n", + " 0.03862112481147051,\n", + " 0.03574187494814396,\n", + " 0.03246379108168185,\n", + " 0.036486207973212004,\n", + " 0.03896020818501711,\n", + " 0.03466787491925061,\n", + " 0.0354886669665575,\n", + " 0.03665566607378423,\n", + " 0.03877758304588497,\n", + " 0.03936316608451307,\n", + " 0.04036858305335045,\n", + " 0.03711925004608929,\n", + " 0.03579937503673136,\n", + " 0.0394124158192426,\n", + " 0.03487608302384615,\n", + " 0.033077584113925695,\n", + " 0.03582812496460974,\n", + " 0.03510708408430219,\n", + " 0.03271691710688174,\n", + " 0.024238749872893095,\n", + " 0.0317631671205163,\n", + " 0.037570209009572864,\n", + " 0.037720915861427784,\n", + " 0.03836658294312656,\n", + " 0.03753329208120704,\n", + " 0.03856374998576939,\n", + " 0.03874720796011388,\n", + " 0.038331500021740794,\n", + " 0.03494950011372566,\n", + " 0.03934533311985433,\n", + " 0.034773917170241475,\n", + " 0.0353236251976341,\n", + " 0.03172524995170534,\n", + " 0.03164450009353459,\n", + " 0.02425049990415573,\n", + " 0.03322154190391302,\n", + " 0.036195874912664294,\n", + " 0.03410979197360575,\n", + " 0.03734816703945398,\n", + " 0.039303625002503395,\n", + " 0.03886774997226894],\n", + " 'compile_mode': 'none',\n", + " 'use_compile': False,\n", + " 'requested_vmap': False,\n", + " 'effective_vmap': False,\n", + " 'chunk_size_effective': 0,\n", + " 'label': 'Eager'},\n", + " {'elapsed': 6.3454663751181215,\n", + " 'losses': [7.57296085357666,\n", + " 2.4917588233947754,\n", + " 3.297954797744751,\n", + " 3.0194761753082275,\n", + " 0.9618798494338989,\n", + " 0.8673862218856812,\n", + " 1.61958646774292,\n", + " 0.5269474983215332,\n", + " 1.0297842025756836,\n", + " 1.332466721534729,\n", + " 0.6973803639411926,\n", + " 1.6610198020935059,\n", + " 0.471968412399292,\n", + " 1.9802007675170898,\n", + " 2.4848790168762207,\n", + " 1.1533082723617554,\n", + " 0.5622122883796692,\n", + " 0.7249880433082581,\n", + " 1.2468247413635254,\n", + " 1.9157638549804688,\n", + " 1.580257773399353,\n", + " 0.9950148463249207,\n", + " 0.9827099442481995,\n", + " 0.9594108462333679,\n", + " 2.0273146629333496,\n", + " 1.0678751468658447,\n", + " 1.7654999494552612,\n", + " 1.8363943099975586,\n", + " 0.5704579949378967,\n", + " 2.0948455333709717,\n", + " 0.8548235893249512,\n", + " 4.518638610839844,\n", + " 1.0827548503875732,\n", + " 1.2317492961883545,\n", + " 0.6395676732063293,\n", + " 1.3933277130126953,\n", + " 1.7131195068359375,\n", + " 1.1856666803359985,\n", + " 1.4280558824539185,\n", + " 0.8084155917167664,\n", + " 0.3790717124938965,\n", + " 1.5839354991912842,\n", + " 2.1613659858703613,\n", + " 1.4849202632904053,\n", + " 1.6980226039886475,\n", + " 0.4082470238208771,\n", + " 1.0781641006469727,\n", + " 0.6617385149002075,\n", + " 0.8540250062942505,\n", + " 0.7804922461509705,\n", + " 1.532320499420166,\n", + " 1.1752188205718994,\n", + " 0.45735907554626465,\n", + " 1.7341620922088623,\n", + " 1.3420584201812744,\n", + " 0.6467709541320801,\n", + " 0.7274090647697449,\n", + " 0.4892843961715698,\n", + " 0.42712870240211487,\n", + " 1.8384690284729004,\n", + " 0.7235339879989624,\n", + " 1.1163374185562134,\n", + " 1.3501720428466797,\n", + " 0.42495307326316833,\n", + " 1.4814316034317017,\n", + " 0.9633446335792542,\n", + " 0.7441746592521667,\n", + " 0.49172279238700867,\n", + " 0.8439255952835083,\n", + " 1.7822625637054443,\n", + " 2.70001482963562,\n", + " 1.226853370666504,\n", + " 1.6889946460723877,\n", + " 0.8238599300384521,\n", + " 1.1699568033218384,\n", + " 0.5300586223602295,\n", + " 0.34182092547416687,\n", + " 0.6136969327926636,\n", + " 0.7866867780685425,\n", + " 2.1859169006347656,\n", + " 0.5212369561195374,\n", + " 1.2727247476577759,\n", + " 0.45056644082069397,\n", + " 1.4032206535339355,\n", + " 0.4590488374233246,\n", + " 0.4650927782058716,\n", + " 0.8082565665245056,\n", + " 0.7144088745117188,\n", + " 0.543725311756134,\n", + " 1.4212666749954224,\n", + " 1.2204252481460571,\n", + " 0.25085633993148804,\n", + " 1.3868451118469238,\n", + " 1.0468615293502808,\n", + " 1.4329462051391602,\n", + " 0.6538652777671814,\n", + " 1.561804175376892,\n", + " 2.5266776084899902,\n", + " 0.7883028984069824,\n", + " 5.479814529418945,\n", + " 2.439664840698242,\n", + " 1.371835470199585,\n", + " 0.9409430027008057,\n", + " 0.46193727850914,\n", + " 0.5832473039627075,\n", + " 1.6056299209594727,\n", + " 1.7503899335861206,\n", + " 0.9054805040359497,\n", + " 0.7594331502914429,\n", + " 1.2933785915374756,\n", + " 0.4637310206890106,\n", + " 0.47456109523773193,\n", + " 0.7240238189697266,\n", + " 1.725350260734558,\n", + " 2.4346182346343994,\n", + " 2.2380924224853516,\n", + " 2.0717313289642334,\n", + " 0.5502132177352905,\n", + " 1.7969448566436768,\n", + " 1.1502264738082886,\n", + " 0.26300373673439026,\n", + " 0.8080339431762695,\n", + " 0.28202709555625916,\n", + " 0.145426943898201,\n", + " 0.8435786962509155,\n", + " 0.5270522236824036,\n", + " 0.3560110628604889,\n", + " 0.609140157699585,\n", + " 0.4524146318435669,\n", + " 0.5814225077629089,\n", + " 1.2689967155456543,\n", + " 0.5928022861480713,\n", + " 0.2952418625354767,\n", + " 0.24432311952114105,\n", + " 0.7377331256866455,\n", + " 0.8195648193359375,\n", + " 1.2766600847244263,\n", + " 0.9409489035606384,\n", + " 0.830878496170044,\n", + " 0.4374370574951172,\n", + " 0.3859502971172333,\n", + " 2.2971324920654297,\n", + " 0.41110262274742126,\n", + " 0.7398536801338196,\n", + " 0.43272972106933594,\n", + " 0.5752124190330505,\n", + " 0.3000510632991791,\n", + " 0.8548444509506226,\n", + " 0.373995304107666,\n", + " 1.8866313695907593,\n", + " 1.077983021736145,\n", + " 0.5132556557655334,\n", + " 1.1473032236099243,\n", + " 0.6178485155105591,\n", + " 0.4333594739437103,\n", + " 1.4737694263458252,\n", + " 0.9747253060340881,\n", + " 1.5080058574676514,\n", + " 1.314931869506836,\n", + " 0.9588497877120972,\n", + " 0.39719900488853455,\n", + " 1.0430501699447632,\n", + " 1.1309837102890015,\n", + " 0.43614932894706726,\n", + " 0.58064204454422,\n", + " 1.1400829553604126,\n", + " 0.3988802433013916,\n", + " 0.963148832321167,\n", + " 2.17482328414917,\n", + " 1.3901969194412231,\n", + " 0.3156747817993164,\n", + " 0.5436887741088867,\n", + " 0.36854708194732666,\n", + " 0.37455135583877563,\n", + " 0.2726321220397949,\n", + " 0.3139721155166626,\n", + " 0.5012255311012268,\n", + " 0.820480227470398,\n", + " 1.0951125621795654,\n", + " 0.6761919856071472,\n", + " 0.790934145450592,\n", + " 0.9907330274581909,\n", + " 0.8022286891937256,\n", + " 0.3866922855377197,\n", + " 0.7084116339683533,\n", + " 0.866324245929718,\n", + " 0.46766072511672974,\n", + " 0.26419714093208313,\n", + " 0.32584092020988464,\n", + " 1.2846601009368896,\n", + " 0.39885473251342773,\n", + " 1.0205860137939453,\n", + " 0.27573293447494507,\n", + " 0.24224549531936646,\n", + " 0.6909810900688171,\n", + " 0.3044925034046173,\n", + " 0.25011563301086426,\n", + " 0.44614750146865845,\n", + " 0.6451624035835266,\n", + " 0.5779326558113098],\n", + " 'iter_times': [0.03169812494888902,\n", + " 0.027584708062931895,\n", + " 0.02858670800924301,\n", + " 0.026709083002060652,\n", + " 0.0258675420191139,\n", + " 0.02468925016000867,\n", + " 0.030730166006833315,\n", + " 0.016805625054985285,\n", + " 0.026752999983727932,\n", + " 0.023861791007220745,\n", + " 0.01803375012241304,\n", + " 0.022661666851490736,\n", + " 0.02204920817166567,\n", + " 0.02710441709496081,\n", + " 0.02583591709844768,\n", + " 0.027296457905322313,\n", + " 0.027498583076521754,\n", + " 0.023009249940514565,\n", + " 0.029670832911506295,\n", + " 0.026142749935388565,\n", + " 0.026309833861887455,\n", + " 0.032059666933491826,\n", + " 0.02789283310994506,\n", + " 0.02576716709882021,\n", + " 0.02486624987795949,\n", + " 0.027173375012353063,\n", + " 0.02338912500999868,\n", + " 0.024146082811057568,\n", + " 0.02213795785792172,\n", + " 0.023384332889690995,\n", + " 0.02873941697180271,\n", + " 0.0250537081155926,\n", + " 0.028922915924340487,\n", + " 0.02695966698229313,\n", + " 0.03005241695791483,\n", + " 0.025074792094528675,\n", + " 0.02407074999064207,\n", + " 0.028803542023524642,\n", + " 0.024359500035643578,\n", + " 0.025547208031639457,\n", + " 0.02867558295838535,\n", + " 0.02669191686436534,\n", + " 0.02979383314959705,\n", + " 0.026735665975138545,\n", + " 0.028341416968032718,\n", + " 0.026234874967485666,\n", + " 0.030128249898552895,\n", + " 0.02529741614125669,\n", + " 0.027294500032439828,\n", + " 0.026511665899306536,\n", + " 0.024816541001200676,\n", + " 0.03207574994303286,\n", + " 0.02337020798586309,\n", + " 0.026867209002375603,\n", + " 0.03262049984186888,\n", + " 0.026773500023409724,\n", + " 0.025173667119815946,\n", + " 0.030965832993388176,\n", + " 0.023638332961127162,\n", + " 0.03076474997214973,\n", + " 0.026863333070650697,\n", + " 0.025366832967847586,\n", + " 0.03145316708832979,\n", + " 0.03426345810294151,\n", + " 0.03046545898541808,\n", + " 0.029700499959290028,\n", + " 0.029783915961161256,\n", + " 0.02838366595096886,\n", + " 0.029486834071576595,\n", + " 0.02670641685836017,\n", + " 0.031350333942100406,\n", + " 0.027266208082437515,\n", + " 0.025731500005349517,\n", + " 0.030091083142906427,\n", + " 0.031152040930464864,\n", + " 0.03152075014077127,\n", + " 0.01954770809970796,\n", + " 0.026827292051166296,\n", + " 0.027664915891364217,\n", + " 0.024289875058457255,\n", + " 0.02911649993620813,\n", + " 0.029995207907631993,\n", + " 0.02480362495407462,\n", + " 0.024720708839595318,\n", + " 0.024827249813824892,\n", + " 0.027033083140850067,\n", + " 0.027515250025317073,\n", + " 0.029572209110483527,\n", + " 0.02584737492725253,\n", + " 0.02511425013653934,\n", + " 0.02771108318120241,\n", + " 0.03237133310176432,\n", + " 0.032289124792441726,\n", + " 0.024745125090703368,\n", + " 0.03070554183796048,\n", + " 0.0387380828615278,\n", + " 0.027526458026841283,\n", + " 0.038875750033184886,\n", + " 0.033374999882653356,\n", + " 0.04179483302868903,\n", + " 0.03558491705916822,\n", + " 0.03673141589388251,\n", + " 0.03694891603663564,\n", + " 0.031565249897539616,\n", + " 0.029297207947820425,\n", + " 0.03449812508188188,\n", + " 0.032412667060270905,\n", + " 0.03472112491726875,\n", + " 0.03546695807017386,\n", + " 0.03327808412723243,\n", + " 0.028670792002230883,\n", + " 0.02604879206046462,\n", + " 0.037921792129054666,\n", + " 0.03335845796391368,\n", + " 0.03821487491950393,\n", + " 0.038532041013240814,\n", + " 0.038210750091820955,\n", + " 0.034492208855226636,\n", + " 0.03084229095838964,\n", + " 0.03950933297164738,\n", + " 0.026532666059210896,\n", + " 0.03208570904098451,\n", + " 0.03370637493208051,\n", + " 0.03659145790152252,\n", + " 0.03467929200269282,\n", + " 0.03974362509325147,\n", + " 0.03336887480691075,\n", + " 0.03240949986502528,\n", + " 0.025664541870355606,\n", + " 0.02777195884846151,\n", + " 0.030273291980847716,\n", + " 0.039200125029310584,\n", + " 0.02961912495084107,\n", + " 0.030783750116825104,\n", + " 0.03524379199370742,\n", + " 0.039315874921157956,\n", + " 0.04285550001077354,\n", + " 0.03660220908932388,\n", + " 0.04133837507106364,\n", + " 0.03899241704493761,\n", + " 0.03553324984386563,\n", + " 0.03786025010049343,\n", + " 0.032377874944359064,\n", + " 0.028963749995455146,\n", + " 0.03383512492291629,\n", + " 0.028710374841466546,\n", + " 0.03110875003039837,\n", + " 0.030619458062574267,\n", + " 0.03356295800767839,\n", + " 0.03945783409290016,\n", + " 0.039811542024835944,\n", + " 0.03152520814910531,\n", + " 0.040032291086390615,\n", + " 0.038056124933063984,\n", + " 0.03658033302053809,\n", + " 0.03990195784717798,\n", + " 0.038062167121097445,\n", + " 0.041290540946647525,\n", + " 0.040220499970018864,\n", + " 0.036755625158548355,\n", + " 0.03777787508442998,\n", + " 0.04015366593375802,\n", + " 0.0368197918869555,\n", + " 0.038513916078954935,\n", + " 0.03783300006762147,\n", + " 0.04108829190954566,\n", + " 0.038560917135328054,\n", + " 0.04114095913246274,\n", + " 0.04210358299314976,\n", + " 0.04152379208244383,\n", + " 0.03358387481421232,\n", + " 0.03409449988976121,\n", + " 0.03836424998007715,\n", + " 0.03440766618587077,\n", + " 0.0377882921602577,\n", + " 0.03659795899875462,\n", + " 0.037964458810165524,\n", + " 0.03920579212717712,\n", + " 0.04200679203495383,\n", + " 0.03882258292287588,\n", + " 0.039652334060519934,\n", + " 0.037086208118125796,\n", + " 0.04016891587525606,\n", + " 0.04108100011944771,\n", + " 0.038782832911238074,\n", + " 0.037558333948254585,\n", + " 0.04046133300289512,\n", + " 0.03616187488660216,\n", + " 0.03772379201836884,\n", + " 0.040625750087201595,\n", + " 0.0360937500372529,\n", + " 0.039527124958112836,\n", + " 0.02647212496958673,\n", + " 0.03991162497550249,\n", + " 0.034061457961797714,\n", + " 0.03235474997200072,\n", + " 0.025128833018243313,\n", + " 0.03827429190278053,\n", + " 0.03287583403289318,\n", + " 0.03880300000309944],\n", + " 'compile_mode': 'reduce-overhead',\n", + " 'use_compile': True,\n", + " 'requested_vmap': False,\n", + " 'effective_vmap': False,\n", + " 'chunk_size_effective': 0,\n", + " 'label': 'Compiled'}]" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import importlib\n", + "import json\n", + "import sys\n", + "from pathlib import Path\n", + "\n", + "# Add project root to path (notebook is in tutorials/notebooks/)\n", + "project_root = Path.cwd().parent.parent\n", + "sys.path.append(str(project_root))\n", + "from tutorials.examples import train_hypergrid_optimized as hypergrid_train\n", + "\n", + "\n", + "# Reload to pick up local edits without restarting the kernel.\n", + "importlib.reload(hypergrid_train)\n", + "\n", + "\n", + "def notebook_benchmark_run(\n", + " *,\n", + " compile_mode: str = \"none\",\n", + " use_compile: bool = False,\n", + " chunk_size: int = 0,\n", + " n_iterations: int = 200,\n", + " warmup_iters: int = 50,\n", + " seed: int = 0,\n", + " device: str = \"cpu\",\n", + " label: str,\n", + ") -> dict:\n", + " argv_backup = sys.argv\n", + " try:\n", + " sys.argv = [sys.argv[0]]\n", + " args = hypergrid_train.parse_args()\n", + " finally:\n", + " sys.argv = argv_backup\n", + " args.compile = use_compile\n", + " args.compile_mode = compile_mode\n", + " args.chunk_size = chunk_size\n", + " args.n_iterations = n_iterations\n", + " args.warmup_iters = warmup_iters\n", + " args.seed = seed\n", + " args.device = device\n", + " args.benchmark = True\n", + " args.use_vmap = False\n", + " args.loss = \"TB\"\n", + " args.batch_size = 16\n", + " args.height = 32\n", + " args.ndim = 2\n", + "\n", + " result = hypergrid_train.train_with_options(\n", + " args,\n", + " device=hypergrid_train.resolve_device(device),\n", + " enable_compile=use_compile,\n", + " use_vmap=False,\n", + " warmup_iters=warmup_iters,\n", + " quiet=True,\n", + " timing=True,\n", + " record_history=True,\n", + " use_chunk=(chunk_size > 0),\n", + " )\n", + " result[\"label\"] = label\n", + " result[\"compile_mode\"] = compile_mode if use_compile else \"none\"\n", + " return result\n", + "\n", + "\n", + "scenarios = [\n", + " dict(label=\"Eager\", use_compile=False),\n", + " dict(label=\"Compiled\", use_compile=True, compile_mode=\"reduce-overhead\"),\n", + "]\n", + "\n", + "benchmark_runs = []\n", + "for scenario in scenarios:\n", + " run_result = notebook_benchmark_run(**scenario)\n", + " benchmark_runs.append(run_result)\n", + "\n", + "benchmark_runs\n" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
labelelapsedcompile_modeeffective_vmapchunk_size_effective
0Eager6.105576noneFalse0
1Compiled6.345466reduce-overheadFalse0
\n", + "
" + ], + "text/plain": [ + " label elapsed compile_mode effective_vmap chunk_size_effective\n", + "0 Eager 6.105576 none False 0\n", + "1 Compiled 6.345466 reduce-overhead False 0" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Baseline label: Eager elapsed: 6.11s\n", + "Compiled elapsed=6.35s (0.96x vs baseline), compile_mode=reduce-overhead\n" + ] + } + ], + "source": [ + "import pandas as pd\n", + "\n", + "benchmark_df = pd.DataFrame(benchmark_runs)\n", + "display(\n", + " benchmark_df[\n", + " [\n", + " \"label\",\n", + " \"elapsed\",\n", + " \"compile_mode\",\n", + " \"effective_vmap\",\n", + " \"chunk_size_effective\",\n", + " ]\n", + " ]\n", + ")\n", + "\n", + "baseline = benchmark_df.iloc[0]\n", + "print(\"Baseline label:\", baseline[\"label\"], \"elapsed:\", f\"{baseline['elapsed']:.2f}s\")\n", + "for idx in range(1, len(benchmark_df)):\n", + " row = benchmark_df.iloc[idx]\n", + " speedup = baseline[\"elapsed\"] / row[\"elapsed\"] if row[\"elapsed\"] else float(\"inf\")\n", + " print(\n", + " f\"{row['label']} elapsed={row['elapsed']:.2f}s \"\n", + " f\"({speedup:.2f}x vs baseline), compile_mode={row['compile_mode']}\"\n", + " )\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Dynamo trace analysis\n", + "\n", + "`torch._dynamo.explain` gives a per-graph summary: captured ops, guards, and where graph breaks (if any) occur. The cell below reuses the state/action tensors above and prints the explanation so you can confirm there is only one graph and zero breaks.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Graph Count: 16\n", + "Graph Break Count: 15\n", + "Op Count: 24\n", + "Break Reasons:\n", + " Break Reason 1:\n", + " Reason: Dynamic shape operator\n", + " Explanation: Operator `aten.repeat_interleave.Tensor`'s output shape depends on input Tensor data.\n", + " Hint: Enable tracing of dynamic shape operators with `torch._dynamo.config.capture_dynamic_output_shape_ops = True`\n", + "\n", + " Developer debug context: aten.repeat_interleave.Tensor\n", + "\n", + " User Stack:\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " Break Reason 2:\n", + " Reason: Dynamic shape operator\n", + " Explanation: Operator `aten.repeat_interleave.Tensor`'s output shape depends on input Tensor data.\n", + " Hint: Enable tracing of dynamic shape operators with `torch._dynamo.config.capture_dynamic_output_shape_ops = True`\n", + "\n", + " Developer debug context: aten.repeat_interleave.Tensor\n", + "\n", + " User Stack:\n", + " \n", + " \n", + " \n", + " \n", + " Break Reason 3:\n", + " Reason: Dynamic shape operator\n", + " Explanation: Operator `aten.repeat_interleave.Tensor`'s output shape depends on input Tensor data.\n", + " Hint: Enable tracing of dynamic shape operators with `torch._dynamo.config.capture_dynamic_output_shape_ops = True`\n", + "\n", + " Developer debug context: aten.repeat_interleave.Tensor\n", + "\n", + " User Stack:\n", + " \n", + " Break Reason 4:\n", + " Reason: Dynamic shape operator\n", + " Explanation: Operator `aten.repeat_interleave.Tensor`'s output shape depends on input Tensor data.\n", + " Hint: Enable tracing of dynamic shape operators with `torch._dynamo.config.capture_dynamic_output_shape_ops = True`\n", + "\n", + " Developer debug context: aten.repeat_interleave.Tensor\n", + "\n", + " User Stack:\n", + " \n", + " \n", + " Break Reason 5:\n", + " Reason: Dynamic shape operator\n", + " Explanation: Operator `aten.repeat_interleave.Tensor`'s output shape depends on input Tensor data.\n", + " Hint: Enable tracing of dynamic shape operators with `torch._dynamo.config.capture_dynamic_output_shape_ops = True`\n", + "\n", + " Developer debug context: aten.repeat_interleave.Tensor\n", + "\n", + " User Stack:\n", + " \n", + " Break Reason 6:\n", + " Reason: Dynamic shape operator\n", + " Explanation: Operator `aten.repeat_interleave.Tensor`'s output shape depends on input Tensor data.\n", + " Hint: Enable tracing of dynamic shape operators with `torch._dynamo.config.capture_dynamic_output_shape_ops = True`\n", + "\n", + " Developer debug context: aten.repeat_interleave.Tensor\n", + "\n", + " User Stack:\n", + " \n", + " \n", + " Break Reason 7:\n", + " Reason: Dynamic shape operator\n", + " Explanation: Operator `aten.repeat_interleave.Tensor`'s output shape depends on input Tensor data.\n", + " Hint: Enable tracing of dynamic shape operators with `torch._dynamo.config.capture_dynamic_output_shape_ops = True`\n", + "\n", + " Developer debug context: aten.repeat_interleave.Tensor\n", + "\n", + " User Stack:\n", + " \n", + " Break Reason 8:\n", + " Reason: Dynamic shape operator\n", + " Explanation: Operator `aten.repeat_interleave.Tensor`'s output shape depends on input Tensor data.\n", + " Hint: Enable tracing of dynamic shape operators with `torch._dynamo.config.capture_dynamic_output_shape_ops = True`\n", + "\n", + " Developer debug context: aten.repeat_interleave.Tensor\n", + "\n", + " User Stack:\n", + " \n", + " \n", + " Break Reason 9:\n", + " Reason: Dynamic shape operator\n", + " Explanation: Operator `aten.repeat_interleave.Tensor`'s output shape depends on input Tensor data.\n", + " Hint: Enable tracing of dynamic shape operators with `torch._dynamo.config.capture_dynamic_output_shape_ops = True`\n", + "\n", + " Developer debug context: aten.repeat_interleave.Tensor\n", + "\n", + " User Stack:\n", + " \n", + " \n", + " \n", + " \n", + " Break Reason 10:\n", + " Reason: Dynamic shape operator\n", + " Explanation: Operator `aten.repeat_interleave.Tensor`'s output shape depends on input Tensor data.\n", + " Hint: Enable tracing of dynamic shape operators with `torch._dynamo.config.capture_dynamic_output_shape_ops = True`\n", + "\n", + " Developer debug context: aten.repeat_interleave.Tensor\n", + "\n", + " User Stack:\n", + " \n", + " \n", + " \n", + " \n", + "Ops per Graph:\n", + " Ops 1:\n", + " \n", + " \n", + " Ops 2:\n", + " \n", + " Ops 3:\n", + " \n", + " Ops 4:\n", + " \n", + " Ops 5:\n", + " Ops 6:\n", + " \n", + " Ops 7:\n", + " \n", + " Ops 8:\n", + " Ops 9:\n", + " \n", + " Ops 10:\n", + " \n", + " Ops 11:\n", + " Ops 12:\n", + " Ops 13:\n", + " \n", + " aten._assert_async.msg\n", + " \n", + " \n", + " Ops 14:\n", + " Ops 15:\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " Ops 16:\n", + " \n", + " \n", + " \n", + " \n", + "Out Guards:\n", + " Guard 1:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: DETERMINISTIC_ALGORITHMS\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 2:\n", + " Name: \"G['__import_gfn_dot_states'].torch\"\n", + " Source: global\n", + " Create Function: FUNCTION_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__import_gfn_dot_states'].torch, 4351320976)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 3:\n", + " Name: ''\n", + " Source: shape_env\n", + " Create Function: SHAPE_ENV\n", + " Guard Types: ['SHAPE_ENV']\n", + " Code List: [\"2 <= L['states_tensor'].size()[0]\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 4:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: GRAD_MODE\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 5:\n", + " Name: \"G['__import_gfn_dot_states']\"\n", + " Source: global\n", + " Create Function: FUNCTION_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__import_gfn_dot_states'], 5499072048)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 6:\n", + " Name: \"L['states_tensor']\"\n", + " Source: local\n", + " Create Function: TYPE_MATCH\n", + " Guard Types: ['TYPE_MATCH']\n", + " Code List: [\"___check_type_id(L['states_tensor'], 5106997584)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 7:\n", + " Name: \"G['env'].States.s0\"\n", + " Source: global\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(G['env'].States.s0, '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 8:\n", + " Name: \"G['env']\"\n", + " Source: global\n", + " Create Function: TYPE_MATCH\n", + " Guard Types: ['TYPE_MATCH']\n", + " Code List: [\"___check_type_id(G['env'], 6146079840)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 9:\n", + " Name: \"L['actions_tensor']\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['actions_tensor'], '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 10:\n", + " Name: \"G['__builtins_dict___66']['zip']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___66']['zip'], 4305487296)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 11:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: AUTOGRAD_SAVED_TENSORS_HOOKS\n", + " Guard Types: ['AUTOGRAD_SAVED_TENSORS_HOOKS']\n", + " Code List: ['torch._functorch.aot_autograd.utils.top_saved_tensors_hooks ids == None']\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 12:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: DEFAULT_DEVICE\n", + " Guard Types: ['DEFAULT_DEVICE']\n", + " Code List: ['utils_device.CURRENT_DEVICE == None']\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 13:\n", + " Name: \"G['__builtins_dict___66']['isinstance']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___66']['isinstance'], 4307013760)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 14:\n", + " Name: \"G['__builtins_dict___66']['tuple']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___66']['tuple'], 4305555088)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 15:\n", + " Name: \"G['__builtins_dict___66']['len']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___66']['len'], 4307014080)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 16:\n", + " Name: \"G['__import_gfn_dot_states'].torch.bool\"\n", + " Source: global\n", + " Create Function: CONSTANT_MATCH\n", + " Guard Types: ['EQUALS_MATCH']\n", + " Code List: [\"G['__import_gfn_dot_states'].torch.bool == torch.bool\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 17:\n", + " Name: \"L['states_tensor']\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['states_tensor'], '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 18:\n", + " Name: \"G['env'].States.sf\"\n", + " Source: global\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(G['env'].States.sf, '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 19:\n", + " Name: \"G['__builtins_dict___66']['super']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___66']['super'], 4305490664)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 20:\n", + " Name: \"G['env'].States.n_actions\"\n", + " Source: global\n", + " Create Function: EQUALS_MATCH\n", + " Guard Types: ['EQUALS_MATCH']\n", + " Code List: [\"G['env'].States.n_actions == 3\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 21:\n", + " Name: \"G['__import_gfn_dot_states'].torch.ones\"\n", + " Source: global\n", + " Create Function: FUNCTION_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__import_gfn_dot_states'].torch.ones, 4428225552)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 22:\n", + " Name: \"G['env'].Actions.action_shape\"\n", + " Source: global\n", + " Create Function: SEQUENCE_LENGTH\n", + " Guard Types: ['TYPE_MATCH', 'SEQUENCE_LENGTH']\n", + " Code List: [\"___check_type_id(G['env'].Actions.action_shape, 4305555088)\", \"len(G['env'].Actions.action_shape) == 1\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 23:\n", + " Name: \"G['env'].States.state_shape\"\n", + " Source: global\n", + " Create Function: SEQUENCE_LENGTH\n", + " Guard Types: ['TYPE_MATCH', 'SEQUENCE_LENGTH']\n", + " Code List: [\"___check_type_id(G['env'].States.state_shape, 4305555088)\", \"len(G['env'].States.state_shape) == 1\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 24:\n", + " Name: \"G['env'].Actions.action_shape[0]\"\n", + " Source: global\n", + " Create Function: EQUALS_MATCH\n", + " Guard Types: ['EQUALS_MATCH']\n", + " Code List: [\"G['env'].Actions.action_shape[0] == 1\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 25:\n", + " Name: \"G['env'].States.state_shape[0]\"\n", + " Source: global\n", + " Create Function: EQUALS_MATCH\n", + " Guard Types: ['EQUALS_MATCH']\n", + " Code List: [\"G['env'].States.state_shape[0] == 2\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 26:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: TORCH_FUNCTION_STATE\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 27:\n", + " Name: \"L['states_tensor'].to\"\n", + " Source: local\n", + " Create Function: HASATTR\n", + " Guard Types: ['HASATTR']\n", + " Code List: [\"hasattr(L['states_tensor'], 'to')\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 28:\n", + " Name: \"G['env'].States\"\n", + " Source: global\n", + " Create Function: ID_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['env'].States, 6146094768)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 29:\n", + " Name: \"G['env'].Actions\"\n", + " Source: global\n", + " Create Function: ID_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['env'].Actions, 6146095712)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 30:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: DETERMINISTIC_ALGORITHMS\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 31:\n", + " Name: \"G['__builtins_dict___70']['isinstance']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___70']['isinstance'], 4307013760)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 32:\n", + " Name: \"G['__import_gfn_dot_states'].torch\"\n", + " Source: global\n", + " Create Function: FUNCTION_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__import_gfn_dot_states'].torch, 4351320976)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 33:\n", + " Name: \"G['__builtins_dict___70']['zip']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___70']['zip'], 4305487296)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 34:\n", + " Name: ''\n", + " Source: shape_env\n", + " Create Function: SHAPE_ENV\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 35:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: GRAD_MODE\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 36:\n", + " Name: \"G['__builtins_dict___70']['len']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___70']['len'], 4307014080)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 37:\n", + " Name: \"G['__import_gfn_dot_states']\"\n", + " Source: global\n", + " Create Function: FUNCTION_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__import_gfn_dot_states'], 5499072048)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 38:\n", + " Name: \"G['__builtins_dict___70']['tuple']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___70']['tuple'], 4305555088)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 39:\n", + " Name: \"L['actions']\"\n", + " Source: local\n", + " Create Function: TYPE_MATCH\n", + " Guard Types: ['TYPE_MATCH']\n", + " Code List: [\"___check_type_id(L['actions'], 6146095712)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 40:\n", + " Name: \"G['__import_gfn_dot_states'].torch.Tensor\"\n", + " Source: global\n", + " Create Function: ID_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__import_gfn_dot_states'].torch.Tensor, 5106997584)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 41:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: AUTOGRAD_SAVED_TENSORS_HOOKS\n", + " Guard Types: ['AUTOGRAD_SAVED_TENSORS_HOOKS']\n", + " Code List: ['torch._functorch.aot_autograd.utils.top_saved_tensors_hooks ids == None']\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 42:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: DEFAULT_DEVICE\n", + " Guard Types: ['DEFAULT_DEVICE']\n", + " Code List: ['utils_device.CURRENT_DEVICE == None']\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 43:\n", + " Name: \"L['states'].__class__.sf\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['states'].__class__.sf, '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 44:\n", + " Name: \"L['actions'].tensor\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['actions'].tensor, '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 45:\n", + " Name: \"G['torch']\"\n", + " Source: global\n", + " Create Function: DUPLICATE_INPUT\n", + " Guard Types: ['DUPLICATE_INPUT']\n", + " Code List: [\"G['__import_gfn_dot_states'].torch is G['torch']\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 46:\n", + " Name: \"L['self']\"\n", + " Source: local\n", + " Create Function: TYPE_MATCH\n", + " Guard Types: ['TYPE_MATCH']\n", + " Code List: [\"___check_type_id(L['self'], 6146079840)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 47:\n", + " Name: \"G['__import_gfn_dot_states'].torch.bool\"\n", + " Source: global\n", + " Create Function: CONSTANT_MATCH\n", + " Guard Types: ['EQUALS_MATCH']\n", + " Code List: [\"G['__import_gfn_dot_states'].torch.bool == torch.bool\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 48:\n", + " Name: \"L['states'].tensor\"\n", + " Source: local\n", + " Create Function: TYPE_MATCH\n", + " Guard Types: ['TYPE_MATCH']\n", + " Code List: [\"___check_type_id(L['states'].tensor, 5106997584)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 49:\n", + " Name: \"L['states']\"\n", + " Source: local\n", + " Create Function: TYPE_MATCH\n", + " Guard Types: ['TYPE_MATCH']\n", + " Code List: [\"___check_type_id(L['states'], 6146094768)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 50:\n", + " Name: \"L['self'].check_action_validity\"\n", + " Source: local\n", + " Create Function: CONSTANT_MATCH\n", + " Guard Types: ['BOOL_MATCH']\n", + " Code List: [\"L['self'].check_action_validity == True\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 51:\n", + " Name: \"L['states'].state_shape\"\n", + " Source: local\n", + " Create Function: SEQUENCE_LENGTH\n", + " Guard Types: ['TYPE_MATCH', 'SEQUENCE_LENGTH']\n", + " Code List: [\"___check_type_id(L['states'].state_shape, 4305555088)\", \"len(L['states'].state_shape) == 1\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 52:\n", + " Name: \"G['__import_gfn_dot_states'].ensure_same_device\"\n", + " Source: global\n", + " Create Function: CLOSURE_MATCH\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 53:\n", + " Name: \"L['states'].tensor.shape\"\n", + " Source: local\n", + " Create Function: SEQUENCE_LENGTH\n", + " Guard Types: ['TYPE_MATCH', 'SEQUENCE_LENGTH']\n", + " Code List: [\"___check_type_id(L['states'].tensor.shape, 4891320080)\", \"len(L['states'].tensor.shape) == 2\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 54:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: TORCH_FUNCTION_STATE\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 55:\n", + " Name: \"L['states'].tensor\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['states'].tensor, '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 56:\n", + " Name: \"L['actions'].action_shape\"\n", + " Source: local\n", + " Create Function: SEQUENCE_LENGTH\n", + " Guard Types: ['TYPE_MATCH', 'SEQUENCE_LENGTH']\n", + " Code List: [\"___check_type_id(L['actions'].action_shape, 4305555088)\", \"len(L['actions'].action_shape) == 1\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 57:\n", + " Name: \"L['states'].state_shape[0]\"\n", + " Source: local\n", + " Create Function: CONSTANT_MATCH\n", + " Guard Types: ['EQUALS_MATCH']\n", + " Code List: [\"L['states'].state_shape[0] == 2\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 58:\n", + " Name: ''\n", + " Source: shape_env\n", + " Create Function: SHAPE_ENV\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 59:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: GRAD_MODE\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 60:\n", + " Name: \"G['torch']\"\n", + " Source: global\n", + " Create Function: FUNCTION_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['torch'], 4351320976)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 61:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: DETERMINISTIC_ALGORITHMS\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 62:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: AUTOGRAD_SAVED_TENSORS_HOOKS\n", + " Guard Types: ['AUTOGRAD_SAVED_TENSORS_HOOKS']\n", + " Code List: ['torch._functorch.aot_autograd.utils.top_saved_tensors_hooks ids == None']\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 63:\n", + " Name: \"L['self']\"\n", + " Source: local\n", + " Create Function: TYPE_MATCH\n", + " Guard Types: ['TYPE_MATCH']\n", + " Code List: [\"___check_type_id(L['self'], 6146095712)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 64:\n", + " Name: \"G['__builtins_dict___74']['isinstance']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___74']['isinstance'], 4307013760)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 65:\n", + " Name: \"L['index']\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['index'], '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 66:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: TORCH_FUNCTION_STATE\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 67:\n", + " Name: \"G['torch'].Tensor\"\n", + " Source: global\n", + " Create Function: ID_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['torch'].Tensor, 5106997584)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 68:\n", + " Name: \"G['torch'].bool\"\n", + " Source: global\n", + " Create Function: CONSTANT_MATCH\n", + " Guard Types: ['EQUALS_MATCH']\n", + " Code List: [\"G['torch'].bool == torch.bool\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 69:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: DEFAULT_DEVICE\n", + " Guard Types: ['DEFAULT_DEVICE']\n", + " Code List: ['utils_device.CURRENT_DEVICE == None']\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 70:\n", + " Name: ''\n", + " Source: shape_env\n", + " Create Function: SHAPE_ENV\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 71:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: GRAD_MODE\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 72:\n", + " Name: \"L['self'].action_shape\"\n", + " Source: local\n", + " Create Function: SEQUENCE_LENGTH\n", + " Guard Types: ['TYPE_MATCH', 'SEQUENCE_LENGTH']\n", + " Code List: [\"___check_type_id(L['self'].action_shape, 4305555088)\", \"len(L['self'].action_shape) == 1\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 73:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: DETERMINISTIC_ALGORITHMS\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 74:\n", + " Name: \"G['__builtins_dict___77']['len']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___77']['len'], 4307014080)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 75:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: AUTOGRAD_SAVED_TENSORS_HOOKS\n", + " Guard Types: ['AUTOGRAD_SAVED_TENSORS_HOOKS']\n", + " Code List: ['torch._functorch.aot_autograd.utils.top_saved_tensors_hooks ids == None']\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 76:\n", + " Name: \"L['self']\"\n", + " Source: local\n", + " Create Function: TYPE_MATCH\n", + " Guard Types: ['TYPE_MATCH']\n", + " Code List: [\"___check_type_id(L['self'], 6146095712)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 77:\n", + " Name: \"G['__builtins_dict___77']['tuple']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___77']['tuple'], 4305555088)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 78:\n", + " Name: \"G['boolean_mask_select']\"\n", + " Source: global\n", + " Create Function: CLOSURE_MATCH\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 79:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: TORCH_FUNCTION_STATE\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 80:\n", + " Name: \"L['self'].tensor\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['self'].tensor, '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 81:\n", + " Name: \"L['mask']\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['mask'], '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 82:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: DEFAULT_DEVICE\n", + " Guard Types: ['DEFAULT_DEVICE']\n", + " Code List: ['utils_device.CURRENT_DEVICE == None']\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 83:\n", + " Name: \"L['mask'].to\"\n", + " Source: local\n", + " Create Function: HASATTR\n", + " Guard Types: ['HASATTR']\n", + " Code List: [\"hasattr(L['mask'], 'to')\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 84:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: DETERMINISTIC_ALGORITHMS\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 85:\n", + " Name: \"L['mask']\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['mask'], '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 86:\n", + " Name: \"G['torch'].arange\"\n", + " Source: global\n", + " Create Function: FUNCTION_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['torch'].arange, 4428197584)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 87:\n", + " Name: \"G['torch'].repeat_interleave\"\n", + " Source: global\n", + " Create Function: FUNCTION_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['torch'].repeat_interleave, 4428289808)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 88:\n", + " Name: \"G['_expand_mask_to_batch']\"\n", + " Source: global\n", + " Create Function: CLOSURE_MATCH\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 89:\n", + " Name: \"L['batch_shape'][0]\"\n", + " Source: local\n", + " Create Function: CONSTANT_MATCH\n", + " Guard Types: ['EQUALS_MATCH']\n", + " Code List: [\"L['batch_shape'][0] == 4\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 90:\n", + " Name: \"G['__import_torch_dot__dynamo_dot_polyfills'].types.FunctionType\"\n", + " Source: global\n", + " Create Function: ID_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__import_torch_dot__dynamo_dot_polyfills'].types.FunctionType, 4305497480)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 91:\n", + " Name: ''\n", + " Source: shape_env\n", + " Create Function: SHAPE_ENV\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 92:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: GRAD_MODE\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 93:\n", + " Name: \"G['__import_torch_dot__dynamo_dot_polyfills'].cmp_eq\"\n", + " Source: global\n", + " Create Function: CLOSURE_MATCH\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 94:\n", + " Name: \"G['torch'].int64\"\n", + " Source: global\n", + " Create Function: CONSTANT_MATCH\n", + " Guard Types: ['EQUALS_MATCH']\n", + " Code List: [\"G['torch'].int64 == torch.int64\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 95:\n", + " Name: \"G['torch']\"\n", + " Source: global\n", + " Create Function: FUNCTION_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['torch'], 4351320976)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 96:\n", + " Name: \"L['value_shape']\"\n", + " Source: local\n", + " Create Function: SEQUENCE_LENGTH\n", + " Guard Types: ['TYPE_MATCH', 'SEQUENCE_LENGTH']\n", + " Code List: [\"___check_type_id(L['value_shape'], 4891320080)\", \"len(L['value_shape']) == 1\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 97:\n", + " Name: \"L['data']\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['data'], '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 98:\n", + " Name: \"G['__builtins_dict___79']['type']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___79']['type'], 4305563712)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 99:\n", + " Name: \"L['batch_shape']\"\n", + " Source: local\n", + " Create Function: SEQUENCE_LENGTH\n", + " Guard Types: ['TYPE_MATCH', 'SEQUENCE_LENGTH']\n", + " Code List: [\"___check_type_id(L['batch_shape'], 4305555088)\", \"len(L['batch_shape']) == 1\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 100:\n", + " Name: \"L['device']\"\n", + " Source: local\n", + " Create Function: CONSTANT_MATCH\n", + " Guard Types: ['EQUALS_MATCH']\n", + " Code List: [\"L['device'] == device(type='cpu')\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 101:\n", + " Name: \"G['__builtins_dict___79']['isinstance']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___79']['isinstance'], 4307013760)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 102:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: AUTOGRAD_SAVED_TENSORS_HOOKS\n", + " Guard Types: ['AUTOGRAD_SAVED_TENSORS_HOOKS']\n", + " Code List: ['torch._functorch.aot_autograd.utils.top_saved_tensors_hooks ids == None']\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 103:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: DEFAULT_DEVICE\n", + " Guard Types: ['DEFAULT_DEVICE']\n", + " Code List: ['utils_device.CURRENT_DEVICE == None']\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 104:\n", + " Name: \"G['__builtins_dict___79']['len']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___79']['len'], 4307014080)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 105:\n", + " Name: \"G['__import_torch_dot__dynamo_dot_polyfills'].types\"\n", + " Source: global\n", + " Create Function: FUNCTION_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__import_torch_dot__dynamo_dot_polyfills'].types, 4308052752)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 106:\n", + " Name: \"G['__builtins_dict___79']['zip']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___79']['zip'], 4305487296)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 107:\n", + " Name: \"G['torch'].bool\"\n", + " Source: global\n", + " Create Function: CONSTANT_MATCH\n", + " Guard Types: ['EQUALS_MATCH']\n", + " Code List: [\"G['torch'].bool == torch.bool\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 108:\n", + " Name: \"L['value_shape'][0]\"\n", + " Source: local\n", + " Create Function: CONSTANT_MATCH\n", + " Guard Types: ['EQUALS_MATCH']\n", + " Code List: [\"L['value_shape'][0] == 1\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 109:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: TORCH_FUNCTION_STATE\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 110:\n", + " Name: \"L['data'].reshape\"\n", + " Source: local\n", + " Create Function: HASATTR\n", + " Guard Types: ['HASATTR']\n", + " Code List: [\"hasattr(L['data'], 'reshape')\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 111:\n", + " Name: \"G['__import_torch_dot__dynamo_dot_polyfills']\"\n", + " Source: global\n", + " Create Function: FUNCTION_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__import_torch_dot__dynamo_dot_polyfills'], 5486842384)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 112:\n", + " Name: ''\n", + " Source: shape_env\n", + " Create Function: SHAPE_ENV\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 113:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: GRAD_MODE\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 114:\n", + " Name: \"G['torch']\"\n", + " Source: global\n", + " Create Function: FUNCTION_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['torch'], 4351320976)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 115:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: DETERMINISTIC_ALGORITHMS\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 116:\n", + " Name: \"L['value_shape']\"\n", + " Source: local\n", + " Create Function: SEQUENCE_LENGTH\n", + " Guard Types: ['TYPE_MATCH', 'SEQUENCE_LENGTH']\n", + " Code List: [\"___check_type_id(L['value_shape'], 4891320080)\", \"len(L['value_shape']) == 1\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 117:\n", + " Name: \"L['___stack0']\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['___stack0'], '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 118:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: AUTOGRAD_SAVED_TENSORS_HOOKS\n", + " Guard Types: ['AUTOGRAD_SAVED_TENSORS_HOOKS']\n", + " Code List: ['torch._functorch.aot_autograd.utils.top_saved_tensors_hooks ids == None']\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 119:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: TORCH_FUNCTION_STATE\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 120:\n", + " Name: \"L['original_ndim']\"\n", + " Source: local\n", + " Create Function: CONSTANT_MATCH\n", + " Guard Types: ['EQUALS_MATCH']\n", + " Code List: [\"L['original_ndim'] == 1\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 121:\n", + " Name: \"G['torch'].index_select\"\n", + " Source: global\n", + " Create Function: FUNCTION_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['torch'].index_select, 4428298288)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 122:\n", + " Name: \"L['batch_shape']\"\n", + " Source: local\n", + " Create Function: SEQUENCE_LENGTH\n", + " Guard Types: ['TYPE_MATCH', 'SEQUENCE_LENGTH']\n", + " Code List: [\"___check_type_id(L['batch_shape'], 4305555088)\", \"len(L['batch_shape']) == 1\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 123:\n", + " Name: \"L['flat_data']\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['flat_data'], '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 124:\n", + " Name: \"L['value_shape'][0]\"\n", + " Source: local\n", + " Create Function: CONSTANT_MATCH\n", + " Guard Types: ['EQUALS_MATCH']\n", + " Code List: [\"L['value_shape'][0] == 1\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 125:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: DEFAULT_DEVICE\n", + " Guard Types: ['DEFAULT_DEVICE']\n", + " Code List: ['utils_device.CURRENT_DEVICE == None']\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 126:\n", + " Name: ''\n", + " Source: shape_env\n", + " Create Function: SHAPE_ENV\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 127:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: GRAD_MODE\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 128:\n", + " Name: \"L['valid_states_idx']\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['valid_states_idx'], '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 129:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: DETERMINISTIC_ALGORITHMS\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 130:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: AUTOGRAD_SAVED_TENSORS_HOOKS\n", + " Guard Types: ['AUTOGRAD_SAVED_TENSORS_HOOKS']\n", + " Code List: ['torch._functorch.aot_autograd.utils.top_saved_tensors_hooks ids == None']\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 131:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: TORCH_FUNCTION_STATE\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 132:\n", + " Name: \"L['___stack0']\"\n", + " Source: local\n", + " Create Function: TYPE_MATCH\n", + " Guard Types: ['TYPE_MATCH']\n", + " Code List: [\"___check_type_id(L['___stack0'], 6146095712)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 133:\n", + " Name: \"L['states']\"\n", + " Source: local\n", + " Create Function: TYPE_MATCH\n", + " Guard Types: ['TYPE_MATCH']\n", + " Code List: [\"___check_type_id(L['states'], 6146094768)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 134:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: DEFAULT_DEVICE\n", + " Guard Types: ['DEFAULT_DEVICE']\n", + " Code List: ['utils_device.CURRENT_DEVICE == None']\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 135:\n", + " Name: ''\n", + " Source: shape_env\n", + " Create Function: SHAPE_ENV\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 136:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: GRAD_MODE\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 137:\n", + " Name: \"G['torch']\"\n", + " Source: global\n", + " Create Function: FUNCTION_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['torch'], 4351320976)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 138:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: DETERMINISTIC_ALGORITHMS\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 139:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: AUTOGRAD_SAVED_TENSORS_HOOKS\n", + " Guard Types: ['AUTOGRAD_SAVED_TENSORS_HOOKS']\n", + " Code List: ['torch._functorch.aot_autograd.utils.top_saved_tensors_hooks ids == None']\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 140:\n", + " Name: \"L['self']\"\n", + " Source: local\n", + " Create Function: TYPE_MATCH\n", + " Guard Types: ['TYPE_MATCH']\n", + " Code List: [\"___check_type_id(L['self'], 6146094768)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 141:\n", + " Name: \"G['torch'].bool\"\n", + " Source: global\n", + " Create Function: CONSTANT_MATCH\n", + " Guard Types: ['EQUALS_MATCH']\n", + " Code List: [\"G['torch'].bool == torch.bool\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 142:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: TORCH_FUNCTION_STATE\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 143:\n", + " Name: \"G['__builtins_dict___90']['isinstance']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___90']['isinstance'], 4307013760)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 144:\n", + " Name: \"G['torch'].Tensor\"\n", + " Source: global\n", + " Create Function: ID_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['torch'].Tensor, 5106997584)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 145:\n", + " Name: \"L['index']\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['index'], '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 146:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: DEFAULT_DEVICE\n", + " Guard Types: ['DEFAULT_DEVICE']\n", + " Code List: ['utils_device.CURRENT_DEVICE == None']\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 147:\n", + " Name: ''\n", + " Source: shape_env\n", + " Create Function: SHAPE_ENV\n", + " Guard Types: ['SHAPE_ENV']\n", + " Code List: [\"2 <= L['self'].tensor.size()[0]\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 148:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: GRAD_MODE\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 149:\n", + " Name: \"L['mask']\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['mask'], '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 150:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: DETERMINISTIC_ALGORITHMS\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 151:\n", + " Name: \"L['self'].tensor\"\n", + " Source: local\n", + " Create Function: TYPE_MATCH\n", + " Guard Types: ['TYPE_MATCH']\n", + " Code List: [\"___check_type_id(L['self'].tensor, 5106997584)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 152:\n", + " Name: \"L['self'].state_shape\"\n", + " Source: local\n", + " Create Function: SEQUENCE_LENGTH\n", + " Guard Types: ['TYPE_MATCH', 'SEQUENCE_LENGTH']\n", + " Code List: [\"___check_type_id(L['self'].state_shape, 4305555088)\", \"len(L['self'].state_shape) == 1\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 153:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: AUTOGRAD_SAVED_TENSORS_HOOKS\n", + " Guard Types: ['AUTOGRAD_SAVED_TENSORS_HOOKS']\n", + " Code List: ['torch._functorch.aot_autograd.utils.top_saved_tensors_hooks ids == None']\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 154:\n", + " Name: \"L['self']\"\n", + " Source: local\n", + " Create Function: TYPE_MATCH\n", + " Guard Types: ['TYPE_MATCH']\n", + " Code List: [\"___check_type_id(L['self'], 6146094768)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 155:\n", + " Name: \"L['self'].tensor\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['self'].tensor, '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 156:\n", + " Name: \"G['__builtins_dict___93']['tuple']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___93']['tuple'], 4305555088)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 157:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: TORCH_FUNCTION_STATE\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 158:\n", + " Name: \"L['self'].tensor.shape\"\n", + " Source: local\n", + " Create Function: SEQUENCE_LENGTH\n", + " Guard Types: ['TYPE_MATCH', 'SEQUENCE_LENGTH']\n", + " Code List: [\"___check_type_id(L['self'].tensor.shape, 4891320080)\", \"len(L['self'].tensor.shape) == 2\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 159:\n", + " Name: \"G['boolean_mask_select']\"\n", + " Source: global\n", + " Create Function: CLOSURE_MATCH\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 160:\n", + " Name: \"G['__builtins_dict___93']['len']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___93']['len'], 4307014080)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 161:\n", + " Name: \"L['self'].state_shape[0]\"\n", + " Source: local\n", + " Create Function: CONSTANT_MATCH\n", + " Guard Types: ['EQUALS_MATCH']\n", + " Code List: [\"L['self'].state_shape[0] == 2\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 162:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: DEFAULT_DEVICE\n", + " Guard Types: ['DEFAULT_DEVICE']\n", + " Code List: ['utils_device.CURRENT_DEVICE == None']\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 163:\n", + " Name: \"L['mask'].to\"\n", + " Source: local\n", + " Create Function: HASATTR\n", + " Guard Types: ['HASATTR']\n", + " Code List: [\"hasattr(L['mask'], 'to')\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 164:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: DETERMINISTIC_ALGORITHMS\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 165:\n", + " Name: \"G['__builtins_dict___96']['isinstance']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___96']['isinstance'], 4307013760)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 166:\n", + " Name: \"G['torch'].arange\"\n", + " Source: global\n", + " Create Function: FUNCTION_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['torch'].arange, 4428197584)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 167:\n", + " Name: \"G['torch'].repeat_interleave\"\n", + " Source: global\n", + " Create Function: FUNCTION_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['torch'].repeat_interleave, 4428289808)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 168:\n", + " Name: \"G['_expand_mask_to_batch']\"\n", + " Source: global\n", + " Create Function: CLOSURE_MATCH\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 169:\n", + " Name: \"L['batch_shape'][0]\"\n", + " Source: local\n", + " Create Function: CONSTANT_MATCH\n", + " Guard Types: ['EQUALS_MATCH']\n", + " Code List: [\"L['batch_shape'][0] == 4\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 170:\n", + " Name: \"G['__import_torch_dot__dynamo_dot_polyfills'].types.FunctionType\"\n", + " Source: global\n", + " Create Function: ID_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__import_torch_dot__dynamo_dot_polyfills'].types.FunctionType, 4305497480)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 171:\n", + " Name: \"L['data']\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['data'], '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 172:\n", + " Name: ''\n", + " Source: shape_env\n", + " Create Function: SHAPE_ENV\n", + " Guard Types: ['SHAPE_ENV', 'SHAPE_ENV', 'SHAPE_ENV', 'SHAPE_ENV']\n", + " Code List: [\"L['data'].stride()[0] == L['data'].size()[1]\", \"L['value_shape'][0] == L['data'].size()[1]\", \"2 <= L['data'].size()[0]\", \"2 <= L['data'].size()[1]\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 173:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: GRAD_MODE\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 174:\n", + " Name: \"G['__import_torch_dot__dynamo_dot_polyfills'].cmp_eq\"\n", + " Source: global\n", + " Create Function: CLOSURE_MATCH\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 175:\n", + " Name: \"G['torch'].int64\"\n", + " Source: global\n", + " Create Function: CONSTANT_MATCH\n", + " Guard Types: ['EQUALS_MATCH']\n", + " Code List: [\"G['torch'].int64 == torch.int64\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 176:\n", + " Name: \"G['torch']\"\n", + " Source: global\n", + " Create Function: FUNCTION_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['torch'], 4351320976)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 177:\n", + " Name: \"L['value_shape']\"\n", + " Source: local\n", + " Create Function: SEQUENCE_LENGTH\n", + " Guard Types: ['TYPE_MATCH', 'SEQUENCE_LENGTH']\n", + " Code List: [\"___check_type_id(L['value_shape'], 4305555088)\", \"len(L['value_shape']) == 1\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 178:\n", + " Name: \"L['batch_shape']\"\n", + " Source: local\n", + " Create Function: SEQUENCE_LENGTH\n", + " Guard Types: ['TYPE_MATCH', 'SEQUENCE_LENGTH']\n", + " Code List: [\"___check_type_id(L['batch_shape'], 4305555088)\", \"len(L['batch_shape']) == 1\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 179:\n", + " Name: \"L['device']\"\n", + " Source: local\n", + " Create Function: CONSTANT_MATCH\n", + " Guard Types: ['EQUALS_MATCH']\n", + " Code List: [\"L['device'] == device(type='cpu')\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 180:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: AUTOGRAD_SAVED_TENSORS_HOOKS\n", + " Guard Types: ['AUTOGRAD_SAVED_TENSORS_HOOKS']\n", + " Code List: ['torch._functorch.aot_autograd.utils.top_saved_tensors_hooks ids == None']\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 181:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: DEFAULT_DEVICE\n", + " Guard Types: ['DEFAULT_DEVICE']\n", + " Code List: ['utils_device.CURRENT_DEVICE == None']\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 182:\n", + " Name: \"L['value_shape'][0]\"\n", + " Source: local\n", + " Create Function: TYPE_MATCH\n", + " Guard Types: ['TYPE_MATCH']\n", + " Code List: [\"___check_type_id(L['value_shape'][0], 4305558200)\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 183:\n", + " Name: \"G['__import_torch_dot__dynamo_dot_polyfills'].types\"\n", + " Source: global\n", + " Create Function: FUNCTION_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__import_torch_dot__dynamo_dot_polyfills'].types, 4308052752)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 184:\n", + " Name: \"G['__builtins_dict___96']['type']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___96']['type'], 4305563712)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 185:\n", + " Name: \"G['torch'].bool\"\n", + " Source: global\n", + " Create Function: CONSTANT_MATCH\n", + " Guard Types: ['EQUALS_MATCH']\n", + " Code List: [\"G['torch'].bool == torch.bool\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 186:\n", + " Name: \"G['__builtins_dict___96']['zip']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___96']['zip'], 4305487296)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 187:\n", + " Name: \"G['__builtins_dict___96']['len']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___96']['len'], 4307014080)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 188:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: TORCH_FUNCTION_STATE\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 189:\n", + " Name: \"L['mask']\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['mask'], '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 190:\n", + " Name: \"L['data'].reshape\"\n", + " Source: local\n", + " Create Function: HASATTR\n", + " Guard Types: ['HASATTR']\n", + " Code List: [\"hasattr(L['data'], 'reshape')\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 191:\n", + " Name: \"G['__import_torch_dot__dynamo_dot_polyfills']\"\n", + " Source: global\n", + " Create Function: FUNCTION_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__import_torch_dot__dynamo_dot_polyfills'], 5486842384)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 192:\n", + " Name: ''\n", + " Source: shape_env\n", + " Create Function: SHAPE_ENV\n", + " Guard Types: ['SHAPE_ENV', 'SHAPE_ENV', 'SHAPE_ENV', 'SHAPE_ENV', 'SHAPE_ENV', 'SHAPE_ENV']\n", + " Code List: [\"L['flat_data'].stride()[0] == L['flat_data'].size()[1]\", \"L['flat_data']._base.stride()[0] == L['flat_data']._base.size()[1]\", \"L['value_shape'][0] == L['flat_data'].size()[1]\", \"2 <= L['flat_data'].size()[1]\", \"2 <= L['flat_data']._base.size()[0]\", \"2 <= L['flat_data']._base.size()[1]\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 193:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: GRAD_MODE\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 194:\n", + " Name: \"L['value_shape'][0]\"\n", + " Source: local\n", + " Create Function: TYPE_MATCH\n", + " Guard Types: ['TYPE_MATCH']\n", + " Code List: [\"___check_type_id(L['value_shape'][0], 4305558200)\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 195:\n", + " Name: \"G['torch']\"\n", + " Source: global\n", + " Create Function: FUNCTION_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['torch'], 4351320976)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 196:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: DETERMINISTIC_ALGORITHMS\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 197:\n", + " Name: \"L['value_shape']\"\n", + " Source: local\n", + " Create Function: SEQUENCE_LENGTH\n", + " Guard Types: ['TYPE_MATCH', 'SEQUENCE_LENGTH']\n", + " Code List: [\"___check_type_id(L['value_shape'], 4305555088)\", \"len(L['value_shape']) == 1\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 198:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: AUTOGRAD_SAVED_TENSORS_HOOKS\n", + " Guard Types: ['AUTOGRAD_SAVED_TENSORS_HOOKS']\n", + " Code List: ['torch._functorch.aot_autograd.utils.top_saved_tensors_hooks ids == None']\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 199:\n", + " Name: \"L['___stack0']\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['___stack0'], '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 200:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: TORCH_FUNCTION_STATE\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 201:\n", + " Name: \"L['original_ndim']\"\n", + " Source: local\n", + " Create Function: CONSTANT_MATCH\n", + " Guard Types: ['EQUALS_MATCH']\n", + " Code List: [\"L['original_ndim'] == 1\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 202:\n", + " Name: \"G['torch'].index_select\"\n", + " Source: global\n", + " Create Function: FUNCTION_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['torch'].index_select, 4428298288)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 203:\n", + " Name: \"L['flat_data']\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['flat_data'], '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 204:\n", + " Name: \"L['batch_shape']\"\n", + " Source: local\n", + " Create Function: SEQUENCE_LENGTH\n", + " Guard Types: ['TYPE_MATCH', 'SEQUENCE_LENGTH']\n", + " Code List: [\"___check_type_id(L['batch_shape'], 4305555088)\", \"len(L['batch_shape']) == 1\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 205:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: DEFAULT_DEVICE\n", + " Guard Types: ['DEFAULT_DEVICE']\n", + " Code List: ['utils_device.CURRENT_DEVICE == None']\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 206:\n", + " Name: ''\n", + " Source: shape_env\n", + " Create Function: SHAPE_ENV\n", + " Guard Types: ['SHAPE_ENV']\n", + " Code List: [\"2 <= L['self'].tensor.size()[0]\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 207:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: GRAD_MODE\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 208:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: DETERMINISTIC_ALGORITHMS\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 209:\n", + " Name: \"L['self'].backward_masks\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['self'].backward_masks, '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 210:\n", + " Name: \"L['self'].tensor\"\n", + " Source: local\n", + " Create Function: TYPE_MATCH\n", + " Guard Types: ['TYPE_MATCH']\n", + " Code List: [\"___check_type_id(L['self'].tensor, 5106997584)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 211:\n", + " Name: \"L['self'].tensor\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['self'].tensor, '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 212:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: AUTOGRAD_SAVED_TENSORS_HOOKS\n", + " Guard Types: ['AUTOGRAD_SAVED_TENSORS_HOOKS']\n", + " Code List: ['torch._functorch.aot_autograd.utils.top_saved_tensors_hooks ids == None']\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 213:\n", + " Name: \"L['self']\"\n", + " Source: local\n", + " Create Function: TYPE_MATCH\n", + " Guard Types: ['TYPE_MATCH']\n", + " Code List: [\"___check_type_id(L['self'], 6146094768)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 214:\n", + " Name: \"L['self'].tensor.shape\"\n", + " Source: local\n", + " Create Function: SEQUENCE_LENGTH\n", + " Guard Types: ['TYPE_MATCH', 'SEQUENCE_LENGTH']\n", + " Code List: [\"___check_type_id(L['self'].tensor.shape, 4891320080)\", \"len(L['self'].tensor.shape) == 2\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 215:\n", + " Name: \"G['__builtins_dict___102']['len']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___102']['len'], 4307014080)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 216:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: TORCH_FUNCTION_STATE\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 217:\n", + " Name: \"L['___stack0']\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['___stack0'], '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 218:\n", + " Name: \"L['self'].state_shape\"\n", + " Source: local\n", + " Create Function: SEQUENCE_LENGTH\n", + " Guard Types: ['TYPE_MATCH', 'SEQUENCE_LENGTH']\n", + " Code List: [\"___check_type_id(L['self'].state_shape, 4305555088)\", \"len(L['self'].state_shape) == 1\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 219:\n", + " Name: \"L['self'].forward_masks\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['self'].forward_masks, '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 220:\n", + " Name: \"G['boolean_mask_select']\"\n", + " Source: global\n", + " Create Function: CLOSURE_MATCH\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 221:\n", + " Name: \"G['__builtins_dict___102']['tuple']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___102']['tuple'], 4305555088)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 222:\n", + " Name: \"L['bool_mask']\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['bool_mask'], '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 223:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: DEFAULT_DEVICE\n", + " Guard Types: ['DEFAULT_DEVICE']\n", + " Code List: ['utils_device.CURRENT_DEVICE == None']\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 224:\n", + " Name: \"L['mask'].to\"\n", + " Source: local\n", + " Create Function: HASATTR\n", + " Guard Types: ['HASATTR']\n", + " Code List: [\"hasattr(L['mask'], 'to')\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 225:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: DETERMINISTIC_ALGORITHMS\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 226:\n", + " Name: \"G['torch'].arange\"\n", + " Source: global\n", + " Create Function: FUNCTION_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['torch'].arange, 4428197584)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 227:\n", + " Name: \"G['__builtins_dict___106']['isinstance']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___106']['isinstance'], 4307013760)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 228:\n", + " Name: \"G['torch'].repeat_interleave\"\n", + " Source: global\n", + " Create Function: FUNCTION_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['torch'].repeat_interleave, 4428289808)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 229:\n", + " Name: \"G['_expand_mask_to_batch']\"\n", + " Source: global\n", + " Create Function: CLOSURE_MATCH\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 230:\n", + " Name: \"L['batch_shape'][0]\"\n", + " Source: local\n", + " Create Function: CONSTANT_MATCH\n", + " Guard Types: ['EQUALS_MATCH']\n", + " Code List: [\"L['batch_shape'][0] == 4\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 231:\n", + " Name: \"G['__import_torch_dot__dynamo_dot_polyfills'].types.FunctionType\"\n", + " Source: global\n", + " Create Function: ID_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__import_torch_dot__dynamo_dot_polyfills'].types.FunctionType, 4305497480)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 232:\n", + " Name: \"G['__builtins_dict___106']['type']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___106']['type'], 4305563712)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 233:\n", + " Name: ''\n", + " Source: shape_env\n", + " Create Function: SHAPE_ENV\n", + " Guard Types: ['SHAPE_ENV', 'SHAPE_ENV', 'SHAPE_ENV']\n", + " Code List: [\"L['data'].stride()[0] == L['data'].size()[1]\", \"L['value_shape'][0] == L['data'].size()[1]\", \"2 <= L['data'].size()[1]\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 234:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: GRAD_MODE\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 235:\n", + " Name: \"G['__import_torch_dot__dynamo_dot_polyfills'].cmp_eq\"\n", + " Source: global\n", + " Create Function: CLOSURE_MATCH\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 236:\n", + " Name: \"G['torch'].int64\"\n", + " Source: global\n", + " Create Function: CONSTANT_MATCH\n", + " Guard Types: ['EQUALS_MATCH']\n", + " Code List: [\"G['torch'].int64 == torch.int64\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 237:\n", + " Name: \"G['torch']\"\n", + " Source: global\n", + " Create Function: FUNCTION_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['torch'], 4351320976)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 238:\n", + " Name: \"L['mask']\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['mask'], '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 239:\n", + " Name: \"L['value_shape']\"\n", + " Source: local\n", + " Create Function: SEQUENCE_LENGTH\n", + " Guard Types: ['TYPE_MATCH', 'SEQUENCE_LENGTH']\n", + " Code List: [\"___check_type_id(L['value_shape'], 4891320080)\", \"len(L['value_shape']) == 1\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 240:\n", + " Name: \"L['data']\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['data'], '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 241:\n", + " Name: \"L['batch_shape']\"\n", + " Source: local\n", + " Create Function: SEQUENCE_LENGTH\n", + " Guard Types: ['TYPE_MATCH', 'SEQUENCE_LENGTH']\n", + " Code List: [\"___check_type_id(L['batch_shape'], 4305555088)\", \"len(L['batch_shape']) == 1\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 242:\n", + " Name: \"L['device']\"\n", + " Source: local\n", + " Create Function: CONSTANT_MATCH\n", + " Guard Types: ['EQUALS_MATCH']\n", + " Code List: [\"L['device'] == device(type='cpu')\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 243:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: AUTOGRAD_SAVED_TENSORS_HOOKS\n", + " Guard Types: ['AUTOGRAD_SAVED_TENSORS_HOOKS']\n", + " Code List: ['torch._functorch.aot_autograd.utils.top_saved_tensors_hooks ids == None']\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 244:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: DEFAULT_DEVICE\n", + " Guard Types: ['DEFAULT_DEVICE']\n", + " Code List: ['utils_device.CURRENT_DEVICE == None']\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 245:\n", + " Name: \"L['value_shape'][0]\"\n", + " Source: local\n", + " Create Function: TYPE_MATCH\n", + " Guard Types: ['TYPE_MATCH']\n", + " Code List: [\"___check_type_id(L['value_shape'][0], 4305558200)\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 246:\n", + " Name: \"G['__import_torch_dot__dynamo_dot_polyfills'].types\"\n", + " Source: global\n", + " Create Function: FUNCTION_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__import_torch_dot__dynamo_dot_polyfills'].types, 4308052752)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 247:\n", + " Name: \"G['__builtins_dict___106']['len']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___106']['len'], 4307014080)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 248:\n", + " Name: \"G['torch'].bool\"\n", + " Source: global\n", + " Create Function: CONSTANT_MATCH\n", + " Guard Types: ['EQUALS_MATCH']\n", + " Code List: [\"G['torch'].bool == torch.bool\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 249:\n", + " Name: \"G['__builtins_dict___106']['zip']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___106']['zip'], 4305487296)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 250:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: TORCH_FUNCTION_STATE\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 251:\n", + " Name: \"L['data'].reshape\"\n", + " Source: local\n", + " Create Function: HASATTR\n", + " Guard Types: ['HASATTR']\n", + " Code List: [\"hasattr(L['data'], 'reshape')\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 252:\n", + " Name: \"G['__import_torch_dot__dynamo_dot_polyfills']\"\n", + " Source: global\n", + " Create Function: FUNCTION_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__import_torch_dot__dynamo_dot_polyfills'], 5486842384)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 253:\n", + " Name: \"L['flat_data']\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['flat_data'], '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 254:\n", + " Name: ''\n", + " Source: shape_env\n", + " Create Function: SHAPE_ENV\n", + " Guard Types: ['SHAPE_ENV', 'SHAPE_ENV', 'SHAPE_ENV', 'SHAPE_ENV', 'SHAPE_ENV']\n", + " Code List: [\"L['flat_data'].stride()[0] == L['flat_data'].size()[1]\", \"L['flat_data']._base.stride()[0] == L['flat_data']._base.size()[1]\", \"L['value_shape'][0] == L['flat_data'].size()[1]\", \"2 <= L['flat_data'].size()[1]\", \"2 <= L['flat_data']._base.size()[1]\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 255:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: GRAD_MODE\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 256:\n", + " Name: \"L['___stack0']\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['___stack0'], '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 257:\n", + " Name: \"L['value_shape'][0]\"\n", + " Source: local\n", + " Create Function: TYPE_MATCH\n", + " Guard Types: ['TYPE_MATCH']\n", + " Code List: [\"___check_type_id(L['value_shape'][0], 4305558200)\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 258:\n", + " Name: \"G['torch']\"\n", + " Source: global\n", + " Create Function: FUNCTION_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['torch'], 4351320976)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 259:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: DETERMINISTIC_ALGORITHMS\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 260:\n", + " Name: \"L['value_shape']\"\n", + " Source: local\n", + " Create Function: SEQUENCE_LENGTH\n", + " Guard Types: ['TYPE_MATCH', 'SEQUENCE_LENGTH']\n", + " Code List: [\"___check_type_id(L['value_shape'], 4891320080)\", \"len(L['value_shape']) == 1\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 261:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: AUTOGRAD_SAVED_TENSORS_HOOKS\n", + " Guard Types: ['AUTOGRAD_SAVED_TENSORS_HOOKS']\n", + " Code List: ['torch._functorch.aot_autograd.utils.top_saved_tensors_hooks ids == None']\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 262:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: TORCH_FUNCTION_STATE\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 263:\n", + " Name: \"L['original_ndim']\"\n", + " Source: local\n", + " Create Function: CONSTANT_MATCH\n", + " Guard Types: ['EQUALS_MATCH']\n", + " Code List: [\"L['original_ndim'] == 1\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 264:\n", + " Name: \"G['torch'].index_select\"\n", + " Source: global\n", + " Create Function: FUNCTION_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['torch'].index_select, 4428298288)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 265:\n", + " Name: \"L['batch_shape']\"\n", + " Source: local\n", + " Create Function: SEQUENCE_LENGTH\n", + " Guard Types: ['TYPE_MATCH', 'SEQUENCE_LENGTH']\n", + " Code List: [\"___check_type_id(L['batch_shape'], 4305555088)\", \"len(L['batch_shape']) == 1\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 266:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: DEFAULT_DEVICE\n", + " Guard Types: ['DEFAULT_DEVICE']\n", + " Code List: ['utils_device.CURRENT_DEVICE == None']\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 267:\n", + " Name: ''\n", + " Source: shape_env\n", + " Create Function: SHAPE_ENV\n", + " Guard Types: ['SHAPE_ENV']\n", + " Code List: [\"2 <= L['self'].tensor.size()[0]\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 268:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: GRAD_MODE\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 269:\n", + " Name: \"L['self'].tensor\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['self'].tensor, '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 270:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: DETERMINISTIC_ALGORITHMS\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 271:\n", + " Name: \"L['___stack0']\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['___stack0'], '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 272:\n", + " Name: \"G['__builtins_dict___112']['len']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___112']['len'], 4307014080)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 273:\n", + " Name: \"L['self'].tensor\"\n", + " Source: local\n", + " Create Function: TYPE_MATCH\n", + " Guard Types: ['TYPE_MATCH']\n", + " Code List: [\"___check_type_id(L['self'].tensor, 5106997584)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 274:\n", + " Name: \"L['self'].state_shape\"\n", + " Source: local\n", + " Create Function: SEQUENCE_LENGTH\n", + " Guard Types: ['TYPE_MATCH', 'SEQUENCE_LENGTH']\n", + " Code List: [\"___check_type_id(L['self'].state_shape, 4305555088)\", \"len(L['self'].state_shape) == 1\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 275:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: AUTOGRAD_SAVED_TENSORS_HOOKS\n", + " Guard Types: ['AUTOGRAD_SAVED_TENSORS_HOOKS']\n", + " Code List: ['torch._functorch.aot_autograd.utils.top_saved_tensors_hooks ids == None']\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 276:\n", + " Name: \"L['self']\"\n", + " Source: local\n", + " Create Function: TYPE_MATCH\n", + " Guard Types: ['TYPE_MATCH']\n", + " Code List: [\"___check_type_id(L['self'], 6146094768)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 277:\n", + " Name: \"L['self'].tensor.shape\"\n", + " Source: local\n", + " Create Function: SEQUENCE_LENGTH\n", + " Guard Types: ['TYPE_MATCH', 'SEQUENCE_LENGTH']\n", + " Code List: [\"___check_type_id(L['self'].tensor.shape, 4891320080)\", \"len(L['self'].tensor.shape) == 2\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 278:\n", + " Name: \"G['boolean_mask_select']\"\n", + " Source: global\n", + " Create Function: CLOSURE_MATCH\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 279:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: TORCH_FUNCTION_STATE\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 280:\n", + " Name: \"L['bool_mask']\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['bool_mask'], '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 281:\n", + " Name: \"L['self'].backward_masks\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['self'].backward_masks, '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 282:\n", + " Name: \"G['__builtins_dict___112']['tuple']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___112']['tuple'], 4305555088)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 283:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: DEFAULT_DEVICE\n", + " Guard Types: ['DEFAULT_DEVICE']\n", + " Code List: ['utils_device.CURRENT_DEVICE == None']\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 284:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: DETERMINISTIC_ALGORITHMS\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 285:\n", + " Name: \"L['states']\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['states'], '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 286:\n", + " Name: \"G['__builtins_dict___115']['len']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___115']['len'], 4307014080)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 287:\n", + " Name: ''\n", + " Source: shape_env\n", + " Create Function: SHAPE_ENV\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 288:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: GRAD_MODE\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 289:\n", + " Name: \"L['self'].__class__.n_actions\"\n", + " Source: local\n", + " Create Function: CONSTANT_MATCH\n", + " Guard Types: ['EQUALS_MATCH']\n", + " Code List: [\"L['self'].__class__.n_actions == 3\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 290:\n", + " Name: \"L['forward_masks']\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['forward_masks'], '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 291:\n", + " Name: \"L['___stack0'].to\"\n", + " Source: local\n", + " Create Function: HASATTR\n", + " Guard Types: ['HASATTR']\n", + " Code List: [\"hasattr(L['___stack0'], 'to')\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 292:\n", + " Name: \"L['self'].__class__.sf\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['self'].__class__.sf, '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 293:\n", + " Name: \"L['forward_masks'].to\"\n", + " Source: local\n", + " Create Function: HASATTR\n", + " Guard Types: ['HASATTR']\n", + " Code List: [\"hasattr(L['forward_masks'], 'to')\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 294:\n", + " Name: \"L['self'].__class__.state_shape[0]\"\n", + " Source: local\n", + " Create Function: CONSTANT_MATCH\n", + " Guard Types: ['EQUALS_MATCH']\n", + " Code List: [\"L['self'].__class__.state_shape[0] == 2\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 295:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: AUTOGRAD_SAVED_TENSORS_HOOKS\n", + " Guard Types: ['AUTOGRAD_SAVED_TENSORS_HOOKS']\n", + " Code List: ['torch._functorch.aot_autograd.utils.top_saved_tensors_hooks ids == None']\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 296:\n", + " Name: \"L['self'].__class__.s0\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['self'].__class__.s0, '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 297:\n", + " Name: \"G['__builtins_dict___115']['super']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___115']['super'], 4305490664)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 298:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: DEFAULT_DEVICE\n", + " Guard Types: ['DEFAULT_DEVICE']\n", + " Code List: ['utils_device.CURRENT_DEVICE == None']\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 299:\n", + " Name: \"L['___stack0']\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['___stack0'], '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 300:\n", + " Name: \"L['self']\"\n", + " Source: local\n", + " Create Function: TYPE_MATCH\n", + " Guard Types: ['TYPE_MATCH']\n", + " Code List: [\"___check_type_id(L['self'], 6146094768)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 301:\n", + " Name: \"G['__builtins_dict___115']['tuple']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___115']['tuple'], 4305555088)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 302:\n", + " Name: \"L['self'].__class__.state_shape\"\n", + " Source: local\n", + " Create Function: SEQUENCE_LENGTH\n", + " Guard Types: ['TYPE_MATCH', 'SEQUENCE_LENGTH']\n", + " Code List: [\"___check_type_id(L['self'].__class__.state_shape, 4305555088)\", \"len(L['self'].__class__.state_shape) == 1\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 303:\n", + " Name: \"L['states'].to\"\n", + " Source: local\n", + " Create Function: HASATTR\n", + " Guard Types: ['HASATTR']\n", + " Code List: [\"hasattr(L['states'], 'to')\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 304:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: TORCH_FUNCTION_STATE\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 305:\n", + " Name: \"G['__builtins_dict___115']['zip']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___115']['zip'], 4305487296)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 306:\n", + " Name: \"G['__builtins_dict___115']['isinstance']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___115']['isinstance'], 4307013760)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 307:\n", + " Name: \"L['valid_actions']\"\n", + " Source: local\n", + " Create Function: TYPE_MATCH\n", + " Guard Types: ['TYPE_MATCH']\n", + " Code List: [\"___check_type_id(L['valid_actions'], 6146095712)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 308:\n", + " Name: \"G['torch'].ops.aten\"\n", + " Source: global\n", + " Create Function: FUNCTION_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['torch'].ops.aten, 4848906736)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 309:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: DETERMINISTIC_ALGORITHMS\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 310:\n", + " Name: \"G['__builtins_dict___118']['isinstance']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___118']['isinstance'], 4307013760)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 311:\n", + " Name: \"L['___stack0']\"\n", + " Source: local\n", + " Create Function: TYPE_MATCH\n", + " Guard Types: ['TYPE_MATCH']\n", + " Code List: [\"___check_type_id(L['___stack0'], 6146094768)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 312:\n", + " Name: \"L['actions'].__class__.exit_action\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['actions'].__class__.exit_action, '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 313:\n", + " Name: \"G['torch'].ops\"\n", + " Source: global\n", + " Create Function: FUNCTION_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['torch'].ops, 4697612160)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 314:\n", + " Name: \"L['___stack0'].forward_masks\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['___stack0'].forward_masks, '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 315:\n", + " Name: ''\n", + " Source: shape_env\n", + " Create Function: SHAPE_ENV\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 316:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: GRAD_MODE\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 317:\n", + " Name: \"G['torch']\"\n", + " Source: global\n", + " Create Function: FUNCTION_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['torch'], 4351320976)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 318:\n", + " Name: \"G['torch'].compiler.is_compiling\"\n", + " Source: global\n", + " Create Function: FUNCTION_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['torch'].compiler.is_compiling, 4965572688)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 319:\n", + " Name: \"G['torch'].gather\"\n", + " Source: global\n", + " Create Function: FUNCTION_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['torch'].gather, 4428298368)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 320:\n", + " Name: \"L['actions']\"\n", + " Source: local\n", + " Create Function: TYPE_MATCH\n", + " Guard Types: ['TYPE_MATCH']\n", + " Code List: [\"___check_type_id(L['actions'], 6146095712)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 321:\n", + " Name: \"L['___stack0'].backward_masks\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['___stack0'].backward_masks, '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 322:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: AUTOGRAD_SAVED_TENSORS_HOOKS\n", + " Guard Types: ['AUTOGRAD_SAVED_TENSORS_HOOKS']\n", + " Code List: ['torch._functorch.aot_autograd.utils.top_saved_tensors_hooks ids == None']\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 323:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: DEFAULT_DEVICE\n", + " Guard Types: ['DEFAULT_DEVICE']\n", + " Code List: ['utils_device.CURRENT_DEVICE == None']\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 324:\n", + " Name: \"G['__builtins_dict___118']['len']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___118']['len'], 4307014080)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 325:\n", + " Name: \"G['__builtins_dict___118']['tuple']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___118']['tuple'], 4305555088)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 326:\n", + " Name: \"L['valid_states_idx']\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['valid_states_idx'], '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 327:\n", + " Name: \"L['valid_actions'].tensor\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['valid_actions'].tensor, '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 328:\n", + " Name: \"G['__builtins_dict___118']['zip']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___118']['zip'], 4305487296)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 329:\n", + " Name: \"L['self'].is_action_valid.__defaults__[0]\"\n", + " Source: local\n", + " Create Function: CONSTANT_MATCH\n", + " Guard Types: ['BOOL_MATCH']\n", + " Code List: [\"L['self'].is_action_valid.__defaults__[0] == False\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 330:\n", + " Name: \"L['self']\"\n", + " Source: local\n", + " Create Function: TYPE_MATCH\n", + " Guard Types: ['TYPE_MATCH']\n", + " Code List: [\"___check_type_id(L['self'], 6146079840)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 331:\n", + " Name: \"G['torch'].compiler\"\n", + " Source: global\n", + " Create Function: FUNCTION_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['torch'].compiler, 4965862352)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 332:\n", + " Name: \"L['states']\"\n", + " Source: local\n", + " Create Function: TYPE_MATCH\n", + " Guard Types: ['TYPE_MATCH']\n", + " Code List: [\"___check_type_id(L['states'], 6146094768)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 333:\n", + " Name: \"G['torch'].bool\"\n", + " Source: global\n", + " Create Function: CONSTANT_MATCH\n", + " Guard Types: ['EQUALS_MATCH']\n", + " Code List: [\"G['torch'].bool == torch.bool\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 334:\n", + " Name: \"L['actions'].action_shape[0]\"\n", + " Source: local\n", + " Create Function: CONSTANT_MATCH\n", + " Guard Types: ['EQUALS_MATCH']\n", + " Code List: [\"L['actions'].action_shape[0] == 1\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 335:\n", + " Name: \"L['actions'].tensor\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['actions'].tensor, '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 336:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: TORCH_FUNCTION_STATE\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 337:\n", + " Name: \"G['torch'].Tensor\"\n", + " Source: global\n", + " Create Function: ID_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['torch'].Tensor, 5106997584)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 338:\n", + " Name: \"L['actions'].action_shape\"\n", + " Source: local\n", + " Create Function: SEQUENCE_LENGTH\n", + " Guard Types: ['TYPE_MATCH', 'SEQUENCE_LENGTH']\n", + " Code List: [\"___check_type_id(L['actions'].action_shape, 4305555088)\", \"len(L['actions'].action_shape) == 1\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 339:\n", + " Name: \"G['torch'].ops.aten._assert_async\"\n", + " Source: global\n", + " Create Function: FUNCTION_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['torch'].ops.aten._assert_async, 4848920192)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 340:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: DETERMINISTIC_ALGORITHMS\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 341:\n", + " Name: \"L['___stack0'].__class__.state_shape[0]\"\n", + " Source: local\n", + " Create Function: CONSTANT_MATCH\n", + " Guard Types: ['EQUALS_MATCH']\n", + " Code List: [\"L['___stack0'].__class__.state_shape[0] == 2\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 342:\n", + " Name: \"L['___stack0'].backward_masks\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['___stack0'].backward_masks, '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 343:\n", + " Name: \"L['___stack0']\"\n", + " Source: local\n", + " Create Function: TYPE_MATCH\n", + " Guard Types: ['TYPE_MATCH']\n", + " Code List: [\"___check_type_id(L['___stack0'], 6146094768)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 344:\n", + " Name: ''\n", + " Source: shape_env\n", + " Create Function: SHAPE_ENV\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 345:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: GRAD_MODE\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 346:\n", + " Name: \"L['___stack0'].__class__.n_actions\"\n", + " Source: local\n", + " Create Function: CONSTANT_MATCH\n", + " Guard Types: ['EQUALS_MATCH']\n", + " Code List: [\"L['___stack0'].__class__.n_actions == 3\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 347:\n", + " Name: \"L['___stack0'].__class__.sf\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['___stack0'].__class__.sf, '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 348:\n", + " Name: \"G['__builtins_dict___122']['len']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___122']['len'], 4307014080)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 349:\n", + " Name: \"L['actions']\"\n", + " Source: local\n", + " Create Function: TYPE_MATCH\n", + " Guard Types: ['TYPE_MATCH']\n", + " Code List: [\"___check_type_id(L['actions'], 6146095712)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 350:\n", + " Name: \"L['___stack0'].backward_masks.clone\"\n", + " Source: local\n", + " Create Function: HASATTR\n", + " Guard Types: ['HASATTR']\n", + " Code List: [\"hasattr(L['___stack0'].backward_masks, 'clone')\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 351:\n", + " Name: \"L['___stack0'].__class__.s0\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['___stack0'].__class__.s0, '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 352:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: AUTOGRAD_SAVED_TENSORS_HOOKS\n", + " Guard Types: ['AUTOGRAD_SAVED_TENSORS_HOOKS']\n", + " Code List: ['torch._functorch.aot_autograd.utils.top_saved_tensors_hooks ids == None']\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 353:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: DEFAULT_DEVICE\n", + " Guard Types: ['DEFAULT_DEVICE']\n", + " Code List: ['utils_device.CURRENT_DEVICE == None']\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 354:\n", + " Name: \"G['__builtins_dict___122']['tuple']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___122']['tuple'], 4305555088)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 355:\n", + " Name: \"L['___stack0'].__class__.state_shape\"\n", + " Source: local\n", + " Create Function: SEQUENCE_LENGTH\n", + " Guard Types: ['TYPE_MATCH', 'SEQUENCE_LENGTH']\n", + " Code List: [\"___check_type_id(L['___stack0'].__class__.state_shape, 4305555088)\", \"len(L['___stack0'].__class__.state_shape) == 1\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 356:\n", + " Name: \"L['new_valid_states_idx']\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['new_valid_states_idx'], '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 357:\n", + " Name: \"L['___stack0'].tensor\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['___stack0'].tensor, '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 358:\n", + " Name: \"G['__builtins_dict___122']['isinstance']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___122']['isinstance'], 4307013760)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 359:\n", + " Name: \"G['__builtins_dict___122']['super']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___122']['super'], 4305490664)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 360:\n", + " Name: \"G['__builtins_dict___122']['zip']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___122']['zip'], 4305487296)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 361:\n", + " Name: \"L['___stack0'].forward_masks\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['___stack0'].forward_masks, '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 362:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: TORCH_FUNCTION_STATE\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 363:\n", + " Name: \"L['___stack0'].tensor.clone\"\n", + " Source: local\n", + " Create Function: HASATTR\n", + " Guard Types: ['HASATTR']\n", + " Code List: [\"hasattr(L['___stack0'].tensor, 'clone')\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 364:\n", + " Name: \"L['___stack0'].forward_masks.clone\"\n", + " Source: local\n", + " Create Function: HASATTR\n", + " Guard Types: ['HASATTR']\n", + " Code List: [\"hasattr(L['___stack0'].forward_masks, 'clone')\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 365:\n", + " Name: \"L['self'].States.state_shape[0]\"\n", + " Source: local\n", + " Create Function: CONSTANT_MATCH\n", + " Guard Types: ['EQUALS_MATCH']\n", + " Code List: [\"L['self'].States.state_shape[0] == 2\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 366:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: DETERMINISTIC_ALGORITHMS\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 367:\n", + " Name: \"G['__builtins_dict___125']['tuple']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___125']['tuple'], 4305555088)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 368:\n", + " Name: \"L['___stack0']\"\n", + " Source: local\n", + " Create Function: TYPE_MATCH\n", + " Guard Types: ['TYPE_MATCH']\n", + " Code List: [\"___check_type_id(L['___stack0'], 6146095712)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 369:\n", + " Name: \"G['__import_gfn_dot_states'].torch\"\n", + " Source: global\n", + " Create Function: FUNCTION_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__import_gfn_dot_states'].torch, 4351320976)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 370:\n", + " Name: \"L['not_done_states'].tensor\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['not_done_states'].tensor, '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 371:\n", + " Name: \"G['__builtins_dict___125']['zip']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___125']['zip'], 4305487296)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 372:\n", + " Name: ''\n", + " Source: shape_env\n", + " Create Function: SHAPE_ENV\n", + " Guard Types: ['SHAPE_ENV']\n", + " Code List: [\"2 <= L['states'].tensor.size()[0]\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 373:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: GRAD_MODE\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 374:\n", + " Name: \"G['__import_gfn_dot_states']\"\n", + " Source: global\n", + " Create Function: FUNCTION_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__import_gfn_dot_states'], 5499072048)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 375:\n", + " Name: \"G['__builtins_dict___125']['super']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___125']['super'], 4305490664)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 376:\n", + " Name: \"G['__import_gfn_dot_states'].torch.Tensor\"\n", + " Source: global\n", + " Create Function: ID_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__import_gfn_dot_states'].torch.Tensor, 5106997584)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 377:\n", + " Name: \"L['self'].States.s0\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['self'].States.s0, '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 378:\n", + " Name: \"G['__builtins_dict___125']['isinstance']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___125']['isinstance'], 4307013760)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 379:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: AUTOGRAD_SAVED_TENSORS_HOOKS\n", + " Guard Types: ['AUTOGRAD_SAVED_TENSORS_HOOKS']\n", + " Code List: ['torch._functorch.aot_autograd.utils.top_saved_tensors_hooks ids == None']\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 380:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: DEFAULT_DEVICE\n", + " Guard Types: ['DEFAULT_DEVICE']\n", + " Code List: ['utils_device.CURRENT_DEVICE == None']\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 381:\n", + " Name: \"L['states'].tensor\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['states'].tensor, '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 382:\n", + " Name: \"L['self'].States.sf.repeat\"\n", + " Source: local\n", + " Create Function: HASATTR\n", + " Guard Types: ['HASATTR']\n", + " Code List: [\"hasattr(L['self'].States.sf, 'repeat')\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 383:\n", + " Name: \"L['self'].States\"\n", + " Source: local\n", + " Create Function: ID_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(L['self'].States, 6146094768)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 384:\n", + " Name: \"L['not_done_states'].tensor.scatter\"\n", + " Source: local\n", + " Create Function: HASATTR\n", + " Guard Types: ['HASATTR']\n", + " Code List: [\"hasattr(L['not_done_states'].tensor, 'scatter')\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 385:\n", + " Name: \"L['self']\"\n", + " Source: local\n", + " Create Function: TYPE_MATCH\n", + " Guard Types: ['TYPE_MATCH']\n", + " Code List: [\"___check_type_id(L['self'], 6146079840)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 386:\n", + " Name: \"G['States']\"\n", + " Source: global\n", + " Create Function: ID_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['States'], 6146032752)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 387:\n", + " Name: \"G['__import_gfn_dot_states'].torch.bool\"\n", + " Source: global\n", + " Create Function: CONSTANT_MATCH\n", + " Guard Types: ['EQUALS_MATCH']\n", + " Code List: [\"G['__import_gfn_dot_states'].torch.bool == torch.bool\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 388:\n", + " Name: \"L['states']\"\n", + " Source: local\n", + " Create Function: TYPE_MATCH\n", + " Guard Types: ['TYPE_MATCH']\n", + " Code List: [\"___check_type_id(L['states'], 6146094768)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 389:\n", + " Name: \"L['states'].tensor\"\n", + " Source: local\n", + " Create Function: TYPE_MATCH\n", + " Guard Types: ['TYPE_MATCH']\n", + " Code List: [\"___check_type_id(L['states'].tensor, 5106997584)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 390:\n", + " Name: \"L['self'].States.state_shape\"\n", + " Source: local\n", + " Create Function: SEQUENCE_LENGTH\n", + " Guard Types: ['TYPE_MATCH', 'SEQUENCE_LENGTH']\n", + " Code List: [\"___check_type_id(L['self'].States.state_shape, 4305555088)\", \"len(L['self'].States.state_shape) == 1\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 391:\n", + " Name: \"L['states'].state_shape\"\n", + " Source: local\n", + " Create Function: SEQUENCE_LENGTH\n", + " Guard Types: ['TYPE_MATCH', 'SEQUENCE_LENGTH']\n", + " Code List: [\"___check_type_id(L['states'].state_shape, 4305555088)\", \"len(L['states'].state_shape) == 1\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 392:\n", + " Name: \"L['new_valid_states_idx']\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['new_valid_states_idx'], '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 393:\n", + " Name: \"L['self'].States.sf\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['self'].States.sf, '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 394:\n", + " Name: \"L['self'].States.n_actions\"\n", + " Source: local\n", + " Create Function: CONSTANT_MATCH\n", + " Guard Types: ['EQUALS_MATCH']\n", + " Code List: [\"L['self'].States.n_actions == 3\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 395:\n", + " Name: \"G['__import_gfn_dot_states'].torch.ones\"\n", + " Source: global\n", + " Create Function: FUNCTION_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__import_gfn_dot_states'].torch.ones, 4428225552)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 396:\n", + " Name: \"L['not_done_states']\"\n", + " Source: local\n", + " Create Function: TYPE_MATCH\n", + " Guard Types: ['TYPE_MATCH']\n", + " Code List: [\"___check_type_id(L['not_done_states'], 6146094768)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 397:\n", + " Name: \"L['___stack0'].tensor\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['___stack0'].tensor, '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 398:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: TORCH_FUNCTION_STATE\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 399:\n", + " Name: \"L['states'].tensor.shape\"\n", + " Source: local\n", + " Create Function: SEQUENCE_LENGTH\n", + " Guard Types: ['TYPE_MATCH', 'SEQUENCE_LENGTH']\n", + " Code List: [\"___check_type_id(L['states'].tensor.shape, 4891320080)\", \"len(L['states'].tensor.shape) == 2\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 400:\n", + " Name: \"G['__builtins_dict___125']['len']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___125']['len'], 4307014080)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 401:\n", + " Name: \"L['self'].tensor\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['self'].tensor, '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 402:\n", + " Name: \"G['torch'].zeros\"\n", + " Source: global\n", + " Create Function: FUNCTION_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['torch'].zeros, 4428230352)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 403:\n", + " Name: \"G['__builtins_dict___127']['tuple']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___127']['tuple'], 4305555088)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 404:\n", + " Name: ''\n", + " Source: shape_env\n", + " Create Function: SHAPE_ENV\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 405:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: GRAD_MODE\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 406:\n", + " Name: \"G['torch']\"\n", + " Source: global\n", + " Create Function: FUNCTION_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['torch'], 4351320976)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 407:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: DETERMINISTIC_ALGORITHMS\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 408:\n", + " Name: \"L['cond']\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['cond'], '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 409:\n", + " Name: \"L['self'].state_shape\"\n", + " Source: local\n", + " Create Function: SEQUENCE_LENGTH\n", + " Guard Types: ['TYPE_MATCH', 'SEQUENCE_LENGTH']\n", + " Code List: [\"___check_type_id(L['self'].state_shape, 4305555088)\", \"len(L['self'].state_shape) == 1\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 410:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: AUTOGRAD_SAVED_TENSORS_HOOKS\n", + " Guard Types: ['AUTOGRAD_SAVED_TENSORS_HOOKS']\n", + " Code List: ['torch._functorch.aot_autograd.utils.top_saved_tensors_hooks ids == None']\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 411:\n", + " Name: \"L['self']\"\n", + " Source: local\n", + " Create Function: TYPE_MATCH\n", + " Guard Types: ['TYPE_MATCH']\n", + " Code List: [\"___check_type_id(L['self'], 6146094768)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 412:\n", + " Name: \"G['torch'].cat\"\n", + " Source: global\n", + " Create Function: FUNCTION_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['torch'].cat, 4428310752)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 413:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: TORCH_FUNCTION_STATE\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 414:\n", + " Name: \"L['allow_exit']\"\n", + " Source: local\n", + " Create Function: CONSTANT_MATCH\n", + " Guard Types: ['BOOL_MATCH']\n", + " Code List: [\"L['allow_exit'] == True\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 415:\n", + " Name: \"L['self'].forward_masks\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['self'].forward_masks, '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 416:\n", + " Name: \"G['__builtins_dict___127']['len']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___127']['len'], 4307014080)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 417:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: DEFAULT_DEVICE\n", + " Guard Types: ['DEFAULT_DEVICE']\n", + " Code List: ['utils_device.CURRENT_DEVICE == None']\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + "Compile Times: TorchDynamo compilation metrics:\n", + "Function, Runtimes (s)\n", + "_compile.compile_inner, 0.0755, 0.0391, 0.0135, 0.0140, 0.0205, 0.0074, 0.0033, 0.0027, 0.0187, 0.0188, 0.0175, 0.0315, 0.0340, 0.0261, 0.0341, 0.0167, 0.0232, 0.0155, 0.0428, 0.1327, 0.0327, 0.0096\n", + "compile_attempt_0, 0.0499, 0.0194, 0.0083, 0.0081, 0.0069, 0.0045, 0.0030, 0.0024, 0.0108, 0.0124, 0.0096, 0.0136, 0.0277, 0.0127, 0.0168, 0.0121, 0.0115, 0.0118, 0.0223, 0.1121, 0.0279, 0.0063\n", + "bytecode_tracing, 0.0484, 0.0159, 0.0188, 0.0110, 0.0079, 0.0011, 0.0077, 0.0016, 0.0064, 0.0053, 0.0027, 0.0026, 0.0020, 0.0103, 0.0001, 0.0118, 0.0013, 0.0092, 0.0019, 0.0131, 0.0104, 0.0253, 0.0122, 0.0045, 0.0161, 0.0096, 0.0101, 0.0110, 0.0038, 0.0096, 0.0218, 0.0104, 0.1115, 0.0103, 0.0242, 0.0044\n", + "compile_attempt_1, 0.0200, 0.0156, 0.0023, 0.0028, 0.0097, 0.0051, 0.0035, 0.0043, 0.0127, 0.0090, 0.0120, 0.0081, 0.0161, 0.0166\n", + "OutputGraph.call_user_compiler, 0.0001, 0.0001, 0.0001, 0.0001, 0.0001, 0.0001, 0.0001, 0.0001, 0.0001, 0.0001, 0.0001, 0.0001, 0.0001, 0.0001, 0.0001, 0.0001\n", + "build_guards, 0.0046, 0.0035, 0.0023, 0.0026, 0.0034, 0.0025, 0.0023, 0.0024, 0.0031, 0.0046, 0.0054, 0.0038, 0.0046, 0.0041, 0.0032, 0.0032, 0.0037, 0.0033, 0.0042, 0.0028\n", + "gc, 0.0006, 0.0003, 0.0002, 0.0002, 0.0002, 0.0001, 0.0001, 0.0001, 0.0003, 0.0002, 0.0003, 0.0003, 0.0004, 0.0003, 0.0003, 0.0002, 0.0002, 0.0002, 0.0005, 0.0003, 0.0002, 0.0002\n", + "pgo.dynamic_whitelist, 0.0000, 0.0000, 0.0000, 0.0000\n", + "\n" + ] + } + ], + "source": [ + "import torch._dynamo as dynamo\n", + "\n", + "explanation = dynamo.explain(step_once)(states.tensor, actions.tensor)\n", + "print(explanation)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "torchgfn", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.9" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} From 1aa2bf4ee674ce5f0e71b0fa0bb1997adff8810e Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Thu, 27 Nov 2025 20:06:04 -0500 Subject: [PATCH 06/15] removed plan --- multi.plan.md | 34 ---------------------------------- 1 file changed, 34 deletions(-) delete mode 100644 multi.plan.md diff --git a/multi.plan.md b/multi.plan.md deleted file mode 100644 index 0df97141..00000000 --- a/multi.plan.md +++ /dev/null @@ -1,34 +0,0 @@ - -# Plan: Extend Benchmark to Diffusion Sampling - -## 1. Refactor scenario management - -- Extract existing HyperGrid logic into a reusable `EnvironmentBenchmark` structure (e.g., dataclass with name, color, scenario list, builder function). -- Keep HyperGrid’s current scenarios (baseline / library fast path / script fast path) but register them under the new structure. -- Update the main loop to iterate over environments sequentially, collecting per-env results (including histories) and tagging each record with both env and scenario identifiers. - -## 2. Add diffusion sampling environment support - -- Review `tutorials/examples/train_diffusion_sampler.py` to reuse its estimator construction (`DiffusionSampling`, `DiffusionPISGradNetForward`, `DiffusionFixedBackwardModule`, `PinnedBrownianMotionForward/Backward`). -- Implement a new `DiffusionEnvConfig` builder under `build_training_components` (or a dedicated helper) that creates the env, forward/backward estimators, optimizer groups, and default hyperparameters mirroring the standalone script. -- Define diffusion-specific scenarios: -- Baseline: standard sampler, no compilation. -- Library Fast Path: use `CompiledChunkSampler` (env already inherits `EnvFastPathMixin`). -- Script Fast Path: implement a local chunked sampler analogous to `ChunkedHyperGridSampler`, but operating on diffusion states/tensors (handle continuous actions, exit padding, dummy actions). Expose it only for diffusion. - -## 3. Integrate new sampler/env wiring - -- Update `build_training_components` to dispatch based on the environment key (hypergrid vs diffusion) so each path can select the correct preprocessor, estimator modules, sampler type, and optimizer parameter groups. -- Ensure the diffusion path still returns metrics compatible with the existing training loop (needs `validate`?—if not available for diffusion, skip validation or provide a stub message). - -## 4. Expand plotting to multi-row layout - -- Adjust `plot_benchmark` to group results by environment and create one row per environment (HyperGrid row retains three scenarios; Diffusion row shows its two/three variants). -- Reuse the existing color mapping for GFlowNet variants; introduce per-environment scenario linestyles (or reuse existing names when overlapping). -- Update subplot titles/labels to mention the environment name so viewers can distinguish rows easily. - -## 5. Final polish - -- Update CLI help text to mention multi-environment benchmarking and any diffusion-specific knobs (e.g., target selection, num steps) if exposed; otherwise, explain defaults in docstring/comments. -- Verify histories are recorded for both environments so the new loss/timing plots aren’t empty. -- Refresh documentation/comments at the top of the script to describe the new diffusion benchmark capability. \ No newline at end of file From 914283c771f1b074bceb621016e0142a890b623b Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Thu, 27 Nov 2025 21:30:13 -0500 Subject: [PATCH 07/15] fixes --- .../examples/train_hypergrid_optimized.py | 320 +++++++++++------- 1 file changed, 195 insertions(+), 125 deletions(-) diff --git a/tutorials/examples/train_hypergrid_optimized.py b/tutorials/examples/train_hypergrid_optimized.py index 211e2edf..759255e5 100644 --- a/tutorials/examples/train_hypergrid_optimized.py +++ b/tutorials/examples/train_hypergrid_optimized.py @@ -9,6 +9,7 @@ import argparse import statistics import time +import warnings from dataclasses import dataclass from pathlib import Path from typing import Any, Dict, Iterable, List, Literal, cast @@ -268,16 +269,20 @@ def step_tensor( next_states = states.clone() non_exit_mask = ~is_exit - if torch.any(non_exit_mask): - sel_states = next_states[non_exit_mask] - sel_actions = actions_idx[non_exit_mask] - sel_states = sel_states.scatter(-1, sel_actions, 1, reduce="add") - next_states[non_exit_mask] = sel_states - if torch.any(is_exit): - # Ensure exit actions land exactly on the sink state so downstream - # `is_sink_state` masks match the action padding semantics assumed - # by `Trajectories` and probability calculations. - next_states[is_exit] = self.sf.to(device=device) + non_exit_mask_exp = non_exit_mask.unsqueeze(-1) + safe_actions = torch.where( + non_exit_mask_exp, actions_idx, torch.zeros_like(actions_idx) + ) + delta = torch.zeros_like(next_states) + delta = delta.scatter(-1, safe_actions, 1, reduce="add") + delta = delta * non_exit_mask_exp.to(next_states.dtype) + next_states = next_states + delta + + # Ensure exit actions land exactly on the sink state so downstream + # `is_sink_state` masks match the action padding semantics assumed + # by `Trajectories` and probability calculations. + sink_state = self.sf.to(device=device).unsqueeze(0).expand_as(next_states) + next_states = torch.where(is_exit.unsqueeze(-1), sink_state, next_states) next_forward_masks = torch.ones( (batch, self.n_actions), dtype=torch.bool, device=device @@ -326,7 +331,8 @@ def sample_trajectories( # noqa: C901 ): assert self.chunk_size > 0 assert hasattr(env, "step_tensor") - epsilon = float(policy_kwargs.get("epsilon", 0.0)) + policy_kwargs = dict(policy_kwargs) + epsilon = float(policy_kwargs.pop("epsilon", 0.0)) if states is None: assert n is not None @@ -334,10 +340,12 @@ def sample_trajectories( # noqa: C901 else: states_obj = states - estimator = self.estimator - module = getattr(estimator, "module", None) - assert module is not None - height = int(env.height) + if not isinstance(self.estimator, FastPolicyMixin): + raise TypeError( + "ChunkedHyperGridSampler requires a FastPolicy-compatible estimator." + ) + policy = cast(FastPolicyMixin, self.estimator) + chunk_size = max(1, self.chunk_size) exit_idx = env.n_actions - 1 curr_states = states_obj.tensor @@ -370,31 +378,28 @@ def step_tensor( next_states, next_masks, is_exit_states = step_result return next_states, next_masks, is_exit_states + exit_action_value = env.exit_action.to(device=device) + dummy_action_value = env.dummy_action.to(device=device) + forward_masks = compute_forward_masks(curr_states) done = torch.zeros(batch, dtype=torch.bool, device=device) actions_seq: List[torch.Tensor] = [] dones_seq: List[torch.Tensor] = [] - - def sample_actions_from_logits( - logits: torch.Tensor, masks: torch.Tensor, eps: float - ) -> torch.Tensor: - masked_logits = logits.masked_fill(~masks, float("-inf")) - probs = torch.softmax(masked_logits, dim=-1) - - if eps > 0.0: - valid_counts = masks.sum(dim=-1, keepdim=True).clamp_min(1) - uniform = masks.to(probs.dtype) / valid_counts.to(probs.dtype) - probs = (1.0 - eps) * probs + eps * uniform - - # Ensure exit actions have probability 1.0 so that they land exactly on - # the sink state and downstream `is_sink_state` masks match the action - # padding semantics assumed by `Trajectories` and probability calculations. - nan_rows = torch.isnan(probs).any(dim=-1) - if nan_rows.any(): - probs[nan_rows] = 0.0 - probs[nan_rows, exit_idx] = 1.0 - - return torch.multinomial(probs, 1) + states_stack: List[torch.Tensor] = [curr_states.clone()] + + def _expand_front(tensor: torch.Tensor, target_ndim: int) -> torch.Tensor: + expand_dims = target_ndim - tensor.ndim + if expand_dims <= 0: + return tensor + view_shape = (1,) * expand_dims + tuple(tensor.shape) + return tensor.view(view_shape) + + def _expand_back(tensor: torch.Tensor, target_ndim: int) -> torch.Tensor: + expand_dims = target_ndim - tensor.ndim + if expand_dims <= 0: + return tensor + view_shape = tuple(tensor.shape) + (1,) * expand_dims + return tensor.view(view_shape) def _chunk_loop( current_states: torch.Tensor, @@ -403,38 +408,72 @@ def _chunk_loop( ): actions_list: List[torch.Tensor] = [] dones_list: List[torch.Tensor] = [] - for _ in range(self.chunk_size): + states_list: List[torch.Tensor] = [] + for _ in range(chunk_size): + if bool(done_mask.all().item()): + break + + masks = current_masks if done_mask.any(): - current_masks = current_masks.clone() - current_masks[done_mask] = False - current_masks[done_mask, exit_idx] = True - states_for_encoding = torch.clamp(current_states, min=0) - khot = torch.nn.functional.one_hot( - states_for_encoding, num_classes=height - ).to(dtype=torch.get_default_dtype()) - khot = khot.view(current_states.shape[0], -1) - logits = module(khot) - actions = sample_actions_from_logits(logits, current_masks, epsilon) - next_states, next_masks, is_exit = step_tensor(current_states, actions) - record_actions = actions.clone() - - # Replace actions for already-finished trajectories with the dummy - # action so that their timeline matches the padded semantics expected - # by Trajectories (actions.is_dummy aligns with states.is_sink_state[:-1]). + masks = masks.clone() + masks[done_mask] = False + masks[done_mask, exit_idx] = True + + features = policy.fast_features( + current_states, + forward_masks=masks, + backward_masks=None, + conditions=conditions, + ) + dist = policy.fast_distribution( + features, + forward_masks=masks, + backward_masks=None, + states_tensor=current_states, + epsilon=epsilon, + **policy_kwargs, + ) + sampled_actions = dist.sample() + step_actions = sampled_actions + record_actions = sampled_actions + if done_mask.any(): - dummy_val = env.dummy_action.to(device=device) - record_actions[done_mask] = dummy_val + mask = _expand_back(done_mask, sampled_actions.ndim) + exit_fill = _expand_front( + exit_action_value.to( + device=sampled_actions.device, dtype=sampled_actions.dtype + ), + sampled_actions.ndim, + ) + dummy_fill = _expand_front( + dummy_action_value.to( + device=sampled_actions.device, dtype=sampled_actions.dtype + ), + sampled_actions.ndim, + ) + step_actions = torch.where(mask, exit_fill, sampled_actions) + record_actions = torch.where(mask, dummy_fill, sampled_actions) + + next_states, next_masks, is_exit = step_tensor( + current_states, step_actions + ) + actions_list.append(record_actions) dones_list.append(is_exit) + states_list.append(next_states.clone()) current_states = next_states current_masks = next_masks done_mask = done_mask | is_exit - if bool(done_mask.all().item()): - break - - return current_states, current_masks, done_mask, actions_list, dones_list + return ( + current_states, + current_masks, + done_mask, + actions_list, + dones_list, + states_list, + ) chunk_fn = _chunk_loop if hasattr(torch, "compile"): @@ -444,34 +483,35 @@ def _chunk_loop( pass while not bool(done.all().item()): - curr_states, forward_masks, done, actions_chunk, dones_chunk = chunk_fn( - curr_states, forward_masks, done - ) + ( + curr_states, + forward_masks, + done, + actions_chunk, + dones_chunk, + states_chunk, + ) = chunk_fn(curr_states, forward_masks, done) if actions_chunk: actions_seq.extend(actions_chunk) dones_seq.extend(dones_chunk) + states_stack.extend(states_chunk) if actions_seq: - actions_tsr = torch.stack([a for a in actions_seq], dim=0) - replay_actions = actions_tsr.clone() - dummy_val = env.dummy_action.to(device=device, dtype=replay_actions.dtype) - exit_val = env.exit_action.to(device=device, dtype=replay_actions.dtype) - - mask = replay_actions == dummy_val - if mask.any(): - exit_fill = exit_val - while exit_fill.ndim < replay_actions.ndim: - exit_fill = exit_fill.unsqueeze(0) - replay_actions = torch.where(mask, exit_fill, replay_actions) - - T = actions_tsr.shape[0] - s = states_obj.tensor - states_stack = [s] - for t in range(T): - s, _, _ = step_tensor(s, replay_actions[t]) - states_stack.append(s) + actions_tsr = torch.stack(actions_seq, dim=0) states_tsr = torch.stack(states_stack, dim=0) + action_shape = getattr(env, "action_shape", None) + if action_shape: + tail_shape = tuple(actions_tsr.shape[-len(action_shape):]) + if tail_shape != tuple(action_shape): + if tuple(action_shape) == (1,): + actions_tsr = actions_tsr.unsqueeze(-1) + else: + raise ValueError( + "ChunkedHyperGridSampler produced actions with shape " + f"{actions_tsr.shape}, expected trailing dims {action_shape}." + ) is_exit_seq = torch.stack(dones_seq, dim=0) + T = actions_tsr.shape[0] first_exit = torch.argmax(is_exit_seq.to(torch.long), dim=0) never_exited = ~is_exit_seq.any(dim=0) first_exit = torch.where( @@ -543,9 +583,40 @@ def sample_trajectories( # noqa: C901 exit_action_value = env.exit_action.to(device=curr_states.device) dummy_action_value = env.dummy_action.to(device=curr_states.device) - step_actions_seq: List[torch.Tensor] = [] recorded_actions_seq: List[torch.Tensor] = [] sink_seq: List[torch.Tensor] = [] + states_stack: List[torch.Tensor] = [curr_states.clone()] + + exit_template_cache: dict[tuple[int, torch.dtype], torch.Tensor] = {} + dummy_template_cache: dict[tuple[int, torch.dtype], torch.Tensor] = {} + + def _expand_front(tensor: torch.Tensor, target_ndim: int) -> torch.Tensor: + expand_dims = target_ndim - tensor.ndim + if expand_dims <= 0: + return tensor + view_shape = (1,) * expand_dims + tuple(tensor.shape) + return tensor.view(view_shape) + + def _expand_back(tensor: torch.Tensor, target_ndim: int) -> torch.Tensor: + expand_dims = target_ndim - tensor.ndim + if expand_dims <= 0: + return tensor + view_shape = tuple(tensor.shape) + (1,) * expand_dims + return tensor.view(view_shape) + + def _get_template( + cache: dict[tuple[int, torch.dtype], torch.Tensor], + base_value: torch.Tensor, + target_ndim: int, + dtype: torch.dtype, + device: torch.device, + ) -> torch.Tensor: + key = (target_ndim, dtype) + tensor = cache.get(key) + if tensor is None: + tensor = _expand_front(base_value.to(device=device, dtype=dtype), target_ndim) + cache[key] = tensor + return tensor def _chunk_loop(current_states: torch.Tensor, done_mask: torch.Tensor) -> tuple[ torch.Tensor, @@ -554,12 +625,12 @@ def _chunk_loop(current_states: torch.Tensor, done_mask: torch.Tensor) -> tuple[ List[torch.Tensor], List[torch.Tensor], ]: - local_step_actions: List[torch.Tensor] = [] local_recorded_actions: List[torch.Tensor] = [] local_sinks: List[torch.Tensor] = [] + local_states: List[torch.Tensor] = [] for _ in range(chunk_size): - if bool(done_mask.all().item()): + if torch.all(done_mask): break features = policy.fast_features( @@ -576,29 +647,27 @@ def _chunk_loop(current_states: torch.Tensor, done_mask: torch.Tensor) -> tuple[ **policy_kwargs, ) sampled_actions = dist.sample() - - if done_mask.any(): - mask = done_mask - while mask.ndim < sampled_actions.ndim: - mask = mask.unsqueeze(-1) - - exit_fill = exit_action_value.to( - device=sampled_actions.device, dtype=sampled_actions.dtype + step_actions = sampled_actions + record_actions = sampled_actions + + if torch.any(done_mask): + mask = _expand_back(done_mask, sampled_actions.ndim) + exit_fill = _get_template( + exit_template_cache, + exit_action_value, + sampled_actions.ndim, + sampled_actions.dtype, + sampled_actions.device, ) - while exit_fill.ndim < sampled_actions.ndim: - exit_fill = exit_fill.unsqueeze(0) - - dummy_fill = dummy_action_value.to( - device=sampled_actions.device, dtype=sampled_actions.dtype + dummy_fill = _get_template( + dummy_template_cache, + dummy_action_value, + sampled_actions.ndim, + sampled_actions.dtype, + sampled_actions.device, ) - while dummy_fill.ndim < sampled_actions.ndim: - dummy_fill = dummy_fill.unsqueeze(0) - step_actions = torch.where(mask, exit_fill, sampled_actions) record_actions = torch.where(mask, dummy_fill, sampled_actions) - else: - step_actions = sampled_actions - record_actions = sampled_actions step_res = env.step_tensor(current_states, step_actions) current_states = step_res.next_states @@ -607,16 +676,16 @@ def _chunk_loop(current_states: torch.Tensor, done_mask: torch.Tensor) -> tuple[ sinks = env.states_from_tensor(current_states).is_sink_state done_mask = done_mask | sinks - local_step_actions.append(step_actions) local_recorded_actions.append(record_actions) local_sinks.append(sinks) + local_states.append(current_states.clone()) return ( current_states, done_mask, - local_step_actions, local_recorded_actions, local_sinks, + local_states, ) chunk_fn = _chunk_loop @@ -624,36 +693,41 @@ def _chunk_loop(current_states: torch.Tensor, done_mask: torch.Tensor) -> tuple[ if hasattr(torch, "compile") and device_type in ("cuda", "cpu"): try: chunk_fn = torch.compile(_chunk_loop, mode="reduce-overhead") # type: ignore[arg-type] - except Exception: - raise RuntimeError( - "Compilation of _chunk_loop for Diffusion Sampling fails on MPS" + except Exception as exc: # pragma: no cover - compile fallback + warnings.warn( + f"Compilation of diffusion chunk loop failed ({exc}); using eager version.", + stacklevel=2, ) + chunk_fn = _chunk_loop while not bool(done.all().item()): ( curr_states, done, - step_actions_chunk, recorded_actions_chunk, sinks_chunk, + states_chunk, ) = chunk_fn(curr_states, done) - if step_actions_chunk: - step_actions_seq.extend(step_actions_chunk) + if recorded_actions_chunk: recorded_actions_seq.extend(recorded_actions_chunk) sink_seq.extend(sinks_chunk) + states_stack.extend(states_chunk) if recorded_actions_seq: actions_tsr = torch.stack(recorded_actions_seq, dim=0) - T = actions_tsr.shape[0] - - s = states_obj.tensor - states_stack = [s] - for t in range(T): - step = env.step_tensor(s, step_actions_seq[t]) - s = step.next_states - states_stack.append(s) states_tsr = torch.stack(states_stack, dim=0) - + action_shape = getattr(env, "action_shape", None) + if action_shape: + tail_shape = tuple(actions_tsr.shape[-len(action_shape):]) + if tail_shape != tuple(action_shape): + if tuple(action_shape) == (1,): + actions_tsr = actions_tsr.unsqueeze(-1) + else: + raise ValueError( + "ChunkedDiffusionSampler produced actions with shape " + f"{actions_tsr.shape}, expected trailing dims {action_shape}." + ) + T = actions_tsr.shape[0] sinks_tsr = torch.stack(sink_seq, dim=0) first_sink = torch.argmax(sinks_tsr.to(torch.long), dim=0) never_sink = ~sinks_tsr.any(dim=0) @@ -1156,11 +1230,7 @@ def run_iterations( metrics["measured_steps"] += 1 last_loss = loss.item() - if ( - record_history - and losses_history is not None - and iter_time_history is not None - ): + if record_history and losses_history is not None and iter_time_history is not None: losses_history.append(last_loss) iter_duration = ( (time.perf_counter() - iter_start) if iter_start is not None else 0.0 @@ -1328,7 +1398,7 @@ def _build_hypergrid_components( trunk=module_pf.trunk, ) - if scenario.sampler == "compiled_chunk": + if scenario.sampler in {"compiled_chunk", "script_chunk"}: pf_estimator = FastKHotDiscretePolicyEstimator(env, module_pf, preprocessor) else: pf_estimator = DiscretePolicyEstimator( From f008b89d0e28cc5809ebf41b866d182d5284595d Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Thu, 27 Nov 2025 21:33:44 -0500 Subject: [PATCH 08/15] fixes --- src/gfn/gym/hypergrid.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/src/gfn/gym/hypergrid.py b/src/gfn/gym/hypergrid.py index e2d65553..db6d97b8 100644 --- a/src/gfn/gym/hypergrid.py +++ b/src/gfn/gym/hypergrid.py @@ -217,14 +217,20 @@ def step_tensor( next_states = states_tensor.clone() non_exit_mask = ~is_exit_action - if torch.any(non_exit_mask): - sel_states = next_states[non_exit_mask] - sel_actions = actions_idx[non_exit_mask] - sel_states = sel_states.scatter(-1, sel_actions, 1, reduce="add") - next_states[non_exit_mask] = sel_states - - if torch.any(is_exit_action): - next_states[is_exit_action] = self.sf.to(device=states_tensor.device) + non_exit_mask_exp = non_exit_mask.unsqueeze(-1) + safe_actions = torch.where( + non_exit_mask_exp, actions_idx, torch.zeros_like(actions_idx) + ) + delta = torch.zeros_like(next_states) + delta = delta.scatter(-1, safe_actions, 1, reduce="add") + delta = delta * non_exit_mask_exp.to(next_states.dtype) + next_states = next_states + delta + + sink_state = self.sf.to(device=states_tensor.device) + while sink_state.ndim < next_states.ndim: + sink_state = sink_state.unsqueeze(0) + sink_state = sink_state.expand_as(next_states) + next_states = torch.where(is_exit_action.unsqueeze(-1), sink_state, next_states) forward_masks = self.forward_action_masks_tensor(next_states) backward_masks = next_states != 0 From d2fd0436cb5c2cd43e5abe1caa8a694af0b2a203 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Thu, 27 Nov 2025 22:09:40 -0500 Subject: [PATCH 09/15] fixed most problems... --- src/gfn/samplers.py | 50 +++- .../examples/train_hypergrid_optimized.py | 251 +++++++++++------- 2 files changed, 198 insertions(+), 103 deletions(-) diff --git a/src/gfn/samplers.py b/src/gfn/samplers.py index 053ef290..d5ec0fe0 100644 --- a/src/gfn/samplers.py +++ b/src/gfn/samplers.py @@ -1,5 +1,5 @@ import warnings -from typing import Any, List, Optional, Tuple, cast +from typing import Any, Callable, List, Optional, Tuple, cast import torch @@ -13,6 +13,15 @@ from gfn.utils.prob_calculations import get_trajectory_pbs, get_trajectory_pfs +def _mark_cudagraph_step() -> None: + compiler = getattr(torch, "compiler", None) + if compiler is None: + return + marker = getattr(compiler, "cudagraph_mark_step_begin", None) + if callable(marker): + marker() + + class Sampler: """Estimator‑driven sampler for GFlowNet environments. @@ -906,6 +915,7 @@ def __init__( super().__init__(estimator) self.chunk_size = int(chunk_size) self.compile_mode = compile_mode + self._compiled_chunk_cache: dict[tuple[int, str], Callable] = {} def sample_trajectories( self, @@ -1080,21 +1090,35 @@ def _chunk_loop(current_states: torch.Tensor, done_mask: torch.Tensor) -> tuple[ local_sinks, ) - # Fallback to the non-compiled version if compilation fails. - chunk_fn = _chunk_loop - if hasattr(torch, "compile"): - try: - chunk_fn = torch.compile(_chunk_loop, mode=self.compile_mode) # type: ignore[arg-type] - except Exception: - # If compilation fails, use the non-compiled version. - warnings.warn( - "Compilation of chunk_loop failed, using non-compiled version.", - stacklevel=2, - ) - chunk_fn = _chunk_loop + chunk_fn: Callable = _chunk_loop + chunk_fn_compiled = False + device_type = curr_states.device.type + compile_allowed = ( + hasattr(torch, "compile") and device_type in ("cuda", "cpu") and conditions is None and not policy_kwargs + ) + cache_key = (id(env), device_type) + if compile_allowed: + cached = self._compiled_chunk_cache.get(cache_key) + if cached is not None: + chunk_fn = cached + chunk_fn_compiled = True + else: + try: + compiled = torch.compile(_chunk_loop, mode=self.compile_mode) # type: ignore[arg-type] + self._compiled_chunk_cache[cache_key] = compiled + chunk_fn = compiled + chunk_fn_compiled = True + except Exception: + warnings.warn( + "Compilation of chunk_loop failed, using non-compiled version.", + stacklevel=2, + ) + chunk_fn = _chunk_loop # Main loop: call the compiled function until all states are done. while not bool(done.all().item()): + if chunk_fn_compiled: + _mark_cudagraph_step() ( curr_states, done, diff --git a/tutorials/examples/train_hypergrid_optimized.py b/tutorials/examples/train_hypergrid_optimized.py index 759255e5..306271d5 100644 --- a/tutorials/examples/train_hypergrid_optimized.py +++ b/tutorials/examples/train_hypergrid_optimized.py @@ -12,7 +12,7 @@ import warnings from dataclasses import dataclass from pathlib import Path -from typing import Any, Dict, Iterable, List, Literal, cast +from typing import Any, Callable, Dict, Iterable, List, Literal, cast import torch from torch.func import vmap @@ -44,6 +44,24 @@ ) from gfn.utils.training import validate + +def _mark_cudagraph_step() -> None: + compiler = getattr(torch, "compiler", None) + if compiler is None: + return + marker = getattr(compiler, "cudagraph_mark_step_begin", None) + if callable(marker): + marker() + + +def _fill_like_reference(reference: torch.Tensor, value: torch.Tensor) -> torch.Tensor: + """Broadcasts `value` to the shape/dtype/device of `reference`.""" + fill = value.to(device=reference.device, dtype=reference.dtype) + while fill.ndim < reference.ndim: + fill = fill.unsqueeze(0) + return fill.expand_as(reference).clone() + + # Default HyperGrid configuration (easy to extend to multiple envs later on). HYPERGRID_KWARGS: Dict[str, Any] = { "ndim": 2, @@ -318,6 +336,7 @@ class ChunkedHyperGridSampler(Sampler): def __init__(self, estimator, chunk_size: int): super().__init__(estimator) self.chunk_size = int(chunk_size) + self._compiled_chunk_cache: dict[tuple[int, str], Callable] = {} def sample_trajectories( # noqa: C901 self, @@ -401,88 +420,116 @@ def _expand_back(tensor: torch.Tensor, target_ndim: int) -> torch.Tensor: view_shape = tuple(tensor.shape) + (1,) * expand_dims return tensor.view(view_shape) - def _chunk_loop( - current_states: torch.Tensor, - current_masks: torch.Tensor, - done_mask: torch.Tensor, - ): - actions_list: List[torch.Tensor] = [] - dones_list: List[torch.Tensor] = [] - states_list: List[torch.Tensor] = [] - for _ in range(chunk_size): - if bool(done_mask.all().item()): - break - - masks = current_masks - if done_mask.any(): - masks = masks.clone() - masks[done_mask] = False - masks[done_mask, exit_idx] = True - - features = policy.fast_features( - current_states, - forward_masks=masks, - backward_masks=None, - conditions=conditions, - ) - dist = policy.fast_distribution( - features, - forward_masks=masks, - backward_masks=None, - states_tensor=current_states, - epsilon=epsilon, - **policy_kwargs, - ) - sampled_actions = dist.sample() - step_actions = sampled_actions - record_actions = sampled_actions - - if done_mask.any(): - mask = _expand_back(done_mask, sampled_actions.ndim) - exit_fill = _expand_front( - exit_action_value.to( - device=sampled_actions.device, dtype=sampled_actions.dtype - ), - sampled_actions.ndim, + device_type = curr_states.device.type + compile_allowed = ( + hasattr(torch, "compile") and device_type in ("cuda", "cpu") and conditions is None and not policy_kwargs + ) + compile_key = (id(env), device_type) + chunk_fn: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], tuple] | None = None + chunk_fn_compiled = False + if compile_allowed: + chunk_fn = self._compiled_chunk_cache.get(compile_key) + if chunk_fn is not None: + chunk_fn_compiled = True + + if chunk_fn is None: + + def _chunk_loop( + current_states: torch.Tensor, + current_masks: torch.Tensor, + done_mask: torch.Tensor, + ): + actions_list: List[torch.Tensor] = [] + dones_list: List[torch.Tensor] = [] + states_list: List[torch.Tensor] = [] + action_template: torch.Tensor | None = None + steps_taken = 0 + for _ in range(chunk_size): + if bool(done_mask.all().item()): + assert action_template is not None + pad_actions = _fill_like_reference(action_template, dummy_action_value) + actions_list.append(pad_actions) + dones_list.append(done_mask.clone()) + states_list.append(current_states.clone()) + continue + + masks = current_masks + if done_mask.any(): + masks = masks.clone() + masks[done_mask] = False + masks[done_mask, exit_idx] = True + + features = policy.fast_features( + current_states, + forward_masks=masks, + backward_masks=None, + conditions=conditions, ) - dummy_fill = _expand_front( - dummy_action_value.to( - device=sampled_actions.device, dtype=sampled_actions.dtype - ), - sampled_actions.ndim, + dist = policy.fast_distribution( + features, + forward_masks=masks, + backward_masks=None, + states_tensor=current_states, + epsilon=epsilon, + **policy_kwargs, ) - step_actions = torch.where(mask, exit_fill, sampled_actions) - record_actions = torch.where(mask, dummy_fill, sampled_actions) + sampled_actions = dist.sample() + step_actions = sampled_actions + record_actions = sampled_actions + + if done_mask.any(): + mask = _expand_back(done_mask, sampled_actions.ndim) + exit_fill = _expand_front( + exit_action_value.to( + device=sampled_actions.device, dtype=sampled_actions.dtype + ), + sampled_actions.ndim, + ) + dummy_fill = _expand_front( + dummy_action_value.to( + device=sampled_actions.device, dtype=sampled_actions.dtype + ), + sampled_actions.ndim, + ) + step_actions = torch.where(mask, exit_fill, sampled_actions) + record_actions = torch.where(mask, dummy_fill, sampled_actions) - next_states, next_masks, is_exit = step_tensor( - current_states, step_actions - ) + next_states, next_masks, is_exit = step_tensor( + current_states, step_actions + ) - actions_list.append(record_actions) - dones_list.append(is_exit) - states_list.append(next_states.clone()) + actions_list.append(record_actions) + action_template = record_actions.detach() + dones_list.append(is_exit) + states_list.append(next_states.clone()) - current_states = next_states - current_masks = next_masks - done_mask = done_mask | is_exit + current_states = next_states + current_masks = next_masks + done_mask = done_mask | is_exit + steps_taken += 1 - return ( - current_states, - current_masks, - done_mask, - actions_list, - dones_list, - states_list, - ) + return ( + current_states, + current_masks, + done_mask, + actions_list, + dones_list, + states_list, + torch.tensor(steps_taken, device=current_states.device), + ) - chunk_fn = _chunk_loop - if hasattr(torch, "compile"): - try: - chunk_fn = torch.compile(_chunk_loop, mode="reduce-overhead") # type: ignore[arg-type] - except Exception: - pass + chunk_fn = _chunk_loop + if compile_allowed: + try: + chunk_fn = torch.compile(_chunk_loop, mode="reduce-overhead") # type: ignore[arg-type] + self._compiled_chunk_cache[compile_key] = chunk_fn + chunk_fn_compiled = True + except Exception: + chunk_fn = _chunk_loop while not bool(done.all().item()): + if chunk_fn_compiled: + _mark_cudagraph_step() ( curr_states, forward_masks, @@ -490,11 +537,13 @@ def _chunk_loop( actions_chunk, dones_chunk, states_chunk, + steps_taken_tensor, ) = chunk_fn(curr_states, forward_masks, done) - if actions_chunk: - actions_seq.extend(actions_chunk) - dones_seq.extend(dones_chunk) - states_stack.extend(states_chunk) + steps_taken = int(steps_taken_tensor.item()) + if steps_taken: + actions_seq.extend(actions_chunk[:steps_taken]) + dones_seq.extend(dones_chunk[:steps_taken]) + states_stack.extend(states_chunk[:steps_taken]) if actions_seq: actions_tsr = torch.stack(actions_seq, dim=0) @@ -545,6 +594,7 @@ class ChunkedDiffusionSampler(Sampler): def __init__(self, estimator: PinnedBrownianMotionForward, chunk_size: int): super().__init__(estimator) self.chunk_size = int(chunk_size) + self._compiled_chunk_cache: dict[tuple[int, str], Callable] = {} def sample_trajectories( # noqa: C901 self, @@ -688,19 +738,34 @@ def _chunk_loop(current_states: torch.Tensor, done_mask: torch.Tensor) -> tuple[ local_states, ) - chunk_fn = _chunk_loop + chunk_fn: Callable = _chunk_loop + chunk_fn_compiled = False device_type = curr_states.device.type - if hasattr(torch, "compile") and device_type in ("cuda", "cpu"): - try: - chunk_fn = torch.compile(_chunk_loop, mode="reduce-overhead") # type: ignore[arg-type] - except Exception as exc: # pragma: no cover - compile fallback - warnings.warn( - f"Compilation of diffusion chunk loop failed ({exc}); using eager version.", - stacklevel=2, - ) - chunk_fn = _chunk_loop + compile_allowed = ( + hasattr(torch, "compile") and device_type in ("cuda", "cpu") and conditions is None and not policy_kwargs + ) + cache_key = (id(env), device_type) + if compile_allowed: + cached = self._compiled_chunk_cache.get(cache_key) + if cached is not None: + chunk_fn = cached + chunk_fn_compiled = True + else: + try: + compiled = torch.compile(_chunk_loop, mode="reduce-overhead") # type: ignore[arg-type] + self._compiled_chunk_cache[cache_key] = compiled + chunk_fn = compiled + chunk_fn_compiled = True + except Exception as exc: # pragma: no cover - compile fallback + warnings.warn( + f"Compilation of diffusion chunk loop failed ({exc}); using eager version.", + stacklevel=2, + ) + chunk_fn = _chunk_loop while not bool(done.all().item()): + if chunk_fn_compiled: + _mark_cudagraph_step() ( curr_states, done, @@ -797,9 +862,15 @@ def fast_features( conditions: torch.Tensor | None = None, ) -> torch.Tensor: assert states_tensor.dtype == torch.long - khot = torch.nn.functional.one_hot(states_tensor, num_classes=self.height).to( - dtype=torch.get_default_dtype() + sink_mask = states_tensor < 0 # HyperGrid sink state stores -1 in every dim. + safe_states = torch.where( + sink_mask, torch.zeros_like(states_tensor), states_tensor ) + khot = torch.nn.functional.one_hot( + safe_states, num_classes=self.height + ).to(dtype=torch.get_default_dtype()) + if sink_mask.any(): + khot = khot * (~sink_mask).unsqueeze(-1).to(khot.dtype) return khot.view(states_tensor.shape[0], -1) def fast_distribution( From 338a7835ad238d8d637c096d203a0a843491be87 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Thu, 27 Nov 2025 22:11:58 -0500 Subject: [PATCH 10/15] fixed sizes for CUDA --- src/gfn/samplers.py | 33 ++++++++++++++++--- .../examples/train_hypergrid_optimized.py | 25 ++++++++++---- 2 files changed, 47 insertions(+), 11 deletions(-) diff --git a/src/gfn/samplers.py b/src/gfn/samplers.py index d5ec0fe0..070c4965 100644 --- a/src/gfn/samplers.py +++ b/src/gfn/samplers.py @@ -22,6 +22,13 @@ def _mark_cudagraph_step() -> None: marker() +def _fill_like_reference(reference: torch.Tensor, value: torch.Tensor) -> torch.Tensor: + fill = value.to(device=reference.device, dtype=reference.dtype) + while fill.ndim < reference.ndim: + fill = fill.unsqueeze(0) + return fill.expand_as(reference).clone() + + class Sampler: """Estimator‑driven sampler for GFlowNet environments. @@ -998,6 +1005,7 @@ def _chunk_loop(current_states: torch.Tensor, done_mask: torch.Tensor) -> tuple[ List[torch.Tensor], List[torch.Tensor], List[torch.Tensor], + torch.Tensor, ]: """ This function is the core of the chunked sampler. It is responsible for @@ -1019,10 +1027,19 @@ def _chunk_loop(current_states: torch.Tensor, done_mask: torch.Tensor) -> tuple[ local_step_actions: List[torch.Tensor] = [] local_recorded_actions: List[torch.Tensor] = [] local_sinks: List[torch.Tensor] = [] + step_template: torch.Tensor | None = None + record_template: torch.Tensor | None = None + steps_taken = 0 for _ in range(chunk_size): if bool(done_mask.all().item()): - break + assert step_template is not None and record_template is not None + pad_step = _fill_like_reference(step_template, exit_action_value) + pad_record = _fill_like_reference(record_template, dummy_action_value) + local_step_actions.append(pad_step) + local_recorded_actions.append(pad_record) + local_sinks.append(done_mask.clone()) + continue state_view = current_states features = policy.fast_features( @@ -1080,7 +1097,10 @@ def _chunk_loop(current_states: torch.Tensor, done_mask: torch.Tensor) -> tuple[ done_mask = done_mask | sinks local_step_actions.append(step_actions) local_recorded_actions.append(record_actions) + step_template = step_actions.detach() + record_template = record_actions.detach() local_sinks.append(sinks) + steps_taken += 1 return ( current_states, @@ -1088,6 +1108,7 @@ def _chunk_loop(current_states: torch.Tensor, done_mask: torch.Tensor) -> tuple[ local_step_actions, local_recorded_actions, local_sinks, + torch.tensor(steps_taken, device=current_states.device), ) chunk_fn: Callable = _chunk_loop @@ -1125,11 +1146,13 @@ def _chunk_loop(current_states: torch.Tensor, done_mask: torch.Tensor) -> tuple[ step_actions_chunk, recorded_actions_chunk, sinks_chunk, + steps_taken_tensor, ) = chunk_fn(curr_states, done) - if step_actions_chunk: - step_actions_seq.extend(step_actions_chunk) - recorded_actions_seq.extend(recorded_actions_chunk) - sink_seq.extend(sinks_chunk) + steps_taken = int(steps_taken_tensor.item()) + if steps_taken: + step_actions_seq.extend(step_actions_chunk[:steps_taken]) + recorded_actions_seq.extend(recorded_actions_chunk[:steps_taken]) + sink_seq.extend(sinks_chunk[:steps_taken]) if recorded_actions_seq: actions_tsr = torch.stack(recorded_actions_seq, dim=0) diff --git a/tutorials/examples/train_hypergrid_optimized.py b/tutorials/examples/train_hypergrid_optimized.py index 306271d5..0439523d 100644 --- a/tutorials/examples/train_hypergrid_optimized.py +++ b/tutorials/examples/train_hypergrid_optimized.py @@ -674,14 +674,22 @@ def _chunk_loop(current_states: torch.Tensor, done_mask: torch.Tensor) -> tuple[ List[torch.Tensor], List[torch.Tensor], List[torch.Tensor], + torch.Tensor, ]: local_recorded_actions: List[torch.Tensor] = [] local_sinks: List[torch.Tensor] = [] local_states: List[torch.Tensor] = [] + action_template: torch.Tensor | None = None + steps_taken = 0 for _ in range(chunk_size): - if torch.all(done_mask): - break + if bool(done_mask.all().item()): + assert action_template is not None + pad_actions = _fill_like_reference(action_template, dummy_action_value) + local_recorded_actions.append(pad_actions) + local_sinks.append(done_mask.clone()) + local_states.append(current_states.clone()) + continue features = policy.fast_features( current_states, @@ -727,8 +735,10 @@ def _chunk_loop(current_states: torch.Tensor, done_mask: torch.Tensor) -> tuple[ done_mask = done_mask | sinks local_recorded_actions.append(record_actions) + action_template = record_actions.detach() local_sinks.append(sinks) local_states.append(current_states.clone()) + steps_taken += 1 return ( current_states, @@ -736,6 +746,7 @@ def _chunk_loop(current_states: torch.Tensor, done_mask: torch.Tensor) -> tuple[ local_recorded_actions, local_sinks, local_states, + torch.tensor(steps_taken, device=current_states.device), ) chunk_fn: Callable = _chunk_loop @@ -772,11 +783,13 @@ def _chunk_loop(current_states: torch.Tensor, done_mask: torch.Tensor) -> tuple[ recorded_actions_chunk, sinks_chunk, states_chunk, + steps_taken_tensor, ) = chunk_fn(curr_states, done) - if recorded_actions_chunk: - recorded_actions_seq.extend(recorded_actions_chunk) - sink_seq.extend(sinks_chunk) - states_stack.extend(states_chunk) + steps_taken = int(steps_taken_tensor.item()) + if steps_taken: + recorded_actions_seq.extend(recorded_actions_chunk[:steps_taken]) + sink_seq.extend(sinks_chunk[:steps_taken]) + states_stack.extend(states_chunk[:steps_taken]) if recorded_actions_seq: actions_tsr = torch.stack(recorded_actions_seq, dim=0) From f6b3d094e84945c661f3bb863b9e66834fa2765a Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Thu, 27 Nov 2025 22:39:35 -0500 Subject: [PATCH 11/15] sync --- .../examples/train_hypergrid_optimized.py | 44 ++++++++++++++++--- 1 file changed, 37 insertions(+), 7 deletions(-) diff --git a/tutorials/examples/train_hypergrid_optimized.py b/tutorials/examples/train_hypergrid_optimized.py index 0439523d..72cdeaa3 100644 --- a/tutorials/examples/train_hypergrid_optimized.py +++ b/tutorials/examples/train_hypergrid_optimized.py @@ -106,18 +106,26 @@ class FlowVariant: use_vmap=False, ), ScenarioConfig( - name="Library Fast Path", - description="Core EnvFastPath + CompiledChunkSampler + compile + vmap TB.", - sampler="compiled_chunk", + name="VMap Only", + description="VMAP Accelerated Loss.", + sampler="standard", + use_script_env=False, + use_compile=False, + use_vmap=True, + ), + ScenarioConfig( + name="Compile Only (core)", + description="Standard env + sampler with torch.compile but no chunking.", + sampler="standard", use_script_env=False, use_compile=True, use_vmap=True, ), ScenarioConfig( - name="Script Env + Compiled Chunk", - description="Script-local tensor env + library CompiledChunkSampler + compile/vmap.", + name="Library Fast Path", + description="Core EnvFastPath + CompiledChunkSampler + compile + vmap TB.", sampler="compiled_chunk", - use_script_env=True, + use_script_env=False, use_compile=True, use_vmap=True, ), @@ -140,6 +148,22 @@ class FlowVariant: use_compile=False, use_vmap=False, ), + ScenarioConfig( + name="Diffusion VMap Only", + description="Pinned Brownian sampler without compilation or chunking.", + sampler="standard", + use_script_env=False, + use_compile=False, + use_vmap=True, + ), + ScenarioConfig( + name="Diffusion Compile Only", + description="Standard diffusion sampler with torch.compile but no chunking.", + sampler="standard", + use_script_env=False, + use_compile=True, + use_vmap=True, + ), ScenarioConfig( name="Diffusion Library Fast Path", description="EnvFastPath + CompiledChunkSampler (library implementation).", @@ -193,10 +217,11 @@ class FlowVariant: } SCENARIO_LINESTYLES: dict[str, Any] = { "Baseline (core)": "-", + "Compile Only (core)": "-.", "Library Fast Path": "--", # fast-path compiled - "Script Env + Compiled Chunk": "dashdot", "Script Fast Path": ":", "Diffusion Baseline": "-", + "Diffusion Compile Only": "-.", "Diffusion Library Fast Path": "--", "Diffusion Script Fast Path": ":", } @@ -1208,6 +1233,7 @@ def run_scenario( track_time=False, record_history=False, supports_validation=env_cfg.supports_validation, + mark_compiled_step=scenario.use_compile, ) elapsed, history = run_iterations( @@ -1225,6 +1251,7 @@ def run_scenario( track_time=True, record_history=True, supports_validation=env_cfg.supports_validation, + mark_compiled_step=scenario.use_compile, ) validation_info = metrics["validation_info"] @@ -1272,6 +1299,7 @@ def run_iterations( track_time: bool, record_history: bool, supports_validation: bool, + mark_compiled_step: bool = False, ) -> tuple[float | None, Dict[str, list[float]] | None]: if n_iters <= 0: empty_history = {"losses": [], "iter_times": []} if record_history else None @@ -1290,6 +1318,8 @@ def run_iterations( for _ in iterator: iter_start = time.perf_counter() if (track_time or record_history) else None + if mark_compiled_step: + _mark_cudagraph_step() trajectories = sampler.sample_trajectories( env, n=args.batch_size, From 41095a805af42f9be06841ee43c4c2554d5b5286 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Thu, 27 Nov 2025 22:58:28 -0500 Subject: [PATCH 12/15] sync --- .../examples/train_hypergrid_optimized.py | 234 ++++++++++++------ 1 file changed, 161 insertions(+), 73 deletions(-) diff --git a/tutorials/examples/train_hypergrid_optimized.py b/tutorials/examples/train_hypergrid_optimized.py index 72cdeaa3..39afe71b 100644 --- a/tutorials/examples/train_hypergrid_optimized.py +++ b/tutorials/examples/train_hypergrid_optimized.py @@ -293,51 +293,79 @@ def _normalize_env_keys(requested: list[str]) -> list[str]: # Local subclasses for benchmarking-only optimizations (no core library changes) class HyperGridWithTensorStep(HyperGrid): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self._unit_step_cache: dict[ + tuple[torch.device, torch.dtype], torch.Tensor + ] = {} + self._sink_state_cache: dict[ + tuple[torch.device, torch.dtype], torch.Tensor + ] = {} + + def _get_unit_steps( + self, device: torch.device, dtype: torch.dtype + ) -> torch.Tensor: + key = (device, dtype) + cached = self._unit_step_cache.get(key) + if cached is None: + identity = torch.eye( + self.ndim, dtype=dtype, device=device, requires_grad=False + ) + zero_row = torch.zeros( + (1, self.ndim), dtype=dtype, device=device, requires_grad=False + ) + cached = torch.cat([identity, zero_row], dim=0) + self._unit_step_cache[key] = cached + return cached + + def _get_sink_state( + self, device: torch.device, dtype: torch.dtype + ) -> torch.Tensor: + key = (device, dtype) + cached = self._sink_state_cache.get(key) + if cached is None: + cached = self.sf.to(device=device, dtype=dtype) + self._sink_state_cache[key] = cached + return cached + def step_tensor( self, states: torch.Tensor, actions: torch.Tensor ) -> Env.TensorStepResult: assert states.dtype == torch.long device = states.device batch = states.shape[0] - ndim = self.ndim exit_idx = self.n_actions - 1 if actions.ndim == 1: - actions_idx = actions.view(-1, 1) + action_idx = actions else: - assert actions.shape[-1] == 1 - actions_idx = actions + action_idx = actions.view(-1) + + action_idx = action_idx.to(torch.long) + is_exit = action_idx == exit_idx - is_exit = actions_idx.squeeze(-1) == exit_idx + unit_steps = self._get_unit_steps(device, states.dtype) + deltas = unit_steps.index_select(0, action_idx.clamp(max=exit_idx)) + next_states = states + deltas - next_states = states.clone() - non_exit_mask = ~is_exit - non_exit_mask_exp = non_exit_mask.unsqueeze(-1) - safe_actions = torch.where( - non_exit_mask_exp, actions_idx, torch.zeros_like(actions_idx) + sink_state = self._get_sink_state(device, states.dtype).view(1, -1) + next_states = torch.where( + is_exit.view(-1, 1), sink_state.expand_as(next_states), next_states ) - delta = torch.zeros_like(next_states) - delta = delta.scatter(-1, safe_actions, 1, reduce="add") - delta = delta * non_exit_mask_exp.to(next_states.dtype) - next_states = next_states + delta - - # Ensure exit actions land exactly on the sink state so downstream - # `is_sink_state` masks match the action padding semantics assumed - # by `Trajectories` and probability calculations. - sink_state = self.sf.to(device=device).unsqueeze(0).expand_as(next_states) - next_states = torch.where(is_exit.unsqueeze(-1), sink_state, next_states) - - next_forward_masks = torch.ones( - (batch, self.n_actions), dtype=torch.bool, device=device + + forward_masks = torch.cat( + [ + next_states != (self.height - 1), + torch.ones((batch, 1), dtype=torch.bool, device=device), + ], + dim=-1, ) - next_forward_masks[:, :ndim] = next_states != (self.height - 1) - next_forward_masks[:, ndim] = True backward_masks = next_states != 0 return self.TensorStepResult( next_states=next_states, is_sink_state=is_exit, - forward_masks=next_forward_masks, + forward_masks=forward_masks, backward_masks=backward_masks, ) @@ -429,21 +457,37 @@ def step_tensor( done = torch.zeros(batch, dtype=torch.bool, device=device) actions_seq: List[torch.Tensor] = [] dones_seq: List[torch.Tensor] = [] - states_stack: List[torch.Tensor] = [curr_states.clone()] + states_seq: List[torch.Tensor] = [curr_states.clone().unsqueeze(0)] - def _expand_front(tensor: torch.Tensor, target_ndim: int) -> torch.Tensor: - expand_dims = target_ndim - tensor.ndim - if expand_dims <= 0: - return tensor - view_shape = (1,) * expand_dims + tuple(tensor.shape) - return tensor.view(view_shape) + exit_template_cache: dict[tuple[int, torch.dtype], torch.Tensor] = {} + dummy_template_cache: dict[tuple[int, torch.dtype], torch.Tensor] = {} - def _expand_back(tensor: torch.Tensor, target_ndim: int) -> torch.Tensor: - expand_dims = target_ndim - tensor.ndim - if expand_dims <= 0: - return tensor - view_shape = tuple(tensor.shape) + (1,) * expand_dims - return tensor.view(view_shape) + def _broadcast_done_mask( + mask: torch.Tensor, target_ndim: int + ) -> torch.Tensor: + view_shape = mask.shape + (1,) * (target_ndim - mask.ndim) + return mask.view(view_shape) + + def _get_template( + cache: dict[tuple[int, torch.dtype], torch.Tensor], + base_value: torch.Tensor, + target_ndim: int, + dtype: torch.dtype, + device: torch.device, + ) -> torch.Tensor: + key = (target_ndim, dtype) + tensor = cache.get(key) + if tensor is None: + tensor = base_value.to(device=device, dtype=dtype) + if tensor.ndim > target_ndim: + raise ValueError( + f"Base action tensor has ndim={tensor.ndim}, " + f"but target_ndim={target_ndim}." + ) + leading = (1,) * (target_ndim - tensor.ndim) + tensor = tensor.view(leading + tuple(tensor.shape)) + cache[key] = tensor + return tensor device_type = curr_states.device.type compile_allowed = ( @@ -464,18 +508,21 @@ def _chunk_loop( current_masks: torch.Tensor, done_mask: torch.Tensor, ): - actions_list: List[torch.Tensor] = [] - dones_list: List[torch.Tensor] = [] - states_list: List[torch.Tensor] = [] - action_template: torch.Tensor | None = None + actions_buf: torch.Tensor | None = None + dones_buf: torch.Tensor | None = None + states_buf: torch.Tensor | None = None + pad_template: torch.Tensor | None = None steps_taken = 0 - for _ in range(chunk_size): + + for step in range(chunk_size): if bool(done_mask.all().item()): - assert action_template is not None - pad_actions = _fill_like_reference(action_template, dummy_action_value) - actions_list.append(pad_actions) - dones_list.append(done_mask.clone()) - states_list.append(current_states.clone()) + assert actions_buf is not None + assert dones_buf is not None + assert states_buf is not None + assert pad_template is not None + actions_buf[step].copy_(pad_template) + dones_buf[step].copy_(done_mask) + states_buf[step].copy_(current_states) continue masks = current_masks @@ -503,19 +550,21 @@ def _chunk_loop( record_actions = sampled_actions if done_mask.any(): - mask = _expand_back(done_mask, sampled_actions.ndim) - exit_fill = _expand_front( - exit_action_value.to( - device=sampled_actions.device, dtype=sampled_actions.dtype - ), + mask = _broadcast_done_mask(done_mask, sampled_actions.ndim) + exit_fill = _get_template( + exit_template_cache, + exit_action_value, sampled_actions.ndim, - ) - dummy_fill = _expand_front( - dummy_action_value.to( - device=sampled_actions.device, dtype=sampled_actions.dtype - ), + sampled_actions.dtype, + sampled_actions.device, + ).expand_as(sampled_actions) + dummy_fill = _get_template( + dummy_template_cache, + dummy_action_value, sampled_actions.ndim, - ) + sampled_actions.dtype, + sampled_actions.device, + ).expand_as(sampled_actions) step_actions = torch.where(mask, exit_fill, sampled_actions) record_actions = torch.where(mask, dummy_fill, sampled_actions) @@ -523,23 +572,62 @@ def _chunk_loop( current_states, step_actions ) - actions_list.append(record_actions) - action_template = record_actions.detach() - dones_list.append(is_exit) - states_list.append(next_states.clone()) + if actions_buf is None: + actions_buf = record_actions.new_empty( + (chunk_size,) + tuple(record_actions.shape) + ) + dones_buf = is_exit.new_empty( + (chunk_size,) + tuple(is_exit.shape) + ) + states_buf = next_states.new_empty( + (chunk_size,) + tuple(next_states.shape) + ) + pad_template = _get_template( + dummy_template_cache, + dummy_action_value, + record_actions.ndim, + record_actions.dtype, + record_actions.device, + ).expand_as(record_actions) + + assert actions_buf is not None + assert dones_buf is not None + assert states_buf is not None + + actions_buf[step].copy_(record_actions) + dones_buf[step].copy_(is_exit) + states_buf[step].copy_(next_states) current_states = next_states current_masks = next_masks done_mask = done_mask | is_exit steps_taken += 1 + if actions_buf is None: + batch = current_states.shape[0] + empty_actions = env.actions_from_batch_shape( + (0, batch) + ).tensor.to(device=current_states.device) + actions_out = empty_actions + else: + actions_out = actions_buf[:steps_taken] + + empty_dones = done_mask.new_empty((0,) + done_mask.shape) + empty_states = current_states.new_empty((0,) + current_states.shape) + dones_out = ( + dones_buf[:steps_taken] if dones_buf is not None else empty_dones + ) + states_out = ( + states_buf[:steps_taken] if states_buf is not None else empty_states + ) + return ( current_states, current_masks, done_mask, - actions_list, - dones_list, - states_list, + actions_out, + dones_out, + states_out, torch.tensor(steps_taken, device=current_states.device), ) @@ -566,13 +654,13 @@ def _chunk_loop( ) = chunk_fn(curr_states, forward_masks, done) steps_taken = int(steps_taken_tensor.item()) if steps_taken: - actions_seq.extend(actions_chunk[:steps_taken]) - dones_seq.extend(dones_chunk[:steps_taken]) - states_stack.extend(states_chunk[:steps_taken]) + actions_seq.append(actions_chunk) + dones_seq.append(dones_chunk) + states_seq.append(states_chunk) if actions_seq: - actions_tsr = torch.stack(actions_seq, dim=0) - states_tsr = torch.stack(states_stack, dim=0) + actions_tsr = torch.cat(actions_seq, dim=0) + states_tsr = torch.cat(states_seq, dim=0) action_shape = getattr(env, "action_shape", None) if action_shape: tail_shape = tuple(actions_tsr.shape[-len(action_shape):]) @@ -584,7 +672,7 @@ def _chunk_loop( "ChunkedHyperGridSampler produced actions with shape " f"{actions_tsr.shape}, expected trailing dims {action_shape}." ) - is_exit_seq = torch.stack(dones_seq, dim=0) + is_exit_seq = torch.cat(dones_seq, dim=0) T = actions_tsr.shape[0] first_exit = torch.argmax(is_exit_seq.to(torch.long), dim=0) never_exited = ~is_exit_seq.any(dim=0) From 29d6a26a6c009bc87d53449e23a424793a64001b Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Thu, 27 Nov 2025 23:01:08 -0500 Subject: [PATCH 13/15] sync --- .../examples/train_hypergrid_optimized.py | 43 ++++++++----------- 1 file changed, 17 insertions(+), 26 deletions(-) diff --git a/tutorials/examples/train_hypergrid_optimized.py b/tutorials/examples/train_hypergrid_optimized.py index 39afe71b..ac0a8527 100644 --- a/tutorials/examples/train_hypergrid_optimized.py +++ b/tutorials/examples/train_hypergrid_optimized.py @@ -18,6 +18,13 @@ from torch.func import vmap from tqdm import tqdm +try: # Enable scalar captures for torch.compile to avoid graph breaks on .item(). + import torch._dynamo as _torch_dynamo + + _torch_dynamo.config.capture_scalar_outputs = True +except Exception: # pragma: no cover - defensive fallback on older PyTorch + _torch_dynamo = None + from gfn.containers import Trajectories from gfn.env import Env, EnvFastPathMixin from gfn.estimators import ( @@ -295,38 +302,20 @@ def _normalize_env_keys(requested: list[str]) -> list[str]: class HyperGridWithTensorStep(HyperGrid): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) - self._unit_step_cache: dict[ - tuple[torch.device, torch.dtype], torch.Tensor - ] = {} - self._sink_state_cache: dict[ - tuple[torch.device, torch.dtype], torch.Tensor - ] = {} + eye = torch.eye(self.ndim, dtype=torch.long) + zero_row = torch.zeros((1, self.ndim), dtype=torch.long) + self._unit_step_template = torch.cat([eye, zero_row], dim=0) + self._sink_state_template = self.sf.to(dtype=torch.long).clone() def _get_unit_steps( self, device: torch.device, dtype: torch.dtype ) -> torch.Tensor: - key = (device, dtype) - cached = self._unit_step_cache.get(key) - if cached is None: - identity = torch.eye( - self.ndim, dtype=dtype, device=device, requires_grad=False - ) - zero_row = torch.zeros( - (1, self.ndim), dtype=dtype, device=device, requires_grad=False - ) - cached = torch.cat([identity, zero_row], dim=0) - self._unit_step_cache[key] = cached - return cached + return self._unit_step_template.to(device=device, dtype=dtype) def _get_sink_state( self, device: torch.device, dtype: torch.dtype ) -> torch.Tensor: - key = (device, dtype) - cached = self._sink_state_cache.get(key) - if cached is None: - cached = self.sf.to(device=device, dtype=dtype) - self._sink_state_cache[key] = cached - return cached + return self._sink_state_template.to(device=device, dtype=dtype) def step_tensor( self, states: torch.Tensor, actions: torch.Tensor @@ -1293,12 +1282,14 @@ def run_scenario( ) = build_training_components(args, device, scenario, flow_variant, env_cfg) metrics = init_metrics() use_vmap = scenario.use_vmap and flow_variant.supports_vmap + compiled_any = False if scenario.use_compile: compile_results = try_compile_gflownet( gflownet, mode=DEFAULT_COMPILE_MODE, ) + compiled_any = any(compile_results.values()) formatted = ", ".join( f"{name}:{'✓' if success else 'x'}" for name, success in compile_results.items() @@ -1321,7 +1312,7 @@ def run_scenario( track_time=False, record_history=False, supports_validation=env_cfg.supports_validation, - mark_compiled_step=scenario.use_compile, + mark_compiled_step=compiled_any, ) elapsed, history = run_iterations( @@ -1339,7 +1330,7 @@ def run_scenario( track_time=True, record_history=True, supports_validation=env_cfg.supports_validation, - mark_compiled_step=scenario.use_compile, + mark_compiled_step=compiled_any, ) validation_info = metrics["validation_info"] From 81eb3b2bbd468383310a3be8c18d7af23afeecc1 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Thu, 27 Nov 2025 23:04:41 -0500 Subject: [PATCH 14/15] sync --- tutorials/examples/train_hypergrid_optimized.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tutorials/examples/train_hypergrid_optimized.py b/tutorials/examples/train_hypergrid_optimized.py index ac0a8527..1195e6ea 100644 --- a/tutorials/examples/train_hypergrid_optimized.py +++ b/tutorials/examples/train_hypergrid_optimized.py @@ -2,6 +2,17 @@ r""" Optimized multi-environment (HyperGrid + Diffusion) training/benchmark script with optional torch.compile, vmap, and chunked sampling across several GFlowNet variants. + + +TODO: + +We need actual profiling on CUDA (start with torch.profiler.profile(use_cuda=True) around chunk_fn) to see the kernel counts and copy sizes. If compile is failing, we must inspect the Dynamo logs to see what op blocks it (maybe the env.actions_from_batch_shape call inside _chunk_loop still triggers Python). If compile succeeds, then GPU is just overwhelmed by host-device copies and we should keep the script fast path on CPU. +Next actions I recommend: +Run the benchmark with TORCH_LOGS="graph_breaks" and TORCHDYNAMO_VERBOSE=1 so we can see why the chunk loops bail out of compilation. Share those snippets if possible. +Profile the “Library Fast Path” on CUDA (PyTorch profiler) to find the hottest ops. If the C++ chunk sampler is dominated by scatter/where, we might need to batch them or increase chunk size to amortize kernel launches. +For the script sampler, try forcing chunk_fn_compiled=False to confirm whether the slowdown is due to torch.compile overhead; if it runs faster without compilation, we’ll know the compiled graph is recomputing templates or copying more than expected. + + """ from __future__ import annotations From 0a1ff4b60f4fad4a7827a5ead76fb1e211405694 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Thu, 27 Nov 2025 23:09:09 -0500 Subject: [PATCH 15/15] sync --- .../examples/train_hypergrid_optimized.py | 56 +++++++++++-------- 1 file changed, 33 insertions(+), 23 deletions(-) diff --git a/tutorials/examples/train_hypergrid_optimized.py b/tutorials/examples/train_hypergrid_optimized.py index 1195e6ea..1b9cf93f 100644 --- a/tutorials/examples/train_hypergrid_optimized.py +++ b/tutorials/examples/train_hypergrid_optimized.py @@ -318,14 +318,10 @@ def __init__(self, *args, **kwargs) -> None: self._unit_step_template = torch.cat([eye, zero_row], dim=0) self._sink_state_template = self.sf.to(dtype=torch.long).clone() - def _get_unit_steps( - self, device: torch.device, dtype: torch.dtype - ) -> torch.Tensor: + def _get_unit_steps(self, device: torch.device, dtype: torch.dtype) -> torch.Tensor: return self._unit_step_template.to(device=device, dtype=dtype) - def _get_sink_state( - self, device: torch.device, dtype: torch.dtype - ) -> torch.Tensor: + def _get_sink_state(self, device: torch.device, dtype: torch.dtype) -> torch.Tensor: return self._sink_state_template.to(device=device, dtype=dtype) def step_tensor( @@ -462,9 +458,7 @@ def step_tensor( exit_template_cache: dict[tuple[int, torch.dtype], torch.Tensor] = {} dummy_template_cache: dict[tuple[int, torch.dtype], torch.Tensor] = {} - def _broadcast_done_mask( - mask: torch.Tensor, target_ndim: int - ) -> torch.Tensor: + def _broadcast_done_mask(mask: torch.Tensor, target_ndim: int) -> torch.Tensor: view_shape = mask.shape + (1,) * (target_ndim - mask.ndim) return mask.view(view_shape) @@ -491,10 +485,15 @@ def _get_template( device_type = curr_states.device.type compile_allowed = ( - hasattr(torch, "compile") and device_type in ("cuda", "cpu") and conditions is None and not policy_kwargs + hasattr(torch, "compile") + and device_type in ("cuda", "cpu") + and conditions is None + and not policy_kwargs ) compile_key = (id(env), device_type) - chunk_fn: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], tuple] | None = None + chunk_fn: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], tuple] | None = ( + None + ) chunk_fn_compiled = False if compile_allowed: chunk_fn = self._compiled_chunk_cache.get(compile_key) @@ -605,9 +604,9 @@ def _chunk_loop( if actions_buf is None: batch = current_states.shape[0] - empty_actions = env.actions_from_batch_shape( - (0, batch) - ).tensor.to(device=current_states.device) + empty_actions = env.actions_from_batch_shape((0, batch)).tensor.to( + device=current_states.device + ) actions_out = empty_actions else: actions_out = actions_buf[:steps_taken] @@ -663,7 +662,7 @@ def _chunk_loop( states_tsr = torch.cat(states_seq, dim=0) action_shape = getattr(env, "action_shape", None) if action_shape: - tail_shape = tuple(actions_tsr.shape[-len(action_shape):]) + tail_shape = tuple(actions_tsr.shape[-len(action_shape) :]) if tail_shape != tuple(action_shape): if tuple(action_shape) == (1,): actions_tsr = actions_tsr.unsqueeze(-1) @@ -777,7 +776,9 @@ def _get_template( key = (target_ndim, dtype) tensor = cache.get(key) if tensor is None: - tensor = _expand_front(base_value.to(device=device, dtype=dtype), target_ndim) + tensor = _expand_front( + base_value.to(device=device, dtype=dtype), target_ndim + ) cache[key] = tensor return tensor @@ -798,7 +799,9 @@ def _chunk_loop(current_states: torch.Tensor, done_mask: torch.Tensor) -> tuple[ for _ in range(chunk_size): if bool(done_mask.all().item()): assert action_template is not None - pad_actions = _fill_like_reference(action_template, dummy_action_value) + pad_actions = _fill_like_reference( + action_template, dummy_action_value + ) local_recorded_actions.append(pad_actions) local_sinks.append(done_mask.clone()) local_states.append(current_states.clone()) @@ -866,7 +869,10 @@ def _chunk_loop(current_states: torch.Tensor, done_mask: torch.Tensor) -> tuple[ chunk_fn_compiled = False device_type = curr_states.device.type compile_allowed = ( - hasattr(torch, "compile") and device_type in ("cuda", "cpu") and conditions is None and not policy_kwargs + hasattr(torch, "compile") + and device_type in ("cuda", "cpu") + and conditions is None + and not policy_kwargs ) cache_key = (id(env), device_type) if compile_allowed: @@ -909,7 +915,7 @@ def _chunk_loop(current_states: torch.Tensor, done_mask: torch.Tensor) -> tuple[ states_tsr = torch.stack(states_stack, dim=0) action_shape = getattr(env, "action_shape", None) if action_shape: - tail_shape = tuple(actions_tsr.shape[-len(action_shape):]) + tail_shape = tuple(actions_tsr.shape[-len(action_shape) :]) if tail_shape != tuple(action_shape): if tuple(action_shape) == (1,): actions_tsr = actions_tsr.unsqueeze(-1) @@ -992,9 +998,9 @@ def fast_features( safe_states = torch.where( sink_mask, torch.zeros_like(states_tensor), states_tensor ) - khot = torch.nn.functional.one_hot( - safe_states, num_classes=self.height - ).to(dtype=torch.get_default_dtype()) + khot = torch.nn.functional.one_hot(safe_states, num_classes=self.height).to( + dtype=torch.get_default_dtype() + ) if sink_mask.any(): khot = khot * (~sink_mask).unsqueeze(-1).to(khot.dtype) return khot.view(states_tensor.shape[0], -1) @@ -1434,7 +1440,11 @@ def run_iterations( metrics["measured_steps"] += 1 last_loss = loss.item() - if record_history and losses_history is not None and iter_time_history is not None: + if ( + record_history + and losses_history is not None + and iter_time_history is not None + ): losses_history.append(last_loss) iter_duration = ( (time.perf_counter() - iter_start) if iter_start is not None else 0.0