diff --git a/src/gfn/containers/states_container.py b/src/gfn/containers/states_container.py index f94ced55..96b444e0 100644 --- a/src/gfn/containers/states_container.py +++ b/src/gfn/containers/states_container.py @@ -1,9 +1,12 @@ from __future__ import annotations +from functools import partial from typing import TYPE_CHECKING, Generic, Sequence, TypeVar, cast import torch +from gfn.env import ConditionalEnv + if TYPE_CHECKING: from gfn.env import Env from gfn.states import States @@ -71,7 +74,8 @@ def __init__( self.conditions = conditions assert self.conditions is None or ( - self.conditions.shape[: len(batch_shape)] == batch_shape + len(self.conditions.shape) == 2 + and self.conditions.shape[0] == len(self.states) ) self.is_terminating = ( @@ -178,7 +182,15 @@ def log_rewards(self) -> torch.Tensor: fill_value=-float("inf"), device=self.states.device, ) - self._log_rewards[self.is_terminating] = self.env.log_reward( + if isinstance(self.env, ConditionalEnv): + assert self.conditions is not None + log_reward_fn = partial( + self.env.log_reward, + conditions=self.conditions[self.is_terminating], + ) + else: + log_reward_fn = self.env.log_reward + self._log_rewards[self.is_terminating] = log_reward_fn( self.terminating_states ) diff --git a/src/gfn/containers/trajectories.py b/src/gfn/containers/trajectories.py index 309f727e..6869bd2c 100644 --- a/src/gfn/containers/trajectories.py +++ b/src/gfn/containers/trajectories.py @@ -1,5 +1,6 @@ from __future__ import annotations +from functools import partial from typing import Sequence import torch @@ -8,7 +9,7 @@ from gfn.containers.base import Container from gfn.containers.states_container import StatesContainer from gfn.containers.transitions import Transitions -from gfn.env import Env +from gfn.env import ConditionalEnv, Env from gfn.states import DiscreteStates, GraphStates, States from gfn.utils.common import ensure_same_device, is_int_dtype @@ -109,8 +110,8 @@ def __init__( self.conditions = conditions assert self.conditions is None or ( - self.conditions.shape[: len(self.states.batch_shape)] - == self.states.batch_shape + len(self.conditions.shape) == 2 + and self.conditions.shape[0] == self.n_trajectories ) self.actions = ( @@ -245,7 +246,12 @@ def log_rewards(self) -> torch.Tensor | None: return None if self._log_rewards is None: - self._log_rewards = self.env.log_reward(self.terminating_states) + if isinstance(self.env, ConditionalEnv): + assert self.conditions is not None + log_reward_fn = partial(self.env.log_reward, conditions=self.conditions) + else: + log_reward_fn = self.env.log_reward + self._log_rewards = log_reward_fn(self.terminating_states) assert self._log_rewards.shape == (self.n_trajectories,) return self._log_rewards @@ -266,7 +272,7 @@ def __getitem__( terminating_idx = self.terminating_idx[index] new_max_length = terminating_idx.max().item() if len(terminating_idx) > 0 else 0 states = self.states[:, index] - conditions = self.conditions[:, index] if self.conditions is not None else None + conditions = self.conditions[index] if self.conditions is not None else None actions = self.actions[:, index] states = states[: 1 + new_max_length] actions = actions[:new_max_length] @@ -311,14 +317,8 @@ def extend(self, other: Trajectories) -> None: log_rewards). Args: - Another Trajectories to append. + other: Another Trajectories to append. """ - if self.conditions is not None: - # TODO: Support the case - raise NotImplementedError( - "`extend` is not implemented for conditional Trajectories." - ) - if len(other) == 0: return @@ -346,6 +346,12 @@ def extend(self, other: Trajectories) -> None: (self.terminating_idx, other.terminating_idx), dim=0 ) + # Concatenate conditions of the trajectories. + if self.conditions is not None and other.conditions is not None: + self.conditions = torch.cat((self.conditions, other.conditions), dim=0) + else: + self.conditions = None + # Concatenate log_rewards of the trajectories. if self._log_rewards is not None and other._log_rewards is not None: self._log_rewards = torch.cat((self._log_rewards, other._log_rewards), dim=0) @@ -387,18 +393,21 @@ def to_transitions(self) -> Transitions: A Transitions object with the same states, actions, and log_rewards as the current Trajectories. """ + valid_action_mask = ~self.actions.is_dummy if self.conditions is not None: - # The conditions tensor has shape (max_length, n_trajectories, 1) - # The actions have shape (max_length, n_trajectories) - # We need to index the conditions tensor to match the actions - # The actions exclude the last step, so we need to exclude the last step from conditions - conditions = self.conditions[:-1][~self.actions.is_dummy] + # The conditions tensor has shape (n_trajectories, condition_vector_dim) + # The actions have batch shape (max_length, n_trajectories) + # We need to repeat the condition vector tensor to match the actions + conditions = self.conditions.repeat(self.actions.batch_shape[0], 1, 1) + assert conditions.shape[:2] == self.actions.batch_shape + # Then we mask it with the valid action mask. + conditions = conditions[valid_action_mask] else: conditions = None - states = self.states[:-1][~self.actions.is_dummy] - next_states = self.states[1:][~self.actions.is_dummy] - actions = self.actions[~self.actions.is_dummy] + states = self.states[:-1][valid_action_mask] + next_states = self.states[1:][valid_action_mask] + actions = self.actions[valid_action_mask] is_terminating = ( next_states.is_sink_state if not self.is_backward @@ -454,9 +463,7 @@ def to_states_container(self) -> StatesContainer: ) is_terminating[self.terminating_idx - 1, torch.arange(len(self))] = True - states = self.states.flatten() - is_terminating = is_terminating.flatten() - + states = self.states is_valid = ~states.is_sink_state & ( ~states.is_initial_state | (states.is_initial_state & is_terminating) ) @@ -465,40 +472,25 @@ def to_states_container(self) -> StatesContainer: conditions = None if self.conditions is not None: - # The conditions tensor has shape (max_length, n_trajectories, 1) - # We need to flatten it to match the flattened states - # First, we need to repeat it to match the flattened shape - # The flattened states have shape (max_length * n_trajectories,) - # So we need to repeat the conditions tensor accordingly - conditions = self.conditions.flatten(0, 1)[is_valid] + # The conditions tensor has shape (n_trajectories, condition_vector_dim) + # The states have batch shape (max_length, n_trajectories) + # We need to repeat the conditions to match the batch shape of the states. + conditions = self.conditions.repeat(self.states.batch_shape[0], 1, 1) + # (max_length, n_trajectories, condition_vector_dim) + assert conditions.shape[:2] == self.states.batch_shape + # Then we mask it with the valid state mask. + conditions = conditions[is_valid] if self.log_rewards is None: log_rewards = None else: log_rewards = torch.full( - (len(states),), - fill_value=-float("inf"), - device=states.device, + self.states.batch_shape, fill_value=-float("inf"), device=states.device ) - # Get the original indices (before flattening and filtering). - orig_batch_indices = torch.arange( - self.states.batch_shape[0], device=states.device - ).repeat_interleave(self.states.batch_shape[1]) - orig_traj_indices = torch.arange( - self.states.batch_shape[1], device=states.device - ).repeat(self.states.batch_shape[0]) - - # Retain only the valid indices. - valid_batch_indices = orig_batch_indices[is_valid] - valid_traj_indices = orig_traj_indices[is_valid] - - # Assign rewards to valid terminating states. - terminating_mask = is_terminating & ( - valid_batch_indices == (self.terminating_idx[valid_traj_indices] - 1) + log_rewards[self.terminating_idx - 1, torch.arange(len(self))] = ( + self.log_rewards ) - log_rewards[terminating_mask] = self.log_rewards[ - valid_traj_indices[terminating_mask] - ] + log_rewards = log_rewards[is_valid] return StatesContainer[DiscreteStates]( env=self.env, diff --git a/src/gfn/containers/transitions.py b/src/gfn/containers/transitions.py index d2c444b0..6f4c4c0c 100644 --- a/src/gfn/containers/transitions.py +++ b/src/gfn/containers/transitions.py @@ -1,9 +1,12 @@ from __future__ import annotations +from functools import partial from typing import TYPE_CHECKING, Sequence import torch +from gfn.env import ConditionalEnv + if TYPE_CHECKING: from gfn.actions import Actions from gfn.env import Env @@ -92,7 +95,8 @@ def __init__( self.conditions = conditions assert self.conditions is None or ( - self.conditions.shape[: len(batch_shape)] == batch_shape + len(self.conditions.shape) == 2 + and self.conditions.shape[0] == self.n_transitions ) self.actions = ( @@ -204,7 +208,15 @@ def log_rewards(self) -> torch.Tensor | None: fill_value=-float("inf"), device=self.states.device, ) - self._log_rewards[self.is_terminating] = self.env.log_reward( + if isinstance(self.env, ConditionalEnv): + assert self.conditions is not None + log_reward_fn = partial( + self.env.log_reward, + conditions=self.conditions[self.is_terminating], + ) + else: + log_reward_fn = self.env.log_reward + self._log_rewards[self.is_terminating] = log_reward_fn( self.terminating_states ) @@ -231,10 +243,15 @@ def all_log_rewards(self) -> torch.Tensor: fill_value=-float("inf"), device=self.states.device, ) - log_rewards[~is_sink_state, 0] = self.env.log_reward(self.states[~is_sink_state]) - log_rewards[~is_sink_state, 1] = self.env.log_reward( - self.next_states[~is_sink_state] - ) + if isinstance(self.env, ConditionalEnv): + assert self.conditions is not None + log_reward_fn = partial( + self.env.log_reward, conditions=self.conditions[~is_sink_state] + ) + else: + log_reward_fn = self.env.log_reward + log_rewards[~is_sink_state, 0] = log_reward_fn(self.states[~is_sink_state]) + log_rewards[~is_sink_state, 1] = log_reward_fn(self.next_states[~is_sink_state]) assert ( log_rewards.shape == (self.n_transitions, 2) @@ -282,12 +299,6 @@ def extend(self, other: Transitions) -> None: Args: Another Transitions object to append. """ - if self.conditions is not None: - # TODO: Support the case - raise NotImplementedError( - "`extend` is not implemented for conditional Transitions." - ) - if len(other) == 0: return @@ -314,6 +325,12 @@ def extend(self, other: Transitions) -> None: ) self.next_states.extend(other.next_states) + # Concatenate conditions of the transitions. + if self.conditions is not None and other.conditions is not None: + self.conditions = torch.cat((self.conditions, other.conditions), dim=0) + else: + self.conditions = None + # Concatenate log_rewards of the trajectories. if self._log_rewards is not None and other._log_rewards is not None: self._log_rewards = torch.cat((self._log_rewards, other._log_rewards), dim=0) diff --git a/src/gfn/env.py b/src/gfn/env.py index 646094f0..89dfe4e9 100644 --- a/src/gfn/env.py +++ b/src/gfn/env.py @@ -414,7 +414,7 @@ def log_partition(self) -> float: The log partition function. """ raise NotImplementedError( - "The environment does not support enumeration of states" + "The environment does not support calculating the log partition" ) @property @@ -429,6 +429,84 @@ def true_dist(self) -> torch.Tensor: ) +class ConditionalEnv(Env, ABC): + """Base class for conditional environments. + + Conditional environments are environments with condition variables. + For now, we assume that the conditions only affect the rewards, not the dynamics of + the environment. + """ + + @abstractmethod + def sample_conditions(self, batch_size: int) -> torch.Tensor: + """Sample conditions for the environment. + + Args: + batch_size: The number of conditions to sample. + + Returns: + A tensor of shape (batch_size, condition_vector_dim) containing the conditions. + """ + raise NotImplementedError + + def reward(self, states: States, conditions: torch.Tensor) -> torch.Tensor: + """Compute rewards for the conditional environment. + + Args: + states: The states to compute rewards for. + states.tensor.shape should be (batch_size, *state_shape) + conditions: The conditions to compute rewards for. + conditions.shape should be (batch_size, condition_vector_dim) + + Returns: + A tensor of shape (batch_size,) containing the rewards. + """ + raise NotImplementedError + + def log_reward(self, states: States, conditions: torch.Tensor) -> torch.Tensor: + """Compute log rewards for the conditional environment. + + Args: + states: The states to compute log rewards for. + states.tensor.shape should be (batch_size, *state_shape) + conditions: The conditions to compute log rewards for. + conditions.shape should be (batch_size, condition_vector_dim) + + Returns: + A tensor of shape (batch_size,) containing the log rewards. + """ + return torch.log(self.reward(states, conditions)) + + def log_partition(self, condition: torch.Tensor) -> float: + """Optional method to return the logarithm of the partition function for a + given condition. + + Args: + condition: The condition to compute the log partition for. + condition.shape should be (condition_vector_dim,) + + Returns: + The log partition function, as a float. + """ + raise NotImplementedError( + "The environment may not support enumeration of states" + ) + + def true_dist(self, condition: torch.Tensor) -> torch.Tensor: + """Optional method to return the true distribution for a given condition. + + Args: + condition: The condition to compute the true distribution for. + condition.shape should be (condition_vector_dim,) + + Returns: + The true distribution for the given condition as a 1-dimensional tensor. + """ + raise NotImplementedError( + "The environment may not support enumeration of states" + ) + + class DiscreteEnv(Env, ABC): """Base class for discrete environments, where states are defined in a discrete space, and actions are represented by an integer in {0, ..., n_actions - 1}, the diff --git a/src/gfn/gflownet/flow_matching.py b/src/gfn/gflownet/flow_matching.py index c1a70d59..e3f7f32a 100644 --- a/src/gfn/gflownet/flow_matching.py +++ b/src/gfn/gflownet/flow_matching.py @@ -4,7 +4,7 @@ import torch from gfn.containers import StatesContainer, Trajectories -from gfn.env import DiscreteEnv +from gfn.env import Env from gfn.estimators import ( DiscretePolicyEstimator, PolicyMixin, @@ -59,7 +59,7 @@ def __init__(self, logF: DiscretePolicyEstimator, alpha: float = 1.0): def sample_trajectories( self, - env: DiscreteEnv, + env: Env, n: int, conditions: torch.Tensor | None = None, save_logprobs: bool = False, @@ -96,7 +96,7 @@ def sample_trajectories( def flow_matching_loss( self, - env: DiscreteEnv, + env: Env, states: DiscreteStates, conditions: torch.Tensor | None, reduction: str = "mean", @@ -194,7 +194,7 @@ def flow_matching_loss( def reward_matching_loss( self, - env: DiscreteEnv, + env: Env, terminating_states: DiscreteStates, conditions: torch.Tensor | None, log_rewards: torch.Tensor | None, @@ -231,7 +231,7 @@ def reward_matching_loss( def loss( self, - env: DiscreteEnv, + env: Env, states_container: StatesContainer[DiscreteStates], recalculate_all_logprobs: bool = True, reduction: str = "mean", @@ -254,6 +254,10 @@ def loss( The computed flow matching loss as a tensor. The shape depends on the reduction method. """ + if not env.is_discrete: + raise NotImplementedError( + "Flow Matching GFlowNet only supports discrete environments for now." + ) assert isinstance(states_container.intermediary_states, DiscreteStates) assert isinstance(states_container.terminating_states, DiscreteStates) if recalculate_all_logprobs: diff --git a/src/gfn/gflownet/sub_trajectory_balance.py b/src/gfn/gflownet/sub_trajectory_balance.py index 37296c61..0fce3d77 100644 --- a/src/gfn/gflownet/sub_trajectory_balance.py +++ b/src/gfn/gflownet/sub_trajectory_balance.py @@ -5,7 +5,7 @@ import torch from gfn.containers import Trajectories -from gfn.env import Env +from gfn.env import ConditionalEnv, Env from gfn.estimators import ConditionalScalarEstimator, Estimator, ScalarEstimator from gfn.gflownet.base import TrajectoryBasedGFlowNet, loss_reduce from gfn.utils.handlers import ( @@ -265,21 +265,28 @@ def calculate_log_state_flows( if trajectories.conditions is not None: # Compute the condition matrix broadcast to match valid_states. - # The conditions tensor has shape (max_length, n_trajectories, 1) - # We need to index it to match the valid states - conditions = trajectories.conditions[mask] - + # The conditions tensor has shape (n_trajectories, condition_vector_dim) + # The states have batch shape (max_length, n_trajectories) + # We need to repeat the conditions to match the batch shape of the states. + conditions = trajectories.conditions.repeat(states.batch_shape[0], 1, 1) + # (max_length, n_trajectories, condition_vector_dim) + assert conditions.shape[:2] == states.batch_shape + conditions = conditions[mask] with has_conditions_exception_handler("logF", self.logF): - log_F = self.logF(valid_states, conditions) + log_F = self.logF(valid_states, conditions).squeeze(-1) + + if self.forward_looking: + assert isinstance(env, ConditionalEnv) + log_F = log_F + env.log_reward(valid_states, conditions) + else: with no_conditions_exception_handler("logF", self.logF): log_F = self.logF(valid_states).squeeze(-1) - if self.forward_looking: - log_rewards = env.log_reward(states).unsqueeze(-1) - log_F = log_F + log_rewards + if self.forward_looking: + log_F = log_F + env.log_reward(valid_states) - log_state_flows[mask[:-1]] = log_F.squeeze() + log_state_flows[mask[:-1]] = log_F return log_state_flows def calculate_masks( diff --git a/src/gfn/gym/hypergrid.py b/src/gfn/gym/hypergrid.py index 1215b5f5..0231a6b7 100644 --- a/src/gfn/gym/hypergrid.py +++ b/src/gfn/gym/hypergrid.py @@ -16,7 +16,6 @@ from gfn.actions import Actions from gfn.env import DiscreteEnv from gfn.states import DiscreteStates -from gfn.utils.common import ensure_same_device if platform.system() == "Windows": multiprocessing.set_start_method("spawn", force=True) @@ -113,17 +112,20 @@ def __init__( self._all_states_tensor = None # Populated optionally in init. self._log_partition = None # Populated optionally in init. self._true_dist = None # Populated at first request. - self.calculate_partition = calculate_partition + + # If we store the all states, the partition function is calculated automatically. + self.calculate_partition = calculate_partition or store_all_states self.store_all_states = store_all_states # Pre-computes these values when printing. + if self.store_all_states or self.calculate_partition: + self._enumerate_all_states_tensor() + if self.store_all_states: - self._store_all_states_tensor() assert self._all_states_tensor is not None print(f"+ Environment has {len(self._all_states_tensor)} states") - if self.calculate_partition: - self._calculate_log_partition() + assert self._log_partition is not None print(f"+ Environment log partition is {self._log_partition}") if isinstance(device, str): @@ -375,96 +377,57 @@ def n_terminating_states(self) -> int: """Returns the number of terminating states in the environment.""" return self.n_states - # Functions for calculating the true log partition function / state enumeration. - def _calculate_log_partition(self, batch_size: int = 20_000): - """Calculates the log partition of the complete hypergrid. + def _enumerate_all_states_tensor(self, batch_size: int = 20_000): + """Enumerates all states_tensor of the complete hypergrid. Args: batch_size: The batch size to use for the calculation. """ - if self._log_partition is None and self.calculate_partition: - if self._all_states_tensor is not None: - self._log_partition = ( - self.reward_fn(self._all_states_tensor).sum().log().item() - ) - return - - # The # of possible combinations (with repetition) of - # numbers, where each - # number can be any integer from 0 to - # (inclusive), is given by: - # n = (k + 1) ** n -- note that k in our case is height-1, as it represents - # a python index. - max_height_idx = self.height - 1 # Handles 0 indexing. - n_expected = (max_height_idx + 1) ** self.ndim - n_found = 0 - start_time = time() - total_reward = 0 - - for batch in self._generate_combinations_in_batches( - self.ndim, - max_height_idx, - batch_size, - ): - batch = torch.LongTensor(list(batch)) - rewards = self.reward_fn( - batch - ) # Operates on raw tensors due to multiprocessing. - total_reward += rewards.sum().item() # Accumulate. - n_found += batch.shape[0] - - assert n_expected == n_found, "failed to compute reward of all indices!" - end_time = time() - total_log_reward = log(total_reward) - - print( - "log_partition = {}, calculated in {} minutes".format( - total_log_reward, - (end_time - start_time) / 60.0, - ) - ) - - self._log_partition = total_log_reward + # Check if we really need to enumerate + need_to_enumerate = ( + self.store_all_states and self._all_states_tensor is None + ) or (self.calculate_partition and self._log_partition is None) - def _store_all_states_tensor(self, batch_size: int = 20_000): - """Enumerates all states_tensor of the complete hypergrid. - - Args: - batch_size: The batch size to use for the calculation. - """ - if self._all_states_tensor is None: + if need_to_enumerate: start_time = time() all_states_tensor = [] + total_rewards = 0.0 for batch in self._generate_combinations_in_batches( self.ndim, self.height - 1, # Handles 0 indexing. batch_size, ): - all_states_tensor.append(torch.LongTensor(list(batch))) - - all_states_tensor = torch.cat(all_states_tensor, dim=0) + batch_tensor = torch.LongTensor(list(batch)) + if self.store_all_states: + all_states_tensor.append(batch_tensor) + if self.calculate_partition: + # Operates on raw tensors due to multiprocessing. + total_rewards += self.reward_fn(batch_tensor).sum().item() end_time = time() print( - "calculated tensor of all states in {} minutes".format( + "Enumerated all states in {} minutes".format( (end_time - start_time) / 60.0, ) ) - self._all_states_tensor = all_states_tensor + if self.store_all_states: + self._all_states_tensor = torch.cat(all_states_tensor, dim=0) + + if self.calculate_partition: + self._log_partition = log(total_rewards) @property def true_dist(self) -> torch.Tensor | None: """Returns the pmf over all states in the hypergrid.""" - if self._true_dist is None and self.all_states is not None: - assert torch.all( - self.get_states_indices(self.all_states) - == torch.arange(self.n_states, device=self.device) - ) - self._true_dist = self.reward(self.all_states) - self._true_dist /= self._true_dist.sum() + if self._true_dist is None: + assert ( + self.all_states is not None + ), "true_dist is not available without all_states" + all_rewards = self.reward(self.all_states) + self._true_dist = all_rewards / all_rewards.sum() return self._true_dist @@ -492,35 +455,31 @@ def log_partition(self) -> float | None: @property def all_states(self) -> DiscreteStates | None: """Returns a tensor of all hypergrid states as a `DiscreteStates` instance.""" + if not self.store_all_states: + return None + if self._all_states_tensor is None: - if not self.store_all_states: - return None - self._store_all_states_tensor() + self._enumerate_all_states_tensor() assert self._all_states_tensor is not None - try: - ensure_same_device(self._all_states_tensor.device, self.device) - except ValueError: - self._all_states_tensor = self._all_states_tensor.to(self.device) - - all_states = self.States(self._all_states_tensor) - return all_states + assert torch.all( + self.get_states_indices(self._all_states_tensor) + == torch.arange(self.n_states, device=self.device) + ) + self._all_states_tensor = self._all_states_tensor.to(self.device) + return self.States(self._all_states_tensor) @property def terminating_states(self) -> DiscreteStates | None: """Returns all terminating states of the environment.""" return self.all_states - # Helper methods for enumerating all possible states. - def _generate_combinations_chunk(self, numbers, n, start, end): - """Generate combinations with replacement for the specified range.""" - # islice accesses a subset of the full iterator - each job does unique work. - return itertools.islice(itertools.product(numbers, repeat=n), start, end) - def _worker(self, task): """Executes a single call to `generate_combinations_chunk`.""" numbers, n, start, end = task - return self._generate_combinations_chunk(numbers, n, start, end) + # Generate combinations with replacement for the specified range. + # islice accesses a subset of the full iterator - each job does unique work. + return itertools.islice(itertools.product(numbers, repeat=n), start, end) def _generate_combinations_in_batches(self, n, k, batch_size): """Uses Pool to collect subsets of the results of itertools.product in parallel.""" diff --git a/src/gfn/samplers.py b/src/gfn/samplers.py index a4d74409..49a7f08f 100644 --- a/src/gfn/samplers.py +++ b/src/gfn/samplers.py @@ -191,8 +191,8 @@ def sample_trajectories( # noqa: C901 device = states.device if conditions is not None: - assert states.batch_shape == conditions.shape[: len(states.batch_shape)] ensure_same_device(states.device, conditions.device) + assert conditions.shape[0] == n_trajectories if policy_estimator.is_backward: dones = states.is_initial_state @@ -310,25 +310,6 @@ def sample_trajectories( # noqa: C901 if len(stacked_estimator_outputs) == 0: stacked_estimator_outputs = None - # Broadcast condition tensor to match states batch shape if needed - if conditions is not None: - # The states have batch shape (max_length, n_trajectories). The - # conditions tensor should have shape (n_trajectories,) or - # (n_trajectories, 1). We need to broadcast it to (max_length, - # n_trajectories, 1) for the estimator - if len(conditions.shape) == 1: - # conditions has shape (n_trajectories,) - conditions = ( - conditions.unsqueeze(0) - .unsqueeze(-1) - .expand(stacked_states.batch_shape[0], -1, 1) - ) - elif len(conditions.shape) == 2 and conditions.shape[1] == 1: - # conditions has shape (n_trajectories, 1) - conditions = conditions.unsqueeze(0).expand( - stacked_states.batch_shape[0], -1, -1 - ) - trajectories = Trajectories( env=env, states=stacked_states, diff --git a/src/gfn/utils/prob_calculations.py b/src/gfn/utils/prob_calculations.py index 1047dc98..ce524088 100644 --- a/src/gfn/utils/prob_calculations.py +++ b/src/gfn/utils/prob_calculations.py @@ -86,15 +86,6 @@ def get_trajectory_pfs( if trajectories.is_backward: raise ValueError("Backward trajectories are not supported") - state_mask = ~trajectories.states.is_sink_state - action_mask = ~trajectories.actions.is_dummy - - valid_states = trajectories.states[state_mask] - valid_actions = trajectories.actions[action_mask] - - if valid_states.batch_shape != valid_actions.batch_shape: - raise AssertionError("Something wrong happening with log_pf evaluations") - if trajectories.has_log_probs and not recalculate_all_logprobs: log_pf_trajectories = trajectories.log_probs assert log_pf_trajectories is not None @@ -116,13 +107,9 @@ def get_trajectory_pfs( # Per-step path. N = trajectories.n_trajectories device = trajectories.states.device - cond = trajectories.conditions - - # TODO: Why do we need this? - if cond is not None and len(cond.shape) >= 2: - cond = cond[0] + cond = trajectories.conditions # shape (N, cond_dim) - ctx = policy_pf.init_context(int(N), device, cond) # type: ignore[arg-type] + ctx = policy_pf.init_context(int(N), device, cond) T = trajectories.max_length log_pf_trajectories = torch.full( @@ -133,16 +120,18 @@ def get_trajectory_pfs( ) for t in range(T): - state_ok = ~trajectories.states.is_sink_state[t] - action_ok = ~trajectories.actions.is_dummy[t] - step_mask = state_ok & action_ok + step_states = trajectories.states[t] + step_actions = trajectories.actions[t] + + assert (step_states.is_sink_state == step_actions.is_dummy).all() + step_mask = ~step_states.is_sink_state + + valid_step_states = step_states[step_mask] + valid_step_actions = step_actions[step_mask] if not torch.any(step_mask): continue - step_states = trajectories.states[t][step_mask] - step_actions = trajectories.actions.tensor[t][step_mask] - # Optimization: forward cached estimator outputs when available if ( trajectories.estimator_outputs is not None @@ -158,25 +147,27 @@ def get_trajectory_pfs( ctx.current_estimator_output = None # Build distribution for active rows and compute step log-probs + # TODO: masking ctx with step_mask outside of compute_dist and log_probs, + # i.e., implement __getitem__ for ctx. (maybe we should contain only the + # tensors, and not additional metadata like the batch size, device, etc.) dist, ctx = policy_pf.compute_dist( - step_states, ctx, step_mask, **policy_kwargs + valid_step_states, ctx, step_mask, **policy_kwargs ) step_log_probs, ctx = policy_pf.log_probs( - step_actions, dist, ctx, step_mask, vectorized=False + valid_step_actions.tensor, dist, ctx, step_mask, vectorized=False ) - # Pad back to full batch size. - if fill_value != 0.0: - padded = torch.full( - (N,), fill_value, device=device, dtype=step_log_probs.dtype - ) - padded[step_mask] = step_log_probs[step_mask] - step_log_probs = padded - # Store in trajectory-level tensor. log_pf_trajectories[t] = step_log_probs else: + state_mask = ~trajectories.states.is_sink_state + action_mask = ~trajectories.actions.is_dummy + assert (state_mask[:-1] == action_mask).all() # state_mask[-1] is all False + + valid_states = trajectories.states[state_mask] + valid_actions = trajectories.actions[action_mask] + # Vectorized path. log_pf_trajectories = torch.full_like( trajectories.actions.tensor[..., 0], @@ -189,16 +180,12 @@ def get_trajectory_pfs( # Build conditions per-step shape to align with valid_states masked_cond = None - cond = trajectories.conditions - - if cond is not None: - T = trajectories.states.tensor.shape[0] - # If conditions already has time dim (T, N, ...), index directly. - if cond.shape[0] == T: - masked_cond = cond[state_mask] - else: - # Broadcast (N, ...) to (T, N, ...), then index. - masked_cond = cond.unsqueeze(0).expand((T,) + cond.shape)[state_mask] + if trajectories.conditions is not None: + # trajectories.conditions shape: (N, cond_dim) + # Repeat it to (T, N, cond_dim) and then mask it with the state_mask + T = trajectories.max_length + 1 + masked_cond = trajectories.conditions.repeat(T, 1, 1) + masked_cond = masked_cond[state_mask] ctx_v = policy_pf.init_context( int(len(valid_states)), @@ -216,10 +203,7 @@ def get_trajectory_pfs( # Build distribution and compute vectorized log-probs dist, ctx_v = policy_pf.compute_dist( - valid_states, - ctx_v, - step_mask=None, - **policy_kwargs, + valid_states, ctx_v, step_mask=None, **policy_kwargs ) valid_log_pf_actions, _ = policy_pf.log_probs( valid_actions.tensor, dist, ctx_v, step_mask=None, vectorized=True @@ -333,10 +317,8 @@ def get_trajectory_pbs( # Per-step pb evaluation (state at t+1, action at t) N = trajectories.n_trajectories device = trajectories.states.device - cond = trajectories.conditions - if cond is not None and len(cond.shape) >= 2: - cond_step0 = cond[0] # TODO: Why do we need this? - ctx = policy_pb.init_context(int(N), device, cond_step0) # type: ignore[arg-type] + cond = trajectories.conditions # shape (N, cond_dim) + ctx = policy_pb.init_context(int(N), device, cond) # Iterate per-step with masking (state at t+1, action at t) for t in range(trajectories.max_length): diff --git a/tutorials/examples/train_conditional.py b/tutorials/examples/train_conditional.py index b2f1b3a5..b31da791 100644 --- a/tutorials/examples/train_conditional.py +++ b/tutorials/examples/train_conditional.py @@ -4,7 +4,7 @@ This script demonstrates how to train conditional GFlowNets that learn different distributions based on a continuous condition variable on the HyperGrid environment. -The conditioning interpolates between two extremes: +The condition interpolates between two extremes: - Condition = 0: Uniform distribution (all states get reward R0+R1+R2) - Condition = 1: Original HyperGrid multi-modal distribution @@ -30,6 +30,7 @@ from torch.optim import Adam from tqdm import tqdm +from gfn.env import ConditionalEnv, DiscreteEnv from gfn.estimators import ( ConditionalDiscretePolicyEstimator, ConditionalLogZEstimator, @@ -52,7 +53,7 @@ DEFAULT_SEED: int = 4444 -class ConditionalHyperGrid(HyperGrid): +class ConditionalHyperGrid(HyperGrid, ConditionalEnv): """HyperGrid environment with condition-aware rewards. Condition values: @@ -62,63 +63,86 @@ class ConditionalHyperGrid(HyperGrid): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.conditions = None - self._original_reward_fn = self.reward_fn + self._original_reward_fn = self.reward_fn # Rename, just to avoid confusion + self._max_reward: float = ( + self.reward_fn_kwargs.get("R0", 0.1) + + self.reward_fn_kwargs.get("R1", 0.5) + + self.reward_fn_kwargs.get("R2", 2.0) + ) + self._log_partition_cache: dict[torch.Tensor, float] = {} + self._true_dist_cache: dict[torch.Tensor, torch.Tensor] = {} - def set_conditions(self, conditions: torch.Tensor): - """Set the conditions for the environment.""" - self.conditions = conditions + def sample_conditions(self, batch_size: int) -> torch.Tensor: + """Sample conditions for the environment.""" + return torch.rand((batch_size, 1), device=self.device) - def reward(self, states: DiscreteStates) -> torch.Tensor: + def reward(self, states: DiscreteStates, conditions: torch.Tensor) -> torch.Tensor: """Compute rewards based on current conditions. A condition is continuous from 0 to 1: - 0: Fully uniform reward (all states get R0+R1+R2) - 1: Fully original HyperGrid reward - In between: Linear interpolation between uniform and original + + Args: + states: The states to compute rewards for. + states.tensor.shape should be (batch_size, *state_shape) + conditions: The conditions to compute rewards for. + conditions.shape should be (batch_size, 1) + + Returns: + A tensor of shape (batch_size,) containing the rewards. """ # Get original rewards original_rewards = self._original_reward_fn(states.tensor) - - if self.conditions is None: - return original_rewards - - # Apply condition-based modification - # condition shape: (batch_size, 1) or (1, 1) - # original_rewards shape: (batch_size,) + # shape: (batch_size,) # Expand conditions to match batch shape if needed - cond = self.conditions.squeeze(-1) # Remove feature dim - - # Handle different scenarios for batch size mismatch - if cond.shape[0] != original_rewards.shape[0]: - if cond.shape[0] == 1: - # Single condition value, broadcast to all states - cond = cond.expand(original_rewards.shape[0]) - else: - # Multiple condition values but different batch size - # This can happen during DB loss calculation with transitions - # Use the first condition value as a fallback - if len(cond) > 0: - cond = cond[0].expand(original_rewards.shape[0]) - else: - # No condition available, return original rewards - return original_rewards + cond = conditions.squeeze(-1) # Remove feature dim; shape: (batch_size,) # For uniform, all states get the max reward (R0+R1+R2) - max_reward = ( - self.reward_fn_kwargs.get("R0", 0.1) - + self.reward_fn_kwargs.get("R1", 0.5) - + self.reward_fn_kwargs.get("R2", 2.0) - ) - uniform_rewards = torch.full_like(original_rewards, max_reward) + uniform_rewards = torch.full_like(original_rewards, self._max_reward) - # Linear interpolation between uniform and original based on condition - # rewards = (1 - cond) * uniform + cond * original + # Linear interpolation between uniform and original based on conditions rewards = (1 - cond) * uniform_rewards + cond * original_rewards - return rewards + def log_partition(self, condition: torch.Tensor) -> float: + """Compute the log partition for the given condition. + + Args: + condition: The condition to compute the log partition for. + condition.shape should be (1,) + + Returns: + The log partition function, as a float. + """ + if condition not in self._log_partition_cache: + assert self.all_states is not None + all_rewards = self.reward( + self.all_states, condition.repeat(self.n_states, 1) + ) + self._log_partition_cache[condition] = all_rewards.sum().log().item() + return self._log_partition_cache[condition] + + def true_dist(self, condition: torch.Tensor) -> torch.Tensor: + """Compute the true distribution for the given condition. + + Args: + condition: The condition to compute the true distribution for. + condition.shape should be (1,) + + Returns: + The true distribution for the given condition as a 1-dimensional tensor. + """ + if condition not in self._true_dist_cache: + assert self.all_states is not None + all_rewards = self.reward( + self.all_states, condition.repeat(self.n_states, 1) + ) + self._true_dist_cache[condition] = all_rewards / all_rewards.sum() + return self._true_dist_cache[condition] + def build_conditional_pf_pb( env: HyperGrid, @@ -318,7 +342,7 @@ def build_subTB_gflownet(env): def train( - env, + env: ConditionalEnv, gflownet, seed, device, @@ -364,12 +388,7 @@ def train( final_loss = None for it in (pbar := tqdm(range(n_iterations), dynamic_ncols=True)): # Sample conditions uniformly from [0, 1] for this batch - conditions = torch.rand((batch_size,)).to(device) - # Keep as continuous value between 0 and 1 - conditions = conditions.unsqueeze(-1) # Add feature dimension for MLP - - # Set conditions in environment for reward calculation - env.set_conditions(conditions) + conditions = env.sample_conditions(batch_size) # Sample trajectories with conditions trajectories = gflownet.sample_trajectories( @@ -409,7 +428,6 @@ def train( conditions_val = torch.full( (validation_samples, 1), cond_val, device=device ) - env.set_conditions(conditions_val) # Sample fresh trajectories for this conditions value # This follows the validate function's approach but with conditions support @@ -428,38 +446,34 @@ def train( # Update discovered modes for condition=1 if cond_val == 1.0: - rewards = env.reward(sampled_states) + rewards = env.reward( + sampled_states, torch.tensor([cond_val], device=device) + ) modes = sampled_states[rewards >= mode_reward_threshold].tensor modes_found = set([tuple(s.tolist()) for s in modes]) discovered_modes.update(modes_found) # Compute empirical distribution using validate's helper function - empirical_dist = get_terminating_state_dist(env, sampled_states) - - # Compute true distribution for this condition value - uniform_dist = torch.ones(env.n_states, device=device) / env.n_states - # Get original HyperGrid true_dist - env.set_conditions(torch.ones((1, 1), device=device)) - original_true_dist = env.true_dist - # Interpolate - true_dist = (1 - cond_val) * uniform_dist + cond_val * original_true_dist - - # L1 distance as computed in validate function - l1_dist = (empirical_dist - true_dist).abs().mean().item() - l1_dists.append(l1_dist) + if isinstance(env, DiscreteEnv): + empirical_dist = env.get_terminating_state_dist(sampled_states) + # Compute true distribution for this condition values + true_conditional_dist = env.true_dist( + torch.tensor([cond_val], device=device) + ) + # L1 distance as computed in validate function + l1_dist = ( + (empirical_dist - true_conditional_dist).abs().mean().item() + ) + l1_dists.append(l1_dist) # Print concise results - print( - "Iter {}: L1=[{:.6f}, {:.6f}, {:.6f}, {:.6f}, {:.6f}] modes={}".format( - it + 1, - l1_dists[0], - l1_dists[1], - l1_dists[2], - l1_dists[3], - l1_dists[4], - len(discovered_modes) / n_pixels_per_mode, - ) - ) + log_str = f"Iter {it + 1}: " + if len(l1_dists) > 0: + l1_dists_str = [f"{l1_dist:.6f}" for l1_dist in l1_dists] + l1_dists_str = ", ".join(l1_dists_str) + log_str += f"L1=[{l1_dists_str}], " + log_str += f"modes={len(discovered_modes) / n_pixels_per_mode}" + print(log_str) print("\n" + "=" * 60) print("+ Training complete!") @@ -488,7 +502,6 @@ def evaluate_conditional_sampling(env, gflownet, device, n_eval_samples=10000): conditions = torch.full( (n_eval_samples, 1), cond_value, dtype=torch.float, device=device ) - env.set_conditions(conditions) # Sample without exploration print(f"Sampling {n_eval_samples} trajectories with condition={cond_value}...") @@ -505,21 +518,7 @@ def evaluate_conditional_sampling(env, gflownet, device, n_eval_samples=10000): term_states = cast(DiscreteStates, trajectories.terminating_states) empirical_dist = get_terminating_state_dist(env, term_states) - - # Get true distribution for this condition - # Linear interpolation between uniform and original - uniform_dist = torch.ones(env.n_states, device=device) / env.n_states - - # Temporarily set single condition to get true_dist for original hypergrid - original_cond = torch.ones((1, 1), device=device) - env.set_conditions(original_cond) - original_dist = env.true_dist - - # Compute interpolated true distribution - true_dist = (1 - cond_value) * uniform_dist + cond_value * original_dist - - # Restore condition for this batch - env.set_conditions(conditions) + true_dist = env.true_dist(torch.tensor([cond_value], device=device)) if cond_value == 0: dist_type = "Uniform" @@ -746,7 +745,7 @@ def main(args): parser.add_argument( "--batch_size", type=int, - default=500, + default=200, help="Batch size, i.e. number of trajectories to sample per training iteration", ) parser.add_argument(