Skip to content

Commit 8cf3faa

Browse files
[feat] Auto-enable ngram with concurrency <= 32. (#6232)
Signed-off-by: Simeng Liu <[email protected]> Signed-off-by: Mike Iovine <[email protected]> Signed-off-by: Mike Iovine <[email protected]> Co-authored-by: Mike Iovine <[email protected]> Co-authored-by: Mike Iovine <[email protected]>
1 parent 8062e0f commit 8cf3faa

File tree

10 files changed

+124
-24
lines changed

10 files changed

+124
-24
lines changed

examples/llm-api/quickstart_advanced.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import argparse
22

33
from tensorrt_llm import LLM, SamplingParams
4-
from tensorrt_llm.llmapi import (CudaGraphConfig, DraftTargetDecodingConfig,
5-
EagleDecodingConfig, KvCacheConfig, MoeConfig,
6-
MTPDecodingConfig, NGramDecodingConfig,
7-
TorchCompileConfig)
4+
from tensorrt_llm.llmapi import (AutoDecodingConfig, CudaGraphConfig,
5+
DraftTargetDecodingConfig, EagleDecodingConfig,
6+
KvCacheConfig, MoeConfig, MTPDecodingConfig,
7+
NGramDecodingConfig, TorchCompileConfig)
88

99
example_prompts = [
1010
"Hello, my name is",
@@ -181,6 +181,8 @@ def setup_llm(args, **kwargs):
181181
is_use_oldest=True,
182182
is_public_pool=True,
183183
)
184+
elif spec_decode_algo == "AUTO":
185+
spec_config = AutoDecodingConfig()
184186
else:
185187
spec_config = None
186188

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -864,7 +864,7 @@ def _prepare_and_schedule_batch(self):
864864
self.use_spec_decode = self.drafter.should_use_spec_decode(
865865
self.active_requests)
866866
self.model_engine.enable_spec_decode = self.use_spec_decode
867-
self._prepare_draft_requests(self.active_requests)
867+
self._prepare_draft_requests()
868868

869869
scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = self._schedule(
870870
)
@@ -965,14 +965,15 @@ def _executor_loop(self):
965965
iter_stats=iter_stats,
966966
iter_start_time=iter_start_time))
967967

968-
def _prepare_draft_requests(self, requests):
968+
def _prepare_draft_requests(self):
969969
try:
970970
# Set draft tokens here to make the KV cache manager
971971
# and scheduler aware of them.
972-
for req in requests:
972+
for req in self.active_requests:
973973
if req.state not in (LlmRequestState.GENERATION_IN_PROGRESS,
974974
LlmRequestState.DISAGG_GENERATION_INIT):
975975
continue
976+
976977
req.py_last_draft_tokens = req.py_draft_tokens
977978
max_draft_len = self.model_engine.spec_config.max_draft_len
978979

tensorrt_llm/_torch/speculative/interface.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ class SpeculativeDecodingMode(IntEnum):
1818
DRAFT_TARGET = auto()
1919
USER_PROVIDED = auto()
2020
NONE = auto()
21+
AUTO = auto()
2122

2223
def is_mtp(self):
2324
return self == SpeculativeDecodingMode.MTP or self == SpeculativeDecodingMode.MTP_EAGLE

tensorrt_llm/_torch/speculative/ngram.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from ordered_set import OrderedSet
44

5+
from tensorrt_llm.llmapi import NGramDecodingConfig
56
from tensorrt_llm.logger import logger
67

78
from ..pyexecutor.llm_request import *
@@ -163,10 +164,11 @@ class NGramDrafter(Drafter):
163164

164165
def __init__(
165166
self,
166-
spec_config: "NGramDecodingConfig",
167+
spec_config: NGramDecodingConfig,
167168
ngram_pool_manager: NGramPoolManager = None,
168169
):
169170
assert ngram_pool_manager is not None, "NGram needs a resource manager to maintain the pool."
171+
self.spec_config = spec_config
170172
self.max_draft_len = spec_config.max_draft_len
171173
self.spec_resource_manager = ngram_pool_manager
172174

@@ -175,6 +177,10 @@ def prepare_draft_tokens(
175177
scheduled_requests: ScheduledRequests,
176178
resource_manager: Optional[ResourceManager] = None,
177179
) -> None:
180+
# Disable NGram speculative decoding auto heuristic for batch size > 32.
181+
if self.spec_config.is_auto_heuristic and len(
182+
scheduled_requests.all_requests()) > 32:
183+
return
178184
# Sort by request_id when py_batch_idx is None as a fallback.
179185
# This happens in the disagg case: for a set of new requests, we draft
180186
# before forward_step, so py_batch_idx is not assigned.

