Skip to content

Commit 7381f1d

Browse files
authored
[TRTLLM-5059][feat] Add KV cache reuse support for multimodal models (#5444)
Only supports qwen in this PR
1 parent 4a0951f commit 7381f1d

File tree

8 files changed

+716
-12
lines changed

8 files changed

+716
-12
lines changed

cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
#include "tensorrt_llm/runtime/worldConfig.h"
3232
#include <NvInferRuntime.h>
3333

34+
#include <array>
3435
#include <cstdint>
3536
#include <limits>
3637
#include <list>
@@ -68,6 +69,9 @@ using VecUniqueTokens = tensorrt_llm::runtime::VecUniqueTokens;
6869
using LoraTaskIdType = tensorrt_llm::runtime::LoraTaskIdType;
6970
using BlocksPerWindow = std::map<SizeType32, std::tuple<SizeType32, SizeType32>>;
7071

72+
// Type alias for multimodal hash key (hash array + start offset)
73+
using MmKey = std::pair<std::array<uint8_t, 32>, SizeType32>;
74+
7175
template <typename T>
7276
using OptionalRef = tensorrt_llm::common::OptionalRef<T>;
7377

@@ -107,6 +111,10 @@ struct BlockKey
107111
std::optional<LoraTaskIdType> loraTaskId = std::nullopt;
108112
VecUniqueTokens uniqueTokens;
109113

114+
// Extra keys for multimodal data (similar to VLLM's approach)
115+
// Each extra key is a pair of (mm_hash, start_offset_in_block)
116+
std::vector<MmKey> extraKeys;
117+
110118
BlockKey() = default;
111119

112120
explicit BlockKey(VecTokens const& tokens, std::optional<LoraTaskIdType> loraTaskId = std::nullopt)
@@ -119,23 +127,25 @@ struct BlockKey
119127
}
120128
}
121129

122-
BlockKey(bool usesExtraIds, std::optional<LoraTaskIdType> loraTaskId, VecUniqueTokens uniqueTokens)
123-
: usesExtraIds(usesExtraIds)
130+
explicit BlockKey(bool usesExtraIds, std::optional<LoraTaskIdType> loraTaskId, VecUniqueTokens uniqueTokens,
131+
std::vector<MmKey> extraKeys = {})
132+
: usesExtraIds{usesExtraIds}
124133
, loraTaskId{loraTaskId}
125134
, uniqueTokens{std::move(uniqueTokens)}
135+
, extraKeys{std::move(extraKeys)}
126136
{
127137
}
128138

129139
bool operator==(BlockKey const& other) const noexcept
130140
{
131-
return (
132-
usesExtraIds == other.usesExtraIds && loraTaskId == other.loraTaskId && uniqueTokens == other.uniqueTokens);
141+
return (usesExtraIds == other.usesExtraIds && loraTaskId == other.loraTaskId
142+
&& uniqueTokens == other.uniqueTokens && extraKeys == other.extraKeys);
133143
}
134144

