Skip to content

Commit 92daec1

Browse files
authored
[TRTLLM-7348] [feat] Enable Cross-Attention to use XQA kernels for Whisper (NVIDIA#7035)
Signed-off-by: Dom Brown <[email protected]>
1 parent 8ac7dec commit 92daec1

File tree

4 files changed

+153
-106
lines changed

4 files changed

+153
-106
lines changed

cpp/tensorrt_llm/common/attentionOp.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,9 @@ bool AttentionOp::convertMMHAParamsToXQAParams(tensorrt_llm::kernels::XQAParams&
285285
xqaParams.fp4_out_sf_scale = generationsParams.attention_output_sf_scale;
286286
xqaParams.start_token_idx_sf = generationsParams.start_token_idx_sf;
287287

288+
// Cross attention parameters.
289+
xqaParams.encoder_input_lengths = generationsParams.encoder_input_lengths;
290+
288291
return true;
289292
}
290293

@@ -2229,6 +2232,10 @@ int AttentionOp::enqueueGeneration(EnqueueGenerationParams<T> const& params, cud
22292232
{
22302233
TLLM_CHECK_WITH_INFO(false, "No available kernels are found for FP4 output.");
22312234
}
2235+
else
2236+
{
2237+
TLLM_LOG_DEBUG("XQA kernels are not selected in the generation phase.");
2238+
}
22322239
}
22332240

22342241
// This is the number of kv tokens that q needs to visit, but excluding one as it will be processed before the kv
@@ -2750,7 +2757,7 @@ int AttentionOp::initialize() noexcept
27502757
!useCustomMask() || mEnableContextFMHA, "Only Context FMHA supports custom mask input currently.");
27512758
}
27522759

2753-
mEnableXQA = (mEnableXQA || mIsSpecDecodingEnabled) && !mCrossAttention
2760+
mEnableXQA = (mEnableXQA || mIsSpecDecodingEnabled)
27542761
&& (mType == nvinfer1::DataType::kHALF || mType == nvinfer1::DataType::kBF16) && mUseKVCache;
27552762

27562763
if (mEnableXQA)

cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/xqaParams.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,9 @@ struct XQAParams
106106

107107
void* quant_q_buffer_ptr = nullptr;
108108

109+
// for cross attention
110+
int32_t const* encoder_input_lengths = nullptr;
111+
109112
cudaStream_t stream = 0;
110113

111114
std::string toString() const
@@ -175,6 +178,7 @@ struct XQAParams
175178
<< "total_num_input_tokens :" << total_num_input_tokens << std ::endl
176179
<< "is_fp8_output :" << (is_fp8_output ? "true" : "false") << std ::endl
177180
<< "fp8_out_scale :" << fp8_out_scale << std ::endl
181+
<< "encoder_input_lengths: " << encoder_input_lengths << std::endl
178182
<< "stream :" << stream;
179183

180184
return ss.str();

cpp/tensorrt_llm/kernels/unfusedAttentionKernels/unfusedAttentionKernels_2_template.h

Lines changed: 46 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1348,15 +1348,17 @@ __global__ void updateKVCacheForCrossAttention(QKVPreprocessingParams<T, KVCache
13481348
int const batch_idx = blockIdx.z;
13491349

13501350
// The decoder sequence length.
1351-
int const decoder_seq_len = params.seq_lens[batch_idx];
1351+
// Spec decoding not supported for cross-attention at the moment so we can set 1 and batch_idx here
1352+
int const decoder_seq_len = params.generation_phase ? 1 : params.seq_lens[batch_idx];
13521353
// The decoder sequence offset.
1353-
int const decoder_seq_offset = params.cu_seq_lens[batch_idx];
1354+
int const decoder_seq_offset = params.generation_phase ? batch_idx : params.cu_seq_lens[batch_idx];
13541355
// The decoder cache sequence length (includes the current input).
13551356
int const decoder_cache_seq_len = params.cache_seq_lens[batch_idx];
13561357
// The encoder sequence length.
13571358
int const encoder_seq_len = params.encoder_seq_lens[batch_idx];
13581359
// The encoder sequence offset.
1359-
int const encoder_seq_offset = params.cu_kv_seq_lens[batch_idx];
1360+
// Not needed in Gen phase
1361+
int const encoder_seq_offset = params.generation_phase ? -1 : params.cu_kv_seq_lens[batch_idx];
13601362
// THe maximum sequence length of encoder and decoder.
13611363
int const max_seq_len = max(decoder_seq_len, encoder_seq_len);
13621364

@@ -1411,45 +1413,49 @@ __global__ void updateKVCacheForCrossAttention(QKVPreprocessingParams<T, KVCache
14111413
}
14121414
}
14131415

