Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 29 additions & 39 deletions train_dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@
import jax
import jax.numpy as jnp
import tyro
import wandb

from genie import Genie, restore_genie_components
from models.tokenizer import TokenizerVQVAE
from models.lam import LatentActionModel
from utils.dataloader import get_dataloader
from utils.parameter_utils import count_parameters_by_component
from utils.logger import CompositeLogger


@dataclass
Expand Down Expand Up @@ -60,7 +60,8 @@ class Args:
dropout: float = 0.0
mask_limit: float = 0.5
# Logging
log: bool = False
log_dir: str = "logs/"
loggers: list[str] = field(default_factory=lambda: ["console"]) # options: console, local, tb, wandb
entity: str = ""
project: str = ""
name: str = "train_dynamics"
Expand Down Expand Up @@ -173,19 +174,11 @@ def train_step(state, inputs):

param_counts = count_parameters_by_component(init_params)

if args.log and jax.process_index() == 0:
wandb.init(
entity=args.entity,
project=args.project,
name=args.name,
tags=args.tags,
group="debug",
config=args,
)
wandb.config.update({"model_param_count": param_counts})

print("Parameter counts:")
print(param_counts)
if jax.process_index() == 0:
cfg = vars(args).copy()
cfg["model_param_count"] = param_counts
logger = CompositeLogger(args.loggers, cfg)
print(f"Training Dynamics Model with {param_counts['total']} parameters")

# --- Initialize optimizer ---
lr_schedule = optax.warmup_cosine_decay_schedule(
Expand Down Expand Up @@ -240,31 +233,28 @@ def train_step(state, inputs):
step += 1

# --- Logging ---
if args.log:
if step % args.log_interval == 0 and jax.process_index() == 0:
wandb.log(
{
"loss": loss,
"step": step,
**metrics,
}
)
if step % args.log_image_interval == 0:
gt_seq = inputs["videos"][0]
recon_seq = recon[0].clip(0, 1)
comparison_seq = jnp.concatenate((gt_seq, recon_seq), axis=1)
comparison_seq = einops.rearrange(
comparison_seq * 255, "t h w c -> h (t w) c"
if step % args.log_interval == 0 and jax.process_index() == 0:
logger.log_metrics(
{
"loss": loss,
**metrics,
},
step
)
if step % args.log_image_interval == 0:
gt_seq = inputs["videos"][0]
recon_seq = recon[0].clip(0, 1)
comparison_seq = jnp.concatenate((gt_seq, recon_seq), axis=1)
comparison_seq = einops.rearrange(
comparison_seq * 255, "t h w c -> h (t w) c"
)
if jax.process_index() == 0:
log_images = dict(
image=np.asarray(gt_seq[args.seq_len - 1] * 255.).astype(np.uint8),
recon=np.asarray(recon_seq[args.seq_len - 1] * 255.).astype(np.uint8),
true_vs_recon=np.asarray(comparison_seq.astype(np.uint8)),
)
if jax.process_index() == 0:
log_images = dict(
image=wandb.Image(np.asarray(gt_seq[args.seq_len - 1])),
recon=wandb.Image(np.asarray(recon_seq[args.seq_len - 1])),
true_vs_recon=wandb.Image(
np.asarray(comparison_seq.astype(np.uint8))
),
)
wandb.log(log_images)
logger.log_images(log_images, step)
if step % args.log_checkpoint_interval == 0:
ckpt = {"model": train_state}
orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
Expand Down
73 changes: 29 additions & 44 deletions train_lam.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,11 @@
import jax
import jax.numpy as jnp
import tyro
import wandb

from models.lam import LatentActionModel
from utils.dataloader import get_dataloader
from utils.parameter_utils import count_parameters_by_component

from utils.logger import CompositeLogger

@dataclass
class Args:
Expand Down Expand Up @@ -49,7 +48,8 @@ class Args:
dropout: float = 0.0
codebook_dropout: float = 0.0
# Logging
log: bool = False
log_dir: str = "logs/"
loggers: list[str] = field(default_factory=lambda: ["console"]) # options: console, local, tb, wandb
entity: str = ""
project: str = ""
name: str = "train_lam"
Expand All @@ -59,10 +59,8 @@ class Args:
ckpt_dir: str = ""
log_checkpoint_interval: int = 10000


args = tyro.cli(Args)


def lam_loss_fn(params, state, inputs):
# --- Compute loss ---
outputs = state.apply_fn(
Expand Down Expand Up @@ -94,7 +92,6 @@ def lam_loss_fn(params, state, inputs):
)
return loss, (outputs["recon"], index_counts, metrics)


@jax.jit
def train_step(state, inputs, action_last_active):
# --- Update model ---
Expand All @@ -118,7 +115,6 @@ def train_step(state, inputs, action_last_active):
action_last_active = jnp.where(do_reset, 0, action_last_active)
return state, loss, recon, action_last_active, metrics


if __name__ == "__main__":
jax.distributed.initialize()
num_devices = jax.device_count()
Expand Down Expand Up @@ -164,19 +160,11 @@ def train_step(state, inputs, action_last_active):

param_counts = count_parameters_by_component(init_params)

if args.log and jax.process_index() == 0:
wandb.init(
entity=args.entity,
project=args.project,
name=args.name,
tags=args.tags,
group="debug",
config=args,
)
wandb.config.update({"model_param_count": param_counts})

print("Parameter counts:")
print(param_counts)
if jax.process_index() == 0:
cfg = vars(args).copy()
cfg["model_param_count"] = param_counts
logger = CompositeLogger(args.loggers, cfg)
print(f"Training Latent Action Model with {param_counts['total']} parameters")

# --- Initialize optimizer ---
lr_schedule = optax.warmup_cosine_decay_schedule(
Expand Down Expand Up @@ -240,31 +228,28 @@ def train_step(state, inputs, action_last_active):
step += 1

# --- Logging ---
if args.log:
if step % args.log_interval == 0 and jax.process_index() == 0:
wandb.log(
{
"loss": loss,
"step": step,
**metrics,
}
)
if step % args.log_image_interval == 0:
gt_seq = inputs["videos"][0][1:]
recon_seq = recon[0].clip(0, 1)
comparison_seq = jnp.concatenate((gt_seq, recon_seq), axis=1)
comparison_seq = einops.rearrange(
comparison_seq * 255, "t h w c -> h (t w) c"
if step % args.log_interval == 0 and jax.process_index() == 0:
logger.log_metrics(
{
"loss": loss,
**metrics,
},
step
)
if step % args.log_image_interval == 0:
gt_seq = inputs["videos"][0][1:]
recon_seq = recon[0].clip(0, 1)
comparison_seq = jnp.concatenate((gt_seq, recon_seq), axis=1)
comparison_seq = einops.rearrange(
comparison_seq * 255, "t h w c -> h (t w) c"
)
if jax.process_index() == 0:
log_images = dict(
image=np.asarray(gt_seq[0] * 255.).astype(np.uint8),
recon=np.asarray(recon_seq[0] * 255.).astype(np.uint8),
true_vs_recon=np.asarray(comparison_seq.astype(np.uint8)),
)
if jax.process_index() == 0:
log_images = dict(
image=wandb.Image(np.asarray(gt_seq[0])),
recon=wandb.Image(np.asarray(recon_seq[0])),
true_vs_recon=wandb.Image(
np.asarray(comparison_seq.astype(np.uint8))
),
)
wandb.log(log_images)
logger.log_images(log_images, step)
if step % args.log_checkpoint_interval == 0:
ckpt = {"model": train_state}
orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
Expand Down
73 changes: 33 additions & 40 deletions train_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@
import jax
import jax.numpy as jnp
import tyro
import wandb

from models.tokenizer import TokenizerVQVAE
from utils.dataloader import get_dataloader
from utils.parameter_utils import count_parameters_by_component
from utils.logger import CompositeLogger


@dataclass
Expand Down Expand Up @@ -48,7 +48,8 @@ class Args:
dropout: float = 0.0
codebook_dropout: float = 0.01
# Logging
log: bool = False
log_dir: str = "logs/"
loggers: list[str] = field(default_factory=lambda: ["console"]) # options: console, local, tb, wandb
entity: str = ""
project: str = ""
name: str = "train_tokenizer"
Expand Down Expand Up @@ -158,16 +159,11 @@ def train_step(state, inputs):

param_counts = count_parameters_by_component(init_params)

if args.log and jax.process_index() == 0:
wandb.init(
entity=args.entity,
project=args.project,
name=args.name,
tags=args.tags,
group="debug",
config=args,
)
wandb.config.update({"model_param_count": param_counts})
if jax.process_index() == 0:
cfg = vars(args).copy()
cfg["model_param_count"] = param_counts
logger = CompositeLogger(args.loggers, cfg)
print(f"Training Tokenizer Model with {param_counts['total']} parameters")

print("Parameter counts:")
print(param_counts)
Expand Down Expand Up @@ -228,38 +224,35 @@ def train_step(state, inputs):

inputs = dict(videos=videos, rng=_rng, dropout_rng=_rng_dropout)
train_state, loss, recon, metrics = train_step(train_state, inputs)
print(f"Step {step}, loss: {loss}")
step += 1

# --- Logging ---
if args.log:
if step % args.log_interval == 0 and jax.process_index() == 0:
wandb.log(
{
"loss": loss,
"step": step,
**metrics,
}
)
if step % args.log_image_interval == 0:
gt_seq = inputs["videos"][0]
recon_seq = recon[0].clip(0, 1)
comparison_seq = jnp.concatenate((gt_seq, recon_seq), axis=1)
comparison_seq = einops.rearrange(
comparison_seq * 255, "t h w c -> h (t w) c"
if step % args.log_interval == 0 and jax.process_index() == 0:
logger.log_metrics(
{
"loss": loss,
**metrics,
},
step
)
if step % args.log_image_interval == 0:
gt_seq = inputs["videos"][0]
recon_seq = recon[0].clip(0, 1)
comparison_seq = jnp.concatenate((gt_seq, recon_seq), axis=1)
comparison_seq = einops.rearrange(
comparison_seq * 255, "t h w c -> h (t w) c"
)
# NOTE: Process-dependent control flow deliberately happens
# after indexing operation since it must not contain code
# sections that lead to cross-accelerator communication.
if jax.process_index() == 0:
log_images = dict(
image=np.asarray(gt_seq[0] * 255.).astype(np.uint8),
recon=np.asarray(recon_seq[0] * 255.).astype(np.uint8),
true_vs_recon=np.asarray(comparison_seq.astype(np.uint8)
),
)
# NOTE: Process-dependent control flow deliberately happens
# after indexing operation since it must not contain code
# sections that lead to cross-accelerator communication.
if jax.process_index() == 0:
log_images = dict(
image=wandb.Image(np.asarray(gt_seq[0])),
recon=wandb.Image(np.asarray(recon_seq[0])),
true_vs_recon=wandb.Image(
np.asarray(comparison_seq.astype(np.uint8))
),
)
wandb.log(log_images)
logger.log_images(log_images, step)
if step % args.log_checkpoint_interval == 0:
ckpt = {"model": train_state}
orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
Expand Down
Loading