135145
int partialMatch(BlockKey const& other) const noexcept
136146
{
137147
SizeType32 numMatched{0};
138-
if (loraTaskId == other.loraTaskId)
148+
if (loraTaskId == other.loraTaskId && extraKeys == other.extraKeys)
139149
{
140150
auto [matchEnd, otherMatchEnd] = std::mismatch(
141151
uniqueTokens.begin(), uniqueTokens.end(), other.uniqueTokens.begin(), other.uniqueTokens.end());

cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp

Lines changed: 103 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,14 +76,82 @@ std::list<std::vector<T>> chopVectorIntoBlocks(
7676
return blockedVectors;
7777
}
7878

79+
inline uint8_t getNthByte(SizeType32 hashPart, uint8_t byteIdx) noexcept
80+
{
81+
return static_cast<uint8_t>((hashPart >> (24 - byteIdx * 8)) & 0xFF);
82+
}
83+
84+
std::vector<MmKey> generateBlockHashExtraKeys(
85+
tensorrt_llm::batch_manager::LlmRequest const& llmRequest, SizeType32 startTokenIdx, SizeType32 endTokenIdx)
86+
{
87+
auto const multimodalHashes = llmRequest.getMultimodalHashes();
88+
auto const multimodalPositions = llmRequest.getMultimodalPositions();
89+
auto const multimodalLengths = llmRequest.getMultimodalLengths();
90+
91+
if (!multimodalHashes || !multimodalPositions || !multimodalLengths || !(*multimodalHashes)
92+
|| (*multimodalHashes)->empty() || !(*multimodalPositions) || (*multimodalPositions)->empty()
93+
|| !(*multimodalLengths) || (*multimodalLengths)->empty())
94+
{
95+
return {};
96+
}
97+
98+
if ((*multimodalHashes)->size() != (*multimodalPositions)->size()
99+
|| (*multimodalPositions)->size() != (*multimodalLengths)->size())
100+
{
101+
TLLM_LOG_WARNING("Multimodal data arrays have mismatched sizes");
102+
return {};
103+
}
104+
105+
std::vector<MmKey> extraKeys; // MmKey = std::pair<std::array<uint8_t, 32>, SizeType32>
106+
extraKeys.reserve((*multimodalPositions)->size());
107+
std::array<uint8_t, 32> mmHashArray;
108+
109+
for (size_t i = 0; i < (*multimodalPositions)->size(); ++i)
110+
{
111+
auto const& startPos = (*(*multimodalPositions))[i];
112+
auto const& length = (*(*multimodalLengths))[i];
113+
auto const& mmHashVector = (*(*multimodalHashes))[i];
114+
115+
TLLM_CHECK_WITH_INFO(mmHashVector.size() == 8, "Multimodal hash vector has unexpected size: %zu (expected 8)",
116+
mmHashVector.size());
117+
118+
// mmHashVector[j] comes from Python's int(hex_chunk, 16)
119+
// where hex_chunk like "00010203" means 0x00 is MSB and 0x03 is LSB (big endian)
120+
// Convert 8x 32-bit integers into a 32-byte array preserving Blake3 hash byte order
121+
// Example: hashPart = 0x00010203 → mmHashArray[0:3] = [0x00, 0x01, 0x02, 0x03]
122+
for (size_t j = 0; j < 8; ++j)
123+
{
124+
auto const& hashPart = mmHashVector[j];
125+
for (uint8_t byteIdx = 0; byteIdx < 4; ++byteIdx)
126+
{
127+
mmHashArray[j * 4 + byteIdx] = getNthByte(hashPart, byteIdx);
128+
}
129+
}
130+
131+
// Check if this multimodal content overlaps with the current block
132+
if (endTokenIdx > startPos && startTokenIdx < startPos + length)
133+
{
134+
SizeType32 mmStartInBlock = (startPos >= startTokenIdx) ? 0 : startTokenIdx - startPos;
135+
extraKeys.emplace_back(mmHashArray, mmStartInBlock);
136+
}
137+
}
138+
139+
return extraKeys;
140+
}
141+
79142
std::vector<BlockKey> buildBlockKeys(
80143
std::list<VecUniqueTokens>& blockedUniqueTokens, tensorrt_llm::batch_manager::LlmRequest const& llmRequest)
81144
{
82145
std::vector<BlockKey> blockKeys;
146+
147+
SizeType32 currentTokenIdx = 0;
83148
for (auto& uniqueTokens : blockedUniqueTokens)
84149
{
85-
blockKeys.emplace_back(
86-
llmRequest.getInputTokensExtraIds().has_value(), llmRequest.getLoraTaskId(), std::move(uniqueTokens));
150+
auto extraKeys = generateBlockHashExtraKeys(llmRequest, currentTokenIdx, currentTokenIdx + uniqueTokens.size());
151+
currentTokenIdx += uniqueTokens.size();
152+
153+
blockKeys.emplace_back(llmRequest.getInputTokensExtraIds().has_value(), llmRequest.getLoraTaskId(),
154+
std::move(uniqueTokens), std::move(extraKeys));
87155
}
88156
return blockKeys;
89157
}
@@ -92,9 +160,11 @@ std::vector<BlockKey> buildBlockKeys(
92160

93161
namespace tensorrt_llm::batch_manager::kv_cache_manager
94162
{
95-
96163
size_t BlockKeyHasher::hash(BlockKey const& blockKey, std::size_t parentHash) noexcept
97164
{
165+
// Hashing algorithm adapted from StackOverflow:
166+
// https://stackoverflow.com/questions/664014/what-integer-hash-function-are-good-that-accepts-an-integer-hash-key
167+
// Constants provide very good distribution - each input bit affects each output bit with ~50% probability.
98168
size_t seed = blockKey.uniqueTokens.size() ^ parentHash * UINT64_C(0xbf58476d1ce4e5b9);
99169

100170
for (auto const& uniqueToken : blockKey.uniqueTokens)
@@ -122,7 +192,36 @@ size_t BlockKeyHasher::hash(BlockKey const& blockKey, std::size_t parentHash) no
122192
c = c ^ (c >> 31);
123193
seed ^= c + 0x9e3779b9 + (seed << 6) + (seed >> 2);
124194
}
125-
// TODO: support external hashes for multimodal
195+
196+
// Add extra keys for multimodal data mixing in external multimodal item hash and token offset within this sequence
197+
// block
198+
if (!blockKey.extraKeys.empty())
199+
{
200+
for (auto const& [mmHash, startOffset] : blockKey.extraKeys)
201+
{
202+
// Hash the multimodal hash array in 32-bit chunks (more efficient)
203+
for (size_t i = 0; i < 32; i += 4)
204+
{
205+
// Combine 4 bytes into a 32-bit word (construct as little endian order)
206+
uint32_t word = static_cast<uint32_t>(mmHash[i]) | (static_cast<uint32_t>(mmHash[i + 1]) << 8)
207+
| (static_cast<uint32_t>(mmHash[i + 2]) << 16) | (static_cast<uint32_t>(mmHash[i + 3]) << 24);
208+
209+
// Mix the word into the seed
210+
word = ((word >> 16) ^ word) * 0x45d9f3b;
211+
word = ((word >> 16) ^ word) * 0x45d9f3b;
212+
word = (word >> 16) ^ word;
213+
seed ^= word + 0x9e3779b9 + (seed << 6) + (seed >> 2);
214+
}
215+
216+
// Hash the start offset
217+
uint64_t e = static_cast<uint64_t>(startOffset);
218+
e = (e ^ (e >> 30)) * UINT64_C(0xbf58476d1ce4e5b9);
219+
e = (e ^ (e >> 27)) * UINT64_C(0x94d049bb133111eb);
220+
e = e ^ (e >> 31);
221+
seed ^= e + 0x9e3779b9 + (seed << 6) + (seed >> 2);
222+
}
223+
}
224+
126225
return seed;
127226
}
128227

cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp

Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1034,6 +1034,182 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdTest)
10341034
EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool);
10351035
}
10361036

