Skip to content

feat: Add support for KVCache transfer from KVCache reuse path #6348

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions cpp/include/tensorrt_llm/batch_manager/cacheTransceiver.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ namespace tensorrt_llm::batch_manager

class ContextProgress;
class BaseCacheTransceiver;
class DataResponder;
class DataRequester;
class CacheSender;
class CacheReceiver;

class CacheTransceiverFactory
{
Expand Down Expand Up @@ -110,15 +110,16 @@ class CacheTransceiver : public BaseCacheTransceiver

void setContextState(LlmRequest* llmRequest);

std::unique_ptr<DataResponder> mDataResponder;
std::unique_ptr<DataRequester> mDataRequester;
std::vector<std::pair<LlmRequest*, std::future<void>>> mResponderFutures;
std::unique_ptr<CacheSender> mCacheSender;
std::unique_ptr<CacheReceiver> mCacheReceiver;
std::vector<std::pair<LlmRequest*, std::future<void>>> mSenderFutures;
std::vector<std::pair<LlmRequest*, std::future<void>>> mRequesterFutures;
mpi::MpiComm const *mMpiGroupComm{nullptr}, *mMpiWorldComm{nullptr};
std::shared_ptr<mpi::MpiComm> mMpiGroupTensorParaComm, mMpiGroupPipeParaComm, mMpiGroupDataComm,
mMpiGroupTPInDPComm;
executor::kv_cache::CommState const* mCommState;
std::unique_ptr<executor::kv_cache::CacheState> mCacheState;
// std::unique_ptr<CacheServer> mCacheServer;
std::unique_ptr<executor::kv_cache::ConnectionManager> mManager;
std::optional<executor::CacheTransceiverConfig> mCacheTransceiverConfig;
std::unique_ptr<kv_cache_manager::CacheTransBufferManager> mCacheTransBufferManager;
Expand Down
38 changes: 38 additions & 0 deletions cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ struct BlockKey
// Each extra key is a pair of (mm_hash, start_offset_in_block)
std::vector<MmKey> extraKeys;

size_t hash{0};

BlockKey() = default;

explicit BlockKey(VecTokens const& tokens, std::optional<LoraTaskIdType> loraTaskId = std::nullopt)
Expand All @@ -127,6 +129,11 @@ struct BlockKey
}
}

explicit BlockKey(size_t hash)
: hash{hash}
{
}
Comment on lines +132 to +135
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Document the purpose of the hash-only constructor.

The constructor that takes only a hash value needs documentation explaining when it should be used.

Add documentation:

+    /// @brief Construct a BlockKey with a precomputed hash.
+    /// @param hash The precomputed hash value.
+    /// @details Used when retrieving blocks by hash from the reuse tree.
     explicit BlockKey(size_t hash)
         : hash{hash}
     {
     }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
explicit BlockKey(size_t hash)
: hash{hash}
{
}
/// @brief Construct a BlockKey with a precomputed hash.
/// @param hash The precomputed hash value.
/// @details Used when retrieving blocks by hash from the reuse tree.
explicit BlockKey(size_t hash)
: hash{hash}
{
}
🤖 Prompt for AI Agents
In cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h around lines 132 to
135, add a documentation comment for the explicit BlockKey(size_t hash)
constructor that explains it constructs a key from a precomputed hash (used for
lookups/comparisons where only the hash is available), clarifies it does not
populate other key fields, notes any invariants or expected source of the hash
(e.g., result of BlockKey::hash() or external hashing function), and documents
intended usage scenarios (fast-lookup, temporary key for map/set operations) and
ownership/lifetime implications.


explicit BlockKey(bool usesExtraIds, std::optional<LoraTaskIdType> loraTaskId, VecUniqueTokens uniqueTokens,
std::vector<MmKey> extraKeys = {})
: usesExtraIds{usesExtraIds}
Expand Down Expand Up @@ -164,6 +171,10 @@ struct BlockKeyHasher

std::size_t operator()(BlockKey const& blockKey, std::size_t parentHash = 0) const noexcept
{
if (blockKey.hash != 0)
{
return blockKey.hash;
}
return hash(blockKey, parentHash);
}
};
Expand Down Expand Up @@ -566,6 +577,8 @@ class WindowBlockManager

void storeNewBlock(GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest);

void pinBlocks(GenerationRequest& sequence);

Comment on lines +580 to +581
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add documentation for the new pinBlocks method.

The new pinBlocks method lacks documentation explaining its purpose, behavior, and when it should be called. This is particularly important as it appears to be part of the public API.

Add documentation before the method declaration:

+    //! \brief Pin blocks associated with a sequence to prevent eviction.
+    //! \param sequence The generation request whose blocks should be pinned.
+    //! \details This method marks blocks as pinned in the KV cache to ensure they
+    //!          remain available for reuse across multiple requests.
     void pinBlocks(GenerationRequest& sequence);
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
void pinBlocks(GenerationRequest& sequence);
//! \brief Pin blocks associated with a sequence to prevent eviction.
//! \param sequence The generation request whose blocks should be pinned.
//! \details This method marks blocks as pinned in the KV cache to ensure they
//! remain available for reuse across multiple requests.
void pinBlocks(GenerationRequest& sequence);
🤖 Prompt for AI Agents
In cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h around lines 569-570,
the public method declaration void pinBlocks(GenerationRequest& sequence); is
missing documentation; add a brief doxygen-style comment immediately above the
declaration describing purpose (what pinning does), expected behavior (which
blocks are pinned, lifetime and thread-safety expectations), when callers should
invoke it (e.g., before generation, after allocation), parameters (explain
GenerationRequest reference and ownership/constness), return/exception behavior
(whether it can fail, how failures are reported), and any side effects (impact
on memory and eviction). Ensure the comment is concise, uses @param and
@throws/@note as appropriate, and matches the style of surrounding method docs
in this header.

//! \brief Release blocks of the sequence.
void releaseBlocks(GenerationRequest& sequence);

Expand Down Expand Up @@ -737,6 +750,9 @@ class WindowBlockManager
return 0;
}

[[nodiscard]] std::optional<std::shared_ptr<KVCacheBlock>> findBlocksInReuseTreeByHashes(
std::vector<size_t> const& hashes) const;

private:
//! \brief Add single block to beam of sequence and mAllocatedBlocksPerSeq.
void addBlockToBeam(BlockPtr& block, GenerationRequest& sequence, SizeType32 beamIdx);
Expand Down Expand Up @@ -883,6 +899,8 @@ class BlockManager

void schedulingReleaseBlocks(LlmRequest::RequestIdType requestId);

void pinBlocks(GenerationRequest& sequence);

void releaseLastBlock(GenerationRequest& sequence, SizeType32 windowSize);

void setOffsets(kernels::KVCacheIndex* offsetsPtr, nvinfer1::Dims const& offsetsShape, SizeType32 beamIdx,
Expand Down Expand Up @@ -1067,6 +1085,12 @@ class BlockManager
return mWindowBlockManagers.at(windowSize).getBlockById(blockId);
}

[[nodiscard]] std::optional<std::shared_ptr<KVCacheBlock>> findBlocksInReuseTreeByHashes(
std::vector<size_t> const& hashes, SizeType32 windowSize) const
{
return mWindowBlockManagers.at(windowSize).findBlocksInReuseTreeByHashes(hashes);
}

[[nodiscard]] SizeType32 getNumPrimaryBlocks() const
{
return sumWindows([](auto const& manager) { return manager.getNumPrimaryBlocks(); });
Expand Down Expand Up @@ -1201,6 +1225,8 @@ class BaseKVCacheManager
[[nodiscard]] virtual SizeType32 getRemainingBlocksToCompletion(LlmRequest const& req, SizeType32 windowSize) const
= 0;

virtual void pinBlocks(LlmRequest::RequestIdType requestId) = 0;

Comment on lines +1228 to +1229
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Document the pure virtual pinBlocks method in the base class.

The pure virtual pinBlocks method in BaseKVCacheManager lacks documentation explaining the contract that derived classes must fulfill.

+    //! \brief Pin blocks associated with a request to prevent eviction.
+    //! \param requestId The ID of the request whose blocks should be pinned.
+    //! \details Implementations should mark all blocks associated with the request
+    //!          as pinned to ensure they remain available for potential reuse.
     virtual void pinBlocks(LlmRequest::RequestIdType requestId) = 0;
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
virtual void pinBlocks(LlmRequest::RequestIdType requestId) = 0;
//! \brief Pin blocks associated with a request to prevent eviction.
//! \param requestId The ID of the request whose blocks should be pinned.
//! \details Implementations should mark all blocks associated with the request
//! as pinned to ensure they remain available for potential reuse.
virtual void pinBlocks(LlmRequest::RequestIdType requestId) = 0;
🤖 Prompt for AI Agents
In cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h around lines
1218-1219, the pure virtual method `pinBlocks(LlmRequest::RequestIdType
requestId)` in BaseKVCacheManager is undocumented; add a brief doxygen-style
comment above the declaration that states the method's contract: what pinning
means, when it will be called, expected behavior (e.g., mark blocks for the
given requestId as pinned to prevent eviction), whether it must be thread-safe,
return/exception expectations (void, should not throw), and any required side
effects or preconditions so derived classes implement consistent behavior.

/// @brief Increase size for request at seqSlotIdx. Allocate new KV cache block(s) if needed.
virtual void addToken(LlmRequest::RequestIdType requestId) = 0;

Expand Down Expand Up @@ -1338,6 +1364,10 @@ class BaseKVCacheManager
[[nodiscard]] virtual SizeType32 getMaxCapacityBatchSize(SizeType32 inputLength, SizeType32 outputLength) const = 0;

[[nodiscard]] virtual CacheType getCacheType() const = 0;

[[nodiscard]] virtual std::optional<std::shared_ptr<KVCacheBlock>> findBlocksInReuseTreeByHashes(
std::vector<size_t> const& hashes, SizeType32 windowSize) const
= 0;
};

