Skip to content

Commit d1c5f80

Browse files
committed
Refactor dataTransceiver classes
Signed-off-by: Iman Tabrizian <[email protected]>
1 parent abb41e4 commit d1c5f80

File tree

11 files changed

+524
-653
lines changed

11 files changed

+524
-653
lines changed

cpp/include/tensorrt_llm/batch_manager/cacheTransceiver.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ namespace tensorrt_llm::batch_manager
3434

3535
class ContextProgress;
3636
class BaseCacheTransceiver;
37-
class DataResponder;
38-
class DataRequester;
37+
class CacheSender;
38+
class CacheReceiver;
3939

4040
class CacheTransceiverFactory
4141
{
@@ -110,8 +110,8 @@ class CacheTransceiver : public BaseCacheTransceiver
110110

111111
void setContextState(LlmRequest* llmRequest);
112112

113-
std::unique_ptr<DataResponder> mCacheSender;
114-
std::unique_ptr<DataRequester> mDataRequester;
113+
std::unique_ptr<CacheSender> mCacheSender;
114+
std::unique_ptr<CacheReceiver> mCacheReceiver;
115115
std::vector<std::pair<LlmRequest*, std::future<void>>> mSenderFutures;
116116
std::vector<std::pair<LlmRequest*, std::future<void>>> mRequesterFutures;
117117
mpi::MpiComm const *mMpiGroupComm{nullptr}, *mMpiWorldComm{nullptr};

cpp/tensorrt_llm/batch_manager/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ set(SRCS
2424
createNewDecoderRequests.cpp
2525
contextProgress.cpp
2626
dataTransceiver.cpp
27-
dataTransceiverImpl.cpp
2827
decoderBuffers.cpp
2928
encoderBuffers.cpp
3029
guidedDecoder.cpp

cpp/tensorrt_llm/batch_manager/cacheFormatter.h

Lines changed: 190 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,11 @@
1818
#pragma once
1919

2020
#include "cacheTransBuffer.h"
21-
#include "dataTransceiver.h"
2221
#include "tensorrt_llm/batch_manager/kvCacheManager.h"
2322
#include "tensorrt_llm/batch_manager/kvCacheUtils.h"
2423
#include "tensorrt_llm/common/envUtils.h"
2524
#include "tensorrt_llm/common/logger.h"
25+
#include "tensorrt_llm/executor/cacheCommunicator.h"
2626
#include "tensorrt_llm/executor/cache_transmission/cacheSplitConcat.h"
2727
#include "tensorrt_llm/executor/dataTransceiverState.h"
2828
#include "tensorrt_llm/runtime/bufferManager.h"
@@ -38,6 +38,135 @@ BlockRange getBlockRangeForSending(BaseKVCacheManager* cacheManager, LlmRequest
3838

3939
BlockRange getBlockRangeForReceiving(BaseKVCacheManager* cacheManager, LlmRequest const& llmRequest);
4040

41+
using DataContext = tensorrt_llm::executor::kv_cache::DataContext;
42+
using Connection = tensorrt_llm::executor::kv_cache::Connection;
43+
using SizeType32 = tensorrt_llm::runtime::SizeType32;
44+
45+
class TransferSession
46+
{
47+
public:
48+
struct Measure
49+
{
50+
double delay; // from last token (ctx) or arrival time (gen), in ms
51+
double duration; // in ms
52+
double bandwidth; // in Gbps
53+
};
54+
55+
TransferSession(std::vector<Connection const*> connections, DataContext dataContext,
56+
executor::DataTransceiverState const& selfState, executor::DataTransceiverState otherState,
57+
runtime::BufferManager const& bufferManager, LlmRequest const* llmRequest = nullptr)
58+
: mConnections(std::move(connections))
59+
, mDataContext(dataContext)
60+
, mSelfState(&selfState)
61+
, mOtherState(std::move(otherState))
62+
, mBufferManager(&bufferManager)
63+
, mRequest(llmRequest)
64+
{
65+
TLLM_CHECK(!mConnections.empty());
66+
}
67+
68+
[[nodiscard]] std::vector<Connection const*> const& getConnections() const
69+
{
70+
return mConnections;
71+
}
72+
73+
// should be called only during the initialization of the TransferSession
74+
void setConnection(size_t idx, Connection const* conn)
75+
{
76+
mConnections.at(idx) = conn;
77+
}
78+
79+
[[nodiscard]] DataContext const& getDataContext() const
80+
{
81+
return mDataContext;
82+
}
83+
84+
[[nodiscard]] executor::DataTransceiverState const& getSelfState() const
85+
{
86+
return *mSelfState;
87+
}
88+
89+
[[nodiscard]] executor::DataTransceiverState const& getOtherState() const
90+
{
91+
return mOtherState;
92+
}
93+
94+
[[nodiscard]] runtime::BufferManager const& getBufferManager() const
95+
{
96+
return *mBufferManager;
97+
}
98+
99+
void send(size_t idx, void const* data, size_t size)
100+
{
101+
mConnections.at(idx)->send(mDataContext, data, size);
102+
}
103+
104+
void recv(size_t idx, void* data, size_t size)
105+
{
106+
mConnections.at(idx)->recv(mDataContext, data, size);
107+
}
108+
109+
[[nodiscard]] LlmRequest const& getLlmRequest() const
110+
{
111+
TLLM_CHECK(mRequest != nullptr);
112+
return *mRequest;
113+
}
114+
115+
// in CacheSender, the LlmRequest is not available until the sendSync is called
116+
void setLlmRequest(LlmRequest const& llmRequest)
117+
{
118+
mRequest = &llmRequest;
119+
}
120+
121+
void appendMeasure(double delay, double duration, size_t size)
122+
{
123+
if (!mRecordMeasure)
124+
{
125+
return;
126+
}
127+
auto bandwidth = size * 8 / (duration / 1000) / 1e9; // byte, ms => Gbps
128+
mMeasures.emplace_back(Measure{delay, duration, bandwidth});
129+
}
130+
131+
// TODO: 1. use global id instead of context request id; 2. export to llm metrics instead of file
132+
void exportMeasure(std::ofstream& outFile, bool isContext) const
133+
{
134+
if (mMeasures.empty())
135+
{
136+
return;
137+
}
138+
// write header if not exist
139+
if (outFile.tellp() == 0)
140+
{
141+
outFile << "RequestID";
142+
for (size_t i = 0; i < mMeasures.size(); i++)
143+
{
144+
outFile << ",Delay(ms),Duration(ms),Bandwidth(Gbps)";
145+
}
146+
outFile << '\n';
147+
}
148+
// write measures
149+
TLLM_CHECK(isContext || mRequest->getContextPhaseParams().has_value());
150+
auto reqId = isContext ? mRequest->mRequestId : mRequest->getContextPhaseParams().value().getReqId();
151+
outFile << reqId;
152+
for (auto const& measure : mMeasures)
153+
{
154+
outFile << "," << measure.delay << "," << measure.duration << "," << measure.bandwidth;
155+
}
156+
outFile << '\n' << std::flush;
157+
}
158+
159+
private:
160+
std::vector<Connection const*> mConnections;
161+
DataContext mDataContext;
162+
executor::DataTransceiverState const* mSelfState; // stored in CacheReceiver/CacheSender
163+
executor::DataTransceiverState mOtherState;
164+
runtime::BufferManager const* mBufferManager;
165+
LlmRequest const* mRequest;
166+
std::vector<Measure> mMeasures;
167+
bool mRecordMeasure{false};
168+
};
169+
41170
// Used to support the cache transmission with different layouts and different protocols.
42171
class BaseCacheFormatter
43172
{
@@ -78,6 +207,66 @@ class BaseCacheFormatter
78207
virtual ~BaseCacheFormatter() = default;
79208
};
80209

210+
class KvCacheMeasureHelper
211+
{
212+
public:
213+
KvCacheMeasureHelper(std::string output_path)
214+
: mOutputPath(std::move(output_path))
215+
{
216+
}
217+
218+
void appendKVCacheTransfer(LlmRequest::RequestIdType requestId, double duration, size_t size)
219+
{
220+
auto bandwidth = size * 8 / (duration / 1000) / 1e9;
221+
if (mOutputPath.empty())
222+
{
223+
return;
224+
}
225+
226+
std::lock_guard<std::mutex> lock(mMutex);
227+
mRequestKVCacheTranfserMeasure[requestId].emplace_back(duration, bandwidth);
228+
}
229+
230+
~KvCacheMeasureHelper()
231+
{
232+
if (!mRequestKVCacheTranfserMeasure.empty() && !mOutputPath.empty())
233+
{
234+
auto rank = mpi::MpiComm::world().getRank();
235+
std::string outFilePath = mOutputPath + "rank_" + std::to_string(rank) + ".txt";
236+
std::ofstream outFile(outFilePath);
237+
238+
TLLM_CHECK_WITH_INFO(outFile.is_open(), "Cannot write to file " + outFilePath);
239+
240+
size_t numTransferMeasure = mRequestKVCacheTranfserMeasure.begin()->second.size();
241+
242+
outFile << "RequestID";
243+
for (size_t i = 0; i < numTransferMeasure; i++)
244+
{
245+
outFile << ",TimeDuration,Bandwidth";
246+
}
247+
outFile << '\n';
248+
249+
for (auto const& [requestID, measures] : mRequestKVCacheTranfserMeasure)
250+
{
251+
outFile << requestID;
252+
253+
for (auto const& [time, bandwidth] : measures)
254+
{
255+
outFile << "," << time << "," << bandwidth;
256+
}
257+
outFile << '\n';
258+
}
259+
260+
outFile.close();
261+
}
262+
}
263+
264+
private:
265+
std::map<LlmRequest::RequestIdType, std::vector<std::pair<double, double>>> mRequestKVCacheTranfserMeasure;
266+
std::string mOutputPath;
267+
std::mutex mMutex;
268+
};
269+
81270
// Simple cache block copy. Because it does not involve data splitting or merging, it performs best when the
82271
// parallel topology is completely identical, making it the preferred method.
83272
class CacheFormatter final : public BaseCacheFormatter

cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
#include "tensorrt_llm/batch_manager/cacheFormatter.h"
3838
#include "tensorrt_llm/batch_manager/cacheTransceiver.h"
3939
#include "tensorrt_llm/batch_manager/contextProgress.h"
40-
#include "tensorrt_llm/batch_manager/dataTransceiverImpl.h"
4140
#include "tensorrt_llm/batch_manager/kvCacheManager.h"
4241
#include "tensorrt_llm/batch_manager/llmRequest.h"
4342
#include "tensorrt_llm/batch_manager/mlaCacheFormatter.h"
@@ -195,10 +194,9 @@ CacheTransceiver::CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheMa
195194
auto makeFormatter = [cacheManager, isMLA, this]()
196195
{ return createCacheFormatter(cacheManager, mCacheTransBufferManager.get(), isMLA); };
197196

198-
mCacheSender = std::make_unique<DataResponder>(
199-
std::make_unique<CacheSenderImpl>(mManager.get(), *mCacheState, worldConfig.getRank(), makeFormatter()));
200-
mDataRequester = std::make_unique<DataRequester>(
201-
std::make_unique<CacheReceiverImpl>(mManager.get(), *mCacheState, worldConfig.getRank(), makeFormatter()));
197+
mCacheSender = std::make_unique<CacheSender>(mManager.get(), *mCacheState, worldConfig.getRank(), makeFormatter());
198+
mCacheReceiver
199+
= std::make_unique<CacheReceiver>(mManager.get(), *mCacheState, worldConfig.getRank(), makeFormatter());
202200

203201
initializeCommState();
204202
}
@@ -250,7 +248,7 @@ void CacheTransceiver::respondAndSendAsync(LlmRequest* llmRequest)
250248
return;
251249
}
252250
setContextState(llmRequest);
253-
auto future = mCacheSender->respondAndSendAsync(*llmRequest);
251+
auto future = mCacheSender->sendAsync(*llmRequest);
254252
mSenderFutures.emplace_back(llmRequest, std::move(future));
255253
}
256254

