diff --git a/examples/llm-api/quickstart_advanced.py b/examples/llm-api/quickstart_advanced.py index 3e60d0f48cb..cae5b9d96da 100644 --- a/examples/llm-api/quickstart_advanced.py +++ b/examples/llm-api/quickstart_advanced.py @@ -133,6 +133,9 @@ def add_llm_args(parser): parser.add_argument('--draft_model_dir', type=str, default=None) parser.add_argument('--max_matching_ngram_size', type=int, default=5) parser.add_argument('--use_one_model', default=False, action='store_true') + parser.add_argument('--use_advanced_spec_dec_sampling', + default=False, + action='store_true') parser.add_argument('--eagle_choices', type=str, default=None) parser.add_argument('--use_dynamic_tree', default=False, @@ -194,7 +197,8 @@ def setup_llm(args, **kwargs): relaxed_topk=args.relaxed_topk, relaxed_delta=args.relaxed_delta, mtp_eagle_one_model=args.use_one_model, - speculative_model_dir=args.model_dir) + speculative_model_dir=args.model_dir, + use_advanced_spec_dec_sampling=args.use_advanced_spec_dec_sampling) elif spec_decode_algo == "EAGLE3": spec_config = EagleDecodingConfig( max_draft_len=args.spec_decode_max_draft_len, @@ -202,7 +206,8 @@ def setup_llm(args, **kwargs): eagle3_one_model=args.use_one_model, eagle_choices=args.eagle_choices, use_dynamic_tree=args.use_dynamic_tree, - dynamic_tree_max_topK=args.dynamic_tree_max_topK) + dynamic_tree_max_topK=args.dynamic_tree_max_topK, + use_advanced_spec_dec_sampling=args.use_advanced_spec_dec_sampling) elif spec_decode_algo == "DRAFT_TARGET": spec_config = DraftTargetDecodingConfig( max_draft_len=args.spec_decode_max_draft_len, diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index f1d61462566..73e49ce419a 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -14,6 +14,7 @@ import torch._dynamo.config import tensorrt_llm.bindings.internal.userbuffers as ub +from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequest from tensorrt_llm._utils import (is_trace_enabled, nvtx_range, release_gc, torch_dtype_to_str, trace_func) from tensorrt_llm.inputs.multimodal import (MultimodalParams, @@ -173,6 +174,12 @@ def __init__( self.enable_spec_decode = self.is_spec_decode self.is_draft_model = is_draft_model + self.is_advanced_spec_dec_sampler = (self.is_spec_decode and ( + (self.spec_config.spec_dec_mode.is_mtp_one_model() + and self.spec_config.use_advanced_spec_dec_sampling) or + (self.spec_config.spec_dec_mode.is_eagle3_one_model() + and self.spec_config.use_advanced_spec_dec_sampling))) + self.attn_runtime_features = attn_runtime_features or AttentionRuntimeFeatures( ) self.input_processor = create_input_processor(model_path, None) @@ -1211,11 +1218,79 @@ def _prepare_tp_inputs( num_cached_tokens_per_seq = [] # per sequence draft_tokens = [] draft_lens = [] + gen_request_seq_slots = [] # per generation request multimodal_params_list = [] mrope_position_ids = [] num_accepted_draft_tokens = [] # per request + if self.is_advanced_spec_dec_sampler: + temperatures = [] + top_k = [] + top_p = [] + min_p = [] + + # advanced mtp sampling's request preprocessing helper functions + def collect_req_spec_dec_sampling_params(request: LlmRequest, + draft_len: int = 0): + + def get_request_temperature(request: LlmRequest, + is_greedy: bool) -> float: + if is_greedy: + return 0.01 # avoid numerical errors and keep the same behavior as greedy sampling + if not request.sampling_config.temperature: + return 1.0 + temperature = request.sampling_config.temperature[0] + if 0 <= temperature < 1e-2: + # temperature less than 0.01 may cause numerical errors + temperature = 0.01 + return temperature + + def get_request_top_k(request: LlmRequest, is_greedy: bool) -> int: + if is_greedy: + return 1 + if not request.sampling_config.top_k: + top_k = 0 + else: + top_k = request.sampling_config.top_k[0] + + # set k to a very large value (larger than vocab size) to disable top_k sampling + TOP_K_DISABLED = torch.iinfo(torch.int32).max + if top_k <= 0: + top_k = TOP_K_DISABLED + return top_k + + def get_request_top_p(request: LlmRequest) -> float: + if not request.sampling_config.top_p: + top_p = 1.0 + else: + top_p = request.sampling_config.top_p[0] + return top_p + + def get_request_min_p(request: LlmRequest) -> float: + if not request.sampling_config.min_p: + min_p = 0.0 + else: + min_p = request.sampling_config.min_p[0] + return min_p + + def is_greedy_sampling(request: LlmRequest) -> bool: + if (request.sampling_config.temperature is None + and request.sampling_config.top_p is None + and request.sampling_config.top_k is None): + return True + return False + + is_greedy_req = is_greedy_sampling(request) + + temperatures.extend( + [get_request_temperature(request, is_greedy_req)] * + (draft_len + 1)) + top_k.extend([get_request_top_k(request, is_greedy_req)] * + (draft_len + 1)) + top_p.extend([get_request_top_p(request)] * (draft_len + 1)) + min_p.extend([get_request_min_p(request)] * (draft_len + 1)) + for request in scheduled_requests.context_requests: request_ids.append(request.py_request_id) all_prompt_tokens = request.get_tokens(0) @@ -1233,7 +1308,6 @@ def _prepare_tp_inputs( past_seen_token_num = begin_compute num_cached_tokens_per_seq.append(past_seen_token_num) request.cached_tokens = num_cached_tokens_per_seq[-1] - # Multimodal py_multimodal_runtime = MultimodalRuntimeData( mm_token_lengths=request.multimodal_lengths, @@ -1269,6 +1343,9 @@ def _prepare_tp_inputs( request.py_multimodal_data = multimodal_params.multimodal_data multimodal_params_list.append(multimodal_params) + if self.is_advanced_spec_dec_sampler: + collect_req_spec_dec_sampling_params(request) + request.py_batch_idx = request.py_seq_slot if len(multimodal_params_list) > 0: @@ -1344,6 +1421,11 @@ def _prepare_tp_inputs( past_seen_token_num + 1 + num_draft_tokens))) num_cached_tokens_per_seq.append(past_seen_token_num) request.cached_tokens = num_cached_tokens_per_seq[-1] + + if self.is_advanced_spec_dec_sampler: + collect_req_spec_dec_sampling_params( + request, num_draft_tokens) + # update batch index request.py_batch_idx = request.py_seq_slot else: @@ -1378,6 +1460,9 @@ def _prepare_tp_inputs( num_cached_tokens_per_seq.append(past_seen_token_num + self.runtime_draft_len + 1) request.cached_tokens = num_cached_tokens_per_seq[-1] + if self.is_advanced_spec_dec_sampler: + collect_req_spec_dec_sampling_params( + request, self.runtime_draft_len) if self.enable_spec_decode and spec_config.spec_dec_mode.extend_ctx( self.attn_backend): prompt_lengths.append(1 + self.runtime_draft_len) @@ -1408,6 +1493,10 @@ def _prepare_tp_inputs( # update batch index request.py_batch_idx = request.py_seq_slot + if self.is_advanced_spec_dec_sampler: + collect_req_spec_dec_sampling_params( + request, self.original_max_draft_len) + for request in generation_requests: request_ids.append(request.py_request_id) beam_width = request.sampling_config.beam_width @@ -1467,6 +1556,10 @@ def _prepare_tp_inputs( ]) multimodal_params_list.append(multimodal_params) + if self.is_advanced_spec_dec_sampler: + collect_req_spec_dec_sampling_params(request, + self.runtime_draft_len) + request.py_batch_idx = request.py_seq_slot # Do not add a gen_request_seq_slot for CUDA graph dummy requests # to prevent access errors due to None values @@ -1701,6 +1794,11 @@ def previous_seq_slots_device(): total_draft_lens] spec_metadata.request_ids = request_ids spec_metadata.gather_ids = self.gather_ids_cuda[:len(gather_ids)] + + if self.is_advanced_spec_dec_sampler: + spec_metadata.update_advanced_sampling_params( + temperatures, top_k, top_p, min_p) + spec_metadata.num_generations = len( scheduled_requests.generation_requests) spec_metadata.num_tokens = total_num_tokens @@ -2257,6 +2355,10 @@ def forward( is_spec_dec_mode, spec_metadata.is_spec_dec_tree, spec_metadata.is_spec_dec_dynamic_tree, self.original_max_draft_len) + + if self.is_advanced_spec_dec_sampler: + spec_metadata._set_up_advanced_sampling( + self.batch_size, self.original_max_draft_len) else: spec_resource_manager = None spec_metadata = None diff --git a/tensorrt_llm/_torch/pyexecutor/sampler.py b/tensorrt_llm/_torch/pyexecutor/sampler.py index 0c37b2856c5..cb7fa1396e9 100644 --- a/tensorrt_llm/_torch/pyexecutor/sampler.py +++ b/tensorrt_llm/_torch/pyexecutor/sampler.py @@ -235,6 +235,113 @@ def is_generation_model(self) -> bool: return False +def forward_native( + logits: torch.Tensor, + k: Optional[torch.Tensor], + p: Optional[torch.Tensor], +) -> torch.Tensor: + """ + PyTorch-native implementation of top-k and top-p sampling. + + The logits tensor may be updated in-place. + """ + logits = apply_top_k_top_p(logits, k, p) + probs = logits.softmax(dim=-1, dtype=torch.float32) + return random_sample(probs) + + +def random_sample( + probs: torch.Tensor, +) -> torch.Tensor: + """Randomly sample from the probabilities. + + We use this function instead of torch.multinomial because torch.multinomial + causes CPU-GPU synchronization. + """ + q = torch.empty_like(probs) + q.exponential_() + return probs.div_(q).argmax(dim=-1).view(-1) + + +def apply_min_p( + logits: torch.Tensor, + min_p: torch.Tensor, +) -> torch.Tensor: + """ + Filters logits using adaptive probability thresholding. + """ + # Convert logits to probability distribution + probability_values = torch.nn.functional.softmax(logits, dim=-1) + # Calculate maximum probabilities per sequence + max_probabilities = torch.amax(probability_values, dim=-1, keepdim=True) + # Reshape min_p for broadcasting + adjusted_min_p = min_p.unsqueeze(1) * max_probabilities + # Identify valid tokens using threshold comparison + valid_token_mask = probability_values >= adjusted_min_p + # Apply mask using boolean indexing + logits[~valid_token_mask] = -float("inf") + return logits + + +def apply_top_k_top_p( + logits: torch.Tensor, + k: Optional[torch.Tensor], + p: Optional[torch.Tensor], +) -> torch.Tensor: + """Apply top-k and top-p masks to the logits. + + If a top-p is used, this function will sort the logits tensor, + which can be slow for large batches. + + The logits tensor may be updated in-place. + """ + logits_sort, logits_idx = logits.sort(dim=-1, descending=False) + if k is not None: + # Apply top-k. + top_k_mask = logits_sort.size(1) - k.to(torch.long) # shape: B + top_k_mask = top_k_mask.clamp(min=0) + # Get all the top_k values. + top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1)) + top_k_mask = logits_sort < top_k_mask + logits_sort.masked_fill_(top_k_mask, -float("inf")) + if p is not None: + # Apply top-p. + probs_sort = logits_sort.softmax(dim=-1) + probs_sum = torch.cumsum(probs_sort, dim=-1, out=probs_sort) + top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1) + # at least one + top_p_mask[:, -1] = False + logits_sort.masked_fill_(top_p_mask, -float("inf")) + # Re-sort the probabilities. + logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort) + return logits + + +def apply_temperature( + logits: torch.Tensor, + temp: torch.Tensor, +) -> torch.Tensor: + # Use in-place division to avoid creating a new tensor. + return logits.div_(temp.unsqueeze(dim=1)) + + +@torch.compile(options={"max-autotune": True}) +def sampling_batch_spec_dec_one_model( + logits: torch.Tensor, + temperatures: torch.Tensor, + top_k: torch.Tensor, + top_p: torch.Tensor, + min_p: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + raw_probs = torch.softmax(logits, dim=-1) + logits = apply_temperature(logits, temperatures) + logits = apply_min_p(logits, min_p) + random_sampled = forward_native(logits, top_k, top_p) + token_probs = torch.gather(raw_probs, dim=1, index=random_sampled.unsqueeze(1)).squeeze(-1) + log_probs = torch.log(token_probs) + return random_sampled, log_probs + + # Due to tensorrt_llm::runtime::SamplingConfig using vectors, params # in LlmRequest.sampling_params are either None or single-element lists. # This helper method simplifies code using such params. diff --git a/tensorrt_llm/_torch/speculative/eagle3.py b/tensorrt_llm/_torch/speculative/eagle3.py index 423f4354bd9..346bf75da7d 100644 --- a/tensorrt_llm/_torch/speculative/eagle3.py +++ b/tensorrt_llm/_torch/speculative/eagle3.py @@ -10,7 +10,7 @@ from ..pyexecutor.guided_decoder import CapturableGuidedDecoder from ..pyexecutor.llm_request import LlmRequest from ..pyexecutor.resource_manager import BaseResourceManager, SlotManager -from ..pyexecutor.sampler import TorchSampler +from ..pyexecutor.sampler import TorchSampler, sampling_batch_spec_dec_one_model from ..pyexecutor.scheduler import ScheduledRequests from .interface import SpecMetadata from .mtp import MTPSampler @@ -244,6 +244,11 @@ class Eagle3OneModelSpecMetadata(SpecMetadata): # The index of the batche inputs batch_indices_cuda: Optional[torch.Tensor] = None + temperatures_cuda: Optional[torch.Tensor] = None + top_k_cuda: Optional[torch.Tensor] = None + top_p_cuda: Optional[torch.Tensor] = None + min_p_cuda: Optional[torch.Tensor] = None + def __post_init__(self): if self.layers_to_capture is None: if self.num_layers == 1: @@ -304,6 +309,51 @@ def maybe_capture_hidden_states( non_blocking=True) break + def _set_up_advanced_sampling(self, batch_size: int, max_draft_len: int): + # create once and reuse + if self.temperatures_cuda is None: + # Set deterministic seed (one time) for consistent multi-GPU sampling using PyTorch RNG + # operations that avoid torch.multinomial's CPU-GPU sync overhead + torch.manual_seed(0) + + max_total_sampling_size = batch_size * (max_draft_len + 1) + self.temperatures_cuda = torch.empty((max_total_sampling_size, ), + dtype=torch.float, + device='cuda') + self.top_k_cuda = torch.empty((max_total_sampling_size, ), + dtype=torch.int, + device='cuda') + self.top_p_cuda = torch.empty((max_total_sampling_size, ), + dtype=torch.float, + device='cuda') + self.min_p_cuda = torch.empty((max_total_sampling_size, ), + dtype=torch.float, + device='cuda') + + def update_advanced_sampling_params(self, temperatures: list[float], + top_k: list[int], top_p: list[float], + min_p: list[float]): + self.temperatures_cuda[:len(temperatures)].copy_(torch.tensor( + temperatures, dtype=torch.float, pin_memory=True), + non_blocking=True) + self.top_k_cuda[:len(top_k)].copy_(torch.tensor(top_k, + dtype=torch.int, + pin_memory=True), + non_blocking=True) + self.top_p_cuda[:len(top_p)].copy_(torch.tensor(top_p, + dtype=torch.float, + pin_memory=True), + non_blocking=True) + self.min_p_cuda[:len(min_p)].copy_(torch.tensor(min_p, + dtype=torch.float, + pin_memory=True), + non_blocking=True) + + self.temperatures = self.temperatures_cuda[:len(temperatures)] + self.top_k = self.top_k_cuda[:len(top_k)] + self.top_p = self.top_p_cuda[:len(top_p)] + self.min_p = self.min_p_cuda[:len(min_p)] + class Eagle3OneModelSampler(MTPSampler): @@ -467,8 +517,13 @@ def sample_and_accept_draft_tokens( dtype=torch.int, device=logits.device) - # Do greedy sampling for the input logits - target_tokens = torch.argmax(logits, dim=-1) + if self.spec_config.use_advanced_spec_dec_sampling: + target_tokens, _ = sampling_batch_spec_dec_one_model( + logits, spec_metadata.temperatures, spec_metadata.top_k, + spec_metadata.top_p, spec_metadata.min_p) + else: + # Do greedy sampling for the input logits + target_tokens = torch.argmax(logits, dim=-1) # context accepted_tokens[:num_contexts, 0] = target_tokens[:num_contexts] diff --git a/tensorrt_llm/_torch/speculative/interface.py b/tensorrt_llm/_torch/speculative/interface.py index 998f4b28cbb..c3848b1ccf5 100644 --- a/tensorrt_llm/_torch/speculative/interface.py +++ b/tensorrt_llm/_torch/speculative/interface.py @@ -175,6 +175,14 @@ class SpecMetadata: gather_ids: Optional[torch.Tensor] = None # The number of accepted draft tokens for each request. num_accepted_draft_tokens: Optional[torch.Tensor] = None + # The temperatures for requests. + temperatures: Optional[torch.Tensor] = None + # The top_k for requests. + top_k: Optional[torch.Tensor] = None + # The top_p for requests. + top_p: Optional[torch.Tensor] = None + # The min_p for requests. + min_p: Optional[torch.Tensor] = None # The number of tokens for speculative model/layer num_tokens: int = 0 # The number of tokens for speculative model/layer of different rank diff --git a/tensorrt_llm/_torch/speculative/mtp.py b/tensorrt_llm/_torch/speculative/mtp.py index 991fb672759..99c6960290c 100644 --- a/tensorrt_llm/_torch/speculative/mtp.py +++ b/tensorrt_llm/_torch/speculative/mtp.py @@ -14,7 +14,8 @@ from ..pyexecutor.llm_request import LlmRequest, LlmRequestState from ..pyexecutor.resource_manager import BaseResourceManager, SlotManager from ..pyexecutor.sampler import (SampleState, SampleStateTensors, TorchSampler, - add_token, int_tensor) + add_token, int_tensor, + sampling_batch_spec_dec_one_model) from ..pyexecutor.scheduler import ScheduledRequests from .interface import SpecMetadata @@ -119,6 +120,11 @@ class MTPSpecMetadata(SpecMetadata): # subsequence draft forward. subseq_all_rank_num_tokens: Optional[List[int]] = None + temperatures_cuda: Optional[torch.Tensor] = None + top_k_cuda: Optional[torch.Tensor] = None + top_p_cuda: Optional[torch.Tensor] = None + min_p_cuda: Optional[torch.Tensor] = None + def __post_init__(self) -> None: if self.mtp_hidden_states_manager is not None: # mtp_hidden_states_ptrs is a pointer tensor @@ -208,6 +214,51 @@ def prepare(self): pin_memory=True) self.slot_ids[:num_seqs].copy_(mtp_slot_ids, non_blocking=True) + def _set_up_advanced_sampling(self, batch_size: int, max_draft_len: int): + # create once and reuse + if self.temperatures_cuda is None: + # Set deterministic seed (one time) for consistent multi-GPU sampling using PyTorch RNG + # operations that avoid torch.multinomial's CPU-GPU sync overhead + torch.manual_seed(0) + + max_total_sampling_size = batch_size * (max_draft_len + 1) + self.temperatures_cuda = torch.empty((max_total_sampling_size, ), + dtype=torch.float, + device='cuda') + self.top_k_cuda = torch.empty((max_total_sampling_size, ), + dtype=torch.int, + device='cuda') + self.top_p_cuda = torch.empty((max_total_sampling_size, ), + dtype=torch.float, + device='cuda') + self.min_p_cuda = torch.empty((max_total_sampling_size, ), + dtype=torch.float, + device='cuda') + + def update_advanced_sampling_params(self, temperatures: list[float], + top_k: list[int], top_p: list[float], + min_p: list[float]): + self.temperatures_cuda[:len(temperatures)].copy_(torch.tensor( + temperatures, dtype=torch.float, pin_memory=True), + non_blocking=True) + self.top_k_cuda[:len(top_k)].copy_(torch.tensor(top_k, + dtype=torch.int, + pin_memory=True), + non_blocking=True) + self.top_p_cuda[:len(top_p)].copy_(torch.tensor(top_p, + dtype=torch.float, + pin_memory=True), + non_blocking=True) + self.min_p_cuda[:len(min_p)].copy_(torch.tensor(min_p, + dtype=torch.float, + pin_memory=True), + non_blocking=True) + + self.temperatures = self.temperatures_cuda[:len(temperatures)] + self.top_k = self.top_k_cuda[:len(top_k)] + self.top_p = self.top_p_cuda[:len(top_p)] + self.min_p = self.min_p_cuda[:len(min_p)] + class MTPSampler(TorchSampler): """ @@ -868,8 +919,15 @@ def sample_and_accept_draft_tokens( logits, spec_metadata.draft_tokens, target_tokens_cache, mtp_num_modules, batch_size, num_contexts, logits.shape[-1]) else: - # Do greedy sampling for the input logits - target_tokens = torch.argmax(logits, dim=-1) + if self.spec_config.use_advanced_spec_dec_sampling: + # Do advanced sampling for the input logits + # target_log_probs currently unused but kept for future log probs support in MTP + target_tokens, target_log_probs = sampling_batch_spec_dec_one_model( + logits, spec_metadata.temperatures, spec_metadata.top_k, + spec_metadata.top_p, spec_metadata.min_p) + else: + # Do greedy sampling for the input logits + target_tokens = torch.argmax(logits, dim=-1) # context accepted_tokens[:num_contexts, 0] = target_tokens[:num_contexts] diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 35d02350e9c..55d856143fe 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -552,6 +552,7 @@ class EagleDecodingConfig(DecodingBaseConfig): max_non_leaves_per_layer: Optional[int] = None eagle3_one_model: Optional[bool] = True eagle3_layers_to_capture: Optional[Set[int]] = None + use_advanced_spec_dec_sampling: Optional[bool] = False def __init__(self, **kwargs): super().__init__() @@ -764,6 +765,7 @@ class MTPDecodingConfig(DecodingBaseConfig): relaxed_delta: float = 0. use_mtp_vanilla: bool = False mtp_eagle_one_model: bool = True + use_advanced_spec_dec_sampling: Optional[bool] = False # TODO: remove this after distinguishing `max_draft_len` and `num_nextn_predict_layers` # Now we need a flag when MTPDecodingConfig is updated by PyTorchModelEngine. diff --git a/tests/unittest/_torch/speculative/test_mtp.py b/tests/unittest/_torch/speculative/test_mtp.py index c4a9783e792..9f638334565 100644 --- a/tests/unittest/_torch/speculative/test_mtp.py +++ b/tests/unittest/_torch/speculative/test_mtp.py @@ -330,6 +330,76 @@ def test_sample_and_accept_draft_tokens(self, test_case_name, accepted_tokens[i][0:ref_num_accepted_tokens[i]], ref_accepted_tokens[i][0:ref_num_accepted_tokens[i]]) + @parameterized.expand(load_sample_and_accept_draft_tokens_test_cases, + name_func=unittest_name_func) + def test_sample_and_accept_draft_tokens_adv_torch_sampler_greedy_mode( + self, test_case_name, mtp_num_modules, logits, draft_tokens, + draft_len, num_context_requests, ref_accepted_tokens, + ref_num_accepted_tokens): + # Set deterministic seed for consistent multi-GPU sampling for advanced pytorch sampler + torch.manual_seed(0) + + batch_size = len(draft_len) + # enable advanced pytorch sampler + spec_config = MTPDecodingConfig( + num_nextn_predict_layers=mtp_num_modules, + use_advanced_spec_dec_sampling=True) + + # attention metedata + attn_metadata = TrtllmAttentionMetadata(max_num_requests=batch_size, + max_num_tokens=1024, + kv_cache_manager=None) + attn_metadata.seq_lens = torch.tensor( + [1] * batch_size, dtype=torch.int) # dummy sequence length + attn_metadata.num_contexts = num_context_requests + + # speculative decoding metadata + spec_metadata = MTPSpecMetadata(max_num_requests=32, + spec_dec_mode=spec_config.spec_dec_mode, + max_draft_len=mtp_num_modules, + mtp_num_modules=mtp_num_modules) + spec_metadata.draft_tokens = draft_tokens + + temperatures = [] + top_k = [] + top_p = [] + min_p = [] + for i in range(batch_size): + num_draft_tokens = draft_len[i] + # set to greedy sampling mode (temperature <= 0.01 boundary) for advanced pytorch sampler + # sampling default config vals set in + # [tensorrt_llm/_torch/pyexecutor/model_engine.py:get_request_[param_name]] + temperatures.extend([0.01] * (num_draft_tokens + 1)) + top_k.extend([1] * (num_draft_tokens + 1)) + top_p.extend([1.0] * (num_draft_tokens + 1)) + min_p.extend([0.0] * (num_draft_tokens + 1)) + spec_metadata.temperatures = torch.tensor(temperatures, + dtype=torch.float, + device="cuda") + spec_metadata.top_k = torch.tensor(top_k, + dtype=torch.int, + device="cuda") + spec_metadata.top_p = torch.tensor(top_p, + dtype=torch.float, + device="cuda") + spec_metadata.min_p = torch.tensor(min_p, + dtype=torch.float, + device="cuda") + + # mtp worker + # is_thop default to False for advanced pytorch sampler testing only + mtpworker = MTPWorker(spec_config) + + # Test advanced torch sampler + accepted_tokens, num_accepted_tokens = mtpworker.sample_and_accept_draft_tokens( + None, logits, spec_metadata, attn_metadata) + + torch.testing.assert_close(num_accepted_tokens, ref_num_accepted_tokens) + for i in range(len(draft_len)): + torch.testing.assert_close( + accepted_tokens[i][0:ref_num_accepted_tokens[i]], + ref_accepted_tokens[i][0:ref_num_accepted_tokens[i]]) + class TestMTPUpdateMTPHiddenStates(unittest.TestCase):