Skip to content

Commit 5f3ea88

Browse files
Fix logging inconsistency: unify metrics logging to use global_steps
- Add global_steps property to PeftTrainer (returns train_steps for compatibility) - Update DPO trainer to use global_steps for aux metrics logging - Update progress bar to use global_steps for initial_steps - Ensures consistent step counting across all RL algorithms - Resolves confusion about max_steps parameter interpretation
1 parent f8baab9 commit 5f3ea88

File tree

2 files changed

+8
-3
lines changed

2 files changed

+8
-3
lines changed

tunix/sft/dpo/dpo_trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ def _prepare_inputs(
262262

263263
@override
264264
def _post_process_train_step(self, aux: Any) -> None:
265-
m, s = self._mode, self._train_steps
265+
m, s = self._mode, self.global_steps
266266
self.metrics_logger.log("rewards/chosen", aux["rewards/chosen"], m, s)
267267
self.metrics_logger.log("rewards/rejected", aux["rewards/rejected"], m, s)
268268
self.metrics_logger.log("rewards/margin", aux["rewards/margin"], m, s)
@@ -274,7 +274,7 @@ def _post_process_train_step(self, aux: Any) -> None:
274274

275275
@override
276276
def _post_process_eval_step(self, aux: Any) -> None:
277-
m, s = self._mode, self._train_steps
277+
m, s = self._mode, self.global_steps
278278
self.metrics_logger.log("rewards/chosen", aux["rewards/chosen"], m, s)
279279
self.metrics_logger.log("rewards/rejected", aux["rewards/rejected"], m, s)
280280
self.metrics_logger.log("rewards/margin", aux["rewards/margin"], m, s)

tunix/sft/peft_trainer.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -586,7 +586,7 @@ def train(
586586
if self.config.max_steps is not None and self._pbar is None:
587587
self._pbar = progress_bar.ProgressBar(
588588
metrics_logger=self.metrics_logger,
589-
initial_steps=self._train_steps,
589+
initial_steps=self.global_steps,
590590
max_steps=self.config.max_steps,
591591
description=self.config.pbar_description,
592592
)
@@ -713,6 +713,11 @@ def train_steps(self) -> int:
713713
"""Returns the number of train steps taken."""
714714
return self._train_steps
715715

716+
@property
717+
def global_steps(self) -> int:
718+
"""Returns the number of global steps taken (same as train_steps for compatibility)."""
719+
return self._train_steps
720+
716721
@property
717722
def iter_steps(self) -> int:
718723
"""Returns the number of iterator steps taken."""

0 commit comments

Comments
 (0)