Skip to content

Conversation

@hyeok9855
Copy link
Collaborator

@hyeok9855 hyeok9855 commented Sep 25, 2025

  • I've read the .github/CONTRIBUTING.md file
  • My code follows the typing guidelines
  • I've added appropriate tests
  • I've run pre-commit hooks locally

Description

  1. Add backward_logprobs and backward_estimator_outputs in Trajectories, so that it doesn't have to be recomputed every time to get the loss.
  2. We can input the precomputed log_rewards for backward sampling.
  3. A few refactorings in Samplers
  4. Remove is_backward flag from Transitions since we do not properly support it.

Discussion needed

We can't calculate loss with "backward" Trajectories. We always need to use .reverse_backward_trajectories before the loss calculation. The question is, is it worth supporting direct loss calculation using the backward Trajectories? This won't be too hard to implement.

This was also the case for the Transitions (although I removed the backward flag from Transitions; this can be easily reverted). If we want to support training with backward Trajectories, do we also want the same for the backward Transitions?

TODO

Currently, we can store only one of the PF or PB during sampling (i.e., we always need to call our NN module for loss calculation). My next PR will resolve this by calculating both PF and PB within a single sampling process.

@hyeok9855 hyeok9855 self-assigned this Sep 25, 2025
Copy link
Collaborator

@josephdviviano josephdviviano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some questions / comments, but I do like this direction quite a bit! Thank you!

terminating_idx: torch.Tensor | None = None,
is_backward: bool = False,
log_rewards: torch.Tensor | None = None,
log_probs: torch.Tensor | None = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is going to be an annoying change but I suspect we should be explicit about forward_log_probs and forward_estimator_outputs in the namespace as well.

self.backward_estimator_outputs.shape[: len(self.states.batch_shape)]
== self.actions.batch_shape
and self.backward_estimator_outputs.is_floating_point()
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could probably write a private _check_estimator_outputs() and check_log_probs() method which deduplicates this logic.

def terminating_states(self) -> States:
"""The terminating states of the trajectories.
"""The terminating states of the trajectories. If backward, the terminating states
are in 0-th position.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Worth explicitly stating whether these are s0 or s0+1?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Worth specifying the terminating state of the backward trajectory is s0, and the terminating state of the reversed forward trajectory is NOT s0.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to be careful about the definition of reversed and backward. A backward trajectory (evaluated under pb) runs in the opposite direction of a forward trajectory (evaluated under pf). These can be reversed, but the reverse of a forward trajectory still corresponds to pf, and a reversed backward trajectory still corresponds to pb. |

Can we add these definitions here, and ensure that the language is consistent throughout.

(max_len + 2, len(self)), device=states.device
)
# shape (max_len + 2, n_trajectories, *state_dim)
actions = self.actions # shape (max_len, n_trajectories *action_dim)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it now possible to replace reverse_backward_trajectories() with simply reverse() which works on either forward or backward trajectories?

We could keep reverse_backward_trajectories() as essentially an alias:

def reverse_backward_trajectories(self):
    assert self.is_backward()
    self.reverse()

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'm wondering if we should have a @property for both Transitions and Trajectories which is self.has_forward and self.has_backward for readability.

Raises:
ValueError: If backward transitions are provided.
"""
if transitions.is_backward:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need some kind of check here. For example, transitions.has_forward()


optimizer.zero_grad()
loss = gflownet.loss(env, buffer_trajectories, recalculate_all_logprobs=True)
loss = gflownet.loss(env, buffer_trajectories, recalculate_all_logprobs=False)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is because the forward logprobs are already populated? How?

# terminating states (can be passed to log_reward fn)
if self.is_backward:
# [IMPORTANT ASSUMPTION] When backward sampling, all provided states are the
# *terminating* states (can be passed to log_reward fn)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this assumption?

transition.
"""
if transitions.is_backward:
raise ValueError("Backward transitions are not supported")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm confused about how this will work in practice. I suspect we still need some kind of assertion here.

@josephdviviano josephdviviano self-assigned this Sep 25, 2025
@hyeok9855 hyeok9855 marked this pull request as draft October 27, 2025 13:47
@hyeok9855 hyeok9855 mentioned this pull request Nov 13, 2025
4 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants