Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion tensorrt_llm/_torch/pyexecutor/model_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -875,11 +875,12 @@ def _get_padded_batch(
if available_blocks < 1:
return 0

max_draft_len = self.max_draft_len if self.enable_spec_decode else 0
cuda_graph_dummy_request_ids = [MAX_UINT64 - 1]
self.cuda_graph_dummy_request = kv_cache_manager.add_dummy_requests(
cuda_graph_dummy_request_ids,
is_gen=True,
max_num_draft_tokens=self.max_draft_len,
max_num_draft_tokens=max_draft_len,
use_mrope=self.use_mrope,
max_beam_width=self.max_beam_width)[0]
self.cuda_graph_dummy_request.is_cuda_graph_dummy = True
Expand Down
2 changes: 2 additions & 0 deletions tests/integration/defs/accuracy/accuracy_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,8 @@ def evaluate(self,
spec_dec_algo = None
elif isinstance(llm.args.speculative_config, DecodingBaseConfig):
spec_dec_algo = llm.args.speculative_config.decoding_type
if spec_dec_algo == 'AUTO':
spec_dec_algo = 'NGram'
else:
raise ValueError(
f"Not recognized speculative_config: {llm.args.speculative_config}."
Expand Down
25 changes: 21 additions & 4 deletions tests/integration/defs/accuracy/test_llm_api_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@
from tensorrt_llm._torch.modules.fused_moe.fused_moe_triton import \
IS_TRITON_KERNELS_AVAILABLE
from tensorrt_llm._torch.pyexecutor.config import MoeLoadBalancerConfig
from tensorrt_llm.llmapi import (CudaGraphConfig, EagleDecodingConfig,
KvCacheConfig, MoeConfig, MTPDecodingConfig,
NGramDecodingConfig, SamplingParams,
TorchCompileConfig)
from tensorrt_llm.llmapi import (AutoDecodingConfig, CudaGraphConfig,
EagleDecodingConfig, KvCacheConfig, MoeConfig,
MTPDecodingConfig, NGramDecodingConfig,
SamplingParams, TorchCompileConfig)
from tensorrt_llm.quantization import QuantAlgo

from ..conftest import (llm_models_root, parametrize_with_ids, skip_no_hopper,
Expand Down Expand Up @@ -356,6 +356,23 @@ def test_guided_decoding_with_ngram(self, backend: str, mocker):
task = JsonModeEval(self.MODEL_NAME)
task.evaluate(llm)

@skip_pre_hopper
def test_auto_spec_decode(self):
pytorch_config = {
"cuda_graph_config":
CudaGraphConfig(batch_sizes=[1, 32, 64], enable_padding=True)
}
kv_cache_config = KvCacheConfig(enable_block_reuse=False,
free_gpu_memory_fraction=0.5)
spec_config = AutoDecodingConfig()
with LLM(model=self.MODEL_PATH,
**pytorch_config,
kv_cache_config=kv_cache_config,
speculative_config=spec_config,
max_batch_size=64) as llm:
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)


class TestLlama3_2_1B(LlmapiAccuracyTestHarness):
MODEL_NAME = "meta-llama/Llama-3.2-1B"
Expand Down