Skip to content

Commit 7afe234

Browse files
Fix commit history
Signed-off-by: Timothy Gao <[email protected]>
1 parent ad6c8a0 commit 7afe234

File tree

9 files changed

+503
-47
lines changed

9 files changed

+503
-47
lines changed

cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp

Lines changed: 120 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
#include "tensorrt_llm/executor/cache_transmission/agent_utils/connection.h"
3131
#include "tensorrt_llm/executor/cache_transmission/cacheSplitConcat.h"
3232
#include "tensorrt_llm/executor/executor.h"
33+
#include "tensorrt_llm/kernels/unfusedAttentionKernels.h"
3334
#include "tensorrt_llm/runtime/iTensor.h"
3435
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
3536
#include <algorithm>
@@ -497,6 +498,102 @@ void CacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& sessio
497498
mpi::MpiComm::world().getRank(), "End the sending of KV cache for the request ID:%ld ", llmRequest.mRequestId);
498499
}
499500

501+
namespace {
502+
std::string dataTypeToString(nvinfer1::DataType type)
503+
{
504+
switch (type)
505+
{
506+
case nvinfer1::DataType::kINT64: return "INT64";
507+
case nvinfer1::DataType::kINT32: return "INT32";
508+
case nvinfer1::DataType::kFLOAT: return "FP32";
509+
case nvinfer1::DataType::kBF16: return "BF16";
510+
case nvinfer1::DataType::kHALF: return "FP16";
511+
case nvinfer1::DataType::kINT8: return "INT8";
512+
case nvinfer1::DataType::kUINT8: return "UINT8";
513+
case nvinfer1::DataType::kFP8: return "FP8";
514+
case nvinfer1::DataType::kBOOL: return "BOOL";
515+
default: return "UNKNOWN";
516+
}
517+
}
518+
} // namespace
519+
520+
void convertKVCachePrecisionVector( //TODO: Instead of iterating and calling kernal, concat into a single tensor and call the kernel once
521+
std::vector<runtime::ITensor::SharedPtr>& blocks,
522+
BaseCacheFormatter::CacheState const& srcConfig,
523+
BaseCacheFormatter::CacheState const& destConfig,
524+
runtime::BufferManager const& bufferManager)
525+
{
526+
auto const srcDataType = srcConfig.getDataType();
527+
auto const destDataType = destConfig.getDataType();
528+
auto stream = bufferManager.getStream().get();
529+
530+
TLLM_LOG_INFO("Converting %zu blocks from %s to %s", blocks.size(),
531+
dataTypeToString(srcDataType).c_str(), dataTypeToString(destDataType).c_str());
532+
533+
for (size_t i = 0; i < blocks.size(); ++i)
534+
{
535+
auto& block = blocks[i];
536+
auto blockShape = block->getShape();
537+
auto blockVolume = runtime::ITensor::volume(blockShape);
538+
539+
auto tempBuffer = bufferManager.gpu(blockShape, destDataType);
540+
541+
if (srcDataType == nvinfer1::DataType::kHALF && destDataType == nvinfer1::DataType::kFP8)
542+
{
543+
kernels::invokeConversion<__nv_fp8_e4m3, half>(
544+
reinterpret_cast<__nv_fp8_e4m3*>(tempBuffer->data()),
545+
reinterpret_cast<half const*>(block->data()),
546+
blockVolume,
547+
nullptr,
548+
stream
549+
);
550+
}
551+
else if (srcDataType == nvinfer1::DataType::kBF16 && destDataType == nvinfer1::DataType::kFP8)
552+
{
553+
kernels::invokeConversion<__nv_fp8_e4m3, __nv_bfloat16>(
554+
reinterpret_cast<__nv_fp8_e4m3*>(tempBuffer->data()),
555+
reinterpret_cast<__nv_bfloat16 const*>(block->data()),
556+
blockVolume,
557+
nullptr,
558+
stream
559+
);
560+
}
561+
else if (srcDataType == nvinfer1::DataType::kFP8 && destDataType == nvinfer1::DataType::kHALF)
562+
{
563+
// FP8 -> FP16 conversion for transmission buffer (Overriding for now)
564+
kernels::invokeConversion<half, __nv_fp8_e4m3>(
565+
reinterpret_cast<half*>(tempBuffer->data()),
566+
reinterpret_cast<__nv_fp8_e4m3 const*>(block->data()),
567+
blockVolume,
568+
nullptr,
569+
stream
570+
);
571+
}
572+
else if (srcDataType == nvinfer1::DataType::kBF16 && destDataType == nvinfer1::DataType::kHALF)
573+
{
574+
// BF16 -> FP16 conversion for transmission buffer (Overriding for now)
575+
kernels::invokeConversion<half, __nv_bfloat16>(
576+
reinterpret_cast<half*>(tempBuffer->data()),
577+
reinterpret_cast<__nv_bfloat16 const*>(block->data()),
578+
blockVolume,
579+
nullptr,
580+
stream
581+
);
582+
}
583+
else
584+
{
585+
TLLM_LOG_WARNING("Unsupported conversion %s -> %s, skipping",
586+
dataTypeToString(srcDataType).c_str(),
587+
dataTypeToString(destDataType).c_str());
588+
continue;
589+
}
590+
591+
block = std::move(tempBuffer);
592+
}
593+
594+
bufferManager.getStream().synchronize();
595+
}
596+
500597
void CacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& session)
501598
{
502599
NVTX3_SCOPED_RANGE(CacheFormatter_unformat);
@@ -744,7 +841,14 @@ void CacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& sess
744841
{
745842
preAllocRecvBuffer = mCacheTransBufferManager->getRecvBuffer(cacheBufferId);
746843
TLLM_CHECK(preAllocRecvBuffer != nullptr);
747-
TLLM_CHECK(preAllocRecvBuffer->getDataType() == dataType);
844+
845+
846+
// TLLM_CHECK(preAllocRecvBuffer->getDataType() == dataType); <-- KEY ASSERT CHANGED HERE
847+
848+
TLLM_LOG_INFO("============= RECEIVE ON GEN SIDE SETTINGS =============");
849+
TLLM_LOG_INFO("preAllocRecvBuffer data type: %s", dataTypeToString(preAllocRecvBuffer->getDataType()).c_str());
850+
TLLM_LOG_INFO("dataType: %s", dataTypeToString(dataType).c_str());
851+
TLLM_LOG_INFO("============= RECEIVE ON GEN SIDE SETTINGS =============");
748852
}
749853

750854
auto recvBufferFun = [&](int deviceId, size_t processIdx)
@@ -854,6 +958,15 @@ void CacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& sess
854958
{
855959
NVTX3_SCOPED_RANGE(formatInputConcatenate);
856960

961+
if (destConfig.getDataType() != selfConfig.getDataType() && common::getEnvEnableKVCachePrecisionConversion())
962+
{
963+
TLLM_LOG_INFO("WE ARE TAKING THIS PATH, CONVERSING.......");
964+
NVTX3_SCOPED_RANGE(kvCacheRecvPrecisionConv);
965+
convertKVCachePrecisionVector(recvSplitCaches, destConfig, selfConfig, bufferManager);
966+
TLLM_LOG_INFO("After conversion, recvSplitCaches[0] dtype: %s",
967+
dataTypeToString(recvSplitCaches[0]->getDataType()).c_str());
968+
}
969+
857970
executor::kv_cache::concatKvCacheV2Dispatch(
858971
recvSplitCaches, outputBuffersPerWindow, destConfig, selfConfig, selfIdx, bufferManager);
859972

@@ -873,11 +986,12 @@ void CacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& sess
873986

874987
[[nodiscard]] bool CacheFormatter::inquireSupport(CacheState const& selfConfig, CacheState const& destConfig) const
875988
{
876-
if (selfConfig.getDataType() != destConfig.getDataType())
877-
{
878-
TLLM_LOG_WARNING("CacheFormatter::inquireSupport: selfConfig.getDataType() != destConfig.getDataType()");
879-
return false;
880-
}
989+
// TODO: Change
990+
// if (selfConfig.getDataType() != destConfig.getDataType()) // Overriding for now
991+
// {
992+
// TLLM_LOG_WARNING("CacheFormatter::inquireSupport: selfConfig.getDataType() != destConfig.getDataType()");
993+
// return false;
994+
// }
881995

882996
std::unordered_set<SizeType32> setVecSelf{
883997
selfConfig.getModelConfig().mNbKvHeadsPerLayer.begin(), selfConfig.getModelConfig().mNbKvHeadsPerLayer.end()};

cpp/tensorrt_llm/batch_manager/cacheTransBuffer.cpp

Lines changed: 78 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -203,27 +203,62 @@ CacheTransBufferManager::CacheTransBufferManager(
203203
if (maxNumTokens.has_value())
204204
{
205205
TLLM_CHECK(maxNumTokens.value() % tokensPerBlock == 0);
206-
auto dataSize = common::getDTypeSize(mDataType);
207-
auto kvCacheByteSizePerTokenPerLayer = mCacheManager->getBlockManager().getBlockSize(0) / tokensPerBlock
208-
* (mCacheManager->getCacheType() == CacheType::kSELFKONLY ? 1 : 2) * dataSize;
206+
207+
// Transmission always uses FP16 for mixed precision support
208+
size_t transmissionDataSize = common::getDTypeSize(nvinfer1::DataType::kHALF);
209+
size_t storageDataSize = common::getDTypeSize(mDataType);
210+
211+
TLLM_LOG_INFO("=== Buffer Size Calculation Debug ===");
212+
TLLM_LOG_INFO("maxNumTokens: %ld", maxNumTokens.value());
213+
TLLM_LOG_INFO("tokensPerBlock: %ld", tokensPerBlock);
214+
215+
// getBlockSize() returns volume of [numKvHeads, tokensPerBlock, sizePerHead] for ONE cache (K or V)
216+
// We need to multiply by KV factor (2 for SELF, 1 for SELFKONLY) to get both K and V
217+
size_t blockSizeInElements = mCacheManager->getBlockManager().getBlockSize(0);
218+
size_t kvFactor = (mCacheManager->getCacheType() == CacheType::kSELFKONLY ? 1 : 2);
219+
220+
TLLM_LOG_INFO("getBlockSize(0): %ld elements (for ONE cache)", blockSizeInElements);
221+
TLLM_LOG_INFO("kvFactor: %ld (1=K only, 2=K+V)", kvFactor);
222+
TLLM_LOG_INFO("storageDataSize (mDataType): %ld bytes", storageDataSize);
223+
TLLM_LOG_INFO("transmissionDataSize (FP16): %ld bytes", transmissionDataSize);
224+
225+
// Calculate bytes per token for both K and V in transmission format (FP16)
226+
size_t bytesPerTokenInStorage = (blockSizeInElements * kvFactor * storageDataSize) / tokensPerBlock;
227+
size_t bytesPerTokenInTransmission = (blockSizeInElements * kvFactor * transmissionDataSize) / tokensPerBlock;
228+
auto kvCacheByteSizePerTokenPerLayer = bytesPerTokenInTransmission;
229+
230+
TLLM_LOG_INFO("bytesPerTokenInStorage: %ld", bytesPerTokenInStorage);
231+
TLLM_LOG_INFO("bytesPerTokenInTransmission: %ld", bytesPerTokenInTransmission);
232+
TLLM_LOG_INFO("kvCacheByteSizePerTokenPerLayer: %ld", kvCacheByteSizePerTokenPerLayer);
233+
209234
for (auto layerId = 0; layerId < mCacheManager->getBlockManager().getNumLayers(); layerId++)
210235
{
211236
auto poolIdx = mCacheManager->getBlockManager().getLayerPoolIdx(layerId);
212237
auto windowSize = static_cast<size_t>(mCacheManager->getBlockManager().getPoolWindowSize(poolIdx));
213-
auto alignedWindowSize = (windowSize + tokensPerBlock - 1) / tokensPerBlock * tokensPerBlock;
214-
auto validTokenNum = (alignedWindowSize < maxNumTokens.value() ? alignedWindowSize : maxNumTokens.value());
215-
// if windowSize % (tokensPerBlock) !=0
216-
validTokenNum += tokensPerBlock; // add one more block
217-
238+
auto validTokenNum = (windowSize < maxNumTokens.value() ? windowSize : maxNumTokens.value());
239+
240+
TLLM_LOG_INFO("Layer %d: poolIdx=%d, windowSize=%ld, validTokenNum=%ld",
241+
layerId, poolIdx, windowSize, validTokenNum);
242+
TLLM_LOG_INFO("Layer %d: adding %ld bytes (validTokenNum=%ld * kvCacheBytesPerToken=%ld)",
243+
layerId, validTokenNum * kvCacheByteSizePerTokenPerLayer, validTokenNum, kvCacheByteSizePerTokenPerLayer);
244+
218245
bufferSizeFromMaxNumToken += validTokenNum * kvCacheByteSizePerTokenPerLayer;
219246
}
247+
248+
TLLM_LOG_INFO("Total bufferSizeFromMaxNumToken: %ld bytes", bufferSizeFromMaxNumToken);
249+
220250
}
221251

222252
mTransferBufferSize
223253
= maxNumTokens.has_value() ? bufferSizeFromMaxNumToken : common::getEnvMemSizeForKVCacheTransferBuffer();
254+
255+
TLLM_LOG_INFO("HERE HERE HERE HERE HERE --------- ");
256+
TLLM_LOG_INFO("mTransferBufferSize:%ld", mTransferBufferSize);
257+
TLLM_LOG_INFO("HERE HERE HERE HERE HERE --------- ");
258+
224259
mOnlyUseDynamicBuffer = mTransferBufferSize == 0;
225260
mRecvBufferCount = common::getEnvRequestKVCacheConcurrent() ? common::getEnvKVCacheRecvBufferCount() : 1;
226-
mSendBufferCount = common::getEnvKVCacheSendMaxConcurrenceNum();
261+
mSendBufferCount = common::getEnvParallelCacheSend() ? common::getEnvKVCacheSendMaxConcurrenceNum() : 1;
227262
mUseFabricMemory = !(common::getEnvKVCacheTransferUseSyncBuffer() || common::getEnvKVCacheTransferUseAsyncBuffer())
228263
&& FabricMemory::supportFbaricMemory();
229264
if (mUseFabricMemory)
@@ -241,6 +276,7 @@ CacheTransBufferManager::CacheTransBufferManager(
241276
allocateBuffer();
242277
}
243278

279+
244280
size_t CacheTransBufferManager::preAllocBufferSize(size_t tokensPerBlock,
245281
std::map<SizeType32, SizeType32> const& cacheSizeBytesPerTokenPerWindow,
246282
std::optional<executor::CacheTransceiverConfig> const& cacheTransceiverConfig)
@@ -365,8 +401,13 @@ std::tuple<std::vector<runtime::ITensor::SharedPtr>, size_t, bool> CacheTransBuf
365401
}
366402
else
367403
{
404+
368405
retSplitCaches.push_back(bufferManagerToUse.gpu(
369-
runtime::ITensor::makeShape({static_cast<int64_t>(targetBufferEleSizes[i])}), mDataType));
406+
runtime::ITensor::makeShape({static_cast<int64_t>(targetBufferEleSizes[i])}), nvinfer1::DataType::kHALF)); // Overriding for now
407+
408+
409+
// retSplitCaches.push_back(bufferManagerToUse.gpu(
410+
// runtime::ITensor::makeShape({static_cast<int64_t>(targetBufferEleSizes[i])}), mDataType));
370411
}
371412
}
372413
TLLM_LOG_DEBUG("getOrAllocateBuffers bufferCoverTargetNum:%d", bufferCoverTargetNum);
@@ -384,8 +425,13 @@ std::tuple<std::vector<runtime::ITensor::SharedPtr>, size_t, bool> CacheTransBuf
384425
{
385426
for (int i = 0; i < targetNum; i++)
386427
{
428+
387429
retSplitCaches.push_back(bufferManagerToUse.gpu(
388-
runtime::ITensor::makeShape({static_cast<int64_t>(targetBufferEleSizes[i])}), mDataType));
430+
runtime::ITensor::makeShape({static_cast<int64_t>(targetBufferEleSizes[i])}), nvinfer1::DataType::kHALF));
431+
432+
433+
// retSplitCaches.push_back(bufferManagerToUse.gpu(
434+
// runtime::ITensor::makeShape({static_cast<int64_t>(targetBufferEleSizes[i])}), mDataType)); // Overriding for now
389435
}
390436
bufferCoverTargetNum = targetNum;
391437
}
@@ -399,36 +445,45 @@ void CacheTransBufferManager::allocateBuffer()
399445
{
400446
return;
401447
}
402-
mBufferEleSize = mTransferBufferSize / common::getDTypeSize(mDataType);
448+
mBufferEleSize = mTransferBufferSize / common::getDTypeSize(nvinfer1::DataType::kHALF); // Overriding for now
449+
// mBufferEleSize = mTransferBufferSize / common::getDTypeSize(mDataType);
403450
mConcurrenceSendResource.mBufferIndexFlag.resize(mSendBufferCount, 0);
404451
mConcurrenceRecvResource.mBufferIndexFlag.resize(mRecvBufferCount, 0);
405452
if (mUseFabricMemory)
406453
{
407454
mFabricMemory.reserve(mSendBufferCount + mRecvBufferCount);
408-
for (size_t i = 0; i < mSendBufferCount; i++)
455+
for (size_t i = 0; i < mSendBufferCount; i++)
409456
{
410457
mFabricMemory.emplace_back(std::make_unique<FabricMemory>(mTransferBufferSize));
411-
mConcurrenceSendResource.mBuffers[i] = runtime::ITensor::wrap(mFabricMemory.back()->getPtr(), mDataType,
458+
mConcurrenceSendResource.mBuffers[i] = runtime::ITensor::wrap(mFabricMemory.back()->getPtr(), nvinfer1::DataType::kHALF, // Overriding for now
412459
runtime::ITensor::makeShape({static_cast<int64_t>(mBufferEleSize)}), mBufferEleSize);
460+
// mConcurrenceSendResource.mBuffers[i] = runtime::ITensor::wrap(mFabricMemory.back()->getPtr(), mDataType,
461+
// runtime::ITensor::makeShape({static_cast<int64_t>(mBufferEleSize)}), mBufferEleSize);
413462
}
414463
for (size_t i = 0; i < mRecvBufferCount; i++)
415464
{
416465
mFabricMemory.emplace_back(std::make_unique<FabricMemory>(mTransferBufferSize));
417-
mConcurrenceRecvResource.mBuffers[i] = runtime::ITensor::wrap(mFabricMemory.back()->getPtr(), mDataType,
466+
mConcurrenceRecvResource.mBuffers[i] = runtime::ITensor::wrap(mFabricMemory.back()->getPtr(), nvinfer1::DataType::kHALF, // Overriding for now
418467
runtime::ITensor::makeShape({static_cast<int64_t>(mBufferEleSize)}), mBufferEleSize);
468+
// mConcurrenceRecvResource.mBuffers[i] = runtime::ITensor::wrap(mFabricMemory.back()->getPtr(), mDataType,
469+
// runtime::ITensor::makeShape({static_cast<int64_t>(mBufferEleSize)}), mBufferEleSize);
419470
}
420471
}
421472
else if (common::getEnvKVCacheTransferUseAsyncBuffer())
422473
{
423474
for (size_t i = 0; i < mSendBufferCount; i++)
424475
{
425476
mConcurrenceSendResource.mBuffers[i]
426-
= mBufferManager.gpu(runtime::ITensor::makeShape({static_cast<int64_t>(mBufferEleSize)}), mDataType);
477+
= mBufferManager.gpu(runtime::ITensor::makeShape({static_cast<int64_t>(mBufferEleSize)}), nvinfer1::DataType::kHALF); // Overriding for now
478+
// mConcurrenceSendResource.mBuffers[i]
479+
// = mBufferManager.gpu(runtime::ITensor::makeShape({static_cast<int64_t>(mBufferEleSize)}), mDataType);
427480
}
428481
for (size_t i = 0; i < mRecvBufferCount; i++)
429482
{
430483
mConcurrenceRecvResource.mBuffers[i]
431-
= mBufferManager.gpu(runtime::ITensor::makeShape({static_cast<int64_t>(mBufferEleSize)}), mDataType);
484+
= mBufferManager.gpu(runtime::ITensor::makeShape({static_cast<int64_t>(mBufferEleSize)}), nvinfer1::DataType::kHALF); // Overriding for now
485+
// mConcurrenceRecvResource.mBuffers[i]
486+
// = mBufferManager.gpu(runtime::ITensor::makeShape({static_cast<int64_t>(mBufferEleSize)}), mDataType);
432487
}
433488
mBufferManager.getStream().synchronize();
434489
}
@@ -437,12 +492,16 @@ void CacheTransBufferManager::allocateBuffer()
437492
for (size_t i = 0; i < mSendBufferCount; i++)
438493
{
439494
mConcurrenceSendResource.mBuffers[i] = mBufferManager.gpuSync(
440-
runtime::ITensor::makeShape({static_cast<int64_t>(mBufferEleSize)}), mDataType);
495+
runtime::ITensor::makeShape({static_cast<int64_t>(mBufferEleSize)}), nvinfer1::DataType::kHALF); // Overriding for now
496+
// mConcurrenceSendResource.mBuffers[i] = mBufferManager.gpuSync(
497+
// runtime::ITensor::makeShape({static_cast<int64_t>(mBufferEleSize)}), mDataType);
441498
}
442499
for (size_t i = 0; i < mRecvBufferCount; i++)
443500
{
444501
mConcurrenceRecvResource.mBuffers[i] = mBufferManager.gpuSync(
445-
runtime::ITensor::makeShape({static_cast<int64_t>(mBufferEleSize)}), mDataType);
502+
runtime::ITensor::makeShape({static_cast<int64_t>(mBufferEleSize)}), nvinfer1::DataType::kHALF); // Overriding for now
503+
// mConcurrenceRecvResource.mBuffers[i] = mBufferManager.gpuSync(
504+
// runtime::ITensor::makeShape({static_cast<int64_t>(mBufferEleSize)}), mDataType);
446505
}
447506
}
448507
}

cpp/tensorrt_llm/common/envUtils.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -450,4 +450,10 @@ bool getEnvDisableChunkedAttentionInGenPhase()
450450
return getBoolEnv("TRTLLM_DISABLE_CHUNKED_ATTENTION_IN_GEN_PHASE");
451451
}
452452

453+
bool getEnvEnableKVCachePrecisionConversion()
454+
{
455+
static bool const enableKVCachePrecisionConversion = getBoolEnv("TRTLLM_ENABLE_KVCACHE_PRECISION_CONVERSION");
456+
return enableKVCachePrecisionConversion;
457+
}
458+
453459
} // namespace tensorrt_llm::common

cpp/tensorrt_llm/common/envUtils.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,4 +136,7 @@ bool getEnvDisaggBenchmarkGenOnly();
136136
// Whether to disable the chunked-attention in the generation phase.
137137
bool getEnvDisableChunkedAttentionInGenPhase();
138138

139+
// Whether to enable KV cache precision conversion during transfers.
140+
bool getEnvEnableKVCachePrecisionConversion();
141+
139142
} // namespace tensorrt_llm::common

0 commit comments

Comments
 (0)