@@ -2960,16 +2960,37 @@ def _read_profiling_cfg(self):
29602960
29612961 @torch .inference_mode ()
29622962 def warmup_model (self ) -> None :
2963- self .defragmenter .initialize (self .kv_caches , self .block_size )
29642963 if not self .enable_bucketing :
29652964 return
2965+
2966+ self .bucketing_manager .generate_prompt_buckets ()
2967+ self .bucketing_manager .generate_decode_buckets ()
2968+
2969+ max_bucket = max (self .bucketing_manager .decode_buckets [- 1 ][0 ],
2970+ self .bucketing_manager .prompt_buckets [- 1 ][0 ])
2971+ if max_bucket > self .input_batch .max_num_reqs :
2972+ input_batch_bkp = self .input_batch
2973+ self .input_batch = InputBatch (
2974+ max_num_reqs = self .bucketing_manager .decode_buckets [- 1 ][0 ],
2975+ max_model_len = self .max_model_len ,
2976+ max_num_batched_tokens = self .max_num_tokens ,
2977+ device = self .device ,
2978+ pin_memory = self .pin_memory ,
2979+ vocab_size = self .model_config .get_vocab_size (),
2980+ block_sizes = [self .block_size ],
2981+ logitsprocs = build_logitsprocs (
2982+ self .vllm_config , self .device , self .pin_memory ,
2983+ self .is_pooling_model ,
2984+ self .vllm_config .model_config .logits_processors ),
2985+ )
2986+
2987+ self .defragmenter .initialize (self .kv_caches , self .block_size )
2988+
29662989 prompt_profile_cfg , decode_profile_cfg = self ._read_profiling_cfg ()
29672990 if prompt_profile_cfg or decode_profile_cfg :
29682991 self ._generate_profiling (prompt_profile_cfg , decode_profile_cfg )
29692992 raise AssertionError ("Finished profiling" )
29702993 kv_caches = self .kv_caches
2971- self .bucketing_manager .generate_prompt_buckets ()
2972- self .bucketing_manager .generate_decode_buckets ()
29732994
29742995 if not htorch .utils .internal .is_lazy (
29752996 ) and not self .model_config .enforce_eager :
@@ -3043,6 +3064,9 @@ def warmup_model(self) -> None:
30433064 logger .info (msg )
30443065 self .profiler .end ()
30453066
3067+ if max_bucket > self .input_batch .max_num_reqs :
3068+ self .input_batch = input_batch_bkp
3069+
30463070 def shutdown_inc (self ):
30473071 can_finalize_inc = self ._is_quant_with_inc () and \
30483072 (self .model .model is not None ) and \
0 commit comments