class KVCacheManager : public BaseKVCacheManager
Expand Down Expand Up @@ -1588,6 +1618,8 @@ class KVCacheManager : public BaseKVCacheManager
[[nodiscard]] static SizeType32 calculateMaxBlockRequirements(SizeType32 inputLength, SizeType32 outputLength,
SizeType32 sinkTokenLength, SizeType32 windowSize, SizeType32 beamWidth, SizeType32 tokensPerBlock);

void pinBlocks(LlmRequest::RequestIdType requestId);

Comment on lines +1621 to +1622
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Missing implementation declaration for pinBlocks override.

The pinBlocks method is declared without the override keyword, which could lead to subtle bugs if the signature doesn't match the base class exactly.

-    void pinBlocks(LlmRequest::RequestIdType requestId);
+    void pinBlocks(LlmRequest::RequestIdType requestId) override;
🤖 Prompt for AI Agents
In cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h around lines
1607-1608, the method declaration for pinBlocks is missing the override
specifier; update the declaration to match and explicitly mark it as overriding
the base class (e.g. change to "void pinBlocks(LlmRequest::RequestIdType
requestId) override;"), and verify the signature (const/&, noexcept) exactly
matches the base class method so the override is valid.

/// @brief Calculates the number of kv-cache blocks that a sequence will require, for a single beam.
///
/// @param sequenceLength The total length of the sequence (input and output).
Expand Down Expand Up @@ -1625,6 +1657,12 @@ class KVCacheManager : public BaseKVCacheManager
mBlockManager.flushIterationEvents();
}

std::optional<std::shared_ptr<KVCacheBlock>> findBlocksInReuseTreeByHashes(
std::vector<size_t> const& hashes, SizeType32 windowSize) const override
{
return mBlockManager.findBlocksInReuseTreeByHashes(hashes, windowSize);
}

/// @brief Finds the maximum attention window that can be used on a sequence, given some kv-cache block capacity.
///
/// @param inputLength The number of input tokens in the sequence.
Expand Down
22 changes: 22 additions & 0 deletions cpp/include/tensorrt_llm/batch_manager/kvCacheUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,28 @@ class BlockRange
return BlockRange(cacheManager, blockIds, requestId);
}

