@@ -1034,6 +1034,182 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdTest)
1034
1034
EXPECT_EQ (blockManager.getNumFreeBlocks (), blocksInPrimaryPool);
1035
1035
}
1036
1036
1037
+ TEST_F (KVCacheManagerTest, BlockManagerReuseWithMultimodalHashTest)
1038
+ {
1039
+ using VecTokenExtraIds = LlmRequest::VecTokenExtraIds;
1040
+
1041
+ auto constexpr numLayers = 12 ;
1042
+ auto constexpr numKvHeads = 6 ;
1043
+ auto constexpr sizePerHead = 16 ;
1044
+ auto constexpr tokensPerBlock = 4 ;
1045
+ auto constexpr maxBlocksPerSeq = 4 ;
1046
+ auto constexpr blocksInPrimaryPool = 16 ;
1047
+ auto constexpr blocksInSecondaryPool = 0 ;
1048
+ auto constexpr maxNumSequences = 8 ;
1049
+ auto const stream = std::make_shared<tr::CudaStream>();
1050
+ auto constexpr onboardBlocks = true ;
1051
+ auto constexpr numReturnSequences = 1 ;
1052
+ auto constexpr maxAttentionWindow = tokensPerBlock * maxBlocksPerSeq;
1053
+ auto constexpr beamWidth = 1 ;
1054
+
1055
+ auto const blocksPerWindow = BlocksPerWindow{{maxAttentionWindow, {blocksInPrimaryPool, blocksInSecondaryPool}}};
1056
+
1057
+ BlockManager blockManager (std::vector (numLayers, numKvHeads), sizePerHead, tokensPerBlock, blocksPerWindow,
1058
+ maxNumSequences, stream, maxAttentionWindow, beamWidth,
1059
+ std::vector<BlockManager::SizeType32>{maxAttentionWindow}, std::nullopt, nvinfer1::DataType::kHALF , 0 ,
1060
+ onboardBlocks);
1061
+ blockManager.allocatePools (false );
1062
+
1063
+ EXPECT_EQ (blockManager.getTokensPerBlock (), tokensPerBlock);
1064
+ EXPECT_EQ (blockManager.getMaxNumBlocks (), blocksInPrimaryPool);
1065
+ EXPECT_EQ (blockManager.getNumFreeBlocks (), blocksInPrimaryPool);
1066
+
1067
+ SizeType32 constexpr maxNewTokens{0 };
1068
+ tr::SamplingConfig const samplingConfig{beamWidth};
1069
+ bool constexpr isStreaming{false };
1070
+
1071
+ // Create multimodal hash data (256-bit hash = 8 int32 values)
1072
+ auto multimodalHashes = std::make_shared<std::vector<std::vector<SizeType32>>>(std::vector<std::vector<SizeType32>>{
1073
+ {0x12345678 , -0x6F543211 , 0x11111111 , 0x22222222 , 0x33333333 , 0x44444444 , 0x55555555 , 0x66666666 } // Hash 1
1074
+ });
1075
+ auto multimodalPositions
1076
+ = std::make_shared<std::vector<SizeType32>>(std::vector<SizeType32>{2 }); // Start at token 2
1077
+ auto multimodalLengths = std::make_shared<std::vector<SizeType32>>(std::vector<SizeType32>{4 }); // Length 4 tokens
1078
+ // assume prompt id starts from 100
1079
+ auto inputTokens = std::make_shared<VecTokens>(VecTokens{100 , 101 , 102 , 103 , 104 , 105 , 0 , 1 , 2 });
1080
+ auto const inputLength = static_cast <SizeType32>(inputTokens->size ());
1081
+ LlmRequest::RequestIdType requestId{0 };
1082
+ auto llmRequest0 = std::make_shared<LlmRequest>(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming,
1083
+ std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt,
1084
+ multimodalHashes, multimodalPositions, multimodalLengths, std::nullopt, std::nullopt, std::nullopt,
1085
+ std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, false , false , false , std::nullopt,
1086
+ std::nullopt, false , std::nullopt, false , std::nullopt, false , std::nullopt, 0.5 , std::nullopt, std::nullopt,
1087
+ std::nullopt, LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, std::nullopt, numReturnSequences);
1088
+
1089
+ GenerationRequest seq0{requestId, inputLength, beamWidth, blockManager.getWindowSizesMetadata ()};
1090
+
1091
+ // /////////////////////////////////////////////////////////////////////////
1092
+ // add request and then remove it
1093
+ auto constexpr beamIdx = 0 ;
1094
+ auto promptLen0 = llmRequest0->getNumTokens (beamIdx);
1095
+ auto numContextBlocks0 = tc::ceilDiv (promptLen0, blockManager.getTokensPerBlock ());
1096
+ blockManager.addSequence (seq0, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow);
1097
+ EXPECT_EQ (llmRequest0->getContextCurrentPosition (), 0 );
1098
+ EXPECT_THAT (seq0.getCacheBlockIds (maxAttentionWindow).at (beamIdx), ::testing::ElementsAreArray ({0 , 1 , 2 }));
1099
+ llmRequest0->addNewToken (3 , beamIdx);
1100
+ llmRequest0->addNewToken (4 , beamIdx);
1101
+ auto numTokens = llmRequest0->getNumTokens (beamIdx);
1102
+ auto numBlocks = tc::ceilDiv (numTokens, tokensPerBlock);
1103
+ EXPECT_EQ (numBlocks, 3 );
1104
+ EXPECT_EQ (blockManager.getNumAllocatedBlocks (), numBlocks);
1105
+ EXPECT_EQ (blockManager.getNumFreeBlocks (), blocksInPrimaryPool - numBlocks);
1106
+
1107
+ // Input: [100, 101, 102, 103, 104, 105, 0, 1, 2] (9 tokens)
1108
+ // Multimodal: starts at token 2, length 4 → [102, 103, 104, 105]
1109
+
1110
+ // Block 0: [100, 101, 102, 103] ← Contains multimodal (102, 103)
1111
+ // Block 1: [104, 105, 0, 1] ← Contains multimodal (104, 105)
1112
+ // Block 2: [2, 3, 4] ← No multimodal
1113
+ blockManager.releaseBlocks (seq0, llmRequest0);
1114
+ EXPECT_EQ (blockManager.getNumAllocatedBlocks (), 0 );
1115
+ EXPECT_EQ (blockManager.getNumFreeBlocks (), blocksInPrimaryPool);
1116
+
1117
+ // /////////////////////////////////////////////////////////////////////////
1118
+ // new request with same tokens and same multimodal hash - should reuse
1119
+ requestId = 1 ;
1120
+ auto llmRequest1 = std::make_shared<LlmRequest>(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming,
1121
+ std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt,
1122
+ multimodalHashes, multimodalPositions, multimodalLengths, std::nullopt, std::nullopt, std::nullopt,
1123
+ std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, false , false , false , std::nullopt,
1124
+ std::nullopt, false , std::nullopt, false , std::nullopt, false , std::nullopt, 0.5 , std::nullopt, std::nullopt,
1125
+ std::nullopt, LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, std::nullopt, numReturnSequences);
1126
+ GenerationRequest seq1{requestId, inputLength, beamWidth, blockManager.getWindowSizesMetadata ()};
1127
+
1128
+ // should reuse blocks 0, 1 and get new block 3
1129
+ auto promptLen1 = llmRequest1->getNumTokens (beamIdx);
1130
+ auto numContextBlocks1 = tc::ceilDiv (promptLen1, blockManager.getTokensPerBlock ());
1131
+ blockManager.addSequence (seq1, promptLen1, numContextBlocks1, *llmRequest1, maxAttentionWindow);
1132
+ EXPECT_EQ (llmRequest1->getContextCurrentPosition (), 2 * tokensPerBlock);
1133
+ EXPECT_THAT (seq1.getCacheBlockIds (maxAttentionWindow).at (beamIdx), ::testing::ElementsAreArray ({0 , 1 , 3 }));
1134
+ llmRequest1->addNewToken (3 , beamIdx);
1135
+ llmRequest1->addNewToken (4 , beamIdx);
1136
+ EXPECT_EQ (blockManager.getNumAllocatedBlocks (), numBlocks);
1137
+ EXPECT_EQ (blockManager.getNumFreeBlocks (), blocksInPrimaryPool - numBlocks);
1138
+ // block 3 matches block 2 and will be freed
1139
+ blockManager.releaseBlocks (seq1, llmRequest1);
1140
+ EXPECT_EQ (blockManager.getNumAllocatedBlocks (), 0 );
1141
+ EXPECT_EQ (blockManager.getNumFreeBlocks (), blocksInPrimaryPool);
1142
+
1143
+ // /////////////////////////////////////////////////////////////////////////
1144
+ // Test Case 2: Different multimodal hash
1145
+ requestId = 2 ;
1146
+ auto multimodalHashes2
1147
+ = std::make_shared<std::vector<std::vector<SizeType32>>>(std::vector<std::vector<SizeType32>>{
1148
+ {0x45678123 , 0x23456789 , 0x34567890 , 0x12121212 , 0x56565656 , 0x78787878 , 0x54545454 , 0x67676767 } // Hash 2
1149
+ });
1150
+ auto multimodalPositions2
1151
+ = std::make_shared<std::vector<SizeType32>>(std::vector<SizeType32>{2 }); // Start at token 2
1152
+ auto multimodalLengths2 = std::make_shared<std::vector<SizeType32>>(std::vector<SizeType32>{4 }); // Length 4 tokens
1153
+ auto llmRequest2 = std::make_shared<LlmRequest>(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming,
1154
+ std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt,
1155
+ multimodalHashes2, multimodalPositions2, multimodalLengths2, std::nullopt, std::nullopt, std::nullopt,
1156
+ std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, false , false , false , std::nullopt,
1157
+ std::nullopt, false , std::nullopt, false , std::nullopt, false , std::nullopt, 0.5 , std::nullopt, std::nullopt,
1158
+ std::nullopt, LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, std::nullopt, numReturnSequences);
1159
+
1160
+ GenerationRequest seq2{requestId, inputLength, beamWidth, blockManager.getWindowSizesMetadata ()};
1161
+ // no reuse, get new blocks 4, 5, 6
1162
+ auto promptLen2 = llmRequest2->getNumTokens (beamIdx);
1163
+ auto numContextBlocks2 = tc::ceilDiv (promptLen2, blockManager.getTokensPerBlock ());
1164
+ blockManager.addSequence (seq2, promptLen2, numContextBlocks2, *llmRequest2, maxAttentionWindow);
1165
+ EXPECT_EQ (llmRequest2->getContextCurrentPosition (), 0 );
1166
+ EXPECT_THAT (seq2.getCacheBlockIds (maxAttentionWindow).at (beamIdx), ::testing::ElementsAreArray ({4 , 5 , 6 }));
1167
+ llmRequest2->addNewToken (9 , beamIdx);
1168
+ numTokens = llmRequest2->getNumTokens (beamIdx);
1169
+ numBlocks = tc::ceilDiv (numTokens, tokensPerBlock);
1170
+ EXPECT_EQ (blockManager.getNumAllocatedBlocks (), numBlocks);
1171
+ EXPECT_EQ (blockManager.getNumFreeBlocks (), blocksInPrimaryPool - numBlocks);
1172
+
1173
+ // /////////////////////////////////////////////////////////////////////////
1174
+ // Test Case 3: Multiple multimodal hashes and partial reuse
1175
+ requestId = 3 ;
1176
+ auto multimodalHashes3
1177
+ = std::make_shared<std::vector<std::vector<SizeType32>>>(std::vector<std::vector<SizeType32>>{
1178
+ {0x12345678 , -0x6F543211 , 0x11111111 , 0x22222222 , 0x33333333 , 0x44444444 , 0x55555555 , 0x66666666 }, // Hash 1
1179
+ {0x45678123 , 0x23456789 , 0x34567890 , 0x12121212 , 0x56565656 , 0x78787878 , 0x54545454 , 0x67676767 } // Hash 2
1180
+ });
1181
+ auto multimodalPositions3
1182
+ = std::make_shared<std::vector<SizeType32>>(std::vector<SizeType32>{2 , 4 }); // Start at token 2 and 4
1183
+ auto multimodalLengths3
1184
+ = std::make_shared<std::vector<SizeType32>>(std::vector<SizeType32>{2 , 2 }); // Length 2 tokens
1185
+
1186
+ auto llmRequest3 = std::make_shared<LlmRequest>(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming,
1187
+ std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt,
1188
+ multimodalHashes3, multimodalPositions3, multimodalLengths3, std::nullopt, std::nullopt, std::nullopt,
1189
+ std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, false , false , false , std::nullopt,
1190
+ std::nullopt, false , std::nullopt, false , std::nullopt, false , std::nullopt, 0.5 , std::nullopt, std::nullopt,
1191
+ std::nullopt, LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, std::nullopt, numReturnSequences);
1192
+ GenerationRequest seq3{requestId, inputLength, beamWidth, blockManager.getWindowSizesMetadata ()};
1193
+ // reuse block 0, get new blocks 7, 8
1194
+ auto promptLen3 = llmRequest3->getNumTokens (beamIdx);
1195
+ auto numContextBlocks3 = tc::ceilDiv (promptLen3, blockManager.getTokensPerBlock ());
1196
+ blockManager.addSequence (seq3, promptLen3, numContextBlocks3, *llmRequest3, maxAttentionWindow);
1197
+ EXPECT_EQ (llmRequest3->getContextCurrentPosition (),
1198
+ tokensPerBlock); // only reuse block 0 [100, 101, 102, 103] with same hash/offset
1199
+ EXPECT_THAT (seq3.getCacheBlockIds (maxAttentionWindow).at (beamIdx), ::testing::ElementsAreArray ({0 , 7 , 8 }));
1200
+ llmRequest3->addNewToken (11 , beamIdx);
1201
+ numTokens = llmRequest3->getNumTokens (beamIdx);
1202
+ numBlocks = tc::ceilDiv (numTokens, tokensPerBlock);
1203
+ EXPECT_EQ (blockManager.getNumAllocatedBlocks (), numBlocks * 2 );
1204
+ EXPECT_EQ (blockManager.getNumFreeBlocks (), blocksInPrimaryPool - numBlocks * 2 );
1205
+
1206
+ // clean up
1207
+ blockManager.releaseBlocks (seq2, llmRequest2);
1208
+ blockManager.releaseBlocks (seq3, llmRequest3);
1209
+ EXPECT_EQ (blockManager.getNumAllocatedBlocks (), 0 );
1210
+ EXPECT_EQ (blockManager.getNumFreeBlocks (), blocksInPrimaryPool);
1211
+ }
1212
+
1037
1213
TEST_F (KVCacheManagerTest, BlockManagerReuseWithLoraTaskIdTest)
1038
1214
{
1039
1215
// tc::Logger::getLogger()->setLevel(tc::Logger::Level::DEBUG);
0 commit comments