@@ -453,6 +453,10 @@ def __init__(
453
453
else :
454
454
self .cache_indirection_attention = None
455
455
456
+ @property
457
+ def runtime_draft_len (self ):
458
+ return self .max_draft_len if self .enable_spec_decode else 0
459
+
456
460
def set_lora_model_config (self , lora_target_modules : list [str ],
457
461
trtllm_modules_to_hf_modules : dict [str , str ]):
458
462
self .lora_model_config = LoraModelConfig (
@@ -573,7 +577,7 @@ def get_torch_compile_warmup_request(batch_size,
573
577
list (range (batch_size )), [num_tokens_per_request ] *
574
578
batch_size if not is_gen else None ,
575
579
is_gen = is_gen ,
576
- max_num_draft_tokens = self .max_draft_len )
580
+ max_num_draft_tokens = self .runtime_draft_len )
577
581
578
582
if spec_resource_manager is not None :
579
583
spec_resource_manager .add_dummy_requests (
@@ -592,7 +596,7 @@ def get_torch_compile_warmup_request(batch_size,
592
596
593
597
def get_autotune_warmup_request ():
594
598
available_tokens = kv_cache_manager .get_num_available_tokens (
595
- self .max_draft_len )
599
+ self .runtime_draft_len )
596
600
num_tokens_per_request = min (
597
601
min (available_tokens , self .max_seq_len - 1 ),
598
602
self .max_num_tokens )
@@ -626,14 +630,14 @@ def get_autotune_warmup_request():
626
630
request_ids = list (range (full_len_request_num )),
627
631
token_nums = [num_tokens_per_request ] * full_len_request_num ,
628
632
is_gen = False ,
629
- max_num_draft_tokens = self .max_draft_len )
633
+ max_num_draft_tokens = self .runtime_draft_len )
630
634
631
635
if remaining_tokens > 0 :
632
636
final_request = kv_cache_manager .add_dummy_requests (
633
637
request_ids = [full_len_request_num ],
634
638
token_nums = [remaining_tokens ],
635
639
is_gen = False ,
636
- max_num_draft_tokens = self .max_draft_len )
640
+ max_num_draft_tokens = self .runtime_draft_len )
637
641
638
642
requests += final_request
639
643
@@ -680,7 +684,7 @@ def disable_optimization(backend: Backend):
680
684
# Disable cuda graph capture here so that we can properly capture it later
681
685
with self .no_cuda_graph ():
682
686
available_tokens = kv_cache_manager .get_num_available_tokens (
683
- self .max_draft_len )
687
+ self .runtime_draft_len )
684
688
warmup_batch_size = [1 , self .batch_size // 2 ]
685
689
if self .batch_size < 2 :
686
690
warmup_batch_size = [1 ]
@@ -898,7 +902,7 @@ def _get_padded_batch(
898
902
self .cuda_graph_dummy_request = kv_cache_manager .add_dummy_requests (
899
903
cuda_graph_dummy_request_ids ,
900
904
is_gen = True ,
901
- max_num_draft_tokens = self .max_draft_len ,
905
+ max_num_draft_tokens = self .runtime_draft_len ,
902
906
use_mrope = self .use_mrope ,
903
907
max_beam_width = self .max_beam_width )[0 ]
904
908
self .cuda_graph_dummy_request .is_cuda_graph_dummy = True
@@ -1332,7 +1336,7 @@ def _prepare_tp_inputs(
1332
1336
gather_ids .extend (
1333
1337
list (
1334
1338
range (len (position_ids ),
1335
- len (position_ids ) + 1 + self .max_draft_len )))
1339
+ len (position_ids ) + 1 + self .runtime_draft_len )))
1336
1340
position_ids .extend (
1337
1341
list (
1338
1342
range (past_seen_token_num ,
@@ -1348,23 +1352,23 @@ def _prepare_tp_inputs(
1348
1352
# inputs
1349
1353
# overlap scheduler can only support the speculative decoding
1350
1354
# methods with a fixed number of draft tokens
1351
- sequence_lengths .append (1 + self .max_draft_len )
1355
+ sequence_lengths .append (1 + self .runtime_draft_len )
1352
1356
past_seen_token_num = request .max_beam_num_tokens - 1
1353
- draft_lens .append (self .max_draft_len )
1357
+ draft_lens .append (self .runtime_draft_len )
1354
1358
gather_ids .extend (
1355
1359
list (
1356
1360
range (len (position_ids ),
1357
- len (position_ids ) + 1 + self .max_draft_len )))
1361
+ len (position_ids ) + 1 + self .runtime_draft_len )))
1358
1362
position_ids .extend (
1359
1363
list (
1360
- range (past_seen_token_num ,
1361
- past_seen_token_num + 1 + self .max_draft_len )))
1364
+ range (past_seen_token_num , past_seen_token_num + 1 +
1365
+ self .runtime_draft_len )))
1362
1366
# previous tensor
1363
1367
previous_batch_indices .append (previous_batch_idx )
1364
1368
previous_pos_indices .extend ([previous_batch_idx ] *
1365
- (1 + self .max_draft_len ))
1369
+ (1 + self .runtime_draft_len ))
1366
1370
num_cached_tokens_per_seq .append (past_seen_token_num +
1367
- self .max_draft_len + 1 )
1371
+ self .runtime_draft_len + 1 )
1368
1372
prompt_lengths .append (request .py_prompt_len )
1369
1373
request_ids .append (request .py_request_id )
1370
1374
@@ -1441,21 +1445,21 @@ def previous_seq_slots_device():
1441
1445
previous_slots = previous_seq_slots_device ()
1442
1446
# previous input ids
1443
1447
previous_batch_tokens = previous_batch_len * (
1444
- 1 + self .max_draft_len )
1448
+ 1 + self .runtime_draft_len )
1445
1449
new_tokens = new_tokens_device .transpose (
1446
1450
0 , 1 )[previous_slots , :].flatten ()
1447
1451
self .input_ids_cuda [num_tokens :num_tokens +
1448
1452
previous_batch_tokens ].copy_ (
1449
1453
new_tokens , non_blocking = True )
1450
1454
# previous draft tokens
1451
- previous_batch_draft_tokens = previous_batch_len * self .max_draft_len
1455
+ previous_batch_draft_tokens = previous_batch_len * self .runtime_draft_len
1452
1456
self .draft_tokens_cuda [num_draft_tokens :num_draft_tokens +
1453
1457
previous_batch_draft_tokens ].copy_ (
1454
1458
next_draft_tokens_device [
1455
1459
previous_slots , :].flatten (),
1456
1460
non_blocking = True )
1457
1461
# prepare data for the preprocess inputs
1458
- kv_len_offsets_device = new_tokens_lens_device - self .max_draft_len - 1
1462
+ kv_len_offsets_device = new_tokens_lens_device - self .runtime_draft_len - 1
1459
1463
previous_pos_indices_host = torch .tensor (previous_pos_indices ,
1460
1464
dtype = torch .int ,
1461
1465
pin_memory = True )
@@ -1480,8 +1484,8 @@ def previous_seq_slots_device():
1480
1484
extend_dummy_requests )
1481
1485
self .previous_pos_id_offsets_cuda [
1482
1486
(num_extend_reqeust_wo_dummy - previous_batch_len ) *
1483
- (1 + self .max_draft_len ):num_extend_reqeust_wo_dummy *
1484
- (1 + self .max_draft_len )].copy_ (
1487
+ (1 + self .runtime_draft_len ):num_extend_reqeust_wo_dummy *
1488
+ (1 + self .runtime_draft_len )].copy_ (
1485
1489
new_tokens_lens_device [self .previous_pos_indices_cuda [
1486
1490
0 :previous_batch_tokens ]],
1487
1491
non_blocking = True )
0 commit comments