static BlockRange fromReuseTree(BaseKVCacheManager const& cacheManager, std::vector<size_t> const& allBlockHashes,
std::vector<size_t> const& requestedBlockHashes)
{
auto const windowSize = firstWindowSize(cacheManager);
auto lastBlock = *cacheManager.findBlocksInReuseTreeByHashes(allBlockHashes, windowSize);
// TODO: handle the case where the last block is not found
TLLM_CHECK_WITH_INFO(lastBlock, "Couldn't find the requested block in the reuse tree");
// Assume the the last block is the requested block
std::vector<SizeType32> blockIds;
Comment on lines +51 to +59
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Dereferencing the result before checking leads to UB when not found

findBlocksInReuseTreeByHashes(...) is dereferenced before verifying it contains a value or non-null pointer. Check first, then dereference.

Apply this diff:

-        auto const windowSize = firstWindowSize(cacheManager);
-        auto lastBlock = *cacheManager.findBlocksInReuseTreeByHashes(allBlockHashes, windowSize);
-        // TODO: handle the case where the last block is not found
-        TLLM_CHECK_WITH_INFO(lastBlock, "Couldn't find the requested block in the reuse tree");
+        auto const windowSize = firstWindowSize(cacheManager);
+        auto lastBlockOpt = cacheManager.findBlocksInReuseTreeByHashes(allBlockHashes, windowSize);
+        TLLM_CHECK_WITH_INFO(lastBlockOpt, "Couldn't find the requested block in the reuse tree");
+        auto lastBlock = *lastBlockOpt;
+        TLLM_CHECK_WITH_INFO(lastBlock, "Couldn't find the requested block in the reuse tree");
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
static BlockRange fromReuseTree(BaseKVCacheManager const& cacheManager, std::vector<size_t> const& allBlockHashes,
std::vector<size_t> const& requestedBlockHashes)
{
auto const windowSize = firstWindowSize(cacheManager);
auto lastBlock = *cacheManager.findBlocksInReuseTreeByHashes(allBlockHashes, windowSize);
// TODO: handle the case where the last block is not found
TLLM_CHECK_WITH_INFO(lastBlock, "Couldn't find the requested block in the reuse tree");
// Assume the the last block is the requested block
std::vector<SizeType32> blockIds;
static BlockRange fromReuseTree(BaseKVCacheManager const& cacheManager, std::vector<size_t> const& allBlockHashes,
std::vector<size_t> const& requestedBlockHashes)
{
auto const windowSize = firstWindowSize(cacheManager);
auto lastBlockOpt = cacheManager.findBlocksInReuseTreeByHashes(allBlockHashes, windowSize);
TLLM_CHECK_WITH_INFO(lastBlockOpt, "Couldn't find the requested block in the reuse tree");
auto lastBlock = *lastBlockOpt;
TLLM_CHECK_WITH_INFO(lastBlock, "Couldn't find the requested block in the reuse tree");
// Assume the the last block is the requested block
std::vector<SizeType32> blockIds;
🤖 Prompt for AI Agents
In cpp/include/tensorrt_llm/batch_manager/kvCacheUtils.h around lines 51 to 59,
the code dereferences the result of
cacheManager.findBlocksInReuseTreeByHashes(...) before checking it, which can
cause UB if the result is empty/null; change the code to first store the call
result in a local variable, verify it contains a value (or non-null pointer) and
only then dereference it, and if it is missing handle the error path (use
TLLM_CHECK_WITH_INFO or return/throw appropriately) instead of dereferencing
unconditionally.

for (auto const& hash : requestedBlockHashes)
{
if (lastBlock->getHash() != hash)
{
return BlockRange(cacheManager, {}, 0);
}
blockIds.emplace_back(lastBlock->getBlockId());
lastBlock = lastBlock->getPrevBlock();
TLLM_CHECK_WITH_INFO(lastBlock, "Last block is not found");
}
return BlockRange(cacheManager, blockIds, 0);
}
Comment on lines +60 to +71
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

⚠️ Potential issue

Null check must precede dereference; avoid returning an “empty” range with bogus requestId

  • The loop dereferences lastBlock before checking it.
  • Returning BlockRange(cacheManager, {}, 0) can later cause mRequestId=0 to be used by updatePoolIdx(), yielding invalid sequence access or hard-to-trace faults.

Apply this loop fix and fail-fast on mismatch:

-        for (auto const& hash : requestedBlockHashes)
+        for (auto const& hash : requestedBlockHashes)
         {
-            if (lastBlock->getHash() != hash)
-            {
-                return BlockRange(cacheManager, {}, 0);
-            }
-            blockIds.emplace_back(lastBlock->getBlockId());
-            lastBlock = lastBlock->getPrevBlock();
-            TLLM_CHECK_WITH_INFO(lastBlock, "Last block is not found");
+            TLLM_CHECK_WITH_INFO(lastBlock, "Requested hash sequence exceeds reuse chain length");
+            if (lastBlock->getHash() != hash)
+            {
+                TLLM_THROW("Requested block hash mismatch in reuse tree (expected %zu, got %zu).",
+                    hash, lastBlock->getHash());
+            }
+            blockIds.emplace_back(lastBlock->getBlockId());
+            lastBlock = lastBlock->getPrevBlock();
         }
-        return BlockRange(cacheManager, blockIds, 0);
+        return BlockRange(cacheManager, blockIds, 0);

Follow-up: to make this robust, pass the real requestId into fromReuseTree; using 0 is fragile. See the callsite changes suggested in cacheFormatter.cpp.

If you prefer minimal change now, keep requestId=0 only when numPools==1 and ensure callers never trigger updatePoolIdx() with a window size change.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
for (auto const& hash : requestedBlockHashes)
{
if (lastBlock->getHash() != hash)
{
return BlockRange(cacheManager, {}, 0);
}
blockIds.emplace_back(lastBlock->getBlockId());
lastBlock = lastBlock->getPrevBlock();
TLLM_CHECK_WITH_INFO(lastBlock, "Last block is not found");
}
return BlockRange(cacheManager, blockIds, 0);
}
for (auto const& hash : requestedBlockHashes)
{
TLLM_CHECK_WITH_INFO(lastBlock, "Requested hash sequence exceeds reuse chain length");
if (lastBlock->getHash() != hash)
{
TLLM_THROW("Requested block hash mismatch in reuse tree (expected %zu, got %zu).",
hash, lastBlock->getHash());
}
blockIds.emplace_back(lastBlock->getBlockId());
lastBlock = lastBlock->getPrevBlock();
}
return BlockRange(cacheManager, blockIds, 0);
}
🤖 Prompt for AI Agents
In cpp/include/tensorrt_llm/batch_manager/kvCacheUtils.h around lines 60-71, the
loop dereferences lastBlock before checking it and returns a BlockRange with
requestId=0 which can lead to invalid updatePoolIdx() usage; fix by validating
lastBlock (TLLM_CHECK_WITH_INFO or if-check) before any dereference inside the
loop, fail-fast and return an error/empty range immediately on a hash mismatch
without touching lastBlock, and stop using a hardcoded requestId=0 — instead
pass the actual requestId into fromReuseTree (preferred) or only retain
requestId=0 when numPools==1 and callers cannot trigger updatePoolIdx() with a
window size change.


