Skip to content

Commit ee2b58e

Browse files
committed
Add NVFP4 KV cache support
Signed-off-by: Tian Zheng <[email protected]>
1 parent 907c180 commit ee2b58e

File tree

125 files changed

+711
-362
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

125 files changed

+711
-362
lines changed

cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -496,7 +496,7 @@ class KVCacheBlockPool
496496
, sizePerHead(sizePerHead)
497497
, tokensPerBlock(tokensPerBlock)
498498
, quantSize(quantSize)
499-
, blockSize((numKvHeads * sizePerHead * tokensPerBlock) / quantSize)
499+
, blockSize(numKvHeads * sizePerHead * tokensPerBlock)
500500
, primaryPtr(std::move(primaryPtr))
501501
, secondaryPtr(std::move(secondaryPtr))
502502
, containsBlockScales(containsBlockScales)
@@ -1251,6 +1251,8 @@ class BaseKVCacheManager
12511251

12521252
[[nodiscard]] virtual runtime::ITensor::SharedPtr getBlockPoolPointers() const = 0;
12531253

1254+
[[nodiscard]] virtual runtime::ITensor::SharedPtr getBlockScalePoolPointers() const = 0;
1255+
12541256
[[nodiscard]] virtual runtime::ITensor::SharedPtr getLayerToPoolMapping() const = 0;
12551257

12561258
virtual void getBlockOffsetsOfBatch(
@@ -1555,7 +1557,7 @@ class KVCacheManager : public BaseKVCacheManager
15551557
return mLayerToPoolMapping;
15561558
}
15571559

1558-
[[nodiscard]] runtime::ITensor::SharedPtr getBlockScalePoolPointers() const
1560+
[[nodiscard]] runtime::ITensor::SharedPtr getBlockScalePoolPointers() const override
15591561
{
15601562
// TODO: add a new optional model input so the attention plugin can access these
15611563
return mBlockScalePoolPointers;

cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -608,6 +608,16 @@ WindowBlockManager::WindowBlockManager(nvinfer1::DataType dtype, SizeType32 wind
608608
mLayerToIndexWithinPool[layerIdx] = layerIndexWithinPool;
609609
}
610610

611+
#ifdef ENABLE_FP4
612+
SizeType32 const numEltsPerContainer = mDataType == nvinfer1::DataType::kFP4 ? 2 : 1;
613+
if (numEltsPerContainer == 2)
614+
{
615+
TLLM_CHECK_WITH_INFO(sizePerHead % 2 == 0, "sizePerHead must be divisible by 2 for 4-bit KV cache.");
616+
}
617+
#else
618+
SizeType32 const numEltsPerContainer = 1;
619+
#endif
620+
611621
size_t poolIndex = 0;
612622
for (auto const [numKvHeads, numLayers] : numLayersPerPool)
613623
{
@@ -618,7 +628,7 @@ WindowBlockManager::WindowBlockManager(nvinfer1::DataType dtype, SizeType32 wind
618628
mLayerToPoolIndex[layerIdx] = poolIndex;
619629
}
620630
}
621-
mPools.emplace_back(numLayers, mKVFactor, numKvHeads, sizePerHead, tokensPerBlock, 1);
631+
mPools.emplace_back(numLayers, mKVFactor, numKvHeads, sizePerHead / numEltsPerContainer, tokensPerBlock, 1);
622632
++poolIndex;
623633
}
624634

