Skip to content

Commit eed674f

Browse files
committed
Refactor dataTransceiver classes
Signed-off-by: Iman Tabrizian <[email protected]>
1 parent 214b335 commit eed674f

File tree

11 files changed

+491
-616
lines changed

11 files changed

+491
-616
lines changed

cpp/include/tensorrt_llm/batch_manager/cacheTransceiver.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ namespace tensorrt_llm::batch_manager
3434

3535
class ContextProgress;
3636
class BaseCacheTransceiver;
37-
class DataResponder;
38-
class DataRequester;
37+
class CacheSender;
38+
class CacheReceiver;
3939

4040
class CacheTransceiverFactory
4141
{
@@ -110,8 +110,8 @@ class CacheTransceiver : public BaseCacheTransceiver
110110

111111
void setContextState(LlmRequest* llmRequest);
112112

113-
std::unique_ptr<DataResponder> mCacheSender;
114-
std::unique_ptr<DataRequester> mDataRequester;
113+
std::unique_ptr<CacheSender> mCacheSender;
114+
std::unique_ptr<CacheReceiver> mCacheReceiver;
115115
std::vector<std::pair<LlmRequest*, std::future<void>>> mSenderFutures;
116116
std::vector<std::pair<LlmRequest*, std::future<void>>> mRequesterFutures;
117117
mpi::MpiComm const *mMpiGroupComm{nullptr}, *mMpiWorldComm{nullptr};

cpp/tensorrt_llm/batch_manager/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ set(SRCS
2424
createNewDecoderRequests.cpp
2525
contextProgress.cpp
2626
dataTransceiver.cpp
27-
dataTransceiverImpl.cpp
2827
decoderBuffers.cpp
2928
encoderBuffers.cpp
3029
guidedDecoder.cpp

cpp/tensorrt_llm/batch_manager/cacheFormatter.h

Lines changed: 143 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,11 @@
1818
#pragma once
1919

2020
#include "cacheTransBuffer.h"
21-
#include "dataTransceiver.h"
2221
#include "tensorrt_llm/batch_manager/kvCacheManager.h"
2322
#include "tensorrt_llm/batch_manager/kvCacheUtils.h"
2423
#include "tensorrt_llm/common/envUtils.h"
2524
#include "tensorrt_llm/common/logger.h"
25+
#include "tensorrt_llm/executor/cacheCommunicator.h"
2626
#include "tensorrt_llm/executor/cache_transmission/cacheSplitConcat.h"
2727
#include "tensorrt_llm/executor/dataTransceiverState.h"
2828
#include "tensorrt_llm/runtime/bufferManager.h"
@@ -38,6 +38,88 @@ BlockRange getBlockRangeForSending(BaseKVCacheManager* cacheManager, LlmRequest
3838

3939
BlockRange getBlockRangeForReceiving(BaseKVCacheManager* cacheManager, LlmRequest const& llmRequest);
4040

41+
using DataContext = tensorrt_llm::executor::kv_cache::DataContext;
42+
using Connection = tensorrt_llm::executor::kv_cache::Connection;
43+
using SizeType32 = tensorrt_llm::runtime::SizeType32;
44+
45+
class TransferSession
46+
{
47+
public:
48+
TransferSession(std::vector<Connection const*> connections, DataContext dataContext,
49+
executor::DataTransceiverState const& selfState, executor::DataTransceiverState otherState,
50+
runtime::BufferManager const& bufferManager, LlmRequest const* llmRequest = nullptr)
51+
: mConnections(std::move(connections))
52+
, mDataContext(dataContext)
53+
, mSelfState(&selfState)
54+
, mOtherState(std::move(otherState))
55+
, mBufferManager(&bufferManager)
56+
, mRequest(llmRequest)
57+
{
58+
TLLM_CHECK(!mConnections.empty());
59+
}
60+
61+
[[nodiscard]] std::vector<Connection const*> const& getConnections() const
62+
{
63+
return mConnections;
64+
}
65+
66+
// should be called only during the initialization of the TransferSession
67+
void setConnection(size_t idx, Connection const* conn)
68+
{
69+
mConnections.at(idx) = conn;
70+
}
71+
72+
[[nodiscard]] DataContext const& getDataContext() const
73+
{
74+
return mDataContext;
75+
}
76+
77+
[[nodiscard]] executor::DataTransceiverState const& getSelfState() const
78+
{
79+
return *mSelfState;
80+
}
81+
82+
[[nodiscard]] executor::DataTransceiverState const& getOtherState() const
83+
{
84+
return mOtherState;
85+
}
86+
87+
[[nodiscard]] runtime::BufferManager const& getBufferManager() const
88+
{
89+
return *mBufferManager;
90+
}
91+
92+
void send(size_t idx, void const* data, size_t size)
93+
{
94+
mConnections.at(idx)->send(mDataContext, data, size);
95+
}
96+
97+
void recv(size_t idx, void* data, size_t size)
98+
{
99+
mConnections.at(idx)->recv(mDataContext, data, size);
100+
}
101+
102+
[[nodiscard]] LlmRequest const& getLlmRequest() const
103+
{
104+
TLLM_CHECK(mRequest != nullptr);
105+
return *mRequest;
106+
}
107+
108+
// in CacheSender, the LlmRequest is not available until the sendSync is called
109+
void setLlmRequest(LlmRequest const& llmRequest)
110+
{
111+
mRequest = &llmRequest;
112+
}
113+
114+
private:
115+
std::vector<Connection const*> mConnections;
116+
DataContext mDataContext;
117+
executor::DataTransceiverState const* mSelfState; // stored in CacheReceiver/CacheSender
118+
executor::DataTransceiverState mOtherState;
119+
runtime::BufferManager const* mBufferManager;
120+
LlmRequest const* mRequest;
121+
};
122+
41123
// Used to support the cache transmission with different layouts and different protocols.
42124
class BaseCacheFormatter
43125
{
@@ -78,6 +160,66 @@ class BaseCacheFormatter
78160
virtual ~BaseCacheFormatter() = default;
79161
};
80162

163+
class KvCacheMeasureHelper
164+
{
165+
public:
166+
KvCacheMeasureHelper(std::string output_path)
167+
: mOutputPath(std::move(output_path))
168+
{
169+
}
170+
171+
void appendKVCacheTransfer(LlmRequest::RequestIdType requestId, double duration, size_t size)
172+
{
173+
auto bandwidth = size * 8 / (duration / 1000) / 1e9;
174+
if (mOutputPath.empty())
175+
{
176+
return;
177+
}
178+
179+
std::lock_guard<std::mutex> lock(mMutex);
180+
mRequestKVCacheTranfserMeasure[requestId].emplace_back(duration, bandwidth);
181+
}
182+
183+
~KvCacheMeasureHelper()
184+
{
185+
if (!mRequestKVCacheTranfserMeasure.empty() && !mOutputPath.empty())
186+
{
187+
auto rank = mpi::MpiComm::world().getRank();
188+
std::string outFilePath = mOutputPath + "rank_" + std::to_string(rank) + ".txt";
189+
std::ofstream outFile(outFilePath);
190+
191+
TLLM_CHECK_WITH_INFO(outFile.is_open(), "Cannot write to file " + outFilePath);
192+
193+
size_t numTransferMeasure = mRequestKVCacheTranfserMeasure.begin()->second.size();
194+
195+
outFile << "RequestID";
196+
for (size_t i = 0; i < numTransferMeasure; i++)
197+
{
198+
outFile << ",TimeDuration,Bandwidth";
199+
}
200+
outFile << '\n';
201+
202+
for (auto const& [requestID, measures] : mRequestKVCacheTranfserMeasure)
203+
{
204+
outFile << requestID;
205+
206+
for (auto const& [time, bandwidth] : measures)
207+
{
208+
outFile << "," << time << "," << bandwidth;
209+
}
210+
outFile << '\n';
211+
}
212+
213+
outFile.close();
214+
}
215+
}
216+
217+
private:
218+
std::map<LlmRequest::RequestIdType, std::vector<std::pair<double, double>>> mRequestKVCacheTranfserMeasure;
219+
std::string mOutputPath;
220+
std::mutex mMutex;
221+
};
222+
81223
// Simple cache block copy. Because it does not involve data splitting or merging, it performs best when the
82224
// parallel topology is completely identical, making it the preferred method.
83225
class CacheFormatter final : public BaseCacheFormatter

cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -194,10 +194,9 @@ CacheTransceiver::CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheMa
194194
auto makeFormatter = [cacheManager, isMLA, this]()
195195
{ return createCacheFormatter(cacheManager, mCacheTransBufferManager.get(), isMLA); };
196196

197-
mCacheSender = std::make_unique<DataResponder>(
198-
std::make_unique<CacheSenderImpl>(mManager.get(), *mCacheState, worldConfig.getRank(), makeFormatter()));
199-
mDataRequester = std::make_unique<DataRequester>(
200-
std::make_unique<CacheReceiverImpl>(mManager.get(), *mCacheState, worldConfig.getRank(), makeFormatter()));
197+
mCacheSender = std::make_unique<CacheSender>(mManager.get(), *mCacheState, worldConfig.getRank(), makeFormatter());
198+
mCacheReceiver
199+
= std::make_unique<CacheReceiver>(mManager.get(), *mCacheState, worldConfig.getRank(), makeFormatter());
201200

202201
initializeCommState();
203202
}
@@ -249,7 +248,7 @@ void CacheTransceiver::respondAndSendAsync(LlmRequest* llmRequest)
249248
return;
250249
}
251250
setContextState(llmRequest);
252-
auto future = mCacheSender->respondAndSendAsync(*llmRequest);
251+
auto future = mCacheSender->sendAsync(*llmRequest);
253252
mSenderFutures.emplace_back(llmRequest, std::move(future));
254253
}
255254

@@ -265,7 +264,7 @@ void CacheTransceiver::respondAndSendLayerWise(
265264

266265
llmRequest->setState(LlmRequestState::kDISAGG_CONTEXT_INIT_AND_TRANS);
267266
setContextState(llmRequest.get());
268-
auto future = mCacheSender->respondAndSendAsync(*llmRequest);
267+
auto future = mCacheSender->sendAsync(*llmRequest);
269268
mSenderFutures.emplace_back(llmRequest.get(), std::move(future));
270269
}
271270
}
@@ -274,7 +273,7 @@ void CacheTransceiver::requestAndReceiveSync(LlmRequest* llmRequest)
274273
{
275274
TLLM_CHECK(llmRequest && llmRequest->isGenerationOnlyRequest());
276275
{
277-
auto future = mDataRequester->requestAndReceiveAsync(*llmRequest);
276+
auto future = mCacheReceiver->receiveAsync(*llmRequest);
278277
future.get();
279278
}
280279
llmRequest->setState(LlmRequestState::kDISAGG_GENERATION_TRANS_COMPLETE);
@@ -292,7 +291,7 @@ void CacheTransceiver::requestAndReceiveAsync(LlmRequest* llmRequest)
292291
return;
293292
}
294293

295-
auto future = mDataRequester->requestAndReceiveAsync(*llmRequest);
294+
auto future = mCacheReceiver->receiveAsync(*llmRequest);
296295
mRequesterFutures.emplace_back(llmRequest, std::move(future));
297296
llmRequest->setState(LlmRequestState::kDISAGG_GENERATION_TRANS_IN_PROGRESS);
298297
}

0 commit comments

Comments
 (0)