Skip to content

Commit 4a68f67

Browse files
committed
add advanced torch sampler flag; add test; cleanup and rebase
Signed-off-by: Xuanyu Chen <[email protected]>
1 parent 5cdb0b9 commit 4a68f67

File tree

7 files changed

+205
-127
lines changed

7 files changed

+205
-127
lines changed

examples/llm-api/quickstart_advanced.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,9 @@ def add_llm_args(parser):
112112
parser.add_argument('--draft_model_dir', type=str, default=None)
113113
parser.add_argument('--max_matching_ngram_size', type=int, default=5)
114114
parser.add_argument('--use_one_model', default=False, action='store_true')
115+
parser.add_argument('--use_advanced_mtp_sampler',
116+
default=False,
117+
action='store_true')
115118

116119
# Relaxed acceptance
117120
parser.add_argument('--use_relaxed_acceptance_for_thinking',
@@ -163,7 +166,8 @@ def setup_llm(args, **kwargs):
163166
use_relaxed_acceptance_for_thinking=args.
164167
use_relaxed_acceptance_for_thinking,
165168
relaxed_topk=args.relaxed_topk,
166-
relaxed_delta=args.relaxed_delta)
169+
relaxed_delta=args.relaxed_delta,
170+
use_advanced_mtp_sampler=args.use_advanced_mtp_sampler)
167171
elif spec_decode_algo == "EAGLE3":
168172
spec_config = EagleDecodingConfig(
169173
max_draft_len=args.spec_decode_max_draft_len,

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 92 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
import tensorrt_llm.bindings.internal.userbuffers as ub
1818
from tensorrt_llm._torch.models.checkpoints.base_checkpoint_loader import \
1919
BaseCheckpointLoader
20-
from tensorrt_llm._torch.pyexecutor.sampler import SampleStateTensors
2120
from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequest
21+
from tensorrt_llm._torch.pyexecutor.sampler import SampleStateTensors
2222
from tensorrt_llm._torch.speculative.mtp import SampleStateTensorsMTP
2323
from tensorrt_llm._utils import (is_trace_enabled, nvtx_range, release_gc,
2424
torch_dtype_to_str, trace_func)
@@ -261,7 +261,10 @@ def __init__(
261261
lora_config: Optional[LoraConfig] = None,
262262
is_draft_model: bool = False,
263263
):
264+
# Set deterministic seed for consistent multi-GPU sampling using PyTorch RNG
265+
# operations that avoid torch.multinomial's CPU-GPU sync overhead
264266
torch.manual_seed(0)
267+
265268
self.ub_buffers = None
266269
self.batch_size = batch_size
267270
self.max_num_tokens = max_num_tokens
@@ -278,6 +281,8 @@ def __init__(
278281
self.spec_config = spec_config
279282
self.is_spec_decode = spec_config is not None
280283
self.is_draft_model = is_draft_model
284+
self.is_advanced_mtp_sampler = self.is_spec_decode and self.spec_config.spec_dec_mode.is_mtp(
285+
) and self.spec_config.use_advanced_mtp_sampler
281286

282287
self.in_warmup = False
283288

@@ -373,18 +378,24 @@ def __init__(
373378
self.without_logits = self.spec_config.spec_dec_mode.without_logits(
374379
)
375380
self.max_draft_len = spec_config.max_draft_len
376-
self.temperatures_cuda = torch.empty((self.batch_size * (self.max_draft_len + 1), ),
377-
dtype=torch.float,
378-
device='cuda')
379-
self.top_k_cuda = torch.empty((self.batch_size * (self.max_draft_len + 1), ),
380-
dtype=torch.int,
381-
device='cuda')
382-
self.top_p_cuda = torch.empty((self.batch_size * (self.max_draft_len + 1), ),
383-
dtype=torch.float,
384-
device='cuda')
385-
self.min_p_cuda = torch.empty((self.batch_size * (self.max_draft_len + 1), ),
386-
dtype=torch.float,
387-
device='cuda')
381+
382+
if self.is_advanced_mtp_sampler:
383+
self.temperatures_cuda = torch.empty(
384+
(self.batch_size * (self.max_draft_len + 1), ),
385+
dtype=torch.float,
386+
device='cuda')
387+
self.top_k_cuda = torch.empty(
388+
(self.batch_size * (self.max_draft_len + 1), ),
389+
dtype=torch.int,
390+
device='cuda')
391+
self.top_p_cuda = torch.empty(
392+
(self.batch_size * (self.max_draft_len + 1), ),
393+
dtype=torch.float,
394+
device='cuda')
395+
self.min_p_cuda = torch.empty(
396+
(self.batch_size * (self.max_draft_len + 1), ),
397+
dtype=torch.float,
398+
device='cuda')
388399
else:
389400
self.without_logits = False
390401
self.max_draft_len = 0
@@ -1142,11 +1153,12 @@ def _prepare_tp_inputs(
11421153
draft_lens = []
11431154
multimodal_params_list = []
11441155
gen_request_seq_slots = [] # per generation request
1145-
1146-
temperatures = []
1147-
top_k = []
1148-
top_p = []
1149-
min_p = []
1156+
1157+
if self.is_advanced_mtp_sampler:
1158+
temperatures = []
1159+
top_k = []
1160+
top_p = []
1161+
min_p = []
11501162

11511163
def get_request_temperature(request: LlmRequest) -> float:
11521164
if not request.sampling_config.temperature:
@@ -1213,12 +1225,13 @@ def get_request_min_p(request: LlmRequest) -> float:
12131225

12141226
if multimodal_params.has_content():
12151227
multimodal_params_list.append(multimodal_params)
1216-
1217-
temperatures.append(get_request_temperature(request))
1218-
top_k.append(get_request_top_k(request))
1219-
top_p.append(get_request_top_p(request))
1220-
min_p.append(get_request_min_p(request))
1221-
1228+
1229+
if self.is_advanced_mtp_sampler:
1230+
temperatures.append(get_request_temperature(request))
1231+
top_k.append(get_request_top_k(request))
1232+
top_p.append(get_request_top_p(request))
1233+
min_p.append(get_request_min_p(request))
1234+
12221235
request.py_batch_idx = request.py_seq_slot
12231236

12241237
num_ctx_requests = len(scheduled_requests.context_requests)
@@ -1300,10 +1313,17 @@ def get_request_min_p(request: LlmRequest) -> float:
13001313
past_seen_token_num + 1 + num_draft_tokens)))
13011314
num_cached_tokens_per_seq.append(past_seen_token_num)
13021315
request_ids.append(request.py_request_id)
1303-
temperatures.extend([get_request_temperature(request)] * (num_draft_tokens + 1))
1304-
top_k.extend([get_request_top_k(request)] * (num_draft_tokens + 1))
1305-
top_p.extend([get_request_top_p(request)] * (num_draft_tokens + 1))
1306-
min_p.extend([get_request_min_p(request)] * (num_draft_tokens + 1))
1316+
1317+
if self.is_advanced_mtp_sampler:
1318+
temperatures.extend([get_request_temperature(request)] *
1319+
(num_draft_tokens + 1))
1320+
top_k.extend([get_request_top_k(request)] *
1321+
(num_draft_tokens + 1))
1322+
top_p.extend([get_request_top_p(request)] *
1323+
(num_draft_tokens + 1))
1324+
min_p.extend([get_request_min_p(request)] *
1325+
(num_draft_tokens + 1))
1326+
13071327
# update batch index
13081328
request.py_batch_idx = request.py_seq_slot
13091329
else:
@@ -1332,10 +1352,16 @@ def get_request_min_p(request: LlmRequest) -> float:
13321352
self.max_draft_len + 1)
13331353
prompt_lengths.append(request.py_prompt_len)
13341354
request_ids.append(request.py_request_id)
1335-
temperatures.extend([get_request_temperature(request)] * (self.max_draft_len + 1))
1336-
top_k.extend([get_request_top_k(request)] * (self.max_draft_len + 1))
1337-
top_p.extend([get_request_top_p(request)] * (self.max_draft_len + 1))
1338-
min_p.extend([get_request_min_p(request)] * (self.max_draft_len + 1))
1355+
1356+
if self.is_advanced_mtp_sampler:
1357+
temperatures.extend([get_request_temperature(request)] *
1358+
(self.max_draft_len + 1))
1359+
top_k.extend([get_request_top_k(request)] *
1360+
(self.max_draft_len + 1))
1361+
top_p.extend([get_request_top_p(request)] *
1362+
(self.max_draft_len + 1))
1363+
min_p.extend([get_request_min_p(request)] *
1364+
(self.max_draft_len + 1))
13391365