BlockRange(runtime::ITensor::SharedPtr pool, std::vector<SizeType32> const& blockIds) // Only used in tests
: mManager{nullptr}
, mPool{std::move(pool)}
Expand Down
10 changes: 0 additions & 10 deletions cpp/include/tensorrt_llm/batch_manager/llmRequest.h
Original file line number Diff line number Diff line change
Expand Up @@ -1831,16 +1831,6 @@ class GenericLlmRequest
}
}

void setRequestedBlockHashes(std::vector<size_t> hashes)
{
mRequestedBlockHashes = std::move(hashes);
}

[[nodiscard]] std::vector<size_t> const& getRequestedBlockHashes() const
{
return mRequestedBlockHashes;
}

void setIsDummyRequest(bool isDummyRequest)
{
mIsDummyRequest = isDummyRequest;
Expand Down
1 change: 0 additions & 1 deletion cpp/tensorrt_llm/batch_manager/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ set(SRCS
createNewDecoderRequests.cpp
contextProgress.cpp
dataTransceiver.cpp
dataTransceiverImpl.cpp
decoderBuffers.cpp
encoderBuffers.cpp
guidedDecoder.cpp
Expand Down
32 changes: 16 additions & 16 deletions cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,37 +39,33 @@
namespace tensorrt_llm::batch_manager::kv_cache_manager
{

BlockRange getBlockRangeForSending(BaseKVCacheManager* cacheManager, LlmRequest const& llmRequest)
BlockRange getBlockRangeForSending(BaseKVCacheManager* cacheManager, LlmRequest const& llmRequest,
std::vector<size_t> const& allBlockHashes, std::vector<size_t> const& requestedBlockHashes)
{
size_t requestBlockNum = llmRequest.getRequestedBlockHashes().size();
size_t requestBlockNum = requestedBlockHashes.size();
constexpr SizeType32 beam{0};
auto blockRange = BlockRange::fromAllBlockIds(*cacheManager, llmRequest.mRequestId, beam);
auto poolNum = cacheManager->getBlockManager().getNumPools();
if (poolNum > 1 || common::getEnvDisableSelectiveCacheTransfer())
if (poolNum > 1 || !cacheManager->isEnableBlockReuse())
{
// disable selective cache transfer for poolNum > 1
auto blockRange = BlockRange::fromAllBlockIds(*cacheManager, llmRequest.mRequestId, beam);
return blockRange;
}
if (requestBlockNum < blockRange.size() && requestBlockNum > 0)
{
// handle block reuse, the prefix blocks are reused
// TODO(zhengd): pass the hashes directly instead of from llmRequest; use hash instead of block num
auto const& ids = blockRange.getBlockIds();
blockRange.setBlockIds({ids.end() - requestBlockNum, ids.end()});
}
return blockRange;
return BlockRange::fromReuseTree(*cacheManager, allBlockHashes, requestedBlockHashes);
}
Comment on lines +42 to 54
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Remove unused variable and pass requestId into reuse-tree pathway

  • requestBlockNum is computed but unused.
  • fromReuseTree should ideally know the requestId to avoid fragile mRequestId=0 behavior in BlockRange.

Apply this diff:

-BlockRange getBlockRangeForSending(BaseKVCacheManager* cacheManager, LlmRequest const& llmRequest,
-    std::vector<size_t> const& allBlockHashes, std::vector<size_t> const& requestedBlockHashes)
+BlockRange getBlockRangeForSending(BaseKVCacheManager* cacheManager, LlmRequest const& llmRequest,
+    std::vector<size_t> const& allBlockHashes, std::vector<size_t> const& requestedBlockHashes)
 {
-    size_t requestBlockNum = requestedBlockHashes.size();
     constexpr SizeType32 beam{0};
     auto poolNum = cacheManager->getBlockManager().getNumPools();
     if (poolNum > 1 || !cacheManager->isEnableBlockReuse())
     {
         auto blockRange = BlockRange::fromAllBlockIds(*cacheManager, llmRequest.mRequestId, beam);
         return blockRange;
     }
-    return BlockRange::fromReuseTree(*cacheManager, allBlockHashes, requestedBlockHashes);
+    return BlockRange::fromReuseTree(*cacheManager, llmRequest.mRequestId, allBlockHashes, requestedBlockHashes);
 }

Note: This assumes adding an overload of BlockRange::fromReuseTree that accepts requestId. See kvCacheUtils.h comment for the corresponding constructor change.

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp around lines 42 to 54,
remove the unused local requestBlockNum variable and change the reuse-tree
return to pass the requestId (llmRequest.mRequestId) into
BlockRange::fromReuseTree so it uses the request context instead of relying on
mRequestId=0; this assumes you have the overload
BlockRange::fromReuseTree(*cacheManager, requestId, allBlockHashes,
requestedBlockHashes) available per the kvCacheUtils.h change.


BlockRange getBlockRangeForReceiving(BaseKVCacheManager* cacheManager, LlmRequest const& llmRequest)
{

auto poolNum = cacheManager->getBlockManager().getNumPools();
if (poolNum > 1 || common::getEnvDisableSelectiveCacheTransfer())
if (poolNum == 1 && cacheManager->isEnableBlockReuse())
{
return BlockRange::fromNewlyAllocatedBlockIds(*cacheManager, llmRequest.mRequestId);
}
else
{
constexpr SizeType32 beam{0};
return BlockRange::fromAllBlockIds(*cacheManager, llmRequest.mRequestId, beam);
}
return BlockRange::fromNewlyAllocatedBlockIds(*cacheManager, llmRequest.mRequestId);
}

bool CacheFormatter::needSendCache(
Expand Down Expand Up @@ -155,13 +151,17 @@ void CacheFormatter::format(TransferSession& session)
auto const& selfConfig = session.getSelfState().getCacheState().value();
auto const& destConfig = session.getOtherState().getCacheState().value();
auto const selfIdx = session.getSelfState().getCommState().value().getSelfIdx();
auto& allBlockHashes = session.getAllBlockHashes();
auto& requestedBlockHashes = session.getRequestedBlockHashes();
TLLM_CHECK_WITH_INFO(allBlockHashes.size() >= requestedBlockHashes.size(),
"allBlockHashes must be greater than or equal to requestedBlockHashes");
auto& bufferManager = session.getBufferManager();
if (!needSendCache(selfConfig, destConfig, selfIdx))
{
return;
}
auto& blockManager = mCacheManager->getBlockManager();
auto blockRange = getBlockRangeForSending(mCacheManager, llmRequest);
auto blockRange = getBlockRangeForSending(mCacheManager, llmRequest, allBlockHashes, requestedBlockHashes);

auto const numPools = blockManager.getNumPools();
// TODO(oargov): are we sure the other side has the same number of pools? this might not hold for pp_size>1...
Expand Down
Loading
Loading