Skip to content

Commit a477bff

Browse files
committed
fix
Signed-off-by: Tian Zheng <[email protected]>
1 parent ee2b58e commit a477bff

File tree

6 files changed

+29
-37
lines changed

6 files changed

+29
-37
lines changed

cpp/tensorrt_llm/common/attentionOp.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1461,7 +1461,7 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
14611461
decoder_params.blockSparseParams = mBlockSparseParams;
14621462
decoder_params.fmhaTileCounter = fmha_tile_counter_ptr;
14631463
decoder_params.quantScaleO = params.attention_output_orig_quant;
1464-
decoder_params.dequantScaleQKv = params.kv_scale_quant_orig;
1464+
decoder_params.dequantScaleQkv = params.kv_scale_quant_orig;
14651465
decoder_params.separateQkvScales = mKVCacheQuantMode.hasFp4KvCache();
14661466
decoder_params.fmhaHostBmm1Scale = 1.0f / (sqrtf(getHeadSize() * 1.0f) * q_scaling);
14671467
decoder_params.fmhaBmm1Scale = fmha_bmm1_scale_ptr;
@@ -1594,7 +1594,7 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
15941594
preprocessingParams.rotary_embedding_inv_freq = rotary_inv_freq_buf;
15951595
preprocessingParams.rotary_coef_cache_buffer = params.rotary_cos_sin;
15961596
preprocessingParams.mrope_rotary_cos_sin = params.mrope_rotary_cos_sin;
1597-
preprocessingParams.kv_scale_orig_quant = params.kv_scale_orig_quant;
1597+
preprocessingParams.qkv_scale_orig_quant = params.kv_scale_orig_quant;
15981598
preprocessingParams.spec_decoding_position_offsets = nullptr;
15991599
preprocessingParams.logn_scaling = params.logn_scaling_ptr;
16001600

cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/decoderXQAImplJIT.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,7 @@ void DecoderXQAImplJIT::runImpl(XQAParams const& xqaParams, KVCacheBuffer const&
309309
preprocessingParams.cu_seq_lens = xqaParams.multi_query_tokens ? launchParams.cu_seq_lens : nullptr;
310310
preprocessingParams.rotary_embedding_inv_freq = rotary_inv_freq_buf;
311311
preprocessingParams.rotary_coef_cache_buffer = xqaParams.rotary_cos_sin;
312-
preprocessingParams.kv_scale_orig_quant = xqaParams.kv_scale_orig_quant;
312+
preprocessingParams.qkv_scale_orig_quant = xqaParams.kv_scale_orig_quant;
313313
preprocessingParams.spec_decoding_position_offsets = xqaParams.spec_decoding_position_offsets;
314314
preprocessingParams.mrope_position_deltas = xqaParams.mrope_position_deltas;
315315
// Scalar parameters.

cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplPrecompiled.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ class XQAKernelList
224224
preprocessingParams.cu_seq_lens = xqaParams.multi_query_tokens ? launchParams.cu_seq_lens : nullptr;
225225
preprocessingParams.rotary_embedding_inv_freq = rotary_inv_freq_buf;
226226
preprocessingParams.rotary_coef_cache_buffer = xqaParams.rotary_cos_sin;
227-
preprocessingParams.kv_scale_orig_quant = xqaParams.kv_scale_orig_quant;
227+
preprocessingParams.qkv_scale_orig_quant = xqaParams.kv_scale_orig_quant;
228228
preprocessingParams.spec_decoding_position_offsets = xqaParams.spec_decoding_position_offsets;
229229
preprocessingParams.mrope_position_deltas = xqaParams.mrope_position_deltas;
230230
// Scalar parameters.

cpp/tensorrt_llm/kernels/gptKernels.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,7 @@ __global__ __launch_bounds__(THREADS_PER_BLOCK) void computeSeqAndPaddingOffsets
279279
int const q_scale_idx = 0;
280280
int const k_scale_idx = params.separateQkvScales ? 1 : 0;
281281
int const v_scale_idx = params.separateQkvScales ? 2 : 0;
282-
float dequantScaleQ = params.dequantScaleQkv ? params.dequantScaleQ[q_scale_idx] : 1.f;
282+
float dequantScaleQ = params.dequantScaleQkv ? params.dequantScaleQkv[q_scale_idx] : 1.f;
283283
float dequantScaleK = params.dequantScaleQkv ? params.dequantScaleQkv[k_scale_idx] : 1.f;
284284
float dequantScaleV = params.dequantScaleQkv ? params.dequantScaleQkv[v_scale_idx] : 1.f;
285285

cpp/tensorrt_llm/kernels/unfusedAttentionKernels.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ struct QKVPreprocessingParams
235235
<< *(runtime::ITensor::wrap((void*) rotary_embedding_inv_freq, nvinfer1::DataType::kFLOAT,
236236
runtime::ITensor::makeShape({batch_size, rotary_embedding_dim / 2})));
237237
ss << "rotary_coef_cache_buffer: " << rotary_coef_cache_buffer << std::endl;
238-
ss << "kv_scale_orig_quant: " << kv_scale_orig_quant << std::endl;
238+
ss << "qkv_scale_orig_quant: " << qkv_scale_orig_quant << std::endl;
239239
ss << "spec_decoding_position_offsets: " << spec_decoding_position_offsets << std::endl;
240240
ss << "batch_size: " << batch_size << std::endl;
241241
ss << "max_input_seq_len: " << max_input_seq_len << std::endl;

cpp/tensorrt_llm/kernels/unfusedAttentionKernels/unfusedAttentionKernels_2_template.h

Lines changed: 23 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -559,7 +559,7 @@ __global__ void applyBiasRopeUpdateKVCache(QKVPreprocessingParams<T, KVCacheBuff
559559
if constexpr (FP8_OUTPUT || ENABLE_8BITS_CACHE)
560560
{
561561
mmha::convert_from_float(
562-
&scaleOrigQuant, params.kv_scale_orig_quant ? params.kv_scale_orig_quant[0] : 1.0f);
562+
&scaleOrigQuant, params.qkv_scale_orig_quant ? params.qkv_scale_orig_quant[0] : 1.0f);
563563
}
564564

565565
if constexpr (FP8_OUTPUT)
@@ -611,13 +611,8 @@ __global__ void applyBiasRopeUpdateKVCache(QKVPreprocessingParams<T, KVCacheBuff
611611
params.kv_cache_block_scales_buffer.getKBlockPtr(batch_idx, token_idx_in_kv_cache));
612612
auto* vBlockScales = reinterpret_cast<uint8_t*>(
613613
params.kv_cache_block_scales_buffer.getVBlockPtr(batch_idx, token_idx_in_kv_cache));
614-
float kSecondLevelSF = params.kv_scale_orig_quant[1];
615-
float vSecondLevelSF = params.kv_scale_orig_quant[2];
616-
if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && threadIdx.x == 0
617-
&& threadIdx.y == 0 && threadIdx.z == 0)
618-
{
619-
printf("kSecondLevelSF: %f, vSecondLevelSF: %f\n", kSecondLevelSF, vSecondLevelSF);
620-
}
614+
float kSecondLevelSF = params.qkv_scale_orig_quant[1];
615+
float vSecondLevelSF = params.qkv_scale_orig_quant[2];
621616
auto& kPacked = reinterpret_cast<PackedVec<T>&>(k_to_cache);
622617
auto& vPacked = reinterpret_cast<PackedVec<T>&>(v);
623618
quantizeAndWriteFP4KVCache<T>(kBlockScales, vBlockScales, reinterpret_cast<uint32_t*>(kDst),
@@ -644,17 +639,18 @@ __global__ void applyBiasRopeUpdateKVCache(QKVPreprocessingParams<T, KVCacheBuff
644639
params.fmha_tile_counter[0] = 0u;
645640
}
646641
// Take the quantization scales into consideration.
642+
float q_scale_quant_orig, k_scale_quant_orig, v_scale_quant_orig;
647643
if constexpr (ENABLE_4BITS_CACHE)
648644
{
649-
float q_scale_quant_orig = params.qkv_scale_quant_orig ? params.qkv_scale_quant_orig[0] : 1.f;
650-
float k_scale_quant_orig = params.qkv_scale_quant_orig ? params.qkv_scale_quant_orig[1] : 1.f;
651-
float v_scale_quant_orig = params.qkv_scale_quant_orig ? params.qkv_scale_quant_orig[2] : 1.f;
645+
q_scale_quant_orig = params.qkv_scale_quant_orig ? params.qkv_scale_quant_orig[0] : 1.f;
646+
k_scale_quant_orig = params.qkv_scale_quant_orig ? params.qkv_scale_quant_orig[1] : 1.f;
647+
v_scale_quant_orig = params.qkv_scale_quant_orig ? params.qkv_scale_quant_orig[2] : 1.f;
652648
}
653649
else
654650
{
655-
float q_scale_quant_orig = params.qkv_scale_quant_orig ? params.qkv_scale_quant_orig[0] : 1.f;
656-
float k_scale_quant_orig = params.qkv_scale_quant_orig ? params.qkv_scale_quant_orig[0] : 1.f;
657-
float v_scale_quant_orig = params.qkv_scale_quant_orig ? params.qkv_scale_quant_orig[0] : 1.f;
651+
q_scale_quant_orig = params.qkv_scale_quant_orig ? params.qkv_scale_quant_orig[0] : 1.f;
652+
k_scale_quant_orig = params.qkv_scale_quant_orig ? params.qkv_scale_quant_orig[0] : 1.f;
653+
v_scale_quant_orig = params.qkv_scale_quant_orig ? params.qkv_scale_quant_orig[0] : 1.f;
658654
}
659655
float o_scale_orig_quant = params.o_scale_orig_quant ? params.o_scale_orig_quant[0] : 1.f;
660656
if (params.fmha_bmm1_scale)
@@ -966,7 +962,7 @@ __global__ void applyBiasRopeUpdateKVCacheV2(QKVPreprocessingParams<T, KVCacheBu
966962
if constexpr (FP8_OUTPUT || ENABLE_8BITS_CACHE)
967963
{
968964
mmha::convert_from_float(
969-
&scaleOrigQuant, params.kv_scale_orig_quant ? params.kv_scale_orig_quant[0] : 1.0f);
965+
&scaleOrigQuant, params.qkv_scale_orig_quant ? params.qkv_scale_orig_quant[0] : 1.0f);
970966
}
971967

972968
if constexpr (FP8_OUTPUT)
@@ -1011,7 +1007,7 @@ __global__ void applyBiasRopeUpdateKVCacheV2(QKVPreprocessingParams<T, KVCacheBu
10111007
// Cast float scale to dst data type.
10121008
using TScale = typename mmha::kv_cache_scale_type_t<T, TCache>::Type;
10131009
TScale scaleOrigQuant;
1014-
mmha::convert_from_float(&scaleOrigQuant, params.kv_scale_orig_quant[0]);
1010+
mmha::convert_from_float(&scaleOrigQuant, params.qkv_scale_orig_quant[0]);
10151011
// Store 8bits kv cache.
10161012
mmha::store_8bits_vec(kDst, k, inBlockIdx, scaleOrigQuant);
10171013
mmha::store_8bits_vec(vDst, v, inBlockIdx, scaleOrigQuant);
@@ -1022,14 +1018,8 @@ __global__ void applyBiasRopeUpdateKVCacheV2(QKVPreprocessingParams<T, KVCacheBu
10221018
params.kv_cache_block_scales_buffer.getKBlockPtr(batch_idx, token_idx_in_kv_cache));
10231019
auto* vBlockScales = reinterpret_cast<uint8_t*>(
10241020
params.kv_cache_block_scales_buffer.getVBlockPtr(batch_idx, token_idx_in_kv_cache));
1025-
float kSecondLevelSF = params.kv_scale_orig_quant[1];
1026-
float vSecondLevelSF = params.kv_scale_orig_quant[2];
1027-
if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && threadIdx.x == 0
1028-
&& threadIdx.y == 0 && threadIdx.z == 0)
1029-
{
1030-
printf("kSecondLevelSF: %f, vSecondLevelSF: %f\n", kSecondLevelSF, vSecondLevelSF);
1031-
}
1032-
1021+
float kSecondLevelSF = params.qkv_scale_orig_quant[1];
1022+
float vSecondLevelSF = params.qkv_scale_orig_quant[2];
10331023
auto& kPacked = reinterpret_cast<PackedVec<T>&>(k);
10341024
auto& vPacked = reinterpret_cast<PackedVec<T>&>(v);
10351025
quantizeAndWriteFP4KVCache<T>(kBlockScales, vBlockScales, reinterpret_cast<uint32_t*>(kDst),
@@ -1055,17 +1045,18 @@ __global__ void applyBiasRopeUpdateKVCacheV2(QKVPreprocessingParams<T, KVCacheBu
10551045
params.fmha_tile_counter[0] = 0u;
10561046
}
10571047
// Take the quantization scales into consideration.
1048+
float q_scale_quant_orig, k_scale_quant_orig, v_scale_quant_orig;
10581049
if constexpr (ENABLE_4BITS_CACHE)
10591050
{
1060-
float q_scale_quant_orig = params.qkv_scale_quant_orig ? params.qkv_scale_quant_orig[0] : 1.f;
1061-
float k_scale_quant_orig = params.qkv_scale_quant_orig ? params.qkv_scale_quant_orig[1] : 1.f;
1062-
float v_scale_quant_orig = params.qkv_scale_quant_orig ? params.qkv_scale_quant_orig[2] : 1.f;
1051+
q_scale_quant_orig = params.qkv_scale_quant_orig ? params.qkv_scale_quant_orig[0] : 1.f;
1052+
k_scale_quant_orig = params.qkv_scale_quant_orig ? params.qkv_scale_quant_orig[1] : 1.f;
1053+
v_scale_quant_orig = params.qkv_scale_quant_orig ? params.qkv_scale_quant_orig[2] : 1.f;
10631054
}
10641055
else
10651056
{
1066-
float q_scale_quant_orig = params.qkv_scale_quant_orig ? params.qkv_scale_quant_orig[0] : 1.f;
1067-
float k_scale_quant_orig = params.qkv_scale_quant_orig ? params.qkv_scale_quant_orig[0] : 1.f;
1068-
float v_scale_quant_orig = params.qkv_scale_quant_orig ? params.qkv_scale_quant_orig[0] : 1.f;
1057+
q_scale_quant_orig = params.qkv_scale_quant_orig ? params.qkv_scale_quant_orig[0] : 1.f;
1058+
k_scale_quant_orig = params.qkv_scale_quant_orig ? params.qkv_scale_quant_orig[0] : 1.f;
1059+
v_scale_quant_orig = params.qkv_scale_quant_orig ? params.qkv_scale_quant_orig[0] : 1.f;
10691060
}
10701061
float o_scale_orig_quant = params.o_scale_orig_quant ? params.o_scale_orig_quant[0] : 1.f;
10711062
if (params.fmha_bmm1_scale)
@@ -1406,7 +1397,8 @@ __global__ void updateKVCacheForCrossAttention(QKVPreprocessingParams<T, KVCache
14061397
[[maybe_unused]] TScale scale_orig_quant;
14071398
if constexpr (sizeof(TCache) == 1 || FP8_OUTPUT)
14081399
{
1409-
mmha::convert_from_float(&scale_orig_quant, params.kv_scale_orig_quant ? params.kv_scale_orig_quant[0] : 1.0f);
1400+
mmha::convert_from_float(
1401+
&scale_orig_quant, params.qkv_scale_orig_quant ? params.qkv_scale_orig_quant[0] : 1.0f);
14101402
}
14111403

14121404
// For loop in the sequence length dimension.

0 commit comments

Comments
 (0)