@@ -59,7 +59,7 @@ void HierarchyBlockManagerPool::deallocate(Sequence* sequence) {
5959 auto * blocks = sequence->kv_state ().mutable_kv_blocks ();
6060 auto * host_blocks = sequence->host_kv_state ().mutable_kv_blocks ();
6161
62- if (blocks->size () == 0 || host_blocks->size () >= blocks->size ()) {
62+ if (blocks->size () == 0 || host_blocks->size () > blocks->size ()) {
6363 return ;
6464 }
6565
@@ -148,12 +148,14 @@ void HierarchyBlockManagerPool::prefetch_from_storage(
148148 prefill_sequence->tokens ());
149149 prefill_sequence->add_shared_host_kv_blocks (std::move (shared_blocks));
150150
151- const size_t num_blocks = prefill_sequence->host_kv_state ().num_kv_blocks ();
152151 // round down to the nearest block number
153- const size_t block_size = options_.block_size ();
152+ size_t shared_blocks_num =
153+ prefill_sequence->host_kv_state ().shared_kv_blocks_num ();
154154 const size_t num_additional_blocks =
155- prefill_sequence->num_tokens () / block_size - num_blocks;
156- if (num_additional_blocks <= 0 ) {
155+ (prefill_sequence->num_tokens () + options_.block_size () - 1 ) /
156+ options_.block_size () -
157+ shared_blocks_num;
158+ if (num_additional_blocks <= 1 ) {
157159 return ;
158160 }
159161
@@ -165,20 +167,19 @@ void HierarchyBlockManagerPool::prefetch_from_storage(
165167 prefill_sequence->host_kv_state ().add_kv_blocks (host_blocks);
166168 PrefixCache::compute_hash_keys (
167169 prefill_sequence->tokens (),
168- *prefill_sequence->host_kv_state ().mutable_kv_blocks ());
170+ *prefill_sequence->host_kv_state ().mutable_kv_blocks (),
171+ shared_blocks_num);
169172
170- if (num_additional_blocks > 0 ) {
173+ if (num_additional_blocks > 1 ) {
171174 const auto host_blocks = prefill_sequence->host_kv_state ().kv_blocks ();
172175 std::vector<BlockTransferInfo> block_transfer_infos;
173176 block_transfer_infos.reserve (num_additional_blocks);
174- for (int i = host_blocks.size () - num_additional_blocks;
175- i < host_blocks.size ();
176- i++) {
177- block_transfer_infos.emplace_back (
178- BlockTransferInfo (-1 ,
179- host_blocks[i].id (),
180- host_blocks[i].get_immutable_hash_value (),
181- TransferType::G2H));
177+ for (int i = 0 ; i < num_additional_blocks - 1 ; i++) {
178+ block_transfer_infos.emplace_back (BlockTransferInfo (
179+ -1 ,
180+ host_blocks[shared_blocks_num + i].id (),
181+ host_blocks[shared_blocks_num + i].get_immutable_hash_value (),
182+ TransferType::G2H));
182183 }
183184
184185 engine_->prefetch_from_storage (prefill_sequence->dp_rank (),
@@ -198,8 +199,21 @@ bool HierarchyBlockManagerPool::update_prefetch_result(
198199
199200 bool prefetch_result = true ;
200201 for (auto & prefill_sequence : request->sequences ()) {
201- prefetch_result &= prefill_sequence->update_prefetch_result (timeout);
202+ uint32_t success_cnt = 0 ;
203+ prefetch_result &=
204+ prefill_sequence->update_prefetch_result (timeout, success_cnt);
205+
206+ if (success_cnt > 0 ) {
207+ int32_t dp_rank = BlockManagerPool::get_dp_rank (prefill_sequence.get ());
208+ auto host_blocks = prefill_sequence->host_kv_state ().kv_blocks ();
209+ auto cached_blocks =
210+ prefill_sequence->host_kv_state ().shared_kv_blocks_num ();
211+
212+ host_block_managers_[dp_rank]->cache (
213+ host_blocks.slice (cached_blocks - success_cnt, cached_blocks));
214+ }
202215 }
216+
203217 return prefetch_result;
204218}
205219
0 commit comments