@@ -263,10 +263,6 @@ def __init__(
263
263
lora_config : Optional [LoraConfig ] = None ,
264
264
is_draft_model : bool = False ,
265
265
):
266
- # Set deterministic seed for consistent multi-GPU sampling using PyTorch RNG
267
- # operations that avoid torch.multinomial's CPU-GPU sync overhead
268
- torch .manual_seed (0 )
269
-
270
266
self .ub_buffers = None
271
267
self .batch_size = batch_size
272
268
self .max_num_tokens = max_num_tokens
@@ -381,23 +377,6 @@ def __init__(
381
377
self .without_logits = self .spec_config .spec_dec_mode .without_logits (
382
378
)
383
379
self .max_draft_len = spec_config .max_draft_len
384
-
385
- if self .is_advanced_mtp_sampler :
386
- mtp_total_sampling_size = self .batch_size * (
387
- self .max_draft_len + 1 )
388
- self .temperatures_cuda = torch .empty (
389
- (mtp_total_sampling_size , ),
390
- dtype = torch .float ,
391
- device = 'cuda' )
392
- self .top_k_cuda = torch .empty ((mtp_total_sampling_size , ),
393
- dtype = torch .int ,
394
- device = 'cuda' )
395
- self .top_p_cuda = torch .empty ((mtp_total_sampling_size , ),
396
- dtype = torch .float ,
397
- device = 'cuda' )
398
- self .min_p_cuda = torch .empty ((mtp_total_sampling_size , ),
399
- dtype = torch .float ,
400
- device = 'cuda' )
401
380
else :
402
381
self .without_logits = False
403
382
self .max_draft_len = 0
@@ -1185,38 +1164,50 @@ def _prepare_tp_inputs(
1185
1164
top_p = []
1186
1165
min_p = []
1187
1166
1188
- def get_request_temperature (request : LlmRequest ) -> float :
1189
- if not request .sampling_config .temperature :
1190
- return 1.0
1191
- temperature = request .sampling_config .temperature [0 ]
1192
- if 0 < temperature < 1e-2 :
1193
- # temperature less than 0.01 may cause numerical errors
1194
- temperature = 0.01
1195
- return temperature
1196
-
1197
- def get_request_top_k (request : LlmRequest ) -> int :
1198
- if not request .sampling_config .top_k :
1199
- top_k = 0
1200
- else :
1201
- top_k = request .sampling_config .top_k [0 ]
1167
+ # advanced mtp sampling's request preprocessing helper functions
1168
+ def collect_req_mtp_sampling_params (request : LlmRequest ,
1169
+ draft_len : int = 0 ):
1170
+
1171
+ def get_request_temperature (request : LlmRequest ) -> float :
1172
+ if not request .sampling_config .temperature :
1173
+ return 1.0
1174
+ temperature = request .sampling_config .temperature [0 ]
1175
+ if 0 < temperature < 1e-2 :
1176
+ # temperature less than 0.01 may cause numerical errors
1177
+ temperature = 0.01
1178
+ return temperature
1179
+
1180
+ def get_request_top_k (request : LlmRequest ) -> int :
1181
+ if not request .sampling_config .top_k :
1182
+ top_k = 0
1183
+ else :
1184
+ top_k = request .sampling_config .top_k [0 ]
1202
1185
1203
- if top_k <= 0 :
1204
- top_k = 2147483647
1205
- return top_k
1186
+ # set k to a very large value (larger than vocab size) to disable top_k sampling
1187
+ TOP_K_DISABLED = (1 << 31 ) - 1
1188
+ if top_k <= 0 :
1189
+ top_k = TOP_K_DISABLED
1190
+ return top_k
1206
1191
1207
- def get_request_top_p (request : LlmRequest ) -> float :
1208
- if not request .sampling_config .top_p :
1209
- top_p = 1.0
1210
- else :
1211
- top_p = request .sampling_config .top_p [0 ]
1212
- return top_p
1192
+ def get_request_top_p (request : LlmRequest ) -> float :
1193
+ if not request .sampling_config .top_p :
1194
+ top_p = 1.0
1195
+ else :
1196
+ top_p = request .sampling_config .top_p [0 ]
1197
+ return top_p
1213
1198
1214
- def get_request_min_p (request : LlmRequest ) -> float :
1215
- if not request .sampling_config .min_p :
1216
- min_p = 0.0
1217
- else :
1218
- min_p = request .sampling_config .min_p [0 ]
1219
- return min_p
1199
+ def get_request_min_p (request : LlmRequest ) -> float :
1200
+ if not request .sampling_config .min_p :
1201
+ min_p = 0.0
1202
+ else :
1203
+ min_p = request .sampling_config .min_p [0 ]
1204
+ return min_p
1205
+
1206
+ temperatures .extend ([get_request_temperature (request )] *
1207
+ (draft_len + 1 ))
1208
+ top_k .extend ([get_request_top_k (request )] * (draft_len + 1 ))
1209
+ top_p .extend ([get_request_top_p (request )] * (draft_len + 1 ))
1210
+ min_p .extend ([get_request_min_p (request )] * (draft_len + 1 ))
1220
1211
1221
1212
for request in scheduled_requests .context_requests :
1222
1213
request_ids .append (request .py_request_id )
@@ -1252,10 +1243,7 @@ def get_request_min_p(request: LlmRequest) -> float:
1252
1243
multimodal_params_list .append (multimodal_params )
1253
1244
1254
1245
if self .is_advanced_mtp_sampler :
1255
- temperatures .append (get_request_temperature (request ))
1256
- top_k .append (get_request_top_k (request ))
1257
- top_p .append (get_request_top_p (request ))
1258
- min_p .append (get_request_min_p (request ))
1246
+ collect_req_mtp_sampling_params (request )
1259
1247
1260
1248
request .py_batch_idx = request .py_seq_slot
1261
1249
@@ -1341,14 +1329,7 @@ def get_request_min_p(request: LlmRequest) -> float:
1341
1329
request_ids .append (request .py_request_id )
1342
1330
1343
1331
if self .is_advanced_mtp_sampler :
1344
- temperatures .extend ([get_request_temperature (request )] *
1345
- (num_draft_tokens + 1 ))
1346
- top_k .extend ([get_request_top_k (request )] *
1347
- (num_draft_tokens + 1 ))
1348
- top_p .extend ([get_request_top_p (request )] *
1349
- (num_draft_tokens + 1 ))
1350
- min_p .extend ([get_request_min_p (request )] *
1351
- (num_draft_tokens + 1 ))
1332
+ collect_req_mtp_sampling_params (request , num_draft_tokens )
1352
1333
1353
1334
# update batch index
1354
1335
request .py_batch_idx = request .py_seq_slot
@@ -1380,14 +1361,7 @@ def get_request_min_p(request: LlmRequest) -> float:
1380
1361
request_ids .append (request .py_request_id )
1381
1362
1382
1363
if self .is_advanced_mtp_sampler :
1383
- temperatures .extend ([get_request_temperature (request )] *
1384
- (self .max_draft_len + 1 ))
1385
- top_k .extend ([get_request_top_k (request )] *
1386
- (self .max_draft_len + 1 ))
1387
- top_p .extend ([get_request_top_p (request )] *
1388
- (self .max_draft_len + 1 ))
1389
- min_p .extend ([get_request_min_p (request )] *
1390
- (self .max_draft_len + 1 ))
1364
+ collect_req_mtp_sampling_params (request , self .max_draft_len )
1391
1365
1392
1366
for request in generation_requests :
1393
1367
beam_width = request .sampling_config .beam_width
@@ -1422,14 +1396,7 @@ def get_request_min_p(request: LlmRequest) -> float:
1422
1396
gen_request_seq_slots .append (request .py_seq_slot )
1423
1397
1424
1398
if self .is_advanced_mtp_sampler :
1425
- temperatures .extend ([get_request_temperature (request )] *
1426
- (self .max_draft_len + 1 ))
1427
- top_k .extend ([get_request_top_k (request )] *
1428
- (self .max_draft_len + 1 ))
1429
- top_p .extend ([get_request_top_p (request )] *
1430
- (self .max_draft_len + 1 ))
1431
- min_p .extend ([get_request_min_p (request )] *
1432
- (self .max_draft_len + 1 ))
1399
+ collect_req_mtp_sampling_params (request , self .max_draft_len )
1433
1400
1434
1401
request .py_batch_idx = request .py_seq_slot
1435
1402
@@ -1550,21 +1517,6 @@ def previous_seq_slots_device():
1550
1517
self .gather_ids_cuda [:len (gather_ids )].copy_ (torch .tensor (
1551
1518
gather_ids , dtype = torch .int , pin_memory = True ),
1552
1519
non_blocking = True )
1553
- if self .is_advanced_mtp_sampler :
1554
- self .temperatures_cuda [:len (temperatures )].copy_ (
1555
- torch .tensor (temperatures ,
1556
- dtype = torch .float ,
1557
- pin_memory = True ),
1558
- non_blocking = True )
1559
- self .top_k_cuda [:len (top_k )].copy_ (torch .tensor (
1560
- top_k , dtype = torch .int , pin_memory = True ),
1561
- non_blocking = True )
1562
- self .top_p_cuda [:len (top_p )].copy_ (torch .tensor (
1563
- top_p , dtype = torch .float , pin_memory = True ),
1564
- non_blocking = True )
1565
- self .min_p_cuda [:len (min_p )].copy_ (torch .tensor (
1566
- min_p , dtype = torch .float , pin_memory = True ),
1567
- non_blocking = True )
1568
1520
1569
1521
if not attn_metadata .is_cuda_graph :
1570
1522
# Assumes seq lens do not change between CUDA graph invocations. This applies
@@ -1640,11 +1592,8 @@ def previous_seq_slots_device():
1640
1592
spec_metadata .gather_ids = self .gather_ids_cuda [:len (gather_ids )]
1641
1593
1642
1594
if self .is_advanced_mtp_sampler :
1643
- spec_metadata .temperatures = self .temperatures_cuda [:len (
1644
- temperatures )]
1645
- spec_metadata .top_k = self .top_k_cuda [:len (top_k )]
1646
- spec_metadata .top_p = self .top_p_cuda [:len (top_p )]
1647
- spec_metadata .min_p = self .min_p_cuda [:len (min_p )]
1595
+ spec_metadata .update_advanced_mtp_sampling_params (
1596
+ temperatures , top_k , top_p , min_p )
1648
1597
1649
1598
spec_metadata .num_generations = len (
1650
1599
scheduled_requests .generation_requests )
@@ -2194,6 +2143,10 @@ def forward(
2194
2143
spec_metadata .is_spec_dec_tree ,
2195
2144
spec_metadata .is_spec_dec_dynamic_tree ,
2196
2145
spec_metadata .max_draft_len )
2146
+
2147
+ if self .is_advanced_mtp_sampler :
2148
+ spec_metadata ._set_up_advanced_mtp_sampling (
2149
+ self .batch_size , self .max_draft_len )
2197
2150
else :
2198
2151
spec_metadata = None
2199
2152
0 commit comments