Skip to content

Commit 603173d

Browse files
committed
Respond to review comments
Signed-off-by: Dom Brown <[email protected]>
1 parent 310eb4b commit 603173d

File tree

3 files changed

+2
-19
lines changed

3 files changed

+2
-19
lines changed

cpp/tensorrt_llm/kernels/gptKernels.h

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,6 @@ struct BuildDecoderInfoParams
186186
float2* rotaryEmbeddingCoeffCache;
187187
// Dynamic scaling;
188188
int rotaryEmbeddingMaxPositions;
189-
bool isCrossAttention{false};
190189

191190
bool isBuildDecoderInfoKernelNeeded()
192191
{
@@ -214,10 +213,6 @@ struct BuildDecoderInfoParams
214213
{
215214
return true;
216215
}
217-
if (isCrossAttention)
218-
{
219-
return true;
220-
}
221216
// Other cases don't need to call buildDecoderInfo kernel.
222217
return false;
223218
}
@@ -269,7 +264,6 @@ struct BuildDecoderInfoParams
269264
ss << "rotaryEmbeddingInvFreqCache: " << rotaryEmbeddingInvFreqCache << std::endl;
270265
ss << "rotaryEmbeddingCoeffCache: " << rotaryEmbeddingCoeffCache << std::endl;
271266
ss << "rotaryEmbeddingMaxPositions: " << rotaryEmbeddingMaxPositions << std::endl;
272-
ss << "isCrossAttention: " << isCrossAttention << std::endl;
273267

274268
return ss.str();
275269
}

cpp/tensorrt_llm/kernels/unfusedAttentionKernels/unfusedAttentionKernels_2_template.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1357,7 +1357,8 @@ __global__ void updateKVCacheForCrossAttention(QKVPreprocessingParams<T, KVCache
13571357
// The encoder sequence length.
13581358
int const encoder_seq_len = params.encoder_seq_lens[batch_idx];
13591359
// The encoder sequence offset.
1360-
int const encoder_seq_offset = params.cu_kv_seq_lens[batch_idx];
1360+
// Not needed in Gen phase
1361+
int const encoder_seq_offset = params.generation_phase ? -1 : params.cu_kv_seq_lens[batch_idx];
13611362
// THe maximum sequence length of encoder and decoder.
13621363
int const max_seq_len = max(decoder_seq_len, encoder_seq_len);
13631364

cpp/tensorrt_llm/kernels/xqaDispatcher.cpp

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -114,12 +114,8 @@ QKVPreprocessingParams<T, KVCacheBuffer> makeQKVPreprocessingParams(XQAParams co
114114

115115
// Cross-attention only.
116116

117-
preprocessingParms.cu_kv_seq_lens = cu_kv_seqlens;
118117
preprocessingParms.encoder_seq_lens = params.encoder_input_lengths;
119118

120-
// Not available in generation phase
121-
preprocessingParms.mrope_rotary_cos_sin = nullptr;
122-
123119
return preprocessingParms;
124120
}
125121

@@ -355,14 +351,6 @@ void XqaDispatcher::runImpl(XQAParams params, KVCacheBuffer const& kv_cache_buff
355351
decoder_params.rotaryEmbeddingInvFreq = launchParams.rotary_inv_freq_buf;
356352
decoder_params.rotaryEmbeddingInvFreqCache = params.rotary_embedding_inv_freq_cache;
357353
decoder_params.rotaryEmbeddingMaxPositions = params.rotary_embedding_max_positions;
358-
decoder_params.isCrossAttention = params.cross_attention;
359-
360-
if (params.cross_attention)
361-
{
362-
// cross attention only
363-
decoder_params.maxEncoderQSeqLength = params.max_past_kv_length;
364-
decoder_params.encoderPaddingOffsets = nullptr;
365-
}
366354

367355
// The rotary_embedding_inv_freq_cache for QKVPreprocessing.
368356
// Use the params.rotary_embedding_inv_freq_cache input when the buildDecoderInfoKernel is skipped.

0 commit comments

Comments
 (0)