1414-
// Encoder tokens (i.e. KV tokens).
1415-
if (head_idx == (kv_head_idx * params.qheads_per_kv_head) && token_idx < encoder_seq_len
1416-
&& store_encoder_kv_cache && params.kv_cache_buffer.data != nullptr)
1416+
if (!params.generation_phase)
14171417
{
1418-
// The global token idx in all sequences.
1419-
int global_token_idx = token_idx + encoder_seq_offset;
1420-
1421-
// The memory offset.
1422-
auto const src_k_idx = static_cast<size_t>(global_token_idx) * params.kv_hidden_size * 2 + hidden_idx_kv;
1423-
auto const src_v_idx
1424-
= static_cast<size_t>(global_token_idx) * params.kv_hidden_size * 2 + src_v_offset + hidden_idx_kv;
1425-
1426-
// Only load K,V tokens from encoder qkv input.
1427-
auto k = *reinterpret_cast<VecT const*>(&params.cross_kv_input[src_k_idx]);
1428-
auto v = *reinterpret_cast<VecT const*>(&params.cross_kv_input[src_v_idx]);
1429-
1430-
// The kv cache pointers.
1431-
auto k_cache_block_ptr
1432-
= reinterpret_cast<TCache*>(params.kv_cache_buffer.getKBlockPtr(batch_idx, token_idx));
1433-
auto v_cache_block_ptr
1434-
= reinterpret_cast<TCache*>(params.kv_cache_buffer.getVBlockPtr(batch_idx, token_idx));
1435-
// The vector idx in the cache block.
1436-
auto block_vec_idx
1437-
= params.kv_cache_buffer.getKVLocalIdx(token_idx, kv_head_idx, VECS_PER_HEAD, head_dim_vec_idx);
1438-
1439-
// Store K and V to the cache.
1440-
// INT8/FP8 kv cache.
1441-
if constexpr (sizeof(TCache) == 1)
1442-
{
1443-
// The element index inside the block.
1444-
auto block_elt_idx = block_vec_idx * ELTS_PER_VEC;
1445-
// Store 8bits kv cache.
1446-
mmha::store_8bits_vec(k_cache_block_ptr, k, block_elt_idx, scale_orig_quant);
1447-
mmha::store_8bits_vec(v_cache_block_ptr, v, block_elt_idx, scale_orig_quant);
1448-
}
1449-
else
1418+
// Encoder tokens (i.e. KV tokens).
1419+
if (head_idx == (kv_head_idx * params.qheads_per_kv_head) && token_idx < encoder_seq_len
1420+
&& store_encoder_kv_cache && params.kv_cache_buffer.data != nullptr)
14501421
{
1451-
reinterpret_cast<VecT*>(k_cache_block_ptr)[block_vec_idx] = k;
1452-
reinterpret_cast<VecT*>(v_cache_block_ptr)[block_vec_idx] = v;
1422+
// The global token idx in all sequences.
1423+
int global_token_idx = token_idx + encoder_seq_offset;
1424+
1425+
// The memory offset.
1426+
auto const src_k_idx
1427+
= static_cast<size_t>(global_token_idx) * params.kv_hidden_size * 2 + hidden_idx_kv;
1428+
auto const src_v_idx
1429+
= static_cast<size_t>(global_token_idx) * params.kv_hidden_size * 2 + src_v_offset + hidden_idx_kv;
1430+
1431+
// Only load K,V tokens from encoder qkv input.
1432+
auto k = *reinterpret_cast<VecT const*>(&params.cross_kv_input[src_k_idx]);
1433+
auto v = *reinterpret_cast<VecT const*>(&params.cross_kv_input[src_v_idx]);
1434+
1435+
// The kv cache pointers.
1436+
auto k_cache_block_ptr
1437+
= reinterpret_cast<TCache*>(params.kv_cache_buffer.getKBlockPtr(batch_idx, token_idx));
1438+
auto v_cache_block_ptr
1439+
= reinterpret_cast<TCache*>(params.kv_cache_buffer.getVBlockPtr(batch_idx, token_idx));
1440+
// The vector idx in the cache block.
1441+
auto block_vec_idx
1442+
= params.kv_cache_buffer.getKVLocalIdx(token_idx, kv_head_idx, VECS_PER_HEAD, head_dim_vec_idx);
1443+
1444+
// Store K and V to the cache.
1445+
// INT8/FP8 kv cache.
1446+
if constexpr (sizeof(TCache) == 1)
1447+
{
1448+
// The element index inside the block.
1449+
auto block_elt_idx = block_vec_idx * ELTS_PER_VEC;
1450+
// Store 8bits kv cache.
1451+
mmha::store_8bits_vec(k_cache_block_ptr, k, block_elt_idx, scale_orig_quant);
1452+
mmha::store_8bits_vec(v_cache_block_ptr, v, block_elt_idx, scale_orig_quant);
1453+
}
1454+
else
1455+
{
1456+
reinterpret_cast<VecT*>(k_cache_block_ptr)[block_vec_idx] = k;
1457+
reinterpret_cast<VecT*>(v_cache_block_ptr)[block_vec_idx] = v;
1458+
}
14531459
}
14541460
}
14551461
}