13401366
for request in generation_requests:
13411367
beam_width = request.sampling_config.beam_width
@@ -1368,11 +1394,17 @@ def get_request_min_p(request: LlmRequest) -> float:
13681394

13691395
request_ids.append(request.py_request_id)
13701396
gen_request_seq_slots.append(request.py_seq_slot)
1371-
1372-
temperatures.extend([get_request_temperature(request)] * (self.max_draft_len + 1))
1373-
top_k.extend([get_request_top_k(request)] * (self.max_draft_len + 1))
1374-
top_p.extend([get_request_top_p(request)] * (self.max_draft_len + 1))
1375-
min_p.extend([get_request_min_p(request)] * (self.max_draft_len + 1))
1397+
1398+
if self.is_advanced_mtp_sampler:
1399+
temperatures.extend([get_request_temperature(request)] *
1400+
(self.max_draft_len + 1))
1401+
top_k.extend([get_request_top_k(request)] *
1402+
(self.max_draft_len + 1))
1403+
top_p.extend([get_request_top_p(request)] *
1404+
(self.max_draft_len + 1))
1405+
min_p.extend([get_request_min_p(request)] *
1406+
(self.max_draft_len + 1))
1407+
13761408
request.py_batch_idx = request.py_seq_slot
13771409

13781410
previous_batch_len = len(previous_batch_indices)
@@ -1476,18 +1508,21 @@ def previous_seq_slots_device():
14761508
self.gather_ids_cuda[:len(gather_ids)].copy_(torch.tensor(
14771509
gather_ids, dtype=torch.int, pin_memory=True),
14781510
non_blocking=True)
1479-
self.temperatures_cuda[:len(temperatures)].copy_(torch.tensor(
1480-
temperatures, dtype=torch.float, pin_memory=True),
1481-
non_blocking=True)
1482-
self.top_k_cuda[:len(top_k)].copy_(torch.tensor(
1483-
top_k, dtype=torch.int, pin_memory=True),
1484-
non_blocking=True)
1485-
self.top_p_cuda[:len(top_p)].copy_(torch.tensor(
1486-
top_p, dtype=torch.float, pin_memory=True),
1487-
non_blocking=True)
1488-
self.min_p_cuda[:len(min_p)].copy_(torch.tensor(
1489-
min_p, dtype=torch.float, pin_memory=True),
1490-
non_blocking=True)
1511+
if self.is_advanced_mtp_sampler:
1512+
self.temperatures_cuda[:len(temperatures)].copy_(
1513+
torch.tensor(temperatures,
1514+
dtype=torch.float,
1515+
pin_memory=True),
1516+
non_blocking=True)
1517+
self.top_k_cuda[:len(top_k)].copy_(torch.tensor(
1518+
top_k, dtype=torch.int, pin_memory=True),
1519+
non_blocking=True)
1520+
self.top_p_cuda[:len(top_p)].copy_(torch.tensor(
1521+
top_p, dtype=torch.float, pin_memory=True),
1522+
non_blocking=True)
1523+
self.min_p_cuda[:len(min_p)].copy_(torch.tensor(
1524+
min_p, dtype=torch.float, pin_memory=True),
1525+
non_blocking=True)
14911526

