diff --git a/tests/test_eagle3_loss.py b/tests/test_eagle3_loss.py index ab0feb0..3d96e79 100644 --- a/tests/test_eagle3_loss.py +++ b/tests/test_eagle3_loss.py @@ -2,10 +2,11 @@ Verifies that: 1. compiled_forward_kl_loss matches a naive reference implementation. -2. compute_target_p_padded produces correct shapes and valid probabilities +2. compiled_lk_alpha_loss and compiled_lk_lambda_loss match reference implementations. +3. compute_target_p_padded produces correct shapes and valid probabilities for both pruning and non-pruning paths. -3. The lazy target path (non-pruning, target_p_padded=None) produces identical - losses to the pre-computed target_p_padded path. +4. The lazy target path (non-pruning, target_p_padded=None) produces identical + losses to the pre-computed target_p_padded path for all loss types. """ import unittest @@ -24,6 +25,8 @@ from torchspec.models.ops.loss import ( compiled_forward_kl_loss, compiled_forward_kl_loss_from_hs, + compiled_lk_alpha_loss, + compiled_lk_lambda_loss, ) @@ -41,6 +44,44 @@ def _reference_forward_kl_loss(hs_flat, target_p_flat, norm_weight, lm_head_weig return loss, acc +def _reference_lk_alpha_loss(hs_flat, target_p_flat, norm_weight, lm_head_weight, norm_eps): + """Pure-Python reference for LK^α loss.""" + hs_f32 = hs_flat.float() + variance = hs_f32.pow(2).mean(-1, keepdim=True) + rstd = torch.rsqrt(variance + norm_eps) + norm_hs = (hs_f32 * rstd).to(hs_flat.dtype) * norm_weight + + logits = F.linear(norm_hs, lm_head_weight) + q = F.softmax(logits.float(), dim=-1) + + alpha = torch.min(target_p_flat, q).sum(-1) + loss = -torch.log(alpha.clamp(min=1e-8)).mean() + acc = (logits.argmax(-1) == target_p_flat.argmax(-1)).float().mean() + return loss, acc, alpha.mean() + + +def _reference_lk_lambda_loss(hs_flat, target_p_flat, norm_weight, lm_head_weight, norm_eps, eta): + """Pure-Python reference for LK^λ loss.""" + hs_f32 = hs_flat.float() + variance = hs_f32.pow(2).mean(-1, keepdim=True) + rstd = torch.rsqrt(variance + norm_eps) + norm_hs = (hs_f32 * rstd).to(hs_flat.dtype) * norm_weight + + logits = F.linear(norm_hs, lm_head_weight) + q = F.softmax(logits.float(), dim=-1) + log_q = F.log_softmax(logits.float(), dim=-1) + + alpha = torch.min(target_p_flat, q).sum(-1) + lam = torch.exp(-eta * alpha.detach()) + + kl = F.kl_div(log_q, target_p_flat, reduction="none").sum(-1) + tv = 0.5 * (target_p_flat - q).abs().sum(-1) + + loss = (lam * kl + (1.0 - lam) * tv).mean() + acc = (logits.argmax(-1) == target_p_flat.argmax(-1)).float().mean() + return loss, acc, alpha.mean() + + def _make_config(H=128, V=256, draft_V=None, num_heads=4, num_kv_heads=2): config = LlamaConfig( hidden_size=H, @@ -59,13 +100,17 @@ def _make_config(H=128, V=256, draft_V=None, num_heads=4, num_kv_heads=2): return config -def _make_model(config, length=3, attention_backend="sdpa", device="cpu"): +def _make_model( + config, length=3, attention_backend="sdpa", device="cpu", loss_type="forward_kl", lk_eta=3.0 +): draft_model = LlamaForCausalLMEagle3(config, attention_backend=attention_backend) draft_model = draft_model.to(device=device, dtype=torch.bfloat16) model = Eagle3Model( draft_model, length=length, attention_backend=attention_backend, + loss_type=loss_type, + lk_eta=lk_eta, ) model.eval() return model @@ -155,6 +200,155 @@ def test_loss_non_negative_and_finite(self): self.assertLessEqual(acc.item(), 1.0) +class TestCompiledLkAlphaLoss(unittest.TestCase): + """compiled_lk_alpha_loss should match the reference implementation.""" + + def test_matches_reference(self): + torch.manual_seed(42) + N, H, V = 32, 128, 256 + hs = torch.randn(N, H, dtype=torch.bfloat16) + norm_weight = torch.randn(H, dtype=torch.bfloat16) + lm_head_weight = torch.randn(V, H, dtype=torch.bfloat16) + norm_eps = 1e-6 + valid_idx = torch.arange(N) + + raw_logits = F.linear(hs.float(), lm_head_weight.float()) + target_p = F.softmax(raw_logits + torch.randn_like(raw_logits) * 0.5, dim=-1) + + loss, acc, alpha = compiled_lk_alpha_loss( + hs, target_p, valid_idx, norm_weight, lm_head_weight, norm_eps + ) + ref_loss, ref_acc, ref_alpha = _reference_lk_alpha_loss( + hs, target_p, norm_weight, lm_head_weight, norm_eps + ) + + torch.testing.assert_close(loss, ref_loss, atol=1e-3, rtol=1e-3) + torch.testing.assert_close(acc, ref_acc, atol=1e-3, rtol=1e-3) + torch.testing.assert_close(alpha, ref_alpha, atol=1e-3, rtol=1e-3) + + def test_perfect_prediction_loss_zero(self): + """When draft == target, α=1 so -log(α)=0.""" + torch.manual_seed(0) + N, H, V = 16, 64, 32 + norm_weight = torch.ones(H, dtype=torch.float32) + lm_head_weight = torch.randn(V, H, dtype=torch.float32) + norm_eps = 1e-6 + valid_idx = torch.arange(N) + + hs = torch.randn(N, H, dtype=torch.float32) + variance = hs.pow(2).mean(-1, keepdim=True) + rstd = torch.rsqrt(variance + norm_eps) + norm_hs = hs * rstd * norm_weight + logits = F.linear(norm_hs, lm_head_weight) + target_p = F.softmax(logits, dim=-1) + + loss, acc, alpha = compiled_lk_alpha_loss( + hs, target_p, valid_idx, norm_weight, lm_head_weight, norm_eps + ) + self.assertAlmostEqual(loss.item(), 0.0, places=3) + self.assertAlmostEqual(alpha.item(), 1.0, places=3) + self.assertAlmostEqual(acc.item(), 1.0, places=2) + + def test_loss_finite_and_alpha_in_range(self): + torch.manual_seed(0) + N, H, V = 16, 64, 32 + hs = torch.randn(N, H, dtype=torch.bfloat16) + norm_weight = torch.randn(H, dtype=torch.bfloat16) + lm_head_weight = torch.randn(V, H, dtype=torch.bfloat16) + target_p = F.softmax(torch.randn(N, V), dim=-1) + valid_idx = torch.arange(N) + + loss, acc, alpha = compiled_lk_alpha_loss( + hs, target_p, valid_idx, norm_weight, lm_head_weight, 1e-6 + ) + self.assertTrue(torch.isfinite(loss)) + self.assertGreaterEqual(alpha.item(), 0.0) + self.assertLessEqual(alpha.item(), 1.0) + + +class TestCompiledLkLambdaLoss(unittest.TestCase): + """compiled_lk_lambda_loss should match the reference implementation.""" + + def test_matches_reference(self): + torch.manual_seed(42) + N, H, V = 32, 128, 256 + hs = torch.randn(N, H, dtype=torch.bfloat16) + norm_weight = torch.randn(H, dtype=torch.bfloat16) + lm_head_weight = torch.randn(V, H, dtype=torch.bfloat16) + norm_eps = 1e-6 + eta = 3.0 + valid_idx = torch.arange(N) + + raw_logits = F.linear(hs.float(), lm_head_weight.float()) + target_p = F.softmax(raw_logits + torch.randn_like(raw_logits) * 0.5, dim=-1) + + loss, acc, alpha = compiled_lk_lambda_loss( + hs, target_p, valid_idx, norm_weight, lm_head_weight, norm_eps, eta + ) + ref_loss, ref_acc, ref_alpha = _reference_lk_lambda_loss( + hs, target_p, norm_weight, lm_head_weight, norm_eps, eta + ) + + torch.testing.assert_close(loss, ref_loss, atol=1e-3, rtol=1e-3) + torch.testing.assert_close(acc, ref_acc, atol=1e-3, rtol=1e-3) + torch.testing.assert_close(alpha, ref_alpha, atol=1e-3, rtol=1e-3) + + def test_eta_sensitivity(self): + """Different η values should produce different losses.""" + torch.manual_seed(42) + N, H, V = 32, 128, 256 + hs = torch.randn(N, H, dtype=torch.bfloat16) + norm_weight = torch.randn(H, dtype=torch.bfloat16) + lm_head_weight = torch.randn(V, H, dtype=torch.bfloat16) + target_p = F.softmax(torch.randn(N, V), dim=-1) + valid_idx = torch.arange(N) + + loss_eta3, _, _ = compiled_lk_lambda_loss( + hs, target_p, valid_idx, norm_weight, lm_head_weight, 1e-6, 3.0 + ) + loss_eta10, _, _ = compiled_lk_lambda_loss( + hs, target_p, valid_idx, norm_weight, lm_head_weight, 1e-6, 10.0 + ) + self.assertFalse(torch.allclose(loss_eta3, loss_eta10)) + + def test_perfect_prediction_loss_zero(self): + """When draft == target, KL=0 and TV=0 so loss=0.""" + torch.manual_seed(0) + N, H, V = 16, 64, 32 + norm_weight = torch.ones(H, dtype=torch.float32) + lm_head_weight = torch.randn(V, H, dtype=torch.float32) + norm_eps = 1e-6 + valid_idx = torch.arange(N) + + hs = torch.randn(N, H, dtype=torch.float32) + variance = hs.pow(2).mean(-1, keepdim=True) + rstd = torch.rsqrt(variance + norm_eps) + norm_hs = hs * rstd * norm_weight + logits = F.linear(norm_hs, lm_head_weight) + target_p = F.softmax(logits, dim=-1) + + loss, acc, alpha = compiled_lk_lambda_loss( + hs, target_p, valid_idx, norm_weight, lm_head_weight, norm_eps, 3.0 + ) + self.assertAlmostEqual(loss.item(), 0.0, places=3) + self.assertAlmostEqual(alpha.item(), 1.0, places=3) + + def test_loss_finite(self): + torch.manual_seed(0) + N, H, V = 16, 64, 32 + hs = torch.randn(N, H, dtype=torch.bfloat16) + norm_weight = torch.randn(H, dtype=torch.bfloat16) + lm_head_weight = torch.randn(V, H, dtype=torch.bfloat16) + target_p = F.softmax(torch.randn(N, V), dim=-1) + valid_idx = torch.arange(N) + + loss, acc, alpha = compiled_lk_lambda_loss( + hs, target_p, valid_idx, norm_weight, lm_head_weight, 1e-6, 3.0 + ) + self.assertTrue(torch.isfinite(loss)) + self.assertGreaterEqual(loss.item(), 0.0) + + class TestComputeTargetPPadded(unittest.TestCase): """compute_target_p_padded: shape, dtype, and probability correctness.""" @@ -212,12 +406,14 @@ def test_loss_mask_respected_in_position_mask(self): class TestLazyVsPrecomputedTarget(unittest.TestCase): """The lazy path (target_p_padded=None) must produce identical losses.""" - def _run_both_paths(self, device="cpu"): + def _run_both_paths(self, device="cpu", loss_type="forward_kl", lk_eta=3.0): torch.manual_seed(42) H, V, B, T, length = 128, 256, 1, 32, 3 config = _make_config(H=H, V=V) - model = _make_model(config, length=length, device=device) + model = _make_model( + config, length=length, device=device, loss_type=loss_type, lk_eta=lk_eta + ) batch = _make_batch(B, T, H, V, device=device) draft_model = model.draft_model @@ -230,7 +426,7 @@ def _run_both_paths(self, device="cpu"): precomputed = PrecomputedTarget(target_p_padded) with torch.no_grad(): - plosses_pre, _, acces_pre = model( + plosses_pre, _, acces_pre, alphas_pre = model( input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], target=precomputed, @@ -244,7 +440,7 @@ def _run_both_paths(self, device="cpu"): length, ) with torch.no_grad(): - plosses_lazy, _, acces_lazy = model( + plosses_lazy, _, acces_lazy, alphas_lazy = model( input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], target=lazy, @@ -252,47 +448,57 @@ def _run_both_paths(self, device="cpu"): hidden_states=batch["hidden_states"], ) - return plosses_pre, acces_pre, plosses_lazy, acces_lazy + return plosses_pre, acces_pre, alphas_pre, plosses_lazy, acces_lazy, alphas_lazy - def test_losses_match_cpu(self): - plosses_pre, acces_pre, plosses_lazy, acces_lazy = self._run_both_paths("cpu") + def _assert_paths_match(self, device, loss_type="forward_kl", lk_eta=3.0, atol=1e-4, rtol=1e-4): + results = self._run_both_paths(device, loss_type=loss_type, lk_eta=lk_eta) + plosses_pre, acces_pre, alphas_pre, plosses_lazy, acces_lazy, alphas_lazy = results for i, (pre, lazy) in enumerate(zip(plosses_pre, plosses_lazy)): torch.testing.assert_close( pre, lazy, - atol=1e-4, - rtol=1e-4, - msg=f"Loss mismatch at position {i}", + atol=atol, + rtol=rtol, + msg=f"Loss mismatch at position {i} (loss_type={loss_type})", ) for i, (pre, lazy) in enumerate(zip(acces_pre, acces_lazy)): torch.testing.assert_close( pre, lazy, - atol=1e-4, - rtol=1e-4, - msg=f"Accuracy mismatch at position {i}", - ) - - @unittest.skipUnless(torch.cuda.is_available(), "CUDA not available") - def test_losses_match_cuda(self): - plosses_pre, acces_pre, plosses_lazy, acces_lazy = self._run_both_paths("cuda") - for i, (pre, lazy) in enumerate(zip(plosses_pre, plosses_lazy)): - torch.testing.assert_close( - pre, - lazy, - atol=1e-3, - rtol=1e-3, - msg=f"Loss mismatch at position {i}", + atol=atol, + rtol=rtol, + msg=f"Accuracy mismatch at position {i} (loss_type={loss_type})", ) - for i, (pre, lazy) in enumerate(zip(acces_pre, acces_lazy)): + for i, (pre, lazy) in enumerate(zip(alphas_pre, alphas_lazy)): torch.testing.assert_close( pre, lazy, - atol=1e-3, - rtol=1e-3, - msg=f"Accuracy mismatch at position {i}", + atol=atol, + rtol=rtol, + msg=f"Alpha mismatch at position {i} (loss_type={loss_type})", ) + def test_forward_kl_losses_match_cpu(self): + self._assert_paths_match("cpu", loss_type="forward_kl") + + def test_lk_alpha_losses_match_cpu(self): + self._assert_paths_match("cpu", loss_type="lk_alpha") + + def test_lk_lambda_losses_match_cpu(self): + self._assert_paths_match("cpu", loss_type="lk_lambda") + + @unittest.skipUnless(torch.cuda.is_available(), "CUDA not available") + def test_forward_kl_losses_match_cuda(self): + self._assert_paths_match("cuda", loss_type="forward_kl", atol=1e-3, rtol=1e-3) + + @unittest.skipUnless(torch.cuda.is_available(), "CUDA not available") + def test_lk_alpha_losses_match_cuda(self): + self._assert_paths_match("cuda", loss_type="lk_alpha", atol=1e-3, rtol=1e-3) + + @unittest.skipUnless(torch.cuda.is_available(), "CUDA not available") + def test_lk_lambda_losses_match_cuda(self): + self._assert_paths_match("cuda", loss_type="lk_lambda", atol=1e-3, rtol=1e-3) + def _make_mask_patterns(BT): """Return (name, valid_idx) pairs covering diverse masking patterns.""" @@ -399,6 +605,75 @@ def _check_forward_kl_from_hs(self, valid_idx): torch.testing.assert_close(loss, loss_ref, atol=1e-5, rtol=1e-5) torch.testing.assert_close(acc, acc_ref, atol=1e-5, rtol=1e-5) + def _check_lk_alpha(self, valid_idx): + torch.manual_seed(7) + hs_flat = torch.randn(self.BT, self.H, dtype=torch.bfloat16) + norm_weight = torch.randn(self.H, dtype=torch.bfloat16) + lm_head_weight = torch.randn(self.V, self.H, dtype=torch.bfloat16) + tp_flat = F.softmax(torch.randn(self.BT, self.V), dim=-1) + norm_eps = 1e-6 + + loss, acc, alpha = compiled_lk_alpha_loss( + hs_flat, + tp_flat, + valid_idx, + norm_weight, + lm_head_weight, + norm_eps, + ) + + hs_valid = hs_flat[valid_idx] + tp_valid = tp_flat[valid_idx] + all_idx = torch.arange(hs_valid.shape[0]) + loss_ref, acc_ref, alpha_ref = compiled_lk_alpha_loss( + hs_valid, + tp_valid, + all_idx, + norm_weight, + lm_head_weight, + norm_eps, + ) + + torch.testing.assert_close(loss, loss_ref, atol=1e-5, rtol=1e-5) + torch.testing.assert_close(acc, acc_ref, atol=1e-5, rtol=1e-5) + torch.testing.assert_close(alpha, alpha_ref, atol=1e-5, rtol=1e-5) + + def _check_lk_lambda(self, valid_idx): + torch.manual_seed(7) + hs_flat = torch.randn(self.BT, self.H, dtype=torch.bfloat16) + norm_weight = torch.randn(self.H, dtype=torch.bfloat16) + lm_head_weight = torch.randn(self.V, self.H, dtype=torch.bfloat16) + tp_flat = F.softmax(torch.randn(self.BT, self.V), dim=-1) + norm_eps = 1e-6 + eta = 3.0 + + loss, acc, alpha = compiled_lk_lambda_loss( + hs_flat, + tp_flat, + valid_idx, + norm_weight, + lm_head_weight, + norm_eps, + eta, + ) + + hs_valid = hs_flat[valid_idx] + tp_valid = tp_flat[valid_idx] + all_idx = torch.arange(hs_valid.shape[0]) + loss_ref, acc_ref, alpha_ref = compiled_lk_lambda_loss( + hs_valid, + tp_valid, + all_idx, + norm_weight, + lm_head_weight, + norm_eps, + eta, + ) + + torch.testing.assert_close(loss, loss_ref, atol=1e-5, rtol=1e-5) + torch.testing.assert_close(acc, acc_ref, atol=1e-5, rtol=1e-5) + torch.testing.assert_close(alpha, alpha_ref, atol=1e-5, rtol=1e-5) + # Dynamically generate one test method per mask pattern per loss function. for _name, _vidx in _make_mask_patterns(TestValidIdxSubsetting.BT): @@ -415,8 +690,22 @@ def test(self): return test + def _make_lk_alpha(vidx=_vidx): + def test(self): + self._check_lk_alpha(vidx) + + return test + + def _make_lk_lambda(vidx=_vidx): + def test(self): + self._check_lk_lambda(vidx) + + return test + setattr(TestValidIdxSubsetting, f"test_forward_kl_{_name}", _make_kl()) setattr(TestValidIdxSubsetting, f"test_forward_kl_from_hs_{_name}", _make_kl_from_hs()) + setattr(TestValidIdxSubsetting, f"test_lk_alpha_{_name}", _make_lk_alpha()) + setattr(TestValidIdxSubsetting, f"test_lk_lambda_{_name}", _make_lk_lambda()) if __name__ == "__main__": diff --git a/tools/benchmark_eagle3.py b/tools/benchmark_eagle3.py index 6ed61f3..9a75d96 100755 --- a/tools/benchmark_eagle3.py +++ b/tools/benchmark_eagle3.py @@ -127,7 +127,7 @@ def run_eagle3_forward(model, batch): eagle3.length, ) - plosses, _, acces = model( + plosses, _, acces, _ = model( input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], target=target, diff --git a/tools/max_seq_search.py b/tools/max_seq_search.py index 50ada6c..57d6efa 100755 --- a/tools/max_seq_search.py +++ b/tools/max_seq_search.py @@ -157,7 +157,7 @@ def run_eagle3_forward( eagle3.length, ) - plosses, _, acces = model( + plosses, _, acces, _ = model( input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], target=target, diff --git a/torchspec/config/train_config.py b/torchspec/config/train_config.py index 11bea12..ed0624a 100644 --- a/torchspec/config/train_config.py +++ b/torchspec/config/train_config.py @@ -100,6 +100,8 @@ class TrainingConfig: gradient_checkpointing: bool = False learning_rate: float = 1e-4 + lk_eta: float = 3.0 + loss_type: str = "forward_kl" load_path: Optional[str] = None lr_decay_style: str = "cosine" lr_total_steps: Optional[int] = None diff --git a/torchspec/models/eagle3.py b/torchspec/models/eagle3.py index 00ad122..e0cfed9 100644 --- a/torchspec/models/eagle3.py +++ b/torchspec/models/eagle3.py @@ -29,6 +29,10 @@ from torchspec.models.ops.loss import ( compiled_forward_kl_loss, compiled_forward_kl_loss_from_hs, + compiled_lk_alpha_loss, + compiled_lk_alpha_loss_from_hs, + compiled_lk_lambda_loss, + compiled_lk_lambda_loss_from_hs, ) from torchspec.utils.tensor import padding @@ -56,6 +60,8 @@ def __init__( length: int = 7, attention_backend="sdpa", gradient_checkpointing: bool = False, + loss_type: str = "forward_kl", + lk_eta: float = 3.0, ): super().__init__() self.draft_model = draft_model @@ -63,6 +69,17 @@ def __init__( self.attention_backend = attention_backend self.gradient_checkpointing = gradient_checkpointing self.vocab_pruning = draft_model.vocab_size != draft_model.target_vocab_size + self.loss_type = loss_type + self.lk_eta = lk_eta + + def _select_loss_fns(self): + """Return (precomputed_fn, lazy_fn) based on self.loss_type.""" + if self.loss_type == "lk_alpha": + return compiled_lk_alpha_loss, compiled_lk_alpha_loss_from_hs + elif self.loss_type == "lk_lambda": + return compiled_lk_lambda_loss, compiled_lk_lambda_loss_from_hs + else: + return compiled_forward_kl_loss, compiled_forward_kl_loss_from_hs def _calculate_loss( self, @@ -74,38 +91,40 @@ def _calculate_loss( norm_weight: torch.Tensor, lm_head_weight: torch.Tensor, norm_eps: float, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Compute forward-KL loss and accuracy for one TTT step. + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Compute loss, accuracy, and alpha for one TTT step. Both paths pass full (B*T, ...) flat views + valid_idx into the compiled function so torch.compile can fuse index_select with subsequent ops, avoiding separate (N_valid, V) copies outside. - - PrecomputedTarget (vocab pruning): compiled_forward_kl_loss - with pre-computed target probs. - - LazyTarget (no pruning): compiled_forward_kl_loss_from_hs - computes target softmax inside the compiled graph. + - PrecomputedTarget (vocab pruning): compiled loss with pre-computed target probs. + - LazyTarget (no pruning): compiled loss computes target softmax inside the graph. + + Returns (loss, acc, alpha) where alpha is 0.0 for forward_kl. """ valid_idx = mask.flatten().nonzero().squeeze(-1) # Guard against all-masked positions to avoid nan from mean() on empty tensors. if valid_idx.numel() == 0: zero = hidden_states.new_tensor(0.0) - return zero, zero + return zero, zero, zero # Important as it prevents recompilation. torch._dynamo.mark_dynamic(valid_idx, 0) hs_flat = hidden_states.reshape(-1, hidden_states.shape[-1]) + precomputed_fn, lazy_fn = self._select_loss_fns() + is_lk = self.loss_type in ("lk_alpha", "lk_lambda") + if isinstance(target, PrecomputedTarget): target_p_step = target.target_p_padded[:, idx : idx + seq_length, :] tp_flat = target_p_step.reshape(-1, target_p_step.shape[-1]) args = (hs_flat, tp_flat, valid_idx, norm_weight, lm_head_weight, norm_eps) + if self.loss_type == "lk_lambda": + args = args + (self.lk_eta,) if self.gradient_checkpointing and self.training: - return torch_checkpoint( - compiled_forward_kl_loss, - *args, - use_reentrant=False, - ) - return compiled_forward_kl_loss(*args) + result = torch_checkpoint(precomputed_fn, *args, use_reentrant=False) + else: + result = precomputed_fn(*args) else: # lazy ths_flat = target.hidden_states_padded[:, idx : idx + seq_length, :].reshape( @@ -120,13 +139,18 @@ def _calculate_loss( target.lm_head_weight, norm_eps, ) + if self.loss_type == "lk_lambda": + args = args + (self.lk_eta,) if self.gradient_checkpointing and self.training: - return torch_checkpoint( - compiled_forward_kl_loss_from_hs, - *args, - use_reentrant=False, - ) - return compiled_forward_kl_loss_from_hs(*args) + result = torch_checkpoint(lazy_fn, *args, use_reentrant=False) + else: + result = lazy_fn(*args) + + if is_lk: + return result # (loss, acc, alpha) + else: + loss, acc = result + return loss, acc, hidden_states.new_tensor(0.0) def forward( self, @@ -137,7 +161,7 @@ def forward( hidden_states: torch.Tensor, past_key_values: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, position_ids: Optional[torch.Tensor] = None, - ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]: + ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]: batch_size, seq_length, _ = hidden_states.shape seq_length_with_past = seq_length past_key_values_length = 0 @@ -180,6 +204,7 @@ def forward( plosses = [] vlosses = [] acces = [] + alphas = [] cache_keys = None cache_values = None @@ -218,7 +243,7 @@ def forward( hidden_states = hidden_states_out - loss, acc = self._calculate_loss( + loss, acc, alpha = self._calculate_loss( hidden_states=hidden_states, target=target, mask=mask, @@ -230,11 +255,12 @@ def forward( ) plosses.append(loss) acces.append(acc) + alphas.append(alpha) if not is_last: input_ids = padding(input_ids, left=False) mask = padding(mask, left=False) - return plosses, vlosses, acces + return plosses, vlosses, acces, alphas @torch.no_grad() diff --git a/torchspec/models/ops/__init__.py b/torchspec/models/ops/__init__.py index 194be8d..27c49f3 100644 --- a/torchspec/models/ops/__init__.py +++ b/torchspec/models/ops/__init__.py @@ -23,7 +23,11 @@ compile_friendly_flex_attention, generate_eagle3_mask, ) -from torchspec.models.ops.loss import compiled_forward_kl_loss +from torchspec.models.ops.loss import ( + compiled_forward_kl_loss, + compiled_lk_alpha_loss, + compiled_lk_lambda_loss, +) from torchspec.models.ops.loss_mask import compute_assistant_loss_mask __all__ = [ @@ -31,5 +35,7 @@ "compile_friendly_flex_attention", "generate_eagle3_mask", "compiled_forward_kl_loss", + "compiled_lk_alpha_loss", + "compiled_lk_lambda_loss", "compute_assistant_loss_mask", ] diff --git a/torchspec/models/ops/loss.py b/torchspec/models/ops/loss.py index 476d36b..edfbe4c 100644 --- a/torchspec/models/ops/loss.py +++ b/torchspec/models/ops/loss.py @@ -65,6 +65,82 @@ def compiled_forward_kl_loss( return loss, acc +@torch.compile(dynamic=None) +def compiled_lk_alpha_loss( + prenorm_hidden_states_flat, + target_p_flat, + valid_idx, + norm_weight, + lm_head_weight, + norm_eps, +): + """LK^α loss: -log(acceptance_rate).mean(). + + Directly optimizes the log acceptance rate α_i = Σ_x min(p_i(x), q_i(x)). + """ + hs = prenorm_hidden_states_flat.index_select(0, valid_idx) + tp = target_p_flat.index_select(0, valid_idx) + + # RMSNorm + hs_f32 = hs.float() + variance = hs_f32.pow(2).mean(-1, keepdim=True) + rstd = torch.rsqrt(variance + norm_eps) + norm_hs = (hs_f32 * rstd).to(hs.dtype) * norm_weight + + logits = F.linear(norm_hs, lm_head_weight) # (N, V_out) + q = F.softmax(logits.float(), dim=-1) + + # Acceptance rate per position + alpha = torch.min(tp, q).sum(-1) # (N,) + loss = -torch.log(alpha.clamp(min=1e-8)).mean() + + acc = (logits.argmax(-1) == tp.argmax(-1)).float().mean() + + return loss, acc, alpha.mean() + + +@torch.compile(dynamic=None) +def compiled_lk_lambda_loss( + prenorm_hidden_states_flat, + target_p_flat, + valid_idx, + norm_weight, + lm_head_weight, + norm_eps, + eta, +): + """LK^λ loss: λ·KL(p‖q) + (1-λ)·TV(p,q) where λ = exp(-η·sg[α]).""" + hs = prenorm_hidden_states_flat.index_select(0, valid_idx) + tp = target_p_flat.index_select(0, valid_idx) + + # RMSNorm + hs_f32 = hs.float() + variance = hs_f32.pow(2).mean(-1, keepdim=True) + rstd = torch.rsqrt(variance + norm_eps) + norm_hs = (hs_f32 * rstd).to(hs.dtype) * norm_weight + + logits = F.linear(norm_hs, lm_head_weight) # (N, V_out) + q = F.softmax(logits.float(), dim=-1) + log_q = F.log_softmax(logits.float(), dim=-1) + + # Acceptance rate (stop-gradient for λ computation) + alpha = torch.min(tp, q).sum(-1) # (N,) + lam = torch.exp(-eta * alpha.detach()) # (N,) + + # KL(p‖q) per position + kl = F.kl_div(log_q, tp, reduction="none").sum(-1) # (N,) + + # TV(p,q) per position + tv = 0.5 * (tp - q).abs().sum(-1) # (N,) + + # Combined loss + loss = (lam * kl + (1.0 - lam) * tv).mean() + + acc = (logits.argmax(-1) == tp.argmax(-1)).float().mean() + + return loss, acc, alpha.mean() + + @torch.compile(dynamic=None) def compiled_forward_kl_loss_from_hs( prenorm_hidden_states_flat, @@ -106,3 +182,78 @@ def compiled_forward_kl_loss_from_hs( acc = (logits.argmax(-1) == tp.argmax(-1)).float().mean() return loss, acc + + +@torch.compile(dynamic=None) +def compiled_lk_alpha_loss_from_hs( + prenorm_hidden_states_flat, + target_hidden_states_flat, + valid_idx, + norm_weight, + lm_head_weight, + target_lm_head_weight, + norm_eps, +): + """LK^α loss from hidden states (LazyTarget path).""" + hs = prenorm_hidden_states_flat.index_select(0, valid_idx) + ths = target_hidden_states_flat.index_select(0, valid_idx) + + # Target probs + tp = F.softmax(F.linear(ths, target_lm_head_weight).float(), dim=-1) + + # RMSNorm + hs_f32 = hs.float() + variance = hs_f32.pow(2).mean(-1, keepdim=True) + rstd = torch.rsqrt(variance + norm_eps) + norm_hs = (hs_f32 * rstd).to(hs.dtype) * norm_weight + + logits = F.linear(norm_hs, lm_head_weight) + q = F.softmax(logits.float(), dim=-1) + + alpha = torch.min(tp, q).sum(-1) + loss = -torch.log(alpha.clamp(min=1e-8)).mean() + + acc = (logits.argmax(-1) == tp.argmax(-1)).float().mean() + + return loss, acc, alpha.mean() + + +@torch.compile(dynamic=None) +def compiled_lk_lambda_loss_from_hs( + prenorm_hidden_states_flat, + target_hidden_states_flat, + valid_idx, + norm_weight, + lm_head_weight, + target_lm_head_weight, + norm_eps, + eta, +): + """LK^λ loss from hidden states (LazyTarget path).""" + hs = prenorm_hidden_states_flat.index_select(0, valid_idx) + ths = target_hidden_states_flat.index_select(0, valid_idx) + + # Target probs + tp = F.softmax(F.linear(ths, target_lm_head_weight).float(), dim=-1) + + # RMSNorm + hs_f32 = hs.float() + variance = hs_f32.pow(2).mean(-1, keepdim=True) + rstd = torch.rsqrt(variance + norm_eps) + norm_hs = (hs_f32 * rstd).to(hs.dtype) * norm_weight + + logits = F.linear(norm_hs, lm_head_weight) + q = F.softmax(logits.float(), dim=-1) + log_q = F.log_softmax(logits.float(), dim=-1) + + alpha = torch.min(tp, q).sum(-1) + lam = torch.exp(-eta * alpha.detach()) + + kl = F.kl_div(log_q, tp, reduction="none").sum(-1) + tv = 0.5 * (tp - q).abs().sum(-1) + + loss = (lam * kl + (1.0 - lam) * tv).mean() + + acc = (logits.argmax(-1) == tp.argmax(-1)).float().mean() + + return loss, acc, alpha.mean() diff --git a/torchspec/training/eagle3_trainer.py b/torchspec/training/eagle3_trainer.py index cf75d4e..84bd716 100644 --- a/torchspec/training/eagle3_trainer.py +++ b/torchspec/training/eagle3_trainer.py @@ -93,6 +93,8 @@ def init_model( length=self.args.ttt_length, attention_backend=self.args.attention_backend, gradient_checkpointing=getattr(self.args, "gradient_checkpointing", True), + loss_type=getattr(self.args, "loss_type", "forward_kl"), + lk_eta=getattr(self.args, "lk_eta", 3.0), ) full_state = eagle3_model.state_dict() if dist.get_rank() == 0 else {} @@ -213,7 +215,9 @@ def _init_target_lm_head(self, target_model_path: str) -> None: # Forward / backward # ------------------------------------------------------------------ - def _forward(self, batch: dict) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + def _forward( + self, batch: dict + ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]: input_ids = padding(batch["input_ids"], left=False).cuda() target_hidden_states = padding(batch["last_hidden_states"], left=False).cuda() @@ -238,14 +242,14 @@ def _forward(self, batch: dict) -> Tuple[List[torch.Tensor], List[torch.Tensor]] ) del target_hidden_states - plosses, _, acces = self.model( + plosses, _, acces, alphas = self.model( input_ids=input_ids, attention_mask=batch["attention_mask"].cuda(), target=target, loss_mask=loss_mask, hidden_states=batch["hidden_states"].cuda(), ) - return plosses, acces + return plosses, acces, alphas def _backward(self, plosses: List[torch.Tensor], accumulation_steps: int = 1) -> torch.Tensor: ploss_weight = [0.8**i for i in range(len(plosses))] @@ -260,10 +264,11 @@ def _backward(self, plosses: List[torch.Tensor], accumulation_steps: int = 1) -> def eval_forward(self, batch: dict) -> dict: """Single forward pass without backward — returns per-position metrics.""" with torch.no_grad(): - plosses, acces = self._forward(batch) + plosses, acces, alphas = self._forward(batch) return { "plosses": torch.stack(plosses).detach(), "acces": torch.stack(acces).detach(), + "alphas": torch.stack(alphas).detach(), } def eval_from_cache(self) -> dict: @@ -299,9 +304,11 @@ def _aggregate_eval_metrics(self, all_step_metrics: list[dict]) -> dict: avg_plosses = torch.stack([m["plosses"] for m in all_step_metrics]).mean(dim=0) avg_acces = torch.stack([m["acces"] for m in all_step_metrics]).mean(dim=0) + avg_alphas = torch.stack([m["alphas"] for m in all_step_metrics]).mean(dim=0) dist.all_reduce(avg_plosses, op=dist.ReduceOp.AVG) dist.all_reduce(avg_acces, op=dist.ReduceOp.AVG) + dist.all_reduce(avg_alphas, op=dist.ReduceOp.AVG) cumulative = 1.0 simulated_acc_len = 0.0 @@ -317,11 +324,13 @@ def _aggregate_eval_metrics(self, all_step_metrics: list[dict]) -> dict: metrics: dict = { "eval/avg_loss": weighted_avg_loss, "eval/avg_acc": avg_acces.mean().item(), + "eval/avg_alpha": avg_alphas.mean().item(), "eval/simulated_acc_len": simulated_acc_len, } for i in range(avg_plosses.shape[0]): metrics[f"eval/ploss_{i}"] = avg_plosses[i].item() metrics[f"eval/acc_{i}"] = avg_acces[i].item() + metrics[f"eval/alpha_{i}"] = avg_alphas[i].item() if dist.get_rank() == 0: logger.info( @@ -343,12 +352,13 @@ def _train_step( batch_idx: int, num_batches: int, ) -> dict: - plosses, acces = self._forward(batch) + plosses, acces, alphas = self._forward(batch) total_loss = self._backward(plosses, accumulation_steps=accumulation_steps) return { "plosses": torch.stack(plosses).detach(), "acces": torch.stack(acces).detach(), + "alphas": torch.stack(alphas).detach(), "plosses_raw": [p.detach() for p in plosses], "acces_raw": [a.detach() for a in acces], "total_loss": total_loss.detach(), @@ -384,12 +394,15 @@ def _aggregate_metrics( plosses = [m["plosses"] for m in all_step_metrics] acces = [m["acces"] for m in all_step_metrics] + alphas_list = [m["alphas"] for m in all_step_metrics] avg_plosses = torch.stack(plosses).mean(dim=0) avg_acces = torch.stack(acces).mean(dim=0) + avg_alphas = torch.stack(alphas_list).mean(dim=0) dist.all_reduce(avg_plosses, op=dist.ReduceOp.AVG) dist.all_reduce(avg_acces, op=dist.ReduceOp.AVG) + dist.all_reduce(avg_alphas, op=dist.ReduceOp.AVG) # Simulated acceptance length: acc_0 + acc_0*acc_1 + acc_0*acc_1*acc_2 + ... # Models the expected number of consecutively accepted draft tokens, @@ -409,6 +422,7 @@ def _aggregate_metrics( metrics = { "train/avg_loss": weighted_avg_loss, "train/avg_acc": avg_acces.mean().item(), + "train/avg_alpha": avg_alphas.mean().item(), "train/simulated_acc_len": simulated_acc_len, "train/grad_norm": grad_norm.item() if grad_norm is not None else 0.0, "train/global_step": self.global_step, @@ -419,6 +433,7 @@ def _aggregate_metrics( for i in range(avg_plosses.shape[0]): metrics[f"train/ploss_{i}"] = avg_plosses[i].item() metrics[f"train/acc_{i}"] = avg_acces[i].item() + metrics[f"train/alpha_{i}"] = avg_alphas[i].item() if dist.get_rank() == 0: logger.debug(f"step {step}: {metrics}")