@@ -559,7 +559,7 @@ __global__ void applyBiasRopeUpdateKVCache(QKVPreprocessingParams<T, KVCacheBuff
559
559
if constexpr (FP8_OUTPUT || ENABLE_8BITS_CACHE)
560
560
{
561
561
mmha::convert_from_float (
562
- &scaleOrigQuant, params.kv_scale_orig_quant ? params.kv_scale_orig_quant [0 ] : 1 .0f );
562
+ &scaleOrigQuant, params.qkv_scale_orig_quant ? params.qkv_scale_orig_quant [0 ] : 1 .0f );
563
563
}
564
564
565
565
if constexpr (FP8_OUTPUT)
@@ -611,13 +611,8 @@ __global__ void applyBiasRopeUpdateKVCache(QKVPreprocessingParams<T, KVCacheBuff
611
611
params.kv_cache_block_scales_buffer .getKBlockPtr (batch_idx, token_idx_in_kv_cache));
612
612
auto * vBlockScales = reinterpret_cast <uint8_t *>(
613
613
params.kv_cache_block_scales_buffer .getVBlockPtr (batch_idx, token_idx_in_kv_cache));
614
- float kSecondLevelSF = params.kv_scale_orig_quant [1 ];
615
- float vSecondLevelSF = params.kv_scale_orig_quant [2 ];
616
- if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && threadIdx.x == 0
617
- && threadIdx.y == 0 && threadIdx.z == 0 )
618
- {
619
- printf (" kSecondLevelSF: %f, vSecondLevelSF: %f\n " , kSecondLevelSF , vSecondLevelSF);
620
- }
614
+ float kSecondLevelSF = params.qkv_scale_orig_quant [1 ];
615
+ float vSecondLevelSF = params.qkv_scale_orig_quant [2 ];
621
616
auto & kPacked = reinterpret_cast <PackedVec<T>&>(k_to_cache);
622
617
auto & vPacked = reinterpret_cast <PackedVec<T>&>(v);
623
618
quantizeAndWriteFP4KVCache<T>(kBlockScales , vBlockScales, reinterpret_cast <uint32_t *>(kDst ),
@@ -644,17 +639,18 @@ __global__ void applyBiasRopeUpdateKVCache(QKVPreprocessingParams<T, KVCacheBuff
644
639
params.fmha_tile_counter [0 ] = 0u ;
645
640
}
646
641
// Take the quantization scales into consideration.
642
+ float q_scale_quant_orig, k_scale_quant_orig, v_scale_quant_orig;
647
643
if constexpr (ENABLE_4BITS_CACHE)
648
644
{
649
- float q_scale_quant_orig = params.qkv_scale_quant_orig ? params.qkv_scale_quant_orig [0 ] : 1 .f ;
650
- float k_scale_quant_orig = params.qkv_scale_quant_orig ? params.qkv_scale_quant_orig [1 ] : 1 .f ;
651
- float v_scale_quant_orig = params.qkv_scale_quant_orig ? params.qkv_scale_quant_orig [2 ] : 1 .f ;
645
+ q_scale_quant_orig = params.qkv_scale_quant_orig ? params.qkv_scale_quant_orig [0 ] : 1 .f ;
646
+ k_scale_quant_orig = params.qkv_scale_quant_orig ? params.qkv_scale_quant_orig [1 ] : 1 .f ;
647
+ v_scale_quant_orig = params.qkv_scale_quant_orig ? params.qkv_scale_quant_orig [2 ] : 1 .f ;
652
648
}
653
649
else
654
650
{
655
- float q_scale_quant_orig = params.qkv_scale_quant_orig ? params.qkv_scale_quant_orig [0 ] : 1 .f ;
656
- float k_scale_quant_orig = params.qkv_scale_quant_orig ? params.qkv_scale_quant_orig [0 ] : 1 .f ;
657
- float v_scale_quant_orig = params.qkv_scale_quant_orig ? params.qkv_scale_quant_orig [0 ] : 1 .f ;
651
+ q_scale_quant_orig = params.qkv_scale_quant_orig ? params.qkv_scale_quant_orig [0 ] : 1 .f ;
652
+ k_scale_quant_orig = params.qkv_scale_quant_orig ? params.qkv_scale_quant_orig [0 ] : 1 .f ;
653
+ v_scale_quant_orig = params.qkv_scale_quant_orig ? params.qkv_scale_quant_orig [0 ] : 1 .f ;
658
654
}
659
655
float o_scale_orig_quant = params.o_scale_orig_quant ? params.o_scale_orig_quant [0 ] : 1 .f ;
660
656
if (params.fmha_bmm1_scale )
@@ -966,7 +962,7 @@ __global__ void applyBiasRopeUpdateKVCacheV2(QKVPreprocessingParams<T, KVCacheBu
966
962
if constexpr (FP8_OUTPUT || ENABLE_8BITS_CACHE)
967
963
{
968
964
mmha::convert_from_float (
969
- &scaleOrigQuant, params.kv_scale_orig_quant ? params.kv_scale_orig_quant [0 ] : 1 .0f );
965
+ &scaleOrigQuant, params.qkv_scale_orig_quant ? params.qkv_scale_orig_quant [0 ] : 1 .0f );
970
966
}
971
967
972
968
if constexpr (FP8_OUTPUT)
@@ -1011,7 +1007,7 @@ __global__ void applyBiasRopeUpdateKVCacheV2(QKVPreprocessingParams<T, KVCacheBu
1011
1007
// Cast float scale to dst data type.
1012
1008
using TScale = typename mmha::kv_cache_scale_type_t <T, TCache>::Type;
1013
1009
TScale scaleOrigQuant;
1014
- mmha::convert_from_float (&scaleOrigQuant, params.kv_scale_orig_quant [0 ]);
1010
+ mmha::convert_from_float (&scaleOrigQuant, params.qkv_scale_orig_quant [0 ]);
1015
1011
// Store 8bits kv cache.
1016
1012
mmha::store_8bits_vec (kDst , k, inBlockIdx, scaleOrigQuant);
1017
1013
mmha::store_8bits_vec (vDst, v, inBlockIdx, scaleOrigQuant);
@@ -1022,14 +1018,8 @@ __global__ void applyBiasRopeUpdateKVCacheV2(QKVPreprocessingParams<T, KVCacheBu
1022
1018
params.kv_cache_block_scales_buffer .getKBlockPtr (batch_idx, token_idx_in_kv_cache));
1023
1019
auto * vBlockScales = reinterpret_cast <uint8_t *>(
1024
1020
params.kv_cache_block_scales_buffer .getVBlockPtr (batch_idx, token_idx_in_kv_cache));
1025
- float kSecondLevelSF = params.kv_scale_orig_quant [1 ];
1026
- float vSecondLevelSF = params.kv_scale_orig_quant [2 ];
1027
- if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && threadIdx.x == 0
1028
- && threadIdx.y == 0 && threadIdx.z == 0 )
1029
- {
1030
- printf (" kSecondLevelSF: %f, vSecondLevelSF: %f\n " , kSecondLevelSF , vSecondLevelSF);
1031
- }
1032
-
1021
+ float kSecondLevelSF = params.qkv_scale_orig_quant [1 ];
1022
+ float vSecondLevelSF = params.qkv_scale_orig_quant [2 ];
1033
1023
auto & kPacked = reinterpret_cast <PackedVec<T>&>(k);
1034
1024
auto & vPacked = reinterpret_cast <PackedVec<T>&>(v);
1035
1025
quantizeAndWriteFP4KVCache<T>(kBlockScales , vBlockScales, reinterpret_cast <uint32_t *>(kDst ),
@@ -1055,17 +1045,18 @@ __global__ void applyBiasRopeUpdateKVCacheV2(QKVPreprocessingParams<T, KVCacheBu
1055
1045
params.fmha_tile_counter [0 ] = 0u ;
1056
1046
}
1057
1047
// Take the quantization scales into consideration.
1048
+ float q_scale_quant_orig, k_scale_quant_orig, v_scale_quant_orig;
1058
1049
if constexpr (ENABLE_4BITS_CACHE)
1059
1050
{
1060
- float q_scale_quant_orig = params.qkv_scale_quant_orig ? params.qkv_scale_quant_orig [0 ] : 1 .f ;
1061
- float k_scale_quant_orig = params.qkv_scale_quant_orig ? params.qkv_scale_quant_orig [1 ] : 1 .f ;
1062
- float v_scale_quant_orig = params.qkv_scale_quant_orig ? params.qkv_scale_quant_orig [2 ] : 1 .f ;
1051
+ q_scale_quant_orig = params.qkv_scale_quant_orig ? params.qkv_scale_quant_orig [0 ] : 1 .f ;
1052
+ k_scale_quant_orig = params.qkv_scale_quant_orig ? params.qkv_scale_quant_orig [1 ] : 1 .f ;
1053
+ v_scale_quant_orig = params.qkv_scale_quant_orig ? params.qkv_scale_quant_orig [2 ] : 1 .f ;
1063
1054
}
1064
1055
else
1065
1056
{
1066
- float q_scale_quant_orig = params.qkv_scale_quant_orig ? params.qkv_scale_quant_orig [0 ] : 1 .f ;
1067
- float k_scale_quant_orig = params.qkv_scale_quant_orig ? params.qkv_scale_quant_orig [0 ] : 1 .f ;
1068
- float v_scale_quant_orig = params.qkv_scale_quant_orig ? params.qkv_scale_quant_orig [0 ] : 1 .f ;
1057
+ q_scale_quant_orig = params.qkv_scale_quant_orig ? params.qkv_scale_quant_orig [0 ] : 1 .f ;
1058
+ k_scale_quant_orig = params.qkv_scale_quant_orig ? params.qkv_scale_quant_orig [0 ] : 1 .f ;
1059
+ v_scale_quant_orig = params.qkv_scale_quant_orig ? params.qkv_scale_quant_orig [0 ] : 1 .f ;
1069
1060
}
1070
1061
float o_scale_orig_quant = params.o_scale_orig_quant ? params.o_scale_orig_quant [0 ] : 1 .f ;
1071
1062
if (params.fmha_bmm1_scale )
@@ -1406,7 +1397,8 @@ __global__ void updateKVCacheForCrossAttention(QKVPreprocessingParams<T, KVCache
1406
1397
[[maybe_unused]] TScale scale_orig_quant;
1407
1398
if constexpr (sizeof (TCache) == 1 || FP8_OUTPUT)
1408
1399
{
1409
- mmha::convert_from_float (&scale_orig_quant, params.kv_scale_orig_quant ? params.kv_scale_orig_quant [0 ] : 1 .0f );
1400
+ mmha::convert_from_float (
1401
+ &scale_orig_quant, params.qkv_scale_orig_quant ? params.qkv_scale_orig_quant [0 ] : 1 .0f );
1410
1402
}
1411
1403
1412
1404
// For loop in the sequence length dimension.
0 commit comments