Skip to content

Commit 0ee79b3

Browse files
committed
Change the logic to reuse enable_mixed_sampler flag; remove mtp advanced sampling flag
Signed-off-by: Xuanyu Chen <[email protected]>
1 parent 3be14a6 commit 0ee79b3

File tree

6 files changed

+16
-11
lines changed

6 files changed

+16
-11
lines changed

examples/llm-api/quickstart_advanced.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -112,9 +112,6 @@ def add_llm_args(parser):
112112
parser.add_argument('--draft_model_dir', type=str, default=None)
113113
parser.add_argument('--max_matching_ngram_size', type=int, default=5)
114114
parser.add_argument('--use_one_model', default=False, action='store_true')
115-
parser.add_argument('--use_advanced_mtp_sampler',
116-
default=False,
117-
action='store_true')
118115

119116
# Relaxed acceptance
120117
parser.add_argument('--use_relaxed_acceptance_for_thinking',
@@ -166,8 +163,7 @@ def setup_llm(args, **kwargs):
166163
use_relaxed_acceptance_for_thinking=args.
167164
use_relaxed_acceptance_for_thinking,
168165
relaxed_topk=args.relaxed_topk,
169-
relaxed_delta=args.relaxed_delta,
170-
use_advanced_mtp_sampler=args.use_advanced_mtp_sampler)
166+
relaxed_delta=args.relaxed_delta)
171167
elif spec_decode_algo == "EAGLE3":
172168
spec_config = EagleDecodingConfig(
173169
max_draft_len=args.spec_decode_max_draft_len,

tensorrt_llm/_torch/model_config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,10 @@ class ModelConfig(Generic[TConfig]):
8989
# Allow models to select op according to whether CUDA Graphs are used.
9090
use_cuda_graph: bool = False
9191

92+
# If true, iterate over sampling_params of each request and use the corresponding sampling strategy.
93+
# Currently only used for DeepSeek-MTP.
94+
enable_mixed_sampler: bool = False
95+
9296
force_dynamic_quantization: bool = False
9397

9498
extra_attrs: Dict = field(default_factory=dict, repr=False, init=False)

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ def __init__(
280280
self.is_spec_decode = spec_config is not None
281281
self.is_draft_model = is_draft_model
282282
self.is_advanced_mtp_sampler = self.is_spec_decode and self.spec_config.spec_dec_mode.is_mtp(
283-
) and self.spec_config.use_advanced_mtp_sampler
283+
) and self.pytorch_backend_config.enable_mixed_sampler
284284

285285
self.in_warmup = False
286286

@@ -298,6 +298,7 @@ def __init__(
298298
max_num_tokens=max_num_tokens,
299299
moe_max_num_tokens=pytorch_backend_config.moe_max_num_tokens,
300300
moe_load_balancer=pytorch_backend_config.moe_load_balancer,
301+
enable_mixed_sampler=pytorch_backend_config.enable_mixed_sampler,
301302
lora_config=lora_config)
302303
# In case that some tests use stub models and override `_load_model`.
303304
if not hasattr(self.model, 'extra_attrs'):
@@ -1195,7 +1196,7 @@ def get_request_top_k(request: LlmRequest) -> int:
11951196
top_k = request.sampling_config.top_k[0]
11961197

11971198
# set k to a very large value (larger than vocab size) to disable top_k sampling
1198-
TOP_K_DISABLED = (1 << 31) - 1
1199+
TOP_K_DISABLED = torch.iinfo(torch.int32).max
11991200
if top_k <= 0:
12001201
top_k = TOP_K_DISABLED
12011202
return top_k

tensorrt_llm/_torch/speculative/mtp.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,11 @@ def __init__(self, spec_config: "MTPDecodingConfig", model_config=None):
376376
self.model_config = model_config
377377
self.is_thop = False
378378

379+
# Default to greedy mode. If true, use advanced pytorch sampling strategy.
380+
self.enable_mixed_sampler = False
381+
if self.model_config is not None:
382+
self.enable_mixed_sampler = self.model_config.enable_mixed_sampler
383+
379384
def forward(
380385
self,
381386
input_ids,
@@ -891,7 +896,7 @@ def sample_and_accept_draft_tokens(
891896
logits, spec_metadata.draft_tokens, target_tokens_cache,
892897
mtp_num_modules, batch_size, num_contexts, logits.shape[-1])
893898
else:
894-
if self.spec_config.use_advanced_mtp_sampler:
899+
if self.enable_mixed_sampler:
895900
# Do advanced sampling for the input logits
896901
# target_log_probs currently unused but kept for future log probs support in MTP
897902
target_tokens, target_log_probs = sampling_batch(

tensorrt_llm/llmapi/llm_args.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -478,7 +478,6 @@ class MTPDecodingConfig(DecodingBaseConfig):
478478
relaxed_topk: int = 1
479479
relaxed_delta: float = 0.
480480
use_mtp_vanilla: bool = False
481-
use_advanced_mtp_sampler: Optional[bool] = False
482481

483482
# TODO: remove this after distinguishing `max_draft_len` and `num_nextn_predict_layers`
484483
# Now we need a flag when MTPDecodingConfig is updated by PyTorchModelEngine.

tests/unittest/_torch/speculative/test_mtp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -342,8 +342,7 @@ def test_sample_and_accept_draft_tokens_adv_torch_sampler_greedy_mode(
342342
batch_size = len(draft_len)
343343
# enable advanced pytorch sampler
344344
spec_config = MTPDecodingConfig(
345-
num_nextn_predict_layers=mtp_num_modules,
346-
use_advanced_mtp_sampler=True)
345+
num_nextn_predict_layers=mtp_num_modules)
347346

348347
# attention metedata
349348
attn_metadata = TrtllmAttentionMetadata(max_num_requests=batch_size,
@@ -389,6 +388,7 @@ def test_sample_and_accept_draft_tokens_adv_torch_sampler_greedy_mode(
389388
# mtp worker
390389
# is_thop default to False for advanced pytorch sampler testing only
391390
mtpworker = MTPWorker(spec_config)
391+
mtpworker.enable_mixed_sampler = True
392392

393393
# Test advanced torch sampler
394394
accepted_tokens, num_accepted_tokens = mtpworker.sample_and_accept_draft_tokens(

0 commit comments

Comments
 (0)