From 6d2c242a6c14f567ae22d08f85c2cdc1f9b01664 Mon Sep 17 00:00:00 2001 From: Alfred Date: Thu, 23 Oct 2025 13:38:45 +0200 Subject: [PATCH] feature: log number of seen sequence and frames --- jasmine/train_dynamics.py | 11 +++++++++-- jasmine/train_lam.py | 10 +++++++++- jasmine/train_tokenizer.py | 10 +++++++++- 3 files changed, 27 insertions(+), 4 deletions(-) diff --git a/jasmine/train_dynamics.py b/jasmine/train_dynamics.py index 06cd966..1ea43de 100644 --- a/jasmine/train_dynamics.py +++ b/jasmine/train_dynamics.py @@ -1,6 +1,5 @@ import os - os.environ.setdefault("XLA_PYTHON_CLIENT_MEM_FRACTION", "0.98") from dataclasses import dataclass, field @@ -686,7 +685,15 @@ def calculate_validation_metrics(val_dataloader, genie, rng): # --- Logging --- if args.log: if step % args.log_interval == 0 and jax.process_index() == 0: - log_dict = {"loss": loss, "step": step, **metrics} + sequences_seen = step * args.batch_size + frames_seen = step * args.seq_len * args.batch_size + log_dict = { + "loss": loss, + "step": step, + "sequences_seen": sequences_seen, + "frames_seen": frames_seen, + **metrics, + } if val_results: log_dict.update(val_results["metrics"]) wandb.log(log_dict) diff --git a/jasmine/train_lam.py b/jasmine/train_lam.py index 95fcb67..f9c56e9 100644 --- a/jasmine/train_lam.py +++ b/jasmine/train_lam.py @@ -506,7 +506,15 @@ def calculate_validation_metrics(val_dataloader, lam): # --- Logging --- if args.log: if step % args.log_interval == 0 and jax.process_index() == 0: - log_dict = {"loss": loss, "step": step, **metrics} + sequences_seen = step * args.batch_size + frames_seen = step * args.seq_len * args.batch_size + log_dict = { + "loss": loss, + "step": step, + "sequences_seen": sequences_seen, + "frames_seen": frames_seen, + **metrics, + } if val_results: log_dict.update(val_results["metrics"]) wandb.log(log_dict) diff --git a/jasmine/train_tokenizer.py b/jasmine/train_tokenizer.py index 0bdcd7b..f02a297 100644 --- a/jasmine/train_tokenizer.py +++ b/jasmine/train_tokenizer.py @@ -482,7 +482,15 @@ def calculate_validation_metrics(val_dataloader, tokenizer): # --- Logging --- if args.log: if step % args.log_interval == 0 and jax.process_index() == 0: - log_dict = {"loss": loss, "step": step, **metrics} + sequences_seen = step * args.batch_size + frames_seen = step * args.seq_len * args.batch_size + log_dict = { + "loss": loss, + "step": step, + "sequences_seen": sequences_seen, + "frames_seen": frames_seen, + **metrics, + } if val_results: log_dict.update(val_results["metrics"]) wandb.log(log_dict)