@@ -87,27 +87,68 @@ def _get_from_waiting_queue(
87
87
self ,
88
88
waiting_queue : deque [RequestQueueItem ],
89
89
max_req_count : int ,
90
+ enable_attention_dp : bool ,
91
+ all_ranks_num_active_requests : Optional [List [int ]] = None ,
90
92
) -> List [RequestQueueItem ]:
91
- """Safely extracts up to max_req_count items from a deque.
92
-
93
+ """
93
94
Args:
94
95
waiting_queue: The queue to pop items from.
95
96
max_req_count: Maximum items to retrieve. Returns empty list if <=0.
96
-
97
+ enable_attention_dp: Whether to enable attention DP scheduling.
98
+ all_ranks_num_active_requests: Number of active requests for each rank.
97
99
Returns:
98
- List of retrieved items (may be shorter than max_req_count if queue empties first) .
100
+ List of requests that can be processed .
99
101
"""
100
- # Edge case handling
101
- if max_req_count <= 0 : # Handles negative/zero counts
102
+
103
+ if max_req_count <= 0 :
102
104
return []
103
105
104
- items = []
105
106
req_count = 0
107
+ items = []
108
+ pending_requests = []
109
+
110
+ # Track the request with strict requirements
111
+ scheduling_all_ranks_num_active_requests = all_ranks_num_active_requests .copy (
112
+ ) if enable_attention_dp else None
106
113
while req_count < max_req_count and waiting_queue :
107
- items .append (waiting_queue .popleft ())
108
- req_count += 1
114
+ req_item = waiting_queue .popleft ()
115
+ can_process = self ._can_process_attention_dp_request (
116
+ req_item , scheduling_all_ranks_num_active_requests
117
+ ) if enable_attention_dp else True
118
+
119
+ if can_process :
120
+ items .append (req_item )
121
+ req_count += 1
122
+ else :
123
+ pending_requests .append (req_item )
124
+
125
+ # Put the pending requests back to the waiting queue
126
+ # All ranks should have the same waiting queue
127
+ waiting_queue .extendleft (reversed (pending_requests ))
128
+
109
129
return items
110
130
131
+ def _can_process_attention_dp_request (
132
+ self , req_item : RequestQueueItem ,
133
+ all_ranks_num_active_requests : List [int ]) -> bool :
134
+ """Return True if the request can be processed immediately, else False."""
135
+
136
+ scheduling_params = getattr (req_item .request , 'py_scheduling_params' ,
137
+ None )
138
+ if scheduling_params is None :
139
+ return True
140
+
141
+ target_dp_rank = scheduling_params .attention_dp_rank
142
+ if target_dp_rank is None or scheduling_params .attention_dp_relax :
143
+ return True
144
+
145
+ if all_ranks_num_active_requests [
146
+ target_dp_rank ] < self .max_num_active_requests :
147
+ all_ranks_num_active_requests [target_dp_rank ] += 1
148
+ return True
149
+
150
+ return False
151
+
111
152
def enqueue_requests (self , requests : List [ExecutorRequest ]):
112
153
req_ids = []
113
154
try :
@@ -166,8 +207,12 @@ def can_enqueue_request(self) -> bool:
166
207
return can_enqueue and self .dist .rank == 0
167
208
168
209
def _fetch_and_process_requests (
169
- self , total_num_active_requests : int ,
170
- total_max_num_active_requests : int ) -> List [RequestQueueItem ]:
210
+ self ,
211
+ total_num_active_requests : int ,
212
+ total_max_num_active_requests : int ,
213
+ enable_attention_dp : bool ,
214
+ all_ranks_num_active_requests : Optional [List [int ]] = None
215
+ ) -> List [RequestQueueItem ]:
171
216
"""Common logic for fetching and processing requests from the queue."""
172
217
# Calculate timeout
173
218
timeout = None if (total_num_active_requests == 0 ) and len (
@@ -195,7 +240,8 @@ def _fetch_and_process_requests(
195
240
196
241
new_requests = self ._get_from_waiting_queue (
197
242
self .waiting_queue ,
198
- total_max_num_active_requests - total_num_active_requests )
243
+ total_max_num_active_requests - total_num_active_requests ,
244
+ enable_attention_dp , all_ranks_num_active_requests )
199
245
200
246
# Update performance metrics
201
247
if self .enable_iter_perf_stats and self .dist .rank == 0 :
@@ -218,9 +264,11 @@ def _fetch_new_requests_attention_tp(
218
264
total_num_active_requests = num_active_requests
219
265
total_max_num_active_requests = self .max_num_active_requests
220
266
221
- # Use common request fetching logic
267
+ # fetch and process requests into waiting queue
222
268
new_requests = self ._fetch_and_process_requests (
223
- total_num_active_requests , total_max_num_active_requests )
269
+ total_num_active_requests ,
270
+ total_max_num_active_requests ,
271
+ enable_attention_dp = False )
224
272
225
273
# Merge requests and add to active list
226
274
merged_requests = self ._merge_requests (new_requests )
@@ -238,34 +286,84 @@ def _fetch_new_requests_attention_dp(
238
286
total_num_active_requests = sum (all_ranks_num_active_requests )
239
287
total_max_num_active_requests = self .dist .tp_size * self .max_num_active_requests
240
288
241
- # Use common request fetching logic
289
+ # fetch and process requests into waiting queue
242
290
new_requests = self ._fetch_and_process_requests (
243
- total_num_active_requests , total_max_num_active_requests )
291
+ total_num_active_requests ,
292
+ total_max_num_active_requests ,
293
+ enable_attention_dp = True ,
294
+ all_ranks_num_active_requests = all_ranks_num_active_requests )
244
295
245
- # Balance requests across ranks
246
- num_new_requests_all_ranks = len (new_requests )
247
- self .expected_num_active_requests = max (
248
- (total_num_active_requests + num_new_requests_all_ranks +
249
- self .dist .tp_size - 1 ) // self .dist .tp_size ,
250
- max (all_ranks_num_active_requests ),
251
- )
252
-
253
- new_requests_cur_rank = self ._balance_requests_across_ranks (
296
+ # Schedule attention dp requests
297
+ all_ranks_new_requests = self ._schedule_attention_dp_requests (
254
298
new_requests , all_ranks_num_active_requests )
299
+ new_requests_cur_rank = all_ranks_new_requests [self .dist .tp_rank ]
255
300
256
301
# Update performance metrics
257
302
if self .enable_iter_perf_stats and self .start_times :
258
303
self ._update_new_active_requests_queue_latency (
259
304
new_requests_cur_rank )
260
305
261
306
# Update counters
262
- self .num_fetch_requests += num_new_requests_all_ranks
307
+ self .num_fetch_requests += len ( new_requests )
263
308
self .num_fetch_requests_cur_rank += len (new_requests_cur_rank )
264
309
265
310
# Merge requests and add to active list
266
311
new_requests_cur_rank = self ._merge_requests (new_requests_cur_rank )
267
312
return new_requests_cur_rank
268
313
314
+ def _schedule_attention_dp_requests (
315
+ self , new_requests : List [RequestQueueItem ],
316
+ all_ranks_num_active_requests : List [int ]) -> List [RequestQueueItem ]:
317
+ """Schedule attention dp requests."""
318
+
319
+ # Map from ranks to new requests
320
+ all_ranks_new_requests = {
321
+ tp_rank : []
322
+ for tp_rank in range (self .dist .tp_size )
323
+ }
324
+
325
+ # Prioritize the requests that are not in relax mode
326
+ def get_relax_value (req_item ):
327
+ scheduling_params = getattr (req_item .request ,
328
+ 'py_scheduling_params' , None )
329
+ if scheduling_params is None :
330
+ return True
331
+ return scheduling_params .attention_dp_relax
332
+
333
+ new_requests = sorted (new_requests , key = get_relax_value , reverse = True )
334
+
335
+ # Try to put the requests to the target dp rank until the max_num_active_requests is reached
336
+ remaining_unscheduled = []
337
+ for req_item in new_requests :
338
+ scheduled = False
339
+ scheduling_params = getattr (req_item .request ,
340
+ 'py_scheduling_params' , None )
341
+ if scheduling_params is not None :
342
+ target_dp_rank = scheduling_params .attention_dp_rank
343
+ if target_dp_rank is not None and all_ranks_num_active_requests [
344
+ target_dp_rank ] < self .max_num_active_requests :
345
+ all_ranks_num_active_requests [target_dp_rank ] += 1
346
+ scheduled = True
347
+ all_ranks_new_requests [target_dp_rank ].append (req_item )
348
+
349
+ if not scheduled :
350
+ remaining_unscheduled .append (req_item )
351
+
352
+ # Balance the remaining unscheduled requests across ranks
353
+ num_new_requests_all_ranks = len (remaining_unscheduled )
354
+ total_num_active_requests = sum (all_ranks_num_active_requests )
355
+ self .expected_num_active_requests = max (
356
+ (total_num_active_requests + num_new_requests_all_ranks +
357
+ self .dist .tp_size - 1 ) // self .dist .tp_size ,
358
+ max (all_ranks_num_active_requests ),
359
+ )
360
+
361
+ all_ranks_new_requests = self ._balance_requests_across_ranks (
362
+ remaining_unscheduled , all_ranks_new_requests ,
363
+ all_ranks_num_active_requests )
364
+
365
+ return all_ranks_new_requests
366
+
269
367
def _handle_request_broadcasting (self ,
270
368
new_requests : List [RequestQueueItem ]):
271
369
"""Handle broadcasting of requests and Python objects across ranks."""
@@ -274,8 +372,13 @@ def _handle_request_broadcasting(self,
274
372
new_requests , "py_logits_post_processors" )
275
373
py_multimodal_data = self ._collect_py_objects_from_requests (
276
374
new_requests , "py_multimodal_data" )
375
+ py_scheduling_params = self ._collect_py_objects_from_requests (
376
+ new_requests , "py_scheduling_params" )
277
377
py_request_objects = tuple (
278
- filter (None , [py_logits_post_processors , py_multimodal_data ]))
378
+ filter (None , [
379
+ py_logits_post_processors , py_multimodal_data ,
380
+ py_scheduling_params
381
+ ]))
279
382
else :
280
383
py_request_objects = None
281
384
@@ -314,28 +417,30 @@ def _validate_and_filter_requests(
314
417
315
418
def _balance_requests_across_ranks (
316
419
self , new_requests : List [RequestQueueItem ],
420
+ all_ranks_new_requests : Dict [int , List [RequestQueueItem ]],
317
421
all_ranks_num_active_requests : List [int ]) -> List [RequestQueueItem ]:
318
422
"""Balance requests across ranks for attention DP."""
319
- new_requests_cur_rank = []
320
-
321
- if new_requests and self .expected_num_active_requests > all_ranks_num_active_requests [
322
- self .dist .tp_rank ]:
423
+ if new_requests :
323
424
# Balance context tokens across ranks using heap
324
425
HeapVal = namedtuple (
325
426
'HeapVal' ,
326
427
['num_tokens' , 'num_requests' , 'rank' , 'request_list' ])
327
428
328
429
all_ranks_new_requests_heap = [
329
- HeapVal (0 , self . expected_num_active_requests - val , tp_rank , [])
430
+ HeapVal (0 , val , tp_rank , [])
330
431
for tp_rank , val in enumerate (all_ranks_num_active_requests )
331
432
]
332
433
333
- new_requests_cur_rank = all_ranks_new_requests_heap [
334
- self .dist .tp_rank ].request_list
335
434
all_ranks_new_requests_heap = [
336
435
val for val in all_ranks_new_requests_heap
337
- if val .num_requests > 0
436
+ if val .num_requests < self . expected_num_active_requests
338
437
]
438
+
439
+ all_ranks_new_scheduled_requests = {
440
+ val .rank : val .request_list
441
+ for val in all_ranks_new_requests_heap
442
+ }
443
+
339
444
heapq .heapify (all_ranks_new_requests_heap )
340
445
341
446
# Sort by token count (descending) for better load balancing
@@ -351,17 +456,22 @@ def _balance_requests_across_ranks(
351
456
token_count = len (
352
457
getattr (req_item .request , 'input_token_ids' ,
353
458
[])) if req_item .request else 0
459
+ # Update the heap value with the new request
354
460
val = val ._replace (
355
461
num_tokens = val .num_tokens + token_count ,
356
- num_requests = val .num_requests - 1 ,
462
+ num_requests = val .num_requests + 1 ,
357
463
)
464
+
358
465
val .request_list .append (req_item )
359
- if val .num_requests > 0 :
466
+ # If rank still has room for new requests, push back into heap
467
+ if val .num_requests < self .expected_num_active_requests :
360
468
heapq .heappush (all_ranks_new_requests_heap , val )
361
- elif val .rank == self .dist .tp_rank :
362
- break
363
469
364
- return new_requests_cur_rank
470
+ # Extend all_ranks_new_requests with the new requests that have been scheduled
471
+ for rank , reqs in all_ranks_new_scheduled_requests .items ():
472
+ all_ranks_new_requests [rank ].extend (reqs )
473
+
474
+ return all_ranks_new_requests
365
475
366
476
def _collect_py_objects_from_requests (
367
477
self , requests : List [RequestQueueItem ],
0 commit comments