From 75dd7ba5d04442ba33e39a0c1290911ae0f89e90 Mon Sep 17 00:00:00 2001 From: Mihir Mahajan Date: Mon, 14 Jul 2025 17:16:18 +0200 Subject: [PATCH 1/4] checkpoint loading from different topogies using new ckpt logic; dataloader init for grain --- sample.py | 80 +++++++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 60 insertions(+), 20 deletions(-) diff --git a/sample.py b/sample.py index 7f11d79..8322ec5 100644 --- a/sample.py +++ b/sample.py @@ -8,7 +8,10 @@ import jax.numpy as jnp import flax.linen as nn import numpy as np -from orbax.checkpoint import PyTreeCheckpointer +from flax.training.train_state import TrainState +import grain +import orbax.checkpoint as ocp +import optax from PIL import Image, ImageDraw import tyro @@ -26,6 +29,7 @@ class Args: image_width: int = 160 data_dir: str = "data/coinrun_episodes" checkpoint: str = "" + checkpoint_step: int = None # Sampling batch_size: int = 1 maskgit_steps: int = 25 @@ -46,6 +50,7 @@ class Args: lam_patch_size: int = 16 lam_num_blocks: int = 8 lam_num_heads: int = 8 + lam_co_train: bool = True, # Dynamics checkpoint dyna_dim: int = 512 dyna_num_blocks: int = 12 @@ -72,6 +77,7 @@ class Args: lam_patch_size=args.lam_patch_size, lam_num_blocks=args.lam_num_blocks, lam_num_heads=args.lam_num_heads, + lam_co_train=args.lam_co_train, # Dynamics dyna_dim=args.dyna_dim, dyna_num_blocks=args.dyna_num_blocks, @@ -85,8 +91,35 @@ class Args: ) rng, _rng = jax.random.split(rng) params = genie.init(_rng, dummy_inputs) -ckpt = PyTreeCheckpointer().restore(args.checkpoint)["model"]["params"]["params"] -params["params"].update(ckpt) + +dummy_train_state = TrainState.create( + apply_fn=genie.apply, + params=params, + tx=optax.adamw( + optax.warmup_cosine_decay_schedule( + 0, 0, 1, 2 # dummy values + ) + ), +) +handler_registry = ocp.handlers.DefaultCheckpointHandlerRegistry() +handler_registry.add('model_state', ocp.args.StandardRestore, ocp.handlers.StandardCheckpointHandler) +checkpoint_manager = ocp.CheckpointManager( + args.checkpoint, + options=ocp.CheckpointManagerOptions(step_format_fixed_length=6), + handler_registry=handler_registry +) +abstract_train_state = jax.tree_util.tree_map( + ocp.utils.to_shape_dtype_struct, dummy_train_state +) + +restored = checkpoint_manager.restore( + args.checkpoint_step or checkpoint_manager.latest_step(), + args=ocp.args.Composite( + model_state=ocp.args.StandardRestore(abstract_train_state), + ), +) +restored_train_state = restored["model_state"] +params = restored_train_state.params def _sampling_wrapper(module, batch): @@ -104,24 +137,31 @@ def _autoreg_sample(rng, video_batch, action_batch): ) return generated_vid +def _get_dataloader_iterator(): + array_record_files = [ + os.path.join(args.data_dir, x) + for x in os.listdir(args.data_dir) + if x.endswith(".array_record") + ] + grain_dataloader = get_dataloader( + array_record_files, + args.seq_len, + # NOTE: We deliberately pass the global batch size + # The dataloader shards the dataset across all processes + args.batch_size, + *image_shape, + num_workers=0, + prefetch_buffer_size=1, + seed=args.seed, + ) + initial_state = grain_dataloader._create_initial_state() + grain_iterator = grain.DataLoaderIterator(grain_dataloader, initial_state) + return grain_iterator + # --- Get video + latent actions --- -array_record_files = [ - os.path.join(args.data_dir, x) - for x in os.listdir(args.data_dir) - if x.endswith(".array_record") -] -dataloader = get_dataloader( - array_record_files, - args.seq_len, - args.batch_size, - args.image_height, - args.image_width, - args.image_channels, - num_workers=8, - prefetch_buffer_size=1, - seed=args.seed, -) -video_batch = next(iter(dataloader)) +grain_iterator = _get_dataloader_iterator() +video_batch = next(grain_iterator) + # Get latent actions for all videos in the batch batch = dict(videos=video_batch) action_batch = genie.apply(params, batch, False, method=Genie.vq_encode) From 13eabbb13a40da30a34d27db54bf7af1d54f6eeb Mon Sep 17 00:00:00 2001 From: mihir <78321484+maharajamihir@users.noreply.github.com> Date: Mon, 14 Jul 2025 18:54:42 +0200 Subject: [PATCH 2/4] Update sample.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- sample.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sample.py b/sample.py index 8322ec5..ff6decc 100644 --- a/sample.py +++ b/sample.py @@ -50,7 +50,7 @@ class Args: lam_patch_size: int = 16 lam_num_blocks: int = 8 lam_num_heads: int = 8 - lam_co_train: bool = True, + lam_co_train: bool = True # Dynamics checkpoint dyna_dim: int = 512 dyna_num_blocks: int = 12 From 66fec58c3c1e7adbe1f7e719b961f6dde93be7cc Mon Sep 17 00:00:00 2001 From: mihir <78321484+maharajamihir@users.noreply.github.com> Date: Mon, 14 Jul 2025 18:54:59 +0200 Subject: [PATCH 3/4] Update sample.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- sample.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sample.py b/sample.py index ff6decc..75ff7a4 100644 --- a/sample.py +++ b/sample.py @@ -29,7 +29,7 @@ class Args: image_width: int = 160 data_dir: str = "data/coinrun_episodes" checkpoint: str = "" - checkpoint_step: int = None + checkpoint_step: Optional[int] = None # Sampling batch_size: int = 1 maskgit_steps: int = 25 From a2f40d48c6ae7f7958044fbeab5315b527978fe5 Mon Sep 17 00:00:00 2001 From: Mihir Mahajan Date: Tue, 15 Jul 2025 11:08:54 +0200 Subject: [PATCH 4/4] added missing import --- sample.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sample.py b/sample.py index 75ff7a4..e288bf4 100644 --- a/sample.py +++ b/sample.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +from typing import Optional import time import os