Skip to content

Commit 21cfc3d

Browse files
committed
Add AutoDecodingConfig to apply the default spec_decoding heuristic with Ngram.
Signed-off-by: Simeng Liu <[email protected]>
1 parent c61979d commit 21cfc3d

File tree

6 files changed

+67
-33
lines changed

6 files changed

+67
-33
lines changed

examples/llm-api/quickstart_advanced.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import argparse
22

33
from tensorrt_llm import LLM, SamplingParams
4-
from tensorrt_llm.llmapi import (CudaGraphConfig, DraftTargetDecodingConfig,
5-
EagleDecodingConfig, KvCacheConfig, MoeConfig,
6-
MTPDecodingConfig, NGramDecodingConfig,
7-
TorchCompileConfig)
4+
from tensorrt_llm.llmapi import (AutoDecodingConfig, CudaGraphConfig,
5+
DraftTargetDecodingConfig, EagleDecodingConfig,
6+
KvCacheConfig, MoeConfig, MTPDecodingConfig,
7+
NGramDecodingConfig, TorchCompileConfig)
88

99
example_prompts = [
1010
"Hello, my name is",
@@ -107,7 +107,11 @@ def add_llm_args(parser):
107107
parser.add_argument('--max_beam_width', type=int, default=1)
108108

109109
# Speculative decoding
110-
parser.add_argument('--spec_decode_algo', type=str, default=None)
110+
parser.add_argument(
111+
'--spec_decode_algo',
112+
type=str,
113+
default=None,
114+
choices=['MTP', 'EAGLE3', 'DRAFT_TARGET', 'NGRAM', 'AUTO'])
111115
parser.add_argument('--spec_decode_max_draft_len', type=int, default=0)
112116
parser.add_argument('--draft_model_dir', type=str, default=None)
113117
parser.add_argument('--max_matching_ngram_size', type=int, default=0)
@@ -152,11 +156,6 @@ def setup_llm(args, **kwargs):
152156
spec_decode_algo = args.spec_decode_algo.upper(
153157
) if args.spec_decode_algo is not None else None
154158

155-
# Update spec_decode_max_draft_len to 1 if unset by the user for non-NGRAM spec_decode_algo
156-
# NGRAM spec_decode_algo will use default heuristic to set spec_decode_max_draft_len and max_matching_ngram_size
157-
if spec_decode_algo != "NGRAM" and args.spec_decode_max_draft_len == 0:
158-
args.spec_decode_max_draft_len = 1
159-
160159
if spec_decode_algo == 'MTP':
161160
if not args.use_one_model:
162161
print(
@@ -186,6 +185,8 @@ def setup_llm(args, **kwargs):
186185
is_use_oldest=True,
187186
is_public_pool=True,
188187
)
188+
elif spec_decode_algo == "AUTO":
189+
spec_config = AutoDecodingConfig()
189190
else:
190191
spec_config = None
191192

tensorrt_llm/_torch/speculative/interface.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ class SpeculativeDecodingMode(IntEnum):
1818
DRAFT_TARGET = auto()
1919
USER_PROVIDED = auto()
2020
NONE = auto()
21+
AUTO = auto()
2122

2223
def is_mtp(self):
2324
return self == SpeculativeDecodingMode.MTP or self == SpeculativeDecodingMode.MTP_EAGLE

tensorrt_llm/llmapi/__init__.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,11 @@
88
CapacitySchedulerPolicy, ContextChunkingPolicy,
99
CudaGraphConfig, DraftTargetDecodingConfig,
1010
DynamicBatchConfig, EagleDecodingConfig,
11-
ExtendedRuntimePerfKnobConfig, KvCacheConfig, LlmArgs,
12-
LookaheadDecodingConfig, MedusaDecodingConfig, MoeConfig,
13-
MTPDecodingConfig, NGramDecodingConfig, SchedulerConfig,
14-
TorchCompileConfig, TorchLlmArgs, TrtLlmArgs,
15-
UserProvidedDecodingConfig)
11+
ExtendedRuntimePerfKnobConfig, KvCacheConfig,
12+
LlmArgs, LookaheadDecodingConfig, MedusaDecodingConfig,
13+
MoeConfig, MTPDecodingConfig, NGramDecodingConfig,
14+
SchedulerConfig, TorchCompileConfig, TorchLlmArgs,
15+
TrtLlmArgs, UserProvidedDecodingConfig)
1616
from .llm_utils import (BuildConfig, KvCacheRetentionConfig, QuantAlgo,
1717
QuantConfig)
1818
from .mpi_session import MpiCommSession

