Skip to content

[feat] Auto-enable ngram with concurrency <= 32. #6232

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jul 31, 2025
10 changes: 6 additions & 4 deletions examples/llm-api/quickstart_advanced.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import argparse

from tensorrt_llm import LLM, SamplingParams
from tensorrt_llm.llmapi import (CudaGraphConfig, DraftTargetDecodingConfig,
EagleDecodingConfig, KvCacheConfig, MoeConfig,
MTPDecodingConfig, NGramDecodingConfig,
TorchCompileConfig)
from tensorrt_llm.llmapi import (AutoDecodingConfig, CudaGraphConfig,
DraftTargetDecodingConfig, EagleDecodingConfig,
KvCacheConfig, MoeConfig, MTPDecodingConfig,
NGramDecodingConfig, TorchCompileConfig)

example_prompts = [
"Hello, my name is",
Expand Down Expand Up @@ -181,6 +181,8 @@ def setup_llm(args, **kwargs):
is_use_oldest=True,
is_public_pool=True,
)
elif spec_decode_algo == "AUTO":
spec_config = AutoDecodingConfig()
else:
spec_config = None

Expand Down
7 changes: 4 additions & 3 deletions tensorrt_llm/_torch/pyexecutor/py_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -864,7 +864,7 @@ def _prepare_and_schedule_batch(self):
self.use_spec_decode = self.drafter.should_use_spec_decode(
self.active_requests)
self.model_engine.enable_spec_decode = self.use_spec_decode
self._prepare_draft_requests(self.active_requests)
self._prepare_draft_requests()

scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = self._schedule(
)
Expand Down Expand Up @@ -965,14 +965,15 @@ def _executor_loop(self):
iter_stats=iter_stats,
iter_start_time=iter_start_time))

def _prepare_draft_requests(self, requests):
def _prepare_draft_requests(self):
try:
# Set draft tokens here to make the KV cache manager
# and scheduler aware of them.
for req in requests:
for req in self.active_requests:
if req.state not in (LlmRequestState.GENERATION_IN_PROGRESS,
LlmRequestState.DISAGG_GENERATION_INIT):
continue

req.py_last_draft_tokens = req.py_draft_tokens
max_draft_len = self.model_engine.spec_config.max_draft_len

Expand Down
1 change: 1 addition & 0 deletions tensorrt_llm/_torch/speculative/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class SpeculativeDecodingMode(IntEnum):
DRAFT_TARGET = auto()
USER_PROVIDED = auto()
NONE = auto()
AUTO = auto()

def is_mtp(self):
return self == SpeculativeDecodingMode.MTP or self == SpeculativeDecodingMode.MTP_EAGLE
Expand Down
8 changes: 7 additions & 1 deletion tensorrt_llm/_torch/speculative/ngram.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from ordered_set import OrderedSet

from tensorrt_llm.llmapi import NGramDecodingConfig
from tensorrt_llm.logger import logger

from ..pyexecutor.llm_request import *
Expand Down Expand Up @@ -163,10 +164,11 @@ class NGramDrafter(Drafter):

def __init__(
self,
spec_config: "NGramDecodingConfig",
spec_config: NGramDecodingConfig,
ngram_pool_manager: NGramPoolManager = None,
):
assert ngram_pool_manager is not None, "NGram needs a resource manager to maintain the pool."
self.spec_config = spec_config
self.max_draft_len = spec_config.max_draft_len
self.spec_resource_manager = ngram_pool_manager

