Skip to content

Commit 1cd1082

Browse files
committed
[KV Cache Manager] Dead code elimination, we no longer record/fetch the blocks with hashmap
WindowBlockManager::mCachedBlocksRoot is responsible for the bookkeeping of the KVCacheBlock, and the mNextBlocks is now the actual hash map that fetches the block. Signed-off-by: eopXD <[email protected]>
1 parent 907c180 commit 1cd1082

File tree

3 files changed

+0
-263
lines changed

3 files changed

+0
-263
lines changed

cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -633,11 +633,6 @@ class WindowBlockManager
633633
return mAllBlocksById.at(blockId);
634634
}
635635

636-
[[nodiscard]] BlockMapIterRange getBlocksByHash(size_t hash) const
637-
{
638-
return mContextBlocksByHash.equal_range(hash);
639-
}
640-
641636
[[nodiscard]] SizeType32 getTokensPerBlock() const noexcept
642637
{
643638
return mTokensPerBlock;
@@ -723,10 +718,6 @@ class WindowBlockManager
723718
//! \param blockIds Id of each block.
724719
void storeBlocks(std::vector<BlockKey> const& blockKeys, std::vector<KVCacheBlock::IdType> const& blockIds);
725720

726-
void addBlockToHashMap(BlockPtr const& block);
727-
728-
void removeBlockFromHashMap(BlockPtr const& block);
729-
730721
[[nodiscard]] bool verifyQueueIntegrity();
731722

732723
// Only needed when sliding window attention + paged context fmha are used together.
@@ -808,8 +799,6 @@ class WindowBlockManager
808799
SizeType32 mTokensPerBlock;
809800
// List of all blocks by idx
810801
std::vector<BlockPtr> mAllBlocksById;
811-
// List of all context blocks by hash
812-
BlockMap mContextBlocksByHash;
813802
// Dummy block acting as root for BlockToken searches
814803
BlockPtr mCachedBlocksRoot;
815804
// KV cache type (self or cross)
@@ -1081,11 +1070,6 @@ class BlockManager
10811070
return mWindowBlockManagers.at(windowSize).getBlockById(blockId);
10821071
}
10831072

1084-
[[nodiscard]] WindowBlockManager::BlockMapIterRange getBlocksByHash(size_t hash, SizeType32 windowSize) const
1085-
{
1086-
return mWindowBlockManagers.at(windowSize).getBlocksByHash(hash);
1087-
}
1088-
10891073
[[nodiscard]] SizeType32 getNumPrimaryBlocks() const
10901074
{
10911075
return sumWindows([](auto const& manager) { return manager.getNumPrimaryBlocks(); });
@@ -1096,16 +1080,6 @@ class BlockManager
10961080
return getPool(poolIdx).containsBlockScales;
10971081
}
10981082

1099-
void addBlockToHashMap(BlockPtr const& block, SizeType32 windowSize)
1100-
{
1101-
mWindowBlockManagers.at(windowSize).addBlockToHashMap(block);
1102-
}
1103-
1104-
void removeBlockFromHashMap(BlockPtr const& block, SizeType32 windowSize)
1105-
{
1106-
mWindowBlockManagers.at(windowSize).removeBlockFromHashMap(block);
1107-
}
1108-
11091083
//! \brief Store context blocks
11101084
void storeContextBlocks(GenerationRequest& sequence, LlmRequest const& llmRequest);
11111085

cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp

Lines changed: 0 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -920,50 +920,6 @@ void BlockManager::setOffsets(tk::KVCacheIndex* offsetsPtr, nvinfer1::Dims const
920920
mWindowBlockManagers.at(windowSize).setOffsets(offsetsPtr, offsetsShape, beamIdx, blockIdx, blockId);
921921
}
922922