tensorrt_llm/llmapi/llm.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@
3131
from ..logger import logger
3232
from ..sampling_params import SamplingParams
3333
from .llm_args import (TORCH_LLMARGS_EXPLICIT_DOCSTRING,
34-
TRT_LLMARGS_EXPLICIT_DOCSTRING, PybindMirror,
35-
TorchLlmArgs, TrtLlmArgs)
34+
TRT_LLMARGS_EXPLICIT_DOCSTRING, NGramDecodingConfig,
35+
PybindMirror, TorchLlmArgs, TrtLlmArgs)
3636
from .llm_utils import (CachedModelLoader, KvCacheRetentionConfig,
3737
LlmBuildStats, ModelLoader, _ModelRuntimeContext)
3838
from .mpi_session import MpiPoolSession, external_mpi_comm_available
@@ -959,30 +959,30 @@ def _build_model(self):
959959

960960
spec_config = self.args.speculative_config
961961
max_batch_size = self._executor_config.max_batch_size
962-
# Apply heuristic to incomplete NGramDecodingConfig based on benchmark results
962+
# Apply default heuristic to AutoDecodingConfig based on benchmark results
963963
# With concurrency <= 4, max_draft_len = 5, max_matching_ngram_size = 3
964964
# With concurrency <= 32, max_draft_len = 3, max_matching_ngram_size = 5
965-
if spec_config.spec_dec_mode() == "NGRAM" and max_batch_size <= 32:
965+
# With concurrency > 32, speculative decoding is disabled.
966+
if spec_config is not None and spec_config.decoding_type == "AUTO" and max_batch_size <= 32:
966967
if not self.args.disable_overlap_scheduler:
967968
logger.info(
968-
"Disable overlap scheduler to enable NGram speculative decoding."
969+
"Disable overlap scheduler to enable Auto speculative decoding with Ngram."
969970
)
970971
# From benchmark results, we found that NGram speculative decoding provides better performance than overlap scheduler with low concurrency <= 32.
971972
# Therefore, we disable overlap scheduler to enable NGram speculative decoding.
972973
self.args.disable_overlap_scheduler = True
973974

974-
if spec_config.max_draft_len != 0 and spec_config.max_matching_ngram_size != 0:
975-
pass
976-
else:
977-
if max_batch_size <= 4:
978-
spec_config.max_draft_len = 5 if spec_config.max_draft_len == 0 else spec_config.max_draft_len
979-
spec_config.max_matching_ngram_size = 3 if spec_config.max_matching_ngram_size == 0 else spec_config.max_matching_ngram_size
980-
elif max_batch_size <= 32:
981-
spec_config.max_draft_len = 3 if spec_config.max_draft_len == 0 else spec_config.max_draft_len
982-
spec_config.max_matching_ngram_size = 5 if spec_config.max_matching_ngram_size == 0 else spec_config.max_matching_ngram_size
983-
logger.info(
984-
f"Apply heuristic to incomplete NGramDecodingConfig: max_draft_len={spec_config.max_draft_len}, max_matching_ngram_size={spec_config.max_matching_ngram_size}"
985-
)
975+
spec_config = NGramDecodingConfig(
976+
max_draft_len=5 if max_batch_size <= 4 else 3,
977+
max_matching_ngram_size=3 if max_batch_size <= 4 else 5,
978+
is_keep_all=True,
979+
is_use_oldest=True,
980+
is_public_pool=True,
981+
)
982+
983+
logger.info(
984+
f"Apply heuristic to incomplete NGramDecodingConfig: max_draft_len={spec_config.max_draft_len}, max_matching_ngram_size={spec_config.max_matching_ngram_size}"
985+
)
986986

