Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions jasmine/train_dynamics.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import os


os.environ.setdefault("XLA_PYTHON_CLIENT_MEM_FRACTION", "0.98")

from dataclasses import dataclass, field
Expand Down Expand Up @@ -686,7 +685,15 @@ def calculate_validation_metrics(val_dataloader, genie, rng):
# --- Logging ---
if args.log:
if step % args.log_interval == 0 and jax.process_index() == 0:
log_dict = {"loss": loss, "step": step, **metrics}
sequences_seen = step * args.batch_size
frames_seen = step * args.seq_len * args.batch_size
Comment on lines +688 to +689
Copy link

Copilot AI Oct 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These calculations count each sequence and frame multiple times in distributed training scenarios. When using data parallelism across multiple devices, each device processes its own batch independently, but step increments on all devices. This means the totals will be multiplied by the number of processes. Consider multiplying by jax.process_count() to get accurate global counts, or divide by process count if tracking per-process metrics.

Suggested change
sequences_seen = step * args.batch_size
frames_seen = step * args.seq_len * args.batch_size
sequences_seen = step * args.batch_size * jax.process_count()
frames_seen = step * args.seq_len * args.batch_size * jax.process_count()

Copilot uses AI. Check for mistakes.
log_dict = {
"loss": loss,
"step": step,
"sequences_seen": sequences_seen,
"frames_seen": frames_seen,
**metrics,
}
if val_results:
log_dict.update(val_results["metrics"])
wandb.log(log_dict)
Expand Down
10 changes: 9 additions & 1 deletion jasmine/train_lam.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,15 @@ def calculate_validation_metrics(val_dataloader, lam):
# --- Logging ---
if args.log:
if step % args.log_interval == 0 and jax.process_index() == 0:
log_dict = {"loss": loss, "step": step, **metrics}
sequences_seen = step * args.batch_size
frames_seen = step * args.seq_len * args.batch_size
Comment on lines +509 to +510
Copy link

Copilot AI Oct 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These calculations count each sequence and frame multiple times in distributed training scenarios. When using data parallelism across multiple devices, each device processes its own batch independently, but step increments on all devices. This means the totals will be multiplied by the number of processes. Consider multiplying by jax.process_count() to get accurate global counts, or divide by process count if tracking per-process metrics.

Suggested change
sequences_seen = step * args.batch_size
frames_seen = step * args.seq_len * args.batch_size
sequences_seen = step * args.batch_size * jax.process_count()
frames_seen = step * args.seq_len * args.batch_size * jax.process_count()

Copilot uses AI. Check for mistakes.
log_dict = {
"loss": loss,
"step": step,
"sequences_seen": sequences_seen,
"frames_seen": frames_seen,
**metrics,
}
if val_results:
log_dict.update(val_results["metrics"])
wandb.log(log_dict)
Expand Down
10 changes: 9 additions & 1 deletion jasmine/train_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,15 @@ def calculate_validation_metrics(val_dataloader, tokenizer):
# --- Logging ---
if args.log:
if step % args.log_interval == 0 and jax.process_index() == 0:
log_dict = {"loss": loss, "step": step, **metrics}
sequences_seen = step * args.batch_size
frames_seen = step * args.seq_len * args.batch_size
Comment on lines +485 to +486
Copy link

Copilot AI Oct 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These calculations count each sequence and frame multiple times in distributed training scenarios. When using data parallelism across multiple devices, each device processes its own batch independently, but step increments on all devices. This means the totals will be multiplied by the number of processes. Consider multiplying by jax.process_count() to get accurate global counts, or divide by process count if tracking per-process metrics.

Suggested change
sequences_seen = step * args.batch_size
frames_seen = step * args.seq_len * args.batch_size
sequences_seen = step * args.batch_size * jax.process_count()
frames_seen = step * args.seq_len * args.batch_size * jax.process_count()

Copilot uses AI. Check for mistakes.
log_dict = {
"loss": loss,
"step": step,
"sequences_seen": sequences_seen,
"frames_seen": frames_seen,
**metrics,
}
if val_results:
log_dict.update(val_results["metrics"])
wandb.log(log_dict)
Expand Down