Skip to content

Commit ba694eb

Browse files
committed
[KV Cache Manager] Dead code elimination, we no longer record/fetch the blocks with hashmap
WindowBlockManager::mCachedBlocksRoot is responsible for the bookkeeping of the KVCacheBlock, and the mNextBlocks is now the actual hash map that fetches the block. Signed-off-by: eopXD <[email protected]>
1 parent 907c180 commit ba694eb

File tree

3 files changed

+14
-311
lines changed

3 files changed

+14
-311
lines changed

cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h

Lines changed: 7 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -536,8 +536,7 @@ class WindowBlockManager
536536
SizeType32 sizePerHead, SizeType32 tokensPerBlock, SizeType32 blocksInPrimaryPool,
537537
SizeType32 blocksInSecondaryPool, SizeType32 maxNumSequences, std::shared_ptr<runtime::CudaStream> stream,
538538
bool onboardBlocks, CacheType cacheType, std::optional<executor::RetentionPriority> secondaryOffloadMinPriority,
539-
std::shared_ptr<KVCacheEventManager> eventManager, bool enableHashKey, bool enablePartialReuse,
540-
bool copyOnPartialReuse);
539+
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse);
541540

542541
~WindowBlockManager();
543542

@@ -633,11 +632,6 @@ class WindowBlockManager
633632
return mAllBlocksById.at(blockId);
634633
}
635634

636-
[[nodiscard]] BlockMapIterRange getBlocksByHash(size_t hash) const
637-
{
638-
return mContextBlocksByHash.equal_range(hash);
639-
}
640-
641635
[[nodiscard]] SizeType32 getTokensPerBlock() const noexcept
642636
{
643637
return mTokensPerBlock;
@@ -723,10 +717,6 @@ class WindowBlockManager
723717
//! \param blockIds Id of each block.
724718
void storeBlocks(std::vector<BlockKey> const& blockKeys, std::vector<KVCacheBlock::IdType> const& blockIds);
725719

726-
void addBlockToHashMap(BlockPtr const& block);
727-
728-
void removeBlockFromHashMap(BlockPtr const& block);
729-
730720
[[nodiscard]] bool verifyQueueIntegrity();
731721

732722
// Only needed when sliding window attention + paged context fmha are used together.
@@ -808,8 +798,6 @@ class WindowBlockManager
808798
SizeType32 mTokensPerBlock;
809799
// List of all blocks by idx
810800
std::vector<BlockPtr> mAllBlocksById;
811-
// List of all context blocks by hash
812-
BlockMap mContextBlocksByHash;
813801
// Dummy block acting as root for BlockToken searches
814802
BlockPtr mCachedBlocksRoot;
815803
// KV cache type (self or cross)
@@ -841,8 +829,6 @@ class WindowBlockManager
841829
double mReusedTokens;
842830
// Total number of input tokens
843831
double mTotalInputTokens;
844-
// Whether or not to maintain a hashmap of blocks.
845-
bool mEnableHashKey;
846832
// Whether blocks that are partially matched should be reused.
847833
bool mEnablePartialReuse;
848834
// Whether partially matched blocks that are already in use should be copied and reused.
@@ -863,8 +849,8 @@ class BlockManager
863849
std::optional<TempAttentionWindowInputs> const& tempAttentionWindowInputs, nvinfer1::DataType dtype,
864850
SizeType32 sinkBubbleLength, bool onboardBlocks, CacheType cacheType = CacheType::kSELF,
865851
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority = std::nullopt,
866-
std::shared_ptr<KVCacheEventManager> eventManager = nullptr, bool enableHashKey = false,
867-
bool enablePartialReuse = true, bool copyOnPartialReuse = true);
852+
std::shared_ptr<KVCacheEventManager> eventManager = nullptr, bool enablePartialReuse = true,
853+
bool copyOnPartialReuse = true);
868854

869855
BlockManager(BlockManager const&) = delete;
870856
BlockManager& operator=(BlockManager const&) = delete;
@@ -1081,11 +1067,6 @@ class BlockManager
10811067
return mWindowBlockManagers.at(windowSize).getBlockById(blockId);
10821068
}
10831069

1084-
[[nodiscard]] WindowBlockManager::BlockMapIterRange getBlocksByHash(size_t hash, SizeType32 windowSize) const
1085-
{
1086-
return mWindowBlockManagers.at(windowSize).getBlocksByHash(hash);
1087-
}
1088-
10891070
[[nodiscard]] SizeType32 getNumPrimaryBlocks() const
10901071
{
10911072
return sumWindows([](auto const& manager) { return manager.getNumPrimaryBlocks(); });
@@ -1096,16 +1077,6 @@ class BlockManager
10961077
return getPool(poolIdx).containsBlockScales;
10971078
}
10981079

