Problem
_aggregate_metrics (line 368-415) and _aggregate_eval_metrics (line 285-321) in eagle3_trainer.py share ~30 lines of nearly identical logic:
torch.stack + mean + all_reduce for plosses and acces
simulated_acc_len cumulative calculation (cumulative *= acces[i])
0.8**i weighted loss computation
- Per-position metric extraction
Bug fixes must be applied in two places, which is error-prone.
Proposed Solution
Extract a shared helper method:
def _compute_weighted_loss_and_acc(self, avg_plosses, avg_acces, prefix="train"):
"""Shared logic: simulated_acc_len, weighted loss, per-position metrics."""
...
Both _aggregate_metrics and _aggregate_eval_metrics would call this helper, adding only their unique fields (e.g., grad_norm, lr for training).
Files
torchspec/training/eagle3_trainer.py