Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions examples/llm-api/quickstart_advanced.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,9 @@ def add_llm_args(parser):
parser.add_argument('--draft_model_dir', type=str, default=None)
parser.add_argument('--max_matching_ngram_size', type=int, default=5)
parser.add_argument('--use_one_model', default=False, action='store_true')
parser.add_argument('--use_advanced_spec_dec_sampling',
default=False,
action='store_true')
parser.add_argument('--eagle_choices', type=str, default=None)
parser.add_argument('--use_dynamic_tree',
default=False,
Expand Down Expand Up @@ -194,15 +197,17 @@ def setup_llm(args, **kwargs):
relaxed_topk=args.relaxed_topk,
relaxed_delta=args.relaxed_delta,
mtp_eagle_one_model=args.use_one_model,
speculative_model_dir=args.model_dir)
speculative_model_dir=args.model_dir,
use_advanced_spec_dec_sampling=args.use_advanced_spec_dec_sampling)
elif spec_decode_algo == "EAGLE3":
spec_config = EagleDecodingConfig(
max_draft_len=args.spec_decode_max_draft_len,
speculative_model_dir=args.draft_model_dir,
eagle3_one_model=args.use_one_model,
eagle_choices=args.eagle_choices,
use_dynamic_tree=args.use_dynamic_tree,
dynamic_tree_max_topK=args.dynamic_tree_max_topK)
dynamic_tree_max_topK=args.dynamic_tree_max_topK,
use_advanced_spec_dec_sampling=args.use_advanced_spec_dec_sampling)
elif spec_decode_algo == "DRAFT_TARGET":
spec_config = DraftTargetDecodingConfig(
max_draft_len=args.spec_decode_max_draft_len,
Expand Down
104 changes: 103 additions & 1 deletion tensorrt_llm/_torch/pyexecutor/model_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import torch._dynamo.config

import tensorrt_llm.bindings.internal.userbuffers as ub
from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequest
from tensorrt_llm._utils import (is_trace_enabled, nvtx_range, release_gc,
torch_dtype_to_str, trace_func)
from tensorrt_llm.inputs.multimodal import (MultimodalParams,
Expand Down Expand Up @@ -173,6 +174,12 @@ def __init__(
self.enable_spec_decode = self.is_spec_decode
self.is_draft_model = is_draft_model

self.is_advanced_spec_dec_sampler = (self.is_spec_decode and (
(self.spec_config.spec_dec_mode.is_mtp_one_model()
and self.spec_config.use_advanced_spec_dec_sampling) or
(self.spec_config.spec_dec_mode.is_eagle3_one_model()
and self.spec_config.use_advanced_spec_dec_sampling)))

self.attn_runtime_features = attn_runtime_features or AttentionRuntimeFeatures(
)
self.input_processor = create_input_processor(model_path, None)
Expand Down Expand Up @@ -1211,11 +1218,79 @@ def _prepare_tp_inputs(
num_cached_tokens_per_seq = [] # per sequence
draft_tokens = []
draft_lens = []

gen_request_seq_slots = [] # per generation request
multimodal_params_list = []
mrope_position_ids = []
num_accepted_draft_tokens = [] # per request

if self.is_advanced_spec_dec_sampler:
temperatures = []
top_k = []
top_p = []
min_p = []

# advanced mtp sampling's request preprocessing helper functions
def collect_req_spec_dec_sampling_params(request: LlmRequest,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The resolution of request.sampling_config to sampling strategy has been cleaned up in #8132. See PR description for the intended semantics. The relevant function is

def _request_strategy(request: LlmRequest, *, vocab_size: int) -> Strategy:

The existing function covers various corner cases already (e.g. temperature=0, top_p=1, etc.) and has extensive unit tests. Consider reusing this function here (perhaps make it "public", i.e., rename to something that does not start with _).

draft_len: int = 0):

def get_request_temperature(request: LlmRequest,
is_greedy: bool) -> float:
if is_greedy:
return 0.01 # avoid numerical errors and keep the same behavior as greedy sampling
if not request.sampling_config.temperature:
return 1.0
temperature = request.sampling_config.temperature[0]
if 0 <= temperature < 1e-2:
# temperature less than 0.01 may cause numerical errors
temperature = 0.01
return temperature

def get_request_top_k(request: LlmRequest, is_greedy: bool) -> int:
if is_greedy:
return 1
if not request.sampling_config.top_k:
top_k = 0
else:
top_k = request.sampling_config.top_k[0]

# set k to a very large value (larger than vocab size) to disable top_k sampling
TOP_K_DISABLED = torch.iinfo(torch.int32).max
if top_k <= 0:
top_k = TOP_K_DISABLED
return top_k

def get_request_top_p(request: LlmRequest) -> float:
if not request.sampling_config.top_p:
top_p = 1.0
else:
top_p = request.sampling_config.top_p[0]
return top_p

def get_request_min_p(request: LlmRequest) -> float:
if not request.sampling_config.min_p:
min_p = 0.0
else:
min_p = request.sampling_config.min_p[0]
return min_p

def is_greedy_sampling(request: LlmRequest) -> bool:
if (request.sampling_config.temperature is None
and request.sampling_config.top_p is None
and request.sampling_config.top_k is None):
return True
return False

is_greedy_req = is_greedy_sampling(request)

temperatures.extend(
[get_request_temperature(request, is_greedy_req)] *
(draft_len + 1))
top_k.extend([get_request_top_k(request, is_greedy_req)] *
(draft_len + 1))
top_p.extend([get_request_top_p(request)] * (draft_len + 1))
min_p.extend([get_request_min_p(request)] * (draft_len + 1))

for request in scheduled_requests.context_requests:
request_ids.append(request.py_request_id)
all_prompt_tokens = request.get_tokens(0)
Expand All @@ -1233,7 +1308,6 @@ def _prepare_tp_inputs(
past_seen_token_num = begin_compute
num_cached_tokens_per_seq.append(past_seen_token_num)
request.cached_tokens = num_cached_tokens_per_seq[-1]

# Multimodal
py_multimodal_runtime = MultimodalRuntimeData(
mm_token_lengths=request.multimodal_lengths,
Expand Down Expand Up @@ -1269,6 +1343,9 @@ def _prepare_tp_inputs(
request.py_multimodal_data = multimodal_params.multimodal_data
multimodal_params_list.append(multimodal_params)

if self.is_advanced_spec_dec_sampler:
collect_req_spec_dec_sampling_params(request)

request.py_batch_idx = request.py_seq_slot

if len(multimodal_params_list) > 0:
Expand Down Expand Up @@ -1344,6 +1421,11 @@ def _prepare_tp_inputs(
past_seen_token_num + 1 + num_draft_tokens)))
num_cached_tokens_per_seq.append(past_seen_token_num)
request.cached_tokens = num_cached_tokens_per_seq[-1]

if self.is_advanced_spec_dec_sampler:
collect_req_spec_dec_sampling_params(
request, num_draft_tokens)

# update batch index
request.py_batch_idx = request.py_seq_slot
else:
Expand Down Expand Up @@ -1378,6 +1460,9 @@ def _prepare_tp_inputs(
num_cached_tokens_per_seq.append(past_seen_token_num +
self.runtime_draft_len + 1)
request.cached_tokens = num_cached_tokens_per_seq[-1]
if self.is_advanced_spec_dec_sampler:
collect_req_spec_dec_sampling_params(
request, self.runtime_draft_len)
if self.enable_spec_decode and spec_config.spec_dec_mode.extend_ctx(
self.attn_backend):
prompt_lengths.append(1 + self.runtime_draft_len)
Expand Down Expand Up @@ -1408,6 +1493,10 @@ def _prepare_tp_inputs(
# update batch index
request.py_batch_idx = request.py_seq_slot

if self.is_advanced_spec_dec_sampler:
collect_req_spec_dec_sampling_params(
request, self.original_max_draft_len)

for request in generation_requests:
request_ids.append(request.py_request_id)
beam_width = request.sampling_config.beam_width
Expand Down Expand Up @@ -1467,6 +1556,10 @@ def _prepare_tp_inputs(
])
multimodal_params_list.append(multimodal_params)

if self.is_advanced_spec_dec_sampler:
collect_req_spec_dec_sampling_params(request,
self.runtime_draft_len)

request.py_batch_idx = request.py_seq_slot
# Do not add a gen_request_seq_slot for CUDA graph dummy requests
# to prevent access errors due to None values
Expand Down Expand Up @@ -1701,6 +1794,11 @@ def previous_seq_slots_device():
total_draft_lens]
spec_metadata.request_ids = request_ids
spec_metadata.gather_ids = self.gather_ids_cuda[:len(gather_ids)]

if self.is_advanced_spec_dec_sampler:
spec_metadata.update_advanced_sampling_params(
temperatures, top_k, top_p, min_p)

spec_metadata.num_generations = len(
scheduled_requests.generation_requests)
spec_metadata.num_tokens = total_num_tokens
Expand Down Expand Up @@ -2257,6 +2355,10 @@ def forward(
is_spec_dec_mode, spec_metadata.is_spec_dec_tree,
spec_metadata.is_spec_dec_dynamic_tree,
self.original_max_draft_len)

if self.is_advanced_spec_dec_sampler:
spec_metadata._set_up_advanced_sampling(
self.batch_size, self.original_max_draft_len)
else:
spec_resource_manager = None
spec_metadata = None
Expand Down
107 changes: 107 additions & 0 deletions tensorrt_llm/_torch/pyexecutor/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,113 @@ def is_generation_model(self) -> bool:
return False


def forward_native(
logits: torch.Tensor,
k: Optional[torch.Tensor],
p: Optional[torch.Tensor],
) -> torch.Tensor:
"""
PyTorch-native implementation of top-k and top-p sampling.

The logits tensor may be updated in-place.
"""
logits = apply_top_k_top_p(logits, k, p)
probs = logits.softmax(dim=-1, dtype=torch.float32)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can technically skip this softmax

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I am also wrong here

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think @IzzyPutterman is right: apply_min_p evaluates the softmax of the temperature-scaled logits and uses that to mask out some of the logits (set to -inf). The probs could be masked in the same way (set to 0). The resulting probs can (mind paragraph below) then be reused in apply_top_k_top_p, which masks out more logits/probs.

Every time logits/probs are masked, it is sufficient to renormalize the probs such that they sum to one, which is much cheaper than computing softmax. This is probably also why https://docs.flashinfer.ai/api/sampling.html uses function names like ..._renorm_probs.

Note that much of this is already worked out in #8581, albeit using flashinfer.sampling.

return random_sample(probs)


def random_sample(
probs: torch.Tensor,
) -> torch.Tensor:
"""Randomly sample from the probabilities.

We use this function instead of torch.multinomial because torch.multinomial
causes CPU-GPU synchronization.
"""
q = torch.empty_like(probs)
q.exponential_()
return probs.div_(q).argmax(dim=-1).view(-1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have to admit that I am not familiar with this sampling scheme. If you happen to have a literature reference at hand, I would be curious to learn more (perhaps also include a comment stating the name of the method).

BTW, TorchSampler is using torch.multinomial and I did not notice any stream syncs. Code:

next_tokens = torch.multinomial(softmax, num_samples=1, generator=generator).squeeze(-1)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Disclaimer: I might have well overlooked that torch.multinomial is syncing so far, so I would be curious to hear more on this.)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I have found the answer to my first question: This uses the "Gumbel max trick" (Coderabbit even points that out in it's review...), after variable transformation from log-probabilities to probabilities. Including a corresponding remark in the doc-string might be useful for future readers.



def apply_min_p(
logits: torch.Tensor,
min_p: torch.Tensor,
) -> torch.Tensor:
"""
Filters logits using adaptive probability thresholding.
"""
# Convert logits to probability distribution
probability_values = torch.nn.functional.softmax(logits, dim=-1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this effectively neutralizes the temperature right? We apply temp then softmax again which undoes the scaling

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or perhaps in sampling_batch_spec_dec_one_model, we should remove the first softmax and put one just before the sort in the apply top_k top_P

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nvm im wrong here, misread something

# Calculate maximum probabilities per sequence
max_probabilities = torch.amax(probability_values, dim=-1, keepdim=True)
# Reshape min_p for broadcasting
adjusted_min_p = min_p.unsqueeze(1) * max_probabilities
# Identify valid tokens using threshold comparison
valid_token_mask = probability_values >= adjusted_min_p
# Apply mask using boolean indexing
logits[~valid_token_mask] = -float("inf")
return logits


def apply_top_k_top_p(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that we already have

def top_k_top_p_sampling_batch(

and some related functions. We should not duplicate this basic functionality, so let's use the existing functions and extend them as necessary (adding the Tensor-type k, p, temperature, etc.).

Also note that I am working on FlashInfer.sampling based alternatives for those functions. This upcoming PR brings support for Tensor-type k, p, temperature, etc. when FlashInfer is used. If you integrate the improvements made here for the non-FlashInfer case, this could give a quite nice feature set.

Ideally, this PR could (i) improve the existing sampling routines and (ii) use them via

class SimpleGroupedStrategySampler(GroupedStrategySampler[Strategy]):

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Put up #8581 (work in progress!) to give an idea of what to expect.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See also TRTLLM-7723 (and TRTLLM-7152) for scope of ongoing work.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I tested flashinfer with cuda graphs and it was breaking a bunch. With the generator objects its quite annoying in TRTLLM becuase in warmup we alternate between cuda graph warmup and non-cuda graph warmup, which breaks

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Worth a double check ofc, perhaps there is an easy way around it

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ixlmar I think the current implementation of TopK TopP only allows all the request having the same TopK TopP value instead of individual requests having different values, please correct me if I'm wrong.

The current logic in model_engine.py didn't parse out all the sampling params into GPU tensors for cuda graph, this PR enables that.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@IzzyPutterman The idea of #8581 is to allow choosing between the sampling routines we have today in sampling_utils.py and those provided by FlashInfer. Both will be available as implementations of GroupedStrategySampler. SimpleGroupedStrategySampler uses the former sampling routines (non FlashInfer) and is already available in main.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jhaotingc Correct. This was what I meant in the first comment:

We should not duplicate this basic functionality, so let's use the existing functions and extend them as necessary (adding the Tensor-type k, p, temperature, etc.).

Ideally, this PR could extend SimpleGroupedStrategySampler to allow for Tensor-type k, p, temperature, etc., in the same way as FlashInferGroupedStrategySampler does it for FlashInfer.sampling in #8581. If the GroupedStrategySampler abstraction is not viable (e.g. because the host data structures interfere with CUDA graphs), then I think we should extend top_k_top_p_sampling_batch (and the related functions) directly (promote scalar arguments to accept any broadcastable torch.Tensor) and reuse them here.

logits: torch.Tensor,
k: Optional[torch.Tensor],
p: Optional[torch.Tensor],
) -> torch.Tensor:
"""Apply top-k and top-p masks to the logits.

If a top-p is used, this function will sort the logits tensor,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a perf optimization, should we skip the expensive sorting / softmax / cumsum ops for top_p >=1?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If top_p is 1, we can skip the expensive sorting / softmax / cumsum ops.

In the latest trt llm version, it is already implemented. Please refer to https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/_torch/pyexecutor/sampling_utils.py#L159-L171.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The skipping is not possible because in regular decoding the sampling is not captured in cuda graph.
This part is captured in cuda graph, so unless there's a kernel that determine whether to skip or not (like cpp/kernels/samplingTopPKernel.cu checkAllTopP) there's no way to check with the cpu flag need_top_p.

which can be slow for large batches.

The logits tensor may be updated in-place.
"""
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
if k is not None:
# Apply top-k.
top_k_mask = logits_sort.size(1) - k.to(torch.long) # shape: B
top_k_mask = top_k_mask.clamp(min=0)
# Get all the top_k values.
top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
top_k_mask = logits_sort < top_k_mask
logits_sort.masked_fill_(top_k_mask, -float("inf"))
if p is not None:
# Apply top-p.
probs_sort = logits_sort.softmax(dim=-1)
probs_sum = torch.cumsum(probs_sort, dim=-1, out=probs_sort)
top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
# at least one
top_p_mask[:, -1] = False
logits_sort.masked_fill_(top_p_mask, -float("inf"))
# Re-sort the probabilities.
logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort)
return logits


def apply_temperature(
logits: torch.Tensor,
temp: torch.Tensor,
) -> torch.Tensor:
# Use in-place division to avoid creating a new tensor.
return logits.div_(temp.unsqueeze(dim=1))


@torch.compile(options={"max-autotune": True})
def sampling_batch_spec_dec_one_model(
logits: torch.Tensor,
temperatures: torch.Tensor,
top_k: torch.Tensor,
top_p: torch.Tensor,
min_p: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
raw_probs = torch.softmax(logits, dim=-1)
logits = apply_temperature(logits, temperatures)
logits = apply_min_p(logits, min_p)
random_sampled = forward_native(logits, top_k, top_p)
token_probs = torch.gather(raw_probs, dim=1, index=random_sampled.unsqueeze(1)).squeeze(-1)
log_probs = torch.log(token_probs)
return random_sampled, log_probs


# Due to tensorrt_llm::runtime::SamplingConfig using vectors, params
# in LlmRequest.sampling_params are either None or single-element lists.
# This helper method simplifies code using such params.
Expand Down
Loading