1099-
void addBlockToHashMap(BlockPtr const& block, SizeType32 windowSize)
1100-
{
1101-
mWindowBlockManagers.at(windowSize).addBlockToHashMap(block);
1102-
}
1103-
1104-
void removeBlockFromHashMap(BlockPtr const& block, SizeType32 windowSize)
1105-
{
1106-
mWindowBlockManagers.at(windowSize).removeBlockFromHashMap(block);
1107-
}
1108-
11091080
//! \brief Store context blocks
11101081
void storeContextBlocks(GenerationRequest& sequence, LlmRequest const& llmRequest);
11111082

@@ -1385,8 +1356,8 @@ class KVCacheManager : public BaseKVCacheManager
13851356
SizeType32 sinkTokenLength, CudaStreamPtr stream, std::optional<SizeType32> maxSequenceLength,
13861357
bool enableBlockReuse = false, bool onboardBlocks = true, CacheType cacheType = CacheType::kSELF,
13871358
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority = std::nullopt,
1388-
std::shared_ptr<KVCacheEventManager> eventManager = nullptr, bool enableHashKey = false,
1389-
bool enablePartialReuse = true, bool copyOnpartialReuse = true);
1359+
std::shared_ptr<KVCacheEventManager> eventManager = nullptr, bool enablePartialReuse = true,
1360+
bool copyOnpartialReuse = true);
13901361

13911362
KVCacheManager(std::vector<SizeType32> const& numKvHeadsPerLayer, SizeType32 sizePerHead, SizeType32 tokensPerBlock,
13921363
BlocksPerWindow const& blocksPerWindow, SizeType32 maxNumSequences, SizeType32 maxBeamWidth,
@@ -1405,8 +1376,8 @@ class KVCacheManager : public BaseKVCacheManager
14051376
SizeType32 sinkTokenLength, CudaStreamPtr stream, std::optional<SizeType32> maxSequenceLength,
14061377
bool enableBlockReuse = true, bool onboardBlocks = true, CacheType cacheType = CacheType::kSELF,
14071378
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority = std::nullopt,
1408-
std::shared_ptr<KVCacheEventManager> eventManager = nullptr, bool enableHashKey = false,
1409-
bool enablePartialReuse = true, bool copyOnpartialReuse = true);
1379+
std::shared_ptr<KVCacheEventManager> eventManager = nullptr, bool enablePartialReuse = true,
1380+
bool copyOnpartialReuse = true);
14101381

14111382
KVCacheManager(SizeType32 numLayers, SizeType32 numKvHeads, SizeType32 sizePerHead, SizeType32 tokensPerBlock,
14121383
BlocksPerWindow const& blocksPerWindow, SizeType32 maxNumSequences, SizeType32 maxBeamWidth,
@@ -1692,8 +1663,6 @@ class KVCacheManager : public BaseKVCacheManager
16921663
std::unordered_map<LlmRequest::RequestIdType, GenerationRequest> mSequences;
16931664
// Whether to cache KV pages for reuse
16941665
bool mEnableBlockReuse;
1695-
// Whether enable finding blocks by their hash, ignored when reuse enabled
1696-
bool mEnableHashKey;
16971666
// Mutex to protect access to mSequences
16981667
mutable std::mutex mSequencesMtx;
16991668
// buffers for static tensors, will be created after allocating pools

cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp

Lines changed: 7 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -504,8 +504,7 @@ BlockManager::BlockManager(std::vector<SizeType32> const& numKvHeadsPerLayer, Si
504504
std::optional<TempAttentionWindowInputs> const& tempAttentionWindowInputs, nvinfer1::DataType dtype,
505505
SizeType32 sinkBubbleLength, bool onboardBlocks, CacheType cacheType,
506506
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority,
507-
std::shared_ptr<KVCacheEventManager> eventManager, bool enableHashKey, bool enablePartialReuse,
508-
bool copyOnPartialReuse)
507+
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse)
509508
: mNumLayers{static_cast<SizeType32>(numKvHeadsPerLayer.size())}
510509
, mTokensPerBlock{tokensPerBlock}
511510
, mEventManager{std::move(eventManager)}
@@ -530,7 +529,7 @@ BlockManager::BlockManager(std::vector<SizeType32> const& numKvHeadsPerLayer, Si
530529
TLLM_CHECK(allottedPrimaryBlocks > 0); // You can't have a model with negative primary blocks...
531530
mWindowBlockManagers.try_emplace(windowSize, dtype, windowSize, layersWithWindowSize, numKvHeadsPerLayer,
532531
sizePerHead, tokensPerBlock, allottedPrimaryBlocks, allottedSecondaryBlocks, maxNumSequences, stream,
533-
onboardBlocks, cacheType, secondaryOffloadMinPriority, mEventManager, enableHashKey, enablePartialReuse,
532+
onboardBlocks, cacheType, secondaryOffloadMinPriority, mEventManager, enablePartialReuse,
534533
copyOnPartialReuse);
535534
}
536535

