Skip to content

refactor: extract a dedicated ploss weighting function #12

@cicirori

Description

@cicirori

Problem

The TTT position loss weighting logic 0.8**i is hardcoded and duplicated in 3 separate locations in eagle3_trainer.py:

  • Line 251 (_backward): ploss_weight = [0.8**i for i in range(len(plosses))]
  • Line 302 (_aggregate_eval_metrics): [0.8**i for i in range(avg_plosses.shape[0])]
  • Line 394 (_aggregate_metrics): [0.8**i for i in range(avg_plosses.shape[0])]

The weighting strategy, decay factor, and reduction are all implicitly embedded in each call site with no abstraction.

Proposed Solution

Extract a ploss function that takes only the per-position losses and fully encapsulates the weighting strategy internally:

def weighted_ploss(plosses: list[torch.Tensor]) -> torch.Tensor:
    """Compute weighted position loss.

    Encapsulates the full weighting strategy — callers don't need to know
    whether it's exponential decay, linear decay, or something else.
    """
    weights = [0.8 ** i for i in range(len(plosses))]
    return sum(w * p for w, p in zip(weights, plosses)) / sum(weights)

The key design choice: the function signature is just plosses in, scalar out. The weighting strategy is an implementation detail hidden inside. This makes it easy to swap to a completely different strategy (e.g., non-uniform, learned weights, position-dependent) without changing any call sites.

This function should be used in:

  • _backward for the training loss
  • _aggregate_metrics for the training metric
  • _aggregate_eval_metrics for the eval metric

Files

  • torchspec/training/eagle3_trainer.py

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions