From 2b5cc79bfacf76c8b2ee8b9b9ca85b0f9d4463fa Mon Sep 17 00:00:00 2001 From: Aya Ibrahim Date: Thu, 4 Sep 2025 09:02:00 -0700 Subject: [PATCH 1/2] FAv4 CuteDSL Bench for decode Summary: use of headq = 8 , is doing much better. Maybe because headq= 5 probably doesn't work with TMA_q used here. Differential Revision: D80830933 --- .../operators/decoding_attention/operator.py | 31 +++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/tritonbench/operators/decoding_attention/operator.py b/tritonbench/operators/decoding_attention/operator.py index 680b0d5df..e5ff8096f 100644 --- a/tritonbench/operators/decoding_attention/operator.py +++ b/tritonbench/operators/decoding_attention/operator.py @@ -72,6 +72,17 @@ HAS_AITER = False +# [Optional] flash_fwd cute-DSL backend +HAS_FLASH_CUTE = True +try: + from flash_attn.cute.interface import ( + flash_attn_func as flash_attn_cute_func + ) +except (ImportError, IOError, AttributeError): + HAS_FLASH_CUTE = False + flash_attn_cute_func = None # Define it as None to avoid NameError + + def parse_op_args(args: List[str]): parser = argparse.ArgumentParser() parser.add_argument("--batch", type=int, help="Batch size") @@ -559,6 +570,26 @@ def fbgemm_gqa_fp8kv( cache_logical_dtype_int=1, # FP8 = 1 ) + + @register_benchmark(enabled=HAS_FLASH_CUTE) + def flash_cute_dsl( + self, + q: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + cache_seqlens: torch.Tensor, + ) -> Callable: + """Flash Attention implementation using cute-DSL backend.""" + # For GQA, cute-DSL handles the head expansion internally + # We pass the original KV tensors without manual expansion + q_heads = q.shape[2] + kv_heads = k_cache.shape[2] + return lambda:flash_attn_cute_func( + q, k_cache, v_cache, + causal=CAUSAL, + pack_gqa=(q_heads != kv_heads) + ) + @register_benchmark(enabled=HAS_AITER) def aiter_paged_fp8kv( self, From 7a4063e25c738a46f28ed0985f6db5e1b2c266ad Mon Sep 17 00:00:00 2001 From: Aya Ibrahim Date: Thu, 4 Sep 2025 09:06:32 -0700 Subject: [PATCH 2/2] Add trtlllm to triton bench (#379) Summary: Pull Request resolved: https://github.com/meta-pytorch/tritonbench/pull/379 Run C++ FLASHINFER_CUBIN_DIR=/data/users/$USER/fbsource/fbcode/deeplearning/flashinfer/fb/cubins/ buck2 run mode/opt mode/inplace -c fbcode.enable_gpu_sections=true -c fbcode.nvcc_arch=b200a -c fbcode.platform010_cuda_version=12.8 //deeplearning/flashinfer/trtllm_kernel_interfaces:run_example``` ------- Run Triton bench buck2 run mode/opt mode/inplace -c fbcode.enable_gpu_sections=true -c fbcode.nvcc_arch=b200a -c fbcode.platform010_cuda_version=12.8 //pytorch/tritonbench:run -- --op decoding_attention --only trtllm_decode_fmha --seq-len-q 1 --metrics gbps Todo: Support non-paged case Differential Revision: D81021980 --- .../operators/decoding_attention/operator.py | 19 ++++++ .../decoding_attention/trtllm_utils.py | 68 +++++++++++++++++++ 2 files changed, 87 insertions(+) create mode 100644 tritonbench/operators/decoding_attention/trtllm_utils.py diff --git a/tritonbench/operators/decoding_attention/operator.py b/tritonbench/operators/decoding_attention/operator.py index e5ff8096f..6ddcd9447 100644 --- a/tritonbench/operators/decoding_attention/operator.py +++ b/tritonbench/operators/decoding_attention/operator.py @@ -55,6 +55,11 @@ torch.ops.load_library( "//deeplearning/fbgemm/fbgemm_gpu/experimental:gen_ai_attention_ops" ) +torch.ops.load_library( + "//deeplearning/flashinfer/trtllm_kernel_interfaces:trtllm_fmha_pybind" +) + +from .trtllm_utils import trtllm_decode_fmha_func from tritonbench.utils.triton_op import ( BenchmarkOperator, @@ -660,3 +665,17 @@ def aiter_paged_fp8kv( k_scale_asm, v_scale_asm, ) + + @register_benchmark() + def trtllm_decode_fmha( + self, + q: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + cache_seqlens: torch.Tensor, + ) -> Callable: + + args = trtllm_decode_fmha_func(q, k_cache, v_cache, cache_seqlens) + return lambda: torch.ops.trtllm_kernel_interfaces.trtllm_decode_fmha( + *args + ) diff --git a/tritonbench/operators/decoding_attention/trtllm_utils.py b/tritonbench/operators/decoding_attention/trtllm_utils.py new file mode 100644 index 000000000..cc257a94e --- /dev/null +++ b/tritonbench/operators/decoding_attention/trtllm_utils.py @@ -0,0 +1,68 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +TRTLLM FMHA utility functions for handling tensor conversion and kernel preparation. +""" + +import torch + + +def trtllm_decode_fmha_func(q, k_cache, v_cache, cache_seqlens): + """ + TRTLLM FMHA decode function that converts standard tensors to paged format + and calls the TRTLLM FMHA kernel via PyBind extension. + """ + + device = q.device + # Convert input tensors to paged format for TRTLLM FMHA + batch_size, seq_len_q, num_qo_heads, head_dim = q.shape + _, max_seq_len_kv, num_kv_heads, _ = k_cache.shape + + # Use page size of 16 for TRTLLM FMHA + page_size = 16 + max_num_blocks_per_seq = (max_seq_len_kv + page_size - 1) // page_size + total_pages = batch_size * max_num_blocks_per_seq + + # Reshape k_cache and v_cache to paged format [total_pages, num_kv_heads, page_size, head_dim] + k_cache_paged = k_cache.view(batch_size, max_num_blocks_per_seq, page_size, num_kv_heads, head_dim) + k_cache_paged = k_cache_paged.permute(0, 1, 3, 2, 4).contiguous() + k_cache_paged = k_cache_paged.view(total_pages, num_kv_heads, page_size, head_dim) + + v_cache_paged = v_cache.view(batch_size, max_num_blocks_per_seq, page_size, num_kv_heads, head_dim) + v_cache_paged = v_cache_paged.permute(0, 1, 3, 2, 4).contiguous() + v_cache_paged = v_cache_paged.view(total_pages, num_kv_heads, page_size, head_dim) + + # Create block tables + block_tables = torch.zeros( + (batch_size, max_num_blocks_per_seq), + dtype=torch.int32, + device=device + ) + for i in range(batch_size): + for j in range(max_num_blocks_per_seq): + block_tables[i, j] = i * max_num_blocks_per_seq + j + + # Create output tensor + out = torch.zeros_like(q) + + # Create workspace buffer + workspace_size = 128 * 1024 * 1024 # 128MB + workspace_buffer = torch.zeros(workspace_size, dtype=torch.uint8, device=device) + + # Attention parameters + max_seq_len = cache_seqlens.max().item() + bmm1_scale = 1.0 / (head_dim ** 0.5) + bmm2_scale = 1.0 + window_left = -1 # No sliding window + sm_count = torch.cuda.get_device_properties(device).multi_processor_count + + args =( + out, q, k_cache_paged, v_cache_paged, workspace_buffer, + block_tables, cache_seqlens, max_seq_len, + bmm1_scale, bmm2_scale, window_left, sm_count + ) + return args