Skip to content

Commit a88b715

Browse files
authored
Merge pull request #432 from GFNOrg/refactor-db
Refactor Detailed Balance
2 parents c21f619 + 5f4ce7e commit a88b715

File tree

7 files changed

+150
-130
lines changed

7 files changed

+150
-130
lines changed

src/gfn/containers/transitions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ def log_rewards(self) -> torch.Tensor | None:
194194
If not provided at initialization, log rewards are computed on demand for
195195
terminating transitions.
196196
"""
197-
if self.is_backward:
197+
if self.is_backward: # TODO: Why can't backward trajectories have log_rewards?
198198
return None
199199

200200
if self._log_rewards is None:

src/gfn/gflownet/base.py

Lines changed: 9 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -168,13 +168,15 @@ class PFBasedGFlowNet(GFlowNet[TrainingSampleType], ABC):
168168
pb: The backward policy estimator, or None if it can be ignored (e.g., the
169169
gflownet DAG is a tree, and pb is therefore always 1).
170170
constant_pb: Whether to ignore the backward policy estimator.
171+
log_reward_clip_min: If finite, clips log rewards to this value.
171172
"""
172173

173174
def __init__(
174175
self,
175176
pf: Estimator,
176177
pb: Estimator | None,
177178
constant_pb: bool = False,
179+
log_reward_clip_min: float = float("-inf"),
178180
) -> None:
179181
"""Initializes a PFBasedGFlowNet instance.
180182
@@ -186,6 +188,7 @@ def __init__(
186188
gflownet DAG is a tree, and pb is therefore always 1. Must be set
187189
explicitly by user to ensure that pb is an Estimator except under this
188190
special case.
191+
log_reward_clip_min: If finite, clips log rewards to this value.
189192
190193
"""
191194
super().__init__()
@@ -215,11 +218,12 @@ def __init__(
215218
self.pf = pf
216219
self.pb = pb
217220
self.constant_pb = constant_pb
221+
self.log_reward_clip_min = log_reward_clip_min
218222

219223
# Advisory: recurrent PF with non-recurrent PB is unusual
220224
# (tree DAGs typically prefer pb=None with constant_pb=True).
221225
# Import locally to avoid circular imports during module import time.
222-
from gfn.estimators import RecurrentDiscretePolicyEstimator # type: ignore
226+
from gfn.estimators import RecurrentDiscretePolicyEstimator
223227

224228
if isinstance(self.pf, RecurrentDiscretePolicyEstimator) and isinstance(
225229
self.pb, Estimator
@@ -288,7 +292,7 @@ def pf_pb_parameters(self) -> list[torch.Tensor]:
288292
return [v for k, v in self.named_parameters() if "pb" in k or "pf" in k]
289293

290294

291-
class TrajectoryBasedGFlowNet(PFBasedGFlowNet[Trajectories]):
295+
class TrajectoryBasedGFlowNet(PFBasedGFlowNet[Trajectories], ABC):
292296
"""A GFlowNet that operates on complete trajectories.
293297
294298
Attributes:
@@ -297,32 +301,9 @@ class TrajectoryBasedGFlowNet(PFBasedGFlowNet[Trajectories]):
297301
pb is therefore always 1.
298302
constant_pb: Whether to ignore the backward policy estimator, e.g., if the
299303
gflownet DAG is a tree, and pb is therefore always 1.
304+
log_reward_clip_min: If finite, clips log rewards to this value.
300305
"""
301306

302-
def __init__(
303-
self,
304-
pf: Estimator,
305-
pb: Estimator | None,
306-
constant_pb: bool = False,
307-
) -> None:
308-
"""Initializes a TrajectoryBasedGFlowNet instance.
309-
310-
Args:
311-
pf: The forward policy estimator.
312-
pb: The backward policy estimator, or None if the gflownet DAG is a tree,
313-
and pb is therefore always 1.
314-
constant_pb: Whether to ignore the backward policy estimator, e.g., if the
315-
gflownet DAG is a tree, and pb is therefore always 1. Must be set
316-
explicitly by user to ensure that pb is an Estimator except under this
317-
special case.
318-
319-
"""
320-
super().__init__(
321-
pf,
322-
pb,
323-
constant_pb=constant_pb,
324-
)
325-
326307
def get_pfs_and_pbs(
327308
self,
328309
trajectories: Trajectories,
@@ -388,8 +369,9 @@ def get_scores(
388369
total_log_pb_trajectories = log_pb_trajectories.sum(dim=0)
389370

390371
log_rewards = trajectories.log_rewards
372+
assert log_rewards is not None
391373

392-
if math.isfinite(self.log_reward_clip_min) and log_rewards is not None:
374+
if math.isfinite(self.log_reward_clip_min):
393375
log_rewards = log_rewards.clamp_min(self.log_reward_clip_min)
394376

395377
if torch.any(torch.isinf(total_log_pf_trajectories)):
@@ -399,7 +381,6 @@ def get_scores(
399381

400382
assert total_log_pf_trajectories.shape == (trajectories.n_trajectories,)
401383
assert total_log_pb_trajectories.shape == (trajectories.n_trajectories,)
402-
assert log_rewards is not None
403384
return total_log_pf_trajectories - total_log_pb_trajectories - log_rewards
404385

405386
def to_training_samples(self, trajectories: Trajectories) -> Trajectories:

src/gfn/gflownet/detailed_balance.py

Lines changed: 64 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -60,11 +60,9 @@ class DBGFlowNet(PFBasedGFlowNet[Transitions]):
6060
logF: A ScalarEstimator or ConditionalScalarEstimator for estimating the log
6161
flow of the states.
6262
forward_looking: Whether to use the forward-looking GFN loss.
63-
log_reward_clip_min: If finite, clips log rewards to this value.
64-
safe_log_prob_min: If True, uses -1e10 as the minimum log probability value
65-
to avoid numerical instability, otherwise uses -1e38.
6663
constant_pb: Whether to ignore the backward policy estimator, e.g., if the
6764
gflownet DAG is a tree, and pb is therefore always 1.
65+
log_reward_clip_min: If finite, clips log rewards to this value.
6866
"""
6967

7068
def __init__(
@@ -73,9 +71,8 @@ def __init__(
7371
pb: Estimator | None,
7472
logF: ScalarEstimator | ConditionalScalarEstimator,
7573
forward_looking: bool = False,
76-
log_reward_clip_min: float = -float("inf"),
77-
safe_log_prob_min: bool = True,
7874
constant_pb: bool = False,
75+
log_reward_clip_min: float = -float("inf"),
7976
) -> None:
8077
"""Initializes a DBGFlowNet instance.
8178
@@ -86,19 +83,19 @@ def __init__(
8683
logF: A ScalarEstimator or ConditionalScalarEstimator for estimating the log
8784
flow of the states.
8885
forward_looking: Whether to use the forward-looking GFN loss.
89-
log_reward_clip_min: If finite, clips log rewards to this value.
90-
safe_log_prob_min: If True, uses -1e10 as the minimum log probability value
91-
to avoid numerical instability, otherwise uses -1e38.
9286
constant_pb: Whether to ignore the backward policy estimator, e.g., if the
9387
gflownet DAG is a tree, and pb is therefore always 1. Must be set
9488
explicitly by user to ensure that pb is an Estimator except under this
9589
special case.
90+
log_reward_clip_min: If finite, clips log rewards to this value.
9691
9792
"""
98-
super().__init__(pf, pb, constant_pb=constant_pb)
93+
super().__init__(
94+
pf, pb, constant_pb=constant_pb, log_reward_clip_min=log_reward_clip_min
95+
)
9996

10097
# Disallow recurrent PF for transition-based DB
101-
from gfn.estimators import RecurrentDiscretePolicyEstimator # type: ignore
98+
from gfn.estimators import RecurrentDiscretePolicyEstimator
10299

103100
if isinstance(self.pf, RecurrentDiscretePolicyEstimator):
104101
raise TypeError(
@@ -112,11 +109,6 @@ def __init__(
112109

113110
self.logF = logF
114111
self.forward_looking = forward_looking
115-
self.log_reward_clip_min = log_reward_clip_min
116-
if safe_log_prob_min:
117-
self.log_prob_min = -1e10
118-
else:
119-
self.log_prob_min = -1e38
120112

121113
def logF_named_parameters(self) -> dict[str, torch.Tensor]:
122114
"""Returns a dictionary of named parameters containing 'logF' in their name.
@@ -191,14 +183,15 @@ def get_scores(
191183
if len(states) == 0:
192184
return torch.tensor(0.0, device=transitions.device)
193185

194-
# uncomment next line for debugging
195-
# assert transitions.states.is_sink_state.equal(transitions.actions.is_dummy)
196186
check_compatibility(states, actions, transitions)
187+
assert (
188+
not transitions.states.is_sink_state.any()
189+
), "Transition from sink state is not allowed. This is a bug."
197190

198-
log_pf_actions, log_pb_actions = self.get_pfs_and_pbs(
199-
transitions, recalculate_all_logprobs
200-
)
191+
### Compute log_pf and log_pb
192+
log_pf, log_pb = self.get_pfs_and_pbs(transitions, recalculate_all_logprobs)
201193

194+
### Compute log_F_s
202195
# LogF is potentially a conditional computation.
203196
if transitions.conditions is not None:
204197
with has_conditions_exception_handler("logF", self.logF):
@@ -207,50 +200,65 @@ def get_scores(
207200
with no_conditions_exception_handler("logF", self.logF):
208201
log_F_s = self.logF(states).squeeze(-1)
209202

210-
if self.forward_looking:
211-
log_rewards = env.log_reward(states)
212-
if math.isfinite(self.log_reward_clip_min):
213-
log_rewards = log_rewards.clamp_min(self.log_reward_clip_min)
214-
log_F_s = log_F_s + log_rewards
215-
216-
preds = log_pf_actions + log_F_s
217-
218-
# uncomment next line for debugging
219-
# assert transitions.next_states.is_sink_state.equal(transitions.is_terminating)
220-
221-
# automatically removes invalid transitions (i.e. s_f -> s_f)
222-
valid_next_states = transitions.next_states[~transitions.is_terminating]
223-
valid_transitions_is_terminating = transitions.is_terminating[
224-
~transitions.states.is_sink_state
225-
]
203+
### Compute log_F_s_next
204+
log_F_s_next = torch.zeros_like(log_F_s)
205+
is_terminating = transitions.is_terminating
206+
is_intermediate = ~is_terminating
226207

227-
if len(valid_next_states) == 0:
228-
return torch.tensor(0.0, device=transitions.device)
229-
230-
# LogF is potentially a conditional computation.
208+
# Assign log_F_s_next for intermediate next states
209+
interm_next_states = transitions.next_states[is_intermediate]
210+
# log_F is potentially a conditional computation.
231211
if transitions.conditions is not None:
232212
with has_conditions_exception_handler("logF", self.logF):
233-
valid_log_F_s_next = self.logF(
234-
valid_next_states,
235-
transitions.conditions[~transitions.is_terminating],
213+
log_F_s_next[is_intermediate] = self.logF(
214+
interm_next_states,
215+
transitions.conditions[is_intermediate],
236216
).squeeze(-1)
237217
else:
238218
with no_conditions_exception_handler("logF", self.logF):
239-
valid_log_F_s_next = self.logF(valid_next_states).squeeze(-1)
240-
241-
log_F_s_next = torch.zeros_like(log_pb_actions)
242-
log_F_s_next[~valid_transitions_is_terminating] = valid_log_F_s_next
243-
assert transitions.log_rewards is not None
244-
valid_transitions_log_rewards = transitions.log_rewards[
245-
~transitions.states.is_sink_state
246-
]
247-
log_F_s_next[valid_transitions_is_terminating] = valid_transitions_log_rewards[
248-
valid_transitions_is_terminating
249-
]
250-
targets = log_pb_actions + log_F_s_next
219+
log_F_s_next[is_intermediate] = self.logF(interm_next_states).squeeze(-1)
251220

252-
scores = preds - targets
221+
# Apply forward-looking if applicable
222+
if self.forward_looking:
223+
import warnings
224+
225+
warnings.warn(
226+
"Rewards should be defined over edges in forward-looking settings. "
227+
"The current implementation is a special case of this, where the edge "
228+
"reward is defined as the difference between the reward of two states "
229+
"that the edge connects. If your environment is not the case, "
230+
"forward-looking may be inappropriate."
231+
)
232+
233+
# Reward calculation can also be conditional.
234+
if transitions.conditions is not None:
235+
log_rewards_state = env.log_reward(states, transitions.conditions) # type: ignore
236+
log_rewards_next = env.log_reward(
237+
interm_next_states, transitions.conditions[is_intermediate] # type: ignore
238+
)
239+
else:
240+
log_rewards_state = env.log_reward(states)
241+
log_rewards_next = env.log_reward(interm_next_states)
242+
if math.isfinite(self.log_reward_clip_min):
243+
log_rewards_state = log_rewards_state.clamp_min(self.log_reward_clip_min)
244+
log_rewards_next = log_rewards_next.clamp_min(self.log_reward_clip_min)
253245

246+
log_F_s = log_F_s + log_rewards_state
247+
log_F_s_next[is_intermediate] = (
248+
log_F_s_next[is_intermediate] + log_rewards_next
249+
)
250+
251+
# Assign log_F_s_next for terminating transitions as log_rewards
252+
log_rewards = transitions.log_rewards
253+
assert log_rewards is not None
254+
if math.isfinite(self.log_reward_clip_min):
255+
log_rewards = log_rewards.clamp_min(self.log_reward_clip_min)
256+
log_F_s_next[is_terminating] = log_rewards[is_terminating]
257+
258+
### Compute scores
259+
preds = log_pf + log_F_s
260+
targets = log_pb + log_F_s_next
261+
scores = preds - targets
254262
assert scores.shape == (transitions.n_transitions,)
255263
return scores
256264

src/gfn/gflownet/sub_trajectory_balance.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,15 +102,16 @@ def __init__(
102102
special case.
103103
104104
"""
105-
super().__init__(pf, pb, constant_pb=constant_pb)
105+
super().__init__(
106+
pf, pb, constant_pb=constant_pb, log_reward_clip_min=log_reward_clip_min
107+
)
106108
assert any(
107109
isinstance(logF, cls)
108110
for cls in [ScalarEstimator, ConditionalScalarEstimator]
109111
), "logF must be a ScalarEstimator or derived"
110112
self.logF = logF
111113
self.weighting = weighting
112114
self.lamda = lamda
113-
self.log_reward_clip_min = log_reward_clip_min
114115
self.forward_looking = forward_looking
115116

116117
def logF_named_parameters(self) -> dict[str, torch.Tensor]:

src/gfn/gflownet/trajectory_balance.py

Lines changed: 11 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,10 @@ class TBGFlowNet(TrajectoryBasedGFlowNet):
3535
pb: The backward policy estimator, or None if the gflownet DAG is a tree, and
3636
pb is therefore always 1.
3737
logZ: A learnable parameter or a ScalarEstimator instance (for conditional GFNs).
38+
constant_pb: Whether to ignore pb e.g., the GFlowNet DAG is a tree, and pb
39+
is therefore always 1. Must be set explicitly by user to ensure that pb
40+
is an Estimator except under this special case.
3841
log_reward_clip_min: If finite, clips log rewards to this value.
39-
constant_pb: Whether the gflownet DAG is a tree, and pb is therefore always 1.
4042
"""
4143

4244
def __init__(
@@ -45,8 +47,8 @@ def __init__(
4547
pb: Estimator | None,
4648
logZ: nn.Parameter | ScalarEstimator | None = None,
4749
init_logZ: float = 0.0,
48-
log_reward_clip_min: float = -float("inf"),
4950
constant_pb: bool = False,
51+
log_reward_clip_min: float = -float("inf"),
5052
):
5153
"""Initializes a TBGFlowNet instance.
5254
@@ -57,15 +59,16 @@ def __init__(
5759
logZ: A learnable parameter or a ScalarEstimator instance (for
5860
conditional GFNs).
5961
init_logZ: The initial value for the logZ parameter (used if logZ is None).
60-
log_reward_clip_min: If finite, clips log rewards to this value.
6162
constant_pb: Whether to ignore pb e.g., the GFlowNet DAG is a tree, and pb
6263
is therefore always 1. Must be set explicitly by user to ensure that pb
6364
is an Estimator except under this special case.
65+
log_reward_clip_min: If finite, clips log rewards to this value.
6466
"""
65-
super().__init__(pf, pb, constant_pb=constant_pb)
67+
super().__init__(
68+
pf, pb, constant_pb=constant_pb, log_reward_clip_min=log_reward_clip_min
69+
)
6670

6771
self.logZ = logZ or nn.Parameter(torch.tensor(init_logZ))
68-
self.log_reward_clip_min = log_reward_clip_min
6972

7073
def logz_named_parameters(self) -> dict[str, torch.Tensor]:
7174
"""Returns a dictionary of named parameters containing 'logZ' in their name.
@@ -138,25 +141,12 @@ class LogPartitionVarianceGFlowNet(TrajectoryBasedGFlowNet):
138141
Attributes:
139142
pf: The forward policy estimator.
140143
pb: The backward policy estimator.
144+
constant_pb: Whether to ignore pb e.g., the GFlowNet DAG is a tree, and pb
145+
is therefore always 1. Must be set explicitly by user to ensure that pb
146+
is an Estimator except under this special case.
141147
log_reward_clip_min: If finite, clips log rewards to this value.
142148
"""
143149

144-
def __init__(
145-
self,
146-
pf: Estimator,
147-
pb: Estimator,
148-
log_reward_clip_min: float = -float("inf"),
149-
):
150-
"""Initializes a LogPartitionVarianceGFlowNet instance.
151-
152-
Args:
153-
pf: The forward policy estimator.
154-
pb: The backward policy estimator.
155-
log_reward_clip_min: If finite, clips log rewards to this value.
156-
"""
157-
super().__init__(pf, pb)
158-
self.log_reward_clip_min = log_reward_clip_min
159-
160150
def loss(
161151
self,
162152
env: Env,

tutorials/examples/train_hypergrid_gafn.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,9 @@ def __init__(
142142
flow_estimator: The flow estimator, required if use_edge_ri is True.
143143
log_reward_clip_min: If finite, clips log rewards to this value.
144144
"""
145-
super().__init__(pf, pb, logZ, init_logZ, log_reward_clip_min)
145+
super().__init__(
146+
pf, pb, logZ, init_logZ, log_reward_clip_min=log_reward_clip_min
147+
)
146148
self.rnd = rnd
147149
self.use_edge_ri = use_edge_ri
148150
if use_edge_ri and flow_estimator is None:

0 commit comments

Comments
 (0)