1037+
TEST_F(KVCacheManagerTest, BlockManagerReuseWithMultimodalHashTest)
1038+
{
1039+
using VecTokenExtraIds = LlmRequest::VecTokenExtraIds;
1040+
1041+
auto constexpr numLayers = 12;
1042+
auto constexpr numKvHeads = 6;
1043+
auto constexpr sizePerHead = 16;
1044+
auto constexpr tokensPerBlock = 4;
1045+
auto constexpr maxBlocksPerSeq = 4;
1046+
auto constexpr blocksInPrimaryPool = 16;
1047+
auto constexpr blocksInSecondaryPool = 0;
1048+
auto constexpr maxNumSequences = 8;
1049+
auto const stream = std::make_shared<tr::CudaStream>();
1050+
auto constexpr onboardBlocks = true;
1051+
auto constexpr numReturnSequences = 1;
1052+
auto constexpr maxAttentionWindow = tokensPerBlock * maxBlocksPerSeq;
1053+
auto constexpr beamWidth = 1;
1054+
1055+
auto const blocksPerWindow = BlocksPerWindow{{maxAttentionWindow, {blocksInPrimaryPool, blocksInSecondaryPool}}};
1056+
1057+
BlockManager blockManager(std::vector(numLayers, numKvHeads), sizePerHead, tokensPerBlock, blocksPerWindow,
1058+
maxNumSequences, stream, maxAttentionWindow, beamWidth,
1059+
std::vector<BlockManager::SizeType32>{maxAttentionWindow}, std::nullopt, nvinfer1::DataType::kHALF, 0,
1060+
onboardBlocks);
1061+
blockManager.allocatePools(false);
1062+
1063+
EXPECT_EQ(blockManager.getTokensPerBlock(), tokensPerBlock);
1064+
EXPECT_EQ(blockManager.getMaxNumBlocks(), blocksInPrimaryPool);
1065+
EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool);
1066+
1067+
SizeType32 constexpr maxNewTokens{0};
1068+
tr::SamplingConfig const samplingConfig{beamWidth};
1069+
bool constexpr isStreaming{false};
1070+
1071+
// Create multimodal hash data (256-bit hash = 8 int32 values)
1072+
auto multimodalHashes = std::make_shared<std::vector<std::vector<SizeType32>>>(std::vector<std::vector<SizeType32>>{
1073+
{0x12345678, -0x6F543211, 0x11111111, 0x22222222, 0x33333333, 0x44444444, 0x55555555, 0x66666666} // Hash 1
1074+
});
1075+
auto multimodalPositions
1076+
= std::make_shared<std::vector<SizeType32>>(std::vector<SizeType32>{2}); // Start at token 2
1077+
auto multimodalLengths = std::make_shared<std::vector<SizeType32>>(std::vector<SizeType32>{4}); // Length 4 tokens
1078+
// assume prompt id starts from 100
1079+
auto inputTokens = std::make_shared<VecTokens>(VecTokens{100, 101, 102, 103, 104, 105, 0, 1, 2});
1080+
auto const inputLength = static_cast<SizeType32>(inputTokens->size());
1081+
LlmRequest::RequestIdType requestId{0};
1082+
auto llmRequest0 = std::make_shared<LlmRequest>(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming,
1083+
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt,
1084+
multimodalHashes, multimodalPositions, multimodalLengths, std::nullopt, std::nullopt, std::nullopt,
1085+
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt,
1086+
std::nullopt, false, std::nullopt, false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt,
1087+
std::nullopt, LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, std::nullopt, numReturnSequences);
1088+
1089+
GenerationRequest seq0{requestId, inputLength, beamWidth, blockManager.getWindowSizesMetadata()};
1090+
1091+
///////////////////////////////////////////////////////////////////////////
1092+
// add request and then remove it
1093+
auto constexpr beamIdx = 0;
1094+
auto promptLen0 = llmRequest0->getNumTokens(beamIdx);
1095+
auto numContextBlocks0 = tc::ceilDiv(promptLen0, blockManager.getTokensPerBlock());
1096+
blockManager.addSequence(seq0, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow);
1097+
EXPECT_EQ(llmRequest0->getContextCurrentPosition(), 0);
1098+
EXPECT_THAT(seq0.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 2}));
1099+
llmRequest0->addNewToken(3, beamIdx);
1100+
llmRequest0->addNewToken(4, beamIdx);
1101+
auto numTokens = llmRequest0->getNumTokens(beamIdx);
1102+
auto numBlocks = tc::ceilDiv(numTokens, tokensPerBlock);
1103+
EXPECT_EQ(numBlocks, 3);
1104+
EXPECT_EQ(blockManager.getNumAllocatedBlocks(), numBlocks);
1105+
EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool - numBlocks);
1106+
1107+
// Input: [100, 101, 102, 103, 104, 105, 0, 1, 2] (9 tokens)
1108+
// Multimodal: starts at token 2, length 4 → [102, 103, 104, 105]
1109+
1110+
// Block 0: [100, 101, 102, 103] ← Contains multimodal (102, 103)
1111+
// Block 1: [104, 105, 0, 1] ← Contains multimodal (104, 105)
1112+
// Block 2: [2, 3, 4] ← No multimodal
1113+
blockManager.releaseBlocks(seq0, llmRequest0);
1114+
EXPECT_EQ(blockManager.getNumAllocatedBlocks(), 0);
1115+
EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool);
1116+
1117+
///////////////////////////////////////////////////////////////////////////
1118+
// new request with same tokens and same multimodal hash - should reuse
1119+
requestId = 1;
1120+
auto llmRequest1 = std::make_shared<LlmRequest>(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming,
1121+
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt,
1122+
multimodalHashes, multimodalPositions, multimodalLengths, std::nullopt, std::nullopt, std::nullopt,
1123+
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt,
1124+
std::nullopt, false, std::nullopt, false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt,
1125+
std::nullopt, LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, std::nullopt, numReturnSequences);
1126+
GenerationRequest seq1{requestId, inputLength, beamWidth, blockManager.getWindowSizesMetadata()};
1127+
1128+
// should reuse blocks 0, 1 and get new block 3
1129+
auto promptLen1 = llmRequest1->getNumTokens(beamIdx);
1130+
auto numContextBlocks1 = tc::ceilDiv(promptLen1, blockManager.getTokensPerBlock());
1131+
blockManager.addSequence(seq1, promptLen1, numContextBlocks1, *llmRequest1, maxAttentionWindow);
1132+
EXPECT_EQ(llmRequest1->getContextCurrentPosition(), 2 * tokensPerBlock);
1133+
EXPECT_THAT(seq1.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 3}));
1134+
llmRequest1->addNewToken(3, beamIdx);
1135+
llmRequest1->addNewToken(4, beamIdx);
1136+
EXPECT_EQ(blockManager.getNumAllocatedBlocks(), numBlocks);
1137+
EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool - numBlocks);
1138+
// block 3 matches block 2 and will be freed
1139+
blockManager.releaseBlocks(seq1, llmRequest1);
1140+
EXPECT_EQ(blockManager.getNumAllocatedBlocks(), 0);
1141+
EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool);
1142+
1143+
///////////////////////////////////////////////////////////////////////////
1144+
// Test Case 2: Different multimodal hash
1145+
requestId = 2;
1146+
auto multimodalHashes2
1147+
= std::make_shared<std::vector<std::vector<SizeType32>>>(std::vector<std::vector<SizeType32>>{
1148+
{0x45678123, 0x23456789, 0x34567890, 0x12121212, 0x56565656, 0x78787878, 0x54545454, 0x67676767} // Hash 2
1149+
});
1150+
auto multimodalPositions2
1151+
= std::make_shared<std::vector<SizeType32>>(std::vector<SizeType32>{2}); // Start at token 2
1152+
auto multimodalLengths2 = std::make_shared<std::vector<SizeType32>>(std::vector<SizeType32>{4}); // Length 4 tokens
1153+
auto llmRequest2 = std::make_shared<LlmRequest>(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming,
1154+
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt,
1155+
multimodalHashes2, multimodalPositions2, multimodalLengths2, std::nullopt, std::nullopt, std::nullopt,
1156+
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt,
1157+
std::nullopt, false, std::nullopt, false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt,
1158+
std::nullopt, LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, std::nullopt, numReturnSequences);
1159+
1160+
GenerationRequest seq2{requestId, inputLength, beamWidth, blockManager.getWindowSizesMetadata()};
1161+
// no reuse, get new blocks 4, 5, 6
1162+
auto promptLen2 = llmRequest2->getNumTokens(beamIdx);
1163+
auto numContextBlocks2 = tc::ceilDiv(promptLen2, blockManager.getTokensPerBlock());
1164+
blockManager.addSequence(seq2, promptLen2, numContextBlocks2, *llmRequest2, maxAttentionWindow);
1165+
EXPECT_EQ(llmRequest2->getContextCurrentPosition(), 0);
1166+
EXPECT_THAT(seq2.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({4, 5, 6}));
1167+
llmRequest2->addNewToken(9, beamIdx);
1168+
numTokens = llmRequest2->getNumTokens(beamIdx);
1169+
numBlocks = tc::ceilDiv(numTokens, tokensPerBlock);
1170+
EXPECT_EQ(blockManager.getNumAllocatedBlocks(), numBlocks);
1171+
EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool - numBlocks);
1172+
1173+
///////////////////////////////////////////////////////////////////////////
1174+
// Test Case 3: Multiple multimodal hashes and partial reuse
1175+
requestId = 3;
1176+
auto multimodalHashes3
1177+
= std::make_shared<std::vector<std::vector<SizeType32>>>(std::vector<std::vector<SizeType32>>{
1178+
{0x12345678, -0x6F543211, 0x11111111, 0x22222222, 0x33333333, 0x44444444, 0x55555555, 0x66666666}, // Hash 1
1179+
{0x45678123, 0x23456789, 0x34567890, 0x12121212, 0x56565656, 0x78787878, 0x54545454, 0x67676767} // Hash 2
1180+
});
1181+
auto multimodalPositions3
1182+
= std::make_shared<std::vector<SizeType32>>(std::vector<SizeType32>{2, 4}); // Start at token 2 and 4
1183+
auto multimodalLengths3
1184+
= std::make_shared<std::vector<SizeType32>>(std::vector<SizeType32>{2, 2}); // Length 2 tokens
1185+
1186+
auto llmRequest3 = std::make_shared<LlmRequest>(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming,
1187+
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt,
1188+
multimodalHashes3, multimodalPositions3, multimodalLengths3, std::nullopt, std::nullopt, std::nullopt,
1189+
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt,
1190+
std::nullopt, false, std::nullopt, false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt,
1191+
std::nullopt, LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, std::nullopt, numReturnSequences);
1192+
GenerationRequest seq3{requestId, inputLength, beamWidth, blockManager.getWindowSizesMetadata()};
1193+
// reuse block 0, get new blocks 7, 8
1194+
auto promptLen3 = llmRequest3->getNumTokens(beamIdx);
1195+
auto numContextBlocks3 = tc::ceilDiv(promptLen3, blockManager.getTokensPerBlock());
1196+
blockManager.addSequence(seq3, promptLen3, numContextBlocks3, *llmRequest3, maxAttentionWindow);
1197+
EXPECT_EQ(llmRequest3->getContextCurrentPosition(),
1198+
tokensPerBlock); // only reuse block 0 [100, 101, 102, 103] with same hash/offset
1199+
EXPECT_THAT(seq3.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 7, 8}));
1200+
llmRequest3->addNewToken(11, beamIdx);
1201+
numTokens = llmRequest3->getNumTokens(beamIdx);
1202+
numBlocks = tc::ceilDiv(numTokens, tokensPerBlock);
1203+
EXPECT_EQ(blockManager.getNumAllocatedBlocks(), numBlocks * 2);
1204+
EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool - numBlocks * 2);
1205+
1206+
// clean up
1207+
blockManager.releaseBlocks(seq2, llmRequest2);
1208+
blockManager.releaseBlocks(seq3, llmRequest3);
1209+
EXPECT_EQ(blockManager.getNumAllocatedBlocks(), 0);
1210+
EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool);
1211+
}
1212+
10371213
TEST_F(KVCacheManagerTest, BlockManagerReuseWithLoraTaskIdTest)
10381214
{
10391215
// tc::Logger::getLogger()->setLevel(tc::Logger::Level::DEBUG);

0 commit comments

Comments
 (0)