@@ -60,11 +60,9 @@ class DBGFlowNet(PFBasedGFlowNet[Transitions]):
6060 logF: A ScalarEstimator or ConditionalScalarEstimator for estimating the log
6161 flow of the states.
6262 forward_looking: Whether to use the forward-looking GFN loss.
63- log_reward_clip_min: If finite, clips log rewards to this value.
64- safe_log_prob_min: If True, uses -1e10 as the minimum log probability value
65- to avoid numerical instability, otherwise uses -1e38.
6663 constant_pb: Whether to ignore the backward policy estimator, e.g., if the
6764 gflownet DAG is a tree, and pb is therefore always 1.
65+ log_reward_clip_min: If finite, clips log rewards to this value.
6866 """
6967
7068 def __init__ (
@@ -73,9 +71,8 @@ def __init__(
7371 pb : Estimator | None ,
7472 logF : ScalarEstimator | ConditionalScalarEstimator ,
7573 forward_looking : bool = False ,
76- log_reward_clip_min : float = - float ("inf" ),
77- safe_log_prob_min : bool = True ,
7874 constant_pb : bool = False ,
75+ log_reward_clip_min : float = - float ("inf" ),
7976 ) -> None :
8077 """Initializes a DBGFlowNet instance.
8178
@@ -86,19 +83,19 @@ def __init__(
8683 logF: A ScalarEstimator or ConditionalScalarEstimator for estimating the log
8784 flow of the states.
8885 forward_looking: Whether to use the forward-looking GFN loss.
89- log_reward_clip_min: If finite, clips log rewards to this value.
90- safe_log_prob_min: If True, uses -1e10 as the minimum log probability value
91- to avoid numerical instability, otherwise uses -1e38.
9286 constant_pb: Whether to ignore the backward policy estimator, e.g., if the
9387 gflownet DAG is a tree, and pb is therefore always 1. Must be set
9488 explicitly by user to ensure that pb is an Estimator except under this
9589 special case.
90+ log_reward_clip_min: If finite, clips log rewards to this value.
9691
9792 """
98- super ().__init__ (pf , pb , constant_pb = constant_pb )
93+ super ().__init__ (
94+ pf , pb , constant_pb = constant_pb , log_reward_clip_min = log_reward_clip_min
95+ )
9996
10097 # Disallow recurrent PF for transition-based DB
101- from gfn .estimators import RecurrentDiscretePolicyEstimator # type: ignore
98+ from gfn .estimators import RecurrentDiscretePolicyEstimator
10299
103100 if isinstance (self .pf , RecurrentDiscretePolicyEstimator ):
104101 raise TypeError (
@@ -112,11 +109,6 @@ def __init__(
112109
113110 self .logF = logF
114111 self .forward_looking = forward_looking
115- self .log_reward_clip_min = log_reward_clip_min
116- if safe_log_prob_min :
117- self .log_prob_min = - 1e10
118- else :
119- self .log_prob_min = - 1e38
120112
121113 def logF_named_parameters (self ) -> dict [str , torch .Tensor ]:
122114 """Returns a dictionary of named parameters containing 'logF' in their name.
@@ -191,14 +183,15 @@ def get_scores(
191183 if len (states ) == 0 :
192184 return torch .tensor (0.0 , device = transitions .device )
193185
194- # uncomment next line for debugging
195- # assert transitions.states.is_sink_state.equal(transitions.actions.is_dummy)
196186 check_compatibility (states , actions , transitions )
187+ assert (
188+ not transitions .states .is_sink_state .any ()
189+ ), "Transition from sink state is not allowed. This is a bug."
197190
198- log_pf_actions , log_pb_actions = self .get_pfs_and_pbs (
199- transitions , recalculate_all_logprobs
200- )
191+ ### Compute log_pf and log_pb
192+ log_pf , log_pb = self .get_pfs_and_pbs (transitions , recalculate_all_logprobs )
201193
194+ ### Compute log_F_s
202195 # LogF is potentially a conditional computation.
203196 if transitions .conditions is not None :
204197 with has_conditions_exception_handler ("logF" , self .logF ):
@@ -207,50 +200,65 @@ def get_scores(
207200 with no_conditions_exception_handler ("logF" , self .logF ):
208201 log_F_s = self .logF (states ).squeeze (- 1 )
209202
210- if self .forward_looking :
211- log_rewards = env .log_reward (states )
212- if math .isfinite (self .log_reward_clip_min ):
213- log_rewards = log_rewards .clamp_min (self .log_reward_clip_min )
214- log_F_s = log_F_s + log_rewards
215-
216- preds = log_pf_actions + log_F_s
217-
218- # uncomment next line for debugging
219- # assert transitions.next_states.is_sink_state.equal(transitions.is_terminating)
220-
221- # automatically removes invalid transitions (i.e. s_f -> s_f)
222- valid_next_states = transitions .next_states [~ transitions .is_terminating ]
223- valid_transitions_is_terminating = transitions .is_terminating [
224- ~ transitions .states .is_sink_state
225- ]
203+ ### Compute log_F_s_next
204+ log_F_s_next = torch .zeros_like (log_F_s )
205+ is_terminating = transitions .is_terminating
206+ is_intermediate = ~ is_terminating
226207
227- if len (valid_next_states ) == 0 :
228- return torch .tensor (0.0 , device = transitions .device )
229-
230- # LogF is potentially a conditional computation.
208+ # Assign log_F_s_next for intermediate next states
209+ interm_next_states = transitions .next_states [is_intermediate ]
210+ # log_F is potentially a conditional computation.
231211 if transitions .conditions is not None :
232212 with has_conditions_exception_handler ("logF" , self .logF ):
233- valid_log_F_s_next = self .logF (
234- valid_next_states ,
235- transitions .conditions [~ transitions . is_terminating ],
213+ log_F_s_next [ is_intermediate ] = self .logF (
214+ interm_next_states ,
215+ transitions .conditions [is_intermediate ],
236216 ).squeeze (- 1 )
237217 else :
238218 with no_conditions_exception_handler ("logF" , self .logF ):
239- valid_log_F_s_next = self .logF (valid_next_states ).squeeze (- 1 )
240-
241- log_F_s_next = torch .zeros_like (log_pb_actions )
242- log_F_s_next [~ valid_transitions_is_terminating ] = valid_log_F_s_next
243- assert transitions .log_rewards is not None
244- valid_transitions_log_rewards = transitions .log_rewards [
245- ~ transitions .states .is_sink_state
246- ]
247- log_F_s_next [valid_transitions_is_terminating ] = valid_transitions_log_rewards [
248- valid_transitions_is_terminating
249- ]
250- targets = log_pb_actions + log_F_s_next
219+ log_F_s_next [is_intermediate ] = self .logF (interm_next_states ).squeeze (- 1 )
251220
252- scores = preds - targets
221+ # Apply forward-looking if applicable
222+ if self .forward_looking :
223+ import warnings
224+
225+ warnings .warn (
226+ "Rewards should be defined over edges in forward-looking settings. "
227+ "The current implementation is a special case of this, where the edge "
228+ "reward is defined as the difference between the reward of two states "
229+ "that the edge connects. If your environment is not the case, "
230+ "forward-looking may be inappropriate."
231+ )
232+
233+ # Reward calculation can also be conditional.
234+ if transitions .conditions is not None :
235+ log_rewards_state = env .log_reward (states , transitions .conditions ) # type: ignore
236+ log_rewards_next = env .log_reward (
237+ interm_next_states , transitions .conditions [is_intermediate ] # type: ignore
238+ )
239+ else :
240+ log_rewards_state = env .log_reward (states )
241+ log_rewards_next = env .log_reward (interm_next_states )
242+ if math .isfinite (self .log_reward_clip_min ):
243+ log_rewards_state = log_rewards_state .clamp_min (self .log_reward_clip_min )
244+ log_rewards_next = log_rewards_next .clamp_min (self .log_reward_clip_min )
253245
246+ log_F_s = log_F_s + log_rewards_state
247+ log_F_s_next [is_intermediate ] = (
248+ log_F_s_next [is_intermediate ] + log_rewards_next
249+ )
250+
251+ # Assign log_F_s_next for terminating transitions as log_rewards
252+ log_rewards = transitions .log_rewards
253+ assert log_rewards is not None
254+ if math .isfinite (self .log_reward_clip_min ):
255+ log_rewards = log_rewards .clamp_min (self .log_reward_clip_min )
256+ log_F_s_next [is_terminating ] = log_rewards [is_terminating ]
257+
258+ ### Compute scores
259+ preds = log_pf + log_F_s
260+ targets = log_pb + log_F_s_next
261+ scores = preds - targets
254262 assert scores .shape == (transitions .n_transitions ,)
255263 return scores
256264
0 commit comments