Skip to content

Commit 9ae979b

Browse files
committed
Add basic support for cross-attention to XQA dispatch in support of Whisper
Signed-off-by: Dom Brown <[email protected]>
1 parent 953f4fd commit 9ae979b

File tree

7 files changed

+188
-106
lines changed

7 files changed

+188
-106
lines changed

cpp/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,7 @@ endif()
358358

359359
setup_sanitizers()
360360

361-
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-extended-lambda")
361+
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-extended-lambda -lineinfo")
362362
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr")
363363
if(FAST_MATH)
364364
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --use_fast_math")

cpp/tensorrt_llm/common/attentionOp.cpp

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
3131
#include <algorithm>
3232
#include <cstdint>
33+
#include <cstdlib>
3334
#include <type_traits>
3435

3536
using namespace tensorrt_llm::kernels;
@@ -285,6 +286,13 @@ bool AttentionOp::convertMMHAParamsToXQAParams(tensorrt_llm::kernels::XQAParams&
285286
xqaParams.fp4_out_sf_scale = generationsParams.attention_output_sf_scale;
286287
xqaParams.start_token_idx_sf = generationsParams.start_token_idx_sf;
287288

289+
xqaParams.num_tokens = generationsParams.num_tokens;
290+
// Cross attention parameters.
291+
xqaParams.encoder_input_lengths = generationsParams.encoder_input_lengths;
292+
// xqaParams.cross_kv = generationsParams.cross_kv;
293+
// xqaParams.cross_kv_length = generationsParams.cross_kv_length;
294+
// xqaParams.num_encoder_tokens = generationsParams.num_encoder_tokens;
295+
288296
return true;
289297
}
290298

@@ -2210,6 +2218,10 @@ int AttentionOp::enqueueGeneration(EnqueueGenerationParams<T> const& params, cud
22102218
{
22112219
TLLM_CHECK_WITH_INFO(false, "No available kernels are found for FP4 output.");
22122220
}
2221+
else
2222+
{
2223+
TLLM_LOG_DEBUG("XQA kernels are not selected in the generation phase. mEnableXQA: %d", mEnableXQA);
2224+
}
22132225
}
22142226

22152227
// This is the number of kv tokens that q needs to visit, but excluding one as it will be processed before the kv
@@ -2728,7 +2740,7 @@ int AttentionOp::initialize() noexcept
27282740
!useCustomMask() || mEnableContextFMHA, "Only Context FMHA supports custom mask input currently.");
27292741
}
27302742

2731-
mEnableXQA = (mEnableXQA || mIsSpecDecodingEnabled) && !mCrossAttention
2743+
mEnableXQA = (mEnableXQA || mIsSpecDecodingEnabled)
27322744
&& (mType == nvinfer1::DataType::kHALF || mType == nvinfer1::DataType::kBF16) && mUseKVCache;
27332745

27342746
if (mEnableXQA)

cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/xqaParams.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,13 @@ struct XQAParams
106106

107107
void* quant_q_buffer_ptr = nullptr;
108108

109+
int num_tokens;
110+
// for cross attention
111+
int32_t const* encoder_input_lengths = nullptr;
112+
// void const* cross_kv = nullptr;
113+
// int32_t cross_kv_length = 0;
114+
// int32_t num_encoder_tokens = 0;
115+
109116
cudaStream_t stream = 0;
110117

111118
std::string toString() const
@@ -175,6 +182,11 @@ struct XQAParams
175182
<< "total_num_input_tokens :" << total_num_input_tokens << std ::endl
176183
<< "is_fp8_output :" << (is_fp8_output ? "true" : "false") << std ::endl
177184
<< "fp8_out_scale :" << fp8_out_scale << std ::endl
185+
<< "encoder_input_lengths: " << encoder_input_lengths
186+
<< std::endl
187+
//<< "cross_kv: " << cross_kv << std::endl
188+
//<< "cross_kv_length: " << cross_kv_length << std::endl
189+
//<< "num_encoder_tokens: " << num_encoder_tokens << std::endl
178190
<< "stream :" << stream;
179191

180192
return ss.str();

cpp/tensorrt_llm/kernels/gptKernels.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,7 @@ struct BuildDecoderInfoParams
186186
float2* rotaryEmbeddingCoeffCache;
187187
// Dynamic scaling;
188188
int rotaryEmbeddingMaxPositions;
189+
bool isCrossAttention{false};
189190

190191
bool isBuildDecoderInfoKernelNeeded()
191192
{
@@ -213,6 +214,10 @@ struct BuildDecoderInfoParams
213214
{
214215
return true;
215216
}
217+
if (isCrossAttention)
218+
{
219+
return true;
220+
}
216221
// Other cases don't need to call buildDecoderInfo kernel.
217222
return false;
218223
}

cpp/tensorrt_llm/kernels/unfusedAttentionKernels/unfusedAttentionKernels_2_template.h

Lines changed: 44 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1348,9 +1348,10 @@ __global__ void updateKVCacheForCrossAttention(QKVPreprocessingParams<T, KVCache
13481348
int const batch_idx = blockIdx.z;
13491349

13501350
// The decoder sequence length.
1351-
int const decoder_seq_len = params.seq_lens[batch_idx];
1351+
// Spec decoding not supported for cross-attention at the moment so we can set 1 and batch_idx here
1352+
int const decoder_seq_len = params.generation_phase ? 1 : params.seq_lens[batch_idx];
13521353
// The decoder sequence offset.
1353-
int const decoder_seq_offset = params.cu_seq_lens[batch_idx];
1354+
int const decoder_seq_offset = params.generation_phase ? batch_idx : params.cu_seq_lens[batch_idx];
13541355
// The decoder cache sequence length (includes the current input).
13551356
int const decoder_cache_seq_len = params.cache_seq_lens[batch_idx];
13561357
// The encoder sequence length.
@@ -1411,45 +1412,49 @@ __global__ void updateKVCacheForCrossAttention(QKVPreprocessingParams<T, KVCache
14111412
}
14121413
}
14131414

