diff --git a/python/paddle/nn/functional/flash_attention.py b/python/paddle/nn/functional/flash_attention.py index 2d1b050cdba6e7..75feccd2f7a5a2 100644 --- a/python/paddle/nn/functional/flash_attention.py +++ b/python/paddle/nn/functional/flash_attention.py @@ -1453,6 +1453,7 @@ def scaled_dot_product_attention( dropout_p, is_causal, ) + return out if attn_mask is None: # downgraded to ordinary flash attention implementation