From 524d6a0ac16b453695b502161233a654a09f504c Mon Sep 17 00:00:00 2001 From: Simeng Liu Date: Sun, 20 Jul 2025 21:39:46 -0700 Subject: [PATCH 1/6] Auto-enable ngram with concurrency <= 32. Signed-off-by: Simeng Liu --- examples/llm-api/quickstart_advanced.py | 9 ++++-- tensorrt_llm/_torch/pyexecutor/py_executor.py | 7 +++-- tensorrt_llm/llmapi/llm.py | 30 ++++++++++++++++++- tensorrt_llm/llmapi/llm_args.py | 8 +++-- 4 files changed, 46 insertions(+), 8 deletions(-) diff --git a/examples/llm-api/quickstart_advanced.py b/examples/llm-api/quickstart_advanced.py index 5e447e6a0e..5a3f11053a 100644 --- a/examples/llm-api/quickstart_advanced.py +++ b/examples/llm-api/quickstart_advanced.py @@ -108,9 +108,9 @@ def add_llm_args(parser): # Speculative decoding parser.add_argument('--spec_decode_algo', type=str, default=None) - parser.add_argument('--spec_decode_max_draft_len', type=int, default=1) + parser.add_argument('--spec_decode_max_draft_len', type=int, default=0) parser.add_argument('--draft_model_dir', type=str, default=None) - parser.add_argument('--max_matching_ngram_size', type=int, default=5) + parser.add_argument('--max_matching_ngram_size', type=int, default=0) parser.add_argument('--use_one_model', default=False, action='store_true') # Relaxed acceptance @@ -152,6 +152,11 @@ def setup_llm(args, **kwargs): spec_decode_algo = args.spec_decode_algo.upper( ) if args.spec_decode_algo is not None else None + # Update spec_decode_max_draft_len to 1 if unset by the user for non-NGRAM spec_decode_algo + # NGRAM spec_decode_algo will use default heuristic to set spec_decode_max_draft_len and max_matching_ngram_size + if spec_decode_algo != "NGRAM" and args.spec_decode_max_draft_len == 0: + args.spec_decode_max_draft_len = 1 + if spec_decode_algo == 'MTP': if not args.use_one_model: print( diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 715a701398..71d22096cd 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -826,7 +826,7 @@ def _executor_loop(self): self._pad_attention_dp_dummy_request() if self.drafter is not None: - self._prepare_draft_requests(self.active_requests) + self._prepare_draft_requests() scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = self._schedule( ) @@ -913,14 +913,15 @@ def _executor_loop(self): iter_stats=iter_stats, iter_start_time=iter_start_time)) - def _prepare_draft_requests(self, requests): + def _prepare_draft_requests(self): try: # Set draft tokens here to make the KV cache manager # and scheduler aware of them. - for req in requests: + for req in self.active_requests: if req.state not in (LlmRequestState.GENERATION_IN_PROGRESS, LlmRequestState.DISAGG_GENERATION_INIT): continue + req.py_last_draft_tokens = req.py_draft_tokens max_draft_len = self.model_engine.spec_config.max_draft_len diff --git a/tensorrt_llm/llmapi/llm.py b/tensorrt_llm/llmapi/llm.py index dcf3ca9290..125a2e18c9 100644 --- a/tensorrt_llm/llmapi/llm.py +++ b/tensorrt_llm/llmapi/llm.py @@ -960,13 +960,41 @@ def _build_model(self): self._executor_config.cache_transceiver_config = PybindMirror.maybe_to_pybind( self.args.cache_transceiver_config) from tensorrt_llm._torch.pyexecutor.config import update_executor_config + + spec_config = self.args.speculative_config + max_batch_size = self._executor_config.max_batch_size + # Apply heuristic to incomplete NGramDecodingConfig based on benchmark results + # With concurrency <= 4, max_draft_len = 5, max_matching_ngram_size = 3 + # With concurrency <= 32, max_draft_len = 3, max_matching_ngram_size = 5 + if spec_config.spec_dec_mode() == "NGRAM" and max_batch_size <= 32: + if not self.args.disable_overlap_scheduler: + logger.info( + "Disable overlap scheduler to enable NGram speculative decoding." + ) + # From benchmark results, we found that NGram speculative decoding provides better performance than overlap scheduler with low concurrency <= 32. + # Therefore, we disable overlap scheduler to enable NGram speculative decoding. + self.args.disable_overlap_scheduler = True + + if spec_config.max_draft_len != 0 and spec_config.max_matching_ngram_size != 0: + pass + else: + if max_batch_size <= 4: + spec_config.max_draft_len = 5 if spec_config.max_draft_len == 0 else spec_config.max_draft_len + spec_config.max_matching_ngram_size = 3 if spec_config.max_matching_ngram_size == 0 else spec_config.max_matching_ngram_size + elif max_batch_size <= 32: + spec_config.max_draft_len = 3 if spec_config.max_draft_len == 0 else spec_config.max_draft_len + spec_config.max_matching_ngram_size = 5 if spec_config.max_matching_ngram_size == 0 else spec_config.max_matching_ngram_size + logger.info( + f"Apply heuristic to incomplete NGramDecodingConfig: max_draft_len={spec_config.max_draft_len}, max_matching_ngram_size={spec_config.max_matching_ngram_size}" + ) + update_executor_config( self._executor_config, backend=self.args.backend, pytorch_backend_config=self.args.get_pytorch_backend_config() if self.args.backend in ["pytorch", "_autodeploy"] else None, mapping=self.args.parallel_config.to_mapping(), - speculative_config=self.args.speculative_config, + speculative_config=spec_config, hf_model_dir=self._hf_model_dir, max_input_len=self.args.max_input_len, max_seq_len=max_seq_len, diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index a563bc98f2..7104b0b8ea 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -371,8 +371,12 @@ class NGramDecodingConfig(DecodingBaseConfig): is_public_pool: bool = True Whether to use a common pool for all requests, or the pool is private for each request if False. """ - - max_matching_ngram_size: int = 4 + # If max_draft_len or max_matching_ngram_size are not set by user + # Default heuristic will be use + # With concurrency <= 4, max_draft_len = 5, max_matching_ngram_size = 3 + # With concurrency <= 32, max_draft_len = 3, max_matching_ngram_size = 5 + max_draft_len: int = 0 + max_matching_ngram_size: int = 0 is_keep_all: bool = True is_use_oldest: bool = True is_public_pool: bool = True From 11b11c2db80e0b89b9832ee6670e9cbbfcfc6c2a Mon Sep 17 00:00:00 2001 From: Simeng Liu Date: Mon, 21 Jul 2025 15:48:16 -0700 Subject: [PATCH 2/6] Add AutoDecodingConfig to apply the default spec_decoding heuristic with Ngram. Signed-off-by: Simeng Liu --- examples/llm-api/quickstart_advanced.py | 25 +++++++------- tensorrt_llm/_torch/speculative/interface.py | 1 + tensorrt_llm/llmapi/__init__.py | 19 +++++------ tensorrt_llm/llmapi/llm.py | 34 ++++++++++---------- tensorrt_llm/llmapi/llm_args.py | 31 +++++++++++++++++- tensorrt_llm/models/modeling_utils.py | 3 ++ 6 files changed, 74 insertions(+), 39 deletions(-) diff --git a/examples/llm-api/quickstart_advanced.py b/examples/llm-api/quickstart_advanced.py index 5a3f11053a..df9838f363 100644 --- a/examples/llm-api/quickstart_advanced.py +++ b/examples/llm-api/quickstart_advanced.py @@ -1,10 +1,10 @@ import argparse from tensorrt_llm import LLM, SamplingParams -from tensorrt_llm.llmapi import (CudaGraphConfig, DraftTargetDecodingConfig, - EagleDecodingConfig, KvCacheConfig, MoeConfig, - MTPDecodingConfig, NGramDecodingConfig, - TorchCompileConfig) +from tensorrt_llm.llmapi import (AutoDecodingConfig, CudaGraphConfig, + DraftTargetDecodingConfig, EagleDecodingConfig, + KvCacheConfig, MoeConfig, MTPDecodingConfig, + NGramDecodingConfig, TorchCompileConfig) example_prompts = [ "Hello, my name is", @@ -107,10 +107,14 @@ def add_llm_args(parser): parser.add_argument('--max_beam_width', type=int, default=1) # Speculative decoding - parser.add_argument('--spec_decode_algo', type=str, default=None) - parser.add_argument('--spec_decode_max_draft_len', type=int, default=0) + parser.add_argument( + '--spec_decode_algo', + type=str, + default=None, + choices=['MTP', 'EAGLE3', 'DRAFT_TARGET', 'NGRAM', 'AUTO']) + parser.add_argument('--spec_decode_max_draft_len', type=int, default=1) parser.add_argument('--draft_model_dir', type=str, default=None) - parser.add_argument('--max_matching_ngram_size', type=int, default=0) + parser.add_argument('--max_matching_ngram_size', type=int, default=5) parser.add_argument('--use_one_model', default=False, action='store_true') # Relaxed acceptance @@ -152,11 +156,6 @@ def setup_llm(args, **kwargs): spec_decode_algo = args.spec_decode_algo.upper( ) if args.spec_decode_algo is not None else None - # Update spec_decode_max_draft_len to 1 if unset by the user for non-NGRAM spec_decode_algo - # NGRAM spec_decode_algo will use default heuristic to set spec_decode_max_draft_len and max_matching_ngram_size - if spec_decode_algo != "NGRAM" and args.spec_decode_max_draft_len == 0: - args.spec_decode_max_draft_len = 1 - if spec_decode_algo == 'MTP': if not args.use_one_model: print( @@ -186,6 +185,8 @@ def setup_llm(args, **kwargs): is_use_oldest=True, is_public_pool=True, ) + elif spec_decode_algo == "AUTO": + spec_config = AutoDecodingConfig() else: spec_config = None diff --git a/tensorrt_llm/_torch/speculative/interface.py b/tensorrt_llm/_torch/speculative/interface.py index 46fe18e058..d606073f00 100644 --- a/tensorrt_llm/_torch/speculative/interface.py +++ b/tensorrt_llm/_torch/speculative/interface.py @@ -18,6 +18,7 @@ class SpeculativeDecodingMode(IntEnum): DRAFT_TARGET = auto() USER_PROVIDED = auto() NONE = auto() + AUTO = auto() def is_mtp(self): return self == SpeculativeDecodingMode.MTP or self == SpeculativeDecodingMode.MTP_EAGLE diff --git a/tensorrt_llm/llmapi/__init__.py b/tensorrt_llm/llmapi/__init__.py index 24f7ad00e7..bef7ded994 100644 --- a/tensorrt_llm/llmapi/__init__.py +++ b/tensorrt_llm/llmapi/__init__.py @@ -4,15 +4,15 @@ from .build_cache import BuildCacheConfig from .llm import LLM, RequestOutput # yapf: disable -from .llm_args import (BatchingType, CacheTransceiverConfig, CalibConfig, - CapacitySchedulerPolicy, ContextChunkingPolicy, - CudaGraphConfig, DraftTargetDecodingConfig, - DynamicBatchConfig, EagleDecodingConfig, - ExtendedRuntimePerfKnobConfig, KvCacheConfig, LlmArgs, - LookaheadDecodingConfig, MedusaDecodingConfig, MoeConfig, - MTPDecodingConfig, NGramDecodingConfig, SchedulerConfig, - TorchCompileConfig, TorchLlmArgs, TrtLlmArgs, - UserProvidedDecodingConfig) +from .llm_args import (AutoDecodingConfig, BatchingType, CacheTransceiverConfig, + CalibConfig, CapacitySchedulerPolicy, + ContextChunkingPolicy, CudaGraphConfig, + DraftTargetDecodingConfig, DynamicBatchConfig, + EagleDecodingConfig, ExtendedRuntimePerfKnobConfig, + KvCacheConfig, LlmArgs, LookaheadDecodingConfig, + MedusaDecodingConfig, MoeConfig, MTPDecodingConfig, + NGramDecodingConfig, SchedulerConfig, TorchCompileConfig, + TorchLlmArgs, TrtLlmArgs, UserProvidedDecodingConfig) from .llm_utils import (BuildConfig, KvCacheRetentionConfig, QuantAlgo, QuantConfig) from .mpi_session import MpiCommSession @@ -53,4 +53,5 @@ 'LlmArgs', 'TorchLlmArgs', 'TrtLlmArgs', + 'AutoDecodingConfig', ] diff --git a/tensorrt_llm/llmapi/llm.py b/tensorrt_llm/llmapi/llm.py index 125a2e18c9..f4cf807abd 100644 --- a/tensorrt_llm/llmapi/llm.py +++ b/tensorrt_llm/llmapi/llm.py @@ -31,8 +31,8 @@ from ..logger import logger from ..sampling_params import SamplingParams from .llm_args import (TORCH_LLMARGS_EXPLICIT_DOCSTRING, - TRT_LLMARGS_EXPLICIT_DOCSTRING, PybindMirror, - TorchLlmArgs, TrtLlmArgs) + TRT_LLMARGS_EXPLICIT_DOCSTRING, NGramDecodingConfig, + PybindMirror, TorchLlmArgs, TrtLlmArgs) from .llm_utils import (CachedModelLoader, KvCacheRetentionConfig, LlmBuildStats, ModelLoader, _ModelRuntimeContext) from .mpi_session import MpiPoolSession, external_mpi_comm_available @@ -963,30 +963,30 @@ def _build_model(self): spec_config = self.args.speculative_config max_batch_size = self._executor_config.max_batch_size - # Apply heuristic to incomplete NGramDecodingConfig based on benchmark results + # Apply default heuristic to AutoDecodingConfig based on benchmark results # With concurrency <= 4, max_draft_len = 5, max_matching_ngram_size = 3 # With concurrency <= 32, max_draft_len = 3, max_matching_ngram_size = 5 - if spec_config.spec_dec_mode() == "NGRAM" and max_batch_size <= 32: + # With concurrency > 32, speculative decoding is disabled. + if spec_config is not None and spec_config.decoding_type == "AUTO" and max_batch_size <= 32: if not self.args.disable_overlap_scheduler: logger.info( - "Disable overlap scheduler to enable NGram speculative decoding." + "Disable overlap scheduler to enable Auto speculative decoding with Ngram." ) # From benchmark results, we found that NGram speculative decoding provides better performance than overlap scheduler with low concurrency <= 32. # Therefore, we disable overlap scheduler to enable NGram speculative decoding. self.args.disable_overlap_scheduler = True - if spec_config.max_draft_len != 0 and spec_config.max_matching_ngram_size != 0: - pass - else: - if max_batch_size <= 4: - spec_config.max_draft_len = 5 if spec_config.max_draft_len == 0 else spec_config.max_draft_len - spec_config.max_matching_ngram_size = 3 if spec_config.max_matching_ngram_size == 0 else spec_config.max_matching_ngram_size - elif max_batch_size <= 32: - spec_config.max_draft_len = 3 if spec_config.max_draft_len == 0 else spec_config.max_draft_len - spec_config.max_matching_ngram_size = 5 if spec_config.max_matching_ngram_size == 0 else spec_config.max_matching_ngram_size - logger.info( - f"Apply heuristic to incomplete NGramDecodingConfig: max_draft_len={spec_config.max_draft_len}, max_matching_ngram_size={spec_config.max_matching_ngram_size}" - ) + spec_config = NGramDecodingConfig( + max_draft_len=5 if max_batch_size <= 4 else 3, + max_matching_ngram_size=3 if max_batch_size <= 4 else 5, + is_keep_all=True, + is_use_oldest=True, + is_public_pool=True, + ) + + logger.info( + f"Apply heuristic to incomplete NGramDecodingConfig: max_draft_len={spec_config.max_draft_len}, max_matching_ngram_size={spec_config.max_matching_ngram_size}" + ) update_executor_config( self._executor_config, diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 7104b0b8ea..3923dc6fcc 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -262,6 +262,7 @@ def from_dict(cls, data: dict): "NGram": NGramDecodingConfig, "DraftTarget": DraftTargetDecodingConfig, "UserProvided": UserProvidedDecodingConfig, + "AUTO": AutoDecodingConfig, } config_class = config_classes.get(decoding_type) @@ -439,6 +440,29 @@ def spec_dec_mode(self): return TorchSpeculativeDecodingMode.MTP +class AutoDecodingConfig(DecodingBaseConfig): + """ + Configuration for auto speculative decoding. + + This config is used to automatically select the best speculative decoding algorithm. + + According to benchmark results, the best algorithm in general is NGRAM with low concurrency <= 32. + Default heuristic: + With concurrency <= 4, max_draft_len = 5, max_matching_ngram_size = 3 + With concurrency <= 32, max_draft_len = 3, max_matching_ngram_size = 5 + With concurrency > 32, speculative decoding is disabled. + """ + + @classmethod + def from_dict(cls, data: dict): + return cls(**data) + + decoding_type: ClassVar[str] = "AUTO" + + def supports_backend(self, backend: str) -> bool: + return backend == "pytorch" + + class PybindMirror(ABC): ''' A class containing the utilities for mirroring Python classes to pybinding classes. @@ -742,6 +766,7 @@ def supports_backend(self, backend: str) -> bool: MTPDecodingConfig, NGramDecodingConfig, UserProvidedDecodingConfig, + AutoDecodingConfig, ]] @@ -1165,7 +1190,6 @@ def from_kwargs(cls, **kwargs: Any) -> "BaseLlmArgs": tensorrt_llm.llmapi.llm_utils.BaseLlmArgs: The `BaseLlmArgs` instance. """ kwargs = BaseLlmArgs._check_consistency(dict(kwargs)) - ret = cls(**kwargs) return ret @@ -1493,6 +1517,11 @@ def validate_speculative_config(self): self.build_config.speculative_decoding_mode = SpeculativeDecodingMode.USER_PROVIDED self.build_config.max_draft_len = self.speculative_config.max_draft_len + elif isinstance(self.speculative_config, AutoDecodingConfig): + assert self.backend in ['pytorch', '_autodeploy'] + self.build_config.speculative_decoding_mode = SpeculativeDecodingMode.AUTO + self.build_config.max_draft_len = self.speculative_config.max_draft_len + else: raise ValueError( f"Unrecognized speculative config type {type(self.speculative_config)}" diff --git a/tensorrt_llm/models/modeling_utils.py b/tensorrt_llm/models/modeling_utils.py index 07d930064c..7b2af7af15 100644 --- a/tensorrt_llm/models/modeling_utils.py +++ b/tensorrt_llm/models/modeling_utils.py @@ -98,6 +98,7 @@ class SpeculativeDecodingMode(IntFlag): EAGLE = auto() NGRAM = auto() USER_PROVIDED = auto() + AUTO = auto() @staticmethod def from_arguments(args: argparse.Namespace): @@ -117,6 +118,8 @@ def from_arguments(args: argparse.Namespace): return SpeculativeDecodingMode.NGRAM elif args.speculative_decoding_mode == "user_provided": return SpeculativeDecodingMode.USER_PROVIDED + elif args.speculative_decoding_mode == "auto": + return SpeculativeDecodingMode.AUTO else: assert False, "Unknown speculative_decoding_mode " + args.speculative_decoding_mode From 27aa440d1f2a5153444610e29d61b1623872a306 Mon Sep 17 00:00:00 2001 From: Simeng Liu Date: Tue, 22 Jul 2025 16:34:37 -0700 Subject: [PATCH 3/6] Address comments Signed-off-by: Simeng Liu --- tensorrt_llm/llmapi/llm_args.py | 5 ----- tests/integration/defs/test_e2e.py | 25 +++++++++++++++++++++++++ 2 files changed, 25 insertions(+), 5 deletions(-) diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 3923dc6fcc..e15e37f4e8 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -372,11 +372,6 @@ class NGramDecodingConfig(DecodingBaseConfig): is_public_pool: bool = True Whether to use a common pool for all requests, or the pool is private for each request if False. """ - # If max_draft_len or max_matching_ngram_size are not set by user - # Default heuristic will be use - # With concurrency <= 4, max_draft_len = 5, max_matching_ngram_size = 3 - # With concurrency <= 32, max_draft_len = 3, max_matching_ngram_size = 5 - max_draft_len: int = 0 max_matching_ngram_size: int = 0 is_keep_all: bool = True is_use_oldest: bool = True diff --git a/tests/integration/defs/test_e2e.py b/tests/integration/defs/test_e2e.py index dfb0a1a0d1..e75e665bc4 100644 --- a/tests/integration/defs/test_e2e.py +++ b/tests/integration/defs/test_e2e.py @@ -1773,6 +1773,31 @@ def test_ptp_quickstart_advanced_ngram(llm_root, llm_venv, model_name, _check_mem_usage(running_log, [27.0, 0, 0, 0]) +@pytest.mark.parametrize("model_name,model_path", [ + ("Llama-3.1-8B-Instruct", "llama-3.1-model/Llama-3.1-8B-Instruct"), +]) +def test_ptp_quickstart_advanced_auto(llm_root, llm_venv, model_name, + model_path): + print(f"Testing {model_name}.") + example_root = Path(os.path.join(llm_root, "examples", "llm-api")) + with tempfile.NamedTemporaryFile(mode='w+t', + suffix=f".{model_name}.log", + dir="./", + delete=True, + delete_on_close=True) as running_log: + llm_venv.run_cmd([ + str(example_root / "quickstart_advanced.py"), + "--model_dir", + f"{llm_models_root()}/{model_path}", + "--spec_decode_algo", + "AUTO", + "--use_cuda_graph", + "--max_batch_size=4", + ], + stdout=running_log) + _check_mem_usage(running_log, [27.0, 0, 0, 0]) + + @skip_post_blackwell @pytest.mark.skip_less_device_memory(110000) @pytest.mark.skip_less_device(8) From f5c115a1324905bad6c92e64415d4582ce9e4cb8 Mon Sep 17 00:00:00 2001 From: Simeng Liu Date: Tue, 22 Jul 2025 16:43:59 -0700 Subject: [PATCH 4/6] Resolve CI failures Signed-off-by: Simeng Liu --- examples/llm-api/quickstart_advanced.py | 6 +----- tests/unittest/api_stability/references_committed/llm.yaml | 2 +- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/examples/llm-api/quickstart_advanced.py b/examples/llm-api/quickstart_advanced.py index df9838f363..de4c9b23a8 100644 --- a/examples/llm-api/quickstart_advanced.py +++ b/examples/llm-api/quickstart_advanced.py @@ -107,11 +107,7 @@ def add_llm_args(parser): parser.add_argument('--max_beam_width', type=int, default=1) # Speculative decoding - parser.add_argument( - '--spec_decode_algo', - type=str, - default=None, - choices=['MTP', 'EAGLE3', 'DRAFT_TARGET', 'NGRAM', 'AUTO']) + parser.add_argument('--spec_decode_algo', type=str, default=None) parser.add_argument('--spec_decode_max_draft_len', type=int, default=1) parser.add_argument('--draft_model_dir', type=str, default=None) parser.add_argument('--max_matching_ngram_size', type=int, default=5) diff --git a/tests/unittest/api_stability/references_committed/llm.yaml b/tests/unittest/api_stability/references_committed/llm.yaml index 66fbdabfc5..7ea5fc5695 100644 --- a/tests/unittest/api_stability/references_committed/llm.yaml +++ b/tests/unittest/api_stability/references_committed/llm.yaml @@ -59,7 +59,7 @@ methods: default: null # Speculative decoding speculative_config: - annotation: Union[tensorrt_llm.llmapi.llm_args.DraftTargetDecodingConfig, tensorrt_llm.llmapi.llm_args.EagleDecodingConfig,tensorrt_llm.llmapi.llm_args.LookaheadDecodingConfig, tensorrt_llm.llmapi.llm_args.MedusaDecodingConfig, tensorrt_llm.llmapi.llm_args.MTPDecodingConfig, tensorrt_llm.llmapi.llm_args.NGramDecodingConfig, tensorrt_llm.llmapi.llm_args.UserProvidedDecodingConfig, NoneType] + annotation: Union[tensorrt_llm.llmapi.llm_args.DraftTargetDecodingConfig, tensorrt_llm.llmapi.llm_args.EagleDecodingConfig,tensorrt_llm.llmapi.llm_args.LookaheadDecodingConfig, tensorrt_llm.llmapi.llm_args.MedusaDecodingConfig, tensorrt_llm.llmapi.llm_args.MTPDecodingConfig, tensorrt_llm.llmapi.llm_args.NGramDecodingConfig, tensorrt_llm.llmapi.llm_args.UserProvidedDecodingConfig, tensorrt_llm.llmapi.llm_args.AutoDecodingConfig, NoneType] default: null # generation constraints max_batch_size: From 71e3af9d0e95d571df8be4327d2169de01230da1 Mon Sep 17 00:00:00 2001 From: Simeng Liu Date: Fri, 25 Jul 2025 19:56:44 -0700 Subject: [PATCH 5/6] Enable AUTO speculative decoding with Ngram in LLM for all batch sizes. Turn off spec_decoding for batch size > 32 in executor_loop. Signed-off-by: Simeng Liu --- tensorrt_llm/_torch/speculative/ngram.py | 8 +++++++- tensorrt_llm/llmapi/llm.py | 4 +++- tensorrt_llm/llmapi/llm_args.py | 3 +++ 3 files changed, 13 insertions(+), 2 deletions(-) diff --git a/tensorrt_llm/_torch/speculative/ngram.py b/tensorrt_llm/_torch/speculative/ngram.py index 9113900ef9..39267f5da2 100644 --- a/tensorrt_llm/_torch/speculative/ngram.py +++ b/tensorrt_llm/_torch/speculative/ngram.py @@ -2,6 +2,7 @@ from ordered_set import OrderedSet +from tensorrt_llm.llmapi import NGramDecodingConfig from tensorrt_llm.logger import logger from ..pyexecutor.llm_request import * @@ -163,10 +164,11 @@ class NGramDrafter(Drafter): def __init__( self, - spec_config: "NGramDecodingConfig", + spec_config: NGramDecodingConfig, ngram_pool_manager: NGramPoolManager = None, ): assert ngram_pool_manager is not None, "NGram needs a resource manager to maintain the pool." + self.spec_config = spec_config self.max_draft_len = spec_config.max_draft_len self.spec_resource_manager = ngram_pool_manager @@ -175,6 +177,10 @@ def prepare_draft_tokens( scheduled_requests: ScheduledRequests, resource_manager: Optional[ResourceManager] = None, ) -> None: + # Disable NGram speculative decoding auto heuristic for batch size > 32. + if self.spec_config.is_auto_heuristic and len( + scheduled_requests.all_requests()) > 32: + return # Sort by request_id when py_batch_idx is None as a fallback. # This happens in the disagg case: for a set of new requests, we draft # before forward_step, so py_batch_idx is not assigned. diff --git a/tensorrt_llm/llmapi/llm.py b/tensorrt_llm/llmapi/llm.py index f4cf807abd..922f3f348b 100644 --- a/tensorrt_llm/llmapi/llm.py +++ b/tensorrt_llm/llmapi/llm.py @@ -967,7 +967,7 @@ def _build_model(self): # With concurrency <= 4, max_draft_len = 5, max_matching_ngram_size = 3 # With concurrency <= 32, max_draft_len = 3, max_matching_ngram_size = 5 # With concurrency > 32, speculative decoding is disabled. - if spec_config is not None and spec_config.decoding_type == "AUTO" and max_batch_size <= 32: + if spec_config is not None and spec_config.decoding_type == "AUTO": if not self.args.disable_overlap_scheduler: logger.info( "Disable overlap scheduler to enable Auto speculative decoding with Ngram." @@ -982,6 +982,8 @@ def _build_model(self): is_keep_all=True, is_use_oldest=True, is_public_pool=True, + # Flag to indicate the NGramDecodingConfig is instantiated by auto heuristic. + is_auto_heuristic=True, ) logger.info( diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index e15e37f4e8..32e56de3da 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -376,6 +376,9 @@ class NGramDecodingConfig(DecodingBaseConfig): is_keep_all: bool = True is_use_oldest: bool = True is_public_pool: bool = True + # Flag to indicate the NGramDecodingConfig is instantiated by auto heuristic. + # User should not set this flag. Use AutoDecodingConfig instead. + is_auto_heuristic: bool = False @classmethod def from_dict(cls, data: dict): From 94eaf6aa9426f609fa4dca05182100185c435387 Mon Sep 17 00:00:00 2001 From: Mike Iovine Date: Thu, 31 Jul 2025 11:06:27 -0400 Subject: [PATCH 6/6] Fix lint Signed-off-by: Mike Iovine --- tensorrt_llm/llmapi/llm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorrt_llm/llmapi/llm.py b/tensorrt_llm/llmapi/llm.py index a7e2814b47..bd24bfc2a5 100644 --- a/tensorrt_llm/llmapi/llm.py +++ b/tensorrt_llm/llmapi/llm.py @@ -31,8 +31,8 @@ from ..logger import logger from ..sampling_params import SamplingParams from .llm_args import (TORCH_LLMARGS_EXPLICIT_DOCSTRING, - TRT_LLMARGS_EXPLICIT_DOCSTRING, NGramDecodingConfig, PeftCacheConfig, - PybindMirror, TorchLlmArgs, TrtLlmArgs) + TRT_LLMARGS_EXPLICIT_DOCSTRING, NGramDecodingConfig, + PeftCacheConfig, PybindMirror, TorchLlmArgs, TrtLlmArgs) from .llm_utils import (CachedModelLoader, KvCacheRetentionConfig, LlmBuildStats, ModelLoader, _ModelRuntimeContext) from .mpi_session import MpiPoolSession, external_mpi_comm_available