Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions tritonbench/operators/decoding_attention/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@
torch.ops.load_library(
"//deeplearning/fbgemm/fbgemm_gpu/experimental:gen_ai_attention_ops"
)
torch.ops.load_library(
"//deeplearning/flashinfer/trtllm_kernel_interfaces:trtllm_fmha_pybind"
)

from .trtllm_utils import trtllm_decode_fmha_func

from tritonbench.utils.triton_op import (
BenchmarkOperator,
Expand All @@ -72,6 +77,17 @@
HAS_AITER = False


# [Optional] flash_fwd cute-DSL backend
HAS_FLASH_CUTE = True
try:
from flash_attn.cute.interface import (
flash_attn_func as flash_attn_cute_func
)
except (ImportError, IOError, AttributeError):
HAS_FLASH_CUTE = False
flash_attn_cute_func = None # Define it as None to avoid NameError


def parse_op_args(args: List[str]):
parser = argparse.ArgumentParser()
parser.add_argument("--batch", type=int, help="Batch size")
Expand Down Expand Up @@ -559,6 +575,26 @@ def fbgemm_gqa_fp8kv(
cache_logical_dtype_int=1, # FP8 = 1
)


@register_benchmark(enabled=HAS_FLASH_CUTE)
def flash_cute_dsl(
self,
q: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
cache_seqlens: torch.Tensor,
) -> Callable:
"""Flash Attention implementation using cute-DSL backend."""
# For GQA, cute-DSL handles the head expansion internally
# We pass the original KV tensors without manual expansion
q_heads = q.shape[2]
kv_heads = k_cache.shape[2]
return lambda:flash_attn_cute_func(
q, k_cache, v_cache,
causal=CAUSAL,
pack_gqa=(q_heads != kv_heads)
)

@register_benchmark(enabled=HAS_AITER)
def aiter_paged_fp8kv(
self,
Expand Down Expand Up @@ -629,3 +665,17 @@ def aiter_paged_fp8kv(
k_scale_asm,
v_scale_asm,
)

@register_benchmark()
def trtllm_decode_fmha(
self,
q: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
cache_seqlens: torch.Tensor,
) -> Callable:

args = trtllm_decode_fmha_func(q, k_cache, v_cache, cache_seqlens)
return lambda: torch.ops.trtllm_kernel_interfaces.trtllm_decode_fmha(
*args
)
68 changes: 68 additions & 0 deletions tritonbench/operators/decoding_attention/trtllm_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""
TRTLLM FMHA utility functions for handling tensor conversion and kernel preparation.
"""

import torch


def trtllm_decode_fmha_func(q, k_cache, v_cache, cache_seqlens):
"""
TRTLLM FMHA decode function that converts standard tensors to paged format
and calls the TRTLLM FMHA kernel via PyBind extension.
"""

device = q.device
# Convert input tensors to paged format for TRTLLM FMHA
batch_size, seq_len_q, num_qo_heads, head_dim = q.shape
_, max_seq_len_kv, num_kv_heads, _ = k_cache.shape

# Use page size of 16 for TRTLLM FMHA
page_size = 16
max_num_blocks_per_seq = (max_seq_len_kv + page_size - 1) // page_size
total_pages = batch_size * max_num_blocks_per_seq

# Reshape k_cache and v_cache to paged format [total_pages, num_kv_heads, page_size, head_dim]
k_cache_paged = k_cache.view(batch_size, max_num_blocks_per_seq, page_size, num_kv_heads, head_dim)
k_cache_paged = k_cache_paged.permute(0, 1, 3, 2, 4).contiguous()
k_cache_paged = k_cache_paged.view(total_pages, num_kv_heads, page_size, head_dim)

v_cache_paged = v_cache.view(batch_size, max_num_blocks_per_seq, page_size, num_kv_heads, head_dim)
v_cache_paged = v_cache_paged.permute(0, 1, 3, 2, 4).contiguous()
v_cache_paged = v_cache_paged.view(total_pages, num_kv_heads, page_size, head_dim)

# Create block tables
block_tables = torch.zeros(
(batch_size, max_num_blocks_per_seq),
dtype=torch.int32,
device=device
)
for i in range(batch_size):
for j in range(max_num_blocks_per_seq):
block_tables[i, j] = i * max_num_blocks_per_seq + j

# Create output tensor
out = torch.zeros_like(q)

# Create workspace buffer
workspace_size = 128 * 1024 * 1024 # 128MB
workspace_buffer = torch.zeros(workspace_size, dtype=torch.uint8, device=device)

# Attention parameters
max_seq_len = cache_seqlens.max().item()
bmm1_scale = 1.0 / (head_dim ** 0.5)
bmm2_scale = 1.0
window_left = -1 # No sliding window
sm_count = torch.cuda.get_device_properties(device).multi_processor_count

args =(
out, q, k_cache_paged, v_cache_paged, workspace_buffer,
block_tables, cache_seqlens, max_seq_len,
bmm1_scale, bmm2_scale, window_left, sm_count
)
return args
Loading