Skip to content

Commit abb41e4

Browse files
committed
Rename data -> cache
Signed-off-by: Iman Tabrizian <[email protected]>
1 parent b6baa9e commit abb41e4

File tree

8 files changed

+59
-60
lines changed

8 files changed

+59
-60
lines changed

cpp/include/tensorrt_llm/batch_manager/cacheTransceiver.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,9 +110,9 @@ class CacheTransceiver : public BaseCacheTransceiver
110110

111111
void setContextState(LlmRequest* llmRequest);
112112

113-
std::unique_ptr<DataResponder> mDataResponder;
113+
std::unique_ptr<DataResponder> mCacheSender;
114114
std::unique_ptr<DataRequester> mDataRequester;
115-
std::vector<std::pair<LlmRequest*, std::future<void>>> mResponderFutures;
115+
std::vector<std::pair<LlmRequest*, std::future<void>>> mSenderFutures;
116116
std::vector<std::pair<LlmRequest*, std::future<void>>> mRequesterFutures;
117117
mpi::MpiComm const *mMpiGroupComm{nullptr}, *mMpiWorldComm{nullptr};
118118
std::shared_ptr<mpi::MpiComm> mMpiGroupTensorParaComm, mMpiGroupPipeParaComm, mMpiGroupDataComm,

cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -195,10 +195,10 @@ CacheTransceiver::CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheMa
195195
auto makeFormatter = [cacheManager, isMLA, this]()
196196
{ return createCacheFormatter(cacheManager, mCacheTransBufferManager.get(), isMLA); };
197197

198-
mDataResponder = std::make_unique<DataResponder>(
199-
std::make_unique<DataSenderImpl>(mManager.get(), *mCacheState, worldConfig.getRank(), makeFormatter()));
198+
mCacheSender = std::make_unique<DataResponder>(
199+
std::make_unique<CacheSenderImpl>(mManager.get(), *mCacheState, worldConfig.getRank(), makeFormatter()));
200200
mDataRequester = std::make_unique<DataRequester>(
201-
std::make_unique<DataReceiverImpl>(mManager.get(), *mCacheState, worldConfig.getRank(), makeFormatter()));
201+
std::make_unique<CacheReceiverImpl>(mManager.get(), *mCacheState, worldConfig.getRank(), makeFormatter()));
202202

203203
initializeCommState();
204204
}
@@ -214,7 +214,7 @@ CacheTransceiver::~CacheTransceiver()
214214

215215
void CacheTransceiver::initializeCommState()
216216
{
217-
mCommState = std::addressof(mDataResponder->getCommState());
217+
mCommState = std::addressof(mCacheSender->getCommState());
218218
}
219219

220220
void CacheTransceiver::setContextState(LlmRequest* llmRequest)
@@ -250,8 +250,8 @@ void CacheTransceiver::respondAndSendAsync(LlmRequest* llmRequest)
250250
return;
251251
}
252252
setContextState(llmRequest);
253-
auto future = mDataResponder->respondAndSendAsync(*llmRequest);
254-
mResponderFutures.emplace_back(llmRequest, std::move(future));
253+
auto future = mCacheSender->respondAndSendAsync(*llmRequest);
254+
mSenderFutures.emplace_back(llmRequest, std::move(future));
255255
}
256256

257257
void CacheTransceiver::respondAndSendLayerWise(
@@ -266,8 +266,8 @@ void CacheTransceiver::respondAndSendLayerWise(
266266

267267
llmRequest->setState(LlmRequestState::kDISAGG_CONTEXT_INIT_AND_TRANS);
268268
setContextState(llmRequest.get());
269-
auto future = mDataResponder->respondAndSendAsync(*llmRequest);
270-
mResponderFutures.emplace_back(llmRequest.get(), std::move(future));
269+
auto future = mCacheSender->respondAndSendAsync(*llmRequest);
270+
mSenderFutures.emplace_back(llmRequest.get(), std::move(future));
271271
}
272272
}
273273

