From f613a979eea292a8e321eecd07b3192af779003e Mon Sep 17 00:00:00 2001 From: Mihir Mahajan Date: Mon, 7 Jul 2025 17:27:15 +0200 Subject: [PATCH 1/5] init other typed of loggers --- train_tokenizer.py | 71 +++++++++++++---------------- utils/logger.py | 109 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 141 insertions(+), 39 deletions(-) create mode 100644 utils/logger.py diff --git a/train_tokenizer.py b/train_tokenizer.py index eb2411b..fe43a01 100644 --- a/train_tokenizer.py +++ b/train_tokenizer.py @@ -19,6 +19,7 @@ from models.tokenizer import TokenizerVQVAE from utils.dataloader import get_dataloader from utils.parameter_utils import count_parameters_by_component +from utils.logger import CompositeLogger @dataclass @@ -48,7 +49,8 @@ class Args: dropout: float = 0.0 codebook_dropout: float = 0.01 # Logging - log: bool = False + log_dir: str = "logs/" + loggers: list[str] = field(default_factory=lambda: ["console"]) entity: str = "" project: str = "" name: str = "train_tokenizer" @@ -158,16 +160,10 @@ def train_step(state, inputs): param_counts = count_parameters_by_component(init_params) - if args.log and jax.process_index() == 0: - wandb.init( - entity=args.entity, - project=args.project, - name=args.name, - tags=args.tags, - group="debug", - config=args, - ) - wandb.config.update({"model_param_count": param_counts}) + if jax.process_index() == 0: + cfg = vars(args).copy() + cfg["model_param_count"] = param_counts + logger = CompositeLogger(args.loggers, cfg) print("Parameter counts:") print(param_counts) @@ -228,38 +224,35 @@ def train_step(state, inputs): inputs = dict(videos=videos, rng=_rng, dropout_rng=_rng_dropout) train_state, loss, recon, metrics = train_step(train_state, inputs) - print(f"Step {step}, loss: {loss}") step += 1 # --- Logging --- - if args.log: - if step % args.log_interval == 0 and jax.process_index() == 0: - wandb.log( - { - "loss": loss, - "step": step, - **metrics, - } - ) - if step % args.log_image_interval == 0: - gt_seq = inputs["videos"][0] - recon_seq = recon[0].clip(0, 1) - comparison_seq = jnp.concatenate((gt_seq, recon_seq), axis=1) - comparison_seq = einops.rearrange( - comparison_seq * 255, "t h w c -> h (t w) c" + if step % args.log_interval == 0 and jax.process_index() == 0: + logger.log_metrics( + { + "loss": loss, + **metrics, + }, + step + ) + if step % args.log_image_interval == 0: + gt_seq = inputs["videos"][0] + recon_seq = recon[0].clip(0, 1) + comparison_seq = jnp.concatenate((gt_seq, recon_seq), axis=1) + comparison_seq = einops.rearrange( + comparison_seq * 255, "t h w c -> h (t w) c" + ) + # NOTE: Process-dependent control flow deliberately happens + # after indexing operation since it must not contain code + # sections that lead to cross-accelerator communication. + if jax.process_index() == 0: + log_images = dict( + image=np.asarray(gt_seq[0]).astype(np.uint8), + recon=np.asarray(recon_seq[0]).astype(np.uint8), + true_vs_recon=np.asarray(comparison_seq.astype(np.uint8) + ), ) - # NOTE: Process-dependent control flow deliberately happens - # after indexing operation since it must not contain code - # sections that lead to cross-accelerator communication. - if jax.process_index() == 0: - log_images = dict( - image=wandb.Image(np.asarray(gt_seq[0])), - recon=wandb.Image(np.asarray(recon_seq[0])), - true_vs_recon=wandb.Image( - np.asarray(comparison_seq.astype(np.uint8)) - ), - ) - wandb.log(log_images) + logger.log_images(log_images, step) if step % args.log_checkpoint_interval == 0: ckpt = {"model": train_state} orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer() diff --git a/utils/logger.py b/utils/logger.py new file mode 100644 index 0000000..e16e8ec --- /dev/null +++ b/utils/logger.py @@ -0,0 +1,109 @@ +# utils/logger.py + +import os +import json +from abc import ABC, abstractmethod +from typing import Dict, Any +from pprint import pprint + +class BaseLogger(ABC): + @abstractmethod + def log_metrics(self, metrics: Dict[str, Any], step: int): + pass + + @abstractmethod + def log_images(self, images: Dict[str, Any], step: int): + pass + +class WandbLogger(BaseLogger): + def __init__(self, config): + import wandb + self.wandb = wandb + self.wandb.init( + entity=config["entity"], + project=config["project"], + name=config["name"], + tags=config["tags"], + group="debug", + config=config, + ) + + def log_metrics(self, metrics, step): + self.wandb.log({**metrics, "step": step}) + + def log_images(self, images, step): + log_images = {k: self.wandb.Image(v) for k, v in images.items()} + self.wandb.log({**log_images, "step": step}) + +class TensorboardLogger(BaseLogger): + def __init__(self, config): + from torch.utils.tensorboard import SummaryWriter + log_dir = os.path.join(config["log_dir"], "tb_logger") + self.log_dir = log_dir + self.writer = SummaryWriter(log_dir=log_dir) + + def log_metrics(self, metrics, step): + for k, v in metrics.items(): + self.writer.add_scalar(k, v, step) + + def log_images(self, images, step): + for k, v in images.items(): + self.writer.add_image(k, v, step, dataformats='HWC') + +class LocalLogger(BaseLogger): + def __init__(self, config): + log_dir = os.path.join(config["log_dir"], "local_logger") + self.log_dir = log_dir + os.makedirs(log_dir, exist_ok=True) + self.metrics_file = os.path.join(log_dir, "metrics.jsonl") + self.images_dir = os.path.join(log_dir, "images") + os.makedirs(self.images_dir, exist_ok=True) + + def log_metrics(self, metrics, step): + with open(self.metrics_file, "a") as f: + metrics = {k: str(v) for k, v in metrics.items()} + f.write(json.dumps({"step": step, **metrics}) + "\n") + + def log_images(self, images, step): + for k, v in images.items(): + # v is expected to be a numpy array (HWC, uint8) + from PIL import Image + img = Image.fromarray(v) + img.save(os.path.join(self.images_dir, f"{k}_step{step}.png")) + +class ConsoleLogger(BaseLogger): + def __init__(self, cfg): + pprint(cfg, compact=True) + + def log_metrics(self, metrics, step): + print(f"[Step {step}] Metrics: " + ", ".join(f"{k}: {v}" for k, v in metrics.items())) + + def log_images(self, images, step): + print(f"[Step {step}] Images logged: {', '.join(images.keys())}") + + +class CompositeLogger(BaseLogger): + def __init__(self, loggers, cfg): + available_loggers = {"wandb": WandbLogger, + "tb": TensorboardLogger, + "json": TensorboardLogger, + "local": LocalLogger, + "console": ConsoleLogger} + self.loggers = [] + for logger in loggers: + assert logger in available_loggers.keys(), f"Logger \"{logger}\" not known. Available loggers are: {available_loggers.keys()}" + logger_class = available_loggers[logger] + self.loggers.append(logger_class(cfg)) + + + def log_metrics(self, metrics, step): + for logger in self.loggers: + logger.log_metrics(metrics, step) + + def log_images(self, images, step): + for logger in self.loggers: + logger.log_images(images, step) + + def log_checkpoint(self, checkpoint, step): + for logger in self.loggers: + logger.log_checkpoint(checkpoint, step) \ No newline at end of file From 2e190718a49cd12d3d8212adf730974e0efb4dd9 Mon Sep 17 00:00:00 2001 From: Mihir Mahajan Date: Tue, 8 Jul 2025 16:51:23 +0200 Subject: [PATCH 2/5] bugfixes and cleanup logger --- train_dynamics.py | 68 +++++++---------- train_lam.py | 73 +++++++----------- train_tokenizer.py | 8 +- utils/logger.py | 185 +++++++++++++++++++++++++++++++++++++++++---- 4 files changed, 233 insertions(+), 101 deletions(-) diff --git a/train_dynamics.py b/train_dynamics.py index 27f6f2c..4b814f8 100644 --- a/train_dynamics.py +++ b/train_dynamics.py @@ -13,13 +13,13 @@ import jax import jax.numpy as jnp import tyro -import wandb from genie import Genie, restore_genie_components from models.tokenizer import TokenizerVQVAE from models.lam import LatentActionModel from utils.dataloader import get_dataloader from utils.parameter_utils import count_parameters_by_component +from utils.logger import CompositeLogger @dataclass @@ -60,7 +60,8 @@ class Args: dropout: float = 0.0 mask_limit: float = 0.5 # Logging - log: bool = False + log_dir: str = "logs/" + loggers: list[str] = field(default_factory=lambda: ["console"]) # options: console, local, tb, wandb entity: str = "" project: str = "" name: str = "train_dynamics" @@ -173,19 +174,11 @@ def train_step(state, inputs): param_counts = count_parameters_by_component(init_params) - if args.log and jax.process_index() == 0: - wandb.init( - entity=args.entity, - project=args.project, - name=args.name, - tags=args.tags, - group="debug", - config=args, - ) - wandb.config.update({"model_param_count": param_counts}) - - print("Parameter counts:") - print(param_counts) + if jax.process_index() == 0: + cfg = vars(args).copy() + cfg["model_param_count"] = param_counts + logger = CompositeLogger(args.loggers, cfg) + print(f"Training Dynamics Model with {param_counts["total"]} parameters") # --- Initialize optimizer --- lr_schedule = optax.warmup_cosine_decay_schedule( @@ -240,31 +233,28 @@ def train_step(state, inputs): step += 1 # --- Logging --- - if args.log: - if step % args.log_interval == 0 and jax.process_index() == 0: - wandb.log( - { - "loss": loss, - "step": step, - **metrics, - } - ) - if step % args.log_image_interval == 0: - gt_seq = inputs["videos"][0] - recon_seq = recon[0].clip(0, 1) - comparison_seq = jnp.concatenate((gt_seq, recon_seq), axis=1) - comparison_seq = einops.rearrange( - comparison_seq * 255, "t h w c -> h (t w) c" + if step % args.log_interval == 0 and jax.process_index() == 0: + logger.log_metrics( + { + "loss": loss, + **metrics, + }, + step + ) + if step % args.log_image_interval == 0: + gt_seq = inputs["videos"][0] + recon_seq = recon[0].clip(0, 1) + comparison_seq = jnp.concatenate((gt_seq, recon_seq), axis=1) + comparison_seq = einops.rearrange( + comparison_seq * 255, "t h w c -> h (t w) c" + ) + if jax.process_index() == 0: + log_images = dict( + image=np.asarray(gt_seq[args.seq_len - 1] * 255.).astype(np.uint8), + recon=np.asarray(recon_seq[args.seq_len - 1] * 255.).astype(np.uint8), + true_vs_recon=np.asarray(comparison_seq.astype(np.uint8)), ) - if jax.process_index() == 0: - log_images = dict( - image=wandb.Image(np.asarray(gt_seq[args.seq_len - 1])), - recon=wandb.Image(np.asarray(recon_seq[args.seq_len - 1])), - true_vs_recon=wandb.Image( - np.asarray(comparison_seq.astype(np.uint8)) - ), - ) - wandb.log(log_images) + logger.log_images(log_images, step) if step % args.log_checkpoint_interval == 0: ckpt = {"model": train_state} orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer() diff --git a/train_lam.py b/train_lam.py index 0629d2a..558d733 100644 --- a/train_lam.py +++ b/train_lam.py @@ -14,12 +14,11 @@ import jax import jax.numpy as jnp import tyro -import wandb from models.lam import LatentActionModel from utils.dataloader import get_dataloader from utils.parameter_utils import count_parameters_by_component - +from utils.logger import CompositeLogger @dataclass class Args: @@ -49,7 +48,8 @@ class Args: dropout: float = 0.0 codebook_dropout: float = 0.0 # Logging - log: bool = False + log_dir: str = "logs/" + loggers: list[str] = field(default_factory=lambda: ["console"]) # options: console, local, tb, wandb entity: str = "" project: str = "" name: str = "train_lam" @@ -59,10 +59,8 @@ class Args: ckpt_dir: str = "" log_checkpoint_interval: int = 10000 - args = tyro.cli(Args) - def lam_loss_fn(params, state, inputs): # --- Compute loss --- outputs = state.apply_fn( @@ -94,7 +92,6 @@ def lam_loss_fn(params, state, inputs): ) return loss, (outputs["recon"], index_counts, metrics) - @jax.jit def train_step(state, inputs, action_last_active): # --- Update model --- @@ -118,7 +115,6 @@ def train_step(state, inputs, action_last_active): action_last_active = jnp.where(do_reset, 0, action_last_active) return state, loss, recon, action_last_active, metrics - if __name__ == "__main__": jax.distributed.initialize() num_devices = jax.device_count() @@ -164,19 +160,11 @@ def train_step(state, inputs, action_last_active): param_counts = count_parameters_by_component(init_params) - if args.log and jax.process_index() == 0: - wandb.init( - entity=args.entity, - project=args.project, - name=args.name, - tags=args.tags, - group="debug", - config=args, - ) - wandb.config.update({"model_param_count": param_counts}) - - print("Parameter counts:") - print(param_counts) + if jax.process_index() == 0: + cfg = vars(args).copy() + cfg["model_param_count"] = param_counts + logger = CompositeLogger(args.loggers, cfg) + print(f"Training Latent Action Model with {param_counts["total"]} parameters") # --- Initialize optimizer --- lr_schedule = optax.warmup_cosine_decay_schedule( @@ -240,31 +228,28 @@ def train_step(state, inputs, action_last_active): step += 1 # --- Logging --- - if args.log: - if step % args.log_interval == 0 and jax.process_index() == 0: - wandb.log( - { - "loss": loss, - "step": step, - **metrics, - } - ) - if step % args.log_image_interval == 0: - gt_seq = inputs["videos"][0][1:] - recon_seq = recon[0].clip(0, 1) - comparison_seq = jnp.concatenate((gt_seq, recon_seq), axis=1) - comparison_seq = einops.rearrange( - comparison_seq * 255, "t h w c -> h (t w) c" + if step % args.log_interval == 0 and jax.process_index() == 0: + logger.log_metrics( + { + "loss": loss, + **metrics, + }, + step + ) + if step % args.log_image_interval == 0: + gt_seq = inputs["videos"][0][1:] + recon_seq = recon[0].clip(0, 1) + comparison_seq = jnp.concatenate((gt_seq, recon_seq), axis=1) + comparison_seq = einops.rearrange( + comparison_seq * 255, "t h w c -> h (t w) c" + ) + if jax.process_index() == 0: + log_images = dict( + image=np.asarray(gt_seq[0] * 255.).astype(np.uint8), + recon=np.asarray(recon_seq[0] * 255.).astype(np.uint8), + true_vs_recon=np.asarray(comparison_seq.astype(np.uint8)), ) - if jax.process_index() == 0: - log_images = dict( - image=wandb.Image(np.asarray(gt_seq[0])), - recon=wandb.Image(np.asarray(recon_seq[0])), - true_vs_recon=wandb.Image( - np.asarray(comparison_seq.astype(np.uint8)) - ), - ) - wandb.log(log_images) + logger.log_images(log_images, step) if step % args.log_checkpoint_interval == 0: ckpt = {"model": train_state} orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer() diff --git a/train_tokenizer.py b/train_tokenizer.py index fe43a01..188c834 100644 --- a/train_tokenizer.py +++ b/train_tokenizer.py @@ -14,7 +14,6 @@ import jax import jax.numpy as jnp import tyro -import wandb from models.tokenizer import TokenizerVQVAE from utils.dataloader import get_dataloader @@ -50,7 +49,7 @@ class Args: codebook_dropout: float = 0.01 # Logging log_dir: str = "logs/" - loggers: list[str] = field(default_factory=lambda: ["console"]) + loggers: list[str] = field(default_factory=lambda: ["console"]) # options: console, local, tb, wandb entity: str = "" project: str = "" name: str = "train_tokenizer" @@ -164,6 +163,7 @@ def train_step(state, inputs): cfg = vars(args).copy() cfg["model_param_count"] = param_counts logger = CompositeLogger(args.loggers, cfg) + print(f"Training Tokenizer Model with {param_counts["total"]} parameters") print("Parameter counts:") print(param_counts) @@ -247,8 +247,8 @@ def train_step(state, inputs): # sections that lead to cross-accelerator communication. if jax.process_index() == 0: log_images = dict( - image=np.asarray(gt_seq[0]).astype(np.uint8), - recon=np.asarray(recon_seq[0]).astype(np.uint8), + image=np.asarray(gt_seq[0] * 255.).astype(np.uint8), + recon=np.asarray(recon_seq[0] * 255.).astype(np.uint8), true_vs_recon=np.asarray(comparison_seq.astype(np.uint8) ), ) diff --git a/utils/logger.py b/utils/logger.py index e16e8ec..6afb344 100644 --- a/utils/logger.py +++ b/utils/logger.py @@ -1,22 +1,54 @@ -# utils/logger.py - import os import json from abc import ABC, abstractmethod from typing import Dict, Any from pprint import pprint +import numpy as np +import jax class BaseLogger(ABC): + """ + Abstract base class for all loggers. + + Defines the interface for logging metrics and images. + """ + @abstractmethod def log_metrics(self, metrics: Dict[str, Any], step: int): + """ + Log metrics at a given step. + + Args: + metrics (Dict[str, Any]): Dictionary of metric names and values. + step (int): The current step or epoch. + """ pass @abstractmethod def log_images(self, images: Dict[str, Any], step: int): + """ + Log images at a given step. + + Args: + images (Dict[str, Any]): Dictionary of image names and image data. + step (int): The current step or epoch. + """ pass class WandbLogger(BaseLogger): + """ + Logger for Weights & Biases (wandb) integration. + + Logs metrics and images to the wandb dashboard. + """ + def __init__(self, config): + """ + Initialize the WandbLogger. + + Args: + config (dict): Configuration dictionary containing wandb parameters. + """ import wandb self.wandb = wandb self.wandb.init( @@ -29,42 +61,116 @@ def __init__(self, config): ) def log_metrics(self, metrics, step): + """ + Log metrics to wandb. + + Args: + metrics (dict): Dictionary of metric names and values. + step (int): The current step or epoch. + """ self.wandb.log({**metrics, "step": step}) def log_images(self, images, step): + """ + Log images to wandb. + + Args: + images (dict): Dictionary of image names and image data. + step (int): The current step or epoch. + """ log_images = {k: self.wandb.Image(v) for k, v in images.items()} self.wandb.log({**log_images, "step": step}) class TensorboardLogger(BaseLogger): + """ + Logger for TensorBoard integration. + + Logs metrics and images to TensorBoard. + """ + def __init__(self, config): - from torch.utils.tensorboard import SummaryWriter - log_dir = os.path.join(config["log_dir"], "tb_logger") + """ + Initialize the TensorboardLogger. + + Args: + config (dict): Configuration dictionary containing log directory and experiment name. + """ + from tensorboardX import SummaryWriter + base_log_dir = os.path.join(config["log_dir"], "tb_logger", config["name"]) + log_dir = base_log_dir + idx = 1 + while os.path.exists(log_dir): + log_dir = f"{base_log_dir}-{idx}" + idx += 1 self.log_dir = log_dir self.writer = SummaryWriter(log_dir=log_dir) def log_metrics(self, metrics, step): + """ + Log metrics to TensorBoard. + + Args: + metrics (dict): Dictionary of metric names and values. + step (int): The current step or epoch. + """ for k, v in metrics.items(): - self.writer.add_scalar(k, v, step) + self.writer.add_scalar(f"metrics/{k}", v, step) def log_images(self, images, step): + """ + Log images to TensorBoard. + + Args: + images (dict): Dictionary of image names and image data. + step (int): The current step or epoch. + """ for k, v in images.items(): - self.writer.add_image(k, v, step, dataformats='HWC') + self.writer.add_image(f"media/{k}", v, step, dataformats='HWC') class LocalLogger(BaseLogger): + """ + Logger for local filesystem logging. + + Logs metrics to a JSONL file and images as PNGs in a directory. + """ + def __init__(self, config): - log_dir = os.path.join(config["log_dir"], "local_logger") - self.log_dir = log_dir + """ + Initialize the LocalLogger. + + Args: + config (dict): Configuration dictionary containing log directory and experiment name. + """ + base_log_dir = os.path.join(config["log_dir"], "local_logger", config["name"]) + log_dir = base_log_dir + idx = 1 + while os.path.exists(log_dir): + log_dir = f"{base_log_dir}-{idx}" + idx += 1 os.makedirs(log_dir, exist_ok=True) self.metrics_file = os.path.join(log_dir, "metrics.jsonl") self.images_dir = os.path.join(log_dir, "images") os.makedirs(self.images_dir, exist_ok=True) def log_metrics(self, metrics, step): + """ + Log metrics to a local JSONL file. + + Args: + metrics (dict): Dictionary of metric names and values. + step (int): The current step or epoch. + """ with open(self.metrics_file, "a") as f: - metrics = {k: str(v) for k, v in metrics.items()} f.write(json.dumps({"step": step, **metrics}) + "\n") def log_images(self, images, step): + """ + Log images as PNG files to the local filesystem. + + Args: + images (dict): Dictionary of image names and image data (numpy arrays). + step (int): The current step or epoch. + """ for k, v in images.items(): # v is expected to be a numpy array (HWC, uint8) from PIL import Image @@ -72,21 +178,59 @@ def log_images(self, images, step): img.save(os.path.join(self.images_dir, f"{k}_step{step}.png")) class ConsoleLogger(BaseLogger): + """ + Logger for console output. + + Prints metrics and image logging information to the console. + """ + def __init__(self, cfg): + """ + Initialize the ConsoleLogger. + + Args: + cfg (dict): Configuration dictionary to print at initialization. + """ pprint(cfg, compact=True) def log_metrics(self, metrics, step): + """ + Print metrics to the console. + + Args: + metrics (dict): Dictionary of metric names and values. + step (int): The current step or epoch. + """ print(f"[Step {step}] Metrics: " + ", ".join(f"{k}: {v}" for k, v in metrics.items())) def log_images(self, images, step): + """ + Print image logging information to the console. + + Args: + images (dict): Dictionary of image names and image data. + step (int): The current step or epoch. + """ print(f"[Step {step}] Images logged: {', '.join(images.keys())}") class CompositeLogger(BaseLogger): + """ + Logger that combines multiple logger backends. + + Forwards logging calls to all specified loggers. + """ + def __init__(self, loggers, cfg): + """ + Initialize the CompositeLogger. + + Args: + loggers (list): List of logger names to instantiate. + cfg (dict): Configuration dictionary to pass to each logger. + """ available_loggers = {"wandb": WandbLogger, "tb": TensorboardLogger, - "json": TensorboardLogger, "local": LocalLogger, "console": ConsoleLogger} self.loggers = [] @@ -97,13 +241,26 @@ def __init__(self, loggers, cfg): def log_metrics(self, metrics, step): + """ + Log metrics to all contained loggers. + + Args: + metrics (dict): Dictionary of metric names and values. + step (int): The current step or epoch. + """ for logger in self.loggers: + metrics = jax.tree.map( + lambda x: x.item() if isinstance(x, (jax.Array, np.ndarray)) else x, metrics + ) logger.log_metrics(metrics, step) def log_images(self, images, step): - for logger in self.loggers: - logger.log_images(images, step) + """ + Log images to all contained loggers. - def log_checkpoint(self, checkpoint, step): + Args: + images (dict): Dictionary of image names and image data. + step (int): The current step or epoch. + """ for logger in self.loggers: - logger.log_checkpoint(checkpoint, step) \ No newline at end of file + logger.log_images(images, step) From 35c92aafc707dd0bc71973d5dcd390a5fcb234c5 Mon Sep 17 00:00:00 2001 From: mihir <78321484+maharajamihir@users.noreply.github.com> Date: Tue, 8 Jul 2025 17:03:54 +0200 Subject: [PATCH 3/5] Update utils/logger.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- utils/logger.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/utils/logger.py b/utils/logger.py index 6afb344..694bb25 100644 --- a/utils/logger.py +++ b/utils/logger.py @@ -248,10 +248,10 @@ def log_metrics(self, metrics, step): metrics (dict): Dictionary of metric names and values. step (int): The current step or epoch. """ + metrics = jax.tree.map( + lambda x: x.item() if isinstance(x, (jax.Array, np.ndarray)) else x, metrics + ) for logger in self.loggers: - metrics = jax.tree.map( - lambda x: x.item() if isinstance(x, (jax.Array, np.ndarray)) else x, metrics - ) logger.log_metrics(metrics, step) def log_images(self, images, step): From 08e0eb70670fe6e8ecfc626e4995860fcf9783bb Mon Sep 17 00:00:00 2001 From: mihir <78321484+maharajamihir@users.noreply.github.com> Date: Tue, 8 Jul 2025 17:04:03 +0200 Subject: [PATCH 4/5] Update utils/logger.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- utils/logger.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/logger.py b/utils/logger.py index 694bb25..d3a6985 100644 --- a/utils/logger.py +++ b/utils/logger.py @@ -235,7 +235,7 @@ def __init__(self, loggers, cfg): "console": ConsoleLogger} self.loggers = [] for logger in loggers: - assert logger in available_loggers.keys(), f"Logger \"{logger}\" not known. Available loggers are: {available_loggers.keys()}" + assert logger in available_loggers.keys(), f"Logger \"{logger}\" not known. Available loggers are: {', '.join(available_loggers.keys())}" logger_class = available_loggers[logger] self.loggers.append(logger_class(cfg)) From c80f2c93d3bef6f6c1fd43670793a400ed09639e Mon Sep 17 00:00:00 2001 From: Mihir Mahajan Date: Tue, 8 Jul 2025 17:06:45 +0200 Subject: [PATCH 5/5] fixed copilots nitpicks --- train_dynamics.py | 2 +- train_lam.py | 2 +- train_tokenizer.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/train_dynamics.py b/train_dynamics.py index 4b814f8..a60bd9e 100644 --- a/train_dynamics.py +++ b/train_dynamics.py @@ -178,7 +178,7 @@ def train_step(state, inputs): cfg = vars(args).copy() cfg["model_param_count"] = param_counts logger = CompositeLogger(args.loggers, cfg) - print(f"Training Dynamics Model with {param_counts["total"]} parameters") + print(f"Training Dynamics Model with {param_counts['total']} parameters") # --- Initialize optimizer --- lr_schedule = optax.warmup_cosine_decay_schedule( diff --git a/train_lam.py b/train_lam.py index 558d733..0aabc0f 100644 --- a/train_lam.py +++ b/train_lam.py @@ -164,7 +164,7 @@ def train_step(state, inputs, action_last_active): cfg = vars(args).copy() cfg["model_param_count"] = param_counts logger = CompositeLogger(args.loggers, cfg) - print(f"Training Latent Action Model with {param_counts["total"]} parameters") + print(f"Training Latent Action Model with {param_counts['total']} parameters") # --- Initialize optimizer --- lr_schedule = optax.warmup_cosine_decay_schedule( diff --git a/train_tokenizer.py b/train_tokenizer.py index 188c834..0508cdd 100644 --- a/train_tokenizer.py +++ b/train_tokenizer.py @@ -163,7 +163,7 @@ def train_step(state, inputs): cfg = vars(args).copy() cfg["model_param_count"] = param_counts logger = CompositeLogger(args.loggers, cfg) - print(f"Training Tokenizer Model with {param_counts["total"]} parameters") + print(f"Training Tokenizer Model with {param_counts['total']} parameters") print("Parameter counts:") print(param_counts)