diff --git a/tensorrt_llm/_torch/pyexecutor/_util.py b/tensorrt_llm/_torch/pyexecutor/_util.py index b67a0821fd4..8da982aba2b 100644 --- a/tensorrt_llm/_torch/pyexecutor/_util.py +++ b/tensorrt_llm/_torch/pyexecutor/_util.py @@ -823,7 +823,8 @@ def create_py_executor_instance( def create_torch_sampler_args(mapping: Mapping, *, max_seq_len: int, max_batch_size: int, speculative_config: SpeculativeConfig, - max_beam_width: int): + max_beam_width: int, + disable_flash_infer_sampling: bool): max_num_sequences = max_batch_size * mapping.pp_size max_draft_len = (0 if speculative_config is None else speculative_config.max_draft_len) @@ -836,20 +837,32 @@ def create_torch_sampler_args(mapping: Mapping, *, max_seq_len: int, max_total_draft_tokens=max_total_draft_tokens, max_num_sequences=max_num_sequences, max_beam_width=max_beam_width, + disable_flash_infer_sampling=disable_flash_infer_sampling, ) def instantiate_sampler( - engine: PyTorchModelEngine, llm_args: TorchLlmArgs, mapping: Mapping, - max_batch_size: int, max_beam_width: int, max_seq_len: int, - mm_encoder_only: bool, speculative_config: SpeculativeConfig, - decoding_config: trtllm.DecodingConfig, kv_cache_config: KvCacheConfig): + engine: PyTorchModelEngine, + llm_args: TorchLlmArgs, + mapping: Mapping, + *, + max_batch_size: int, + max_beam_width: int, + max_seq_len: int, + mm_encoder_only: bool, + speculative_config: SpeculativeConfig, + decoding_config: trtllm.DecodingConfig, + kv_cache_config: KvCacheConfig, + disable_flash_infer_sampling: bool, +): sampler_args = create_torch_sampler_args( mapping, max_seq_len=engine.max_seq_len, max_batch_size=max_batch_size, speculative_config=speculative_config, - max_beam_width=max_beam_width) + max_beam_width=max_beam_width, + disable_flash_infer_sampling=disable_flash_infer_sampling, + ) decoding_mode = get_decoding_mode(decoding_config=decoding_config, max_beam_width=max_beam_width) if mapping.cp_config.get('cp_type') == CpType.STAR: diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py index 813585950c9..be407a47ed4 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py @@ -493,16 +493,19 @@ def drafting_loop_wrapper(model): ) with allocation_scope(ExecutorMemoryType.SAMPLER, RestoreMode.PINNED): - sampler = instantiate_sampler(model_engine, - llm_args, - mapping, - max_batch_size=max_batch_size, - max_beam_width=max_beam_width, - max_seq_len=max_seq_len, - mm_encoder_only=mm_encoder_only, - speculative_config=spec_config, - decoding_config=decoding_config, - kv_cache_config=kv_cache_config) + sampler = instantiate_sampler( + model_engine, + llm_args, + mapping, + max_batch_size=max_batch_size, + max_beam_width=max_beam_width, + max_seq_len=max_seq_len, + mm_encoder_only=mm_encoder_only, + speculative_config=spec_config, + decoding_config=decoding_config, + kv_cache_config=kv_cache_config, + disable_flash_infer_sampling=llm_args._disable_flash_infer_sampling, + ) logger.info(f"Using Sampler: {type(sampler).__name__}") if kv_connector_config is not None: diff --git a/tensorrt_llm/_torch/pyexecutor/sampler.py b/tensorrt_llm/_torch/pyexecutor/sampler.py index 276aa977003..88c0a57f715 100644 --- a/tensorrt_llm/_torch/pyexecutor/sampler.py +++ b/tensorrt_llm/_torch/pyexecutor/sampler.py @@ -19,7 +19,7 @@ from collections.abc import Iterable from dataclasses import dataclass from itertools import repeat -from typing import Any, Callable, List, Optional, TypeVar, cast +from typing import Any, Callable, List, Optional, Type, TypeVar, cast import torch import torch.nn.functional as F @@ -53,6 +53,7 @@ from tensorrt_llm.mapping import Mapping from tensorrt_llm.sampling_params import SamplingParams +from ..flashinfer_utils import IS_FLASHINFER_AVAILABLE from ..speculative.spec_tree_manager import SpecTreeManager from .finish_reason import FinishedState from .llm_request import LlmRequest, LlmRequestState, get_draft_token_length @@ -60,6 +61,7 @@ from .sampling_utils import ( GREEDY, GenericStrategyKeyType, + GroupedStrategySampler, SimpleGroupedStrategySampler, Strategy, UtilsSamplingParams, @@ -266,7 +268,7 @@ def _request_strategy(request: LlmRequest, *, vocab_size: int) -> Strategy: def _group_requests_by_strategy_key( requests: Iterable[LlmRequest], *, - strategy_to_key: Callable[[Strategy], GenericStrategyKeyType], + strategy_to_key: Callable[[Strategy, bool], GenericStrategyKeyType], pin_memory: bool = False, vocab_size: int, ) -> dict[tuple[GenericStrategyKeyType, bool], tuple[torch.Tensor, List[Strategy]]]: @@ -276,8 +278,8 @@ def _group_requests_by_strategy_key( ) for req_index, req in enumerate(requests): strategy = _request_strategy(req, vocab_size=vocab_size) - strategy_key = strategy_to_key(strategy) speculation_needs_probs = req.py_draft_logits is not None and strategy is not GREEDY + strategy_key = strategy_to_key(strategy, speculation_needs_probs) group_dict_entry = group_dict[(strategy_key, speculation_needs_probs)] group_dict_entry[0].append(req_index) group_dict_entry[1].append(strategy) @@ -586,6 +588,7 @@ class Args: max_num_sequences: int max_beam_width: int max_total_draft_tokens: int + disable_flash_infer_sampling: bool = False def __init__(self, args: Args): self.max_seq_len = args.max_seq_len @@ -602,6 +605,14 @@ def __init__(self, args: Args): with torch.inference_mode(False): self.store = self.create_store() + self._grouped_sampler_cls: Type[GroupedStrategySampler] + if IS_FLASHINFER_AVAILABLE and not args.disable_flash_infer_sampling: + from .sampling_utils_flashinfer import FlashInferGroupedStrategySampler + + self._grouped_sampler_cls = FlashInferGroupedStrategySampler + else: + self._grouped_sampler_cls = SimpleGroupedStrategySampler + # Initialize seed for multi-GPU consistency self._global_seed = 42 self._generator = None @@ -1181,7 +1192,7 @@ def _sample_batched_by_strategy( requests, pin_memory=True, vocab_size=logits_cuda.size(1), - strategy_to_key=SimpleGroupedStrategySampler.strategy_grouping_key, + strategy_to_key=self._grouped_sampler_cls.strategy_grouping_key, ) generator_cuda = self.get_generator(cuda_device) @@ -1238,7 +1249,7 @@ def _sample_batched_by_strategy( for _ in range(steps) ] group_next_tokens_cuda, group_softmax_cuda = ( - SimpleGroupedStrategySampler.sample_grouped_strategies( + self._grouped_sampler_cls.sample_grouped_strategies( strategy_key, group_strategies_per_step, group_logits_cuda, diff --git a/tensorrt_llm/_torch/pyexecutor/sampling_utils.py b/tensorrt_llm/_torch/pyexecutor/sampling_utils.py index 0ea3494aa1c..35e64afe4c2 100644 --- a/tensorrt_llm/_torch/pyexecutor/sampling_utils.py +++ b/tensorrt_llm/_torch/pyexecutor/sampling_utils.py @@ -33,13 +33,13 @@ from typing_extensions import override -TemperatureOnly = tuple[Literal["temperature"], float] -TopK = tuple[Literal["top_k"], int, float] -TopP = tuple[Literal["top_p"], float, float] -TopKTopP = tuple[Literal["top_k_top_p"], int, float, float] -Greedy = tuple[Literal["greedy"], None] +TemperatureOnly: TypeAlias = tuple[Literal["temperature"], float] +TopK: TypeAlias = tuple[Literal["top_k"], int, float] +TopP: TypeAlias = tuple[Literal["top_p"], float, float] +TopKTopP: TypeAlias = tuple[Literal["top_k_top_p"], int, float, float] +Greedy: TypeAlias = tuple[Literal["greedy"], None] GREEDY: Greedy = ("greedy", None) -Strategy = TopK | TopP | Greedy | TopKTopP | TemperatureOnly +Strategy: TypeAlias = TopK | TopP | Greedy | TopKTopP | TemperatureOnly @dataclass(frozen=True, kw_only=True) @@ -258,7 +258,10 @@ def sample( match strategy: case ("top_k", top_k, temperature): tokens, softmax = top_k_sampling_batch( - logits, top_k=top_k, temperature=temperature, generator=generator + logits, + top_k=top_k, + temperature=temperature, + generator=generator, ) case ("top_p", top_p, temperature): tokens, softmax = top_p_sampling_batch( @@ -292,7 +295,7 @@ def sample( class GroupedStrategySampler(Generic[GenericStrategyKeyType], abc.ABC): @staticmethod @abc.abstractmethod - def strategy_grouping_key(strategy: Strategy) -> GenericStrategyKeyType: + def strategy_grouping_key(strategy: Strategy, return_probs: bool) -> GenericStrategyKeyType: raise NotImplementedError @staticmethod @@ -314,7 +317,7 @@ class SimpleGroupedStrategySampler(GroupedStrategySampler[Strategy]): @override @staticmethod - def strategy_grouping_key(strategy: Strategy) -> STRATEGY_KEY_TYPE: + def strategy_grouping_key(strategy: Strategy, return_probs: bool) -> STRATEGY_KEY_TYPE: return strategy @override diff --git a/tensorrt_llm/_torch/pyexecutor/sampling_utils_flashinfer.py b/tensorrt_llm/_torch/pyexecutor/sampling_utils_flashinfer.py new file mode 100644 index 00000000000..f5da51de6bc --- /dev/null +++ b/tensorrt_llm/_torch/pyexecutor/sampling_utils_flashinfer.py @@ -0,0 +1,579 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helper functions for using FlashInfer.sampling. + +Code in this module should operate on logits and probs, without +referring to types like LlmRequest. +""" + +import abc +import sys +from typing import Optional, Type, TypeAlias, cast + +import flashinfer.sampling +import torch + +if sys.version_info[:2] >= (3, 12): + from typing import override +else: + from typing_extensions import override + +from ..flashinfer_utils import ENABLE_PDL +from .sampling_utils import ( + GREEDY, + GroupedStrategySampler, + Strategy, + TemperatureOnly, + TopK, + TopKTopP, + TopP, + greedy_search_sampling_batch, +) + + +class _StrategyImpls: + class StrategyImpl(abc.ABC): + @classmethod + @abc.abstractmethod + def from_strategies( + cls, strategies: list[Strategy], cuda_device: torch.device + ) -> "_StrategyImpls.StrategyImpl": + pass + + @classmethod + @abc.abstractmethod + def computes_probs(cls) -> bool: + pass + + @abc.abstractmethod + def sample( + self, + logits: torch.Tensor, + *, + group_logit_indices: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + pass + + # TODO: Revisit this after determining performance impact + # + # NB: NaN logits can lead to crashes, see + # https://github.com/flashinfer-ai/flashinfer/issues/1575 + # + @staticmethod + def _flashinfer_check_nans(inputs: torch.Tensor) -> bool: + # Using explicit async NaN check because FlashInfer.sampling 'nan_check' syncs + + # https://github.com/pytorch/pytorch/issues/36853 + torch._assert_async(~torch.any(torch.isnan(inputs))) + + return False + + @staticmethod + def _make_tensor(data: list, dtype: torch.dtype, device: torch.device) -> torch.Tensor: + return torch.tensor(data, dtype=dtype, pin_memory=True).to( + device=device, non_blocking=True + ) + + @staticmethod + def _prepare_logits_with_temperature( + logits: torch.Tensor, + group_logit_indices: Optional[torch.Tensor], + temperature: Optional[torch.Tensor], + ) -> torch.Tensor: + if temperature is not None: + temperature = temperature.unsqueeze(-1) + if group_logit_indices is not None: + logits = torch.index_select(logits, 0, group_logit_indices) # ensures copy + logits /= temperature + else: + logits = logits / temperature # not inplace + elif group_logit_indices is not None: + logits = logits[group_logit_indices] + return logits + + @staticmethod + def _prepare_probs_with_temperature( + logits: torch.Tensor, + group_logit_indices: Optional[torch.Tensor], + temperature: Optional[torch.Tensor], + ) -> torch.Tensor: + if group_logit_indices is not None: + logits = logits[group_logit_indices] + logits = flashinfer.sampling.softmax( + logits, + temperature, + enable_pdl=ENABLE_PDL, + ) + return logits + + @classmethod + def _sample_from_probs( + cls, + probs: torch.Tensor, + generator: Optional[torch.Generator], + ) -> torch.Tensor: + new_tokens = flashinfer.sampling.sampling_from_probs( + probs, + deterministic=True, + generator=generator, + check_nan=cls._flashinfer_check_nans(probs), + ) + return new_tokens + + def _sample_greedy_with_probs( + self, + logits: torch.Tensor, + *, + group_logit_indices: Optional[torch.Tensor], + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + probs = self._prepare_probs_with_temperature(logits, group_logit_indices, None) + new_tokens, _ = greedy_search_sampling_batch(probs, return_probs=False) + return new_tokens, probs + + @classmethod + def _sample_with_probs( + cls, + logits: torch.Tensor, + *, + group_logit_indices: Optional[torch.Tensor], + top_k: Optional[torch.Tensor], + top_p: Optional[torch.Tensor], + temperature: Optional[torch.Tensor], + generator: Optional[torch.Generator], + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + if top_k is not None: + logits = cls._prepare_logits_with_temperature( + logits, group_logit_indices, temperature + ) + logits = flashinfer.sampling.top_k_mask_logits(logits, top_k) + probs = cls._prepare_probs_with_temperature(logits, None, None) # plain softmax + else: + probs = cls._prepare_probs_with_temperature( + logits, group_logit_indices, temperature + ) + + if top_p is not None: + probs = flashinfer.sampling.top_p_renorm_probs(probs, top_p) + + new_tokens = cls._sample_from_probs(probs, generator=generator) + return new_tokens, probs + + class StrategyImplWithProbs(StrategyImpl): + @override + @classmethod + def computes_probs(cls) -> bool: + return True + + class GreedyWithProbs(StrategyImplWithProbs): + @override + @classmethod + def from_strategies( + cls, strategies: list[Strategy], cuda_device: torch.device + ) -> "_StrategyImpls.GreedyWithProbs": + assert all(strat == GREEDY for strat in strategies) + return cls() + + @override + def sample( + self, + logits: torch.Tensor, + *, + group_logit_indices: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + return self._sample_greedy_with_probs(logits, group_logit_indices=group_logit_indices) + + class TopKTopPWithProbs(StrategyImplWithProbs): + def __init__(self, top_k: torch.Tensor, top_p: torch.Tensor, temperature: torch.Tensor): + self._top_k = top_k + self._top_p = top_p + self._temperature = temperature + + @override + @classmethod + def from_strategies( + cls, strategies: list[Strategy], cuda_device: torch.device + ) -> "_StrategyImpls.TopKTopPWithProbs": + assert all(strat[0] == "top_k_top_p" for strat in strategies) + narrowed_strats = cast(list[TopKTopP], strategies) + top_k = cls._make_tensor( + [strat[1] for strat in narrowed_strats], torch.int32, cuda_device + ) + top_p = cls._make_tensor( + [strat[2] for strat in narrowed_strats], torch.float32, cuda_device + ) + temperature = cls._make_tensor( + [strat[3] for strat in narrowed_strats], torch.float32, cuda_device + ) + return cls(top_k, top_p, temperature) + + @override + def sample( + self, + logits: torch.Tensor, + *, + group_logit_indices: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + new_tokens, probs = self._sample_with_probs( + logits, + group_logit_indices=group_logit_indices, + top_k=self._top_k, + top_p=self._top_p, + temperature=self._temperature, + generator=generator, + ) + return new_tokens, probs + + class TopKWithProbs(StrategyImplWithProbs): + def __init__(self, top_k: torch.Tensor, temperature: torch.Tensor): + self._top_k = top_k + self._temperature = temperature + + @override + @classmethod + def from_strategies( + cls, strategies: list[Strategy], cuda_device: torch.device + ) -> "_StrategyImpls.TopKWithProbs": + assert all(strat[0] == "top_k" for strat in strategies) + narrowed_strats = cast(list[TopK], strategies) + top_k = cls._make_tensor( + [strat[1] for strat in narrowed_strats], torch.int32, cuda_device + ) + temperature = cls._make_tensor( + [strat[2] for strat in narrowed_strats], torch.float32, cuda_device + ) + return cls(top_k, temperature) + + @override + def sample( + self, + logits: torch.Tensor, + *, + group_logit_indices: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + new_tokens, probs = self._sample_with_probs( + logits, + group_logit_indices=group_logit_indices, + top_k=self._top_k, + top_p=None, + temperature=self._temperature, + generator=generator, + ) + return new_tokens, probs + + class TopPWithProbs(StrategyImplWithProbs): + def __init__(self, top_p: torch.Tensor, temperature: torch.Tensor): + self._top_p = top_p + self._temperature = temperature + + @override + @classmethod + def from_strategies( + cls, strategies: list[Strategy], cuda_device: torch.device + ) -> "_StrategyImpls.TopPWithProbs": + assert all(strat[0] == "top_p" for strat in strategies) + narrowed_strats = cast(list[TopP], strategies) + top_p = cls._make_tensor( + [strat[1] for strat in narrowed_strats], torch.float32, cuda_device + ) + temperature = cls._make_tensor( + [strat[2] for strat in narrowed_strats], torch.float32, cuda_device + ) + return cls(top_p, temperature) + + @override + def sample( + self, + logits: torch.Tensor, + *, + group_logit_indices: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + new_tokens, probs = self._sample_with_probs( + logits, + group_logit_indices=group_logit_indices, + top_k=None, + top_p=self._top_p, + temperature=self._temperature, + generator=generator, + ) + return new_tokens, probs + + class TemperatureOnlyWithProbs(StrategyImplWithProbs): + def __init__(self, temperature: torch.Tensor): + self._temperature = temperature + + @override + @classmethod + def from_strategies( + cls, strategies: list[Strategy], cuda_device: torch.device + ) -> "_StrategyImpls.TemperatureOnlyWithProbs": + assert all(strat[0] == "temperature" for strat in strategies) + narrowed_strats = cast(list[TemperatureOnly], strategies) + temperature = cls._make_tensor( + [strat[1] for strat in narrowed_strats], torch.float32, cuda_device + ) + return cls(temperature) + + @override + def sample( + self, + logits: torch.Tensor, + *, + group_logit_indices: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + new_tokens, probs = self._sample_with_probs( + logits, + group_logit_indices=group_logit_indices, + top_k=None, + top_p=None, + temperature=self._temperature, + generator=generator, + ) + return new_tokens, probs + + class StrategyImplSampleOnly(StrategyImpl): + @override + @classmethod + def computes_probs(cls) -> bool: + return False + + class GreedySampleOnly(StrategyImplSampleOnly): + @override + @classmethod + def from_strategies( + cls, strategies: list[Strategy], cuda_device: torch.device + ) -> "_StrategyImpls.GreedySampleOnly": + assert all(strat == GREEDY for strat in strategies) + return cls() + + @override + def sample( + self, + logits: torch.Tensor, + *, + group_logit_indices: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + if group_logit_indices is not None: + logits = logits[group_logit_indices] + return greedy_search_sampling_batch(logits, return_probs=False) + + class TopKTopPSampleOnly(StrategyImplSampleOnly): + def __init__(self, top_k: torch.Tensor, top_p: torch.Tensor, temperature: torch.Tensor): + self._top_k = top_k + self._top_p = top_p + self._temperature = temperature + + @override + @classmethod + def from_strategies( + cls, strategies: list[Strategy], cuda_device: torch.device + ) -> "_StrategyImpls.TopKTopPSampleOnly": + assert all(strat[0] in ["top_k_top_p", "top_k"] for strat in strategies) + narrowed_strats = cast(list[TopKTopP | TopK], strategies) + top_k_list = [] + top_p_list = [] + temperature_list = [] + for strat in narrowed_strats: + top_k_list.append(strat[1]) + if strat[0] == "top_k_top_p": + top_p_list.append(strat[2]) + temperature_list.append(strat[3]) + else: + top_p_list.append(1.0) + temperature_list.append(strat[2]) + top_k = cls._make_tensor(top_k_list, torch.int32, cuda_device) + top_p = cls._make_tensor(top_p_list, torch.float32, cuda_device) + temperature = cls._make_tensor(temperature_list, torch.float32, cuda_device) + return cls(top_k, top_p, temperature) + + @override + def sample( + self, + logits: torch.Tensor, + *, + group_logit_indices: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + logits = self._prepare_logits_with_temperature( + logits, group_logit_indices, self._temperature + ) + return flashinfer.sampling.top_k_top_p_sampling_from_logits( + logits, + top_k=self._top_k, + top_p=self._top_p, + # NB: Leveraging 'indices' would require applying temperature+softmax before batching, + # because 'flashinfer.sampling.softmax' has no 'indices' argument; but that would + # compute unnecessarily softmax also for situations allowing + # flashinfer.sampling...._sampling_from_logits. + # indices=group_logit_indices, + filter_apply_order="top_k_first", + deterministic=True, + check_nan=self._flashinfer_check_nans(logits), + generator=generator, + ), None + + class TopPSampleOnly(StrategyImplSampleOnly): + def __init__(self, top_p: torch.Tensor, temperature: torch.Tensor): + self._top_p = top_p + self._temperature = temperature + + @override + @classmethod + def from_strategies( + cls, strategies: list[Strategy], cuda_device: torch.device + ) -> "_StrategyImpls.TopPSampleOnly": + assert all(strat[0] == "top_p" for strat in strategies) + narrowed_strats = cast(list[TopP], strategies) + top_p = cls._make_tensor( + [strat[1] for strat in narrowed_strats], torch.float32, cuda_device + ) + temperature = cls._make_tensor( + [strat[2] for strat in narrowed_strats], torch.float32, cuda_device + ) + return cls(top_p, temperature) + + @override + def sample( + self, + logits: torch.Tensor, + *, + group_logit_indices: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + probs = self._prepare_probs_with_temperature( + logits, group_logit_indices, self._temperature + ) + return flashinfer.sampling.top_p_sampling_from_probs( + probs, + top_p=self._top_p, + # NB: Leveraging 'indices' would require applying temperature+softmax before batching, + # because 'flashinfer.sampling.softmax' has no 'indices' argument; but that would + # compute unnecessarily softmax also for situations allowing + # flashinfer.sampling...._sampling_from_logits. + # indices=group_logit_indices, + deterministic=True, + check_nan=self._flashinfer_check_nans(probs), + generator=generator, + ), None + + class TemperatureOnlySampleOnly(StrategyImplSampleOnly): + def __init__(self, temperature: torch.Tensor): + self._temperature = temperature + + @override + @classmethod + def from_strategies( + cls, strategies: list[Strategy], cuda_device: torch.device + ) -> "_StrategyImpls.TemperatureOnlySampleOnly": + assert all(strat[0] == "temperature" for strat in strategies) + narrowed_strats = cast(list[TemperatureOnly], strategies) + temperature = cls._make_tensor( + [strat[1] for strat in narrowed_strats], torch.float32, cuda_device + ) + return cls(temperature) + + @override + def sample( + self, + logits: torch.Tensor, + *, + group_logit_indices: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + logits = self._prepare_logits_with_temperature( + logits, group_logit_indices, self._temperature + ) + new_tokens = flashinfer.sampling.sampling_from_logits( + logits, + # NB: Leveraging 'indices' would require applying temperature+softmax before batching, + # because 'flashinfer.sampling.softmax' has no 'indices' argument; but that would + # compute unnecessarily softmax also for situations allowing + # flashinfer.sampling...._sampling_from_logits. + # indices=group_logit_indices, + deterministic=True, + generator=generator, + check_nan=self._flashinfer_check_nans(logits), + ) + return new_tokens, None + + +class FlashInferGroupedStrategySampler(GroupedStrategySampler[Type[_StrategyImpls.StrategyImpl]]): + """Implements batched sampling with FlashInfer.sampling kernels. + + Note: Currently, FlashInfer.sampling appears to have limited CUDA graph + support, see https://github.com/flashinfer-ai/flashinfer/issues/978. + """ + + STRATEGY_KEY_TYPE: TypeAlias = Type[_StrategyImpls.StrategyImpl] + + @override + @staticmethod + def strategy_grouping_key(strategy: Strategy, return_probs: bool) -> STRATEGY_KEY_TYPE: + if return_probs: + match strategy: + case ("top_k", _, _): + return _StrategyImpls.TopKWithProbs + case ("top_p", _, _): + return _StrategyImpls.TopPWithProbs + case ("top_k_top_p", _, _, _): + return _StrategyImpls.TopKTopPWithProbs + case ("temperature", _): + return _StrategyImpls.TemperatureOnlyWithProbs + case ("greedy", None): + return _StrategyImpls.GreedyWithProbs + else: + match strategy: + case ("top_p", _, _): + return _StrategyImpls.TopPSampleOnly + case ("top_k_top_p", _, _, _) | ("top_k", _, _): + # NB: There is no TopKSampleOnly, because FlashInfer only provides + # top_k_sampling_from_probs (not top_k_sampling_from_logits), + # which is likely slower than top_k_top_p_sampling_from_logits. + return _StrategyImpls.TopKTopPSampleOnly + case ("temperature", _): + return _StrategyImpls.TemperatureOnlySampleOnly + case ("greedy", None): + return _StrategyImpls.GreedySampleOnly + + @override + @staticmethod + def sample_grouped_strategies( + group_key: STRATEGY_KEY_TYPE, + strategies: list[Strategy], + logits: torch.Tensor, + *, + group_logit_indices: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, + return_probs: bool, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + if group_logit_indices is None: + assert logits.size(0) == len(strategies) + else: + assert group_logit_indices.size(0) == len(strategies) + + assert return_probs == group_key.computes_probs() + + strategy_impl_cls = group_key + return strategy_impl_cls.from_strategies(strategies, cuda_device=logits.device).sample( + logits, + group_logit_indices=group_logit_indices, + generator=generator, + ) diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 1cc5b373341..ddd3f5b6164 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -2593,6 +2593,9 @@ class TorchLlmArgs(BaseLlmArgs): # PrivateVars _quant_config: Optional[QuantConfig] = PrivateAttr(default=None) + _disable_flash_infer_sampling: bool = PrivateAttr(default=True) + """Unless this is set to False, FlashInfer.sampling is not used, even if available.""" + @property def quant_config(self) -> QuantConfig: if self._quant_config is None: