From f3d75066f1e0396a378a99a50884fdcf225e1a6b Mon Sep 17 00:00:00 2001 From: autodl Date: Fri, 6 Mar 2026 15:33:35 +0800 Subject: [PATCH] loss: disable triton focal path when gamma is zero --- sam3/train/loss/loss_fns.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/sam3/train/loss/loss_fns.py b/sam3/train/loss/loss_fns.py index 8fa17741e..aa4ed7cea 100644 --- a/sam3/train/loss/loss_fns.py +++ b/sam3/train/loss/loss_fns.py @@ -149,6 +149,10 @@ def sigmoid_focal_loss( """ if not (0 <= alpha <= 1) and triton: raise RuntimeError(f"Alpha should be in [0,1], got {alpha}") + # The Triton backward path is numerically invalid for gamma == 0 because it + # computes (1 - p_t) ** (gamma - 1), i.e. an inverse power at zero. + if triton and float(gamma) == 0.0: + triton = False if triton: if reduce and not loss_on_multimask: loss = triton_sigmoid_focal_loss_reduce(inputs, targets, alpha, gamma)