Skip to content

[TRTLLM-5627] feat: Implement pytorch sampler for MTP #6245

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions tensorrt_llm/_torch/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@ class ModelConfig(Generic[TConfig]):
# Allow models to select op according to whether CUDA Graphs are used.
use_cuda_graph: bool = False

# If true, iterate over sampling_params of each request and use the corresponding sampling strategy.
# Currently only used for DeepSeek-MTP.
enable_mixed_sampler: bool = False

force_dynamic_quantization: bool = False

extra_attrs: Dict = field(default_factory=dict, repr=False, init=False)
Expand Down
79 changes: 78 additions & 1 deletion tensorrt_llm/_torch/pyexecutor/model_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import tensorrt_llm.bindings.internal.userbuffers as ub
from tensorrt_llm._torch.models.checkpoints.base_checkpoint_loader import \
BaseCheckpointLoader
from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequest
from tensorrt_llm._torch.pyexecutor.sampler import SampleStateTensors
from tensorrt_llm._torch.speculative import (
get_num_extra_kv_tokens, update_spec_config_from_model_config)
Expand Down Expand Up @@ -278,6 +279,8 @@ def __init__(
self.spec_config = spec_config
self.is_spec_decode = spec_config is not None
self.is_draft_model = is_draft_model
self.is_advanced_mtp_sampler = self.is_spec_decode and self.spec_config.spec_dec_mode.is_mtp(
) and self.pytorch_backend_config.enable_mixed_sampler

self.in_warmup = False

Expand All @@ -295,6 +298,7 @@ def __init__(
max_num_tokens=max_num_tokens,
moe_max_num_tokens=pytorch_backend_config.moe_max_num_tokens,
moe_load_balancer=pytorch_backend_config.moe_load_balancer,
enable_mixed_sampler=pytorch_backend_config.enable_mixed_sampler,
lora_config=lora_config)
# In case that some tests use stub models and override `_load_model`.
if not hasattr(self.model, 'extra_attrs'):
Expand Down Expand Up @@ -1166,6 +1170,57 @@ def _prepare_tp_inputs(
multimodal_params_list = []
gen_request_seq_slots = [] # per generation request

if self.is_advanced_mtp_sampler:
temperatures = []
top_k = []
top_p = []
min_p = []

# advanced mtp sampling's request preprocessing helper functions
def collect_req_mtp_sampling_params(request: LlmRequest,
draft_len: int = 0):

def get_request_temperature(request: LlmRequest) -> float:
if not request.sampling_config.temperature:
return 1.0
temperature = request.sampling_config.temperature[0]
if 0 < temperature < 1e-2:
# temperature less than 0.01 may cause numerical errors
temperature = 0.01
return temperature

def get_request_top_k(request: LlmRequest) -> int:
if not request.sampling_config.top_k:
top_k = 0
else:
top_k = request.sampling_config.top_k[0]

# set k to a very large value (larger than vocab size) to disable top_k sampling
TOP_K_DISABLED = torch.iinfo(torch.int32).max
if top_k <= 0:
top_k = TOP_K_DISABLED
return top_k

def get_request_top_p(request: LlmRequest) -> float:
if not request.sampling_config.top_p:
top_p = 1.0
else:
top_p = request.sampling_config.top_p[0]
return top_p

def get_request_min_p(request: LlmRequest) -> float:
if not request.sampling_config.min_p:
min_p = 0.0
else:
min_p = request.sampling_config.min_p[0]
return min_p

temperatures.extend([get_request_temperature(request)] *
(draft_len + 1))
top_k.extend([get_request_top_k(request)] * (draft_len + 1))
top_p.extend([get_request_top_p(request)] * (draft_len + 1))
min_p.extend([get_request_min_p(request)] * (draft_len + 1))

for request in scheduled_requests.context_requests:
request_ids.append(request.py_request_id)
all_prompt_tokens = request.get_tokens(0)
Expand All @@ -1181,7 +1236,6 @@ def _prepare_tp_inputs(
prompt_lengths.append(len(prompt_tokens))
past_seen_token_num = begin_compute
num_cached_tokens_per_seq.append(past_seen_token_num)

# Multimodal
# TODO: enable chunk prefill for multimodal (maybe need to pass prompt_tokens to MultimodalRuntimeData)
py_multimodal_runtime = MultimodalRuntimeData(
Expand All @@ -1200,6 +1254,9 @@ def _prepare_tp_inputs(
if multimodal_params.has_content():
multimodal_params_list.append(multimodal_params)

if self.is_advanced_mtp_sampler:
collect_req_mtp_sampling_params(request)

request.py_batch_idx = request.py_seq_slot

num_ctx_requests = len(scheduled_requests.context_requests)
Expand Down Expand Up @@ -1282,6 +1339,10 @@ def _prepare_tp_inputs(
past_seen_token_num + 1 + num_draft_tokens)))
num_cached_tokens_per_seq.append(past_seen_token_num)
request_ids.append(request.py_request_id)

if self.is_advanced_mtp_sampler:
collect_req_mtp_sampling_params(request, num_draft_tokens)

# update batch index
request.py_batch_idx = request.py_seq_slot
else:
Expand Down Expand Up @@ -1311,6 +1372,9 @@ def _prepare_tp_inputs(
prompt_lengths.append(request.py_prompt_len)
request_ids.append(request.py_request_id)

if self.is_advanced_mtp_sampler:
collect_req_mtp_sampling_params(request, self.max_draft_len)

for request in generation_requests:
beam_width = request.sampling_config.beam_width
for beam in range(beam_width):
Expand Down Expand Up @@ -1342,6 +1406,10 @@ def _prepare_tp_inputs(

request_ids.append(request.py_request_id)
gen_request_seq_slots.append(request.py_seq_slot)

if self.is_advanced_mtp_sampler:
collect_req_mtp_sampling_params(request, self.max_draft_len)

request.py_batch_idx = request.py_seq_slot

previous_batch_len = len(previous_batch_indices)
Expand Down Expand Up @@ -1534,6 +1602,11 @@ def previous_seq_slots_device():
total_draft_lens]
spec_metadata.request_ids = request_ids
spec_metadata.gather_ids = self.gather_ids_cuda[:len(gather_ids)]

if self.is_advanced_mtp_sampler:
spec_metadata.update_advanced_mtp_sampling_params(
temperatures, top_k, top_p, min_p)

spec_metadata.num_generations = len(
scheduled_requests.generation_requests)
spec_metadata.num_tokens = total_num_tokens
Expand Down Expand Up @@ -2082,6 +2155,10 @@ def forward(
spec_metadata.is_spec_dec_tree,
spec_metadata.is_spec_dec_dynamic_tree,
spec_metadata.max_draft_len)

if self.is_advanced_mtp_sampler:
spec_metadata._set_up_advanced_mtp_sampling(
self.batch_size, self.max_draft_len)
else:
spec_resource_manager = None
spec_metadata = None
Expand Down
114 changes: 113 additions & 1 deletion tensorrt_llm/_torch/pyexecutor/sampler.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from abc import ABC, abstractmethod
from collections.abc import Iterable
from dataclasses import dataclass
from typing import Literal
from typing import Literal, Optional

import torch

Expand Down Expand Up @@ -151,6 +151,118 @@ def top_p_sampling_batch(logits: torch.Tensor, top_p: float = 0.9):
return next_tokens, softmax


def forward_native(
logits: torch.Tensor,
k: Optional[torch.Tensor],
p: Optional[torch.Tensor],
) -> torch.Tensor:
"""
PyTorch-native implementation of top-k and top-p sampling.

The logits tensor may be updated in-place.
"""
logits = apply_top_k_top_p(logits, k, p)
probs = logits.softmax(dim=-1, dtype=torch.float32)
return random_sample(probs)


def random_sample(probs: torch.Tensor, ) -> torch.Tensor:
"""Randomly sample from the probabilities.

We use this function instead of torch.multinomial because torch.multinomial
causes CPU-GPU synchronization.
"""
q = torch.empty_like(probs)
q.exponential_()
return probs.div_(q).argmax(dim=-1).view(-1)


def apply_min_p(
logits: torch.Tensor,
min_p: torch.Tensor,
) -> torch.Tensor:
"""
Filters logits using adaptive probability thresholding.
"""
# Convert logits to probability distribution
probability_values = torch.nn.functional.softmax(logits, dim=-1)
# Calculate maximum probabilities per sequence
max_probabilities = torch.amax(probability_values, dim=-1, keepdim=True)
# Reshape min_p for broadcasting
adjusted_min_p = min_p.unsqueeze(1) * max_probabilities
# Identify valid tokens using threshold comparison
valid_token_mask = probability_values >= adjusted_min_p
# Apply mask using boolean indexing
logits[~valid_token_mask] = -float('inf')
return logits


def apply_top_k_top_p(
logits: torch.Tensor,
k: Optional[torch.Tensor],
p: Optional[torch.Tensor],
) -> torch.Tensor:
"""Apply top-k and top-p masks to the logits.

If a top-p is used, this function will sort the logits tensor,
which can be slow for large batches.

The logits tensor may be updated in-place.
"""
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
if k is not None:
# Apply top-k.
top_k_mask = logits_sort.size(1) - k.to(torch.long) # shape: B
top_k_mask = top_k_mask.clamp(min=0)
# Get all the top_k values.
top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
top_k_mask = logits_sort < top_k_mask
logits_sort.masked_fill_(top_k_mask, -float("inf"))
if p is not None:
# Apply top-p.
probs_sort = logits_sort.softmax(dim=-1)
probs_sum = torch.cumsum(probs_sort, dim=-1, out=probs_sort)
top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
# at least one
top_p_mask[:, -1] = False
logits_sort.masked_fill_(top_p_mask, -float("inf"))
# Re-sort the probabilities.
logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort)
return logits


def greedy_sample(logits: torch.Tensor) -> torch.Tensor:
return logits.argmax(dim=-1).view(-1)


def apply_temperature(
logits: torch.Tensor,
temp: torch.Tensor,
) -> torch.Tensor:
# Use in-place division to avoid creating a new tensor.
return logits.div_(temp.unsqueeze(dim=1))


def sampling_batch(logits: torch.Tensor, temperatures: torch.Tensor,
top_k: torch.Tensor, top_p: torch.Tensor,
min_p: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
raw_probs = torch.softmax(logits, dim=-1)
greedy_sampled = greedy_sample(logits)
logits = apply_temperature(logits, temperatures)
logits = apply_min_p(logits, min_p)
random_sampled = forward_native(logits, top_k, top_p)
next_tokens = torch.where(
temperatures <= 1e-2, # Match the clamping threshold
greedy_sampled,
random_sampled,
out=greedy_sampled, # Reuse tensor
)
token_probs = torch.gather(raw_probs, dim=1,
index=next_tokens.unsqueeze(1)).squeeze(-1)
log_probs = torch.log(token_probs)
return next_tokens, log_probs

Comment on lines +246 to +264
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Unified sampling function is well-designed but has a potential numerical stability issue.

The sampling_batch function effectively combines all sampling strategies with a temperature threshold for greedy vs. random sampling. The design is sound, but there's a potential numerical issue:

The log probability calculation at Line 262 could produce NaN values if token_probs contains zeros due to precision issues:

    token_probs = torch.gather(raw_probs, dim=1,
                               index=next_tokens.unsqueeze(1)).squeeze(-1)
-   log_probs = torch.log(token_probs)
+   log_probs = torch.log(token_probs.clamp(min=1e-8))

This prevents log(0) = -inf issues that could propagate through the system.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def sampling_batch(logits: torch.Tensor, temperatures: torch.Tensor,
top_k: torch.Tensor, top_p: torch.Tensor,
min_p: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
raw_probs = torch.softmax(logits, dim=-1)
greedy_sampled = greedy_sample(logits)
logits = apply_temperature(logits, temperatures)
logits = apply_min_p(logits, min_p)
random_sampled = forward_native(logits, top_k, top_p)
next_tokens = torch.where(
temperatures <= 1e-2, # Match the clamping threshold
greedy_sampled,
random_sampled,
out=greedy_sampled, # Reuse tensor
)
token_probs = torch.gather(raw_probs, dim=1,
index=next_tokens.unsqueeze(1)).squeeze(-1)
log_probs = torch.log(token_probs)
return next_tokens, log_probs
def sampling_batch(logits: torch.Tensor, temperatures: torch.Tensor,
top_k: torch.Tensor, top_p: torch.Tensor,
min_p: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
raw_probs = torch.softmax(logits, dim=-1)
greedy_sampled = greedy_sample(logits)
logits = apply_temperature(logits, temperatures)
logits = apply_min_p(logits, min_p)
random_sampled = forward_native(logits, top_k, top_p)
next_tokens = torch.where(
temperatures <= 1e-2, # Match the clamping threshold
greedy_sampled,
random_sampled,
out=greedy_sampled, # Reuse tensor
)
token_probs = torch.gather(
raw_probs,
dim=1,
index=next_tokens.unsqueeze(1),
).squeeze(-1)
log_probs = torch.log(token_probs.clamp(min=1e-8))
return next_tokens, log_probs
🤖 Prompt for AI Agents
In tensorrt_llm/_torch/pyexecutor/sampler.py around lines 246 to 264, the log
probability calculation uses torch.log on token_probs which may contain zeros,
causing NaN values. To fix this, clamp token_probs to a small positive value
(e.g., 1e-10) before applying torch.log to avoid log(0) and ensure numerical
stability.


def greedy_search_sampling_batch(logits):
next_tokens = torch.argmax(logits, dim=-1)
softmax = torch.softmax(logits, dim=-1)
Expand Down
8 changes: 8 additions & 0 deletions tensorrt_llm/_torch/speculative/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,14 @@ class SpecMetadata:
seq_lens: Optional[List[int]] = None
# The gather ids for logits.
gather_ids: Optional[torch.Tensor] = None
# The temperatures for requests.
temperatures: Optional[torch.Tensor] = None
# The top_k for requests.
top_k: Optional[torch.Tensor] = None
# The top_p for requests.
top_p: Optional[torch.Tensor] = None
# The min_p for requests.
min_p: Optional[torch.Tensor] = None
# The number of tokens for speculative model/layer
num_tokens: int = 0
# The number of tokens for speculative model/layer of different rank
Expand Down
Loading