cpp/tensorrt_llm/kernels/xqaDispatcher.cpp

Lines changed: 95 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616

1717
#include "xqaDispatcher.h"
1818
#include "tensorrt_llm/common/cudaUtils.h"
19+
#include "tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplCommon.h"
1920
#include "tensorrt_llm/kernels/unfusedAttentionKernels.h"
21+
#include <cstdint>
2022

2123
namespace
2224
{
@@ -38,6 +40,87 @@ constexpr inline T roundUp(T a, T b)
3840
namespace tensorrt_llm::kernels
3941
{
4042

43+
namespace
44+
{
45+
46+
template <typename T, typename KVCacheBuffer>
47+
QKVPreprocessingParams<T, KVCacheBuffer> makeQKVPreprocessingParams(XQAParams const& params,
48+
XQALaunchParam<KVCacheBuffer> const& launchParams, void* xqa_q_input_ptr, Data_type QDataType,
49+
KvCacheDataType cache_type, int32_t batch_beam_size, KVCacheBuffer const& kv_cache_buffer,
50+
int32_t const* cu_seqlens, int32_t const* cu_kv_seqlens, float const* rotary_inv_freq_buf, int multiProcessorCount)
51+
{
52+
QKVPreprocessingParams<T, KVCacheBuffer> preprocessingParms;
53+
memset(&preprocessingParms, 0, sizeof(preprocessingParms));
54+
// Set parameters.
55+
preprocessingParms.qkv_input = static_cast<T*>(const_cast<void*>(params.qkv));
56+
preprocessingParms.q_output = static_cast<T*>(xqa_q_input_ptr);
57+
preprocessingParms.kv_cache_buffer = kv_cache_buffer;
58+
preprocessingParms.kv_cache_block_scales_buffer = {};
59+
preprocessingParms.qkv_bias = static_cast<T const*>(params.qkv_bias);
60+
// Prepare values for fmha.
61+
preprocessingParms.fmha_bmm1_scale = launchParams.bmm1_scale_ptr;
62+
preprocessingParms.fmha_bmm2_scale = launchParams.bmm2_scale_ptr;
63+
bool const is_fp8_q_input = (QDataType == DATA_TYPE_E4M3);
64+
if (params.kv_cache_quant_mode.hasFp8KvCache())
65+
{
66+
preprocessingParms.q_scale_quant_orig = params.kv_scale_quant_orig;
67+
preprocessingParms.kv_scale_quant_orig = params.kv_scale_quant_orig;
68+
}
69+
if (params.is_fp8_output)
70+
{
71+
preprocessingParms.o_scale_orig_quant = params.fp8_out_scale;
72+
}
73+
// Buffers.
74+
preprocessingParms.logn_scaling = params.logn_scaling_ptr;
75+
preprocessingParms.seq_lens = params.spec_decoding_generation_lengths;
76+
preprocessingParms.cache_seq_lens = params.sequence_lengths;
77+
preprocessingParms.cu_seq_lens = cu_seqlens;
78+
preprocessingParms.rotary_embedding_inv_freq = rotary_inv_freq_buf;
79+
preprocessingParms.rotary_coef_cache_buffer = params.rotary_cos_sin;
80+
preprocessingParms.kvScaleOrigQuant = params.kv_scale_orig_quant;
81+
preprocessingParms.kv_cache_scale_factors = nullptr;
82+
preprocessingParms.spec_decoding_position_offsets
83+
= params.cross_attention ? nullptr : params.spec_decoding_position_offsets;
84+
preprocessingParms.mrope_position_deltas = params.mrope_position_deltas;
85+
// Scalar parameters.
86+
preprocessingParms.batch_size = int(batch_beam_size);
87+
preprocessingParms.max_input_seq_len = params.generation_input_length;
88+
preprocessingParms.max_kv_seq_len = params.max_past_kv_length;
89+
preprocessingParms.cyclic_kv_cache_len
90+
= params.cross_attention ? params.max_past_kv_length : params.cyclic_attention_window_size;
91+
preprocessingParms.sink_token_len = params.cross_attention ? 0 : params.sink_token_length;
92+
preprocessingParms.token_num = params.total_num_input_tokens;
93+
preprocessingParms.remove_padding = true;
94+
preprocessingParms.cross_attention = params.cross_attention;
95+
preprocessingParms.head_num = params.num_q_heads;
96+
preprocessingParms.kv_head_num = params.num_kv_heads;
97+
preprocessingParms.qheads_per_kv_head = params.num_q_heads / params.num_kv_heads;
98+
preprocessingParms.size_per_head = params.head_size;
99+
preprocessingParms.fmha_host_bmm1_scale = 1.0f / (sqrtf(params.head_size * 1.0f) * params.q_scaling);
100+
preprocessingParms.rotary_embedding_dim = params.rotary_embedding_dim;
101+
preprocessingParms.rotary_embedding_base = params.rotary_embedding_base;
102+
preprocessingParms.rotary_scale_type = params.rotary_embedding_scale_type;
103+
preprocessingParms.rotary_embedding_scale = params.rotary_embedding_scale;
104+
preprocessingParms.rotary_embedding_max_positions = params.rotary_embedding_max_positions;
105+
preprocessingParms.position_embedding_type = params.position_embedding_type;
106+
preprocessingParms.position_shift_enabled = params.position_shift_enabled;
107+
preprocessingParms.cache_type = cache_type;
108+
preprocessingParms.separate_q_kv_output = true;
109+
preprocessingParms.quantized_fp8_output = is_fp8_q_input;
110+
preprocessingParms.generation_phase = true;
111+
preprocessingParms.multi_processor_count = multiProcessorCount;
112+
preprocessingParms.rotary_vision_start = params.rotary_vision_start;
113+
preprocessingParms.rotary_vision_length = params.rotary_vision_length;
114+
115+
// Cross-attention only.
116+
117+
preprocessingParms.encoder_seq_lens = params.encoder_input_lengths;
118+
119+
return preprocessingParms;
120+
}
121+
122+
} // namespace
123+
41124
////////////////////////////////////////////////////////////////////////////////////////////////////
42125

43126
XqaDispatcher::XqaDispatcher(XqaFixedParams fixedParams)
@@ -137,9 +220,10 @@ bool XqaDispatcher::shouldUse(XQAParams const& params)
137220
{
138221
SHOULD_NOT_USE("Fallback to MMHA as unidirectional is not supported by TRTLLM-GEN kernels.");
139222
}
140-
if (params.cross_attention)
223+
if (params.cross_attention && !params.paged_kv_cache)
141224
{
142-
SHOULD_NOT_USE("Fallback to MMHA as cross attention is not supported by TRTLLM-GEN kernels.");
225+
SHOULD_NOT_USE(
226+
"Fallback to MMHA as cross attention without paged KV Cache is not supported by TRTLLM-GEN kernels.");
143227
}
144228
if (params.paged_kv_cache && params.tokens_per_block < 8)
145229
{
@@ -252,8 +336,8 @@ void XqaDispatcher::runImpl(XQAParams params, KVCacheBuffer const& kv_cache_buff
252336
decoder_params.seqQOffsets = launchParams.cu_seq_lens;
253337
decoder_params.seqKVOffsets = launchParams.cu_kv_seq_lens;
254338
decoder_params.seqQLengths = params.spec_decoding_generation_lengths;
255-
decoder_params.seqKVLengths = params.sequence_lengths;
256-
decoder_params.batchSize = int(batch_beam_size);
339+
decoder_params.seqKVLengths = params.cross_attention ? params.encoder_input_lengths : params.sequence_lengths;
340+
decoder_params.batchSize = static_cast<int>(batch_beam_size);
257341
decoder_params.maxQSeqLength = params.generation_input_length;
258342
decoder_params.numTokens = params.total_num_input_tokens;
259343
decoder_params.removePadding = true;
@@ -273,10 +357,12 @@ void XqaDispatcher::runImpl(XQAParams params, KVCacheBuffer const& kv_cache_buff
273357
float const* rotary_inv_freq_buf = params.rotary_embedding_inv_freq_cache;
274358
// Use the nullptr for cu_seqlens when it is not computed.
275359
int const* cu_seqlens{nullptr};
360+
int const* cu_kv_seqlens{nullptr};
276361
if (decoder_params.isBuildDecoderInfoKernelNeeded())
277362
{
278363
rotary_inv_freq_buf = launchParams.rotary_inv_freq_buf;
279364
cu_seqlens = launchParams.cu_seq_lens;
365+
cu_kv_seqlens = launchParams.cu_kv_seq_lens;
280366
invokeBuildDecoderInfo(decoder_params, params.stream);
281367
sync_check_cuda_error(params.stream);
282368
}
@@ -285,66 +371,10 @@ void XqaDispatcher::runImpl(XQAParams params, KVCacheBuffer const& kv_cache_buff
285371
// NOTE: MHA kernels should read kv cache that has already been appended with new tokens' kv cache.
286372
void* xqa_q_input_ptr = inputScratch;
287373
// The preprocessing kernel that applies RoPE and updates kv cache.
288-
QKVPreprocessingParams<T, KVCacheBuffer> preprocessingParms;
289-
memset(&preprocessingParms, 0, sizeof(preprocessingParms));
290-
// Set parameters.
291-
preprocessingParms.qkv_input = static_cast<T*>(const_cast<void*>(params.qkv));
292-
preprocessingParms.q_output = static_cast<T*>(xqa_q_input_ptr);
293-
preprocessingParms.kv_cache_buffer = kv_cache_buffer;
294-
preprocessingParms.kv_cache_block_scales_buffer = {};
295-
preprocessingParms.qkv_bias = static_cast<T const*>(params.qkv_bias);
296-
// Prepare values for fmha.
297-
preprocessingParms.fmha_bmm1_scale = launchParams.bmm1_scale_ptr;
298-
preprocessingParms.fmha_bmm2_scale = launchParams.bmm2_scale_ptr;
299-
bool const is_fp8_q_input = (mQDataType == DATA_TYPE_E4M3);
300-
if (params.kv_cache_quant_mode.hasFp8KvCache())
301-
{
302-
preprocessingParms.q_scale_quant_orig = params.kv_scale_quant_orig;
303-
preprocessingParms.kv_scale_quant_orig = params.kv_scale_quant_orig;
304-
}
305-
if (params.is_fp8_output)
306-
{
307-
preprocessingParms.o_scale_orig_quant = params.fp8_out_scale;
308-
}
309-
// Buffers.
310-
preprocessingParms.logn_scaling = params.logn_scaling_ptr;
311-
preprocessingParms.seq_lens = params.spec_decoding_generation_lengths;
312-
preprocessingParms.cache_seq_lens = params.sequence_lengths;
313-
preprocessingParms.cu_seq_lens = cu_seqlens;
314-
preprocessingParms.rotary_embedding_inv_freq = rotary_inv_freq_buf;
315-
preprocessingParms.rotary_coef_cache_buffer = params.rotary_cos_sin;
316-
preprocessingParms.kvScaleOrigQuant = params.kv_scale_orig_quant;
317-
preprocessingParms.kv_cache_scale_factors = nullptr;
318-
preprocessingParms.spec_decoding_position_offsets = params.spec_decoding_position_offsets;
319-
preprocessingParms.mrope_position_deltas = params.mrope_position_deltas;
320-
// Scalar parameters.
321-
preprocessingParms.batch_size = int(batch_beam_size);
322-
preprocessingParms.max_input_seq_len = params.generation_input_length;
323-
preprocessingParms.max_kv_seq_len = params.max_past_kv_length;
324-
preprocessingParms.cyclic_kv_cache_len = params.cyclic_attention_window_size;
325-
preprocessingParms.sink_token_len = params.sink_token_length;
326-
preprocessingParms.token_num = params.total_num_input_tokens;
327-
preprocessingParms.remove_padding = true;
328-
preprocessingParms.cross_attention = false;
329-
preprocessingParms.head_num = params.num_q_heads;
330-
preprocessingParms.kv_head_num = params.num_kv_heads;
331-
preprocessingParms.qheads_per_kv_head = params.num_q_heads / params.num_kv_heads;
332-
preprocessingParms.size_per_head = params.head_size;
333-
preprocessingParms.fmha_host_bmm1_scale = 1.0f / (sqrtf(params.head_size * 1.0f) * params.q_scaling);
334-
preprocessingParms.rotary_embedding_dim = params.rotary_embedding_dim;
335-
preprocessingParms.rotary_embedding_base = params.rotary_embedding_base;
336-
preprocessingParms.rotary_scale_type = params.rotary_embedding_scale_type;
337-
preprocessingParms.rotary_embedding_scale = params.rotary_embedding_scale;
338-
preprocessingParms.rotary_embedding_max_positions = params.rotary_embedding_max_positions;
339-
preprocessingParms.position_embedding_type = params.position_embedding_type;
340-
preprocessingParms.position_shift_enabled = params.position_shift_enabled;
341-
preprocessingParms.cache_type = cache_type;
342-
preprocessingParms.separate_q_kv_output = true;
343-
preprocessingParms.quantized_fp8_output = is_fp8_q_input;
344-
preprocessingParms.generation_phase = true;
345-
preprocessingParms.multi_processor_count = mMultiProcessorCount;
346-
preprocessingParms.rotary_vision_start = params.rotary_vision_start;
347-
preprocessingParms.rotary_vision_length = params.rotary_vision_length;
374+
375+
auto preprocessingParms = makeQKVPreprocessingParams<T, KVCacheBuffer>(params, launchParams, xqa_q_input_ptr,
376+
mQDataType, cache_type, batch_beam_size, kv_cache_buffer, cu_seqlens, cu_kv_seqlens, rotary_inv_freq_buf,
377+
mMultiProcessorCount);
348378

349379
invokeQKVPreprocessing<T, KVCacheBuffer>(preprocessingParms, params.stream);
350380
sync_check_cuda_error(params.stream);
@@ -394,7 +424,7 @@ void XqaDispatcher::runImpl(XQAParams params, KVCacheBuffer const& kv_cache_buff
394424
= reinterpret_cast<float const*>(launchParams.bmm1_scale_ptr + kIdxScaleSoftmaxLog2Ptr);
395425
tllmRunnerParams.oSfScalePtr = params.fp4_out_sf_scale;
396426
// The sequence lengths for K/V.
397-
tllmRunnerParams.seqLensKvPtr = params.sequence_lengths;
427+
tllmRunnerParams.seqLensKvPtr = params.cross_attention ? params.encoder_input_lengths : params.sequence_lengths;
398428

399429
tllmRunnerParams.oPtr = params.output;
400430
tllmRunnerParams.oSfPtr = params.output_sf;

0 commit comments

Comments
 (0)