@@ -1323,7 +1323,6 @@ def previous_seq_slots_device():
1323
1323
1324
1324
num_tokens = len (input_ids )
1325
1325
num_draft_tokens = len (draft_tokens )
1326
- num_requests = len (request_ids )
1327
1326
total_num_tokens = len (position_ids )
1328
1327
assert total_num_tokens <= self .max_num_tokens , (
1329
1328
"total_num_tokens should be less than or equal to max_num_tokens" )
@@ -1340,6 +1339,10 @@ def previous_seq_slots_device():
1340
1339
self .draft_tokens_cuda [:len (draft_tokens )].copy_ (draft_tokens ,
1341
1340
non_blocking = True )
1342
1341
if next_draft_tokens_device is not None :
1342
+ # Initialize these two values to zeros
1343
+ self .previous_pos_id_offsets_cuda *= 0
1344
+ self .previous_kv_lens_offsets_cuda *= 0
1345
+
1343
1346
if previous_batch_len > 0 :
1344
1347
previous_slots = previous_seq_slots_device ()
1345
1348
# previous input ids
@@ -1364,24 +1367,37 @@ def previous_seq_slots_device():
1364
1367
pin_memory = True )
1365
1368
self .previous_pos_indices_cuda [0 :previous_batch_tokens ].copy_ (
1366
1369
previous_pos_indices_host , non_blocking = True )
1370
+
1371
+ # The order of requests in a batch: [context requests, generation requests]
1372
+ # generation requests: ['requests that do not have previous batch', 'requests that already have previous batch', 'dummy requests']
1373
+ # 1) 'requests that do not have previous batch': disable overlap scheduler or the first step in the generation server of disaggregated serving.
1374
+ # 2) 'requests that already have previous batch': previous iteration's requests.
1375
+ # 3) 'dummy requests': pad dummy requests for CUDA graph or attention dp.
1376
+ # Therefore, both of self.previous_pos_id_offsets_cuda and self.previous_kv_lens_offsets_cuda are also 3 segments.
1377
+ # For 1) 'requests that do not have previous batch': disable overlap scheduler or the first step in the generation server of disaggregated serving.
1378
+ # Set these requests' previous_pos_id_offsets and previous_kv_lens_offsets to '0' to skip the value changes in _preprocess_inputs.
1379
+ # Already set to '0' during initialization.
1380
+ # For 2) 'requests that already have previous batch': enable overlap scheduler.
1381
+ # Set their previous_pos_id_offsets and previous_kv_lens_offsets according to new_tokens_lens_device and kv_len_offsets_device.
1382
+ # For 3) 'dummy requests': pad dummy requests for CUDA graph or attention dp.
1383
+ # Already set to '0' during initialization.
1384
+
1385
+ num_extend_reqeust_wo_dummy = len (extend_requests ) - len (
1386
+ extend_dummy_requests )
1367
1387
self .previous_pos_id_offsets_cuda [
1368
- 0 :previous_batch_tokens ].copy_ (
1388
+ (num_extend_reqeust_wo_dummy - previous_batch_len ) *
1389
+ (1 + self .max_draft_len ):num_extend_reqeust_wo_dummy *
1390
+ (1 + self .max_draft_len )].copy_ (
1369
1391
new_tokens_lens_device [self .previous_pos_indices_cuda [
1370
1392
0 :previous_batch_tokens ]],
1371
1393
non_blocking = True )
1372
- self .previous_kv_lens_offsets_cuda [0 :previous_batch_len ].copy_ (
1373
- kv_len_offsets_device [previous_slots ], non_blocking = True )
1374
- # for the requests that do not have previous batch, set the previous_pos_id_offsets and
1375
- # previous_kv_lens_offsets to zeros to skip the value changes in _preprocess_inputs
1376
- self .previous_pos_id_offsets_cuda [
1377
- previous_batch_tokens :num_requests *
1378
- (1 + self .max_draft_len )] *= 0
1394
+
1379
1395
self .previous_kv_lens_offsets_cuda [
1380
- previous_batch_len : num_requests ] *= 0
1381
- else :
1382
- # change the data to zeros to skip the value changes in _preprocess_inputs
1383
- self . previous_pos_id_offsets_cuda *= 0
1384
- self . previous_kv_lens_offsets_cuda *= 0
1396
+ num_extend_reqeust_wo_dummy -
1397
+ previous_batch_len : num_extend_reqeust_wo_dummy ]. copy_ (
1398
+ kv_len_offsets_device [ previous_slots ],
1399
+ non_blocking = True )
1400
+
1385
1401
elif new_tokens_device is not None :
1386
1402
seq_slots_device = previous_seq_slots_device ()
1387
1403
max_draft_len = max (draft_lens )
0 commit comments