14921527
if not attn_metadata.is_cuda_graph:
14931528
# Assumes seq lens do not change between CUDA graph invocations. This applies
@@ -1562,12 +1597,14 @@ def previous_seq_slots_device():
15621597
total_draft_lens]
15631598
spec_metadata.request_ids = request_ids
15641599
spec_metadata.gather_ids = self.gather_ids_cuda[:len(gather_ids)]
1565-
spec_metadata.temperatures = self.temperatures_cuda[:len(temperatures)]
1566-
spec_metadata.top_k = self.top_k_cuda[:len(top_k)]
1567-
spec_metadata.top_p = self.top_p_cuda[:len(top_p)]
1568-
spec_metadata.min_p = self.min_p_cuda[:len(min_p)]
1569-
# if attn_metadata.is_cuda_graph and not torch.cuda.is_current_stream_capturing():
1570-
# spec_metadata.generator = torch.Generator(device='cpu').manual_seed(0)
1600+
1601+
if self.is_advanced_mtp_sampler:
1602+
spec_metadata.temperatures = self.temperatures_cuda[:len(
1603+
temperatures)]
1604+
spec_metadata.top_k = self.top_k_cuda[:len(top_k)]
1605+
spec_metadata.top_p = self.top_p_cuda[:len(top_p)]
1606+
spec_metadata.min_p = self.min_p_cuda[:len(min_p)]
1607+
15711608
spec_metadata.num_generations = len(
15721609
scheduled_requests.generation_requests)
15731610
spec_metadata.num_tokens = total_num_tokens

