Skip to content

Commit aae8e96

Browse files
taran2210michalkuligowskiadobrzyn
authored
Fix warmup break when max decode bucket bs > max num seq (#107)
Signed-off-by: taran2210 <[email protected]> Co-authored-by: Michał Kuligowski <[email protected]> Co-authored-by: Agata Dobrzyniewicz <[email protected]>
1 parent 69d4ad3 commit aae8e96

File tree

1 file changed

+27
-3
lines changed

1 file changed

+27
-3
lines changed

vllm_gaudi/v1/worker/hpu_model_runner.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2960,16 +2960,37 @@ def _read_profiling_cfg(self):
29602960

29612961
@torch.inference_mode()
29622962
def warmup_model(self) -> None:
2963-
self.defragmenter.initialize(self.kv_caches, self.block_size)
29642963
if not self.enable_bucketing:
29652964
return
2965+
2966+
self.bucketing_manager.generate_prompt_buckets()
2967+
self.bucketing_manager.generate_decode_buckets()
2968+
2969+
max_bucket = max(self.bucketing_manager.decode_buckets[-1][0],
2970+
self.bucketing_manager.prompt_buckets[-1][0])
2971+
if max_bucket > self.input_batch.max_num_reqs:
2972+
input_batch_bkp = self.input_batch
2973+
self.input_batch = InputBatch(
2974+
max_num_reqs=self.bucketing_manager.decode_buckets[-1][0],
2975+
max_model_len=self.max_model_len,
2976+
max_num_batched_tokens=self.max_num_tokens,
2977+
device=self.device,
2978+
pin_memory=self.pin_memory,
2979+
vocab_size=self.model_config.get_vocab_size(),
2980+
block_sizes=[self.block_size],
2981+
logitsprocs=build_logitsprocs(
2982+
self.vllm_config, self.device, self.pin_memory,
2983+
self.is_pooling_model,
2984+
self.vllm_config.model_config.logits_processors),
2985+
)
2986+
2987+
self.defragmenter.initialize(self.kv_caches, self.block_size)
2988+
29662989
prompt_profile_cfg, decode_profile_cfg = self._read_profiling_cfg()
29672990
if prompt_profile_cfg or decode_profile_cfg:
29682991
self._generate_profiling(prompt_profile_cfg, decode_profile_cfg)
29692992
raise AssertionError("Finished profiling")
29702993
kv_caches = self.kv_caches
2971-
self.bucketing_manager.generate_prompt_buckets()
2972-
self.bucketing_manager.generate_decode_buckets()
29732994

29742995
if not htorch.utils.internal.is_lazy(
29752996
) and not self.model_config.enforce_eager:
@@ -3043,6 +3064,9 @@ def warmup_model(self) -> None:
30433064
logger.info(msg)
30443065
self.profiler.end()
30453066

3067+
if max_bucket > self.input_batch.max_num_reqs:
3068+
self.input_batch = input_batch_bkp
3069+
30463070
def shutdown_inc(self):
30473071
can_finalize_inc = self._is_quant_with_inc() and \
30483072
(self.model.model is not None) and \

0 commit comments

Comments
 (0)