diff --git a/src/gfn/containers/trajectories.py b/src/gfn/containers/trajectories.py index 49a6980f..331bbb10 100644 --- a/src/gfn/containers/trajectories.py +++ b/src/gfn/containers/trajectories.py @@ -511,6 +511,146 @@ def to_states_container(self) -> StatesContainer: # FIXME: Add log_probs and estimator_outputs. ) + def splice_from_reconstruction( + self, + n_prevs: torch.Tensor, + recon: "Trajectories", + debug: bool = False, + ) -> "Trajectories": + """Splice this (prefix) trajectories with a reconstructed suffix. + + This mirrors the sampler's combine logic: keep the first ``n_prevs[i]`` steps + from ``self`` for each trajectory, then append the reconstructed portion from + ``recon``. When provided, the PF/PB per-step tensors are spliced identically. + + Args: + n_prevs: How many prefix steps to take from ``self`` per trajectory. + recon: Trajectories representing the reconstructed suffix per trajectory. + debug: If True, validates the vectorized splicing against a for-loop. + + Returns: + new_trajectories + """ + # Basic validations to avoid silent shape/device mismatches. + assert self.env is recon.env + bs = self.n_trajectories + assert bs == recon.n_trajectories + device = self.states.device + + # Determine prefix/suffix lengths and maxima for sizing outputs. + max_n_prev = n_prevs.max() + n_recons = recon.terminating_idx + max_n_recon = n_recons.max() + + # Episodic reward is taken from the reconstructed suffix. + new_log_rewards = recon.log_rewards + new_dones = n_prevs + n_recons + max_traj_len = int(new_dones.max().item()) + + # Helper indices and masks over (time, batch). + idx = ( + torch.arange(max_traj_len + 1, device=n_prevs.device) + .unsqueeze(1) + .expand(-1, bs) + ) + prev_mask = idx < n_prevs + state_recon_mask = (idx >= n_prevs) * (idx <= n_prevs + n_recons) + state_recon_mask2 = idx[: max_n_recon + 1] <= n_recons + action_recon_mask = (idx[:-1] >= n_prevs) * (idx[:-1] <= n_prevs + n_recons - 1) + action_recon_mask2 = idx[:max_n_recon] <= n_recons - 1 + + # Transpose to (batch, time, ...) for efficient advanced indexing. + prev_states_tsr = self.states.tensor.transpose(0, 1) + prev_actions_tsr = self.actions.tensor.transpose(0, 1) + recon_states_tsr = recon.states.tensor.transpose(0, 1) + recon_actions_tsr = recon.actions.tensor.transpose(0, 1) + prev_mask = prev_mask.transpose(0, 1) + state_recon_mask = state_recon_mask.transpose(0, 1) + state_recon_mask2 = state_recon_mask2.transpose(0, 1) + action_recon_mask = action_recon_mask.transpose(0, 1) + action_recon_mask2 = action_recon_mask2.transpose(0, 1) + + # Prepare output tensors in transposed shapes, initialized to sink/dummy. + new_states_tsr = self.env.sf.repeat(bs, max_traj_len + 1, 1).to( + self.states.tensor + ) + new_actions_tsr = self.env.dummy_action.repeat(bs, max_traj_len, 1).to( + self.actions.tensor + ) + + # Fill prefix from ``self`` using ``prev_mask``. + prev_mask_truc = prev_mask[:, :max_n_prev] + new_states_tsr[prev_mask] = prev_states_tsr[:, :max_n_prev][prev_mask_truc] + new_actions_tsr[prev_mask[:, :-1]] = prev_actions_tsr[:, :max_n_prev][ + prev_mask_truc + ] + + # Fill suffix from ``recon`` using recon masks. + new_states_tsr[state_recon_mask] = recon_states_tsr[state_recon_mask2] + new_actions_tsr[action_recon_mask] = recon_actions_tsr[action_recon_mask2] + + # Transpose back to (time, batch, ...). + new_states_tsr = new_states_tsr.transpose(0, 1) + new_actions_tsr = new_actions_tsr.transpose(0, 1) + + # Optionally combine PF from container log_probs when both are present. + new_log_pf = None + if (self.log_probs is not None) and (recon.log_probs is not None): + plp = self.log_probs.transpose(0, 1) + rlp = recon.log_probs.transpose(0, 1) + new_log_pf = torch.full((bs, max_traj_len), 0.0).to( + device=device, dtype=plp.dtype # type: ignore[arg-type] + ) + new_log_pf[prev_mask[:, :-1]] = plp[:, :max_n_prev][prev_mask_truc] + new_log_pf[action_recon_mask] = rlp[action_recon_mask2] + new_log_pf = new_log_pf.transpose(0, 1) + + # ------------------------------ DEBUG ------------------------------ + if debug: + _states = self.env.sf.repeat(max_traj_len + 1, bs, 1).to(self.states.tensor) + _actions = self.env.dummy_action.repeat(max_traj_len, bs, 1).to( + self.actions.tensor + ) + # PF/PB debug-specific tensors are omitted; we only validate states/actions. + + for i in range(bs): + _n_prev = n_prevs[i] + # Prefix copy + _states[: _n_prev + 1, i] = self.states.tensor[: _n_prev + 1, i] + _actions[:_n_prev, i] = self.actions.tensor[:_n_prev, i] + # Suffix copy + _len_recon = recon.terminating_idx[i] + _states[_n_prev + 1 : _n_prev + _len_recon + 1, i] = recon.states.tensor[ + 1 : _len_recon + 1, + i, + ] + _actions[_n_prev : _n_prev + _len_recon, i] = recon.actions.tensor[ + :_len_recon, + i, + ] + + # PF/PB debug path is intentionally omitted. + + assert torch.all(_states == new_states_tsr) + assert torch.all(_actions == new_actions_tsr) + + # PF/PB debug comparisons are intentionally omitted. + + # Materialize the forward-time trajectories and carry over conditioning + # and episodic rewards from the reconstructed suffix. + new_trajs = Trajectories( + env=self.env, + states=self.env.states_from_tensor(new_states_tsr), + conditioning=self.conditioning, + actions=self.env.actions_from_tensor(new_actions_tsr), + terminating_idx=new_dones, + is_backward=False, + log_rewards=new_log_rewards, + log_probs=new_log_pf, + ) + + return new_trajs + def reverse_backward_trajectories(self) -> Trajectories: """Returns a reversed version of the backward trajectories.""" assert self.is_backward, "Trajectories must be backward." diff --git a/src/gfn/samplers.py b/src/gfn/samplers.py index 686ef924..4aaf0008 100644 --- a/src/gfn/samplers.py +++ b/src/gfn/samplers.py @@ -375,6 +375,148 @@ def __init__( super().__init__(pf_estimator) self.backward_sampler = Sampler(pb_estimator) + @staticmethod + def _compute_back_steps( + terminating_idx: torch.Tensor, + back_steps: torch.Tensor | None, + back_ratio: float | None, + ) -> torch.Tensor: + """Compute per-trajectory backtrack length K with validation and clamping. + + This centralizes the logic for deriving K used in local search. + The behavior mirrors the inline implementation: + - When ``back_steps`` is None, require ``0 < back_ratio <= 1`` and set + ``K = ceil(back_ratio * (terminating_idx - 1))``. + - Otherwise, clamp provided ``back_steps`` to ``terminating_idx``. + """ + if back_steps is None: + assert ( + back_ratio is not None and 0 < back_ratio <= 1 + ), "Either kwarg `back_steps` or `back_ratio` must be specified" + return torch.ceil(back_ratio * (terminating_idx - 1)).long() + return torch.where( + back_steps > terminating_idx, + terminating_idx, + back_steps, + ) + + def _reconstruct_from_junctions( + self, + env: Env, + prev_trajectories: Trajectories, + n_prevs: torch.Tensor, + conditioning: torch.Tensor | None, + save_estimator_outputs: bool, + save_logprobs: bool, + **policy_kwargs: Any, + ) -> Trajectories: + """Gather junction states and reconstruct forward suffixes with self.estimator. + + This isolates the PolicyMixin-dependent sampling. + """ + # Derive junction positions and gather the junction states (one per traj). + junction_states_tsr = torch.gather( + prev_trajectories.states.tensor, + 0, + (n_prevs) + .view(1, -1, 1) + .expand(-1, -1, *prev_trajectories.states.state_shape), + ).squeeze(0) + + # Reconstruct forward suffixes starting from the junction states using the + # forward policy estimator owned by `self`. + recon_trajectories = super().sample_trajectories( + env, + states=env.states_from_tensor(junction_states_tsr), + conditioning=conditioning, + save_estimator_outputs=save_estimator_outputs, + save_logprobs=save_logprobs, + **policy_kwargs, + ) + + return recon_trajectories + + @staticmethod + def _metropolis_hastings_accept( + prev_log_rewards: torch.Tensor, + new_log_rewards: torch.Tensor, + prev_log_pf: torch.Tensor, + new_log_pf: torch.Tensor, + prev_log_pb: torch.Tensor, + new_log_pb: torch.Tensor, + ) -> torch.Tensor: + """Compute MH acceptance mask for candidate trajectories. + + The acceptance probability is + min{1, R(x') p_B(x->s') p_F(s'->x') / [R(x) p_B(x'->s') p_F(s'->x)]}. + """ + log_accept_ratio = torch.clamp_max( + new_log_rewards + + prev_log_pb.sum(0) + + new_log_pf.sum(0) + - prev_log_rewards + - new_log_pb.sum(0) + - prev_log_pf.sum(0), + 0.0, + ) + return torch.rand( + new_log_rewards.shape[0], device=log_accept_ratio.device + ) < torch.exp(log_accept_ratio) + + @staticmethod + def _splice_pf( + n_prevs: torch.Tensor, + prev_log_pf: torch.Tensor, + n_recons: torch.Tensor, + recon_log_pf: torch.Tensor, + T_new: int, + device: torch.device, + ) -> torch.Tensor: + """Splice per-step PF log-probs of prefix and reconstructed suffix. + + Args: + n_prevs: Number of prefix steps kept from prev trajectories (N,). + prev_log_pf: Per-step PF for prev trajectories (T_prev, N). + n_recons: Number of reconstruction steps per trajectory (N,). + recon_log_pf: Per-step PF for reconstructed trajectories (T_recon, N). + T_new: Maximum trajectory length of the spliced trajectories. + device: Torch device for the output tensor. + + Returns: + Spliced per-step PF log-probs of shape (T_new, N). + """ + bs = int(n_prevs.shape[0]) + + # Determine maxima for mask construction + max_n_prev = n_prevs.max() + max_n_recon = n_recons.max() + + # Build masks over states time (T_new + 1), then adapt to per-step PF (T_new) + idx_states = ( + torch.arange(T_new + 1, device=n_prevs.device).unsqueeze(1).expand(-1, bs) + ) + prev_mask = (idx_states < n_prevs).transpose(0, 1) # (bs, T_new+1) + action_recon_mask = ( + (idx_states[:-1] >= n_prevs) & (idx_states[:-1] <= (n_prevs + n_recons - 1)) + ).transpose( + 0, 1 + ) # (bs, T_new) + action_recon_mask2 = (idx_states[:max_n_recon] <= (n_recons - 1)).transpose( + 0, 1 + ) # (bs, max_n_recon) + + # Transpose PF tensors to (bs, time) + prev_pf_t = prev_log_pf.transpose(0, 1) + recon_pf_t = recon_log_pf.transpose(0, 1) + + # Allocate and fill spliced PF + new_pf_t = torch.full((bs, T_new), 0.0).to(device=device, dtype=prev_pf_t.dtype) + prev_mask_trunc = prev_mask[:, :max_n_prev] + new_pf_t[prev_mask[:, :-1]] = prev_pf_t[:, :max_n_prev][prev_mask_trunc] + new_pf_t[action_recon_mask] = recon_pf_t[action_recon_mask2] + + return new_pf_t.transpose(0, 1) + def local_search( self, env: Env, @@ -392,9 +534,10 @@ def local_search( This method implements the local search algorithm by: 1. For each trajectory, performing K backward steps to reach a junction state - 2. Reconstructing the trajectory from the junction state using the forward policy - 3. Optionally applying Metropolis-Hastings acceptance criterion to decide whether - to accept the new trajectory. + 2. Reconstructing the trajectory from the junction state using the forward + policy estimator. + 3. Optionally applying Metropolis-Hastings acceptance criterion to decide + whether to accept the new trajectory. Args: env: The environment to sample trajectories from. @@ -421,26 +564,38 @@ def local_search( - A Trajectories object refined by local search - A boolean tensor indicating which trajectories were updated """ + # High-level outline: + # 1) Choose the backtrack length K per-trajectory (either from `back_steps` or + # via a ratio of current lengths). This determines how much prefix to keep. + # 2) Sample backward trajectories from terminal states using the backward + # policy estimator, then reverse them into forward-time to obtain the prefix + # trajectories. + # 3) Extract the junction states at step `n_prevs = L - K - 1` and reconstruct + # forward suffixes using the forward policy starting from those junctions. + # 4) Optionally compute PF/PB per-step log-probabilities for MH acceptance. + # 5) Splice prefix and suffix into candidate new trajectories. + # 6) Accept/reject (MH or greedy by reward), return the candidates and update + # mask. + # TODO: Implement local search for GraphStates. - if isinstance(env.States, GraphStates): + # Guard against graph-based states; not yet supported. + if issubclass(env.States, GraphStates): raise NotImplementedError("Local search is not implemented for GraphStates.") + # Ensure PF/PB log-probabilities are computed when MH acceptance is requested. save_logprobs = save_logprobs or use_metropolis_hastings - # K-step backward sampling with the backward estimator, - # where K is the number of backward steps used in https://arxiv.org/abs/2202.01361. - if back_steps is None: - assert ( - back_ratio is not None and 0 < back_ratio <= 1 - ), "Either kwarg `back_steps` or `back_ratio` must be specified" - K = torch.ceil(back_ratio * (trajectories.terminating_idx - 1)).long() - else: - K = torch.where( - back_steps > trajectories.terminating_idx, - trajectories.terminating_idx, - back_steps, - ) + # 1) K-step backward sampling with the backward estimator, where K is the + # number of backward steps. When specified via `back_ratio`, K is proportional + # to the previous trajectory length; otherwise clamp the provided `back_steps` + # to valid bounds. This is used in https://arxiv.org/abs/2202.01361. + # Compute per-trajectory backtrack length K. + K = self._compute_back_steps( + trajectories.terminating_idx, back_steps, back_ratio + ) + + # 1) Backward sampling from terminal states (PolicyMixin-driven Sampler). prev_trajectories = self.backward_sampler.sample_trajectories( env, states=trajectories.terminating_states, @@ -454,26 +609,24 @@ def local_search( # This is called `prev_trajectories` since they are the trajectories before # the local search. The `new_trajectories` will be obtained by performing local # search on them. + # Convert backward trajectories to forward-time ordering (s0 -> ... -> sf). prev_trajectories = prev_trajectories.reverse_backward_trajectories() assert prev_trajectories.log_rewards is not None - # Reconstructing with self.estimator + # Reconstruct suffixes from junction states using self.estimator. n_prevs = prev_trajectories.terminating_idx - K - 1 - junction_states_tsr = torch.gather( - prev_trajectories.states.tensor, - 0, - (n_prevs).view(1, -1, 1).expand(-1, -1, *trajectories.states.state_shape), - ).squeeze(0) - recon_trajectories = super().sample_trajectories( + recon_trajectories = self._reconstruct_from_junctions( env, - states=env.states_from_tensor(junction_states_tsr), - conditioning=conditioning, - save_estimator_outputs=save_estimator_outputs, - save_logprobs=save_logprobs, + prev_trajectories, + n_prevs, + conditioning, + save_estimator_outputs, + save_logprobs, **policy_kwargs, ) # Calculate the log probabilities as needed. + # 4) PF on prefix and reconstructed suffix (needed for MH or for logging). prev_trajectories_log_pf = ( get_trajectory_pfs(pf=self.estimator, trajectories=prev_trajectories) if save_logprobs @@ -484,6 +637,7 @@ def local_search( if save_logprobs else None ) + # 5) PB on prefix and reconstructed suffix (needed only for MH acceptance). prev_trajectories_log_pb = ( get_trajectory_pbs( pb=self.backward_sampler.estimator, trajectories=prev_trajectories @@ -499,43 +653,89 @@ def local_search( else None ) - ( - new_trajectories, - new_trajectories_log_pf, - new_trajectories_log_pb, - ) = self._combine_prev_and_recon_trajectories( + # 6) Splice prefix and suffix into candidate trajectories. + new_trajectories = prev_trajectories.splice_from_reconstruction( n_prevs=n_prevs, - prev_trajectories=prev_trajectories, - recon_trajectories=recon_trajectories, - prev_trajectories_log_pf=prev_trajectories_log_pf, - recon_trajectories_log_pf=recon_trajectories_log_pf, - prev_trajectories_log_pb=prev_trajectories_log_pb, - recon_trajectories_log_pb=recon_trajectories_log_pb, + recon=recon_trajectories, debug=debug, ) - if use_metropolis_hastings: - assert ( - prev_trajectories_log_pb is not None - and new_trajectories_log_pf is not None - and new_trajectories_log_pb is not None - and prev_trajectories_log_pf is not None - and new_trajectories.log_rewards is not None + # Build PF by splicing prev/recon PF log-probs. + if save_logprobs: + assert prev_trajectories_log_pf is not None + assert recon_trajectories_log_pf is not None + T_new = int(new_trajectories.max_length) + n_recons = recon_trajectories.terminating_idx + new_trajectories.log_probs = self._splice_pf( + n_prevs=n_prevs, + prev_log_pf=prev_trajectories_log_pf, + n_recons=n_recons, + recon_log_pf=recon_trajectories_log_pf, + T_new=T_new, + device=new_trajectories.states.device, ) + # Compute PF/PB sums for MH without building full per-step spliced tensors. + if use_metropolis_hastings: + assert prev_trajectories_log_pb is not None + assert prev_trajectories_log_pf is not None + assert recon_trajectories_log_pb is not None + assert recon_trajectories_log_pf is not None + + # Sum over prefix/suffix per trajectory using n_prevs and recon lengths. + # Prefix sums: [0:n_prev) + sum_prev_pf = prev_trajectories_log_pf.cumsum(0) + sum_prev_pb = prev_trajectories_log_pb.cumsum(0) + prefix_idx = (n_prevs - 1).clamp_min(0).view(1, -1) + prefix_pf = sum_prev_pf.gather(0, prefix_idx).squeeze(0) + prefix_pb = sum_prev_pb.gather(0, prefix_idx).squeeze(0) + zero_prefix = torch.zeros_like(prefix_pf) + prefix_pf = torch.where(n_prevs > 0, prefix_pf, zero_prefix) + prefix_pb = torch.where(n_prevs > 0, prefix_pb, zero_prefix) + + # Suffix sums from recon: [0:n_recon) + n_recons = recon_trajectories.terminating_idx + sum_recon_pf = recon_trajectories_log_pf.cumsum(0) + sum_recon_pb = recon_trajectories_log_pb.cumsum(0) + suffix_idx = (n_recons - 1).clamp_min(0).view(1, -1) + suffix_pf = sum_recon_pf.gather(0, suffix_idx).squeeze(0) + suffix_pb = sum_recon_pb.gather(0, suffix_idx).squeeze(0) + zero_suffix = torch.zeros_like(suffix_pf) + suffix_pf = torch.where(n_recons > 0, suffix_pf, zero_suffix) + suffix_pb = torch.where(n_recons > 0, suffix_pb, zero_suffix) + + # 7) Accept/reject. With MH, accept with probability: + # min\{1, R(x') p_B(x->s') p_F(s'->x') / [R(x) p_B(x'->s') p_F(s'->x)]\}. + # Without MH, accept when the episodic reward improves (ties accepted). + if use_metropolis_hastings: + assert prev_trajectories_log_pb is not None + assert prev_trajectories_log_pf is not None + assert recon_trajectories_log_pb is not None + assert recon_trajectories_log_pf is not None + assert prev_trajectories.log_rewards is not None + assert new_trajectories.log_rewards is not None + # The acceptance ratio is: min(1, R(x')p(x->s'->x') / R(x)p(x'->s'-> x)) # Also, note this: # p(x->s'->x') / p(x'->s'-> x)) # = p_B(x->s')p_F(s'->x') / p_B(x'->s')p_F(s'->x) # = p_B(x->s'->s0)p_F(s0->s'->x') / p_B(x'->s'->s0)p_F(s0->s'->x) # = p_B(tau|x)p_F(tau') / p_B(tau'|x')p_F(tau) + # Combine episodic reward and log-prob sums, clamp at 0 (min with 1 in prob + # space). + prev_total_pf = prev_trajectories_log_pf.sum(0) + prev_total_pb = prev_trajectories_log_pb.sum(0) + assert isinstance(prefix_pf, torch.Tensor) + assert isinstance(prefix_pb, torch.Tensor) + new_total_pf = prefix_pf + suffix_pf + new_total_pb = prefix_pb + suffix_pb log_accept_ratio = torch.clamp_max( new_trajectories.log_rewards - + prev_trajectories_log_pb.sum(0) - + new_trajectories_log_pf.sum(0) + + prev_total_pb + + new_total_pf - prev_trajectories.log_rewards - - new_trajectories_log_pb.sum(0) - - prev_trajectories_log_pf.sum(0), + - new_total_pb + - prev_total_pf, 0.0, ) is_updated = torch.rand( @@ -601,6 +801,9 @@ def sample_trajectories( The final trajectories container contains both the initial trajectories and the improved trajectories from local search. """ + # Roll out an initial batch with the forward policy, then perform + # `n_local_search_loops` rounds of refinement. Each round appends + # one candidate per original seed trajectory to the container. trajectories = super().sample_trajectories( env, n, @@ -614,10 +817,12 @@ def sample_trajectories( if n is None: n = int(trajectories.n_trajectories) + # Indices referring to the current seed trajectories within the container. + # Initially these are the first `n` entries (the initial batch). search_indices = torch.arange(n, device=trajectories.states.device) for it in range(n_local_search_loops): - # Search phase + # Run a single local-search refinement on the current seeds. ls_trajectories, is_updated = self.local_search( env, trajectories[search_indices], @@ -629,262 +834,14 @@ def sample_trajectories( use_metropolis_hastings, **policy_kwargs, ) + # Append refined candidates; they occupy a new contiguous block at the end. trajectories.extend(ls_trajectories) + # Map accepted seeds to the indices of the just-appended block so that + # the next round uses the latest accepted candidates as seeds. last_indices = torch.arange( n * it, n * (it + 1), device=trajectories.states.device ) search_indices[is_updated] = last_indices[is_updated] return trajectories - - @staticmethod - def _combine_prev_and_recon_trajectories( # noqa: C901 - n_prevs: torch.Tensor, - prev_trajectories: Trajectories, - recon_trajectories: Trajectories, - prev_trajectories_log_pf: torch.Tensor | None = None, - recon_trajectories_log_pf: torch.Tensor | None = None, - prev_trajectories_log_pb: torch.Tensor | None = None, - recon_trajectories_log_pb: torch.Tensor | None = None, - debug: bool = False, - ) -> tuple[Trajectories, torch.Tensor | None, torch.Tensor | None]: - """Combines previous and reconstructed trajectories to create new trajectories. - - This static method combines two trajectory segments: `prev_trajectories` and - `recon_trajectories` to create `new_trajectories`. Specifically, - `new_trajectories` is constructed by replacing certain portion of the - `prev_trajectories` with `recon_trajectories`. See self.local_search for how - to generate `prev_trajectories` and `recon_trajectories`. - - Args: - n_prevs: Tensor indicating how many steps to take from prev_trajectories - for each trajectory in the batch. - prev_trajectories: Trajectories obtained from backward sampling. - recon_trajectories: Trajectories obtained from forward reconstruction. - prev_trajectories_log_pf: Optional log probabilities for forward policy - on `prev_trajectories`. - recon_trajectories_log_pf: Optional log probabilities for forward policy - on `recon_trajectories`. - prev_trajectories_log_pb: Optional log probabilities for backward policy - on `prev_trajectories`. - recon_trajectories_log_pb: Optional log probabilities for backward policy - on `recon_trajectories`. - debug: If True, performs additional validation checks for debugging. - - Returns: - A tuple containing: - - the `new_trajectories` Trajectories object with the combined trajectories - - the `new_trajectories_log_pf` tensor of combined forward log probabilities - - the `new_trajectories_log_pb` tensor of combined backward log probabilities - - Note: - This method performs complex tensor operations to efficiently combine - trajectory segments. The debug mode compares the vectorized approach - with a for-loop implementation to ensure correctness. - """ - new_trajectories_log_pf = None - new_trajectories_log_pb = None - - bs = prev_trajectories.n_trajectories - device = prev_trajectories.states.device - env = prev_trajectories.env - - # Obtain full trajectories by concatenating the backward and forward parts. - max_n_prev = n_prevs.max() - n_recons = recon_trajectories.terminating_idx - max_n_recon = n_recons.max() - - new_trajectories_log_rewards = recon_trajectories.log_rewards # Episodic reward - new_trajectories_dones = n_prevs + n_recons - max_traj_len = int(new_trajectories_dones.max().item()) - - # Create helper indices and masks - idx = torch.arange(max_traj_len + 1).unsqueeze(1).expand(-1, bs).to(n_prevs) - - prev_mask = idx < n_prevs - state_recon_mask = (idx >= n_prevs) * (idx <= n_prevs + n_recons) - state_recon_mask2 = idx[: max_n_recon + 1] <= n_recons - action_recon_mask = (idx[:-1] >= n_prevs) * (idx[:-1] <= n_prevs + n_recons - 1) - action_recon_mask2 = idx[:max_n_recon] <= n_recons - 1 - - # Transpose for easier indexing - prev_trajectories_states_tsr = prev_trajectories.states.tensor.transpose(0, 1) - prev_trajectories_actions_tsr = prev_trajectories.actions.tensor.transpose(0, 1) - recon_trajectories_states_tsr = recon_trajectories.states.tensor.transpose(0, 1) - recon_trajectories_actions_tsr = recon_trajectories.actions.tensor.transpose( - 0, 1 - ) - prev_mask = prev_mask.transpose(0, 1) - state_recon_mask = state_recon_mask.transpose(0, 1) - state_recon_mask2 = state_recon_mask2.transpose(0, 1) - action_recon_mask = action_recon_mask.transpose(0, 1) - action_recon_mask2 = action_recon_mask2.transpose(0, 1) - - # Prepare the new states and actions - # Note that these are initialized in transposed shapes - new_trajectories_states_tsr = env.sf.repeat(bs, max_traj_len + 1, 1).to( - prev_trajectories.states.tensor - ) - new_trajectories_actions_tsr = env.dummy_action.repeat(bs, max_traj_len, 1).to( - prev_trajectories.actions.tensor - ) - - # Assign the first part (backtracked from backward policy) of the trajectory - prev_mask_truc = prev_mask[:, :max_n_prev] - new_trajectories_states_tsr[prev_mask] = prev_trajectories_states_tsr[ - :, :max_n_prev - ][prev_mask_truc] - new_trajectories_actions_tsr[prev_mask[:, :-1]] = prev_trajectories_actions_tsr[ - :, :max_n_prev - ][prev_mask_truc] - - # Assign the second part (reconstructed from forward policy) of the trajectory - new_trajectories_states_tsr[state_recon_mask] = recon_trajectories_states_tsr[ - state_recon_mask2 - ] - new_trajectories_actions_tsr[action_recon_mask] = recon_trajectories_actions_tsr[ - action_recon_mask2 - ] - - # Transpose back - new_trajectories_states_tsr = new_trajectories_states_tsr.transpose(0, 1) - new_trajectories_actions_tsr = new_trajectories_actions_tsr.transpose(0, 1) - - # Similarly, combine log_pf and log_pb if needed - if ( - prev_trajectories_log_pf is not None - and recon_trajectories_log_pf is not None - ): - prev_trajectories_log_pf = prev_trajectories_log_pf.transpose(0, 1) - recon_trajectories_log_pf = recon_trajectories_log_pf.transpose(0, 1) - new_trajectories_log_pf = torch.full((bs, max_traj_len), 0.0).to( - device=device, dtype=prev_trajectories_log_pf.dtype # type: ignore - ) - new_trajectories_log_pf[prev_mask[:, :-1]] = prev_trajectories_log_pf[ # type: ignore - :, :max_n_prev - ][ - prev_mask_truc - ] - new_trajectories_log_pf[action_recon_mask] = recon_trajectories_log_pf[ # type: ignore - action_recon_mask2 - ] - new_trajectories_log_pf = new_trajectories_log_pf.transpose(0, 1) - if ( - prev_trajectories_log_pb is not None - and recon_trajectories_log_pb is not None - ): - prev_trajectories_log_pb = prev_trajectories_log_pb.transpose(0, 1) - recon_trajectories_log_pb = recon_trajectories_log_pb.transpose(0, 1) - new_trajectories_log_pb = torch.full((bs, max_traj_len), 0.0).to( - device=device, dtype=prev_trajectories_log_pb.dtype # type: ignore - ) - new_trajectories_log_pb[prev_mask[:, :-1]] = prev_trajectories_log_pb[ # type: ignore - :, :max_n_prev - ][ - prev_mask_truc - ] - new_trajectories_log_pb[action_recon_mask] = recon_trajectories_log_pb[ # type: ignore - action_recon_mask2 - ] - new_trajectories_log_pb = new_trajectories_log_pb.transpose(0, 1) - - # ------------------------------ DEBUG ------------------------------ - # If `debug` is True (expected only when testing), compare the - # vectorized approach's results (above) to the for-loop results (below). - if debug: - _new_trajectories_states_tsr = env.sf.repeat(max_traj_len + 1, bs, 1).to( - prev_trajectories.states.tensor - ) - _new_trajectories_actions_tsr = env.dummy_action.repeat( - max_traj_len, bs, 1 - ).to(prev_trajectories.actions.tensor) - - if ( - prev_trajectories_log_pf is not None - and recon_trajectories_log_pf is not None - ): - _new_trajectories_log_pf = torch.full((max_traj_len, bs), 0.0).to( - device=device, dtype=prev_trajectories_log_pf.dtype - ) - prev_trajectories_log_pf = prev_trajectories_log_pf.transpose(0, 1) - recon_trajectories_log_pf = recon_trajectories_log_pf.transpose(0, 1) - - if ( - prev_trajectories_log_pb is not None - and recon_trajectories_log_pb is not None - ): - _new_trajectories_log_pb = torch.full((max_traj_len, bs), 0.0).to( - device=device, dtype=prev_trajectories_log_pb.dtype - ) - prev_trajectories_log_pb = prev_trajectories_log_pb.transpose(0, 1) - recon_trajectories_log_pb = recon_trajectories_log_pb.transpose(0, 1) - - for i in range(bs): - _n_prev = n_prevs[i] - - # Backward part - _new_trajectories_states_tsr[: _n_prev + 1, i] = ( - prev_trajectories.states.tensor[: _n_prev + 1, i] - ) - _new_trajectories_actions_tsr[:_n_prev, i] = ( - prev_trajectories.actions.tensor[:_n_prev, i] - ) - - # Forward part - _len_recon = recon_trajectories.terminating_idx[i] - _new_trajectories_states_tsr[ - _n_prev + 1 : _n_prev + _len_recon + 1, i - ] = recon_trajectories.states.tensor[1 : _len_recon + 1, i] - _new_trajectories_actions_tsr[_n_prev : _n_prev + _len_recon, i] = ( - recon_trajectories.actions.tensor[:_len_recon, i] - ) - - if ( - prev_trajectories_log_pf is not None - and recon_trajectories_log_pf is not None - ): - _new_trajectories_log_pf[:_n_prev, i] = prev_trajectories_log_pf[ - :_n_prev, i - ] - _new_trajectories_log_pf[_n_prev : _n_prev + _len_recon, i] = ( - recon_trajectories_log_pf[:_len_recon, i] - ) - if ( - prev_trajectories_log_pb is not None - and recon_trajectories_log_pb is not None - ): - _new_trajectories_log_pb[:_n_prev, i] = prev_trajectories_log_pb[ - :_n_prev, i - ] - _new_trajectories_log_pb[_n_prev : _n_prev + _len_recon, i] = ( - recon_trajectories_log_pb[:_len_recon, i] - ) - - assert torch.all(_new_trajectories_states_tsr == new_trajectories_states_tsr) - assert torch.all( - _new_trajectories_actions_tsr == new_trajectories_actions_tsr - ) - if ( - prev_trajectories_log_pf is not None - and recon_trajectories_log_pf is not None - ): - assert torch.all(_new_trajectories_log_pf == new_trajectories_log_pf) - if ( - prev_trajectories_log_pb is not None - and recon_trajectories_log_pb is not None - ): - assert torch.all(_new_trajectories_log_pb == new_trajectories_log_pb) - - new_trajectories = Trajectories( - env=env, - states=env.states_from_tensor(new_trajectories_states_tsr), - conditioning=prev_trajectories.conditioning, - actions=env.actions_from_tensor(new_trajectories_actions_tsr), - terminating_idx=new_trajectories_dones, - is_backward=False, - log_rewards=new_trajectories_log_rewards, - log_probs=new_trajectories_log_pf, - ) - - return new_trajectories, new_trajectories_log_pf, new_trajectories_log_pb diff --git a/testing/test_local_search_sampler.py b/testing/test_local_search_sampler.py new file mode 100644 index 00000000..29d463d5 --- /dev/null +++ b/testing/test_local_search_sampler.py @@ -0,0 +1,387 @@ +# testing/test_local_search_sampler.py +import pytest +import torch + +from gfn.containers import Trajectories +from gfn.estimators import DiscretePolicyEstimator +from gfn.gym import Box, DiscreteEBM, HyperGrid +from gfn.gym.helpers.box_utils import BoxPBEstimator, BoxPBMLP, BoxPFEstimator, BoxPFMLP +from gfn.preprocessors import IdentityPreprocessor, KHotPreprocessor +from gfn.samplers import LocalSearchSampler, Sampler +from gfn.utils.modules import MLP +from gfn.utils.prob_calculations import get_trajectory_pbs, get_trajectory_pfs + + +def _make_env_estimators(env_name: str): + torch.manual_seed(0) + if env_name == "HyperGrid": + env = HyperGrid(ndim=2, height=5) + preproc = KHotPreprocessor(env.height, env.ndim) + assert isinstance(preproc.output_dim, int) + pf_module = MLP(preproc.output_dim, env.n_actions) + pb_module = MLP(preproc.output_dim, env.n_actions - 1) + pf = DiscretePolicyEstimator( + module=pf_module, + n_actions=env.n_actions, + is_backward=False, + preprocessor=preproc, + ) + pb = DiscretePolicyEstimator( + module=pb_module, + n_actions=env.n_actions, + is_backward=True, + preprocessor=preproc, + ) + + elif env_name == "DiscreteEBM": + env = DiscreteEBM(ndim=5) + preproc = IdentityPreprocessor(output_dim=env.state_shape[-1]) + assert isinstance(preproc.output_dim, int) + pf_module = MLP(preproc.output_dim, env.n_actions) + pb_module = MLP(preproc.output_dim, env.n_actions - 1) + pf = DiscretePolicyEstimator( + module=pf_module, + n_actions=env.n_actions, + is_backward=False, + preprocessor=preproc, + ) + pb = DiscretePolicyEstimator( + module=pb_module, + n_actions=env.n_actions, + is_backward=True, + preprocessor=preproc, + ) + + elif env_name == "Box": + env = Box(delta=0.1) + pf_module = BoxPFMLP( + hidden_dim=16, + n_hidden_layers=1, + n_components=1, + n_components_s0=1, + ) + pb_module = BoxPBMLP( + hidden_dim=16, + n_hidden_layers=1, + n_components=1, + trunk=pf_module.trunk, + ) + pf = BoxPFEstimator(env=env, module=pf_module, n_components=1, n_components_s0=1) + pb = BoxPBEstimator(env=env, module=pb_module, n_components=1) + + else: + raise ValueError(env_name) + return env, pf, pb + + +def _reference_local_search( + env, + trajectories: Trajectories, + pf_estimator, + pb_estimator, + *, + conditioning=None, + back_steps: torch.Tensor | None = None, + back_ratio: float | None = None, + use_metropolis_hastings: bool = True, + save_logprobs: bool = True, +): + # K selection identical to production + if back_steps is None: + assert (back_ratio is not None) and (0 < back_ratio <= 1) + K = torch.ceil(back_ratio * (trajectories.terminating_idx - 1)).long() + else: + K = torch.where( + back_steps > trajectories.terminating_idx, + trajectories.terminating_idx, + back_steps, + ) + + # Backward sampling from terminal states + bw_sampler = Sampler(pb_estimator) + prev_traj_bw = bw_sampler.sample_trajectories( + env, + states=trajectories.terminating_states, + conditioning=conditioning, + save_logprobs=save_logprobs or use_metropolis_hastings, + ) + # Reverse to forward-time + prev_traj = prev_traj_bw.reverse_backward_trajectories() + assert prev_traj.log_rewards is not None + + # Junctions + n_prevs = prev_traj.terminating_idx - K - 1 + jx_states = torch.gather( + prev_traj.states.tensor, + 0, + n_prevs.view(1, -1, 1).expand(-1, -1, *trajectories.states.state_shape), + ).squeeze(0) + # Reconstruct from junctions + fw_sampler = Sampler(pf_estimator) + recon_traj = fw_sampler.sample_trajectories( + env, + states=env.states_from_tensor(jx_states), + conditioning=conditioning, + save_logprobs=save_logprobs or use_metropolis_hastings, + ) + + # Per-step PF/PB for acceptance + prev_pf = ( + get_trajectory_pfs(pf_estimator, prev_traj) + if (save_logprobs or use_metropolis_hastings) + else None + ) + recon_pf = ( + get_trajectory_pfs(pf_estimator, recon_traj) + if (save_logprobs or use_metropolis_hastings) + else None + ) + prev_pb = ( + get_trajectory_pbs(pb_estimator, prev_traj) if use_metropolis_hastings else None + ) + recon_pb = ( + get_trajectory_pbs(pb_estimator, recon_traj) if use_metropolis_hastings else None + ) + + # For-loop splice to new trajectories (and log_pf/pb) to mirror production semantics + bs = prev_traj.n_trajectories + n_recons = recon_traj.terminating_idx + dones = (n_prevs + n_recons).to(torch.long) + max_len = int(dones.max().item()) + + new_states = env.States.make_sink_states( + (max_len + 1, bs), device=prev_traj.states.device + ) + new_actions = env.Actions.make_dummy_actions( + (max_len, bs), device=prev_traj.actions.device + ) + # Work with raw tensors during splice, then wrap once at the end. + new_states_tsr = new_states.tensor + new_actions_tsr = new_actions.tensor + new_log_pf = ( + torch.full((max_len, bs), 0.0, device=prev_traj.states.device) + if (prev_pf is not None and recon_pf is not None) + else None + ) + new_log_pb = ( + torch.full((max_len, bs), 0.0, device=prev_traj.states.device) + if (prev_pb is not None and recon_pb is not None) + else None + ) + + for i in range(bs): + npv = int(n_prevs[i].item()) + nrc = int(n_recons[i].item()) + + # prefix from prev_traj + new_states_tsr[: npv + 1, i] = prev_traj.states.tensor[: npv + 1, i] + new_actions_tsr[:npv, i] = prev_traj.actions.tensor[:npv, i] + # suffix from recon_traj (skip junction state duplication) + new_states_tsr[npv + 1 : npv + nrc + 1, i] = recon_traj.states.tensor[ + 1 : nrc + 1, i + ] + new_actions_tsr[npv : npv + nrc, i] = recon_traj.actions.tensor[:nrc, i] + + if new_log_pf is not None: + new_log_pf[:npv, i] = prev_pf[:npv, i] # type: ignore + new_log_pf[npv : npv + nrc, i] = recon_pf[:nrc, i] # type: ignore + if new_log_pb is not None: + new_log_pb[:npv, i] = prev_pb[:npv, i] # type: ignore + new_log_pb[npv : npv + nrc, i] = recon_pb[:nrc, i] # type: ignore + + new_traj = Trajectories( + env=env, + states=env.states_from_tensor(new_states_tsr), + conditioning=prev_traj.conditioning, + actions=env.actions_from_tensor(new_actions_tsr), + terminating_idx=dones, + is_backward=False, + log_rewards=recon_traj.log_rewards, # episodic reward per splice + log_probs=new_log_pf, + ) + + # Acceptance + if use_metropolis_hastings: + assert ( + prev_pf is not None + and recon_pf is not None + and prev_pb is not None + and recon_pb is not None + ) + log_accept_ratio = torch.clamp_max( + new_traj.log_rewards # type: ignore + + prev_pb.sum(0) + + (new_log_pf if new_log_pf is not None else recon_pf).sum(0) + - prev_traj.log_rewards + - (new_log_pb if new_log_pb is not None else recon_pb).sum(0) + - prev_pf.sum(0), + 0.0, + ) + accept = torch.rand(bs, device=log_accept_ratio.device) < torch.exp( + log_accept_ratio + ) + else: + accept = prev_traj.log_rewards <= new_traj.log_rewards # type: ignore + + return new_traj, accept + + +@pytest.mark.parametrize("env_name", ["HyperGrid", "DiscreteEBM", "Box"]) +def test_local_search_back_steps_vs_back_ratio(env_name: str): + env, pf, pb = _make_env_estimators(env_name) + sampler = LocalSearchSampler(pf, pb) + + torch.manual_seed(123) + base = sampler.sample_trajectories(env, n=4) + + min_len = int(base.terminating_idx.min().item()) + desired_k = max(min_len // 2, 1) + back_steps = torch.full((base.n_trajectories,), desired_k, device=base.states.device) + back_ratio = float(desired_k / max(min_len - 1, 1)) + + traj_steps, upd_steps = sampler.local_search( + env, + base, + save_logprobs=True, + back_steps=back_steps, + use_metropolis_hastings=False, + ) + traj_ratio, upd_ratio = sampler.local_search( + env, + base, + save_logprobs=True, + back_ratio=back_ratio, + use_metropolis_hastings=False, + ) + + assert traj_steps.n_trajectories == base.n_trajectories + assert traj_ratio.n_trajectories == base.n_trajectories + assert upd_steps.shape == (base.n_trajectories,) + assert upd_ratio.shape == (base.n_trajectories,) + assert traj_steps.actions.batch_shape == traj_steps.log_probs.shape # type: ignore + assert traj_ratio.actions.batch_shape == traj_ratio.log_probs.shape # type: ignore + + +@pytest.mark.parametrize("env_name", ["HyperGrid", "DiscreteEBM", "Box"]) +@pytest.mark.parametrize("use_mh", [True, False]) +def test_local_search_acceptance_mask_and_shapes(env_name: str, use_mh: bool): + env, pf, pb = _make_env_estimators(env_name) + sampler = LocalSearchSampler(pf, pb) + + torch.manual_seed(321) + base = sampler.sample_trajectories(env, n=3, save_logprobs=True) + new_traj, is_updated = sampler.local_search( + env, + base, + save_logprobs=True, + back_ratio=0.5, + use_metropolis_hastings=use_mh, + debug=True, + ) + assert isinstance(is_updated, torch.Tensor) and is_updated.dtype == torch.bool + assert is_updated.shape == (new_traj.n_trajectories,) + if use_mh: + lp_pf = get_trajectory_pfs(pf, new_traj, recalculate_all_logprobs=False) + lp_pb = get_trajectory_pbs(pb, new_traj) + assert lp_pf.shape == new_traj.actions.batch_shape + assert lp_pb.shape == new_traj.actions.batch_shape + + +@pytest.mark.parametrize("env_name", ["HyperGrid", "DiscreteEBM", "Box"]) +def test_sample_trajectories_with_local_search_loops(env_name: str): + env, pf, pb = _make_env_estimators(env_name) + sampler = LocalSearchSampler(pf, pb) + + torch.manual_seed(999) + n = 5 + trajs = sampler.sample_trajectories( + env, + n=n, + save_logprobs=False, + n_local_search_loops=2, + back_ratio=0.5, + use_metropolis_hastings=False, + ) + assert trajs.n_trajectories == n * (1 + 2) + assert trajs.actions.batch_shape[1] == trajs.n_trajectories + assert trajs.log_rewards is not None + + +@pytest.mark.parametrize("env_name", ["HyperGrid", "DiscreteEBM", "Box"]) +def test_local_search_large_back_steps_are_handled_when_safe(env_name: str): + env, pf, pb = _make_env_estimators(env_name) + sampler = LocalSearchSampler(pf, pb) + + torch.manual_seed(7) + base = sampler.sample_trajectories(env, n=4) + # Very large back_steps; adjust to ensure K <= L-1 so that n_prevs >= 0 + # (the current implementation requires this to avoid negative gather indices). + back_steps = base.terminating_idx + 100 + back_steps = torch.minimum(back_steps, base.terminating_idx - 1) + traj, is_updated = sampler.local_search( + env, + base, + back_steps=back_steps, + save_logprobs=True, + use_metropolis_hastings=False, + ) + assert traj.n_trajectories == base.n_trajectories + assert is_updated.shape == (base.n_trajectories,) + + +@pytest.mark.parametrize("env_name", ["HyperGrid", "DiscreteEBM", "Box"]) +@pytest.mark.parametrize("use_mh", [True, False]) +def test_local_search_reference_impl_end_to_end_match(env_name: str, use_mh: bool): + """ + End-to-end regression: the production local_search must match a + reference (for-loop splice) implementation when RNG is identical. + This guards against behavior drift during refactors. + """ + env, pf, pb = _make_env_estimators(env_name) + sampler = LocalSearchSampler(pf, pb) + + torch.manual_seed(42) + base = sampler.sample_trajectories(env, n=3, save_logprobs=True) + + # Ensure identical RNG for both production and reference runs + rng = torch.get_rng_state() + + prod_traj, prod_upd = sampler.local_search( + env, + base, + save_logprobs=True, + back_ratio=0.5, + use_metropolis_hastings=use_mh, + debug=False, + ) + + torch.set_rng_state(rng) + ref_traj, ref_upd = _reference_local_search( + env, + base, + pf_estimator=pf, + pb_estimator=pb, + conditioning=None, + back_ratio=0.5, + use_metropolis_hastings=use_mh, + save_logprobs=True, + ) + + # Compare tensors (align dtypes if needed) + assert torch.allclose( + prod_traj.states.tensor.to(ref_traj.states.tensor.dtype), ref_traj.states.tensor + ) + assert torch.allclose( + prod_traj.actions.tensor.to(ref_traj.actions.tensor.dtype), + ref_traj.actions.tensor, + ) + assert torch.equal(prod_upd, ref_upd) + # log_rewards and log_probs comparability + assert torch.allclose(prod_traj.log_rewards.to(ref_traj.log_rewards.dtype), ref_traj.log_rewards) # type: ignore + if prod_traj.log_probs is None: + assert ref_traj.log_probs is None + else: + assert ref_traj.log_probs is not None + assert torch.allclose( + prod_traj.log_probs.to(ref_traj.log_probs.dtype), ref_traj.log_probs + ) diff --git a/tutorials/examples/train_hypergrid_exploration_examples.py b/tutorials/examples/train_hypergrid_exploration_examples.py index 67da4d5d..a53bb5c9 100644 --- a/tutorials/examples/train_hypergrid_exploration_examples.py +++ b/tutorials/examples/train_hypergrid_exploration_examples.py @@ -19,7 +19,6 @@ import matplotlib.pyplot as plt import pandas as pd -import seaborn as sns import torch from tqdm import tqdm @@ -433,6 +432,8 @@ def main(args): # Adjust layout and save to home directory. if args.plot: + import seaborn as sns + # Create a figure with 3 subplots arranged horizontally fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 6))