Skip to content

Commit b4fcd5f

Browse files
authored
[https://nvbugs/5441438][fix] Set correct draft length for the cuda graph dummy request (#6701)
Signed-off-by: ziyixiong-nv <[email protected]>
1 parent ead89a0 commit b4fcd5f

File tree

4 files changed

+51
-28
lines changed

4 files changed

+51
-28
lines changed

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,10 @@ def __init__(
453453
else:
454454
self.cache_indirection_attention = None
455455

456+
@property
457+
def runtime_draft_len(self):
458+
return self.max_draft_len if self.enable_spec_decode else 0
459+
456460
def set_lora_model_config(self, lora_target_modules: list[str],
457461
trtllm_modules_to_hf_modules: dict[str, str]):
458462
self.lora_model_config = LoraModelConfig(
@@ -573,7 +577,7 @@ def get_torch_compile_warmup_request(batch_size,
573577
list(range(batch_size)), [num_tokens_per_request] *
574578
batch_size if not is_gen else None,
575579
is_gen=is_gen,
576-
max_num_draft_tokens=self.max_draft_len)
580+
max_num_draft_tokens=self.runtime_draft_len)
577581

578582
if spec_resource_manager is not None:
579583
spec_resource_manager.add_dummy_requests(
@@ -592,7 +596,7 @@ def get_torch_compile_warmup_request(batch_size,
592596

593597
def get_autotune_warmup_request():
594598
available_tokens = kv_cache_manager.get_num_available_tokens(
595-
self.max_draft_len)
599+
self.runtime_draft_len)
596600
num_tokens_per_request = min(
597601
min(available_tokens, self.max_seq_len - 1),
598602
self.max_num_tokens)
@@ -626,14 +630,14 @@ def get_autotune_warmup_request():
626630
request_ids=list(range(full_len_request_num)),
627631
token_nums=[num_tokens_per_request] * full_len_request_num,
628632
is_gen=False,
629-
max_num_draft_tokens=self.max_draft_len)
633+
max_num_draft_tokens=self.runtime_draft_len)
630634

631635
if remaining_tokens > 0:
632636
final_request = kv_cache_manager.add_dummy_requests(
633637
request_ids=[full_len_request_num],
634638
token_nums=[remaining_tokens],
635639
is_gen=False,
636-
max_num_draft_tokens=self.max_draft_len)
640+
max_num_draft_tokens=self.runtime_draft_len)
637641

638642
requests += final_request
639643

@@ -680,7 +684,7 @@ def disable_optimization(backend: Backend):
680684
# Disable cuda graph capture here so that we can properly capture it later
681685
with self.no_cuda_graph():
682686
available_tokens = kv_cache_manager.get_num_available_tokens(
683-
self.max_draft_len)
687+
self.runtime_draft_len)
684688
warmup_batch_size = [1, self.batch_size // 2]
685689
if self.batch_size < 2:
686690
warmup_batch_size = [1]
@@ -898,7 +902,7 @@ def _get_padded_batch(
898902
self.cuda_graph_dummy_request = kv_cache_manager.add_dummy_requests(
899903
cuda_graph_dummy_request_ids,
900904
is_gen=True,
901-
max_num_draft_tokens=self.max_draft_len,
905+
max_num_draft_tokens=self.runtime_draft_len,
902906
use_mrope=self.use_mrope,
903907
max_beam_width=self.max_beam_width)[0]
904908
self.cuda_graph_dummy_request.is_cuda_graph_dummy = True
@@ -1332,7 +1336,7 @@ def _prepare_tp_inputs(
13321336
gather_ids.extend(
13331337
list(
13341338
range(len(position_ids),
1335-
len(position_ids) + 1 + self.max_draft_len)))
1339+
len(position_ids) + 1 + self.runtime_draft_len)))
13361340
position_ids.extend(
13371341
list(
13381342
range(past_seen_token_num,
@@ -1348,23 +1352,23 @@ def _prepare_tp_inputs(
13481352
# inputs
13491353
# overlap scheduler can only support the speculative decoding
13501354
# methods with a fixed number of draft tokens
1351-
sequence_lengths.append(1 + self.max_draft_len)
1355+
sequence_lengths.append(1 + self.runtime_draft_len)
13521356
past_seen_token_num = request.max_beam_num_tokens - 1
1353-
draft_lens.append(self.max_draft_len)
1357+
draft_lens.append(self.runtime_draft_len)
13541358
gather_ids.extend(
13551359
list(
13561360
range(len(position_ids),
1357-
len(position_ids) + 1 + self.max_draft_len)))
1361+
len(position_ids) + 1 + self.runtime_draft_len)))
13581362
position_ids.extend(
13591363
list(
1360-
range(past_seen_token_num,
1361-
past_seen_token_num + 1 + self.max_draft_len)))
1364+
range(past_seen_token_num, past_seen_token_num + 1 +
1365+
self.runtime_draft_len)))
13621366
# previous tensor
13631367
previous_batch_indices.append(previous_batch_idx)
13641368
previous_pos_indices.extend([previous_batch_idx] *
1365-
(1 + self.max_draft_len))
1369+
(1 + self.runtime_draft_len))
13661370
num_cached_tokens_per_seq.append(past_seen_token_num +
1367-
self.max_draft_len + 1)
1371+
self.runtime_draft_len + 1)
13681372
prompt_lengths.append(request.py_prompt_len)
13691373
request_ids.append(request.py_request_id)
13701374

