-
Notifications
You must be signed in to change notification settings - Fork 6.4k
[core] support sage attention through kernels
#12439
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very cool ! I will try to look into the torch compile compatibility, but for the other variants, they are the same as sageattn
, what i mean is sageattn
is just a wrapper that dispatches to the correct kernel depending on the hardware used : https://github.com/thu-ml/SageAttention/blob/main/sageattention/core.py#L140
So, you mean we shouldn't have to have different dispatched functions like this?
|
Yes I think we don't need that because it depends on the hardware. For example if a user chooses : |
_SAGE_ATTENTION_PV_ACCUM_DTYPE = Literal["fp32", "fp32+fp32"] | ||
_SAGE_ATTENTION_QK_QUANT_GRAN = Literal["per_thread", "per_warp"] | ||
_SAGE_ATTENTION_QUANTIZATION_BACKEND = Literal["cuda", "triton"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see their usage, hence removed.
What does this PR do?
Code to test:
Result:

Notes
torch.compile
support when using sage attention like we have for flash and flash 3. Currently, this fails.Code to test
Error: https://pastebin.com/3HS6HNzR
sageattn
variants (see here), which would be cool to expose from the Hub kernel.Cc: @MekkCyber