Skip to content

[None][feat] Clean up ngram auto mode, add max_concurrency to configs #6676

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions tensorrt_llm/_torch/pyexecutor/py_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -871,6 +871,10 @@ def _prepare_and_schedule_batch(self):
self.use_spec_decode = self.drafter.should_use_spec_decode(
self.active_requests)
self.model_engine.enable_spec_decode = self.use_spec_decode
# If speculation is off, this function sets py_draft_tokens to None
# for all active requests. If it's on, we initialize py_draft_tokens
# with dummy draft tokens to make the scheduler aware of the fact
# that speculation is about to happen.
self._prepare_draft_requests()

scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = self._schedule(
Expand Down
8 changes: 8 additions & 0 deletions tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,14 @@ def _mangle_executor_config(executor_config: ExecutorConfig):
)
executor_config.enable_chunked_context = False

spec_config = executor_config.speculative_config
if not executor_config.pytorch_backend_config.disable_overlap_scheduler and spec_config is not None:
if not spec_config.spec_dec_mode.support_overlap_scheduler():
logger.warning(
f"Disable overlap scheduler for speculation mode {spec_config.spec_dec_mode.name}"
)
executor_config.pytorch_backend_config.disable_overlap_scheduler = True


def _get_mapping(executor_config: ExecutorConfig) -> Mapping:
if executor_config.mapping is None:
Expand Down
2 changes: 2 additions & 0 deletions tensorrt_llm/_torch/speculative/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .auto_heuristic import suggest_spec_config
from .eagle3 import Eagle3SpecMetadata
from .interface import SpecMetadata
from .mtp import MTPEagleWorker, MTPSpecMetadata, MTPWorker
Expand All @@ -23,4 +24,5 @@
"get_spec_resource_manager",
"get_spec_worker",
"update_spec_config_from_model_config",
"suggest_spec_config",
]
17 changes: 17 additions & 0 deletions tensorrt_llm/_torch/speculative/auto_heuristic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
def suggest_spec_config(max_batch_size: int) -> "DecodingBaseConfig":
"""
Suggests a reasonable draft model free speculation scheme.
Used when the user specifies spec_mode == AUTO.

For now, we always use an ngram scheme that gets disabled at
BS>=32.
"""
from tensorrt_llm.llmapi.llm_args import NGramDecodingConfig
return NGramDecodingConfig(
max_draft_len=5 if max_batch_size <= 4 else 3,
max_matching_ngram_size=3 if max_batch_size <= 4 else 5,
max_concurrency=32,
is_keep_all=True,
is_use_oldest=True,
is_public_pool=True,
)
5 changes: 5 additions & 0 deletions tensorrt_llm/_torch/speculative/drafter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
class Drafter(ABC):
"""Abstract base class for all drafter implementations."""

def __init__(self, max_concurrency: Optional[int] = None) -> None:
self.max_concurrency = max_concurrency

@abstractmethod
def prepare_draft_tokens(
self,
Expand All @@ -25,4 +28,6 @@ def prepare_draft_tokens(

def should_use_spec_decode(self, requests: List[LlmRequest]) -> bool:
"""Check if spec decode should be used for the current iteration."""
if self.max_concurrency is not None:
return len(requests) <= self.max_concurrency
return True
2 changes: 2 additions & 0 deletions tensorrt_llm/_torch/speculative/model_drafter.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ def __init__(
spec_resource_manager: Optional[BaseResourceManager] = None,
guided_decoder: Optional[GuidedDecoder] = None,
):
super().__init__(spec_config.max_concurrency)

# Validate required parameters
if draft_model_engine is None:
raise ValueError("draft_model_engine cannot be None")
Expand Down
6 changes: 1 addition & 5 deletions tensorrt_llm/_torch/speculative/ngram.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ def __init__(
spec_config: NGramDecodingConfig,
ngram_pool_manager: NGramPoolManager = None,
):
super().__init__(spec_config.max_concurrency)
assert ngram_pool_manager is not None, "NGram needs a resource manager to maintain the pool."
self.spec_config = spec_config
self.max_draft_len = spec_config.max_draft_len
Expand All @@ -178,11 +179,6 @@ def prepare_draft_tokens(
scheduled_requests: ScheduledRequests,
resource_manager: Optional[ResourceManager] = None,
) -> None:
# Disable NGram speculative decoding auto heuristic for batch size > 32.
if self.spec_config.is_auto_heuristic and len(
scheduled_requests.all_requests()) > 32:
return

# Sort by request_id when py_batch_idx is None as a fallback.
# This happens in the disagg case: for a set of new requests, we draft
# before forward_step, so py_batch_idx is not assigned.
Expand Down
32 changes: 5 additions & 27 deletions tensorrt_llm/llmapi/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@
from ..sampling_params import SamplingParams
from ..scheduling_params import SchedulingParams
from .llm_args import (TORCH_LLMARGS_EXPLICIT_DOCSTRING,
TRT_LLMARGS_EXPLICIT_DOCSTRING, NGramDecodingConfig,
PeftCacheConfig, PybindMirror, TorchLlmArgs, TrtLlmArgs)
TRT_LLMARGS_EXPLICIT_DOCSTRING, PeftCacheConfig,
PybindMirror, TorchLlmArgs, TrtLlmArgs)
from .llm_utils import (CachedModelLoader, KvCacheRetentionConfig,
LlmBuildStats, ModelLoader, _ModelRuntimeContext)
from .mpi_session import MpiPoolSession, external_mpi_comm_available
Expand Down Expand Up @@ -1015,32 +1015,10 @@ def _build_model(self):

spec_config = self.args.speculative_config
max_batch_size = self._executor_config.max_batch_size
# Apply default heuristic to AutoDecodingConfig based on benchmark results
# With concurrency <= 4, max_draft_len = 5, max_matching_ngram_size = 3
# With concurrency <= 32, max_draft_len = 3, max_matching_ngram_size = 5
# With concurrency > 32, speculative decoding is disabled.
if spec_config is not None and spec_config.decoding_type == "AUTO":
if not self.args.disable_overlap_scheduler:
logger.info(
"Disable overlap scheduler to enable Auto speculative decoding with Ngram."
)
# From benchmark results, we found that NGram speculative decoding provides better performance than overlap scheduler with low concurrency <= 32.
# Therefore, we disable overlap scheduler to enable NGram speculative decoding.
self.args.disable_overlap_scheduler = True

spec_config = NGramDecodingConfig(
max_draft_len=5 if max_batch_size <= 4 else 3,
max_matching_ngram_size=3 if max_batch_size <= 4 else 5,
is_keep_all=True,
is_use_oldest=True,
is_public_pool=True,
# Flag to indicate the NGramDecodingConfig is instantiated by auto heuristic.
is_auto_heuristic=True,
)

logger.info(
f"Apply heuristic to incomplete NGramDecodingConfig: max_draft_len={spec_config.max_draft_len}, max_matching_ngram_size={spec_config.max_matching_ngram_size}"
)
if spec_config is not None and spec_config.decoding_type == "AUTO":
from tensorrt_llm._torch.speculative import suggest_spec_config
spec_config = suggest_spec_config(max_batch_size)

update_executor_config(
self._executor_config,
Expand Down
17 changes: 8 additions & 9 deletions tensorrt_llm/llmapi/llm_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,11 @@ class DecodingBaseConfig(StrictBaseModel):
max_draft_len: Optional[int] = None
speculative_model_dir: Optional[Union[str, Path]] = None

# PyTorch only.
# When specified, speculation will be disabled at batch sizes above
# this value. Otherwise, speculation will always be on.
max_concurrency: Optional[int] = None

@classmethod
def from_dict(cls, data: dict):
# dispatch to the correct decoding config
Expand Down Expand Up @@ -469,9 +474,6 @@ class NGramDecodingConfig(DecodingBaseConfig):
is_keep_all: bool = True
is_use_oldest: bool = True
is_public_pool: bool = True
# Flag to indicate the NGramDecodingConfig is instantiated by auto heuristic.
# User should not set this flag. Use AutoDecodingConfig instead.
is_auto_heuristic: bool = False

@classmethod
def from_dict(cls, data: dict):
Expand Down Expand Up @@ -535,13 +537,10 @@ class AutoDecodingConfig(DecodingBaseConfig):
"""
Configuration for auto speculative decoding.

This config is used to automatically select the best speculative decoding algorithm.
This config will automatically select a good, draft-model free
speculation algorithm with some heuristic.

According to benchmark results, the best algorithm in general is NGRAM with low concurrency <= 32.
Default heuristic:
With concurrency <= 4, max_draft_len = 5, max_matching_ngram_size = 3
With concurrency <= 32, max_draft_len = 3, max_matching_ngram_size = 5
With concurrency > 32, speculative decoding is disabled.
Attributes that are inherited from the base class are ignored.
"""

@classmethod
Expand Down