diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index 247f6da1754..ae777d711f6 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -437,6 +437,10 @@ def __init__( else: self.cache_indirection_attention = None + @property + def runtime_draft_len(self): + return self.max_draft_len if self.enable_spec_decode else 0 + def set_lora_model_config(self, lora_target_modules: list[str], trtllm_modules_to_hf_modules: dict[str, str]): self.lora_model_config = LoraModelConfig( @@ -557,7 +561,7 @@ def get_torch_compile_warmup_request(batch_size, list(range(batch_size)), [num_tokens_per_request] * batch_size if not is_gen else None, is_gen=is_gen, - max_num_draft_tokens=self.max_draft_len) + max_num_draft_tokens=self.runtime_draft_len) if spec_resource_manager is not None: spec_resource_manager.add_dummy_requests( @@ -576,7 +580,7 @@ def get_torch_compile_warmup_request(batch_size, def get_autotune_warmup_request(): available_tokens = kv_cache_manager.get_num_available_tokens( - self.max_draft_len) + self.runtime_draft_len) num_tokens_per_request = min( min(available_tokens, self.max_seq_len - 1), self.max_num_tokens) @@ -610,14 +614,14 @@ def get_autotune_warmup_request(): request_ids=list(range(full_len_request_num)), token_nums=[num_tokens_per_request] * full_len_request_num, is_gen=False, - max_num_draft_tokens=self.max_draft_len) + max_num_draft_tokens=self.runtime_draft_len) if remaining_tokens > 0: final_request = kv_cache_manager.add_dummy_requests( request_ids=[full_len_request_num], token_nums=[remaining_tokens], is_gen=False, - max_num_draft_tokens=self.max_draft_len) + max_num_draft_tokens=self.runtime_draft_len) requests += final_request @@ -664,7 +668,7 @@ def disable_optimization(backend: Backend): # Disable cuda graph capture here so that we can properly capture it later with self.no_cuda_graph(): available_tokens = kv_cache_manager.get_num_available_tokens( - self.max_draft_len) + self.runtime_draft_len) warmup_batch_size = [1, self.batch_size // 2] if self.batch_size < 2: warmup_batch_size = [1] @@ -879,7 +883,7 @@ def _get_padded_batch( self.cuda_graph_dummy_request = kv_cache_manager.add_dummy_requests( cuda_graph_dummy_request_ids, is_gen=True, - max_num_draft_tokens=self.max_draft_len, + max_num_draft_tokens=self.runtime_draft_len, use_mrope=self.use_mrope, max_beam_width=self.max_beam_width)[0] self.cuda_graph_dummy_request.is_cuda_graph_dummy = True @@ -1306,7 +1310,7 @@ def _prepare_tp_inputs( gather_ids.extend( list( range(len(position_ids), - len(position_ids) + 1 + self.max_draft_len))) + len(position_ids) + 1 + self.runtime_draft_len))) position_ids.extend( list( range(past_seen_token_num, @@ -1322,23 +1326,23 @@ def _prepare_tp_inputs( # inputs # overlap scheduler can only support the speculative decoding # methods with a fixed number of draft tokens - sequence_lengths.append(1 + self.max_draft_len) + sequence_lengths.append(1 + self.runtime_draft_len) past_seen_token_num = request.max_beam_num_tokens - 1 - draft_lens.append(self.max_draft_len) + draft_lens.append(self.runtime_draft_len) gather_ids.extend( list( range(len(position_ids), - len(position_ids) + 1 + self.max_draft_len))) + len(position_ids) + 1 + self.runtime_draft_len))) position_ids.extend( list( - range(past_seen_token_num, - past_seen_token_num + 1 + self.max_draft_len))) + range(past_seen_token_num, past_seen_token_num + 1 + + self.runtime_draft_len))) # previous tensor previous_batch_indices.append(previous_batch_idx) previous_pos_indices.extend([previous_batch_idx] * - (1 + self.max_draft_len)) + (1 + self.runtime_draft_len)) num_cached_tokens_per_seq.append(past_seen_token_num + - self.max_draft_len + 1) + self.runtime_draft_len + 1) prompt_lengths.append(request.py_prompt_len) request_ids.append(request.py_request_id) @@ -1412,21 +1416,21 @@ def previous_seq_slots_device(): previous_slots = previous_seq_slots_device() # previous input ids previous_batch_tokens = previous_batch_len * ( - 1 + self.max_draft_len) + 1 + self.runtime_draft_len) new_tokens = new_tokens_device.transpose( 0, 1)[previous_slots, :].flatten() self.input_ids_cuda[num_tokens:num_tokens + previous_batch_tokens].copy_( new_tokens, non_blocking=True) # previous draft tokens - previous_batch_draft_tokens = previous_batch_len * self.max_draft_len + previous_batch_draft_tokens = previous_batch_len * self.runtime_draft_len self.draft_tokens_cuda[num_draft_tokens:num_draft_tokens + previous_batch_draft_tokens].copy_( next_draft_tokens_device[ previous_slots, :].flatten(), non_blocking=True) # prepare data for the preprocess inputs - kv_len_offsets_device = new_tokens_lens_device - self.max_draft_len - 1 + kv_len_offsets_device = new_tokens_lens_device - self.runtime_draft_len - 1 previous_pos_indices_host = torch.tensor(previous_pos_indices, dtype=torch.int, pin_memory=True) @@ -1451,8 +1455,8 @@ def previous_seq_slots_device(): extend_dummy_requests) self.previous_pos_id_offsets_cuda[ (num_extend_reqeust_wo_dummy - previous_batch_len) * - (1 + self.max_draft_len):num_extend_reqeust_wo_dummy * - (1 + self.max_draft_len)].copy_( + (1 + self.runtime_draft_len):num_extend_reqeust_wo_dummy * + (1 + self.runtime_draft_len)].copy_( new_tokens_lens_device[self.previous_pos_indices_cuda[ 0:previous_batch_tokens]], non_blocking=True) diff --git a/tensorrt_llm/_torch/pyexecutor/sampler.py b/tensorrt_llm/_torch/pyexecutor/sampler.py index 9a227c0c65b..97c591f37e9 100644 --- a/tensorrt_llm/_torch/pyexecutor/sampler.py +++ b/tensorrt_llm/_torch/pyexecutor/sampler.py @@ -336,7 +336,7 @@ def process_draft_tokens(self, request: LlmRequest, if request.py_draft_logits is None: new_token = add_token(request, new_tokens, beam=self.BEAM) stop = self._handle_stop_criteria(request, new_token) - if stop or len(request.py_draft_tokens) == 0: + if stop or get_draft_token_length(request) == 0: return 0 num_accepted = 0 @@ -360,10 +360,10 @@ def process_draft_tokens(self, request: LlmRequest, request.py_draft_logits[0], generator=generator) target_probs = request.py_target_probs - p = draft_probs[torch.arange(len(request.py_draft_tokens)), + p = draft_probs[torch.arange(get_draft_token_length(request)), request.py_draft_tokens] q = target_probs[:-1] - q = q[torch.arange(len(request.py_draft_tokens)), + q = q[torch.arange(get_draft_token_length(request)), request.py_draft_tokens] accept_probs = torch.minimum(torch.ones(()), q / p) # Use deterministic random generation for multi-GPU consistency @@ -374,7 +374,7 @@ def process_draft_tokens(self, request: LlmRequest, sample_last = True stop = False if rejected_indices.numel() == 0: - num_initially_accepted = len(request.py_draft_tokens) + num_initially_accepted = get_draft_token_length(request) sample_last = False else: num_initially_accepted = rejected_indices[0].item() @@ -575,7 +575,7 @@ def _process_requests(self, logits = raw_logits[:sum_steps] # Collect steps per request for batched strategy steps_per_request = [ - 1 + len(req.py_draft_tokens) for req in requests + 1 + get_draft_token_length(req) for req in requests ] logits = self._apply_embedding_bias(logits, requests, steps_per_request) diff --git a/tests/integration/defs/accuracy/accuracy_core.py b/tests/integration/defs/accuracy/accuracy_core.py index e135ddfa010..dd8688ec636 100644 --- a/tests/integration/defs/accuracy/accuracy_core.py +++ b/tests/integration/defs/accuracy/accuracy_core.py @@ -155,6 +155,8 @@ def evaluate(self, spec_dec_algo = None elif isinstance(llm.args.speculative_config, DecodingBaseConfig): spec_dec_algo = llm.args.speculative_config.decoding_type + if spec_dec_algo == 'AUTO': + spec_dec_algo = 'NGram' else: raise ValueError( f"Not recognized speculative_config: {llm.args.speculative_config}." diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index ada8352b9ad..bab00d7dec9 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -21,10 +21,10 @@ from tensorrt_llm._torch.modules.fused_moe.fused_moe_triton import \ IS_TRITON_KERNELS_AVAILABLE from tensorrt_llm._torch.pyexecutor.config import MoeLoadBalancerConfig -from tensorrt_llm.llmapi import (CudaGraphConfig, EagleDecodingConfig, - KvCacheConfig, MoeConfig, MTPDecodingConfig, - NGramDecodingConfig, SamplingParams, - TorchCompileConfig) +from tensorrt_llm.llmapi import (AutoDecodingConfig, CudaGraphConfig, + EagleDecodingConfig, KvCacheConfig, MoeConfig, + MTPDecodingConfig, NGramDecodingConfig, + SamplingParams, TorchCompileConfig) from tensorrt_llm.quantization import QuantAlgo from ..conftest import (llm_models_root, parametrize_with_ids, skip_no_hopper, @@ -356,6 +356,23 @@ def test_guided_decoding_with_ngram(self, backend: str, mocker): task = JsonModeEval(self.MODEL_NAME) task.evaluate(llm) + @skip_pre_hopper + def test_auto_spec_decode(self): + pytorch_config = { + "cuda_graph_config": + CudaGraphConfig(batch_sizes=[1, 32, 64], enable_padding=True) + } + kv_cache_config = KvCacheConfig(enable_block_reuse=False, + free_gpu_memory_fraction=0.5) + spec_config = AutoDecodingConfig() + with LLM(model=self.MODEL_PATH, + **pytorch_config, + kv_cache_config=kv_cache_config, + speculative_config=spec_config, + max_batch_size=64) as llm: + task = GSM8K(self.MODEL_NAME) + task.evaluate(llm) + class TestLlama3_2_1B(LlmapiAccuracyTestHarness): MODEL_NAME = "meta-llama/Llama-3.2-1B"