@@ -1441,21 +1445,21 @@ def previous_seq_slots_device():
14411445
previous_slots = previous_seq_slots_device()
14421446
# previous input ids
14431447
previous_batch_tokens = previous_batch_len * (
1444-
1 + self.max_draft_len)
1448+
1 + self.runtime_draft_len)
14451449
new_tokens = new_tokens_device.transpose(
14461450
0, 1)[previous_slots, :].flatten()
14471451
self.input_ids_cuda[num_tokens:num_tokens +
14481452
previous_batch_tokens].copy_(
14491453
new_tokens, non_blocking=True)
14501454
# previous draft tokens
1451-
previous_batch_draft_tokens = previous_batch_len * self.max_draft_len
1455+
previous_batch_draft_tokens = previous_batch_len * self.runtime_draft_len
14521456
self.draft_tokens_cuda[num_draft_tokens:num_draft_tokens +
14531457
previous_batch_draft_tokens].copy_(
14541458
next_draft_tokens_device[
14551459
previous_slots, :].flatten(),
14561460
non_blocking=True)
14571461
# prepare data for the preprocess inputs
1458-
kv_len_offsets_device = new_tokens_lens_device - self.max_draft_len - 1
1462+
kv_len_offsets_device = new_tokens_lens_device - self.runtime_draft_len - 1
14591463
previous_pos_indices_host = torch.tensor(previous_pos_indices,
14601464
dtype=torch.int,
14611465
pin_memory=True)
@@ -1480,8 +1484,8 @@ def previous_seq_slots_device():
14801484
extend_dummy_requests)
14811485
self.previous_pos_id_offsets_cuda[
14821486
(num_extend_reqeust_wo_dummy - previous_batch_len) *
1483-
(1 + self.max_draft_len):num_extend_reqeust_wo_dummy *
1484-
(1 + self.max_draft_len)].copy_(
1487+
(1 + self.runtime_draft_len):num_extend_reqeust_wo_dummy *
1488+
(1 + self.runtime_draft_len)].copy_(
14851489
new_tokens_lens_device[self.previous_pos_indices_cuda[
14861490
0:previous_batch_tokens]],
14871491
non_blocking=True)

