-
Couldn't load subscription status.
- Fork 6.5k
[core] Refactor hub attn kernels #12475
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very good to only load the invoked attention implementation ! Thanks for adding this
|
I think the control flow here is a bit difficult to follow. We should aim to minimize the number of new objects/concepts introduced in this module since there's already quite a lot of routing going on in here. My recommendations:
def _set_attention_backend(backend: AttentionBackendName) -> None:
_check_attention_backend_requirements(backend)
_maybe_download_kernel_for_backend(backend)In attention dispatch, let's create a @dataclass
class _HubKernelConfig:
"""Configuration for downloading and using a hub-based attention kernel."""
repo_id: str
function_attr: str
revision: Optional[str] = None
kernel_fn: Optional[Callable] = None
# Registry for hub-based attention kernels
_HUB_KERNELS_REGISTRY: Dict["AttentionBackendName", _HubKernelConfig] = {
AttentionBackendName._FLASH_3_HUB: _HubKernelConfig(
repo_id="kernels-community/flash-attn3", function_attr="flash_attn_func", revision="fake-ops-return-probs"
)
}
Then in your hub function, fetch the downloaded kernel from the registry func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB]._kernel_fn
out = func(
q=query,
....We shouldn't attempt kernel downloads from the dispatch function. It should already be downloaded/available before hand. |
|
Okay but
|
This is also good 👍🏽
Something like def _maybe_download_kernel_for_backend(backend: AttentionBackendName) -> None:
if backend not in _HUB_KERNELS_REGISTRY:
return
config = _HUB_KERNELS_REGISTRY[backend]
if config._kernel_fn is not None:
return
try:
from kernels import get_kernel
kernel_module = get_kernel(config.repo_id, revision=config.revision)
kernel_func = getattr(kernel_module, config.function_attr)
# Cache the downloaded kernel function in the config object
config._kernel_fn = kernel_func
except Exception as e:
raise
|
Co-authored-by: Dhruv Nair <dhruv@huggingface.co>
|
@DN6 check now. Your feedback should have been addressed. I was able to completely get rid of |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice clean up! thanks
| _HUB_KERNELS_REGISTRY: Dict["AttentionBackendName", _HubKernelConfig] = { | ||
| # TODO: temporary revision for now. Remove when merged upstream into `main`. | ||
| AttentionBackendName._FLASH_3_HUB: _HubKernelConfig( | ||
| repo_id="kernels-community/flash-attn3", function_attr="flash_attn_func", revision="fake-ops-return-probs" | ||
| ) | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what about the other backends ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will be populated as we incorporate others.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mean shouldn't FA2 be here already ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah okay sounds good !
What does this PR do?
Refactors how we load the attention kernels from the Hub.
Currently, when a user specifies the
DIFFUSERS_ENABLE_HUB_KERNELSenv var, we always download the supported kernel. Currently, we have FA3, but we have ongoing PRs that support FA and SAGE: #12387 and #12439. So, we will download ALL of them even when they're not required. This is not good.This PR makes it so that only the relevant kernel gets downloaded without breaking
torch.compilecompliance (fullgraph and no recompilation triggers).Cc: @MekkCyber