@@ -244,30 +244,105 @@ def _fetch_new_requests_attention_dp(
244
244
new_requests = self ._fetch_and_process_requests (
245
245
total_num_active_requests , total_max_num_active_requests )
246
246
247
- # Balance requests across ranks
248
- num_new_requests_all_ranks = len (new_requests )
249
- self .expected_num_active_requests = max (
250
- (total_num_active_requests + num_new_requests_all_ranks +
251
- self .dist .tp_size - 1 ) // self .dist .tp_size ,
252
- max (all_ranks_num_active_requests ),
253
- )
254
-
255
- new_requests_cur_rank = self ._balance_requests_across_ranks (
256
- new_requests , all_ranks_num_active_requests )
247
+ # Schedule attention dp requests
248
+ new_requests_cur_rank = self ._schedule_attention_dp_requests (
249
+ num_active_requests , new_requests , all_ranks_num_active_requests )
257
250
258
251
# Update performance metrics
259
252
if self .enable_iter_perf_stats and self .start_times :
260
253
self ._update_new_active_requests_queue_latency (
261
254
new_requests_cur_rank )
262
255
263
256
# Update counters
264
- self .num_fetch_requests += num_new_requests_all_ranks
257
+ self .num_fetch_requests += len ( new_requests )
265
258
self .num_fetch_requests_cur_rank += len (new_requests_cur_rank )
266
259
267
260
# Merge requests and add to active list
268
261
new_requests_cur_rank = self ._merge_requests (new_requests_cur_rank )
269
262
return new_requests_cur_rank
270
263
264
+ def _schedule_attention_dp_requests (
265
+ self , num_active_requests : int ,
266
+ new_requests : List [RequestQueueItem ],
267
+ all_ranks_num_active_requests : List [int ]) -> List [RequestQueueItem ]:
268
+ """Schedule attention dp requests."""
269
+ # Separate the requests into two groups
270
+ # 1. requests without schedule params or with schedule params that don't specify attention dp rank
271
+ # 2. requests with schedule params that specify attention dp rank
272
+ requests_specified_attention_dp_rank = []
273
+ for req_item in new_requests :
274
+ if req_item .request .schedule_params is not None and \
275
+ req_item .request .schedule_params .attention_dp_rank == self .dist .tp_rank :
276
+ requests_specified_attention_dp_rank .append (req_item )
277
+
278
+ # Routing requests to the corresponding attention dp without exceeding the max_num_active_requests
279
+ new_requests_cur_rank = []
280
+ new_requests_cur_rank_waiting = []
281
+ new_requests_cur_rank_relax = []
282
+
283
+ available_slots = self .max_num_active_requests - num_active_requests
284
+
285
+ for req_item in requests_specified_attention_dp_rank :
286
+ is_relax = req_item .request .schedule_params .attention_dp_relax
287
+
288
+ if len (new_requests_cur_rank ) < available_slots :
289
+ # Prioritize the non-relax requests
290
+ target_list = new_requests_cur_rank_relax if is_relax else new_requests_cur_rank
291
+ target_list .append (req_item )
292
+ else :
293
+ # Add to waiting queue
294
+ target_list = new_requests_cur_rank_relax if is_relax else new_requests_cur_rank_waiting
295
+ target_list .append (req_item )
296
+
297
+ items_to_move = available_slots - len (new_requests_cur_rank )
298
+ if items_to_move > 0 :
299
+ new_requests_cur_rank .extend (
300
+ new_requests_cur_rank_relax [:items_to_move ])
301
+ new_requests_cur_rank_relax = new_requests_cur_rank_relax [
302
+ items_to_move :]
303
+
304
+ # Allgather the non-scheduled requests across ranks
305
+ # TODO: Remove the padding overhead
306
+ new_requests_cur_rank_relax_ids = [
307
+ req_item .id for req_item in new_requests_cur_rank_relax
308
+ ]
309
+ padding_num = self .max_num_active_requests - len (
310
+ new_requests_cur_rank_relax )
311
+ for _ in range (padding_num ):
312
+ new_requests_cur_rank_relax_ids .append (None )
313
+ non_scheduled_requests_id = self .dist .tp_allgather (
314
+ new_requests_cur_rank_relax_ids )
315
+ non_scheduled_requests_id = [
316
+ req_id for req_id in non_scheduled_requests_id if req_id is not None
317
+ ]
318
+
319
+ # Non-scheduled requests should be same across ranks
320
+ non_scheduled_requests = []
321
+ for req_item in new_requests :
322
+ if req_item .id in non_scheduled_requests_id :
323
+ non_scheduled_requests .append (req_item )
324
+ elif req_item .request .schedule_params is None or \
325
+ req_item .request .schedule_params is not None and \
326
+ req_item .request .schedule_params .attention_dp_rank is None :
327
+ non_scheduled_requests .append (req_item )
328
+
329
+ # Put the request back to the waiting queue
330
+ self .waiting_queue .extendleft (new_requests_cur_rank_waiting )
331
+
332
+ # TODO: Balance the no attention dp rank requests and relax requests across ranks
333
+ num_new_requests_all_ranks = len (new_requests )
334
+ total_num_active_requests = sum (all_ranks_num_active_requests )
335
+ self .expected_num_active_requests = max (
336
+ (total_num_active_requests + num_new_requests_all_ranks +
337
+ self .dist .tp_size - 1 ) // self .dist .tp_size ,
338
+ max (all_ranks_num_active_requests ),
339
+ )
340
+
341
+ new_requests_cur_rank = self ._balance_requests_across_ranks (
342
+ non_scheduled_requests , )
343
+
344
+ return new_requests_cur_rank
345
+
271
346
def _handle_request_broadcasting (self ,
272
347
new_requests : List [RequestQueueItem ]):
273
348
"""Handle broadcasting of requests and Python objects across ranks."""
0 commit comments