Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 120 additions & 6 deletions cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include "tensorrt_llm/executor/cache_transmission/agent_utils/connection.h"
#include "tensorrt_llm/executor/cache_transmission/cacheSplitConcat.h"
#include "tensorrt_llm/executor/executor.h"
#include "tensorrt_llm/kernels/unfusedAttentionKernels.h"
#include "tensorrt_llm/runtime/iTensor.h"
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
#include <algorithm>
Expand Down Expand Up @@ -497,6 +498,102 @@ void CacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& sessio
mpi::MpiComm::world().getRank(), "End the sending of KV cache for the request ID:%ld ", llmRequest.mRequestId);
}

namespace {
std::string dataTypeToString(nvinfer1::DataType type)
{
switch (type)
{
case nvinfer1::DataType::kINT64: return "INT64";
case nvinfer1::DataType::kINT32: return "INT32";
case nvinfer1::DataType::kFLOAT: return "FP32";
case nvinfer1::DataType::kBF16: return "BF16";
case nvinfer1::DataType::kHALF: return "FP16";
case nvinfer1::DataType::kINT8: return "INT8";
case nvinfer1::DataType::kUINT8: return "UINT8";
case nvinfer1::DataType::kFP8: return "FP8";
case nvinfer1::DataType::kBOOL: return "BOOL";
default: return "UNKNOWN";
}
}
} // namespace

void convertKVCachePrecisionVector( //TODO: Instead of iterating and calling kernal, concat into a single tensor and call the kernel once
std::vector<runtime::ITensor::SharedPtr>& blocks,
BaseCacheFormatter::CacheState const& srcConfig,
BaseCacheFormatter::CacheState const& destConfig,
runtime::BufferManager const& bufferManager)
{
auto const srcDataType = srcConfig.getDataType();
auto const destDataType = destConfig.getDataType();
auto stream = bufferManager.getStream().get();

TLLM_LOG_INFO("Converting %zu blocks from %s to %s", blocks.size(),
dataTypeToString(srcDataType).c_str(), dataTypeToString(destDataType).c_str());

for (size_t i = 0; i < blocks.size(); ++i)
{
auto& block = blocks[i];
auto blockShape = block->getShape();
auto blockVolume = runtime::ITensor::volume(blockShape);

auto tempBuffer = bufferManager.gpu(blockShape, destDataType);

if (srcDataType == nvinfer1::DataType::kHALF && destDataType == nvinfer1::DataType::kFP8)
{
kernels::invokeConversion<__nv_fp8_e4m3, half>(
reinterpret_cast<__nv_fp8_e4m3*>(tempBuffer->data()),
reinterpret_cast<half const*>(block->data()),
blockVolume,
nullptr,
stream
);
}
else if (srcDataType == nvinfer1::DataType::kBF16 && destDataType == nvinfer1::DataType::kFP8)
{
kernels::invokeConversion<__nv_fp8_e4m3, __nv_bfloat16>(
reinterpret_cast<__nv_fp8_e4m3*>(tempBuffer->data()),
reinterpret_cast<__nv_bfloat16 const*>(block->data()),
blockVolume,
nullptr,
stream
);
}
else if (srcDataType == nvinfer1::DataType::kFP8 && destDataType == nvinfer1::DataType::kHALF)
{
// FP8 -> FP16 conversion for transmission buffer (Overriding for now)
kernels::invokeConversion<half, __nv_fp8_e4m3>(
reinterpret_cast<half*>(tempBuffer->data()),
reinterpret_cast<__nv_fp8_e4m3 const*>(block->data()),
blockVolume,
nullptr,
stream
);
}
else if (srcDataType == nvinfer1::DataType::kBF16 && destDataType == nvinfer1::DataType::kHALF)
{
// BF16 -> FP16 conversion for transmission buffer (Overriding for now)
kernels::invokeConversion<half, __nv_bfloat16>(
reinterpret_cast<half*>(tempBuffer->data()),
reinterpret_cast<__nv_bfloat16 const*>(block->data()),
blockVolume,
nullptr,
stream
);
}
else
{
TLLM_LOG_WARNING("Unsupported conversion %s -> %s, skipping",
dataTypeToString(srcDataType).c_str(),
dataTypeToString(destDataType).c_str());
continue;
}

block = std::move(tempBuffer);
}

bufferManager.getStream().synchronize();
}

void CacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& session)
{
NVTX3_SCOPED_RANGE(CacheFormatter_unformat);
Expand Down Expand Up @@ -744,7 +841,14 @@ void CacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& sess
{
preAllocRecvBuffer = mCacheTransBufferManager->getRecvBuffer(cacheBufferId);
TLLM_CHECK(preAllocRecvBuffer != nullptr);
TLLM_CHECK(preAllocRecvBuffer->getDataType() == dataType);


// TLLM_CHECK(preAllocRecvBuffer->getDataType() == dataType); <-- KEY ASSERT CHANGED HERE

TLLM_LOG_INFO("============= RECEIVE ON GEN SIDE SETTINGS =============");
TLLM_LOG_INFO("preAllocRecvBuffer data type: %s", dataTypeToString(preAllocRecvBuffer->getDataType()).c_str());
TLLM_LOG_INFO("dataType: %s", dataTypeToString(dataType).c_str());
TLLM_LOG_INFO("============= RECEIVE ON GEN SIDE SETTINGS =============");
}

auto recvBufferFun = [&](int deviceId, size_t processIdx)
Expand Down Expand Up @@ -854,6 +958,15 @@ void CacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& sess
{
NVTX3_SCOPED_RANGE(formatInputConcatenate);

if (destConfig.getDataType() != selfConfig.getDataType() && common::getEnvEnableKVCachePrecisionConversion())
{
TLLM_LOG_INFO("WE ARE TAKING THIS PATH, CONVERSING.......");
NVTX3_SCOPED_RANGE(kvCacheRecvPrecisionConv);
convertKVCachePrecisionVector(recvSplitCaches, destConfig, selfConfig, bufferManager);
TLLM_LOG_INFO("After conversion, recvSplitCaches[0] dtype: %s",
dataTypeToString(recvSplitCaches[0]->getDataType()).c_str());
}

executor::kv_cache::concatKvCacheV2Dispatch(
recvSplitCaches, outputBuffersPerWindow, destConfig, selfConfig, selfIdx, bufferManager);

Expand All @@ -873,11 +986,12 @@ void CacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& sess

[[nodiscard]] bool CacheFormatter::inquireSupport(CacheState const& selfConfig, CacheState const& destConfig) const
{
if (selfConfig.getDataType() != destConfig.getDataType())
{
TLLM_LOG_WARNING("CacheFormatter::inquireSupport: selfConfig.getDataType() != destConfig.getDataType()");
return false;
}
// TODO: Change
// if (selfConfig.getDataType() != destConfig.getDataType()) // Overriding for now
// {
// TLLM_LOG_WARNING("CacheFormatter::inquireSupport: selfConfig.getDataType() != destConfig.getDataType()");
// return false;
// }

std::unordered_set<SizeType32> setVecSelf{
selfConfig.getModelConfig().mNbKvHeadsPerLayer.begin(), selfConfig.getModelConfig().mNbKvHeadsPerLayer.end()};
Expand Down
97 changes: 78 additions & 19 deletions cpp/tensorrt_llm/batch_manager/cacheTransBuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -203,27 +203,62 @@ CacheTransBufferManager::CacheTransBufferManager(
if (maxNumTokens.has_value())
{
TLLM_CHECK(maxNumTokens.value() % tokensPerBlock == 0);
auto dataSize = common::getDTypeSize(mDataType);
auto kvCacheByteSizePerTokenPerLayer = mCacheManager->getBlockManager().getBlockSize(0) / tokensPerBlock
* (mCacheManager->getCacheType() == CacheType::kSELFKONLY ? 1 : 2) * dataSize;

// Transmission always uses FP16 for mixed precision support
size_t transmissionDataSize = common::getDTypeSize(nvinfer1::DataType::kHALF);
size_t storageDataSize = common::getDTypeSize(mDataType);

TLLM_LOG_INFO("=== Buffer Size Calculation Debug ===");
TLLM_LOG_INFO("maxNumTokens: %ld", maxNumTokens.value());
TLLM_LOG_INFO("tokensPerBlock: %ld", tokensPerBlock);

// getBlockSize() returns volume of [numKvHeads, tokensPerBlock, sizePerHead] for ONE cache (K or V)
// We need to multiply by KV factor (2 for SELF, 1 for SELFKONLY) to get both K and V
size_t blockSizeInElements = mCacheManager->getBlockManager().getBlockSize(0);
size_t kvFactor = (mCacheManager->getCacheType() == CacheType::kSELFKONLY ? 1 : 2);

TLLM_LOG_INFO("getBlockSize(0): %ld elements (for ONE cache)", blockSizeInElements);
TLLM_LOG_INFO("kvFactor: %ld (1=K only, 2=K+V)", kvFactor);
TLLM_LOG_INFO("storageDataSize (mDataType): %ld bytes", storageDataSize);
TLLM_LOG_INFO("transmissionDataSize (FP16): %ld bytes", transmissionDataSize);

// Calculate bytes per token for both K and V in transmission format (FP16)
size_t bytesPerTokenInStorage = (blockSizeInElements * kvFactor * storageDataSize) / tokensPerBlock;
size_t bytesPerTokenInTransmission = (blockSizeInElements * kvFactor * transmissionDataSize) / tokensPerBlock;
auto kvCacheByteSizePerTokenPerLayer = bytesPerTokenInTransmission;

TLLM_LOG_INFO("bytesPerTokenInStorage: %ld", bytesPerTokenInStorage);
TLLM_LOG_INFO("bytesPerTokenInTransmission: %ld", bytesPerTokenInTransmission);
TLLM_LOG_INFO("kvCacheByteSizePerTokenPerLayer: %ld", kvCacheByteSizePerTokenPerLayer);

for (auto layerId = 0; layerId < mCacheManager->getBlockManager().getNumLayers(); layerId++)
{
auto poolIdx = mCacheManager->getBlockManager().getLayerPoolIdx(layerId);
auto windowSize = static_cast<size_t>(mCacheManager->getBlockManager().getPoolWindowSize(poolIdx));
auto alignedWindowSize = (windowSize + tokensPerBlock - 1) / tokensPerBlock * tokensPerBlock;
auto validTokenNum = (alignedWindowSize < maxNumTokens.value() ? alignedWindowSize : maxNumTokens.value());
// if windowSize % (tokensPerBlock) !=0
validTokenNum += tokensPerBlock; // add one more block

auto validTokenNum = (windowSize < maxNumTokens.value() ? windowSize : maxNumTokens.value());

TLLM_LOG_INFO("Layer %d: poolIdx=%d, windowSize=%ld, validTokenNum=%ld",
layerId, poolIdx, windowSize, validTokenNum);
TLLM_LOG_INFO("Layer %d: adding %ld bytes (validTokenNum=%ld * kvCacheBytesPerToken=%ld)",
layerId, validTokenNum * kvCacheByteSizePerTokenPerLayer, validTokenNum, kvCacheByteSizePerTokenPerLayer);

bufferSizeFromMaxNumToken += validTokenNum * kvCacheByteSizePerTokenPerLayer;
}

TLLM_LOG_INFO("Total bufferSizeFromMaxNumToken: %ld bytes", bufferSizeFromMaxNumToken);

}

mTransferBufferSize
= maxNumTokens.has_value() ? bufferSizeFromMaxNumToken : common::getEnvMemSizeForKVCacheTransferBuffer();

TLLM_LOG_INFO("HERE HERE HERE HERE HERE --------- ");
TLLM_LOG_INFO("mTransferBufferSize:%ld", mTransferBufferSize);
TLLM_LOG_INFO("HERE HERE HERE HERE HERE --------- ");

mOnlyUseDynamicBuffer = mTransferBufferSize == 0;
mRecvBufferCount = common::getEnvRequestKVCacheConcurrent() ? common::getEnvKVCacheRecvBufferCount() : 1;
mSendBufferCount = common::getEnvKVCacheSendMaxConcurrenceNum();
mSendBufferCount = common::getEnvParallelCacheSend() ? common::getEnvKVCacheSendMaxConcurrenceNum() : 1;
mUseFabricMemory = !(common::getEnvKVCacheTransferUseSyncBuffer() || common::getEnvKVCacheTransferUseAsyncBuffer())
&& FabricMemory::supportFbaricMemory();
if (mUseFabricMemory)
Expand All @@ -241,6 +276,7 @@ CacheTransBufferManager::CacheTransBufferManager(
allocateBuffer();
}


size_t CacheTransBufferManager::preAllocBufferSize(size_t tokensPerBlock,
std::map<SizeType32, SizeType32> const& cacheSizeBytesPerTokenPerWindow,
std::optional<executor::CacheTransceiverConfig> const& cacheTransceiverConfig)
Expand Down Expand Up @@ -365,8 +401,13 @@ std::tuple<std::vector<runtime::ITensor::SharedPtr>, size_t, bool> CacheTransBuf
}
else
{

retSplitCaches.push_back(bufferManagerToUse.gpu(
runtime::ITensor::makeShape({static_cast<int64_t>(targetBufferEleSizes[i])}), mDataType));
runtime::ITensor::makeShape({static_cast<int64_t>(targetBufferEleSizes[i])}), nvinfer1::DataType::kHALF)); // Overriding for now


// retSplitCaches.push_back(bufferManagerToUse.gpu(
// runtime::ITensor::makeShape({static_cast<int64_t>(targetBufferEleSizes[i])}), mDataType));
}
}
TLLM_LOG_DEBUG("getOrAllocateBuffers bufferCoverTargetNum:%d", bufferCoverTargetNum);
Expand All @@ -384,8 +425,13 @@ std::tuple<std::vector<runtime::ITensor::SharedPtr>, size_t, bool> CacheTransBuf
{
for (int i = 0; i < targetNum; i++)
{

retSplitCaches.push_back(bufferManagerToUse.gpu(
runtime::ITensor::makeShape({static_cast<int64_t>(targetBufferEleSizes[i])}), mDataType));
runtime::ITensor::makeShape({static_cast<int64_t>(targetBufferEleSizes[i])}), nvinfer1::DataType::kHALF));


// retSplitCaches.push_back(bufferManagerToUse.gpu(
// runtime::ITensor::makeShape({static_cast<int64_t>(targetBufferEleSizes[i])}), mDataType)); // Overriding for now
}
bufferCoverTargetNum = targetNum;
}
Expand All @@ -399,36 +445,45 @@ void CacheTransBufferManager::allocateBuffer()
{
return;
}
mBufferEleSize = mTransferBufferSize / common::getDTypeSize(mDataType);
mBufferEleSize = mTransferBufferSize / common::getDTypeSize(nvinfer1::DataType::kHALF); // Overriding for now
// mBufferEleSize = mTransferBufferSize / common::getDTypeSize(mDataType);
mConcurrenceSendResource.mBufferIndexFlag.resize(mSendBufferCount, 0);
mConcurrenceRecvResource.mBufferIndexFlag.resize(mRecvBufferCount, 0);
if (mUseFabricMemory)
{
mFabricMemory.reserve(mSendBufferCount + mRecvBufferCount);
for (size_t i = 0; i < mSendBufferCount; i++)
for (size_t i = 0; i < mSendBufferCount; i++)
{
mFabricMemory.emplace_back(std::make_unique<FabricMemory>(mTransferBufferSize));
mConcurrenceSendResource.mBuffers[i] = runtime::ITensor::wrap(mFabricMemory.back()->getPtr(), mDataType,
mConcurrenceSendResource.mBuffers[i] = runtime::ITensor::wrap(mFabricMemory.back()->getPtr(), nvinfer1::DataType::kHALF, // Overriding for now
runtime::ITensor::makeShape({static_cast<int64_t>(mBufferEleSize)}), mBufferEleSize);
// mConcurrenceSendResource.mBuffers[i] = runtime::ITensor::wrap(mFabricMemory.back()->getPtr(), mDataType,
// runtime::ITensor::makeShape({static_cast<int64_t>(mBufferEleSize)}), mBufferEleSize);
}
for (size_t i = 0; i < mRecvBufferCount; i++)
{
mFabricMemory.emplace_back(std::make_unique<FabricMemory>(mTransferBufferSize));
mConcurrenceRecvResource.mBuffers[i] = runtime::ITensor::wrap(mFabricMemory.back()->getPtr(), mDataType,
mConcurrenceRecvResource.mBuffers[i] = runtime::ITensor::wrap(mFabricMemory.back()->getPtr(), nvinfer1::DataType::kHALF, // Overriding for now
runtime::ITensor::makeShape({static_cast<int64_t>(mBufferEleSize)}), mBufferEleSize);
// mConcurrenceRecvResource.mBuffers[i] = runtime::ITensor::wrap(mFabricMemory.back()->getPtr(), mDataType,
// runtime::ITensor::makeShape({static_cast<int64_t>(mBufferEleSize)}), mBufferEleSize);
}
}
else if (common::getEnvKVCacheTransferUseAsyncBuffer())
{
for (size_t i = 0; i < mSendBufferCount; i++)
{
mConcurrenceSendResource.mBuffers[i]
= mBufferManager.gpu(runtime::ITensor::makeShape({static_cast<int64_t>(mBufferEleSize)}), mDataType);
= mBufferManager.gpu(runtime::ITensor::makeShape({static_cast<int64_t>(mBufferEleSize)}), nvinfer1::DataType::kHALF); // Overriding for now
// mConcurrenceSendResource.mBuffers[i]
// = mBufferManager.gpu(runtime::ITensor::makeShape({static_cast<int64_t>(mBufferEleSize)}), mDataType);
}
for (size_t i = 0; i < mRecvBufferCount; i++)
{
mConcurrenceRecvResource.mBuffers[i]
= mBufferManager.gpu(runtime::ITensor::makeShape({static_cast<int64_t>(mBufferEleSize)}), mDataType);
= mBufferManager.gpu(runtime::ITensor::makeShape({static_cast<int64_t>(mBufferEleSize)}), nvinfer1::DataType::kHALF); // Overriding for now
// mConcurrenceRecvResource.mBuffers[i]
// = mBufferManager.gpu(runtime::ITensor::makeShape({static_cast<int64_t>(mBufferEleSize)}), mDataType);
}
mBufferManager.getStream().synchronize();
}
Expand All @@ -437,12 +492,16 @@ void CacheTransBufferManager::allocateBuffer()
for (size_t i = 0; i < mSendBufferCount; i++)
{
mConcurrenceSendResource.mBuffers[i] = mBufferManager.gpuSync(
runtime::ITensor::makeShape({static_cast<int64_t>(mBufferEleSize)}), mDataType);
runtime::ITensor::makeShape({static_cast<int64_t>(mBufferEleSize)}), nvinfer1::DataType::kHALF); // Overriding for now
// mConcurrenceSendResource.mBuffers[i] = mBufferManager.gpuSync(
// runtime::ITensor::makeShape({static_cast<int64_t>(mBufferEleSize)}), mDataType);
}
for (size_t i = 0; i < mRecvBufferCount; i++)
{
mConcurrenceRecvResource.mBuffers[i] = mBufferManager.gpuSync(
runtime::ITensor::makeShape({static_cast<int64_t>(mBufferEleSize)}), mDataType);
runtime::ITensor::makeShape({static_cast<int64_t>(mBufferEleSize)}), nvinfer1::DataType::kHALF); // Overriding for now
// mConcurrenceRecvResource.mBuffers[i] = mBufferManager.gpuSync(
// runtime::ITensor::makeShape({static_cast<int64_t>(mBufferEleSize)}), mDataType);
}
}
}
Expand Down
6 changes: 6 additions & 0 deletions cpp/tensorrt_llm/common/envUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -450,4 +450,10 @@ bool getEnvDisableChunkedAttentionInGenPhase()
return getBoolEnv("TRTLLM_DISABLE_CHUNKED_ATTENTION_IN_GEN_PHASE");
}

bool getEnvEnableKVCachePrecisionConversion()
{
static bool const enableKVCachePrecisionConversion = getBoolEnv("TRTLLM_ENABLE_KVCACHE_PRECISION_CONVERSION");
return enableKVCachePrecisionConversion;
}

} // namespace tensorrt_llm::common
3 changes: 3 additions & 0 deletions cpp/tensorrt_llm/common/envUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,4 +136,7 @@ bool getEnvDisaggBenchmarkGenOnly();
// Whether to disable the chunked-attention in the generation phase.
bool getEnvDisableChunkedAttentionInGenPhase();

// Whether to enable KV cache precision conversion during transfers.
bool getEnvEnableKVCachePrecisionConversion();

} // namespace tensorrt_llm::common
Loading
Loading