tensorrt_llm/_torch/pyexecutor/sampler.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,7 @@ def process_draft_tokens(self, request: LlmRequest,
336336
if request.py_draft_logits is None:
337337
new_token = add_token(request, new_tokens, beam=self.BEAM)
338338
stop = self._handle_stop_criteria(request, new_token)
339-
if stop or len(request.py_draft_tokens) == 0:
339+
if stop or get_draft_token_length(request) == 0:
340340
return 0
341341
num_accepted = 0
342342

@@ -360,10 +360,10 @@ def process_draft_tokens(self, request: LlmRequest,
360360
request.py_draft_logits[0],
361361
generator=generator)
362362
target_probs = request.py_target_probs
363-
p = draft_probs[torch.arange(len(request.py_draft_tokens)),
363+
p = draft_probs[torch.arange(get_draft_token_length(request)),
364364
request.py_draft_tokens]
365365
q = target_probs[:-1]
366-
q = q[torch.arange(len(request.py_draft_tokens)),
366+
q = q[torch.arange(get_draft_token_length(request)),
367367
request.py_draft_tokens]
368368
accept_probs = torch.minimum(torch.ones(()), q / p)
369369
# Use deterministic random generation for multi-GPU consistency
@@ -374,7 +374,7 @@ def process_draft_tokens(self, request: LlmRequest,
374374
sample_last = True
375375
stop = False
376376
if rejected_indices.numel() == 0:
377-
num_initially_accepted = len(request.py_draft_tokens)
377+
num_initially_accepted = get_draft_token_length(request)
378378
sample_last = False
379379
else:
380380
num_initially_accepted = rejected_indices[0].item()
@@ -575,7 +575,7 @@ def _process_requests(self,
575575
logits = raw_logits[:sum_steps]
576576
# Collect steps per request for batched strategy
577577
steps_per_request = [
578-
1 + len(req.py_draft_tokens) for req in requests
578+
1 + get_draft_token_length(req) for req in requests
579579
]
580580
logits = self._apply_embedding_bias(logits, requests,
581581
steps_per_request)

tests/integration/defs/accuracy/accuracy_core.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,8 @@ def evaluate(self,
155155
spec_dec_algo = None
156156
elif isinstance(llm.args.speculative_config, DecodingBaseConfig):
157157
spec_dec_algo = llm.args.speculative_config.decoding_type
158+
if spec_dec_algo == 'AUTO':
159+
spec_dec_algo = 'NGram'
158160
else:
159161
raise ValueError(
160162
f"Not recognized speculative_config: {llm.args.speculative_config}."

tests/integration/defs/accuracy/test_llm_api_pytorch.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,10 @@
2121
from tensorrt_llm._torch.modules.fused_moe.fused_moe_triton import \
2222
IS_TRITON_KERNELS_AVAILABLE
2323
from tensorrt_llm._torch.pyexecutor.config import MoeLoadBalancerConfig
24-
from tensorrt_llm.llmapi import (CudaGraphConfig, EagleDecodingConfig,
25-
KvCacheConfig, MoeConfig, MTPDecodingConfig,
26-
NGramDecodingConfig, SamplingParams,
27-
TorchCompileConfig)
24+
from tensorrt_llm.llmapi import (AutoDecodingConfig, CudaGraphConfig,
25+
EagleDecodingConfig, KvCacheConfig, MoeConfig,
26+
MTPDecodingConfig, NGramDecodingConfig,
27+
SamplingParams, TorchCompileConfig)
2828
from tensorrt_llm.quantization import QuantAlgo
2929

3030
from ..conftest import (llm_models_root, parametrize_with_ids, skip_no_hopper,
@@ -355,6 +355,23 @@ def test_guided_decoding_with_ngram(self, backend: str, mocker):
355355
task = JsonModeEval(self.MODEL_NAME)
356356
task.evaluate(llm)
357357

358+
@skip_pre_hopper
359+
def test_auto_spec_decode(self):
360+
pytorch_config = {
361+
"cuda_graph_config":
362+
CudaGraphConfig(batch_sizes=[1, 32, 64], enable_padding=True)
363+
}
364+
kv_cache_config = KvCacheConfig(enable_block_reuse=False,
365+
free_gpu_memory_fraction=0.5)
366+
spec_config = AutoDecodingConfig()
367+
with LLM(model=self.MODEL_PATH,
368+
**pytorch_config,
369+
kv_cache_config=kv_cache_config,
370+
speculative_config=spec_config,
371+
max_batch_size=64) as llm:
372+
task = GSM8K(self.MODEL_NAME)
373+
task.evaluate(llm)
374+
358375

359376
class TestLlama3_2_1B(LlmapiAccuracyTestHarness):
360377
MODEL_NAME = "meta-llama/Llama-3.2-1B"

0 commit comments

Comments
 (0)