diff --git a/train_dynamics.py b/train_dynamics.py index 27f6f2c..a60bd9e 100644 --- a/train_dynamics.py +++ b/train_dynamics.py @@ -13,13 +13,13 @@ import jax import jax.numpy as jnp import tyro -import wandb from genie import Genie, restore_genie_components from models.tokenizer import TokenizerVQVAE from models.lam import LatentActionModel from utils.dataloader import get_dataloader from utils.parameter_utils import count_parameters_by_component +from utils.logger import CompositeLogger @dataclass @@ -60,7 +60,8 @@ class Args: dropout: float = 0.0 mask_limit: float = 0.5 # Logging - log: bool = False + log_dir: str = "logs/" + loggers: list[str] = field(default_factory=lambda: ["console"]) # options: console, local, tb, wandb entity: str = "" project: str = "" name: str = "train_dynamics" @@ -173,19 +174,11 @@ def train_step(state, inputs): param_counts = count_parameters_by_component(init_params) - if args.log and jax.process_index() == 0: - wandb.init( - entity=args.entity, - project=args.project, - name=args.name, - tags=args.tags, - group="debug", - config=args, - ) - wandb.config.update({"model_param_count": param_counts}) - - print("Parameter counts:") - print(param_counts) + if jax.process_index() == 0: + cfg = vars(args).copy() + cfg["model_param_count"] = param_counts + logger = CompositeLogger(args.loggers, cfg) + print(f"Training Dynamics Model with {param_counts['total']} parameters") # --- Initialize optimizer --- lr_schedule = optax.warmup_cosine_decay_schedule( @@ -240,31 +233,28 @@ def train_step(state, inputs): step += 1 # --- Logging --- - if args.log: - if step % args.log_interval == 0 and jax.process_index() == 0: - wandb.log( - { - "loss": loss, - "step": step, - **metrics, - } - ) - if step % args.log_image_interval == 0: - gt_seq = inputs["videos"][0] - recon_seq = recon[0].clip(0, 1) - comparison_seq = jnp.concatenate((gt_seq, recon_seq), axis=1) - comparison_seq = einops.rearrange( - comparison_seq * 255, "t h w c -> h (t w) c" + if step % args.log_interval == 0 and jax.process_index() == 0: + logger.log_metrics( + { + "loss": loss, + **metrics, + }, + step + ) + if step % args.log_image_interval == 0: + gt_seq = inputs["videos"][0] + recon_seq = recon[0].clip(0, 1) + comparison_seq = jnp.concatenate((gt_seq, recon_seq), axis=1) + comparison_seq = einops.rearrange( + comparison_seq * 255, "t h w c -> h (t w) c" + ) + if jax.process_index() == 0: + log_images = dict( + image=np.asarray(gt_seq[args.seq_len - 1] * 255.).astype(np.uint8), + recon=np.asarray(recon_seq[args.seq_len - 1] * 255.).astype(np.uint8), + true_vs_recon=np.asarray(comparison_seq.astype(np.uint8)), ) - if jax.process_index() == 0: - log_images = dict( - image=wandb.Image(np.asarray(gt_seq[args.seq_len - 1])), - recon=wandb.Image(np.asarray(recon_seq[args.seq_len - 1])), - true_vs_recon=wandb.Image( - np.asarray(comparison_seq.astype(np.uint8)) - ), - ) - wandb.log(log_images) + logger.log_images(log_images, step) if step % args.log_checkpoint_interval == 0: ckpt = {"model": train_state} orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer() diff --git a/train_lam.py b/train_lam.py index 0629d2a..0aabc0f 100644 --- a/train_lam.py +++ b/train_lam.py @@ -14,12 +14,11 @@ import jax import jax.numpy as jnp import tyro -import wandb from models.lam import LatentActionModel from utils.dataloader import get_dataloader from utils.parameter_utils import count_parameters_by_component - +from utils.logger import CompositeLogger @dataclass class Args: @@ -49,7 +48,8 @@ class Args: dropout: float = 0.0 codebook_dropout: float = 0.0 # Logging - log: bool = False + log_dir: str = "logs/" + loggers: list[str] = field(default_factory=lambda: ["console"]) # options: console, local, tb, wandb entity: str = "" project: str = "" name: str = "train_lam" @@ -59,10 +59,8 @@ class Args: ckpt_dir: str = "" log_checkpoint_interval: int = 10000 - args = tyro.cli(Args) - def lam_loss_fn(params, state, inputs): # --- Compute loss --- outputs = state.apply_fn( @@ -94,7 +92,6 @@ def lam_loss_fn(params, state, inputs): ) return loss, (outputs["recon"], index_counts, metrics) - @jax.jit def train_step(state, inputs, action_last_active): # --- Update model --- @@ -118,7 +115,6 @@ def train_step(state, inputs, action_last_active): action_last_active = jnp.where(do_reset, 0, action_last_active) return state, loss, recon, action_last_active, metrics - if __name__ == "__main__": jax.distributed.initialize() num_devices = jax.device_count() @@ -164,19 +160,11 @@ def train_step(state, inputs, action_last_active): param_counts = count_parameters_by_component(init_params) - if args.log and jax.process_index() == 0: - wandb.init( - entity=args.entity, - project=args.project, - name=args.name, - tags=args.tags, - group="debug", - config=args, - ) - wandb.config.update({"model_param_count": param_counts}) - - print("Parameter counts:") - print(param_counts) + if jax.process_index() == 0: + cfg = vars(args).copy() + cfg["model_param_count"] = param_counts + logger = CompositeLogger(args.loggers, cfg) + print(f"Training Latent Action Model with {param_counts['total']} parameters") # --- Initialize optimizer --- lr_schedule = optax.warmup_cosine_decay_schedule( @@ -240,31 +228,28 @@ def train_step(state, inputs, action_last_active): step += 1 # --- Logging --- - if args.log: - if step % args.log_interval == 0 and jax.process_index() == 0: - wandb.log( - { - "loss": loss, - "step": step, - **metrics, - } - ) - if step % args.log_image_interval == 0: - gt_seq = inputs["videos"][0][1:] - recon_seq = recon[0].clip(0, 1) - comparison_seq = jnp.concatenate((gt_seq, recon_seq), axis=1) - comparison_seq = einops.rearrange( - comparison_seq * 255, "t h w c -> h (t w) c" + if step % args.log_interval == 0 and jax.process_index() == 0: + logger.log_metrics( + { + "loss": loss, + **metrics, + }, + step + ) + if step % args.log_image_interval == 0: + gt_seq = inputs["videos"][0][1:] + recon_seq = recon[0].clip(0, 1) + comparison_seq = jnp.concatenate((gt_seq, recon_seq), axis=1) + comparison_seq = einops.rearrange( + comparison_seq * 255, "t h w c -> h (t w) c" + ) + if jax.process_index() == 0: + log_images = dict( + image=np.asarray(gt_seq[0] * 255.).astype(np.uint8), + recon=np.asarray(recon_seq[0] * 255.).astype(np.uint8), + true_vs_recon=np.asarray(comparison_seq.astype(np.uint8)), ) - if jax.process_index() == 0: - log_images = dict( - image=wandb.Image(np.asarray(gt_seq[0])), - recon=wandb.Image(np.asarray(recon_seq[0])), - true_vs_recon=wandb.Image( - np.asarray(comparison_seq.astype(np.uint8)) - ), - ) - wandb.log(log_images) + logger.log_images(log_images, step) if step % args.log_checkpoint_interval == 0: ckpt = {"model": train_state} orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer() diff --git a/train_tokenizer.py b/train_tokenizer.py index eb2411b..0508cdd 100644 --- a/train_tokenizer.py +++ b/train_tokenizer.py @@ -14,11 +14,11 @@ import jax import jax.numpy as jnp import tyro -import wandb from models.tokenizer import TokenizerVQVAE from utils.dataloader import get_dataloader from utils.parameter_utils import count_parameters_by_component +from utils.logger import CompositeLogger @dataclass @@ -48,7 +48,8 @@ class Args: dropout: float = 0.0 codebook_dropout: float = 0.01 # Logging - log: bool = False + log_dir: str = "logs/" + loggers: list[str] = field(default_factory=lambda: ["console"]) # options: console, local, tb, wandb entity: str = "" project: str = "" name: str = "train_tokenizer" @@ -158,16 +159,11 @@ def train_step(state, inputs): param_counts = count_parameters_by_component(init_params) - if args.log and jax.process_index() == 0: - wandb.init( - entity=args.entity, - project=args.project, - name=args.name, - tags=args.tags, - group="debug", - config=args, - ) - wandb.config.update({"model_param_count": param_counts}) + if jax.process_index() == 0: + cfg = vars(args).copy() + cfg["model_param_count"] = param_counts + logger = CompositeLogger(args.loggers, cfg) + print(f"Training Tokenizer Model with {param_counts['total']} parameters") print("Parameter counts:") print(param_counts) @@ -228,38 +224,35 @@ def train_step(state, inputs): inputs = dict(videos=videos, rng=_rng, dropout_rng=_rng_dropout) train_state, loss, recon, metrics = train_step(train_state, inputs) - print(f"Step {step}, loss: {loss}") step += 1 # --- Logging --- - if args.log: - if step % args.log_interval == 0 and jax.process_index() == 0: - wandb.log( - { - "loss": loss, - "step": step, - **metrics, - } - ) - if step % args.log_image_interval == 0: - gt_seq = inputs["videos"][0] - recon_seq = recon[0].clip(0, 1) - comparison_seq = jnp.concatenate((gt_seq, recon_seq), axis=1) - comparison_seq = einops.rearrange( - comparison_seq * 255, "t h w c -> h (t w) c" + if step % args.log_interval == 0 and jax.process_index() == 0: + logger.log_metrics( + { + "loss": loss, + **metrics, + }, + step + ) + if step % args.log_image_interval == 0: + gt_seq = inputs["videos"][0] + recon_seq = recon[0].clip(0, 1) + comparison_seq = jnp.concatenate((gt_seq, recon_seq), axis=1) + comparison_seq = einops.rearrange( + comparison_seq * 255, "t h w c -> h (t w) c" + ) + # NOTE: Process-dependent control flow deliberately happens + # after indexing operation since it must not contain code + # sections that lead to cross-accelerator communication. + if jax.process_index() == 0: + log_images = dict( + image=np.asarray(gt_seq[0] * 255.).astype(np.uint8), + recon=np.asarray(recon_seq[0] * 255.).astype(np.uint8), + true_vs_recon=np.asarray(comparison_seq.astype(np.uint8) + ), ) - # NOTE: Process-dependent control flow deliberately happens - # after indexing operation since it must not contain code - # sections that lead to cross-accelerator communication. - if jax.process_index() == 0: - log_images = dict( - image=wandb.Image(np.asarray(gt_seq[0])), - recon=wandb.Image(np.asarray(recon_seq[0])), - true_vs_recon=wandb.Image( - np.asarray(comparison_seq.astype(np.uint8)) - ), - ) - wandb.log(log_images) + logger.log_images(log_images, step) if step % args.log_checkpoint_interval == 0: ckpt = {"model": train_state} orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer() diff --git a/utils/logger.py b/utils/logger.py new file mode 100644 index 0000000..d3a6985 --- /dev/null +++ b/utils/logger.py @@ -0,0 +1,266 @@ +import os +import json +from abc import ABC, abstractmethod +from typing import Dict, Any +from pprint import pprint +import numpy as np +import jax + +class BaseLogger(ABC): + """ + Abstract base class for all loggers. + + Defines the interface for logging metrics and images. + """ + + @abstractmethod + def log_metrics(self, metrics: Dict[str, Any], step: int): + """ + Log metrics at a given step. + + Args: + metrics (Dict[str, Any]): Dictionary of metric names and values. + step (int): The current step or epoch. + """ + pass + + @abstractmethod + def log_images(self, images: Dict[str, Any], step: int): + """ + Log images at a given step. + + Args: + images (Dict[str, Any]): Dictionary of image names and image data. + step (int): The current step or epoch. + """ + pass + +class WandbLogger(BaseLogger): + """ + Logger for Weights & Biases (wandb) integration. + + Logs metrics and images to the wandb dashboard. + """ + + def __init__(self, config): + """ + Initialize the WandbLogger. + + Args: + config (dict): Configuration dictionary containing wandb parameters. + """ + import wandb + self.wandb = wandb + self.wandb.init( + entity=config["entity"], + project=config["project"], + name=config["name"], + tags=config["tags"], + group="debug", + config=config, + ) + + def log_metrics(self, metrics, step): + """ + Log metrics to wandb. + + Args: + metrics (dict): Dictionary of metric names and values. + step (int): The current step or epoch. + """ + self.wandb.log({**metrics, "step": step}) + + def log_images(self, images, step): + """ + Log images to wandb. + + Args: + images (dict): Dictionary of image names and image data. + step (int): The current step or epoch. + """ + log_images = {k: self.wandb.Image(v) for k, v in images.items()} + self.wandb.log({**log_images, "step": step}) + +class TensorboardLogger(BaseLogger): + """ + Logger for TensorBoard integration. + + Logs metrics and images to TensorBoard. + """ + + def __init__(self, config): + """ + Initialize the TensorboardLogger. + + Args: + config (dict): Configuration dictionary containing log directory and experiment name. + """ + from tensorboardX import SummaryWriter + base_log_dir = os.path.join(config["log_dir"], "tb_logger", config["name"]) + log_dir = base_log_dir + idx = 1 + while os.path.exists(log_dir): + log_dir = f"{base_log_dir}-{idx}" + idx += 1 + self.log_dir = log_dir + self.writer = SummaryWriter(log_dir=log_dir) + + def log_metrics(self, metrics, step): + """ + Log metrics to TensorBoard. + + Args: + metrics (dict): Dictionary of metric names and values. + step (int): The current step or epoch. + """ + for k, v in metrics.items(): + self.writer.add_scalar(f"metrics/{k}", v, step) + + def log_images(self, images, step): + """ + Log images to TensorBoard. + + Args: + images (dict): Dictionary of image names and image data. + step (int): The current step or epoch. + """ + for k, v in images.items(): + self.writer.add_image(f"media/{k}", v, step, dataformats='HWC') + +class LocalLogger(BaseLogger): + """ + Logger for local filesystem logging. + + Logs metrics to a JSONL file and images as PNGs in a directory. + """ + + def __init__(self, config): + """ + Initialize the LocalLogger. + + Args: + config (dict): Configuration dictionary containing log directory and experiment name. + """ + base_log_dir = os.path.join(config["log_dir"], "local_logger", config["name"]) + log_dir = base_log_dir + idx = 1 + while os.path.exists(log_dir): + log_dir = f"{base_log_dir}-{idx}" + idx += 1 + os.makedirs(log_dir, exist_ok=True) + self.metrics_file = os.path.join(log_dir, "metrics.jsonl") + self.images_dir = os.path.join(log_dir, "images") + os.makedirs(self.images_dir, exist_ok=True) + + def log_metrics(self, metrics, step): + """ + Log metrics to a local JSONL file. + + Args: + metrics (dict): Dictionary of metric names and values. + step (int): The current step or epoch. + """ + with open(self.metrics_file, "a") as f: + f.write(json.dumps({"step": step, **metrics}) + "\n") + + def log_images(self, images, step): + """ + Log images as PNG files to the local filesystem. + + Args: + images (dict): Dictionary of image names and image data (numpy arrays). + step (int): The current step or epoch. + """ + for k, v in images.items(): + # v is expected to be a numpy array (HWC, uint8) + from PIL import Image + img = Image.fromarray(v) + img.save(os.path.join(self.images_dir, f"{k}_step{step}.png")) + +class ConsoleLogger(BaseLogger): + """ + Logger for console output. + + Prints metrics and image logging information to the console. + """ + + def __init__(self, cfg): + """ + Initialize the ConsoleLogger. + + Args: + cfg (dict): Configuration dictionary to print at initialization. + """ + pprint(cfg, compact=True) + + def log_metrics(self, metrics, step): + """ + Print metrics to the console. + + Args: + metrics (dict): Dictionary of metric names and values. + step (int): The current step or epoch. + """ + print(f"[Step {step}] Metrics: " + ", ".join(f"{k}: {v}" for k, v in metrics.items())) + + def log_images(self, images, step): + """ + Print image logging information to the console. + + Args: + images (dict): Dictionary of image names and image data. + step (int): The current step or epoch. + """ + print(f"[Step {step}] Images logged: {', '.join(images.keys())}") + + +class CompositeLogger(BaseLogger): + """ + Logger that combines multiple logger backends. + + Forwards logging calls to all specified loggers. + """ + + def __init__(self, loggers, cfg): + """ + Initialize the CompositeLogger. + + Args: + loggers (list): List of logger names to instantiate. + cfg (dict): Configuration dictionary to pass to each logger. + """ + available_loggers = {"wandb": WandbLogger, + "tb": TensorboardLogger, + "local": LocalLogger, + "console": ConsoleLogger} + self.loggers = [] + for logger in loggers: + assert logger in available_loggers.keys(), f"Logger \"{logger}\" not known. Available loggers are: {', '.join(available_loggers.keys())}" + logger_class = available_loggers[logger] + self.loggers.append(logger_class(cfg)) + + + def log_metrics(self, metrics, step): + """ + Log metrics to all contained loggers. + + Args: + metrics (dict): Dictionary of metric names and values. + step (int): The current step or epoch. + """ + metrics = jax.tree.map( + lambda x: x.item() if isinstance(x, (jax.Array, np.ndarray)) else x, metrics + ) + for logger in self.loggers: + logger.log_metrics(metrics, step) + + def log_images(self, images, step): + """ + Log images to all contained loggers. + + Args: + images (dict): Dictionary of image names and image data. + step (int): The current step or epoch. + """ + for logger in self.loggers: + logger.log_images(images, step)