923-
void WindowBlockManager::addBlockToHashMap(BlockPtr const& block)
924-
{
925-
if (!mEnableHashKey)
926-
{
927-
return;
928-
}
929-
auto range = mContextBlocksByHash.equal_range(block->getHash());
930-
for (auto it = range.first; it != range.second; ++it)
931-
{
932-
if (it->second == block)
933-
{
934-
// TODO: change to assert when reused block is added only once
935-
TLLM_LOG_TRACE(
936-
"Block %d by %zx exists", block->getBlockId(), block->getHash(), mContextBlocksByHash.size());
937-
return;
938-
}
939-
}
940-
TLLM_LOG_TRACE(
941-
"Add block %d by %zx, block n = %zu", block->getBlockId(), block->getHash(), mContextBlocksByHash.size());
942-
mContextBlocksByHash.emplace(block->getHash(), std::move(block));
943-
}
944-
945-
void WindowBlockManager::removeBlockFromHashMap(BlockPtr const& block)
946-
{
947-
if (mContextBlocksByHash.empty() || block->getBlockKey().uniqueTokens.empty())
948-
{
949-
// Hash key not enabled / Empty block
950-
return;
951-
}
952-
auto range = mContextBlocksByHash.equal_range(block->getHash());
953-
TLLM_LOG_TRACE(
954-
"Remove block %d by %zx, block n = %zu", block->getBlockId(), block->getHash(), mContextBlocksByHash.size());
955-
for (auto it = range.first; it != range.second; ++it)
956-
{
957-
if (it->second == block)
958-
{
959-
mContextBlocksByHash.erase(it);
960-
return;
961-
}
962-
}
963-
// TODO: should be unreachable
964-
TLLM_LOG_DEBUG("Trying to remove block %d by %zx that is not in hash map", block->getBlockId(), block->getHash());
965-
}
966-
967923
void BlockManager::onboardBlock(BlockPtr const& offloadBlock, SizeType32 windowSize)
968924
{
969925
mWindowBlockManagers.at(windowSize).onboardBlock(offloadBlock);
@@ -1104,7 +1060,6 @@ SizeType32 WindowBlockManager::loadOrAllocateBlocks(std::vector<BlockKey> const&
11041060
matchingBlock, perBlockRetentions[bi].retentionPriority, perBlockRetentions[bi].durationMs);
11051061
TLLM_LOG_DEBUG("%s::loadOrAllocateBlocks - Reused partially filled block %d", mLogPrefix.c_str(),
11061062
matchingBlockId);
1107-
addBlockToHashMap(matchingBlock);
11081063
}
11091064
searchRoot = nullptr; // no matching needed for following blocks
11101065
}
@@ -1114,7 +1069,6 @@ SizeType32 WindowBlockManager::loadOrAllocateBlocks(std::vector<BlockKey> const&
11141069
mEvictionPolicy->claimBlock(
11151070
matchingBlock, perBlockRetentions[bi].retentionPriority, perBlockRetentions[bi].durationMs);
11161071
TLLM_LOG_DEBUG("%s::loadOrAllocateBlocks - Matched full block %d", mLogPrefix.c_str(), matchingBlockId);
1117-
addBlockToHashMap(matchingBlock);
11181072
searchRoot = matchingBlock;
11191073
}
11201074
onboardBlock(matchingBlock);
@@ -1145,7 +1099,6 @@ SizeType32 WindowBlockManager::loadOrAllocateBlocks(std::vector<BlockKey> const&
11451099
++blockItr;
11461100
}
11471101
freeBlock->setHash();
1148-
addBlockToHashMap(freeBlock);
11491102
++mMissedBlocks;
11501103
}
11511104
}
@@ -1169,7 +1122,6 @@ SizeType32 WindowBlockManager::loadOrAllocateBlocks(std::vector<BlockKey> const&
11691122
++blockItr;
11701123
}
11711124
freeBlock->setHash();
1172-
addBlockToHashMap(freeBlock);
11731125
TLLM_LOG_DEBUG("%s::loadOrAllocateBlocks - Beam %d. Allocated non-shared block %d for bi %d",
11741126
mLogPrefix.c_str(), beamIdx, freeBlock->getBlockId(), bi);
11751127
}
@@ -1369,9 +1321,7 @@ void WindowBlockManager::storeBlocks(
13691321
if (oldHash != newHash)
13701322
{
13711323
TLLM_LOG_DEBUG("#%d block hash %zx -> %zx", block->getBlockId(), oldHash, newHash);
1372-
removeBlockFromHashMap(block);
13731324
block->setHash(newHash);
1374-
addBlockToHashMap(block);
13751325
}
13761326
searchRoot = block;
13771327
}
@@ -1408,7 +1358,6 @@ void WindowBlockManager::replaceSharedBlock(GenerationRequest& sequence, SizeTyp
14081358
if (!block->hasRefs())
14091359
{
14101360
mEvictionPolicy->releaseBlock(block);
1411-
removeBlockFromHashMap(block);
14121361
}
14131362
}
14141363

