Skip to content
16 changes: 14 additions & 2 deletions src/gfn/containers/states_container.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from __future__ import annotations

from functools import partial
from typing import TYPE_CHECKING, Generic, Sequence, TypeVar, cast

import torch

from gfn.env import ConditionalEnv

if TYPE_CHECKING:
from gfn.env import Env
from gfn.states import States
Expand Down Expand Up @@ -71,7 +74,8 @@ def __init__(

self.conditions = conditions
assert self.conditions is None or (
self.conditions.shape[: len(batch_shape)] == batch_shape
len(self.conditions.shape) == 2
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

right, because we assume the conditioning would not change through the trajectory?

and self.conditions.shape[0] == len(self.states)
)

self.is_terminating = (
Expand Down Expand Up @@ -178,7 +182,15 @@ def log_rewards(self) -> torch.Tensor:
fill_value=-float("inf"),
device=self.states.device,
)
self._log_rewards[self.is_terminating] = self.env.log_reward(
if isinstance(self.env, ConditionalEnv):
assert self.conditions is not None
log_reward_fn = partial(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

self.env.log_reward,
conditions=self.conditions[self.is_terminating],
)
else:
log_reward_fn = self.env.log_reward
self._log_rewards[self.is_terminating] = log_reward_fn(
self.terminating_states
)

Expand Down
92 changes: 42 additions & 50 deletions src/gfn/containers/trajectories.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from functools import partial
from typing import Sequence

import torch
Expand All @@ -8,7 +9,7 @@
from gfn.containers.base import Container
from gfn.containers.states_container import StatesContainer
from gfn.containers.transitions import Transitions
from gfn.env import Env
from gfn.env import ConditionalEnv, Env
from gfn.states import DiscreteStates, GraphStates, States
from gfn.utils.common import ensure_same_device, is_int_dtype

Expand Down Expand Up @@ -109,8 +110,8 @@ def __init__(

self.conditions = conditions
assert self.conditions is None or (
self.conditions.shape[: len(self.states.batch_shape)]
== self.states.batch_shape
len(self.conditions.shape) == 2
and self.conditions.shape[0] == self.n_trajectories
)

self.actions = (
Expand Down Expand Up @@ -245,7 +246,12 @@ def log_rewards(self) -> torch.Tensor | None:
return None

if self._log_rewards is None:
self._log_rewards = self.env.log_reward(self.terminating_states)
if isinstance(self.env, ConditionalEnv):
assert self.conditions is not None
log_reward_fn = partial(self.env.log_reward, conditions=self.conditions)
else:
log_reward_fn = self.env.log_reward
self._log_rewards = log_reward_fn(self.terminating_states)

assert self._log_rewards.shape == (self.n_trajectories,)
return self._log_rewards
Expand All @@ -266,7 +272,7 @@ def __getitem__(
terminating_idx = self.terminating_idx[index]
new_max_length = terminating_idx.max().item() if len(terminating_idx) > 0 else 0
states = self.states[:, index]
conditions = self.conditions[:, index] if self.conditions is not None else None
conditions = self.conditions[index] if self.conditions is not None else None
actions = self.actions[:, index]
states = states[: 1 + new_max_length]
actions = actions[:new_max_length]
Expand Down Expand Up @@ -311,14 +317,8 @@ def extend(self, other: Trajectories) -> None:
log_rewards).

Args:
Another Trajectories to append.
other: Another Trajectories to append.
"""
if self.conditions is not None:
# TODO: Support the case
raise NotImplementedError(
"`extend` is not implemented for conditional Trajectories."
)

if len(other) == 0:
return

Expand Down Expand Up @@ -346,6 +346,12 @@ def extend(self, other: Trajectories) -> None:
(self.terminating_idx, other.terminating_idx), dim=0
)

# Concatenate conditions of the trajectories.
if self.conditions is not None and other.conditions is not None:
self.conditions = torch.cat((self.conditions, other.conditions), dim=0)
else:
self.conditions = None
Comment on lines +349 to +353
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we maybe add a test for extending with conditions, and then try common ops like get_item to check the output is as expected?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we maybe add a test for extending with conditions

I will add one.

and then try common ops like get_item to check the output is as expected?

I have no idea what this means. Could you elaborate more?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean in the test, after calling extend, check if the extend operation gave the expected result.
Like here:

pre_extend_shape = state2.batch_shape
state1.extend(state2)
assert state2.batch_shape == pre_extend_shape
# Check final shape should be (max_len=3, B=4)
assert state1.batch_shape == (3, 4)
# The actual count might be higher due to padding with sink states
assert state1.tensor.x.size(0) == expected_nodes
assert state1.tensor.num_edges == expected_edges
# Check if states are extended as expected
assert (state1[0, 0].tensor.x == datas[0].x).all()
assert (state1[0, 1].tensor.x == datas[1].x).all()
assert (state1[0, 2].tensor.x == datas[4].x).all()
assert (state1[0, 3].tensor.x == datas[5].x).all()
assert (state1[1, 0].tensor.x == datas[2].x).all()
assert (state1[1, 1].tensor.x == datas[3].x).all()
assert (state1[1, 2].tensor.x == datas[6].x).all()
assert (state1[1, 3].tensor.x == datas[7].x).all()
assert (state1[2, 0].tensor.x == MyGraphStates.sf.x).all()
assert (state1[2, 1].tensor.x == MyGraphStates.sf.x).all()
assert (state1[2, 2].tensor.x == datas[8].x).all()
assert (state1[2, 3].tensor.x == datas[9].x).all()

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. I will add a test soon!


# Concatenate log_rewards of the trajectories.
if self._log_rewards is not None and other._log_rewards is not None:
self._log_rewards = torch.cat((self._log_rewards, other._log_rewards), dim=0)
Expand Down Expand Up @@ -387,18 +393,21 @@ def to_transitions(self) -> Transitions:
A Transitions object with the same states, actions, and log_rewards as the
current Trajectories.
"""
valid_action_mask = ~self.actions.is_dummy
if self.conditions is not None:
# The conditions tensor has shape (max_length, n_trajectories, 1)
# The actions have shape (max_length, n_trajectories)
# We need to index the conditions tensor to match the actions
# The actions exclude the last step, so we need to exclude the last step from conditions
conditions = self.conditions[:-1][~self.actions.is_dummy]
# The conditions tensor has shape (n_trajectories, condition_vector_dim)
# The actions have batch shape (max_length, n_trajectories)
# We need to repeat the condition vector tensor to match the actions
conditions = self.conditions.repeat(self.actions.batch_shape[0], 1, 1)
assert conditions.shape[:2] == self.actions.batch_shape
# Then we mask it with the valid action mask.
conditions = conditions[valid_action_mask]
else:
conditions = None

states = self.states[:-1][~self.actions.is_dummy]
next_states = self.states[1:][~self.actions.is_dummy]
actions = self.actions[~self.actions.is_dummy]
states = self.states[:-1][valid_action_mask]
next_states = self.states[1:][valid_action_mask]
actions = self.actions[valid_action_mask]
is_terminating = (
next_states.is_sink_state
if not self.is_backward
Expand Down Expand Up @@ -454,9 +463,7 @@ def to_states_container(self) -> StatesContainer:
)
is_terminating[self.terminating_idx - 1, torch.arange(len(self))] = True

states = self.states.flatten()
is_terminating = is_terminating.flatten()

states = self.states
is_valid = ~states.is_sink_state & (
~states.is_initial_state | (states.is_initial_state & is_terminating)
)
Expand All @@ -465,40 +472,25 @@ def to_states_container(self) -> StatesContainer:

conditions = None
if self.conditions is not None:
# The conditions tensor has shape (max_length, n_trajectories, 1)
# We need to flatten it to match the flattened states
# First, we need to repeat it to match the flattened shape
# The flattened states have shape (max_length * n_trajectories,)
# So we need to repeat the conditions tensor accordingly
conditions = self.conditions.flatten(0, 1)[is_valid]
# The conditions tensor has shape (n_trajectories, condition_vector_dim)
# The states have batch shape (max_length, n_trajectories)
# We need to repeat the conditions to match the batch shape of the states.
conditions = self.conditions.repeat(self.states.batch_shape[0], 1, 1)
# (max_length, n_trajectories, condition_vector_dim)
assert conditions.shape[:2] == self.states.batch_shape
# Then we mask it with the valid state mask.
conditions = conditions[is_valid]

if self.log_rewards is None:
log_rewards = None
else:
log_rewards = torch.full(
(len(states),),
fill_value=-float("inf"),
device=states.device,
self.states.batch_shape, fill_value=-float("inf"), device=states.device
)
# Get the original indices (before flattening and filtering).
orig_batch_indices = torch.arange(
self.states.batch_shape[0], device=states.device
).repeat_interleave(self.states.batch_shape[1])
orig_traj_indices = torch.arange(
self.states.batch_shape[1], device=states.device
).repeat(self.states.batch_shape[0])

# Retain only the valid indices.
valid_batch_indices = orig_batch_indices[is_valid]
valid_traj_indices = orig_traj_indices[is_valid]

# Assign rewards to valid terminating states.
terminating_mask = is_terminating & (
valid_batch_indices == (self.terminating_idx[valid_traj_indices] - 1)
log_rewards[self.terminating_idx - 1, torch.arange(len(self))] = (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

really nice cleanup here!

self.log_rewards
)
log_rewards[terminating_mask] = self.log_rewards[
valid_traj_indices[terminating_mask]
]
log_rewards = log_rewards[is_valid]

return StatesContainer[DiscreteStates](
env=self.env,
Expand Down
41 changes: 29 additions & 12 deletions src/gfn/containers/transitions.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from __future__ import annotations

from functools import partial
from typing import TYPE_CHECKING, Sequence

import torch

from gfn.env import ConditionalEnv

if TYPE_CHECKING:
from gfn.actions import Actions
from gfn.env import Env
Expand Down Expand Up @@ -92,7 +95,8 @@ def __init__(

self.conditions = conditions
assert self.conditions is None or (
self.conditions.shape[: len(batch_shape)] == batch_shape
len(self.conditions.shape) == 2
and self.conditions.shape[0] == self.n_transitions
)

self.actions = (
Expand Down Expand Up @@ -204,7 +208,15 @@ def log_rewards(self) -> torch.Tensor | None:
fill_value=-float("inf"),
device=self.states.device,
)
self._log_rewards[self.is_terminating] = self.env.log_reward(
if isinstance(self.env, ConditionalEnv):
assert self.conditions is not None
log_reward_fn = partial(
self.env.log_reward,
conditions=self.conditions[self.is_terminating],
)
else:
log_reward_fn = self.env.log_reward
self._log_rewards[self.is_terminating] = log_reward_fn(
self.terminating_states
)

Expand All @@ -231,10 +243,15 @@ def all_log_rewards(self) -> torch.Tensor:
fill_value=-float("inf"),
device=self.states.device,
)
log_rewards[~is_sink_state, 0] = self.env.log_reward(self.states[~is_sink_state])
log_rewards[~is_sink_state, 1] = self.env.log_reward(
self.next_states[~is_sink_state]
)
if isinstance(self.env, ConditionalEnv):
assert self.conditions is not None
log_reward_fn = partial(
self.env.log_reward, conditions=self.conditions[~is_sink_state]
)
else:
log_reward_fn = self.env.log_reward
log_rewards[~is_sink_state, 0] = log_reward_fn(self.states[~is_sink_state])
log_rewards[~is_sink_state, 1] = log_reward_fn(self.next_states[~is_sink_state])

assert (
log_rewards.shape == (self.n_transitions, 2)
Expand Down Expand Up @@ -282,12 +299,6 @@ def extend(self, other: Transitions) -> None:
Args:
Another Transitions object to append.
"""
if self.conditions is not None:
# TODO: Support the case
raise NotImplementedError(
"`extend` is not implemented for conditional Transitions."
)

if len(other) == 0:
return

Expand All @@ -314,6 +325,12 @@ def extend(self, other: Transitions) -> None:
)
self.next_states.extend(other.next_states)

# Concatenate conditions of the transitions.
if self.conditions is not None and other.conditions is not None:
self.conditions = torch.cat((self.conditions, other.conditions), dim=0)
else:
self.conditions = None

# Concatenate log_rewards of the trajectories.
if self._log_rewards is not None and other._log_rewards is not None:
self._log_rewards = torch.cat((self._log_rewards, other._log_rewards), dim=0)
Expand Down
80 changes: 79 additions & 1 deletion src/gfn/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ def log_partition(self) -> float:
The log partition function.
"""
raise NotImplementedError(
"The environment does not support enumeration of states"
"The environment does not support calculating the log partition"
)

@property
Expand All @@ -429,6 +429,84 @@ def true_dist(self) -> torch.Tensor:
)


class ConditionalEnv(Env, ABC):
"""Base class for conditional environments.

Conditional environments are environments with condition variables.
For now, we assume that the conditions only affect the rewards, not the dynamics of
the environment.
"""

@abstractmethod
def sample_conditions(self, batch_size: int) -> torch.Tensor:
"""Sample conditions for the environment.

Args:
batch_size: The number of conditions to sample.

Returns:
A tensor of shape (batch_size, condition_vector_dim) containing the conditions.
"""
raise NotImplementedError

def reward(self, states: States, conditions: torch.Tensor) -> torch.Tensor:
"""Compute rewards for the conditional environment.

Args:
states: The states to compute rewards for.
states.tensor.shape should be (batch_size, *state_shape)
conditions: The conditions to compute rewards for.
conditions.shape should be (batch_size, condition_vector_dim)

Returns:
A tensor of shape (batch_size,) containing the rewards.
"""
raise NotImplementedError
Comment on lines +452 to +464
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

aha, this is not a real subclass of Env, as conditions are mandatory (i.e. if you can't call this function pretending it is an env obj while it is ConditionEnv).

Would it make sense to have a default condition?
If not, this shouldn't inehrit from Env probably.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to have a default condition?

How could having a default condition solve the problem?

If not, this shouldn't inherit from Env probably.

Maybe, but still we need a parent class that defines the default methods for Envs, like reward, step, etc...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How could having a default condition solve the problem?

If we have a function like this:

def get_reward(env: Env, states: States) -> torch.Tensor:
   return  env.reward(states)

This should work with any Env object, given the interface of Env.

However, currently, if I pass a ConditionEnv (which is an Env), this will fail as you need to specify the conditioning. If you have a default value for conditioning, now the get_reward function will work properly (indeed, with default, the reward function interface of ConditionEnv becomes a subtype of the one of Env)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An alternative approach would be to have the conditions live inside the states themselves (states could have a conditioning field that is None unless conditioning is required, and then anything that accepts States follows a different path when conditioning is present).

The env itself would only be conditional or not depending on the logic the user defines in the reward and step functions. No actual ConditionalEnv class would be required.

The estimators would also optionally use the conditioning information, if it's present, just like how it's done currently.


def log_reward(self, states: States, conditions: torch.Tensor) -> torch.Tensor:
"""Compute log rewards for the conditional environment.

Args:
states: The states to compute log rewards for.
states.tensor.shape should be (batch_size, *state_shape)
conditions: The conditions to compute log rewards for.
conditions.shape should be (batch_size, condition_vector_dim)

Returns:
A tensor of shape (batch_size,) containing the log rewards.
"""
return torch.log(self.reward(states, conditions))

def log_partition(self, condition: torch.Tensor) -> float:
"""Optional method to return the logarithm of the partition function for a
given condition.

Args:
condition: The condition to compute the log partition for.
condition.shape should be (condition_vector_dim,)

Returns:
The log partition function, as a float.
"""
raise NotImplementedError(
"The environment may not support enumeration of states"
)

def true_dist(self, condition: torch.Tensor) -> torch.Tensor:
"""Optional method to return the true distribution for a given condition.

Args:
condition: The condition to compute the true distribution for.
condition.shape should be (condition_vector_dim,)

Returns:
The true distribution for the given condition as a 1-dimensional tensor.
"""
raise NotImplementedError(
"The environment may not support enumeration of states"
)


class DiscreteEnv(Env, ABC):
"""Base class for discrete environments, where states are defined in a discrete
space, and actions are represented by an integer in {0, ..., n_actions - 1}, the
Expand Down
Loading