@@ -266,7 +264,7 @@ void CacheTransceiver::respondAndSendLayerWise(
266264

267265
llmRequest->setState(LlmRequestState::kDISAGG_CONTEXT_INIT_AND_TRANS);
268266
setContextState(llmRequest.get());
269-
auto future = mCacheSender->respondAndSendAsync(*llmRequest);
267+
auto future = mCacheSender->sendAsync(*llmRequest);
270268
mSenderFutures.emplace_back(llmRequest.get(), std::move(future));
271269
}
272270
}
@@ -275,7 +273,7 @@ void CacheTransceiver::requestAndReceiveSync(LlmRequest* llmRequest)
275273
{
276274
TLLM_CHECK(llmRequest && llmRequest->isGenerationOnlyRequest());
277275
{
278-
auto future = mDataRequester->requestAndReceiveAsync(*llmRequest);
276+
auto future = mCacheReceiver->receiveAsync(*llmRequest);
279277
future.get();
280278
}
281279
llmRequest->setState(LlmRequestState::kDISAGG_GENERATION_TRANS_COMPLETE);
@@ -293,7 +291,7 @@ void CacheTransceiver::requestAndReceiveAsync(LlmRequest* llmRequest)
293291
return;
294292
}
295293

296-
auto future = mDataRequester->requestAndReceiveAsync(*llmRequest);
294+
auto future = mCacheReceiver->receiveAsync(*llmRequest);
297295
mRequesterFutures.emplace_back(llmRequest, std::move(future));
298296
llmRequest->setState(LlmRequestState::kDISAGG_GENERATION_TRANS_IN_PROGRESS);
299297
}

0 commit comments

Comments
 (0)