Skip to content

Commit 545497d

Browse files
jhaotingclancelly
authored andcommitted
[None][feat] Multi-block mode for Hopper spec dec XQA kernel (NVIDIA#4416)
Signed-off-by: Jhao-Ting Chen <[email protected]> Signed-off-by: Lanyu Liao <[email protected]>
1 parent 389b472 commit 545497d

File tree

7 files changed

+115
-37
lines changed

7 files changed

+115
-37
lines changed

cpp/tensorrt_llm/common/attentionOp.cpp

Lines changed: 19 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2077,35 +2077,31 @@ int AttentionOp::enqueueGeneration(EnqueueGenerationParams<T> const& params, cud
20772077
debugCheckSemaphores(stream);
20782078
#endif
20792079

2080-
// Medusa doesn't support multi-block mode.
2081-
if (!(mIsSpecDecodingEnabled && mUseSpecDecoding))
2080+
if (params.runtime_perf_knobs)
20822081
{
2083-
if (params.runtime_perf_knobs)
2084-
{
2085-
int64_t multi_block_mode_val = params.runtime_perf_knobs[0];
2086-
mMultiBlockMode = multi_block_mode_val == 1;
2087-
if (common::getEnvForceDeterministicAttention())
2088-
{
2089-
mMultiBlockMode = false;
2090-
}
2091-
}
2092-
2082+
int64_t multi_block_mode_val = params.runtime_perf_knobs[0];
2083+
mMultiBlockMode = multi_block_mode_val == 1;
20932084
if (common::getEnvForceDeterministicAttention())
20942085
{
20952086
mMultiBlockMode = false;
20962087
}
2088+
}
20972089

2098-
// TODO only for debug usage
2099-
if (!mMultiBlockMode)
2100-
{
2101-
char* isForceMultiBlockModeChar = std::getenv("FORCE_MULTI_BLOCK_MODE");
2102-
bool isForceMultiBlockMode
2103-
= (isForceMultiBlockModeChar != nullptr && std::string(isForceMultiBlockModeChar) == "ON");
2104-
TLLM_CHECK_WITH_INFO(!(common::getEnvForceDeterministicAttention() && isForceMultiBlockMode),
2105-
"FORCE_MULTI_BLOCK_MODE and FORCE_DETERMINISTIC/FORCE_ATTENTION_KERNEL_DETERMINISTIC can not be set at "
2106-
"the same time.");
2107-
mMultiBlockMode = isForceMultiBlockMode;
2108-
}
2090+
if (common::getEnvForceDeterministicAttention())
2091+
{
2092+
mMultiBlockMode = false;
2093+
}
2094+
2095+
// TODO only for debug usage
2096+
if (!mMultiBlockMode)
2097+
{
2098+
char* isForceMultiBlockModeChar = std::getenv("FORCE_MULTI_BLOCK_MODE");
2099+
bool isForceMultiBlockMode
2100+
= (isForceMultiBlockModeChar != nullptr && std::string(isForceMultiBlockModeChar) == "ON");
2101+
TLLM_CHECK_WITH_INFO(!(common::getEnvForceDeterministicAttention() && isForceMultiBlockMode),
2102+
"FORCE_MULTI_BLOCK_MODE and FORCE_DETERMINISTIC/FORCE_ATTENTION_KERNEL_DETERMINISTIC can not be set at "
2103+
"the same time.");
2104+
mMultiBlockMode = isForceMultiBlockMode;
21092105
}
21102106

21112107
// Check that the chunked-attention and sliding-window-attention are not enabled at the same time.
@@ -2723,7 +2719,6 @@ int AttentionOp::initialize() noexcept
27232719
{
27242720
fixedParams.outputDataType = DATA_TYPE_E4M3;
27252721
TLLM_CHECK_WITH_INFO(mNumHeads % mNumKVHeads == 0, "mNumHeads should be multiples of mNumKVHeads.");
2726-
TLLM_CHECK_WITH_INFO(!mMultiBlockMode, "Medusa doesn't support multi-block mode.");
27272722
}
27282723
fixedParams.numQHeads = mNumAttnHeads;
27292724
fixedParams.numKvHeads = mNumAttnKVHeads;

cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplCommon.h

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -381,5 +381,90 @@ inline int computeMultiBlockCountForMLA(XQAParams const& xqaParams, int multipro
381381
return 1; // disable multi-block for MLA kernel for now.
382382
}
383383

384+
inline int computeMultiBlockCountSpecDecGMMA(
385+
XQAParams const& xqaParams, int batch_size, int multiprocessor_count, int specDecBlocks)
386+
{
387+
auto const userSpecified = tensorrt_llm::common::getEnvXqaBlocksPerSequence();
388+
if (userSpecified.has_value())
389+
{
390+
return userSpecified.value();
391+
}
392+
int multi_block_count = 1;
393+
394+
int num_kv_heads = xqaParams.num_kv_heads;
395+
int history_length = xqaParams.max_past_kv_length;
396+
397+
// skip tuning for large BS or short ISL case.
398+
if (batch_size > 32 || history_length < 2048)
399+
{
400+
return multi_block_count;
401+
}
402+
403+
// gridDim = dim3{specDecBlocks, multi_block, nbKVHeads * xqaParams.batch_size}
404+
int single_block_count = specDecBlocks * num_kv_heads * batch_size;
405+
double wave_count = (double) single_block_count / (double) multiprocessor_count;
406+
407+
// Multi block tuning for low CTA: populating CTAs to at most 1 wave of SMs
408+
if (wave_count < 1)
409+
{
410+
auto highestPowerof2 = [](int x)
411+
{
412+
x |= x >> 1;
413+
x |= x >> 2;
414+
x |= x >> 4;
415+
x |= x >> 8;
416+
x |= x >> 16;
417+
return x ^ (x >> 1);
418+
};
419+
420+
// calculate the maximum blocks to be populated at most 1 wave
421+
multi_block_count = floor(multiprocessor_count / single_block_count);
422+
// make multi_block_count a power of 2 for tuning convenience.
423+
multi_block_count = highestPowerof2(multi_block_count);
424+
// make multi_block_count at most 64 and at least 1.
425+
multi_block_count = std::min(multi_block_count, 64);
426+
multi_block_count = std::max(multi_block_count, 1);
427+
428+
// tune only when original CTA is too small, multi_block_count is too big, and history length < 2^16
429+
// For Hopper, most cases there are 114, 132, 144 SMs. For H20 about 78.
430+
// single_block_count = [1..8]
431+
// multi_block_count = [16,32,64,128]
432+
// history_length = [1024..65536]
433+
if (single_block_count <= 8 && multi_block_count >= 16 && history_length < 65536)
434+
{
435+
if (history_length < 2048)
436+
{
437+
// for history length < 2048 and low CTA, scaling is not effective, so we set a hard limit to
438+
// multi_block_count = 4
439+
multi_block_count = std::min(multi_block_count, 4);
440+
}
441+
else if (history_length < 65536)
442+
{
443+
// at single_block == 8, multi_block_count can only be 16. (SM / 8 ~= 16)
444+
// tune only 2048 <= kvlen < 8192
445+
if (single_block_count == 8 && history_length <= 8192)
446+
{
447+
multi_block_count >>= 1;
448+
}
449+
else
450+
{
451+
auto getLog2 = [](int x) { return x ? 31 - __builtin_clz(x) : -1; };
452+
auto history_length_log2 = getLog2(history_length);
453+
// Adjust multi_block_count based on history length using formula:
454+
// shift_amount = 3 - (log2(history_length) - 10) / 2
455+
// This gives us:
456+
// - history_length in [2^11, 2^12): shift by 3
457+
// - history_length in [2^13, 2^14): shift by 2
458+
// - history_length in [2^15, 2^16): shift by 1
459+
multi_block_count >>= 3 - (history_length_log2 - 10) / 2;
460+
}
461+
}
462+
}
463+
TLLM_CHECK_WITH_INFO((multi_block_count * single_block_count) <= multiprocessor_count,
464+
"The adjusted MultiBlock exceed number of SMs, adding additional wave may result to perf drop.");
465+
}
466+
return multi_block_count;
467+
}
468+
384469
} // namespace kernels
385470
} // namespace tensorrt_llm

cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/decoderXQAImplJIT.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -441,7 +441,15 @@ void DecoderXQAImplJIT::runImpl(XQAParams const& xqaParams, KVCacheBuffer const&
441441
uint32_t multi_block = 1;
442442
if (xqaParams.multi_block_mode)
443443
{
444-
multi_block = computeMultiBlockCount(xqaParams, xqaParams.batch_size, multiprocessor_count);
444+
if (isSpecDec && isGMMAKernel)
445+
{
446+
multi_block = computeMultiBlockCountSpecDecGMMA(
447+
xqaParams, xqaParams.batch_size, multiprocessor_count, specDecBlocks);
448+
}
449+
else if (!isSpecDec)
450+
{
451+
multi_block = computeMultiBlockCount(xqaParams, xqaParams.batch_size, multiprocessor_count);
452+
}
445453
}
446454
uint32_t const nbKVHeads = xqaParams.num_kv_heads;
447455
auto const gridDim = (isGMMAKernel ? dim3{specDecBlocks, multi_block, nbKVHeads * xqaParams.batch_size}

cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplPrecompiled.cpp

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -287,14 +287,8 @@ class XQAKernelList
287287
void* kernelParams[] = {&maxQSeqLen, &launchParams.num_k_heads, &headGrpSize, &cuQSeqLens,
288288
&launchParams.output, &xqa_q_input_ptr, &maskPtr, &launchParams.kvCacheParams, &launchParams.batch_size,
289289
&launchParams.kv_scale_quant_orig, &launchParams.scratch};
290+
// precompiled XQA Spec-dec kernel does not support multi-block mode
290291
int multi_block = 1;
291-
if (xqaParams.multi_block_mode)
292-
{
293-
multi_block = computeMultiBlockCount(xqaParams, xqaParams.batch_size, multiprocessor_count);
294-
check_cuda_error(cudaMemsetAsync(xqaParams.workspaces, 0,
295-
sizeof(int) * xqaParams.batch_size * qSeqLen * xqaParams.num_kv_heads, stream));
296-
sync_check_cuda_error(stream);
297-
}
298292
TLLM_CU_CHECK(mDriver->cuLaunchKernel(func, multi_block, xqaParams.num_kv_heads * nbTokenBlocksPerGrp,
299293
xqaParams.batch_size, 128, 1, 2, shared_mem_bytes, stream, kernelParams, nullptr));
300294
}

cpp/tensorrt_llm/plugins/gptAttentionCommon/gptAttentionCommon.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,7 @@ GPTAttentionPluginCommon::GPTAttentionPluginCommon(int layer_idx, int num_heads,
7272
mMaskType = mask_type;
7373
mBlockSparseParams = block_sparse_params;
7474
mType = type;
75-
mMultiBlockMode
76-
= is_spec_decoding_enabled ? false : true; // set to true in build time to account for enough workspace size
75+
mMultiBlockMode = true;
7776
mEnableXQA = true;
7877
mKVCacheQuantMode = tc::QuantMode(kv_cache_quant_mode);
7978
mRemovePadding = remove_input_padding;

cpp/tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -703,8 +703,6 @@ int GPTAttentionPlugin::enqueueSome(int32_t seqIdxBeg, int32_t localNbSeq, int32
703703
= static_cast<bool>(reinterpret_cast<int const*>(inputs[getIdx(IdxEntry::SPEC_DECODING_USE)])[0]);
704704
changeSpecDecodingMode = mUseSpecDecoding != useSpecDecoding;
705705
mUseSpecDecoding = useSpecDecoding;
706-
// change mMultiBlockMode to default
707-
mMultiBlockMode = mUseSpecDecoding ? false : true;
708706
}
709707

710708
[[maybe_unused]] MlaParams<T> mla_params;

cpp/tensorrt_llm/thop/attentionOp.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -528,7 +528,6 @@ void attention_inplace(torch::Tensor q, torch::optional<torch::Tensor> k, torch:
528528
"Expecting 2 bools for spec-dec mode, is_spec_decoding_enabled and use_spec_decoding.");
529529
op->mIsSpecDecodingEnabled = spec_decoding_bool_params[0]; // is_spec_decoding_enabled
530530
op->mUseSpecDecoding = spec_decoding_bool_params[1]; // use_spec_decoding
531-
op->mMultiBlockMode = op->mIsSpecDecodingEnabled ? false : true;
532531

533532
if (is_mla_enable)
534533
{

0 commit comments

Comments
 (0)