Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 61 additions & 20 deletions sample.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from dataclasses import dataclass
from typing import Optional
import time
import os

Expand All @@ -8,7 +9,10 @@
import jax.numpy as jnp
import flax.linen as nn
import numpy as np
from orbax.checkpoint import PyTreeCheckpointer
from flax.training.train_state import TrainState
import grain
import orbax.checkpoint as ocp
import optax
from PIL import Image, ImageDraw
import tyro

Expand All @@ -26,6 +30,7 @@ class Args:
image_width: int = 160
data_dir: str = "data/coinrun_episodes"
checkpoint: str = ""
checkpoint_step: Optional[int] = None
# Sampling
batch_size: int = 1
maskgit_steps: int = 25
Expand All @@ -46,6 +51,7 @@ class Args:
lam_patch_size: int = 16
lam_num_blocks: int = 8
lam_num_heads: int = 8
lam_co_train: bool = True
# Dynamics checkpoint
dyna_dim: int = 512
dyna_num_blocks: int = 12
Expand All @@ -72,6 +78,7 @@ class Args:
lam_patch_size=args.lam_patch_size,
lam_num_blocks=args.lam_num_blocks,
lam_num_heads=args.lam_num_heads,
lam_co_train=args.lam_co_train,
# Dynamics
dyna_dim=args.dyna_dim,
dyna_num_blocks=args.dyna_num_blocks,
Expand All @@ -85,8 +92,35 @@ class Args:
)
rng, _rng = jax.random.split(rng)
params = genie.init(_rng, dummy_inputs)
ckpt = PyTreeCheckpointer().restore(args.checkpoint)["model"]["params"]["params"]
params["params"].update(ckpt)

dummy_train_state = TrainState.create(
apply_fn=genie.apply,
params=params,
tx=optax.adamw(
optax.warmup_cosine_decay_schedule(
0, 0, 1, 2 # dummy values
)
),
)
handler_registry = ocp.handlers.DefaultCheckpointHandlerRegistry()
handler_registry.add('model_state', ocp.args.StandardRestore, ocp.handlers.StandardCheckpointHandler)
checkpoint_manager = ocp.CheckpointManager(
args.checkpoint,
options=ocp.CheckpointManagerOptions(step_format_fixed_length=6),
handler_registry=handler_registry
)
abstract_train_state = jax.tree_util.tree_map(
ocp.utils.to_shape_dtype_struct, dummy_train_state
)

restored = checkpoint_manager.restore(
args.checkpoint_step or checkpoint_manager.latest_step(),
args=ocp.args.Composite(
model_state=ocp.args.StandardRestore(abstract_train_state),
),
)
restored_train_state = restored["model_state"]
params = restored_train_state.params


def _sampling_wrapper(module, batch):
Expand All @@ -104,24 +138,31 @@ def _autoreg_sample(rng, video_batch, action_batch):
)
return generated_vid

def _get_dataloader_iterator():
array_record_files = [
os.path.join(args.data_dir, x)
for x in os.listdir(args.data_dir)
if x.endswith(".array_record")
]
grain_dataloader = get_dataloader(
array_record_files,
args.seq_len,
# NOTE: We deliberately pass the global batch size
# The dataloader shards the dataset across all processes
args.batch_size,
*image_shape,
num_workers=0,
prefetch_buffer_size=1,
seed=args.seed,
)
initial_state = grain_dataloader._create_initial_state()
grain_iterator = grain.DataLoaderIterator(grain_dataloader, initial_state)
return grain_iterator

# --- Get video + latent actions ---
array_record_files = [
os.path.join(args.data_dir, x)
for x in os.listdir(args.data_dir)
if x.endswith(".array_record")
]
dataloader = get_dataloader(
array_record_files,
args.seq_len,
args.batch_size,
args.image_height,
args.image_width,
args.image_channels,
num_workers=8,
prefetch_buffer_size=1,
seed=args.seed,
)
video_batch = next(iter(dataloader))
grain_iterator = _get_dataloader_iterator()
video_batch = next(grain_iterator)

# Get latent actions for all videos in the batch
batch = dict(videos=video_batch)
action_batch = genie.apply(params, batch, False, method=Genie.vq_encode)
Expand Down