@@ -241,6 +241,10 @@ def __init__(self,
241
241
self .enable_iter_perf_stats = model_engine .pytorch_backend_config .enable_iter_perf_stats
242
242
self .enable_iter_req_stats = model_engine .pytorch_backend_config .enable_iter_req_stats
243
243
self .stream_interval = model_engine .pytorch_backend_config .stream_interval
244
+ self .use_attention_dp_config = model_engine .pytorch_backend_config .use_attention_dp_config
245
+ self .attention_dp_time_out_iters = model_engine .pytorch_backend_config .attention_dp_time_out_iters
246
+ self .attention_dp_batching_wait_iters = model_engine .pytorch_backend_config .attention_dp_batching_wait_iters
247
+
244
248
self .num_fetch_requests_cur_rank = 0
245
249
self .num_fetch_requests = 0
246
250
self .shutdown_event = threading .Event ()
@@ -287,6 +291,9 @@ def __init__(self,
287
291
self .draft_model_engine .warmup (self .resource_manager )
288
292
289
293
self .is_shutdown = False
294
+ self .max_batch_size = max_batch_size
295
+ self .adp_ctx_waiting_iters = 0
296
+ self .adp_ctx_batching_wait_iters = 0
290
297
291
298
self .stats_lock = threading .Lock ()
292
299
self .stats = []
@@ -1228,7 +1235,15 @@ def _broadcast_new_requests(
1228
1235
def _fetch_new_requests (self ) -> List [RequestQueueItem ]:
1229
1236
if self .enable_attention_dp :
1230
1237
all_ranks_num_active_requests = []
1231
- responses_list = self .dist .tp_allgather (len (self .active_requests ))
1238
+ num_active_requests = len (self .active_requests )
1239
+ responses_list = self .dist .tp_allgather (num_active_requests )
1240
+ # Debug check - remove after verification
1241
+ if not all (isinstance (x , int ) for x in responses_list ):
1242
+ raise RuntimeError (
1243
+ f"tp_allgather returned non-integer values: { responses_list } " +
1244
+ f"Expected all ranks to return int from { num_active_requests } and { self .active_requests } ."
1245
+ )
1246
+
1232
1247
for num_active_requests in responses_list :
1233
1248
all_ranks_num_active_requests .append (num_active_requests )
1234
1249
total_num_active_requests = sum (all_ranks_num_active_requests )
@@ -1518,8 +1533,66 @@ def _schedule(self):
1518
1533
scheduler_output = self .scheduler .schedule_request (
1519
1534
self .active_requests , self .inflight_req_ids )
1520
1535
scheduled_requests = ScheduledRequests ()
1536
+ context_requests = scheduler_output .context_requests
1537
+ if self .enable_attention_dp :
1538
+ num_scheduled_context_requests = len (
1539
+ scheduler_output .context_requests )
1540
+ num_scheduled_generation_requests = len (
1541
+ scheduler_output .generation_requests )
1542
+ num_scheduled_tokens = sum ([
1543
+ len (req .get_tokens (0 )) for req in context_requests
1544
+ ]) + num_scheduled_generation_requests
1545
+ responses_list = self .dist .tp_allgather ([
1546
+ num_scheduled_context_requests ,
1547
+ num_scheduled_generation_requests , num_scheduled_tokens
1548
+ ])
1549
+ all_ranks_num_scheduled_context_requests = [
1550
+ response [0 ] for response in responses_list
1551
+ ]
1552
+ all_ranks_num_scheduled_generation_requests = [
1553
+ response [1 ] for response in responses_list
1554
+ ]
1555
+ all_ranks_num_scheduled_tokens = [
1556
+ response [2 ] for response in responses_list
1557
+ ]
1558
+
1559
+ all_ranks_have_free_ctx_slots = all ([
1560
+ num_gen < self .max_batch_size
1561
+ for num_gen in all_ranks_num_scheduled_generation_requests
1562
+ ])
1563
+ all_ranks_have_multi_gen = all ([
1564
+ num_gen > 1
1565
+ for num_gen in all_ranks_num_scheduled_generation_requests
1566
+ ])
1567
+ all_ranks_have_ctx_requests = all ([
1568
+ num_ctx > 0
1569
+ for num_ctx in all_ranks_num_scheduled_context_requests
1570
+ ])
1571
+
1572
+ all_ranks_have_gen_requests = all ([
1573
+ num_gen > 0
1574
+ for num_gen in all_ranks_num_scheduled_generation_requests
1575
+ ])
1576
+ if self .use_attention_dp_config :
1577
+ # wait for all ranks have context requests
1578
+ if all_ranks_have_multi_gen :
1579
+ if all_ranks_have_free_ctx_slots and all_ranks_have_ctx_requests :
1580
+ self .adp_ctx_waiting_iters = 0
1581
+ else :
1582
+ self .adp_ctx_waiting_iters += 1
1583
+ context_requests = []
1584
+ if self .adp_ctx_waiting_iters >= self .attention_dp_time_out_iters :
1585
+ self .adp_ctx_waiting_iters = 0
1586
+ context_requests = scheduler_output .context_requests
1587
+ # balance number of context requests across ranks
1588
+ if all_ranks_have_free_ctx_slots and all_ranks_have_ctx_requests and all_ranks_have_gen_requests :
1589
+ if self .adp_ctx_batching_wait_iters <= self .attention_dp_batching_wait_iters :
1590
+ self .adp_ctx_batching_wait_iters += 1
1591
+ context_requests = []
1592
+ else :
1593
+ self .adp_ctx_batching_wait_iters = 0
1521
1594
1522
- scheduled_requests .context_requests = scheduler_output . context_requests
1595
+ scheduled_requests .context_requests = context_requests
1523
1596
scheduled_requests .generation_requests = scheduler_output .generation_requests
1524
1597
scheduled_requests .paused_requests = scheduler_output .paused_requests
1525
1598
return scheduled_requests , scheduler_output .fitting_disagg_gen_init_requests , scheduler_output .num_fitting_requests
0 commit comments