@@ -1473,7 +1422,6 @@ void WindowBlockManager::releaseLastBlock(GenerationRequest& sequence)
14731422
if (!block->hasRefs())
14741423
{
14751424
mEvictionPolicy->releaseBlock(block, true);
1476-
removeBlockFromHashMap(block);
14771425
}
14781426
// Remove block from allocated blocks
14791427
allocatedBlocks.pop_back();
@@ -1616,7 +1564,6 @@ void WindowBlockManager::releaseBlocks(GenerationRequest& sequence)
16161564
if (!block->hasRefs())
16171565
{
16181566
mEvictionPolicy->releaseBlock(block);
1619-
removeBlockFromHashMap(block);
16201567
}
16211568
}
16221569
// Remove stored block ids in sequence
@@ -2106,7 +2053,6 @@ void KVCacheManager::addSequence(
21062053
block->setBlockKey({}, false);
21072054
}
21082055
block->setHash();
2109-
mBlockManager.addBlockToHashMap(block, windowSize);
21102056
}
21112057
}
21122058
}

cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp

Lines changed: 0 additions & 183 deletions
Original file line numberDiff line numberDiff line change
@@ -3053,189 +3053,6 @@ TEST_F(KVCacheManagerTest, KVCacheManagerVariableWindowAttentionWithReuseTest)
30533053
assertBlocks(seq3, {4}, {6});
30543054
}
30553055