@@ -573,8 +572,7 @@ WindowBlockManager::WindowBlockManager(nvinfer1::DataType dtype, SizeType32 wind
573572
SizeType32 sizePerHead, SizeType32 tokensPerBlock, SizeType32 blocksInPrimaryPool, SizeType32 blocksInSecondaryPool,
574573
SizeType32 maxNumSequences, std::shared_ptr<runtime::CudaStream> stream, bool onboardBlocks, CacheType cacheType,
575574
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority,
576-
std::shared_ptr<KVCacheEventManager> eventManager, bool enableHashKey, bool enablePartialReuse,
577-
bool copyOnPartialReuse)
575+
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse)
578576
: mDataType{dtype}
579577
, mWindowSize{windowSize}
580578
, mNumPrimaryBlocks{blocksInPrimaryPool}
@@ -596,7 +594,6 @@ WindowBlockManager::WindowBlockManager(nvinfer1::DataType dtype, SizeType32 wind
596594
, mLogPrefix{tensorrt_llm::common::fmtstr("BlockManager[windowSize=%u]", mWindowSize)}
597595
, mReusedTokens{0.0}
598596
, mTotalInputTokens{0.0}
599-
, mEnableHashKey{enableHashKey}
600597
, mEnablePartialReuse{enablePartialReuse}
601598
, mCopyOnPartialReuse{copyOnPartialReuse}
602599
{
@@ -920,50 +917,6 @@ void BlockManager::setOffsets(tk::KVCacheIndex* offsetsPtr, nvinfer1::Dims const
920917
mWindowBlockManagers.at(windowSize).setOffsets(offsetsPtr, offsetsShape, beamIdx, blockIdx, blockId);
921918
}
922919

923-
void WindowBlockManager::addBlockToHashMap(BlockPtr const& block)
924-
{
925-
if (!mEnableHashKey)
926-
{
927-
return;
928-
}
929-
auto range = mContextBlocksByHash.equal_range(block->getHash());
930-
for (auto it = range.first; it != range.second; ++it)
931-
{
932-
if (it->second == block)
933-
{
934-
// TODO: change to assert when reused block is added only once
935-
TLLM_LOG_TRACE(
936-
"Block %d by %zx exists", block->getBlockId(), block->getHash(), mContextBlocksByHash.size());
937-
return;
938-
}
939-
}
940-
TLLM_LOG_TRACE(
941-
"Add block %d by %zx, block n = %zu", block->getBlockId(), block->getHash(), mContextBlocksByHash.size());
942-
mContextBlocksByHash.emplace(block->getHash(), std::move(block));
943-
}
944-
945-
void WindowBlockManager::removeBlockFromHashMap(BlockPtr const& block)
946-
{
947-
if (mContextBlocksByHash.empty() || block->getBlockKey().uniqueTokens.empty())
948-
{
949-
// Hash key not enabled / Empty block
950-
return;
951-
}
952-
auto range = mContextBlocksByHash.equal_range(block->getHash());
953-
TLLM_LOG_TRACE(
954-
"Remove block %d by %zx, block n = %zu", block->getBlockId(), block->getHash(), mContextBlocksByHash.size());
955-
for (auto it = range.first; it != range.second; ++it)
956-
{
957-
if (it->second == block)
958-
{
959-
mContextBlocksByHash.erase(it);
960-
return;
961-
}
962-
}
963-
// TODO: should be unreachable
964-
TLLM_LOG_DEBUG("Trying to remove block %d by %zx that is not in hash map", block->getBlockId(), block->getHash());
965-
}
966-
967920
void BlockManager::onboardBlock(BlockPtr const& offloadBlock, SizeType32 windowSize)
968921
{
969922
mWindowBlockManagers.at(windowSize).onboardBlock(offloadBlock);
@@ -1104,7 +1057,6 @@ SizeType32 WindowBlockManager::loadOrAllocateBlocks(std::vector<BlockKey> const&
11041057
matchingBlock, perBlockRetentions[bi].retentionPriority, perBlockRetentions[bi].durationMs);
11051058
TLLM_LOG_DEBUG("%s::loadOrAllocateBlocks - Reused partially filled block %d", mLogPrefix.c_str(),
11061059
matchingBlockId);
1107-
addBlockToHashMap(matchingBlock);
11081060
}
11091061
searchRoot = nullptr; // no matching needed for following blocks
11101062
}
@@ -1114,7 +1066,6 @@ SizeType32 WindowBlockManager::loadOrAllocateBlocks(std::vector<BlockKey> const&
11141066
mEvictionPolicy->claimBlock(
11151067
matchingBlock, perBlockRetentions[bi].retentionPriority, perBlockRetentions[bi].durationMs);
11161068
TLLM_LOG_DEBUG("%s::loadOrAllocateBlocks - Matched full block %d", mLogPrefix.c_str(), matchingBlockId);
1117-
addBlockToHashMap(matchingBlock);
11181069
searchRoot = matchingBlock;
11191070
}
11201071
onboardBlock(matchingBlock);
@@ -1145,7 +1096,6 @@ SizeType32 WindowBlockManager::loadOrAllocateBlocks(std::vector<BlockKey> const&
11451096
++blockItr;
11461097
}
11471098
freeBlock->setHash();
1148-
addBlockToHashMap(freeBlock);
11491099
++mMissedBlocks;
11501100
}
11511101
}
@@ -1169,7 +1119,6 @@ SizeType32 WindowBlockManager::loadOrAllocateBlocks(std::vector<BlockKey> const&
11691119
++blockItr;
11701120
}
11711121
freeBlock->setHash();
1172-
addBlockToHashMap(freeBlock);
11731122
TLLM_LOG_DEBUG("%s::loadOrAllocateBlocks - Beam %d. Allocated non-shared block %d for bi %d",
11741123
mLogPrefix.c_str(), beamIdx, freeBlock->getBlockId(), bi);
11751124
}
@@ -1369,9 +1318,7 @@ void WindowBlockManager::storeBlocks(
13691318
if (oldHash != newHash)
13701319
{
13711320
TLLM_LOG_DEBUG("#%d block hash %zx -> %zx", block->getBlockId(), oldHash, newHash);
1372-
removeBlockFromHashMap(block);
13731321
block->setHash(newHash);
1374-
addBlockToHashMap(block);
13751322
}
13761323
searchRoot = block;
13771324
}
@@ -1408,7 +1355,6 @@ void WindowBlockManager::replaceSharedBlock(GenerationRequest& sequence, SizeTyp
14081355
if (!block->hasRefs())
14091356
{
14101357
mEvictionPolicy->releaseBlock(block);
1411-
removeBlockFromHashMap(block);
14121358
}
14131359
}
14141360

