Skip to content

Commit 4ad7ef1

Browse files
authored
[https://nvbugs/5534705][fix] Skip unnecessary CUDA graph capture (#8… (NVIDIA#8344)
Signed-off-by: ziyixiong-nv <[email protected]>
1 parent 838958c commit 4ad7ef1

File tree

1 file changed

+20
-13
lines changed

1 file changed

+20
-13
lines changed

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -712,19 +712,25 @@ def general_warmup(reverse: bool = False):
712712
cuda_graph_batch_sizes = sorted(self._cuda_graph_batch_sizes,
713713
reverse=True)
714714
# Create CUDA graphs for different draft lengths
715-
draft_lengths = [self.max_draft_len]
716-
# For non-draft model, we also capture the CUDA graph instance for draft length 0,
717-
# so that when we disable spec decode at runtime, we can still run the captured graph.
718-
# Note that for one engine mode, we are not able to turn off spec decode at runtime.
719-
if (not self.is_draft_model and self.max_draft_len > 0
720-
and not self.spec_config.spec_dec_mode.use_one_engine()
721-
# Assume that speculation is always on if the user didn't give us a max_concurrency
722-
# value. This will save on memory.
723-
and self.spec_config.max_concurrency is not None):
724-
draft_lengths.append(0)
725-
if self.is_spec_decode and self.is_draft_model and spec_resource_manager is not None and isinstance(
726-
spec_resource_manager, Eagle3ResourceManager):
727-
draft_lengths.append(self.original_max_draft_len)
715+
draft_lengths = []
716+
if self.is_draft_model:
717+
if self.model_is_wrapped and self.is_spec_decode and spec_resource_manager is not None and isinstance(
718+
spec_resource_manager, Eagle3ResourceManager):
719+
# The CDL path uses draft_len > 0 for the number of iterations in the drafting loop.
720+
draft_lengths.append(self.original_max_draft_len)
721+
else:
722+
draft_lengths.append(self.max_draft_len)
723+
else:
724+
# For non-draft model, we also capture the CUDA graph instance for draft length 0,
725+
# so that when we disable spec decode at runtime, we can still run the captured graph.
726+
# Note that for one engine mode, we are not able to turn off spec decode at runtime.
727+
if (self.max_draft_len > 0
728+
and not self.spec_config.spec_dec_mode.use_one_engine()
729+
# Assume that speculation is always on if the user didn't give us a max_concurrency
730+
# value. This will save on memory.
731+
and self.spec_config.max_concurrency is not None):
732+
draft_lengths.append(0)
733+
draft_lengths = [self.max_draft_len]
728734

729735
for bs in cuda_graph_batch_sizes:
730736
if bs > self.batch_size:
@@ -740,6 +746,7 @@ def general_warmup(reverse: bool = False):
740746
logger.info(
741747
f"Run generation only CUDA graph warmup for batch size={bs}, draft_len={draft_len}"
742748
)
749+
743750
self.enable_spec_decode = draft_len > 0 or self.is_draft_model
744751

745752
def _update_draft_inference_state(is_first_draft: bool,

0 commit comments

Comments
 (0)