1414-
// Encoder tokens (i.e. KV tokens).
1415-
if (head_idx == (kv_head_idx * params.qheads_per_kv_head) && token_idx < encoder_seq_len
1416-
&& store_encoder_kv_cache && params.kv_cache_buffer.data != nullptr)
1415+
if (!params.generation_phase)
14171416
{
1418-
// The global token idx in all sequences.
1419-
int global_token_idx = token_idx + encoder_seq_offset;
1420-
1421-
// The memory offset.
1422-
auto const src_k_idx = static_cast<size_t>(global_token_idx) * params.kv_hidden_size * 2 + hidden_idx_kv;
1423-
auto const src_v_idx
1424-
= static_cast<size_t>(global_token_idx) * params.kv_hidden_size * 2 + src_v_offset + hidden_idx_kv;
1425-
1426-
// Only load K,V tokens from encoder qkv input.
1427-
auto k = *reinterpret_cast<VecT const*>(&params.cross_kv_input[src_k_idx]);
1428-
auto v = *reinterpret_cast<VecT const*>(&params.cross_kv_input[src_v_idx]);
1429-
1430-
// The kv cache pointers.
1431-
auto k_cache_block_ptr
1432-
= reinterpret_cast<TCache*>(params.kv_cache_buffer.getKBlockPtr(batch_idx, token_idx));
1433-
auto v_cache_block_ptr
1434-
= reinterpret_cast<TCache*>(params.kv_cache_buffer.getVBlockPtr(batch_idx, token_idx));
1435-
// The vector idx in the cache block.
1436-
auto block_vec_idx
1437-
= params.kv_cache_buffer.getKVLocalIdx(token_idx, kv_head_idx, VECS_PER_HEAD, head_dim_vec_idx);
1438-
1439-
// Store K and V to the cache.
1440-
// INT8/FP8 kv cache.
1441-
if constexpr (sizeof(TCache) == 1)
1442-
{
1443-
// The element index inside the block.
1444-
auto block_elt_idx = block_vec_idx * ELTS_PER_VEC;
1445-
// Store 8bits kv cache.
1446-
mmha::store_8bits_vec(k_cache_block_ptr, k, block_elt_idx, scale_orig_quant);
1447-
mmha::store_8bits_vec(v_cache_block_ptr, v, block_elt_idx, scale_orig_quant);
1448-
}
1449-
else
1417+
// Encoder tokens (i.e. KV tokens).
1418+
if (head_idx == (kv_head_idx * params.qheads_per_kv_head) && token_idx < encoder_seq_len
1419+
&& store_encoder_kv_cache && params.kv_cache_buffer.data != nullptr)
14501420
{
1451-
reinterpret_cast<VecT*>(k_cache_block_ptr)[block_vec_idx] = k;
1452-
reinterpret_cast<VecT*>(v_cache_block_ptr)[block_vec_idx] = v;
1421+
// The global token idx in all sequences.
1422+
int global_token_idx = token_idx + encoder_seq_offset;
1423+
1424+
// The memory offset.
1425+
auto const src_k_idx
1426+
= static_cast<size_t>(global_token_idx) * params.kv_hidden_size * 2 + hidden_idx_kv;
1427+
auto const src_v_idx
1428+
= static_cast<size_t>(global_token_idx) * params.kv_hidden_size * 2 + src_v_offset + hidden_idx_kv;
1429+
1430+
// Only load K,V tokens from encoder qkv input.
1431+
auto k = *reinterpret_cast<VecT const*>(&params.cross_kv_input[src_k_idx]);
1432+
auto v = *reinterpret_cast<VecT const*>(&params.cross_kv_input[src_v_idx]);
1433+
1434+
// The kv cache pointers.
1435+
auto k_cache_block_ptr
1436+
= reinterpret_cast<TCache*>(params.kv_cache_buffer.getKBlockPtr(batch_idx, token_idx));
1437+
auto v_cache_block_ptr
1438+
= reinterpret_cast<TCache*>(params.kv_cache_buffer.getVBlockPtr(batch_idx, token_idx));
1439+
// The vector idx in the cache block.
1440+
auto block_vec_idx
1441+
= params.kv_cache_buffer.getKVLocalIdx(token_idx, kv_head_idx, VECS_PER_HEAD, head_dim_vec_idx);
1442+
1443+
// Store K and V to the cache.
1444+
// INT8/FP8 kv cache.
1445+
if constexpr (sizeof(TCache) == 1)
1446+
{
1447+
// The element index inside the block.
1448+
auto block_elt_idx = block_vec_idx * ELTS_PER_VEC;
1449+
// Store 8bits kv cache.
1450+
mmha::store_8bits_vec(k_cache_block_ptr, k, block_elt_idx, scale_orig_quant);
1451+
mmha::store_8bits_vec(v_cache_block_ptr, v, block_elt_idx, scale_orig_quant);
1452+
}
1453+
else
1454+
{
1455+
reinterpret_cast<VecT*>(k_cache_block_ptr)[block_vec_idx] = k;
1456+
reinterpret_cast<VecT*>(v_cache_block_ptr)[block_vec_idx] = v;
1457+
}
14531458
}
14541459
}
14551460
}

0 commit comments

Comments
 (0)