tensorrt_llm/llmapi/__init__.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,15 @@
44
from .build_cache import BuildCacheConfig
55
from .llm import LLM, RequestOutput
66
# yapf: disable
7-
from .llm_args import (BatchingType, CacheTransceiverConfig, CalibConfig,
8-
CapacitySchedulerPolicy, ContextChunkingPolicy,
9-
CudaGraphConfig, DraftTargetDecodingConfig,
10-
DynamicBatchConfig, EagleDecodingConfig,
11-
ExtendedRuntimePerfKnobConfig, KvCacheConfig, LlmArgs,
12-
LookaheadDecodingConfig, MedusaDecodingConfig, MoeConfig,
13-
MTPDecodingConfig, NGramDecodingConfig, SchedulerConfig,
14-
TorchCompileConfig, TorchLlmArgs, TrtLlmArgs,
15-
UserProvidedDecodingConfig)
7+
from .llm_args import (AutoDecodingConfig, BatchingType, CacheTransceiverConfig,
8+
CalibConfig, CapacitySchedulerPolicy,
9+
ContextChunkingPolicy, CudaGraphConfig,
10+
DraftTargetDecodingConfig, DynamicBatchConfig,
11+
EagleDecodingConfig, ExtendedRuntimePerfKnobConfig,
12+
KvCacheConfig, LlmArgs, LookaheadDecodingConfig,
13+
MedusaDecodingConfig, MoeConfig, MTPDecodingConfig,
14+
NGramDecodingConfig, SchedulerConfig, TorchCompileConfig,
15+
TorchLlmArgs, TrtLlmArgs, UserProvidedDecodingConfig)
1616
from .llm_utils import (BuildConfig, KvCacheRetentionConfig, QuantAlgo,
1717
QuantConfig)
1818
from .mpi_session import MpiCommSession
@@ -53,4 +53,5 @@
5353
'LlmArgs',
5454
'TorchLlmArgs',
5555
'TrtLlmArgs',
56+
'AutoDecodingConfig',
5657
]

tensorrt_llm/llmapi/llm.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@
3131
from ..logger import logger
3232
from ..sampling_params import SamplingParams
3333
from .llm_args import (TORCH_LLMARGS_EXPLICIT_DOCSTRING,
34-
TRT_LLMARGS_EXPLICIT_DOCSTRING, PeftCacheConfig,
35-
PybindMirror, TorchLlmArgs, TrtLlmArgs)
34+
TRT_LLMARGS_EXPLICIT_DOCSTRING, NGramDecodingConfig,
35+
PeftCacheConfig, PybindMirror, TorchLlmArgs, TrtLlmArgs)
3636
from .llm_utils import (CachedModelLoader, KvCacheRetentionConfig,
3737
LlmBuildStats, ModelLoader, _ModelRuntimeContext)
3838
from .mpi_session import MpiPoolSession, external_mpi_comm_available
@@ -995,13 +995,43 @@ def _build_model(self):
995995
self._executor_config.cache_transceiver_config = PybindMirror.maybe_to_pybind(
996996
self.args.cache_transceiver_config)
997997
from tensorrt_llm._torch.pyexecutor.config import update_executor_config
998+
999+
spec_config = self.args.speculative_config
1000+
max_batch_size = self._executor_config.max_batch_size
1001+
# Apply default heuristic to AutoDecodingConfig based on benchmark results
1002+
# With concurrency <= 4, max_draft_len = 5, max_matching_ngram_size = 3
1003+
# With concurrency <= 32, max_draft_len = 3, max_matching_ngram_size = 5
1004+
# With concurrency > 32, speculative decoding is disabled.
1005+
if spec_config is not None and spec_config.decoding_type == "AUTO":
1006+
if not self.args.disable_overlap_scheduler:
1007+
logger.info(
1008+
"Disable overlap scheduler to enable Auto speculative decoding with Ngram."
1009+
)
1010+
# From benchmark results, we found that NGram speculative decoding provides better performance than overlap scheduler with low concurrency <= 32.
1011+
# Therefore, we disable overlap scheduler to enable NGram speculative decoding.
1012+
self.args.disable_overlap_scheduler = True
1013+
1014+
spec_config = NGramDecodingConfig(
1015+
max_draft_len=5 if max_batch_size <= 4 else 3,
1016+
max_matching_ngram_size=3 if max_batch_size <= 4 else 5,
1017+
is_keep_all=True,
1018+
is_use_oldest=True,
1019+
is_public_pool=True,
1020+
# Flag to indicate the NGramDecodingConfig is instantiated by auto heuristic.
1021+
is_auto_heuristic=True,
1022+
)
1023+
1024+
logger.info(
1025+
f"Apply heuristic to incomplete NGramDecodingConfig: max_draft_len={spec_config.max_draft_len}, max_matching_ngram_size={spec_config.max_matching_ngram_size}"
1026+
)
1027+
9981028
update_executor_config(
9991029
self._executor_config,
10001030
backend=self.args.backend,
10011031
pytorch_backend_config=self.args.get_pytorch_backend_config()
10021032
if self.args.backend in ["pytorch", "_autodeploy"] else None,
10031033
mapping=self.args.parallel_config.to_mapping(),
1004-
speculative_config=self.args.speculative_config,
1034+
speculative_config=spec_config,
10051035
hf_model_dir=self._hf_model_dir,
10061036
max_input_len=self.args.max_input_len,
10071037
max_seq_len=max_seq_len,

tensorrt_llm/llmapi/llm_args.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,7 @@ def from_dict(cls, data: dict):
336336
"NGram": NGramDecodingConfig,
337337
"DraftTarget": DraftTargetDecodingConfig,
338338
"UserProvided": UserProvidedDecodingConfig,
339+
"AUTO": AutoDecodingConfig,
339340
}
340341

