Skip to content

Commit 1a21dba

Browse files
committed
Initial iteration for supporting block hash transfer
Signed-off-by: Iman Tabrizian <[email protected]>
1 parent d1c5f80 commit 1a21dba

File tree

13 files changed

+304
-217
lines changed

13 files changed

+304
-217
lines changed

cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -566,6 +566,8 @@ class WindowBlockManager
566566

567567
void storeNewBlock(GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest);
568568

569+
void pinBlocks(GenerationRequest& sequence);
570+
569571
//! \brief Release blocks of the sequence.
570572
void releaseBlocks(GenerationRequest& sequence);
571573

@@ -883,6 +885,8 @@ class BlockManager
883885

884886
void schedulingReleaseBlocks(LlmRequest::RequestIdType requestId);
885887

888+
void pinBlocks(GenerationRequest& sequence);
889+
886890
void releaseLastBlock(GenerationRequest& sequence, SizeType32 windowSize);
887891

888892
void setOffsets(kernels::KVCacheIndex* offsetsPtr, nvinfer1::Dims const& offsetsShape, SizeType32 beamIdx,
@@ -1067,6 +1071,16 @@ class BlockManager
10671071
return mWindowBlockManagers.at(windowSize).getBlockById(blockId);
10681072
}
10691073

1074+
[[nodiscard]] WindowBlockManager::BlockMapIterRange getBlocksByHash(size_t hash, SizeType32 windowSize) const
1075+
{
1076+
return mWindowBlockManagers.at(windowSize).getBlocksByHash(hash);
1077+
}
1078+
1079+
[[nodiscard]] BlockPtr const& getBlockFromReuseTreeByHash(size_t hash, SizeType32 windowSize) const
1080+
{
1081+
return mWindowBlockManagers.at(windowSize).getBlockFromReuseTreeByHash(hash);
1082+
}
1083+
10701084
[[nodiscard]] SizeType32 getNumPrimaryBlocks() const
10711085
{
10721086
return sumWindows([](auto const& manager) { return manager.getNumPrimaryBlocks(); });
@@ -1201,6 +1215,8 @@ class BaseKVCacheManager
12011215
[[nodiscard]] virtual SizeType32 getRemainingBlocksToCompletion(LlmRequest const& req, SizeType32 windowSize) const
12021216
= 0;
12031217

1218+
virtual void pinBlocks(LlmRequest::RequestIdType requestId) = 0;
1219+
12041220
/// @brief Increase size for request at seqSlotIdx. Allocate new KV cache block(s) if needed.
12051221
virtual void addToken(LlmRequest::RequestIdType requestId) = 0;
12061222

@@ -1588,6 +1604,8 @@ class KVCacheManager : public BaseKVCacheManager
15881604
[[nodiscard]] static SizeType32 calculateMaxBlockRequirements(SizeType32 inputLength, SizeType32 outputLength,
15891605
SizeType32 sinkTokenLength, SizeType32 windowSize, SizeType32 beamWidth, SizeType32 tokensPerBlock);
15901606

1607+
void pinBlocks(LlmRequest::RequestIdType requestId);
1608+
15911609
/// @brief Calculates the number of kv-cache blocks that a sequence will require, for a single beam.
15921610
///
15931611
/// @param sequenceLength The total length of the sequence (input and output).

cpp/include/tensorrt_llm/batch_manager/kvCacheUtils.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,19 @@ class BlockRange
8080
mBlockIds = std::move(blockIds);
8181
}
8282

83+
void setBlockIdsFromHashes(std::vector<size_t> blockHashes)
84+
{
85+
TLLM_CHECK(mManager);
86+
std::vector<SizeType32> blockIds;
87+
blockIds.reserve(blockHashes.size());
88+
auto& blockManager = mManager->getBlockManager();
89+
for (auto hash : blockHashes)
90+
{
91+
blockIds.emplace_back(blockManager.getBlockByHash(hash, mWindowSize)->getId());
92+
}
93+
mBlockIds = std::move(blockIds);
94+
}
95+
8396
[[nodiscard]] std::vector<size_t> getBlockHashes() const
8497
{
8598
TLLM_CHECK(mManager);

cpp/include/tensorrt_llm/batch_manager/llmRequest.h

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1831,16 +1831,6 @@ class GenericLlmRequest
18311831
}
18321832
}
18331833

