17
17
import tensorrt_llm .bindings .internal .userbuffers as ub
18
18
from tensorrt_llm ._torch .models .checkpoints .base_checkpoint_loader import \
19
19
BaseCheckpointLoader
20
- from tensorrt_llm ._torch .pyexecutor .sampler import SampleStateTensors
21
20
from tensorrt_llm ._torch .pyexecutor .llm_request import LlmRequest
21
+ from tensorrt_llm ._torch .pyexecutor .sampler import SampleStateTensors
22
22
from tensorrt_llm ._torch .speculative .mtp import SampleStateTensorsMTP
23
23
from tensorrt_llm ._utils import (is_trace_enabled , nvtx_range , release_gc ,
24
24
torch_dtype_to_str , trace_func )
@@ -261,7 +261,10 @@ def __init__(
261
261
lora_config : Optional [LoraConfig ] = None ,
262
262
is_draft_model : bool = False ,
263
263
):
264
+ # Set deterministic seed for consistent multi-GPU sampling using PyTorch RNG
265
+ # operations that avoid torch.multinomial's CPU-GPU sync overhead
264
266
torch .manual_seed (0 )
267
+
265
268
self .ub_buffers = None
266
269
self .batch_size = batch_size
267
270
self .max_num_tokens = max_num_tokens
@@ -278,6 +281,8 @@ def __init__(
278
281
self .spec_config = spec_config
279
282
self .is_spec_decode = spec_config is not None
280
283
self .is_draft_model = is_draft_model
284
+ self .is_advanced_mtp_sampler = self .is_spec_decode and self .spec_config .spec_dec_mode .is_mtp (
285
+ ) and self .spec_config .use_advanced_mtp_sampler
281
286
282
287
self .in_warmup = False
283
288
@@ -373,18 +378,24 @@ def __init__(
373
378
self .without_logits = self .spec_config .spec_dec_mode .without_logits (
374
379
)
375
380
self .max_draft_len = spec_config .max_draft_len
376
- self .temperatures_cuda = torch .empty ((self .batch_size * (self .max_draft_len + 1 ), ),
377
- dtype = torch .float ,
378
- device = 'cuda' )
379
- self .top_k_cuda = torch .empty ((self .batch_size * (self .max_draft_len + 1 ), ),
380
- dtype = torch .int ,
381
- device = 'cuda' )
382
- self .top_p_cuda = torch .empty ((self .batch_size * (self .max_draft_len + 1 ), ),
383
- dtype = torch .float ,
384
- device = 'cuda' )
385
- self .min_p_cuda = torch .empty ((self .batch_size * (self .max_draft_len + 1 ), ),
386
- dtype = torch .float ,
387
- device = 'cuda' )
381
+
382
+ if self .is_advanced_mtp_sampler :
383
+ self .temperatures_cuda = torch .empty (
384
+ (self .batch_size * (self .max_draft_len + 1 ), ),
385
+ dtype = torch .float ,
386
+ device = 'cuda' )
387
+ self .top_k_cuda = torch .empty (
388
+ (self .batch_size * (self .max_draft_len + 1 ), ),
389
+ dtype = torch .int ,
390
+ device = 'cuda' )
391
+ self .top_p_cuda = torch .empty (
392
+ (self .batch_size * (self .max_draft_len + 1 ), ),
393
+ dtype = torch .float ,
394
+ device = 'cuda' )
395
+ self .min_p_cuda = torch .empty (
396
+ (self .batch_size * (self .max_draft_len + 1 ), ),
397
+ dtype = torch .float ,
398
+ device = 'cuda' )
388
399
else :
389
400
self .without_logits = False
390
401
self .max_draft_len = 0
@@ -1142,11 +1153,12 @@ def _prepare_tp_inputs(
1142
1153
draft_lens = []
1143
1154
multimodal_params_list = []
1144
1155
gen_request_seq_slots = [] # per generation request
1145
-
1146
- temperatures = []
1147
- top_k = []
1148
- top_p = []
1149
- min_p = []
1156
+
1157
+ if self .is_advanced_mtp_sampler :
1158
+ temperatures = []
1159
+ top_k = []
1160
+ top_p = []
1161
+ min_p = []
1150
1162
1151
1163
def get_request_temperature (request : LlmRequest ) -> float :
1152
1164
if not request .sampling_config .temperature :
@@ -1213,12 +1225,13 @@ def get_request_min_p(request: LlmRequest) -> float:
1213
1225
1214
1226
if multimodal_params .has_content ():
1215
1227
multimodal_params_list .append (multimodal_params )
1216
-
1217
- temperatures .append (get_request_temperature (request ))
1218
- top_k .append (get_request_top_k (request ))
1219
- top_p .append (get_request_top_p (request ))
1220
- min_p .append (get_request_min_p (request ))
1221
-
1228
+
1229
+ if self .is_advanced_mtp_sampler :
1230
+ temperatures .append (get_request_temperature (request ))
1231
+ top_k .append (get_request_top_k (request ))
1232
+ top_p .append (get_request_top_p (request ))
1233
+ min_p .append (get_request_min_p (request ))
1234
+
1222
1235
request .py_batch_idx = request .py_seq_slot
1223
1236
1224
1237
num_ctx_requests = len (scheduled_requests .context_requests )
@@ -1300,10 +1313,17 @@ def get_request_min_p(request: LlmRequest) -> float:
1300
1313
past_seen_token_num + 1 + num_draft_tokens )))
1301
1314
num_cached_tokens_per_seq .append (past_seen_token_num )
1302
1315
request_ids .append (request .py_request_id )
1303
- temperatures .extend ([get_request_temperature (request )] * (num_draft_tokens + 1 ))
1304
- top_k .extend ([get_request_top_k (request )] * (num_draft_tokens + 1 ))
1305
- top_p .extend ([get_request_top_p (request )] * (num_draft_tokens + 1 ))
1306
- min_p .extend ([get_request_min_p (request )] * (num_draft_tokens + 1 ))
1316
+
1317
+ if self .is_advanced_mtp_sampler :
1318
+ temperatures .extend ([get_request_temperature (request )] *
1319
+ (num_draft_tokens + 1 ))
1320
+ top_k .extend ([get_request_top_k (request )] *
1321
+ (num_draft_tokens + 1 ))
1322
+ top_p .extend ([get_request_top_p (request )] *
1323
+ (num_draft_tokens + 1 ))
1324
+ min_p .extend ([get_request_min_p (request )] *
1325
+ (num_draft_tokens + 1 ))
1326
+
1307
1327
# update batch index
1308
1328
request .py_batch_idx = request .py_seq_slot
1309
1329
else :
@@ -1332,10 +1352,16 @@ def get_request_min_p(request: LlmRequest) -> float:
1332
1352
self .max_draft_len + 1 )
1333
1353
prompt_lengths .append (request .py_prompt_len )
1334
1354
request_ids .append (request .py_request_id )
1335
- temperatures .extend ([get_request_temperature (request )] * (self .max_draft_len + 1 ))
1336
- top_k .extend ([get_request_top_k (request )] * (self .max_draft_len + 1 ))
1337
- top_p .extend ([get_request_top_p (request )] * (self .max_draft_len + 1 ))
1338
- min_p .extend ([get_request_min_p (request )] * (self .max_draft_len + 1 ))
1355
+
1356
+ if self .is_advanced_mtp_sampler :
1357
+ temperatures .extend ([get_request_temperature (request )] *
1358
+ (self .max_draft_len + 1 ))
1359
+ top_k .extend ([get_request_top_k (request )] *
1360
+ (self .max_draft_len + 1 ))
1361
+ top_p .extend ([get_request_top_p (request )] *
1362
+ (self .max_draft_len + 1 ))
1363
+ min_p .extend ([get_request_min_p (request )] *
1364
+ (self .max_draft_len + 1 ))
1339
1365
1340
1366
for request in generation_requests :
1341
1367
beam_width = request .sampling_config .beam_width
@@ -1368,11 +1394,17 @@ def get_request_min_p(request: LlmRequest) -> float:
1368
1394
1369
1395
request_ids .append (request .py_request_id )
1370
1396
gen_request_seq_slots .append (request .py_seq_slot )
1371
-
1372
- temperatures .extend ([get_request_temperature (request )] * (self .max_draft_len + 1 ))
1373
- top_k .extend ([get_request_top_k (request )] * (self .max_draft_len + 1 ))
1374
- top_p .extend ([get_request_top_p (request )] * (self .max_draft_len + 1 ))
1375
- min_p .extend ([get_request_min_p (request )] * (self .max_draft_len + 1 ))
1397
+
1398
+ if self .is_advanced_mtp_sampler :
1399
+ temperatures .extend ([get_request_temperature (request )] *
1400
+ (self .max_draft_len + 1 ))
1401
+ top_k .extend ([get_request_top_k (request )] *
1402
+ (self .max_draft_len + 1 ))
1403
+ top_p .extend ([get_request_top_p (request )] *
1404
+ (self .max_draft_len + 1 ))
1405
+ min_p .extend ([get_request_min_p (request )] *
1406
+ (self .max_draft_len + 1 ))
1407
+
1376
1408
request .py_batch_idx = request .py_seq_slot
1377
1409
1378
1410
previous_batch_len = len (previous_batch_indices )
@@ -1476,18 +1508,21 @@ def previous_seq_slots_device():
1476
1508
self .gather_ids_cuda [:len (gather_ids )].copy_ (torch .tensor (
1477
1509
gather_ids , dtype = torch .int , pin_memory = True ),
1478
1510
non_blocking = True )
1479
- self .temperatures_cuda [:len (temperatures )].copy_ (torch .tensor (
1480
- temperatures , dtype = torch .float , pin_memory = True ),
1481
- non_blocking = True )
1482
- self .top_k_cuda [:len (top_k )].copy_ (torch .tensor (
1483
- top_k , dtype = torch .int , pin_memory = True ),
1484
- non_blocking = True )
1485
- self .top_p_cuda [:len (top_p )].copy_ (torch .tensor (
1486
- top_p , dtype = torch .float , pin_memory = True ),
1487
- non_blocking = True )
1488
- self .min_p_cuda [:len (min_p )].copy_ (torch .tensor (
1489
- min_p , dtype = torch .float , pin_memory = True ),
1490
- non_blocking = True )
1511
+ if self .is_advanced_mtp_sampler :
1512
+ self .temperatures_cuda [:len (temperatures )].copy_ (
1513
+ torch .tensor (temperatures ,
1514
+ dtype = torch .float ,
1515
+ pin_memory = True ),
1516
+ non_blocking = True )
1517
+ self .top_k_cuda [:len (top_k )].copy_ (torch .tensor (
1518
+ top_k , dtype = torch .int , pin_memory = True ),
1519
+ non_blocking = True )
1520
+ self .top_p_cuda [:len (top_p )].copy_ (torch .tensor (
1521
+ top_p , dtype = torch .float , pin_memory = True ),
1522
+ non_blocking = True )
1523
+ self .min_p_cuda [:len (min_p )].copy_ (torch .tensor (
1524
+ min_p , dtype = torch .float , pin_memory = True ),
1525
+ non_blocking = True )
1491
1526
1492
1527
if not attn_metadata .is_cuda_graph :
1493
1528
# Assumes seq lens do not change between CUDA graph invocations. This applies
@@ -1562,12 +1597,14 @@ def previous_seq_slots_device():
1562
1597
total_draft_lens ]
1563
1598
spec_metadata .request_ids = request_ids
1564
1599
spec_metadata .gather_ids = self .gather_ids_cuda [:len (gather_ids )]
1565
- spec_metadata .temperatures = self .temperatures_cuda [:len (temperatures )]
1566
- spec_metadata .top_k = self .top_k_cuda [:len (top_k )]
1567
- spec_metadata .top_p = self .top_p_cuda [:len (top_p )]
1568
- spec_metadata .min_p = self .min_p_cuda [:len (min_p )]
1569
- # if attn_metadata.is_cuda_graph and not torch.cuda.is_current_stream_capturing():
1570
- # spec_metadata.generator = torch.Generator(device='cpu').manual_seed(0)
1600
+
1601
+ if self .is_advanced_mtp_sampler :
1602
+ spec_metadata .temperatures = self .temperatures_cuda [:len (
1603
+ temperatures )]
1604
+ spec_metadata .top_k = self .top_k_cuda [:len (top_k )]
1605
+ spec_metadata .top_p = self .top_p_cuda [:len (top_p )]
1606
+ spec_metadata .min_p = self .min_p_cuda [:len (min_p )]
1607
+
1571
1608
spec_metadata .num_generations = len (
1572
1609
scheduled_requests .generation_requests )
1573
1610
spec_metadata .num_tokens = total_num_tokens
0 commit comments