Skip to content

Commit 8df7a26

Browse files
mikeiovinenv-yilinf
authored andcommitted
[None][feat] Optimize CUDA graph memory usage for spec decode cases (NVIDIA#6718)
Signed-off-by: Mike Iovine <[email protected]>
1 parent 630c6f3 commit 8df7a26

File tree

2 files changed

+12
-4
lines changed

2 files changed

+12
-4
lines changed

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -726,8 +726,11 @@ def disable_optimization(backend: Backend):
726726
# For non-draft model, we also capture the CUDA graph instance for draft length 0,
727727
# so that when we disable spec decode at runtime, we can still run the captured graph.
728728
# Note that for one engine mode, we are not able to turn off spec decode at runtime.
729-
if not self.is_draft_model and self.max_draft_len > 0 and not self.spec_config.spec_dec_mode.use_one_engine(
730-
):
729+
if (not self.is_draft_model and self.max_draft_len > 0
730+
and not self.spec_config.spec_dec_mode.use_one_engine()
731+
# Assume that speculation is always on if the user didn't give us a max_concurrency
732+
# value. This will save on memory.
733+
and self.spec_config.max_concurrency is not None):
731734
draft_lengths.append(0)
732735

733736
for bs in cuda_graph_batch_sizes:

tensorrt_llm/_torch/speculative/drafter.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from abc import ABC, abstractmethod
2-
from typing import List, Optional
2+
from typing import List, Optional, final
33

44
from ..pyexecutor.llm_request import LlmRequest
55
from ..pyexecutor.resource_manager import ResourceManager
@@ -26,8 +26,13 @@ def prepare_draft_tokens(
2626
"""
2727
raise NotImplementedError
2828

29+
@final
2930
def should_use_spec_decode(self, requests: List[LlmRequest]) -> bool:
30-
"""Check if spec decode should be used for the current iteration."""
31+
"""
32+
You probably don't want to override this. ModelEngine
33+
assumes that speculation is always on if max_concurrency
34+
is not specified by the user's spec config.
35+
"""
3136
if self.max_concurrency is not None:
3237
return len(requests) <= self.max_concurrency
3338
return True

0 commit comments

Comments
 (0)