tensorrt_llm/_torch/pyexecutor/sampler.py

Lines changed: 17 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from typing import Literal, Optional
55

66
import torch
7-
import flashinfer
87

98
from tensorrt_llm._torch.pyexecutor.handle_logits import HandleLogits
109
from tensorrt_llm._torch.pyexecutor.make_decoding_batch_input_output import \
@@ -151,42 +150,6 @@ def top_p_sampling_batch(logits: torch.Tensor, top_p: float = 0.9):
151150
next_tokens = torch.multinomial(softmax, num_samples=1).squeeze(-1)
152151
return next_tokens, softmax
153152

154-
def flashinfer_sample(
155-
logits: torch.Tensor,
156-
k: Optional[torch.Tensor],
157-
p: Optional[torch.Tensor],
158-
generator: Optional[torch.Generator] = None,
159-
) -> torch.Tensor:
160-
"""Sample from the logits using FlashInfer.
161-
162-
Statistically, this function is equivalent to the `random_sample` function.
163-
However, this function is faster because it avoids sorting the logits tensor
164-
via rejection sampling.
165-
166-
NOTE: The outputs of this function do not necessarily match the outputs of
167-
the `random_sample` function. It only guarantees that the outputs are
168-
statistically equivalent.
169-
170-
NOTE: This function includes CPU-GPU synchronization, while `random_sample`
171-
does not. Call this function at the end of the forward pass to minimize
172-
the synchronization overhead.
173-
"""
174-
assert not (k is None and p is None)
175-
if k is None:
176-
# Top-p only.
177-
probs = logits.softmax(dim=-1, dtype=torch.float32)
178-
next_token_ids = flashinfer.sampling.top_p_sampling_from_probs(
179-
probs, p, deterministic=True, generator=generator)
180-
elif p is None:
181-
# Top-k only.
182-
probs = logits.softmax(dim=-1, dtype=torch.float32)
183-
next_token_ids = flashinfer.sampling.top_k_sampling_from_probs(
184-
probs, k, deterministic=True, generator=generator)
185-
else:
186-
# Both top-k and top-p.
187-
next_token_ids = flashinfer.sampling.top_k_top_p_sampling_from_logits(
188-
logits, k, p, deterministic=True, generator=generator)
189-
return next_token_ids.view(-1).long()
190153

