From 27c401e661204eb5ab1f2c631896392645882319 Mon Sep 17 00:00:00 2001 From: Mike Iovine <6158008+mikeiovine@users.noreply.github.com> Date: Wed, 6 Aug 2025 13:46:29 -0700 Subject: [PATCH] [feat] Clean up ngram auto mode, add max_concurrency to DecodingBaseConfig Signed-off-by: Mike Iovine <6158008+mikeiovine@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/py_executor.py | 4 +++ .../_torch/pyexecutor/py_executor_creator.py | 8 +++++ tensorrt_llm/_torch/speculative/__init__.py | 2 ++ .../_torch/speculative/auto_heuristic.py | 17 ++++++++++ tensorrt_llm/_torch/speculative/drafter.py | 5 +++ .../_torch/speculative/model_drafter.py | 2 ++ tensorrt_llm/_torch/speculative/ngram.py | 6 +--- tensorrt_llm/llmapi/llm.py | 32 +++---------------- tensorrt_llm/llmapi/llm_args.py | 17 +++++----- 9 files changed, 52 insertions(+), 41 deletions(-) create mode 100644 tensorrt_llm/_torch/speculative/auto_heuristic.py diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 1b028d097e1..dda9858e63a 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -871,6 +871,10 @@ def _prepare_and_schedule_batch(self): self.use_spec_decode = self.drafter.should_use_spec_decode( self.active_requests) self.model_engine.enable_spec_decode = self.use_spec_decode + # If speculation is off, this function sets py_draft_tokens to None + # for all active requests. If it's on, we initialize py_draft_tokens + # with dummy draft tokens to make the scheduler aware of the fact + # that speculation is about to happen. self._prepare_draft_requests() scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = self._schedule( diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py index af3ee4040a5..9292a615a1e 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py @@ -170,6 +170,14 @@ def _mangle_executor_config(executor_config: ExecutorConfig): ) executor_config.enable_chunked_context = False + spec_config = executor_config.speculative_config + if not executor_config.pytorch_backend_config.disable_overlap_scheduler and spec_config is not None: + if not spec_config.spec_dec_mode.support_overlap_scheduler(): + logger.warning( + f"Disable overlap scheduler for speculation mode {spec_config.spec_dec_mode.name}" + ) + executor_config.pytorch_backend_config.disable_overlap_scheduler = True + def _get_mapping(executor_config: ExecutorConfig) -> Mapping: if executor_config.mapping is None: diff --git a/tensorrt_llm/_torch/speculative/__init__.py b/tensorrt_llm/_torch/speculative/__init__.py index 6918b573905..0856cd46d92 100644 --- a/tensorrt_llm/_torch/speculative/__init__.py +++ b/tensorrt_llm/_torch/speculative/__init__.py @@ -1,3 +1,4 @@ +from .auto_heuristic import suggest_spec_config from .eagle3 import Eagle3SpecMetadata from .interface import SpecMetadata from .mtp import MTPEagleWorker, MTPSpecMetadata, MTPWorker @@ -23,4 +24,5 @@ "get_spec_resource_manager", "get_spec_worker", "update_spec_config_from_model_config", + "suggest_spec_config", ] diff --git a/tensorrt_llm/_torch/speculative/auto_heuristic.py b/tensorrt_llm/_torch/speculative/auto_heuristic.py new file mode 100644 index 00000000000..907909beb87 --- /dev/null +++ b/tensorrt_llm/_torch/speculative/auto_heuristic.py @@ -0,0 +1,17 @@ +def suggest_spec_config(max_batch_size: int) -> "DecodingBaseConfig": + """ + Suggests a reasonable draft model free speculation scheme. + Used when the user specifies spec_mode == AUTO. + + For now, we always use an ngram scheme that gets disabled at + BS>=32. + """ + from tensorrt_llm.llmapi.llm_args import NGramDecodingConfig + return NGramDecodingConfig( + max_draft_len=5 if max_batch_size <= 4 else 3, + max_matching_ngram_size=3 if max_batch_size <= 4 else 5, + max_concurrency=32, + is_keep_all=True, + is_use_oldest=True, + is_public_pool=True, + ) diff --git a/tensorrt_llm/_torch/speculative/drafter.py b/tensorrt_llm/_torch/speculative/drafter.py index 9624193d457..4f2ea0b70b1 100644 --- a/tensorrt_llm/_torch/speculative/drafter.py +++ b/tensorrt_llm/_torch/speculative/drafter.py @@ -9,6 +9,9 @@ class Drafter(ABC): """Abstract base class for all drafter implementations.""" + def __init__(self, max_concurrency: Optional[int] = None) -> None: + self.max_concurrency = max_concurrency + @abstractmethod def prepare_draft_tokens( self, @@ -25,4 +28,6 @@ def prepare_draft_tokens( def should_use_spec_decode(self, requests: List[LlmRequest]) -> bool: """Check if spec decode should be used for the current iteration.""" + if self.max_concurrency is not None: + return len(requests) <= self.max_concurrency return True diff --git a/tensorrt_llm/_torch/speculative/model_drafter.py b/tensorrt_llm/_torch/speculative/model_drafter.py index fba50db175f..7f11142c3fa 100644 --- a/tensorrt_llm/_torch/speculative/model_drafter.py +++ b/tensorrt_llm/_torch/speculative/model_drafter.py @@ -48,6 +48,8 @@ def __init__( spec_resource_manager: Optional[BaseResourceManager] = None, guided_decoder: Optional[GuidedDecoder] = None, ): + super().__init__(spec_config.max_concurrency) + # Validate required parameters if draft_model_engine is None: raise ValueError("draft_model_engine cannot be None") diff --git a/tensorrt_llm/_torch/speculative/ngram.py b/tensorrt_llm/_torch/speculative/ngram.py index 6ca615de34b..2de21162c43 100644 --- a/tensorrt_llm/_torch/speculative/ngram.py +++ b/tensorrt_llm/_torch/speculative/ngram.py @@ -168,6 +168,7 @@ def __init__( spec_config: NGramDecodingConfig, ngram_pool_manager: NGramPoolManager = None, ): + super().__init__(spec_config.max_concurrency) assert ngram_pool_manager is not None, "NGram needs a resource manager to maintain the pool." self.spec_config = spec_config self.max_draft_len = spec_config.max_draft_len @@ -178,11 +179,6 @@ def prepare_draft_tokens( scheduled_requests: ScheduledRequests, resource_manager: Optional[ResourceManager] = None, ) -> None: - # Disable NGram speculative decoding auto heuristic for batch size > 32. - if self.spec_config.is_auto_heuristic and len( - scheduled_requests.all_requests()) > 32: - return - # Sort by request_id when py_batch_idx is None as a fallback. # This happens in the disagg case: for a set of new requests, we draft # before forward_step, so py_batch_idx is not assigned. diff --git a/tensorrt_llm/llmapi/llm.py b/tensorrt_llm/llmapi/llm.py index 321ec11bd75..12bb079eaf5 100644 --- a/tensorrt_llm/llmapi/llm.py +++ b/tensorrt_llm/llmapi/llm.py @@ -32,8 +32,8 @@ from ..sampling_params import SamplingParams from ..scheduling_params import SchedulingParams from .llm_args import (TORCH_LLMARGS_EXPLICIT_DOCSTRING, - TRT_LLMARGS_EXPLICIT_DOCSTRING, NGramDecodingConfig, - PeftCacheConfig, PybindMirror, TorchLlmArgs, TrtLlmArgs) + TRT_LLMARGS_EXPLICIT_DOCSTRING, PeftCacheConfig, + PybindMirror, TorchLlmArgs, TrtLlmArgs) from .llm_utils import (CachedModelLoader, KvCacheRetentionConfig, LlmBuildStats, ModelLoader, _ModelRuntimeContext) from .mpi_session import MpiPoolSession, external_mpi_comm_available @@ -1015,32 +1015,10 @@ def _build_model(self): spec_config = self.args.speculative_config max_batch_size = self._executor_config.max_batch_size - # Apply default heuristic to AutoDecodingConfig based on benchmark results - # With concurrency <= 4, max_draft_len = 5, max_matching_ngram_size = 3 - # With concurrency <= 32, max_draft_len = 3, max_matching_ngram_size = 5 - # With concurrency > 32, speculative decoding is disabled. - if spec_config is not None and spec_config.decoding_type == "AUTO": - if not self.args.disable_overlap_scheduler: - logger.info( - "Disable overlap scheduler to enable Auto speculative decoding with Ngram." - ) - # From benchmark results, we found that NGram speculative decoding provides better performance than overlap scheduler with low concurrency <= 32. - # Therefore, we disable overlap scheduler to enable NGram speculative decoding. - self.args.disable_overlap_scheduler = True - - spec_config = NGramDecodingConfig( - max_draft_len=5 if max_batch_size <= 4 else 3, - max_matching_ngram_size=3 if max_batch_size <= 4 else 5, - is_keep_all=True, - is_use_oldest=True, - is_public_pool=True, - # Flag to indicate the NGramDecodingConfig is instantiated by auto heuristic. - is_auto_heuristic=True, - ) - logger.info( - f"Apply heuristic to incomplete NGramDecodingConfig: max_draft_len={spec_config.max_draft_len}, max_matching_ngram_size={spec_config.max_matching_ngram_size}" - ) + if spec_config is not None and spec_config.decoding_type == "AUTO": + from tensorrt_llm._torch.speculative import suggest_spec_config + spec_config = suggest_spec_config(max_batch_size) update_executor_config( self._executor_config, diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 4e393dfdbcf..0481199a9b4 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -342,6 +342,11 @@ class DecodingBaseConfig(StrictBaseModel): max_draft_len: Optional[int] = None speculative_model_dir: Optional[Union[str, Path]] = None + # PyTorch only. + # When specified, speculation will be disabled at batch sizes above + # this value. Otherwise, speculation will always be on. + max_concurrency: Optional[int] = None + @classmethod def from_dict(cls, data: dict): # dispatch to the correct decoding config @@ -469,9 +474,6 @@ class NGramDecodingConfig(DecodingBaseConfig): is_keep_all: bool = True is_use_oldest: bool = True is_public_pool: bool = True - # Flag to indicate the NGramDecodingConfig is instantiated by auto heuristic. - # User should not set this flag. Use AutoDecodingConfig instead. - is_auto_heuristic: bool = False @classmethod def from_dict(cls, data: dict): @@ -535,13 +537,10 @@ class AutoDecodingConfig(DecodingBaseConfig): """ Configuration for auto speculative decoding. - This config is used to automatically select the best speculative decoding algorithm. + This config will automatically select a good, draft-model free + speculation algorithm with some heuristic. - According to benchmark results, the best algorithm in general is NGRAM with low concurrency <= 32. - Default heuristic: - With concurrency <= 4, max_draft_len = 5, max_matching_ngram_size = 3 - With concurrency <= 32, max_draft_len = 3, max_matching_ngram_size = 5 - With concurrency > 32, speculative decoding is disabled. + Attributes that are inherited from the base class are ignored. """ @classmethod