Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 140 additions & 0 deletions src/gfn/containers/trajectories.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,146 @@ def to_states_container(self) -> StatesContainer:
# FIXME: Add log_probs and estimator_outputs.
)

def splice_from_reconstruction(
self,
n_prevs: torch.Tensor,
recon: "Trajectories",
debug: bool = False,
) -> "Trajectories":
"""Splice this (prefix) trajectories with a reconstructed suffix.

This mirrors the sampler's combine logic: keep the first ``n_prevs[i]`` steps
from ``self`` for each trajectory, then append the reconstructed portion from
``recon``. When provided, the PF/PB per-step tensors are spliced identically.

Args:
n_prevs: How many prefix steps to take from ``self`` per trajectory.
recon: Trajectories representing the reconstructed suffix per trajectory.
debug: If True, validates the vectorized splicing against a for-loop.

Returns:
new_trajectories
"""
# Basic validations to avoid silent shape/device mismatches.
assert self.env is recon.env
bs = self.n_trajectories
assert bs == recon.n_trajectories
device = self.states.device

# Determine prefix/suffix lengths and maxima for sizing outputs.
max_n_prev = n_prevs.max()
n_recons = recon.terminating_idx
max_n_recon = n_recons.max()

# Episodic reward is taken from the reconstructed suffix.
new_log_rewards = recon.log_rewards
new_dones = n_prevs + n_recons
max_traj_len = int(new_dones.max().item())

# Helper indices and masks over (time, batch).
idx = (
torch.arange(max_traj_len + 1, device=n_prevs.device)
.unsqueeze(1)
.expand(-1, bs)
)
prev_mask = idx < n_prevs
state_recon_mask = (idx >= n_prevs) * (idx <= n_prevs + n_recons)
state_recon_mask2 = idx[: max_n_recon + 1] <= n_recons
action_recon_mask = (idx[:-1] >= n_prevs) * (idx[:-1] <= n_prevs + n_recons - 1)
action_recon_mask2 = idx[:max_n_recon] <= n_recons - 1

# Transpose to (batch, time, ...) for efficient advanced indexing.
prev_states_tsr = self.states.tensor.transpose(0, 1)
prev_actions_tsr = self.actions.tensor.transpose(0, 1)
recon_states_tsr = recon.states.tensor.transpose(0, 1)
recon_actions_tsr = recon.actions.tensor.transpose(0, 1)
prev_mask = prev_mask.transpose(0, 1)
state_recon_mask = state_recon_mask.transpose(0, 1)
state_recon_mask2 = state_recon_mask2.transpose(0, 1)
action_recon_mask = action_recon_mask.transpose(0, 1)
action_recon_mask2 = action_recon_mask2.transpose(0, 1)

# Prepare output tensors in transposed shapes, initialized to sink/dummy.
new_states_tsr = self.env.sf.repeat(bs, max_traj_len + 1, 1).to(
self.states.tensor
)
new_actions_tsr = self.env.dummy_action.repeat(bs, max_traj_len, 1).to(
self.actions.tensor
)

# Fill prefix from ``self`` using ``prev_mask``.
prev_mask_truc = prev_mask[:, :max_n_prev]
new_states_tsr[prev_mask] = prev_states_tsr[:, :max_n_prev][prev_mask_truc]
new_actions_tsr[prev_mask[:, :-1]] = prev_actions_tsr[:, :max_n_prev][
prev_mask_truc
]

# Fill suffix from ``recon`` using recon masks.
new_states_tsr[state_recon_mask] = recon_states_tsr[state_recon_mask2]
new_actions_tsr[action_recon_mask] = recon_actions_tsr[action_recon_mask2]

# Transpose back to (time, batch, ...).
new_states_tsr = new_states_tsr.transpose(0, 1)
new_actions_tsr = new_actions_tsr.transpose(0, 1)

# Optionally combine PF from container log_probs when both are present.
new_log_pf = None
if (self.log_probs is not None) and (recon.log_probs is not None):
plp = self.log_probs.transpose(0, 1)
rlp = recon.log_probs.transpose(0, 1)
new_log_pf = torch.full((bs, max_traj_len), 0.0).to(
device=device, dtype=plp.dtype # type: ignore[arg-type]
)
new_log_pf[prev_mask[:, :-1]] = plp[:, :max_n_prev][prev_mask_truc]
new_log_pf[action_recon_mask] = rlp[action_recon_mask2]
new_log_pf = new_log_pf.transpose(0, 1)

# ------------------------------ DEBUG ------------------------------
if debug:
_states = self.env.sf.repeat(max_traj_len + 1, bs, 1).to(self.states.tensor)
_actions = self.env.dummy_action.repeat(max_traj_len, bs, 1).to(
self.actions.tensor
)
# PF/PB debug-specific tensors are omitted; we only validate states/actions.

for i in range(bs):
_n_prev = n_prevs[i]
# Prefix copy
_states[: _n_prev + 1, i] = self.states.tensor[: _n_prev + 1, i]
_actions[:_n_prev, i] = self.actions.tensor[:_n_prev, i]
# Suffix copy
_len_recon = recon.terminating_idx[i]
_states[_n_prev + 1 : _n_prev + _len_recon + 1, i] = recon.states.tensor[
1 : _len_recon + 1,
i,
]
_actions[_n_prev : _n_prev + _len_recon, i] = recon.actions.tensor[
:_len_recon,
i,
]

# PF/PB debug path is intentionally omitted.

assert torch.all(_states == new_states_tsr)
assert torch.all(_actions == new_actions_tsr)

# PF/PB debug comparisons are intentionally omitted.

# Materialize the forward-time trajectories and carry over conditioning
# and episodic rewards from the reconstructed suffix.
new_trajs = Trajectories(
env=self.env,
states=self.env.states_from_tensor(new_states_tsr),
conditioning=self.conditioning,
actions=self.env.actions_from_tensor(new_actions_tsr),
terminating_idx=new_dones,
is_backward=False,
log_rewards=new_log_rewards,
log_probs=new_log_pf,
)

return new_trajs

def reverse_backward_trajectories(self) -> Trajectories:
"""Returns a reversed version of the backward trajectories."""
assert self.is_backward, "Trajectories must be backward."
Expand Down
Loading