-
Notifications
You must be signed in to change notification settings - Fork 52
Refactor Conditional GFlowNets #431
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
ea5e1f5
1c25f04
97ac4a8
915f171
74c85f6
b3496cd
9839e10
0d577a9
580d00a
5187a8a
e299081
8ef7951
7eb8255
0cae53e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,9 +1,12 @@ | ||
| from __future__ import annotations | ||
|
|
||
| from functools import partial | ||
| from typing import TYPE_CHECKING, Generic, Sequence, TypeVar, cast | ||
|
|
||
| import torch | ||
|
|
||
| from gfn.env import ConditionalEnv | ||
|
|
||
| if TYPE_CHECKING: | ||
| from gfn.env import Env | ||
| from gfn.states import States | ||
|
|
@@ -71,7 +74,8 @@ def __init__( | |
|
|
||
| self.conditions = conditions | ||
| assert self.conditions is None or ( | ||
| self.conditions.shape[: len(batch_shape)] == batch_shape | ||
| len(self.conditions.shape) == 2 | ||
| and self.conditions.shape[0] == len(self.states) | ||
| ) | ||
|
|
||
| self.is_terminating = ( | ||
|
|
@@ -178,7 +182,15 @@ def log_rewards(self) -> torch.Tensor: | |
| fill_value=-float("inf"), | ||
| device=self.states.device, | ||
| ) | ||
| self._log_rewards[self.is_terminating] = self.env.log_reward( | ||
| if isinstance(self.env, ConditionalEnv): | ||
| assert self.conditions is not None | ||
| log_reward_fn = partial( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nice! |
||
| self.env.log_reward, | ||
| conditions=self.conditions[self.is_terminating], | ||
| ) | ||
| else: | ||
| log_reward_fn = self.env.log_reward | ||
| self._log_rewards[self.is_terminating] = log_reward_fn( | ||
| self.terminating_states | ||
| ) | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -1,5 +1,6 @@ | ||||||||||||||||||||||||||||||||||||||||||||||||
| from __future__ import annotations | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| from functools import partial | ||||||||||||||||||||||||||||||||||||||||||||||||
| from typing import Sequence | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| import torch | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -8,7 +9,7 @@ | |||||||||||||||||||||||||||||||||||||||||||||||
| from gfn.containers.base import Container | ||||||||||||||||||||||||||||||||||||||||||||||||
| from gfn.containers.states_container import StatesContainer | ||||||||||||||||||||||||||||||||||||||||||||||||
| from gfn.containers.transitions import Transitions | ||||||||||||||||||||||||||||||||||||||||||||||||
| from gfn.env import Env | ||||||||||||||||||||||||||||||||||||||||||||||||
| from gfn.env import ConditionalEnv, Env | ||||||||||||||||||||||||||||||||||||||||||||||||
| from gfn.states import DiscreteStates, GraphStates, States | ||||||||||||||||||||||||||||||||||||||||||||||||
| from gfn.utils.common import ensure_same_device, is_int_dtype | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -109,8 +110,8 @@ def __init__( | |||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| self.conditions = conditions | ||||||||||||||||||||||||||||||||||||||||||||||||
| assert self.conditions is None or ( | ||||||||||||||||||||||||||||||||||||||||||||||||
| self.conditions.shape[: len(self.states.batch_shape)] | ||||||||||||||||||||||||||||||||||||||||||||||||
| == self.states.batch_shape | ||||||||||||||||||||||||||||||||||||||||||||||||
| len(self.conditions.shape) == 2 | ||||||||||||||||||||||||||||||||||||||||||||||||
| and self.conditions.shape[0] == self.n_trajectories | ||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| self.actions = ( | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -245,7 +246,12 @@ def log_rewards(self) -> torch.Tensor | None: | |||||||||||||||||||||||||||||||||||||||||||||||
| return None | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| if self._log_rewards is None: | ||||||||||||||||||||||||||||||||||||||||||||||||
| self._log_rewards = self.env.log_reward(self.terminating_states) | ||||||||||||||||||||||||||||||||||||||||||||||||
| if isinstance(self.env, ConditionalEnv): | ||||||||||||||||||||||||||||||||||||||||||||||||
| assert self.conditions is not None | ||||||||||||||||||||||||||||||||||||||||||||||||
| log_reward_fn = partial(self.env.log_reward, conditions=self.conditions) | ||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||
| log_reward_fn = self.env.log_reward | ||||||||||||||||||||||||||||||||||||||||||||||||
| self._log_rewards = log_reward_fn(self.terminating_states) | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| assert self._log_rewards.shape == (self.n_trajectories,) | ||||||||||||||||||||||||||||||||||||||||||||||||
| return self._log_rewards | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -266,7 +272,7 @@ def __getitem__( | |||||||||||||||||||||||||||||||||||||||||||||||
| terminating_idx = self.terminating_idx[index] | ||||||||||||||||||||||||||||||||||||||||||||||||
| new_max_length = terminating_idx.max().item() if len(terminating_idx) > 0 else 0 | ||||||||||||||||||||||||||||||||||||||||||||||||
| states = self.states[:, index] | ||||||||||||||||||||||||||||||||||||||||||||||||
| conditions = self.conditions[:, index] if self.conditions is not None else None | ||||||||||||||||||||||||||||||||||||||||||||||||
| conditions = self.conditions[index] if self.conditions is not None else None | ||||||||||||||||||||||||||||||||||||||||||||||||
| actions = self.actions[:, index] | ||||||||||||||||||||||||||||||||||||||||||||||||
| states = states[: 1 + new_max_length] | ||||||||||||||||||||||||||||||||||||||||||||||||
| actions = actions[:new_max_length] | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -311,14 +317,8 @@ def extend(self, other: Trajectories) -> None: | |||||||||||||||||||||||||||||||||||||||||||||||
| log_rewards). | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| Args: | ||||||||||||||||||||||||||||||||||||||||||||||||
| Another Trajectories to append. | ||||||||||||||||||||||||||||||||||||||||||||||||
| other: Another Trajectories to append. | ||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||
| if self.conditions is not None: | ||||||||||||||||||||||||||||||||||||||||||||||||
| # TODO: Support the case | ||||||||||||||||||||||||||||||||||||||||||||||||
| raise NotImplementedError( | ||||||||||||||||||||||||||||||||||||||||||||||||
| "`extend` is not implemented for conditional Trajectories." | ||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| if len(other) == 0: | ||||||||||||||||||||||||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -346,6 +346,12 @@ def extend(self, other: Trajectories) -> None: | |||||||||||||||||||||||||||||||||||||||||||||||
| (self.terminating_idx, other.terminating_idx), dim=0 | ||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| # Concatenate conditions of the trajectories. | ||||||||||||||||||||||||||||||||||||||||||||||||
| if self.conditions is not None and other.conditions is not None: | ||||||||||||||||||||||||||||||||||||||||||||||||
| self.conditions = torch.cat((self.conditions, other.conditions), dim=0) | ||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||
| self.conditions = None | ||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+349
to
+353
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we maybe add a test for extending with conditions, and then try common ops like get_item to check the output is as expected?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I will add one.
I have no idea what this means. Could you elaborate more?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I mean in the test, after calling extend, check if the extend operation gave the expected result. torchgfn/testing/test_states.py Lines 432 to 454 in c3f3096
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see. I will add a test soon! |
||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| # Concatenate log_rewards of the trajectories. | ||||||||||||||||||||||||||||||||||||||||||||||||
| if self._log_rewards is not None and other._log_rewards is not None: | ||||||||||||||||||||||||||||||||||||||||||||||||
| self._log_rewards = torch.cat((self._log_rewards, other._log_rewards), dim=0) | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -387,18 +393,21 @@ def to_transitions(self) -> Transitions: | |||||||||||||||||||||||||||||||||||||||||||||||
| A Transitions object with the same states, actions, and log_rewards as the | ||||||||||||||||||||||||||||||||||||||||||||||||
| current Trajectories. | ||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||
| valid_action_mask = ~self.actions.is_dummy | ||||||||||||||||||||||||||||||||||||||||||||||||
| if self.conditions is not None: | ||||||||||||||||||||||||||||||||||||||||||||||||
| # The conditions tensor has shape (max_length, n_trajectories, 1) | ||||||||||||||||||||||||||||||||||||||||||||||||
| # The actions have shape (max_length, n_trajectories) | ||||||||||||||||||||||||||||||||||||||||||||||||
| # We need to index the conditions tensor to match the actions | ||||||||||||||||||||||||||||||||||||||||||||||||
| # The actions exclude the last step, so we need to exclude the last step from conditions | ||||||||||||||||||||||||||||||||||||||||||||||||
| conditions = self.conditions[:-1][~self.actions.is_dummy] | ||||||||||||||||||||||||||||||||||||||||||||||||
| # The conditions tensor has shape (n_trajectories, condition_vector_dim) | ||||||||||||||||||||||||||||||||||||||||||||||||
| # The actions have batch shape (max_length, n_trajectories) | ||||||||||||||||||||||||||||||||||||||||||||||||
| # We need to repeat the condition vector tensor to match the actions | ||||||||||||||||||||||||||||||||||||||||||||||||
| conditions = self.conditions.repeat(self.actions.batch_shape[0], 1, 1) | ||||||||||||||||||||||||||||||||||||||||||||||||
| assert conditions.shape[:2] == self.actions.batch_shape | ||||||||||||||||||||||||||||||||||||||||||||||||
| # Then we mask it with the valid action mask. | ||||||||||||||||||||||||||||||||||||||||||||||||
| conditions = conditions[valid_action_mask] | ||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||
| conditions = None | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| states = self.states[:-1][~self.actions.is_dummy] | ||||||||||||||||||||||||||||||||||||||||||||||||
| next_states = self.states[1:][~self.actions.is_dummy] | ||||||||||||||||||||||||||||||||||||||||||||||||
| actions = self.actions[~self.actions.is_dummy] | ||||||||||||||||||||||||||||||||||||||||||||||||
| states = self.states[:-1][valid_action_mask] | ||||||||||||||||||||||||||||||||||||||||||||||||
| next_states = self.states[1:][valid_action_mask] | ||||||||||||||||||||||||||||||||||||||||||||||||
| actions = self.actions[valid_action_mask] | ||||||||||||||||||||||||||||||||||||||||||||||||
| is_terminating = ( | ||||||||||||||||||||||||||||||||||||||||||||||||
| next_states.is_sink_state | ||||||||||||||||||||||||||||||||||||||||||||||||
| if not self.is_backward | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -454,9 +463,7 @@ def to_states_container(self) -> StatesContainer: | |||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||
| is_terminating[self.terminating_idx - 1, torch.arange(len(self))] = True | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| states = self.states.flatten() | ||||||||||||||||||||||||||||||||||||||||||||||||
| is_terminating = is_terminating.flatten() | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| states = self.states | ||||||||||||||||||||||||||||||||||||||||||||||||
| is_valid = ~states.is_sink_state & ( | ||||||||||||||||||||||||||||||||||||||||||||||||
| ~states.is_initial_state | (states.is_initial_state & is_terminating) | ||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -465,40 +472,25 @@ def to_states_container(self) -> StatesContainer: | |||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| conditions = None | ||||||||||||||||||||||||||||||||||||||||||||||||
| if self.conditions is not None: | ||||||||||||||||||||||||||||||||||||||||||||||||
| # The conditions tensor has shape (max_length, n_trajectories, 1) | ||||||||||||||||||||||||||||||||||||||||||||||||
| # We need to flatten it to match the flattened states | ||||||||||||||||||||||||||||||||||||||||||||||||
| # First, we need to repeat it to match the flattened shape | ||||||||||||||||||||||||||||||||||||||||||||||||
| # The flattened states have shape (max_length * n_trajectories,) | ||||||||||||||||||||||||||||||||||||||||||||||||
| # So we need to repeat the conditions tensor accordingly | ||||||||||||||||||||||||||||||||||||||||||||||||
| conditions = self.conditions.flatten(0, 1)[is_valid] | ||||||||||||||||||||||||||||||||||||||||||||||||
| # The conditions tensor has shape (n_trajectories, condition_vector_dim) | ||||||||||||||||||||||||||||||||||||||||||||||||
| # The states have batch shape (max_length, n_trajectories) | ||||||||||||||||||||||||||||||||||||||||||||||||
| # We need to repeat the conditions to match the batch shape of the states. | ||||||||||||||||||||||||||||||||||||||||||||||||
| conditions = self.conditions.repeat(self.states.batch_shape[0], 1, 1) | ||||||||||||||||||||||||||||||||||||||||||||||||
| # (max_length, n_trajectories, condition_vector_dim) | ||||||||||||||||||||||||||||||||||||||||||||||||
| assert conditions.shape[:2] == self.states.batch_shape | ||||||||||||||||||||||||||||||||||||||||||||||||
| # Then we mask it with the valid state mask. | ||||||||||||||||||||||||||||||||||||||||||||||||
| conditions = conditions[is_valid] | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| if self.log_rewards is None: | ||||||||||||||||||||||||||||||||||||||||||||||||
| log_rewards = None | ||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||
| log_rewards = torch.full( | ||||||||||||||||||||||||||||||||||||||||||||||||
| (len(states),), | ||||||||||||||||||||||||||||||||||||||||||||||||
| fill_value=-float("inf"), | ||||||||||||||||||||||||||||||||||||||||||||||||
| device=states.device, | ||||||||||||||||||||||||||||||||||||||||||||||||
| self.states.batch_shape, fill_value=-float("inf"), device=states.device | ||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||
| # Get the original indices (before flattening and filtering). | ||||||||||||||||||||||||||||||||||||||||||||||||
| orig_batch_indices = torch.arange( | ||||||||||||||||||||||||||||||||||||||||||||||||
| self.states.batch_shape[0], device=states.device | ||||||||||||||||||||||||||||||||||||||||||||||||
| ).repeat_interleave(self.states.batch_shape[1]) | ||||||||||||||||||||||||||||||||||||||||||||||||
| orig_traj_indices = torch.arange( | ||||||||||||||||||||||||||||||||||||||||||||||||
| self.states.batch_shape[1], device=states.device | ||||||||||||||||||||||||||||||||||||||||||||||||
| ).repeat(self.states.batch_shape[0]) | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| # Retain only the valid indices. | ||||||||||||||||||||||||||||||||||||||||||||||||
| valid_batch_indices = orig_batch_indices[is_valid] | ||||||||||||||||||||||||||||||||||||||||||||||||
| valid_traj_indices = orig_traj_indices[is_valid] | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| # Assign rewards to valid terminating states. | ||||||||||||||||||||||||||||||||||||||||||||||||
| terminating_mask = is_terminating & ( | ||||||||||||||||||||||||||||||||||||||||||||||||
| valid_batch_indices == (self.terminating_idx[valid_traj_indices] - 1) | ||||||||||||||||||||||||||||||||||||||||||||||||
| log_rewards[self.terminating_idx - 1, torch.arange(len(self))] = ( | ||||||||||||||||||||||||||||||||||||||||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. really nice cleanup here! |
||||||||||||||||||||||||||||||||||||||||||||||||
| self.log_rewards | ||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||
| log_rewards[terminating_mask] = self.log_rewards[ | ||||||||||||||||||||||||||||||||||||||||||||||||
| valid_traj_indices[terminating_mask] | ||||||||||||||||||||||||||||||||||||||||||||||||
| ] | ||||||||||||||||||||||||||||||||||||||||||||||||
| log_rewards = log_rewards[is_valid] | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| return StatesContainer[DiscreteStates]( | ||||||||||||||||||||||||||||||||||||||||||||||||
| env=self.env, | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -414,7 +414,7 @@ def log_partition(self) -> float: | |
| The log partition function. | ||
| """ | ||
| raise NotImplementedError( | ||
| "The environment does not support enumeration of states" | ||
| "The environment does not support calculating the log partition" | ||
| ) | ||
|
|
||
| @property | ||
|
|
@@ -429,6 +429,84 @@ def true_dist(self) -> torch.Tensor: | |
| ) | ||
|
|
||
|
|
||
| class ConditionalEnv(Env, ABC): | ||
| """Base class for conditional environments. | ||
|
|
||
| Conditional environments are environments with condition variables. | ||
| For now, we assume that the conditions only affect the rewards, not the dynamics of | ||
| the environment. | ||
| """ | ||
|
|
||
| @abstractmethod | ||
| def sample_conditions(self, batch_size: int) -> torch.Tensor: | ||
| """Sample conditions for the environment. | ||
|
|
||
| Args: | ||
| batch_size: The number of conditions to sample. | ||
|
|
||
| Returns: | ||
| A tensor of shape (batch_size, condition_vector_dim) containing the conditions. | ||
| """ | ||
| raise NotImplementedError | ||
|
|
||
| def reward(self, states: States, conditions: torch.Tensor) -> torch.Tensor: | ||
| """Compute rewards for the conditional environment. | ||
|
|
||
| Args: | ||
| states: The states to compute rewards for. | ||
| states.tensor.shape should be (batch_size, *state_shape) | ||
| conditions: The conditions to compute rewards for. | ||
| conditions.shape should be (batch_size, condition_vector_dim) | ||
|
|
||
| Returns: | ||
| A tensor of shape (batch_size,) containing the rewards. | ||
| """ | ||
| raise NotImplementedError | ||
|
Comment on lines
+452
to
+464
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. aha, this is not a real subclass of Env, as conditions are mandatory (i.e. if you can't call this function pretending it is an env obj while it is ConditionEnv). Would it make sense to have a default condition?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
How could having a default condition solve the problem?
Maybe, but still we need a parent class that defines the default methods for Envs, like
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
If we have a function like this: def get_reward(env: Env, states: States) -> torch.Tensor:
return env.reward(states)This should work with any Env object, given the interface of Env. However, currently, if I pass a ConditionEnv (which is an Env), this will fail as you need to specify the conditioning. If you have a default value for conditioning, now the get_reward function will work properly (indeed, with default, the reward function interface of ConditionEnv becomes a subtype of the one of Env)
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. An alternative approach would be to have the conditions live inside the states themselves (states could have a conditioning field that is None unless conditioning is required, and then anything that accepts The env itself would only be conditional or not depending on the logic the user defines in the reward and step functions. No actual The estimators would also optionally use the conditioning information, if it's present, just like how it's done currently. |
||
|
|
||
| def log_reward(self, states: States, conditions: torch.Tensor) -> torch.Tensor: | ||
| """Compute log rewards for the conditional environment. | ||
|
|
||
| Args: | ||
| states: The states to compute log rewards for. | ||
| states.tensor.shape should be (batch_size, *state_shape) | ||
| conditions: The conditions to compute log rewards for. | ||
| conditions.shape should be (batch_size, condition_vector_dim) | ||
|
|
||
| Returns: | ||
| A tensor of shape (batch_size,) containing the log rewards. | ||
| """ | ||
| return torch.log(self.reward(states, conditions)) | ||
|
|
||
| def log_partition(self, condition: torch.Tensor) -> float: | ||
| """Optional method to return the logarithm of the partition function for a | ||
| given condition. | ||
|
|
||
| Args: | ||
| condition: The condition to compute the log partition for. | ||
| condition.shape should be (condition_vector_dim,) | ||
|
|
||
| Returns: | ||
| The log partition function, as a float. | ||
| """ | ||
| raise NotImplementedError( | ||
| "The environment may not support enumeration of states" | ||
| ) | ||
|
|
||
| def true_dist(self, condition: torch.Tensor) -> torch.Tensor: | ||
| """Optional method to return the true distribution for a given condition. | ||
|
|
||
| Args: | ||
| condition: The condition to compute the true distribution for. | ||
| condition.shape should be (condition_vector_dim,) | ||
|
|
||
| Returns: | ||
| The true distribution for the given condition as a 1-dimensional tensor. | ||
| """ | ||
| raise NotImplementedError( | ||
| "The environment may not support enumeration of states" | ||
| ) | ||
|
|
||
|
|
||
| class DiscreteEnv(Env, ABC): | ||
| """Base class for discrete environments, where states are defined in a discrete | ||
| space, and actions are represented by an integer in {0, ..., n_actions - 1}, the | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
right, because we assume the conditioning would not change through the trajectory?