Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 19 additions & 24 deletions cpp/tensorrt_llm/common/attentionOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2077,35 +2077,31 @@ int AttentionOp::enqueueGeneration(EnqueueGenerationParams<T> const& params, cud
debugCheckSemaphores(stream);
#endif

// Medusa doesn't support multi-block mode.
if (!(mIsSpecDecodingEnabled && mUseSpecDecoding))
if (params.runtime_perf_knobs)
{
if (params.runtime_perf_knobs)
{
int64_t multi_block_mode_val = params.runtime_perf_knobs[0];
mMultiBlockMode = multi_block_mode_val == 1;
if (common::getEnvForceDeterministicAttention())
{
mMultiBlockMode = false;
}
}

int64_t multi_block_mode_val = params.runtime_perf_knobs[0];
mMultiBlockMode = multi_block_mode_val == 1;
if (common::getEnvForceDeterministicAttention())
{
mMultiBlockMode = false;
}
}

// TODO only for debug usage
if (!mMultiBlockMode)
{
char* isForceMultiBlockModeChar = std::getenv("FORCE_MULTI_BLOCK_MODE");
bool isForceMultiBlockMode
= (isForceMultiBlockModeChar != nullptr && std::string(isForceMultiBlockModeChar) == "ON");
TLLM_CHECK_WITH_INFO(!(common::getEnvForceDeterministicAttention() && isForceMultiBlockMode),
"FORCE_MULTI_BLOCK_MODE and FORCE_DETERMINISTIC/FORCE_ATTENTION_KERNEL_DETERMINISTIC can not be set at "
"the same time.");
mMultiBlockMode = isForceMultiBlockMode;
}
if (common::getEnvForceDeterministicAttention())
{
mMultiBlockMode = false;
}

// TODO only for debug usage
if (!mMultiBlockMode)
{
char* isForceMultiBlockModeChar = std::getenv("FORCE_MULTI_BLOCK_MODE");
bool isForceMultiBlockMode
= (isForceMultiBlockModeChar != nullptr && std::string(isForceMultiBlockModeChar) == "ON");
TLLM_CHECK_WITH_INFO(!(common::getEnvForceDeterministicAttention() && isForceMultiBlockMode),
"FORCE_MULTI_BLOCK_MODE and FORCE_DETERMINISTIC/FORCE_ATTENTION_KERNEL_DETERMINISTIC can not be set at "
"the same time.");
mMultiBlockMode = isForceMultiBlockMode;
}

// Check that the chunked-attention and sliding-window-attention are not enabled at the same time.
Expand Down Expand Up @@ -2723,7 +2719,6 @@ int AttentionOp::initialize() noexcept
{
fixedParams.outputDataType = DATA_TYPE_E4M3;
TLLM_CHECK_WITH_INFO(mNumHeads % mNumKVHeads == 0, "mNumHeads should be multiples of mNumKVHeads.");
TLLM_CHECK_WITH_INFO(!mMultiBlockMode, "Medusa doesn't support multi-block mode.");
}
fixedParams.numQHeads = mNumAttnHeads;
fixedParams.numKvHeads = mNumAttnKVHeads;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -381,5 +381,90 @@ inline int computeMultiBlockCountForMLA(XQAParams const& xqaParams, int multipro
return 1; // disable multi-block for MLA kernel for now.
}

inline int computeMultiBlockCountSpecDecGMMA(
XQAParams const& xqaParams, int batch_size, int multiprocessor_count, int specDecBlocks)
{
auto const userSpecified = tensorrt_llm::common::getEnvXqaBlocksPerSequence();
if (userSpecified.has_value())
{
return userSpecified.value();
}
int multi_block_count = 1;

int num_kv_heads = xqaParams.num_kv_heads;
int history_length = xqaParams.max_past_kv_length;

// skip tuning for large BS or short ISL case.
if (batch_size > 32 || history_length < 2048)
{
return multi_block_count;
}

// gridDim = dim3{specDecBlocks, multi_block, nbKVHeads * xqaParams.batch_size}
int single_block_count = specDecBlocks * num_kv_heads * batch_size;
double wave_count = (double) single_block_count / (double) multiprocessor_count;

// Multi block tuning for low CTA: populating CTAs to at most 1 wave of SMs
if (wave_count < 1)
{
auto highestPowerof2 = [](int x)
{
x |= x >> 1;
x |= x >> 2;
x |= x >> 4;
x |= x >> 8;
x |= x >> 16;
return x ^ (x >> 1);
};

// calculate the maximum blocks to be populated at most 1 wave
multi_block_count = floor(multiprocessor_count / single_block_count);
// make multi_block_count a power of 2 for tuning convenience.
multi_block_count = highestPowerof2(multi_block_count);
// make multi_block_count at most 64 and at least 1.
multi_block_count = std::min(multi_block_count, 64);
multi_block_count = std::max(multi_block_count, 1);

// tune only when original CTA is too small, multi_block_count is too big, and history length < 2^16
// For Hopper, most cases there are 114, 132, 144 SMs. For H20 about 78.
// single_block_count = [1..8]
// multi_block_count = [16,32,64,128]
// history_length = [1024..65536]
if (single_block_count <= 8 && multi_block_count >= 16 && history_length < 65536)
{
if (history_length < 2048)
{
// for history length < 2048 and low CTA, scaling is not effective, so we set a hard limit to
// multi_block_count = 4
multi_block_count = std::min(multi_block_count, 4);
}
else if (history_length < 65536)
{
// at single_block == 8, multi_block_count can only be 16. (SM / 8 ~= 16)
// tune only 2048 <= kvlen < 8192
if (single_block_count == 8 && history_length <= 8192)
{
multi_block_count >>= 1;
}
else
{
auto getLog2 = [](int x) { return x ? 31 - __builtin_clz(x) : -1; };
auto history_length_log2 = getLog2(history_length);
// Adjust multi_block_count based on history length using formula:
// shift_amount = 3 - (log2(history_length) - 10) / 2
// This gives us:
// - history_length in [2^11, 2^12): shift by 3
// - history_length in [2^13, 2^14): shift by 2
// - history_length in [2^15, 2^16): shift by 1
multi_block_count >>= 3 - (history_length_log2 - 10) / 2;
}
}
}
TLLM_CHECK_WITH_INFO((multi_block_count * single_block_count) <= multiprocessor_count,
"The adjusted MultiBlock exceed number of SMs, adding additional wave may result to perf drop.");
}
return multi_block_count;
}

} // namespace kernels
} // namespace tensorrt_llm
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,15 @@ void DecoderXQAImplJIT::runImpl(XQAParams const& xqaParams, KVCacheBuffer const&
uint32_t multi_block = 1;
if (xqaParams.multi_block_mode)
{
multi_block = computeMultiBlockCount(xqaParams, xqaParams.batch_size, multiprocessor_count);
if (isSpecDec && isGMMAKernel)
{
multi_block = computeMultiBlockCountSpecDecGMMA(
xqaParams, xqaParams.batch_size, multiprocessor_count, specDecBlocks);
}
else if (!isSpecDec)
{
multi_block = computeMultiBlockCount(xqaParams, xqaParams.batch_size, multiprocessor_count);
}
}
uint32_t const nbKVHeads = xqaParams.num_kv_heads;
auto const gridDim = (isGMMAKernel ? dim3{specDecBlocks, multi_block, nbKVHeads * xqaParams.batch_size}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -287,14 +287,8 @@ class XQAKernelList
void* kernelParams[] = {&maxQSeqLen, &launchParams.num_k_heads, &headGrpSize, &cuQSeqLens,
&launchParams.output, &xqa_q_input_ptr, &maskPtr, &launchParams.kvCacheParams, &launchParams.batch_size,
&launchParams.kv_scale_quant_orig, &launchParams.scratch};
// precompiled XQA Spec-dec kernel does not support multi-block mode
int multi_block = 1;
if (xqaParams.multi_block_mode)
{
multi_block = computeMultiBlockCount(xqaParams, xqaParams.batch_size, multiprocessor_count);
check_cuda_error(cudaMemsetAsync(xqaParams.workspaces, 0,
sizeof(int) * xqaParams.batch_size * qSeqLen * xqaParams.num_kv_heads, stream));
sync_check_cuda_error(stream);
}
TLLM_CU_CHECK(mDriver->cuLaunchKernel(func, multi_block, xqaParams.num_kv_heads * nbTokenBlocksPerGrp,
xqaParams.batch_size, 128, 1, 2, shared_mem_bytes, stream, kernelParams, nullptr));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,7 @@ GPTAttentionPluginCommon::GPTAttentionPluginCommon(int layer_idx, int num_heads,
mMaskType = mask_type;
mBlockSparseParams = block_sparse_params;
mType = type;
mMultiBlockMode
= is_spec_decoding_enabled ? false : true; // set to true in build time to account for enough workspace size
mMultiBlockMode = true;
mEnableXQA = true;
mKVCacheQuantMode = tc::QuantMode(kv_cache_quant_mode);
mRemovePadding = remove_input_padding;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -703,8 +703,6 @@ int GPTAttentionPlugin::enqueueSome(int32_t seqIdxBeg, int32_t localNbSeq, int32
= static_cast<bool>(reinterpret_cast<int const*>(inputs[getIdx(IdxEntry::SPEC_DECODING_USE)])[0]);
changeSpecDecodingMode = mUseSpecDecoding != useSpecDecoding;
mUseSpecDecoding = useSpecDecoding;
// change mMultiBlockMode to default
mMultiBlockMode = mUseSpecDecoding ? false : true;
}

[[maybe_unused]] MlaParams<T> mla_params;
Expand Down
1 change: 0 additions & 1 deletion cpp/tensorrt_llm/thop/attentionOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,6 @@ void attention_inplace(torch::Tensor q, torch::optional<torch::Tensor> k, torch:
"Expecting 2 bools for spec-dec mode, is_spec_decoding_enabled and use_spec_decoding.");
op->mIsSpecDecodingEnabled = spec_decoding_bool_params[0]; // is_spec_decoding_enabled
op->mUseSpecDecoding = spec_decoding_bool_params[1]; // use_spec_decoding
op->mMultiBlockMode = op->mIsSpecDecodingEnabled ? false : true;

if (is_mla_enable)
{
Expand Down