@@ -373,7 +373,7 @@ void CacheTransceiver::checkContextTransferStatus(std::optional<int> const& atLe
373373
bool blockAll = !atLeastRequestNum.has_value();
374374
auto syncComm = mCacheState->getParallelConfig().mEnableAttentionDP ? mMpiGroupTPInDPComm : mMpiGroupTensorParaComm;
375375
std::vector<LlmRequest::RequestIdType> contextCompleteRequestIds;
376-
for (auto&& [request, future] : mResponderFutures)
376+
for (auto&& [request, future] : mSenderFutures)
377377
{
378378
if (future.wait_for(std::chrono::milliseconds(0)) == std::future_status::ready)
379379
{
@@ -413,23 +413,22 @@ void CacheTransceiver::checkContextTransferStatus(std::optional<int> const& atLe
413413

414414
// Make sure there are at least atLeastRequestNum requests in toCompleteIdSet.
415415
// This will preserve the order of insertion for KVCache transfer requests.
416-
for (auto it = mResponderFutures.begin();
417-
atLeastRequestNum.value_or(0) > static_cast<int>(toCompleteIdSet.size()) && it != mResponderFutures.end();
418-
++it)
416+
for (auto it = mSenderFutures.begin();
417+
atLeastRequestNum.value_or(0) > static_cast<int>(toCompleteIdSet.size()) && it != mSenderFutures.end(); ++it)
419418
{
420419
auto& [request, future] = *it;
421420
toCompleteIdSet.insert(request->mRequestId);
422421
}
423422

424423
// Complete all the requests in toCompleteIdSet
425-
for (auto it = mResponderFutures.begin(); it != mResponderFutures.end();)
424+
for (auto it = mSenderFutures.begin(); it != mSenderFutures.end();)
426425
{
427426
auto& [request, future] = *it;
428427
if (blockAll || (toCompleteIdSet.find(request->mRequestId) != toCompleteIdSet.end()))
429428
{
430429
future.get();
431430
request->setState(LlmRequestState::kDISAGG_CONTEXT_COMPLETE);
432-
it = mResponderFutures.erase(it);
431+
it = mSenderFutures.erase(it);
433432
}
434433
else
435434
{

cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -148,14 +148,14 @@ class DataResponder::Impl
148148
auto future = promise.get_future();
149149
{
150150
{
151-
std::unique_lock lkResp(mResponderMutex);
151+
std::unique_lock lkResp(mSenderMutex);
152152
mReadyResponses.emplace(
153153
llmRequest.mRequestId, Response{std::addressof(llmRequest), std::move(promise)});
154154
}
155155
std::unique_lock lkCond(mCondMutex);
156156
mAnyReady = true;
157157
}
158-
mResponderCv.notify_all();
158+
mSenderCv.notify_all();
159159
return future;
160160
}
161161

@@ -208,7 +208,7 @@ class DataResponder::Impl
208208
if (!mAnyReady)
209209
{
210210
std::unique_lock lk(mCondMutex);
211-
mResponderCv.wait(lk, [this]() { return (mAnyReady || mTerminate); });
211+
mSenderCv.wait(lk, [this]() { return (mAnyReady || mTerminate); });
212212
}
213213
if (mTerminate)
214214
{
@@ -263,7 +263,7 @@ class DataResponder::Impl
263263
"mReadyResponses size is: %zu. mpi rank :%d ",
264264
mCurrentRequest.value(), mReadyResponses.size(), mpi::MpiComm::world().getRank());
265265
std::unique_lock lk(mCondMutex);
266-
mResponderCv.wait(lk, [this]() { return (mAnyReady || mTerminate); });
266+
mSenderCv.wait(lk, [this]() { return (mAnyReady || mTerminate); });
267267
}
268268
}
269269
}
@@ -285,13 +285,13 @@ class DataResponder::Impl
285285
}
286286
// We don't have to wait for the future. If another thread is sending data, it won't pay attention
287287
// to the terminate flag.
288-
mResponderCv.notify_all();
288+
mSenderCv.notify_all();
289289
}
290290