@@ -703,14 +713,20 @@ void BlockManager::storeContextBlocks(GenerationRequest& sequence, LlmRequest co
703713

704714
void WindowBlockManager::createBlockScalePools(SizeType32 quantBlockSize)
705715
{
716+
717+
#ifdef ENABLE_FP4
718+
SizeType32 const numEltsPerContainer = mDataType == nvinfer1::DataType::kFP4 ? 2 : 1;
719+
#else
720+
SizeType32 const numEltsPerContainer = 1;
721+
#endif
706722
auto num_pools = mPools.size();
707723
for (size_t i = 0; i < num_pools; ++i)
708724
{
709725
auto& kv_pool = mPools[i];
710-
TLLM_CHECK_WITH_INFO(kv_pool.blockSize % quantBlockSize == 0,
711-
"Cannot use FP4 quantization since kv_pool.blockSize is not divisible by FP4 quantBlockSize.");
712-
713-
mPools.emplace_back(kv_pool.numLayers, kv_pool.kvFactor, kv_pool.numKvHeads, kv_pool.sizePerHead,
726+
TLLM_CHECK_WITH_INFO((kv_pool.sizePerHead * numEltsPerContainer) % quantBlockSize == 0,
727+
"Cannot use FP4 quantization since kv_pool.sizePerHead is not divisible by FP4 quantBlockSize.");
728+
auto blockScaleSizePerHead = kv_pool.sizePerHead * numEltsPerContainer / quantBlockSize;
729+
mPools.emplace_back(kv_pool.numLayers, kv_pool.kvFactor, kv_pool.numKvHeads, blockScaleSizePerHead,
714730
kv_pool.tokensPerBlock, quantBlockSize,
715731
/*primaryPool=*/nullptr,
716732
/*secondaryPool=*/nullptr,
@@ -745,10 +761,6 @@ void WindowBlockManager::allocatePools(bool useUvm)
745761

746762
if (poolIsFP4)
747763
{
748-
TLLM_CHECK_WITH_INFO(blockSize % 2 == 0, "Block size must be divisible by 2 for FP4 KV cache.");
749-
// Divide by 2. We can't create FP4 buffers directly, so we'll have to create a uint8 buffer with
750-
// half the expected number of elements.
751-
blockSize /= 2;
752764
poolDtype = nvinfer1::DataType::kINT8;
753765
}
754766

cpp/tensorrt_llm/common/attentionOp.cpp

Lines changed: 97 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,10 @@ bool AttentionOp::convertMMHAParamsToXQAParams(tensorrt_llm::kernels::XQAParams&
212212
}
213213
xqaParams.kv_cache_data_type = DATA_TYPE_E4M3;
214214
}
215+
else if (mKVCacheQuantMode.hasFp4KvCache())
216+
{
217+
xqaParams.kv_cache_data_type = DATA_TYPE_E2M1;
218+
}
215219
else
216220
{
217221
xqaParams.kv_cache_data_type = xqaParams.data_type;
@@ -924,6 +928,9 @@ int AttentionOp::mlaGeneration(
924928
generation_params.can_use_one_more_block, generation_params.host_primary_pool_pointer,
925929
generation_params.host_secondary_pool_pointer, generation_params.block_offsets);
926930

931+
// Currently NVFP4 KV cache is not supported for MLA. An empty placeholder is provided.
932+
auto kv_scale_cache_buffer = KVBlockArray();
933+
927934
// Workspace pointer shift
928935
int8_t* workspace_byte_ptr = reinterpret_cast<int8_t*>(params.workspace);
929936
size_t offset = 0;
@@ -1200,7 +1207,7 @@ int AttentionOp::mlaGeneration(
12001207
{
12011208
TLLM_LOG_DEBUG("XQA kernels are selected in the generation phase.");
12021209
xqaParams.stream = stream;
1203-
mXqaDispatcher->run(xqaParams, kv_cache_buffer);
1210+
mXqaDispatcher->run(xqaParams, kv_cache_buffer, kv_scale_cache_buffer);
12041211
return 0;
12051212
}
12061213
else if (mIsSpecDecodingEnabled && mUseSpecDecoding)
@@ -1274,8 +1281,23 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
12741281
float const q_scaling = mQScaling;
12751282

12761283
KVCacheBuffer kv_cache_buffer;
1277-
auto const elemSize = mKVCacheQuantMode.hasKvCacheQuant() ? sizeof(int8_t) : sizeof(T);
1278-
auto sizePerToken = mNumAttnKVHeads * headSize * elemSize;
1284+
KVCacheBuffer kv_scale_cache_buffer;
1285+
1286+
int elemBits;
1287+
if (mKVCacheQuantMode.hasInt8KvCache() || mKVCacheQuantMode.hasFp8KvCache())
1288+
{
1289+
elemBits = 8;
1290+
}
1291+
else if (mKVCacheQuantMode.hasFp4KvCache())
1292+
{
1293+
elemBits = 4;
1294+
}
1295+
else
1296+
{
1297+
elemBits = sizeof(T) * 8;
1298+
}
1299+
auto sizePerToken = mNumKVHeads * headSize * elemBits / 8 /*bits*/;
1300+
12791301
if (useKVCache())
12801302
{
12811303
if constexpr (std::is_same_v<KVCacheBuffer, KVBlockArray>)
@@ -1284,6 +1306,14 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
12841306
sizePerToken, params.cyclic_attention_window_size, params.max_cyclic_attention_window_size,
12851307
params.sink_token_length, params.can_use_one_more_block, params.host_primary_pool_pointer,
12861308
params.host_secondary_pool_pointer, params.block_offsets);
1309+
if (mKVCacheQuantMode.hasFp4KvCache())
1310+
{
1311+
kv_scale_cache_buffer = KVBlockArray(params.batch_size, params.max_blocks_per_sequence, mTokensPerBlock,
1312+
sizePerToken / 8, params.cyclic_attention_window_size, params.max_cyclic_attention_window_size,
1313+
params.sink_token_length, params.can_use_one_more_block,
1314+
params.host_primary_block_scale_pool_pointer, params.host_secondary_block_scale_pool_pointer,
1315+
params.block_offsets);
1316+
}
12871317
}
12881318
else if constexpr (std::is_same_v<KVCacheBuffer, KVLinearBuffer>)
12891319
{
@@ -1292,6 +1322,7 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
12921322
isCrossAttention() ? params.cross_kv_length : params.max_attention_window_size, sizePerToken,
12931323
params.cyclic_attention_window_size, params.sink_token_length, false,
12941324
reinterpret_cast<BufferDataType*>(params.key_value_cache));
1325+
TLLM_CHECK_WITH_INFO(!(mKVCacheQuantMode.hasFp4KvCache()), "FP4 KV cache only supports paged KV.");
12951326
}
12961327
}
12971328

@@ -1430,8 +1461,8 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
14301461
decoder_params.blockSparseParams = mBlockSparseParams;
14311462
decoder_params.fmhaTileCounter = fmha_tile_counter_ptr;
14321463
decoder_params.quantScaleO = params.attention_output_orig_quant;
1433-
decoder_params.dequantScaleQ = params.kv_scale_quant_orig;
1434-
decoder_params.dequantScaleKv = params.kv_scale_quant_orig;
1464+
decoder_params.dequantScaleQKv = params.kv_scale_quant_orig;
1465+
decoder_params.separateQkvScales = mKVCacheQuantMode.hasFp4KvCache();
14351466
decoder_params.fmhaHostBmm1Scale = 1.0f / (sqrtf(getHeadSize() * 1.0f) * q_scaling);
14361467
decoder_params.fmhaBmm1Scale = fmha_bmm1_scale_ptr;
14371468
decoder_params.fmhaBmm2Scale = fmha_bmm2_scale_ptr;
@@ -1489,9 +1520,19 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
14891520
sync_check_cuda_error(stream);
14901521
}
14911522

1492-
KvCacheDataType const cache_type = mKVCacheQuantMode.hasInt8KvCache()
1493-
? KvCacheDataType::INT8
1494-
: (mKVCacheQuantMode.hasFp8KvCache() ? KvCacheDataType::FP8 : KvCacheDataType::BASE);
1523+
KvCacheDataType cache_type{KvCacheDataType::BASE};
1524+
if (mKVCacheQuantMode.hasInt8KvCache())
1525+
{
1526+
cache_type = KvCacheDataType::INT8;
1527+
}
1528+
else if (mKVCacheQuantMode.hasFp8KvCache())
1529+
{
1530+
cache_type = KvCacheDataType::FP8;
1531+
}
1532+
else if (mKVCacheQuantMode.hasFp4KvCache())
1533+
{
1534+
cache_type = KvCacheDataType::NVFP4;
1535+
}
14951536

14961537
cudaDataType_t const gemm_data_type = tc::CudaDataType<T>::value;
14971538
int const attention_seq_len_1 = params.input_seq_length; // q length
@@ -1540,6 +1581,7 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
15401581
preprocessingParams.quantized_qkv_output = fp8_qkv_buffer;
15411582
preprocessingParams.q_output = q_buf_2_;
15421583
preprocessingParams.kv_cache_buffer = kv_cache_buffer;
1584+
preprocessingParams.kv_cache_block_scales_buffer = kv_scale_cache_buffer;
15431585
preprocessingParams.qkv_bias = params.qkv_bias;
15441586
preprocessingParams.tokens_info = decoder_params.tokensInfo;
15451587
preprocessingParams.seq_lens = params.context_lengths;
@@ -1552,7 +1594,7 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
15521594
preprocessingParams.rotary_embedding_inv_freq = rotary_inv_freq_buf;
15531595
preprocessingParams.rotary_coef_cache_buffer = params.rotary_cos_sin;
15541596
preprocessingParams.mrope_rotary_cos_sin = params.mrope_rotary_cos_sin;
1555-
preprocessingParams.kvScaleOrigQuant = params.kv_scale_orig_quant;
1597+
preprocessingParams.kv_scale_orig_quant = params.kv_scale_orig_quant;
15561598
preprocessingParams.spec_decoding_position_offsets = nullptr;
15571599
preprocessingParams.logn_scaling = params.logn_scaling_ptr;
15581600

@@ -1702,6 +1744,7 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
17021744
else
17031745
{
17041746
fmhaParams.pagedKvCache = kv_cache_buffer;
1747+
fmhaParams.pagedKvSfCache = kv_scale_cache_buffer;
17051748
}
17061749
}
17071750
fmhaParams.cuQSeqLenPtr = cu_q_seqlens;
@@ -2048,8 +2091,24 @@ int AttentionOp::enqueueGeneration(EnqueueGenerationParams<T> const& params, cud
20482091
int32_t const batch_beam = params.beam_width * params.num_requests;
20492092

20502093
KVCacheBuffer kv_cache_buffer;
2051-
auto const elemSize = mKVCacheQuantMode.hasKvCacheQuant() ? sizeof(int8_t) : sizeof(T);
2052-
auto const sizePerToken = mNumAttnKVHeads * headSize * elemSize;
2094+
KVCacheBuffer kv_scale_cache_buffer;
2095+
2096+
int elemBits;
2097+
if (mKVCacheQuantMode.hasInt8KvCache() || mKVCacheQuantMode.hasFp8KvCache())
2098+
{
2099+
elemBits = 8;
2100+
}
2101+
else if (mKVCacheQuantMode.hasFp4KvCache())
2102+
{
2103+
elemBits = 4;
2104+
}
2105+
else
2106+
{
2107+
elemBits = sizeof(T) * 8;
2108+
}
2109+
2110+
auto const sizePerToken = mNumKVHeads * headSize * elemBits / 8 /*bits*/;
2111+
20532112
if (useKVCache())
20542113
{
20552114
if constexpr (std::is_same_v<KVCacheBuffer, KVBlockArray>)
@@ -2059,13 +2118,22 @@ int AttentionOp::enqueueGeneration(EnqueueGenerationParams<T> const& params, cud
20592118
params.cyclic_attention_window_size, params.max_cyclic_attention_window_size, params.sink_token_length,
20602119
params.can_use_one_more_block, params.host_primary_pool_pointer, params.host_secondary_pool_pointer,
20612120
reinterpret_cast<BufferDataType*>(params.block_offsets));
2121+
if (mKVCacheQuantMode.hasFp4KvCache())
2122+
{
2123+
kv_scale_cache_buffer = KVBlockArray(batch_beam, params.max_blocks_per_sequence, mTokensPerBlock,
2124+
sizePerToken / 8, params.cyclic_attention_window_size, params.max_cyclic_attention_window_size,
2125+
params.sink_token_length, params.can_use_one_more_block,
2126+
params.host_primary_block_scale_pool_pointer, params.host_secondary_block_scale_pool_pointer,
2127+
reinterpret_cast<BufferDataType*>(params.block_offsets));
2128+
}
20622129
}
20632130
else if constexpr (std::is_same_v<KVCacheBuffer, KVLinearBuffer>)
20642131
{
20652132
using BufferDataType = typename KVCacheBuffer::DataType;
20662133
kv_cache_buffer = KVLinearBuffer(batch_beam, params.max_attention_window_size, sizePerToken,
20672134
params.cyclic_attention_window_size, params.sink_token_length, false,
20682135
reinterpret_cast<BufferDataType*>(params.key_value_cache));
2136+
TLLM_CHECK_WITH_INFO(!(mKVCacheQuantMode.hasFp4KvCache()), "FP4 KV cache only supports paged KV.");
20692137
}
20702138
}
20712139
sync_check_cuda_error(stream);
@@ -2137,7 +2205,7 @@ int AttentionOp::enqueueGeneration(EnqueueGenerationParams<T> const& params, cud
21372205
xqaParams.output = mhaOutput;
21382206
xqaParams.qkv = attention_input;
21392207
}
2140-
mXqaDispatcher->run(xqaParams, kv_cache_buffer);
2208+
mXqaDispatcher->run(xqaParams, kv_cache_buffer, kv_scale_cache_buffer);
21412209
if (mCpSize > 1 && mAttnTpSize > 1 && mAttnCpSize == 1)
21422210
{
21432211
this->template ulyssesGenerationPostprocess<T>(
@@ -2154,6 +2222,10 @@ int AttentionOp::enqueueGeneration(EnqueueGenerationParams<T> const& params, cud
21542222
{
21552223
TLLM_CHECK_WITH_INFO(false, "No available kernels are found for FP4 output.");
21562224
}
2225+
else if (mKVCacheQuantMode.hasFp4KvCache())
2226+
{
2227+
TLLM_CHECK_WITH_INFO(false, "No available kernels are found for FP4 KV cache.");
2228+
}
21572229
}
21582230

21592231
// This is the number of kv tokens that q needs to visit, but excluding one as it will be processed before the kv
@@ -2419,6 +2491,10 @@ int AttentionOp::initialize() noexcept
24192491
TLLM_CHECK_WITH_INFO(!mFuseFp4Quant || mSM == 100 || mSM == 120 || mSM == 121,
24202492
"fuse_fp4_quant only supports SM100 or SM120 or SM121 devices.");
24212493

2494+
// Check requirements for FP4 KV cache.
2495+
TLLM_CHECK_WITH_INFO(!mKVCacheQuantMode.hasFp4KvCache() || mFP8ContextFMHA,
2496+
"mFP8ContextFMHA must enable if FP4 KV cache is enabled");
2497+
24222498
TLLM_CHECK(isRoPE() == (mRotaryEmbeddingDim != 0));
24232499
TLLM_CHECK_WITH_INFO((mSM >= 80) || (mType != nvinfer1::DataType::kBF16),
24242500
"Unsupported data type, pre SM 80 GPUs do not support bfloat16");
@@ -2495,7 +2571,10 @@ int AttentionOp::initialize() noexcept
24952571
{
24962572
fmhaParams.dataTypeKv = DATA_TYPE_E4M3;
24972573
}
2498-
// TODO: add FP4 KV cache support.
2574+
else if (mKVCacheQuantMode.hasFp4KvCache())
2575+
{
2576+
fmhaParams.dataTypeKv = DATA_TYPE_E2M1;
2577+
}
24992578
}
25002579
// The output dtype.
25012580
fmhaParams.dataTypeOut = data_type;
@@ -2697,6 +2776,11 @@ int AttentionOp::initialize() noexcept
26972776
fixedParams.kvDataType = DATA_TYPE_E4M3;
26982777
fixedParams.mathDataType = DATA_TYPE_E4M3;
26992778
}
2779+
else if (mKVCacheQuantMode.hasFp4KvCache())
2780+
{
2781+
fixedParams.kvDataType = DATA_TYPE_E2M1;
2782+
fixedParams.mathDataType = DATA_TYPE_E4M3;
2783+
}
27002784
else
27012785
{
27022786
fixedParams.kvDataType = fixedParams.inputDataType;

cpp/tensorrt_llm/common/attentionOp.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,8 @@ class AttentionOp
9292
kernels::KVBlockArray::DataType* block_offsets = nullptr;
9393
void* host_primary_pool_pointer = nullptr;
9494
void* host_secondary_pool_pointer = nullptr;
95+
void* host_primary_block_scale_pool_pointer = nullptr;
96+
void* host_secondary_block_scale_pool_pointer = nullptr;
9597
int32_t num_tokens = 0;
9698
int32_t max_blocks_per_sequence = 0;
9799
int32_t const* sequence_lengths = nullptr;

cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fused_multihead_attention_common.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,8 @@ struct MHARunnerParams
257257
void const* kvPtr;
258258
// The paged kv cache array.
259259
KVBlockArray pagedKvCache;
260+
// The paged kv cache array for scaling factor.
261+
KVBlockArray pagedKvSfCache;
260262
// The output buffer ptr.
261263
void* outputPtr;
262264
// The output scaling factor buffer ptr. (only used for FP4 output)

cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/decoderXQAImplJIT.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -309,8 +309,7 @@ void DecoderXQAImplJIT::runImpl(XQAParams const& xqaParams, KVCacheBuffer const&
309309
preprocessingParams.cu_seq_lens = xqaParams.multi_query_tokens ? launchParams.cu_seq_lens : nullptr;
310310
preprocessingParams.rotary_embedding_inv_freq = rotary_inv_freq_buf;
311311
preprocessingParams.rotary_coef_cache_buffer = xqaParams.rotary_cos_sin;
312-
preprocessingParams.kvScaleOrigQuant = xqaParams.kv_scale_orig_quant;
313-
preprocessingParams.kv_cache_scale_factors = nullptr;
312+
preprocessingParams.kv_scale_orig_quant = xqaParams.kv_scale_orig_quant;
314313
preprocessingParams.spec_decoding_position_offsets = xqaParams.spec_decoding_position_offsets;
315314
preprocessingParams.mrope_position_deltas = xqaParams.mrope_position_deltas;
316315
// Scalar parameters.

cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplPrecompiled.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -224,8 +224,7 @@ class XQAKernelList
224224
preprocessingParams.cu_seq_lens = xqaParams.multi_query_tokens ? launchParams.cu_seq_lens : nullptr;
225225
preprocessingParams.rotary_embedding_inv_freq = rotary_inv_freq_buf;
226226
preprocessingParams.rotary_coef_cache_buffer = xqaParams.rotary_cos_sin;
227-
preprocessingParams.kvScaleOrigQuant = xqaParams.kv_scale_orig_quant;
228-
preprocessingParams.kv_cache_scale_factors = nullptr;
227+
preprocessingParams.kv_scale_orig_quant = xqaParams.kv_scale_orig_quant;
229228
preprocessingParams.spec_decoding_position_offsets = xqaParams.spec_decoding_position_offsets;
230229
preprocessingParams.mrope_position_deltas = xqaParams.mrope_position_deltas;
231230
// Scalar parameters.

cpp/tensorrt_llm/kernels/fmhaDispatcher.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ void FmhaDispatcher::run(MHARunnerParams runnerParams)
135135
TLLM_CHECK_WITH_INFO(mTllmGenFMHARunner.get(), "mTllmGenFMHARunner not initialized.");
136136
// Convert from MHAFixedParams + MHARunnerParams to TllmGenFmhaRunnerParams
137137
void const* kvPoolPtr = nullptr;
138+
void const* kvSfPoolPtr = nullptr;
138139
void const* kvPageIdxPtr = nullptr;
139140
auto qkvLayout = kernels::QkvLayout::PackedQkv;
140141
int32_t maxBlocksPerSeq = 0;
@@ -144,6 +145,7 @@ void FmhaDispatcher::run(MHARunnerParams runnerParams)
144145
qkvLayout = kernels::QkvLayout::PagedKv;
145146
auto pagedKvCache = runnerParams.pagedKvCache.copyKVBlockArrayForContextFMHA();
146147
kvPoolPtr = pagedKvCache.mPrimaryPoolPtr;
148+
kvSfPoolPtr = runnerParams.pagedKvSfCache.mPrimaryPoolPtr;
147149
kvPageIdxPtr = reinterpret_cast<int const*>(pagedKvCache.data);
148150
maxBlocksPerSeq = pagedKvCache.mMaxBlocksPerSeq;
149151
numTokensPerBlock = pagedKvCache.mTokensPerBlock;
@@ -164,6 +166,7 @@ void FmhaDispatcher::run(MHARunnerParams runnerParams)
164166
tllmRunnerParams.kPtr = nullptr;
165167
tllmRunnerParams.vPtr = nullptr;
166168
tllmRunnerParams.kvPtr = kvPoolPtr;
169+
tllmRunnerParams.kvSfPtr = kvSfPoolPtr;
167170
tllmRunnerParams.qkvPtr = runnerParams.qkvPtr;
168171
tllmRunnerParams.cumSeqLensQPtr = reinterpret_cast<int const*>(runnerParams.cuQSeqLenPtr);
169172
tllmRunnerParams.cumSeqLensKvPtr = reinterpret_cast<int const*>(runnerParams.cuKvSeqLenPtr);

0 commit comments

Comments
 (0)