18
18
#pragma once
19
19
20
20
#include " cacheTransBuffer.h"
21
- #include " dataTransceiver.h"
22
21
#include " tensorrt_llm/batch_manager/kvCacheManager.h"
23
22
#include " tensorrt_llm/batch_manager/kvCacheUtils.h"
24
23
#include " tensorrt_llm/common/envUtils.h"
25
24
#include " tensorrt_llm/common/logger.h"
25
+ #include " tensorrt_llm/executor/cacheCommunicator.h"
26
26
#include " tensorrt_llm/executor/cache_transmission/cacheSplitConcat.h"
27
27
#include " tensorrt_llm/executor/dataTransceiverState.h"
28
28
#include " tensorrt_llm/runtime/bufferManager.h"
@@ -38,6 +38,88 @@ BlockRange getBlockRangeForSending(BaseKVCacheManager* cacheManager, LlmRequest
38
38
39
39
BlockRange getBlockRangeForReceiving (BaseKVCacheManager* cacheManager, LlmRequest const & llmRequest);
40
40
41
+ using DataContext = tensorrt_llm::executor::kv_cache::DataContext;
42
+ using Connection = tensorrt_llm::executor::kv_cache::Connection;
43
+ using SizeType32 = tensorrt_llm::runtime::SizeType32;
44
+
45
+ class TransferSession
46
+ {
47
+ public:
48
+ TransferSession (std::vector<Connection const *> connections, DataContext dataContext,
49
+ executor::DataTransceiverState const & selfState, executor::DataTransceiverState otherState,
50
+ runtime::BufferManager const & bufferManager, LlmRequest const * llmRequest = nullptr )
51
+ : mConnections (std::move(connections))
52
+ , mDataContext (dataContext)
53
+ , mSelfState (&selfState)
54
+ , mOtherState (std::move(otherState))
55
+ , mBufferManager (&bufferManager)
56
+ , mRequest (llmRequest)
57
+ {
58
+ TLLM_CHECK (!mConnections .empty ());
59
+ }
60
+
61
+ [[nodiscard]] std::vector<Connection const *> const & getConnections () const
62
+ {
63
+ return mConnections ;
64
+ }
65
+
66
+ // should be called only during the initialization of the TransferSession
67
+ void setConnection (size_t idx, Connection const * conn)
68
+ {
69
+ mConnections .at (idx) = conn;
70
+ }
71
+
72
+ [[nodiscard]] DataContext const & getDataContext () const
73
+ {
74
+ return mDataContext ;
75
+ }
76
+
77
+ [[nodiscard]] executor::DataTransceiverState const & getSelfState () const
78
+ {
79
+ return *mSelfState ;
80
+ }
81
+
82
+ [[nodiscard]] executor::DataTransceiverState const & getOtherState () const
83
+ {
84
+ return mOtherState ;
85
+ }
86
+
87
+ [[nodiscard]] runtime::BufferManager const & getBufferManager () const
88
+ {
89
+ return *mBufferManager ;
90
+ }
91
+
92
+ void send (size_t idx, void const * data, size_t size)
93
+ {
94
+ mConnections .at (idx)->send (mDataContext , data, size);
95
+ }
96
+
97
+ void recv (size_t idx, void * data, size_t size)
98
+ {
99
+ mConnections .at (idx)->recv (mDataContext , data, size);
100
+ }
101
+
102
+ [[nodiscard]] LlmRequest const & getLlmRequest () const
103
+ {
104
+ TLLM_CHECK (mRequest != nullptr );
105
+ return *mRequest ;
106
+ }
107
+
108
+ // in CacheSender, the LlmRequest is not available until the sendSync is called
109
+ void setLlmRequest (LlmRequest const & llmRequest)
110
+ {
111
+ mRequest = &llmRequest;
112
+ }
113
+
114
+ private:
115
+ std::vector<Connection const *> mConnections ;
116
+ DataContext mDataContext ;
117
+ executor::DataTransceiverState const * mSelfState ; // stored in CacheReceiver/CacheSender
118
+ executor::DataTransceiverState mOtherState ;
119
+ runtime::BufferManager const * mBufferManager ;
120
+ LlmRequest const * mRequest ;
121
+ };
122
+
41
123
// Used to support the cache transmission with different layouts and different protocols.
42
124
class BaseCacheFormatter
43
125
{
@@ -78,6 +160,66 @@ class BaseCacheFormatter
78
160
virtual ~BaseCacheFormatter () = default ;
79
161
};
80
162
163
+ class KvCacheMeasureHelper
164
+ {
165
+ public:
166
+ KvCacheMeasureHelper (std::string output_path)
167
+ : mOutputPath (std::move(output_path))
168
+ {
169
+ }
170
+
171
+ void appendKVCacheTransfer (LlmRequest::RequestIdType requestId, double duration, size_t size)
172
+ {
173
+ auto bandwidth = size * 8 / (duration / 1000 ) / 1e9 ;
174
+ if (mOutputPath .empty ())
175
+ {
176
+ return ;
177
+ }
178
+
179
+ std::lock_guard<std::mutex> lock (mMutex );
180
+ mRequestKVCacheTranfserMeasure [requestId].emplace_back (duration, bandwidth);
181
+ }
182
+
183
+ ~KvCacheMeasureHelper ()
184
+ {
185
+ if (!mRequestKVCacheTranfserMeasure .empty () && !mOutputPath .empty ())
186
+ {
187
+ auto rank = mpi::MpiComm::world ().getRank ();
188
+ std::string outFilePath = mOutputPath + " rank_" + std::to_string (rank) + " .txt" ;
189
+ std::ofstream outFile (outFilePath);
190
+
191
+ TLLM_CHECK_WITH_INFO (outFile.is_open (), " Cannot write to file " + outFilePath);
192
+
193
+ size_t numTransferMeasure = mRequestKVCacheTranfserMeasure .begin ()->second .size ();
194
+
195
+ outFile << " RequestID" ;
196
+ for (size_t i = 0 ; i < numTransferMeasure; i++)
197
+ {
198
+ outFile << " ,TimeDuration,Bandwidth" ;
199
+ }
200
+ outFile << ' \n ' ;
201
+
202
+ for (auto const & [requestID, measures] : mRequestKVCacheTranfserMeasure )
203
+ {
204
+ outFile << requestID;
205
+
206
+ for (auto const & [time, bandwidth] : measures)
207
+ {
208
+ outFile << " ," << time << " ," << bandwidth;
209
+ }
210
+ outFile << ' \n ' ;
211
+ }
212
+
213
+ outFile.close ();
214
+ }
215
+ }
216
+
217
+ private:
218
+ std::map<LlmRequest::RequestIdType, std::vector<std::pair<double , double >>> mRequestKVCacheTranfserMeasure ;
219
+ std::string mOutputPath ;
220
+ std::mutex mMutex ;
221
+ };
222
+
81
223
// Simple cache block copy. Because it does not involve data splitting or merging, it performs best when the
82
224
// parallel topology is completely identical, making it the preferred method.
83
225
class CacheFormatter final : public BaseCacheFormatter
0 commit comments