Skip to content

Commit f2af5a0

Browse files
committed
feat: sampling using FlashInfer.sampling
Signed-off-by: ixlmar <[email protected]>
1 parent a0024f4 commit f2af5a0

File tree

7 files changed

+625
-18
lines changed

7 files changed

+625
-18
lines changed

pyproject.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ extend_skip_glob = [
3737
"tests/unittest/_torch/sampler/test_torch_sampler.py",
3838
"tensorrt_llm/_torch/pyexecutor/sampler.py",
3939
"tensorrt_llm/_torch/pyexecutor/sampling_utils.py",
40+
"tensorrt_llm/_torch/pyexecutor/sampling_utils_flashinfer.py",
4041
]
4142

4243
[tool.yapf]
@@ -71,6 +72,7 @@ ignore_patterns = [
7172
"tests/unittest/_torch/sampler/test_torch_sampler.py",
7273
"tensorrt_llm/_torch/pyexecutor/sampler.py",
7374
"tensorrt_llm/_torch/pyexecutor/sampling_utils.py",
75+
"tensorrt_llm/_torch/pyexecutor/sampling_utils_flashinfer.py",
7476
]
7577

7678
[tool.codespell]
@@ -108,6 +110,7 @@ exclude = [
108110
"tests/unittest/_torch/models/checkpoints/hf/test_weight_loader.py",
109111
"tensorrt_llm/_torch/pyexecutor/sampler.py",
110112
"tensorrt_llm/_torch/pyexecutor/sampling_utils.py",
113+
"tensorrt_llm/_torch/pyexecutor/sampling_utils_flashinfer.py",
111114
]
112115

113116

@@ -155,6 +158,7 @@ include = [
155158
"tests/unittest/_torch/sampler/test_torch_sampler.py",
156159
"tensorrt_llm/_torch/pyexecutor/sampler.py",
157160
"tensorrt_llm/_torch/pyexecutor/sampling_utils.py",
161+
"tensorrt_llm/_torch/pyexecutor/sampling_utils_flashinfer.py",
158162
]
159163
exclude = [
160164
"**3rdparty/**",

tensorrt_llm/_torch/pyexecutor/_util.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -820,7 +820,8 @@ def create_py_executor_instance(
820820
def create_torch_sampler_args(mapping: Mapping, *, max_seq_len: int,
821821
max_batch_size: int,
822822
speculative_config: SpeculativeConfig,
823-
max_beam_width: int):
823+
max_beam_width: int,
824+
disable_flash_infer_sampling: bool):
824825
max_num_sequences = max_batch_size * mapping.pp_size
825826
max_draft_len = (0 if speculative_config is None else
826827
speculative_config.max_draft_len)
@@ -838,6 +839,7 @@ def create_torch_sampler_args(mapping: Mapping, *, max_seq_len: int,
838839
max_total_draft_tokens=max_total_draft_tokens,
839840
max_num_sequences=max_num_sequences,
840841
max_beam_width=max_beam_width,
842+
disable_flash_infer_sampling=disable_flash_infer_sampling,
841843
)
842844

843845

@@ -847,13 +849,17 @@ def instantiate_sampler(engine: PyTorchModelEngine,
847849
max_seq_len: int, mm_encoder_only: bool,
848850
speculative_config: SpeculativeConfig,
849851
decoding_config: trtllm.DecodingConfig,
850-
kv_cache_config: KvCacheConfig):
852+
kv_cache_config: KvCacheConfig,
853+
disable_flash_infer_sampling: bool,
854+
):
851855
sampler_args = create_torch_sampler_args(
852856
mapping,
853857
max_seq_len=engine.max_seq_len,
854858
max_batch_size=max_batch_size,
855859
speculative_config=speculative_config,
856-
max_beam_width=max_beam_width)
860+
max_beam_width=max_beam_width,
861+
disable_flash_infer_sampling=disable_flash_infer_sampling,
862+
)
857863
decoding_mode = get_decoding_mode(decoding_config=decoding_config,
858864
max_beam_width=max_beam_width)
859865
if mapping.cp_config.get('cp_type') == CpType.STAR:

