Skip to content

Conversation

@avocadoali
Copy link

No description provided.

@avocadoali avocadoali requested a review from Copilot October 23, 2025 11:43
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR adds throughput tracking metrics to the training logging system by calculating and logging the cumulative number of sequences and frames processed during training.

  • Adds sequences_seen and frames_seen metrics to training logs
  • Implements consistent tracking across tokenizer, LAM, and dynamics training scripts
  • Removes an unnecessary blank line in train_dynamics.py

Reviewed Changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 3 comments.

File Description
jasmine/train_tokenizer.py Adds sequences_seen and frames_seen calculations to log_dict
jasmine/train_lam.py Adds sequences_seen and frames_seen calculations to log_dict
jasmine/train_dynamics.py Adds sequences_seen and frames_seen calculations to log_dict and removes extra blank line

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

Comment on lines +485 to +486
sequences_seen = step * args.batch_size
frames_seen = step * args.seq_len * args.batch_size
Copy link

Copilot AI Oct 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These calculations count each sequence and frame multiple times in distributed training scenarios. When using data parallelism across multiple devices, each device processes its own batch independently, but step increments on all devices. This means the totals will be multiplied by the number of processes. Consider multiplying by jax.process_count() to get accurate global counts, or divide by process count if tracking per-process metrics.

Suggested change
sequences_seen = step * args.batch_size
frames_seen = step * args.seq_len * args.batch_size
sequences_seen = step * args.batch_size * jax.process_count()
frames_seen = step * args.seq_len * args.batch_size * jax.process_count()

Copilot uses AI. Check for mistakes.
Comment on lines +509 to +510
sequences_seen = step * args.batch_size
frames_seen = step * args.seq_len * args.batch_size
Copy link

Copilot AI Oct 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These calculations count each sequence and frame multiple times in distributed training scenarios. When using data parallelism across multiple devices, each device processes its own batch independently, but step increments on all devices. This means the totals will be multiplied by the number of processes. Consider multiplying by jax.process_count() to get accurate global counts, or divide by process count if tracking per-process metrics.

Suggested change
sequences_seen = step * args.batch_size
frames_seen = step * args.seq_len * args.batch_size
sequences_seen = step * args.batch_size * jax.process_count()
frames_seen = step * args.seq_len * args.batch_size * jax.process_count()

Copilot uses AI. Check for mistakes.
Comment on lines +688 to +689
sequences_seen = step * args.batch_size
frames_seen = step * args.seq_len * args.batch_size
Copy link

Copilot AI Oct 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These calculations count each sequence and frame multiple times in distributed training scenarios. When using data parallelism across multiple devices, each device processes its own batch independently, but step increments on all devices. This means the totals will be multiplied by the number of processes. Consider multiplying by jax.process_count() to get accurate global counts, or divide by process count if tracking per-process metrics.

Suggested change
sequences_seen = step * args.batch_size
frames_seen = step * args.seq_len * args.batch_size
sequences_seen = step * args.batch_size * jax.process_count()
frames_seen = step * args.seq_len * args.batch_size * jax.process_count()

Copilot uses AI. Check for mistakes.
Copy link

@emergenz emergenz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test this on multi-host.

@emergenz
Copy link

@avocadoali small ping

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants