@@ -413,9 +413,13 @@ class TllmGenFmhaKernel
413
413
return std::make_tuple (numCtasPerSeqQ, numCtasPerSeqKv, numCtasX, numCtasY, numCtasZ, clusterDimX);
414
414
}
415
415
416
- // Compute the seqLenPerCtaKv for selecting the MLA generation kernel .
417
- int computeSeqLenPerCtaKv (RunnerParams const & params) const
416
+ // Determine if we should use the SwapsMmaAbForGeneration kernel for MLA generation.
417
+ bool useSwapsMmaAbMlaGenKernel (RunnerParams const & params) const
418
418
{
419
+ // Use the SwapsMmaAbForGeneration kernel for MLA generation when the following conditions are met:
420
+ // 1. The seqLenPerCtaKv <= 1024 based on the benchmark results (this might be fine-tuned later).
421
+ // 2. The numCtas (after splitting the heads across multiple CTAs) <= params.mMultiProcessorCount.
422
+
419
423
// The maximum number Ctas per Kv sequence, which makes sure that each CtaKv has work to do.
420
424
// Here we assume the stepKv is 256.
421
425
int const maxNumCtasPerSeqKv = (params.mMaxSeqLenKv + 256 - 1 ) / 256 ;
@@ -427,8 +431,8 @@ class TllmGenFmhaKernel
427
431
= std::min (maxNumCtasPerSeqKv, std::max (1 , int32_t (params.mMultiProcessorCount / numCtas)));
428
432
// Compute the seqLenPerCtaKv.
429
433
int const seqLenPerCtaKv = (params.mMaxSeqLenKv + numCtasPerSeqKv - 1 ) / numCtasPerSeqKv;
430
- // Return the seqLenPerCtaKv .
431
- return seqLenPerCtaKv;
434
+ // Whether we should use the SwapsMmaAbForGeneration kernel for MLA generation .
435
+ return seqLenPerCtaKv <= 1024 && numCtas <= params. mMultiProcessorCount ;
432
436
}
433
437
434
438
std::pair<uint64_t , std::string> hashFromRunnerParams (
@@ -442,10 +446,11 @@ class TllmGenFmhaKernel
442
446
// We use the low-latency kernel (SwapsMmaAbForGeneration with tileSizeQ = 16) when any of the following
443
447
// conditions are met:
444
448
// 1. The number of headsQPerKv is <= 32.
445
- // 2. The seqLenPerCtaKv <= 1024 based on the benchmark results (this might be fine-tuned later).
449
+ // 2. The seqLenPerCtaKv <= 1024 based on the benchmark results (this might be fine-tuned later) and
450
+ // the numCtas (after splitting the heads across multiple CTAs) <= params.mMultiProcessorCount.
446
451
447
452
// Check the conditions.
448
- if (params.mNumHeadsQPerKv <= 32 || computeSeqLenPerCtaKv (params) <= 1024 )
453
+ if (params.mNumHeadsQPerKv <= 32 || useSwapsMmaAbMlaGenKernel (params))
449
454
{
450
455
kernelType = FmhaKernelType::SwapsMmaAbForGeneration;
451
456
}
0 commit comments