Skip to content

Commit 9c4da6b

Browse files
committed
code refactoring
Signed-off-by: Xuanyu Chen <[email protected]>
1 parent 5ccfa4d commit 9c4da6b

File tree

3 files changed

+105
-99
lines changed

3 files changed

+105
-99
lines changed

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 51 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -263,10 +263,6 @@ def __init__(
263263
lora_config: Optional[LoraConfig] = None,
264264
is_draft_model: bool = False,
265265
):
266-
# Set deterministic seed for consistent multi-GPU sampling using PyTorch RNG
267-
# operations that avoid torch.multinomial's CPU-GPU sync overhead
268-
torch.manual_seed(0)
269-
270266
self.ub_buffers = None
271267
self.batch_size = batch_size
272268
self.max_num_tokens = max_num_tokens
@@ -381,23 +377,6 @@ def __init__(
381377
self.without_logits = self.spec_config.spec_dec_mode.without_logits(
382378
)
383379
self.max_draft_len = spec_config.max_draft_len
384-
385-
if self.is_advanced_mtp_sampler:
386-
mtp_total_sampling_size = self.batch_size * (
387-
self.max_draft_len + 1)
388-
self.temperatures_cuda = torch.empty(
389-
(mtp_total_sampling_size, ),
390-
dtype=torch.float,
391-
device='cuda')
392-
self.top_k_cuda = torch.empty((mtp_total_sampling_size, ),
393-
dtype=torch.int,
394-
device='cuda')
395-
self.top_p_cuda = torch.empty((mtp_total_sampling_size, ),
396-
dtype=torch.float,
397-
device='cuda')
398-
self.min_p_cuda = torch.empty((mtp_total_sampling_size, ),
399-
dtype=torch.float,
400-
device='cuda')
401380
else:
402381
self.without_logits = False
403382
self.max_draft_len = 0
@@ -1185,38 +1164,50 @@ def _prepare_tp_inputs(
11851164
top_p = []
11861165
min_p = []
11871166

1188-
def get_request_temperature(request: LlmRequest) -> float:
1189-
if not request.sampling_config.temperature:
1190-
return 1.0
1191-
temperature = request.sampling_config.temperature[0]
1192-
if 0 < temperature < 1e-2:
1193-
# temperature less than 0.01 may cause numerical errors
1194-
temperature = 0.01
1195-
return temperature
1196-
1197-
def get_request_top_k(request: LlmRequest) -> int:
1198-
if not request.sampling_config.top_k:
1199-
top_k = 0
1200-
else:
1201-
top_k = request.sampling_config.top_k[0]
1167+
# advanced mtp sampling's request preprocessing helper functions
1168+
def collect_req_mtp_sampling_params(request: LlmRequest,
1169+
draft_len: int = 0):
1170+
1171+
def get_request_temperature(request: LlmRequest) -> float:
1172+
if not request.sampling_config.temperature:
1173+
return 1.0
1174+
temperature = request.sampling_config.temperature[0]
1175+
if 0 < temperature < 1e-2:
1176+
# temperature less than 0.01 may cause numerical errors
1177+
temperature = 0.01
1178+
return temperature
1179+
1180+
def get_request_top_k(request: LlmRequest) -> int:
1181+
if not request.sampling_config.top_k:
1182+
top_k = 0
1183+
else:
1184+
top_k = request.sampling_config.top_k[0]
12021185

1203-
if top_k <= 0:
1204-
top_k = 2147483647
1205-
return top_k
1186+
# set k to a very large value (larger than vocab size) to disable top_k sampling
1187+
TOP_K_DISABLED = (1 << 31) - 1
1188+
if top_k <= 0:
1189+
top_k = TOP_K_DISABLED
1190+
return top_k
12061191

1207-
def get_request_top_p(request: LlmRequest) -> float:
1208-
if not request.sampling_config.top_p:
1209-
top_p = 1.0
1210-
else:
1211-
top_p = request.sampling_config.top_p[0]
1212-
return top_p
1192+
def get_request_top_p(request: LlmRequest) -> float:
1193+
if not request.sampling_config.top_p:
1194+
top_p = 1.0
1195+
else:
1196+
top_p = request.sampling_config.top_p[0]
1197+
return top_p
12131198

1214-
def get_request_min_p(request: LlmRequest) -> float:
1215-
if not request.sampling_config.min_p:
1216-
min_p = 0.0
1217-
else:
1218-
min_p = request.sampling_config.min_p[0]
1219-
return min_p
1199+
def get_request_min_p(request: LlmRequest) -> float:
1200+
if not request.sampling_config.min_p:
1201+
min_p = 0.0
1202+
else:
1203+
min_p = request.sampling_config.min_p[0]
1204+
return min_p
1205+
1206+
temperatures.extend([get_request_temperature(request)] *
1207+
(draft_len + 1))
1208+
top_k.extend([get_request_top_k(request)] * (draft_len + 1))
1209+
top_p.extend([get_request_top_p(request)] * (draft_len + 1))
1210+
min_p.extend([get_request_min_p(request)] * (draft_len + 1))
12201211

12211212
for request in scheduled_requests.context_requests:
12221213
request_ids.append(request.py_request_id)
@@ -1252,10 +1243,7 @@ def get_request_min_p(request: LlmRequest) -> float:
12521243
multimodal_params_list.append(multimodal_params)
12531244

12541245
if self.is_advanced_mtp_sampler:
1255-
temperatures.append(get_request_temperature(request))
1256-
top_k.append(get_request_top_k(request))
1257-
top_p.append(get_request_top_p(request))
1258-
min_p.append(get_request_min_p(request))
1246+
collect_req_mtp_sampling_params(request)
12591247

12601248
request.py_batch_idx = request.py_seq_slot
12611249

@@ -1341,14 +1329,7 @@ def get_request_min_p(request: LlmRequest) -> float:
13411329
request_ids.append(request.py_request_id)
13421330

13431331
if self.is_advanced_mtp_sampler:
1344-
temperatures.extend([get_request_temperature(request)] *
1345-
(num_draft_tokens + 1))
1346-
top_k.extend([get_request_top_k(request)] *
1347-
(num_draft_tokens + 1))
1348-
top_p.extend([get_request_top_p(request)] *
1349-
(num_draft_tokens + 1))
1350-
min_p.extend([get_request_min_p(request)] *
1351-
(num_draft_tokens + 1))
1332+
collect_req_mtp_sampling_params(request, num_draft_tokens)
13521333

13531334
# update batch index
13541335
request.py_batch_idx = request.py_seq_slot
@@ -1380,14 +1361,7 @@ def get_request_min_p(request: LlmRequest) -> float:
13801361
request_ids.append(request.py_request_id)
13811362

13821363
if self.is_advanced_mtp_sampler:
1383-
temperatures.extend([get_request_temperature(request)] *
1384-
(self.max_draft_len + 1))
1385-
top_k.extend([get_request_top_k(request)] *
1386-
(self.max_draft_len + 1))
1387-
top_p.extend([get_request_top_p(request)] *
1388-
(self.max_draft_len + 1))
1389-
min_p.extend([get_request_min_p(request)] *
1390-
(self.max_draft_len + 1))
1364+
collect_req_mtp_sampling_params(request, self.max_draft_len)
13911365

13921366
for request in generation_requests:
13931367
beam_width = request.sampling_config.beam_width
@@ -1422,14 +1396,7 @@ def get_request_min_p(request: LlmRequest) -> float:
14221396
gen_request_seq_slots.append(request.py_seq_slot)
14231397

14241398
if self.is_advanced_mtp_sampler:
1425-
temperatures.extend([get_request_temperature(request)] *
1426-
(self.max_draft_len + 1))
1427-
top_k.extend([get_request_top_k(request)] *
1428-
(self.max_draft_len + 1))
1429-
top_p.extend([get_request_top_p(request)] *
1430-
(self.max_draft_len + 1))
1431-
min_p.extend([get_request_min_p(request)] *
1432-
(self.max_draft_len + 1))
1399+
collect_req_mtp_sampling_params(request, self.max_draft_len)
14331400

14341401
request.py_batch_idx = request.py_seq_slot
14351402

@@ -1550,21 +1517,6 @@ def previous_seq_slots_device():
15501517
self.gather_ids_cuda[:len(gather_ids)].copy_(torch.tensor(
15511518
gather_ids, dtype=torch.int, pin_memory=True),
15521519
non_blocking=True)
1553-
if self.is_advanced_mtp_sampler:
1554-
self.temperatures_cuda[:len(temperatures)].copy_(
1555-
torch.tensor(temperatures,
1556-
dtype=torch.float,
1557-
pin_memory=True),
1558-
non_blocking=True)
1559-
self.top_k_cuda[:len(top_k)].copy_(torch.tensor(
1560-
top_k, dtype=torch.int, pin_memory=True),
1561-
non_blocking=True)
1562-
self.top_p_cuda[:len(top_p)].copy_(torch.tensor(
1563-
top_p, dtype=torch.float, pin_memory=True),
1564-
non_blocking=True)
1565-
self.min_p_cuda[:len(min_p)].copy_(torch.tensor(
1566-
min_p, dtype=torch.float, pin_memory=True),
1567-
non_blocking=True)
15681520

15691521
if not attn_metadata.is_cuda_graph:
15701522
# Assumes seq lens do not change between CUDA graph invocations. This applies
@@ -1640,11 +1592,8 @@ def previous_seq_slots_device():
16401592
spec_metadata.gather_ids = self.gather_ids_cuda[:len(gather_ids)]
16411593

16421594
if self.is_advanced_mtp_sampler:
1643-
spec_metadata.temperatures = self.temperatures_cuda[:len(
1644-
temperatures)]
1645-
spec_metadata.top_k = self.top_k_cuda[:len(top_k)]
1646-
spec_metadata.top_p = self.top_p_cuda[:len(top_p)]
1647-
spec_metadata.min_p = self.min_p_cuda[:len(min_p)]
1595+
spec_metadata.update_advanced_mtp_sampling_params(
1596+
temperatures, top_k, top_p, min_p)
16481597

16491598
spec_metadata.num_generations = len(
16501599
scheduled_requests.generation_requests)
@@ -2194,6 +2143,10 @@ def forward(
21942143
spec_metadata.is_spec_dec_tree,
21952144
spec_metadata.is_spec_dec_dynamic_tree,
21962145
spec_metadata.max_draft_len)
2146+
2147+
if self.is_advanced_mtp_sampler:
2148+
spec_metadata._set_up_advanced_mtp_sampling(
2149+
self.batch_size, self.max_draft_len)
21972150
else:
21982151
spec_metadata = None
21992152

tensorrt_llm/_torch/speculative/mtp.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,11 @@ class MTPSpecMetadata(SpecMetadata):
108108
# subsequence draft forward.
109109
subseq_all_rank_num_tokens: Optional[List[int]] = None
110110

111+
temperatures_cuda: Optional[torch.Tensor] = None
112+
top_k_cuda: Optional[torch.Tensor] = None
113+
top_p_cuda: Optional[torch.Tensor] = None
114+
min_p_cuda: Optional[torch.Tensor] = None
115+
111116
def __post_init__(self) -> None:
112117
if self.mtp_hidden_states_manager is not None:
113118
# mtp_hidden_states_ptrs is a pointer tensor
@@ -197,6 +202,53 @@ def prepare(self):
197202
pin_memory=True)
198203
self.slot_ids[:num_seqs].copy_(mtp_slot_ids, non_blocking=True)
199204

205+
def _set_up_advanced_mtp_sampling(self, batch_size: int,
206+
max_draft_len: int):
207+
# create once and reuse
208+
if self.temperatures_cuda is None:
209+
# Set deterministic seed (one time) for consistent multi-GPU sampling using PyTorch RNG
210+
# operations that avoid torch.multinomial's CPU-GPU sync overhead
211+
torch.manual_seed(0)
212+
213+
max_total_sampling_size = batch_size * (max_draft_len + 1)
214+
self.temperatures_cuda = torch.empty((max_total_sampling_size, ),
215+
dtype=torch.float,
216+
device='cuda')
217+
self.top_k_cuda = torch.empty((max_total_sampling_size, ),
218+
dtype=torch.int,
219+
device='cuda')
220+
self.top_p_cuda = torch.empty((max_total_sampling_size, ),
221+
dtype=torch.float,
222+
device='cuda')
223+
self.min_p_cuda = torch.empty((max_total_sampling_size, ),
224+
dtype=torch.float,
225+
device='cuda')
226+
227+
def update_advanced_mtp_sampling_params(self, temperatures: list[float],
228+
top_k: list[int],
229+
top_p: list[float],
230+
min_p: list[float]):
231+
self.temperatures_cuda[:len(temperatures)].copy_(torch.tensor(
232+
temperatures, dtype=torch.float, pin_memory=True),
233+
non_blocking=True)
234+
self.top_k_cuda[:len(top_k)].copy_(torch.tensor(top_k,
235+
dtype=torch.int,
236+
pin_memory=True),
237+
non_blocking=True)
238+
self.top_p_cuda[:len(top_p)].copy_(torch.tensor(top_p,
239+
dtype=torch.float,
240+
pin_memory=True),
241+
non_blocking=True)
242+
self.min_p_cuda[:len(min_p)].copy_(torch.tensor(min_p,
243+
dtype=torch.float,
244+
pin_memory=True),
245+
non_blocking=True)
246+
247+
self.temperatures = self.temperatures_cuda[:len(temperatures)]
248+
self.top_k = self.top_k_cuda[:len(top_k)]
249+
self.top_p = self.top_p_cuda[:len(top_p)]
250+
self.min_p = self.min_p_cuda[:len(min_p)]
251+
200252

201253
class MTPSampler(TorchSampler):
202254
"""

tests/unittest/_torch/speculative/test_mtp.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,8 @@ def test_sample_and_accept_draft_tokens_adv_torch_sampler_greedy_mode(
367367
for i in range(batch_size):
368368
num_draft_tokens = draft_len[i]
369369
# set to greedy sampling mode (temperature <= 0.01 boundary) for advanced pytorch sampler
370-
# sampling default config vals set in [tensorrt_llm/_torch/pyexecutor/model_engine.py:get_request_[param_name]]
370+
# sampling default config vals set in
371+
# [tensorrt_llm/_torch/pyexecutor/model_engine.py:get_request_[param_name]]
371372
temperatures.extend([0.01] * (num_draft_tokens + 1))
372373
top_k.extend([2147483647] * (num_draft_tokens + 1))
373374
top_p.extend([1.0] * (num_draft_tokens + 1))

0 commit comments

Comments
 (0)