Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 38 additions & 7 deletions src/diffusers/models/attention_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
import inspect
import math
from enum import Enum
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Tuple, Union
from functools import partial
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union

import torch

Expand Down Expand Up @@ -83,12 +84,20 @@
raise ImportError(
"To use FA3 kernel for your hardware from the Hub, the `kernels` library must be installed. Install with `pip install kernels`."
)
from ..utils.kernels_utils import _get_fa3_from_hub
from ..utils.kernels_utils import _DEFAULT_HUB_ID_FA3, _DEFAULT_HUB_ID_SAGE, _get_kernel_from_hub
from ..utils.sage_utils import _get_sage_attn_fn_for_device

flash_attn_interface_hub = _get_fa3_from_hub()
flash_attn_interface_hub = _get_kernel_from_hub(_DEFAULT_HUB_ID_FA3)
flash_attn_3_func_hub = flash_attn_interface_hub.flash_attn_func

sage_interface_hub = _get_kernel_from_hub(_DEFAULT_HUB_ID_SAGE)
sage_fn_with_kwargs = _get_sage_attn_fn_for_device()
sage_attn_func_hub = getattr(sage_interface_hub, sage_fn_with_kwargs["func"])
sage_attn_func_hub = partial(sage_attn_func_hub, **sage_fn_with_kwargs["kwargs"])

else:
flash_attn_3_func_hub = None
sage_attn_func_hub = None

if _CAN_USE_SAGE_ATTN:
from sageattention import (
Expand Down Expand Up @@ -162,10 +171,6 @@ def wrap(func):
# - CP with sage attention, flex, xformers, other missing backends
# - Add support for normal and CP training with backends that don't support it yet

_SAGE_ATTENTION_PV_ACCUM_DTYPE = Literal["fp32", "fp32+fp32"]
_SAGE_ATTENTION_QK_QUANT_GRAN = Literal["per_thread", "per_warp"]
_SAGE_ATTENTION_QUANTIZATION_BACKEND = Literal["cuda", "triton"]
Comment on lines -165 to -167
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see their usage, hence removed.



class AttentionBackendName(str, Enum):
# EAGER = "eager"
Expand All @@ -190,6 +195,7 @@ class AttentionBackendName(str, Enum):

# `sageattention`
SAGE = "sage"
SAGE_HUB = "sage_hub"
SAGE_VARLEN = "sage_varlen"
_SAGE_QK_INT8_PV_FP8_CUDA = "_sage_qk_int8_pv_fp8_cuda"
_SAGE_QK_INT8_PV_FP8_CUDA_SM90 = "_sage_qk_int8_pv_fp8_cuda_sm90"
Expand Down Expand Up @@ -1756,6 +1762,31 @@ def _sage_attention(
return (out, lse) if return_lse else out


@_AttentionBackendRegistry.register(
AttentionBackendName.SAGE_HUB,
constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
supports_context_parallel=False,
)
def _sage_attention_hub(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
is_causal: bool = False,
scale: Optional[float] = None,
return_lse: bool = False,
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
lse = None
if _parallel_config is None:
out = sage_attn_func_hub(q=query, k=key, v=value)
if return_lse:
out, lse, *_ = out
else:
raise NotImplementedError("SAGE attention doesn't yet support parallelism.")

return (out, lse) if return_lse else out


@_AttentionBackendRegistry.register(
AttentionBackendName.SAGE_VARLEN,
constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
Expand Down
17 changes: 12 additions & 5 deletions src/diffusers/utils/kernels_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,25 @@


_DEFAULT_HUB_ID_FA3 = "kernels-community/flash-attn3"
_DEFAULT_HUB_ID_SAGE = "kernels-community/sage_attention"
_KERNEL_REVISION = {
# TODO: temporary revision for now. Remove when merged upstream into `main`.
_DEFAULT_HUB_ID_FA3: "fake-ops-return-probs",
_DEFAULT_HUB_ID_SAGE: "compile",
}


def _get_fa3_from_hub():
def _get_kernel_from_hub(kernel_id):
if not is_kernels_available():
return None
else:
from kernels import get_kernel

try:
# TODO: temporary revision for now. Remove when merged upstream into `main`.
flash_attn_3_hub = get_kernel(_DEFAULT_HUB_ID_FA3, revision="fake-ops-return-probs")
return flash_attn_3_hub
if kernel_id not in _KERNEL_REVISION:
raise NotImplementedError(f"{kernel_id} is not implemented in Diffusers.")
kernel_hub = get_kernel(kernel_id, revision=_KERNEL_REVISION.get(kernel_id))
return kernel_hub
except Exception as e:
logger.error(f"An error occurred while fetching kernel '{_DEFAULT_HUB_ID_FA3}' from the Hub: {e}")
logger.error(f"An error occurred while fetching kernel '{kernel_id}' from the Hub: {e}")
raise
137 changes: 137 additions & 0 deletions src/diffusers/utils/sage_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
"""
Copyright (c) 2024 by SageAttention, The HuggingFace team.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the
License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
"""

"""
Modified from
https://github.com/thu-ml/SageAttention/blob/68de3797d163b89d28f9a38026c3b7313f6940d2/sageattention/core.py
"""


import torch # noqa


SAGE_ATTENTION_DISPATCH = {
"sm80": {
"func": "sageattn_qk_int8_pv_fp16_cuda",
"kwargs": {
"tensor_layout": "NHD",
"is_causal": False,
"sm_scale": None,
"return_lse": False,
"pv_accum_dtype": "fp32",
},
},
"sm89": {
"func": "sageattn_qk_int8_pv_fp8_cuda",
"kwargs": {
"tensor_layout": "NHD",
"is_causal": False,
"sm_scale": None,
"return_lse": False,
"pv_accum_dtype": "fp32+fp16",
},
},
"sm90": {
"func": "sageattn_qk_int8_pv_fp8_cuda_sm90",
"kwargs": {
"tensor_layout": "NHD",
"is_causal": False,
"sm_scale": None,
"return_lse": False,
"pv_accum_dtype": "fp32+fp32",
},
},
"sm120": {
"func": "sageattn_qk_int8_pv_fp8_cuda",
"kwargs": {
"tensor_layout": "NHD",
"is_causal": False,
"qk_quant_gran": "per_warp",
"sm_scale": None,
"return_lse": False,
"pv_accum_dtype": "fp32+fp16",
},
},
}


def get_cuda_version():
if torch.cuda.is_available():
major, minor = torch.cuda.get_device_capability()
return major, minor
else:
raise EnvironmentError("CUDA not found.")


def get_cuda_arch_versions():
if not torch.cuda.is_available():
EnvironmentError("CUDA not found.")
cuda_archs = []
for i in range(torch.cuda.device_count()):
major, minor = torch.cuda.get_device_capability(i)
cuda_archs.append(f"sm{major}{minor}")
return cuda_archs


# Unlike the actual implementation, we just maintain function names rather than actual
# implementations.
def _get_sage_attn_fn_for_device():
"""
Automatically selects the appropriate implementation of the SageAttention kernel based on the GPU compute
capability.

Parameters ---------- q : torch.Tensor
The query tensor. Shape:
- If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
- If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.

k : torch.Tensor
The key tensor. Shape:
- If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
- If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.

v : torch.Tensor
The value tensor. Shape:
- If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
- If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.

tensor_layout : str
The tensor layout, either "HND" or "NHD". Default: "HND".

is_causal : bool
Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len. Default: False.

sm_scale : Optional[float]
The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``.

return_lse : bool
Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention.
Default: False.

Returns ------- torch.Tensor
The output tensor. Shape:
- If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
- If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.

torch.Tensor
The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). Shape:
``[batch_size, num_qo_heads, qo_len]``. Only returned if `return_lse` is True.

Note ----
- ``num_qo_heads`` must be divisible by ``num_kv_heads``.
- The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16``
- All tensors must be on the same cuda device.
"""
device_index = torch.cuda.current_device()
arch = get_cuda_arch_versions()[device_index]
return SAGE_ATTENTION_DISPATCH[arch]
Loading