Skip to content

Commit 3e06862

Browse files
committed
feat: adapt host block manager for PD.
1 parent a3bf8b3 commit 3e06862

11 files changed

+97
-86
lines changed

xllm/core/framework/block/block.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ class Block final {
6767
memcpy(hash_value_, hash_value, MURMUR_HASH3_VALUE_LEN);
6868
}
6969

70-
uint32_t get_hash_value_len() { return MURMUR_HASH3_VALUE_LEN; }
70+
uint32_t get_hash_value_len() const { return MURMUR_HASH3_VALUE_LEN; }
7171

7272
private:
7373
// increase reference count

xllm/core/framework/block/block_manager.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,6 @@ class BlockManager {
5454

5555
virtual void deallocate(const Slice<Block>& blocks) = 0;
5656

57-
virtual void deallocate(std::vector<Block>& blocks) = 0;
58-
5957
virtual std::vector<Block> allocate(size_t num_blocks) = 0;
6058

6159
virtual std::vector<Block> allocate_shared(

xllm/core/framework/block/block_manager_impl.cpp

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -93,12 +93,6 @@ void BlockManagerImpl::deallocate(const Slice<Block>& blocks) {
9393
}
9494
}
9595

96-
void BlockManagerImpl::deallocate(std::vector<Block>& blocks) {
97-
Slice<Block> slice(blocks);
98-
deallocate(slice);
99-
blocks.clear();
100-
}
101-
10296
bool BlockManagerImpl::has_enough_blocks(uint32_t num_blocks) {
10397
if (num_blocks <= num_free_blocks_) {
10498
return true;

xllm/core/framework/block/block_manager_impl.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,6 @@ class BlockManagerImpl : public BlockManager {
3535

3636
void deallocate(const Slice<Block>& blocks) override;
3737

38-
void deallocate(std::vector<Block>& blocks) override;
39-
4038
// allocate shared blocks when enable prefix cache
4139
std::vector<Block> allocate_shared(
4240
const Slice<int32_t>& tokens_ids,

xllm/core/framework/block/concurrent_block_manager_impl.cpp

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,6 @@ void ConcurrentBlockManagerImpl::deallocate(const Slice<Block>& blocks) {
3030
BlockManagerImpl::deallocate(blocks);
3131
}
3232

33-
void ConcurrentBlockManagerImpl::deallocate(std::vector<Block>& blocks) {
34-
std::lock_guard<std::mutex> lock(mutex_);
35-
BlockManagerImpl::deallocate(blocks);
36-
}
37-
3833
std::vector<Block> ConcurrentBlockManagerImpl::allocate_shared(
3934
const Slice<int32_t>& tokens_ids,
4035
const Slice<Block>& existed_shared_blocks) {

xllm/core/framework/block/concurrent_block_manager_impl.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,6 @@ class ConcurrentBlockManagerImpl : public BlockManagerImpl {
3030

3131
void deallocate(const Slice<Block>& blocks) override;
3232

33-
void deallocate(std::vector<Block>& blocks) override;
34-
3533
// try to share blocks among sequences with the same prefix
3634
std::vector<Block> allocate_shared(
3735
const Slice<int32_t>& tokens_ids,

xllm/core/framework/block/hierarchy_block_manager_pool.cpp

Lines changed: 35 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,7 @@ HierarchyBlockManagerPool::HierarchyBlockManagerPool(
4747
}
4848

4949
load_block_transfer_infos_.resize(host_block_managers_.size());
50-
offload_block_transfer_infos_.resize(host_block_managers_.size());
51-
saved_host_blocks_.resize(host_block_managers_.size());
52-
saved_device_blocks_.resize(host_block_managers_.size());
50+
offload_block_pair_queues_.resize(host_block_managers_.size());
5351
}
5452

5553
void HierarchyBlockManagerPool::deallocate(Sequence* sequence) {
@@ -68,10 +66,6 @@ void HierarchyBlockManagerPool::deallocate(Sequence* sequence) {
6866
size_t cached_block_num =
6967
sequence->host_kv_state().kv_cache_tokens_num() / options_.block_size();
7068

71-
if (host_blocks->size() > 0) {
72-
host_block_managers_[dp_rank]->cache(sequence->tokens(), *host_blocks);
73-
}
74-
7569
size_t needed_block_num =
7670
sequence->num_tokens() / options_.block_size() - host_blocks->size();
7771

@@ -88,14 +82,9 @@ void HierarchyBlockManagerPool::deallocate(Sequence* sequence) {
8882
}
8983

9084
host_blocks->at(i).set_hash_value(blocks->at(i).get_immutable_hash_value());
91-
saved_host_blocks_[dp_rank].emplace_back(std::move(host_blocks->at(i)));
92-
saved_device_blocks_[dp_rank].emplace_back(std::move(blocks->at(i)));
93-
offload_block_transfer_infos_[dp_rank].emplace_back(BlockTransferInfo(
94-
saved_device_blocks_[dp_rank].back().id(),
95-
saved_host_blocks_[dp_rank].back().id(),
96-
saved_host_blocks_[dp_rank].back().get_immutable_hash_value(),
97-
saved_host_blocks_[dp_rank].back().get_hash_value_len(),
98-
TransferType::D2G));
85+
auto block_pair = std::make_shared<OffloadBlockPair>(
86+
std::move(blocks->at(i)), std::move(host_blocks->at(i)));
87+
offload_block_pair_queues_[dp_rank].enqueue(std::move(block_pair));
9988
}
10089
host_block_managers_[dp_rank]->cache(
10190
*sequence->host_kv_state().mutable_kv_blocks());
@@ -235,36 +224,50 @@ void HierarchyBlockManagerPool::transfer_blocks(
235224
}
236225

237226
// offload blocks from device to host and kvcache store
238-
for (int i = 0; i < offload_block_transfer_infos_.size(); i++) {
239-
if (!offload_block_transfer_infos_[i].empty()) {
240-
folly::collectAll(std::move(engine_->transfer_kv_blocks(
241-
i, std::move(offload_block_transfer_infos_[i]))))
227+
for (int i = 0; i < offload_block_pair_queues_.size(); i++) {
228+
std::vector<BlockTransferInfo> transfer_infos;
229+
std::vector<Block> src_blocks;
230+
std::vector<Block> dst_blocks;
231+
232+
std::shared_ptr<OffloadBlockPair> block_pair;
233+
while (offload_block_pair_queues_[i].try_dequeue(block_pair)) {
234+
src_blocks.emplace_back(std::move(block_pair->src));
235+
dst_blocks.emplace_back(std::move(block_pair->dst));
236+
transfer_infos.emplace_back(
237+
BlockTransferInfo(src_blocks.back().id(),
238+
dst_blocks.back().id(),
239+
dst_blocks.back().get_immutable_hash_value(),
240+
TransferType::D2G));
241+
block_pair.reset();
242+
}
243+
244+
if (!transfer_infos.empty()) {
245+
folly::collectAll(
246+
std::move(engine_->transfer_kv_blocks(i, std::move(transfer_infos))))
242247
.via(folly::getGlobalCPUExecutor())
243-
.thenValue([host_blocks = std::move(saved_host_blocks_[i]),
244-
device_blocks = std::move(saved_device_blocks_[i]),
245-
host_block_mgr_ptr = host_block_managers_[i].get(),
246-
device_block_mgr_ptr = block_managers_[i].get()](
247-
std::vector<folly::Try<uint32_t>>&& results) {
248+
.thenValue([device_blocks = std::move(src_blocks),
249+
host_blocks = std::move(dst_blocks),
250+
device_block_mgr_ptr = block_managers_[i].get(),
251+
host_block_mgr_ptr = host_block_managers_[i].get()](
252+
std::vector<folly::Try<uint32_t>>&& results) mutable {
248253
for (auto&& result : results) {
249254
if (result.value() != host_blocks.size()) {
250255
LOG(FATAL) << "Offload copy fail, expected "
251256
<< host_blocks.size() << ", got " << result.value();
252257
}
253258
}
259+
260+
device_block_mgr_ptr->deallocate({device_blocks});
261+
device_blocks.clear();
262+
254263
host_block_mgr_ptr->cache(host_blocks);
255264
host_block_mgr_ptr->deallocate({host_blocks});
256-
device_block_mgr_ptr->deallocate({device_blocks});
265+
host_blocks.clear();
266+
257267
return 0;
258268
});
259269
}
260270
}
261-
262-
offload_block_transfer_infos_.clear();
263-
saved_host_blocks_.clear();
264-
saved_device_blocks_.clear();
265-
offload_block_transfer_infos_.resize(host_block_managers_.size());
266-
saved_host_blocks_.resize(host_block_managers_.size());
267-
saved_device_blocks_.resize(host_block_managers_.size());
268271
}
269272

270273
void HierarchyBlockManagerPool::get_merged_kvcache_event(

xllm/core/framework/block/hierarchy_block_manager_pool.h

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,31 @@ limitations under the License.
1717

1818
#include "block_manager_pool.h"
1919
#include "runtime/engine.h"
20+
#include "util/blockingconcurrentqueue.h"
2021

2122
namespace xllm {
2223

2324
class Engine;
2425

26+
struct OffloadBlockPair {
27+
OffloadBlockPair(Block& s, Block& d) : src(s), dst(d) {}
28+
29+
OffloadBlockPair(Block&& s, Block&& d)
30+
: src(std::move(s)), dst(std::move(d)) {}
31+
32+
OffloadBlockPair(Block& s) : src(s) {}
33+
34+
OffloadBlockPair(Block&& s) : src(std::move(s)) {}
35+
36+
Block src;
37+
Block dst;
38+
};
39+
2540
class HierarchyBlockManagerPool : public BlockManagerPool {
2641
public:
42+
using OffloadBlockPairQueue =
43+
moodycamel::BlockingConcurrentQueue<std::shared_ptr<OffloadBlockPair>>;
44+
2745
explicit HierarchyBlockManagerPool(const BlockManagerPool::Options& options,
2846
Engine* engine,
2947
int32_t dp_size = 1);
@@ -51,9 +69,7 @@ class HierarchyBlockManagerPool : public BlockManagerPool {
5169

5270
// BlockTransferInfo per step
5371
std::vector<std::vector<BlockTransferInfo>> load_block_transfer_infos_;
54-
std::vector<std::vector<BlockTransferInfo>> offload_block_transfer_infos_;
55-
std::vector<std::vector<Block>> saved_host_blocks_;
56-
std::vector<std::vector<Block>> saved_device_blocks_;
72+
std::vector<OffloadBlockPairQueue> offload_block_pair_queues_;
5773
};
5874

5975
} // namespace xllm

xllm/core/framework/model/model_input_params.h

Lines changed: 39 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ limitations under the License.
2424
#include "framework/batch/batch_forward_type.h"
2525
#include "framework/request/mm_data.h"
2626
#include "npu_dp_ep_padding.h"
27+
#include "util/hash_util.h"
2728
#include "util/tensor_helper.h"
2829

2930
namespace xllm {
@@ -38,47 +39,60 @@ enum class TransferType : uint8_t {
3839
struct BlockTransferInfo {
3940
int32_t src_block_id = -1;
4041
int32_t dst_block_id = -1;
41-
uint8_t* hash_key = nullptr;
42+
uint8_t hash_key[MURMUR_HASH3_VALUE_LEN];
4243
TransferType transfer_type;
43-
uint32_t hash_key_len = -1;
4444

4545
BlockTransferInfo(int32_t src_block_id, int32_t dst_block_id) {
4646
this->src_block_id = src_block_id;
4747
this->dst_block_id = dst_block_id;
4848
}
4949

50-
BlockTransferInfo(int32_t src_block_id,
51-
int32_t dst_block_id,
52-
const uint8_t* hash_key,
53-
TransferType transfer_type) {
54-
this->src_block_id = src_block_id;
55-
this->dst_block_id = dst_block_id;
56-
this->hash_key = const_cast<uint8_t*>(hash_key);
57-
this->transfer_type = transfer_type;
50+
BlockTransferInfo(int32_t src_id,
51+
int32_t dst_id,
52+
const uint8_t* key,
53+
TransferType type)
54+
: src_block_id(src_id), dst_block_id(dst_id), transfer_type(type) {
55+
memcpy(hash_key, key, MURMUR_HASH3_VALUE_LEN);
5856
}
5957

60-
BlockTransferInfo(int32_t src_block_id,
61-
int32_t dst_block_id,
62-
const uint8_t* hash_key,
63-
uint32_t hash_key_len,
64-
TransferType transfer_type) {
65-
this->src_block_id = src_block_id;
66-
this->dst_block_id = dst_block_id;
67-
this->hash_key = new uint8_t[hash_key_len];
68-
memcpy(this->hash_key, hash_key, hash_key_len);
69-
this->transfer_type = transfer_type;
58+
BlockTransferInfo(const BlockTransferInfo& other)
59+
: src_block_id(other.src_block_id),
60+
dst_block_id(other.dst_block_id),
61+
transfer_type(other.transfer_type) {
62+
memcpy(hash_key, other.hash_key, MURMUR_HASH3_VALUE_LEN);
7063
}
7164

72-
~BlockTransferInfo() {
73-
if (hash_key_len != -1 && hash_key != nullptr) {
74-
delete[] hash_key;
75-
}
65+
BlockTransferInfo(BlockTransferInfo&& other)
66+
: src_block_id(other.src_block_id),
67+
dst_block_id(other.dst_block_id),
68+
transfer_type(other.transfer_type) {
69+
memcpy(hash_key, other.hash_key, MURMUR_HASH3_VALUE_LEN);
70+
71+
other.src_block_id = -1;
72+
other.dst_block_id = -1;
73+
}
74+
75+
BlockTransferInfo& operator=(const BlockTransferInfo& other) {
76+
src_block_id = other.src_block_id;
77+
dst_block_id = other.dst_block_id;
78+
transfer_type = other.transfer_type;
79+
memcpy(hash_key, other.hash_key, MURMUR_HASH3_VALUE_LEN);
80+
}
81+
82+
BlockTransferInfo& operator=(BlockTransferInfo&& other) {
83+
src_block_id = other.src_block_id;
84+
dst_block_id = other.dst_block_id;
85+
transfer_type = other.transfer_type;
86+
memcpy(hash_key, other.hash_key, MURMUR_HASH3_VALUE_LEN);
87+
88+
other.src_block_id = -1;
89+
other.dst_block_id = -1;
7690
}
7791

7892
std::string to_string() const {
7993
std::string rt = ", has_key:";
8094
for (int i = 0; i < 16; i++) {
81-
rt += std::to_string(int64_t(*(hash_key + i))) + " ";
95+
rt += std::to_string(int64_t(hash_key[i])) + " ";
8296
}
8397
return std::to_string(src_block_id) + "->" + std::to_string(dst_block_id) +
8498
", " + std::to_string(uint32_t(transfer_type)) + rt;

xllm/core/runtime/params_utils.cpp

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -723,7 +723,6 @@ uint64_t proto_to_block_transfer_info(
723723
pb_block_transfer_info.transfer_infos(i).dst_block_id(),
724724
reinterpret_cast<const uint8_t*>(
725725
pb_block_transfer_info.transfer_infos(i).hash_key().data()),
726-
pb_block_transfer_info.transfer_infos(i).hash_key().size(),
727726
TransferType(pb_block_transfer_info.transfer_type()));
728727
}
729728

@@ -737,11 +736,6 @@ bool block_transfer_info_to_proto(
737736
block_transfer_info.size());
738737
auto transfer_type = block_transfer_info[0].transfer_type;
739738
for (const BlockTransferInfo info : block_transfer_info) {
740-
if (info.hash_key == nullptr) {
741-
LOG(ERROR) << "Convert to BlockTransferInfos fail, hash key is nullptr!";
742-
return false;
743-
}
744-
745739
if (transfer_type != info.transfer_type) {
746740
LOG(ERROR) << "Convert to BlockTransferInfos fail, TransferType must be "
747741
"same, but got "

0 commit comments

Comments
 (0)