@@ -203,27 +203,62 @@ CacheTransBufferManager::CacheTransBufferManager(
203203 if (maxNumTokens.has_value ())
204204 {
205205 TLLM_CHECK (maxNumTokens.value () % tokensPerBlock == 0 );
206- auto dataSize = common::getDTypeSize (mDataType );
207- auto kvCacheByteSizePerTokenPerLayer = mCacheManager ->getBlockManager ().getBlockSize (0 ) / tokensPerBlock
208- * (mCacheManager ->getCacheType () == CacheType::kSELFKONLY ? 1 : 2 ) * dataSize;
206+
207+ // Transmission always uses FP16 for mixed precision support
208+ size_t transmissionDataSize = common::getDTypeSize (nvinfer1::DataType::kHALF );
209+ size_t storageDataSize = common::getDTypeSize (mDataType );
210+
211+ TLLM_LOG_INFO (" === Buffer Size Calculation Debug ===" );
212+ TLLM_LOG_INFO (" maxNumTokens: %ld" , maxNumTokens.value ());
213+ TLLM_LOG_INFO (" tokensPerBlock: %ld" , tokensPerBlock);
214+
215+ // getBlockSize() returns volume of [numKvHeads, tokensPerBlock, sizePerHead] for ONE cache (K or V)
216+ // We need to multiply by KV factor (2 for SELF, 1 for SELFKONLY) to get both K and V
217+ size_t blockSizeInElements = mCacheManager ->getBlockManager ().getBlockSize (0 );
218+ size_t kvFactor = (mCacheManager ->getCacheType () == CacheType::kSELFKONLY ? 1 : 2 );
219+
220+ TLLM_LOG_INFO (" getBlockSize(0): %ld elements (for ONE cache)" , blockSizeInElements);
221+ TLLM_LOG_INFO (" kvFactor: %ld (1=K only, 2=K+V)" , kvFactor);
222+ TLLM_LOG_INFO (" storageDataSize (mDataType): %ld bytes" , storageDataSize);
223+ TLLM_LOG_INFO (" transmissionDataSize (FP16): %ld bytes" , transmissionDataSize);
224+
225+ // Calculate bytes per token for both K and V in transmission format (FP16)
226+ size_t bytesPerTokenInStorage = (blockSizeInElements * kvFactor * storageDataSize) / tokensPerBlock;
227+ size_t bytesPerTokenInTransmission = (blockSizeInElements * kvFactor * transmissionDataSize) / tokensPerBlock;
228+ auto kvCacheByteSizePerTokenPerLayer = bytesPerTokenInTransmission;
229+
230+ TLLM_LOG_INFO (" bytesPerTokenInStorage: %ld" , bytesPerTokenInStorage);
231+ TLLM_LOG_INFO (" bytesPerTokenInTransmission: %ld" , bytesPerTokenInTransmission);
232+ TLLM_LOG_INFO (" kvCacheByteSizePerTokenPerLayer: %ld" , kvCacheByteSizePerTokenPerLayer);
233+
209234 for (auto layerId = 0 ; layerId < mCacheManager ->getBlockManager ().getNumLayers (); layerId++)
210235 {
211236 auto poolIdx = mCacheManager ->getBlockManager ().getLayerPoolIdx (layerId);
212237 auto windowSize = static_cast <size_t >(mCacheManager ->getBlockManager ().getPoolWindowSize (poolIdx));
213- auto alignedWindowSize = (windowSize + tokensPerBlock - 1 ) / tokensPerBlock * tokensPerBlock;
214- auto validTokenNum = (alignedWindowSize < maxNumTokens.value () ? alignedWindowSize : maxNumTokens.value ());
215- // if windowSize % (tokensPerBlock) !=0
216- validTokenNum += tokensPerBlock; // add one more block
217-
238+ auto validTokenNum = (windowSize < maxNumTokens.value () ? windowSize : maxNumTokens.value ());
239+
240+ TLLM_LOG_INFO (" Layer %d: poolIdx=%d, windowSize=%ld, validTokenNum=%ld" ,
241+ layerId, poolIdx, windowSize, validTokenNum);
242+ TLLM_LOG_INFO (" Layer %d: adding %ld bytes (validTokenNum=%ld * kvCacheBytesPerToken=%ld)" ,
243+ layerId, validTokenNum * kvCacheByteSizePerTokenPerLayer, validTokenNum, kvCacheByteSizePerTokenPerLayer);
244+
218245 bufferSizeFromMaxNumToken += validTokenNum * kvCacheByteSizePerTokenPerLayer;
219246 }
247+
248+ TLLM_LOG_INFO (" Total bufferSizeFromMaxNumToken: %ld bytes" , bufferSizeFromMaxNumToken);
249+
220250 }
221251
222252 mTransferBufferSize
223253 = maxNumTokens.has_value () ? bufferSizeFromMaxNumToken : common::getEnvMemSizeForKVCacheTransferBuffer ();
254+
255+ TLLM_LOG_INFO (" HERE HERE HERE HERE HERE --------- " );
256+ TLLM_LOG_INFO (" mTransferBufferSize:%ld" , mTransferBufferSize );
257+ TLLM_LOG_INFO (" HERE HERE HERE HERE HERE --------- " );
258+
224259 mOnlyUseDynamicBuffer = mTransferBufferSize == 0 ;
225260 mRecvBufferCount = common::getEnvRequestKVCacheConcurrent () ? common::getEnvKVCacheRecvBufferCount () : 1 ;
226- mSendBufferCount = common::getEnvKVCacheSendMaxConcurrenceNum ();
261+ mSendBufferCount = common::getEnvParallelCacheSend () ? common:: getEnvKVCacheSendMaxConcurrenceNum () : 1 ;
227262 mUseFabricMemory = !(common::getEnvKVCacheTransferUseSyncBuffer () || common::getEnvKVCacheTransferUseAsyncBuffer ())
228263 && FabricMemory::supportFbaricMemory ();
229264 if (mUseFabricMemory )
@@ -241,6 +276,7 @@ CacheTransBufferManager::CacheTransBufferManager(
241276 allocateBuffer ();
242277}
243278
279+
244280size_t CacheTransBufferManager::preAllocBufferSize (size_t tokensPerBlock,
245281 std::map<SizeType32, SizeType32> const & cacheSizeBytesPerTokenPerWindow,
246282 std::optional<executor::CacheTransceiverConfig> const & cacheTransceiverConfig)
@@ -365,8 +401,13 @@ std::tuple<std::vector<runtime::ITensor::SharedPtr>, size_t, bool> CacheTransBuf
365401 }
366402 else
367403 {
404+
368405 retSplitCaches.push_back (bufferManagerToUse.gpu (
369- runtime::ITensor::makeShape ({static_cast <int64_t >(targetBufferEleSizes[i])}), mDataType ));
406+ runtime::ITensor::makeShape ({static_cast <int64_t >(targetBufferEleSizes[i])}), nvinfer1::DataType::kHALF )); // Overriding for now
407+
408+
409+ // retSplitCaches.push_back(bufferManagerToUse.gpu(
410+ // runtime::ITensor::makeShape({static_cast<int64_t>(targetBufferEleSizes[i])}), mDataType));
370411 }
371412 }
372413 TLLM_LOG_DEBUG (" getOrAllocateBuffers bufferCoverTargetNum:%d" , bufferCoverTargetNum);
@@ -384,8 +425,13 @@ std::tuple<std::vector<runtime::ITensor::SharedPtr>, size_t, bool> CacheTransBuf
384425 {
385426 for (int i = 0 ; i < targetNum; i++)
386427 {
428+
387429 retSplitCaches.push_back (bufferManagerToUse.gpu (
388- runtime::ITensor::makeShape ({static_cast <int64_t >(targetBufferEleSizes[i])}), mDataType ));
430+ runtime::ITensor::makeShape ({static_cast <int64_t >(targetBufferEleSizes[i])}), nvinfer1::DataType::kHALF ));
431+
432+
433+ // retSplitCaches.push_back(bufferManagerToUse.gpu(
434+ // runtime::ITensor::makeShape({static_cast<int64_t>(targetBufferEleSizes[i])}), mDataType)); // Overriding for now
389435 }
390436 bufferCoverTargetNum = targetNum;
391437 }
@@ -399,36 +445,45 @@ void CacheTransBufferManager::allocateBuffer()
399445 {
400446 return ;
401447 }
402- mBufferEleSize = mTransferBufferSize / common::getDTypeSize (mDataType );
448+ mBufferEleSize = mTransferBufferSize / common::getDTypeSize (nvinfer1::DataType::kHALF ); // Overriding for now
449+ // mBufferEleSize = mTransferBufferSize / common::getDTypeSize(mDataType);
403450 mConcurrenceSendResource .mBufferIndexFlag .resize (mSendBufferCount , 0 );
404451 mConcurrenceRecvResource .mBufferIndexFlag .resize (mRecvBufferCount , 0 );
405452 if (mUseFabricMemory )
406453 {
407454 mFabricMemory .reserve (mSendBufferCount + mRecvBufferCount );
408- for (size_t i = 0 ; i < mSendBufferCount ; i++)
455+ for (size_t i = 0 ; i < mSendBufferCount ; i++)
409456 {
410457 mFabricMemory .emplace_back (std::make_unique<FabricMemory>(mTransferBufferSize ));
411- mConcurrenceSendResource .mBuffers [i] = runtime::ITensor::wrap (mFabricMemory .back ()->getPtr (), mDataType ,
458+ mConcurrenceSendResource .mBuffers [i] = runtime::ITensor::wrap (mFabricMemory .back ()->getPtr (), nvinfer1::DataType:: kHALF , // Overriding for now
412459 runtime::ITensor::makeShape ({static_cast <int64_t >(mBufferEleSize )}), mBufferEleSize );
460+ // mConcurrenceSendResource.mBuffers[i] = runtime::ITensor::wrap(mFabricMemory.back()->getPtr(), mDataType,
461+ // runtime::ITensor::makeShape({static_cast<int64_t>(mBufferEleSize)}), mBufferEleSize);
413462 }
414463 for (size_t i = 0 ; i < mRecvBufferCount ; i++)
415464 {
416465 mFabricMemory .emplace_back (std::make_unique<FabricMemory>(mTransferBufferSize ));
417- mConcurrenceRecvResource .mBuffers [i] = runtime::ITensor::wrap (mFabricMemory .back ()->getPtr (), mDataType ,
466+ mConcurrenceRecvResource .mBuffers [i] = runtime::ITensor::wrap (mFabricMemory .back ()->getPtr (), nvinfer1::DataType:: kHALF , // Overriding for now
418467 runtime::ITensor::makeShape ({static_cast <int64_t >(mBufferEleSize )}), mBufferEleSize );
468+ // mConcurrenceRecvResource.mBuffers[i] = runtime::ITensor::wrap(mFabricMemory.back()->getPtr(), mDataType,
469+ // runtime::ITensor::makeShape({static_cast<int64_t>(mBufferEleSize)}), mBufferEleSize);
419470 }
420471 }
421472 else if (common::getEnvKVCacheTransferUseAsyncBuffer ())
422473 {
423474 for (size_t i = 0 ; i < mSendBufferCount ; i++)
424475 {
425476 mConcurrenceSendResource .mBuffers [i]
426- = mBufferManager .gpu (runtime::ITensor::makeShape ({static_cast <int64_t >(mBufferEleSize )}), mDataType );
477+ = mBufferManager .gpu (runtime::ITensor::makeShape ({static_cast <int64_t >(mBufferEleSize )}), nvinfer1::DataType::kHALF ); // Overriding for now
478+ // mConcurrenceSendResource.mBuffers[i]
479+ // = mBufferManager.gpu(runtime::ITensor::makeShape({static_cast<int64_t>(mBufferEleSize)}), mDataType);
427480 }
428481 for (size_t i = 0 ; i < mRecvBufferCount ; i++)
429482 {
430483 mConcurrenceRecvResource .mBuffers [i]
431- = mBufferManager .gpu (runtime::ITensor::makeShape ({static_cast <int64_t >(mBufferEleSize )}), mDataType );
484+ = mBufferManager .gpu (runtime::ITensor::makeShape ({static_cast <int64_t >(mBufferEleSize )}), nvinfer1::DataType::kHALF ); // Overriding for now
485+ // mConcurrenceRecvResource.mBuffers[i]
486+ // = mBufferManager.gpu(runtime::ITensor::makeShape({static_cast<int64_t>(mBufferEleSize)}), mDataType);
432487 }
433488 mBufferManager .getStream ().synchronize ();
434489 }
@@ -437,12 +492,16 @@ void CacheTransBufferManager::allocateBuffer()
437492 for (size_t i = 0 ; i < mSendBufferCount ; i++)
438493 {
439494 mConcurrenceSendResource .mBuffers [i] = mBufferManager .gpuSync (
440- runtime::ITensor::makeShape ({static_cast <int64_t >(mBufferEleSize )}), mDataType );
495+ runtime::ITensor::makeShape ({static_cast <int64_t >(mBufferEleSize )}), nvinfer1::DataType::kHALF ); // Overriding for now
496+ // mConcurrenceSendResource.mBuffers[i] = mBufferManager.gpuSync(
497+ // runtime::ITensor::makeShape({static_cast<int64_t>(mBufferEleSize)}), mDataType);
441498 }
442499 for (size_t i = 0 ; i < mRecvBufferCount ; i++)
443500 {
444501 mConcurrenceRecvResource .mBuffers [i] = mBufferManager .gpuSync (
445- runtime::ITensor::makeShape ({static_cast <int64_t >(mBufferEleSize )}), mDataType );
502+ runtime::ITensor::makeShape ({static_cast <int64_t >(mBufferEleSize )}), nvinfer1::DataType::kHALF ); // Overriding for now
503+ // mConcurrenceRecvResource.mBuffers[i] = mBufferManager.gpuSync(
504+ // runtime::ITensor::makeShape({static_cast<int64_t>(mBufferEleSize)}), mDataType);
446505 }
447506 }
448507}
0 commit comments