Disable Triton sigmoid focal loss path when gamma == 0#484
Disable Triton sigmoid focal loss path when gamma == 0#484jushanshine wants to merge 1 commit intofacebookresearch:mainfrom
Conversation
|
Hi @jushanshine! Thank you for your pull request and welcome to our community. Action RequiredIn order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you. ProcessIn order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks! |
|
Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks! |
Summary
This fixes a numerical issue in the Triton sigmoid focal loss backward path when
gamma == 0.In the current Triton kernel, backward computes:
When
gamma == 0, that becomes(1 - p_t) ** (-1). For positive targets with large positive logits,p_tapproaches1, so this term can diverge and produceinf/nangradients.In local reproduction, this surfaced as:
SigmoidFocalLossReducedBackwardanomaly-detection failureLoss is nanduring trainingChange
This PR adds a small guard in the focal-loss wrapper:
tritonis enabled andgamma == 0.0This keeps Triton enabled for the usual
gamma > 0case and only bypasses the numerically invalid edge case.Why Wrapper Fallback
I chose the wrapper fallback instead of changing the Triton kernel directly because:
gamma == 0If preferred, a follow-up could implement a Triton-side special case for
gamma == 0, but the wrapper guard is the smallest safe fix.Why This Is Safe
gamma == 0edge case.Reproduction
The issue was reproduced from a saved training checkpoint by replaying the real train order (
forward -> backward) with anomaly detection enabled.Observed failure mode:
epoch=1, step=0, chunk=1SigmoidFocalLossReducedBackwardFor the exact failing
presencetensors:Validation
After this change:
Files
sam3/train/loss/loss_fns.py