@@ -1473,7 +1419,6 @@ void WindowBlockManager::releaseLastBlock(GenerationRequest& sequence)
14731419
if (!block->hasRefs())
14741420
{
14751421
mEvictionPolicy->releaseBlock(block, true);
1476-
removeBlockFromHashMap(block);
14771422
}
14781423
// Remove block from allocated blocks
14791424
allocatedBlocks.pop_back();
@@ -1616,7 +1561,6 @@ void WindowBlockManager::releaseBlocks(GenerationRequest& sequence)
16161561
if (!block->hasRefs())
16171562
{
16181563
mEvictionPolicy->releaseBlock(block);
1619-
removeBlockFromHashMap(block);
16201564
}
16211565
}
16221566
// Remove stored block ids in sequence
@@ -1682,8 +1626,7 @@ KVCacheManager::KVCacheManager(std::vector<SizeType32> const& numKvHeadsPerLayer
16821626
SizeType32 sinkTokenLength, CudaStreamPtr stream, std::optional<runtime::SizeType32> maxSequenceLength,
16831627
bool enableBlockReuse, bool onboardBlocks, CacheType cacheType,
16841628
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority,
1685-
std::shared_ptr<KVCacheEventManager> eventManager, bool enableHashKey, bool enablePartialReuse,
1686-
bool copyOnPartialReuse)
1629+
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse)
16871630
: mMaxBeamWidth(maxBeamWidth)
16881631
, mDataType(dtype)
16891632
, mMaxAttentionWindow(*std::max_element(maxAttentionWindowVec.begin(), maxAttentionWindowVec.end()))
@@ -1693,10 +1636,9 @@ KVCacheManager::KVCacheManager(std::vector<SizeType32> const& numKvHeadsPerLayer
16931636
, mBlockManager(numKvHeadsPerLayer, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences,
16941637
std::move(stream), maxSequenceLength, maxBeamWidth, maxAttentionWindowVec, tempAttentionWindowInputs, dtype,
16951638
mSinkBubbleLength, onboardBlocks, cacheType, secondaryOffloadMinPriority, std::move(eventManager),
1696-
enableHashKey, enablePartialReuse, copyOnPartialReuse)
1639+
enablePartialReuse, copyOnPartialReuse)
16971640
// disable block reuse for sink bubble since chopVectorIntoBlocks does not match KV cache blocks in this case
16981641
, mEnableBlockReuse{mSinkBubbleLength > 0 ? false : enableBlockReuse}
1699-
, mEnableHashKey{enableHashKey}
17001642
{
17011643
TLLM_CHECK_DEBUG(std::find(maxAttentionWindowVec.begin(), maxAttentionWindowVec.end(), mMaxAttentionWindow)
17021644
!= maxAttentionWindowVec.end());
@@ -1716,12 +1658,11 @@ KVCacheManager::KVCacheManager(SizeType32 numLayers, SizeType32 numKvHeads, Size
17161658
SizeType32 sinkTokenLength, CudaStreamPtr stream, std::optional<runtime::SizeType32> maxSequenceLength,
17171659
bool enableBlockReuse, bool onboardBlocks, CacheType cacheType,
17181660
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority,
1719-
std::shared_ptr<KVCacheEventManager> eventManager, bool enableHashKey, bool enablePartialReuse,
1720-
bool copyOnPartialReuse)
1661+
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse)
17211662
: KVCacheManager(std::vector<SizeType32>(numLayers, numKvHeads), sizePerHead, tokensPerBlock, blocksPerWindow,
17221663
maxNumSequences, maxBeamWidth, maxAttentionWindowVec, tempAttentionWindowInputs, dtype, sinkTokenLength,
17231664
std::move(stream), maxSequenceLength, enableBlockReuse, onboardBlocks, cacheType, secondaryOffloadMinPriority,
1724-
std::move(eventManager), enableHashKey, enablePartialReuse, copyOnPartialReuse)
1665+
std::move(eventManager), enablePartialReuse, copyOnPartialReuse)
17251666
{
17261667
}
17271668

@@ -2085,30 +2026,6 @@ void KVCacheManager::addSequence(
20852026
llmRequest->mRequestId);
20862027
}
20872028
mBlockManager.addSequence(sequence, numContextBlocks, unsharedBlockIdx, windowSize);
2088-
if (mEnableHashKey && llmRequest.has_value() && beamWidth == 1)
2089-
{
2090-
constexpr SizeType32 beamIdx = 0;
2091-
auto const& blockIds = sequence.getCacheBlockIds(windowSize).at(beamIdx);
2092-
auto const& uniqueTokens = llmRequest->getUniqueTokens(beamIdx);
2093-
auto blockedUniqueTokens = chopVectorIntoBlocks<UniqueToken>(
2094-
uniqueTokens, uniqueTokens.size() - 1, getTokensPerBlock(), true);
2095-
auto blockKeys = buildBlockKeys(blockedUniqueTokens, *llmRequest);
2096-
auto tokensPerBlock = static_cast<size_t>(getTokensPerBlock());
2097-
for (size_t i = 0; i < blockIds.size(); i++)
2098-
{
2099-
auto const& block = mBlockManager.getBlockById(blockIds[i], windowSize);
2100-
if (i < blockKeys.size())
2101-
{
2102-
block->setBlockKey(blockKeys[i], blockKeys[i].uniqueTokens.size() == tokensPerBlock);
2103-
}
2104-
else
2105-
{
2106-
block->setBlockKey({}, false);
2107-
}
2108-
block->setHash();
2109-
mBlockManager.addBlockToHashMap(block, windowSize);
2110-
}
2111-
}
21122029
}
21132030
cacheBlockOffsets(sequence, windowSize);
21142031
}

0 commit comments

Comments
 (0)