1834-
void setRequestedBlockHashes(std::vector<size_t> hashes)
1835-
{
1836-
mRequestedBlockHashes = std::move(hashes);
1837-
}
1838-
1839-
[[nodiscard]] std::vector<size_t> const& getRequestedBlockHashes() const
1840-
{
1841-
return mRequestedBlockHashes;
1842-
}
1843-
18441834
void setIsDummyRequest(bool isDummyRequest)
18451835
{
18461836
mIsDummyRequest = isDummyRequest;

cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -39,23 +39,20 @@
3939
namespace tensorrt_llm::batch_manager::kv_cache_manager
4040
{
4141

42-
BlockRange getBlockRangeForSending(BaseKVCacheManager* cacheManager, LlmRequest const& llmRequest)
42+
BlockRange getBlockRangeForSending(
43+
BaseKVCacheManager* cacheManager, LlmRequest const& llmRequest, std::vector<size_t> const& blockHashes)
4344
{
44-
size_t requestBlockNum = llmRequest.getRequestedBlockHashes().size();
45+
size_t requestBlockNum = blockHashes.size();
4546
constexpr SizeType32 beam{0};
4647
auto blockRange = BlockRange::fromAllBlockIds(*cacheManager, llmRequest.mRequestId, beam);
4748
auto poolNum = cacheManager->getBlockManager().getNumPools();
48-
if (poolNum > 1 || common::getEnvDisableSelectiveCacheTransfer())
49+
if (poolNum > 1 || !cacheManager->isEnableBlockReuse())
4950
{
50-
// disable selective cache transfer for poolNum > 1
5151
return blockRange;
5252
}
5353
if (requestBlockNum < blockRange.size() && requestBlockNum > 0)
5454
{
55-
// handle block reuse, the prefix blocks are reused
56-
// TODO(zhengd): pass the hashes directly instead of from llmRequest; use hash instead of block num
57-
auto const& ids = blockRange.getBlockIds();
58-
blockRange.setBlockIds({ids.end() - requestBlockNum, ids.end()});
55+
blockRange.setBlockIds(blockHashes);
5956
}
6057
return blockRange;
6158
}
@@ -64,12 +61,15 @@ BlockRange getBlockRangeForReceiving(BaseKVCacheManager* cacheManager, LlmReques
6461
{
6562

6663
auto poolNum = cacheManager->getBlockManager().getNumPools();
67-
if (poolNum > 1 || common::getEnvDisableSelectiveCacheTransfer())
64+
if (poolNum == 1 && cacheManager->isEnableBlockReuse())
65+
{
66+
return BlockRange::fromNewlyAllocatedBlockIds(*cacheManager, llmRequest.mRequestId);
67+
}
68+
else
6869
{
6970
constexpr SizeType32 beam{0};
7071
return BlockRange::fromAllBlockIds(*cacheManager, llmRequest.mRequestId, beam);
7172
}
72-
return BlockRange::fromNewlyAllocatedBlockIds(*cacheManager, llmRequest.mRequestId);
7373
}
7474

7575
bool CacheFormatter::needSendCache(
@@ -155,13 +155,14 @@ void CacheFormatter::format(TransferSession& session)
155155
auto const& selfConfig = session.getSelfState().getCacheState().value();
156156
auto const& destConfig = session.getOtherState().getCacheState().value();
157157
auto const selfIdx = session.getSelfState().getCommState().value().getSelfIdx();
158+
auto& blockHashes = session.getBlockHashes();
158159
auto& bufferManager = session.getBufferManager();
159160
if (!needSendCache(selfConfig, destConfig, selfIdx))
160161
{
161162
return;
162163
}
163164
auto& blockManager = mCacheManager->getBlockManager();
164-
auto blockRange = getBlockRangeForSending(mCacheManager, llmRequest);
165+
auto blockRange = getBlockRangeForSending(mCacheManager, llmRequest, blockHashes);
165166

166167
auto const numPools = blockManager.getNumPools();
167168
// TODO(oargov): are we sure the other side has the same number of pools? this might not hold for pp_size>1...

0 commit comments

Comments
 (0)