Skip to content

Disable Triton sigmoid focal loss path when gamma == 0#484

Open
jushanshine wants to merge 1 commit intofacebookresearch:mainfrom
jushanshine:fix/triton-focal-gamma-zero
Open

Disable Triton sigmoid focal loss path when gamma == 0#484
jushanshine wants to merge 1 commit intofacebookresearch:mainfrom
jushanshine:fix/triton-focal-gamma-zero

Conversation

@jushanshine
Copy link

Summary

This fixes a numerical issue in the Triton sigmoid focal loss backward path when gamma == 0.

In the current Triton kernel, backward computes:

tmp = libdevice.pow(1 - p_t, gamma - 1)

When gamma == 0, that becomes (1 - p_t) ** (-1). For positive targets with large positive logits, p_t approaches 1, so this term can diverge and produce inf/nan gradients.

In local reproduction, this surfaced as:

  • non-finite matcher inputs
  • SigmoidFocalLossReducedBackward anomaly-detection failure
  • eventual Loss is nan during training

Change

This PR adds a small guard in the focal-loss wrapper:

  • if triton is enabled and gamma == 0.0
  • automatically fall back to the existing PyTorch implementation

This keeps Triton enabled for the usual gamma > 0 case and only bypasses the numerically invalid edge case.

Why Wrapper Fallback

I chose the wrapper fallback instead of changing the Triton kernel directly because:

  • it preserves the exact focal-loss semantics for gamma == 0
  • it is a minimal, low-risk fix for a clearly defined edge case
  • it avoids adding Triton-side branching and special-case backward logic in the custom kernel

If preferred, a follow-up could implement a Triton-side special case for gamma == 0, but the wrapper guard is the smallest safe fix.

Why This Is Safe

  • The fallback path already exists and is used by the same wrapper.
  • The change is minimal and only affects the gamma == 0 edge case.
  • No behavior changes for the normal Triton-accelerated focal-loss configuration.

Reproduction

The issue was reproduced from a saved training checkpoint by replaying the real train order (forward -> backward) with anomaly detection enabled.

Observed failure mode:

  • first bad point at epoch=1, step=0, chunk=1
  • forward values remained finite
  • first non-finite values appeared during backward
  • anomaly detection pointed to SigmoidFocalLossReducedBackward

For the exact failing presence tensors:

  • Triton backward produced non-finite gradients
  • PyTorch fallback stayed finite

Validation

After this change:

  • the minimal failing replay point no longer produced non-finite gradients
  • replay over multiple consecutive training steps completed without non-finite events
  • the same training job could be restarted on the patched codebase

Files

  • sam3/train/loss/loss_fns.py

@meta-cla
Copy link

meta-cla bot commented Mar 6, 2026

Hi @jushanshine!

Thank you for your pull request and welcome to our community.

Action Required

In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks!

@meta-cla
Copy link

meta-cla bot commented Mar 6, 2026

Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks!

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Mar 6, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant