1616
1717#include " xqaDispatcher.h"
1818#include " tensorrt_llm/common/cudaUtils.h"
19+ #include " tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplCommon.h"
1920#include " tensorrt_llm/kernels/unfusedAttentionKernels.h"
21+ #include < cstdint>
2022
2123namespace
2224{
@@ -38,6 +40,87 @@ constexpr inline T roundUp(T a, T b)
3840namespace tensorrt_llm ::kernels
3941{
4042
43+ namespace
44+ {
45+
46+ template <typename T, typename KVCacheBuffer>
47+ QKVPreprocessingParams<T, KVCacheBuffer> makeQKVPreprocessingParams (XQAParams const & params,
48+ XQALaunchParam<KVCacheBuffer> const & launchParams, void * xqa_q_input_ptr, Data_type QDataType,
49+ KvCacheDataType cache_type, int32_t batch_beam_size, KVCacheBuffer const & kv_cache_buffer,
50+ int32_t const * cu_seqlens, int32_t const * cu_kv_seqlens, float const * rotary_inv_freq_buf, int multiProcessorCount)
51+ {
52+ QKVPreprocessingParams<T, KVCacheBuffer> preprocessingParms;
53+ memset (&preprocessingParms, 0 , sizeof (preprocessingParms));
54+ // Set parameters.
55+ preprocessingParms.qkv_input = static_cast <T*>(const_cast <void *>(params.qkv ));
56+ preprocessingParms.q_output = static_cast <T*>(xqa_q_input_ptr);
57+ preprocessingParms.kv_cache_buffer = kv_cache_buffer;
58+ preprocessingParms.kv_cache_block_scales_buffer = {};
59+ preprocessingParms.qkv_bias = static_cast <T const *>(params.qkv_bias );
60+ // Prepare values for fmha.
61+ preprocessingParms.fmha_bmm1_scale = launchParams.bmm1_scale_ptr ;
62+ preprocessingParms.fmha_bmm2_scale = launchParams.bmm2_scale_ptr ;
63+ bool const is_fp8_q_input = (QDataType == DATA_TYPE_E4M3);
64+ if (params.kv_cache_quant_mode .hasFp8KvCache ())
65+ {
66+ preprocessingParms.q_scale_quant_orig = params.kv_scale_quant_orig ;
67+ preprocessingParms.kv_scale_quant_orig = params.kv_scale_quant_orig ;
68+ }
69+ if (params.is_fp8_output )
70+ {
71+ preprocessingParms.o_scale_orig_quant = params.fp8_out_scale ;
72+ }
73+ // Buffers.
74+ preprocessingParms.logn_scaling = params.logn_scaling_ptr ;
75+ preprocessingParms.seq_lens = params.spec_decoding_generation_lengths ;
76+ preprocessingParms.cache_seq_lens = params.sequence_lengths ;
77+ preprocessingParms.cu_seq_lens = cu_seqlens;
78+ preprocessingParms.rotary_embedding_inv_freq = rotary_inv_freq_buf;
79+ preprocessingParms.rotary_coef_cache_buffer = params.rotary_cos_sin ;
80+ preprocessingParms.kvScaleOrigQuant = params.kv_scale_orig_quant ;
81+ preprocessingParms.kv_cache_scale_factors = nullptr ;
82+ preprocessingParms.spec_decoding_position_offsets
83+ = params.cross_attention ? nullptr : params.spec_decoding_position_offsets ;
84+ preprocessingParms.mrope_position_deltas = params.mrope_position_deltas ;
85+ // Scalar parameters.
86+ preprocessingParms.batch_size = int (batch_beam_size);
87+ preprocessingParms.max_input_seq_len = params.generation_input_length ;
88+ preprocessingParms.max_kv_seq_len = params.max_past_kv_length ;
89+ preprocessingParms.cyclic_kv_cache_len
90+ = params.cross_attention ? params.max_past_kv_length : params.cyclic_attention_window_size ;
91+ preprocessingParms.sink_token_len = params.cross_attention ? 0 : params.sink_token_length ;
92+ preprocessingParms.token_num = params.total_num_input_tokens ;
93+ preprocessingParms.remove_padding = true ;
94+ preprocessingParms.cross_attention = params.cross_attention ;
95+ preprocessingParms.head_num = params.num_q_heads ;
96+ preprocessingParms.kv_head_num = params.num_kv_heads ;
97+ preprocessingParms.qheads_per_kv_head = params.num_q_heads / params.num_kv_heads ;
98+ preprocessingParms.size_per_head = params.head_size ;
99+ preprocessingParms.fmha_host_bmm1_scale = 1 .0f / (sqrtf (params.head_size * 1 .0f ) * params.q_scaling );
100+ preprocessingParms.rotary_embedding_dim = params.rotary_embedding_dim ;
101+ preprocessingParms.rotary_embedding_base = params.rotary_embedding_base ;
102+ preprocessingParms.rotary_scale_type = params.rotary_embedding_scale_type ;
103+ preprocessingParms.rotary_embedding_scale = params.rotary_embedding_scale ;
104+ preprocessingParms.rotary_embedding_max_positions = params.rotary_embedding_max_positions ;
105+ preprocessingParms.position_embedding_type = params.position_embedding_type ;
106+ preprocessingParms.position_shift_enabled = params.position_shift_enabled ;
107+ preprocessingParms.cache_type = cache_type;
108+ preprocessingParms.separate_q_kv_output = true ;
109+ preprocessingParms.quantized_fp8_output = is_fp8_q_input;
110+ preprocessingParms.generation_phase = true ;
111+ preprocessingParms.multi_processor_count = multiProcessorCount;
112+ preprocessingParms.rotary_vision_start = params.rotary_vision_start ;
113+ preprocessingParms.rotary_vision_length = params.rotary_vision_length ;
114+
115+ // Cross-attention only.
116+
117+ preprocessingParms.encoder_seq_lens = params.encoder_input_lengths ;
118+
119+ return preprocessingParms;
120+ }
121+
122+ } // namespace
123+
41124// //////////////////////////////////////////////////////////////////////////////////////////////////
42125
43126XqaDispatcher::XqaDispatcher (XqaFixedParams fixedParams)
@@ -137,9 +220,10 @@ bool XqaDispatcher::shouldUse(XQAParams const& params)
137220 {
138221 SHOULD_NOT_USE (" Fallback to MMHA as unidirectional is not supported by TRTLLM-GEN kernels." );
139222 }
140- if (params.cross_attention )
223+ if (params.cross_attention && !params. paged_kv_cache )
141224 {
142- SHOULD_NOT_USE (" Fallback to MMHA as cross attention is not supported by TRTLLM-GEN kernels." );
225+ SHOULD_NOT_USE (
226+ " Fallback to MMHA as cross attention without paged KV Cache is not supported by TRTLLM-GEN kernels." );
143227 }
144228 if (params.paged_kv_cache && params.tokens_per_block < 8 )
145229 {
@@ -252,8 +336,8 @@ void XqaDispatcher::runImpl(XQAParams params, KVCacheBuffer const& kv_cache_buff
252336 decoder_params.seqQOffsets = launchParams.cu_seq_lens ;
253337 decoder_params.seqKVOffsets = launchParams.cu_kv_seq_lens ;
254338 decoder_params.seqQLengths = params.spec_decoding_generation_lengths ;
255- decoder_params.seqKVLengths = params.sequence_lengths ;
256- decoder_params.batchSize = int (batch_beam_size);
339+ decoder_params.seqKVLengths = params.cross_attention ? params. encoder_input_lengths : params. sequence_lengths ;
340+ decoder_params.batchSize = static_cast < int > (batch_beam_size);
257341 decoder_params.maxQSeqLength = params.generation_input_length ;
258342 decoder_params.numTokens = params.total_num_input_tokens ;
259343 decoder_params.removePadding = true ;
@@ -273,10 +357,12 @@ void XqaDispatcher::runImpl(XQAParams params, KVCacheBuffer const& kv_cache_buff
273357 float const * rotary_inv_freq_buf = params.rotary_embedding_inv_freq_cache ;
274358 // Use the nullptr for cu_seqlens when it is not computed.
275359 int const * cu_seqlens{nullptr };
360+ int const * cu_kv_seqlens{nullptr };
276361 if (decoder_params.isBuildDecoderInfoKernelNeeded ())
277362 {
278363 rotary_inv_freq_buf = launchParams.rotary_inv_freq_buf ;
279364 cu_seqlens = launchParams.cu_seq_lens ;
365+ cu_kv_seqlens = launchParams.cu_kv_seq_lens ;
280366 invokeBuildDecoderInfo (decoder_params, params.stream );
281367 sync_check_cuda_error (params.stream );
282368 }
@@ -285,66 +371,10 @@ void XqaDispatcher::runImpl(XQAParams params, KVCacheBuffer const& kv_cache_buff
285371 // NOTE: MHA kernels should read kv cache that has already been appended with new tokens' kv cache.
286372 void * xqa_q_input_ptr = inputScratch;
287373 // The preprocessing kernel that applies RoPE and updates kv cache.
288- QKVPreprocessingParams<T, KVCacheBuffer> preprocessingParms;
289- memset (&preprocessingParms, 0 , sizeof (preprocessingParms));
290- // Set parameters.
291- preprocessingParms.qkv_input = static_cast <T*>(const_cast <void *>(params.qkv ));
292- preprocessingParms.q_output = static_cast <T*>(xqa_q_input_ptr);
293- preprocessingParms.kv_cache_buffer = kv_cache_buffer;
294- preprocessingParms.kv_cache_block_scales_buffer = {};
295- preprocessingParms.qkv_bias = static_cast <T const *>(params.qkv_bias );
296- // Prepare values for fmha.
297- preprocessingParms.fmha_bmm1_scale = launchParams.bmm1_scale_ptr ;
298- preprocessingParms.fmha_bmm2_scale = launchParams.bmm2_scale_ptr ;
299- bool const is_fp8_q_input = (mQDataType == DATA_TYPE_E4M3);
300- if (params.kv_cache_quant_mode .hasFp8KvCache ())
301- {
302- preprocessingParms.q_scale_quant_orig = params.kv_scale_quant_orig ;
303- preprocessingParms.kv_scale_quant_orig = params.kv_scale_quant_orig ;
304- }
305- if (params.is_fp8_output )
306- {
307- preprocessingParms.o_scale_orig_quant = params.fp8_out_scale ;
308- }
309- // Buffers.
310- preprocessingParms.logn_scaling = params.logn_scaling_ptr ;
311- preprocessingParms.seq_lens = params.spec_decoding_generation_lengths ;
312- preprocessingParms.cache_seq_lens = params.sequence_lengths ;
313- preprocessingParms.cu_seq_lens = cu_seqlens;
314- preprocessingParms.rotary_embedding_inv_freq = rotary_inv_freq_buf;
315- preprocessingParms.rotary_coef_cache_buffer = params.rotary_cos_sin ;
316- preprocessingParms.kvScaleOrigQuant = params.kv_scale_orig_quant ;
317- preprocessingParms.kv_cache_scale_factors = nullptr ;
318- preprocessingParms.spec_decoding_position_offsets = params.spec_decoding_position_offsets ;
319- preprocessingParms.mrope_position_deltas = params.mrope_position_deltas ;
320- // Scalar parameters.
321- preprocessingParms.batch_size = int (batch_beam_size);
322- preprocessingParms.max_input_seq_len = params.generation_input_length ;
323- preprocessingParms.max_kv_seq_len = params.max_past_kv_length ;
324- preprocessingParms.cyclic_kv_cache_len = params.cyclic_attention_window_size ;
325- preprocessingParms.sink_token_len = params.sink_token_length ;
326- preprocessingParms.token_num = params.total_num_input_tokens ;
327- preprocessingParms.remove_padding = true ;
328- preprocessingParms.cross_attention = false ;
329- preprocessingParms.head_num = params.num_q_heads ;
330- preprocessingParms.kv_head_num = params.num_kv_heads ;
331- preprocessingParms.qheads_per_kv_head = params.num_q_heads / params.num_kv_heads ;
332- preprocessingParms.size_per_head = params.head_size ;
333- preprocessingParms.fmha_host_bmm1_scale = 1 .0f / (sqrtf (params.head_size * 1 .0f ) * params.q_scaling );
334- preprocessingParms.rotary_embedding_dim = params.rotary_embedding_dim ;
335- preprocessingParms.rotary_embedding_base = params.rotary_embedding_base ;
336- preprocessingParms.rotary_scale_type = params.rotary_embedding_scale_type ;
337- preprocessingParms.rotary_embedding_scale = params.rotary_embedding_scale ;
338- preprocessingParms.rotary_embedding_max_positions = params.rotary_embedding_max_positions ;
339- preprocessingParms.position_embedding_type = params.position_embedding_type ;
340- preprocessingParms.position_shift_enabled = params.position_shift_enabled ;
341- preprocessingParms.cache_type = cache_type;
342- preprocessingParms.separate_q_kv_output = true ;
343- preprocessingParms.quantized_fp8_output = is_fp8_q_input;
344- preprocessingParms.generation_phase = true ;
345- preprocessingParms.multi_processor_count = mMultiProcessorCount ;
346- preprocessingParms.rotary_vision_start = params.rotary_vision_start ;
347- preprocessingParms.rotary_vision_length = params.rotary_vision_length ;
374+
375+ auto preprocessingParms = makeQKVPreprocessingParams<T, KVCacheBuffer>(params, launchParams, xqa_q_input_ptr,
376+ mQDataType , cache_type, batch_beam_size, kv_cache_buffer, cu_seqlens, cu_kv_seqlens, rotary_inv_freq_buf,
377+ mMultiProcessorCount );
348378
349379 invokeQKVPreprocessing<T, KVCacheBuffer>(preprocessingParms, params.stream );
350380 sync_check_cuda_error (params.stream );
@@ -394,7 +424,7 @@ void XqaDispatcher::runImpl(XQAParams params, KVCacheBuffer const& kv_cache_buff
394424 = reinterpret_cast <float const *>(launchParams.bmm1_scale_ptr + kIdxScaleSoftmaxLog2Ptr );
395425 tllmRunnerParams.oSfScalePtr = params.fp4_out_sf_scale ;
396426 // The sequence lengths for K/V.
397- tllmRunnerParams.seqLensKvPtr = params.sequence_lengths ;
427+ tllmRunnerParams.seqLensKvPtr = params.cross_attention ? params. encoder_input_lengths : params. sequence_lengths ;
398428
399429 tllmRunnerParams.oPtr = params.output ;
400430 tllmRunnerParams.oSfPtr = params.output_sf ;
0 commit comments