Skip to content

Conversation

@josephdviviano
Copy link
Collaborator

@josephdviviano josephdviviano commented Oct 14, 2025

  • I've read the .github/CONTRIBUTING.md file
  • My code follows the typing guidelines
  • I've added appropriate tests
  • I've run pre-commit hooks locally

Description

WIP - trying to refactor the LocalSearchSampler logic to take better advantage of the PolicyMixin framework.

For now, consists only of comments and tests (the tests will be useful to ensure any refactoring does not break the method).

Some notes on the plan:

design ideas to simplify LocalSearchSampler using PolicyMixin

  • Centralize “K” computation: Extract a helper on LocalSearchSampler (e.g., _compute_back_steps(terminating_idx, back_steps, back_ratio)) to unify clamping and validation. This makes the local_search body shorter and easier to follow.
  • Encapsulate “junction reconstruction”: Extract a helper (e.g., _reconstruct_from_junctions(junction_states, ...)) that calls super().sample_trajectories. This isolates the PolicyMixin-dependent sampling and clarifies that reconstruction is just a standard rollout starting from junction states.
  • Factor acceptance computation: Extract _metropolis_hastings_accept(log_r, log_pf_prev, log_pf_new, log_pb_prev, log_pb_new) that returns the boolean accept mask. This isolates the math (already correct) and makes local_search linear and readable.
  • Clarify PB/PF reuse semantics: When save_logprobs=True, we can set recalculate_all_logprobs=False in get_trajectory_pfs so cached per-step log-probs (already saved in Trajectories.log_probs via PolicyMixin) are reused. This avoids redundant evaluation without changing outputs.
  • Make splice operation a container method (optional): Move _combine_prev_and_recon_trajectories into a Trajectories.splice(prev_len, recon)-like API (keeping identical tensor ops). This shifts responsibility to the container and shortens the sampler. If you prefer to keep it static, at least rename and document it as “splice” for intent.
  • Graph states guard correctness: The current check uses isinstance(env.States, GraphStates) which never triggers because env.States is a class. Replace with issubclass(type(env.states), GraphStates) or issubclass(env.States, GraphStates) so unsupported cases fail early. This is a correctness fix, not a logic change.
  • Conditioning handling consistency: Ensure we rely on the PolicyMixin broadcasting in both PF/PB calculators and reconstruction calls (we already do). Document that conditioning is passed through as-is and broadcast in Sampler.
  • Naming/structure: Use a small “pipeline” structure in local_search: compute K → backward sample → reverse → pick junctions → reconstruct → get logprobs → splice → accept. With the three helpers above, the main body becomes a straightforward sequence of calls.

@josephdviviano josephdviviano self-assigned this Oct 14, 2025
@josephdviviano josephdviviano marked this pull request as draft October 14, 2025 21:29
Copy link
Collaborator

@saleml saleml left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Solid PR - adds much-needed documentation and comprehensive tests for the local search sampler. Code is production-ready.

  • Excellent inline comments explaining the algorithm flow
  • Smart reference implementation for regression testing
  • Good edge case coverage

Can you please check if Back-step clamping should be terminating_idx - 1 ?

assert (back_ratio is not None) and (0 < back_ratio <= 1)
K = torch.ceil(back_ratio * (trajectories.terminating_idx - 1)).long()
else:
K = torch.where(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be trajectories.terminating_idx - 1 to ensure n_prevs >= 0


# Adjust layout and save to home directory.
if args.plot:
import seaborn as sns
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is to fix an annoying auxiliary issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants