Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 103 additions & 1 deletion src/gfn/env.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
import warnings
from abc import ABC, abstractmethod
from collections import Counter
from typing import TYPE_CHECKING, Dict, Optional, Tuple, cast
from dataclasses import dataclass
from typing import (
TYPE_CHECKING,
Dict,
Optional,
Tuple,
cast,
)

if TYPE_CHECKING:
from gfn.gflownet import GFlowNet
Expand All @@ -17,6 +24,25 @@
NonValidActionsError = type("NonValidActionsError", (ValueError,), {})


class EnvFastPathMixin:
"""Marker mixin for environments exposing tensor-only fast-path helpers.

Environments inheriting this mixin are expected to override:

- ``step_tensor``: vectorized transition operating purely on tensors.
- ``forward_action_masks_tensor``: tensor-based forward action masks.
- ``states_from_tensor_fast``: lightweight wrapper that avoids redundant
allocations when reconstructing ``States`` objects from raw tensors.

The mixin itself does not provide implementations; it purely signals that
the environment intends to support the fast path and enables nominal checks
such as ``isinstance(env, EnvFastPathMixin)`` without relying on structural
typing.
"""

fast_path_enabled: bool = True


class Env(ABC):
"""Base class for all environments.

Expand All @@ -37,6 +63,22 @@ class Env(ABC):

is_discrete: bool = False

@dataclass
class TensorStepResult:
"""Container returned by tensor-level step helpers.

Attributes:
next_states: Tensor containing the next states produced by the step.
is_sink_state: Optional boolean tensor indicating which rows are sink.
forward_masks: Optional boolean tensor with forward action masks.
backward_masks: Optional boolean tensor with backward action masks.
"""

next_states: torch.Tensor
is_sink_state: torch.Tensor | None = None
forward_masks: torch.Tensor | None = None
backward_masks: torch.Tensor | None = None

def __init__(
self,
s0: torch.Tensor | GeometricData,
Expand Down Expand Up @@ -145,6 +187,51 @@ def actions_from_batch_shape(self, batch_shape: Tuple) -> Actions:
"""
return self.Actions.make_dummy_actions(batch_shape, device=self.device)

@property
def has_tensor_fast_path(self) -> bool:
"""Whether this environment opts into the tensor-only fast API."""

return isinstance(self, EnvFastPathMixin)

def states_from_tensor_fast(self, tensor: torch.Tensor) -> States:
"""Fallback helper recreating ``States`` objects from tensors.

Fast-path environments can override this to avoid redundant mask
recomputation or to attach cached metadata. The default simply calls
``states_from_tensor``.
"""

return self.states_from_tensor(tensor)

def step_tensor(
self, states_tensor: torch.Tensor, actions_tensor: torch.Tensor
) -> "Env.TensorStepResult":
"""Tensor equivalent of `_step` with default object-based fallback.

Environments can override this method to provide compiler-friendly
implementations that avoid constructing `States`/`Actions`. The default
fallback simply wraps tensors into the standard containers and delegates
to `_step`, ensuring parity with the legacy path.
"""

states = self.states_from_tensor(states_tensor.clone())
actions = self.actions_from_tensor(actions_tensor.clone())
new_states = self._step(states, actions)
return self.TensorStepResult(next_states=new_states.tensor.clone())

def forward_action_masks_tensor(self, states_tensor: torch.Tensor) -> torch.Tensor:
"""Tensor helper returning forward masks for the supplied states.

Base environments do not provide a generic implementation because mask
semantics are environment-specific. Subclasses (e.g., ``DiscreteEnv``)
are expected to override this to expose a fallback compatible with the
fast sampler path.
"""

raise NotImplementedError(
f"{self.__class__.__name__} does not expose tensor forward masks."
)

@abstractmethod
def step(self, states: States, actions: Actions) -> States:
"""Forward transition function of the environment.
Expand Down Expand Up @@ -559,6 +646,21 @@ def states_from_batch_shape(
assert isinstance(out, DiscreteStates)
return out

def states_from_tensor_fast(self, tensor: torch.Tensor) -> DiscreteStates:
"""Return `DiscreteStates` without extra bookkeeping for fast paths."""

states = self.states_from_tensor(tensor)
assert isinstance(states, DiscreteStates)
return states

def forward_action_masks_tensor(self, states_tensor: torch.Tensor) -> torch.Tensor:
"""Recompute forward masks for the supplied state tensor."""

states = self.states_from_tensor(states_tensor.clone())
self.update_masks(states)
assert states.forward_masks is not None
return states.forward_masks.clone()

def reset(
self,
batch_shape: int | Tuple[int, ...],
Expand Down
103 changes: 97 additions & 6 deletions src/gfn/estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,45 @@ def get_current_estimator_output(self, ctx: Any) -> Optional[torch.Tensor]:
return getattr(ctx, "current_estimator_output", None)


class FastPolicyMixin(PolicyMixin):
"""Optional mixin for policies that ingest tensors directly on fast paths.

Estimators inheriting this mixin should implement the tensor-oriented hooks
below so samplers can bypass `States`/`Actions` allocation when environments
expose compatible helpers.
"""

fast_path_enabled: bool = True

def fast_features(
self,
states_tensor: torch.Tensor,
*,
forward_masks: Optional[torch.Tensor] = None,
backward_masks: Optional[torch.Tensor] = None,
conditions: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Preprocess raw tensors into module-ready features."""

raise NotImplementedError(
f"{self.__class__.__name__} does not implement fast_features."
)

def fast_distribution(
self,
features: torch.Tensor,
*,
forward_masks: Optional[torch.Tensor] = None,
backward_masks: Optional[torch.Tensor] = None,
**policy_kwargs: Any,
) -> Distribution:
"""Build the action distribution from tensor features."""

raise NotImplementedError(
f"{self.__class__.__name__} does not implement fast_distribution."
)


class RecurrentPolicyMixin(PolicyMixin):
"""Mixin for recurrent policies that maintain and update a rollout carry."""

Expand Down Expand Up @@ -1227,7 +1266,7 @@ def init_carry(
return init_carry_fn(batch_size, device)


class DiffusionPolicyEstimator(PolicyMixin, Estimator):
class DiffusionPolicyEstimator(FastPolicyMixin, Estimator):
"""Base class for diffusion policy estimators."""

def __init__(self, s_dim: int, module: nn.Module, is_backward: bool = False):
Expand Down Expand Up @@ -1282,6 +1321,16 @@ def to_probability_distribution(
"""
raise NotImplementedError

def fast_features(
self,
states_tensor: torch.Tensor,
*,
forward_masks: torch.Tensor | None = None,
backward_masks: torch.Tensor | None = None,
conditions: torch.Tensor | None = None,
) -> torch.Tensor:
return states_tensor


class PinnedBrownianMotionForward(DiffusionPolicyEstimator): # TODO: support OU process
def __init__(
Expand Down Expand Up @@ -1345,8 +1394,13 @@ def to_probability_distribution(
A IsotropicGaussian distribution (distribution of the next states)
"""
assert len(states.batch_shape) == 1, "States must have a batch_shape of length 1"
s_curr = states.tensor[:, :-1]
t_curr = states.tensor[:, [-1]]
return self._distribution_from_tensor(states.tensor, module_output)

def _distribution_from_tensor(
self, states_tensor: torch.Tensor, module_output: torch.Tensor
) -> IsotropicGaussian:
s_curr = states_tensor[:, :-1]
t_curr = states_tensor[:, [-1]]

module_output = torch.where(
(1.0 - t_curr) < self.dt * 1e-2, # sf case; when t_curr is 1.0
Expand All @@ -1359,6 +1413,22 @@ def to_probability_distribution(
fwd_std = fwd_std.repeat(fwd_mean.shape[0], 1)
return IsotropicGaussian(fwd_mean, fwd_std)

def fast_distribution(
self,
features: torch.Tensor,
*,
states_tensor: torch.Tensor | None = None,
forward_masks: torch.Tensor | None = None,
backward_masks: torch.Tensor | None = None,
**policy_kwargs: Any,
) -> IsotropicGaussian:
if states_tensor is None:
raise ValueError(
"states_tensor is required for PinnedBrownianMotionForward fast path."
)
module_output = self.module(features)
return self._distribution_from_tensor(states_tensor, module_output)


class PinnedBrownianMotionBackward(DiffusionPolicyEstimator): # TODO: support OU process
def __init__(
Expand Down Expand Up @@ -1422,10 +1492,15 @@ def to_probability_distribution(
A IsotropicGaussian distribution (distribution of the previous states)
"""
assert len(states.batch_shape) == 1, "States must have a batch_shape of length 1"
s_curr = states.tensor[:, :-1]
t_curr = states.tensor[:, [-1]] # shape: (*batch_shape,)
return self._distribution_from_tensor(states.tensor, module_output)

def _distribution_from_tensor(
self, states_tensor: torch.Tensor, module_output: torch.Tensor
) -> IsotropicGaussian:
s_curr = states_tensor[:, :-1]
t_curr = states_tensor[:, [-1]]

is_s0 = (t_curr - self.dt) < self.dt * 1e-2 # s0 case; when t_curr - dt is 0.0
is_s0 = (t_curr - self.dt) < self.dt * 1e-2
bwd_mean = torch.where(
is_s0,
s_curr,
Expand All @@ -1437,3 +1512,19 @@ def to_probability_distribution(
self.sigma * (self.dt * (t_curr - self.dt) / t_curr).sqrt(),
)
return IsotropicGaussian(bwd_mean, bwd_std)

def fast_distribution(
self,
features: torch.Tensor,
*,
states_tensor: torch.Tensor | None = None,
forward_masks: torch.Tensor | None = None,
backward_masks: torch.Tensor | None = None,
**policy_kwargs: Any,
) -> IsotropicGaussian:
if states_tensor is None:
raise ValueError(
"states_tensor is required for PinnedBrownianMotionBackward fast path."
)
module_output = self.module(features)
return self._distribution_from_tensor(states_tensor, module_output)
45 changes: 30 additions & 15 deletions src/gfn/gflownet/sub_trajectory_balance.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,18 +492,32 @@ def get_geometric_within_contributions(
Returns:
The contributions tensor of shape (max_len * (max_len+1) / 2, n_trajectories).
"""
L = self.lamda

max_len = trajectories.max_length
t_idx = trajectories.terminating_idx
if max_len == 0 or len(trajectories) == 0:
return torch.zeros(
(0, len(trajectories)),
device=trajectories.device,
dtype=torch.get_default_dtype(),
)

# The following tensor represents the weights given to each possible
# sub-trajectory length.
contributions = (L ** torch.arange(max_len, device=t_idx.device).double()).to(
torch.get_default_dtype()
)
contributions = contributions.unsqueeze(-1).repeat(1, len(trajectories))
dtype = torch.get_default_dtype()
device = trajectories.device
t_idx = trajectories.terminating_idx.to(dtype)

# Clamp lambda away from 0/1 to avoid divisions by zero or log(0) while keeping
# the computation compatible with torch.compile.
lamda = torch.as_tensor(self.lamda, device=device, dtype=dtype)
finfo = torch.finfo(dtype)
lamda = torch.clamp(lamda, finfo.tiny, 1 - finfo.eps)

# Geometric weights for each possible sub-trajectory length, computed in log
# space to reduce error when lamda is close to 1.
lengths = torch.arange(max_len, device=device, dtype=dtype)
log_weights = lengths * torch.log(lamda)
contributions = torch.exp(log_weights).unsqueeze(-1).repeat(1, len(trajectories))
contributions = contributions.repeat_interleave(
torch.arange(max_len, 0, -1, device=t_idx.device),
torch.arange(max_len, 0, -1, device=device),
dim=0,
output_size=int(max_len * (max_len + 1) / 2),
)
Expand All @@ -512,13 +526,14 @@ def get_geometric_within_contributions(
# where n is the length of the trajectory corresponding to that column
# We can do it the ugly way, or using the cool identity:
# https://www.wolframalpha.com/input?i=sum%28%28n-i%29+*+lambda+%5Ei%2C+i%3D0..n%29
per_trajectory_denom = (
1.0
/ (1 - L) ** 2
* (L * (L ** t_idx.double() - 1) + (1 - L) * t_idx.double())
).to(torch.get_default_dtype())
contributions = contributions / per_trajectory_denom / len(trajectories)
# Closed-form normalization:
# sum_{i=0}^{n-1} (n - i) * lamda^i
lamda_pow_n = torch.pow(lamda, t_idx)
numerator = lamda * (lamda_pow_n - 1) + (1 - lamda) * t_idx
denominator = (1 - lamda) ** 2
per_trajectory_denom = numerator / denominator

contributions = contributions / per_trajectory_denom / len(trajectories)
return contributions

def loss(
Expand Down
Loading
Loading