@@ -42,9 +42,15 @@ struct SM90ArchSpec {
42
42
43
43
// Too many scaling factors in a single block: `block_n > block_k and std::gcd(block_n, block_k) != block_n - block_k`
44
44
// Or too many register spills
45
- if (block_n > 128 and kernel_type == KernelType::Kernel1D2D and (block_n != 144 and block_n != 160 and block_n != 192 ))
46
- return false ;
47
45
46
+ if (get_env<int >(" ENABLE_SWAPAB" )){
47
+ if (block_n != 64 and block_n != 128 and block_n != 256 )
48
+ return false ;
49
+ }else {
50
+ if (block_n > 128 and kernel_type == KernelType::Kernel1D2D and (block_n != 144 and block_n != 160 and block_n != 192 ))
51
+ return false ;
52
+ }
53
+
48
54
// Avoid bank conflicts for FP32 output
49
55
if (cd_dtype == torch::kFloat and block_n % 16 == 0 )
50
56
return false ;
@@ -79,7 +85,13 @@ struct SM90ArchSpec {
79
85
80
86
static ThreadConfig get_thread_config (const KernelType& kernel_type,
81
87
const int & block_m, const int & block_n) {
82
- return ThreadConfig::sm90 (128 , (block_m == 64 ? 1 : 2 ) * 128 );
88
+ int tile = 64 ;
89
+ if (get_env<int >(" ENABLE_SWAPAB" )){
90
+ tile = block_n;
91
+ }else {
92
+ tile = block_m;
93
+ }
94
+ return ThreadConfig::sm90 (128 , (tile > 64 ? 2 : 1 ) * 128 );
83
95
}
84
96
85
97
static int get_smem_cd_size (const KernelType& kernel_type,
@@ -104,7 +116,8 @@ struct SM90ArchSpec {
104
116
105
117
static int get_extra_sfb_smem_size (const int & m, const int & n, const int & k,
106
118
const int & block_m, const int & block_n, const int & block_k) {
107
- const auto & use_uniform_sfb = block_k % block_n == 0 ? 1 : 2 ;
119
+ const auto & use_uniform_sfb = get_env<int >(" ENABLE_SWAPAB" ) ? (block_n / 64 ):(block_k % block_n == 0 ? 1 : 2 );
120
+
108
121
return align<int >(ceil_div (k, block_k) * static_cast <int >(sizeof (float )) * use_uniform_sfb, 8 );
109
122
}
110
123
0 commit comments