Skip to content

Commit e968f98

Browse files
authored
[None][feat] Clean up ngram auto mode, add max_concurrency to configs (#6676)
Signed-off-by: Mike Iovine <[email protected]>
1 parent 4055b76 commit e968f98

File tree

9 files changed

+52
-41
lines changed

9 files changed

+52
-41
lines changed

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -871,6 +871,10 @@ def _prepare_and_schedule_batch(self):
871871
self.use_spec_decode = self.drafter.should_use_spec_decode(
872872
self.active_requests)
873873
self.model_engine.enable_spec_decode = self.use_spec_decode
874+
# If speculation is off, this function sets py_draft_tokens to None
875+
# for all active requests. If it's on, we initialize py_draft_tokens
876+
# with dummy draft tokens to make the scheduler aware of the fact
877+
# that speculation is about to happen.
874878
self._prepare_draft_requests()
875879

876880
scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = self._schedule(

tensorrt_llm/_torch/pyexecutor/py_executor_creator.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,14 @@ def _mangle_executor_config(executor_config: ExecutorConfig):
170170
)
171171
executor_config.enable_chunked_context = False
172172

173+
spec_config = executor_config.speculative_config
174+
if not executor_config.pytorch_backend_config.disable_overlap_scheduler and spec_config is not None:
175+
if not spec_config.spec_dec_mode.support_overlap_scheduler():
176+
logger.warning(
177+
f"Disable overlap scheduler for speculation mode {spec_config.spec_dec_mode.name}"
178+
)
179+
executor_config.pytorch_backend_config.disable_overlap_scheduler = True
180+
173181

174182
def _get_mapping(executor_config: ExecutorConfig) -> Mapping:
175183
if executor_config.mapping is None:

tensorrt_llm/_torch/speculative/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from .auto_heuristic import suggest_spec_config
12
from .eagle3 import Eagle3SpecMetadata
23
from .interface import SpecMetadata
34
from .mtp import MTPEagleWorker, MTPSpecMetadata, MTPWorker
@@ -23,4 +24,5 @@
2324
"get_spec_resource_manager",
2425
"get_spec_worker",
2526
"update_spec_config_from_model_config",
27+
"suggest_spec_config",
2628
]
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
def suggest_spec_config(max_batch_size: int) -> "DecodingBaseConfig":
2+
"""
3+
Suggests a reasonable draft model free speculation scheme.
4+
Used when the user specifies spec_mode == AUTO.
5+
6+
For now, we always use an ngram scheme that gets disabled at
7+
BS>=32.
8+
"""
9+
from tensorrt_llm.llmapi.llm_args import NGramDecodingConfig
10+
return NGramDecodingConfig(
11+
max_draft_len=5 if max_batch_size <= 4 else 3,
12+
max_matching_ngram_size=3 if max_batch_size <= 4 else 5,
13+
max_concurrency=32,
14+
is_keep_all=True,
15+
is_use_oldest=True,
16+
is_public_pool=True,
17+
)

tensorrt_llm/_torch/speculative/drafter.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99
class Drafter(ABC):
1010
"""Abstract base class for all drafter implementations."""
1111

12+
def __init__(self, max_concurrency: Optional[int] = None) -> None:
13+
self.max_concurrency = max_concurrency
14+
1215
@abstractmethod
1316
def prepare_draft_tokens(
1417
self,
@@ -25,4 +28,6 @@ def prepare_draft_tokens(
2528

2629
def should_use_spec_decode(self, requests: List[LlmRequest]) -> bool:
2730
"""Check if spec decode should be used for the current iteration."""
31+
if self.max_concurrency is not None:
32+
return len(requests) <= self.max_concurrency
2833
return True

tensorrt_llm/_torch/speculative/model_drafter.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ def __init__(
4848
spec_resource_manager: Optional[BaseResourceManager] = None,
4949
guided_decoder: Optional[GuidedDecoder] = None,
5050
):
51+
super().__init__(spec_config.max_concurrency)
52+
5153
# Validate required parameters
5254
if draft_model_engine is None:
5355
raise ValueError("draft_model_engine cannot be None")

tensorrt_llm/_torch/speculative/ngram.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@ def __init__(
168168
spec_config: NGramDecodingConfig,
169169
ngram_pool_manager: NGramPoolManager = None,
170170
):
171+
super().__init__(spec_config.max_concurrency)
171172
assert ngram_pool_manager is not None, "NGram needs a resource manager to maintain the pool."
172173
self.spec_config = spec_config
173174
self.max_draft_len = spec_config.max_draft_len
@@ -178,11 +179,6 @@ def prepare_draft_tokens(
178179
scheduled_requests: ScheduledRequests,
179180
resource_manager: Optional[ResourceManager] = None,
180181
) -> None:
181-
# Disable NGram speculative decoding auto heuristic for batch size > 32.
182-
if self.spec_config.is_auto_heuristic and len(
183-
scheduled_requests.all_requests()) > 32:
184-
return
185-
186182
# Sort by request_id when py_batch_idx is None as a fallback.
187183
# This happens in the disagg case: for a set of new requests, we draft
188184
# before forward_step, so py_batch_idx is not assigned.

tensorrt_llm/llmapi/llm.py

Lines changed: 5 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@
3232
from ..sampling_params import SamplingParams
3333
from ..scheduling_params import SchedulingParams
3434
from .llm_args import (TORCH_LLMARGS_EXPLICIT_DOCSTRING,
35-
TRT_LLMARGS_EXPLICIT_DOCSTRING, NGramDecodingConfig,
36-
PeftCacheConfig, PybindMirror, TorchLlmArgs, TrtLlmArgs)
35+
TRT_LLMARGS_EXPLICIT_DOCSTRING, PeftCacheConfig,
36+
PybindMirror, TorchLlmArgs, TrtLlmArgs)
3737
from .llm_utils import (CachedModelLoader, KvCacheRetentionConfig,
3838
LlmBuildStats, ModelLoader, _ModelRuntimeContext)
3939
from .mpi_session import MpiPoolSession, external_mpi_comm_available
@@ -1015,32 +1015,10 @@ def _build_model(self):
10151015

10161016
spec_config = self.args.speculative_config
10171017
max_batch_size = self._executor_config.max_batch_size
1018-
# Apply default heuristic to AutoDecodingConfig based on benchmark results
1019-
# With concurrency <= 4, max_draft_len = 5, max_matching_ngram_size = 3
1020-
# With concurrency <= 32, max_draft_len = 3, max_matching_ngram_size = 5
1021-
# With concurrency > 32, speculative decoding is disabled.
1022-
if spec_config is not None and spec_config.decoding_type == "AUTO":
1023-
if not self.args.disable_overlap_scheduler:
1024-
logger.info(
1025-
"Disable overlap scheduler to enable Auto speculative decoding with Ngram."
1026-
)
1027-
# From benchmark results, we found that NGram speculative decoding provides better performance than overlap scheduler with low concurrency <= 32.
1028-
# Therefore, we disable overlap scheduler to enable NGram speculative decoding.
1029-
self.args.disable_overlap_scheduler = True
1030-
1031-
spec_config = NGramDecodingConfig(
1032-
max_draft_len=5 if max_batch_size <= 4 else 3,
1033-
max_matching_ngram_size=3 if max_batch_size <= 4 else 5,
1034-
is_keep_all=True,
1035-
is_use_oldest=True,
1036-
is_public_pool=True,
1037-
# Flag to indicate the NGramDecodingConfig is instantiated by auto heuristic.
1038-
is_auto_heuristic=True,
1039-
)
10401018

1041-
logger.info(
1042-
f"Apply heuristic to incomplete NGramDecodingConfig: max_draft_len={spec_config.max_draft_len}, max_matching_ngram_size={spec_config.max_matching_ngram_size}"
1043-
)
1019+
if spec_config is not None and spec_config.decoding_type == "AUTO":
1020+
from tensorrt_llm._torch.speculative import suggest_spec_config
1021+
spec_config = suggest_spec_config(max_batch_size)
10441022

10451023
update_executor_config(
10461024
self._executor_config,

tensorrt_llm/llmapi/llm_args.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,11 @@ class DecodingBaseConfig(StrictBaseModel):
342342
max_draft_len: Optional[int] = None
343343
speculative_model_dir: Optional[Union[str, Path]] = None
344344

345+
# PyTorch only.
346+
# When specified, speculation will be disabled at batch sizes above
347+
# this value. Otherwise, speculation will always be on.
348+
max_concurrency: Optional[int] = None
349+
345350
@classmethod
346351
def from_dict(cls, data: dict):
347352
# dispatch to the correct decoding config
@@ -469,9 +474,6 @@ class NGramDecodingConfig(DecodingBaseConfig):
469474
is_keep_all: bool = True
470475
is_use_oldest: bool = True
471476
is_public_pool: bool = True
472-
# Flag to indicate the NGramDecodingConfig is instantiated by auto heuristic.
473-
# User should not set this flag. Use AutoDecodingConfig instead.
474-
is_auto_heuristic: bool = False
475477

476478
@classmethod
477479
def from_dict(cls, data: dict):
@@ -535,13 +537,10 @@ class AutoDecodingConfig(DecodingBaseConfig):
535537
"""
536538
Configuration for auto speculative decoding.
537539
538-
This config is used to automatically select the best speculative decoding algorithm.
540+
This config will automatically select a good, draft-model free
541+
speculation algorithm with some heuristic.
539542
540-
According to benchmark results, the best algorithm in general is NGRAM with low concurrency <= 32.
541-
Default heuristic:
542-
With concurrency <= 4, max_draft_len = 5, max_matching_ngram_size = 3
543-
With concurrency <= 32, max_draft_len = 3, max_matching_ngram_size = 5
544-
With concurrency > 32, speculative decoding is disabled.
543+
Attributes that are inherited from the base class are ignored.
545544
"""
546545

547546
@classmethod

0 commit comments

Comments
 (0)