-
Notifications
You must be signed in to change notification settings - Fork 52
Add backward logprobs and estimator_outputs in Trajectories #396
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
josephdviviano
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some questions / comments, but I do like this direction quite a bit! Thank you!
| terminating_idx: torch.Tensor | None = None, | ||
| is_backward: bool = False, | ||
| log_rewards: torch.Tensor | None = None, | ||
| log_probs: torch.Tensor | None = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is going to be an annoying change but I suspect we should be explicit about forward_log_probs and forward_estimator_outputs in the namespace as well.
| self.backward_estimator_outputs.shape[: len(self.states.batch_shape)] | ||
| == self.actions.batch_shape | ||
| and self.backward_estimator_outputs.is_floating_point() | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could probably write a private _check_estimator_outputs() and check_log_probs() method which deduplicates this logic.
| def terminating_states(self) -> States: | ||
| """The terminating states of the trajectories. | ||
| """The terminating states of the trajectories. If backward, the terminating states | ||
| are in 0-th position. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Worth explicitly stating whether these are s0 or s0+1?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Worth specifying the terminating state of the backward trajectory is s0, and the terminating state of the reversed forward trajectory is NOT s0.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to be careful about the definition of reversed and backward. A backward trajectory (evaluated under pb) runs in the opposite direction of a forward trajectory (evaluated under pf). These can be reversed, but the reverse of a forward trajectory still corresponds to pf, and a reversed backward trajectory still corresponds to pb. |
Can we add these definitions here, and ensure that the language is consistent throughout.
| (max_len + 2, len(self)), device=states.device | ||
| ) | ||
| # shape (max_len + 2, n_trajectories, *state_dim) | ||
| actions = self.actions # shape (max_len, n_trajectories *action_dim) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it now possible to replace reverse_backward_trajectories() with simply reverse() which works on either forward or backward trajectories?
We could keep reverse_backward_trajectories() as essentially an alias:
def reverse_backward_trajectories(self):
assert self.is_backward()
self.reverse()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i'm wondering if we should have a @property for both Transitions and Trajectories which is self.has_forward and self.has_backward for readability.
| Raises: | ||
| ValueError: If backward transitions are provided. | ||
| """ | ||
| if transitions.is_backward: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we need some kind of check here. For example, transitions.has_forward()
|
|
||
| optimizer.zero_grad() | ||
| loss = gflownet.loss(env, buffer_trajectories, recalculate_all_logprobs=True) | ||
| loss = gflownet.loss(env, buffer_trajectories, recalculate_all_logprobs=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is because the forward logprobs are already populated? How?
| # terminating states (can be passed to log_reward fn) | ||
| if self.is_backward: | ||
| # [IMPORTANT ASSUMPTION] When backward sampling, all provided states are the | ||
| # *terminating* states (can be passed to log_reward fn) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need this assumption?
| transition. | ||
| """ | ||
| if transitions.is_backward: | ||
| raise ValueError("Backward transitions are not supported") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm confused about how this will work in practice. I suspect we still need some kind of assertion here.
Description
backward_logprobsandbackward_estimator_outputsinTrajectories, so that it doesn't have to be recomputed every time to get the loss.log_rewardsfor backward sampling.is_backwardflag fromTransitionssince we do not properly support it.Discussion needed
We can't calculate loss with "backward" Trajectories. We always need to use
.reverse_backward_trajectoriesbefore the loss calculation. The question is, is it worth supporting direct loss calculation using the backward Trajectories? This won't be too hard to implement.This was also the case for the Transitions (although I removed the backward flag from Transitions; this can be easily reverted). If we want to support training with backward Trajectories, do we also want the same for the backward Transitions?
TODO
Currently, we can store only one of the PF or PB during sampling (i.e., we always need to call our NN module for loss calculation). My next PR will resolve this by calculating both PF and PB within a single sampling process.