Skip to content

Conversation

sayakpaul
Copy link
Member

@sayakpaul sayakpaul commented Oct 6, 2025

What does this PR do?

Code to test:

from diffusers import DiffusionPipeline 
import torch 

repo_id = "black-forest-labs/FLUX.1-dev"
pipe = DiffusionPipeline.from_pretrained(repo_id, torch_dtype=torch.bfloat16).to("cuda")
pipe.transformer.set_attention_backend("sage_hub")

image = pipe(
    prompt="a dog sitting by the sea, waiting for its companion to come",
    guidance_scale=3.5,
    num_inference_steps=30,
    max_sequence_length=512,
    generator=torch.manual_seed(0)
).images[0]
image.save("sage_flux.png")

Result:
image


Notes

  1. It would be nice to get torch.compile support when using sage attention like we have for flash and flash 3. Currently, this fails.
Code to test
from diffusers import DiffusionPipeline 
import torch 

repo_id = "black-forest-labs/FLUX.1-dev"
pipe = DiffusionPipeline.from_pretrained(repo_id, torch_dtype=torch.bfloat16).to("cuda")
pipe.transformer.set_attention_backend("sage_hub")
pipe.transformer.compile_repeated_blocks(fullgraph=True)

with (
    torch._inductor.utils.fresh_inductor_cache(),
    torch._dynamo.config.patch(error_on_recompile=True),
):
    image = pipe(
        prompt="a dog sitting by the sea, waiting for its companion to come",
        guidance_scale=3.5,
        num_inference_steps=30,
        max_sequence_length=512,
        generator=torch.manual_seed(0)
    ).images[0]
image.save("sage_flux.png")

Error: https://pastebin.com/3HS6HNzR

  1. We have other sageattn variants (see here), which would be cool to expose from the Hub kernel.

Cc: @MekkCyber

@sayakpaul sayakpaul added the performance Anything related to performance improvements, profiling and benchmarking label Oct 6, 2025
@sayakpaul sayakpaul requested a review from DN6 October 6, 2025 05:48
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link

@MekkCyber MekkCyber left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very cool ! I will try to look into the torch compile compatibility, but for the other variants, they are the same as sageattn, what i mean is sageattn is just a wrapper that dispatches to the correct kernel depending on the hardware used : https://github.com/thu-ml/SageAttention/blob/main/sageattention/core.py#L140

@sayakpaul
Copy link
Member Author

they are the same as sageattn, what i mean is sageattn is just a wrapper that dispatches to the correct kernel depending on the hardware used :

So, you mean we shouldn't have to have different dispatched functions like this?

_SAGE_QK_INT8_PV_FP8_CUDA = "_sage_qk_int8_pv_fp8_cuda"

@MekkCyber
Copy link

Yes I think we don't need that because it depends on the hardware. For example if a user chooses : _sage_qk_int8_pv_fp8_cuda on A100 (8.0) it will fail, because this function is only supported and compiled for 8.9 gpus

@sayakpaul sayakpaul marked this pull request as draft October 7, 2025 13:38
Comment on lines -165 to -167
_SAGE_ATTENTION_PV_ACCUM_DTYPE = Literal["fp32", "fp32+fp32"]
_SAGE_ATTENTION_QK_QUANT_GRAN = Literal["per_thread", "per_warp"]
_SAGE_ATTENTION_QUANTIZATION_BACKEND = Literal["cuda", "triton"]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see their usage, hence removed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
performance Anything related to performance improvements, profiling and benchmarking
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants