Skip to content

Commit 2a3ff38

Browse files
PerkzZhengNVShreyas
authored andcommitted
[Fix] the bug in the trtllm-gen heurisitcf for MLA kernels. (NVIDIA#6284)
Signed-off-by: Perkz Zheng <[email protected]> Signed-off-by: Shreyas Misra <[email protected]>
1 parent cd32157 commit 2a3ff38

File tree

1 file changed

+11
-6
lines changed

1 file changed

+11
-6
lines changed

cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaKernels.h

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -413,9 +413,13 @@ class TllmGenFmhaKernel
413413
return std::make_tuple(numCtasPerSeqQ, numCtasPerSeqKv, numCtasX, numCtasY, numCtasZ, clusterDimX);
414414
}
415415

416-
// Compute the seqLenPerCtaKv for selecting the MLA generation kernel.
417-
int computeSeqLenPerCtaKv(RunnerParams const& params) const
416+
// Determine if we should use the SwapsMmaAbForGeneration kernel for MLA generation.
417+
bool useSwapsMmaAbMlaGenKernel(RunnerParams const& params) const
418418
{
419+
// Use the SwapsMmaAbForGeneration kernel for MLA generation when the following conditions are met:
420+
// 1. The seqLenPerCtaKv <= 1024 based on the benchmark results (this might be fine-tuned later).
421+
// 2. The numCtas (after splitting the heads across multiple CTAs) <= params.mMultiProcessorCount.
422+
419423
// The maximum number Ctas per Kv sequence, which makes sure that each CtaKv has work to do.
420424
// Here we assume the stepKv is 256.
421425
int const maxNumCtasPerSeqKv = (params.mMaxSeqLenKv + 256 - 1) / 256;
@@ -427,8 +431,8 @@ class TllmGenFmhaKernel
427431
= std::min(maxNumCtasPerSeqKv, std::max(1, int32_t(params.mMultiProcessorCount / numCtas)));
428432
// Compute the seqLenPerCtaKv.
429433
int const seqLenPerCtaKv = (params.mMaxSeqLenKv + numCtasPerSeqKv - 1) / numCtasPerSeqKv;
430-
// Return the seqLenPerCtaKv.
431-
return seqLenPerCtaKv;
434+
// Whether we should use the SwapsMmaAbForGeneration kernel for MLA generation.
435+
return seqLenPerCtaKv <= 1024 && numCtas <= params.mMultiProcessorCount;
432436
}
433437

434438
std::pair<uint64_t, std::string> hashFromRunnerParams(
@@ -442,10 +446,11 @@ class TllmGenFmhaKernel
442446
// We use the low-latency kernel (SwapsMmaAbForGeneration with tileSizeQ = 16) when any of the following
443447
// conditions are met:
444448
// 1. The number of headsQPerKv is <= 32.
445-
// 2. The seqLenPerCtaKv <= 1024 based on the benchmark results (this might be fine-tuned later).
449+
// 2. The seqLenPerCtaKv <= 1024 based on the benchmark results (this might be fine-tuned later) and
450+
// the numCtas (after splitting the heads across multiple CTAs) <= params.mMultiProcessorCount.
446451

447452
// Check the conditions.
448-
if (params.mNumHeadsQPerKv <= 32 || computeSeqLenPerCtaKv(params) <= 1024)
453+
if (params.mNumHeadsQPerKv <= 32 || useSwapsMmaAbMlaGenKernel(params))
449454
{
450455
kernelType = FmhaKernelType::SwapsMmaAbForGeneration;
451456
}

0 commit comments

Comments
 (0)