Skip to content

Commit 214b335

Browse files
committed
Rename data -> cache
Signed-off-by: Iman Tabrizian <[email protected]>
1 parent 428e340 commit 214b335

File tree

8 files changed

+59
-60
lines changed

8 files changed

+59
-60
lines changed

cpp/include/tensorrt_llm/batch_manager/cacheTransceiver.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,9 +110,9 @@ class CacheTransceiver : public BaseCacheTransceiver
110110

111111
void setContextState(LlmRequest* llmRequest);
112112

113-
std::unique_ptr<DataResponder> mDataResponder;
113+
std::unique_ptr<DataResponder> mCacheSender;
114114
std::unique_ptr<DataRequester> mDataRequester;
115-
std::vector<std::pair<LlmRequest*, std::future<void>>> mResponderFutures;
115+
std::vector<std::pair<LlmRequest*, std::future<void>>> mSenderFutures;
116116
std::vector<std::pair<LlmRequest*, std::future<void>>> mRequesterFutures;
117117
mpi::MpiComm const *mMpiGroupComm{nullptr}, *mMpiWorldComm{nullptr};
118118
std::shared_ptr<mpi::MpiComm> mMpiGroupTensorParaComm, mMpiGroupPipeParaComm, mMpiGroupDataComm,

cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -194,10 +194,10 @@ CacheTransceiver::CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheMa
194194
auto makeFormatter = [cacheManager, isMLA, this]()
195195
{ return createCacheFormatter(cacheManager, mCacheTransBufferManager.get(), isMLA); };
196196

197-
mDataResponder = std::make_unique<DataResponder>(
198-
std::make_unique<DataSenderImpl>(mManager.get(), *mCacheState, worldConfig.getRank(), makeFormatter()));
197+
mCacheSender = std::make_unique<DataResponder>(
198+
std::make_unique<CacheSenderImpl>(mManager.get(), *mCacheState, worldConfig.getRank(), makeFormatter()));
199199
mDataRequester = std::make_unique<DataRequester>(
200-
std::make_unique<DataReceiverImpl>(mManager.get(), *mCacheState, worldConfig.getRank(), makeFormatter()));
200+
std::make_unique<CacheReceiverImpl>(mManager.get(), *mCacheState, worldConfig.getRank(), makeFormatter()));
201201

202202
initializeCommState();
203203
}
@@ -213,7 +213,7 @@ CacheTransceiver::~CacheTransceiver()
213213

214214
void CacheTransceiver::initializeCommState()
215215
{
216-
mCommState = std::addressof(mDataResponder->getCommState());
216+
mCommState = std::addressof(mCacheSender->getCommState());
217217
}
218218

219219
void CacheTransceiver::setContextState(LlmRequest* llmRequest)
@@ -249,8 +249,8 @@ void CacheTransceiver::respondAndSendAsync(LlmRequest* llmRequest)
249249
return;
250250
}
251251
setContextState(llmRequest);
252-
auto future = mDataResponder->respondAndSendAsync(*llmRequest);
253-
mResponderFutures.emplace_back(llmRequest, std::move(future));
252+
auto future = mCacheSender->respondAndSendAsync(*llmRequest);
253+
mSenderFutures.emplace_back(llmRequest, std::move(future));
254254
}
255255

256256
void CacheTransceiver::respondAndSendLayerWise(
@@ -265,8 +265,8 @@ void CacheTransceiver::respondAndSendLayerWise(
265265

266266
llmRequest->setState(LlmRequestState::kDISAGG_CONTEXT_INIT_AND_TRANS);
267267
setContextState(llmRequest.get());
268-
auto future = mDataResponder->respondAndSendAsync(*llmRequest);
269-
mResponderFutures.emplace_back(llmRequest.get(), std::move(future));
268+
auto future = mCacheSender->respondAndSendAsync(*llmRequest);
269+
mSenderFutures.emplace_back(llmRequest.get(), std::move(future));
270270
}
271271
}
272272

@@ -372,7 +372,7 @@ void CacheTransceiver::checkContextTransferStatus(std::optional<int> const& atLe
372372
bool blockAll = !atLeastRequestNum.has_value();
373373
auto syncComm = mCacheState->getParallelConfig().mEnableAttentionDP ? mMpiGroupTPInDPComm : mMpiGroupTensorParaComm;
374374
std::vector<LlmRequest::RequestIdType> contextCompleteRequestIds;
375-
for (auto&& [request, future] : mResponderFutures)
375+
for (auto&& [request, future] : mSenderFutures)
376376
{
377377
if (future.wait_for(std::chrono::milliseconds(0)) == std::future_status::ready)
378378
{
@@ -412,23 +412,22 @@ void CacheTransceiver::checkContextTransferStatus(std::optional<int> const& atLe
412412

413413
// Make sure there are at least atLeastRequestNum requests in toCompleteIdSet.
414414
// This will preserve the order of insertion for KVCache transfer requests.
415-
for (auto it = mResponderFutures.begin();
416-
atLeastRequestNum.value_or(0) > static_cast<int>(toCompleteIdSet.size()) && it != mResponderFutures.end();
417-
++it)
415+
for (auto it = mSenderFutures.begin();
416+
atLeastRequestNum.value_or(0) > static_cast<int>(toCompleteIdSet.size()) && it != mSenderFutures.end(); ++it)
418417
{
419418
auto& [request, future] = *it;
420419
toCompleteIdSet.insert(request->mRequestId);
421420
}
422421

423422
// Complete all the requests in toCompleteIdSet
424-
for (auto it = mResponderFutures.begin(); it != mResponderFutures.end();)
423+
for (auto it = mSenderFutures.begin(); it != mSenderFutures.end();)
425424
{
426425
auto& [request, future] = *it;
427426
if (blockAll || (toCompleteIdSet.find(request->mRequestId) != toCompleteIdSet.end()))
428427
{
429428
future.get();
430429
request->setState(LlmRequestState::kDISAGG_CONTEXT_COMPLETE);
431-
it = mResponderFutures.erase(it);
430+
it = mSenderFutures.erase(it);
432431
}
433432
else
434433
{

cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -111,14 +111,14 @@ class DataResponder::Impl
111111
auto future = promise.get_future();
112112
{
113113
{
114-
std::unique_lock lkResp(mResponderMutex);
114+
std::unique_lock lkResp(mSenderMutex);
115115
mReadyResponses.emplace(
116116
llmRequest.mRequestId, Response{std::addressof(llmRequest), std::move(promise)});
117117
}
118118
std::unique_lock lkCond(mCondMutex);
119119
mAnyReady = true;
120120
}
121-
mResponderCv.notify_all();
121+
mSenderCv.notify_all();
122122
return future;
123123
}
124124

@@ -171,7 +171,7 @@ class DataResponder::Impl
171171
if (!mAnyReady)
172172
{
173173
std::unique_lock lk(mCondMutex);
174-
mResponderCv.wait(lk, [this]() { return (mAnyReady || mTerminate); });
174+
mSenderCv.wait(lk, [this]() { return (mAnyReady || mTerminate); });
175175
}
176176
if (mTerminate)
177177
{
@@ -226,7 +226,7 @@ class DataResponder::Impl
226226
"mReadyResponses size is: %zu. mpi rank :%d ",
227227
mCurrentRequest.value(), mReadyResponses.size(), mpi::MpiComm::world().getRank());
228228
std::unique_lock lk(mCondMutex);
229-
mResponderCv.wait(lk, [this]() { return (mAnyReady || mTerminate); });
229+
mSenderCv.wait(lk, [this]() { return (mAnyReady || mTerminate); });
230230
}
231231
}
232232
}
@@ -248,13 +248,13 @@ class DataResponder::Impl
248248
}
249249
// We don't have to wait for the future. If another thread is sending data, it won't pay attention
250250
// to the terminate flag.
251-
mResponderCv.notify_all();
251+
mSenderCv.notify_all();
252252
}
253253

254254
void removeResponse(std::map<RequestIdType, Response>::iterator it)
255255
{
256256
{
257-
std::unique_lock lkResp(mResponderMutex);
257+
std::unique_lock lkResp(mSenderMutex);
258258
mReadyResponses.erase(it);
259259
}
260260
if (mReadyResponses.empty())
@@ -276,16 +276,16 @@ class DataResponder::Impl
276276

277277
[[nodiscard]] std::map<RequestIdType, Response>::iterator getCurrentResponse()
278278
{
279-
std::unique_lock lk(mResponderMutex);
279+
std::unique_lock lk(mSenderMutex);
280280
return mReadyResponses.find(getCurrentRequestId());
281281
}
282282

283283
private:
284284
std::optional<RequestIdType> mCurrentRequest;
285285
std::map<RequestIdType, Response> mReadyResponses;
286-
std::mutex mResponderMutex, mCondMutex;
286+
std::mutex mSenderMutex, mCondMutex;
287287
std::atomic<bool> mAnyReady{false}, mTerminate{false};
288-
std::condition_variable mResponderCv;
288+
std::condition_variable mSenderCv;
289289
std::future<void> mResponseFuture;
290290
std::unique_ptr<DataSender> mSender;
291291
std::unordered_map<LlmRequest::RequestIdType, int> mRemainSendCount;
@@ -296,9 +296,9 @@ class DataRequester::Impl
296296
{
297297
public:
298298
Impl(std::unique_ptr<DataReceiver> receiver)
299-
: mReceiver{std::move(receiver)}
299+
: mCacheReceiver{std::move(receiver)}
300300
{
301-
TLLM_CHECK(mReceiver);
301+
TLLM_CHECK(mCacheReceiver);
302302
TLLM_CUDA_CHECK(cudaGetDevice(&mDeviceId));
303303
}
304304

@@ -363,8 +363,8 @@ class DataRequester::Impl
363363
llmRequest.getContextPhaseParams().value().getReqId());
364364
llmRequest.setKvCacheTransferStart(std::chrono::steady_clock::now());
365365
TLLM_CUDA_CHECK(cudaSetDevice(mDeviceId));
366-
auto session = mReceiver->sendRequestInfo(llmRequest);
367-
mReceiver->receiveSync(session);
366+
auto session = mCacheReceiver->sendRequestInfo(llmRequest);
367+
mCacheReceiver->receiveSync(session);
368368
llmRequest.setKvCacheTransferEnd(std::chrono::steady_clock::now());
369369

370370
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(),
@@ -470,7 +470,7 @@ class DataRequester::Impl
470470
}
471471
}
472472

473-
std::unique_ptr<DataReceiver> mReceiver;
473+
std::unique_ptr<DataReceiver> mCacheReceiver;
474474
int mDeviceId{-1};
475475

476476
std::vector<std::future<void>> mRequestFutures;

cpp/tensorrt_llm/batch_manager/dataTransceiverImpl.cpp

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ static int32_t tagFromRequestId(LlmRequest::RequestIdType requestId)
3030
return ((requestId & 0xFFF) << 8) | (kDATA_TAG & 0xFF);
3131
}
3232

33-
DataSenderImpl::DataSenderImpl(executor::kv_cache::ConnectionManager* manager,
33+
CacheSenderImpl::CacheSenderImpl(executor::kv_cache::ConnectionManager* manager,
3434
executor::kv_cache::CacheState selfCacheState, SizeType32 selfIndex, std::unique_ptr<BaseCacheFormatter> formatter)
3535
: mManager{manager}
3636
, mSelfState{std::move(selfCacheState), executor::kv_cache::CommState{manager->getCommState()}}
@@ -41,7 +41,7 @@ DataSenderImpl::DataSenderImpl(executor::kv_cache::ConnectionManager* manager,
4141
TLLM_CHECK(mManager->getCommState().getSelfIdx() == selfIndex);
4242
}
4343

44-
[[nodiscard]] RequestInfo DataSenderImpl::recvRequestInfo()
44+
[[nodiscard]] RequestInfo CacheSenderImpl::recvRequestInfo()
4545
{
4646
using DataContext = tensorrt_llm::executor::kv_cache::DataContext;
4747
auto* agentConnectionManager = dynamic_cast<executor::kv_cache::AgentConnectionManager*>(mManager);
@@ -93,7 +93,7 @@ DataSenderImpl::DataSenderImpl(executor::kv_cache::ConnectionManager* manager,
9393
return info;
9494
}
9595

96-
void DataSenderImpl::sendSync(LlmRequest const& llmRequest)
96+
void CacheSenderImpl::sendSync(LlmRequest const& llmRequest)
9797
{
9898
auto it = mRequestToSession.find(llmRequest.mRequestId);
9999
TLLM_CHECK(it != mRequestToSession.end());
@@ -102,32 +102,32 @@ void DataSenderImpl::sendSync(LlmRequest const& llmRequest)
102102
mFormatter->format(session);
103103
}
104104

105-
[[nodiscard]] executor::kv_cache::CommState const& DataSenderImpl::getCommState() const
105+
[[nodiscard]] executor::kv_cache::CommState const& CacheSenderImpl::getCommState() const
106106
{
107107
return mSelfState.getCommState().value();
108108
}
109109

110-
void DataSenderImpl::setCommState(executor::kv_cache::CommState commState)
110+
void CacheSenderImpl::setCommState(executor::kv_cache::CommState commState)
111111
{
112112
mSelfState.setCommState(std::move(commState));
113113
}
114114

115-
[[nodiscard]] size_t DataSenderImpl::getCounterpartsCount(LlmRequest::RequestIdType requestId) const
115+
[[nodiscard]] size_t CacheSenderImpl::getCounterpartsCount(LlmRequest::RequestIdType requestId) const
116116
{
117117
auto it = mRequestToSession.find(requestId);
118118
TLLM_CHECK(it != mRequestToSession.end());
119119
return it->second.getConnections().size();
120120
}
121121

122-
void DataSenderImpl::release(LlmRequest::RequestIdType requestId)
122+
void CacheSenderImpl::release(LlmRequest::RequestIdType requestId)
123123
{
124124
auto it = mRequestToSession.find(requestId);
125125
TLLM_CHECK(it != mRequestToSession.end());
126126
std::unique_lock<std::mutex> lk(mMtxForMap);
127127
mRequestToSession.erase(it);
128128
}
129129

130-
DataReceiverImpl::DataReceiverImpl(executor::kv_cache::ConnectionManager* manager,
130+
CacheReceiverImpl::CacheReceiverImpl(executor::kv_cache::ConnectionManager* manager,
131131
executor::kv_cache::CacheState selfCacheState, SizeType32 selfIndex, std::unique_ptr<BaseCacheFormatter> formatter)
132132
: mManager{manager}
133133
, mSelfState{std::move(selfCacheState), executor::kv_cache::CommState{manager->getCommState()}}
@@ -138,7 +138,7 @@ DataReceiverImpl::DataReceiverImpl(executor::kv_cache::ConnectionManager* manage
138138
TLLM_CHECK(mFormatter);
139139
}
140140

141-
TransferSession DataReceiverImpl::sendRequestInfo(LlmRequest const& llmRequest)
141+
TransferSession CacheReceiverImpl::sendRequestInfo(LlmRequest const& llmRequest)
142142
{
143143
uint64_t requestId = llmRequest.getContextPhaseParams().value().getReqId();
144144
auto const& contextState = llmRequest.getDataTransceiverState();
@@ -204,12 +204,12 @@ TransferSession DataReceiverImpl::sendRequestInfo(LlmRequest const& llmRequest)
204204
contextState, resource->mBufferManager, &llmRequest);
205205
}
206206

207-
void DataReceiverImpl::receiveSync(TransferSession& session)
207+
void CacheReceiverImpl::receiveSync(TransferSession& session)
208208
{
209209
mFormatter->unformat(session);
210210
}
211211

212-
void DataReceiverImpl::sendRequestInfo(executor::kv_cache::Connection const* connection, RequestInfo const& info)
212+
void CacheReceiverImpl::sendRequestInfo(executor::kv_cache::Connection const* connection, RequestInfo const& info)
213213
{
214214
std::ostringstream oss;
215215
RequestInfo::serialize(info, oss);
@@ -221,7 +221,7 @@ void DataReceiverImpl::sendRequestInfo(executor::kv_cache::Connection const* con
221221
connection->send(executor::kv_cache::DataContext{kINFO_TAG}, serializedInfo.data(), infoSize);
222222
}
223223

224-
std::unique_ptr<DataReceiverImpl::ReceiveCacheResource> const& DataReceiverImpl::getReceiveCacheResource(
224+
std::unique_ptr<CacheReceiverImpl::ReceiveCacheResource> const& CacheReceiverImpl::getReceiveCacheResource(
225225
LlmRequest const& llmRequest)
226226
{
227227
std::scoped_lock<std::mutex> lock(mProcessIoResouceMutex);

cpp/tensorrt_llm/batch_manager/dataTransceiverImpl.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,12 @@ struct TransceiverTag
4040

4141
using BaseCacheFormatter = kv_cache_manager::BaseCacheFormatter;
4242

43-
class DataSenderImpl : public DataSender, public TransceiverTag
43+
class CacheSenderImpl : public DataSender, public TransceiverTag
4444
{
4545
public:
4646
using SizeType32 = tensorrt_llm::runtime::SizeType32;
4747

48-
DataSenderImpl(executor::kv_cache::ConnectionManager* manager, executor::kv_cache::CacheState selfCacheState,
48+
CacheSenderImpl(executor::kv_cache::ConnectionManager* manager, executor::kv_cache::CacheState selfCacheState,
4949
SizeType32 selfIndex, std::unique_ptr<BaseCacheFormatter> formatter);
5050

5151
[[nodiscard]] RequestInfo recvRequestInfo() override;
@@ -69,12 +69,12 @@ class DataSenderImpl : public DataSender, public TransceiverTag
6969
runtime::BufferManager mBufferManager;
7070
};
7171

72-
class DataReceiverImpl : public DataReceiver, public TransceiverTag
72+
class CacheReceiverImpl : public DataReceiver, public TransceiverTag
7373
{
7474
public:
7575
using SizeType32 = tensorrt_llm::runtime::SizeType32;
7676

77-
DataReceiverImpl(executor::kv_cache::ConnectionManager* manager, executor::kv_cache::CacheState selfCacheState,
77+
CacheReceiverImpl(executor::kv_cache::ConnectionManager* manager, executor::kv_cache::CacheState selfCacheState,
7878
SizeType32 selfIndex, std::unique_ptr<BaseCacheFormatter> formatter);
7979

8080
TransferSession sendRequestInfo(LlmRequest const& llmRequest) override;

cpp/tensorrt_llm/executor/cache_transmission/agent_utils/connection.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ void AgentConnection::send(DataContext const& ctx, void const* data, size_t size
8181
MemoryDesc srcDesc{
8282
reinterpret_cast<uintptr_t>(data), size, static_cast<uint32_t>(mAgentConnectionManager->getDeviceId())};
8383
MemoryDescs srcDescs{MemoryType::kVRAM, {srcDesc}};
84-
auto dstBaseDesc = mSenderState.mReceiverBufferDesc;
84+
auto dstBaseDesc = mSenderState.mCacheReceiverBufferDesc;
8585
MemoryDesc dstDesc{dstBaseDesc.getAddr() + (mSenderState.validSegmentIdx * size), size, dstBaseDesc.getDeviceId()};
8686
TLLM_LOG_DEBUG(
8787
"send dstDesc: %p, size: %ld ,validSegmentIdx: %ld", dstDesc.getAddr(), size, mSenderState.validSegmentIdx);
@@ -137,9 +137,9 @@ void AgentConnection::sendRequestAndBufferInfo(
137137
mAgentConnectionManager->getAgent()->notifySyncMessage(mRemoteAgentName, ss.str());
138138
}
139139

140-
void AgentConnection::setSenderState(MemoryDesc mReceiverBufferDesc, int validSegmentIdx)
140+
void AgentConnection::setSenderState(MemoryDesc mCacheReceiverBufferDesc, int validSegmentIdx)
141141
{
142-
mSenderState.mReceiverBufferDesc = mReceiverBufferDesc;
142+
mSenderState.mCacheReceiverBufferDesc = mCacheReceiverBufferDesc;
143143
mSenderState.validSegmentIdx = validSegmentIdx;
144144
}
145145

cpp/tensorrt_llm/executor/cache_transmission/agent_utils/connection.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ class AgentConnection : public Connection
175175
void recv(DataContext const& ctx, void* data, size_t size) const override;
176176
void sendRequestAndBufferInfo(
177177
batch_manager::RequestInfo& requestInfo, std::optional<size_t> cacheBufferId, int validConnectionIdx);
178-
void setSenderState(MemoryDesc mReceiverBufferDesc, int valideSegmentIdx);
178+
void setSenderState(MemoryDesc mCacheReceiverBufferDesc, int valideSegmentIdx);
179179
[[nodiscard]] std::optional<size_t> getCacheBufferId() const;
180180
void setHasLoadRemoteAgent(bool hasLoadRemoteAgent);
181181
[[nodiscard]] bool hasLoadRemoteAgent() const;
@@ -186,7 +186,7 @@ class AgentConnection : public Connection
186186

187187
struct SenderState
188188
{
189-
MemoryDesc mReceiverBufferDesc{nullptr, 0, 0};
189+
MemoryDesc mCacheReceiverBufferDesc{nullptr, 0, 0};
190190
int validSegmentIdx{0};
191191
SenderState() = default;
192192
};

0 commit comments

Comments
 (0)