Skip to content

Commit bd2757d

Browse files
committed
Initial iteration for supporting block hash transfer
Signed-off-by: Iman Tabrizian <[email protected]>
1 parent d1c5f80 commit bd2757d

File tree

13 files changed

+324
-217
lines changed

13 files changed

+324
-217
lines changed

cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,8 @@ struct BlockKey
115115
// Each extra key is a pair of (mm_hash, start_offset_in_block)
116116
std::vector<MmKey> extraKeys;
117117

118+
size_t hash{0};
119+
118120
BlockKey() = default;
119121

120122
explicit BlockKey(VecTokens const& tokens, std::optional<LoraTaskIdType> loraTaskId = std::nullopt)
@@ -127,6 +129,11 @@ struct BlockKey
127129
}
128130
}
129131

132+
explicit BlockKey(size_t hash)
133+
: hash{hash}
134+
{
135+
}
136+
130137
explicit BlockKey(bool usesExtraIds, std::optional<LoraTaskIdType> loraTaskId, VecUniqueTokens uniqueTokens,
131138
std::vector<MmKey> extraKeys = {})
132139
: usesExtraIds{usesExtraIds}
@@ -164,6 +171,10 @@ struct BlockKeyHasher
164171

165172
std::size_t operator()(BlockKey const& blockKey, std::size_t parentHash = 0) const noexcept
166173
{
174+
if (blockKey.hash != 0)
175+
{
176+
return blockKey.hash;
177+
}
167178
return hash(blockKey, parentHash);
168179
}
169180
};
@@ -566,6 +577,8 @@ class WindowBlockManager
566577

567578
void storeNewBlock(GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest);
568579

580+
void pinBlocks(GenerationRequest& sequence);
581+
569582
//! \brief Release blocks of the sequence.
570583
void releaseBlocks(GenerationRequest& sequence);
571584

@@ -737,6 +750,8 @@ class WindowBlockManager
737750
return 0;
738751
}
739752

753+
[[nodiscard]] std::optional<KVCacheBlock> findBlocksInReuseTreeByHashes(std::vector<size_t> const& hashes) const;
754+
740755
private:
741756
//! \brief Add single block to beam of sequence and mAllocatedBlocksPerSeq.
742757
void addBlockToBeam(BlockPtr& block, GenerationRequest& sequence, SizeType32 beamIdx);
@@ -883,6 +898,8 @@ class BlockManager
883898

884899
void schedulingReleaseBlocks(LlmRequest::RequestIdType requestId);
885900

901+
void pinBlocks(GenerationRequest& sequence);
902+
886903
void releaseLastBlock(GenerationRequest& sequence, SizeType32 windowSize);
887904

888905
void setOffsets(kernels::KVCacheIndex* offsetsPtr, nvinfer1::Dims const& offsetsShape, SizeType32 beamIdx,
@@ -1201,6 +1218,8 @@ class BaseKVCacheManager
12011218
[[nodiscard]] virtual SizeType32 getRemainingBlocksToCompletion(LlmRequest const& req, SizeType32 windowSize) const
12021219
= 0;
12031220

1221+
virtual void pinBlocks(LlmRequest::RequestIdType requestId) = 0;
1222+
12041223
/// @brief Increase size for request at seqSlotIdx. Allocate new KV cache block(s) if needed.
12051224
virtual void addToken(LlmRequest::RequestIdType requestId) = 0;
12061225

@@ -1588,6 +1607,8 @@ class KVCacheManager : public BaseKVCacheManager
15881607
[[nodiscard]] static SizeType32 calculateMaxBlockRequirements(SizeType32 inputLength, SizeType32 outputLength,
15891608
SizeType32 sinkTokenLength, SizeType32 windowSize, SizeType32 beamWidth, SizeType32 tokensPerBlock);
15901609

1610+
void pinBlocks(LlmRequest::RequestIdType requestId);
1611+
15911612
/// @brief Calculates the number of kv-cache blocks that a sequence will require, for a single beam.
15921613
///
15931614
/// @param sequenceLength The total length of the sequence (input and output).

cpp/include/tensorrt_llm/batch_manager/kvCacheUtils.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,19 @@ class BlockRange
8080
mBlockIds = std::move(blockIds);
8181
}
8282

83+
void setBlockIdsFromHashes(std::vector<size_t> blockHashes)
84+
{
85+
TLLM_CHECK(mManager);
86+
std::vector<SizeType32> blockIds;
87+
blockIds.reserve(blockHashes.size());
88+
auto& blockManager = mManager->getBlockManager();
89+
for (auto hash : blockHashes)
90+
{
91+
blockIds.emplace_back(blockManager.getBlockByHash(hash, mWindowSize)->getId());
92+
}
93+
mBlockIds = std::move(blockIds);
94+
}
95+
8396
[[nodiscard]] std::vector<size_t> getBlockHashes() const
8497
{
8598
TLLM_CHECK(mManager);

cpp/include/tensorrt_llm/batch_manager/llmRequest.h

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1831,16 +1831,6 @@ class GenericLlmRequest
18311831
}
18321832
}
18331833

1834-
void setRequestedBlockHashes(std::vector<size_t> hashes)
1835-
{
1836-
mRequestedBlockHashes = std::move(hashes);
1837-
}
1838-
1839-
[[nodiscard]] std::vector<size_t> const& getRequestedBlockHashes() const
1840-
{
1841-
return mRequestedBlockHashes;
1842-
}
1843-
18441834
void setIsDummyRequest(bool isDummyRequest)
18451835
{
18461836
mIsDummyRequest = isDummyRequest;

cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -39,23 +39,20 @@
3939
namespace tensorrt_llm::batch_manager::kv_cache_manager
4040
{
4141

42-
BlockRange getBlockRangeForSending(BaseKVCacheManager* cacheManager, LlmRequest const& llmRequest)
42+
BlockRange getBlockRangeForSending(
43+
BaseKVCacheManager* cacheManager, LlmRequest const& llmRequest, std::vector<size_t> const& blockHashes)
4344
{
44-
size_t requestBlockNum = llmRequest.getRequestedBlockHashes().size();
45+
size_t requestBlockNum = blockHashes.size();
4546
constexpr SizeType32 beam{0};
4647
auto blockRange = BlockRange::fromAllBlockIds(*cacheManager, llmRequest.mRequestId, beam);
4748
auto poolNum = cacheManager->getBlockManager().getNumPools();
48-
if (poolNum > 1 || common::getEnvDisableSelectiveCacheTransfer())
49+
if (poolNum > 1 || !cacheManager->isEnableBlockReuse())
4950
{
50-
// disable selective cache transfer for poolNum > 1
5151
return blockRange;
5252
}
5353
if (requestBlockNum < blockRange.size() && requestBlockNum > 0)
5454
{
55-
// handle block reuse, the prefix blocks are reused
56-
// TODO(zhengd): pass the hashes directly instead of from llmRequest; use hash instead of block num
57-
auto const& ids = blockRange.getBlockIds();
58-
blockRange.setBlockIds({ids.end() - requestBlockNum, ids.end()});
55+
blockRange.setBlockIds(blockHashes);
5956
}
6057
return blockRange;
6158
}
@@ -64,12 +61,15 @@ BlockRange getBlockRangeForReceiving(BaseKVCacheManager* cacheManager, LlmReques
6461
{
6562

6663
auto poolNum = cacheManager->getBlockManager().getNumPools();
67-
if (poolNum > 1 || common::getEnvDisableSelectiveCacheTransfer())
64+
if (poolNum == 1 && cacheManager->isEnableBlockReuse())
65+
{
66+
return BlockRange::fromNewlyAllocatedBlockIds(*cacheManager, llmRequest.mRequestId);
67+
}
68+
else
6869
{
6970
constexpr SizeType32 beam{0};
7071
return BlockRange::fromAllBlockIds(*cacheManager, llmRequest.mRequestId, beam);
7172
}
72-
return BlockRange::fromNewlyAllocatedBlockIds(*cacheManager, llmRequest.mRequestId);
7373
}
7474

7575
bool CacheFormatter::needSendCache(
@@ -155,13 +155,14 @@ void CacheFormatter::format(TransferSession& session)
155155
auto const& selfConfig = session.getSelfState().getCacheState().value();
156156
auto const& destConfig = session.getOtherState().getCacheState().value();
157157
auto const selfIdx = session.getSelfState().getCommState().value().getSelfIdx();
158+
auto& blockHashes = session.getBlockHashes();
158159
auto& bufferManager = session.getBufferManager();
159160
if (!needSendCache(selfConfig, destConfig, selfIdx))
160161
{
161162
return;
162163
}
163164
auto& blockManager = mCacheManager->getBlockManager();
164-
auto blockRange = getBlockRangeForSending(mCacheManager, llmRequest);
165+
auto blockRange = getBlockRangeForSending(mCacheManager, llmRequest, blockHashes);
165166

166167
auto const numPools = blockManager.getNumPools();
167168
// TODO(oargov): are we sure the other side has the same number of pools? this might not hold for pp_size>1...

0 commit comments

Comments
 (0)