3056-
namespace
3057-
{
3058-
KVCacheManager setupKvCacheManagerForHashTest(bool enableBlockReuse)
3059-
{
3060-
auto constexpr numLayers = 2;
3061-
auto constexpr numHeads = 2;
3062-
auto constexpr sizePerHead = 64;
3063-
auto constexpr tokensPerBlock = 4;
3064-
auto constexpr maxNumSequences = 8;
3065-
auto constexpr maxBeamWidth = 1;
3066-
auto constexpr sinkTokenLength = 0;
3067-
auto const stream = std::make_shared<tr::CudaStream>();
3068-
3069-
auto constexpr maxBlocksPerSeq = 8;
3070-
auto constexpr maxNumTokens = tokensPerBlock * maxBlocksPerSeq;
3071-
auto constexpr maxAttentionWindow = maxNumTokens;
3072-
3073-
auto constexpr blocksInPrimaryPool = 16;
3074-
auto constexpr blocksInSecondaryPool = 0;
3075-
3076-
auto constexpr onboardBlocks = true;
3077-
3078-
auto const blocksPerWindow = BlocksPerWindow{{maxAttentionWindow, {blocksInPrimaryPool, blocksInSecondaryPool}}};
3079-
3080-
return KVCacheManager(std::vector<SizeType32>(numLayers, numHeads), sizePerHead, tokensPerBlock, blocksPerWindow,
3081-
maxNumSequences, maxBeamWidth, std::vector<BlockManager::SizeType32>{maxAttentionWindow}, std::nullopt,
3082-
nvinfer1::DataType::kHALF, sinkTokenLength, stream, std::nullopt, enableBlockReuse, onboardBlocks,
3083-
CacheType::kSELF, std::nullopt, nullptr,
3084-
/*enableHashKey*/ true);
3085-
}
3086-
3087-
std::vector<size_t> getHashAndRetrieveBlocksByHashTest(
3088-
BlockManager const& blockManager, std::vector<KVCacheBlock::IdType> const& blockIds, SizeType32 windowSize)
3089-
{
3090-
std::vector<size_t> blockHashes;
3091-
for (auto blockId : blockIds)
3092-
{
3093-
blockHashes.emplace_back(blockManager.getBlockById(blockId, windowSize)->getHash());
3094-
}
3095-
std::vector<BlockPtr> blockPtrs;
3096-
for (auto hash : blockHashes)
3097-
{
3098-
auto range = blockManager.getBlocksByHash(hash, windowSize);
3099-
BlockPtr const prevBlock = blockPtrs.empty() ? nullptr : blockPtrs.back();
3100-
BlockPtr thisBlock = nullptr;
3101-
for (auto it = range.first; it != range.second; ++it)
3102-
{
3103-
if (it->second->getPrevBlockInSeq() == prevBlock)
3104-
{
3105-
thisBlock = it->second;
3106-
break;
3107-
}
3108-
}
3109-
EXPECT_NE(thisBlock, nullptr);
3110-
blockPtrs.emplace_back(thisBlock);
3111-
}
3112-
EXPECT_EQ(blockHashes.size(), blockPtrs.size());
3113-
for (size_t i = 0; i < blockHashes.size(); i++)
3114-
{
3115-
EXPECT_EQ(blockManager.getBlockById(blockIds[i], windowSize), blockPtrs[i]);
3116-
}
3117-
return blockHashes;
3118-
}
3119-
} // namespace
3120-
3121-
TEST_F(KVCacheManagerTest, KVCacheManagerHashKeyTest)
3122-
{
3123-
auto kvCacheManager = setupKvCacheManagerForHashTest(false);
3124-
3125-
auto const& blockManager = kvCacheManager.getBlockManager();
3126-
3127-
SizeType32 constexpr maxNewTokens = 4;
3128-
3129-
// prepare tokens with token[i] = 1000 + i
3130-
TokenIdType constexpr firstToken = 1000;
3131-
3132-
auto constexpr beamWidth = 1;
3133-
tr::SamplingConfig const samplingConfig{beamWidth};
3134-
bool constexpr isStreaming{false};
3135-
3136-
SizeType32 requestId = 0;
3137-
int inputLength = 16;
3138-
auto inputTokens = std::make_shared<VecTokens>(inputLength);
3139-
std::iota(inputTokens->begin(), inputTokens->end(), firstToken);
3140-
auto llmRequest = std::make_shared<LlmRequest>(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming);
3141-
auto constexpr beamIdx = 0;
3142-
3143-
///////////////////////////////////////////////////////////////////////////
3144-
// add a request and then remove it without reuse
3145-
kvCacheManager.addSequence(requestId, inputLength, beamWidth, llmRequest);
3146-
GenerationRequest const& seq = kvCacheManager.getSequence(requestId);
3147-
EXPECT_EQ(llmRequest->getContextCurrentPosition(), 0);
3148-
3149-
auto const onlyWindowSize = theOnlyWindowSize(kvCacheManager);
3150-
3151-
auto& blockIds = seq.getCacheBlockIds(onlyWindowSize).at(beamIdx);
3152-
EXPECT_THAT(blockIds, ::testing::ElementsAreArray({0, 1, 2, 3}));
3153-
3154-
// get blocks by hash and try to retrieve them by hash
3155-
auto blockHashes = getHashAndRetrieveBlocksByHashTest(blockManager, blockIds, onlyWindowSize);
3156-
3157-
EXPECT_NO_THROW(kvCacheManager.removeSequence(requestId, llmRequest));
3158-
3159-
// blocks are all removed
3160-
for (auto hash : blockHashes)
3161-
{
3162-
auto range = blockManager.getBlocksByHash(hash, onlyWindowSize);
3163-
EXPECT_EQ(range.first, range.second);
3164-
}
3165-
EXPECT_EQ(blockManager.getNumAllocatedBlocks(), 0);
3166-
}
3167-
3168-
TEST_F(KVCacheManagerTest, KVCacheManagerHashKeyWithReuseTest)
3169-
{
3170-
auto kvCacheManager = setupKvCacheManagerForHashTest(true);
3171-
3172-
auto const& blockManager = kvCacheManager.getBlockManager();
3173-
3174-
SizeType32 constexpr maxNewTokens = 4;
3175-
3176-
// prepare tokens with token[i] = 1000 + i
3177-
TokenIdType constexpr firstToken = 1000;
3178-
3179-
auto constexpr beamWidth = 1;
3180-
tr::SamplingConfig const samplingConfig{beamWidth};
3181-
bool constexpr isStreaming{false};
3182-
3183-
SizeType32 requestId = 0;
3184-
int inputLength = 16;
3185-
auto inputTokens = std::make_shared<VecTokens>(inputLength);
3186-
std::iota(inputTokens->begin(), inputTokens->end(), firstToken);
3187-
auto llmRequest = std::make_shared<LlmRequest>(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming);
3188-
auto constexpr beamIdx = 0;
3189-
3190-
///////////////////////////////////////////////////////////////////////////
3191-
// add a request and then remove it with reuse
3192-
kvCacheManager.addSequence(requestId, inputLength, beamWidth, llmRequest);
3193-
GenerationRequest const& seq0 = kvCacheManager.getSequence(requestId);
3194-
EXPECT_EQ(llmRequest->getContextCurrentPosition(), 0);
3195-
3196-
EXPECT_EQ(blockManager.getNumPools(), 1);
3197-
auto const onlyWindowSize = theOnlyWindowSize(kvCacheManager);
3198-
3199-
auto& blockIds0 = seq0.getCacheBlockIds(onlyWindowSize).at(beamIdx);
3200-
EXPECT_THAT(blockIds0, ::testing::ElementsAreArray({0, 1, 2, 3}));
3201-
3202-
// get blocks by hash and try to retrieve them by hash
3203-
auto blockHashes = getHashAndRetrieveBlocksByHashTest(blockManager, blockIds0, onlyWindowSize);
3204-
3205-
EXPECT_NO_THROW(kvCacheManager.removeSequence(requestId, llmRequest));
3206-
3207-
// TODO: Make reused blocks accessible by hash, after sequence removed. Test here.
3208-
3209-
///////////////////////////////////////////////////////////////////////////
3210-
// add a new request with same prefix
3211-
requestId = 1;
3212-
inputLength = 20;
3213-
inputTokens->resize(inputLength);
3214-
std::iota(inputTokens->begin(), inputTokens->end(), firstToken);
3215-
llmRequest = std::make_shared<LlmRequest>(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming);
3216-
kvCacheManager.addSequence(requestId, inputLength, beamWidth, llmRequest);
3217-
GenerationRequest const& seq1 = kvCacheManager.getSequence(requestId);
3218-
EXPECT_EQ(llmRequest->getContextCurrentPosition(), 15);
3219-
auto& blockIds1 = seq1.getCacheBlockIds(onlyWindowSize).at(beamIdx);
3220-
EXPECT_THAT(blockIds1, ::testing::ElementsAreArray({0, 1, 2, 3, 4}));
3221-
3222-
std::ignore = getHashAndRetrieveBlocksByHashTest(blockManager, blockIds1, onlyWindowSize);
3223-
3224-
// blocks are reused, so reused blocks are still accessible by previous hashes
3225-
for (size_t i = 0; i < 4; i++)
3226-
{
3227-
auto range = blockManager.getBlocksByHash(blockHashes[i], onlyWindowSize);
3228-
EXPECT_NE(range.first, range.second);
3229-
}
3230-
// evicted block is not accessible
3231-
{
3232-
size_t i = 4;
3233-
auto range = blockManager.getBlocksByHash(blockHashes[i], onlyWindowSize);
3234-
EXPECT_EQ(range.first, range.second);
3235-
}
3236-
EXPECT_EQ(blockManager.getNumAllocatedBlocks(), 5);
3237-
}
3238-
32393056
TEST_F(KVCacheManagerTest, KVCacheManagerEventStream)
32403057
{
32413058
auto constexpr numLayers = 12;

0 commit comments

Comments
 (0)