Expand All @@ -175,6 +177,10 @@ def prepare_draft_tokens(
scheduled_requests: ScheduledRequests,
resource_manager: Optional[ResourceManager] = None,
) -> None:
# Disable NGram speculative decoding auto heuristic for batch size > 32.
if self.spec_config.is_auto_heuristic and len(
scheduled_requests.all_requests()) > 32:
return
# Sort by request_id when py_batch_idx is None as a fallback.
# This happens in the disagg case: for a set of new requests, we draft
# before forward_step, so py_batch_idx is not assigned.
Expand Down
19 changes: 10 additions & 9 deletions tensorrt_llm/llmapi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@
from .build_cache import BuildCacheConfig
from .llm import LLM, RequestOutput
# yapf: disable
from .llm_args import (BatchingType, CacheTransceiverConfig, CalibConfig,
CapacitySchedulerPolicy, ContextChunkingPolicy,
CudaGraphConfig, DraftTargetDecodingConfig,
DynamicBatchConfig, EagleDecodingConfig,
ExtendedRuntimePerfKnobConfig, KvCacheConfig, LlmArgs,
LookaheadDecodingConfig, MedusaDecodingConfig, MoeConfig,
MTPDecodingConfig, NGramDecodingConfig, SchedulerConfig,
TorchCompileConfig, TorchLlmArgs, TrtLlmArgs,
UserProvidedDecodingConfig)
from .llm_args import (AutoDecodingConfig, BatchingType, CacheTransceiverConfig,
CalibConfig, CapacitySchedulerPolicy,
ContextChunkingPolicy, CudaGraphConfig,
DraftTargetDecodingConfig, DynamicBatchConfig,
EagleDecodingConfig, ExtendedRuntimePerfKnobConfig,
KvCacheConfig, LlmArgs, LookaheadDecodingConfig,
MedusaDecodingConfig, MoeConfig, MTPDecodingConfig,
NGramDecodingConfig, SchedulerConfig, TorchCompileConfig,
TorchLlmArgs, TrtLlmArgs, UserProvidedDecodingConfig)
from .llm_utils import (BuildConfig, KvCacheRetentionConfig, QuantAlgo,
QuantConfig)
from .mpi_session import MpiCommSession
Expand Down Expand Up @@ -53,4 +53,5 @@
'LlmArgs',
'TorchLlmArgs',
'TrtLlmArgs',
'AutoDecodingConfig',
]
36 changes: 33 additions & 3 deletions tensorrt_llm/llmapi/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@
from ..logger import logger
from ..sampling_params import SamplingParams
from .llm_args import (TORCH_LLMARGS_EXPLICIT_DOCSTRING,
TRT_LLMARGS_EXPLICIT_DOCSTRING, PeftCacheConfig,
PybindMirror, TorchLlmArgs, TrtLlmArgs)
TRT_LLMARGS_EXPLICIT_DOCSTRING, NGramDecodingConfig,
PeftCacheConfig, PybindMirror, TorchLlmArgs, TrtLlmArgs)
from .llm_utils import (CachedModelLoader, KvCacheRetentionConfig,
LlmBuildStats, ModelLoader, _ModelRuntimeContext)
from .mpi_session import MpiPoolSession, external_mpi_comm_available
Expand Down Expand Up @@ -995,13 +995,43 @@ def _build_model(self):
self._executor_config.cache_transceiver_config = PybindMirror.maybe_to_pybind(
self.args.cache_transceiver_config)
from tensorrt_llm._torch.pyexecutor.config import update_executor_config

spec_config = self.args.speculative_config
max_batch_size = self._executor_config.max_batch_size
# Apply default heuristic to AutoDecodingConfig based on benchmark results
# With concurrency <= 4, max_draft_len = 5, max_matching_ngram_size = 3
# With concurrency <= 32, max_draft_len = 3, max_matching_ngram_size = 5
# With concurrency > 32, speculative decoding is disabled.
if spec_config is not None and spec_config.decoding_type == "AUTO":
if not self.args.disable_overlap_scheduler:
logger.info(
"Disable overlap scheduler to enable Auto speculative decoding with Ngram."
)
# From benchmark results, we found that NGram speculative decoding provides better performance than overlap scheduler with low concurrency <= 32.
# Therefore, we disable overlap scheduler to enable NGram speculative decoding.
self.args.disable_overlap_scheduler = True