341342
config_class = config_classes.get(decoding_type)
@@ -446,11 +447,13 @@ class NGramDecodingConfig(DecodingBaseConfig):
446447
is_public_pool: bool = True
447448
Whether to use a common pool for all requests, or the pool is private for each request if False.
448449
"""
449-
450-
max_matching_ngram_size: int = 4
450+
max_matching_ngram_size: int = 0
451451
is_keep_all: bool = True
452452
is_use_oldest: bool = True
453453
is_public_pool: bool = True
454+
# Flag to indicate the NGramDecodingConfig is instantiated by auto heuristic.
455+
# User should not set this flag. Use AutoDecodingConfig instead.
456+
is_auto_heuristic: bool = False
454457

455458
@classmethod
456459
def from_dict(cls, data: dict):
@@ -510,6 +513,29 @@ def spec_dec_mode(self):
510513
return TorchSpeculativeDecodingMode.MTP
511514

512515

516+
class AutoDecodingConfig(DecodingBaseConfig):
517+
"""
518+
Configuration for auto speculative decoding.
519+
520+
This config is used to automatically select the best speculative decoding algorithm.
521+
522+
According to benchmark results, the best algorithm in general is NGRAM with low concurrency <= 32.
523+
Default heuristic:
524+
With concurrency <= 4, max_draft_len = 5, max_matching_ngram_size = 3
525+
With concurrency <= 32, max_draft_len = 3, max_matching_ngram_size = 5
526+
With concurrency > 32, speculative decoding is disabled.
527+
"""
528+
529+
@classmethod
530+
def from_dict(cls, data: dict):
531+
return cls(**data)
532+
533+
decoding_type: ClassVar[str] = "AUTO"
534+
535+
def supports_backend(self, backend: str) -> bool:
536+
return backend == "pytorch"
537+
538+
513539
class PybindMirror(ABC):
514540
''' A class containing the utilities for mirroring Python classes to
515541
pybinding classes.
@@ -872,6 +898,7 @@ def supports_backend(self, backend: str) -> bool:
872898
MTPDecodingConfig,
873899
NGramDecodingConfig,
874900
UserProvidedDecodingConfig,
901+
AutoDecodingConfig,
875902
]]
876903

877904

@@ -1292,7 +1319,6 @@ def from_kwargs(cls, **kwargs: Any) -> "BaseLlmArgs":
12921319
tensorrt_llm.llmapi.llm_utils.BaseLlmArgs: The `BaseLlmArgs` instance.
12931320
"""
12941321
kwargs = BaseLlmArgs._check_consistency(dict(kwargs))
1295-
12961322
ret = cls(**kwargs)
12971323
return ret
12981324

@@ -1621,6 +1647,11 @@ def validate_speculative_config(self):
16211647
self.build_config.speculative_decoding_mode = SpeculativeDecodingMode.USER_PROVIDED
16221648
self.build_config.max_draft_len = self.speculative_config.max_draft_len
16231649

