18
18
#pragma once
19
19
20
20
#include " cacheTransBuffer.h"
21
- #include " dataTransceiver.h"
22
21
#include " tensorrt_llm/batch_manager/kvCacheManager.h"
23
22
#include " tensorrt_llm/batch_manager/kvCacheUtils.h"
24
23
#include " tensorrt_llm/common/envUtils.h"
25
24
#include " tensorrt_llm/common/logger.h"
25
+ #include " tensorrt_llm/executor/cacheCommunicator.h"
26
26
#include " tensorrt_llm/executor/cache_transmission/cacheSplitConcat.h"
27
27
#include " tensorrt_llm/executor/dataTransceiverState.h"
28
28
#include " tensorrt_llm/runtime/bufferManager.h"
@@ -38,6 +38,135 @@ BlockRange getBlockRangeForSending(BaseKVCacheManager* cacheManager, LlmRequest
38
38
39
39
BlockRange getBlockRangeForReceiving (BaseKVCacheManager* cacheManager, LlmRequest const & llmRequest);
40
40
41
+ using DataContext = tensorrt_llm::executor::kv_cache::DataContext;
42
+ using Connection = tensorrt_llm::executor::kv_cache::Connection;
43
+ using SizeType32 = tensorrt_llm::runtime::SizeType32;
44
+
45
+ class TransferSession
46
+ {
47
+ public:
48
+ struct Measure
49
+ {
50
+ double delay; // from last token (ctx) or arrival time (gen), in ms
51
+ double duration; // in ms
52
+ double bandwidth; // in Gbps
53
+ };
54
+
55
+ TransferSession (std::vector<Connection const *> connections, DataContext dataContext,
56
+ executor::DataTransceiverState const & selfState, executor::DataTransceiverState otherState,
57
+ runtime::BufferManager const & bufferManager, LlmRequest const * llmRequest = nullptr )
58
+ : mConnections (std::move(connections))
59
+ , mDataContext (dataContext)
60
+ , mSelfState (&selfState)
61
+ , mOtherState (std::move(otherState))
62
+ , mBufferManager (&bufferManager)
63
+ , mRequest (llmRequest)
64
+ {
65
+ TLLM_CHECK (!mConnections .empty ());
66
+ }
67
+
68
+ [[nodiscard]] std::vector<Connection const *> const & getConnections () const
69
+ {
70
+ return mConnections ;
71
+ }
72
+
73
+ // should be called only during the initialization of the TransferSession
74
+ void setConnection (size_t idx, Connection const * conn)
75
+ {
76
+ mConnections .at (idx) = conn;
77
+ }
78
+
79
+ [[nodiscard]] DataContext const & getDataContext () const
80
+ {
81
+ return mDataContext ;
82
+ }
83
+
84
+ [[nodiscard]] executor::DataTransceiverState const & getSelfState () const
85
+ {
86
+ return *mSelfState ;
87
+ }
88
+
89
+ [[nodiscard]] executor::DataTransceiverState const & getOtherState () const
90
+ {
91
+ return mOtherState ;
92
+ }
93
+
94
+ [[nodiscard]] runtime::BufferManager const & getBufferManager () const
95
+ {
96
+ return *mBufferManager ;
97
+ }
98
+
99
+ void send (size_t idx, void const * data, size_t size)
100
+ {
101
+ mConnections .at (idx)->send (mDataContext , data, size);
102
+ }
103
+
104
+ void recv (size_t idx, void * data, size_t size)
105
+ {
106
+ mConnections .at (idx)->recv (mDataContext , data, size);
107
+ }
108
+
109
+ [[nodiscard]] LlmRequest const & getLlmRequest () const
110
+ {
111
+ TLLM_CHECK (mRequest != nullptr );
112
+ return *mRequest ;
113
+ }
114
+
115
+ // in CacheSender, the LlmRequest is not available until the sendSync is called
116
+ void setLlmRequest (LlmRequest const & llmRequest)
117
+ {
118
+ mRequest = &llmRequest;
119
+ }
120
+
121
+ void appendMeasure (double delay, double duration, size_t size)
122
+ {
123
+ if (!mRecordMeasure )
124
+ {
125
+ return ;
126
+ }
127
+ auto bandwidth = size * 8 / (duration / 1000 ) / 1e9 ; // byte, ms => Gbps
128
+ mMeasures .emplace_back (Measure{delay, duration, bandwidth});
129
+ }
130
+
131
+ // TODO: 1. use global id instead of context request id; 2. export to llm metrics instead of file
132
+ void exportMeasure (std::ofstream& outFile, bool isContext) const
133
+ {
134
+ if (mMeasures .empty ())
135
+ {
136
+ return ;
137
+ }
138
+ // write header if not exist
139
+ if (outFile.tellp () == 0 )
140
+ {
141
+ outFile << " RequestID" ;
142
+ for (size_t i = 0 ; i < mMeasures .size (); i++)
143
+ {
144
+ outFile << " ,Delay(ms),Duration(ms),Bandwidth(Gbps)" ;
145
+ }
146
+ outFile << ' \n ' ;
147
+ }
148
+ // write measures
149
+ TLLM_CHECK (isContext || mRequest ->getContextPhaseParams ().has_value ());
150
+ auto reqId = isContext ? mRequest ->mRequestId : mRequest ->getContextPhaseParams ().value ().getReqId ();
151
+ outFile << reqId;
152
+ for (auto const & measure : mMeasures )
153
+ {
154
+ outFile << " ," << measure.delay << " ," << measure.duration << " ," << measure.bandwidth ;
155
+ }
156
+ outFile << ' \n ' << std::flush;
157
+ }
158
+
159
+ private:
160
+ std::vector<Connection const *> mConnections ;
161
+ DataContext mDataContext ;
162
+ executor::DataTransceiverState const * mSelfState ; // stored in CacheReceiver/CacheSender
163
+ executor::DataTransceiverState mOtherState ;
164
+ runtime::BufferManager const * mBufferManager ;
165
+ LlmRequest const * mRequest ;
166
+ std::vector<Measure> mMeasures ;
167
+ bool mRecordMeasure {false };
168
+ };
169
+
41
170
// Used to support the cache transmission with different layouts and different protocols.
42
171
class BaseCacheFormatter
43
172
{
@@ -78,6 +207,66 @@ class BaseCacheFormatter
78
207
virtual ~BaseCacheFormatter () = default ;
79
208
};
80
209
210
+ class KvCacheMeasureHelper
211
+ {
212
+ public:
213
+ KvCacheMeasureHelper (std::string output_path)
214
+ : mOutputPath (std::move(output_path))
215
+ {
216
+ }
217
+
218
+ void appendKVCacheTransfer (LlmRequest::RequestIdType requestId, double duration, size_t size)
219
+ {
220
+ auto bandwidth = size * 8 / (duration / 1000 ) / 1e9 ;
221
+ if (mOutputPath .empty ())
222
+ {
223
+ return ;
224
+ }
225
+
226
+ std::lock_guard<std::mutex> lock (mMutex );
227
+ mRequestKVCacheTranfserMeasure [requestId].emplace_back (duration, bandwidth);
228
+ }
229
+
230
+ ~KvCacheMeasureHelper ()
231
+ {
232
+ if (!mRequestKVCacheTranfserMeasure .empty () && !mOutputPath .empty ())
233
+ {
234
+ auto rank = mpi::MpiComm::world ().getRank ();
235
+ std::string outFilePath = mOutputPath + " rank_" + std::to_string (rank) + " .txt" ;
236
+ std::ofstream outFile (outFilePath);
237
+
238
+ TLLM_CHECK_WITH_INFO (outFile.is_open (), " Cannot write to file " + outFilePath);
239
+
240
+ size_t numTransferMeasure = mRequestKVCacheTranfserMeasure .begin ()->second .size ();
241
+
242
+ outFile << " RequestID" ;
243
+ for (size_t i = 0 ; i < numTransferMeasure; i++)
244
+ {
245
+ outFile << " ,TimeDuration,Bandwidth" ;
246
+ }
247
+ outFile << ' \n ' ;
248
+
249
+ for (auto const & [requestID, measures] : mRequestKVCacheTranfserMeasure )
250
+ {
251
+ outFile << requestID;
252
+
253
+ for (auto const & [time, bandwidth] : measures)
254
+ {
255
+ outFile << " ," << time << " ," << bandwidth;
256
+ }
257
+ outFile << ' \n ' ;
258
+ }
259
+
260
+ outFile.close ();
261
+ }
262
+ }
263
+
264
+ private:
265
+ std::map<LlmRequest::RequestIdType, std::vector<std::pair<double , double >>> mRequestKVCacheTranfserMeasure ;
266
+ std::string mOutputPath ;
267
+ std::mutex mMutex ;
268
+ };
269
+
81
270
// Simple cache block copy. Because it does not involve data splitting or merging, it performs best when the
82
271
// parallel topology is completely identical, making it the preferred method.
83
272
class CacheFormatter final : public BaseCacheFormatter
0 commit comments