1919from  collections .abc  import  Iterable 
2020from  dataclasses  import  dataclass 
2121from  itertools  import  repeat 
22- from  typing  import  Any , Callable , List , Optional , TypeVar , cast 
22+ from  typing  import  Any , Callable , List , Optional , Type ,  TypeVar , cast 
2323
2424import  torch 
2525import  torch .nn .functional  as  F 
5353from  tensorrt_llm .mapping  import  Mapping 
5454from  tensorrt_llm .sampling_params  import  SamplingParams 
5555
56+ from  ..flashinfer_utils  import  IS_FLASHINFER_AVAILABLE 
5657from  ..speculative .spec_tree_manager  import  SpecTreeManager 
5758from  .finish_reason  import  FinishedState 
5859from  .llm_request  import  LlmRequest , LlmRequestState , get_draft_token_length 
5960from  .resource_manager  import  ResourceManager , ResourceManagerType 
6061from  .sampling_utils  import  (
6162    GREEDY ,
6263    GenericStrategyKeyType ,
64+     GroupedStrategySampler ,
6365    SimpleGroupedStrategySampler ,
6466    Strategy ,
6567    UtilsSamplingParams ,
7173)
7274from  .scheduler  import  ScheduledRequests 
7375
76+ if  IS_FLASHINFER_AVAILABLE :
77+     from  .sampling_utils_flashinfer  import  FlashInferGroupedStrategySampler 
78+ 
7479if  sys .version_info [:2 ] >=  (3 , 12 ):
7580    from  typing  import  override 
7681else :
@@ -266,7 +271,7 @@ def _request_strategy(request: LlmRequest, *, vocab_size: int) -> Strategy:
266271def  _group_requests_by_strategy_key (
267272    requests : Iterable [LlmRequest ],
268273    * ,
269-     strategy_to_key : Callable [[Strategy ], GenericStrategyKeyType ],
274+     strategy_to_key : Callable [[Strategy ,  bool ], GenericStrategyKeyType ],
270275    pin_memory : bool  =  False ,
271276    vocab_size : int ,
272277) ->  dict [tuple [GenericStrategyKeyType , bool ], tuple [torch .Tensor , List [Strategy ]]]:
@@ -276,8 +281,8 @@ def _group_requests_by_strategy_key(
276281    )
277282    for  req_index , req  in  enumerate (requests ):
278283        strategy  =  _request_strategy (req , vocab_size = vocab_size )
279-         strategy_key  =  strategy_to_key (strategy )
280284        speculation_needs_probs  =  req .py_draft_logits  is  not None  and  strategy  is  not GREEDY 
285+         strategy_key  =  strategy_to_key (strategy , speculation_needs_probs )
281286        group_dict_entry  =  group_dict [(strategy_key , speculation_needs_probs )]
282287        group_dict_entry [0 ].append (req_index )
283288        group_dict_entry [1 ].append (strategy )
@@ -586,6 +591,7 @@ class Args:
586591        max_num_sequences : int 
587592        max_beam_width : int 
588593        max_total_draft_tokens : int 
594+         disable_flash_infer_sampling : bool  =  False 
589595
590596    def  __init__ (self , args : Args ):
591597        self .max_seq_len  =  args .max_seq_len 
@@ -602,6 +608,13 @@ def __init__(self, args: Args):
602608        with  torch .inference_mode (False ):
603609            self .store  =  self .create_store ()
604610
611+         self ._grouped_sampler_cls : Type [GroupedStrategySampler ]
612+         if  IS_FLASHINFER_AVAILABLE  and  not  args .disable_flash_infer_sampling :
613+             cls_not_possibly_unbound  =  FlashInferGroupedStrategySampler   # type: ignore 
614+             self ._grouped_sampler_cls  =  cls_not_possibly_unbound 
615+         else :
616+             self ._grouped_sampler_cls  =  SimpleGroupedStrategySampler 
617+ 
605618        # Initialize seed for multi-GPU consistency 
606619        self ._global_seed  =  42 
607620        self ._generator  =  None 
@@ -1181,7 +1194,7 @@ def _sample_batched_by_strategy(
11811194            requests ,
11821195            pin_memory = True ,
11831196            vocab_size = logits_cuda .size (1 ),
1184-             strategy_to_key = SimpleGroupedStrategySampler .strategy_grouping_key ,
1197+             strategy_to_key = self . _grouped_sampler_cls .strategy_grouping_key ,
11851198        )
11861199        generator_cuda  =  self .get_generator (cuda_device )
11871200
@@ -1238,7 +1251,7 @@ def _sample_batched_by_strategy(
12381251                for  _  in  range (steps )
12391252            ]
12401253            group_next_tokens_cuda , group_softmax_cuda  =  (
1241-                 SimpleGroupedStrategySampler .sample_grouped_strategies (
1254+                 self . _grouped_sampler_cls .sample_grouped_strategies (
12421255                    strategy_key ,
12431256                    group_strategies_per_step ,
12441257                    group_logits_cuda ,
0 commit comments