Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 19 additions & 9 deletions tensorrt_llm/_torch/pyexecutor/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -824,7 +824,8 @@ def create_py_executor_instance(
def create_torch_sampler_args(mapping: Mapping, *, max_seq_len: int,
max_batch_size: int,
speculative_config: SpeculativeConfig,
max_beam_width: int):
max_beam_width: int,
disable_flash_infer_sampling: bool):
max_num_sequences = max_batch_size * mapping.pp_size
max_draft_len = (0 if speculative_config is None else
speculative_config.max_draft_len)
Expand All @@ -837,22 +838,31 @@ def create_torch_sampler_args(mapping: Mapping, *, max_seq_len: int,
max_total_draft_tokens=max_total_draft_tokens,
max_num_sequences=max_num_sequences,
max_beam_width=max_beam_width,
disable_flash_infer_sampling=disable_flash_infer_sampling,
)


def instantiate_sampler(engine: PyTorchModelEngine,
pytorch_backend_config: PyTorchConfig, mapping: Mapping,
max_batch_size: int, max_beam_width: int,
max_seq_len: int, mm_encoder_only: bool,
speculative_config: SpeculativeConfig,
decoding_config: trtllm.DecodingConfig,
kv_cache_config: KvCacheConfig):
def instantiate_sampler(
engine: PyTorchModelEngine,
pytorch_backend_config: PyTorchConfig,
mapping: Mapping,
max_batch_size: int,
max_beam_width: int,
max_seq_len: int,
mm_encoder_only: bool,
speculative_config: SpeculativeConfig,
decoding_config: trtllm.DecodingConfig,
kv_cache_config: KvCacheConfig,
disable_flash_infer_sampling: bool,
):
sampler_args = create_torch_sampler_args(
mapping,
max_seq_len=engine.max_seq_len,
max_batch_size=max_batch_size,
speculative_config=speculative_config,
max_beam_width=max_beam_width)
max_beam_width=max_beam_width,
disable_flash_infer_sampling=disable_flash_infer_sampling,
)
decoding_mode = get_decoding_mode(decoding_config=decoding_config,
max_beam_width=max_beam_width)
if mapping.cp_config.get('cp_type') == CpType.STAR:
Expand Down
25 changes: 14 additions & 11 deletions tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def get_guided_decoding_config(guided_decoding_backend: str,

def create_py_executor(
llm_args: TorchLlmArgs,
checkpoint_dir: str = None,
checkpoint_dir: Optional[str] = None,
tokenizer: Optional[TokenizerBase] = None,
profiling_stage_data: Optional[dict] = None,
) -> PyExecutor:
Expand Down Expand Up @@ -482,16 +482,19 @@ def drafting_loop_wrapper(model):
)

with mem_monitor.observe_creation_stage(_ExecutorCreationStage.SAMPLER):
sampler = instantiate_sampler(model_engine,
pytorch_backend_config,
mapping,
max_batch_size=max_batch_size,
max_beam_width=max_beam_width,
max_seq_len=max_seq_len,
mm_encoder_only=mm_encoder_only,
speculative_config=spec_config,
decoding_config=decoding_config,
kv_cache_config=kv_cache_config)
sampler = instantiate_sampler(
model_engine,
pytorch_backend_config,
mapping,
max_batch_size=max_batch_size,
max_beam_width=max_beam_width,
max_seq_len=max_seq_len,
mm_encoder_only=mm_encoder_only,
speculative_config=spec_config,
decoding_config=decoding_config,
kv_cache_config=kv_cache_config,
disable_flash_infer_sampling=llm_args._disable_flash_infer_sampling,
)
logger.info(f"Using Sampler: {type(sampler).__name__}")

if kv_connector_config is not None:
Expand Down
23 changes: 18 additions & 5 deletions tensorrt_llm/_torch/pyexecutor/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from collections.abc import Iterable
from dataclasses import dataclass
from itertools import repeat
from typing import Any, Callable, List, Optional, TypeVar, cast
from typing import Any, Callable, List, Optional, Type, TypeVar, cast

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -53,13 +53,15 @@
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.sampling_params import SamplingParams

from ..flashinfer_utils import IS_FLASHINFER_AVAILABLE
from ..speculative.spec_tree_manager import SpecTreeManager
from .finish_reason import FinishedState
from .llm_request import LlmRequest, LlmRequestState, get_draft_token_length
from .resource_manager import ResourceManager, ResourceManagerType
from .sampling_utils import (
GREEDY,
GenericStrategyKeyType,
GroupedStrategySampler,
SimpleGroupedStrategySampler,
Strategy,
UtilsSamplingParams,
Expand All @@ -71,6 +73,9 @@
)
from .scheduler import ScheduledRequests

if IS_FLASHINFER_AVAILABLE:
from .sampling_utils_flashinfer import FlashInferGroupedStrategySampler

if sys.version_info[:2] >= (3, 12):
from typing import override
else:
Expand Down Expand Up @@ -266,7 +271,7 @@ def _request_strategy(request: LlmRequest, *, vocab_size: int) -> Strategy:
def _group_requests_by_strategy_key(
requests: Iterable[LlmRequest],
*,
strategy_to_key: Callable[[Strategy], GenericStrategyKeyType],
strategy_to_key: Callable[[Strategy, bool], GenericStrategyKeyType],
pin_memory: bool = False,
vocab_size: int,
) -> dict[tuple[GenericStrategyKeyType, bool], tuple[torch.Tensor, List[Strategy]]]:
Expand All @@ -276,8 +281,8 @@ def _group_requests_by_strategy_key(
)
for req_index, req in enumerate(requests):
strategy = _request_strategy(req, vocab_size=vocab_size)
strategy_key = strategy_to_key(strategy)
speculation_needs_probs = req.py_draft_logits is not None and strategy is not GREEDY
strategy_key = strategy_to_key(strategy, speculation_needs_probs)
group_dict_entry = group_dict[(strategy_key, speculation_needs_probs)]
group_dict_entry[0].append(req_index)
group_dict_entry[1].append(strategy)
Expand Down Expand Up @@ -586,6 +591,7 @@ class Args:
max_num_sequences: int
max_beam_width: int
max_total_draft_tokens: int
disable_flash_infer_sampling: bool = False

def __init__(self, args: Args):
self.max_seq_len = args.max_seq_len
Expand All @@ -602,6 +608,13 @@ def __init__(self, args: Args):
with torch.inference_mode(False):
self.store = self.create_store()

self._grouped_sampler_cls: Type[GroupedStrategySampler]
if IS_FLASHINFER_AVAILABLE and not args.disable_flash_infer_sampling:
cls_not_possibly_unbound = FlashInferGroupedStrategySampler # type: ignore
self._grouped_sampler_cls = cls_not_possibly_unbound
else:
self._grouped_sampler_cls = SimpleGroupedStrategySampler

# Initialize seed for multi-GPU consistency
self._global_seed = 42
self._generator = None
Expand Down Expand Up @@ -1181,7 +1194,7 @@ def _sample_batched_by_strategy(
requests,
pin_memory=True,
vocab_size=logits_cuda.size(1),
strategy_to_key=SimpleGroupedStrategySampler.strategy_grouping_key,
strategy_to_key=self._grouped_sampler_cls.strategy_grouping_key,
)
generator_cuda = self.get_generator(cuda_device)

Expand Down Expand Up @@ -1238,7 +1251,7 @@ def _sample_batched_by_strategy(
for _ in range(steps)
]
group_next_tokens_cuda, group_softmax_cuda = (
SimpleGroupedStrategySampler.sample_grouped_strategies(
self._grouped_sampler_cls.sample_grouped_strategies(
strategy_key,
group_strategies_per_step,
group_logits_cuda,
Expand Down
21 changes: 12 additions & 9 deletions tensorrt_llm/_torch/pyexecutor/sampling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,13 @@
from typing_extensions import override


TemperatureOnly = tuple[Literal["temperature"], float]
TopK = tuple[Literal["top_k"], int, float]
TopP = tuple[Literal["top_p"], float, float]
TopKTopP = tuple[Literal["top_k_top_p"], int, float, float]
Greedy = tuple[Literal["greedy"], None]
TemperatureOnly: TypeAlias = tuple[Literal["temperature"], float]
TopK: TypeAlias = tuple[Literal["top_k"], int, float]
TopP: TypeAlias = tuple[Literal["top_p"], float, float]
TopKTopP: TypeAlias = tuple[Literal["top_k_top_p"], int, float, float]
Greedy: TypeAlias = tuple[Literal["greedy"], None]
GREEDY: Greedy = ("greedy", None)
Strategy = TopK | TopP | Greedy | TopKTopP | TemperatureOnly
Strategy: TypeAlias = TopK | TopP | Greedy | TopKTopP | TemperatureOnly


@dataclass(frozen=True, kw_only=True)
Expand Down Expand Up @@ -258,7 +258,10 @@ def sample(
match strategy:
case ("top_k", top_k, temperature):
tokens, softmax = top_k_sampling_batch(
logits, top_k=top_k, temperature=temperature, generator=generator
logits,
top_k=top_k,
temperature=temperature,
generator=generator,
)
case ("top_p", top_p, temperature):
tokens, softmax = top_p_sampling_batch(
Expand Down Expand Up @@ -292,7 +295,7 @@ def sample(
class GroupedStrategySampler(Generic[GenericStrategyKeyType], abc.ABC):
@staticmethod
@abc.abstractmethod
def strategy_grouping_key(strategy: Strategy) -> GenericStrategyKeyType:
def strategy_grouping_key(strategy: Strategy, return_probs: bool) -> GenericStrategyKeyType:
raise NotImplementedError

@staticmethod
Expand All @@ -314,7 +317,7 @@ class SimpleGroupedStrategySampler(GroupedStrategySampler[Strategy]):

@override
@staticmethod
def strategy_grouping_key(strategy: Strategy) -> STRATEGY_KEY_TYPE:
def strategy_grouping_key(strategy: Strategy, return_probs: bool) -> STRATEGY_KEY_TYPE:
return strategy

@override
Expand Down
Loading
Loading