Skip to content

Commit 215d5b2

Browse files
committed
Nitpicks from CodeRabbit
Signed-off-by: Dom Brown <[email protected]>
1 parent aa9b222 commit 215d5b2

File tree

2 files changed

+3
-7
lines changed

2 files changed

+3
-7
lines changed

cpp/tensorrt_llm/kernels/gptKernels.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,7 @@ struct BuildDecoderInfoParams
269269
ss << "rotaryEmbeddingInvFreqCache: " << rotaryEmbeddingInvFreqCache << std::endl;
270270
ss << "rotaryEmbeddingCoeffCache: " << rotaryEmbeddingCoeffCache << std::endl;
271271
ss << "rotaryEmbeddingMaxPositions: " << rotaryEmbeddingMaxPositions << std::endl;
272+
ss << "isCrossAttention: " << isCrossAttention << std::endl;
272273

273274
return ss.str();
274275
}

cpp/tensorrt_llm/kernels/xqaDispatcher.cpp

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,8 @@ QKVPreprocessingParams<T, KVCacheBuffer> makeQKVPreprocessingParams(XQAParams co
7979
preprocessingParms.rotary_coef_cache_buffer = params.rotary_cos_sin;
8080
preprocessingParms.kvScaleOrigQuant = params.kv_scale_orig_quant;
8181
preprocessingParms.kv_cache_scale_factors = nullptr;
82-
preprocessingParms.spec_decoding_position_offsets = params.spec_decoding_position_offsets;
82+
preprocessingParms.spec_decoding_position_offsets
83+
= params.cross_attention ? nullptr : params.spec_decoding_position_offsets;
8384
preprocessingParms.mrope_position_deltas = params.mrope_position_deltas;
8485
// Scalar parameters.
8586
preprocessingParms.batch_size = int(batch_beam_size);
@@ -115,12 +116,6 @@ QKVPreprocessingParams<T, KVCacheBuffer> makeQKVPreprocessingParams(XQAParams co
115116

116117
preprocessingParms.cu_kv_seq_lens = cu_kv_seqlens;
117118
preprocessingParms.encoder_seq_lens = params.encoder_input_lengths;
118-
preprocessingParms.rotary_embedding_inv_freq = rotary_inv_freq_buf;
119-
preprocessingParms.rotary_coef_cache_buffer = params.rotary_cos_sin;
120-
121-
preprocessingParms.kvScaleOrigQuant = params.kv_scale_orig_quant;
122-
preprocessingParms.spec_decoding_position_offsets = nullptr;
123-
preprocessingParms.logn_scaling = params.logn_scaling_ptr;
124119

125120
// Not available in generation phase
126121
preprocessingParms.mrope_rotary_cos_sin = nullptr;

0 commit comments

Comments
 (0)