@@ -241,6 +241,10 @@ def __init__(self,
241
241
self .enable_iter_perf_stats = model_engine .pytorch_backend_config .enable_iter_perf_stats
242
242
self .enable_iter_req_stats = model_engine .pytorch_backend_config .enable_iter_req_stats
243
243
self .stream_interval = model_engine .pytorch_backend_config .stream_interval
244
+ self .use_attention_dp_config = model_engine .pytorch_backend_config .use_attention_dp_config
245
+ self .attention_dp_time_out_iters = model_engine .pytorch_backend_config .attention_dp_time_out_iters
246
+ self .attention_dp_batching_wait_iters = model_engine .pytorch_backend_config .attention_dp_batching_wait_iters
247
+
244
248
self .num_fetch_requests_cur_rank = 0
245
249
self .num_fetch_requests = 0
246
250
self .shutdown_event = threading .Event ()
@@ -287,6 +291,9 @@ def __init__(self,
287
291
self .draft_model_engine .warmup (self .resource_manager )
288
292
289
293
self .is_shutdown = False
294
+ self .max_batch_size = max_batch_size
295
+ self .adp_ctx_waiting_iters = 0
296
+ self .adp_ctx_batching_wait_iters = 0
290
297
291
298
self .stats_lock = threading .Lock ()
292
299
self .stats = []
@@ -1228,7 +1235,16 @@ def _broadcast_new_requests(
1228
1235
def _fetch_new_requests (self ) -> List [RequestQueueItem ]:
1229
1236
if self .enable_attention_dp :
1230
1237
all_ranks_num_active_requests = []
1231
- responses_list = self .dist .tp_allgather (len (self .active_requests ))
1238
+ num_active_requests = len (self .active_requests )
1239
+ responses_list = self .dist .tp_allgather (num_active_requests )
1240
+ # Debug check - remove after verification
1241
+ if not all (isinstance (x , int ) for x in responses_list ):
1242
+ raise RuntimeError (
1243
+ f"tp_allgather returned non-integer values: { responses_list } "
1244
+ +
1245
+ f"Expected all ranks to return int from { num_active_requests } and { self .active_requests } ."
1246
+ )
1247
+
1232
1248
for num_active_requests in responses_list :
1233
1249
all_ranks_num_active_requests .append (num_active_requests )
1234
1250
total_num_active_requests = sum (all_ranks_num_active_requests )
@@ -1518,8 +1534,66 @@ def _schedule(self):
1518
1534
scheduler_output = self .scheduler .schedule_request (
1519
1535
self .active_requests , self .inflight_req_ids )
1520
1536
scheduled_requests = ScheduledRequests ()
1537
+ context_requests = scheduler_output .context_requests
1538
+ if self .enable_attention_dp :
1539
+ num_scheduled_context_requests = len (
1540
+ scheduler_output .context_requests )
1541
+ num_scheduled_generation_requests = len (
1542
+ scheduler_output .generation_requests )
1543
+ num_scheduled_tokens = sum ([
1544
+ len (req .get_tokens (0 )) for req in context_requests
1545
+ ]) + num_scheduled_generation_requests
1546
+ responses_list = self .dist .tp_allgather ([
1547
+ num_scheduled_context_requests ,
1548
+ num_scheduled_generation_requests , num_scheduled_tokens
1549
+ ])
1550
+ all_ranks_num_scheduled_context_requests = [
1551
+ response [0 ] for response in responses_list
1552
+ ]
1553
+ all_ranks_num_scheduled_generation_requests = [
1554
+ response [1 ] for response in responses_list
1555
+ ]
1556
+ all_ranks_num_scheduled_tokens = [
1557
+ response [2 ] for response in responses_list
1558
+ ]
1559
+
1560
+ all_ranks_have_free_ctx_slots = all ([
1561
+ num_gen < self .max_batch_size
1562
+ for num_gen in all_ranks_num_scheduled_generation_requests
1563
+ ])
1564
+ all_ranks_have_multi_gen = all ([
1565
+ num_gen > 1
1566
+ for num_gen in all_ranks_num_scheduled_generation_requests
1567
+ ])
1568
+ all_ranks_have_ctx_requests = all ([
1569
+ num_ctx > 0
1570
+ for num_ctx in all_ranks_num_scheduled_context_requests
1571
+ ])
1572
+
1573
+ all_ranks_have_gen_requests = all ([
1574
+ num_gen > 0
1575
+ for num_gen in all_ranks_num_scheduled_generation_requests
1576
+ ])
1577
+ if self .use_attention_dp_config :
1578
+ # wait for all ranks have context requests
1579
+ if all_ranks_have_multi_gen :
1580
+ if all_ranks_have_free_ctx_slots and all_ranks_have_ctx_requests :
1581
+ self .adp_ctx_waiting_iters = 0
1582
+ else :
1583
+ self .adp_ctx_waiting_iters += 1
1584
+ context_requests = []
1585
+ if self .adp_ctx_waiting_iters >= self .attention_dp_time_out_iters :
1586
+ self .adp_ctx_waiting_iters = 0
1587
+ context_requests = scheduler_output .context_requests
1588
+ # balance number of context requests across ranks
1589
+ if all_ranks_have_free_ctx_slots and all_ranks_have_ctx_requests and all_ranks_have_gen_requests :
1590
+ if self .adp_ctx_batching_wait_iters <= self .attention_dp_batching_wait_iters :
1591
+ self .adp_ctx_batching_wait_iters += 1
1592
+ context_requests = []
1593
+ else :
1594
+ self .adp_ctx_batching_wait_iters = 0
1521
1595
1522
- scheduled_requests .context_requests = scheduler_output . context_requests
1596
+ scheduled_requests .context_requests = context_requests
1523
1597
scheduled_requests .generation_requests = scheduler_output .generation_requests
1524
1598
scheduled_requests .paused_requests = scheduler_output .paused_requests
1525
1599
return scheduled_requests , scheduler_output .fitting_disagg_gen_init_requests , scheduler_output .num_fitting_requests
0 commit comments