-
Notifications
You must be signed in to change notification settings - Fork 3
Open
Description
Problem
The TTT position loss weighting logic 0.8**i is hardcoded and duplicated in 3 separate locations in eagle3_trainer.py:
- Line 251 (
_backward):ploss_weight = [0.8**i for i in range(len(plosses))] - Line 302 (
_aggregate_eval_metrics):[0.8**i for i in range(avg_plosses.shape[0])] - Line 394 (
_aggregate_metrics):[0.8**i for i in range(avg_plosses.shape[0])]
The weighting strategy, decay factor, and reduction are all implicitly embedded in each call site with no abstraction.
Proposed Solution
Extract a ploss function that takes only the per-position losses and fully encapsulates the weighting strategy internally:
def weighted_ploss(plosses: list[torch.Tensor]) -> torch.Tensor:
"""Compute weighted position loss.
Encapsulates the full weighting strategy — callers don't need to know
whether it's exponential decay, linear decay, or something else.
"""
weights = [0.8 ** i for i in range(len(plosses))]
return sum(w * p for w, p in zip(weights, plosses)) / sum(weights)The key design choice: the function signature is just plosses in, scalar out. The weighting strategy is an implementation detail hidden inside. This makes it easy to swap to a completely different strategy (e.g., non-uniform, learned weights, position-dependent) without changing any call sites.
This function should be used in:
_backwardfor the training loss_aggregate_metricsfor the training metric_aggregate_eval_metricsfor the eval metric
Files
torchspec/training/eagle3_trainer.py
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels