-
Notifications
You must be signed in to change notification settings - Fork 1.7k
feat: Add support for KVCache transfer from KVCache reuse path #6348
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -115,6 +115,8 @@ struct BlockKey | |||||||||||||
// Each extra key is a pair of (mm_hash, start_offset_in_block) | ||||||||||||||
std::vector<MmKey> extraKeys; | ||||||||||||||
|
||||||||||||||
size_t hash{0}; | ||||||||||||||
|
||||||||||||||
BlockKey() = default; | ||||||||||||||
|
||||||||||||||
explicit BlockKey(VecTokens const& tokens, std::optional<LoraTaskIdType> loraTaskId = std::nullopt) | ||||||||||||||
|
@@ -127,6 +129,11 @@ struct BlockKey | |||||||||||||
} | ||||||||||||||
} | ||||||||||||||
|
||||||||||||||
explicit BlockKey(size_t hash) | ||||||||||||||
: hash{hash} | ||||||||||||||
{ | ||||||||||||||
} | ||||||||||||||
|
||||||||||||||
explicit BlockKey(bool usesExtraIds, std::optional<LoraTaskIdType> loraTaskId, VecUniqueTokens uniqueTokens, | ||||||||||||||
std::vector<MmKey> extraKeys = {}) | ||||||||||||||
: usesExtraIds{usesExtraIds} | ||||||||||||||
|
@@ -164,6 +171,10 @@ struct BlockKeyHasher | |||||||||||||
|
||||||||||||||
std::size_t operator()(BlockKey const& blockKey, std::size_t parentHash = 0) const noexcept | ||||||||||||||
{ | ||||||||||||||
if (blockKey.hash != 0) | ||||||||||||||
{ | ||||||||||||||
return blockKey.hash; | ||||||||||||||
} | ||||||||||||||
return hash(blockKey, parentHash); | ||||||||||||||
} | ||||||||||||||
}; | ||||||||||||||
|
@@ -566,6 +577,8 @@ class WindowBlockManager | |||||||||||||
|
||||||||||||||
void storeNewBlock(GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest); | ||||||||||||||
|
||||||||||||||
void pinBlocks(GenerationRequest& sequence); | ||||||||||||||
|
||||||||||||||
Comment on lines
+580
to
+581
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Add documentation for the new The new Add documentation before the method declaration: + //! \brief Pin blocks associated with a sequence to prevent eviction.
+ //! \param sequence The generation request whose blocks should be pinned.
+ //! \details This method marks blocks as pinned in the KV cache to ensure they
+ //! remain available for reuse across multiple requests.
void pinBlocks(GenerationRequest& sequence); 📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents
|
||||||||||||||
//! \brief Release blocks of the sequence. | ||||||||||||||
void releaseBlocks(GenerationRequest& sequence); | ||||||||||||||
|
||||||||||||||
|
@@ -737,6 +750,9 @@ class WindowBlockManager | |||||||||||||
return 0; | ||||||||||||||
} | ||||||||||||||
|
||||||||||||||
[[nodiscard]] std::optional<std::shared_ptr<KVCacheBlock>> findBlocksInReuseTreeByHashes( | ||||||||||||||
std::vector<size_t> const& hashes) const; | ||||||||||||||
|
||||||||||||||
private: | ||||||||||||||
//! \brief Add single block to beam of sequence and mAllocatedBlocksPerSeq. | ||||||||||||||
void addBlockToBeam(BlockPtr& block, GenerationRequest& sequence, SizeType32 beamIdx); | ||||||||||||||
|
@@ -883,6 +899,8 @@ class BlockManager | |||||||||||||
|
||||||||||||||
void schedulingReleaseBlocks(LlmRequest::RequestIdType requestId); | ||||||||||||||
|
||||||||||||||
void pinBlocks(GenerationRequest& sequence); | ||||||||||||||
|
||||||||||||||
void releaseLastBlock(GenerationRequest& sequence, SizeType32 windowSize); | ||||||||||||||
|
||||||||||||||
void setOffsets(kernels::KVCacheIndex* offsetsPtr, nvinfer1::Dims const& offsetsShape, SizeType32 beamIdx, | ||||||||||||||
|
@@ -1067,6 +1085,12 @@ class BlockManager | |||||||||||||
return mWindowBlockManagers.at(windowSize).getBlockById(blockId); | ||||||||||||||
} | ||||||||||||||
|
||||||||||||||
[[nodiscard]] std::optional<std::shared_ptr<KVCacheBlock>> findBlocksInReuseTreeByHashes( | ||||||||||||||
std::vector<size_t> const& hashes, SizeType32 windowSize) const | ||||||||||||||
{ | ||||||||||||||
return mWindowBlockManagers.at(windowSize).findBlocksInReuseTreeByHashes(hashes); | ||||||||||||||
} | ||||||||||||||
|
||||||||||||||
[[nodiscard]] SizeType32 getNumPrimaryBlocks() const | ||||||||||||||
{ | ||||||||||||||
return sumWindows([](auto const& manager) { return manager.getNumPrimaryBlocks(); }); | ||||||||||||||
|
@@ -1201,6 +1225,8 @@ class BaseKVCacheManager | |||||||||||||
[[nodiscard]] virtual SizeType32 getRemainingBlocksToCompletion(LlmRequest const& req, SizeType32 windowSize) const | ||||||||||||||
= 0; | ||||||||||||||
|
||||||||||||||
virtual void pinBlocks(LlmRequest::RequestIdType requestId) = 0; | ||||||||||||||
|
||||||||||||||
Comment on lines
+1228
to
+1229
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Document the pure virtual The pure virtual + //! \brief Pin blocks associated with a request to prevent eviction.
+ //! \param requestId The ID of the request whose blocks should be pinned.
+ //! \details Implementations should mark all blocks associated with the request
+ //! as pinned to ensure they remain available for potential reuse.
virtual void pinBlocks(LlmRequest::RequestIdType requestId) = 0; 📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents
|
||||||||||||||
/// @brief Increase size for request at seqSlotIdx. Allocate new KV cache block(s) if needed. | ||||||||||||||
virtual void addToken(LlmRequest::RequestIdType requestId) = 0; | ||||||||||||||
|
||||||||||||||
|
@@ -1338,6 +1364,10 @@ class BaseKVCacheManager | |||||||||||||
[[nodiscard]] virtual SizeType32 getMaxCapacityBatchSize(SizeType32 inputLength, SizeType32 outputLength) const = 0; | ||||||||||||||
|
||||||||||||||
[[nodiscard]] virtual CacheType getCacheType() const = 0; | ||||||||||||||
|
||||||||||||||
[[nodiscard]] virtual std::optional<std::shared_ptr<KVCacheBlock>> findBlocksInReuseTreeByHashes( | ||||||||||||||
std::vector<size_t> const& hashes, SizeType32 windowSize) const | ||||||||||||||
= 0; | ||||||||||||||
}; | ||||||||||||||
|
||||||||||||||
class KVCacheManager : public BaseKVCacheManager | ||||||||||||||
|
@@ -1588,6 +1618,8 @@ class KVCacheManager : public BaseKVCacheManager | |||||||||||||
[[nodiscard]] static SizeType32 calculateMaxBlockRequirements(SizeType32 inputLength, SizeType32 outputLength, | ||||||||||||||
SizeType32 sinkTokenLength, SizeType32 windowSize, SizeType32 beamWidth, SizeType32 tokensPerBlock); | ||||||||||||||
|
||||||||||||||
void pinBlocks(LlmRequest::RequestIdType requestId); | ||||||||||||||
|
||||||||||||||
Comment on lines
+1621
to
+1622
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Missing implementation declaration for The - void pinBlocks(LlmRequest::RequestIdType requestId);
+ void pinBlocks(LlmRequest::RequestIdType requestId) override; 🤖 Prompt for AI Agents
|
||||||||||||||
/// @brief Calculates the number of kv-cache blocks that a sequence will require, for a single beam. | ||||||||||||||
/// | ||||||||||||||
/// @param sequenceLength The total length of the sequence (input and output). | ||||||||||||||
|
@@ -1625,6 +1657,12 @@ class KVCacheManager : public BaseKVCacheManager | |||||||||||||
mBlockManager.flushIterationEvents(); | ||||||||||||||
} | ||||||||||||||
|
||||||||||||||
std::optional<std::shared_ptr<KVCacheBlock>> findBlocksInReuseTreeByHashes( | ||||||||||||||
std::vector<size_t> const& hashes, SizeType32 windowSize) const override | ||||||||||||||
{ | ||||||||||||||
return mBlockManager.findBlocksInReuseTreeByHashes(hashes, windowSize); | ||||||||||||||
} | ||||||||||||||
|
||||||||||||||
/// @brief Finds the maximum attention window that can be used on a sequence, given some kv-cache block capacity. | ||||||||||||||
/// | ||||||||||||||
/// @param inputLength The number of input tokens in the sequence. | ||||||||||||||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -48,6 +48,28 @@ class BlockRange | |||||||||||||||||||||||||||||||||||||||||||||||||||
return BlockRange(cacheManager, blockIds, requestId); | ||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
static BlockRange fromReuseTree(BaseKVCacheManager const& cacheManager, std::vector<size_t> const& allBlockHashes, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
std::vector<size_t> const& requestedBlockHashes) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||||||||||||||||||
auto const windowSize = firstWindowSize(cacheManager); | ||||||||||||||||||||||||||||||||||||||||||||||||||||
auto lastBlock = *cacheManager.findBlocksInReuseTreeByHashes(allBlockHashes, windowSize); | ||||||||||||||||||||||||||||||||||||||||||||||||||||
// TODO: handle the case where the last block is not found | ||||||||||||||||||||||||||||||||||||||||||||||||||||
TLLM_CHECK_WITH_INFO(lastBlock, "Couldn't find the requested block in the reuse tree"); | ||||||||||||||||||||||||||||||||||||||||||||||||||||
// Assume the the last block is the requested block | ||||||||||||||||||||||||||||||||||||||||||||||||||||
std::vector<SizeType32> blockIds; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+51
to
+59
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Dereferencing the result before checking leads to UB when not found findBlocksInReuseTreeByHashes(...) is dereferenced before verifying it contains a value or non-null pointer. Check first, then dereference. Apply this diff: - auto const windowSize = firstWindowSize(cacheManager);
- auto lastBlock = *cacheManager.findBlocksInReuseTreeByHashes(allBlockHashes, windowSize);
- // TODO: handle the case where the last block is not found
- TLLM_CHECK_WITH_INFO(lastBlock, "Couldn't find the requested block in the reuse tree");
+ auto const windowSize = firstWindowSize(cacheManager);
+ auto lastBlockOpt = cacheManager.findBlocksInReuseTreeByHashes(allBlockHashes, windowSize);
+ TLLM_CHECK_WITH_INFO(lastBlockOpt, "Couldn't find the requested block in the reuse tree");
+ auto lastBlock = *lastBlockOpt;
+ TLLM_CHECK_WITH_INFO(lastBlock, "Couldn't find the requested block in the reuse tree"); 📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
for (auto const& hash : requestedBlockHashes) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||||||||||||||||||
if (lastBlock->getHash() != hash) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||||||||||||||||||
return BlockRange(cacheManager, {}, 0); | ||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||
blockIds.emplace_back(lastBlock->getBlockId()); | ||||||||||||||||||||||||||||||||||||||||||||||||||||
lastBlock = lastBlock->getPrevBlock(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||
TLLM_CHECK_WITH_INFO(lastBlock, "Last block is not found"); | ||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||
return BlockRange(cacheManager, blockIds, 0); | ||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+60
to
+71
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Null check must precede dereference; avoid returning an “empty” range with bogus requestId
Apply this loop fix and fail-fast on mismatch: - for (auto const& hash : requestedBlockHashes)
+ for (auto const& hash : requestedBlockHashes)
{
- if (lastBlock->getHash() != hash)
- {
- return BlockRange(cacheManager, {}, 0);
- }
- blockIds.emplace_back(lastBlock->getBlockId());
- lastBlock = lastBlock->getPrevBlock();
- TLLM_CHECK_WITH_INFO(lastBlock, "Last block is not found");
+ TLLM_CHECK_WITH_INFO(lastBlock, "Requested hash sequence exceeds reuse chain length");
+ if (lastBlock->getHash() != hash)
+ {
+ TLLM_THROW("Requested block hash mismatch in reuse tree (expected %zu, got %zu).",
+ hash, lastBlock->getHash());
+ }
+ blockIds.emplace_back(lastBlock->getBlockId());
+ lastBlock = lastBlock->getPrevBlock();
}
- return BlockRange(cacheManager, blockIds, 0);
+ return BlockRange(cacheManager, blockIds, 0); Follow-up: to make this robust, pass the real requestId into fromReuseTree; using 0 is fragile. See the callsite changes suggested in cacheFormatter.cpp. If you prefer minimal change now, keep requestId=0 only when numPools==1 and ensure callers never trigger updatePoolIdx() with a window size change. 📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
BlockRange(runtime::ITensor::SharedPtr pool, std::vector<SizeType32> const& blockIds) // Only used in tests | ||||||||||||||||||||||||||||||||||||||||||||||||||||
: mManager{nullptr} | ||||||||||||||||||||||||||||||||||||||||||||||||||||
, mPool{std::move(pool)} | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -39,37 +39,33 @@ | |
namespace tensorrt_llm::batch_manager::kv_cache_manager | ||
{ | ||
|
||
BlockRange getBlockRangeForSending(BaseKVCacheManager* cacheManager, LlmRequest const& llmRequest) | ||
BlockRange getBlockRangeForSending(BaseKVCacheManager* cacheManager, LlmRequest const& llmRequest, | ||
std::vector<size_t> const& allBlockHashes, std::vector<size_t> const& requestedBlockHashes) | ||
{ | ||
size_t requestBlockNum = llmRequest.getRequestedBlockHashes().size(); | ||
size_t requestBlockNum = requestedBlockHashes.size(); | ||
constexpr SizeType32 beam{0}; | ||
auto blockRange = BlockRange::fromAllBlockIds(*cacheManager, llmRequest.mRequestId, beam); | ||
auto poolNum = cacheManager->getBlockManager().getNumPools(); | ||
if (poolNum > 1 || common::getEnvDisableSelectiveCacheTransfer()) | ||
if (poolNum > 1 || !cacheManager->isEnableBlockReuse()) | ||
{ | ||
// disable selective cache transfer for poolNum > 1 | ||
auto blockRange = BlockRange::fromAllBlockIds(*cacheManager, llmRequest.mRequestId, beam); | ||
return blockRange; | ||
} | ||
if (requestBlockNum < blockRange.size() && requestBlockNum > 0) | ||
{ | ||
// handle block reuse, the prefix blocks are reused | ||
// TODO(zhengd): pass the hashes directly instead of from llmRequest; use hash instead of block num | ||
auto const& ids = blockRange.getBlockIds(); | ||
blockRange.setBlockIds({ids.end() - requestBlockNum, ids.end()}); | ||
} | ||
return blockRange; | ||
return BlockRange::fromReuseTree(*cacheManager, allBlockHashes, requestedBlockHashes); | ||
} | ||
Comment on lines
+42
to
54
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Remove unused variable and pass requestId into reuse-tree pathway
Apply this diff: -BlockRange getBlockRangeForSending(BaseKVCacheManager* cacheManager, LlmRequest const& llmRequest,
- std::vector<size_t> const& allBlockHashes, std::vector<size_t> const& requestedBlockHashes)
+BlockRange getBlockRangeForSending(BaseKVCacheManager* cacheManager, LlmRequest const& llmRequest,
+ std::vector<size_t> const& allBlockHashes, std::vector<size_t> const& requestedBlockHashes)
{
- size_t requestBlockNum = requestedBlockHashes.size();
constexpr SizeType32 beam{0};
auto poolNum = cacheManager->getBlockManager().getNumPools();
if (poolNum > 1 || !cacheManager->isEnableBlockReuse())
{
auto blockRange = BlockRange::fromAllBlockIds(*cacheManager, llmRequest.mRequestId, beam);
return blockRange;
}
- return BlockRange::fromReuseTree(*cacheManager, allBlockHashes, requestedBlockHashes);
+ return BlockRange::fromReuseTree(*cacheManager, llmRequest.mRequestId, allBlockHashes, requestedBlockHashes);
} Note: This assumes adding an overload of BlockRange::fromReuseTree that accepts requestId. See kvCacheUtils.h comment for the corresponding constructor change.
🤖 Prompt for AI Agents
|
||
|
||
BlockRange getBlockRangeForReceiving(BaseKVCacheManager* cacheManager, LlmRequest const& llmRequest) | ||
{ | ||
|
||
auto poolNum = cacheManager->getBlockManager().getNumPools(); | ||
if (poolNum > 1 || common::getEnvDisableSelectiveCacheTransfer()) | ||
if (poolNum == 1 && cacheManager->isEnableBlockReuse()) | ||
{ | ||
return BlockRange::fromNewlyAllocatedBlockIds(*cacheManager, llmRequest.mRequestId); | ||
} | ||
else | ||
{ | ||
constexpr SizeType32 beam{0}; | ||
return BlockRange::fromAllBlockIds(*cacheManager, llmRequest.mRequestId, beam); | ||
} | ||
return BlockRange::fromNewlyAllocatedBlockIds(*cacheManager, llmRequest.mRequestId); | ||
} | ||
|
||
bool CacheFormatter::needSendCache( | ||
|
@@ -155,13 +151,17 @@ void CacheFormatter::format(TransferSession& session) | |
auto const& selfConfig = session.getSelfState().getCacheState().value(); | ||
auto const& destConfig = session.getOtherState().getCacheState().value(); | ||
auto const selfIdx = session.getSelfState().getCommState().value().getSelfIdx(); | ||
auto& allBlockHashes = session.getAllBlockHashes(); | ||
auto& requestedBlockHashes = session.getRequestedBlockHashes(); | ||
TLLM_CHECK_WITH_INFO(allBlockHashes.size() >= requestedBlockHashes.size(), | ||
"allBlockHashes must be greater than or equal to requestedBlockHashes"); | ||
auto& bufferManager = session.getBufferManager(); | ||
if (!needSendCache(selfConfig, destConfig, selfIdx)) | ||
{ | ||
return; | ||
} | ||
auto& blockManager = mCacheManager->getBlockManager(); | ||
auto blockRange = getBlockRangeForSending(mCacheManager, llmRequest); | ||
auto blockRange = getBlockRangeForSending(mCacheManager, llmRequest, allBlockHashes, requestedBlockHashes); | ||
|
||
auto const numPools = blockManager.getNumPools(); | ||
// TODO(oargov): are we sure the other side has the same number of pools? this might not hold for pp_size>1... | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Document the purpose of the hash-only constructor.
The constructor that takes only a hash value needs documentation explaining when it should be used.
Add documentation:
📝 Committable suggestion
🤖 Prompt for AI Agents