Skip to content
Draft
61 changes: 52 additions & 9 deletions jasmine/genie.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ def __init__(
dyna_ffn_dim: int,
dyna_num_blocks: int,
dyna_num_heads: int,
max_noise_level: float,
noise_buckets: int,
param_dtype: jnp.dtype,
dtype: jnp.dtype,
use_flash_attention: bool,
Expand Down Expand Up @@ -71,6 +73,8 @@ def __init__(
self.dyna_ffn_dim = dyna_ffn_dim
self.dyna_num_blocks = dyna_num_blocks
self.dyna_num_heads = dyna_num_heads
self.max_noise_level = max_noise_level
self.noise_buckets = noise_buckets
self.param_dtype = param_dtype
self.dtype = dtype
self.use_flash_attention = use_flash_attention
Expand Down Expand Up @@ -127,6 +131,8 @@ def __init__(
num_heads=self.dyna_num_heads,
dropout=self.dropout,
mask_limit=self.mask_limit,
max_noise_level=self.max_noise_level,
noise_buckets=self.noise_buckets,
param_dtype=self.param_dtype,
dtype=self.dtype,
use_flash_attention=self.use_flash_attention,
Expand All @@ -141,6 +147,8 @@ def __init__(
num_blocks=self.dyna_num_blocks,
num_heads=self.dyna_num_heads,
dropout=self.dropout,
max_noise_level=self.max_noise_level,
noise_buckets=self.noise_buckets,
param_dtype=self.param_dtype,
dtype=self.dtype,
use_flash_attention=self.use_flash_attention,
Expand Down Expand Up @@ -184,7 +192,7 @@ def __call__(
else latent_actions_BTm11L
),
)
outputs["mask_rng"] = batch["rng"]
outputs["rng"] = batch["rng"]
dyna_logits_BTNV, dyna_mask = self.dynamics(outputs)
outputs["token_logits"] = dyna_logits_BTNV
outputs["mask"] = dyna_mask
Expand All @@ -199,23 +207,30 @@ def sample(
self,
batch: Dict[str, jax.Array],
seq_len: int,
noise_level: float = 0.0,
temperature: float = 1,
sample_argmax: bool = False,
maskgit_steps: int = 25,
) -> tuple[jax.Array, jax.Array]:
assert (
noise_level <= self.max_noise_level
), "Noise level must not be greater than max_noise_level."
if self.dyna_type == "maskgit":
return self.sample_maskgit(
batch, seq_len, maskgit_steps, temperature, sample_argmax
batch, seq_len, noise_level, maskgit_steps, temperature, sample_argmax
)
elif self.dyna_type == "causal":
return self.sample_causal(batch, seq_len, temperature, sample_argmax)
return self.sample_causal(
batch, seq_len, noise_level, temperature, sample_argmax
)
else:
raise ValueError(f"Dynamics model type unknown: {self.dyna_type}")

def sample_maskgit(
self,
batch: Dict[str, jax.Array],
seq_len: int,
noise_level: float,
steps: int = 25,
temperature: float = 1,
sample_argmax: bool = False,
Expand Down Expand Up @@ -257,6 +272,7 @@ def sample_maskgit(
init_logits_BSNV = jnp.zeros(
shape=(*token_idxs_BSN.shape, self.num_patch_latents)
)
noise_level = jnp.array(noise_level)
if self.use_gt_actions:
assert self.action_embed is not None
latent_actions_BT1L = self.action_embed(batch["actions"]).reshape(
Expand Down Expand Up @@ -290,6 +306,8 @@ def maskgit_step_fn(
num_blocks=self.dyna_num_blocks,
num_heads=self.dyna_num_heads,
dropout=self.dropout,
max_noise_level=self.max_noise_level,
noise_buckets=self.noise_buckets,
mask_limit=self.mask_limit,
param_dtype=self.param_dtype,
dtype=self.dtype,
Expand All @@ -313,10 +331,22 @@ def maskgit_step_fn(
act_embed_BS1M = jnp.reshape(
act_embed_BSM, (B, S, 1, act_embed_BSM.shape[-1])
)
vid_embed_BSNM += act_embed_BS1M

rng, _rng_noise_augmentation = jax.random.split(rng)
noise_level_B = jnp.tile(noise_level, B)
_, noise_level_embed_BS1M = dynamics_maskgit.apply_noise_augmentation(
vid_embed_BSNM, _rng_noise_augmentation, noise_level_B
)

vid_embed_BSNp2M = jnp.concatenate(
[act_embed_BS1M, noise_level_embed_BS1M, vid_embed_BSNM], axis=2
)
unmasked_ratio = jnp.cos(jnp.pi * (step + 1) / (steps * 2))
step_temp = temperature * (1.0 - unmasked_ratio)
final_logits_BSNV = dynamics_maskgit.transformer(vid_embed_BSNM) / step_temp
final_logits_BSNp2V = (
dynamics_maskgit.transformer(vid_embed_BSNp2M) / step_temp
)
final_logits_BSNV = final_logits_BSNp2V[:, :, 2:]

# --- Sample new tokens for final frame ---
if sample_argmax:
Expand Down Expand Up @@ -407,6 +437,7 @@ def sample_causal(
self,
batch: Dict[str, jax.Array],
seq_len: int,
noise_level: float,
temperature: float = 1,
sample_argmax: bool = False,
) -> tuple[jax.Array, jax.Array]:
Expand Down Expand Up @@ -445,6 +476,7 @@ def sample_causal(
token_idxs_BSN = jnp.concatenate([token_idxs_BTN, pad], axis=1)
logits_BSNV = jnp.zeros((*token_idxs_BSN.shape, self.num_patch_latents))
dynamics_state = nnx.state(self.dynamics)
noise_level = jnp.array(noise_level)

if self.use_gt_actions:
assert self.action_embed is not None
Expand Down Expand Up @@ -476,6 +508,8 @@ def causal_step_fn(
num_blocks=self.dyna_num_blocks,
num_heads=self.dyna_num_heads,
dropout=self.dropout,
max_noise_level=self.max_noise_level,
noise_buckets=self.noise_buckets,
param_dtype=self.param_dtype,
dtype=self.dtype,
use_flash_attention=self.use_flash_attention,
Expand All @@ -494,12 +528,21 @@ def causal_step_fn(
act_embed_BS1M = jnp.reshape(
act_embed_BSM, (B, S, 1, act_embed_BSM.shape[-1])
)
vid_embed_BSNp1M = jnp.concatenate([act_embed_BS1M, vid_embed_BSNM], axis=2)
final_logits_BTNp1V = (
dynamics_causal.transformer(vid_embed_BSNp1M, (step_t, step_n))

rng, _rng_noise_augmentation = jax.random.split(rng)
noise_level_B = jnp.tile(noise_level, B)
_, noise_level_embed_BS1M = dynamics_causal.apply_noise_augmentation(
vid_embed_BSNM, _rng_noise_augmentation, noise_level_B
)

vid_embed_BSNp2M = jnp.concatenate(
[act_embed_BS1M, noise_level_embed_BS1M, vid_embed_BSNM], axis=2
)
final_logits_BTNp2V = (
dynamics_causal.transformer(vid_embed_BSNp2M, (step_t, step_n))
/ temperature
)
final_logits_BV = final_logits_BTNp1V[:, step_t, step_n, :]
final_logits_BV = final_logits_BTNp2V[:, step_t, step_n + 1, :]

# --- Sample new tokens for final frame ---
if sample_argmax:
Expand Down
108 changes: 95 additions & 13 deletions jasmine/models/dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ def __init__(
num_blocks: int,
num_heads: int,
dropout: float,
max_noise_level: float,
noise_buckets: int,
mask_limit: float,
param_dtype: jnp.dtype,
dtype: jnp.dtype,
Expand All @@ -41,6 +43,8 @@ def __init__(
self.num_blocks = num_blocks
self.num_heads = num_heads
self.dropout = dropout
self.max_noise_level = max_noise_level
self.noise_buckets = noise_buckets
self.mask_limit = mask_limit
self.param_dtype = param_dtype
self.dtype = dtype
Expand Down Expand Up @@ -70,6 +74,40 @@ def __init__(
dtype=self.dtype,
rngs=rngs,
)
self.noise_level_embed = nnx.Embed(
self.noise_buckets, self.model_dim, rngs=rngs
)

def apply_noise_augmentation(self, vid_embed_BTNM, rng, noise_level_B=None):
B, T, N, M = vid_embed_BTNM.shape
rng, _rng_noise_lvl, _rng_noise = jax.random.split(rng, 3)
if noise_level_B is None:
noise_level_B = jax.random.uniform(
_rng_noise_lvl,
shape=(B,),
minval=0.0,
maxval=self.max_noise_level,
dtype=self.dtype,
)
noise_BTNM = jax.random.normal(_rng_noise, shape=(B, T, N, M), dtype=self.dtype)
noise_bucket_idx_B = jnp.floor(
(noise_level_B * self.noise_buckets) / self.max_noise_level
).astype(jnp.int32)

# Clip noise_bucket_idx_B to ensure it stays within valid range to prevent NaNs
noise_bucket_idx_B = jnp.clip(noise_bucket_idx_B, 0, self.noise_buckets - 1)

noise_bucket_idx_B11 = noise_bucket_idx_B.reshape(B, 1, 1)
noise_level_embed_B11M = self.noise_level_embed(noise_bucket_idx_B11)
noise_level_embed_BT1M = jnp.tile(noise_level_embed_B11M, (1, T, 1, 1))
noise_level_B111 = noise_level_B.reshape(B, 1, 1, 1)

noise_augmented_vid_embed_BTNM = (
jnp.sqrt(1 - noise_level_B111) * vid_embed_BTNM
+ jnp.sqrt(noise_level_B111) * noise_BTNM
)

return noise_augmented_vid_embed_BTNM, noise_level_embed_BT1M

def __call__(
self,
Expand All @@ -80,11 +118,9 @@ def __call__(
latent_actions_BTm11L = batch["latent_actions"]
vid_embed_BTNM = self.patch_embed(video_tokens_BTN)

batch_size = vid_embed_BTNM.shape[0]
_rng_prob, *_rngs_mask = jax.random.split(batch["mask_rng"], batch_size + 1)
mask_prob = jax.random.uniform(
_rng_prob, shape=(batch_size,), minval=self.mask_limit
)
B = vid_embed_BTNM.shape[0]
rng, _rng_prob, *_rngs_mask = jax.random.split(batch["rng"], B + 2)
mask_prob = jax.random.uniform(_rng_prob, shape=(B,), minval=self.mask_limit)
per_sample_shape = vid_embed_BTNM.shape[1:-1]
mask = jax.vmap(
lambda rng, prob: jax.random.bernoulli(rng, prob, per_sample_shape),
Expand All @@ -95,16 +131,21 @@ def __call__(
jnp.expand_dims(mask, -1), self.mask_token.value, vid_embed_BTNM
)

# --- Apply noise augmentation ---
vid_embed_BTNM, noise_level_embed_BT1M = self.apply_noise_augmentation(
vid_embed_BTNM, rng
)

# --- Predict transition ---
act_embed_BTm11M = self.action_up(latent_actions_BTm11L)
padded_act_embed_BT1M = jnp.pad(
act_embed_BTm11M, ((0, 0), (1, 0), (0, 0), (0, 0))
)
padded_act_embed_BTNM = jnp.broadcast_to(
padded_act_embed_BT1M, vid_embed_BTNM.shape
vid_embed_BTNp2M = jnp.concatenate(
[padded_act_embed_BT1M, noise_level_embed_BT1M, vid_embed_BTNM], axis=2
)
vid_embed_BTNM += padded_act_embed_BTNM
logits_BTNV = self.transformer(vid_embed_BTNM)
logits_BTNp2V = self.transformer(vid_embed_BTNp2M)
logits_BTNV = logits_BTNp2V[:, :, 2:]
return logits_BTNV, mask


Expand All @@ -120,6 +161,8 @@ def __init__(
num_blocks: int,
num_heads: int,
dropout: float,
max_noise_level: float,
noise_buckets: int,
param_dtype: jnp.dtype,
dtype: jnp.dtype,
use_flash_attention: bool,
Expand All @@ -133,6 +176,8 @@ def __init__(
self.num_blocks = num_blocks
self.num_heads = num_heads
self.dropout = dropout
self.max_noise_level = max_noise_level
self.noise_buckets = noise_buckets
self.param_dtype = param_dtype
self.dtype = dtype
self.use_flash_attention = use_flash_attention
Expand Down Expand Up @@ -160,6 +205,40 @@ def __init__(
dtype=self.dtype,
rngs=rngs,
)
self.noise_level_embed = nnx.Embed(
self.noise_buckets, self.model_dim, rngs=rngs
)

def apply_noise_augmentation(self, vid_embed_BTNM, rng, noise_level_B=None):
B, T, N, M = vid_embed_BTNM.shape
rng, _rng_noise_lvl, _rng_noise = jax.random.split(rng, 3)
if noise_level_B is None:
noise_level_B = jax.random.uniform(
_rng_noise_lvl,
shape=(B,),
minval=0.0,
maxval=self.max_noise_level,
dtype=self.dtype,
)
noise_BTNM = jax.random.normal(_rng_noise, shape=(B, T, N, M), dtype=self.dtype)
noise_bucket_idx_B = jnp.floor(
(noise_level_B * self.noise_buckets) / self.max_noise_level
).astype(jnp.int32)

# Clip noise_bucket_idx_B to ensure it stays within valid range to prevent NaNs
noise_bucket_idx_B = jnp.clip(noise_bucket_idx_B, 0, self.noise_buckets - 1)

noise_bucket_idx_B11 = noise_bucket_idx_B.reshape(B, 1, 1)
noise_level_embed_B11M = self.noise_level_embed(noise_bucket_idx_B11)
noise_level_embed_BT1M = jnp.tile(noise_level_embed_B11M, (1, T, 1, 1))
noise_level_B111 = noise_level_B.reshape(B, 1, 1, 1)

noise_augmented_vid_embed_BTNM = (
jnp.sqrt(1 - noise_level_B111) * vid_embed_BTNM
+ jnp.sqrt(noise_level_B111) * noise_BTNM
)

return noise_augmented_vid_embed_BTNM, noise_level_embed_BT1M

def __call__(
self,
Expand All @@ -172,9 +251,12 @@ def __call__(
padded_act_embed_BT1M = jnp.pad(
act_embed_BTm11M, ((0, 0), (1, 0), (0, 0), (0, 0))
)
vid_embed_BTNp1M = jnp.concatenate(
[padded_act_embed_BT1M, vid_embed_BTNM], axis=2
vid_embed_BTNM, noise_level_embed_BT1M = self.apply_noise_augmentation(
vid_embed_BTNM, batch["rng"]
)
vid_embed_BTNp2M = jnp.concatenate(
[padded_act_embed_BT1M, noise_level_embed_BT1M, vid_embed_BTNM], axis=2
)
logits_BTNp1V = self.transformer(vid_embed_BTNp1M)
logits_BTNV = logits_BTNp1V[:, :, :-1]
logits_BTNp2V = self.transformer(vid_embed_BTNp2M)
logits_BTNV = logits_BTNp2V[:, :, 1:-1]
return logits_BTNV, jnp.ones_like(video_tokens_BTN)
16 changes: 11 additions & 5 deletions jasmine/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ class Args:
temperature: float = 1.0
sample_argmax: bool = True
start_frame: int = 1
noise_level: float = 0.0
max_noise_level: float = 0.7
noise_buckets: int = 10
# Tokenizer checkpoint
tokenizer_dim: int = 512
tokenizer_ffn_dim: int = 2048
Expand Down Expand Up @@ -102,6 +105,8 @@ class Args:
lam_num_blocks=args.lam_num_blocks,
lam_num_heads=args.lam_num_heads,
lam_co_train=False,
max_noise_level=args.max_noise_level,
noise_buckets=args.noise_buckets,
use_gt_actions=args.use_gt_actions,
# Dynamics
dyna_type=args.dyna_type,
Expand Down Expand Up @@ -166,11 +171,12 @@ def _sampling_fn(model: Genie, batch: dict) -> jax.Array:
"causal",
], f"Invalid dynamics type: {args.dyna_type}"
frames, _ = model.sample(
batch,
args.seq_len,
args.temperature,
args.sample_argmax,
args.maskgit_steps,
batch=batch,
seq_len=args.seq_len,
noise_level=args.noise_level,
temperature=args.temperature,
sample_argmax=args.sample_argmax,
maskgit_steps=args.maskgit_steps,
)
return frames

Expand Down
Loading