@@ -712,19 +712,25 @@ def general_warmup(reverse: bool = False):
712712 cuda_graph_batch_sizes = sorted (self ._cuda_graph_batch_sizes ,
713713 reverse = True )
714714 # Create CUDA graphs for different draft lengths
715- draft_lengths = [self .max_draft_len ]
716- # For non-draft model, we also capture the CUDA graph instance for draft length 0,
717- # so that when we disable spec decode at runtime, we can still run the captured graph.
718- # Note that for one engine mode, we are not able to turn off spec decode at runtime.
719- if (not self .is_draft_model and self .max_draft_len > 0
720- and not self .spec_config .spec_dec_mode .use_one_engine ()
721- # Assume that speculation is always on if the user didn't give us a max_concurrency
722- # value. This will save on memory.
723- and self .spec_config .max_concurrency is not None ):
724- draft_lengths .append (0 )
725- if self .is_spec_decode and self .is_draft_model and spec_resource_manager is not None and isinstance (
726- spec_resource_manager , Eagle3ResourceManager ):
727- draft_lengths .append (self .original_max_draft_len )
715+ draft_lengths = []
716+ if self .is_draft_model :
717+ if self .model_is_wrapped and self .is_spec_decode and spec_resource_manager is not None and isinstance (
718+ spec_resource_manager , Eagle3ResourceManager ):
719+ # The CDL path uses draft_len > 0 for the number of iterations in the drafting loop.
720+ draft_lengths .append (self .original_max_draft_len )
721+ else :
722+ draft_lengths .append (self .max_draft_len )
723+ else :
724+ # For non-draft model, we also capture the CUDA graph instance for draft length 0,
725+ # so that when we disable spec decode at runtime, we can still run the captured graph.
726+ # Note that for one engine mode, we are not able to turn off spec decode at runtime.
727+ if (self .max_draft_len > 0
728+ and not self .spec_config .spec_dec_mode .use_one_engine ()
729+ # Assume that speculation is always on if the user didn't give us a max_concurrency
730+ # value. This will save on memory.
731+ and self .spec_config .max_concurrency is not None ):
732+ draft_lengths .append (0 )
733+ draft_lengths = [self .max_draft_len ]
728734
729735 for bs in cuda_graph_batch_sizes :
730736 if bs > self .batch_size :
@@ -740,6 +746,7 @@ def general_warmup(reverse: bool = False):
740746 logger .info (
741747 f"Run generation only CUDA graph warmup for batch size={ bs } , draft_len={ draft_len } "
742748 )
749+
743750 self .enable_spec_decode = draft_len > 0 or self .is_draft_model
744751
745752 def _update_draft_inference_state (is_first_draft : bool ,
0 commit comments