987987
update_executor_config(
988988
self._executor_config,

tensorrt_llm/llmapi/llm_args.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,7 @@ def from_dict(cls, data: dict):
262262
"NGram": NGramDecodingConfig,
263263
"DraftTarget": DraftTargetDecodingConfig,
264264
"UserProvided": UserProvidedDecodingConfig,
265+
"AUTO": AutoDecodingConfig,
265266
}
266267

267268
config_class = config_classes.get(decoding_type)
@@ -458,6 +459,29 @@ def update_from_model_config(self, model_config):
458459
self.num_extra_kv_tokens = self.num_nextn_predict_layers - 1
459460

460461

462+
class AutoDecodingConfig(DecodingBaseConfig):
463+
"""
464+
Configuration for auto speculative decoding.
465+
466+
This config is used to automatically select the best speculative decoding algorithm.
467+
468+
According to benchmark results, the best algorithm in general is NGRAM with low concurrency <= 32.
469+
Default heuristic:
470+
With concurrency <= 4, max_draft_len = 5, max_matching_ngram_size = 3
471+
With concurrency <= 32, max_draft_len = 3, max_matching_ngram_size = 5
472+
With concurrency > 32, speculative decoding is disabled.
473+
"""
474+
475+
@classmethod
476+
def from_dict(cls, data: dict):
477+
return cls(**data)
478+
479+
decoding_type: ClassVar[str] = "AUTO"
480+
481+
def supports_backend(self, backend: str) -> bool:
482+
return backend == "pytorch"
483+
484+
461485
class PybindMirror(ABC):
462486
''' A class containing the utilities for mirroring Python classes to
463487
pybinding classes.
@@ -761,6 +785,7 @@ def supports_backend(self, backend: str) -> bool:
761785
MTPDecodingConfig,
762786
NGramDecodingConfig,
763787
UserProvidedDecodingConfig,
788+
AutoDecodingConfig,
764789
]]
765790

766791

@@ -1178,7 +1203,6 @@ def from_kwargs(cls, **kwargs: Any) -> "BaseLlmArgs":
11781203
tensorrt_llm.llmapi.llm_utils.BaseLlmArgs: The `BaseLlmArgs` instance.
11791204
"""
11801205
kwargs = BaseLlmArgs._check_consistency(dict(kwargs))
1181-
11821206
ret = cls(**kwargs)
11831207
return ret
11841208

@@ -1507,6 +1531,11 @@ def validate_speculative_config(self):
15071531
self.build_config.speculative_decoding_mode = SpeculativeDecodingMode.USER_PROVIDED
15081532
self.build_config.max_draft_len = self.speculative_config.max_draft_len
15091533

1534+
elif isinstance(self.speculative_config, AutoDecodingConfig):
1535+
assert self.backend in ['pytorch', '_autodeploy']
1536+
self.build_config.speculative_decoding_mode = SpeculativeDecodingMode.AUTO
1537+
self.build_config.max_draft_len = self.speculative_config.max_draft_len
1538+
15101539
else:
15111540
raise ValueError(
15121541
f"Unrecognized speculative config type {type(self.speculative_config)}"

tensorrt_llm/models/modeling_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ class SpeculativeDecodingMode(IntFlag):
9898
EAGLE = auto()
9999
NGRAM = auto()
100100
USER_PROVIDED = auto()
101+
AUTO = auto()
101102

102103
@staticmethod
103104
def from_arguments(args: argparse.Namespace):
@@ -117,6 +118,8 @@ def from_arguments(args: argparse.Namespace):
117118
return SpeculativeDecodingMode.NGRAM
118119
elif args.speculative_decoding_mode == "user_provided":
119120
return SpeculativeDecodingMode.USER_PROVIDED
121+
elif args.speculative_decoding_mode == "auto":
122+
return SpeculativeDecodingMode.AUTO
120123
else:
121124
assert False, "Unknown speculative_decoding_mode " + args.speculative_decoding_mode
122125

0 commit comments

Comments
 (0)