Skip to content
67 changes: 47 additions & 20 deletions src/diffusers/models/attention_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import functools
import inspect
import math
from dataclasses import dataclass
from enum import Enum
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Tuple, Union

Expand All @@ -42,7 +43,7 @@
is_xformers_available,
is_xformers_version,
)
from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS, DIFFUSERS_ENABLE_HUB_KERNELS
from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS


if TYPE_CHECKING:
Expand Down Expand Up @@ -82,24 +83,11 @@
flash_attn_3_func = None
flash_attn_3_varlen_func = None


if _CAN_USE_AITER_ATTN:
from aiter import flash_attn_func as aiter_flash_attn_func
else:
aiter_flash_attn_func = None

if DIFFUSERS_ENABLE_HUB_KERNELS:
if not is_kernels_available():
raise ImportError(
"To use FA3 kernel for your hardware from the Hub, the `kernels` library must be installed. Install with `pip install kernels`."
)
from ..utils.kernels_utils import _get_fa3_from_hub

flash_attn_interface_hub = _get_fa3_from_hub()
flash_attn_3_func_hub = flash_attn_interface_hub.flash_attn_func
else:
flash_attn_3_func_hub = None

if _CAN_USE_SAGE_ATTN:
from sageattention import (
sageattn,
Expand Down Expand Up @@ -262,6 +250,25 @@ def _is_context_parallel_enabled(
return supports_context_parallel and is_degree_greater_than_1


@dataclass
class _HubKernelConfig:
"""Configuration for downloading and using a hub-based attention kernel."""

repo_id: str
function_attr: str
revision: Optional[str] = None
kernel_fn: Optional[Callable] = None


# Registry for hub-based attention kernels
_HUB_KERNELS_REGISTRY: Dict["AttentionBackendName", _HubKernelConfig] = {
# TODO: temporary revision for now. Remove when merged upstream into `main`.
AttentionBackendName._FLASH_3_HUB: _HubKernelConfig(
repo_id="kernels-community/flash-attn3", function_attr="flash_attn_func", revision="fake-ops-return-probs"
)
}
Comment on lines +264 to +269

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about the other backends ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will be populated as we incorporate others.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean shouldn't FA2 be here already ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah okay sounds good !



@contextlib.contextmanager
def attention_backend(backend: Union[str, AttentionBackendName] = AttentionBackendName.NATIVE):
"""
Expand Down Expand Up @@ -418,13 +425,9 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None

# TODO: add support Hub variant of FA3 varlen later
elif backend in [AttentionBackendName._FLASH_3_HUB]:
if not DIFFUSERS_ENABLE_HUB_KERNELS:
raise RuntimeError(
f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `DIFFUSERS_ENABLE_HUB_KERNELS` env var isn't set. Please set it like `export DIFFUSERS_ENABLE_HUB_KERNELS=yes`."
)
if not is_kernels_available():
raise RuntimeError(
f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`."
f"Backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`."
)

elif backend == AttentionBackendName.AITER:
Expand Down Expand Up @@ -574,6 +577,29 @@ def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
return q_idx >= kv_idx


# ===== Helpers for downloading kernels =====
def _maybe_download_kernel_for_backend(backend: AttentionBackendName) -> None:
if backend not in _HUB_KERNELS_REGISTRY:
return
config = _HUB_KERNELS_REGISTRY[backend]

if config.kernel_fn is not None:
return

try:
from kernels import get_kernel

kernel_module = get_kernel(config.repo_id, revision=config.revision)
kernel_func = getattr(kernel_module, config.function_attr)

# Cache the downloaded kernel function in the config object
config.kernel_fn = kernel_func

except Exception as e:
logger.error(f"An error occurred while fetching kernel '{config.repo_id}' from the Hub: {e}")
raise


# ===== torch op registrations =====
# Registrations are required for fullgraph tracing compatibility
# TODO: this is only required because the beta release FA3 does not have it. There is a PR adding
Expand Down Expand Up @@ -1341,7 +1367,8 @@ def _flash_attention_3_hub(
return_attn_probs: bool = False,
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
out = flash_attn_3_func_hub(
func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB].kernel_fn
out = func(
q=query,
k=key,
v=value,
Expand Down
8 changes: 7 additions & 1 deletion src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,7 +595,11 @@ def set_attention_backend(self, backend: str) -> None:
attention as backend.
"""
from .attention import AttentionModuleMixin
from .attention_dispatch import AttentionBackendName, _check_attention_backend_requirements
from .attention_dispatch import (
AttentionBackendName,
_check_attention_backend_requirements,
_maybe_download_kernel_for_backend,
)

# TODO: the following will not be required when everything is refactored to AttentionModuleMixin
from .attention_processor import Attention, MochiAttention
Expand All @@ -606,8 +610,10 @@ def set_attention_backend(self, backend: str) -> None:
available_backends = {x.value for x in AttentionBackendName.__members__.values()}
if backend not in available_backends:
raise ValueError(f"`{backend=}` must be one of the following: " + ", ".join(available_backends))

backend = AttentionBackendName(backend)
_check_attention_backend_requirements(backend)
_maybe_download_kernel_for_backend(backend)

attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
for module in self.modules():
Expand Down
1 change: 0 additions & 1 deletion src/diffusers/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
DEFAULT_HF_PARALLEL_LOADING_WORKERS = 8
HF_ENABLE_PARALLEL_LOADING = os.environ.get("HF_ENABLE_PARALLEL_LOADING", "").upper() in ENV_VARS_TRUE_VALUES
DIFFUSERS_DISABLE_REMOTE_CODE = os.getenv("DIFFUSERS_DISABLE_REMOTE_CODE", "false").upper() in ENV_VARS_TRUE_VALUES
DIFFUSERS_ENABLE_HUB_KERNELS = os.environ.get("DIFFUSERS_ENABLE_HUB_KERNELS", "").upper() in ENV_VARS_TRUE_VALUES

# Below should be `True` if the current version of `peft` and `transformers` are compatible with
# PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are
Expand Down
23 changes: 0 additions & 23 deletions src/diffusers/utils/kernels_utils.py

This file was deleted.

1 change: 0 additions & 1 deletion tests/others/test_attention_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

```bash
export RUN_ATTENTION_BACKEND_TESTS=yes
export DIFFUSERS_ENABLE_HUB_KERNELS=yes

pytest tests/others/test_attention_backends.py
```
Expand Down
Loading