tensorrt_llm/_torch/pyexecutor/py_executor_creator.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ def get_guided_decoding_config(guided_decoding_backend: str,
201201

202202
def create_py_executor(
203203
llm_args: TorchLlmArgs,
204-
checkpoint_dir: str = None,
204+
checkpoint_dir: Optional[str] = None,
205205
tokenizer: Optional[TokenizerBase] = None,
206206
profiling_stage_data: Optional[dict] = None,
207207
) -> PyExecutor:
@@ -501,7 +501,9 @@ def drafting_loop_wrapper(model):
501501
mm_encoder_only=mm_encoder_only,
502502
speculative_config=spec_config,
503503
decoding_config=decoding_config,
504-
kv_cache_config=kv_cache_config)
504+
kv_cache_config=kv_cache_config,
505+
disable_flash_infer_sampling=llm_args._disable_flash_infer_sampling,
506+
)
505507
logger.info(f"Using Sampler: {type(sampler).__name__}")
506508

507509
if kv_connector_config is not None:

tensorrt_llm/_torch/pyexecutor/sampler.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from collections.abc import Iterable
2020
from dataclasses import dataclass
2121
from itertools import repeat
22-
from typing import Any, Callable, List, Optional, TypeVar, cast
22+
from typing import Any, Callable, List, Optional, Type, TypeVar, cast
2323

2424
import torch
2525
import torch.nn.functional as F
@@ -53,13 +53,15 @@
5353
from tensorrt_llm.mapping import Mapping
5454
from tensorrt_llm.sampling_params import SamplingParams
5555

56+
from ..flashinfer_utils import IS_FLASHINFER_AVAILABLE
5657
from ..speculative.spec_tree_manager import SpecTreeManager
5758
from .finish_reason import FinishedState
5859
from .llm_request import LlmRequest, LlmRequestState, get_draft_token_length
5960
from .resource_manager import ResourceManager, ResourceManagerType
6061
from .sampling_utils import (
6162
GREEDY,
6263
GenericStrategyKeyType,
64+
GroupedStrategySampler,
6365
SimpleGroupedStrategySampler,
6466
Strategy,
6567
UtilsSamplingParams,
@@ -71,6 +73,9 @@
7173
)
7274
from .scheduler import ScheduledRequests
7375

76+
if IS_FLASHINFER_AVAILABLE:
77+
from .sampling_utils_flashinfer import FlashInferGroupedStrategySampler
78+
7479
if sys.version_info[:2] >= (3, 12):
7580
from typing import override
7681
else:
@@ -266,7 +271,7 @@ def _request_strategy(request: LlmRequest, *, vocab_size: int) -> Strategy:
266271
def _group_requests_by_strategy_key(
267272
requests: Iterable[LlmRequest],
268273
*,
269-
strategy_to_key: Callable[[Strategy], GenericStrategyKeyType],
274+
strategy_to_key: Callable[[Strategy, bool], GenericStrategyKeyType],
270275
pin_memory: bool = False,
271276
vocab_size: int,
272277
) -> dict[tuple[GenericStrategyKeyType, bool], tuple[torch.Tensor, List[Strategy]]]:
@@ -276,8 +281,8 @@ def _group_requests_by_strategy_key(
276281
)
277282
for req_index, req in enumerate(requests):
278283
strategy = _request_strategy(req, vocab_size=vocab_size)
279-
strategy_key = strategy_to_key(strategy)
280284
speculation_needs_probs = req.py_draft_logits is not None and strategy is not GREEDY
285+
strategy_key = strategy_to_key(strategy, speculation_needs_probs)
281286
group_dict_entry = group_dict[(strategy_key, speculation_needs_probs)]
282287
group_dict_entry[0].append(req_index)
283288
group_dict_entry[1].append(strategy)
@@ -586,6 +591,7 @@ class Args:
586591
max_num_sequences: int
587592
max_beam_width: int
588593
max_total_draft_tokens: int
594+
disable_flash_infer_sampling: bool = False
589595

590596
def __init__(self, args: Args):
591597
self.max_seq_len = args.max_seq_len
@@ -602,6 +608,13 @@ def __init__(self, args: Args):
602608
with torch.inference_mode(False):
603609
self.store = self.create_store()
604610

611+
self._grouped_sampler_cls: Type[GroupedStrategySampler]
612+
if IS_FLASHINFER_AVAILABLE and not args.disable_flash_infer_sampling:
613+
cls_not_possibly_unbound = FlashInferGroupedStrategySampler # type: ignore
614+
self._grouped_sampler_cls = cls_not_possibly_unbound
615+
else:
616+
self._grouped_sampler_cls = SimpleGroupedStrategySampler
617+
605618
# Initialize seed for multi-GPU consistency
606619
self._global_seed = 42
607620
self._generator = None
@@ -1181,7 +1194,7 @@ def _sample_batched_by_strategy(
11811194
requests,
11821195
pin_memory=True,
11831196
vocab_size=logits_cuda.size(1),
1184-
strategy_to_key=SimpleGroupedStrategySampler.strategy_grouping_key,
1197+
strategy_to_key=self._grouped_sampler_cls.strategy_grouping_key,
11851198
)
11861199
generator_cuda = self.get_generator(cuda_device)
11871200

@@ -1238,7 +1251,7 @@ def _sample_batched_by_strategy(
12381251
for _ in range(steps)
12391252
]
12401253
group_next_tokens_cuda, group_softmax_cuda = (
1241-
SimpleGroupedStrategySampler.sample_grouped_strategies(
1254+
self._grouped_sampler_cls.sample_grouped_strategies(
12421255
strategy_key,
12431256
group_strategies_per_step,
12441257
group_logits_cuda,

tensorrt_llm/_torch/pyexecutor/sampling_utils.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,13 @@
3333
from typing_extensions import override
3434

3535

36-
TemperatureOnly = tuple[Literal["temperature"], float]
37-
TopK = tuple[Literal["top_k"], int, float]
38-
TopP = tuple[Literal["top_p"], float, float]
39-
TopKTopP = tuple[Literal["top_k_top_p"], int, float, float]
40-
Greedy = tuple[Literal["greedy"], None]
36+
TemperatureOnly: TypeAlias = tuple[Literal["temperature"], float]
37+
TopK: TypeAlias = tuple[Literal["top_k"], int, float]
38+
TopP: TypeAlias = tuple[Literal["top_p"], float, float]
39+
TopKTopP: TypeAlias = tuple[Literal["top_k_top_p"], int, float, float]
40+
Greedy: TypeAlias = tuple[Literal["greedy"], None]
4141
GREEDY: Greedy = ("greedy", None)
42-
Strategy = TopK | TopP | Greedy | TopKTopP | TemperatureOnly
42+
Strategy: TypeAlias = TopK | TopP | Greedy | TopKTopP | TemperatureOnly
4343

4444

4545
@dataclass(frozen=True, kw_only=True)
@@ -292,7 +292,7 @@ def sample(
292292
class GroupedStrategySampler(Generic[GenericStrategyKeyType], abc.ABC):
293293
@staticmethod
294294
@abc.abstractmethod
295-
def strategy_grouping_key(strategy: Strategy) -> GenericStrategyKeyType:
295+
def strategy_grouping_key(strategy: Strategy, return_probs: bool) -> GenericStrategyKeyType:
296296
raise NotImplementedError
297297

298298
@staticmethod
@@ -314,7 +314,7 @@ class SimpleGroupedStrategySampler(GroupedStrategySampler[Strategy]):
314314

315315
@override
316316
@staticmethod
317-
def strategy_grouping_key(strategy: Strategy) -> STRATEGY_KEY_TYPE:
317+
def strategy_grouping_key(strategy: Strategy, return_probs: bool) -> STRATEGY_KEY_TYPE:
318318
return strategy
319319

320320
@override

0 commit comments

Comments
 (0)