diff --git a/sample.py b/sample.py index 7f11d79..e288bf4 100644 --- a/sample.py +++ b/sample.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +from typing import Optional import time import os @@ -8,7 +9,10 @@ import jax.numpy as jnp import flax.linen as nn import numpy as np -from orbax.checkpoint import PyTreeCheckpointer +from flax.training.train_state import TrainState +import grain +import orbax.checkpoint as ocp +import optax from PIL import Image, ImageDraw import tyro @@ -26,6 +30,7 @@ class Args: image_width: int = 160 data_dir: str = "data/coinrun_episodes" checkpoint: str = "" + checkpoint_step: Optional[int] = None # Sampling batch_size: int = 1 maskgit_steps: int = 25 @@ -46,6 +51,7 @@ class Args: lam_patch_size: int = 16 lam_num_blocks: int = 8 lam_num_heads: int = 8 + lam_co_train: bool = True # Dynamics checkpoint dyna_dim: int = 512 dyna_num_blocks: int = 12 @@ -72,6 +78,7 @@ class Args: lam_patch_size=args.lam_patch_size, lam_num_blocks=args.lam_num_blocks, lam_num_heads=args.lam_num_heads, + lam_co_train=args.lam_co_train, # Dynamics dyna_dim=args.dyna_dim, dyna_num_blocks=args.dyna_num_blocks, @@ -85,8 +92,35 @@ class Args: ) rng, _rng = jax.random.split(rng) params = genie.init(_rng, dummy_inputs) -ckpt = PyTreeCheckpointer().restore(args.checkpoint)["model"]["params"]["params"] -params["params"].update(ckpt) + +dummy_train_state = TrainState.create( + apply_fn=genie.apply, + params=params, + tx=optax.adamw( + optax.warmup_cosine_decay_schedule( + 0, 0, 1, 2 # dummy values + ) + ), +) +handler_registry = ocp.handlers.DefaultCheckpointHandlerRegistry() +handler_registry.add('model_state', ocp.args.StandardRestore, ocp.handlers.StandardCheckpointHandler) +checkpoint_manager = ocp.CheckpointManager( + args.checkpoint, + options=ocp.CheckpointManagerOptions(step_format_fixed_length=6), + handler_registry=handler_registry +) +abstract_train_state = jax.tree_util.tree_map( + ocp.utils.to_shape_dtype_struct, dummy_train_state +) + +restored = checkpoint_manager.restore( + args.checkpoint_step or checkpoint_manager.latest_step(), + args=ocp.args.Composite( + model_state=ocp.args.StandardRestore(abstract_train_state), + ), +) +restored_train_state = restored["model_state"] +params = restored_train_state.params def _sampling_wrapper(module, batch): @@ -104,24 +138,31 @@ def _autoreg_sample(rng, video_batch, action_batch): ) return generated_vid +def _get_dataloader_iterator(): + array_record_files = [ + os.path.join(args.data_dir, x) + for x in os.listdir(args.data_dir) + if x.endswith(".array_record") + ] + grain_dataloader = get_dataloader( + array_record_files, + args.seq_len, + # NOTE: We deliberately pass the global batch size + # The dataloader shards the dataset across all processes + args.batch_size, + *image_shape, + num_workers=0, + prefetch_buffer_size=1, + seed=args.seed, + ) + initial_state = grain_dataloader._create_initial_state() + grain_iterator = grain.DataLoaderIterator(grain_dataloader, initial_state) + return grain_iterator + # --- Get video + latent actions --- -array_record_files = [ - os.path.join(args.data_dir, x) - for x in os.listdir(args.data_dir) - if x.endswith(".array_record") -] -dataloader = get_dataloader( - array_record_files, - args.seq_len, - args.batch_size, - args.image_height, - args.image_width, - args.image_channels, - num_workers=8, - prefetch_buffer_size=1, - seed=args.seed, -) -video_batch = next(iter(dataloader)) +grain_iterator = _get_dataloader_iterator() +video_batch = next(grain_iterator) + # Get latent actions for all videos in the batch batch = dict(videos=video_batch) action_batch = genie.apply(params, batch, False, method=Genie.vq_encode)