Skip to content

Commit ecd621f

Browse files
authored
feat: Add head size 72 support for QKV Preprocessing kernel (#3743)
* refactor: Fix headsize 72 attention error for TRTLLM attn backend in PyTorch workflow - Remove the head size pre-check logic in AttentionOp because head size 72 can be supported with fmha kernels. - Added support for head size 72 in unfused attention kernels(QKVPreprocessing). - Enhanced unit tests by introducing a scenario generation function for better test coverage of attention configurations(include head size 72). Signed-off-by: qixiang-99 <[email protected]> * update: Waive head_dim=72 test cases and enhance test representation - Added a waiver for head_dim=72 cases on post sm100 in the test suite to address known issues. - Introduced a custom __repr__ method in the Scenario class for pytest substring match. Signed-off-by: qixiang-99 <[email protected]> --------- Signed-off-by: qixiang-99 <[email protected]>
1 parent 5b9897a commit ecd621f

File tree

3 files changed

+40
-20
lines changed

3 files changed

+40
-20
lines changed

cpp/tensorrt_llm/common/attentionOp.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2297,8 +2297,8 @@ int AttentionOp::initialize() noexcept
22972297
"Unsupported data type, pre SM 80 GPUs do not support bfloat16");
22982298

22992299
// Pre-check whether the head size is supported by MMHA.
2300-
// Support head size == 72 only for fmha kernels (in Cross Attention), so skip pre-check here.
2301-
if (getHeadSize() == 72 && mCrossAttention)
2300+
// Support head size == 72 only for fmha kernels, so skip pre-check here.
2301+
if (getHeadSize() == 72)
23022302
{
23032303
;
23042304
}

cpp/tensorrt_llm/kernels/unfusedAttentionKernels/unfusedAttentionKernels_2_template.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1603,6 +1603,7 @@ void invokeApplyBiasRopeUpdateKVCacheDispatch(QKVPreprocessingParams<T, KVCacheB
16031603
case 32: kernelV2DispatchHeadSize<256, 32, T, TCache, KVCacheBuffer>(params, stream); break;
16041604
case 48: kernelV2DispatchHeadSize<192, 48, T, TCache, KVCacheBuffer>(params, stream); break;
16051605
case 64: kernelV2DispatchHeadSize<256, 64, T, TCache, KVCacheBuffer>(params, stream); break;
1606+
case 72: kernelV2DispatchHeadSize<288, 72, T, TCache, KVCacheBuffer>(params, stream); break;
16061607
case 80: kernelV2DispatchHeadSize<160, 80, T, TCache, KVCacheBuffer>(params, stream); break;
16071608
case 96: kernelV2DispatchHeadSize<192, 96, T, TCache, KVCacheBuffer>(params, stream); break;
16081609
case 104: kernelV2DispatchHeadSize<416, 104, T, TCache, KVCacheBuffer>(params, stream); break;

tests/unittest/_torch/test_attention_no_cache.py

Lines changed: 37 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,35 @@
1+
import itertools
12
import math
23
import random
34
from dataclasses import dataclass
4-
from typing import List
5+
from typing import List, Tuple
56

67
import pytest
78
import torch
9+
from utils.util import skip_blackwell
810

911
from tensorrt_llm._torch.attention_backend.interface import \
1012
PredefinedAttentionMask
1113
from tensorrt_llm._torch.attention_backend.utils import get_attention_backend
1214

1315

16+
def generate_attn_scenarios(num_q_heads_kv_heads: List[Tuple[int, int]],
17+
head_dim: List[int], num_layers: List[int],
18+
dtype: List[torch.dtype]):
19+
scenarios = []
20+
product_iter = itertools.product(num_q_heads_kv_heads, head_dim, num_layers,
21+
dtype)
22+
for num_q_heads_kv_head, head_dim, num_layers, dtype in product_iter:
23+
num_q_heads, num_kv_heads = num_q_heads_kv_head
24+
scenarios.append(
25+
Scenario(num_heads=num_q_heads,
26+
num_kv_heads=num_kv_heads,
27+
head_dim=head_dim,
28+
num_layers=num_layers,
29+
dtype=dtype))
30+
return scenarios
31+
32+
1433
def calculate_ref_result(q: torch.Tensor,
1534
k: torch.Tensor,
1635
v: torch.Tensor,
@@ -110,6 +129,10 @@ class Scenario:
110129
def num_kv_groups(self) -> int:
111130
return self.num_heads // self.num_kv_heads
112131

132+
# self-defined repr for pytest substring match
133+
def __repr__(self) -> str:
134+
return f"Scenario(num_heads_{self.num_heads}, num_kv_heads_{self.num_kv_heads}, head_dim_{self.head_dim}, num_layers_{self.num_layers}, dtype_{self.dtype})"
135+
113136

114137
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
115138
"""
@@ -144,26 +167,21 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
144167
random_context_sequence_lengths,
145168
]
146169

147-
scenarios = [
148-
# num_heads == num_kv_heads, single layer
149-
Scenario(
150-
num_layers=1,
151-
num_heads=32,
152-
num_kv_heads=32,
153-
head_dim=128,
154-
dtype=torch.float16,
155-
),
156-
# num_heads > num_kv_heads, multi-layer
157-
Scenario(
158-
num_layers=2,
159-
num_heads=32,
160-
num_kv_heads=8,
161-
head_dim=128,
162-
dtype=torch.float16,
163-
),
170+
num_q_heads_kv_heads = [
171+
(32, 32),
172+
(32, 8),
173+
(16, 16),
164174
]
175+
num_layers = [1, 2, 16]
176+
head_dim = [64, 72, 128]
177+
dtype = [torch.float16]
178+
179+
scenarios = generate_attn_scenarios(num_q_heads_kv_heads, head_dim, num_layers,
180+
dtype)
165181

166182

183+
# skip for blackwell
184+
@skip_blackwell
167185
# Convert parameterized tests to pytest parametrize
168186
@pytest.mark.parametrize("accuracy", [(1e-2, 1e-3)],
169187
ids=lambda x: f"atol={x[0]} rtol={x[1]}")
@@ -178,6 +196,7 @@ def test_attention_no_cache(scenario: Scenario,
178196
context_sequence_lengths: List[int], mask_type,
179197
accuracy):
180198
"""Test attention computation without using cache for both FULL and CAUSAL masks"""
199+
181200
num_heads = scenario.num_heads
182201
num_kv_heads = scenario.num_kv_heads
183202
head_dim = scenario.head_dim

0 commit comments

Comments
 (0)