diff --git a/jasmine/genie.py b/jasmine/genie.py index 04ba383..678c5c8 100644 --- a/jasmine/genie.py +++ b/jasmine/genie.py @@ -38,6 +38,8 @@ def __init__( dyna_ffn_dim: int, dyna_num_blocks: int, dyna_num_heads: int, + max_noise_level: float, + noise_buckets: int, param_dtype: jnp.dtype, dtype: jnp.dtype, use_flash_attention: bool, @@ -71,6 +73,8 @@ def __init__( self.dyna_ffn_dim = dyna_ffn_dim self.dyna_num_blocks = dyna_num_blocks self.dyna_num_heads = dyna_num_heads + self.max_noise_level = max_noise_level + self.noise_buckets = noise_buckets self.param_dtype = param_dtype self.dtype = dtype self.use_flash_attention = use_flash_attention @@ -127,6 +131,8 @@ def __init__( num_heads=self.dyna_num_heads, dropout=self.dropout, mask_limit=self.mask_limit, + max_noise_level=self.max_noise_level, + noise_buckets=self.noise_buckets, param_dtype=self.param_dtype, dtype=self.dtype, use_flash_attention=self.use_flash_attention, @@ -141,6 +147,8 @@ def __init__( num_blocks=self.dyna_num_blocks, num_heads=self.dyna_num_heads, dropout=self.dropout, + max_noise_level=self.max_noise_level, + noise_buckets=self.noise_buckets, param_dtype=self.param_dtype, dtype=self.dtype, use_flash_attention=self.use_flash_attention, @@ -184,7 +192,7 @@ def __call__( else latent_actions_BTm11L ), ) - outputs["mask_rng"] = batch["rng"] + outputs["rng"] = batch["rng"] dyna_logits_BTNV, dyna_mask = self.dynamics(outputs) outputs["token_logits"] = dyna_logits_BTNV outputs["mask"] = dyna_mask @@ -199,16 +207,22 @@ def sample( self, batch: Dict[str, jax.Array], seq_len: int, + noise_level: float = 0.0, temperature: float = 1, sample_argmax: bool = False, maskgit_steps: int = 25, ) -> tuple[jax.Array, jax.Array]: + assert ( + noise_level <= self.max_noise_level + ), "Noise level must not be greater than max_noise_level." if self.dyna_type == "maskgit": return self.sample_maskgit( - batch, seq_len, maskgit_steps, temperature, sample_argmax + batch, seq_len, noise_level, maskgit_steps, temperature, sample_argmax ) elif self.dyna_type == "causal": - return self.sample_causal(batch, seq_len, temperature, sample_argmax) + return self.sample_causal( + batch, seq_len, noise_level, temperature, sample_argmax + ) else: raise ValueError(f"Dynamics model type unknown: {self.dyna_type}") @@ -216,6 +230,7 @@ def sample_maskgit( self, batch: Dict[str, jax.Array], seq_len: int, + noise_level: float, steps: int = 25, temperature: float = 1, sample_argmax: bool = False, @@ -257,6 +272,7 @@ def sample_maskgit( init_logits_BSNV = jnp.zeros( shape=(*token_idxs_BSN.shape, self.num_patch_latents) ) + noise_level = jnp.array(noise_level) if self.use_gt_actions: assert self.action_embed is not None latent_actions_BT1L = self.action_embed(batch["actions"]).reshape( @@ -290,6 +306,8 @@ def maskgit_step_fn( num_blocks=self.dyna_num_blocks, num_heads=self.dyna_num_heads, dropout=self.dropout, + max_noise_level=self.max_noise_level, + noise_buckets=self.noise_buckets, mask_limit=self.mask_limit, param_dtype=self.param_dtype, dtype=self.dtype, @@ -313,10 +331,22 @@ def maskgit_step_fn( act_embed_BS1M = jnp.reshape( act_embed_BSM, (B, S, 1, act_embed_BSM.shape[-1]) ) - vid_embed_BSNM += act_embed_BS1M + + rng, _rng_noise_augmentation = jax.random.split(rng) + noise_level_B = jnp.tile(noise_level, B) + _, noise_level_embed_BS1M = dynamics_maskgit.apply_noise_augmentation( + vid_embed_BSNM, _rng_noise_augmentation, noise_level_B + ) + + vid_embed_BSNp2M = jnp.concatenate( + [act_embed_BS1M, noise_level_embed_BS1M, vid_embed_BSNM], axis=2 + ) unmasked_ratio = jnp.cos(jnp.pi * (step + 1) / (steps * 2)) step_temp = temperature * (1.0 - unmasked_ratio) - final_logits_BSNV = dynamics_maskgit.transformer(vid_embed_BSNM) / step_temp + final_logits_BSNp2V = ( + dynamics_maskgit.transformer(vid_embed_BSNp2M) / step_temp + ) + final_logits_BSNV = final_logits_BSNp2V[:, :, 2:] # --- Sample new tokens for final frame --- if sample_argmax: @@ -407,6 +437,7 @@ def sample_causal( self, batch: Dict[str, jax.Array], seq_len: int, + noise_level: float, temperature: float = 1, sample_argmax: bool = False, ) -> tuple[jax.Array, jax.Array]: @@ -445,6 +476,7 @@ def sample_causal( token_idxs_BSN = jnp.concatenate([token_idxs_BTN, pad], axis=1) logits_BSNV = jnp.zeros((*token_idxs_BSN.shape, self.num_patch_latents)) dynamics_state = nnx.state(self.dynamics) + noise_level = jnp.array(noise_level) if self.use_gt_actions: assert self.action_embed is not None @@ -476,6 +508,8 @@ def causal_step_fn( num_blocks=self.dyna_num_blocks, num_heads=self.dyna_num_heads, dropout=self.dropout, + max_noise_level=self.max_noise_level, + noise_buckets=self.noise_buckets, param_dtype=self.param_dtype, dtype=self.dtype, use_flash_attention=self.use_flash_attention, @@ -494,12 +528,21 @@ def causal_step_fn( act_embed_BS1M = jnp.reshape( act_embed_BSM, (B, S, 1, act_embed_BSM.shape[-1]) ) - vid_embed_BSNp1M = jnp.concatenate([act_embed_BS1M, vid_embed_BSNM], axis=2) - final_logits_BTNp1V = ( - dynamics_causal.transformer(vid_embed_BSNp1M, (step_t, step_n)) + + rng, _rng_noise_augmentation = jax.random.split(rng) + noise_level_B = jnp.tile(noise_level, B) + _, noise_level_embed_BS1M = dynamics_causal.apply_noise_augmentation( + vid_embed_BSNM, _rng_noise_augmentation, noise_level_B + ) + + vid_embed_BSNp2M = jnp.concatenate( + [act_embed_BS1M, noise_level_embed_BS1M, vid_embed_BSNM], axis=2 + ) + final_logits_BTNp2V = ( + dynamics_causal.transformer(vid_embed_BSNp2M, (step_t, step_n)) / temperature ) - final_logits_BV = final_logits_BTNp1V[:, step_t, step_n, :] + final_logits_BV = final_logits_BTNp2V[:, step_t, step_n + 1, :] # --- Sample new tokens for final frame --- if sample_argmax: diff --git a/jasmine/models/dynamics.py b/jasmine/models/dynamics.py index 6f0288b..8cce945 100644 --- a/jasmine/models/dynamics.py +++ b/jasmine/models/dynamics.py @@ -28,6 +28,8 @@ def __init__( num_blocks: int, num_heads: int, dropout: float, + max_noise_level: float, + noise_buckets: int, mask_limit: float, param_dtype: jnp.dtype, dtype: jnp.dtype, @@ -41,6 +43,8 @@ def __init__( self.num_blocks = num_blocks self.num_heads = num_heads self.dropout = dropout + self.max_noise_level = max_noise_level + self.noise_buckets = noise_buckets self.mask_limit = mask_limit self.param_dtype = param_dtype self.dtype = dtype @@ -70,6 +74,40 @@ def __init__( dtype=self.dtype, rngs=rngs, ) + self.noise_level_embed = nnx.Embed( + self.noise_buckets, self.model_dim, rngs=rngs + ) + + def apply_noise_augmentation(self, vid_embed_BTNM, rng, noise_level_B=None): + B, T, N, M = vid_embed_BTNM.shape + rng, _rng_noise_lvl, _rng_noise = jax.random.split(rng, 3) + if noise_level_B is None: + noise_level_B = jax.random.uniform( + _rng_noise_lvl, + shape=(B,), + minval=0.0, + maxval=self.max_noise_level, + dtype=self.dtype, + ) + noise_BTNM = jax.random.normal(_rng_noise, shape=(B, T, N, M), dtype=self.dtype) + noise_bucket_idx_B = jnp.floor( + (noise_level_B * self.noise_buckets) / self.max_noise_level + ).astype(jnp.int32) + + # Clip noise_bucket_idx_B to ensure it stays within valid range to prevent NaNs + noise_bucket_idx_B = jnp.clip(noise_bucket_idx_B, 0, self.noise_buckets - 1) + + noise_bucket_idx_B11 = noise_bucket_idx_B.reshape(B, 1, 1) + noise_level_embed_B11M = self.noise_level_embed(noise_bucket_idx_B11) + noise_level_embed_BT1M = jnp.tile(noise_level_embed_B11M, (1, T, 1, 1)) + noise_level_B111 = noise_level_B.reshape(B, 1, 1, 1) + + noise_augmented_vid_embed_BTNM = ( + jnp.sqrt(1 - noise_level_B111) * vid_embed_BTNM + + jnp.sqrt(noise_level_B111) * noise_BTNM + ) + + return noise_augmented_vid_embed_BTNM, noise_level_embed_BT1M def __call__( self, @@ -80,11 +118,9 @@ def __call__( latent_actions_BTm11L = batch["latent_actions"] vid_embed_BTNM = self.patch_embed(video_tokens_BTN) - batch_size = vid_embed_BTNM.shape[0] - _rng_prob, *_rngs_mask = jax.random.split(batch["mask_rng"], batch_size + 1) - mask_prob = jax.random.uniform( - _rng_prob, shape=(batch_size,), minval=self.mask_limit - ) + B = vid_embed_BTNM.shape[0] + rng, _rng_prob, *_rngs_mask = jax.random.split(batch["rng"], B + 2) + mask_prob = jax.random.uniform(_rng_prob, shape=(B,), minval=self.mask_limit) per_sample_shape = vid_embed_BTNM.shape[1:-1] mask = jax.vmap( lambda rng, prob: jax.random.bernoulli(rng, prob, per_sample_shape), @@ -95,16 +131,21 @@ def __call__( jnp.expand_dims(mask, -1), self.mask_token.value, vid_embed_BTNM ) + # --- Apply noise augmentation --- + vid_embed_BTNM, noise_level_embed_BT1M = self.apply_noise_augmentation( + vid_embed_BTNM, rng + ) + # --- Predict transition --- act_embed_BTm11M = self.action_up(latent_actions_BTm11L) padded_act_embed_BT1M = jnp.pad( act_embed_BTm11M, ((0, 0), (1, 0), (0, 0), (0, 0)) ) - padded_act_embed_BTNM = jnp.broadcast_to( - padded_act_embed_BT1M, vid_embed_BTNM.shape + vid_embed_BTNp2M = jnp.concatenate( + [padded_act_embed_BT1M, noise_level_embed_BT1M, vid_embed_BTNM], axis=2 ) - vid_embed_BTNM += padded_act_embed_BTNM - logits_BTNV = self.transformer(vid_embed_BTNM) + logits_BTNp2V = self.transformer(vid_embed_BTNp2M) + logits_BTNV = logits_BTNp2V[:, :, 2:] return logits_BTNV, mask @@ -120,6 +161,8 @@ def __init__( num_blocks: int, num_heads: int, dropout: float, + max_noise_level: float, + noise_buckets: int, param_dtype: jnp.dtype, dtype: jnp.dtype, use_flash_attention: bool, @@ -133,6 +176,8 @@ def __init__( self.num_blocks = num_blocks self.num_heads = num_heads self.dropout = dropout + self.max_noise_level = max_noise_level + self.noise_buckets = noise_buckets self.param_dtype = param_dtype self.dtype = dtype self.use_flash_attention = use_flash_attention @@ -160,6 +205,40 @@ def __init__( dtype=self.dtype, rngs=rngs, ) + self.noise_level_embed = nnx.Embed( + self.noise_buckets, self.model_dim, rngs=rngs + ) + + def apply_noise_augmentation(self, vid_embed_BTNM, rng, noise_level_B=None): + B, T, N, M = vid_embed_BTNM.shape + rng, _rng_noise_lvl, _rng_noise = jax.random.split(rng, 3) + if noise_level_B is None: + noise_level_B = jax.random.uniform( + _rng_noise_lvl, + shape=(B,), + minval=0.0, + maxval=self.max_noise_level, + dtype=self.dtype, + ) + noise_BTNM = jax.random.normal(_rng_noise, shape=(B, T, N, M), dtype=self.dtype) + noise_bucket_idx_B = jnp.floor( + (noise_level_B * self.noise_buckets) / self.max_noise_level + ).astype(jnp.int32) + + # Clip noise_bucket_idx_B to ensure it stays within valid range to prevent NaNs + noise_bucket_idx_B = jnp.clip(noise_bucket_idx_B, 0, self.noise_buckets - 1) + + noise_bucket_idx_B11 = noise_bucket_idx_B.reshape(B, 1, 1) + noise_level_embed_B11M = self.noise_level_embed(noise_bucket_idx_B11) + noise_level_embed_BT1M = jnp.tile(noise_level_embed_B11M, (1, T, 1, 1)) + noise_level_B111 = noise_level_B.reshape(B, 1, 1, 1) + + noise_augmented_vid_embed_BTNM = ( + jnp.sqrt(1 - noise_level_B111) * vid_embed_BTNM + + jnp.sqrt(noise_level_B111) * noise_BTNM + ) + + return noise_augmented_vid_embed_BTNM, noise_level_embed_BT1M def __call__( self, @@ -172,9 +251,12 @@ def __call__( padded_act_embed_BT1M = jnp.pad( act_embed_BTm11M, ((0, 0), (1, 0), (0, 0), (0, 0)) ) - vid_embed_BTNp1M = jnp.concatenate( - [padded_act_embed_BT1M, vid_embed_BTNM], axis=2 + vid_embed_BTNM, noise_level_embed_BT1M = self.apply_noise_augmentation( + vid_embed_BTNM, batch["rng"] + ) + vid_embed_BTNp2M = jnp.concatenate( + [padded_act_embed_BT1M, noise_level_embed_BT1M, vid_embed_BTNM], axis=2 ) - logits_BTNp1V = self.transformer(vid_embed_BTNp1M) - logits_BTNV = logits_BTNp1V[:, :, :-1] + logits_BTNp2V = self.transformer(vid_embed_BTNp2M) + logits_BTNV = logits_BTNp2V[:, :, 1:-1] return logits_BTNV, jnp.ones_like(video_tokens_BTN) diff --git a/jasmine/sample.py b/jasmine/sample.py index 1d9aa8e..f2f7573 100644 --- a/jasmine/sample.py +++ b/jasmine/sample.py @@ -36,6 +36,9 @@ class Args: temperature: float = 1.0 sample_argmax: bool = True start_frame: int = 1 + noise_level: float = 0.0 + max_noise_level: float = 0.7 + noise_buckets: int = 10 # Tokenizer checkpoint tokenizer_dim: int = 512 tokenizer_ffn_dim: int = 2048 @@ -102,6 +105,8 @@ class Args: lam_num_blocks=args.lam_num_blocks, lam_num_heads=args.lam_num_heads, lam_co_train=False, + max_noise_level=args.max_noise_level, + noise_buckets=args.noise_buckets, use_gt_actions=args.use_gt_actions, # Dynamics dyna_type=args.dyna_type, @@ -166,11 +171,12 @@ def _sampling_fn(model: Genie, batch: dict) -> jax.Array: "causal", ], f"Invalid dynamics type: {args.dyna_type}" frames, _ = model.sample( - batch, - args.seq_len, - args.temperature, - args.sample_argmax, - args.maskgit_steps, + batch=batch, + seq_len=args.seq_len, + noise_level=args.noise_level, + temperature=args.temperature, + sample_argmax=args.sample_argmax, + maskgit_steps=args.maskgit_steps, ) return frames diff --git a/jasmine/train_dynamics.py b/jasmine/train_dynamics.py index 06cd966..797d52b 100644 --- a/jasmine/train_dynamics.py +++ b/jasmine/train_dynamics.py @@ -78,6 +78,8 @@ class Args: dyna_ffn_dim: int = 2048 dyna_num_blocks: int = 6 dyna_num_heads: int = 8 + max_noise_level: float = 0.7 + noise_buckets: int = 10 dropout: float = 0.0 mask_limit: float = 0.5 z_loss_weight: float = 0.0 @@ -137,6 +139,8 @@ def build_model(args: Args, rng: jax.Array) -> tuple[Genie, jax.Array]: dyna_num_blocks=args.dyna_num_blocks, dyna_num_heads=args.dyna_num_heads, dropout=args.dropout, + max_noise_level=args.max_noise_level, + noise_buckets=args.noise_buckets, mask_limit=args.mask_limit, param_dtype=args.param_dtype, dtype=args.dtype, @@ -538,11 +542,12 @@ def val_step(genie: Genie, inputs: dict) -> dict: :, :-1 ] # remove last frame for generation recon_full_frame, logits_full_frame = genie.sample( - inputs, - args.seq_len, - args.val_temperature, - args.val_sample_argmax, - args.val_maskgit_steps, + batch=inputs, + seq_len=args.seq_len, + noise_level=0.0, + temperature=args.val_temperature, + sample_argmax=args.val_sample_argmax, + maskgit_steps=args.val_maskgit_steps, ) # Calculate metrics for the last frame only step_outputs = {