From 00aabf92d761b73e03f0f3966bc1d7548ced2651 Mon Sep 17 00:00:00 2001 From: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com> Date: Wed, 23 Jul 2025 22:53:08 -0700 Subject: [PATCH 1/6] Rename data -> cache Signed-off-by: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com> --- .../batch_manager/cacheTransceiver.h | 4 +-- .../batch_manager/cacheTransceiver.cpp | 27 +++++++++--------- .../batch_manager/dataTransceiver.cpp | 28 +++++++++---------- .../batch_manager/dataTransceiverImpl.cpp | 24 ++++++++-------- .../batch_manager/dataTransceiverImpl.h | 8 +++--- .../agent_utils/connection.cpp | 6 ++-- .../agent_utils/connection.h | 4 +-- .../multi_gpu/cacheTransceiverTest.cpp | 18 ++++++------ 8 files changed, 59 insertions(+), 60 deletions(-) diff --git a/cpp/include/tensorrt_llm/batch_manager/cacheTransceiver.h b/cpp/include/tensorrt_llm/batch_manager/cacheTransceiver.h index c39fee6f940..91e096264ca 100644 --- a/cpp/include/tensorrt_llm/batch_manager/cacheTransceiver.h +++ b/cpp/include/tensorrt_llm/batch_manager/cacheTransceiver.h @@ -110,9 +110,9 @@ class CacheTransceiver : public BaseCacheTransceiver void setContextState(LlmRequest* llmRequest); - std::unique_ptr mDataResponder; + std::unique_ptr mCacheSender; std::unique_ptr mDataRequester; - std::vector>> mResponderFutures; + std::vector>> mSenderFutures; std::vector>> mRequesterFutures; mpi::MpiComm const *mMpiGroupComm{nullptr}, *mMpiWorldComm{nullptr}; std::shared_ptr mMpiGroupTensorParaComm, mMpiGroupPipeParaComm, mMpiGroupDataComm, diff --git a/cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp b/cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp index 48ac605a3fd..d30486b9edc 100644 --- a/cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp +++ b/cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp @@ -195,10 +195,10 @@ CacheTransceiver::CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheMa auto makeFormatter = [cacheManager, isMLA, this]() { return createCacheFormatter(cacheManager, mCacheTransBufferManager.get(), isMLA); }; - mDataResponder = std::make_unique( - std::make_unique(mManager.get(), *mCacheState, worldConfig.getRank(), makeFormatter())); + mCacheSender = std::make_unique( + std::make_unique(mManager.get(), *mCacheState, worldConfig.getRank(), makeFormatter())); mDataRequester = std::make_unique( - std::make_unique(mManager.get(), *mCacheState, worldConfig.getRank(), makeFormatter())); + std::make_unique(mManager.get(), *mCacheState, worldConfig.getRank(), makeFormatter())); initializeCommState(); } @@ -214,7 +214,7 @@ CacheTransceiver::~CacheTransceiver() void CacheTransceiver::initializeCommState() { - mCommState = std::addressof(mDataResponder->getCommState()); + mCommState = std::addressof(mCacheSender->getCommState()); } void CacheTransceiver::setContextState(LlmRequest* llmRequest) @@ -250,8 +250,8 @@ void CacheTransceiver::respondAndSendAsync(LlmRequest* llmRequest) return; } setContextState(llmRequest); - auto future = mDataResponder->respondAndSendAsync(*llmRequest); - mResponderFutures.emplace_back(llmRequest, std::move(future)); + auto future = mCacheSender->respondAndSendAsync(*llmRequest); + mSenderFutures.emplace_back(llmRequest, std::move(future)); } void CacheTransceiver::respondAndSendLayerWise( @@ -266,8 +266,8 @@ void CacheTransceiver::respondAndSendLayerWise( llmRequest->setState(LlmRequestState::kDISAGG_CONTEXT_INIT_AND_TRANS); setContextState(llmRequest.get()); - auto future = mDataResponder->respondAndSendAsync(*llmRequest); - mResponderFutures.emplace_back(llmRequest.get(), std::move(future)); + auto future = mCacheSender->respondAndSendAsync(*llmRequest); + mSenderFutures.emplace_back(llmRequest.get(), std::move(future)); } } @@ -373,7 +373,7 @@ void CacheTransceiver::checkContextTransferStatus(std::optional const& atLe bool blockAll = !atLeastRequestNum.has_value(); auto syncComm = mCacheState->getParallelConfig().mEnableAttentionDP ? mMpiGroupTPInDPComm : mMpiGroupTensorParaComm; std::vector contextCompleteRequestIds; - for (auto&& [request, future] : mResponderFutures) + for (auto&& [request, future] : mSenderFutures) { if (future.wait_for(std::chrono::milliseconds(0)) == std::future_status::ready) { @@ -413,23 +413,22 @@ void CacheTransceiver::checkContextTransferStatus(std::optional const& atLe // Make sure there are at least atLeastRequestNum requests in toCompleteIdSet. // This will preserve the order of insertion for KVCache transfer requests. - for (auto it = mResponderFutures.begin(); - atLeastRequestNum.value_or(0) > static_cast(toCompleteIdSet.size()) && it != mResponderFutures.end(); - ++it) + for (auto it = mSenderFutures.begin(); + atLeastRequestNum.value_or(0) > static_cast(toCompleteIdSet.size()) && it != mSenderFutures.end(); ++it) { auto& [request, future] = *it; toCompleteIdSet.insert(request->mRequestId); } // Complete all the requests in toCompleteIdSet - for (auto it = mResponderFutures.begin(); it != mResponderFutures.end();) + for (auto it = mSenderFutures.begin(); it != mSenderFutures.end();) { auto& [request, future] = *it; if (blockAll || (toCompleteIdSet.find(request->mRequestId) != toCompleteIdSet.end())) { future.get(); request->setState(LlmRequestState::kDISAGG_CONTEXT_COMPLETE); - it = mResponderFutures.erase(it); + it = mSenderFutures.erase(it); } else { diff --git a/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp b/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp index 522ec80f84a..572a194598d 100644 --- a/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp +++ b/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp @@ -148,14 +148,14 @@ class DataResponder::Impl auto future = promise.get_future(); { { - std::unique_lock lkResp(mResponderMutex); + std::unique_lock lkResp(mSenderMutex); mReadyResponses.emplace( llmRequest.mRequestId, Response{std::addressof(llmRequest), std::move(promise)}); } std::unique_lock lkCond(mCondMutex); mAnyReady = true; } - mResponderCv.notify_all(); + mSenderCv.notify_all(); return future; } @@ -208,7 +208,7 @@ class DataResponder::Impl if (!mAnyReady) { std::unique_lock lk(mCondMutex); - mResponderCv.wait(lk, [this]() { return (mAnyReady || mTerminate); }); + mSenderCv.wait(lk, [this]() { return (mAnyReady || mTerminate); }); } if (mTerminate) { @@ -263,7 +263,7 @@ class DataResponder::Impl "mReadyResponses size is: %zu. mpi rank :%d ", mCurrentRequest.value(), mReadyResponses.size(), mpi::MpiComm::world().getRank()); std::unique_lock lk(mCondMutex); - mResponderCv.wait(lk, [this]() { return (mAnyReady || mTerminate); }); + mSenderCv.wait(lk, [this]() { return (mAnyReady || mTerminate); }); } } } @@ -285,13 +285,13 @@ class DataResponder::Impl } // We don't have to wait for the future. If another thread is sending data, it won't pay attention // to the terminate flag. - mResponderCv.notify_all(); + mSenderCv.notify_all(); } void removeResponse(std::map::iterator it) { { - std::unique_lock lkResp(mResponderMutex); + std::unique_lock lkResp(mSenderMutex); mReadyResponses.erase(it); } if (mReadyResponses.empty()) @@ -313,16 +313,16 @@ class DataResponder::Impl [[nodiscard]] std::map::iterator getCurrentResponse() { - std::unique_lock lk(mResponderMutex); + std::unique_lock lk(mSenderMutex); return mReadyResponses.find(getCurrentRequestId()); } private: std::optional mCurrentRequest; std::map mReadyResponses; - std::mutex mResponderMutex, mCondMutex; + std::mutex mSenderMutex, mCondMutex; std::atomic mAnyReady{false}, mTerminate{false}; - std::condition_variable mResponderCv; + std::condition_variable mSenderCv; std::future mResponseFuture; std::unique_ptr mSender; std::unordered_map mRemainSendCount; @@ -333,9 +333,9 @@ class DataRequester::Impl { public: Impl(std::unique_ptr receiver) - : mReceiver{std::move(receiver)} + : mCacheReceiver{std::move(receiver)} { - TLLM_CHECK(mReceiver); + TLLM_CHECK(mCacheReceiver); TLLM_CUDA_CHECK(cudaGetDevice(&mDeviceId)); } @@ -400,8 +400,8 @@ class DataRequester::Impl llmRequest.getContextPhaseParams().value().getReqId()); llmRequest.setKvCacheTransferStart(std::chrono::steady_clock::now()); TLLM_CUDA_CHECK(cudaSetDevice(mDeviceId)); - auto session = mReceiver->sendRequestInfo(llmRequest); - mReceiver->receiveSync(session); + auto session = mCacheReceiver->sendRequestInfo(llmRequest); + mCacheReceiver->receiveSync(session); llmRequest.setKvCacheTransferEnd(std::chrono::steady_clock::now()); TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(), @@ -507,7 +507,7 @@ class DataRequester::Impl } } - std::unique_ptr mReceiver; + std::unique_ptr mCacheReceiver; int mDeviceId{-1}; std::vector> mRequestFutures; diff --git a/cpp/tensorrt_llm/batch_manager/dataTransceiverImpl.cpp b/cpp/tensorrt_llm/batch_manager/dataTransceiverImpl.cpp index 1a5c7fab4dd..0fa3ca16f54 100644 --- a/cpp/tensorrt_llm/batch_manager/dataTransceiverImpl.cpp +++ b/cpp/tensorrt_llm/batch_manager/dataTransceiverImpl.cpp @@ -47,7 +47,7 @@ static fs::path getTransferOutputPath(char const* tag) return {}; } -DataSenderImpl::DataSenderImpl(executor::kv_cache::ConnectionManager* manager, +CacheSenderImpl::CacheSenderImpl(executor::kv_cache::ConnectionManager* manager, executor::kv_cache::CacheState selfCacheState, SizeType32 selfIndex, std::unique_ptr formatter) : mManager{manager} , mSelfState{std::move(selfCacheState), executor::kv_cache::CommState{manager->getCommState()}} @@ -58,7 +58,7 @@ DataSenderImpl::DataSenderImpl(executor::kv_cache::ConnectionManager* manager, TLLM_CHECK(mManager->getCommState().getSelfIdx() == selfIndex); } -[[nodiscard]] RequestInfo DataSenderImpl::recvRequestInfo() +[[nodiscard]] RequestInfo CacheSenderImpl::recvRequestInfo() { using DataContext = tensorrt_llm::executor::kv_cache::DataContext; auto* agentConnectionManager = dynamic_cast(mManager); @@ -111,7 +111,7 @@ DataSenderImpl::DataSenderImpl(executor::kv_cache::ConnectionManager* manager, return info; } -void DataSenderImpl::sendSync(LlmRequest const& llmRequest) +void CacheSenderImpl::sendSync(LlmRequest const& llmRequest) { auto it = mRequestToSession.find(llmRequest.mRequestId); TLLM_CHECK(it != mRequestToSession.end()); @@ -120,24 +120,24 @@ void DataSenderImpl::sendSync(LlmRequest const& llmRequest) mFormatter->format(session); } -[[nodiscard]] executor::kv_cache::CommState const& DataSenderImpl::getCommState() const +[[nodiscard]] executor::kv_cache::CommState const& CacheSenderImpl::getCommState() const { return mSelfState.getCommState().value(); } -void DataSenderImpl::setCommState(executor::kv_cache::CommState commState) +void CacheSenderImpl::setCommState(executor::kv_cache::CommState commState) { mSelfState.setCommState(std::move(commState)); } -[[nodiscard]] size_t DataSenderImpl::getCounterpartsCount(LlmRequest::RequestIdType requestId) const +[[nodiscard]] size_t CacheSenderImpl::getCounterpartsCount(LlmRequest::RequestIdType requestId) const { auto it = mRequestToSession.find(requestId); TLLM_CHECK(it != mRequestToSession.end()); return it->second.getConnections().size(); } -void DataSenderImpl::release(LlmRequest::RequestIdType requestId) +void CacheSenderImpl::release(LlmRequest::RequestIdType requestId) { auto it = mRequestToSession.find(requestId); TLLM_CHECK(it != mRequestToSession.end()); @@ -156,7 +156,7 @@ void DataSenderImpl::release(LlmRequest::RequestIdType requestId) mRequestToSession.erase(it); } -DataReceiverImpl::DataReceiverImpl(executor::kv_cache::ConnectionManager* manager, +CacheReceiverImpl::CacheReceiverImpl(executor::kv_cache::ConnectionManager* manager, executor::kv_cache::CacheState selfCacheState, SizeType32 selfIndex, std::unique_ptr formatter) : mManager{manager} , mSelfState{std::move(selfCacheState), executor::kv_cache::CommState{manager->getCommState()}} @@ -167,7 +167,7 @@ DataReceiverImpl::DataReceiverImpl(executor::kv_cache::ConnectionManager* manage TLLM_CHECK(mFormatter); } -TransferSession DataReceiverImpl::sendRequestInfo(LlmRequest const& llmRequest) +TransferSession CacheReceiverImpl::sendRequestInfo(LlmRequest const& llmRequest) { uint64_t requestId = llmRequest.getContextPhaseParams().value().getReqId(); auto const& contextState = llmRequest.getDataTransceiverState(); @@ -233,7 +233,7 @@ TransferSession DataReceiverImpl::sendRequestInfo(LlmRequest const& llmRequest) contextState, resource->mBufferManager, &llmRequest, !common::getEnvKVCacheTransferOutputPath().empty()); } -void DataReceiverImpl::receiveSync(TransferSession& session) +void CacheReceiverImpl::receiveSync(TransferSession& session) { mFormatter->unformat(session); if (!common::getEnvKVCacheTransferOutputPath().empty()) @@ -250,7 +250,7 @@ void DataReceiverImpl::receiveSync(TransferSession& session) } } -void DataReceiverImpl::sendRequestInfo(executor::kv_cache::Connection const* connection, RequestInfo const& info) +void CacheReceiverImpl::sendRequestInfo(executor::kv_cache::Connection const* connection, RequestInfo const& info) { std::ostringstream oss; RequestInfo::serialize(info, oss); @@ -262,7 +262,7 @@ void DataReceiverImpl::sendRequestInfo(executor::kv_cache::Connection const* con connection->send(executor::kv_cache::DataContext{kINFO_TAG}, serializedInfo.data(), infoSize); } -std::unique_ptr const& DataReceiverImpl::getReceiveCacheResource( +std::unique_ptr const& CacheReceiverImpl::getReceiveCacheResource( LlmRequest const& llmRequest) { std::scoped_lock lock(mProcessIoResouceMutex); diff --git a/cpp/tensorrt_llm/batch_manager/dataTransceiverImpl.h b/cpp/tensorrt_llm/batch_manager/dataTransceiverImpl.h index 2f277f14fff..2e2e320c72e 100644 --- a/cpp/tensorrt_llm/batch_manager/dataTransceiverImpl.h +++ b/cpp/tensorrt_llm/batch_manager/dataTransceiverImpl.h @@ -42,12 +42,12 @@ struct TransceiverTag using BaseCacheFormatter = kv_cache_manager::BaseCacheFormatter; -class DataSenderImpl : public DataSender, public TransceiverTag +class CacheSenderImpl : public DataSender, public TransceiverTag { public: using SizeType32 = tensorrt_llm::runtime::SizeType32; - DataSenderImpl(executor::kv_cache::ConnectionManager* manager, executor::kv_cache::CacheState selfCacheState, + CacheSenderImpl(executor::kv_cache::ConnectionManager* manager, executor::kv_cache::CacheState selfCacheState, SizeType32 selfIndex, std::unique_ptr formatter); [[nodiscard]] RequestInfo recvRequestInfo() override; @@ -72,12 +72,12 @@ class DataSenderImpl : public DataSender, public TransceiverTag std::ofstream mMeasuresFile; }; -class DataReceiverImpl : public DataReceiver, public TransceiverTag +class CacheReceiverImpl : public DataReceiver, public TransceiverTag { public: using SizeType32 = tensorrt_llm::runtime::SizeType32; - DataReceiverImpl(executor::kv_cache::ConnectionManager* manager, executor::kv_cache::CacheState selfCacheState, + CacheReceiverImpl(executor::kv_cache::ConnectionManager* manager, executor::kv_cache::CacheState selfCacheState, SizeType32 selfIndex, std::unique_ptr formatter); TransferSession sendRequestInfo(LlmRequest const& llmRequest) override; diff --git a/cpp/tensorrt_llm/executor/cache_transmission/agent_utils/connection.cpp b/cpp/tensorrt_llm/executor/cache_transmission/agent_utils/connection.cpp index c64d85e1523..851d116eed6 100644 --- a/cpp/tensorrt_llm/executor/cache_transmission/agent_utils/connection.cpp +++ b/cpp/tensorrt_llm/executor/cache_transmission/agent_utils/connection.cpp @@ -81,7 +81,7 @@ void AgentConnection::send(DataContext const& ctx, void const* data, size_t size MemoryDesc srcDesc{ reinterpret_cast(data), size, static_cast(mAgentConnectionManager->getDeviceId())}; MemoryDescs srcDescs{MemoryType::kVRAM, {srcDesc}}; - auto dstBaseDesc = mSenderState.mReceiverBufferDesc; + auto dstBaseDesc = mSenderState.mCacheReceiverBufferDesc; MemoryDesc dstDesc{dstBaseDesc.getAddr() + (mSenderState.validSegmentIdx * size), size, dstBaseDesc.getDeviceId()}; TLLM_LOG_DEBUG( "send dstDesc: %p, size: %ld ,validSegmentIdx: %ld", dstDesc.getAddr(), size, mSenderState.validSegmentIdx); @@ -137,9 +137,9 @@ void AgentConnection::sendRequestAndBufferInfo( mAgentConnectionManager->getAgent()->notifySyncMessage(mRemoteAgentName, ss.str()); } -void AgentConnection::setSenderState(MemoryDesc mReceiverBufferDesc, int validSegmentIdx) +void AgentConnection::setSenderState(MemoryDesc mCacheReceiverBufferDesc, int validSegmentIdx) { - mSenderState.mReceiverBufferDesc = mReceiverBufferDesc; + mSenderState.mCacheReceiverBufferDesc = mCacheReceiverBufferDesc; mSenderState.validSegmentIdx = validSegmentIdx; } diff --git a/cpp/tensorrt_llm/executor/cache_transmission/agent_utils/connection.h b/cpp/tensorrt_llm/executor/cache_transmission/agent_utils/connection.h index 8f73631d1e8..714730c2c29 100644 --- a/cpp/tensorrt_llm/executor/cache_transmission/agent_utils/connection.h +++ b/cpp/tensorrt_llm/executor/cache_transmission/agent_utils/connection.h @@ -175,7 +175,7 @@ class AgentConnection : public Connection void recv(DataContext const& ctx, void* data, size_t size) const override; void sendRequestAndBufferInfo( batch_manager::RequestInfo& requestInfo, std::optional cacheBufferId, int validConnectionIdx); - void setSenderState(MemoryDesc mReceiverBufferDesc, int valideSegmentIdx); + void setSenderState(MemoryDesc mCacheReceiverBufferDesc, int valideSegmentIdx); [[nodiscard]] std::optional getCacheBufferId() const; void setHasLoadRemoteAgent(bool hasLoadRemoteAgent); [[nodiscard]] bool hasLoadRemoteAgent() const; @@ -186,7 +186,7 @@ class AgentConnection : public Connection struct SenderState { - MemoryDesc mReceiverBufferDesc{nullptr, 0, 0}; + MemoryDesc mCacheReceiverBufferDesc{nullptr, 0, 0}; int validSegmentIdx{0}; SenderState() = default; }; diff --git a/cpp/tests/unit_tests/multi_gpu/cacheTransceiverTest.cpp b/cpp/tests/unit_tests/multi_gpu/cacheTransceiverTest.cpp index 4b513ae57f9..acd3304fcf6 100644 --- a/cpp/tests/unit_tests/multi_gpu/cacheTransceiverTest.cpp +++ b/cpp/tests/unit_tests/multi_gpu/cacheTransceiverTest.cpp @@ -394,14 +394,14 @@ class SymmetricalCacheTest : public ::testing::Test // NOLINT(cppcoreguidelines- mCacheTransBufferManager = std::make_unique(mManager.get(), maxNumTokens); if (isSender) { - mResponder = std::make_unique( - std::make_unique(mConnectionManager.get(), *mCacheState, mlocalRank, + mSender = std::make_unique( + std::make_unique(mConnectionManager.get(), *mCacheState, mlocalRank, std::make_unique(mManager.get(), mCacheTransBufferManager.get()))); } else { mRequester = std::make_unique( - std::make_unique(mConnectionManager.get(), *mCacheState, mlocalRank, + std::make_unique(mConnectionManager.get(), *mCacheState, mlocalRank, std::make_unique(mManager.get(), mCacheTransBufferManager.get()))); } } @@ -432,7 +432,7 @@ class SymmetricalCacheTest : public ::testing::Test // NOLINT(cppcoreguidelines- // fill cache with tokens (= request length), for reuse test TLLM_CUDA_CHECK(cudaMemset(block.data(), llmRequest->getPromptLen(), block.getSizeInBytes())); } - mFutures.emplace_back(mResponder->respondAndSendAsync(*llmRequest)); + mFutures.emplace_back(mSender->respondAndSendAsync(*llmRequest)); } else { @@ -457,7 +457,7 @@ class SymmetricalCacheTest : public ::testing::Test // NOLINT(cppcoreguidelines- SizeType32 mMaxNumSequences{}; std::unique_ptr mManager; std::unique_ptr mCacheTransBufferManager; - std::unique_ptr mResponder; + std::unique_ptr mSender; std::unique_ptr mRequester; std::unique_ptr mCacheState; std::unique_ptr mContextCommState; @@ -764,12 +764,12 @@ class AsymmetricalCacheTest : public ::testing::TestWithParam(std::make_unique( + mSender = std::make_unique(std::make_unique( mConnectionManager.get(), *mCacheState, mRankInInstance, makeFormatter())); } else { - mRequester = std::make_unique(std::make_unique( + mRequester = std::make_unique(std::make_unique( mConnectionManager.get(), *mCacheState, mRankInInstance, makeFormatter())); } @@ -904,7 +904,7 @@ class AsymmetricalCacheTest : public ::testing::TestWithParamrespondAndSendAsync(*llmRequest); + auto future = mSender->respondAndSendAsync(*llmRequest); return future; } @@ -1112,7 +1112,7 @@ class AsymmetricalCacheTest : public ::testing::TestWithParam mManager; std::unique_ptr mCacheTransBufferManager; - std::unique_ptr mResponder; + std::unique_ptr mSender; std::unique_ptr mRequester; std::unique_ptr mCacheState; std::unique_ptr mContextCacheState; From 74f5b63f27ce8f8aa68eb8b55bd30369945dc59e Mon Sep 17 00:00:00 2001 From: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com> Date: Thu, 24 Jul 2025 19:22:41 -0700 Subject: [PATCH 2/6] Refactor dataTransceiver classes Signed-off-by: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com> --- .../batch_manager/cacheTransceiver.h | 8 +- cpp/tensorrt_llm/batch_manager/CMakeLists.txt | 1 - .../batch_manager/cacheFormatter.h | 191 +++++++++- .../batch_manager/cacheTransceiver.cpp | 16 +- .../batch_manager/dataTransceiver.cpp | 335 ++++++++++++++---- .../batch_manager/dataTransceiver.h | 191 +++------- .../batch_manager/dataTransceiverImpl.cpp | 285 --------------- .../batch_manager/dataTransceiverImpl.h | 113 ------ .../ucx_utils/connection.cpp | 2 +- cpp/tests/unit_tests/executor/ucxCommTest.cpp | 2 - .../multi_gpu/cacheTransceiverTest.cpp | 33 +- 11 files changed, 524 insertions(+), 653 deletions(-) delete mode 100644 cpp/tensorrt_llm/batch_manager/dataTransceiverImpl.cpp delete mode 100644 cpp/tensorrt_llm/batch_manager/dataTransceiverImpl.h diff --git a/cpp/include/tensorrt_llm/batch_manager/cacheTransceiver.h b/cpp/include/tensorrt_llm/batch_manager/cacheTransceiver.h index 91e096264ca..2e2cbe13d17 100644 --- a/cpp/include/tensorrt_llm/batch_manager/cacheTransceiver.h +++ b/cpp/include/tensorrt_llm/batch_manager/cacheTransceiver.h @@ -34,8 +34,8 @@ namespace tensorrt_llm::batch_manager class ContextProgress; class BaseCacheTransceiver; -class DataResponder; -class DataRequester; +class CacheSender; +class CacheReceiver; class CacheTransceiverFactory { @@ -110,8 +110,8 @@ class CacheTransceiver : public BaseCacheTransceiver void setContextState(LlmRequest* llmRequest); - std::unique_ptr mCacheSender; - std::unique_ptr mDataRequester; + std::unique_ptr mCacheSender; + std::unique_ptr mCacheReceiver; std::vector>> mSenderFutures; std::vector>> mRequesterFutures; mpi::MpiComm const *mMpiGroupComm{nullptr}, *mMpiWorldComm{nullptr}; diff --git a/cpp/tensorrt_llm/batch_manager/CMakeLists.txt b/cpp/tensorrt_llm/batch_manager/CMakeLists.txt index 5f7d774c0b0..b0e5b2ddf6b 100644 --- a/cpp/tensorrt_llm/batch_manager/CMakeLists.txt +++ b/cpp/tensorrt_llm/batch_manager/CMakeLists.txt @@ -24,7 +24,6 @@ set(SRCS createNewDecoderRequests.cpp contextProgress.cpp dataTransceiver.cpp - dataTransceiverImpl.cpp decoderBuffers.cpp encoderBuffers.cpp guidedDecoder.cpp diff --git a/cpp/tensorrt_llm/batch_manager/cacheFormatter.h b/cpp/tensorrt_llm/batch_manager/cacheFormatter.h index 8ae8ee5f2ca..ac675848b41 100644 --- a/cpp/tensorrt_llm/batch_manager/cacheFormatter.h +++ b/cpp/tensorrt_llm/batch_manager/cacheFormatter.h @@ -18,11 +18,11 @@ #pragma once #include "cacheTransBuffer.h" -#include "dataTransceiver.h" #include "tensorrt_llm/batch_manager/kvCacheManager.h" #include "tensorrt_llm/batch_manager/kvCacheUtils.h" #include "tensorrt_llm/common/envUtils.h" #include "tensorrt_llm/common/logger.h" +#include "tensorrt_llm/executor/cacheCommunicator.h" #include "tensorrt_llm/executor/cache_transmission/cacheSplitConcat.h" #include "tensorrt_llm/executor/dataTransceiverState.h" #include "tensorrt_llm/runtime/bufferManager.h" @@ -38,6 +38,135 @@ BlockRange getBlockRangeForSending(BaseKVCacheManager* cacheManager, LlmRequest BlockRange getBlockRangeForReceiving(BaseKVCacheManager* cacheManager, LlmRequest const& llmRequest); +using DataContext = tensorrt_llm::executor::kv_cache::DataContext; +using Connection = tensorrt_llm::executor::kv_cache::Connection; +using SizeType32 = tensorrt_llm::runtime::SizeType32; + +class TransferSession +{ +public: + struct Measure + { + double delay; // from last token (ctx) or arrival time (gen), in ms + double duration; // in ms + double bandwidth; // in Gbps + }; + + TransferSession(std::vector connections, DataContext dataContext, + executor::DataTransceiverState const& selfState, executor::DataTransceiverState otherState, + runtime::BufferManager const& bufferManager, LlmRequest const* llmRequest = nullptr) + : mConnections(std::move(connections)) + , mDataContext(dataContext) + , mSelfState(&selfState) + , mOtherState(std::move(otherState)) + , mBufferManager(&bufferManager) + , mRequest(llmRequest) + { + TLLM_CHECK(!mConnections.empty()); + } + + [[nodiscard]] std::vector const& getConnections() const + { + return mConnections; + } + + // should be called only during the initialization of the TransferSession + void setConnection(size_t idx, Connection const* conn) + { + mConnections.at(idx) = conn; + } + + [[nodiscard]] DataContext const& getDataContext() const + { + return mDataContext; + } + + [[nodiscard]] executor::DataTransceiverState const& getSelfState() const + { + return *mSelfState; + } + + [[nodiscard]] executor::DataTransceiverState const& getOtherState() const + { + return mOtherState; + } + + [[nodiscard]] runtime::BufferManager const& getBufferManager() const + { + return *mBufferManager; + } + + void send(size_t idx, void const* data, size_t size) + { + mConnections.at(idx)->send(mDataContext, data, size); + } + + void recv(size_t idx, void* data, size_t size) + { + mConnections.at(idx)->recv(mDataContext, data, size); + } + + [[nodiscard]] LlmRequest const& getLlmRequest() const + { + TLLM_CHECK(mRequest != nullptr); + return *mRequest; + } + + // in CacheSender, the LlmRequest is not available until the sendSync is called + void setLlmRequest(LlmRequest const& llmRequest) + { + mRequest = &llmRequest; + } + + void appendMeasure(double delay, double duration, size_t size) + { + if (!mRecordMeasure) + { + return; + } + auto bandwidth = size * 8 / (duration / 1000) / 1e9; // byte, ms => Gbps + mMeasures.emplace_back(Measure{delay, duration, bandwidth}); + } + + // TODO: 1. use global id instead of context request id; 2. export to llm metrics instead of file + void exportMeasure(std::ofstream& outFile, bool isContext) const + { + if (mMeasures.empty()) + { + return; + } + // write header if not exist + if (outFile.tellp() == 0) + { + outFile << "RequestID"; + for (size_t i = 0; i < mMeasures.size(); i++) + { + outFile << ",Delay(ms),Duration(ms),Bandwidth(Gbps)"; + } + outFile << '\n'; + } + // write measures + TLLM_CHECK(isContext || mRequest->getContextPhaseParams().has_value()); + auto reqId = isContext ? mRequest->mRequestId : mRequest->getContextPhaseParams().value().getReqId(); + outFile << reqId; + for (auto const& measure : mMeasures) + { + outFile << "," << measure.delay << "," << measure.duration << "," << measure.bandwidth; + } + outFile << '\n' << std::flush; + } + +private: + std::vector mConnections; + DataContext mDataContext; + executor::DataTransceiverState const* mSelfState; // stored in CacheReceiver/CacheSender + executor::DataTransceiverState mOtherState; + runtime::BufferManager const* mBufferManager; + LlmRequest const* mRequest; + std::vector mMeasures; + bool mRecordMeasure{false}; +}; + // Used to support the cache transmission with different layouts and different protocols. class BaseCacheFormatter { @@ -78,6 +207,66 @@ class BaseCacheFormatter virtual ~BaseCacheFormatter() = default; }; +class KvCacheMeasureHelper +{ +public: + KvCacheMeasureHelper(std::string output_path) + : mOutputPath(std::move(output_path)) + { + } + + void appendKVCacheTransfer(LlmRequest::RequestIdType requestId, double duration, size_t size) + { + auto bandwidth = size * 8 / (duration / 1000) / 1e9; + if (mOutputPath.empty()) + { + return; + } + + std::lock_guard lock(mMutex); + mRequestKVCacheTranfserMeasure[requestId].emplace_back(duration, bandwidth); + } + + ~KvCacheMeasureHelper() + { + if (!mRequestKVCacheTranfserMeasure.empty() && !mOutputPath.empty()) + { + auto rank = mpi::MpiComm::world().getRank(); + std::string outFilePath = mOutputPath + "rank_" + std::to_string(rank) + ".txt"; + std::ofstream outFile(outFilePath); + + TLLM_CHECK_WITH_INFO(outFile.is_open(), "Cannot write to file " + outFilePath); + + size_t numTransferMeasure = mRequestKVCacheTranfserMeasure.begin()->second.size(); + + outFile << "RequestID"; + for (size_t i = 0; i < numTransferMeasure; i++) + { + outFile << ",TimeDuration,Bandwidth"; + } + outFile << '\n'; + + for (auto const& [requestID, measures] : mRequestKVCacheTranfserMeasure) + { + outFile << requestID; + + for (auto const& [time, bandwidth] : measures) + { + outFile << "," << time << "," << bandwidth; + } + outFile << '\n'; + } + + outFile.close(); + } + } + +private: + std::map>> mRequestKVCacheTranfserMeasure; + std::string mOutputPath; + std::mutex mMutex; +}; + // Simple cache block copy. Because it does not involve data splitting or merging, it performs best when the // parallel topology is completely identical, making it the preferred method. class CacheFormatter final : public BaseCacheFormatter diff --git a/cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp b/cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp index d30486b9edc..f22ee779ce8 100644 --- a/cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp +++ b/cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp @@ -37,7 +37,6 @@ #include "tensorrt_llm/batch_manager/cacheFormatter.h" #include "tensorrt_llm/batch_manager/cacheTransceiver.h" #include "tensorrt_llm/batch_manager/contextProgress.h" -#include "tensorrt_llm/batch_manager/dataTransceiverImpl.h" #include "tensorrt_llm/batch_manager/kvCacheManager.h" #include "tensorrt_llm/batch_manager/llmRequest.h" #include "tensorrt_llm/batch_manager/mlaCacheFormatter.h" @@ -195,10 +194,9 @@ CacheTransceiver::CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheMa auto makeFormatter = [cacheManager, isMLA, this]() { return createCacheFormatter(cacheManager, mCacheTransBufferManager.get(), isMLA); }; - mCacheSender = std::make_unique( - std::make_unique(mManager.get(), *mCacheState, worldConfig.getRank(), makeFormatter())); - mDataRequester = std::make_unique( - std::make_unique(mManager.get(), *mCacheState, worldConfig.getRank(), makeFormatter())); + mCacheSender = std::make_unique(mManager.get(), *mCacheState, worldConfig.getRank(), makeFormatter()); + mCacheReceiver + = std::make_unique(mManager.get(), *mCacheState, worldConfig.getRank(), makeFormatter()); initializeCommState(); } @@ -250,7 +248,7 @@ void CacheTransceiver::respondAndSendAsync(LlmRequest* llmRequest) return; } setContextState(llmRequest); - auto future = mCacheSender->respondAndSendAsync(*llmRequest); + auto future = mCacheSender->sendAsync(*llmRequest); mSenderFutures.emplace_back(llmRequest, std::move(future)); } @@ -266,7 +264,7 @@ void CacheTransceiver::respondAndSendLayerWise( llmRequest->setState(LlmRequestState::kDISAGG_CONTEXT_INIT_AND_TRANS); setContextState(llmRequest.get()); - auto future = mCacheSender->respondAndSendAsync(*llmRequest); + auto future = mCacheSender->sendAsync(*llmRequest); mSenderFutures.emplace_back(llmRequest.get(), std::move(future)); } } @@ -275,7 +273,7 @@ void CacheTransceiver::requestAndReceiveSync(LlmRequest* llmRequest) { TLLM_CHECK(llmRequest && llmRequest->isGenerationOnlyRequest()); { - auto future = mDataRequester->requestAndReceiveAsync(*llmRequest); + auto future = mCacheReceiver->receiveAsync(*llmRequest); future.get(); } llmRequest->setState(LlmRequestState::kDISAGG_GENERATION_TRANS_COMPLETE); @@ -293,7 +291,7 @@ void CacheTransceiver::requestAndReceiveAsync(LlmRequest* llmRequest) return; } - auto future = mDataRequester->requestAndReceiveAsync(*llmRequest); + auto future = mCacheReceiver->receiveAsync(*llmRequest); mRequesterFutures.emplace_back(llmRequest, std::move(future)); llmRequest->setState(LlmRequestState::kDISAGG_GENERATION_TRANS_IN_PROGRESS); } diff --git a/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp b/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp index 572a194598d..136692b0363 100644 --- a/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp +++ b/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp @@ -23,6 +23,7 @@ #include "tensorrt_llm/common/envUtils.h" #include "tensorrt_llm/common/logger.h" #include "tensorrt_llm/common/utils.h" +#include "tensorrt_llm/executor/cache_transmission/agent_utils/connection.h" #include "tensorrt_llm/runtime/utils/mpiUtils.h" #include #include @@ -34,6 +35,26 @@ namespace tensorrt_llm::batch_manager using kv_cache_manager::BlockRange; using runtime::SizeType32; +using AgentConnectionManager = tensorrt_llm::executor::kv_cache::AgentConnectionManager; +using DataContext = tensorrt_llm::executor::kv_cache::DataContext; + +static int32_t tagFromRequestId(LlmRequest::RequestIdType requestId) +{ + constexpr int32_t kDATA_TAG{43}; + return ((requestId & 0xFFF) << 8) | (kDATA_TAG & 0xFF); +} + +struct ReceiveCacheResource +{ + runtime::BufferManager mBufferManager; + runtime::CudaEvent mCudaEvent; + + ReceiveCacheResource(runtime::BufferManager&& bufferManager, runtime::CudaEvent&& cudaEvent) + : mBufferManager(bufferManager) + , mCudaEvent(std::move(cudaEvent)) + { + } +}; RequestInfo::RequestInfo(LlmRequest::RequestIdType requestId, executor::DataTransceiverState transState) : mRequestId{requestId} @@ -91,58 +112,26 @@ std::size_t RequestInfo::serializedSize(RequestInfo const& requestInfo) return totalSize; } -void TransferSession::appendMeasure(double delay, double duration, size_t size) -{ - if (!mRecordMeasure) - { - return; - } - auto bandwidth = size * 8 / (duration / 1000) / 1e9; // byte, ms => Gbps - mMeasures.emplace_back(Measure{delay, duration, bandwidth}); -} - -void TransferSession::exportMeasure(std::ofstream& outFile, bool isContext) const -{ - if (mMeasures.empty()) - { - return; - } - // write header if not exist - if (outFile.tellp() == 0) - { - outFile << "RequestID"; - for (size_t i = 0; i < mMeasures.size(); i++) - { - outFile << ",Delay(ms),Duration(ms),Bandwidth(Gbps)"; - } - outFile << '\n'; - } - // write measures - TLLM_CHECK(isContext || mRequest->getContextPhaseParams().has_value()); - auto reqId = isContext ? mRequest->mRequestId : mRequest->getContextPhaseParams().value().getReqId(); - outFile << reqId; - for (auto const& measure : mMeasures) - { - outFile << "," << measure.delay << "," << measure.duration << "," << measure.bandwidth; - } - outFile << '\n' << std::flush; -} - -class DataResponder::Impl +class CacheSender::Impl { public: using RequestIdType = LlmRequest::RequestIdType; - Impl(std::unique_ptr sender) - : mSender{std::move(sender)} + Impl(executor::kv_cache::ConnectionManager* manager, executor::kv_cache::CacheState selfCacheState, + SizeType32 selfIndex, std::unique_ptr formatter) + : mManager{manager} + , mSelfState{std::move(selfCacheState), executor::kv_cache::CommState{manager->getCommState()}} + , mFormatter{std::move(formatter)} + , mBufferManager{std::make_shared()} { - TLLM_CHECK(mSender); + TLLM_CHECK(mManager); + TLLM_CHECK(mManager->getCommState().getSelfIdx() == selfIndex); TLLM_CUDA_CHECK(cudaGetDevice(&mDeviceId)); mCurrentRequest = std::nullopt; mResponseFuture = std::async(std::launch::async, &Impl::response, this); } - [[nodiscard]] std::future respondAndSendAsync(LlmRequest& llmRequest) + [[nodiscard]] std::future sendAsync(LlmRequest& llmRequest) { std::promise promise; auto future = promise.get_future(); @@ -161,12 +150,90 @@ class DataResponder::Impl [[nodiscard]] executor::kv_cache::CommState const& getCommState() const { - return mSender->getCommState(); + return mSelfState.getCommState().value(); } void setCommState(executor::kv_cache::CommState commState) { - mSender->setCommState(std::move(commState)); + mSelfState.setCommState(std::move(commState)); + } + + [[nodiscard]] size_t getCounterpartsCount(LlmRequest::RequestIdType requestId) const + { + auto it = mRequestToSession.find(requestId); + TLLM_CHECK(it != mRequestToSession.end()); + return it->second.getConnections().size(); + } + + void release(LlmRequest::RequestIdType requestId) + { + auto it = mRequestToSession.find(requestId); + TLLM_CHECK(it != mRequestToSession.end()); + std::unique_lock lk(mMtxForMap); + mRequestToSession.erase(it); + } + + [[nodiscard]] RequestInfo recvRequestInfo() + { + auto* agentConnectionManager = dynamic_cast(mManager); + bool isAgent = agentConnectionManager != nullptr; + + auto agentRecvFun = [&](RequestInfo& requestInfo) + { + auto const* connection = agentConnectionManager->recvConnectionAndRequestInfo(requestInfo); + return connection; + }; + TransceiverTag::Id id; + RequestInfo info; + auto const* connection = isAgent ? agentRecvFun(info) + : mManager->recvConnect(DataContext{TransceiverTag::kID_TAG}, &id, sizeof(id)); + if (!isAgent) + { + TLLM_CHECK(id == TransceiverTag::Id::REQUEST_SEND); + std::uint64_t infoSize{0}; + connection->recv( + executor::kv_cache::DataContext{TransceiverTag::kINFO_SIZE_TAG}, &infoSize, sizeof(infoSize)); + std::string serializedInfo; + serializedInfo.resize(infoSize); + connection->recv( + executor::kv_cache::DataContext{TransceiverTag::kINFO_TAG}, serializedInfo.data(), infoSize); + std::istringstream iss(serializedInfo); + info = RequestInfo::deserialize(iss); + } + + auto requestId = info.getRequestId(); + TLLM_CHECK_WITH_INFO(mFormatter->inquireSupport( + mSelfState.getCacheState().value(), info.getTransState().getCacheState().value()), + "Disagg server does not currently support these cacheState, please check the cacheState of the context and " + "gen " + "executors"); + auto peerRelativeRanks = executor::kv_cache::targetIRanks(info.getTransState().getCacheState().value(), + mSelfState.getCacheState().value(), mSelfState.getCommState().value().getSelfIdx()) + .mIRanks; + int peerIdx = std::distance(peerRelativeRanks.begin(), + std::find( + peerRelativeRanks.begin(), peerRelativeRanks.end(), info.getTransState().getCommState()->getSelfIdx())); + { + std::unique_lock lk(mMtxForMap); + auto it = mRequestToSession.find(requestId); + if (it == mRequestToSession.end()) + { + auto session = TransferSession(std::vector(peerRelativeRanks.size(), nullptr), + DataContext{tagFromRequestId(requestId)}, mSelfState, info.getTransState(), mBufferManager); + it = mRequestToSession.emplace(requestId, std::move(session)).first; + } + it->second.setConnection(peerIdx, connection); + } + return info; + } + + void sendSync(LlmRequest const& llmRequest) + { + auto it = mRequestToSession.find(llmRequest.mRequestId); + TLLM_CHECK(it != mRequestToSession.end()); + auto& session = it->second; + session.setLlmRequest(llmRequest); + mFormatter->format(session); } ~Impl() @@ -186,8 +253,8 @@ class DataResponder::Impl try { TLLM_CUDA_CHECK(cudaSetDevice(mDeviceId)); - mSender->sendSync(*resp.mRequest); - mSender->release(id); + sendSync(*resp.mRequest); + release(id); resp.mPromise.set_value(); } catch (std::exception const& e) @@ -217,14 +284,14 @@ class DataResponder::Impl std::vector blockHashes; if (!isSending() && !mReadyResponses.empty()) { - auto const& requestInfo = mSender->recvRequestInfo(); + auto const& requestInfo = recvRequestInfo(); auto reqId = requestInfo.getRequestId(); blockHashes = requestInfo.getBlockHashes(); mCurrentRequest = reqId; if (mRemainSendCount.find(reqId) == mRemainSendCount.end()) { - mRemainSendCount[reqId] = mSender->getCounterpartsCount(reqId); + mRemainSendCount[reqId] = getCounterpartsCount(reqId); } } auto it = getCurrentResponse(); @@ -245,12 +312,12 @@ class DataResponder::Impl { // TODO: Use a thread pool and check for thread safety. std::thread( - &DataResponder::Impl::sendAndRemoveResponse, this, it->first, std::move(it->second)) + &CacheSender::Impl::sendAndRemoveResponse, this, it->first, std::move(it->second)) .detach(); } else { - DataResponder::Impl::sendAndRemoveResponse(it->first, std::move(it->second)); + CacheSender::Impl::sendAndRemoveResponse(it->first, std::move(it->second)); } removeResponse(it); } @@ -269,7 +336,7 @@ class DataResponder::Impl } catch (std::exception const& err) { - TLLM_LOG_ERROR("Exception in DataResponder response: %s", err.what()); + TLLM_LOG_ERROR("Exception in CacheSender response: %s", err.what()); for (auto& it : mReadyResponses) { it.second.mPromise.set_exception(std::current_exception()); @@ -324,25 +391,36 @@ class DataResponder::Impl std::atomic mAnyReady{false}, mTerminate{false}; std::condition_variable mSenderCv; std::future mResponseFuture; - std::unique_ptr mSender; std::unordered_map mRemainSendCount; int mDeviceId{-1}; + + executor::kv_cache::ConnectionManager* mManager; + std::map mRequestToSession; + executor::DataTransceiverState mSelfState; + std::unique_ptr mFormatter; + std::mutex mMtxForMap; + runtime::BufferManager mBufferManager; }; -class DataRequester::Impl +class CacheReceiver::Impl { public: - Impl(std::unique_ptr receiver) - : mCacheReceiver{std::move(receiver)} + Impl(executor::kv_cache::ConnectionManager* manager, executor::kv_cache::CacheState selfCacheState, + SizeType32 selfIndex, std::unique_ptr formatter) + : mManager{manager} + , mSelfState{std::move(selfCacheState), executor::kv_cache::CommState{manager->getCommState()}} + , mFormatter{std::move(formatter)} + , mBufferManager{std::make_shared()} { - TLLM_CHECK(mCacheReceiver); + TLLM_CHECK(mManager); + TLLM_CHECK(mManager->getCommState().getSelfIdx() == selfIndex); TLLM_CUDA_CHECK(cudaGetDevice(&mDeviceId)); } - [[nodiscard]] std::future requestAndReceiveAsync(LlmRequest& llmRequest) + [[nodiscard]] std::future receiveAsync(LlmRequest& llmRequest) { // TODO: Modify the implementation here to avoid frequent thread creation. - return std::async(std::launch::async, &DataRequester::Impl::requestSync, this, std::ref(llmRequest)); + return std::async(std::launch::async, &CacheReceiver::Impl::requestSync, this, std::ref(llmRequest)); } [[nodiscard]] std::future requestAndReceiveAsyncMultiThreads(LlmRequest& llmRequest) @@ -361,7 +439,7 @@ class DataRequester::Impl { mInstanceToAsyncResource.emplace(processInfo, std::make_unique()); - auto requestFuture = std::async(std::launch::async, &DataRequester::Impl::request, this, + auto requestFuture = std::async(std::launch::async, &CacheReceiver::Impl::request, this, std::ref(*mInstanceToAsyncResource.at(processInfo))); mRequestFutures.emplace_back(std::move(requestFuture)); } @@ -379,6 +457,107 @@ class DataRequester::Impl } } + void receiveSync(TransferSession& session) + { + mFormatter->unformat(session); + } + + TransferSession sendRequestInfo(LlmRequest const& llmRequest) + { + uint64_t requestId = llmRequest.getContextPhaseParams().value().getReqId(); + auto const& contextState = llmRequest.getDataTransceiverState(); + auto const& commState = contextState.getCommState().value(); + auto const& destCacheState = contextState.getCacheState().value(); + TLLM_CHECK_WITH_INFO(mFormatter->inquireSupport(mSelfState.getCacheState().value(), destCacheState), + "Disagg server does not currently support these cacheState."); + + RequestInfo requestInfo(requestId, mSelfState); + + auto disableSelectiveCacheTransfer = common::getEnvDisableSelectiveCacheTransfer() + || (mFormatter->getCacheManager()->getBlockManager().getNumPools() > 1); + if (!disableSelectiveCacheTransfer) + { + auto* cacheManager = mFormatter->getCacheManager(); + auto blockRange + = kv_cache_manager::BlockRange::fromNewlyAllocatedBlockIds(*cacheManager, llmRequest.mRequestId); + requestInfo = RequestInfo(requestId, blockRange.getBlockHashes(), mSelfState); + } + + auto* agentConnectionManager = dynamic_cast(mManager); + std::optional cacheBufferId = std::nullopt; + if (agentConnectionManager != nullptr) + { + cacheBufferId = agentConnectionManager->getCacheTransBufferManager()->assignBufferIndexForRecv(); + TLLM_CHECK(cacheBufferId.has_value()); + // memory Desp , validSegmentIdx send + } + auto counterParts = mFormatter->getCounterparts( + mSelfState.getCacheState().value(), mSelfState.getCommState().value().getSelfIdx(), destCacheState); + + auto connections = mManager->getConnections(commState); + std::vector counterPartConnections; + for (auto index : counterParts) + { + auto const* connection = connections.at(index); + counterPartConnections.emplace_back(connection); + } + auto pickUpIdx = mFormatter->pickRecvConnections(counterParts.size(), mSelfState.getCacheState().value(), + mSelfState.getCommState().value().getSelfIdx(), destCacheState); + for (size_t i = 0; i < counterPartConnections.size(); i++) + { + auto const* connection = counterPartConnections[i]; + // if Manager is agentConnectionManager, then send request info to agent + auto* agentConnectionManager = dynamic_cast(mManager); + if (agentConnectionManager != nullptr) + { + // TODO: index -> validConnectionIdx conversion + auto valideConnectionIdx = std::find(pickUpIdx.begin(), pickUpIdx.end(), i) - pickUpIdx.begin(); + auto* agentConnection = dynamic_cast(connection); + TLLM_CHECK(agentConnection != nullptr); + TLLM_CHECK(cacheBufferId.has_value()); + const_cast(agentConnection) + ->sendRequestAndBufferInfo(requestInfo, cacheBufferId, valideConnectionIdx); + } + else + { + sendRequestInfo(connection, requestInfo); + } + } + auto const& resource = getReceiveCacheResource(llmRequest); + return TransferSession(std::move(counterPartConnections), DataContext{tagFromRequestId(requestId)}, mSelfState, + contextState, resource->mBufferManager, &llmRequest); + } + + std::unique_ptr const& getReceiveCacheResource(LlmRequest const& llmRequest) + { + std::scoped_lock lock(mProcessIoResouceMutex); + TLLM_CHECK(llmRequest.getDataTransceiverState().getCommState().has_value()); + std::string processString = "default"; + if (common::getEnvRequestKVCacheConcurrent()) + { + processString = llmRequest.getDataTransceiverState().getCommState()->toString(); + } + if (mProcessToResources.find(processString) == mProcessToResources.end()) + { + mProcessToResources.emplace(processString, + std::make_unique( + runtime::BufferManager{std::make_shared()}, runtime::CudaEvent{})); + } + return mProcessToResources.at(processString); + } + + void sendRequestInfo(executor::kv_cache::Connection const* connection, RequestInfo const& info) + { + std::ostringstream oss; + RequestInfo::serialize(info, oss); + auto const& serializedInfo = oss.str(); + std::size_t const infoSize = serializedInfo.size(); + TransceiverTag::Id id{TransceiverTag::Id::REQUEST_SEND}; + connection->send(executor::kv_cache::DataContext{TransceiverTag::kID_TAG}, &id, sizeof(id)); + connection->send(executor::kv_cache::DataContext{TransceiverTag::kINFO_SIZE_TAG}, &infoSize, sizeof(infoSize)); + connection->send(executor::kv_cache::DataContext{TransceiverTag::kINFO_TAG}, serializedInfo.data(), infoSize); + } + ~Impl() { for (auto&& [processInfo, asyncResource] : mInstanceToAsyncResource) @@ -400,8 +579,8 @@ class DataRequester::Impl llmRequest.getContextPhaseParams().value().getReqId()); llmRequest.setKvCacheTransferStart(std::chrono::steady_clock::now()); TLLM_CUDA_CHECK(cudaSetDevice(mDeviceId)); - auto session = mCacheReceiver->sendRequestInfo(llmRequest); - mCacheReceiver->receiveSync(session); + auto session = sendRequestInfo(llmRequest); + receiveSync(session); llmRequest.setKvCacheTransferEnd(std::chrono::steady_clock::now()); TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(), @@ -498,7 +677,7 @@ class DataRequester::Impl } catch (std::exception const& err) { - TLLM_LOG_ERROR("Exception in DataRequester request(): request id:%ld , request context id:%ld : %s", + TLLM_LOG_ERROR("Exception in CacheReceiver request(): request id:%ld , request context id:%ld : %s", requestAndPromise.mRequest->mRequestId, requestAndPromise.mRequest->getContextPhaseParams().value().getReqId(), err.what()); requestAndPromise.mPromise->set_exception(std::current_exception()); @@ -507,45 +686,51 @@ class DataRequester::Impl } } - std::unique_ptr mCacheReceiver; int mDeviceId{-1}; - std::vector> mRequestFutures; std::unordered_map> mInstanceToAsyncResource; + executor::kv_cache::ConnectionManager* mManager; + executor::DataTransceiverState mSelfState; + std::unique_ptr mFormatter; + std::unordered_map> mProcessToResources; + std::mutex mProcessIoResouceMutex; + runtime::BufferManager mBufferManager; }; -DataResponder::DataResponder(std::unique_ptr sender) - : mImpl{std::make_unique(std::move(sender))} +CacheSender::CacheSender(executor::kv_cache::ConnectionManager* manager, executor::kv_cache::CacheState selfCacheState, + SizeType32 selfIndex, std::unique_ptr formatter) + : mImpl{std::make_unique(manager, selfCacheState, selfIndex, std::move(formatter))} { } -std::future DataResponder::respondAndSendAsync(LlmRequest& llmRequest) const +std::future CacheSender::sendAsync(LlmRequest& llmRequest) const { - return mImpl->respondAndSendAsync(llmRequest); + return mImpl->sendAsync(llmRequest); } -executor::kv_cache::CommState const& DataResponder::getCommState() const +executor::kv_cache::CommState const& CacheSender::getCommState() const { return mImpl->getCommState(); } -void DataResponder::setCommState(executor::kv_cache::CommState commState) +void CacheSender::setCommState(executor::kv_cache::CommState commState) { mImpl->setCommState(std::move(commState)); } -DataResponder::~DataResponder() = default; +CacheSender::~CacheSender() = default; -DataRequester::DataRequester(std::unique_ptr receiver) - : mImpl{std::make_unique(std::move(receiver))} +CacheReceiver::CacheReceiver(executor::kv_cache::ConnectionManager* manager, + executor::kv_cache::CacheState selfCacheState, SizeType32 selfIndex, std::unique_ptr formatter) + : mImpl{std::make_unique(manager, selfCacheState, selfIndex, std::move(formatter))} { } -std::future DataRequester::requestAndReceiveAsync(LlmRequest& llmRequest) const +std::future CacheReceiver::receiveAsync(LlmRequest& llmRequest) const { return mImpl->requestAndReceiveAsyncMultiThreads(llmRequest); } -DataRequester::~DataRequester() = default; +CacheReceiver::~CacheReceiver() = default; } // namespace tensorrt_llm::batch_manager diff --git a/cpp/tensorrt_llm/batch_manager/dataTransceiver.h b/cpp/tensorrt_llm/batch_manager/dataTransceiver.h index ef66cd1382d..14c8302e8d2 100644 --- a/cpp/tensorrt_llm/batch_manager/dataTransceiver.h +++ b/cpp/tensorrt_llm/batch_manager/dataTransceiver.h @@ -21,6 +21,7 @@ #include #include +#include "tensorrt_llm/batch_manager/cacheFormatter.h" #include "tensorrt_llm/batch_manager/llmRequest.h" #include "tensorrt_llm/common/assert.h" #include "tensorrt_llm/common/envUtils.h" @@ -38,6 +39,23 @@ namespace tensorrt_llm::batch_manager using DataContext = tensorrt_llm::executor::kv_cache::DataContext; using Connection = tensorrt_llm::executor::kv_cache::Connection; using ConnectionManager = tensorrt_llm::executor::kv_cache::ConnectionManager; +using SizeType32 = tensorrt_llm::runtime::SizeType32; +using TransferSession = kv_cache_manager::TransferSession; + +struct TransceiverTag +{ + enum class Id : uint64_t + { + REQUEST_SEND = 1, + TERMINATION = 2 + }; + + static constexpr int32_t kID_TAG{19}; + static constexpr int32_t kINFO_SIZE_TAG{22}; + static constexpr int32_t kINFO_TAG{32}; +}; + +using BaseCacheFormatter = kv_cache_manager::BaseCacheFormatter; // Used to store the information that needs to be sent to the context executor to ensure the generation // executor smoothly receives the data. @@ -94,187 +112,70 @@ class RequestInfo executor::DataTransceiverState mTransState; }; -class TransferSession -{ -public: - struct Measure - { - double delay; // from last token (ctx) or arrival time (gen), in ms - double duration; // in ms - double bandwidth; // in Gbps - }; - - TransferSession(std::vector connections, DataContext dataContext, - executor::DataTransceiverState const& selfState, executor::DataTransceiverState otherState, - runtime::BufferManager const& bufferManager, LlmRequest const* llmRequest = nullptr, bool recordMeasure = false) - : mConnections(std::move(connections)) - , mDataContext(dataContext) - , mSelfState(&selfState) - , mOtherState(std::move(otherState)) - , mBufferManager(&bufferManager) - , mRequest(llmRequest) - , mRecordMeasure(recordMeasure) - { - TLLM_CHECK(!mConnections.empty()); - } - - [[nodiscard]] std::vector const& getConnections() const - { - return mConnections; - } - - // should be called only during the initialization of the TransferSession - void setConnection(size_t idx, Connection const* conn) - { - mConnections.at(idx) = conn; - } - - [[nodiscard]] DataContext const& getDataContext() const - { - return mDataContext; - } - - [[nodiscard]] executor::DataTransceiverState const& getSelfState() const - { - return *mSelfState; - } - - [[nodiscard]] executor::DataTransceiverState const& getOtherState() const - { - return mOtherState; - } - - [[nodiscard]] runtime::BufferManager const& getBufferManager() const - { - return *mBufferManager; - } - - void send(size_t idx, void const* data, size_t size) - { - mConnections.at(idx)->send(mDataContext, data, size); - } - - void recv(size_t idx, void* data, size_t size) - { - mConnections.at(idx)->recv(mDataContext, data, size); - } - - [[nodiscard]] LlmRequest const& getLlmRequest() const - { - TLLM_CHECK(mRequest != nullptr); - return *mRequest; - } - - // in DataSender, the LlmRequest is not available until the sendSync is called - void setLlmRequest(LlmRequest const& llmRequest) - { - mRequest = &llmRequest; - } - - void appendMeasure(double delay, double duration, size_t size); - // TODO: 1. use global id instead of context request id; 2. export to llm metrics instead of file - void exportMeasure(std::ofstream& outFile, bool isContext) const; - -private: - std::vector mConnections; - DataContext mDataContext; - executor::DataTransceiverState const* mSelfState; // stored in DataRequester/DataResponder - executor::DataTransceiverState mOtherState; - runtime::BufferManager const* mBufferManager; - LlmRequest const* mRequest; - bool mRecordMeasure; - std::vector mMeasures; -}; - -// Operators required for data transmission in specific communication protocols. -class DataSender +class CacheSender { public: - /// @brief Receive the request information. - /// @return The request information. - [[nodiscard]] virtual RequestInfo recvRequestInfo() = 0; + /// @brief Constructor. + CacheSender(executor::kv_cache::ConnectionManager* manager, executor::kv_cache::CacheState selfCacheState, + SizeType32 selfIndex, std::unique_ptr formatter); - /// @brief Synchronously send data. - /// @param llmRequest The request object to which the data belongs. - virtual void sendSync(LlmRequest const& llmRequest) = 0; + /// @brief Asynchronously respond to the request and send data. + /// @param llmRequest Request object. Its data should be ready when called, and the data for this request + /// should remain valid until future synchronization. + /// @return Once the data is fully sent, the future object will become valid. + [[nodiscard]] std::future sendAsync(LlmRequest& llmRequest) const; /// @brief Return the internal communicator status. /// @return The communicator status. - [[nodiscard]] virtual executor::kv_cache::CommState const& getCommState() const = 0; + [[nodiscard]] executor::kv_cache::CommState const& getCommState() const; /// @brief Reset the internal communicator status. /// @param commState The communicator status. - virtual void setCommState(executor::kv_cache::CommState commState) = 0; - - [[nodiscard]] virtual size_t getCounterpartsCount(LlmRequest::RequestIdType requestId) const = 0; + void setCommState(executor::kv_cache::CommState commState); - virtual void release(LlmRequest::RequestIdType requestId) = 0; + /// @brief Receive the request information. + /// @return The request information. + [[nodiscard]] RequestInfo recvRequestInfo(); - /// @brief Destructor. - virtual ~DataSender() = default; -}; + /// @brief Synchronously send data. + /// @param llmRequest The request object to which the data belongs. + void sendSync(LlmRequest const& llmRequest); -// Operators required for data transmission in specific communication protocols. -class DataReceiver -{ -public: /// @brief Send the request information. /// @param llmRequest The request object to which the information belongs. - virtual TransferSession sendRequestInfo(LlmRequest const& llmRequest) = 0; + [[nodiscard]] TransferSession sendRequestInfo(LlmRequest const& llmRequest); /// @brief Synchronously receive data. /// @param session The transfer session. - virtual void receiveSync(TransferSession& session) = 0; - - /// @brief Destructor. - virtual ~DataReceiver() = default; -}; - -class DataResponder -{ -public: - /// @brief Constructor. - /// @param sender The sender used at the underlying level. - explicit DataResponder(std::unique_ptr sender); - - /// @brief Asynchronously respond to the request and send data. - /// @param llmRequest Request object. Its data should be ready when called, and the data for this request - /// should remain valid until future synchronization. - /// @return Once the data is fully sent, the future object will become valid. - [[nodiscard]] std::future respondAndSendAsync(LlmRequest& llmRequest) const; - - /// @brief Return the internal communicator status. - /// @return The communicator status. - [[nodiscard]] executor::kv_cache::CommState const& getCommState() const; - - /// @brief Reset the internal communicator status. - /// @param commState The communicator status. - void setCommState(executor::kv_cache::CommState commState); + void receiveSync(TransferSession& session); /// @brief Destructor. - ~DataResponder(); + ~CacheSender(); private: class Impl; std::unique_ptr mImpl; }; -class DataRequester +class CacheReceiver { public: /// @brief Constructor. - /// @param receiver The receiver used at the underlying level. - explicit DataRequester(std::unique_ptr receiver); + CacheReceiver(executor::kv_cache::ConnectionManager* manager, executor::kv_cache::CacheState selfCacheState, + SizeType32 selfIndex, std::unique_ptr formatter); /// @brief Asynchronously send a request to receive data. /// @param llmRequest Request object. Its data should be in an allocated but unwritten state when called, and the /// data for this request should remain intact only after future synchronization. /// @return Once the data is fully received, the future object will become valid. - [[nodiscard]] std::future requestAndReceiveAsync(LlmRequest& llmRequest) const; + [[nodiscard]] std::future receiveAsync(LlmRequest& llmRequest) const; + + TransferSession sendRequestInfo(LlmRequest const& llmRequest); + void receiveSync(TransferSession& session); /// @brief Destructor. - ~DataRequester(); + ~CacheReceiver(); private: class Impl; diff --git a/cpp/tensorrt_llm/batch_manager/dataTransceiverImpl.cpp b/cpp/tensorrt_llm/batch_manager/dataTransceiverImpl.cpp deleted file mode 100644 index 0fa3ca16f54..00000000000 --- a/cpp/tensorrt_llm/batch_manager/dataTransceiverImpl.cpp +++ /dev/null @@ -1,285 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "dataTransceiverImpl.h" - -#include "tensorrt_llm/common/envUtils.h" -#include "tensorrt_llm/executor/cache_transmission/agent_utils/connection.h" -#include "tensorrt_llm/runtime/utils/mpiUtils.h" - -#include - -namespace tensorrt_llm::batch_manager -{ - -static int32_t tagFromRequestId(LlmRequest::RequestIdType requestId) -{ - constexpr int32_t kDATA_TAG{43}; - return ((requestId & 0xFFF) << 8) | (kDATA_TAG & 0xFF); -} - -namespace fs = std::filesystem; - -static fs::path getTransferOutputPath(char const* tag) -{ - auto outputPath = common::getEnvKVCacheTransferOutputPath(); - if (!outputPath.empty()) - { - auto rank = mpi::MpiComm::world().getRank(); - auto path = fs::path(outputPath); - fs::create_directories(path); - return path / ("rank_" + std::to_string(rank) + "_" + tag + ".csv"); - } - return {}; -} - -CacheSenderImpl::CacheSenderImpl(executor::kv_cache::ConnectionManager* manager, - executor::kv_cache::CacheState selfCacheState, SizeType32 selfIndex, std::unique_ptr formatter) - : mManager{manager} - , mSelfState{std::move(selfCacheState), executor::kv_cache::CommState{manager->getCommState()}} - , mFormatter(std::move(formatter)) - , mBufferManager{std::make_shared()} -{ - TLLM_CHECK(mManager); - TLLM_CHECK(mManager->getCommState().getSelfIdx() == selfIndex); -} - -[[nodiscard]] RequestInfo CacheSenderImpl::recvRequestInfo() -{ - using DataContext = tensorrt_llm::executor::kv_cache::DataContext; - auto* agentConnectionManager = dynamic_cast(mManager); - bool isAgent = agentConnectionManager != nullptr; - - auto agentRecvFun = [&](RequestInfo& requestInfo) - { - auto const* connection = agentConnectionManager->recvConnectionAndRequestInfo(requestInfo); - return connection; - }; - Id id; - RequestInfo info; - auto const* connection - = isAgent ? agentRecvFun(info) : mManager->recvConnect(DataContext{kID_TAG}, &id, sizeof(id)); - if (!isAgent) - { - TLLM_CHECK(id == Id::REQUEST_SEND); - std::uint64_t infoSize{0}; - connection->recv(executor::kv_cache::DataContext{kINFO_SIZE_TAG}, &infoSize, sizeof(infoSize)); - std::string serializedInfo; - serializedInfo.resize(infoSize); - connection->recv(executor::kv_cache::DataContext{kINFO_TAG}, serializedInfo.data(), infoSize); - std::istringstream iss(serializedInfo); - info = RequestInfo::deserialize(iss); - } - - auto requestId = info.getRequestId(); - TLLM_CHECK_WITH_INFO( - mFormatter->inquireSupport(mSelfState.getCacheState().value(), info.getTransState().getCacheState().value()), - "Disagg server does not currently support these cacheState, please check the cacheState of the context and gen " - "executors"); - auto peerRelativeRanks = executor::kv_cache::targetIRanks(info.getTransState().getCacheState().value(), - mSelfState.getCacheState().value(), mSelfState.getCommState().value().getSelfIdx()) - .mIRanks; - int peerIdx = std::distance(peerRelativeRanks.begin(), - std::find( - peerRelativeRanks.begin(), peerRelativeRanks.end(), info.getTransState().getCommState()->getSelfIdx())); - { - std::unique_lock lk(mMtxForMap); - auto it = mRequestToSession.find(requestId); - if (it == mRequestToSession.end()) - { - auto session = TransferSession(std::vector(peerRelativeRanks.size(), nullptr), - DataContext{tagFromRequestId(requestId)}, mSelfState, info.getTransState(), mBufferManager, nullptr, - !common::getEnvKVCacheTransferOutputPath().empty()); - it = mRequestToSession.emplace(requestId, std::move(session)).first; - } - it->second.setConnection(peerIdx, connection); - } - return info; -} - -void CacheSenderImpl::sendSync(LlmRequest const& llmRequest) -{ - auto it = mRequestToSession.find(llmRequest.mRequestId); - TLLM_CHECK(it != mRequestToSession.end()); - auto& session = it->second; - session.setLlmRequest(llmRequest); - mFormatter->format(session); -} - -[[nodiscard]] executor::kv_cache::CommState const& CacheSenderImpl::getCommState() const -{ - return mSelfState.getCommState().value(); -} - -void CacheSenderImpl::setCommState(executor::kv_cache::CommState commState) -{ - mSelfState.setCommState(std::move(commState)); -} - -[[nodiscard]] size_t CacheSenderImpl::getCounterpartsCount(LlmRequest::RequestIdType requestId) const -{ - auto it = mRequestToSession.find(requestId); - TLLM_CHECK(it != mRequestToSession.end()); - return it->second.getConnections().size(); -} - -void CacheSenderImpl::release(LlmRequest::RequestIdType requestId) -{ - auto it = mRequestToSession.find(requestId); - TLLM_CHECK(it != mRequestToSession.end()); - std::unique_lock lk(mMtxForMap); - if (!common::getEnvKVCacheTransferOutputPath().empty()) - { - if (!mMeasuresFile.is_open()) - { - auto outputPath = getTransferOutputPath("send"); - mMeasuresFile.open(outputPath); - TLLM_CHECK_WITH_INFO( - mMeasuresFile.is_open(), "Failed to open transfer output file: %s", outputPath.string().c_str()); - } - it->second.exportMeasure(mMeasuresFile, true); - } - mRequestToSession.erase(it); -} - -CacheReceiverImpl::CacheReceiverImpl(executor::kv_cache::ConnectionManager* manager, - executor::kv_cache::CacheState selfCacheState, SizeType32 selfIndex, std::unique_ptr formatter) - : mManager{manager} - , mSelfState{std::move(selfCacheState), executor::kv_cache::CommState{manager->getCommState()}} - , mFormatter(std::move(formatter)) -{ - TLLM_CHECK(mManager); - TLLM_CHECK(mManager->getCommState().getSelfIdx() == selfIndex); - TLLM_CHECK(mFormatter); -} - -TransferSession CacheReceiverImpl::sendRequestInfo(LlmRequest const& llmRequest) -{ - uint64_t requestId = llmRequest.getContextPhaseParams().value().getReqId(); - auto const& contextState = llmRequest.getDataTransceiverState(); - auto const& commState = contextState.getCommState().value(); - auto const& destCacheState = contextState.getCacheState().value(); - TLLM_CHECK_WITH_INFO(mFormatter->inquireSupport(mSelfState.getCacheState().value(), destCacheState), - "Disagg server does not currently support these cacheState."); - - RequestInfo requestInfo(requestId, mSelfState); - - auto disableSelectiveCacheTransfer = common::getEnvDisableSelectiveCacheTransfer() - || (mFormatter->getCacheManager()->getBlockManager().getNumPools() > 1); - if (!disableSelectiveCacheTransfer) - { - auto* cacheManager = mFormatter->getCacheManager(); - auto blockRange - = kv_cache_manager::BlockRange::fromNewlyAllocatedBlockIds(*cacheManager, llmRequest.mRequestId); - requestInfo = RequestInfo(requestId, blockRange.getBlockHashes(), mSelfState); - } - - auto* agentConnectionManager = dynamic_cast(mManager); - std::optional cacheBufferId = std::nullopt; - if (agentConnectionManager != nullptr) - { - cacheBufferId = agentConnectionManager->getCacheTransBufferManager()->assignBufferIndexForRecv(); - TLLM_CHECK(cacheBufferId.has_value()); - // memory Desp , validSegmentIdx send - } - auto counterParts = mFormatter->getCounterparts( - mSelfState.getCacheState().value(), mSelfState.getCommState().value().getSelfIdx(), destCacheState); - - auto connections = mManager->getConnections(commState); - std::vector counterPartConnections; - for (auto index : counterParts) - { - auto const* connection = connections.at(index); - counterPartConnections.emplace_back(connection); - } - auto pickUpIdx = mFormatter->pickRecvConnections(counterParts.size(), mSelfState.getCacheState().value(), - mSelfState.getCommState().value().getSelfIdx(), destCacheState); - for (size_t i = 0; i < counterPartConnections.size(); i++) - { - auto const* connection = counterPartConnections[i]; - // if Manager is agentConnectionManager, then send request info to agent - auto* agentConnectionManager = dynamic_cast(mManager); - if (agentConnectionManager != nullptr) - { - // TODO: index -> validConnectionIdx conversion - auto valideConnectionIdx = std::find(pickUpIdx.begin(), pickUpIdx.end(), i) - pickUpIdx.begin(); - auto* agentConnection = dynamic_cast(connection); - TLLM_CHECK(agentConnection != nullptr); - TLLM_CHECK(cacheBufferId.has_value()); - const_cast(agentConnection) - ->sendRequestAndBufferInfo(requestInfo, cacheBufferId, valideConnectionIdx); - } - else - { - sendRequestInfo(connection, requestInfo); - } - } - auto const& resource = getReceiveCacheResource(llmRequest); - return TransferSession(std::move(counterPartConnections), DataContext{tagFromRequestId(requestId)}, mSelfState, - contextState, resource->mBufferManager, &llmRequest, !common::getEnvKVCacheTransferOutputPath().empty()); -} - -void CacheReceiverImpl::receiveSync(TransferSession& session) -{ - mFormatter->unformat(session); - if (!common::getEnvKVCacheTransferOutputPath().empty()) - { - std::unique_lock lock(mMeasuresFileMutex); - if (!mMeasuresFile.is_open()) - { - auto outputPath = getTransferOutputPath("recv"); - mMeasuresFile.open(outputPath); - TLLM_CHECK_WITH_INFO( - mMeasuresFile.is_open(), "Failed to open transfer output file: %s", outputPath.string().c_str()); - } - session.exportMeasure(mMeasuresFile, false); - } -} - -void CacheReceiverImpl::sendRequestInfo(executor::kv_cache::Connection const* connection, RequestInfo const& info) -{ - std::ostringstream oss; - RequestInfo::serialize(info, oss); - auto const& serializedInfo = oss.str(); - std::size_t const infoSize = serializedInfo.size(); - Id id{Id::REQUEST_SEND}; - connection->send(executor::kv_cache::DataContext{kID_TAG}, &id, sizeof(id)); - connection->send(executor::kv_cache::DataContext{kINFO_SIZE_TAG}, &infoSize, sizeof(infoSize)); - connection->send(executor::kv_cache::DataContext{kINFO_TAG}, serializedInfo.data(), infoSize); -} - -std::unique_ptr const& CacheReceiverImpl::getReceiveCacheResource( - LlmRequest const& llmRequest) -{ - std::scoped_lock lock(mProcessIoResouceMutex); - TLLM_CHECK(llmRequest.getDataTransceiverState().getCommState().has_value()); - std::string processString = "default"; - if (common::getEnvRequestKVCacheConcurrent()) - { - processString = llmRequest.getDataTransceiverState().getCommState()->toString(); - } - if (mProcessToResources.find(processString) == mProcessToResources.end()) - { - mProcessToResources.emplace(processString, - std::make_unique( - runtime::BufferManager{std::make_shared()}, runtime::CudaEvent{})); - } - - return mProcessToResources.at(processString); -} - -} // namespace tensorrt_llm::batch_manager diff --git a/cpp/tensorrt_llm/batch_manager/dataTransceiverImpl.h b/cpp/tensorrt_llm/batch_manager/dataTransceiverImpl.h deleted file mode 100644 index 2e2e320c72e..00000000000 --- a/cpp/tensorrt_llm/batch_manager/dataTransceiverImpl.h +++ /dev/null @@ -1,113 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "cacheFormatter.h" -#include "cacheTransBuffer.h" -#include "dataTransceiver.h" -#include "tensorrt_llm/common/envUtils.h" -#include "tensorrt_llm/executor/cache_transmission/cacheSplitConcat.h" - -#include - -namespace tensorrt_llm::batch_manager -{ -struct TransceiverTag -{ - enum class Id : uint64_t - { - REQUEST_SEND = 1, - TERMINATION = 2 - }; - - static constexpr int32_t kID_TAG{19}; - static constexpr int32_t kINFO_SIZE_TAG{22}; - static constexpr int32_t kINFO_TAG{32}; -}; - -using BaseCacheFormatter = kv_cache_manager::BaseCacheFormatter; - -class CacheSenderImpl : public DataSender, public TransceiverTag -{ -public: - using SizeType32 = tensorrt_llm::runtime::SizeType32; - - CacheSenderImpl(executor::kv_cache::ConnectionManager* manager, executor::kv_cache::CacheState selfCacheState, - SizeType32 selfIndex, std::unique_ptr formatter); - - [[nodiscard]] RequestInfo recvRequestInfo() override; - - void sendSync(LlmRequest const& llmRequest) override; - - [[nodiscard]] executor::kv_cache::CommState const& getCommState() const override; - - void setCommState(executor::kv_cache::CommState commState) override; - - [[nodiscard]] size_t getCounterpartsCount(LlmRequest::RequestIdType requestId) const override; - - void release(LlmRequest::RequestIdType requestId) override; - -private: - executor::kv_cache::ConnectionManager* mManager; - std::map mRequestToSession; - executor::DataTransceiverState mSelfState; - std::unique_ptr mFormatter; - std::mutex mMtxForMap; - runtime::BufferManager mBufferManager; - std::ofstream mMeasuresFile; -}; - -class CacheReceiverImpl : public DataReceiver, public TransceiverTag -{ -public: - using SizeType32 = tensorrt_llm::runtime::SizeType32; - - CacheReceiverImpl(executor::kv_cache::ConnectionManager* manager, executor::kv_cache::CacheState selfCacheState, - SizeType32 selfIndex, std::unique_ptr formatter); - - TransferSession sendRequestInfo(LlmRequest const& llmRequest) override; - - void receiveSync(TransferSession& session) override; - -private: - struct ReceiveCacheResource - { - runtime::BufferManager mBufferManager; - runtime::CudaEvent mCudaEvent; - - ReceiveCacheResource(runtime::BufferManager&& bufferManager, runtime::CudaEvent&& cudaEvent) - : mBufferManager(bufferManager) - , mCudaEvent(std::move(cudaEvent)) - { - } - }; - - static void sendRequestInfo(executor::kv_cache::Connection const* connection, RequestInfo const& info); - - [[nodiscard]] std::unique_ptr const& getReceiveCacheResource(LlmRequest const& llmRequest); - - executor::kv_cache::ConnectionManager* mManager; - executor::DataTransceiverState mSelfState; - std::unique_ptr mFormatter; - std::unordered_map> mProcessToResources; - std::mutex mProcessIoResouceMutex; - std::ofstream mMeasuresFile; - std::mutex mMeasuresFileMutex; -}; - -} // namespace tensorrt_llm::batch_manager diff --git a/cpp/tensorrt_llm/executor/cache_transmission/ucx_utils/connection.cpp b/cpp/tensorrt_llm/executor/cache_transmission/ucx_utils/connection.cpp index 07a80be4a72..1c8cc4f6ad3 100644 --- a/cpp/tensorrt_llm/executor/cache_transmission/ucx_utils/connection.cpp +++ b/cpp/tensorrt_llm/executor/cache_transmission/ucx_utils/connection.cpp @@ -18,7 +18,7 @@ #include "ucxCacheCommunicator.h" #if ENABLE_UCX -#include "tensorrt_llm/batch_manager/dataTransceiverImpl.h" +#include "tensorrt_llm/batch_manager/dataTransceiver.h" #include "tensorrt_llm/common/cudaUtils.h" #include "tensorrt_llm/executor/cache_transmission/ucx_utils/connection.h" diff --git a/cpp/tests/unit_tests/executor/ucxCommTest.cpp b/cpp/tests/unit_tests/executor/ucxCommTest.cpp index 5895ac09472..febe61c7145 100644 --- a/cpp/tests/unit_tests/executor/ucxCommTest.cpp +++ b/cpp/tests/unit_tests/executor/ucxCommTest.cpp @@ -26,7 +26,6 @@ #include "tensorrt_llm/batch_manager/cacheFormatter.h" #include "tensorrt_llm/batch_manager/cacheTransceiver.h" -#include "tensorrt_llm/batch_manager/dataTransceiverImpl.h" #include "tensorrt_llm/batch_manager/kvCacheManager.h" #include "tensorrt_llm/common/assert.h" #include "tensorrt_llm/common/cudaUtils.h" @@ -45,7 +44,6 @@ #include #include #include -#include #include #include diff --git a/cpp/tests/unit_tests/multi_gpu/cacheTransceiverTest.cpp b/cpp/tests/unit_tests/multi_gpu/cacheTransceiverTest.cpp index acd3304fcf6..21f1bae1656 100644 --- a/cpp/tests/unit_tests/multi_gpu/cacheTransceiverTest.cpp +++ b/cpp/tests/unit_tests/multi_gpu/cacheTransceiverTest.cpp @@ -26,7 +26,6 @@ #include "tensorrt_llm/batch_manager/cacheFormatter.h" #include "tensorrt_llm/batch_manager/cacheTransceiver.h" -#include "tensorrt_llm/batch_manager/dataTransceiverImpl.h" #include "tensorrt_llm/batch_manager/kvCacheManager.h" #include "tensorrt_llm/common/assert.h" #include "tensorrt_llm/common/cudaUtils.h" @@ -156,10 +155,10 @@ TEST_F(CacheConfigTest, EqualTo) // MockTransceiverTest // --------------------------------------- -class MockDataSender : public DataSender +class MockCacheSender : public CacheSender { public: - MockDataSender() + MockCacheSender() { ON_CALL(*this, getCommState).WillByDefault(ReturnRef(mState)); ON_CALL(*this, recvRequestInfo) @@ -181,9 +180,9 @@ class MockDataSender : public DataSender static texec::kv_cache::CommState mState; }; -texec::kv_cache::CommState MockDataSender::mState; +texec::kv_cache::CommState MockCacheSender::mState; -class MockDataReceiver : public DataReceiver +class MockCacheReceiver : public CacheReceiver { public: MOCK_METHOD(TransferSession, sendRequestInfo, (LlmRequest const&), (override)); @@ -214,7 +213,7 @@ TEST_F(MockTransceiverTest, MpiResponderBasic) { GTEST_SKIP() << "mpirun with procs<=2 is required to run this test."; } - auto sender = std::make_unique(); + auto sender = std::make_unique(); EXPECT_CALL(*sender, recvRequestInfo) .WillOnce(Return(RequestInfo{0, texec::DataTransceiverState{ @@ -224,7 +223,7 @@ TEST_F(MockTransceiverTest, MpiResponderBasic) EXPECT_CALL(*sender, getCounterpartsCount).WillOnce(Return(1)); EXPECT_CALL(*sender, release).WillOnce(Return()); - DataResponder responder{std::move(sender)}; + CacheSender responder{std::move(sender)}; auto request = makeLlmRequest(0); auto future = responder.respondAndSendAsync(*request); future.get(); @@ -237,14 +236,14 @@ TEST_F(MockTransceiverTest, MpiRequesterBasic) { GTEST_SKIP() << "mpirun with procs<=2 is required to run this test."; } - auto receiver = std::make_unique(); + auto receiver = std::make_unique(); auto state = std::make_unique(); state->setCommState(texec::kv_cache::CommState{std::vector{0}}); EXPECT_CALL(*receiver, sendRequestInfo) .WillOnce(Return(TransferSession({nullptr}, DataContext{0}, *state, *state, tensorrt_llm::runtime::BufferManager{std::make_shared()}, nullptr))); EXPECT_CALL(*receiver, receiveSync).WillOnce(Return()); - DataRequester requester{std::move(receiver)}; + CacheReceiver requester{std::move(receiver)}; auto request = makeLlmRequest(0); auto stats = texec::ContextPhaseParams({}, 0, state.release(), std::nullopt); request->setContextPhaseParams(std::move(stats)); @@ -394,13 +393,13 @@ class SymmetricalCacheTest : public ::testing::Test // NOLINT(cppcoreguidelines- mCacheTransBufferManager = std::make_unique(mManager.get(), maxNumTokens); if (isSender) { - mSender = std::make_unique( + mSender = std::make_unique( std::make_unique(mConnectionManager.get(), *mCacheState, mlocalRank, std::make_unique(mManager.get(), mCacheTransBufferManager.get()))); } else { - mRequester = std::make_unique( + mRequester = std::make_unique( std::make_unique(mConnectionManager.get(), *mCacheState, mlocalRank, std::make_unique(mManager.get(), mCacheTransBufferManager.get()))); } @@ -457,8 +456,8 @@ class SymmetricalCacheTest : public ::testing::Test // NOLINT(cppcoreguidelines- SizeType32 mMaxNumSequences{}; std::unique_ptr mManager; std::unique_ptr mCacheTransBufferManager; - std::unique_ptr mSender; - std::unique_ptr mRequester; + std::unique_ptr mSender; + std::unique_ptr mRequester; std::unique_ptr mCacheState; std::unique_ptr mContextCommState; std::vector> mFutures; @@ -764,12 +763,12 @@ class AsymmetricalCacheTest : public ::testing::TestWithParam(std::make_unique( + mSender = std::make_unique(std::make_unique( mConnectionManager.get(), *mCacheState, mRankInInstance, makeFormatter())); } else { - mRequester = std::make_unique(std::make_unique( + mRequester = std::make_unique(std::make_unique( mConnectionManager.get(), *mCacheState, mRankInInstance, makeFormatter())); } @@ -1112,8 +1111,8 @@ class AsymmetricalCacheTest : public ::testing::TestWithParam mManager; std::unique_ptr mCacheTransBufferManager; - std::unique_ptr mSender; - std::unique_ptr mRequester; + std::unique_ptr mSender; + std::unique_ptr mRequester; std::unique_ptr mCacheState; std::unique_ptr mContextCacheState; std::unique_ptr mContextCommState; From 3965e5946d314ff87d6d3de8d214f105a95c692c Mon Sep 17 00:00:00 2001 From: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com> Date: Tue, 12 Aug 2025 13:05:33 -0700 Subject: [PATCH 3/6] Initial iteration for supporting block hash transfer Signed-off-by: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com> --- .../batch_manager/cacheTransceiver.h | 1 + .../batch_manager/kvCacheManager.h | 38 +++++ .../tensorrt_llm/batch_manager/kvCacheUtils.h | 22 +++ .../tensorrt_llm/batch_manager/llmRequest.h | 10 -- .../batch_manager/cacheFormatter.cpp | 32 ++-- .../batch_manager/cacheFormatter.h | 157 +++++++++++------- .../batch_manager/cacheTransceiver.cpp | 6 +- .../batch_manager/dataTransceiver.cpp | 146 ++++++++-------- .../batch_manager/dataTransceiver.h | 20 ++- .../batch_manager/kvCacheManager.cpp | 48 ++++++ .../batch_manager/mlaCacheFormatter.cpp | 4 +- cpp/tensorrt_llm/common/envUtils.cpp | 6 - .../agent_utils/connection.cpp | 1 - .../pybind/batch_manager/kvCacheManager.cpp | 1 + tensorrt_llm/_torch/pyexecutor/py_executor.py | 6 +- 15 files changed, 323 insertions(+), 175 deletions(-) diff --git a/cpp/include/tensorrt_llm/batch_manager/cacheTransceiver.h b/cpp/include/tensorrt_llm/batch_manager/cacheTransceiver.h index 2e2cbe13d17..b9a2fef9087 100644 --- a/cpp/include/tensorrt_llm/batch_manager/cacheTransceiver.h +++ b/cpp/include/tensorrt_llm/batch_manager/cacheTransceiver.h @@ -119,6 +119,7 @@ class CacheTransceiver : public BaseCacheTransceiver mMpiGroupTPInDPComm; executor::kv_cache::CommState const* mCommState; std::unique_ptr mCacheState; + // std::unique_ptr mCacheServer; std::unique_ptr mManager; std::optional mCacheTransceiverConfig; std::unique_ptr mCacheTransBufferManager; diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h index d5da697535f..1a96ec7548f 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h @@ -115,6 +115,8 @@ struct BlockKey // Each extra key is a pair of (mm_hash, start_offset_in_block) std::vector extraKeys; + size_t hash{0}; + BlockKey() = default; explicit BlockKey(VecTokens const& tokens, std::optional loraTaskId = std::nullopt) @@ -127,6 +129,11 @@ struct BlockKey } } + explicit BlockKey(size_t hash) + : hash{hash} + { + } + explicit BlockKey(bool usesExtraIds, std::optional loraTaskId, VecUniqueTokens uniqueTokens, std::vector extraKeys = {}) : usesExtraIds{usesExtraIds} @@ -164,6 +171,10 @@ struct BlockKeyHasher std::size_t operator()(BlockKey const& blockKey, std::size_t parentHash = 0) const noexcept { + if (blockKey.hash != 0) + { + return blockKey.hash; + } return hash(blockKey, parentHash); } }; @@ -568,6 +579,8 @@ class WindowBlockManager void storeNewBlock(GenerationRequest& sequence, OptionalRef llmRequest); + void pinBlocks(GenerationRequest& sequence); + //! \brief Release blocks of the sequence. void releaseBlocks(GenerationRequest& sequence); @@ -739,6 +752,9 @@ class WindowBlockManager return 0; } + [[nodiscard]] std::optional> findBlocksInReuseTreeByHashes( + std::vector const& hashes) const; + private: //! \brief Add single block to beam of sequence and mAllocatedBlocksPerSeq. void addBlockToBeam(BlockPtr& block, GenerationRequest& sequence, SizeType32 beamIdx); @@ -890,6 +906,8 @@ class BlockManager void schedulingReleaseBlocks(LlmRequest::RequestIdType requestId); + void pinBlocks(GenerationRequest& sequence); + void releaseLastBlock(GenerationRequest& sequence, SizeType32 windowSize); void setOffsets(kernels::KVCacheIndex* offsetsPtr, nvinfer1::Dims const& offsetsShape, SizeType32 beamIdx, @@ -1074,6 +1092,12 @@ class BlockManager return mWindowBlockManagers.at(windowSize).getBlockById(blockId); } + [[nodiscard]] std::optional> findBlocksInReuseTreeByHashes( + std::vector const& hashes, SizeType32 windowSize) const + { + return mWindowBlockManagers.at(windowSize).findBlocksInReuseTreeByHashes(hashes); + } + [[nodiscard]] SizeType32 getNumPrimaryBlocks() const { return sumWindows([](auto const& manager) { return manager.getNumPrimaryBlocks(); }); @@ -1217,6 +1241,8 @@ class BaseKVCacheManager [[nodiscard]] virtual SizeType32 getRemainingBlocksToCompletion(LlmRequest const& req, SizeType32 windowSize) const = 0; + virtual void pinBlocks(LlmRequest::RequestIdType requestId) = 0; + /// @brief Increase size for request at seqSlotIdx. Allocate new KV cache block(s) if needed. virtual void addToken(LlmRequest::RequestIdType requestId) = 0; @@ -1354,6 +1380,10 @@ class BaseKVCacheManager [[nodiscard]] virtual SizeType32 getMaxCapacityBatchSize(SizeType32 inputLength, SizeType32 outputLength) const = 0; [[nodiscard]] virtual CacheType getCacheType() const = 0; + + [[nodiscard]] virtual std::optional> findBlocksInReuseTreeByHashes( + std::vector const& hashes, SizeType32 windowSize) const + = 0; }; class KVCacheManager : public BaseKVCacheManager @@ -1605,6 +1635,8 @@ class KVCacheManager : public BaseKVCacheManager [[nodiscard]] static SizeType32 calculateMaxBlockRequirements(SizeType32 inputLength, SizeType32 outputLength, SizeType32 sinkTokenLength, SizeType32 windowSize, SizeType32 beamWidth, SizeType32 tokensPerBlock); + void pinBlocks(LlmRequest::RequestIdType requestId); + /// @brief Calculates the number of kv-cache blocks that a sequence will require, for a single beam. /// /// @param sequenceLength The total length of the sequence (input and output). @@ -1642,6 +1674,12 @@ class KVCacheManager : public BaseKVCacheManager mBlockManager.flushIterationEvents(); } + std::optional> findBlocksInReuseTreeByHashes( + std::vector const& hashes, SizeType32 windowSize) const override + { + return mBlockManager.findBlocksInReuseTreeByHashes(hashes, windowSize); + } + /// @brief Finds the maximum attention window that can be used on a sequence, given some kv-cache block capacity. /// /// @param inputLength The number of input tokens in the sequence. diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheUtils.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheUtils.h index 2aebf77b96d..03d81cd6739 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheUtils.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheUtils.h @@ -48,6 +48,28 @@ class BlockRange return BlockRange(cacheManager, blockIds, requestId); } + static BlockRange fromReuseTree(BaseKVCacheManager const& cacheManager, std::vector const& allBlockHashes, + std::vector const& requestedBlockHashes) + { + auto const windowSize = firstWindowSize(cacheManager); + auto lastBlock = *cacheManager.findBlocksInReuseTreeByHashes(allBlockHashes, windowSize); + // TODO: handle the case where the last block is not found + TLLM_CHECK_WITH_INFO(lastBlock, "Couldn't find the requested block in the reuse tree"); + // Assume the the last block is the requested block + std::vector blockIds; + for (auto const& hash : requestedBlockHashes) + { + if (lastBlock->getHash() != hash) + { + return BlockRange(cacheManager, {}, 0); + } + blockIds.emplace_back(lastBlock->getBlockId()); + lastBlock = lastBlock->getPrevBlock(); + TLLM_CHECK_WITH_INFO(lastBlock, "Last block is not found"); + } + return BlockRange(cacheManager, blockIds, 0); + } + BlockRange(runtime::ITensor::SharedPtr pool, std::vector const& blockIds) // Only used in tests : mManager{nullptr} , mPool{std::move(pool)} diff --git a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h index f069e3ac7f5..6602af405e5 100644 --- a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h +++ b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h @@ -1831,16 +1831,6 @@ class GenericLlmRequest } } - void setRequestedBlockHashes(std::vector hashes) - { - mRequestedBlockHashes = std::move(hashes); - } - - [[nodiscard]] std::vector const& getRequestedBlockHashes() const - { - return mRequestedBlockHashes; - } - void setIsDummyRequest(bool isDummyRequest) { mIsDummyRequest = isDummyRequest; diff --git a/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp b/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp index e73e0f15411..f389c09b7c7 100644 --- a/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp +++ b/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp @@ -39,37 +39,33 @@ namespace tensorrt_llm::batch_manager::kv_cache_manager { -BlockRange getBlockRangeForSending(BaseKVCacheManager* cacheManager, LlmRequest const& llmRequest) +BlockRange getBlockRangeForSending(BaseKVCacheManager* cacheManager, LlmRequest const& llmRequest, + std::vector const& allBlockHashes, std::vector const& requestedBlockHashes) { - size_t requestBlockNum = llmRequest.getRequestedBlockHashes().size(); + size_t requestBlockNum = requestedBlockHashes.size(); constexpr SizeType32 beam{0}; - auto blockRange = BlockRange::fromAllBlockIds(*cacheManager, llmRequest.mRequestId, beam); auto poolNum = cacheManager->getBlockManager().getNumPools(); - if (poolNum > 1 || common::getEnvDisableSelectiveCacheTransfer()) + if (poolNum > 1 || !cacheManager->isEnableBlockReuse()) { - // disable selective cache transfer for poolNum > 1 + auto blockRange = BlockRange::fromAllBlockIds(*cacheManager, llmRequest.mRequestId, beam); return blockRange; } - if (requestBlockNum < blockRange.size() && requestBlockNum > 0) - { - // handle block reuse, the prefix blocks are reused - // TODO(zhengd): pass the hashes directly instead of from llmRequest; use hash instead of block num - auto const& ids = blockRange.getBlockIds(); - blockRange.setBlockIds({ids.end() - requestBlockNum, ids.end()}); - } - return blockRange; + return BlockRange::fromReuseTree(*cacheManager, allBlockHashes, requestedBlockHashes); } BlockRange getBlockRangeForReceiving(BaseKVCacheManager* cacheManager, LlmRequest const& llmRequest) { auto poolNum = cacheManager->getBlockManager().getNumPools(); - if (poolNum > 1 || common::getEnvDisableSelectiveCacheTransfer()) + if (poolNum == 1 && cacheManager->isEnableBlockReuse()) + { + return BlockRange::fromNewlyAllocatedBlockIds(*cacheManager, llmRequest.mRequestId); + } + else { constexpr SizeType32 beam{0}; return BlockRange::fromAllBlockIds(*cacheManager, llmRequest.mRequestId, beam); } - return BlockRange::fromNewlyAllocatedBlockIds(*cacheManager, llmRequest.mRequestId); } bool CacheFormatter::needSendCache( @@ -155,13 +151,17 @@ void CacheFormatter::format(TransferSession& session) auto const& selfConfig = session.getSelfState().getCacheState().value(); auto const& destConfig = session.getOtherState().getCacheState().value(); auto const selfIdx = session.getSelfState().getCommState().value().getSelfIdx(); + auto& allBlockHashes = session.getAllBlockHashes(); + auto& requestedBlockHashes = session.getRequestedBlockHashes(); + TLLM_CHECK_WITH_INFO(allBlockHashes.size() >= requestedBlockHashes.size(), + "allBlockHashes must be greater than or equal to requestedBlockHashes"); auto& bufferManager = session.getBufferManager(); if (!needSendCache(selfConfig, destConfig, selfIdx)) { return; } auto& blockManager = mCacheManager->getBlockManager(); - auto blockRange = getBlockRangeForSending(mCacheManager, llmRequest); + auto blockRange = getBlockRangeForSending(mCacheManager, llmRequest, allBlockHashes, requestedBlockHashes); auto const numPools = blockManager.getNumPools(); // TODO(oargov): are we sure the other side has the same number of pools? this might not hold for pp_size>1... diff --git a/cpp/tensorrt_llm/batch_manager/cacheFormatter.h b/cpp/tensorrt_llm/batch_manager/cacheFormatter.h index ac675848b41..0aaeeb742ec 100644 --- a/cpp/tensorrt_llm/batch_manager/cacheFormatter.h +++ b/cpp/tensorrt_llm/batch_manager/cacheFormatter.h @@ -34,10 +34,86 @@ namespace tensorrt_llm::batch_manager::kv_cache_manager { -BlockRange getBlockRangeForSending(BaseKVCacheManager* cacheManager, LlmRequest const& llmRequest); +BlockRange getBlockRangeForSending(BaseKVCacheManager* cacheManager, LlmRequest const& llmRequest, + std::vector const& allBlockHashes, std::vector const& requestedBlockHashes); BlockRange getBlockRangeForReceiving(BaseKVCacheManager* cacheManager, LlmRequest const& llmRequest); +class KvCacheMeasureHelper +{ +public: + struct Measure + { + double delay; // from last token (ctx) or arrival time (gen), in ms + double duration; // in ms + double bandwidth; // in Gbps + }; + + KvCacheMeasureHelper(std::string output_path) + : mOutputPath(std::move(output_path)) + { + } + + void markAsSender(bool isSender) + { + mIsSender = isSender; + } + + void appendKVCacheTransfer(LlmRequest::RequestIdType requestId, double delay, double duration, size_t size) + { + auto bandwidth = size * 8 / (duration / 1000) / 1e9; + if (mOutputPath.empty()) + { + return; + } + + std::lock_guard lock(mMutex); + mRequestKVCacheTranfserMeasure[requestId].emplace_back(Measure{delay, duration, bandwidth}); + } + + ~KvCacheMeasureHelper() + { + if (!mRequestKVCacheTranfserMeasure.empty() && !mOutputPath.empty()) + { + TLLM_CHECK(mIsSender.has_value()); + auto rank = mpi::MpiComm::world().getRank(); + std::string outFilePath + = mOutputPath + "rank_" + std::to_string(rank) + "_" + (mIsSender.value() ? "send" : "recv") + ".csv"; + std::ofstream outFile(outFilePath); + + TLLM_CHECK_WITH_INFO(outFile.is_open(), "Cannot write to file " + outFilePath); + + size_t numTransferMeasure = mRequestKVCacheTranfserMeasure.begin()->second.size(); + + outFile << "RequestID"; + for (size_t i = 0; i < numTransferMeasure; i++) + { + outFile << ",Delay(ms),Duration(ms),Bandwidth(Gbps)"; + } + outFile << '\n'; + + for (auto const& [requestID, measures] : mRequestKVCacheTranfserMeasure) + { + outFile << requestID; + + for (auto const& measure : measures) + { + outFile << "," << measure.delay << "," << measure.duration << "," << measure.bandwidth; + } + outFile << '\n'; + } + + outFile.close(); + } + } + +private: + std::map> mRequestKVCacheTranfserMeasure; + std::string mOutputPath; + std::mutex mMutex; + std::optional mIsSender; +}; + using DataContext = tensorrt_llm::executor::kv_cache::DataContext; using Connection = tensorrt_llm::executor::kv_cache::Connection; using SizeType32 = tensorrt_llm::runtime::SizeType32; @@ -54,12 +130,15 @@ class TransferSession TransferSession(std::vector connections, DataContext dataContext, executor::DataTransceiverState const& selfState, executor::DataTransceiverState otherState, - runtime::BufferManager const& bufferManager, LlmRequest const* llmRequest = nullptr) + runtime::BufferManager const& bufferManager, std::vector allBlockHashes = {}, + std::vector requestedBlockHashes = {}, LlmRequest const* llmRequest = nullptr) : mConnections(std::move(connections)) , mDataContext(dataContext) , mSelfState(&selfState) , mOtherState(std::move(otherState)) , mBufferManager(&bufferManager) + , mAllBlockHashes(std::move(allBlockHashes)) + , mRequestedBlockHashes(std::move(requestedBlockHashes)) , mRequest(llmRequest) { TLLM_CHECK(!mConnections.empty()); @@ -156,15 +235,27 @@ class TransferSession outFile << '\n' << std::flush; } + [[nodiscard]] std::vector const& getAllBlockHashes() const + { + return mAllBlockHashes; + } + + [[nodiscard]] std::vector const& getRequestedBlockHashes() const + { + return mRequestedBlockHashes; + } + private: std::vector mConnections; DataContext mDataContext; executor::DataTransceiverState const* mSelfState; // stored in CacheReceiver/CacheSender executor::DataTransceiverState mOtherState; runtime::BufferManager const* mBufferManager; - LlmRequest const* mRequest; std::vector mMeasures; bool mRecordMeasure{false}; + std::vector mAllBlockHashes; + std::vector mRequestedBlockHashes; + LlmRequest const* mRequest; }; // Used to support the cache transmission with different layouts and different protocols. @@ -207,66 +298,6 @@ class BaseCacheFormatter virtual ~BaseCacheFormatter() = default; }; -class KvCacheMeasureHelper -{ -public: - KvCacheMeasureHelper(std::string output_path) - : mOutputPath(std::move(output_path)) - { - } - - void appendKVCacheTransfer(LlmRequest::RequestIdType requestId, double duration, size_t size) - { - auto bandwidth = size * 8 / (duration / 1000) / 1e9; - if (mOutputPath.empty()) - { - return; - } - - std::lock_guard lock(mMutex); - mRequestKVCacheTranfserMeasure[requestId].emplace_back(duration, bandwidth); - } - - ~KvCacheMeasureHelper() - { - if (!mRequestKVCacheTranfserMeasure.empty() && !mOutputPath.empty()) - { - auto rank = mpi::MpiComm::world().getRank(); - std::string outFilePath = mOutputPath + "rank_" + std::to_string(rank) + ".txt"; - std::ofstream outFile(outFilePath); - - TLLM_CHECK_WITH_INFO(outFile.is_open(), "Cannot write to file " + outFilePath); - - size_t numTransferMeasure = mRequestKVCacheTranfserMeasure.begin()->second.size(); - - outFile << "RequestID"; - for (size_t i = 0; i < numTransferMeasure; i++) - { - outFile << ",TimeDuration,Bandwidth"; - } - outFile << '\n'; - - for (auto const& [requestID, measures] : mRequestKVCacheTranfserMeasure) - { - outFile << requestID; - - for (auto const& [time, bandwidth] : measures) - { - outFile << "," << time << "," << bandwidth; - } - outFile << '\n'; - } - - outFile.close(); - } - } - -private: - std::map>> mRequestKVCacheTranfserMeasure; - std::string mOutputPath; - std::mutex mMutex; -}; - // Simple cache block copy. Because it does not involve data splitting or merging, it performs best when the // parallel topology is completely identical, making it the preferred method. class CacheFormatter final : public BaseCacheFormatter diff --git a/cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp b/cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp index f22ee779ce8..5ee1a07bcb7 100644 --- a/cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp +++ b/cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp @@ -106,12 +106,8 @@ CacheTransceiver::CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheMa : mMpiGroupComm(std::addressof(tensorrt_llm::mpi::MpiComm::session())) , mCacheTransceiverConfig{cacheTransceiverConfig} { + // mCacheServer = std::make_unique(CacheServerConfig{cacheManager}); using tensorrt_llm::batch_manager::kv_cache_manager::CacheFormatter; - if (worldConfig.isPipelineParallel()) - { - mMpiGroupPipeParaComm = std::make_shared( - mMpiGroupComm->split(worldConfig.getTensorParallelRank(), worldConfig.getPipelineParallelRank())); - } if (worldConfig.isTensorParallel()) { mMpiGroupTensorParaComm = std::make_shared( diff --git a/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp b/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp index 136692b0363..d48d02de8cc 100644 --- a/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp +++ b/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp @@ -62,17 +62,19 @@ RequestInfo::RequestInfo(LlmRequest::RequestIdType requestId, executor::DataTran { } -RequestInfo::RequestInfo( - LlmRequest::RequestIdType requestId, std::vector blockHashes, executor::DataTransceiverState transState) +RequestInfo::RequestInfo(LlmRequest::RequestIdType requestId, std::vector allBlockHashes, + executor::DataTransceiverState transState, std::vector requestedBlockHashes) : mRequestId{requestId} - , mBlockHashes{std::move(blockHashes)} + , mAllBlockHashes{std::move(allBlockHashes)} + , mRequestedBlockHashes{std::move(requestedBlockHashes)} , mTransState{std::move(transState)} { } bool RequestInfo::operator==(RequestInfo const& rhs) const { - return mRequestId == rhs.mRequestId && mBlockHashes == rhs.mBlockHashes && mTransState == rhs.mTransState; + return mRequestId == rhs.mRequestId && mAllBlockHashes == rhs.mAllBlockHashes + && mRequestedBlockHashes == rhs.mRequestedBlockHashes && mTransState == rhs.mTransState; } LlmRequest::RequestIdType RequestInfo::getRequestId() const noexcept @@ -89,7 +91,8 @@ void RequestInfo::serialize(RequestInfo const& requestInfo, std::ostream& os) { namespace su = executor::serialize_utils; su::serialize(requestInfo.mRequestId, os); - su::serialize(requestInfo.mBlockHashes, os); + su::serialize(requestInfo.mAllBlockHashes, os); + su::serialize(requestInfo.mRequestedBlockHashes, os); su::serialize(requestInfo.mTransState, os); } @@ -97,9 +100,10 @@ RequestInfo RequestInfo::deserialize(std::istream& is) { namespace su = executor::serialize_utils; auto requestId = su::deserialize(is); - auto blockHashes = su::deserialize(is); + auto allBlockHashes = su::deserialize(is); + auto requestedBlockHashes = su::deserialize(is); auto transState = su::deserialize(is); - return RequestInfo{requestId, std::move(blockHashes), std::move(transState)}; + return RequestInfo{requestId, std::move(allBlockHashes), std::move(transState), std::move(requestedBlockHashes)}; } std::size_t RequestInfo::serializedSize(RequestInfo const& requestInfo) @@ -107,7 +111,8 @@ std::size_t RequestInfo::serializedSize(RequestInfo const& requestInfo) namespace su = executor::serialize_utils; std::size_t totalSize = 0; totalSize += su::serializedSize(requestInfo.mRequestId); - totalSize += su::serializedSize(requestInfo.mBlockHashes); + totalSize += su::serializedSize(requestInfo.mAllBlockHashes); + totalSize += su::serializedSize(requestInfo.mRequestedBlockHashes); totalSize += su::serializedSize(requestInfo.mTransState); return totalSize; } @@ -138,8 +143,8 @@ class CacheSender::Impl { { std::unique_lock lkResp(mSenderMutex); - mReadyResponses.emplace( - llmRequest.mRequestId, Response{std::addressof(llmRequest), std::move(promise)}); + mReadyRequests.emplace(llmRequest.mRequestId, std::addressof(llmRequest)); + mReadyPromises.emplace(llmRequest.mRequestId, std::move(promise)); } std::unique_lock lkCond(mCondMutex); mAnyReady = true; @@ -178,25 +183,18 @@ class CacheSender::Impl auto* agentConnectionManager = dynamic_cast(mManager); bool isAgent = agentConnectionManager != nullptr; - auto agentRecvFun = [&](RequestInfo& requestInfo) - { - auto const* connection = agentConnectionManager->recvConnectionAndRequestInfo(requestInfo); - return connection; - }; TransceiverTag::Id id; RequestInfo info; - auto const* connection = isAgent ? agentRecvFun(info) + auto const* connection = isAgent ? agentConnectionManager->recvConnectionAndRequestInfo(info) : mManager->recvConnect(DataContext{TransceiverTag::kID_TAG}, &id, sizeof(id)); if (!isAgent) { TLLM_CHECK(id == TransceiverTag::Id::REQUEST_SEND); std::uint64_t infoSize{0}; - connection->recv( - executor::kv_cache::DataContext{TransceiverTag::kINFO_SIZE_TAG}, &infoSize, sizeof(infoSize)); + connection->recv(DataContext{TransceiverTag::kINFO_SIZE_TAG}, &infoSize, sizeof(infoSize)); std::string serializedInfo; serializedInfo.resize(infoSize); - connection->recv( - executor::kv_cache::DataContext{TransceiverTag::kINFO_TAG}, serializedInfo.data(), infoSize); + connection->recv(DataContext{TransceiverTag::kINFO_TAG}, serializedInfo.data(), infoSize); std::istringstream iss(serializedInfo); info = RequestInfo::deserialize(iss); } @@ -219,7 +217,8 @@ class CacheSender::Impl if (it == mRequestToSession.end()) { auto session = TransferSession(std::vector(peerRelativeRanks.size(), nullptr), - DataContext{tagFromRequestId(requestId)}, mSelfState, info.getTransState(), mBufferManager); + DataContext{tagFromRequestId(requestId)}, mSelfState, info.getTransState(), mBufferManager, + info.getAllBlockHashes(), info.getRequestedBlockHashes()); it = mRequestToSession.emplace(requestId, std::move(session)).first; } it->second.setConnection(peerIdx, connection); @@ -242,25 +241,42 @@ class CacheSender::Impl } private: - struct Response + void sendAndRemoveResponse(RequestIdType id) noexcept { - LlmRequest* mRequest; - std::promise mPromise; - }; + LlmRequest* request = nullptr; + std::promise promise; + + // Extract request and promise + { + std::unique_lock lkResp(mSenderMutex); + auto requestIt = mReadyRequests.find(id); + auto promiseIt = mReadyPromises.find(id); + + if (requestIt != mReadyRequests.end() && promiseIt != mReadyPromises.end()) + { + request = requestIt->second; + promise = std::move(promiseIt->second); + mReadyRequests.erase(requestIt); + mReadyPromises.erase(promiseIt); + } + } + + if (request == nullptr) + { + return; // Request not found + } - void sendAndRemoveResponse(RequestIdType id, Response resp) noexcept - { try { TLLM_CUDA_CHECK(cudaSetDevice(mDeviceId)); - sendSync(*resp.mRequest); + sendSync(*request); release(id); - resp.mPromise.set_value(); + promise.set_value(); } catch (std::exception const& e) { TLLM_LOG_ERROR("Exception in sendAndRemoveResponse: %s ", e.what()); - resp.mPromise.set_exception(std::current_exception()); + promise.set_exception(std::current_exception()); } } @@ -281,12 +297,10 @@ class CacheSender::Impl { break; } - std::vector blockHashes; - if (!isSending() && !mReadyResponses.empty()) + if (!isSending() && !mReadyPromises.empty()) { auto const& requestInfo = recvRequestInfo(); auto reqId = requestInfo.getRequestId(); - blockHashes = requestInfo.getBlockHashes(); mCurrentRequest = reqId; if (mRemainSendCount.find(reqId) == mRemainSendCount.end()) @@ -294,32 +308,26 @@ class CacheSender::Impl mRemainSendCount[reqId] = getCounterpartsCount(reqId); } } - auto it = getCurrentResponse(); - if (it != mReadyResponses.end()) + auto reqId = mCurrentRequest.value(); + auto it = mReadyPromises.find(reqId); + if (it != mReadyPromises.end()) { - auto reqId = mCurrentRequest.value(); auto count = --mRemainSendCount[reqId]; TLLM_CHECK(count >= 0); if (count == 0) { mRemainSendCount.erase(reqId); - // TODO(zhengd): pass the hashes directly instead of update llmRequest - auto llmRequest = it->second.mRequest; - llmRequest->setRequestedBlockHashes(std::move(blockHashes)); - if (common::getEnvParallelCacheSend()) { // TODO: Use a thread pool and check for thread safety. - std::thread( - &CacheSender::Impl::sendAndRemoveResponse, this, it->first, std::move(it->second)) - .detach(); + std::thread(&CacheSender::Impl::sendAndRemoveResponse, this, reqId).detach(); } else { - CacheSender::Impl::sendAndRemoveResponse(it->first, std::move(it->second)); + CacheSender::Impl::sendAndRemoveResponse(reqId); } - removeResponse(it); + removeResponse(reqId); } mCurrentRequest = std::nullopt; } @@ -327,8 +335,8 @@ class CacheSender::Impl { TLLM_CHECK_WITH_INFO(!mCurrentRequest.has_value(), "This executor does not have a prepared KV cache for request ID: %zu, and the " - "mReadyResponses size is: %zu. mpi rank :%d ", - mCurrentRequest.value(), mReadyResponses.size(), mpi::MpiComm::world().getRank()); + "mReadyPromises size is: %zu. mpi rank :%d ", + mCurrentRequest.value(), mReadyPromises.size(), mpi::MpiComm::world().getRank()); std::unique_lock lk(mCondMutex); mSenderCv.wait(lk, [this]() { return (mAnyReady || mTerminate); }); } @@ -337,9 +345,9 @@ class CacheSender::Impl catch (std::exception const& err) { TLLM_LOG_ERROR("Exception in CacheSender response: %s", err.what()); - for (auto& it : mReadyResponses) + for (auto& it : mReadyPromises) { - it.second.mPromise.set_exception(std::current_exception()); + it.second.set_exception(std::current_exception()); } } } @@ -355,13 +363,14 @@ class CacheSender::Impl mSenderCv.notify_all(); } - void removeResponse(std::map::iterator it) + void removeResponse(RequestIdType id) { { std::unique_lock lkResp(mSenderMutex); - mReadyResponses.erase(it); + mReadyRequests.erase(id); + mReadyPromises.erase(id); } - if (mReadyResponses.empty()) + if (mReadyRequests.empty()) { std::unique_lock lkCond(mCondMutex); mAnyReady = false; @@ -378,15 +387,18 @@ class CacheSender::Impl return mCurrentRequest.value(); } - [[nodiscard]] std::map::iterator getCurrentResponse() + [[nodiscard]] bool hasCurrentResponse() { - std::unique_lock lk(mSenderMutex); - return mReadyResponses.find(getCurrentRequestId()); + std::unique_lock lk(mSenderMutex); + auto requestId = getCurrentRequestId(); + return mReadyRequests.find(requestId) != mReadyRequests.end() + && mReadyPromises.find(requestId) != mReadyPromises.end(); } private: std::optional mCurrentRequest; - std::map mReadyResponses; + std::map mReadyRequests; + std::map> mReadyPromises; std::mutex mSenderMutex, mCondMutex; std::atomic mAnyReady{false}, mTerminate{false}; std::condition_variable mSenderCv; @@ -473,14 +485,15 @@ class CacheReceiver::Impl RequestInfo requestInfo(requestId, mSelfState); - auto disableSelectiveCacheTransfer = common::getEnvDisableSelectiveCacheTransfer() - || (mFormatter->getCacheManager()->getBlockManager().getNumPools() > 1); - if (!disableSelectiveCacheTransfer) + if (mFormatter->getCacheManager()->isEnableBlockReuse() + && mFormatter->getCacheManager()->getBlockManager().getNumPools() == 1) { auto* cacheManager = mFormatter->getCacheManager(); - auto blockRange - = kv_cache_manager::BlockRange::fromNewlyAllocatedBlockIds(*cacheManager, llmRequest.mRequestId); - requestInfo = RequestInfo(requestId, blockRange.getBlockHashes(), mSelfState); + auto beam = 0; + auto allBlockRange = BlockRange::fromAllBlockIds(*cacheManager, llmRequest.mRequestId, beam); + auto requestedBlockRange = getBlockRangeForReceiving(cacheManager, llmRequest); + requestInfo = RequestInfo( + requestId, allBlockRange.getBlockHashes(), mSelfState, requestedBlockRange.getBlockHashes()); } auto* agentConnectionManager = dynamic_cast(mManager); @@ -525,7 +538,8 @@ class CacheReceiver::Impl } auto const& resource = getReceiveCacheResource(llmRequest); return TransferSession(std::move(counterPartConnections), DataContext{tagFromRequestId(requestId)}, mSelfState, - contextState, resource->mBufferManager, &llmRequest); + contextState, resource->mBufferManager, requestInfo.getAllBlockHashes(), + requestInfo.getRequestedBlockHashes(), &llmRequest); } std::unique_ptr const& getReceiveCacheResource(LlmRequest const& llmRequest) @@ -553,9 +567,9 @@ class CacheReceiver::Impl auto const& serializedInfo = oss.str(); std::size_t const infoSize = serializedInfo.size(); TransceiverTag::Id id{TransceiverTag::Id::REQUEST_SEND}; - connection->send(executor::kv_cache::DataContext{TransceiverTag::kID_TAG}, &id, sizeof(id)); - connection->send(executor::kv_cache::DataContext{TransceiverTag::kINFO_SIZE_TAG}, &infoSize, sizeof(infoSize)); - connection->send(executor::kv_cache::DataContext{TransceiverTag::kINFO_TAG}, serializedInfo.data(), infoSize); + connection->send(DataContext{TransceiverTag::kID_TAG}, &id, sizeof(id)); + connection->send(DataContext{TransceiverTag::kINFO_SIZE_TAG}, &infoSize, sizeof(infoSize)); + connection->send(DataContext{TransceiverTag::kINFO_TAG}, serializedInfo.data(), infoSize); } ~Impl() diff --git a/cpp/tensorrt_llm/batch_manager/dataTransceiver.h b/cpp/tensorrt_llm/batch_manager/dataTransceiver.h index 14c8302e8d2..a542ca2fa42 100644 --- a/cpp/tensorrt_llm/batch_manager/dataTransceiver.h +++ b/cpp/tensorrt_llm/batch_manager/dataTransceiver.h @@ -67,8 +67,8 @@ class RequestInfo /// @param transState The state of the data transceiver. RequestInfo(LlmRequest::RequestIdType requestId, executor::DataTransceiverState transState); - RequestInfo(LlmRequest::RequestIdType requestId, std::vector blockHashes, - executor::DataTransceiverState transState); + RequestInfo(LlmRequest::RequestIdType requestId, std::vector allBlockHashes, + executor::DataTransceiverState transState, std::vector requestedBlockHashes); RequestInfo() = default; /// @brief Equality comparison operator. @@ -79,9 +79,14 @@ class RequestInfo /// @return The request ID. [[nodiscard]] LlmRequest::RequestIdType getRequestId() const noexcept; - [[nodiscard]] std::vector const& getBlockHashes() const noexcept + [[nodiscard]] std::vector const& getAllBlockHashes() const noexcept { - return mBlockHashes; + return mAllBlockHashes; + } + + [[nodiscard]] std::vector const& getRequestedBlockHashes() const noexcept + { + return mRequestedBlockHashes; } /// @brief Return the state of the data transceiver. @@ -106,13 +111,16 @@ class RequestInfo // The ID used in the context phase of the current request. LlmRequest::RequestIdType mRequestId; - std::vector mBlockHashes; + // The block hashes of the request. + std::vector mAllBlockHashes; + + // The block hashes of the requested data. + std::vector mRequestedBlockHashes; // The state of the data transceiver. executor::DataTransceiverState mTransState; }; - class CacheSender { public: diff --git a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp index 364f2409bc4..fb97168cb76 100644 --- a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp @@ -1008,6 +1008,30 @@ bool WindowBlockManager::blockInRadixTree(BlockPtr const& block) return !block->getUniqueTokens().empty() && block->getPrevBlock() != nullptr; } +std::optional> WindowBlockManager::findBlocksInReuseTreeByHashes( + std::vector const& hashes) const +{ + std::vector blockKeys; + for (auto const& hash : hashes) + { + blockKeys.emplace_back(hash); + } + + auto searchRoot = mCachedBlocksRoot; + for (auto const& blockKey : blockKeys) + { + auto [partialMatch, numMatched, matchingBlock] = searchRoot != nullptr + ? searchRoot->findMatchingBlock(blockKey, false, false) + : std::make_tuple(false, 0, nullptr); + if (matchingBlock == nullptr) + { + return std::nullopt; + } + searchRoot = std::move(matchingBlock); + } + return searchRoot; +} + SizeType32 WindowBlockManager::loadOrAllocateBlocks(std::vector const& blockKeys, SizeType32 numContextBlocks, GenerationRequest& sequence, std::vector const& perBlockRetentions) { @@ -1468,6 +1492,24 @@ void BlockManager::releaseBlocks(GenerationRequest& sequence, OptionalRefincRefCount(); + } +} + void BlockManager::storeNewBlock(GenerationRequest& sequence, OptionalRef llmRequest) { // we store newest block for potential reuse only if: @@ -2052,6 +2094,12 @@ void KVCacheManager::schedulingRemoveSequence(RequestIdType requestId) mBlockManager.schedulingReleaseBlocks(requestId); } +void KVCacheManager::pinBlocks(RequestIdType requestId) +{ + auto& sequence = getSequence(requestId); + mBlockManager.pinBlocks(sequence); +} + SizeType32 KVCacheManager::copyBlockOffsets(ITensor& output, SizeType32 outputSlotOffset, RequestIdType requestId) const { auto const& sequence = getSequence(requestId); diff --git a/cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp b/cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp index eaa2e957e87..33d5c7fba21 100644 --- a/cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp +++ b/cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp @@ -93,6 +93,8 @@ void MLACacheFormatter::format(TransferSession& session) auto const& selfConfig = session.getSelfState().getCacheState().value(); auto const& destConfig = session.getOtherState().getCacheState().value(); auto const selfIdx = session.getSelfState().getCommState().value().getSelfIdx(); + auto& blockHashes = session.getAllBlockHashes(); + auto& requestedBlockHashes = session.getRequestedBlockHashes(); auto const& connections = session.getConnections(); auto& bufferManager = session.getBufferManager(); TLLM_CHECK_WITH_INFO(llmRequest.mSamplingConfig.beamWidth == 1, "Currently only supports beam width 1."); @@ -106,7 +108,7 @@ void MLACacheFormatter::format(TransferSession& session) // diff end auto const numPools = mCacheManager->getBlockManager().getNumPools(); - auto blockRange = getBlockRangeForSending(mCacheManager, llmRequest); + auto blockRange = getBlockRangeForSending(mCacheManager, llmRequest, blockHashes, requestedBlockHashes); auto lastTokenTime = llmRequest.getPerfMetrics().timingMetrics.lastTokenTime; bool recordDelay = lastTokenTime != std::chrono::steady_clock::time_point(); diff --git a/cpp/tensorrt_llm/common/envUtils.cpp b/cpp/tensorrt_llm/common/envUtils.cpp index 59c9d2fffe4..3b3abc4e473 100644 --- a/cpp/tensorrt_llm/common/envUtils.cpp +++ b/cpp/tensorrt_llm/common/envUtils.cpp @@ -318,12 +318,6 @@ bool getEnvDisaggLayerwise() return disaggLayerwise; } -bool getEnvDisableSelectiveCacheTransfer() -{ - static bool const disableSelectiveCacheTransfer = getBoolEnv("TRTLLM_DISABLE_SELECTIVE_CACHE_TRANSFER"); - return disableSelectiveCacheTransfer; -} - bool getEnvParallelCacheSend() { static bool const parallelCacheSend = getBoolEnv("TRTLLM_PARALLEL_CACHE_SEND"); diff --git a/cpp/tensorrt_llm/executor/cache_transmission/agent_utils/connection.cpp b/cpp/tensorrt_llm/executor/cache_transmission/agent_utils/connection.cpp index 851d116eed6..98d7c82a02c 100644 --- a/cpp/tensorrt_llm/executor/cache_transmission/agent_utils/connection.cpp +++ b/cpp/tensorrt_llm/executor/cache_transmission/agent_utils/connection.cpp @@ -234,7 +234,6 @@ AgentConnection const* AgentConnectionManager::recvConnectionAndRequestInfo(batc while (true) { - updateUnhandledNotifications(); std::scoped_lock lock(mNotificationMutex); auto it = mUnhandledNotifications.begin(); diff --git a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp index fcbfaf9c64a..11f3c247994 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp @@ -346,6 +346,7 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(py::module_& m) .def("add_sequence", &BaseKVCacheManager::addSequence) .def("remove_sequence", &BaseKVCacheManager::removeSequence) .def("scheduling_remove_sequence", &BaseKVCacheManager::schedulingRemoveSequence) + .def("pin_blocks", &BaseKVCacheManager::pinBlocks) .def("get_block_pool_pointers", [](tbk::BaseKVCacheManager& self) { diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 5f1e8ac147d..4225911f6bb 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -181,6 +181,7 @@ def __init__(self, self.max_beam_width = max_beam_width self.max_draft_len = max_draft_len self.print_log = model_engine.pytorch_backend_config.print_iter_log + self.block_reuse_enabled = model_engine.pytorch_backend_config.kv_cache_config.enable_block_reuse self.enable_iter_perf_stats = model_engine.pytorch_backend_config.enable_iter_perf_stats self.enable_iter_req_stats = model_engine.pytorch_backend_config.enable_iter_req_stats self.stream_interval = model_engine.pytorch_backend_config.stream_interval @@ -1688,7 +1689,10 @@ def _handle_responses(self): if request_done: if request.is_disagg_context_transmission_state: - self.ctx_in_transmission_requests.append(request) + if self.block_reuse_enabled: + requests_to_terminate.append(request) + else: + self.ctx_in_transmission_requests.append(request) else: requests_to_terminate.append(request) else: From 05ea582320234624dd73a0c8157dafa011262410 Mon Sep 17 00:00:00 2001 From: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com> Date: Tue, 26 Aug 2025 15:25:28 -0700 Subject: [PATCH 4/6] Add unittest for findBlocksInReuseTreeByHashes Signed-off-by: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com> --- .../batch_manager/kvCacheManager.h | 6 +- .../tensorrt_llm/batch_manager/llmRequest.h | 3 - .../batch_manager/kvCacheManager.cpp | 18 ++++- .../batch_manager/kvCacheManagerTest.cpp | 69 +++++++++++++++++++ 4 files changed, 87 insertions(+), 9 deletions(-) diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h index 1a96ec7548f..59fe0eadfed 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h @@ -143,11 +143,7 @@ struct BlockKey { } - bool operator==(BlockKey const& other) const noexcept - { - return (usesExtraIds == other.usesExtraIds && loraTaskId == other.loraTaskId - && uniqueTokens == other.uniqueTokens && extraKeys == other.extraKeys); - } + bool operator==(BlockKey const& other) const noexcept; int partialMatch(BlockKey const& other) const noexcept { diff --git a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h index 6602af405e5..be130fa4143 100644 --- a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h +++ b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h @@ -2022,9 +2022,6 @@ class GenericLlmRequest // Tensors containing the additional generation output. TensorMap mAdditionalGenerationOutputTensors; - // Context request only. The hashes of the blocks that are requested by the corresponding generation request. - std::vector mRequestedBlockHashes; - bool mIsDummyRequest{false}; bool mUseDraftModel{false}; diff --git a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp index fb97168cb76..d1842a4e4be 100644 --- a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp @@ -160,6 +160,22 @@ std::vector buildBlockKeys( namespace tensorrt_llm::batch_manager::kv_cache_manager { + +bool BlockKey::operator==(BlockKey const& other) const noexcept +{ + if (hash != 0) + { + return hash == BlockKeyHasher::hash(other); + } + if (other.hash != 0) + { + return other.hash == BlockKeyHasher::hash(*this); + } + + return (usesExtraIds == other.usesExtraIds && loraTaskId == other.loraTaskId && uniqueTokens == other.uniqueTokens + && extraKeys == other.extraKeys); +} + size_t BlockKeyHasher::hash(BlockKey const& blockKey, std::size_t parentHash) noexcept { // Hashing algorithm adapted from StackOverflow: @@ -395,7 +411,7 @@ void KVCacheBlock::addNextBlock(BlockKey const& blockKey, BlockPtr block) std::tuple KVCacheBlock::findMatchingBlock( BlockKey const& blockKey, bool enablePartialReuse, bool copyOnPartialReuse) const { - if (blockKey.uniqueTokens.size() == 0 || mNextBlocks.size() == 0) + if ((blockKey.uniqueTokens.size() == 0 || mNextBlocks.size() == 0) && !blockKey.hash) { return {false, 0, nullptr}; } diff --git a/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp b/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp index a52cca097a3..0708b836586 100644 --- a/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp +++ b/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp @@ -528,6 +528,75 @@ TEST_F(KVCacheManagerTest, BlockManagerTestWindowSizeToShare) } } +TEST_F(KVCacheManagerTest, FindBlocksInReuseTreeByHashesTest) +{ + auto constexpr numLayers = 12; + auto constexpr numKvHeads = 6; + auto constexpr sizePerHead = 128; + auto constexpr tokensPerBlock = 8; + auto constexpr blocksInPrimaryPool = 4; + auto constexpr blocksInSecondaryPool = 4; + auto constexpr maxNumSequences = 8; + auto const stream = std::make_shared(); + auto constexpr onboardBlocks = true; + + auto constexpr batchSize = 1; + auto constexpr maxBlocksPerSeq = 10; + auto constexpr bytesPerToken = 4; + auto constexpr maxAttentionWindow = 4096; + auto constexpr maxAttentionWindowAllLayer = 4096; + auto constexpr sinkTokenLen = 0; + auto constexpr canUseOneMoreBlock = true; + + SizeType32 constexpr maxNewTokens{0}; + auto constexpr beamWidth = 1; + auto constexpr beamIdx = 0; + tr::SamplingConfig const samplingConfig{beamWidth}; + bool constexpr isStreaming{false}; + + auto const blocksPerWindow = BlocksPerWindow{{maxAttentionWindow, {blocksInPrimaryPool, blocksInSecondaryPool}}}; + + BlockManager blockManager(std::vector(numLayers, numKvHeads), sizePerHead, tokensPerBlock, + blocksPerWindow, maxNumSequences, stream, maxAttentionWindow, beamWidth, + std::vector{maxAttentionWindow}, std::nullopt, nvinfer1::DataType::kHALF, 0, + onboardBlocks); + blockManager.allocatePools(false); + + auto oneLayerBlockSize = blockManager.getBlockSize(0); + EXPECT_EQ(oneLayerBlockSize, numKvHeads * sizePerHead * tokensPerBlock); + + // Add sequence [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16] (17 tokens, three blocks) + auto inputTokens = std::make_shared(VecTokens{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); + auto const inputLength = static_cast(inputTokens->size()); + LlmRequest::RequestIdType requestId{0}; + auto llmRequest0 = std::make_shared(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming); + GenerationRequest seq0{requestId, inputLength, beamWidth, blockManager.getWindowSizesMetadata()}; + auto promptLen0 = llmRequest0->getNumTokens(beamIdx); + auto numContextBlocks0 = tc::ceilDiv(promptLen0, blockManager.getTokensPerBlock()); + blockManager.addSequence(seq0, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow); + EXPECT_EQ(llmRequest0->getContextCurrentPosition(), 0); + auto cacheBlockIds = seq0.getCacheBlockIds(maxAttentionWindow).at(beamIdx); + EXPECT_THAT(cacheBlockIds, ::testing::ElementsAreArray({0, 1, 2})); + + blockManager.storeContextBlocks(seq0, *llmRequest0); + + std::vector emptyHashes{}; + auto result = blockManager.findBlocksInReuseTreeByHashes(emptyHashes, maxAttentionWindow); + ASSERT_TRUE(result.has_value()); + EXPECT_EQ((*result)->getBlockId(), KVCacheBlock::kCachedBlocksRootId); + + auto block0 = blockManager.getBlockById(cacheBlockIds[0], maxAttentionWindow); + // Corrupt the valid hash to guarantee a miss + std::vector badHashes{block0->getHash() ^ static_cast(0x9e3779b97f4a7c15ULL)}; + result = blockManager.findBlocksInReuseTreeByHashes(badHashes, maxAttentionWindow); + EXPECT_FALSE(result.has_value()); + + std::vector hashes{block0->getHash()}; + result = blockManager.findBlocksInReuseTreeByHashes(hashes, maxAttentionWindow); + ASSERT_TRUE(result.has_value()); + EXPECT_EQ((*result)->getBlockId(), block0->getBlockId()); +} + #ifdef ENABLE_FP4 TEST_F(KVCacheManagerTest, FP4BlockScaleManagementTest) { From 43b9285f44df68c9d010704e53d1b347f6f4b0d2 Mon Sep 17 00:00:00 2001 From: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com> Date: Wed, 27 Aug 2025 11:05:45 -0700 Subject: [PATCH 5/6] fixes Signed-off-by: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com> --- .../batch_manager/kvCacheManager.h | 28 ++++++++++++++ .../tensorrt_llm/batch_manager/kvCacheUtils.h | 13 +++++++ .../batch_manager/dataTransceiver.cpp | 24 +++++++++++- .../batch_manager/dataTransceiver.h | 1 + .../batch_manager/kvCacheManager.cpp | 37 ++++++------------- .../batch_manager/kvCacheManagerTest.cpp | 3 ++ tensorrt_llm/_torch/pyexecutor/py_executor.py | 2 +- 7 files changed, 79 insertions(+), 29 deletions(-) diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h index 59fe0eadfed..d17b7f1ec4d 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h @@ -75,6 +75,32 @@ using MmKey = std::pair, SizeType32>; template using OptionalRef = tensorrt_llm::common::OptionalRef; +//! \brief Split vector into list of blocks of given size. +//! \param vec vector to split +//! \param usableSize part of the vector that is processed +//! \param elementsPerBlock desired size of blocks +//! \param allowPartial whether to append a block smaller than `elementsPerBlock` at the end +//! \return list of blocks +template +std::list> chopVectorIntoBlocks( + std::vector const& vec, SizeType32 usableSize, SizeType32 elementsPerBlock, bool allowPartial) +{ + TLLM_CHECK_WITH_INFO( + usableSize <= static_cast(vec.size()), "usableSize=%d > %ld=vec.size()", usableSize, vec.size()); + std::list> blockedVectors; + auto const vecEnd = vec.begin() + usableSize; + for (auto begin = vec.begin(); begin < vecEnd; begin += elementsPerBlock) + { + auto blockSize = std::min(elementsPerBlock, static_cast(std::distance(begin, vecEnd))); + auto end = begin + blockSize; + if (blockSize == elementsPerBlock || allowPartial) + { + blockedVectors.emplace_back(begin, end); + } + } + return blockedVectors; +} + struct TempAttentionWindowInputs { bool pagedContextFMHA; @@ -158,6 +184,8 @@ struct BlockKey } }; +std::vector buildBlockKeys(std::list& blockedUniqueTokens, LlmRequest const& llmRequest); + // Implement hash functor for BlockKey. // This allows us to use unordered_map with BlockKey as key. // Based on https://stackoverflow.com/questions/20511347/a-good-hash-function-for-a-vector/72073933#72073933 diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheUtils.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheUtils.h index 03d81cd6739..504ac5b403d 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheUtils.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheUtils.h @@ -52,6 +52,19 @@ class BlockRange std::vector const& requestedBlockHashes) { auto const windowSize = firstWindowSize(cacheManager); + std::cout << "allBlockHashes: " << allBlockHashes.size() << std::endl; + for (auto hash : allBlockHashes) + { + std::cout << hash << " "; + } + std::cout << std::endl; + std::cout << "requestedBlockHashes: " << requestedBlockHashes.size() << std::endl; + for (auto hash : requestedBlockHashes) + { + std::cout << hash << " "; + } + std::cout << std::endl; + auto lastBlock = *cacheManager.findBlocksInReuseTreeByHashes(allBlockHashes, windowSize); // TODO: handle the case where the last block is not found TLLM_CHECK_WITH_INFO(lastBlock, "Couldn't find the requested block in the reuse tree"); diff --git a/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp b/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp index d48d02de8cc..92ae3782636 100644 --- a/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp +++ b/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp @@ -492,8 +492,28 @@ class CacheReceiver::Impl auto beam = 0; auto allBlockRange = BlockRange::fromAllBlockIds(*cacheManager, llmRequest.mRequestId, beam); auto requestedBlockRange = getBlockRangeForReceiving(cacheManager, llmRequest); - requestInfo = RequestInfo( - requestId, allBlockRange.getBlockHashes(), mSelfState, requestedBlockRange.getBlockHashes()); + + auto const& uniqueTokens = llmRequest.getUniqueTokens(beam); + auto blockedUniqueTokens = tensorrt_llm::batch_manager::kv_cache_manager::chopVectorIntoBlocks( + uniqueTokens, uniqueTokens.size() - 1, mFormatter->getCacheManager()->getTokensPerBlock(), false); + auto blockKeys + = tensorrt_llm::batch_manager::kv_cache_manager::buildBlockKeys(blockedUniqueTokens, llmRequest); + std::vector allBlockHashes; + for (auto const& blockKey : blockKeys) + { + allBlockHashes.push_back(tensorrt_llm::batch_manager::kv_cache_manager::BlockKeyHasher::hash(blockKey)); + } + // Figure out the size difference for computing the requested block hashes + std::vector requestedBlockHashes; + for (auto i = 0; i < allBlockHashes.size(); i++) + { + if (i >= requestedBlockRange.getBlockHashes().size()) + { + requestedBlockHashes.push_back(allBlockHashes[i]); + } + } + + requestInfo = RequestInfo(requestId, allBlockHashes, mSelfState, requestedBlockHashes); } auto* agentConnectionManager = dynamic_cast(mManager); diff --git a/cpp/tensorrt_llm/batch_manager/dataTransceiver.h b/cpp/tensorrt_llm/batch_manager/dataTransceiver.h index a542ca2fa42..e0ecbc7c8e1 100644 --- a/cpp/tensorrt_llm/batch_manager/dataTransceiver.h +++ b/cpp/tensorrt_llm/batch_manager/dataTransceiver.h @@ -41,6 +41,7 @@ using Connection = tensorrt_llm::executor::kv_cache::Connection; using ConnectionManager = tensorrt_llm::executor::kv_cache::ConnectionManager; using SizeType32 = tensorrt_llm::runtime::SizeType32; using TransferSession = kv_cache_manager::TransferSession; +using UniqueToken = tensorrt_llm::runtime::UniqueToken; struct TransceiverTag { diff --git a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp index d1842a4e4be..302c03ed0a4 100644 --- a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp @@ -50,32 +50,6 @@ using BlocksPerWindow = std::map> namespace { -//! \brief Split vector into list of blocks of given size. -//! \param vec vector to split -//! \param usableSize part of the vector that is processed -//! \param elementsPerBlock desired size of blocks -//! \param allowPartial whether to append a block smaller than `elementsPerBlock` at the end -//! \return list of blocks -template -std::list> chopVectorIntoBlocks( - std::vector const& vec, SizeType32 usableSize, SizeType32 elementsPerBlock, bool allowPartial) -{ - TLLM_CHECK_WITH_INFO( - usableSize <= static_cast(vec.size()), "usableSize=%d > %ld=vec.size()", usableSize, vec.size()); - std::list> blockedVectors; - auto const vecEnd = vec.begin() + usableSize; - for (auto begin = vec.begin(); begin < vecEnd; begin += elementsPerBlock) - { - auto blockSize = std::min(elementsPerBlock, static_cast(std::distance(begin, vecEnd))); - auto end = begin + blockSize; - if (blockSize == elementsPerBlock || allowPartial) - { - blockedVectors.emplace_back(begin, end); - } - } - return blockedVectors; -} - inline uint8_t getNthByte(SizeType32 hashPart, uint8_t byteIdx) noexcept { return static_cast((hashPart >> (24 - byteIdx * 8)) & 0xFF); @@ -1027,6 +1001,12 @@ bool WindowBlockManager::blockInRadixTree(BlockPtr const& block) std::optional> WindowBlockManager::findBlocksInReuseTreeByHashes( std::vector const& hashes) const { + std::cout << "findBlocksInReuseTreeByHashes: " << hashes.size() << std::endl; + for (auto hash : hashes) + { + std::cout << hash << " "; + } + std::cout << std::endl; std::vector blockKeys; for (auto const& hash : hashes) { @@ -1036,13 +1016,16 @@ std::optional> WindowBlockManager::findBlocksInReu auto searchRoot = mCachedBlocksRoot; for (auto const& blockKey : blockKeys) { + std::cout << "findBlocksInReuseTreeByHashes: searching for block key: " << blockKey.hash << std::endl; auto [partialMatch, numMatched, matchingBlock] = searchRoot != nullptr ? searchRoot->findMatchingBlock(blockKey, false, false) : std::make_tuple(false, 0, nullptr); if (matchingBlock == nullptr) { + std::cout << "findBlocksInReuseTreeByHashes: no matching block found" << std::endl; return std::nullopt; } + std::cout << "findBlocksInReuseTreeByHashes: matching block found" << matchingBlock->getHash() << std::endl; searchRoot = std::move(matchingBlock); } return searchRoot; @@ -1330,6 +1313,8 @@ void WindowBlockManager::storeBlocks( for (std::size_t blockCnt = 0; blockCnt < numBlocks; ++blockCnt) { auto const bid = blockIds[blockCnt]; + std::cout << "Storing block " << bid << " block hash: " << BlockKeyHasher::hash(blockKeys[blockCnt]) + << std::endl; TLLM_LOG_DEBUG("%s::storeBlocks - Searching match for block %d", mLogPrefix.c_str(), bid); auto& block = mAllBlocksById[bid]; auto const& blockKey = blockKeys[blockCnt]; diff --git a/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp b/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp index 0708b836586..2d2b4e0a9da 100644 --- a/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp +++ b/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp @@ -595,6 +595,9 @@ TEST_F(KVCacheManagerTest, FindBlocksInReuseTreeByHashesTest) result = blockManager.findBlocksInReuseTreeByHashes(hashes, maxAttentionWindow); ASSERT_TRUE(result.has_value()); EXPECT_EQ((*result)->getBlockId(), block0->getBlockId()); + + // Add sequence [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16] (17 tokens, three blocks) + BlockKey } #ifdef ENABLE_FP4 diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 4225911f6bb..dc2d6788ec7 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -181,7 +181,6 @@ def __init__(self, self.max_beam_width = max_beam_width self.max_draft_len = max_draft_len self.print_log = model_engine.pytorch_backend_config.print_iter_log - self.block_reuse_enabled = model_engine.pytorch_backend_config.kv_cache_config.enable_block_reuse self.enable_iter_perf_stats = model_engine.pytorch_backend_config.enable_iter_perf_stats self.enable_iter_req_stats = model_engine.pytorch_backend_config.enable_iter_req_stats self.stream_interval = model_engine.pytorch_backend_config.stream_interval @@ -201,6 +200,7 @@ def __init__(self, # kv cache events self.kv_cache_manager = self.resource_manager.resource_managers.get( ResourceManagerType.KV_CACHE_MANAGER) + self.block_reuse_enabled = self.kv_cache_manager.enable_block_reuse self.enable_kv_cache_events = self.kv_cache_manager is not None and self.kv_cache_manager.event_buffer_max_size > 0 self.max_input_len = max_input_len From f1224bcf82fffb850649b5693caf9c7ee2dcdfcc Mon Sep 17 00:00:00 2001 From: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com> Date: Fri, 29 Aug 2025 15:54:29 -0700 Subject: [PATCH 6/6] Switch from hash id to block key Signed-off-by: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com> --- .../batch_manager/kvCacheManager.h | 20 ++--- .../tensorrt_llm/batch_manager/kvCacheUtils.h | 42 ++++----- .../tensorrt_llm/executor/serialization.h | 5 ++ .../batch_manager/cacheFormatter.cpp | 22 ++--- .../batch_manager/cacheFormatter.h | 22 ++--- .../batch_manager/dataTransceiver.cpp | 58 ++++++------ .../batch_manager/dataTransceiver.h | 19 ++-- .../batch_manager/kvCacheManager.cpp | 37 ++++---- .../batch_manager/mlaCacheFormatter.cpp | 6 +- cpp/tensorrt_llm/executor/serialization.cpp | 39 ++++++++ cpp/tensorrt_llm/executor/serializeUtils.h | 90 +++++++++++++++++++ .../batch_manager/kvCacheManagerTest.cpp | 15 ++-- 12 files changed, 246 insertions(+), 129 deletions(-) diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h index d17b7f1ec4d..249dade0e77 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h @@ -776,8 +776,8 @@ class WindowBlockManager return 0; } - [[nodiscard]] std::optional> findBlocksInReuseTreeByHashes( - std::vector const& hashes) const; + [[nodiscard]] std::optional> findBlocksInReuseTreeByBlockKeys( + std::vector const& blockKeys) const; private: //! \brief Add single block to beam of sequence and mAllocatedBlocksPerSeq. @@ -1116,10 +1116,10 @@ class BlockManager return mWindowBlockManagers.at(windowSize).getBlockById(blockId); } - [[nodiscard]] std::optional> findBlocksInReuseTreeByHashes( - std::vector const& hashes, SizeType32 windowSize) const + [[nodiscard]] std::optional> findBlocksInReuseTreeByBlockKeys( + std::vector const& blockKeys, SizeType32 windowSize) const { - return mWindowBlockManagers.at(windowSize).findBlocksInReuseTreeByHashes(hashes); + return mWindowBlockManagers.at(windowSize).findBlocksInReuseTreeByBlockKeys(blockKeys); } [[nodiscard]] SizeType32 getNumPrimaryBlocks() const @@ -1405,8 +1405,8 @@ class BaseKVCacheManager [[nodiscard]] virtual CacheType getCacheType() const = 0; - [[nodiscard]] virtual std::optional> findBlocksInReuseTreeByHashes( - std::vector const& hashes, SizeType32 windowSize) const + [[nodiscard]] virtual std::optional> findBlocksInReuseTreeByBlockKeys( + std::vector const& blockKeys, SizeType32 windowSize) const = 0; }; @@ -1698,10 +1698,10 @@ class KVCacheManager : public BaseKVCacheManager mBlockManager.flushIterationEvents(); } - std::optional> findBlocksInReuseTreeByHashes( - std::vector const& hashes, SizeType32 windowSize) const override + std::optional> findBlocksInReuseTreeByBlockKeys( + std::vector const& blockKeys, SizeType32 windowSize) const override { - return mBlockManager.findBlocksInReuseTreeByHashes(hashes, windowSize); + return mBlockManager.findBlocksInReuseTreeByBlockKeys(blockKeys, windowSize); } /// @brief Finds the maximum attention window that can be used on a sequence, given some kv-cache block capacity. diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheUtils.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheUtils.h index 504ac5b403d..e7a911bd30e 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheUtils.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheUtils.h @@ -48,37 +48,33 @@ class BlockRange return BlockRange(cacheManager, blockIds, requestId); } - static BlockRange fromReuseTree(BaseKVCacheManager const& cacheManager, std::vector const& allBlockHashes, - std::vector const& requestedBlockHashes) + static BlockRange fromReuseTree( + BaseKVCacheManager const& cacheManager, std::vector const& allBlockKeys, SizeType32 indexFromEnd) { auto const windowSize = firstWindowSize(cacheManager); - std::cout << "allBlockHashes: " << allBlockHashes.size() << std::endl; - for (auto hash : allBlockHashes) - { - std::cout << hash << " "; - } - std::cout << std::endl; - std::cout << "requestedBlockHashes: " << requestedBlockHashes.size() << std::endl; - for (auto hash : requestedBlockHashes) - { - std::cout << hash << " "; - } - std::cout << std::endl; - - auto lastBlock = *cacheManager.findBlocksInReuseTreeByHashes(allBlockHashes, windowSize); + // Find the last block in the reuse tree for the provided full sequence of block keys + auto lastBlock = *cacheManager.findBlocksInReuseTreeByBlockKeys(allBlockKeys, windowSize); // TODO: handle the case where the last block is not found TLLM_CHECK_WITH_INFO(lastBlock, "Couldn't find the requested block in the reuse tree"); - // Assume the the last block is the requested block + // Validate indexFromEnd and determine how many trailing blocks to collect + auto const totalNumBlocks = static_cast(allBlockKeys.size()); + TLLM_CHECK_WITH_INFO( + indexFromEnd < totalNumBlocks, "indexFromEnd=%d is out of range (total=%d)", indexFromEnd, totalNumBlocks); + + // Number of blocks to return equals suffix length starting at the block located indexFromEnd from the end + // Example: indexFromEnd=0 -> return last block only; indexFromEnd=2 -> return last 3 blocks + SizeType32 const numBlocksToCollect = indexFromEnd + 1; + std::vector blockIds; - for (auto const& hash : requestedBlockHashes) + blockIds.reserve(numBlocksToCollect); + for (SizeType32 i = 0; i < numBlocksToCollect; ++i) { - if (lastBlock->getHash() != hash) + blockIds.emplace_back(lastBlock->getBlockId()); + if (i + 1 < numBlocksToCollect) { - return BlockRange(cacheManager, {}, 0); + lastBlock = lastBlock->getPrevBlock(); + TLLM_CHECK_WITH_INFO(lastBlock, "Previous block not found while traversing reuse tree"); } - blockIds.emplace_back(lastBlock->getBlockId()); - lastBlock = lastBlock->getPrevBlock(); - TLLM_CHECK_WITH_INFO(lastBlock, "Last block is not found"); } return BlockRange(cacheManager, blockIds, 0); } diff --git a/cpp/include/tensorrt_llm/executor/serialization.h b/cpp/include/tensorrt_llm/executor/serialization.h index c370a652350..1d30da2027c 100644 --- a/cpp/include/tensorrt_llm/executor/serialization.h +++ b/cpp/include/tensorrt_llm/executor/serialization.h @@ -16,6 +16,7 @@ #pragma once +#include "tensorrt_llm/batch_manager/kvCacheManager.h" #include "tensorrt_llm/executor/dataTransceiverState.h" #include "tensorrt_llm/executor/executor.h" #include "tensorrt_llm/executor/tensor.h" @@ -36,6 +37,10 @@ struct SocketState; class Serialization { public: + // BlockKey (KV cache) + static size_t serializedSize(tensorrt_llm::batch_manager::kv_cache_manager::BlockKey const& key); + static void serialize(tensorrt_llm::batch_manager::kv_cache_manager::BlockKey const& key, std::ostream& os); + static tensorrt_llm::batch_manager::kv_cache_manager::BlockKey deserializeBlockKey(std::istream& is); // TimePoint [[nodiscard]] static RequestPerfMetrics::TimePoint deserializeTimePoint(std::istream& is); static void serialize(RequestPerfMetrics::TimePoint const& tp, std::ostream& os); diff --git a/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp b/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp index f389c09b7c7..31ab33d0fbe 100644 --- a/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp +++ b/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp @@ -40,9 +40,8 @@ namespace tensorrt_llm::batch_manager::kv_cache_manager { BlockRange getBlockRangeForSending(BaseKVCacheManager* cacheManager, LlmRequest const& llmRequest, - std::vector const& allBlockHashes, std::vector const& requestedBlockHashes) + std::vector const& allBlockKeys, SizeType32 indexFromEnd) { - size_t requestBlockNum = requestedBlockHashes.size(); constexpr SizeType32 beam{0}; auto poolNum = cacheManager->getBlockManager().getNumPools(); if (poolNum > 1 || !cacheManager->isEnableBlockReuse()) @@ -50,7 +49,7 @@ BlockRange getBlockRangeForSending(BaseKVCacheManager* cacheManager, LlmRequest auto blockRange = BlockRange::fromAllBlockIds(*cacheManager, llmRequest.mRequestId, beam); return blockRange; } - return BlockRange::fromReuseTree(*cacheManager, allBlockHashes, requestedBlockHashes); + return BlockRange::fromReuseTree(*cacheManager, allBlockKeys, indexFromEnd); } BlockRange getBlockRangeForReceiving(BaseKVCacheManager* cacheManager, LlmRequest const& llmRequest) @@ -151,18 +150,17 @@ void CacheFormatter::format(TransferSession& session) auto const& selfConfig = session.getSelfState().getCacheState().value(); auto const& destConfig = session.getOtherState().getCacheState().value(); auto const selfIdx = session.getSelfState().getCommState().value().getSelfIdx(); - auto& allBlockHashes = session.getAllBlockHashes(); - auto& requestedBlockHashes = session.getRequestedBlockHashes(); - TLLM_CHECK_WITH_INFO(allBlockHashes.size() >= requestedBlockHashes.size(), - "allBlockHashes must be greater than or equal to requestedBlockHashes"); + auto& allBlockKeys = session.getAllBlockKeys(); + auto indexFromEnd = session.getIndexFromEnd(); + TLLM_CHECK_WITH_INFO(indexFromEnd < allBlockKeys.size(), "indexFromEnd is out of range"); auto& bufferManager = session.getBufferManager(); if (!needSendCache(selfConfig, destConfig, selfIdx)) { return; } auto& blockManager = mCacheManager->getBlockManager(); - auto blockRange = getBlockRangeForSending(mCacheManager, llmRequest, allBlockHashes, requestedBlockHashes); - + auto blockRange = getBlockRangeForSending(mCacheManager, llmRequest, allBlockKeys, indexFromEnd); + std::cout << "blockRange size is: " << blockRange.getBlockIds().size() << std::endl; auto const numPools = blockManager.getNumPools(); // TODO(oargov): are we sure the other side has the same number of pools? this might not hold for pp_size>1... @@ -211,7 +209,11 @@ void CacheFormatter::format(TransferSession& session) std::map> inputKvCacheBlocks; for (auto poolIdx = 0; poolIdx < numPools; poolIdx++) { - blockRange.updatePoolIdx(poolIdx); + if (numPools > 1) + { + blockRange.updatePoolIdx(poolIdx); + } + std::cout << "blockRange size is: " << blockRange.getBlockIds().size() << std::endl; SizeType32 window = mCacheManager->getBlockManager().getPoolWindowSize(poolIdx); TLLM_CHECK_WITH_INFO(inputKvCacheBlocks.find(window) == inputKvCacheBlocks.end(), "window size already exists, which is not supported"); diff --git a/cpp/tensorrt_llm/batch_manager/cacheFormatter.h b/cpp/tensorrt_llm/batch_manager/cacheFormatter.h index 0aaeeb742ec..89a2e223e75 100644 --- a/cpp/tensorrt_llm/batch_manager/cacheFormatter.h +++ b/cpp/tensorrt_llm/batch_manager/cacheFormatter.h @@ -35,7 +35,7 @@ namespace tensorrt_llm::batch_manager::kv_cache_manager { BlockRange getBlockRangeForSending(BaseKVCacheManager* cacheManager, LlmRequest const& llmRequest, - std::vector const& allBlockHashes, std::vector const& requestedBlockHashes); + std::vector const& allBlockKeys, SizeType32 indexFromEnd); BlockRange getBlockRangeForReceiving(BaseKVCacheManager* cacheManager, LlmRequest const& llmRequest); @@ -130,15 +130,15 @@ class TransferSession TransferSession(std::vector connections, DataContext dataContext, executor::DataTransceiverState const& selfState, executor::DataTransceiverState otherState, - runtime::BufferManager const& bufferManager, std::vector allBlockHashes = {}, - std::vector requestedBlockHashes = {}, LlmRequest const* llmRequest = nullptr) + runtime::BufferManager const& bufferManager, std::vector allBlockKeys = {}, + SizeType32 indexFromEnd = 0, LlmRequest const* llmRequest = nullptr) : mConnections(std::move(connections)) , mDataContext(dataContext) , mSelfState(&selfState) , mOtherState(std::move(otherState)) , mBufferManager(&bufferManager) - , mAllBlockHashes(std::move(allBlockHashes)) - , mRequestedBlockHashes(std::move(requestedBlockHashes)) + , mAllBlockKeys(std::move(allBlockKeys)) + , mIndexFromEnd(indexFromEnd) , mRequest(llmRequest) { TLLM_CHECK(!mConnections.empty()); @@ -235,14 +235,14 @@ class TransferSession outFile << '\n' << std::flush; } - [[nodiscard]] std::vector const& getAllBlockHashes() const + [[nodiscard]] std::vector const& getAllBlockKeys() const { - return mAllBlockHashes; + return mAllBlockKeys; } - [[nodiscard]] std::vector const& getRequestedBlockHashes() const + [[nodiscard]] SizeType32 getIndexFromEnd() const { - return mRequestedBlockHashes; + return mIndexFromEnd; } private: @@ -253,8 +253,8 @@ class TransferSession runtime::BufferManager const* mBufferManager; std::vector mMeasures; bool mRecordMeasure{false}; - std::vector mAllBlockHashes; - std::vector mRequestedBlockHashes; + std::vector mAllBlockKeys; + SizeType32 mIndexFromEnd; LlmRequest const* mRequest; }; diff --git a/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp b/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp index 92ae3782636..26223d056a1 100644 --- a/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp +++ b/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp @@ -62,19 +62,19 @@ RequestInfo::RequestInfo(LlmRequest::RequestIdType requestId, executor::DataTran { } -RequestInfo::RequestInfo(LlmRequest::RequestIdType requestId, std::vector allBlockHashes, - executor::DataTransceiverState transState, std::vector requestedBlockHashes) +RequestInfo::RequestInfo(LlmRequest::RequestIdType requestId, std::vector allBlockKeys, + executor::DataTransceiverState transState, SizeType32 indexFromEnd) : mRequestId{requestId} - , mAllBlockHashes{std::move(allBlockHashes)} - , mRequestedBlockHashes{std::move(requestedBlockHashes)} + , mAllBlockKeys{std::move(allBlockKeys)} + , mIndexFromEnd{indexFromEnd} , mTransState{std::move(transState)} { } bool RequestInfo::operator==(RequestInfo const& rhs) const { - return mRequestId == rhs.mRequestId && mAllBlockHashes == rhs.mAllBlockHashes - && mRequestedBlockHashes == rhs.mRequestedBlockHashes && mTransState == rhs.mTransState; + return mRequestId == rhs.mRequestId && mAllBlockKeys == rhs.mAllBlockKeys && mIndexFromEnd == rhs.mIndexFromEnd + && mTransState == rhs.mTransState; } LlmRequest::RequestIdType RequestInfo::getRequestId() const noexcept @@ -91,8 +91,8 @@ void RequestInfo::serialize(RequestInfo const& requestInfo, std::ostream& os) { namespace su = executor::serialize_utils; su::serialize(requestInfo.mRequestId, os); - su::serialize(requestInfo.mAllBlockHashes, os); - su::serialize(requestInfo.mRequestedBlockHashes, os); + su::serialize(requestInfo.mAllBlockKeys, os); + su::serialize(requestInfo.mIndexFromEnd, os); su::serialize(requestInfo.mTransState, os); } @@ -100,10 +100,10 @@ RequestInfo RequestInfo::deserialize(std::istream& is) { namespace su = executor::serialize_utils; auto requestId = su::deserialize(is); - auto allBlockHashes = su::deserialize(is); - auto requestedBlockHashes = su::deserialize(is); + auto allBlockKeys = su::deserialize(is); + auto indexFromEnd = su::deserialize(is); auto transState = su::deserialize(is); - return RequestInfo{requestId, std::move(allBlockHashes), std::move(transState), std::move(requestedBlockHashes)}; + return RequestInfo{requestId, std::move(allBlockKeys), std::move(transState), indexFromEnd}; } std::size_t RequestInfo::serializedSize(RequestInfo const& requestInfo) @@ -111,8 +111,8 @@ std::size_t RequestInfo::serializedSize(RequestInfo const& requestInfo) namespace su = executor::serialize_utils; std::size_t totalSize = 0; totalSize += su::serializedSize(requestInfo.mRequestId); - totalSize += su::serializedSize(requestInfo.mAllBlockHashes); - totalSize += su::serializedSize(requestInfo.mRequestedBlockHashes); + totalSize += su::serializedSize(requestInfo.mAllBlockKeys); + totalSize += su::serializedSize(requestInfo.mIndexFromEnd); totalSize += su::serializedSize(requestInfo.mTransState); return totalSize; } @@ -218,7 +218,7 @@ class CacheSender::Impl { auto session = TransferSession(std::vector(peerRelativeRanks.size(), nullptr), DataContext{tagFromRequestId(requestId)}, mSelfState, info.getTransState(), mBufferManager, - info.getAllBlockHashes(), info.getRequestedBlockHashes()); + info.getAllBlockKeys(), info.getIndexFromEnd()); it = mRequestToSession.emplace(requestId, std::move(session)).first; } it->second.setConnection(peerIdx, connection); @@ -495,25 +495,17 @@ class CacheReceiver::Impl auto const& uniqueTokens = llmRequest.getUniqueTokens(beam); auto blockedUniqueTokens = tensorrt_llm::batch_manager::kv_cache_manager::chopVectorIntoBlocks( - uniqueTokens, uniqueTokens.size() - 1, mFormatter->getCacheManager()->getTokensPerBlock(), false); + uniqueTokens, uniqueTokens.size() - 1, mFormatter->getCacheManager()->getTokensPerBlock(), true); auto blockKeys = tensorrt_llm::batch_manager::kv_cache_manager::buildBlockKeys(blockedUniqueTokens, llmRequest); - std::vector allBlockHashes; - for (auto const& blockKey : blockKeys) - { - allBlockHashes.push_back(tensorrt_llm::batch_manager::kv_cache_manager::BlockKeyHasher::hash(blockKey)); - } - // Figure out the size difference for computing the requested block hashes - std::vector requestedBlockHashes; - for (auto i = 0; i < allBlockHashes.size(); i++) - { - if (i >= requestedBlockRange.getBlockHashes().size()) - { - requestedBlockHashes.push_back(allBlockHashes[i]); - } - } - - requestInfo = RequestInfo(requestId, allBlockHashes, mSelfState, requestedBlockHashes); + // Compute indexFromEnd from the number of requested blocks + size_t totalBlockSize = allBlockRange.getBlockIds().size(); + size_t requestedBlockSize = requestedBlockRange.getBlockIds().size(); + std::cout << "requestedBlockSize: " << requestedBlockSize << std::endl; + TLLM_CHECK_WITH_INFO(requestedBlockSize > 0, "requestedBlockSize must be > 0"); + SizeType32 indexFromEnd = static_cast(requestedBlockSize - 1); + + requestInfo = RequestInfo(requestId, blockKeys, mSelfState, indexFromEnd); } auto* agentConnectionManager = dynamic_cast(mManager); @@ -558,8 +550,8 @@ class CacheReceiver::Impl } auto const& resource = getReceiveCacheResource(llmRequest); return TransferSession(std::move(counterPartConnections), DataContext{tagFromRequestId(requestId)}, mSelfState, - contextState, resource->mBufferManager, requestInfo.getAllBlockHashes(), - requestInfo.getRequestedBlockHashes(), &llmRequest); + contextState, resource->mBufferManager, requestInfo.getAllBlockKeys(), requestInfo.getIndexFromEnd(), + &llmRequest); } std::unique_ptr const& getReceiveCacheResource(LlmRequest const& llmRequest) diff --git a/cpp/tensorrt_llm/batch_manager/dataTransceiver.h b/cpp/tensorrt_llm/batch_manager/dataTransceiver.h index e0ecbc7c8e1..63660e06855 100644 --- a/cpp/tensorrt_llm/batch_manager/dataTransceiver.h +++ b/cpp/tensorrt_llm/batch_manager/dataTransceiver.h @@ -42,6 +42,7 @@ using ConnectionManager = tensorrt_llm::executor::kv_cache::ConnectionManager; using SizeType32 = tensorrt_llm::runtime::SizeType32; using TransferSession = kv_cache_manager::TransferSession; using UniqueToken = tensorrt_llm::runtime::UniqueToken; +using BlockKey = tensorrt_llm::batch_manager::kv_cache_manager::BlockKey; struct TransceiverTag { @@ -68,8 +69,8 @@ class RequestInfo /// @param transState The state of the data transceiver. RequestInfo(LlmRequest::RequestIdType requestId, executor::DataTransceiverState transState); - RequestInfo(LlmRequest::RequestIdType requestId, std::vector allBlockHashes, - executor::DataTransceiverState transState, std::vector requestedBlockHashes); + RequestInfo(LlmRequest::RequestIdType requestId, std::vector allBlockKeys, + executor::DataTransceiverState transState, SizeType32 indexFromEnd); RequestInfo() = default; /// @brief Equality comparison operator. @@ -80,14 +81,14 @@ class RequestInfo /// @return The request ID. [[nodiscard]] LlmRequest::RequestIdType getRequestId() const noexcept; - [[nodiscard]] std::vector const& getAllBlockHashes() const noexcept + [[nodiscard]] std::vector const& getAllBlockKeys() const noexcept { - return mAllBlockHashes; + return mAllBlockKeys; } - [[nodiscard]] std::vector const& getRequestedBlockHashes() const noexcept + [[nodiscard]] SizeType32 getIndexFromEnd() const noexcept { - return mRequestedBlockHashes; + return mIndexFromEnd; } /// @brief Return the state of the data transceiver. @@ -113,10 +114,10 @@ class RequestInfo LlmRequest::RequestIdType mRequestId; // The block hashes of the request. - std::vector mAllBlockHashes; + std::vector mAllBlockKeys; - // The block hashes of the requested data. - std::vector mRequestedBlockHashes; + // Index from end indicating how many trailing blocks to transfer (index+1) + SizeType32 mIndexFromEnd{0}; // The state of the data transceiver. executor::DataTransceiverState mTransState; diff --git a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp index 302c03ed0a4..a26d6aa1093 100644 --- a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp @@ -113,6 +113,11 @@ std::vector generateBlockHashExtraKeys( return extraKeys; } +} // namespace + +namespace tensorrt_llm::batch_manager::kv_cache_manager +{ + std::vector buildBlockKeys( std::list& blockedUniqueTokens, tensorrt_llm::batch_manager::LlmRequest const& llmRequest) { @@ -130,11 +135,6 @@ std::vector buildBlockKeys( return blockKeys; } -} // namespace - -namespace tensorrt_llm::batch_manager::kv_cache_manager -{ - bool BlockKey::operator==(BlockKey const& other) const noexcept { if (hash != 0) @@ -998,34 +998,29 @@ bool WindowBlockManager::blockInRadixTree(BlockPtr const& block) return !block->getUniqueTokens().empty() && block->getPrevBlock() != nullptr; } -std::optional> WindowBlockManager::findBlocksInReuseTreeByHashes( - std::vector const& hashes) const +std::optional> WindowBlockManager::findBlocksInReuseTreeByBlockKeys( + std::vector const& blockKeys) const { - std::cout << "findBlocksInReuseTreeByHashes: " << hashes.size() << std::endl; - for (auto hash : hashes) + std::cout << "findBlocksInReuseTreeByBlockKeys: " << blockKeys.size() << std::endl; + for (auto const& key : blockKeys) { - std::cout << hash << " "; + std::cout << key.hash << " "; } std::cout << std::endl; - std::vector blockKeys; - for (auto const& hash : hashes) - { - blockKeys.emplace_back(hash); - } auto searchRoot = mCachedBlocksRoot; for (auto const& blockKey : blockKeys) { - std::cout << "findBlocksInReuseTreeByHashes: searching for block key: " << blockKey.hash << std::endl; + std::cout << "findBlocksInReuseTreeByBlockKeys: searching for block key: " << blockKey.hash << std::endl; auto [partialMatch, numMatched, matchingBlock] = searchRoot != nullptr - ? searchRoot->findMatchingBlock(blockKey, false, false) + ? searchRoot->findMatchingBlock(blockKey, true, true) : std::make_tuple(false, 0, nullptr); if (matchingBlock == nullptr) { - std::cout << "findBlocksInReuseTreeByHashes: no matching block found" << std::endl; + std::cout << "findBlocksInReuseTreeByBlockKeys: no matching block found" << std::endl; return std::nullopt; } - std::cout << "findBlocksInReuseTreeByHashes: matching block found" << matchingBlock->getHash() << std::endl; + std::cout << "findBlocksInReuseTreeByBlockKeys: matching block found" << matchingBlock->getHash() << std::endl; searchRoot = std::move(matchingBlock); } return searchRoot; @@ -1313,8 +1308,8 @@ void WindowBlockManager::storeBlocks( for (std::size_t blockCnt = 0; blockCnt < numBlocks; ++blockCnt) { auto const bid = blockIds[blockCnt]; - std::cout << "Storing block " << bid << " block hash: " << BlockKeyHasher::hash(blockKeys[blockCnt]) - << std::endl; + // std::cout << "Storing block " << bid << " block hash: " << BlockKeyHasher::hash(blockKeys[blockCnt]) + // << std::endl; TLLM_LOG_DEBUG("%s::storeBlocks - Searching match for block %d", mLogPrefix.c_str(), bid); auto& block = mAllBlocksById[bid]; auto const& blockKey = blockKeys[blockCnt]; diff --git a/cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp b/cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp index 33d5c7fba21..eb5ab3f81ff 100644 --- a/cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp +++ b/cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp @@ -93,8 +93,8 @@ void MLACacheFormatter::format(TransferSession& session) auto const& selfConfig = session.getSelfState().getCacheState().value(); auto const& destConfig = session.getOtherState().getCacheState().value(); auto const selfIdx = session.getSelfState().getCommState().value().getSelfIdx(); - auto& blockHashes = session.getAllBlockHashes(); - auto& requestedBlockHashes = session.getRequestedBlockHashes(); + auto& blockKeys = session.getAllBlockKeys(); + auto indexFromEnd = session.getIndexFromEnd(); auto const& connections = session.getConnections(); auto& bufferManager = session.getBufferManager(); TLLM_CHECK_WITH_INFO(llmRequest.mSamplingConfig.beamWidth == 1, "Currently only supports beam width 1."); @@ -108,7 +108,7 @@ void MLACacheFormatter::format(TransferSession& session) // diff end auto const numPools = mCacheManager->getBlockManager().getNumPools(); - auto blockRange = getBlockRangeForSending(mCacheManager, llmRequest, blockHashes, requestedBlockHashes); + auto blockRange = getBlockRangeForSending(mCacheManager, llmRequest, blockKeys, indexFromEnd); auto lastTokenTime = llmRequest.getPerfMetrics().timingMetrics.lastTokenTime; bool recordDelay = lastTokenTime != std::chrono::steady_clock::time_point(); diff --git a/cpp/tensorrt_llm/executor/serialization.cpp b/cpp/tensorrt_llm/executor/serialization.cpp index bba8d19e2f6..a0ec10304ba 100644 --- a/cpp/tensorrt_llm/executor/serialization.cpp +++ b/cpp/tensorrt_llm/executor/serialization.cpp @@ -16,6 +16,7 @@ */ #include "tensorrt_llm/executor/serialization.h" +#include "tensorrt_llm/batch_manager/kvCacheManager.h" #include "tensorrt_llm/executor/dataTransceiverState.h" #include "tensorrt_llm/executor/executor.h" #include "tensorrt_llm/executor/requestImpl.h" @@ -2440,4 +2441,42 @@ ModelType Serialization::deserializeModelType(std::istream& is) return su::deserialize(is); } +// BlockKey (KV cache) +size_t Serialization::serializedSize(tensorrt_llm::batch_manager::kv_cache_manager::BlockKey const& key) +{ + size_t totalSize = 0; + totalSize += su::serializedSize(key.usesExtraIds); + totalSize += su::serializedSize(key.loraTaskId); + totalSize += su::serializedSize(key.uniqueTokens); + // std::vector where MmKey is pair, SizeType32> + totalSize += su::serializedSize(key.extraKeys); + totalSize += su::serializedSize(key.hash); + return totalSize; +} + +void Serialization::serialize(tensorrt_llm::batch_manager::kv_cache_manager::BlockKey const& key, std::ostream& os) +{ + su::serialize(key.usesExtraIds, os); + su::serialize(key.loraTaskId, os); + su::serialize(key.uniqueTokens, os); + su::serialize(key.extraKeys, os); + su::serialize(key.hash, os); +} + +tensorrt_llm::batch_manager::kv_cache_manager::BlockKey Serialization::deserializeBlockKey(std::istream& is) +{ + auto usesExtraIds = su::deserialize(is); + auto loraTaskId = su::deserialize>(is); + auto uniqueTokens = su::deserialize>(is); + auto extraKeys = su::deserialize>(is); + auto hash = su::deserialize(is); + tensorrt_llm::batch_manager::kv_cache_manager::BlockKey key; + key.usesExtraIds = usesExtraIds; + key.loraTaskId = std::move(loraTaskId); + key.uniqueTokens = std::move(uniqueTokens); + key.extraKeys = std::move(extraKeys); + key.hash = hash; + return key; +} + } // namespace tensorrt_llm::executor diff --git a/cpp/tensorrt_llm/executor/serializeUtils.h b/cpp/tensorrt_llm/executor/serializeUtils.h index 40b50f92309..1f1e90e0a3b 100644 --- a/cpp/tensorrt_llm/executor/serializeUtils.h +++ b/cpp/tensorrt_llm/executor/serializeUtils.h @@ -21,12 +21,14 @@ #include "tensorrt_llm/executor/executor.h" #include "tensorrt_llm/executor/serialization.h" #include "tensorrt_llm/executor/types.h" +#include #include #include #include #include #include #include +#include #include #include @@ -74,6 +76,44 @@ struct is_variant> : std::true_type template constexpr bool is_variant_v = is_variant::value; +// Detect std::array +template +struct is_std_array : std::false_type +{ +}; + +template +struct is_std_array> : std::true_type +{ + using value_type = U; + static constexpr std::size_t size = N; +}; + +template +constexpr bool is_std_array_v = is_std_array::value; + +template +using array_value_type_t = typename is_std_array::value_type; + +template +constexpr std::size_t array_size_v = is_std_array::size; + +// Detect std::pair +template +struct is_std_pair : std::false_type +{ +}; + +template +struct is_std_pair> : std::true_type +{ + using first_type = A; + using second_type = B; +}; + +template +constexpr bool is_std_pair_v = is_std_pair::value; + // SerializedSize template bool constexpr hasSerializedSize(...) @@ -161,6 +201,21 @@ size_t serializedSize(T const& data) } return size; } + // std::array + else if constexpr (is_std_array_v) + { + size_t size = 0; + for (auto const& elem : data) + { + size += serializedSize(elem); + } + return size; + } + // std::pair + else if constexpr (is_std_pair_v) + { + return serializedSize(data.first) + serializedSize(data.second); + } // Optional else if constexpr (std::is_same_v::type>>) { @@ -266,6 +321,20 @@ void serialize(T const& data, std::ostream& os) serialize(element, os); } } + // std::array + else if constexpr (is_std_array_v) + { + for (auto const& element : data) + { + serialize(element, os); + } + } + // std::pair + else if constexpr (is_std_pair_v) + { + serialize(data.first, os); + serialize(data.second, os); + } // Optional else if constexpr (std::is_same_v::type>>) { @@ -575,6 +644,10 @@ T deserialize(std::istream& is) { return Serialization::deserializeUniqueToken(is); } + else if constexpr (std::is_same_v) + { + return Serialization::deserializeBlockKey(is); + } // Optional else if constexpr (std::is_same_v::type>>) { @@ -604,6 +677,23 @@ T deserialize(std::istream& is) } return container; } + // std::array + else if constexpr (is_std_array_v) + { + T container{}; + for (std::size_t i = 0; i < array_size_v; ++i) + { + container[i] = deserialize>(is); + } + return container; + } + // std::pair + else if constexpr (is_std_pair_v) + { + auto first = deserialize::first_type>(is); + auto second = deserialize::second_type>(is); + return T{std::move(first), std::move(second)}; + } // std::variant else if constexpr (is_variant_v) { diff --git a/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp b/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp index 2d2b4e0a9da..65d780f2910 100644 --- a/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp +++ b/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp @@ -580,24 +580,21 @@ TEST_F(KVCacheManagerTest, FindBlocksInReuseTreeByHashesTest) blockManager.storeContextBlocks(seq0, *llmRequest0); - std::vector emptyHashes{}; - auto result = blockManager.findBlocksInReuseTreeByHashes(emptyHashes, maxAttentionWindow); + std::vector emptyKeys{}; + auto result = blockManager.findBlocksInReuseTreeByBlockKeys(emptyKeys, maxAttentionWindow); ASSERT_TRUE(result.has_value()); EXPECT_EQ((*result)->getBlockId(), KVCacheBlock::kCachedBlocksRootId); auto block0 = blockManager.getBlockById(cacheBlockIds[0], maxAttentionWindow); // Corrupt the valid hash to guarantee a miss - std::vector badHashes{block0->getHash() ^ static_cast(0x9e3779b97f4a7c15ULL)}; - result = blockManager.findBlocksInReuseTreeByHashes(badHashes, maxAttentionWindow); + std::vector badKeys{BlockKey{block0->getHash() ^ static_cast(0x9e3779b97f4a7c15ULL)}}; + result = blockManager.findBlocksInReuseTreeByBlockKeys(badKeys, maxAttentionWindow); EXPECT_FALSE(result.has_value()); - std::vector hashes{block0->getHash()}; - result = blockManager.findBlocksInReuseTreeByHashes(hashes, maxAttentionWindow); + std::vector keys{BlockKey{block0->getHash()}}; + result = blockManager.findBlocksInReuseTreeByBlockKeys(keys, maxAttentionWindow); ASSERT_TRUE(result.has_value()); EXPECT_EQ((*result)->getBlockId(), block0->getBlockId()); - - // Add sequence [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16] (17 tokens, three blocks) - BlockKey } #ifdef ENABLE_FP4