spec_config = NGramDecodingConfig(
max_draft_len=5 if max_batch_size <= 4 else 3,
max_matching_ngram_size=3 if max_batch_size <= 4 else 5,
is_keep_all=True,
is_use_oldest=True,
is_public_pool=True,
# Flag to indicate the NGramDecodingConfig is instantiated by auto heuristic.
is_auto_heuristic=True,
)

logger.info(
f"Apply heuristic to incomplete NGramDecodingConfig: max_draft_len={spec_config.max_draft_len}, max_matching_ngram_size={spec_config.max_matching_ngram_size}"
)

update_executor_config(
self._executor_config,
backend=self.args.backend,
pytorch_backend_config=self.args.get_pytorch_backend_config()
if self.args.backend in ["pytorch", "_autodeploy"] else None,
mapping=self.args.parallel_config.to_mapping(),
speculative_config=self.args.speculative_config,
speculative_config=spec_config,
hf_model_dir=self._hf_model_dir,
max_input_len=self.args.max_input_len,
max_seq_len=max_seq_len,
Expand Down
37 changes: 34 additions & 3 deletions tensorrt_llm/llmapi/llm_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,7 @@ def from_dict(cls, data: dict):
"NGram": NGramDecodingConfig,
"DraftTarget": DraftTargetDecodingConfig,
"UserProvided": UserProvidedDecodingConfig,
"AUTO": AutoDecodingConfig,
}

config_class = config_classes.get(decoding_type)
Expand Down Expand Up @@ -446,11 +447,13 @@ class NGramDecodingConfig(DecodingBaseConfig):
is_public_pool: bool = True
Whether to use a common pool for all requests, or the pool is private for each request if False.
"""

max_matching_ngram_size: int = 4
max_matching_ngram_size: int = 0
is_keep_all: bool = True
is_use_oldest: bool = True
is_public_pool: bool = True
# Flag to indicate the NGramDecodingConfig is instantiated by auto heuristic.
# User should not set this flag. Use AutoDecodingConfig instead.
is_auto_heuristic: bool = False

@classmethod
def from_dict(cls, data: dict):
Expand Down Expand Up @@ -510,6 +513,29 @@ def spec_dec_mode(self):
return TorchSpeculativeDecodingMode.MTP


class AutoDecodingConfig(DecodingBaseConfig):
"""
Configuration for auto speculative decoding.

This config is used to automatically select the best speculative decoding algorithm.

According to benchmark results, the best algorithm in general is NGRAM with low concurrency <= 32.
Default heuristic:
With concurrency <= 4, max_draft_len = 5, max_matching_ngram_size = 3
With concurrency <= 32, max_draft_len = 3, max_matching_ngram_size = 5
With concurrency > 32, speculative decoding is disabled.
"""

@classmethod
def from_dict(cls, data: dict):
return cls(**data)

decoding_type: ClassVar[str] = "AUTO"

def supports_backend(self, backend: str) -> bool:
return backend == "pytorch"


class PybindMirror(ABC):
''' A class containing the utilities for mirroring Python classes to
pybinding classes.
Expand Down Expand Up @@ -872,6 +898,7 @@ def supports_backend(self, backend: str) -> bool:
MTPDecodingConfig,
NGramDecodingConfig,
UserProvidedDecodingConfig,
AutoDecodingConfig,
]]


Expand Down Expand Up @@ -1292,7 +1319,6 @@ def from_kwargs(cls, **kwargs: Any) -> "BaseLlmArgs":
tensorrt_llm.llmapi.llm_utils.BaseLlmArgs: The `BaseLlmArgs` instance.
"""
kwargs = BaseLlmArgs._check_consistency(dict(kwargs))

ret = cls(**kwargs)
return ret

Expand Down Expand Up @@ -1621,6 +1647,11 @@ def validate_speculative_config(self):
self.build_config.speculative_decoding_mode = SpeculativeDecodingMode.USER_PROVIDED
self.build_config.max_draft_len = self.speculative_config.max_draft_len

elif isinstance(self.speculative_config, AutoDecodingConfig):
assert self.backend in ['pytorch', '_autodeploy']
self.build_config.speculative_decoding_mode = SpeculativeDecodingMode.AUTO
self.build_config.max_draft_len = self.speculative_config.max_draft_len

else:
raise ValueError(
f"Unrecognized speculative config type {type(self.speculative_config)}"
Expand Down
3 changes: 3 additions & 0 deletions tensorrt_llm/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ class SpeculativeDecodingMode(IntFlag):
EAGLE = auto()
NGRAM = auto()
USER_PROVIDED = auto()
AUTO = auto()

@staticmethod
def from_arguments(args: argparse.Namespace):
Expand All @@ -117,6 +118,8 @@ def from_arguments(args: argparse.Namespace):
return SpeculativeDecodingMode.NGRAM
elif args.speculative_decoding_mode == "user_provided":
return SpeculativeDecodingMode.USER_PROVIDED
elif args.speculative_decoding_mode == "auto":
return SpeculativeDecodingMode.AUTO
else:
assert False, "Unknown speculative_decoding_mode " + args.speculative_decoding_mode

Expand Down
25 changes: 25 additions & 0 deletions tests/integration/defs/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -1775,6 +1775,31 @@ def test_ptp_quickstart_advanced_ngram(llm_root, llm_venv, model_name,
_check_mem_usage(running_log, [27.0, 0, 0, 0])


@pytest.mark.parametrize("model_name,model_path", [
("Llama-3.1-8B-Instruct", "llama-3.1-model/Llama-3.1-8B-Instruct"),
])
def test_ptp_quickstart_advanced_auto(llm_root, llm_venv, model_name,
model_path):
print(f"Testing {model_name}.")
example_root = Path(os.path.join(llm_root, "examples", "llm-api"))
with tempfile.NamedTemporaryFile(mode='w+t',
suffix=f".{model_name}.log",
dir="./",
delete=True,
delete_on_close=True) as running_log:
llm_venv.run_cmd([
str(example_root / "quickstart_advanced.py"),
"--model_dir",
f"{llm_models_root()}/{model_path}",
"--spec_decode_algo",
"AUTO",
"--use_cuda_graph",
"--max_batch_size=4",
],
stdout=running_log)
_check_mem_usage(running_log, [27.0, 0, 0, 0])


@skip_post_blackwell
@pytest.mark.skip_less_device_memory(110000)
@pytest.mark.skip_less_device(8)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ methods:
default: null
# Speculative decoding
speculative_config:
annotation: Union[tensorrt_llm.llmapi.llm_args.DraftTargetDecodingConfig, tensorrt_llm.llmapi.llm_args.EagleDecodingConfig,tensorrt_llm.llmapi.llm_args.LookaheadDecodingConfig, tensorrt_llm.llmapi.llm_args.MedusaDecodingConfig, tensorrt_llm.llmapi.llm_args.MTPDecodingConfig, tensorrt_llm.llmapi.llm_args.NGramDecodingConfig, tensorrt_llm.llmapi.llm_args.UserProvidedDecodingConfig, NoneType]
annotation: Union[tensorrt_llm.llmapi.llm_args.DraftTargetDecodingConfig, tensorrt_llm.llmapi.llm_args.EagleDecodingConfig,tensorrt_llm.llmapi.llm_args.LookaheadDecodingConfig, tensorrt_llm.llmapi.llm_args.MedusaDecodingConfig, tensorrt_llm.llmapi.llm_args.MTPDecodingConfig, tensorrt_llm.llmapi.llm_args.NGramDecodingConfig, tensorrt_llm.llmapi.llm_args.UserProvidedDecodingConfig, tensorrt_llm.llmapi.llm_args.AutoDecodingConfig, NoneType]
default: null
# generation constraints
max_batch_size:
Expand Down