From 3d8ba61c33be90eab3955303e9e4472d6c44b78b Mon Sep 17 00:00:00 2001 From: Mihir Mahajan Date: Wed, 24 Sep 2025 18:39:12 +0200 Subject: [PATCH 1/9] first attempt to implementing noise level feature --- genie.py | 37 ++++++++++++++++++++++++++++++++++--- jasmine/models/dynamics.py | 35 ++++++++++++++++++++++++++--------- jasmine/train_dynamics.py | 4 ++++ 3 files changed, 64 insertions(+), 12 deletions(-) diff --git a/genie.py b/genie.py index 04ba383..78e8584 100644 --- a/genie.py +++ b/genie.py @@ -38,6 +38,8 @@ def __init__( dyna_ffn_dim: int, dyna_num_blocks: int, dyna_num_heads: int, + max_noise_level: float, + noise_buckets: int, param_dtype: jnp.dtype, dtype: jnp.dtype, use_flash_attention: bool, @@ -71,6 +73,8 @@ def __init__( self.dyna_ffn_dim = dyna_ffn_dim self.dyna_num_blocks = dyna_num_blocks self.dyna_num_heads = dyna_num_heads + self.max_noise_level = max_noise_level + self.noise_buckets = noise_buckets self.param_dtype = param_dtype self.dtype = dtype self.use_flash_attention = use_flash_attention @@ -127,6 +131,8 @@ def __init__( num_heads=self.dyna_num_heads, dropout=self.dropout, mask_limit=self.mask_limit, + max_noise_level=self.max_noise_level, + noise_buckets=self.noise_buckets, param_dtype=self.param_dtype, dtype=self.dtype, use_flash_attention=self.use_flash_attention, @@ -141,6 +147,8 @@ def __init__( num_blocks=self.dyna_num_blocks, num_heads=self.dyna_num_heads, dropout=self.dropout, + max_noise_level=self.max_noise_level, + noise_buckets=self.noise_buckets, param_dtype=self.param_dtype, dtype=self.dtype, use_flash_attention=self.use_flash_attention, @@ -205,7 +213,7 @@ def sample( ) -> tuple[jax.Array, jax.Array]: if self.dyna_type == "maskgit": return self.sample_maskgit( - batch, seq_len, maskgit_steps, temperature, sample_argmax + batch, seq_len, 0.0, maskgit_steps, temperature, sample_argmax ) elif self.dyna_type == "causal": return self.sample_causal(batch, seq_len, temperature, sample_argmax) @@ -216,6 +224,7 @@ def sample_maskgit( self, batch: Dict[str, jax.Array], seq_len: int, + noise_level: float, steps: int = 25, temperature: float = 1, sample_argmax: bool = False, @@ -257,6 +266,7 @@ def sample_maskgit( init_logits_BSNV = jnp.zeros( shape=(*token_idxs_BSN.shape, self.num_patch_latents) ) + noise_level = jnp.array(noise_level) if self.use_gt_actions: assert self.action_embed is not None latent_actions_BT1L = self.action_embed(batch["actions"]).reshape( @@ -290,6 +300,8 @@ def maskgit_step_fn( num_blocks=self.dyna_num_blocks, num_heads=self.dyna_num_heads, dropout=self.dropout, + max_noise_level=self.max_noise_level, + noise_buckets=self.noise_buckets, mask_limit=self.mask_limit, param_dtype=self.param_dtype, dtype=self.dtype, @@ -313,10 +325,29 @@ def maskgit_step_fn( act_embed_BS1M = jnp.reshape( act_embed_BSM, (B, S, 1, act_embed_BSM.shape[-1]) ) - vid_embed_BSNM += act_embed_BS1M + # TODO mihir + + rng, _rng_noise = jax.random.split(rng) + noise_level_111 = noise_level.reshape(1, 1, 1) + noise_level_B11 = jnp.tile(noise_level_111, (B, 1, 1)) + noise_bucket_idx_B11 = jnp.floor( + (noise_level_B11 / self.max_noise_level) * self.noise_buckets + ).astype(jnp.int32) + noise_level_embed_B11M = dynamics_maskgit.noise_level_embed( + noise_bucket_idx_B11 + ) + noise_level_embed_BS1M = jnp.tile(noise_level_embed_B11M, (1, S, 1, 1)) + vid_embed_BSNM += jnp.expand_dims(noise_level_B11, -1) + + vid_embed_BSNp2M = jnp.concatenate( + [act_embed_BS1M, noise_level_embed_BS1M, vid_embed_BSNM], axis=2 + ) unmasked_ratio = jnp.cos(jnp.pi * (step + 1) / (steps * 2)) step_temp = temperature * (1.0 - unmasked_ratio) - final_logits_BSNV = dynamics_maskgit.transformer(vid_embed_BSNM) / step_temp + final_logits_BSNp2V = ( + dynamics_maskgit.transformer(vid_embed_BSNp2M) / step_temp + ) + final_logits_BSNV = final_logits_BSNp2V[:, :, 2:] # --- Sample new tokens for final frame --- if sample_argmax: diff --git a/jasmine/models/dynamics.py b/jasmine/models/dynamics.py index 6f0288b..c773399 100644 --- a/jasmine/models/dynamics.py +++ b/jasmine/models/dynamics.py @@ -28,6 +28,8 @@ def __init__( num_blocks: int, num_heads: int, dropout: float, + max_noise_level: float, + noise_buckets: int, mask_limit: float, param_dtype: jnp.dtype, dtype: jnp.dtype, @@ -41,6 +43,8 @@ def __init__( self.num_blocks = num_blocks self.num_heads = num_heads self.dropout = dropout + self.max_noise_level = max_noise_level + self.noise_buckets = noise_buckets self.mask_limit = mask_limit self.param_dtype = param_dtype self.dtype = dtype @@ -70,6 +74,9 @@ def __init__( dtype=self.dtype, rngs=rngs, ) + self.noise_level_embed = nnx.Embed( + self.noise_buckets, self.model_dim, rngs=rngs + ) def __call__( self, @@ -80,11 +87,9 @@ def __call__( latent_actions_BTm11L = batch["latent_actions"] vid_embed_BTNM = self.patch_embed(video_tokens_BTN) - batch_size = vid_embed_BTNM.shape[0] - _rng_prob, *_rngs_mask = jax.random.split(batch["mask_rng"], batch_size + 1) - mask_prob = jax.random.uniform( - _rng_prob, shape=(batch_size,), minval=self.mask_limit - ) + B, T, N, M = vid_embed_BTNM.shape + rng, _rng_prob, *_rngs_mask = jax.random.split(batch["mask_rng"], B + 2) + mask_prob = jax.random.uniform(_rng_prob, shape=(B,), minval=self.mask_limit) per_sample_shape = vid_embed_BTNM.shape[1:-1] mask = jax.vmap( lambda rng, prob: jax.random.bernoulli(rng, prob, per_sample_shape), @@ -95,16 +100,28 @@ def __call__( jnp.expand_dims(mask, -1), self.mask_token.value, vid_embed_BTNM ) + # --- Sample noise --- + rng, _rng_noise = jax.random.split(rng) + noise_level_B11 = jax.random.uniform( + _rng_noise, shape=(B,), minval=0.0, maxval=self.max_noise_level + ).reshape(B, 1, 1) + noise_bucket_idx_B11 = jnp.floor( + (noise_level_B11 / self.max_noise_level) * self.noise_buckets + ).astype(jnp.int32) + noise_level_embed_B11M = self.noise_level_embed(noise_bucket_idx_B11) + noise_level_embed_BT1M = jnp.tile(noise_level_embed_B11M, (1, T, 1, 1)) + vid_embed_BTNM += jnp.expand_dims(noise_level_B11, -1) + # --- Predict transition --- act_embed_BTm11M = self.action_up(latent_actions_BTm11L) padded_act_embed_BT1M = jnp.pad( act_embed_BTm11M, ((0, 0), (1, 0), (0, 0), (0, 0)) ) - padded_act_embed_BTNM = jnp.broadcast_to( - padded_act_embed_BT1M, vid_embed_BTNM.shape + vid_embed_BTNp2M = jnp.concatenate( + [padded_act_embed_BT1M, noise_level_embed_BT1M, vid_embed_BTNM], axis=2 ) - vid_embed_BTNM += padded_act_embed_BTNM - logits_BTNV = self.transformer(vid_embed_BTNM) + logits_BTNp2V = self.transformer(vid_embed_BTNp2M) + logits_BTNV = logits_BTNp2V[:, :, 2:] return logits_BTNV, mask diff --git a/jasmine/train_dynamics.py b/jasmine/train_dynamics.py index ba6f5b9..2734952 100644 --- a/jasmine/train_dynamics.py +++ b/jasmine/train_dynamics.py @@ -78,6 +78,8 @@ class Args: dyna_ffn_dim: int = 2048 dyna_num_blocks: int = 6 dyna_num_heads: int = 8 + max_noise_level: float = 0.7 + noise_buckets: int = 10 dropout: float = 0.0 mask_limit: float = 0.5 param_dtype = jnp.float32 @@ -136,6 +138,8 @@ def build_model(args: Args, rng: jax.Array) -> tuple[Genie, jax.Array]: dyna_num_blocks=args.dyna_num_blocks, dyna_num_heads=args.dyna_num_heads, dropout=args.dropout, + max_noise_level=args.max_noise_level, + noise_buckets=args.noise_buckets, mask_limit=args.mask_limit, param_dtype=args.param_dtype, dtype=args.dtype, From 69d3b80a13d11d464da81a745f04da3cbb2c5585 Mon Sep 17 00:00:00 2001 From: Mihir Mahajan Date: Thu, 25 Sep 2025 11:08:55 +0200 Subject: [PATCH 2/9] added noise level to sample.py --- jasmine/sample.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/jasmine/sample.py b/jasmine/sample.py index 1d9aa8e..5a4073a 100644 --- a/jasmine/sample.py +++ b/jasmine/sample.py @@ -36,6 +36,8 @@ class Args: temperature: float = 1.0 sample_argmax: bool = True start_frame: int = 1 + noise_level: float = 0.0 + noise_buckets: int = 10 # Tokenizer checkpoint tokenizer_dim: int = 512 tokenizer_ffn_dim: int = 2048 @@ -102,6 +104,8 @@ class Args: lam_num_blocks=args.lam_num_blocks, lam_num_heads=args.lam_num_heads, lam_co_train=False, + max_noise_level=0.0, + noise_buckets=args.noise_buckets, use_gt_actions=args.use_gt_actions, # Dynamics dyna_type=args.dyna_type, From fa9afacf62a0974e93a32bfeb9b120f4fce42993 Mon Sep 17 00:00:00 2001 From: Mihir Mahajan Date: Thu, 25 Sep 2025 16:25:36 +0200 Subject: [PATCH 3/9] fix noise augmentation logic --- jasmine/genie.py | 3 +++ jasmine/models/dynamics.py | 16 +++++++++++----- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/jasmine/genie.py b/jasmine/genie.py index 78e8584..061cebd 100644 --- a/jasmine/genie.py +++ b/jasmine/genie.py @@ -255,6 +255,9 @@ def sample_maskgit( P: S * N """ assert isinstance(self.dynamics, DynamicsMaskGIT) + assert ( + noise_level < self.max_noise_level + ), "Noise level must me smaller than max_noise_level." # --- Encode videos and actions --- videos_BTHWC = batch["videos"] tokenizer_out = self.tokenizer.vq_encode(videos_BTHWC, training=False) diff --git a/jasmine/models/dynamics.py b/jasmine/models/dynamics.py index c773399..0637990 100644 --- a/jasmine/models/dynamics.py +++ b/jasmine/models/dynamics.py @@ -102,15 +102,21 @@ def __call__( # --- Sample noise --- rng, _rng_noise = jax.random.split(rng) - noise_level_B11 = jax.random.uniform( + noise_level_B = jax.random.uniform( _rng_noise, shape=(B,), minval=0.0, maxval=self.max_noise_level - ).reshape(B, 1, 1) - noise_bucket_idx_B11 = jnp.floor( - (noise_level_B11 / self.max_noise_level) * self.noise_buckets + ) + noise_BTNM = jax.random.normal(_rng_noise, shape=(B, T, N, M)) + noise_bucket_idx_B = jnp.floor( + (noise_level_B / self.max_noise_level) * self.noise_buckets ).astype(jnp.int32) + noise_bucket_idx_B11 = noise_bucket_idx_B.reshape(B, 1, 1) noise_level_embed_B11M = self.noise_level_embed(noise_bucket_idx_B11) noise_level_embed_BT1M = jnp.tile(noise_level_embed_B11M, (1, T, 1, 1)) - vid_embed_BTNM += jnp.expand_dims(noise_level_B11, -1) + noise_level_B111 = noise_level_B.reshape(B, 1, 1, 1) + vid_embed_BTNM = ( + jnp.sqrt(1 - noise_level_B111) * vid_embed_BTNM + + jnp.sqrt(noise_level_B111) * noise_BTNM + ) # --- Predict transition --- act_embed_BTm11M = self.action_up(latent_actions_BTm11L) From 28c431527970770c7a4c95d76f8919a8f4041418 Mon Sep 17 00:00:00 2001 From: Mihir Mahajan Date: Sat, 27 Sep 2025 19:13:23 +0200 Subject: [PATCH 4/9] max noise to 1 in sampling logic --- jasmine/sample.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jasmine/sample.py b/jasmine/sample.py index 5a4073a..c0eda88 100644 --- a/jasmine/sample.py +++ b/jasmine/sample.py @@ -104,7 +104,7 @@ class Args: lam_num_blocks=args.lam_num_blocks, lam_num_heads=args.lam_num_heads, lam_co_train=False, - max_noise_level=0.0, + max_noise_level=1.0, noise_buckets=args.noise_buckets, use_gt_actions=args.use_gt_actions, # Dynamics From 57dee334a2a983b86b936d8d7673691ee2c0f16a Mon Sep 17 00:00:00 2001 From: Mihir Mahajan Date: Sun, 28 Sep 2025 13:53:20 +0200 Subject: [PATCH 5/9] made noise augmentation safer --- jasmine/models/dynamics.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/jasmine/models/dynamics.py b/jasmine/models/dynamics.py index 0637990..5ce4dcd 100644 --- a/jasmine/models/dynamics.py +++ b/jasmine/models/dynamics.py @@ -101,23 +101,27 @@ def __call__( ) # --- Sample noise --- - rng, _rng_noise = jax.random.split(rng) + rng, _rng_noise_lvl, _rng_noise = jax.random.split(rng, 3) noise_level_B = jax.random.uniform( - _rng_noise, shape=(B,), minval=0.0, maxval=self.max_noise_level + _rng_noise_lvl, shape=(B,), minval=0.0, maxval=self.max_noise_level ) noise_BTNM = jax.random.normal(_rng_noise, shape=(B, T, N, M)) + # We calculate `(noise_level * noise_buckets) / max_noise_level` instead of + # `(noise_level_B / max_noise_level) * noise_buckets` for numerical stability. noise_bucket_idx_B = jnp.floor( - (noise_level_B / self.max_noise_level) * self.noise_buckets + (noise_level_B * self.noise_buckets) / self.max_noise_level ).astype(jnp.int32) noise_bucket_idx_B11 = noise_bucket_idx_B.reshape(B, 1, 1) noise_level_embed_B11M = self.noise_level_embed(noise_bucket_idx_B11) noise_level_embed_BT1M = jnp.tile(noise_level_embed_B11M, (1, T, 1, 1)) noise_level_B111 = noise_level_B.reshape(B, 1, 1, 1) - vid_embed_BTNM = ( - jnp.sqrt(1 - noise_level_B111) * vid_embed_BTNM - + jnp.sqrt(noise_level_B111) * noise_BTNM - ) + # safe sqrt: clip argument to >= 0 + one_minus_noise = jnp.clip(1.0 - noise_level_B111, a_min=0.0) + sqrt_one_minus = jnp.sqrt(one_minus_noise) + sqrt_noise = jnp.sqrt(jnp.clip(noise_level_B111, a_min=0.0)) + + vid_embed_BTNM = sqrt_one_minus * vid_embed_BTNM + sqrt_noise * noise_BTNM # --- Predict transition --- act_embed_BTm11M = self.action_up(latent_actions_BTm11L) padded_act_embed_BT1M = jnp.pad( From 32f3839f71e68e497c21db7a47e1bf6c9b703b7b Mon Sep 17 00:00:00 2001 From: Mihir Mahajan Date: Sun, 28 Sep 2025 22:18:45 +0200 Subject: [PATCH 6/9] fix nan issue by clipping bucket id; small refactor --- jasmine/genie.py | 43 +++++++++++----- jasmine/models/dynamics.py | 103 ++++++++++++++++++++++++++++--------- jasmine/train_dynamics.py | 11 ++-- 3 files changed, 115 insertions(+), 42 deletions(-) diff --git a/jasmine/genie.py b/jasmine/genie.py index 061cebd..80e9acf 100644 --- a/jasmine/genie.py +++ b/jasmine/genie.py @@ -207,16 +207,22 @@ def sample( self, batch: Dict[str, jax.Array], seq_len: int, + noise_level: float = 0.0, temperature: float = 1, sample_argmax: bool = False, maskgit_steps: int = 25, ) -> tuple[jax.Array, jax.Array]: + assert ( + noise_level < self.max_noise_level + ), "Noise level must me smaller than max_noise_level." if self.dyna_type == "maskgit": return self.sample_maskgit( - batch, seq_len, 0.0, maskgit_steps, temperature, sample_argmax + batch, seq_len, noise_level, maskgit_steps, temperature, sample_argmax ) elif self.dyna_type == "causal": - return self.sample_causal(batch, seq_len, temperature, sample_argmax) + return self.sample_causal( + batch, seq_len, noise_level, temperature, sample_argmax + ) else: raise ValueError(f"Dynamics model type unknown: {self.dyna_type}") @@ -255,9 +261,6 @@ def sample_maskgit( P: S * N """ assert isinstance(self.dynamics, DynamicsMaskGIT) - assert ( - noise_level < self.max_noise_level - ), "Noise level must me smaller than max_noise_level." # --- Encode videos and actions --- videos_BTHWC = batch["videos"] tokenizer_out = self.tokenizer.vq_encode(videos_BTHWC, training=False) @@ -328,19 +331,18 @@ def maskgit_step_fn( act_embed_BS1M = jnp.reshape( act_embed_BSM, (B, S, 1, act_embed_BSM.shape[-1]) ) - # TODO mihir + # TODO mihir rng, _rng_noise = jax.random.split(rng) noise_level_111 = noise_level.reshape(1, 1, 1) noise_level_B11 = jnp.tile(noise_level_111, (B, 1, 1)) noise_bucket_idx_B11 = jnp.floor( - (noise_level_B11 / self.max_noise_level) * self.noise_buckets + (noise_level_B11 * self.noise_buckets) / self.max_noise_level ).astype(jnp.int32) noise_level_embed_B11M = dynamics_maskgit.noise_level_embed( noise_bucket_idx_B11 ) noise_level_embed_BS1M = jnp.tile(noise_level_embed_B11M, (1, S, 1, 1)) - vid_embed_BSNM += jnp.expand_dims(noise_level_B11, -1) vid_embed_BSNp2M = jnp.concatenate( [act_embed_BS1M, noise_level_embed_BS1M, vid_embed_BSNM], axis=2 @@ -441,6 +443,7 @@ def sample_causal( self, batch: Dict[str, jax.Array], seq_len: int, + noise_level: float, temperature: float = 1, sample_argmax: bool = False, ) -> tuple[jax.Array, jax.Array]: @@ -528,12 +531,28 @@ def causal_step_fn( act_embed_BS1M = jnp.reshape( act_embed_BSM, (B, S, 1, act_embed_BSM.shape[-1]) ) - vid_embed_BSNp1M = jnp.concatenate([act_embed_BS1M, vid_embed_BSNM], axis=2) - final_logits_BTNp1V = ( - dynamics_causal.transformer(vid_embed_BSNp1M, (step_t, step_n)) + + # TODO mihir + + rng, _rng_noise = jax.random.split(rng) + noise_level_111 = noise_level.reshape(1, 1, 1) + noise_level_B11 = jnp.tile(noise_level_111, (B, 1, 1)) + noise_bucket_idx_B11 = jnp.floor( + (noise_level_B11 * self.noise_buckets) / self.max_noise_level + ).astype(jnp.int32) + noise_level_embed_B11M = dynamics_causal.noise_level_embed( + noise_bucket_idx_B11 + ) + noise_level_embed_BS1M = jnp.tile(noise_level_embed_B11M, (1, S, 1, 1)) + + vid_embed_BSNp2M = jnp.concatenate( + [act_embed_BS1M, noise_level_embed_BS1M, vid_embed_BSNM], axis=2 + ) + final_logits_BTNp2V = ( + dynamics_causal.transformer(vid_embed_BSNp2M, (step_t, step_n)) / temperature ) - final_logits_BV = final_logits_BTNp1V[:, step_t, step_n, :] + final_logits_BV = final_logits_BTNp2V[:, step_t, step_n + 1, :] # --- Sample new tokens for final frame --- if sample_argmax: diff --git a/jasmine/models/dynamics.py b/jasmine/models/dynamics.py index 5ce4dcd..4842ab5 100644 --- a/jasmine/models/dynamics.py +++ b/jasmine/models/dynamics.py @@ -78,6 +78,36 @@ def __init__( self.noise_buckets, self.model_dim, rngs=rngs ) + def _apply_noise_augmentation(self, vid_embed_BTNM, rng): + B, T, N, M = vid_embed_BTNM.shape + rng, _rng_noise_lvl, _rng_noise = jax.random.split(rng, 3) + noise_level_B = jax.random.uniform( + _rng_noise_lvl, + shape=(B,), + minval=0.0, + maxval=self.max_noise_level, + dtype=self.dtype, + ) + noise_BTNM = jax.random.normal(_rng_noise, shape=(B, T, N, M), dtype=self.dtype) + noise_bucket_idx_B = jnp.floor( + (noise_level_B * self.noise_buckets) / self.max_noise_level + ).astype(jnp.int32) + + # Clip noise_bucket_idx_B to ensure it stays within valid range to prevent NaNs + noise_bucket_idx_B = jnp.clip(noise_bucket_idx_B, 0, self.noise_buckets - 1) + + noise_bucket_idx_B11 = noise_bucket_idx_B.reshape(B, 1, 1) + noise_level_embed_B11M = self.noise_level_embed(noise_bucket_idx_B11) + noise_level_embed_BT1M = jnp.tile(noise_level_embed_B11M, (1, T, 1, 1)) + noise_level_B111 = noise_level_B.reshape(B, 1, 1, 1) + + noise_augmented_vid_embed_BTNM = ( + jnp.sqrt(1 - noise_level_B111) * vid_embed_BTNM + + noise_level_B111 * noise_BTNM + ) + + return noise_augmented_vid_embed_BTNM, noise_level_embed_BT1M + def __call__( self, batch: Dict[str, jax.Array], @@ -87,7 +117,7 @@ def __call__( latent_actions_BTm11L = batch["latent_actions"] vid_embed_BTNM = self.patch_embed(video_tokens_BTN) - B, T, N, M = vid_embed_BTNM.shape + B = vid_embed_BTNM.shape[0] rng, _rng_prob, *_rngs_mask = jax.random.split(batch["mask_rng"], B + 2) mask_prob = jax.random.uniform(_rng_prob, shape=(B,), minval=self.mask_limit) per_sample_shape = vid_embed_BTNM.shape[1:-1] @@ -100,28 +130,11 @@ def __call__( jnp.expand_dims(mask, -1), self.mask_token.value, vid_embed_BTNM ) - # --- Sample noise --- - rng, _rng_noise_lvl, _rng_noise = jax.random.split(rng, 3) - noise_level_B = jax.random.uniform( - _rng_noise_lvl, shape=(B,), minval=0.0, maxval=self.max_noise_level + # --- Apply noise augmentation --- + vid_embed_BTNM, noise_level_embed_BT1M = self._apply_noise_augmentation( + vid_embed_BTNM, rng ) - noise_BTNM = jax.random.normal(_rng_noise, shape=(B, T, N, M)) - # We calculate `(noise_level * noise_buckets) / max_noise_level` instead of - # `(noise_level_B / max_noise_level) * noise_buckets` for numerical stability. - noise_bucket_idx_B = jnp.floor( - (noise_level_B * self.noise_buckets) / self.max_noise_level - ).astype(jnp.int32) - noise_bucket_idx_B11 = noise_bucket_idx_B.reshape(B, 1, 1) - noise_level_embed_B11M = self.noise_level_embed(noise_bucket_idx_B11) - noise_level_embed_BT1M = jnp.tile(noise_level_embed_B11M, (1, T, 1, 1)) - noise_level_B111 = noise_level_B.reshape(B, 1, 1, 1) - - # safe sqrt: clip argument to >= 0 - one_minus_noise = jnp.clip(1.0 - noise_level_B111, a_min=0.0) - sqrt_one_minus = jnp.sqrt(one_minus_noise) - sqrt_noise = jnp.sqrt(jnp.clip(noise_level_B111, a_min=0.0)) - vid_embed_BTNM = sqrt_one_minus * vid_embed_BTNM + sqrt_noise * noise_BTNM # --- Predict transition --- act_embed_BTm11M = self.action_up(latent_actions_BTm11L) padded_act_embed_BT1M = jnp.pad( @@ -147,6 +160,8 @@ def __init__( num_blocks: int, num_heads: int, dropout: float, + max_noise_level: float, + noise_buckets: int, param_dtype: jnp.dtype, dtype: jnp.dtype, use_flash_attention: bool, @@ -160,6 +175,8 @@ def __init__( self.num_blocks = num_blocks self.num_heads = num_heads self.dropout = dropout + self.max_noise_level = max_noise_level + self.noise_buckets = noise_buckets self.param_dtype = param_dtype self.dtype = dtype self.use_flash_attention = use_flash_attention @@ -187,6 +204,39 @@ def __init__( dtype=self.dtype, rngs=rngs, ) + self.noise_level_embed = nnx.Embed( + self.noise_buckets, self.model_dim, rngs=rngs + ) + + def _apply_noise_augmentation(self, vid_embed_BTNM, rng): + B, T, N, M = vid_embed_BTNM.shape + rng, _rng_noise_lvl, _rng_noise = jax.random.split(rng, 3) + noise_level_B = jax.random.uniform( + _rng_noise_lvl, + shape=(B,), + minval=0.0, + maxval=self.max_noise_level, + dtype=self.dtype, + ) + noise_BTNM = jax.random.normal(_rng_noise, shape=(B, T, N, M), dtype=self.dtype) + noise_bucket_idx_B = jnp.floor( + (noise_level_B * self.noise_buckets) / self.max_noise_level + ).astype(jnp.int32) + + # Clip noise_bucket_idx_B to ensure it stays within valid range to prevent NaNs + noise_bucket_idx_B = jnp.clip(noise_bucket_idx_B, 0, self.noise_buckets - 1) + + noise_bucket_idx_B11 = noise_bucket_idx_B.reshape(B, 1, 1) + noise_level_embed_B11M = self.noise_level_embed(noise_bucket_idx_B11) + noise_level_embed_BT1M = jnp.tile(noise_level_embed_B11M, (1, T, 1, 1)) + noise_level_B111 = noise_level_B.reshape(B, 1, 1, 1) + + noise_augmented_vid_embed_BTNM = ( + jnp.sqrt(1 - noise_level_B111) * vid_embed_BTNM + + noise_level_B111 * noise_BTNM + ) + + return noise_augmented_vid_embed_BTNM, noise_level_embed_BT1M def __call__( self, @@ -199,9 +249,12 @@ def __call__( padded_act_embed_BT1M = jnp.pad( act_embed_BTm11M, ((0, 0), (1, 0), (0, 0), (0, 0)) ) - vid_embed_BTNp1M = jnp.concatenate( - [padded_act_embed_BT1M, vid_embed_BTNM], axis=2 + vid_embed_BTNM, noise_level_embed_BT1M = self._apply_noise_augmentation( + video_tokens_BTN, batch["rng"] ) - logits_BTNp1V = self.transformer(vid_embed_BTNp1M) - logits_BTNV = logits_BTNp1V[:, :, :-1] + vid_embed_BTNp2M = jnp.concatenate( + [padded_act_embed_BT1M, noise_level_embed_BT1M, vid_embed_BTNM], axis=2 + ) + logits_BTNp2V = self.transformer(vid_embed_BTNp2M) + logits_BTNV = logits_BTNp2V[:, :, 1:-1] return logits_BTNV, jnp.ones_like(video_tokens_BTN) diff --git a/jasmine/train_dynamics.py b/jasmine/train_dynamics.py index 2734952..a550fd4 100644 --- a/jasmine/train_dynamics.py +++ b/jasmine/train_dynamics.py @@ -506,11 +506,12 @@ def val_step(genie: Genie, inputs: dict) -> dict: :, :-1 ] # remove last frame for generation recon_full_frame, logits_full_frame = genie.sample( - inputs, - args.seq_len, - args.val_temperature, - args.val_sample_argmax, - args.val_maskgit_steps, + batch=inputs, + seq_len=args.seq_len, + noise_level=0.0, + temperature=args.val_temperature, + sample_argmax=args.val_sample_argmax, + maskgit_steps=args.val_maskgit_steps, ) # Calculate metrics for the last frame only step_outputs = { From f87e90f811ae95e3e93af50aed7ac57eb4f1004a Mon Sep 17 00:00:00 2001 From: Mihir Mahajan Date: Sun, 28 Sep 2025 22:38:07 +0200 Subject: [PATCH 7/9] fix typo: forgot sqrt in the formula --- jasmine/genie.py | 12 ++++++++++++ jasmine/models/dynamics.py | 4 ++-- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/jasmine/genie.py b/jasmine/genie.py index 80e9acf..bb87843 100644 --- a/jasmine/genie.py +++ b/jasmine/genie.py @@ -339,6 +339,12 @@ def maskgit_step_fn( noise_bucket_idx_B11 = jnp.floor( (noise_level_B11 * self.noise_buckets) / self.max_noise_level ).astype(jnp.int32) + + # Clip noise_bucket_idx to ensure it stays within valid range to prevent NaNs + noise_bucket_idx_B11 = jnp.clip( + noise_bucket_idx_B11, 0, self.noise_buckets - 1 + ) + noise_level_embed_B11M = dynamics_maskgit.noise_level_embed( noise_bucket_idx_B11 ) @@ -540,6 +546,12 @@ def causal_step_fn( noise_bucket_idx_B11 = jnp.floor( (noise_level_B11 * self.noise_buckets) / self.max_noise_level ).astype(jnp.int32) + + # Clip noise_bucket_idx to ensure it stays within valid range to prevent NaNs + noise_bucket_idx_B11 = jnp.clip( + noise_bucket_idx_B11, 0, self.noise_buckets - 1 + ) + noise_level_embed_B11M = dynamics_causal.noise_level_embed( noise_bucket_idx_B11 ) diff --git a/jasmine/models/dynamics.py b/jasmine/models/dynamics.py index 4842ab5..a0e2f90 100644 --- a/jasmine/models/dynamics.py +++ b/jasmine/models/dynamics.py @@ -103,7 +103,7 @@ def _apply_noise_augmentation(self, vid_embed_BTNM, rng): noise_augmented_vid_embed_BTNM = ( jnp.sqrt(1 - noise_level_B111) * vid_embed_BTNM - + noise_level_B111 * noise_BTNM + + jnp.sqrt(noise_level_B111) * noise_BTNM ) return noise_augmented_vid_embed_BTNM, noise_level_embed_BT1M @@ -233,7 +233,7 @@ def _apply_noise_augmentation(self, vid_embed_BTNM, rng): noise_augmented_vid_embed_BTNM = ( jnp.sqrt(1 - noise_level_B111) * vid_embed_BTNM - + noise_level_B111 * noise_BTNM + + jnp.sqrt(noise_level_B111) * noise_BTNM ) return noise_augmented_vid_embed_BTNM, noise_level_embed_BT1M From bc3e6ee77b36dca354a6febd78c626d285d39768 Mon Sep 17 00:00:00 2001 From: Mihir Mahajan Date: Mon, 29 Sep 2025 15:28:28 +0200 Subject: [PATCH 8/9] refactored the pr --- jasmine/genie.py | 50 +++++++++++--------------------------- jasmine/models/dynamics.py | 27 ++++++++++---------- jasmine/sample.py | 14 ++++++----- 3 files changed, 36 insertions(+), 55 deletions(-) diff --git a/jasmine/genie.py b/jasmine/genie.py index bb87843..678c5c8 100644 --- a/jasmine/genie.py +++ b/jasmine/genie.py @@ -192,7 +192,7 @@ def __call__( else latent_actions_BTm11L ), ) - outputs["mask_rng"] = batch["rng"] + outputs["rng"] = batch["rng"] dyna_logits_BTNV, dyna_mask = self.dynamics(outputs) outputs["token_logits"] = dyna_logits_BTNV outputs["mask"] = dyna_mask @@ -213,8 +213,8 @@ def sample( maskgit_steps: int = 25, ) -> tuple[jax.Array, jax.Array]: assert ( - noise_level < self.max_noise_level - ), "Noise level must me smaller than max_noise_level." + noise_level <= self.max_noise_level + ), "Noise level must not be greater than max_noise_level." if self.dyna_type == "maskgit": return self.sample_maskgit( batch, seq_len, noise_level, maskgit_steps, temperature, sample_argmax @@ -332,24 +332,12 @@ def maskgit_step_fn( act_embed_BSM, (B, S, 1, act_embed_BSM.shape[-1]) ) - # TODO mihir - rng, _rng_noise = jax.random.split(rng) - noise_level_111 = noise_level.reshape(1, 1, 1) - noise_level_B11 = jnp.tile(noise_level_111, (B, 1, 1)) - noise_bucket_idx_B11 = jnp.floor( - (noise_level_B11 * self.noise_buckets) / self.max_noise_level - ).astype(jnp.int32) - - # Clip noise_bucket_idx to ensure it stays within valid range to prevent NaNs - noise_bucket_idx_B11 = jnp.clip( - noise_bucket_idx_B11, 0, self.noise_buckets - 1 + rng, _rng_noise_augmentation = jax.random.split(rng) + noise_level_B = jnp.tile(noise_level, B) + _, noise_level_embed_BS1M = dynamics_maskgit.apply_noise_augmentation( + vid_embed_BSNM, _rng_noise_augmentation, noise_level_B ) - noise_level_embed_B11M = dynamics_maskgit.noise_level_embed( - noise_bucket_idx_B11 - ) - noise_level_embed_BS1M = jnp.tile(noise_level_embed_B11M, (1, S, 1, 1)) - vid_embed_BSNp2M = jnp.concatenate( [act_embed_BS1M, noise_level_embed_BS1M, vid_embed_BSNM], axis=2 ) @@ -488,6 +476,7 @@ def sample_causal( token_idxs_BSN = jnp.concatenate([token_idxs_BTN, pad], axis=1) logits_BSNV = jnp.zeros((*token_idxs_BSN.shape, self.num_patch_latents)) dynamics_state = nnx.state(self.dynamics) + noise_level = jnp.array(noise_level) if self.use_gt_actions: assert self.action_embed is not None @@ -519,6 +508,8 @@ def causal_step_fn( num_blocks=self.dyna_num_blocks, num_heads=self.dyna_num_heads, dropout=self.dropout, + max_noise_level=self.max_noise_level, + noise_buckets=self.noise_buckets, param_dtype=self.param_dtype, dtype=self.dtype, use_flash_attention=self.use_flash_attention, @@ -538,24 +529,11 @@ def causal_step_fn( act_embed_BSM, (B, S, 1, act_embed_BSM.shape[-1]) ) - # TODO mihir - - rng, _rng_noise = jax.random.split(rng) - noise_level_111 = noise_level.reshape(1, 1, 1) - noise_level_B11 = jnp.tile(noise_level_111, (B, 1, 1)) - noise_bucket_idx_B11 = jnp.floor( - (noise_level_B11 * self.noise_buckets) / self.max_noise_level - ).astype(jnp.int32) - - # Clip noise_bucket_idx to ensure it stays within valid range to prevent NaNs - noise_bucket_idx_B11 = jnp.clip( - noise_bucket_idx_B11, 0, self.noise_buckets - 1 - ) - - noise_level_embed_B11M = dynamics_causal.noise_level_embed( - noise_bucket_idx_B11 + rng, _rng_noise_augmentation = jax.random.split(rng) + noise_level_B = jnp.tile(noise_level, B) + _, noise_level_embed_BS1M = dynamics_causal.apply_noise_augmentation( + vid_embed_BSNM, _rng_noise_augmentation, noise_level_B ) - noise_level_embed_BS1M = jnp.tile(noise_level_embed_B11M, (1, S, 1, 1)) vid_embed_BSNp2M = jnp.concatenate( [act_embed_BS1M, noise_level_embed_BS1M, vid_embed_BSNM], axis=2 diff --git a/jasmine/models/dynamics.py b/jasmine/models/dynamics.py index a0e2f90..13c7ec5 100644 --- a/jasmine/models/dynamics.py +++ b/jasmine/models/dynamics.py @@ -78,16 +78,17 @@ def __init__( self.noise_buckets, self.model_dim, rngs=rngs ) - def _apply_noise_augmentation(self, vid_embed_BTNM, rng): + def apply_noise_augmentation(self, vid_embed_BTNM, rng, noise_level_B=None): B, T, N, M = vid_embed_BTNM.shape rng, _rng_noise_lvl, _rng_noise = jax.random.split(rng, 3) - noise_level_B = jax.random.uniform( - _rng_noise_lvl, - shape=(B,), - minval=0.0, - maxval=self.max_noise_level, - dtype=self.dtype, - ) + if noise_level_B is None: + noise_level_B = jax.random.uniform( + _rng_noise_lvl, + shape=(B,), + minval=0.0, + maxval=self.max_noise_level, + dtype=self.dtype, + ) noise_BTNM = jax.random.normal(_rng_noise, shape=(B, T, N, M), dtype=self.dtype) noise_bucket_idx_B = jnp.floor( (noise_level_B * self.noise_buckets) / self.max_noise_level @@ -118,7 +119,7 @@ def __call__( vid_embed_BTNM = self.patch_embed(video_tokens_BTN) B = vid_embed_BTNM.shape[0] - rng, _rng_prob, *_rngs_mask = jax.random.split(batch["mask_rng"], B + 2) + rng, _rng_prob, *_rngs_mask = jax.random.split(batch["rng"], B + 2) mask_prob = jax.random.uniform(_rng_prob, shape=(B,), minval=self.mask_limit) per_sample_shape = vid_embed_BTNM.shape[1:-1] mask = jax.vmap( @@ -131,7 +132,7 @@ def __call__( ) # --- Apply noise augmentation --- - vid_embed_BTNM, noise_level_embed_BT1M = self._apply_noise_augmentation( + vid_embed_BTNM, noise_level_embed_BT1M = self.apply_noise_augmentation( vid_embed_BTNM, rng ) @@ -208,7 +209,7 @@ def __init__( self.noise_buckets, self.model_dim, rngs=rngs ) - def _apply_noise_augmentation(self, vid_embed_BTNM, rng): + def apply_noise_augmentation(self, vid_embed_BTNM, rng): B, T, N, M = vid_embed_BTNM.shape rng, _rng_noise_lvl, _rng_noise = jax.random.split(rng, 3) noise_level_B = jax.random.uniform( @@ -249,8 +250,8 @@ def __call__( padded_act_embed_BT1M = jnp.pad( act_embed_BTm11M, ((0, 0), (1, 0), (0, 0), (0, 0)) ) - vid_embed_BTNM, noise_level_embed_BT1M = self._apply_noise_augmentation( - video_tokens_BTN, batch["rng"] + vid_embed_BTNM, noise_level_embed_BT1M = self.apply_noise_augmentation( + vid_embed_BTNM, batch["rng"] ) vid_embed_BTNp2M = jnp.concatenate( [padded_act_embed_BT1M, noise_level_embed_BT1M, vid_embed_BTNM], axis=2 diff --git a/jasmine/sample.py b/jasmine/sample.py index c0eda88..f2f7573 100644 --- a/jasmine/sample.py +++ b/jasmine/sample.py @@ -37,6 +37,7 @@ class Args: sample_argmax: bool = True start_frame: int = 1 noise_level: float = 0.0 + max_noise_level: float = 0.7 noise_buckets: int = 10 # Tokenizer checkpoint tokenizer_dim: int = 512 @@ -104,7 +105,7 @@ class Args: lam_num_blocks=args.lam_num_blocks, lam_num_heads=args.lam_num_heads, lam_co_train=False, - max_noise_level=1.0, + max_noise_level=args.max_noise_level, noise_buckets=args.noise_buckets, use_gt_actions=args.use_gt_actions, # Dynamics @@ -170,11 +171,12 @@ def _sampling_fn(model: Genie, batch: dict) -> jax.Array: "causal", ], f"Invalid dynamics type: {args.dyna_type}" frames, _ = model.sample( - batch, - args.seq_len, - args.temperature, - args.sample_argmax, - args.maskgit_steps, + batch=batch, + seq_len=args.seq_len, + noise_level=args.noise_level, + temperature=args.temperature, + sample_argmax=args.sample_argmax, + maskgit_steps=args.maskgit_steps, ) return frames From 93c0388794f3f19d3530c0ea100ffe6c29631165 Mon Sep 17 00:00:00 2001 From: Mihir Mahajan Date: Tue, 30 Sep 2025 14:34:20 +0200 Subject: [PATCH 9/9] forgot to add noise_level optional param to causal noise augmentation function --- jasmine/models/dynamics.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/jasmine/models/dynamics.py b/jasmine/models/dynamics.py index 13c7ec5..8cce945 100644 --- a/jasmine/models/dynamics.py +++ b/jasmine/models/dynamics.py @@ -209,16 +209,17 @@ def __init__( self.noise_buckets, self.model_dim, rngs=rngs ) - def apply_noise_augmentation(self, vid_embed_BTNM, rng): + def apply_noise_augmentation(self, vid_embed_BTNM, rng, noise_level_B=None): B, T, N, M = vid_embed_BTNM.shape rng, _rng_noise_lvl, _rng_noise = jax.random.split(rng, 3) - noise_level_B = jax.random.uniform( - _rng_noise_lvl, - shape=(B,), - minval=0.0, - maxval=self.max_noise_level, - dtype=self.dtype, - ) + if noise_level_B is None: + noise_level_B = jax.random.uniform( + _rng_noise_lvl, + shape=(B,), + minval=0.0, + maxval=self.max_noise_level, + dtype=self.dtype, + ) noise_BTNM = jax.random.normal(_rng_noise, shape=(B, T, N, M), dtype=self.dtype) noise_bucket_idx_B = jnp.floor( (noise_level_B * self.noise_buckets) / self.max_noise_level