Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
355 changes: 322 additions & 33 deletions tests/test_eagle3_loss.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion tools/benchmark_eagle3.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def run_eagle3_forward(model, batch):
eagle3.length,
)

plosses, _, acces = model(
plosses, _, acces, _ = model(
input_ids=batch["input_ids"],
attention_mask=batch["attention_mask"],
target=target,
Expand Down
2 changes: 1 addition & 1 deletion tools/max_seq_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def run_eagle3_forward(
eagle3.length,
)

plosses, _, acces = model(
plosses, _, acces, _ = model(
input_ids=batch["input_ids"],
attention_mask=batch["attention_mask"],
target=target,
Expand Down
2 changes: 2 additions & 0 deletions torchspec/config/train_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ class TrainingConfig:

gradient_checkpointing: bool = False
learning_rate: float = 1e-4
lk_eta: float = 3.0
loss_type: str = "forward_kl"
load_path: Optional[str] = None
lr_decay_style: str = "cosine"
lr_total_steps: Optional[int] = None
Expand Down
70 changes: 48 additions & 22 deletions torchspec/models/eagle3.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@
from torchspec.models.ops.loss import (
compiled_forward_kl_loss,
compiled_forward_kl_loss_from_hs,
compiled_lk_alpha_loss,
compiled_lk_alpha_loss_from_hs,
compiled_lk_lambda_loss,
compiled_lk_lambda_loss_from_hs,
)
from torchspec.utils.tensor import padding

Expand Down Expand Up @@ -56,13 +60,26 @@ def __init__(
length: int = 7,
attention_backend="sdpa",
gradient_checkpointing: bool = False,
loss_type: str = "forward_kl",
lk_eta: float = 3.0,
):
super().__init__()
self.draft_model = draft_model
self.length = length
self.attention_backend = attention_backend
self.gradient_checkpointing = gradient_checkpointing
self.vocab_pruning = draft_model.vocab_size != draft_model.target_vocab_size
self.loss_type = loss_type
self.lk_eta = lk_eta

def _select_loss_fns(self):
"""Return (precomputed_fn, lazy_fn) based on self.loss_type."""
if self.loss_type == "lk_alpha":
return compiled_lk_alpha_loss, compiled_lk_alpha_loss_from_hs
elif self.loss_type == "lk_lambda":
return compiled_lk_lambda_loss, compiled_lk_lambda_loss_from_hs
else:
return compiled_forward_kl_loss, compiled_forward_kl_loss_from_hs

def _calculate_loss(
self,
Expand All @@ -74,38 +91,40 @@ def _calculate_loss(
norm_weight: torch.Tensor,
lm_head_weight: torch.Tensor,
norm_eps: float,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute forward-KL loss and accuracy for one TTT step.
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Compute loss, accuracy, and alpha for one TTT step.

Both paths pass full (B*T, ...) flat views + valid_idx into the
compiled function so torch.compile can fuse index_select with
subsequent ops, avoiding separate (N_valid, V) copies outside.

- PrecomputedTarget (vocab pruning): compiled_forward_kl_loss
with pre-computed target probs.
- LazyTarget (no pruning): compiled_forward_kl_loss_from_hs
computes target softmax inside the compiled graph.
- PrecomputedTarget (vocab pruning): compiled loss with pre-computed target probs.
- LazyTarget (no pruning): compiled loss computes target softmax inside the graph.

Returns (loss, acc, alpha) where alpha is 0.0 for forward_kl.
"""
valid_idx = mask.flatten().nonzero().squeeze(-1)
# Guard against all-masked positions to avoid nan from mean() on empty tensors.
if valid_idx.numel() == 0:
zero = hidden_states.new_tensor(0.0)
return zero, zero
return zero, zero, zero
# Important as it prevents recompilation.
torch._dynamo.mark_dynamic(valid_idx, 0)
hs_flat = hidden_states.reshape(-1, hidden_states.shape[-1])

precomputed_fn, lazy_fn = self._select_loss_fns()
is_lk = self.loss_type in ("lk_alpha", "lk_lambda")

if isinstance(target, PrecomputedTarget):
target_p_step = target.target_p_padded[:, idx : idx + seq_length, :]
tp_flat = target_p_step.reshape(-1, target_p_step.shape[-1])
args = (hs_flat, tp_flat, valid_idx, norm_weight, lm_head_weight, norm_eps)
if self.loss_type == "lk_lambda":
args = args + (self.lk_eta,)
if self.gradient_checkpointing and self.training:
return torch_checkpoint(
compiled_forward_kl_loss,
*args,
use_reentrant=False,
)
return compiled_forward_kl_loss(*args)
result = torch_checkpoint(precomputed_fn, *args, use_reentrant=False)
else:
result = precomputed_fn(*args)
else:
# lazy
ths_flat = target.hidden_states_padded[:, idx : idx + seq_length, :].reshape(
Expand All @@ -120,13 +139,18 @@ def _calculate_loss(
target.lm_head_weight,
norm_eps,
)
if self.loss_type == "lk_lambda":
args = args + (self.lk_eta,)
if self.gradient_checkpointing and self.training:
return torch_checkpoint(
compiled_forward_kl_loss_from_hs,
*args,
use_reentrant=False,
)
return compiled_forward_kl_loss_from_hs(*args)
result = torch_checkpoint(lazy_fn, *args, use_reentrant=False)
else:
result = lazy_fn(*args)

if is_lk:
return result # (loss, acc, alpha)
else:
loss, acc = result
return loss, acc, hidden_states.new_tensor(0.0)

def forward(
self,
Expand All @@ -137,7 +161,7 @@ def forward(
hidden_states: torch.Tensor,
past_key_values: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
position_ids: Optional[torch.Tensor] = None,
) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]:
) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]:
batch_size, seq_length, _ = hidden_states.shape
seq_length_with_past = seq_length
past_key_values_length = 0
Expand Down Expand Up @@ -180,6 +204,7 @@ def forward(
plosses = []
vlosses = []
acces = []
alphas = []
cache_keys = None
cache_values = None

Expand Down Expand Up @@ -218,7 +243,7 @@ def forward(

hidden_states = hidden_states_out

loss, acc = self._calculate_loss(
loss, acc, alpha = self._calculate_loss(
hidden_states=hidden_states,
target=target,
mask=mask,
Expand All @@ -230,11 +255,12 @@ def forward(
)
plosses.append(loss)
acces.append(acc)
alphas.append(alpha)

if not is_last:
input_ids = padding(input_ids, left=False)
mask = padding(mask, left=False)
return plosses, vlosses, acces
return plosses, vlosses, acces, alphas


@torch.no_grad()
Expand Down
8 changes: 7 additions & 1 deletion torchspec/models/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,19 @@
compile_friendly_flex_attention,
generate_eagle3_mask,
)
from torchspec.models.ops.loss import compiled_forward_kl_loss
from torchspec.models.ops.loss import (
compiled_forward_kl_loss,
compiled_lk_alpha_loss,
compiled_lk_lambda_loss,
)
from torchspec.models.ops.loss_mask import compute_assistant_loss_mask

__all__ = [
"compile_friendly_create_block_mask",
"compile_friendly_flex_attention",
"generate_eagle3_mask",
"compiled_forward_kl_loss",
"compiled_lk_alpha_loss",
"compiled_lk_lambda_loss",
"compute_assistant_loss_mask",
]
151 changes: 151 additions & 0 deletions torchspec/models/ops/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,82 @@ def compiled_forward_kl_loss(
return loss, acc


@torch.compile(dynamic=None)
def compiled_lk_alpha_loss(
prenorm_hidden_states_flat,
target_p_flat,
valid_idx,
norm_weight,
lm_head_weight,
norm_eps,
):
"""LK^α loss: -log(acceptance_rate).mean().

Directly optimizes the log acceptance rate α_i = Σ_x min(p_i(x), q_i(x)).
"""
hs = prenorm_hidden_states_flat.index_select(0, valid_idx)
tp = target_p_flat.index_select(0, valid_idx)

# RMSNorm
hs_f32 = hs.float()
variance = hs_f32.pow(2).mean(-1, keepdim=True)
rstd = torch.rsqrt(variance + norm_eps)
norm_hs = (hs_f32 * rstd).to(hs.dtype) * norm_weight

logits = F.linear(norm_hs, lm_head_weight) # (N, V_out)
q = F.softmax(logits.float(), dim=-1)

# Acceptance rate per position
alpha = torch.min(tp, q).sum(-1) # (N,)
loss = -torch.log(alpha.clamp(min=1e-8)).mean()

acc = (logits.argmax(-1) == tp.argmax(-1)).float().mean()

return loss, acc, alpha.mean()


@torch.compile(dynamic=None)
def compiled_lk_lambda_loss(
prenorm_hidden_states_flat,
target_p_flat,
valid_idx,
norm_weight,
lm_head_weight,
norm_eps,
eta,
):
"""LK^λ loss: λ·KL(p‖q) + (1-λ)·TV(p,q) where λ = exp(-η·sg[α])."""
hs = prenorm_hidden_states_flat.index_select(0, valid_idx)
tp = target_p_flat.index_select(0, valid_idx)

# RMSNorm
hs_f32 = hs.float()
variance = hs_f32.pow(2).mean(-1, keepdim=True)
rstd = torch.rsqrt(variance + norm_eps)
norm_hs = (hs_f32 * rstd).to(hs.dtype) * norm_weight

logits = F.linear(norm_hs, lm_head_weight) # (N, V_out)
q = F.softmax(logits.float(), dim=-1)
log_q = F.log_softmax(logits.float(), dim=-1)

# Acceptance rate (stop-gradient for λ computation)
alpha = torch.min(tp, q).sum(-1) # (N,)
lam = torch.exp(-eta * alpha.detach()) # (N,)

# KL(p‖q) per position
kl = F.kl_div(log_q, tp, reduction="none").sum(-1) # (N,)

# TV(p,q) per position
tv = 0.5 * (tp - q).abs().sum(-1) # (N,)

# Combined loss
loss = (lam * kl + (1.0 - lam) * tv).mean()

acc = (logits.argmax(-1) == tp.argmax(-1)).float().mean()

return loss, acc, alpha.mean()


@torch.compile(dynamic=None)
def compiled_forward_kl_loss_from_hs(
prenorm_hidden_states_flat,
Expand Down Expand Up @@ -106,3 +182,78 @@ def compiled_forward_kl_loss_from_hs(
acc = (logits.argmax(-1) == tp.argmax(-1)).float().mean()

return loss, acc


@torch.compile(dynamic=None)
def compiled_lk_alpha_loss_from_hs(
prenorm_hidden_states_flat,
target_hidden_states_flat,
valid_idx,
norm_weight,
lm_head_weight,
target_lm_head_weight,
norm_eps,
):
"""LK^α loss from hidden states (LazyTarget path)."""
hs = prenorm_hidden_states_flat.index_select(0, valid_idx)
ths = target_hidden_states_flat.index_select(0, valid_idx)

# Target probs
tp = F.softmax(F.linear(ths, target_lm_head_weight).float(), dim=-1)

# RMSNorm
hs_f32 = hs.float()
variance = hs_f32.pow(2).mean(-1, keepdim=True)
rstd = torch.rsqrt(variance + norm_eps)
norm_hs = (hs_f32 * rstd).to(hs.dtype) * norm_weight

logits = F.linear(norm_hs, lm_head_weight)
q = F.softmax(logits.float(), dim=-1)

alpha = torch.min(tp, q).sum(-1)
loss = -torch.log(alpha.clamp(min=1e-8)).mean()

acc = (logits.argmax(-1) == tp.argmax(-1)).float().mean()

return loss, acc, alpha.mean()


@torch.compile(dynamic=None)
def compiled_lk_lambda_loss_from_hs(
prenorm_hidden_states_flat,
target_hidden_states_flat,
valid_idx,
norm_weight,
lm_head_weight,
target_lm_head_weight,
norm_eps,
eta,
):
"""LK^λ loss from hidden states (LazyTarget path)."""
hs = prenorm_hidden_states_flat.index_select(0, valid_idx)
ths = target_hidden_states_flat.index_select(0, valid_idx)

# Target probs
tp = F.softmax(F.linear(ths, target_lm_head_weight).float(), dim=-1)

# RMSNorm
hs_f32 = hs.float()
variance = hs_f32.pow(2).mean(-1, keepdim=True)
rstd = torch.rsqrt(variance + norm_eps)
norm_hs = (hs_f32 * rstd).to(hs.dtype) * norm_weight

logits = F.linear(norm_hs, lm_head_weight)
q = F.softmax(logits.float(), dim=-1)
log_q = F.log_softmax(logits.float(), dim=-1)

alpha = torch.min(tp, q).sum(-1)
lam = torch.exp(-eta * alpha.detach())

kl = F.kl_div(log_q, tp, reduction="none").sum(-1)
tv = 0.5 * (tp - q).abs().sum(-1)

loss = (lam * kl + (1.0 - lam) * tv).mean()

acc = (logits.argmax(-1) == tp.argmax(-1)).float().mean()

return loss, acc, alpha.mean()
Loading