291291
void removeResponse(std::map<RequestIdType, Response>::iterator it)
292292
{
293293
{
294-
std::unique_lock lkResp(mResponderMutex);
294+
std::unique_lock lkResp(mSenderMutex);
295295
mReadyResponses.erase(it);
296296
}
297297
if (mReadyResponses.empty())
@@ -313,16 +313,16 @@ class DataResponder::Impl
313313

314314
[[nodiscard]] std::map<RequestIdType, Response>::iterator getCurrentResponse()
315315
{
316-
std::unique_lock lk(mResponderMutex);
316+
std::unique_lock lk(mSenderMutex);
317317
return mReadyResponses.find(getCurrentRequestId());
318318
}
319319

320320
private:
321321
std::optional<RequestIdType> mCurrentRequest;
322322
std::map<RequestIdType, Response> mReadyResponses;
323-
std::mutex mResponderMutex, mCondMutex;
323+
std::mutex mSenderMutex, mCondMutex;
324324
std::atomic<bool> mAnyReady{false}, mTerminate{false};
325-
std::condition_variable mResponderCv;
325+
std::condition_variable mSenderCv;
326326
std::future<void> mResponseFuture;
327327
std::unique_ptr<DataSender> mSender;
328328
std::unordered_map<LlmRequest::RequestIdType, int> mRemainSendCount;
@@ -333,9 +333,9 @@ class DataRequester::Impl
333333
{
334334
public:
335335
Impl(std::unique_ptr<DataReceiver> receiver)
336-
: mReceiver{std::move(receiver)}
336+
: mCacheReceiver{std::move(receiver)}
337337
{
338-
TLLM_CHECK(mReceiver);
338+
TLLM_CHECK(mCacheReceiver);
339339
TLLM_CUDA_CHECK(cudaGetDevice(&mDeviceId));
340340
}
341341

@@ -400,8 +400,8 @@ class DataRequester::Impl
400400
llmRequest.getContextPhaseParams().value().getReqId());
401401
llmRequest.setKvCacheTransferStart(std::chrono::steady_clock::now());
402402
TLLM_CUDA_CHECK(cudaSetDevice(mDeviceId));
403-
auto session = mReceiver->sendRequestInfo(llmRequest);
404-
mReceiver->receiveSync(session);
403+
auto session = mCacheReceiver->sendRequestInfo(llmRequest);
404+
mCacheReceiver->receiveSync(session);
405405
llmRequest.setKvCacheTransferEnd(std::chrono::steady_clock::now());
406406

407407
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(),
@@ -507,7 +507,7 @@ class DataRequester::Impl
507507
}
508508
}
509509

510-
std::unique_ptr<DataReceiver> mReceiver;
510+
std::unique_ptr<DataReceiver> mCacheReceiver;
511511
int mDeviceId{-1};
512512

513513
std::vector<std::future<void>> mRequestFutures;

cpp/tensorrt_llm/batch_manager/dataTransceiverImpl.cpp

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ static fs::path getTransferOutputPath(char const* tag)
4747
return {};
4848
}
4949

50-
DataSenderImpl::DataSenderImpl(executor::kv_cache::ConnectionManager* manager,
50+
CacheSenderImpl::CacheSenderImpl(executor::kv_cache::ConnectionManager* manager,
5151
executor::kv_cache::CacheState selfCacheState, SizeType32 selfIndex, std::unique_ptr<BaseCacheFormatter> formatter)
5252
: mManager{manager}
5353
, mSelfState{std::move(selfCacheState), executor::kv_cache::CommState{manager->getCommState()}}
@@ -58,7 +58,7 @@ DataSenderImpl::DataSenderImpl(executor::kv_cache::ConnectionManager* manager,
5858
TLLM_CHECK(mManager->getCommState().getSelfIdx() == selfIndex);
5959
}
6060

61-
[[nodiscard]] RequestInfo DataSenderImpl::recvRequestInfo()
61+
[[nodiscard]] RequestInfo CacheSenderImpl::recvRequestInfo()
6262
{
6363
using DataContext = tensorrt_llm::executor::kv_cache::DataContext;
6464
auto* agentConnectionManager = dynamic_cast<executor::kv_cache::AgentConnectionManager*>(mManager);
@@ -111,7 +111,7 @@ DataSenderImpl::DataSenderImpl(executor::kv_cache::ConnectionManager* manager,
111111
return info;
112112
}
113113

114-
void DataSenderImpl::sendSync(LlmRequest const& llmRequest)
114+
void CacheSenderImpl::sendSync(LlmRequest const& llmRequest)
115115
{
116116
auto it = mRequestToSession.find(llmRequest.mRequestId);
117117
TLLM_CHECK(it != mRequestToSession.end());
@@ -120,24 +120,24 @@ void DataSenderImpl::sendSync(LlmRequest const& llmRequest)
120120
mFormatter->format(session);
121121
}
122122

123-
[[nodiscard]] executor::kv_cache::CommState const& DataSenderImpl::getCommState() const
123+
[[nodiscard]] executor::kv_cache::CommState const& CacheSenderImpl::getCommState() const
124124
{
125125
return mSelfState.getCommState().value();
126126
}
127127

128-
void DataSenderImpl::setCommState(executor::kv_cache::CommState commState)
128+
void CacheSenderImpl::setCommState(executor::kv_cache::CommState commState)
129129
{
130130
mSelfState.setCommState(std::move(commState));
131131
}
132132

133-
[[nodiscard]] size_t DataSenderImpl::getCounterpartsCount(LlmRequest::RequestIdType requestId) const
133+
[[nodiscard]] size_t CacheSenderImpl::getCounterpartsCount(LlmRequest::RequestIdType requestId) const
134134
{
135135
auto it = mRequestToSession.find(requestId);
136136
TLLM_CHECK(it != mRequestToSession.end());
137137
return it->second.getConnections().size();
138138
}
139139

140-
void DataSenderImpl::release(LlmRequest::RequestIdType requestId)
140+
void CacheSenderImpl::release(LlmRequest::RequestIdType requestId)
141141
{
142142
auto it = mRequestToSession.find(requestId);
143143
TLLM_CHECK(it != mRequestToSession.end());
@@ -156,7 +156,7 @@ void DataSenderImpl::release(LlmRequest::RequestIdType requestId)
156156
mRequestToSession.erase(it);
157157
}
158158

159-
DataReceiverImpl::DataReceiverImpl(executor::kv_cache::ConnectionManager* manager,
159+
CacheReceiverImpl::CacheReceiverImpl(executor::kv_cache::ConnectionManager* manager,
160160
executor::kv_cache::CacheState selfCacheState, SizeType32 selfIndex, std::unique_ptr<BaseCacheFormatter> formatter)
161161
: mManager{manager}
162162
, mSelfState{std::move(selfCacheState), executor::kv_cache::CommState{manager->getCommState()}}
@@ -167,7 +167,7 @@ DataReceiverImpl::DataReceiverImpl(executor::kv_cache::ConnectionManager* manage
167167
TLLM_CHECK(mFormatter);
168168
}
169169

170-
TransferSession DataReceiverImpl::sendRequestInfo(LlmRequest const& llmRequest)
170+
TransferSession CacheReceiverImpl::sendRequestInfo(LlmRequest const& llmRequest)
171171
{
172172
uint64_t requestId = llmRequest.getContextPhaseParams().value().getReqId();
173173
auto const& contextState = llmRequest.getDataTransceiverState();
@@ -233,7 +233,7 @@ TransferSession DataReceiverImpl::sendRequestInfo(LlmRequest const& llmRequest)
233233
contextState, resource->mBufferManager, &llmRequest, !common::getEnvKVCacheTransferOutputPath().empty());
234234
}
235235

236-
void DataReceiverImpl::receiveSync(TransferSession& session)
236+
void CacheReceiverImpl::receiveSync(TransferSession& session)
237237
{
238238
mFormatter->unformat(session);
239239
if (!common::getEnvKVCacheTransferOutputPath().empty())
@@ -250,7 +250,7 @@ void DataReceiverImpl::receiveSync(TransferSession& session)
250250
}
251251
}
252252

253-
void DataReceiverImpl::sendRequestInfo(executor::kv_cache::Connection const* connection, RequestInfo const& info)
253+
void CacheReceiverImpl::sendRequestInfo(executor::kv_cache::Connection const* connection, RequestInfo const& info)
254254
{
255255
std::ostringstream oss;
256256
RequestInfo::serialize(info, oss);
@@ -262,7 +262,7 @@ void DataReceiverImpl::sendRequestInfo(executor::kv_cache::Connection const* con
262262
connection->send(executor::kv_cache::DataContext{kINFO_TAG}, serializedInfo.data(), infoSize);
263263
}
264264

265-
std::unique_ptr<DataReceiverImpl::ReceiveCacheResource> const& DataReceiverImpl::getReceiveCacheResource(
265+
std::unique_ptr<CacheReceiverImpl::ReceiveCacheResource> const& CacheReceiverImpl::getReceiveCacheResource(
266266
LlmRequest const& llmRequest)
267267
{
268268
std::scoped_lock<std::mutex> lock(mProcessIoResouceMutex);

cpp/tensorrt_llm/batch_manager/dataTransceiverImpl.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,12 @@ struct TransceiverTag
4242

4343
using BaseCacheFormatter = kv_cache_manager::BaseCacheFormatter;
4444

45-
class DataSenderImpl : public DataSender, public TransceiverTag
45+
class CacheSenderImpl : public DataSender, public TransceiverTag
4646
{
4747
public:
4848
using SizeType32 = tensorrt_llm::runtime::SizeType32;
4949

50-
DataSenderImpl(executor::kv_cache::ConnectionManager* manager, executor::kv_cache::CacheState selfCacheState,
50+
CacheSenderImpl(executor::kv_cache::ConnectionManager* manager, executor::kv_cache::CacheState selfCacheState,
5151
SizeType32 selfIndex, std::unique_ptr<BaseCacheFormatter> formatter);
5252

5353
[[nodiscard]] RequestInfo recvRequestInfo() override;
@@ -72,12 +72,12 @@ class DataSenderImpl : public DataSender, public TransceiverTag
7272
std::ofstream mMeasuresFile;
7373
};
7474

75-
class DataReceiverImpl : public DataReceiver, public TransceiverTag
75+
class CacheReceiverImpl : public DataReceiver, public TransceiverTag
7676
{
7777
public:
7878
using SizeType32 = tensorrt_llm::runtime::SizeType32;
7979

80-
DataReceiverImpl(executor::kv_cache::ConnectionManager* manager, executor::kv_cache::CacheState selfCacheState,
80+
CacheReceiverImpl(executor::kv_cache::ConnectionManager* manager, executor::kv_cache::CacheState selfCacheState,
8181
SizeType32 selfIndex, std::unique_ptr<BaseCacheFormatter> formatter);
8282

8383
TransferSession sendRequestInfo(LlmRequest const& llmRequest) override;

cpp/tensorrt_llm/executor/cache_transmission/agent_utils/connection.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ void AgentConnection::send(DataContext const& ctx, void const* data, size_t size
8181
MemoryDesc srcDesc{
8282
reinterpret_cast<uintptr_t>(data), size, static_cast<uint32_t>(mAgentConnectionManager->getDeviceId())};
8383
MemoryDescs srcDescs{MemoryType::kVRAM, {srcDesc}};
84-
auto dstBaseDesc = mSenderState.mReceiverBufferDesc;
84+
auto dstBaseDesc = mSenderState.mCacheReceiverBufferDesc;
8585
MemoryDesc dstDesc{dstBaseDesc.getAddr() + (mSenderState.validSegmentIdx * size), size, dstBaseDesc.getDeviceId()};
8686
TLLM_LOG_DEBUG(
8787
"send dstDesc: %p, size: %ld ,validSegmentIdx: %ld", dstDesc.getAddr(), size, mSenderState.validSegmentIdx);
@@ -137,9 +137,9 @@ void AgentConnection::sendRequestAndBufferInfo(
137137
mAgentConnectionManager->getAgent()->notifySyncMessage(mRemoteAgentName, ss.str());
138138
}
139139

140-
void AgentConnection::setSenderState(MemoryDesc mReceiverBufferDesc, int validSegmentIdx)
140+
void AgentConnection::setSenderState(MemoryDesc mCacheReceiverBufferDesc, int validSegmentIdx)
141141
{
142-
mSenderState.mReceiverBufferDesc = mReceiverBufferDesc;
142+
mSenderState.mCacheReceiverBufferDesc = mCacheReceiverBufferDesc;
143143
mSenderState.validSegmentIdx = validSegmentIdx;
144144
}
145145

cpp/tensorrt_llm/executor/cache_transmission/agent_utils/connection.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ class AgentConnection : public Connection
175175
void recv(DataContext const& ctx, void* data, size_t size) const override;
176176
void sendRequestAndBufferInfo(
177177
batch_manager::RequestInfo& requestInfo, std::optional<size_t> cacheBufferId, int validConnectionIdx);
178-
void setSenderState(MemoryDesc mReceiverBufferDesc, int valideSegmentIdx);
178+
void setSenderState(MemoryDesc mCacheReceiverBufferDesc, int valideSegmentIdx);
179179
[[nodiscard]] std::optional<size_t> getCacheBufferId() const;
180180
void setHasLoadRemoteAgent(bool hasLoadRemoteAgent);
181181
[[nodiscard]] bool hasLoadRemoteAgent() const;
@@ -186,7 +186,7 @@ class AgentConnection : public Connection
186186

187187
struct SenderState
188188
{
189-
MemoryDesc mReceiverBufferDesc{nullptr, 0, 0};
189+
MemoryDesc mCacheReceiverBufferDesc{nullptr, 0, 0};
190190
int validSegmentIdx{0};
191191
SenderState() = default;
192192
};

0 commit comments

Comments
 (0)