|
32 | 32 | from ..sampling_params import SamplingParams |
33 | 33 | from ..scheduling_params import SchedulingParams |
34 | 34 | from .llm_args import (TORCH_LLMARGS_EXPLICIT_DOCSTRING, |
35 | | - TRT_LLMARGS_EXPLICIT_DOCSTRING, NGramDecodingConfig, |
36 | | - PeftCacheConfig, PybindMirror, TorchLlmArgs, TrtLlmArgs) |
| 35 | + TRT_LLMARGS_EXPLICIT_DOCSTRING, PeftCacheConfig, |
| 36 | + PybindMirror, TorchLlmArgs, TrtLlmArgs) |
37 | 37 | from .llm_utils import (CachedModelLoader, KvCacheRetentionConfig, |
38 | 38 | LlmBuildStats, ModelLoader, _ModelRuntimeContext) |
39 | 39 | from .mpi_session import MpiPoolSession, external_mpi_comm_available |
@@ -1015,32 +1015,10 @@ def _build_model(self): |
1015 | 1015 |
|
1016 | 1016 | spec_config = self.args.speculative_config |
1017 | 1017 | max_batch_size = self._executor_config.max_batch_size |
1018 | | - # Apply default heuristic to AutoDecodingConfig based on benchmark results |
1019 | | - # With concurrency <= 4, max_draft_len = 5, max_matching_ngram_size = 3 |
1020 | | - # With concurrency <= 32, max_draft_len = 3, max_matching_ngram_size = 5 |
1021 | | - # With concurrency > 32, speculative decoding is disabled. |
1022 | | - if spec_config is not None and spec_config.decoding_type == "AUTO": |
1023 | | - if not self.args.disable_overlap_scheduler: |
1024 | | - logger.info( |
1025 | | - "Disable overlap scheduler to enable Auto speculative decoding with Ngram." |
1026 | | - ) |
1027 | | - # From benchmark results, we found that NGram speculative decoding provides better performance than overlap scheduler with low concurrency <= 32. |
1028 | | - # Therefore, we disable overlap scheduler to enable NGram speculative decoding. |
1029 | | - self.args.disable_overlap_scheduler = True |
1030 | | - |
1031 | | - spec_config = NGramDecodingConfig( |
1032 | | - max_draft_len=5 if max_batch_size <= 4 else 3, |
1033 | | - max_matching_ngram_size=3 if max_batch_size <= 4 else 5, |
1034 | | - is_keep_all=True, |
1035 | | - is_use_oldest=True, |
1036 | | - is_public_pool=True, |
1037 | | - # Flag to indicate the NGramDecodingConfig is instantiated by auto heuristic. |
1038 | | - is_auto_heuristic=True, |
1039 | | - ) |
1040 | 1018 |
|
1041 | | - logger.info( |
1042 | | - f"Apply heuristic to incomplete NGramDecodingConfig: max_draft_len={spec_config.max_draft_len}, max_matching_ngram_size={spec_config.max_matching_ngram_size}" |
1043 | | - ) |
| 1019 | + if spec_config is not None and spec_config.decoding_type == "AUTO": |
| 1020 | + from tensorrt_llm._torch.speculative import suggest_spec_config |
| 1021 | + spec_config = suggest_spec_config(max_batch_size) |
1044 | 1022 |
|
1045 | 1023 | update_executor_config( |
1046 | 1024 | self._executor_config, |
|
0 commit comments