diff --git a/sam3/train/loss/loss_fns.py b/sam3/train/loss/loss_fns.py index 8fa17741e..aa4ed7cea 100644 --- a/sam3/train/loss/loss_fns.py +++ b/sam3/train/loss/loss_fns.py @@ -149,6 +149,10 @@ def sigmoid_focal_loss( """ if not (0 <= alpha <= 1) and triton: raise RuntimeError(f"Alpha should be in [0,1], got {alpha}") + # The Triton backward path is numerically invalid for gamma == 0 because it + # computes (1 - p_t) ** (gamma - 1), i.e. an inverse power at zero. + if triton and float(gamma) == 0.0: + triton = False if triton: if reduce and not loss_on_multimask: loss = triton_sigmoid_focal_loss_reduce(inputs, targets, alpha, gamma)