@@ -2493,6 +2493,31 @@ def test_cudnn_attention_gqa(self, device):
24932493
24942494 self .assertEqual (output_math , output_cudnn )
24952495
2496+ @skipIfRocm # No cuDNN Attention
2497+ @unittest .skipIf (not PLATFORM_SUPPORTS_CUDNN_ATTENTION , "cuDNN Attention is not supported on this system" )
2498+ def test_cudnn_attention_d256_heuristic (self , device ):
2499+ dtype = torch .bfloat16
2500+ make_tensor = partial (torch .rand , device = device , dtype = dtype , requires_grad = True )
2501+ batch , num_heads , head_dim_k , head_dim_v = 32 , 16 , 256 , 64
2502+ seq_len = 640
2503+ q_shape = SdpaShape (batch , num_heads , seq_len , head_dim_k )
2504+ k_shape = SdpaShape (batch , num_heads , seq_len , head_dim_k )
2505+ v_shape = SdpaShape (batch , num_heads , seq_len , head_dim_v )
2506+ query , key , value = make_tensor (q_shape ), make_tensor (k_shape ), make_tensor (v_shape )
2507+
2508+ with sdpa_kernel (backends = [SDPBackend .CUDNN_ATTENTION , SDPBackend .MATH ], set_priority = True ):
2509+ actual = torch .nn .functional .scaled_dot_product_attention (
2510+ query , key , value , attn_mask = None , dropout_p = 0.0 , is_causal = False )
2511+ actual .backward (torch .randn_like (actual ))
2512+ with sdpa_kernel (backends = [SDPBackend .MATH ]):
2513+ math_ref = torch .nn .functional .scaled_dot_product_attention (
2514+ query .contiguous ().to (torch .float32 ),
2515+ key .contiguous ().to (torch .float32 ),
2516+ value .contiguous ().to (torch .float32 ),
2517+ attn_mask = None , dropout_p = 0.0 , is_causal = False )
2518+
2519+ self .assertEqual (actual .contiguous (), math_ref .contiguous ().to (dtype ), atol = 1e-3 , rtol = 1e-2 )
2520+
24962521 @skipIfRocm (msg = "No cuDNN on ROCm" )
24972522 @unittest .skipIf (not PLATFORM_SUPPORTS_CUDNN_ATTENTION , "cuDNN Attention is not supported on this system" )
24982523 def test_fused_attention_different_dk_dv (self , device ):
0 commit comments