@@ -504,7 +504,8 @@ BlockManager::BlockManager(std::vector<SizeType32> const& numKvHeadsPerLayer, Si
504
504
std::optional<TempAttentionWindowInputs> const & tempAttentionWindowInputs, nvinfer1::DataType dtype,
505
505
SizeType32 sinkBubbleLength, bool onboardBlocks, CacheType cacheType,
506
506
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority,
507
- std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse)
507
+ std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse,
508
+ std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager)
508
509
: mNumLayers {static_cast <SizeType32>(numKvHeadsPerLayer.size ())}
509
510
, mTokensPerBlock {tokensPerBlock}
510
511
, mEventManager {std::move (eventManager)}
@@ -513,6 +514,10 @@ BlockManager::BlockManager(std::vector<SizeType32> const& numKvHeadsPerLayer, Si
513
514
{
514
515
auto const uniqueWindowSizeToLayers
515
516
= BaseKVCacheManager::groupLayersByWindowSize (maxAttentionWindowVec, mNumLayers );
517
+
518
+ TLLM_CHECK_WITH_INFO (kvCacheConnectorManager == nullptr || uniqueWindowSizeToLayers.size () == 1 ,
519
+ " KV Cache Connector is not supported with multiple window sizes" );
520
+
516
521
auto const numUniqueWindowSizes = static_cast <SizeType32>(uniqueWindowSizeToLayers.size ());
517
522
518
523
mIsVariableWindow = numUniqueWindowSizes > 1 ;
@@ -530,7 +535,7 @@ BlockManager::BlockManager(std::vector<SizeType32> const& numKvHeadsPerLayer, Si
530
535
mWindowBlockManagers .try_emplace (windowSize, dtype, windowSize, layersWithWindowSize, numKvHeadsPerLayer,
531
536
sizePerHead, tokensPerBlock, allottedPrimaryBlocks, allottedSecondaryBlocks, maxNumSequences, stream,
532
537
onboardBlocks, cacheType, secondaryOffloadMinPriority, mEventManager , enablePartialReuse,
533
- copyOnPartialReuse);
538
+ copyOnPartialReuse, kvCacheConnectorManager );
534
539
}
535
540
536
541
auto const numAllPools = getNumPools ();
@@ -572,7 +577,8 @@ WindowBlockManager::WindowBlockManager(nvinfer1::DataType dtype, SizeType32 wind
572
577
SizeType32 sizePerHead, SizeType32 tokensPerBlock, SizeType32 blocksInPrimaryPool, SizeType32 blocksInSecondaryPool,
573
578
SizeType32 maxNumSequences, std::shared_ptr<runtime::CudaStream> stream, bool onboardBlocks, CacheType cacheType,
574
579
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority,
575
- std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse)
580
+ std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse,
581
+ std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager)
576
582
: mDataType {dtype}
577
583
, mWindowSize {windowSize}
578
584
, mNumPrimaryBlocks {blocksInPrimaryPool}
@@ -596,6 +602,7 @@ WindowBlockManager::WindowBlockManager(nvinfer1::DataType dtype, SizeType32 wind
596
602
, mTotalInputTokens {0.0 }
597
603
, mEnablePartialReuse {enablePartialReuse}
598
604
, mCopyOnPartialReuse {copyOnPartialReuse}
605
+ , mKvCacheConnectorManager {std::move (kvCacheConnectorManager)}
599
606
{
600
607
std::map<SizeType32, SizeType32> numLayersPerPool;
601
608
@@ -1188,9 +1195,18 @@ void WindowBlockManager::addSequence(
1188
1195
auto const prepopulatedPromptLen = loadOrAllocateBlocks (blockKeys, numContextBlocks, sequence, perBlockRetentions);
1189
1196
mReusedTokens += static_cast <double >(prepopulatedPromptLen);
1190
1197
mTotalInputTokens += static_cast <double >(uniqueTokens.size ());
1191
- llmRequest.setPrepopulatedPromptLen (prepopulatedPromptLen, getTokensPerBlock ());
1192
- TLLM_LOG_DEBUG (" addSequence: Request %lu, inputLength %d, prepopulatedPromptLen %d" , llmRequest.mRequestId ,
1193
- inputLength, prepopulatedPromptLen);
1198
+
1199
+ SizeType32 numConnectorMatchedTokens = 0 ;
1200
+
1201
+ // If we're using a KV cache connector, check if any additional blocks can be loaded.
1202
+ if (mKvCacheConnectorManager && !llmRequest.isDummyRequest ())
1203
+ {
1204
+ numConnectorMatchedTokens = mKvCacheConnectorManager ->getNumNewMatchedTokens (llmRequest, prepopulatedPromptLen);
1205
+ }
1206
+
1207
+ llmRequest.setPrepopulatedPromptLen (prepopulatedPromptLen + numConnectorMatchedTokens, getTokensPerBlock ());
1208
+ TLLM_LOG_DEBUG (" addSequence: Request %lu, inputLength %d, prepopulatedPromptLen %d, numConnectorMatchedTokens %d" ,
1209
+ llmRequest.mRequestId , inputLength, prepopulatedPromptLen, numConnectorMatchedTokens);
1194
1210
}
1195
1211
1196
1212
// There are two versions of BlockManager::addSequence function.
@@ -1206,6 +1222,13 @@ void BlockManager::addSequence(
1206
1222
void WindowBlockManager::addSequence (
1207
1223
GenerationRequest& sequence, SizeType32 numContextBlocks, bool isShareLastContextBlock)
1208
1224
{
1225
+ if (mKvCacheConnectorManager )
1226
+ {
1227
+ TLLM_LOG_WARNING (
1228
+ " KV Cache Connector specified when block reuse is disabled. The KV Cache Connector will be "
1229
+ " ignored." );
1230
+ }
1231
+
1209
1232
auto const requestId = sequence.getRequestId ();
1210
1233
auto const [seqIt, emplaceDone] = mAllocatedBlocksPerSeq .emplace (requestId, std::vector<BlockPtr>{});
1211
1234
TLLM_CHECK (emplaceDone);
@@ -1618,12 +1641,13 @@ KVCacheManager::KVCacheManager(std::vector<SizeType32> const& numKvHeadsPerLayer
1618
1641
SizeType32 sinkTokenLength, int64_t stream, std::optional<runtime::SizeType32> maxSequenceLength,
1619
1642
bool enableBlockReuse, bool onboardBlocks, CacheType cacheType,
1620
1643
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority,
1621
- std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse)
1644
+ std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse,
1645
+ std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager)
1622
1646
: KVCacheManager(numKvHeadsPerLayer, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, maxBeamWidth,
1623
1647
maxAttentionWindowVec, tempAttentionWindowInputs, dtype, sinkTokenLength,
1624
1648
std::make_shared<runtime::CudaStream>(reinterpret_cast <cudaStream_t>(stream)), maxSequenceLength,
1625
1649
enableBlockReuse, onboardBlocks, cacheType, secondaryOffloadMinPriority, eventManager, enablePartialReuse,
1626
- copyOnPartialReuse)
1650
+ copyOnPartialReuse, kvCacheConnectorManager )
1627
1651
{
1628
1652
}
1629
1653
@@ -1634,7 +1658,8 @@ KVCacheManager::KVCacheManager(std::vector<SizeType32> const& numKvHeadsPerLayer
1634
1658
SizeType32 sinkTokenLength, CudaStreamPtr stream, std::optional<runtime::SizeType32> maxSequenceLength,
1635
1659
bool enableBlockReuse, bool onboardBlocks, CacheType cacheType,
1636
1660
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority,
1637
- std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse)
1661
+ std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse,
1662
+ std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager)
1638
1663
: mMaxBeamWidth (maxBeamWidth)
1639
1664
, mDataType (dtype)
1640
1665
, mMaxAttentionWindow (*std::max_element (maxAttentionWindowVec.begin(), maxAttentionWindowVec.end()))
@@ -1644,7 +1669,7 @@ KVCacheManager::KVCacheManager(std::vector<SizeType32> const& numKvHeadsPerLayer
1644
1669
, mBlockManager(numKvHeadsPerLayer, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences,
1645
1670
std::move (stream), maxSequenceLength, maxBeamWidth, maxAttentionWindowVec, tempAttentionWindowInputs, dtype,
1646
1671
mSinkBubbleLength, onboardBlocks, cacheType, secondaryOffloadMinPriority, std::move(eventManager),
1647
- enablePartialReuse, copyOnPartialReuse)
1672
+ enablePartialReuse, copyOnPartialReuse, std::move(kvCacheConnectorManager) )
1648
1673
// disable block reuse for sink bubble since chopVectorIntoBlocks does not match KV cache blocks in this case
1649
1674
, mEnableBlockReuse{mSinkBubbleLength > 0 ? false : enableBlockReuse}
1650
1675
{
@@ -1668,11 +1693,12 @@ KVCacheManager::KVCacheManager(SizeType32 numLayers, SizeType32 numKvHeads, Size
1668
1693
SizeType32 sinkTokenLength, CudaStreamPtr stream, std::optional<runtime::SizeType32> maxSequenceLength,
1669
1694
bool enableBlockReuse, bool onboardBlocks, CacheType cacheType,
1670
1695
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority,
1671
- std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse)
1696
+ std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse,
1697
+ std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager)
1672
1698
: KVCacheManager(std::vector<SizeType32>(numLayers, numKvHeads), sizePerHead, tokensPerBlock, blocksPerWindow,
1673
1699
maxNumSequences, maxBeamWidth, maxAttentionWindowVec, tempAttentionWindowInputs, dtype, sinkTokenLength,
1674
1700
std::move(stream), maxSequenceLength, enableBlockReuse, onboardBlocks, cacheType, secondaryOffloadMinPriority,
1675
- std::move(eventManager), enablePartialReuse, copyOnPartialReuse)
1701
+ std::move(eventManager), enablePartialReuse, copyOnPartialReuse, std::move(kvCacheConnectorManager) )
1676
1702
{
1677
1703
}
1678
1704
@@ -2383,6 +2409,13 @@ std::vector<SizeType32> KVCacheManager::getNewlyAllocatedBlockIds(
2383
2409
return mBlockManager .getNewlyAllocatedBlockIds (getSequence (requestId), windowSize);
2384
2410
}
2385
2411
2412
+ runtime::ITensor::SharedPtr KVCacheManager::getUniquePrimaryPool () const
2413
+ {
2414
+ TLLM_CHECK_WITH_INFO (mBlockManager .getWindowSizesMetadata ().size () == 1 ,
2415
+ " getUniquePrimaryPool is only supported for a single window size" );
2416
+ return mBlockManager .getPrimaryPool (0 );
2417
+ }
2418
+
2386
2419
runtime::ITensor::SharedPtr KVCacheManager::getPrimaryPool (SizeType32 layer_idx) const
2387
2420
{
2388
2421
return mBlockManager .getPrimaryPool (mBlockManager .getLayerPoolIdx (layer_idx));
@@ -2462,4 +2495,5 @@ SizeType32 KVCacheManager::calculateMaxBlockRequirements(SizeType32 inputLength,
2462
2495
auto const leftoverBlockCapacity = blockCapacity - outputBlockRequirements;
2463
2496
return std::min (outputLength + leftoverBlockCapacity * tokensPerBlock, inputLength + outputLength);
2464
2497
}
2498
+
2465
2499
} // namespace tensorrt_llm::batch_manager::kv_cache_manager
0 commit comments