@@ -263,10 +263,6 @@ def __init__(
263
263
lora_config : Optional [LoraConfig ] = None ,
264
264
is_draft_model : bool = False ,
265
265
):
266
- # Set deterministic seed for consistent multi-GPU sampling using PyTorch RNG
267
- # operations that avoid torch.multinomial's CPU-GPU sync overhead
268
- torch .manual_seed (0 )
269
-
270
266
self .ub_buffers = None
271
267
self .batch_size = batch_size
272
268
self .max_num_tokens = max_num_tokens
@@ -381,23 +377,6 @@ def __init__(
381
377
self .without_logits = self .spec_config .spec_dec_mode .without_logits (
382
378
)
383
379
self .max_draft_len = spec_config .max_draft_len
384
-
385
- if self .is_advanced_mtp_sampler :
386
- mtp_total_sampling_size = self .batch_size * (
387
- self .max_draft_len + 1 )
388
- self .temperatures_cuda = torch .empty (
389
- (mtp_total_sampling_size , ),
390
- dtype = torch .float ,
391
- device = 'cuda' )
392
- self .top_k_cuda = torch .empty ((mtp_total_sampling_size , ),
393
- dtype = torch .int ,
394
- device = 'cuda' )
395
- self .top_p_cuda = torch .empty ((mtp_total_sampling_size , ),
396
- dtype = torch .float ,
397
- device = 'cuda' )
398
- self .min_p_cuda = torch .empty ((mtp_total_sampling_size , ),
399
- dtype = torch .float ,
400
- device = 'cuda' )
401
380
else :
402
381
self .without_logits = False
403
382
self .max_draft_len = 0
@@ -1196,38 +1175,50 @@ def _prepare_tp_inputs(
1196
1175
top_p = []
1197
1176
min_p = []
1198
1177
1199
- def get_request_temperature (request : LlmRequest ) -> float :
1200
- if not request .sampling_config .temperature :
1201
- return 1.0
1202
- temperature = request .sampling_config .temperature [0 ]
1203
- if 0 < temperature < 1e-2 :
1204
- # temperature less than 0.01 may cause numerical errors
1205
- temperature = 0.01
1206
- return temperature
1207
-
1208
- def get_request_top_k (request : LlmRequest ) -> int :
1209
- if not request .sampling_config .top_k :
1210
- top_k = 0
1211
- else :
1212
- top_k = request .sampling_config .top_k [0 ]
1178
+ # advanced mtp sampling's request preprocessing helper functions
1179
+ def collect_req_mtp_sampling_params (request : LlmRequest ,
1180
+ draft_len : int = 0 ):
1181
+
1182
+ def get_request_temperature (request : LlmRequest ) -> float :
1183
+ if not request .sampling_config .temperature :
1184
+ return 1.0
1185
+ temperature = request .sampling_config .temperature [0 ]
1186
+ if 0 < temperature < 1e-2 :
1187
+ # temperature less than 0.01 may cause numerical errors
1188
+ temperature = 0.01
1189
+ return temperature
1190
+
1191
+ def get_request_top_k (request : LlmRequest ) -> int :
1192
+ if not request .sampling_config .top_k :
1193
+ top_k = 0
1194
+ else :
1195
+ top_k = request .sampling_config .top_k [0 ]
1213
1196
1214
- if top_k <= 0 :
1215
- top_k = 2147483647
1216
- return top_k
1197
+ # set k to a very large value (larger than vocab size) to disable top_k sampling
1198
+ TOP_K_DISABLED = (1 << 31 ) - 1
1199
+ if top_k <= 0 :
1200
+ top_k = TOP_K_DISABLED
1201
+ return top_k
1217
1202
1218
- def get_request_top_p (request : LlmRequest ) -> float :
1219
- if not request .sampling_config .top_p :
1220
- top_p = 1.0
1221
- else :
1222
- top_p = request .sampling_config .top_p [0 ]
1223
- return top_p
1203
+ def get_request_top_p (request : LlmRequest ) -> float :
1204
+ if not request .sampling_config .top_p :
1205
+ top_p = 1.0
1206
+ else :
1207
+ top_p = request .sampling_config .top_p [0 ]
1208
+ return top_p
1224
1209
1225
- def get_request_min_p (request : LlmRequest ) -> float :
1226
- if not request .sampling_config .min_p :
1227
- min_p = 0.0
1228
- else :
1229
- min_p = request .sampling_config .min_p [0 ]
1230
- return min_p
1210
+ def get_request_min_p (request : LlmRequest ) -> float :
1211
+ if not request .sampling_config .min_p :
1212
+ min_p = 0.0
1213
+ else :
1214
+ min_p = request .sampling_config .min_p [0 ]
1215
+ return min_p
1216
+
1217
+ temperatures .extend ([get_request_temperature (request )] *
1218
+ (draft_len + 1 ))
1219
+ top_k .extend ([get_request_top_k (request )] * (draft_len + 1 ))
1220
+ top_p .extend ([get_request_top_p (request )] * (draft_len + 1 ))
1221
+ min_p .extend ([get_request_min_p (request )] * (draft_len + 1 ))
1231
1222
1232
1223
for request in scheduled_requests .context_requests :
1233
1224
request_ids .append (request .py_request_id )
@@ -1263,10 +1254,7 @@ def get_request_min_p(request: LlmRequest) -> float:
1263
1254
multimodal_params_list .append (multimodal_params )
1264
1255
1265
1256
if self .is_advanced_mtp_sampler :
1266
- temperatures .append (get_request_temperature (request ))
1267
- top_k .append (get_request_top_k (request ))
1268
- top_p .append (get_request_top_p (request ))
1269
- min_p .append (get_request_min_p (request ))
1257
+ collect_req_mtp_sampling_params (request )
1270
1258
1271
1259
request .py_batch_idx = request .py_seq_slot
1272
1260
@@ -1352,14 +1340,7 @@ def get_request_min_p(request: LlmRequest) -> float:
1352
1340
request_ids .append (request .py_request_id )
1353
1341
1354
1342
if self .is_advanced_mtp_sampler :
1355
- temperatures .extend ([get_request_temperature (request )] *
1356
- (num_draft_tokens + 1 ))
1357
- top_k .extend ([get_request_top_k (request )] *
1358
- (num_draft_tokens + 1 ))
1359
- top_p .extend ([get_request_top_p (request )] *
1360
- (num_draft_tokens + 1 ))
1361
- min_p .extend ([get_request_min_p (request )] *
1362
- (num_draft_tokens + 1 ))
1343
+ collect_req_mtp_sampling_params (request , num_draft_tokens )
1363
1344
1364
1345
# update batch index
1365
1346
request .py_batch_idx = request .py_seq_slot
@@ -1391,14 +1372,7 @@ def get_request_min_p(request: LlmRequest) -> float:
1391
1372
request_ids .append (request .py_request_id )
1392
1373
1393
1374
if self .is_advanced_mtp_sampler :
1394
- temperatures .extend ([get_request_temperature (request )] *
1395
- (self .max_draft_len + 1 ))
1396
- top_k .extend ([get_request_top_k (request )] *
1397
- (self .max_draft_len + 1 ))
1398
- top_p .extend ([get_request_top_p (request )] *
1399
- (self .max_draft_len + 1 ))
1400
- min_p .extend ([get_request_min_p (request )] *
1401
- (self .max_draft_len + 1 ))
1375
+ collect_req_mtp_sampling_params (request , self .max_draft_len )
1402
1376
1403
1377
for request in generation_requests :
1404
1378
beam_width = request .sampling_config .beam_width
@@ -1433,14 +1407,7 @@ def get_request_min_p(request: LlmRequest) -> float:
1433
1407
gen_request_seq_slots .append (request .py_seq_slot )
1434
1408
1435
1409
if self .is_advanced_mtp_sampler :
1436
- temperatures .extend ([get_request_temperature (request )] *
1437
- (self .max_draft_len + 1 ))
1438
- top_k .extend ([get_request_top_k (request )] *
1439
- (self .max_draft_len + 1 ))
1440
- top_p .extend ([get_request_top_p (request )] *
1441
- (self .max_draft_len + 1 ))
1442
- min_p .extend ([get_request_min_p (request )] *
1443
- (self .max_draft_len + 1 ))
1410
+ collect_req_mtp_sampling_params (request , self .max_draft_len )
1444
1411
1445
1412
request .py_batch_idx = request .py_seq_slot
1446
1413
@@ -1561,21 +1528,6 @@ def previous_seq_slots_device():
1561
1528
self .gather_ids_cuda [:len (gather_ids )].copy_ (torch .tensor (
1562
1529
gather_ids , dtype = torch .int , pin_memory = True ),
1563
1530
non_blocking = True )
1564
- if self .is_advanced_mtp_sampler :
1565
- self .temperatures_cuda [:len (temperatures )].copy_ (
1566
- torch .tensor (temperatures ,
1567
- dtype = torch .float ,
1568
- pin_memory = True ),
1569
- non_blocking = True )
1570
- self .top_k_cuda [:len (top_k )].copy_ (torch .tensor (
1571
- top_k , dtype = torch .int , pin_memory = True ),
1572
- non_blocking = True )
1573
- self .top_p_cuda [:len (top_p )].copy_ (torch .tensor (
1574
- top_p , dtype = torch .float , pin_memory = True ),
1575
- non_blocking = True )
1576
- self .min_p_cuda [:len (min_p )].copy_ (torch .tensor (
1577
- min_p , dtype = torch .float , pin_memory = True ),
1578
- non_blocking = True )
1579
1531
1580
1532
if not attn_metadata .is_cuda_graph :
1581
1533
# Assumes seq lens do not change between CUDA graph invocations. This applies
@@ -1651,11 +1603,8 @@ def previous_seq_slots_device():
1651
1603
spec_metadata .gather_ids = self .gather_ids_cuda [:len (gather_ids )]
1652
1604
1653
1605
if self .is_advanced_mtp_sampler :
1654
- spec_metadata .temperatures = self .temperatures_cuda [:len (
1655
- temperatures )]
1656
- spec_metadata .top_k = self .top_k_cuda [:len (top_k )]
1657
- spec_metadata .top_p = self .top_p_cuda [:len (top_p )]
1658
- spec_metadata .min_p = self .min_p_cuda [:len (min_p )]
1606
+ spec_metadata .update_advanced_mtp_sampling_params (
1607
+ temperatures , top_k , top_p , min_p )
1659
1608
1660
1609
spec_metadata .num_generations = len (
1661
1610
scheduled_requests .generation_requests )
@@ -2205,6 +2154,10 @@ def forward(
2205
2154
spec_metadata .is_spec_dec_tree ,
2206
2155
spec_metadata .is_spec_dec_dynamic_tree ,
2207
2156
spec_metadata .max_draft_len )
2157
+
2158
+ if self .is_advanced_mtp_sampler :
2159
+ spec_metadata ._set_up_advanced_mtp_sampling (
2160
+ self .batch_size , self .max_draft_len )
2208
2161
else :
2209
2162
spec_resource_manager = None
2210
2163
spec_metadata = None
0 commit comments