Skip to content

Commit 3be14a6

Browse files
committed
code refactoring
Signed-off-by: Xuanyu Chen <[email protected]>
1 parent b66befe commit 3be14a6

File tree

3 files changed

+105
-99
lines changed

3 files changed

+105
-99
lines changed

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 51 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -263,10 +263,6 @@ def __init__(
263263
lora_config: Optional[LoraConfig] = None,
264264
is_draft_model: bool = False,
265265
):
266-
# Set deterministic seed for consistent multi-GPU sampling using PyTorch RNG
267-
# operations that avoid torch.multinomial's CPU-GPU sync overhead
268-
torch.manual_seed(0)
269-
270266
self.ub_buffers = None
271267
self.batch_size = batch_size
272268
self.max_num_tokens = max_num_tokens
@@ -381,23 +377,6 @@ def __init__(
381377
self.without_logits = self.spec_config.spec_dec_mode.without_logits(
382378
)
383379
self.max_draft_len = spec_config.max_draft_len
384-
385-
if self.is_advanced_mtp_sampler:
386-
mtp_total_sampling_size = self.batch_size * (
387-
self.max_draft_len + 1)
388-
self.temperatures_cuda = torch.empty(
389-
(mtp_total_sampling_size, ),
390-
dtype=torch.float,
391-
device='cuda')
392-
self.top_k_cuda = torch.empty((mtp_total_sampling_size, ),
393-
dtype=torch.int,
394-
device='cuda')
395-
self.top_p_cuda = torch.empty((mtp_total_sampling_size, ),
396-
dtype=torch.float,
397-
device='cuda')
398-
self.min_p_cuda = torch.empty((mtp_total_sampling_size, ),
399-
dtype=torch.float,
400-
device='cuda')
401380
else:
402381
self.without_logits = False
403382
self.max_draft_len = 0
@@ -1196,38 +1175,50 @@ def _prepare_tp_inputs(
11961175
top_p = []
11971176
min_p = []
11981177

1199-
def get_request_temperature(request: LlmRequest) -> float:
1200-
if not request.sampling_config.temperature:
1201-
return 1.0
1202-
temperature = request.sampling_config.temperature[0]
1203-
if 0 < temperature < 1e-2:
1204-
# temperature less than 0.01 may cause numerical errors
1205-
temperature = 0.01
1206-
return temperature
1207-
1208-
def get_request_top_k(request: LlmRequest) -> int:
1209-
if not request.sampling_config.top_k:
1210-
top_k = 0
1211-
else:
1212-
top_k = request.sampling_config.top_k[0]
1178+
# advanced mtp sampling's request preprocessing helper functions
1179+
def collect_req_mtp_sampling_params(request: LlmRequest,
1180+
draft_len: int = 0):
1181+
1182+
def get_request_temperature(request: LlmRequest) -> float:
1183+
if not request.sampling_config.temperature:
1184+
return 1.0
1185+
temperature = request.sampling_config.temperature[0]
1186+
if 0 < temperature < 1e-2:
1187+
# temperature less than 0.01 may cause numerical errors
1188+
temperature = 0.01
1189+
return temperature
1190+
1191+
def get_request_top_k(request: LlmRequest) -> int:
1192+
if not request.sampling_config.top_k:
1193+
top_k = 0
1194+
else:
1195+
top_k = request.sampling_config.top_k[0]
12131196

1214-
if top_k <= 0:
1215-
top_k = 2147483647
1216-
return top_k
1197+
# set k to a very large value (larger than vocab size) to disable top_k sampling
1198+
TOP_K_DISABLED = (1 << 31) - 1
1199+
if top_k <= 0:
1200+
top_k = TOP_K_DISABLED
1201+
return top_k
12171202

1218-
def get_request_top_p(request: LlmRequest) -> float:
1219-
if not request.sampling_config.top_p:
1220-
top_p = 1.0
1221-
else:
1222-
top_p = request.sampling_config.top_p[0]
1223-
return top_p
1203+
def get_request_top_p(request: LlmRequest) -> float:
1204+
if not request.sampling_config.top_p:
1205+
top_p = 1.0
1206+
else:
1207+
top_p = request.sampling_config.top_p[0]
1208+
return top_p
12241209

1225-
def get_request_min_p(request: LlmRequest) -> float:
1226-
if not request.sampling_config.min_p:
1227-
min_p = 0.0
1228-
else:
1229-
min_p = request.sampling_config.min_p[0]
1230-
return min_p
1210+
def get_request_min_p(request: LlmRequest) -> float:
1211+
if not request.sampling_config.min_p:
1212+
min_p = 0.0
1213+
else:
1214+
min_p = request.sampling_config.min_p[0]
1215+
return min_p
1216+
1217+
temperatures.extend([get_request_temperature(request)] *
1218+
(draft_len + 1))
1219+
top_k.extend([get_request_top_k(request)] * (draft_len + 1))
1220+
top_p.extend([get_request_top_p(request)] * (draft_len + 1))
1221+
min_p.extend([get_request_min_p(request)] * (draft_len + 1))
12311222

12321223
for request in scheduled_requests.context_requests:
12331224
request_ids.append(request.py_request_id)
@@ -1263,10 +1254,7 @@ def get_request_min_p(request: LlmRequest) -> float:
12631254
multimodal_params_list.append(multimodal_params)
12641255

12651256
if self.is_advanced_mtp_sampler:
1266-
temperatures.append(get_request_temperature(request))
1267-
top_k.append(get_request_top_k(request))
1268-
top_p.append(get_request_top_p(request))
1269-
min_p.append(get_request_min_p(request))
1257+
collect_req_mtp_sampling_params(request)
12701258

12711259
request.py_batch_idx = request.py_seq_slot
12721260

@@ -1352,14 +1340,7 @@ def get_request_min_p(request: LlmRequest) -> float:
13521340
request_ids.append(request.py_request_id)
13531341

13541342
if self.is_advanced_mtp_sampler:
1355-
temperatures.extend([get_request_temperature(request)] *
1356-
(num_draft_tokens + 1))
1357-
top_k.extend([get_request_top_k(request)] *
1358-
(num_draft_tokens + 1))
1359-
top_p.extend([get_request_top_p(request)] *
1360-
(num_draft_tokens + 1))
1361-
min_p.extend([get_request_min_p(request)] *
1362-
(num_draft_tokens + 1))
1343+
collect_req_mtp_sampling_params(request, num_draft_tokens)
13631344

13641345
# update batch index
13651346
request.py_batch_idx = request.py_seq_slot
@@ -1391,14 +1372,7 @@ def get_request_min_p(request: LlmRequest) -> float:
13911372
request_ids.append(request.py_request_id)
13921373

13931374
if self.is_advanced_mtp_sampler:
1394-
temperatures.extend([get_request_temperature(request)] *
1395-
(self.max_draft_len + 1))
1396-
top_k.extend([get_request_top_k(request)] *
1397-
(self.max_draft_len + 1))
1398-
top_p.extend([get_request_top_p(request)] *
1399-
(self.max_draft_len + 1))
1400-
min_p.extend([get_request_min_p(request)] *
1401-
(self.max_draft_len + 1))
1375+
collect_req_mtp_sampling_params(request, self.max_draft_len)
14021376

14031377
for request in generation_requests:
14041378
beam_width = request.sampling_config.beam_width
@@ -1433,14 +1407,7 @@ def get_request_min_p(request: LlmRequest) -> float:
14331407
gen_request_seq_slots.append(request.py_seq_slot)
14341408

14351409
if self.is_advanced_mtp_sampler:
1436-
temperatures.extend([get_request_temperature(request)] *
1437-
(self.max_draft_len + 1))
1438-
top_k.extend([get_request_top_k(request)] *
1439-
(self.max_draft_len + 1))
1440-
top_p.extend([get_request_top_p(request)] *
1441-
(self.max_draft_len + 1))
1442-
min_p.extend([get_request_min_p(request)] *
1443-
(self.max_draft_len + 1))
1410+
collect_req_mtp_sampling_params(request, self.max_draft_len)
14441411

14451412
request.py_batch_idx = request.py_seq_slot
14461413

@@ -1561,21 +1528,6 @@ def previous_seq_slots_device():
15611528
self.gather_ids_cuda[:len(gather_ids)].copy_(torch.tensor(
15621529
gather_ids, dtype=torch.int, pin_memory=True),
15631530
non_blocking=True)
1564-
if self.is_advanced_mtp_sampler:
1565-
self.temperatures_cuda[:len(temperatures)].copy_(
1566-
torch.tensor(temperatures,
1567-
dtype=torch.float,
1568-
pin_memory=True),
1569-
non_blocking=True)
1570-
self.top_k_cuda[:len(top_k)].copy_(torch.tensor(
1571-
top_k, dtype=torch.int, pin_memory=True),
1572-
non_blocking=True)
1573-
self.top_p_cuda[:len(top_p)].copy_(torch.tensor(
1574-
top_p, dtype=torch.float, pin_memory=True),
1575-
non_blocking=True)
1576-
self.min_p_cuda[:len(min_p)].copy_(torch.tensor(
1577-
min_p, dtype=torch.float, pin_memory=True),
1578-
non_blocking=True)
15791531

15801532
if not attn_metadata.is_cuda_graph:
15811533
# Assumes seq lens do not change between CUDA graph invocations. This applies
@@ -1651,11 +1603,8 @@ def previous_seq_slots_device():
16511603
spec_metadata.gather_ids = self.gather_ids_cuda[:len(gather_ids)]
16521604

16531605
if self.is_advanced_mtp_sampler:
1654-
spec_metadata.temperatures = self.temperatures_cuda[:len(
1655-
temperatures)]
1656-
spec_metadata.top_k = self.top_k_cuda[:len(top_k)]
1657-
spec_metadata.top_p = self.top_p_cuda[:len(top_p)]
1658-
spec_metadata.min_p = self.min_p_cuda[:len(min_p)]
1606+
spec_metadata.update_advanced_mtp_sampling_params(
1607+
temperatures, top_k, top_p, min_p)
16591608

16601609
spec_metadata.num_generations = len(
16611610
scheduled_requests.generation_requests)
@@ -2205,6 +2154,10 @@ def forward(
22052154
spec_metadata.is_spec_dec_tree,
22062155
spec_metadata.is_spec_dec_dynamic_tree,
22072156
spec_metadata.max_draft_len)
2157+
2158+
if self.is_advanced_mtp_sampler:
2159+
spec_metadata._set_up_advanced_mtp_sampling(
2160+
self.batch_size, self.max_draft_len)
22082161
else:
22092162
spec_resource_manager = None
22102163
spec_metadata = None

tensorrt_llm/_torch/speculative/mtp.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,11 @@ class MTPSpecMetadata(SpecMetadata):
115115
# subsequence draft forward.
116116
subseq_all_rank_num_tokens: Optional[List[int]] = None
117117

118+
temperatures_cuda: Optional[torch.Tensor] = None
119+
top_k_cuda: Optional[torch.Tensor] = None
120+
top_p_cuda: Optional[torch.Tensor] = None
121+
min_p_cuda: Optional[torch.Tensor] = None
122+
118123
def __post_init__(self) -> None:
119124
if self.mtp_hidden_states_manager is not None:
120125
# mtp_hidden_states_ptrs is a pointer tensor
@@ -204,6 +209,53 @@ def prepare(self):
204209
pin_memory=True)
205210
self.slot_ids[:num_seqs].copy_(mtp_slot_ids, non_blocking=True)
206211

212+
def _set_up_advanced_mtp_sampling(self, batch_size: int,
213+
max_draft_len: int):
214+
# create once and reuse
215+
if self.temperatures_cuda is None:
216+
# Set deterministic seed (one time) for consistent multi-GPU sampling using PyTorch RNG
217+
# operations that avoid torch.multinomial's CPU-GPU sync overhead
218+
torch.manual_seed(0)
219+
220+
max_total_sampling_size = batch_size * (max_draft_len + 1)
221+
self.temperatures_cuda = torch.empty((max_total_sampling_size, ),
222+
dtype=torch.float,
223+
device='cuda')
224+
self.top_k_cuda = torch.empty((max_total_sampling_size, ),
225+
dtype=torch.int,
226+
device='cuda')
227+
self.top_p_cuda = torch.empty((max_total_sampling_size, ),
228+
dtype=torch.float,
229+
device='cuda')
230+
self.min_p_cuda = torch.empty((max_total_sampling_size, ),
231+
dtype=torch.float,
232+
device='cuda')
233+
234+
def update_advanced_mtp_sampling_params(self, temperatures: list[float],
235+
top_k: list[int],
236+
top_p: list[float],
237+
min_p: list[float]):
238+
self.temperatures_cuda[:len(temperatures)].copy_(torch.tensor(
239+
temperatures, dtype=torch.float, pin_memory=True),
240+
non_blocking=True)
241+
self.top_k_cuda[:len(top_k)].copy_(torch.tensor(top_k,
242+
dtype=torch.int,
243+
pin_memory=True),
244+
non_blocking=True)
245+
self.top_p_cuda[:len(top_p)].copy_(torch.tensor(top_p,
246+
dtype=torch.float,
247+
pin_memory=True),
248+
non_blocking=True)
249+
self.min_p_cuda[:len(min_p)].copy_(torch.tensor(min_p,
250+
dtype=torch.float,
251+
pin_memory=True),
252+
non_blocking=True)
253+
254+
self.temperatures = self.temperatures_cuda[:len(temperatures)]
255+
self.top_k = self.top_k_cuda[:len(top_k)]
256+
self.top_p = self.top_p_cuda[:len(top_p)]
257+
self.min_p = self.min_p_cuda[:len(min_p)]
258+
207259

208260
class MTPSampler(TorchSampler):
209261
"""

tests/unittest/_torch/speculative/test_mtp.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,8 @@ def test_sample_and_accept_draft_tokens_adv_torch_sampler_greedy_mode(
367367
for i in range(batch_size):
368368
num_draft_tokens = draft_len[i]
369369
# set to greedy sampling mode (temperature <= 0.01 boundary) for advanced pytorch sampler
370-
# sampling default config vals set in [tensorrt_llm/_torch/pyexecutor/model_engine.py:get_request_[param_name]]
370+
# sampling default config vals set in
371+
# [tensorrt_llm/_torch/pyexecutor/model_engine.py:get_request_[param_name]]
371372
temperatures.extend([0.01] * (num_draft_tokens + 1))
372373
top_k.extend([2147483647] * (num_draft_tokens + 1))
373374
top_p.extend([1.0] * (num_draft_tokens + 1))

0 commit comments

Comments
 (0)