1717import  tensorrt_llm .bindings .internal .userbuffers  as  ub 
1818from  tensorrt_llm ._torch .models .checkpoints .base_checkpoint_loader  import  \
1919    BaseCheckpointLoader 
20- from  tensorrt_llm ._torch .pyexecutor .sampler  import  SampleStateTensors 
2120from  tensorrt_llm ._torch .pyexecutor .llm_request  import  LlmRequest 
21+ from  tensorrt_llm ._torch .pyexecutor .sampler  import  SampleStateTensors 
2222from  tensorrt_llm ._torch .speculative .mtp  import  SampleStateTensorsMTP 
2323from  tensorrt_llm ._utils  import  (is_trace_enabled , nvtx_range , release_gc ,
2424                                 torch_dtype_to_str , trace_func )
@@ -261,7 +261,10 @@ def __init__(
261261        lora_config : Optional [LoraConfig ] =  None ,
262262        is_draft_model : bool  =  False ,
263263    ):
264+         # Set deterministic seed for consistent multi-GPU sampling using PyTorch RNG 
265+         # operations that avoid torch.multinomial's CPU-GPU sync overhead 
264266        torch .manual_seed (0 )
267+ 
265268        self .ub_buffers  =  None 
266269        self .batch_size  =  batch_size 
267270        self .max_num_tokens  =  max_num_tokens 
@@ -278,6 +281,8 @@ def __init__(
278281        self .spec_config  =  spec_config 
279282        self .is_spec_decode  =  spec_config  is  not None 
280283        self .is_draft_model  =  is_draft_model 
284+         self .is_advanced_mtp_sampler  =  self .is_spec_decode  and  self .spec_config .spec_dec_mode .is_mtp (
285+         ) and  self .spec_config .use_advanced_mtp_sampler 
281286
282287        self .in_warmup  =  False 
283288
@@ -373,18 +378,24 @@ def __init__(
373378            self .without_logits  =  self .spec_config .spec_dec_mode .without_logits (
374379            )
375380            self .max_draft_len  =  spec_config .max_draft_len 
376-             self .temperatures_cuda  =  torch .empty ((self .batch_size  *  (self .max_draft_len  +  1 ), ),
377-                                                dtype = torch .float ,
378-                                                device = 'cuda' )
379-             self .top_k_cuda  =  torch .empty ((self .batch_size  *  (self .max_draft_len  +  1 ), ),
380-                                           dtype = torch .int ,
381-                                           device = 'cuda' )
382-             self .top_p_cuda  =  torch .empty ((self .batch_size  *  (self .max_draft_len  +  1 ), ),
383-                                           dtype = torch .float ,
384-                                           device = 'cuda' )
385-             self .min_p_cuda  =  torch .empty ((self .batch_size  *  (self .max_draft_len  +  1 ), ),
386-                                           dtype = torch .float ,
387-                                           device = 'cuda' )
381+ 
382+             if  self .is_advanced_mtp_sampler :
383+                 self .temperatures_cuda  =  torch .empty (
384+                     (self .batch_size  *  (self .max_draft_len  +  1 ), ),
385+                     dtype = torch .float ,
386+                     device = 'cuda' )
387+                 self .top_k_cuda  =  torch .empty (
388+                     (self .batch_size  *  (self .max_draft_len  +  1 ), ),
389+                     dtype = torch .int ,
390+                     device = 'cuda' )
391+                 self .top_p_cuda  =  torch .empty (
392+                     (self .batch_size  *  (self .max_draft_len  +  1 ), ),
393+                     dtype = torch .float ,
394+                     device = 'cuda' )
395+                 self .min_p_cuda  =  torch .empty (
396+                     (self .batch_size  *  (self .max_draft_len  +  1 ), ),
397+                     dtype = torch .float ,
398+                     device = 'cuda' )
388399        else :
389400            self .without_logits  =  False 
390401            self .max_draft_len  =  0 
@@ -1142,11 +1153,12 @@ def _prepare_tp_inputs(
11421153        draft_lens  =  []
11431154        multimodal_params_list  =  []
11441155        gen_request_seq_slots  =  []  # per generation request 
1145-         
1146-         temperatures  =  []
1147-         top_k  =  []
1148-         top_p  =  []
1149-         min_p  =  []
1156+ 
1157+         if  self .is_advanced_mtp_sampler :
1158+             temperatures  =  []
1159+             top_k  =  []
1160+             top_p  =  []
1161+             min_p  =  []
11501162
11511163        def  get_request_temperature (request : LlmRequest ) ->  float :
11521164            if  not  request .sampling_config .temperature :
@@ -1213,12 +1225,13 @@ def get_request_min_p(request: LlmRequest) -> float:
12131225
12141226            if  multimodal_params .has_content ():
12151227                multimodal_params_list .append (multimodal_params )
1216-                 
1217-             temperatures .append (get_request_temperature (request ))
1218-             top_k .append (get_request_top_k (request ))
1219-             top_p .append (get_request_top_p (request ))
1220-             min_p .append (get_request_min_p (request ))
1221-             
1228+ 
1229+             if  self .is_advanced_mtp_sampler :
1230+                 temperatures .append (get_request_temperature (request ))
1231+                 top_k .append (get_request_top_k (request ))
1232+                 top_p .append (get_request_top_p (request ))
1233+                 min_p .append (get_request_min_p (request ))
1234+ 
12221235            request .py_batch_idx  =  request .py_seq_slot 
12231236
12241237        num_ctx_requests  =  len (scheduled_requests .context_requests )
@@ -1301,10 +1314,17 @@ def get_request_min_p(request: LlmRequest) -> float:
13011314                              past_seen_token_num  +  1  +  num_draft_tokens )))
13021315                num_cached_tokens_per_seq .append (past_seen_token_num )
13031316                request_ids .append (request .py_request_id )
1304-                 temperatures .extend ([get_request_temperature (request )] *  (num_draft_tokens  +  1 ))
1305-                 top_k .extend ([get_request_top_k (request )] *  (num_draft_tokens  +  1 ))
1306-                 top_p .extend ([get_request_top_p (request )] *  (num_draft_tokens  +  1 ))
1307-                 min_p .extend ([get_request_min_p (request )] *  (num_draft_tokens  +  1 ))
1317+ 
1318+                 if  self .is_advanced_mtp_sampler :
1319+                     temperatures .extend ([get_request_temperature (request )] * 
1320+                                         (num_draft_tokens  +  1 ))
1321+                     top_k .extend ([get_request_top_k (request )] * 
1322+                                  (num_draft_tokens  +  1 ))
1323+                     top_p .extend ([get_request_top_p (request )] * 
1324+                                  (num_draft_tokens  +  1 ))
1325+                     min_p .extend ([get_request_min_p (request )] * 
1326+                                  (num_draft_tokens  +  1 ))
1327+ 
13081328                # update batch index 
13091329                request .py_batch_idx  =  request .py_seq_slot 
13101330            else :
@@ -1333,10 +1353,16 @@ def get_request_min_p(request: LlmRequest) -> float:
13331353                                                 self .max_draft_len  +  1 )
13341354                prompt_lengths .append (request .py_prompt_len )
13351355                request_ids .append (request .py_request_id )
1336-                 temperatures .extend ([get_request_temperature (request )] *  (self .max_draft_len  +  1 ))
1337-                 top_k .extend ([get_request_top_k (request )] *  (self .max_draft_len  +  1 ))
1338-                 top_p .extend ([get_request_top_p (request )] *  (self .max_draft_len  +  1 ))
1339-                 min_p .extend ([get_request_min_p (request )] *  (self .max_draft_len  +  1 ))
1356+ 
1357+                 if  self .is_advanced_mtp_sampler :
1358+                     temperatures .extend ([get_request_temperature (request )] * 
1359+                                         (self .max_draft_len  +  1 ))
1360+                     top_k .extend ([get_request_top_k (request )] * 
1361+                                  (self .max_draft_len  +  1 ))
1362+                     top_p .extend ([get_request_top_p (request )] * 
1363+                                  (self .max_draft_len  +  1 ))
1364+                     min_p .extend ([get_request_min_p (request )] * 
1365+                                  (self .max_draft_len  +  1 ))
13401366
13411367        for  request  in  generation_requests :
13421368            beam_width  =  request .sampling_config .beam_width 
@@ -1369,11 +1395,17 @@ def get_request_min_p(request: LlmRequest) -> float:
13691395
13701396            request_ids .append (request .py_request_id )
13711397            gen_request_seq_slots .append (request .py_seq_slot )
1372-             
1373-             temperatures .extend ([get_request_temperature (request )] *  (self .max_draft_len  +  1 ))
1374-             top_k .extend ([get_request_top_k (request )] *  (self .max_draft_len  +  1 ))
1375-             top_p .extend ([get_request_top_p (request )] *  (self .max_draft_len  +  1 ))
1376-             min_p .extend ([get_request_min_p (request )] *  (self .max_draft_len  +  1 ))
1398+ 
1399+             if  self .is_advanced_mtp_sampler :
1400+                 temperatures .extend ([get_request_temperature (request )] * 
1401+                                     (self .max_draft_len  +  1 ))
1402+                 top_k .extend ([get_request_top_k (request )] * 
1403+                              (self .max_draft_len  +  1 ))
1404+                 top_p .extend ([get_request_top_p (request )] * 
1405+                              (self .max_draft_len  +  1 ))
1406+                 min_p .extend ([get_request_min_p (request )] * 
1407+                              (self .max_draft_len  +  1 ))
1408+ 
13771409            request .py_batch_idx  =  request .py_seq_slot 
13781410
13791411        previous_batch_len  =  len (previous_batch_indices )
@@ -1477,18 +1509,21 @@ def previous_seq_slots_device():
14771509            self .gather_ids_cuda [:len (gather_ids )].copy_ (torch .tensor (
14781510                gather_ids , dtype = torch .int , pin_memory = True ),
14791511                                                         non_blocking = True )
1480-             self .temperatures_cuda [:len (temperatures )].copy_ (torch .tensor (
1481-                 temperatures , dtype = torch .float , pin_memory = True ),
1482-                                                         non_blocking = True )
1483-             self .top_k_cuda [:len (top_k )].copy_ (torch .tensor (
1484-                 top_k , dtype = torch .int , pin_memory = True ),
1485-                                                         non_blocking = True )
1486-             self .top_p_cuda [:len (top_p )].copy_ (torch .tensor (
1487-                 top_p , dtype = torch .float , pin_memory = True ),
1488-                                                         non_blocking = True )
1489-             self .min_p_cuda [:len (min_p )].copy_ (torch .tensor (
1490-                 min_p , dtype = torch .float , pin_memory = True ),
1491-                                                         non_blocking = True )
1512+             if  self .is_advanced_mtp_sampler :
1513+                 self .temperatures_cuda [:len (temperatures )].copy_ (
1514+                     torch .tensor (temperatures ,
1515+                                  dtype = torch .float ,
1516+                                  pin_memory = True ),
1517+                     non_blocking = True )
1518+                 self .top_k_cuda [:len (top_k )].copy_ (torch .tensor (
1519+                     top_k , dtype = torch .int , pin_memory = True ),
1520+                                                    non_blocking = True )
1521+                 self .top_p_cuda [:len (top_p )].copy_ (torch .tensor (
1522+                     top_p , dtype = torch .float , pin_memory = True ),
1523+                                                    non_blocking = True )
1524+                 self .min_p_cuda [:len (min_p )].copy_ (torch .tensor (
1525+                     min_p , dtype = torch .float , pin_memory = True ),
1526+                                                    non_blocking = True )
14921527
14931528        if  not  attn_metadata .is_cuda_graph :
14941529            # Assumes seq lens do not change between CUDA graph invocations. This applies 
@@ -1563,12 +1598,14 @@ def previous_seq_slots_device():
15631598                                                                total_draft_lens ]
15641599            spec_metadata .request_ids  =  request_ids 
15651600            spec_metadata .gather_ids  =  self .gather_ids_cuda [:len (gather_ids )]
1566-             spec_metadata .temperatures  =  self .temperatures_cuda [:len (temperatures )]
1567-             spec_metadata .top_k  =  self .top_k_cuda [:len (top_k )]
1568-             spec_metadata .top_p  =  self .top_p_cuda [:len (top_p )]
1569-             spec_metadata .min_p  =  self .min_p_cuda [:len (min_p )]
1570-             # if attn_metadata.is_cuda_graph and not torch.cuda.is_current_stream_capturing(): 
1571-                 # spec_metadata.generator = torch.Generator(device='cpu').manual_seed(0) 
1601+ 
1602+             if  self .is_advanced_mtp_sampler :
1603+                 spec_metadata .temperatures  =  self .temperatures_cuda [:len (
1604+                     temperatures )]
1605+                 spec_metadata .top_k  =  self .top_k_cuda [:len (top_k )]
1606+                 spec_metadata .top_p  =  self .top_p_cuda [:len (top_p )]
1607+                 spec_metadata .min_p  =  self .min_p_cuda [:len (min_p )]
1608+ 
15721609            spec_metadata .num_generations  =  len (
15731610                scheduled_requests .generation_requests )
15741611            spec_metadata .num_tokens  =  total_num_tokens 
0 commit comments