Skip to content

Commit 1d04c58

Browse files
Aya-ZIbrafacebook-github-bot
authored andcommitted
FAv4 CuteDSL Bench for decode
Summary: as title Differential Revision: D80830933
1 parent 41b2c7c commit 1d04c58

File tree

1 file changed

+31
-0
lines changed

1 file changed

+31
-0
lines changed

tritonbench/operators/decoding_attention/operator.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,17 @@
7272
HAS_AITER = False
7373

7474

75+
# [Optional] flash_fwd cute-DSL backend
76+
HAS_FLASH_CUTE = True
77+
try:
78+
from flash_attn.cute.interface import (
79+
flash_attn_func as flash_attn_cute_func
80+
)
81+
except (ImportError, IOError, AttributeError):
82+
HAS_FLASH_CUTE = False
83+
flash_attn_cute_func = None # Define it as None to avoid NameError
84+
85+
7586
def parse_op_args(args: List[str]):
7687
parser = argparse.ArgumentParser()
7788
parser.add_argument("--batch", type=int, help="Batch size")
@@ -559,6 +570,26 @@ def fbgemm_gqa_fp8kv(
559570
cache_logical_dtype_int=1, # FP8 = 1
560571
)
561572

573+
574+
@register_benchmark(enabled=HAS_FLASH_CUTE)
575+
def flash_cute_dsl(
576+
self,
577+
q: torch.Tensor,
578+
k_cache: torch.Tensor,
579+
v_cache: torch.Tensor,
580+
cache_seqlens: torch.Tensor,
581+
) -> Callable:
582+
"""Flash Attention implementation using cute-DSL backend."""
583+
# For GQA, cute-DSL handles the head expansion internally
584+
# We pass the original KV tensors without manual expansion
585+
q_heads = q.shape[2]
586+
kv_heads = k_cache.shape[2]
587+
return lambda:flash_attn_cute_func(
588+
q, k_cache, v_cache,
589+
causal=CAUSAL,
590+
pack_gqa=(q_heads != kv_heads)
591+
)
592+
562593
@register_benchmark(enabled=HAS_AITER)
563594
def aiter_paged_fp8kv(
564595
self,

0 commit comments

Comments
 (0)