@@ -212,6 +212,10 @@ bool AttentionOp::convertMMHAParamsToXQAParams(tensorrt_llm::kernels::XQAParams&
212
212
}
213
213
xqaParams.kv_cache_data_type = DATA_TYPE_E4M3;
214
214
}
215
+ else if (mKVCacheQuantMode .hasFp4KvCache ())
216
+ {
217
+ xqaParams.kv_cache_data_type = DATA_TYPE_E2M1;
218
+ }
215
219
else
216
220
{
217
221
xqaParams.kv_cache_data_type = xqaParams.data_type ;
@@ -924,6 +928,9 @@ int AttentionOp::mlaGeneration(
924
928
generation_params.can_use_one_more_block , generation_params.host_primary_pool_pointer ,
925
929
generation_params.host_secondary_pool_pointer , generation_params.block_offsets );
926
930
931
+ // Currently NVFP4 KV cache is not supported for MLA. An empty placeholder is provided.
932
+ auto kv_scale_cache_buffer = KVBlockArray ();
933
+
927
934
// Workspace pointer shift
928
935
int8_t * workspace_byte_ptr = reinterpret_cast <int8_t *>(params.workspace );
929
936
size_t offset = 0 ;
@@ -1200,7 +1207,7 @@ int AttentionOp::mlaGeneration(
1200
1207
{
1201
1208
TLLM_LOG_DEBUG (" XQA kernels are selected in the generation phase." );
1202
1209
xqaParams.stream = stream;
1203
- mXqaDispatcher ->run (xqaParams, kv_cache_buffer);
1210
+ mXqaDispatcher ->run (xqaParams, kv_cache_buffer, kv_scale_cache_buffer );
1204
1211
return 0 ;
1205
1212
}
1206
1213
else if (mIsSpecDecodingEnabled && mUseSpecDecoding )
@@ -1274,8 +1281,23 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
1274
1281
float const q_scaling = mQScaling ;
1275
1282
1276
1283
KVCacheBuffer kv_cache_buffer;
1277
- auto const elemSize = mKVCacheQuantMode .hasKvCacheQuant () ? sizeof (int8_t ) : sizeof (T);
1278
- auto sizePerToken = mNumAttnKVHeads * headSize * elemSize;
1284
+ KVCacheBuffer kv_scale_cache_buffer;
1285
+
1286
+ int elemBits;
1287
+ if (mKVCacheQuantMode .hasInt8KvCache () || mKVCacheQuantMode .hasFp8KvCache ())
1288
+ {
1289
+ elemBits = 8 ;
1290
+ }
1291
+ else if (mKVCacheQuantMode .hasFp4KvCache ())
1292
+ {
1293
+ elemBits = 4 ;
1294
+ }
1295
+ else
1296
+ {
1297
+ elemBits = sizeof (T) * 8 ;
1298
+ }
1299
+ auto sizePerToken = mNumKVHeads * headSize * elemBits / 8 /* bits*/ ;
1300
+
1279
1301
if (useKVCache ())
1280
1302
{
1281
1303
if constexpr (std::is_same_v<KVCacheBuffer, KVBlockArray>)
@@ -1284,6 +1306,14 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
1284
1306
sizePerToken, params.cyclic_attention_window_size , params.max_cyclic_attention_window_size ,
1285
1307
params.sink_token_length , params.can_use_one_more_block , params.host_primary_pool_pointer ,
1286
1308
params.host_secondary_pool_pointer , params.block_offsets );
1309
+ if (mKVCacheQuantMode .hasFp4KvCache ())
1310
+ {
1311
+ kv_scale_cache_buffer = KVBlockArray (params.batch_size , params.max_blocks_per_sequence , mTokensPerBlock ,
1312
+ sizePerToken / 8 , params.cyclic_attention_window_size , params.max_cyclic_attention_window_size ,
1313
+ params.sink_token_length , params.can_use_one_more_block ,
1314
+ params.host_primary_block_scale_pool_pointer , params.host_secondary_block_scale_pool_pointer ,
1315
+ params.block_offsets );
1316
+ }
1287
1317
}
1288
1318
else if constexpr (std::is_same_v<KVCacheBuffer, KVLinearBuffer>)
1289
1319
{
@@ -1292,6 +1322,7 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
1292
1322
isCrossAttention () ? params.cross_kv_length : params.max_attention_window_size , sizePerToken,
1293
1323
params.cyclic_attention_window_size , params.sink_token_length , false ,
1294
1324
reinterpret_cast <BufferDataType*>(params.key_value_cache ));
1325
+ TLLM_CHECK_WITH_INFO (!(mKVCacheQuantMode .hasFp4KvCache ()), " FP4 KV cache only supports paged KV." );
1295
1326
}
1296
1327
}
1297
1328
@@ -1430,8 +1461,8 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
1430
1461
decoder_params.blockSparseParams = mBlockSparseParams ;
1431
1462
decoder_params.fmhaTileCounter = fmha_tile_counter_ptr;
1432
1463
decoder_params.quantScaleO = params.attention_output_orig_quant ;
1433
- decoder_params.dequantScaleQ = params.kv_scale_quant_orig ;
1434
- decoder_params.dequantScaleKv = params. kv_scale_quant_orig ;
1464
+ decoder_params.dequantScaleQKv = params.kv_scale_quant_orig ;
1465
+ decoder_params.separateQkvScales = mKVCacheQuantMode . hasFp4KvCache () ;
1435
1466
decoder_params.fmhaHostBmm1Scale = 1 .0f / (sqrtf (getHeadSize () * 1 .0f ) * q_scaling);
1436
1467
decoder_params.fmhaBmm1Scale = fmha_bmm1_scale_ptr;
1437
1468
decoder_params.fmhaBmm2Scale = fmha_bmm2_scale_ptr;
@@ -1489,9 +1520,19 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
1489
1520
sync_check_cuda_error (stream);
1490
1521
}
1491
1522
1492
- KvCacheDataType const cache_type = mKVCacheQuantMode .hasInt8KvCache ()
1493
- ? KvCacheDataType::INT8
1494
- : (mKVCacheQuantMode .hasFp8KvCache () ? KvCacheDataType::FP8 : KvCacheDataType::BASE);
1523
+ KvCacheDataType cache_type{KvCacheDataType::BASE};
1524
+ if (mKVCacheQuantMode .hasInt8KvCache ())
1525
+ {
1526
+ cache_type = KvCacheDataType::INT8;
1527
+ }
1528
+ else if (mKVCacheQuantMode .hasFp8KvCache ())
1529
+ {
1530
+ cache_type = KvCacheDataType::FP8;
1531
+ }
1532
+ else if (mKVCacheQuantMode .hasFp4KvCache ())
1533
+ {
1534
+ cache_type = KvCacheDataType::NVFP4;
1535
+ }
1495
1536
1496
1537
cudaDataType_t const gemm_data_type = tc::CudaDataType<T>::value;
1497
1538
int const attention_seq_len_1 = params.input_seq_length ; // q length
@@ -1540,6 +1581,7 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
1540
1581
preprocessingParams.quantized_qkv_output = fp8_qkv_buffer;
1541
1582
preprocessingParams.q_output = q_buf_2_;
1542
1583
preprocessingParams.kv_cache_buffer = kv_cache_buffer;
1584
+ preprocessingParams.kv_cache_block_scales_buffer = kv_scale_cache_buffer;
1543
1585
preprocessingParams.qkv_bias = params.qkv_bias ;
1544
1586
preprocessingParams.tokens_info = decoder_params.tokensInfo ;
1545
1587
preprocessingParams.seq_lens = params.context_lengths ;
@@ -1552,7 +1594,7 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
1552
1594
preprocessingParams.rotary_embedding_inv_freq = rotary_inv_freq_buf;
1553
1595
preprocessingParams.rotary_coef_cache_buffer = params.rotary_cos_sin ;
1554
1596
preprocessingParams.mrope_rotary_cos_sin = params.mrope_rotary_cos_sin ;
1555
- preprocessingParams.kvScaleOrigQuant = params.kv_scale_orig_quant ;
1597
+ preprocessingParams.kv_scale_orig_quant = params.kv_scale_orig_quant ;
1556
1598
preprocessingParams.spec_decoding_position_offsets = nullptr ;
1557
1599
preprocessingParams.logn_scaling = params.logn_scaling_ptr ;
1558
1600
@@ -1702,6 +1744,7 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
1702
1744
else
1703
1745
{
1704
1746
fmhaParams.pagedKvCache = kv_cache_buffer;
1747
+ fmhaParams.pagedKvSfCache = kv_scale_cache_buffer;
1705
1748
}
1706
1749
}
1707
1750
fmhaParams.cuQSeqLenPtr = cu_q_seqlens;
@@ -2048,8 +2091,24 @@ int AttentionOp::enqueueGeneration(EnqueueGenerationParams<T> const& params, cud
2048
2091
int32_t const batch_beam = params.beam_width * params.num_requests ;
2049
2092
2050
2093
KVCacheBuffer kv_cache_buffer;
2051
- auto const elemSize = mKVCacheQuantMode .hasKvCacheQuant () ? sizeof (int8_t ) : sizeof (T);
2052
- auto const sizePerToken = mNumAttnKVHeads * headSize * elemSize;
2094
+ KVCacheBuffer kv_scale_cache_buffer;
2095
+
2096
+ int elemBits;
2097
+ if (mKVCacheQuantMode .hasInt8KvCache () || mKVCacheQuantMode .hasFp8KvCache ())
2098
+ {
2099
+ elemBits = 8 ;
2100
+ }
2101
+ else if (mKVCacheQuantMode .hasFp4KvCache ())
2102
+ {
2103
+ elemBits = 4 ;
2104
+ }
2105
+ else
2106
+ {
2107
+ elemBits = sizeof (T) * 8 ;
2108
+ }
2109
+
2110
+ auto const sizePerToken = mNumKVHeads * headSize * elemBits / 8 /* bits*/ ;
2111
+
2053
2112
if (useKVCache ())
2054
2113
{
2055
2114
if constexpr (std::is_same_v<KVCacheBuffer, KVBlockArray>)
@@ -2059,13 +2118,22 @@ int AttentionOp::enqueueGeneration(EnqueueGenerationParams<T> const& params, cud
2059
2118
params.cyclic_attention_window_size , params.max_cyclic_attention_window_size , params.sink_token_length ,
2060
2119
params.can_use_one_more_block , params.host_primary_pool_pointer , params.host_secondary_pool_pointer ,
2061
2120
reinterpret_cast <BufferDataType*>(params.block_offsets ));
2121
+ if (mKVCacheQuantMode .hasFp4KvCache ())
2122
+ {
2123
+ kv_scale_cache_buffer = KVBlockArray (batch_beam, params.max_blocks_per_sequence , mTokensPerBlock ,
2124
+ sizePerToken / 8 , params.cyclic_attention_window_size , params.max_cyclic_attention_window_size ,
2125
+ params.sink_token_length , params.can_use_one_more_block ,
2126
+ params.host_primary_block_scale_pool_pointer , params.host_secondary_block_scale_pool_pointer ,
2127
+ reinterpret_cast <BufferDataType*>(params.block_offsets ));
2128
+ }
2062
2129
}
2063
2130
else if constexpr (std::is_same_v<KVCacheBuffer, KVLinearBuffer>)
2064
2131
{
2065
2132
using BufferDataType = typename KVCacheBuffer::DataType;
2066
2133
kv_cache_buffer = KVLinearBuffer (batch_beam, params.max_attention_window_size , sizePerToken,
2067
2134
params.cyclic_attention_window_size , params.sink_token_length , false ,
2068
2135
reinterpret_cast <BufferDataType*>(params.key_value_cache ));
2136
+ TLLM_CHECK_WITH_INFO (!(mKVCacheQuantMode .hasFp4KvCache ()), " FP4 KV cache only supports paged KV." );
2069
2137
}
2070
2138
}
2071
2139
sync_check_cuda_error (stream);
@@ -2137,7 +2205,7 @@ int AttentionOp::enqueueGeneration(EnqueueGenerationParams<T> const& params, cud
2137
2205
xqaParams.output = mhaOutput;
2138
2206
xqaParams.qkv = attention_input;
2139
2207
}
2140
- mXqaDispatcher ->run (xqaParams, kv_cache_buffer);
2208
+ mXqaDispatcher ->run (xqaParams, kv_cache_buffer, kv_scale_cache_buffer );
2141
2209
if (mCpSize > 1 && mAttnTpSize > 1 && mAttnCpSize == 1 )
2142
2210
{
2143
2211
this ->template ulyssesGenerationPostprocess <T>(
@@ -2154,6 +2222,10 @@ int AttentionOp::enqueueGeneration(EnqueueGenerationParams<T> const& params, cud
2154
2222
{
2155
2223
TLLM_CHECK_WITH_INFO (false , " No available kernels are found for FP4 output." );
2156
2224
}
2225
+ else if (mKVCacheQuantMode .hasFp4KvCache ())
2226
+ {
2227
+ TLLM_CHECK_WITH_INFO (false , " No available kernels are found for FP4 KV cache." );
2228
+ }
2157
2229
}
2158
2230
2159
2231
// This is the number of kv tokens that q needs to visit, but excluding one as it will be processed before the kv
@@ -2419,6 +2491,10 @@ int AttentionOp::initialize() noexcept
2419
2491
TLLM_CHECK_WITH_INFO (!mFuseFp4Quant || mSM == 100 || mSM == 120 || mSM == 121 ,
2420
2492
" fuse_fp4_quant only supports SM100 or SM120 or SM121 devices." );
2421
2493
2494
+ // Check requirements for FP4 KV cache.
2495
+ TLLM_CHECK_WITH_INFO (!mKVCacheQuantMode .hasFp4KvCache () || mFP8ContextFMHA ,
2496
+ " mFP8ContextFMHA must enable if FP4 KV cache is enabled" );
2497
+
2422
2498
TLLM_CHECK (isRoPE () == (mRotaryEmbeddingDim != 0 ));
2423
2499
TLLM_CHECK_WITH_INFO ((mSM >= 80 ) || (mType != nvinfer1::DataType::kBF16 ),
2424
2500
" Unsupported data type, pre SM 80 GPUs do not support bfloat16" );
@@ -2495,7 +2571,10 @@ int AttentionOp::initialize() noexcept
2495
2571
{
2496
2572
fmhaParams.dataTypeKv = DATA_TYPE_E4M3;
2497
2573
}
2498
- // TODO: add FP4 KV cache support.
2574
+ else if (mKVCacheQuantMode .hasFp4KvCache ())
2575
+ {
2576
+ fmhaParams.dataTypeKv = DATA_TYPE_E2M1;
2577
+ }
2499
2578
}
2500
2579
// The output dtype.
2501
2580
fmhaParams.dataTypeOut = data_type;
@@ -2697,6 +2776,11 @@ int AttentionOp::initialize() noexcept
2697
2776
fixedParams.kvDataType = DATA_TYPE_E4M3;
2698
2777
fixedParams.mathDataType = DATA_TYPE_E4M3;
2699
2778
}
2779
+ else if (mKVCacheQuantMode .hasFp4KvCache ())
2780
+ {
2781
+ fixedParams.kvDataType = DATA_TYPE_E2M1;
2782
+ fixedParams.mathDataType = DATA_TYPE_E4M3;
2783
+ }
2700
2784
else
2701
2785
{
2702
2786
fixedParams.kvDataType = fixedParams.inputDataType ;
0 commit comments