1650+
elif isinstance(self.speculative_config, AutoDecodingConfig):
1651+
assert self.backend in ['pytorch', '_autodeploy']
1652+
self.build_config.speculative_decoding_mode = SpeculativeDecodingMode.AUTO
1653+
self.build_config.max_draft_len = self.speculative_config.max_draft_len
1654+
16241655
else:
16251656
raise ValueError(
16261657
f"Unrecognized speculative config type {type(self.speculative_config)}"

tensorrt_llm/models/modeling_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ class SpeculativeDecodingMode(IntFlag):
9898
EAGLE = auto()
9999
NGRAM = auto()
100100
USER_PROVIDED = auto()
101+
AUTO = auto()
101102

102103
@staticmethod
103104
def from_arguments(args: argparse.Namespace):
@@ -117,6 +118,8 @@ def from_arguments(args: argparse.Namespace):
117118
return SpeculativeDecodingMode.NGRAM
118119
elif args.speculative_decoding_mode == "user_provided":
119120
return SpeculativeDecodingMode.USER_PROVIDED
121+
elif args.speculative_decoding_mode == "auto":
122+
return SpeculativeDecodingMode.AUTO
120123
else:
121124
assert False, "Unknown speculative_decoding_mode " + args.speculative_decoding_mode
122125

tests/integration/defs/test_e2e.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1775,6 +1775,31 @@ def test_ptp_quickstart_advanced_ngram(llm_root, llm_venv, model_name,
17751775
_check_mem_usage(running_log, [27.0, 0, 0, 0])
17761776

17771777

1778+
@pytest.mark.parametrize("model_name,model_path", [
1779+
("Llama-3.1-8B-Instruct", "llama-3.1-model/Llama-3.1-8B-Instruct"),
1780+
])
1781+
def test_ptp_quickstart_advanced_auto(llm_root, llm_venv, model_name,
1782+
model_path):
1783+
print(f"Testing {model_name}.")
1784+
example_root = Path(os.path.join(llm_root, "examples", "llm-api"))
1785+
with tempfile.NamedTemporaryFile(mode='w+t',
1786+
suffix=f".{model_name}.log",
1787+
dir="./",
1788+
delete=True,
1789+
delete_on_close=True) as running_log:
1790+
llm_venv.run_cmd([
1791+
str(example_root / "quickstart_advanced.py"),
1792+
"--model_dir",
1793+
f"{llm_models_root()}/{model_path}",
1794+
"--spec_decode_algo",
1795+
"AUTO",
1796+
"--use_cuda_graph",
1797+
"--max_batch_size=4",
1798+
],
1799+
stdout=running_log)
1800+
_check_mem_usage(running_log, [27.0, 0, 0, 0])
1801+
1802+
17781803
@skip_post_blackwell
17791804
@pytest.mark.skip_less_device_memory(110000)
17801805
@pytest.mark.skip_less_device(8)

tests/unittest/api_stability/references_committed/llm.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ methods:
5959
default: null
6060
# Speculative decoding
6161
speculative_config:
62-
annotation: Union[tensorrt_llm.llmapi.llm_args.DraftTargetDecodingConfig, tensorrt_llm.llmapi.llm_args.EagleDecodingConfig,tensorrt_llm.llmapi.llm_args.LookaheadDecodingConfig, tensorrt_llm.llmapi.llm_args.MedusaDecodingConfig, tensorrt_llm.llmapi.llm_args.MTPDecodingConfig, tensorrt_llm.llmapi.llm_args.NGramDecodingConfig, tensorrt_llm.llmapi.llm_args.UserProvidedDecodingConfig, NoneType]
62+
annotation: Union[tensorrt_llm.llmapi.llm_args.DraftTargetDecodingConfig, tensorrt_llm.llmapi.llm_args.EagleDecodingConfig,tensorrt_llm.llmapi.llm_args.LookaheadDecodingConfig, tensorrt_llm.llmapi.llm_args.MedusaDecodingConfig, tensorrt_llm.llmapi.llm_args.MTPDecodingConfig, tensorrt_llm.llmapi.llm_args.NGramDecodingConfig, tensorrt_llm.llmapi.llm_args.UserProvidedDecodingConfig, tensorrt_llm.llmapi.llm_args.AutoDecodingConfig, NoneType]
6363
default: null
6464
# generation constraints
6565
max_batch_size:

0 commit comments

Comments
 (0)