191154
def forward_native(
192155
logits: torch.Tensor,
@@ -202,9 +165,8 @@ def forward_native(
202165
probs = logits.softmax(dim=-1, dtype=torch.float32)
203166
return random_sample(probs)
204167

205-
def random_sample(
206-
probs: torch.Tensor,
207-
) -> torch.Tensor:
168+
169+
def random_sample(probs: torch.Tensor, ) -> torch.Tensor:
208170
"""Randomly sample from the probabilities.
209171
210172
We use this function instead of torch.multinomial because torch.multinomial
@@ -214,6 +176,7 @@ def random_sample(
214176
q.exponential_()
215177
return probs.div_(q).argmax(dim=-1).view(-1)
216178

179+
217180
def apply_min_p(
218181
logits: torch.Tensor,
219182
min_p: torch.Tensor,
@@ -224,9 +187,7 @@ def apply_min_p(
224187
# Convert logits to probability distribution
225188
probability_values = torch.nn.functional.softmax(logits, dim=-1)
226189
# Calculate maximum probabilities per sequence
227-
max_probabilities = torch.amax(probability_values,
228-
dim=-1,
229-
keepdim=True)
190+
max_probabilities = torch.amax(probability_values, dim=-1, keepdim=True)
230191
# Reshape min_p for broadcasting
231192
adjusted_min_p = min_p.unsqueeze(1) * max_probabilities
232193
# Identify valid tokens using threshold comparison
@@ -235,6 +196,7 @@ def apply_min_p(
235196
logits[~valid_token_mask] = -float('inf')
236197
return logits
237198

199+
238200
def apply_top_k_top_p(
239201
logits: torch.Tensor,
240202
k: Optional[torch.Tensor],
@@ -268,44 +230,39 @@ def apply_top_k_top_p(
268230
logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort)
269231
return logits
270232

233+
271234
def greedy_sample(logits: torch.Tensor) -> torch.Tensor:
272235
return logits.argmax(dim=-1).view(-1)
273236

237+
274238
def apply_temperature(
275239
logits: torch.Tensor,
276240
temp: torch.Tensor,
277241
) -> torch.Tensor:
278242
# Use in-place division to avoid creating a new tensor.
279243
return logits.div_(temp.unsqueeze(dim=1))
280244

281-
def sampling_batch(
282-
logits: torch.Tensor,
283-
temperatures: torch.Tensor,
284-
top_k: torch.Tensor,
285-
top_p: torch.Tensor,
286-
min_p: torch.Tensor
287-
) -> tuple[torch.Tensor, torch.Tensor]:
245+
246+
def sampling_batch(logits: torch.Tensor, temperatures: torch.Tensor,
247+
top_k: torch.Tensor, top_p: torch.Tensor,
248+
min_p: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
288249
raw_probs = torch.softmax(logits, dim=-1)
289250
greedy_sampled = greedy_sample(logits)
290251
logits = apply_temperature(logits, temperatures)
291252
logits = apply_min_p(logits, min_p)
292-
# if not torch.cuda.is_current_stream_capturing():
293-
# generator = torch.Generator(device="cuda")
294-
# generator.manual_seed(0)
295-
# next_tokens = flashinfer_sample(adjusted_logits, top_k, top_p, generator)
296-
# logits = apply_top_k_top_p(logits, top_k, top_p)
297253
random_sampled = forward_native(logits, top_k, top_p)
298254
next_tokens = torch.where(
299-
temperatures < 1e-5,
300-
greedy_sampled,
301-
random_sampled,
302-
out=greedy_sampled, # Reuse tensor
303-
)
255+
temperatures <= 1e-2, # Match the clamping threshold
256+
greedy_sampled,
257+
random_sampled,
258+
out=greedy_sampled, # Reuse tensor
259+
)
304260
token_probs = torch.gather(raw_probs, dim=1,
305261
index=next_tokens.unsqueeze(1)).squeeze(-1)
306262
log_probs = torch.log(token_probs)
307263
return next_tokens, log_probs
308264

265+
309266
def greedy_search_sampling_batch(logits):
310267
next_tokens = torch.argmax(logits, dim=-1)
311268
softmax = torch.softmax(logits, dim=-1)

0 commit comments

Comments
 (0)