@@ -2960,16 +2960,37 @@ def _read_profiling_cfg(self):
2960
2960
2961
2961
@torch .inference_mode ()
2962
2962
def warmup_model (self ) -> None :
2963
- self .defragmenter .initialize (self .kv_caches , self .block_size )
2964
2963
if not self .enable_bucketing :
2965
2964
return
2965
+
2966
+ self .bucketing_manager .generate_prompt_buckets ()
2967
+ self .bucketing_manager .generate_decode_buckets ()
2968
+
2969
+ max_bucket = max (self .bucketing_manager .decode_buckets [- 1 ][0 ],
2970
+ self .bucketing_manager .prompt_buckets [- 1 ][0 ])
2971
+ if max_bucket > self .input_batch .max_num_reqs :
2972
+ input_batch_bkp = self .input_batch
2973
+ self .input_batch = InputBatch (
2974
+ max_num_reqs = self .bucketing_manager .decode_buckets [- 1 ][0 ],
2975
+ max_model_len = self .max_model_len ,
2976
+ max_num_batched_tokens = self .max_num_tokens ,
2977
+ device = self .device ,
2978
+ pin_memory = self .pin_memory ,
2979
+ vocab_size = self .model_config .get_vocab_size (),
2980
+ block_sizes = [self .block_size ],
2981
+ logitsprocs = build_logitsprocs (
2982
+ self .vllm_config , self .device , self .pin_memory ,
2983
+ self .is_pooling_model ,
2984
+ self .vllm_config .model_config .logits_processors ),
2985
+ )
2986
+
2987
+ self .defragmenter .initialize (self .kv_caches , self .block_size )
2988
+
2966
2989
prompt_profile_cfg , decode_profile_cfg = self ._read_profiling_cfg ()
2967
2990
if prompt_profile_cfg or decode_profile_cfg :
2968
2991
self ._generate_profiling (prompt_profile_cfg , decode_profile_cfg )
2969
2992
raise AssertionError ("Finished profiling" )
2970
2993
kv_caches = self .kv_caches
2971
- self .bucketing_manager .generate_prompt_buckets ()
2972
- self .bucketing_manager .generate_decode_buckets ()
2973
2994
2974
2995
if not htorch .utils .internal .is_lazy (
2975
2996
) and not self .model_config .enforce_eager :
@@ -3043,6 +3064,9 @@ def warmup_model(self) -> None:
3043
3064
logger .info (msg )
3044
3065
self .profiler .end ()
3045
3066
3067
+ if max_bucket > self .input_batch .max_num_reqs :
3068
+ self .input_batch = input_batch_bkp
3069
+
3046
3070
def shutdown_inc (self ):
3047
3071
can_finalize_inc = self ._is_quant_with_inc () and \
3048
3072
(self .model .model is not None ) and \
0 commit comments