From abc9df4526a078490682b8daf2af2d88182c15d8 Mon Sep 17 00:00:00 2001 From: Jhao-Ting Chen Date: Thu, 31 Jul 2025 12:06:57 -0700 Subject: [PATCH 1/4] add xqa spec-dec kernel multi block tuning heuristic Signed-off-by: Jhao-Ting Chen --- .../decoderXQAImplCommon.h | 82 +++++++++++++++++++ .../decoderXQAImplJIT/decoderXQAImplJIT.cpp | 6 ++ .../gptAttentionPlugin/gptAttentionPlugin.cpp | 1 + 3 files changed, 89 insertions(+) diff --git a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplCommon.h b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplCommon.h index 884f3c0f64..034dddb5de 100644 --- a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplCommon.h +++ b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplCommon.h @@ -381,5 +381,87 @@ inline int computeMultiBlockCountForMLA(XQAParams const& xqaParams, int multipro return 1; // disable multi-block for MLA kernel for now. } +inline int computeMultiBlockCountSpecDecGMMA( + XQAParams const& xqaParams, int batch_size, int multiprocessor_count, int specDecBlocks) +{ + auto const userSpecified = tensorrt_llm::common::getEnvXqaBlocksPerSequence(); + if (userSpecified.has_value()) + { + return userSpecified.value(); + } + int multi_block_count = 1; + + int num_kv_heads = xqaParams.num_kv_heads; + int history_length = xqaParams.max_past_kv_length; + + // skip tuning for large BS or short ISL case. + if (batch_size > 32 || history_length < 2048) + { + return multi_block_count; + } + + // gridDim = dim3{specDecBlocks, multi_block, nbKVHeads * xqaParams.batch_size} + int single_block_count = specDecBlocks * num_kv_heads * batch_size; + double wave_count = (double) single_block_count / (double) multiprocessor_count; + + // Multi block tuning for low CTA: populating CTAs to at most 1 wave of SMs + if (wave_count < 1) + { + auto highestPowerof2 = [](int x) + { + x |= x >> 1; + x |= x >> 2; + x |= x >> 4; + x |= x >> 8; + x |= x >> 16; + return x ^ (x >> 1); + }; + + // calculate the maximum blocks to be populated at most 1 wave + multi_block_count = floor(multiprocessor_count / single_block_count); + // make multi_block_count a power of 2 for tuning convenience. + multi_block_count = highestPowerof2(multi_block_count); + // make multi_block_count at most 64 and at least 1. + multi_block_count = std::min(multi_block_count, 64); + multi_block_count = std::max(multi_block_count, 1); + + // tune only when original CTA is too small, multi_block_count is too big, and history length < 2^16 + // For Hopper, most cases there are 114, 132, 144 SMs. For H20 about 78. + // single_block_count = [1..8] + // multi_block_count = [16,32,64,128] + // history_length = [1024..65536] + if (single_block_count <= 8 && multi_block_count >= 16 && history_length < 65536) + { + if (history_length < 2048) + { + // for history length < 2048 and low CTA, scaling is not effective, so we set a hard limit to + // multi_block_count = 4 + multi_block_count = std::min(multi_block_count, 4); + } + else if (history_length < 65536) + { + // at single_block == 8, multi_block_count can only be 16. (SM / 8 ~= 16) + // tune only 2048 <= kvlen < 8192 + if (single_block_count == 8 && history_length <= 8192) + { + multi_block_count >>= 1; + } + else + { + auto getLog2 = [](int x) { return x ? 31 - __builtin_clz(x) : -1; }; + auto history_length_log2 = getLog2(history_length); + multi_block_count >>= 3 - (history_length_log2 - 10) / 2; + // 2^15 (< 65536) -> shift 1 + // 2^13, 2^14 -> shift 2 + // 2^11, 2^12 (> 1024) -> shift 3 + } + } + } + TLLM_CHECK_WITH_INFO((multi_block_count * single_block_count) <= multiprocessor_count, + "The adjusted MultiBlock exceed number of SMs, adding additional wave may result to perf drop."); + } + return multi_block_count; +} + } // namespace kernels } // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/decoderXQAImplJIT.cpp b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/decoderXQAImplJIT.cpp index 802ffe31c5..408f9fcf25 100644 --- a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/decoderXQAImplJIT.cpp +++ b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/decoderXQAImplJIT.cpp @@ -443,6 +443,12 @@ void DecoderXQAImplJIT::runImpl(XQAParams const& xqaParams, KVCacheBuffer const& { multi_block = computeMultiBlockCount(xqaParams, xqaParams.batch_size, multiprocessor_count); } + // A WAR to enable Hopper XQA multi-token multi_block mode + if (isSpecDec && isGMMAKernel) + { + multi_block = computeMultiBlockCountSpecDecGMMA( + xqaParams, xqaParams.batch_size, multiprocessor_count, specDecBlocks); + } uint32_t const nbKVHeads = xqaParams.num_kv_heads; auto const gridDim = (isGMMAKernel ? dim3{specDecBlocks, multi_block, nbKVHeads * xqaParams.batch_size} : dim3{multi_block, nbKVHeads, xqaParams.batch_size}); diff --git a/cpp/tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.cpp b/cpp/tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.cpp index ed21db8663..4f56749f8f 100644 --- a/cpp/tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.cpp +++ b/cpp/tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.cpp @@ -705,6 +705,7 @@ int GPTAttentionPlugin::enqueueSome(int32_t seqIdxBeg, int32_t localNbSeq, int32 mUseSpecDecoding = useSpecDecoding; // change mMultiBlockMode to default mMultiBlockMode = mUseSpecDecoding ? false : true; + // if Hopper XQA kernel is enabled, multi block mode will be true in decoderXQAImplJIT::runImpl } [[maybe_unused]] MlaParams mla_params; From 07679bf4b2e4de0a4733af60595be8568faf0aa4 Mon Sep 17 00:00:00 2001 From: Jhao-Ting Chen Date: Thu, 31 Jul 2025 14:54:23 -0700 Subject: [PATCH 2/4] enable multiblock in torch attentionOp.cpp Signed-off-by: Jhao-Ting Chen --- cpp/tensorrt_llm/common/attentionOp.cpp | 43 ++++++++----------- .../decoderXQAImplCommon.h | 9 ++-- .../decoderXQAImplJIT/decoderXQAImplJIT.cpp | 16 ++++--- .../gptAttentionPlugin/gptAttentionPlugin.cpp | 1 - cpp/tensorrt_llm/thop/attentionOp.cpp | 1 - 5 files changed, 34 insertions(+), 36 deletions(-) diff --git a/cpp/tensorrt_llm/common/attentionOp.cpp b/cpp/tensorrt_llm/common/attentionOp.cpp index a7e28defb0..be64673122 100644 --- a/cpp/tensorrt_llm/common/attentionOp.cpp +++ b/cpp/tensorrt_llm/common/attentionOp.cpp @@ -2077,35 +2077,31 @@ int AttentionOp::enqueueGeneration(EnqueueGenerationParams const& params, cud debugCheckSemaphores(stream); #endif - // Medusa doesn't support multi-block mode. - if (!(mIsSpecDecodingEnabled && mUseSpecDecoding)) + if (params.runtime_perf_knobs) { - if (params.runtime_perf_knobs) - { - int64_t multi_block_mode_val = params.runtime_perf_knobs[0]; - mMultiBlockMode = multi_block_mode_val == 1; - if (common::getEnvForceDeterministicAttention()) - { - mMultiBlockMode = false; - } - } - + int64_t multi_block_mode_val = params.runtime_perf_knobs[0]; + mMultiBlockMode = multi_block_mode_val == 1; if (common::getEnvForceDeterministicAttention()) { mMultiBlockMode = false; } + } - // TODO only for debug usage - if (!mMultiBlockMode) - { - char* isForceMultiBlockModeChar = std::getenv("FORCE_MULTI_BLOCK_MODE"); - bool isForceMultiBlockMode - = (isForceMultiBlockModeChar != nullptr && std::string(isForceMultiBlockModeChar) == "ON"); - TLLM_CHECK_WITH_INFO(!(common::getEnvForceDeterministicAttention() && isForceMultiBlockMode), - "FORCE_MULTI_BLOCK_MODE and FORCE_DETERMINISTIC/FORCE_ATTENTION_KERNEL_DETERMINISTIC can not be set at " - "the same time."); - mMultiBlockMode = isForceMultiBlockMode; - } + if (common::getEnvForceDeterministicAttention()) + { + mMultiBlockMode = false; + } + + // TODO only for debug usage + if (!mMultiBlockMode) + { + char* isForceMultiBlockModeChar = std::getenv("FORCE_MULTI_BLOCK_MODE"); + bool isForceMultiBlockMode + = (isForceMultiBlockModeChar != nullptr && std::string(isForceMultiBlockModeChar) == "ON"); + TLLM_CHECK_WITH_INFO(!(common::getEnvForceDeterministicAttention() && isForceMultiBlockMode), + "FORCE_MULTI_BLOCK_MODE and FORCE_DETERMINISTIC/FORCE_ATTENTION_KERNEL_DETERMINISTIC can not be set at " + "the same time."); + mMultiBlockMode = isForceMultiBlockMode; } // Check that the chunked-attention and sliding-window-attention are not enabled at the same time. @@ -2723,7 +2719,6 @@ int AttentionOp::initialize() noexcept { fixedParams.outputDataType = DATA_TYPE_E4M3; TLLM_CHECK_WITH_INFO(mNumHeads % mNumKVHeads == 0, "mNumHeads should be multiples of mNumKVHeads."); - TLLM_CHECK_WITH_INFO(!mMultiBlockMode, "Medusa doesn't support multi-block mode."); } fixedParams.numQHeads = mNumAttnHeads; fixedParams.numKvHeads = mNumAttnKVHeads; diff --git a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplCommon.h b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplCommon.h index 034dddb5de..bc6bbf49d8 100644 --- a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplCommon.h +++ b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplCommon.h @@ -450,10 +450,13 @@ inline int computeMultiBlockCountSpecDecGMMA( { auto getLog2 = [](int x) { return x ? 31 - __builtin_clz(x) : -1; }; auto history_length_log2 = getLog2(history_length); + // Adjust multi_block_count based on history length using formula: + // shift_amount = 3 - (log2(history_length) - 10) / 2 + // This gives us: + // - history_length in [2^11, 2^12): shift by 3 + // - history_length in [2^13, 2^14): shift by 2 + // - history_length in [2^15, 2^16): shift by 1 multi_block_count >>= 3 - (history_length_log2 - 10) / 2; - // 2^15 (< 65536) -> shift 1 - // 2^13, 2^14 -> shift 2 - // 2^11, 2^12 (> 1024) -> shift 3 } } } diff --git a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/decoderXQAImplJIT.cpp b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/decoderXQAImplJIT.cpp index 408f9fcf25..35c7ffcd84 100644 --- a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/decoderXQAImplJIT.cpp +++ b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/decoderXQAImplJIT.cpp @@ -441,13 +441,15 @@ void DecoderXQAImplJIT::runImpl(XQAParams const& xqaParams, KVCacheBuffer const& uint32_t multi_block = 1; if (xqaParams.multi_block_mode) { - multi_block = computeMultiBlockCount(xqaParams, xqaParams.batch_size, multiprocessor_count); - } - // A WAR to enable Hopper XQA multi-token multi_block mode - if (isSpecDec && isGMMAKernel) - { - multi_block = computeMultiBlockCountSpecDecGMMA( - xqaParams, xqaParams.batch_size, multiprocessor_count, specDecBlocks); + if (isSpecDec && isGMMAKernel) + { + multi_block = computeMultiBlockCountSpecDecGMMA( + xqaParams, xqaParams.batch_size, multiprocessor_count, specDecBlocks); + } + else + { + multi_block = computeMultiBlockCount(xqaParams, xqaParams.batch_size, multiprocessor_count); + } } uint32_t const nbKVHeads = xqaParams.num_kv_heads; auto const gridDim = (isGMMAKernel ? dim3{specDecBlocks, multi_block, nbKVHeads * xqaParams.batch_size} diff --git a/cpp/tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.cpp b/cpp/tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.cpp index 4f56749f8f..ed21db8663 100644 --- a/cpp/tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.cpp +++ b/cpp/tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.cpp @@ -705,7 +705,6 @@ int GPTAttentionPlugin::enqueueSome(int32_t seqIdxBeg, int32_t localNbSeq, int32 mUseSpecDecoding = useSpecDecoding; // change mMultiBlockMode to default mMultiBlockMode = mUseSpecDecoding ? false : true; - // if Hopper XQA kernel is enabled, multi block mode will be true in decoderXQAImplJIT::runImpl } [[maybe_unused]] MlaParams mla_params; diff --git a/cpp/tensorrt_llm/thop/attentionOp.cpp b/cpp/tensorrt_llm/thop/attentionOp.cpp index 7a77fc49bb..9017369627 100644 --- a/cpp/tensorrt_llm/thop/attentionOp.cpp +++ b/cpp/tensorrt_llm/thop/attentionOp.cpp @@ -528,7 +528,6 @@ void attention_inplace(torch::Tensor q, torch::optional k, torch: "Expecting 2 bools for spec-dec mode, is_spec_decoding_enabled and use_spec_decoding."); op->mIsSpecDecodingEnabled = spec_decoding_bool_params[0]; // is_spec_decoding_enabled op->mUseSpecDecoding = spec_decoding_bool_params[1]; // use_spec_decoding - op->mMultiBlockMode = op->mIsSpecDecodingEnabled ? false : true; if (is_mla_enable) { From bec4406514d9443180b12c08bc6ea8a79e6d1040 Mon Sep 17 00:00:00 2001 From: Jhao-Ting Chen Date: Thu, 31 Jul 2025 16:28:58 -0700 Subject: [PATCH 3/4] enable trt backend sd kernel multi-block mode Signed-off-by: Jhao-Ting Chen --- .../plugins/gptAttentionCommon/gptAttentionCommon.cpp | 3 +-- .../plugins/gptAttentionPlugin/gptAttentionPlugin.cpp | 2 -- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/cpp/tensorrt_llm/plugins/gptAttentionCommon/gptAttentionCommon.cpp b/cpp/tensorrt_llm/plugins/gptAttentionCommon/gptAttentionCommon.cpp index 0d9eef80fb..98e59c8fdd 100644 --- a/cpp/tensorrt_llm/plugins/gptAttentionCommon/gptAttentionCommon.cpp +++ b/cpp/tensorrt_llm/plugins/gptAttentionCommon/gptAttentionCommon.cpp @@ -72,8 +72,7 @@ GPTAttentionPluginCommon::GPTAttentionPluginCommon(int layer_idx, int num_heads, mMaskType = mask_type; mBlockSparseParams = block_sparse_params; mType = type; - mMultiBlockMode - = is_spec_decoding_enabled ? false : true; // set to true in build time to account for enough workspace size + mMultiBlockMode = true; mEnableXQA = true; mKVCacheQuantMode = tc::QuantMode(kv_cache_quant_mode); mRemovePadding = remove_input_padding; diff --git a/cpp/tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.cpp b/cpp/tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.cpp index ed21db8663..56281416fa 100644 --- a/cpp/tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.cpp +++ b/cpp/tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.cpp @@ -703,8 +703,6 @@ int GPTAttentionPlugin::enqueueSome(int32_t seqIdxBeg, int32_t localNbSeq, int32 = static_cast(reinterpret_cast(inputs[getIdx(IdxEntry::SPEC_DECODING_USE)])[0]); changeSpecDecodingMode = mUseSpecDecoding != useSpecDecoding; mUseSpecDecoding = useSpecDecoding; - // change mMultiBlockMode to default - mMultiBlockMode = mUseSpecDecoding ? false : true; } [[maybe_unused]] MlaParams mla_params; From d3735437815c9e691ae7c7b490b873c372ef536e Mon Sep 17 00:00:00 2001 From: Jhao-Ting Chen Date: Fri, 1 Aug 2025 17:10:49 -0700 Subject: [PATCH 4/4] remove multi-block mode for precompiled XQA, and spec-dec but not QGMMA path Signed-off-by: Jhao-Ting Chen --- .../decoderXQAImplJIT/decoderXQAImplJIT.cpp | 2 +- .../decoderXQAImplPrecompiled.cpp | 8 +------- 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/decoderXQAImplJIT.cpp b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/decoderXQAImplJIT.cpp index 35c7ffcd84..9406141471 100644 --- a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/decoderXQAImplJIT.cpp +++ b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/decoderXQAImplJIT.cpp @@ -446,7 +446,7 @@ void DecoderXQAImplJIT::runImpl(XQAParams const& xqaParams, KVCacheBuffer const& multi_block = computeMultiBlockCountSpecDecGMMA( xqaParams, xqaParams.batch_size, multiprocessor_count, specDecBlocks); } - else + else if (!isSpecDec) { multi_block = computeMultiBlockCount(xqaParams, xqaParams.batch_size, multiprocessor_count); } diff --git a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplPrecompiled.cpp b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplPrecompiled.cpp index 6c6d4cd0b2..ebe6722ac7 100644 --- a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplPrecompiled.cpp +++ b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplPrecompiled.cpp @@ -287,14 +287,8 @@ class XQAKernelList void* kernelParams[] = {&maxQSeqLen, &launchParams.num_k_heads, &headGrpSize, &cuQSeqLens, &launchParams.output, &xqa_q_input_ptr, &maskPtr, &launchParams.kvCacheParams, &launchParams.batch_size, &launchParams.kv_scale_quant_orig, &launchParams.scratch}; + // precompiled XQA Spec-dec kernel does not support multi-block mode int multi_block = 1; - if (xqaParams.multi_block_mode) - { - multi_block = computeMultiBlockCount(xqaParams, xqaParams.batch_size, multiprocessor_count); - check_cuda_error(cudaMemsetAsync(xqaParams.workspaces, 0, - sizeof(int) * xqaParams.batch_size * qSeqLen * xqaParams.num_kv_heads, stream)); - sync_check_cuda_error(stream); - } TLLM_CU_CHECK(mDriver->cuLaunchKernel(func, multi_block, xqaParams.num_kv_heads * nbTokenBlocksPerGrp, xqaParams.batch_size, 128, 1, 2, shared_mem_bytes, stream, kernelParams, nullptr)); }