File tree Expand file tree Collapse file tree 1 file changed +31
-0
lines changed
tritonbench/operators/decoding_attention Expand file tree Collapse file tree 1 file changed +31
-0
lines changed Original file line number Diff line number Diff line change 72
72
HAS_AITER = False
73
73
74
74
75
+ # [Optional] flash_fwd cute-DSL backend
76
+ HAS_FLASH_CUTE = True
77
+ try :
78
+ from flash_attn .cute .interface import (
79
+ flash_attn_func as flash_attn_cute_func
80
+ )
81
+ except (ImportError , IOError , AttributeError ):
82
+ HAS_FLASH_CUTE = False
83
+ flash_attn_cute_func = None # Define it as None to avoid NameError
84
+
85
+
75
86
def parse_op_args (args : List [str ]):
76
87
parser = argparse .ArgumentParser ()
77
88
parser .add_argument ("--batch" , type = int , help = "Batch size" )
@@ -559,6 +570,26 @@ def fbgemm_gqa_fp8kv(
559
570
cache_logical_dtype_int = 1 , # FP8 = 1
560
571
)
561
572
573
+
574
+ @register_benchmark (enabled = HAS_FLASH_CUTE )
575
+ def flash_cute_dsl (
576
+ self ,
577
+ q : torch .Tensor ,
578
+ k_cache : torch .Tensor ,
579
+ v_cache : torch .Tensor ,
580
+ cache_seqlens : torch .Tensor ,
581
+ ) -> Callable :
582
+ """Flash Attention implementation using cute-DSL backend."""
583
+ # For GQA, cute-DSL handles the head expansion internally
584
+ # We pass the original KV tensors without manual expansion
585
+ q_heads = q .shape [2 ]
586
+ kv_heads = k_cache .shape [2 ]
587
+ return lambda :flash_attn_cute_func (
588
+ q , k_cache , v_cache ,
589
+ causal = CAUSAL ,
590
+ pack_gqa = (q_heads != kv_heads )
591
+ )
592
+
562
593
@register_benchmark (enabled = HAS_AITER )
563
594
def aiter_paged_fp8kv (
564
595
self ,
You can’t perform that action at this time.
0 commit comments