Skip to content

Commit 0488883

Browse files
eqypytorchmergebot
authored andcommitted
[cuDNN][SDPA] Fix head-dim 256 condition for SM 10.0 (pytorch#152076)
turns out the backward is not supported yet, whoops Pull Request resolved: pytorch#152076 Approved by: https://github.com/drisspg
1 parent 07290bd commit 0488883

File tree

2 files changed

+26
-1
lines changed

2 files changed

+26
-1
lines changed

aten/src/ATen/native/transformers/cuda/sdp_utils.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -416,7 +416,7 @@ bool check_cudnn_tensor_shapes(sdp_params const& params, bool debug) {
416416
auto head_dim_limit = 128;
417417
if (cudnn_version >= 90501) {
418418
auto dprops = at::cuda::getCurrentDeviceProperties();
419-
if ((dprops->major == 9 || dprops->major == 10) && !dprops->minor) {
419+
if (dprops->major == 9 && !dprops->minor) {
420420
head_dim_limit = 256;
421421
}
422422
}

test/test_transformers.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2493,6 +2493,31 @@ def test_cudnn_attention_gqa(self, device):
24932493

24942494
self.assertEqual(output_math, output_cudnn)
24952495

2496+
@skipIfRocm # No cuDNN Attention
2497+
@unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cuDNN Attention is not supported on this system")
2498+
def test_cudnn_attention_d256_heuristic(self, device):
2499+
dtype = torch.bfloat16
2500+
make_tensor = partial(torch.rand, device=device, dtype=dtype, requires_grad=True)
2501+
batch, num_heads, head_dim_k, head_dim_v = 32, 16, 256, 64
2502+
seq_len = 640
2503+
q_shape = SdpaShape(batch, num_heads, seq_len, head_dim_k)
2504+
k_shape = SdpaShape(batch, num_heads, seq_len, head_dim_k)
2505+
v_shape = SdpaShape(batch, num_heads, seq_len, head_dim_v)
2506+
query, key, value = make_tensor(q_shape), make_tensor(k_shape), make_tensor(v_shape)
2507+
2508+
with sdpa_kernel(backends=[SDPBackend.CUDNN_ATTENTION, SDPBackend.MATH], set_priority=True):
2509+
actual = torch.nn.functional.scaled_dot_product_attention(
2510+
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False)
2511+
actual.backward(torch.randn_like(actual))
2512+
with sdpa_kernel(backends=[SDPBackend.MATH]):
2513+
math_ref = torch.nn.functional.scaled_dot_product_attention(
2514+
query.contiguous().to(torch.float32),
2515+
key.contiguous().to(torch.float32),
2516+
value.contiguous().to(torch.float32),
2517+
attn_mask=None, dropout_p=0.0, is_causal=False)
2518+
2519+
self.assertEqual(actual.contiguous(), math_ref.contiguous().to(dtype), atol=1e-3, rtol=1e-2)
2520+
24962521
@skipIfRocm(msg="No cuDNN on ROCm")
24972522
@unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cuDNN Attention is not supported on this system")
24982523
def test_fused_attention_different_dk_dv(self, device):

0 commit comments

Comments
 (0)