@@ -3052,189 +3052,6 @@ TEST_F(KVCacheManagerTest, KVCacheManagerVariableWindowAttentionWithReuseTest)
30523052 assertBlocks (seq3, {4 }, {6 });
30533053}
30543054
3055- namespace
3056- {
3057- KVCacheManager setupKvCacheManagerForHashTest (bool enableBlockReuse)
3058- {
3059- auto constexpr numLayers = 2 ;
3060- auto constexpr numHeads = 2 ;
3061- auto constexpr sizePerHead = 64 ;
3062- auto constexpr tokensPerBlock = 4 ;
3063- auto constexpr maxNumSequences = 8 ;
3064- auto constexpr maxBeamWidth = 1 ;
3065- auto constexpr sinkTokenLength = 0 ;
3066- auto const stream = std::make_shared<tr::CudaStream>();
3067-
3068- auto constexpr maxBlocksPerSeq = 8 ;
3069- auto constexpr maxNumTokens = tokensPerBlock * maxBlocksPerSeq;
3070- auto constexpr maxAttentionWindow = maxNumTokens;
3071-
3072- auto constexpr blocksInPrimaryPool = 16 ;
3073- auto constexpr blocksInSecondaryPool = 0 ;
3074-
3075- auto constexpr onboardBlocks = true ;
3076-
3077- auto const blocksPerWindow = BlocksPerWindow{{maxAttentionWindow, {blocksInPrimaryPool, blocksInSecondaryPool}}};
3078-
3079- return KVCacheManager (std::vector<SizeType32>(numLayers, numHeads), sizePerHead, tokensPerBlock, blocksPerWindow,
3080- maxNumSequences, maxBeamWidth, std::vector<BlockManager::SizeType32>{maxAttentionWindow}, std::nullopt,
3081- nvinfer1::DataType::kHALF , sinkTokenLength, stream, std::nullopt, enableBlockReuse, onboardBlocks,
3082- CacheType::kSELF , std::nullopt, nullptr ,
3083- /* enableHashKey*/ true );
3084- }
3085-
3086- std::vector<size_t > getHashAndRetrieveBlocksByHashTest (
3087- BlockManager const & blockManager, std::vector<KVCacheBlock::IdType> const & blockIds, SizeType32 windowSize)
3088- {
3089- std::vector<size_t > blockHashes;
3090- for (auto blockId : blockIds)
3091- {
3092- blockHashes.emplace_back (blockManager.getBlockById (blockId, windowSize)->getHash ());
3093- }
3094- std::vector<BlockPtr> blockPtrs;
3095- for (auto hash : blockHashes)
3096- {
3097- auto range = blockManager.getBlocksByHash (hash, windowSize);
3098- BlockPtr const prevBlock = blockPtrs.empty () ? nullptr : blockPtrs.back ();
3099- BlockPtr thisBlock = nullptr ;
3100- for (auto it = range.first ; it != range.second ; ++it)
3101- {
3102- if (it->second ->getPrevBlockInSeq () == prevBlock)
3103- {
3104- thisBlock = it->second ;
3105- break ;
3106- }
3107- }
3108- EXPECT_NE (thisBlock, nullptr );
3109- blockPtrs.emplace_back (thisBlock);
3110- }
3111- EXPECT_EQ (blockHashes.size (), blockPtrs.size ());
3112- for (size_t i = 0 ; i < blockHashes.size (); i++)
3113- {
3114- EXPECT_EQ (blockManager.getBlockById (blockIds[i], windowSize), blockPtrs[i]);
3115- }
3116- return blockHashes;
3117- }
3118- } // namespace
3119-
3120- TEST_F (KVCacheManagerTest, KVCacheManagerHashKeyTest)
3121- {
3122- auto kvCacheManager = setupKvCacheManagerForHashTest (false );
3123-
3124- auto const & blockManager = kvCacheManager.getBlockManager ();
3125-
3126- SizeType32 constexpr maxNewTokens = 4 ;
3127-
3128- // prepare tokens with token[i] = 1000 + i
3129- TokenIdType constexpr firstToken = 1000 ;
3130-
3131- auto constexpr beamWidth = 1 ;
3132- tr::SamplingConfig const samplingConfig{beamWidth};
3133- bool constexpr isStreaming{false };
3134-
3135- SizeType32 requestId = 0 ;
3136- int inputLength = 16 ;
3137- auto inputTokens = std::make_shared<VecTokens>(inputLength);
3138- std::iota (inputTokens->begin (), inputTokens->end (), firstToken);
3139- auto llmRequest = std::make_shared<LlmRequest>(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming);
3140- auto constexpr beamIdx = 0 ;
3141-
3142- // /////////////////////////////////////////////////////////////////////////
3143- // add a request and then remove it without reuse
3144- kvCacheManager.addSequence (requestId, inputLength, beamWidth, llmRequest);
3145- GenerationRequest const & seq = kvCacheManager.getSequence (requestId);
3146- EXPECT_EQ (llmRequest->getContextCurrentPosition (), 0 );
3147-
3148- auto const onlyWindowSize = theOnlyWindowSize (kvCacheManager);
3149-
3150- auto & blockIds = seq.getCacheBlockIds (onlyWindowSize).at (beamIdx);
3151- EXPECT_THAT (blockIds, ::testing::ElementsAreArray ({0 , 1 , 2 , 3 }));
3152-
3153- // get blocks by hash and try to retrieve them by hash
3154- auto blockHashes = getHashAndRetrieveBlocksByHashTest (blockManager, blockIds, onlyWindowSize);
3155-
3156- EXPECT_NO_THROW (kvCacheManager.removeSequence (requestId, llmRequest));
3157-
3158- // blocks are all removed
3159- for (auto hash : blockHashes)
3160- {
3161- auto range = blockManager.getBlocksByHash (hash, onlyWindowSize);
3162- EXPECT_EQ (range.first , range.second );
3163- }
3164- EXPECT_EQ (blockManager.getNumAllocatedBlocks (), 0 );
3165- }
3166-
3167- TEST_F (KVCacheManagerTest, KVCacheManagerHashKeyWithReuseTest)
3168- {
3169- auto kvCacheManager = setupKvCacheManagerForHashTest (true );
3170-
3171- auto const & blockManager = kvCacheManager.getBlockManager ();
3172-
3173- SizeType32 constexpr maxNewTokens = 4 ;
3174-
3175- // prepare tokens with token[i] = 1000 + i
3176- TokenIdType constexpr firstToken = 1000 ;
3177-
3178- auto constexpr beamWidth = 1 ;
3179- tr::SamplingConfig const samplingConfig{beamWidth};
3180- bool constexpr isStreaming{false };
3181-
3182- SizeType32 requestId = 0 ;
3183- int inputLength = 16 ;
3184- auto inputTokens = std::make_shared<VecTokens>(inputLength);
3185- std::iota (inputTokens->begin (), inputTokens->end (), firstToken);
3186- auto llmRequest = std::make_shared<LlmRequest>(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming);
3187- auto constexpr beamIdx = 0 ;
3188-
3189- // /////////////////////////////////////////////////////////////////////////
3190- // add a request and then remove it with reuse
3191- kvCacheManager.addSequence (requestId, inputLength, beamWidth, llmRequest);
3192- GenerationRequest const & seq0 = kvCacheManager.getSequence (requestId);
3193- EXPECT_EQ (llmRequest->getContextCurrentPosition (), 0 );
3194-
3195- EXPECT_EQ (blockManager.getNumPools (), 1 );
3196- auto const onlyWindowSize = theOnlyWindowSize (kvCacheManager);
3197-
3198- auto & blockIds0 = seq0.getCacheBlockIds (onlyWindowSize).at (beamIdx);
3199- EXPECT_THAT (blockIds0, ::testing::ElementsAreArray ({0 , 1 , 2 , 3 }));
3200-
3201- // get blocks by hash and try to retrieve them by hash
3202- auto blockHashes = getHashAndRetrieveBlocksByHashTest (blockManager, blockIds0, onlyWindowSize);
3203-
3204- EXPECT_NO_THROW (kvCacheManager.removeSequence (requestId, llmRequest));
3205-
3206- // TODO: Make reused blocks accessible by hash, after sequence removed. Test here.
3207-
3208- // /////////////////////////////////////////////////////////////////////////
3209- // add a new request with same prefix
3210- requestId = 1 ;
3211- inputLength = 20 ;
3212- inputTokens->resize (inputLength);
3213- std::iota (inputTokens->begin (), inputTokens->end (), firstToken);
3214- llmRequest = std::make_shared<LlmRequest>(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming);
3215- kvCacheManager.addSequence (requestId, inputLength, beamWidth, llmRequest);
3216- GenerationRequest const & seq1 = kvCacheManager.getSequence (requestId);
3217- EXPECT_EQ (llmRequest->getContextCurrentPosition (), 15 );
3218- auto & blockIds1 = seq1.getCacheBlockIds (onlyWindowSize).at (beamIdx);
3219- EXPECT_THAT (blockIds1, ::testing::ElementsAreArray ({0 , 1 , 2 , 3 , 4 }));
3220-
3221- std::ignore = getHashAndRetrieveBlocksByHashTest (blockManager, blockIds1, onlyWindowSize);
3222-
3223- // blocks are reused, so reused blocks are still accessible by previous hashes
3224- for (size_t i = 0 ; i < 4 ; i++)
3225- {
3226- auto range = blockManager.getBlocksByHash (blockHashes[i], onlyWindowSize);
3227- EXPECT_NE (range.first , range.second );
3228- }
3229- // evicted block is not accessible
3230- {
3231- size_t i = 4 ;
3232- auto range = blockManager.getBlocksByHash (blockHashes[i], onlyWindowSize);
3233- EXPECT_EQ (range.first , range.second );
3234- }
3235- EXPECT_EQ (blockManager.getNumAllocatedBlocks (), 5 );
3236- }
3237-
32383055TEST_F (KVCacheManagerTest, KVCacheManagerEventStream)
32393056{
32403057 auto constexpr numLayers = 12 ;
0 commit comments