Skip to content

refactor: extract shared metric aggregation logic in Eagle3Trainer #11

@cicirori

Description

@cicirori

Problem

_aggregate_metrics (line 368-415) and _aggregate_eval_metrics (line 285-321) in eagle3_trainer.py share ~30 lines of nearly identical logic:

  • torch.stack + mean + all_reduce for plosses and acces
  • simulated_acc_len cumulative calculation (cumulative *= acces[i])
  • 0.8**i weighted loss computation
  • Per-position metric extraction

Bug fixes must be applied in two places, which is error-prone.

Proposed Solution

Extract a shared helper method:

def _compute_weighted_loss_and_acc(self, avg_plosses, avg_acces, prefix="train"):
    """Shared logic: simulated_acc_len, weighted loss, per-position metrics."""
    ...

Both _aggregate_metrics and _aggregate_eval_metrics would call this helper, adding only their unique fields (e.g., grad_norm, lr for training).

Files

  • torchspec/training/eagle3_trainer.py

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions