Skip to content

Commit 97d2df6

Browse files
committed
bugfix: resolve prefetch early termination issue in multi-tprank scenarios.
1 parent 08cdf97 commit 97d2df6

19 files changed

+84
-45
lines changed

xllm/core/distributed_runtime/comm_channel.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -372,13 +372,13 @@ void CommChannel::transfer_kv_blocks(
372372

373373
class ClientStreamReceiver : public brpc::StreamInputHandler {
374374
private:
375-
std::shared_ptr<std::atomic<bool>> termination_flag_;
375+
std::shared_ptr<std::atomic<int32_t>> termination_flag_;
376376
std::shared_ptr<std::atomic<uint32_t>> success_cnt_;
377377
std::promise<void> close_promise_;
378378
std::atomic<bool> promise_set_{false};
379379

380380
public:
381-
ClientStreamReceiver(std::shared_ptr<std::atomic<bool>> termination_flag,
381+
ClientStreamReceiver(std::shared_ptr<std::atomic<int32_t>> termination_flag,
382382
std::shared_ptr<std::atomic<uint32_t>> success_cnt)
383383
: termination_flag_(termination_flag), success_cnt_(success_cnt) {}
384384

@@ -398,10 +398,10 @@ class ClientStreamReceiver : public brpc::StreamInputHandler {
398398
int32_t success_cnt = std::stoi(msg_str);
399399

400400
if (success_cnt > 0 &&
401-
!termination_flag_->load(std::memory_order_acquire)) {
401+
termination_flag_->load(std::memory_order_acquire) > 0) {
402402
success_cnt_->fetch_add(success_cnt, std::memory_order_relaxed);
403403
} else {
404-
termination_flag_->store(true, std::memory_order_release);
404+
termination_flag_->fetch_sub(1, std::memory_order_release);
405405
brpc::StreamClose(id);
406406
if (!promise_set_.exchange(true)) {
407407
close_promise_.set_value();
@@ -427,7 +427,7 @@ class ClientStreamReceiver : public brpc::StreamInputHandler {
427427

428428
void CommChannel::prefetch_from_storage(
429429
const std::vector<BlockTransferInfo>& block_transfer_info,
430-
std::shared_ptr<std::atomic<bool>> flag,
430+
std::shared_ptr<std::atomic<int32_t>> flag,
431431
std::shared_ptr<std::atomic<uint32_t>> success_cnt) {
432432
proto::BlockTransferInfos pb_block_transfer_info;
433433
if (!block_transfer_info_to_proto(block_transfer_info,

xllm/core/distributed_runtime/comm_channel.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ class CommChannel {
9999

100100
virtual void prefetch_from_storage(
101101
const std::vector<BlockTransferInfo>& block_transfer_info,
102-
std::shared_ptr<std::atomic<bool>> flag,
102+
std::shared_ptr<std::atomic<int32_t>> flag,
103103
std::shared_ptr<std::atomic<uint32_t>> success_cnt);
104104

105105
virtual bool get_last_step_result_async(

xllm/core/distributed_runtime/engine.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ class Engine {
9797
virtual void prefetch_from_storage(
9898
const uint32_t dp_rank,
9999
const std::vector<BlockTransferInfo>& block_transfer_info,
100-
std::shared_ptr<std::atomic<bool>> flag,
100+
std::shared_ptr<std::atomic<int32_t>> flag,
101101
std::vector<std::shared_ptr<std::atomic<uint32_t>>>* prefetch_results) {
102102
LOG(FATAL) << " prefetch_from_storage is not implemented!";
103103
};

xllm/core/distributed_runtime/llm_engine.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -518,9 +518,10 @@ void LLMEngine::transfer_kv_blocks(
518518
void LLMEngine::prefetch_from_storage(
519519
const uint32_t dp_rank,
520520
const std::vector<BlockTransferInfo>& block_transfer_info,
521-
std::shared_ptr<std::atomic<bool>> flag,
521+
std::shared_ptr<std::atomic<int32_t>> flag,
522522
std::vector<std::shared_ptr<std::atomic<uint32_t>>>* prefetch_results) {
523523
prefetch_results->reserve(dp_local_tp_size_);
524+
flag->store(dp_local_tp_size_, std::memory_order_acquire);
524525
for (auto tp_rank = 0; tp_rank < dp_local_tp_size_; ++tp_rank) {
525526
prefetch_results->emplace_back(std::make_shared<std::atomic<uint32_t>>(0));
526527
worker_clients_[tp_rank + dp_local_tp_size_ * dp_rank]

xllm/core/distributed_runtime/llm_engine.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ class LLMEngine : public Engine {
8383
void prefetch_from_storage(
8484
const uint32_t dp_rank,
8585
const std::vector<BlockTransferInfo>& block_transfer_info,
86-
std::shared_ptr<std::atomic<bool>> flag,
86+
std::shared_ptr<std::atomic<int32_t>> flag,
8787
std::vector<std::shared_ptr<std::atomic<uint32_t>>>* prefetch_results)
8888
override;
8989

xllm/core/distributed_runtime/remote_worker.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,7 @@ void RemoteWorker::transfer_kv_blocks(
314314

315315
void RemoteWorker::prefetch_from_storage(
316316
const std::vector<BlockTransferInfo>& block_transfer_info,
317-
std::shared_ptr<std::atomic<bool>> flag,
317+
std::shared_ptr<std::atomic<int32_t>> flag,
318318
std::shared_ptr<std::atomic<uint32_t>> success_cnt) {
319319
copy_threadpool_.schedule(
320320
[this,

xllm/core/distributed_runtime/remote_worker.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ class RemoteWorker : public WorkerClient {
121121

122122
virtual void prefetch_from_storage(
123123
const std::vector<BlockTransferInfo>& block_transfer_info,
124-
std::shared_ptr<std::atomic<bool>> flag,
124+
std::shared_ptr<std::atomic<int32_t>> flag,
125125
std::shared_ptr<std::atomic<uint32_t>> success_cnt) override;
126126

127127
// Run the model and return the output.

xllm/core/framework/block/hierarchy_block_manager_pool.cpp

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ void HierarchyBlockManagerPool::deallocate(Sequence* sequence) {
5959
auto* blocks = sequence->kv_state().mutable_kv_blocks();
6060
auto* host_blocks = sequence->host_kv_state().mutable_kv_blocks();
6161

62-
if (blocks->size() == 0 || host_blocks->size() >= blocks->size()) {
62+
if (blocks->size() == 0 || host_blocks->size() > blocks->size()) {
6363
return;
6464
}
6565

@@ -148,12 +148,14 @@ void HierarchyBlockManagerPool::prefetch_from_storage(
148148
prefill_sequence->tokens());
149149
prefill_sequence->add_shared_host_kv_blocks(std::move(shared_blocks));
150150

151-
const size_t num_blocks = prefill_sequence->host_kv_state().num_kv_blocks();
152151
// round down to the nearest block number
153-
const size_t block_size = options_.block_size();
152+
size_t shared_blocks_num =
153+
prefill_sequence->host_kv_state().shared_kv_blocks_num();
154154
const size_t num_additional_blocks =
155-
prefill_sequence->num_tokens() / block_size - num_blocks;
156-
if (num_additional_blocks <= 0) {
155+
(prefill_sequence->num_tokens() + options_.block_size() - 1) /
156+
options_.block_size() -
157+
shared_blocks_num;
158+
if (num_additional_blocks <= 1) {
157159
return;
158160
}
159161

@@ -165,20 +167,19 @@ void HierarchyBlockManagerPool::prefetch_from_storage(
165167
prefill_sequence->host_kv_state().add_kv_blocks(host_blocks);
166168
PrefixCache::compute_hash_keys(
167169
prefill_sequence->tokens(),
168-
*prefill_sequence->host_kv_state().mutable_kv_blocks());
170+
*prefill_sequence->host_kv_state().mutable_kv_blocks(),
171+
shared_blocks_num);
169172

170-
if (num_additional_blocks > 0) {
173+
if (num_additional_blocks > 1) {
171174
const auto host_blocks = prefill_sequence->host_kv_state().kv_blocks();
172175
std::vector<BlockTransferInfo> block_transfer_infos;
173176
block_transfer_infos.reserve(num_additional_blocks);
174-
for (int i = host_blocks.size() - num_additional_blocks;
175-
i < host_blocks.size();
176-
i++) {
177-
block_transfer_infos.emplace_back(
178-
BlockTransferInfo(-1,
179-
host_blocks[i].id(),
180-
host_blocks[i].get_immutable_hash_value(),
181-
TransferType::G2H));
177+
for (int i = 0; i < num_additional_blocks - 1; i++) {
178+
block_transfer_infos.emplace_back(BlockTransferInfo(
179+
-1,
180+
host_blocks[shared_blocks_num + i].id(),
181+
host_blocks[shared_blocks_num + i].get_immutable_hash_value(),
182+
TransferType::G2H));
182183
}
183184

184185
engine_->prefetch_from_storage(prefill_sequence->dp_rank(),
@@ -198,8 +199,21 @@ bool HierarchyBlockManagerPool::update_prefetch_result(
198199

199200
bool prefetch_result = true;
200201
for (auto& prefill_sequence : request->sequences()) {
201-
prefetch_result &= prefill_sequence->update_prefetch_result(timeout);
202+
uint32_t success_cnt = 0;
203+
prefetch_result &=
204+
prefill_sequence->update_prefetch_result(timeout, success_cnt);
205+
206+
if (success_cnt > 0) {
207+
int32_t dp_rank = BlockManagerPool::get_dp_rank(prefill_sequence.get());
208+
auto host_blocks = prefill_sequence->host_kv_state().kv_blocks();
209+
auto cached_blocks =
210+
prefill_sequence->host_kv_state().shared_kv_blocks_num();
211+
212+
host_block_managers_[dp_rank]->cache(
213+
host_blocks.slice(cached_blocks - success_cnt, cached_blocks));
214+
}
202215
}
216+
203217
return prefetch_result;
204218
}
205219

xllm/core/framework/prefix_cache/prefix_cache.cpp

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,11 @@ size_t PrefixCache::insert(const Slice<int32_t>& token_ids,
125125
}
126126

127127
size_t PrefixCache::insert(const std::vector<Block>& blocks) {
128+
Slice<Block> slice(blocks);
129+
return insert(slice);
130+
}
131+
132+
size_t PrefixCache::insert(Slice<Block>& blocks) {
128133
std::vector<Murmur3Key> insert_keys;
129134
return insert(blocks, &insert_keys);
130135
}
@@ -197,7 +202,7 @@ size_t PrefixCache::insert(const Slice<int32_t>& token_ids,
197202
return n_tokens;
198203
}
199204

200-
size_t PrefixCache::insert(const std::vector<Block>& blocks,
205+
size_t PrefixCache::insert(Slice<Block>& blocks,
201206
std::vector<Murmur3Key>* insert_keys) {
202207
const int64_t now = absl::ToUnixMicros(absl::Now());
203208
DNodeList node_list;
@@ -279,7 +284,8 @@ size_t PrefixCache::evict(size_t n_blocks,
279284
}
280285

281286
uint32_t PrefixCache::compute_hash_keys(const Slice<int32_t>& token_ids,
282-
std::vector<Block>& blocks) {
287+
std::vector<Block>& blocks,
288+
const size_t cached_blocks) {
283289
if (blocks.size() == 0) {
284290
return 0;
285291
}
@@ -289,8 +295,10 @@ uint32_t PrefixCache::compute_hash_keys(const Slice<int32_t>& token_ids,
289295
LOG(ERROR) << "token ids do not cover the allocate block.";
290296
return 0;
291297
}
298+
size_t full_block_size =
299+
std::min(token_ids.size() / block_size, blocks.size());
292300

293-
for (size_t i = 0; i < token_ids.size() / block_size; i++) {
301+
for (size_t i = cached_blocks; i < full_block_size; i++) {
294302
if (i == 0) {
295303
murmur_hash3(nullptr,
296304
token_ids.slice(i * block_size, (i + 1) * block_size),
@@ -302,7 +310,7 @@ uint32_t PrefixCache::compute_hash_keys(const Slice<int32_t>& token_ids,
302310
}
303311
}
304312

305-
return token_ids.size() / block_size;
313+
return full_block_size;
306314
}
307315

308316
} // namespace xllm

xllm/core/framework/prefix_cache/prefix_cache.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ class PrefixCache {
7373
std::vector<Block>& blocks);
7474

7575
// insert the blocks with hash key into the prefix tree
76+
virtual size_t insert(Slice<Block>& blocks);
7677
virtual size_t insert(const std::vector<Block>& blocks);
7778

7879
// evict blocks hold by the prefix cache
@@ -97,15 +98,15 @@ class PrefixCache {
9798
virtual KvCacheEvent* get_upload_kvcache_events() { return nullptr; }
9899

99100
static uint32_t compute_hash_keys(const Slice<int32_t>& token_ids,
100-
std::vector<Block>& blocks);
101+
std::vector<Block>& blocks,
102+
const size_t cached_blocks = 0);
101103

102104
protected:
103105
size_t insert(const Slice<int32_t>& token_ids,
104106
std::vector<Block>& blocks,
105107
std::vector<Murmur3Key>* insert_keys);
106108

107-
size_t insert(const std::vector<Block>& blocks,
108-
std::vector<Murmur3Key>* insert_keys);
109+
size_t insert(Slice<Block>& blocks, std::vector<Murmur3Key>* insert_keys);
109110

110111
size_t evict(size_t n_blocks, std::vector<Murmur3Key>* evict_keys);
111112

0 commit comments

Comments
 (0)