@@ -87,27 +87,68 @@ def _get_from_waiting_queue(
87
87
self ,
88
88
waiting_queue : deque [RequestQueueItem ],
89
89
max_req_count : int ,
90
+ enable_attention_dp : bool ,
91
+ all_ranks_num_active_requests : Optional [List [int ]] = None ,
90
92
) -> List [RequestQueueItem ]:
91
- """Safely extracts up to max_req_count items from a deque.
92
-
93
+ """
93
94
Args:
94
95
waiting_queue: The queue to pop items from.
95
96
max_req_count: Maximum items to retrieve. Returns empty list if <=0.
96
-
97
+ enable_attention_dp: Whether to enable attention DP scheduling.
98
+ all_ranks_num_active_requests: Number of active requests for each rank.
97
99
Returns:
98
- List of retrieved items (may be shorter than max_req_count if queue empties first) .
100
+ List of requests that can be processed .
99
101
"""
100
- # Edge case handling
101
- if max_req_count <= 0 : # Handles negative/zero counts
102
+
103
+ if max_req_count <= 0 :
102
104
return []
103
105
104
- items = []
105
106
req_count = 0
107
+ items = []
108
+ pending_requests = []
109
+
110
+ # Track the request with strict requirements
111
+ scheduling_all_ranks_num_active_requests = all_ranks_num_active_requests .copy (
112
+ ) if enable_attention_dp else None
106
113
while req_count < max_req_count and waiting_queue :
107
- items .append (waiting_queue .popleft ())
108
- req_count += 1
114
+ req_item = waiting_queue .popleft ()
115
+ can_process = self ._can_process_attention_dp_request (
116
+ req_item , scheduling_all_ranks_num_active_requests
117
+ ) if enable_attention_dp else True
118
+
119
+ if can_process :
120
+ items .append (req_item )
121
+ req_count += 1
122
+ else :
123
+ pending_requests .append (req_item )
124
+
125
+ # Put the pending requests back to the waiting queue
126
+ # All ranks should have the same waiting queue
127
+ waiting_queue .extendleft (reversed (pending_requests ))
128
+
109
129
return items
110
130
131
+ def _can_process_attention_dp_request (
132
+ self , req_item : RequestQueueItem ,
133
+ all_ranks_num_active_requests : List [int ]) -> bool :
134
+ """Return True if the request can be processed immediately, else False."""
135
+
136
+ scheduling_params = getattr (req_item .request , 'py_scheduling_params' ,
137
+ None )
138
+ if scheduling_params is None :
139
+ return True
140
+
141
+ target_dp_rank = scheduling_params .attention_dp_rank
142
+ if target_dp_rank is None or scheduling_params .attention_dp_relax :
143
+ return True
144
+
145
+ if all_ranks_num_active_requests [
146
+ target_dp_rank ] < self .max_num_active_requests :
147
+ all_ranks_num_active_requests [target_dp_rank ] += 1
148
+ return True
149
+
150
+ return False
151
+
111
152
def enqueue_requests (self , requests : List [ExecutorRequest ]):
112
153
req_ids = []
113
154
with self .enqueue_lock :
@@ -152,8 +193,12 @@ def can_enqueue_request(self) -> bool:
152
193
return self .active and self .dist .rank == 0
153
194
154
195
def _fetch_and_process_requests (
155
- self , total_num_active_requests : int ,
156
- total_max_num_active_requests : int ) -> List [RequestQueueItem ]:
196
+ self ,
197
+ total_num_active_requests : int ,
198
+ total_max_num_active_requests : int ,
199
+ enable_attention_dp : bool ,
200
+ all_ranks_num_active_requests : Optional [List [int ]] = None
201
+ ) -> List [RequestQueueItem ]:
157
202
"""Common logic for fetching and processing requests from the queue."""
158
203
# Calculate timeout
159
204
timeout = None if (total_num_active_requests == 0 ) and len (
@@ -181,7 +226,8 @@ def _fetch_and_process_requests(
181
226
182
227
new_requests = self ._get_from_waiting_queue (
183
228
self .waiting_queue ,
184
- total_max_num_active_requests - total_num_active_requests )
229
+ total_max_num_active_requests - total_num_active_requests ,
230
+ enable_attention_dp , all_ranks_num_active_requests )
185
231
186
232
# Update performance metrics
187
233
if self .enable_iter_perf_stats and self .dist .rank == 0 :
@@ -204,9 +250,11 @@ def _fetch_new_requests_attention_tp(
204
250
total_num_active_requests = num_active_requests
205
251
total_max_num_active_requests = self .max_num_active_requests
206
252
207
- # Use common request fetching logic
253
+ # fetch and process requests into waiting queue
208
254
new_requests = self ._fetch_and_process_requests (
209
- total_num_active_requests , total_max_num_active_requests )
255
+ total_num_active_requests ,
256
+ total_max_num_active_requests ,
257
+ enable_attention_dp = False )
210
258
211
259
# Merge requests and add to active list
212
260
merged_requests = self ._merge_requests (new_requests )
@@ -224,34 +272,84 @@ def _fetch_new_requests_attention_dp(
224
272
total_num_active_requests = sum (all_ranks_num_active_requests )
225
273
total_max_num_active_requests = self .dist .tp_size * self .max_num_active_requests
226
274
227
- # Use common request fetching logic
275
+ # fetch and process requests into waiting queue
228
276
new_requests = self ._fetch_and_process_requests (
229
- total_num_active_requests , total_max_num_active_requests )
277
+ total_num_active_requests ,
278
+ total_max_num_active_requests ,
279
+ enable_attention_dp = True ,
280
+ all_ranks_num_active_requests = all_ranks_num_active_requests )
230
281
231
- # Balance requests across ranks
232
- num_new_requests_all_ranks = len (new_requests )
233
- self .expected_num_active_requests = max (
234
- (total_num_active_requests + num_new_requests_all_ranks +
235
- self .dist .tp_size - 1 ) // self .dist .tp_size ,
236
- max (all_ranks_num_active_requests ),
237
- )
238
-
239
- new_requests_cur_rank = self ._balance_requests_across_ranks (
282
+ # Schedule attention dp requests
283
+ all_ranks_new_requests = self ._schedule_attention_dp_requests (
240
284
new_requests , all_ranks_num_active_requests )
285
+ new_requests_cur_rank = all_ranks_new_requests [self .dist .tp_rank ]
241
286
242
287
# Update performance metrics
243
288
if self .enable_iter_perf_stats and self .start_times :
244
289
self ._update_new_active_requests_queue_latency (
245
290
new_requests_cur_rank )
246
291
247
292
# Update counters
248
- self .num_fetch_requests += num_new_requests_all_ranks
293
+ self .num_fetch_requests += len ( new_requests )
249
294
self .num_fetch_requests_cur_rank += len (new_requests_cur_rank )
250
295
251
296
# Merge requests and add to active list
252
297
new_requests_cur_rank = self ._merge_requests (new_requests_cur_rank )
253
298
return new_requests_cur_rank
254
299
300
+ def _schedule_attention_dp_requests (
301
+ self , new_requests : List [RequestQueueItem ],
302
+ all_ranks_num_active_requests : List [int ]) -> List [RequestQueueItem ]:
303
+ """Schedule attention dp requests."""
304
+
305
+ # Map from ranks to new requests
306
+ all_ranks_new_requests = {
307
+ tp_rank : []
308
+ for tp_rank in range (self .dist .tp_size )
309
+ }
310
+
311
+ # Prioritize the requests that are not in relax mode
312
+ def get_relax_value (req_item ):
313
+ scheduling_params = getattr (req_item .request ,
314
+ 'py_scheduling_params' , None )
315
+ if scheduling_params is None :
316
+ return True
317
+ return scheduling_params .attention_dp_relax
318
+
319
+ new_requests = sorted (new_requests , key = get_relax_value , reverse = True )
320
+
321
+ # Try to put the requests to the target dp rank until the max_num_active_requests is reached
322
+ remaining_unscheduled = []
323
+ for req_item in new_requests :
324
+ scheduled = False
325
+ scheduling_params = getattr (req_item .request ,
326
+ 'py_scheduling_params' , None )
327
+ if scheduling_params is not None :
328
+ target_dp_rank = scheduling_params .attention_dp_rank
329
+ if target_dp_rank is not None and all_ranks_num_active_requests [
330
+ target_dp_rank ] < self .max_num_active_requests :
331
+ all_ranks_num_active_requests [target_dp_rank ] += 1
332
+ scheduled = True
333
+ all_ranks_new_requests [target_dp_rank ].append (req_item )
334
+
335
+ if not scheduled :
336
+ remaining_unscheduled .append (req_item )
337
+
338
+ # Balance the remaining unscheduled requests across ranks
339
+ num_new_requests_all_ranks = len (remaining_unscheduled )
340
+ total_num_active_requests = sum (all_ranks_num_active_requests )
341
+ self .expected_num_active_requests = max (
342
+ (total_num_active_requests + num_new_requests_all_ranks +
343
+ self .dist .tp_size - 1 ) // self .dist .tp_size ,
344
+ max (all_ranks_num_active_requests ),
345
+ )
346
+
347
+ all_ranks_new_requests = self ._balance_requests_across_ranks (
348
+ remaining_unscheduled , all_ranks_new_requests ,
349
+ all_ranks_num_active_requests )
350
+
351
+ return all_ranks_new_requests
352
+
255
353
def _handle_request_broadcasting (self ,
256
354
new_requests : List [RequestQueueItem ]):
257
355
"""Handle broadcasting of requests and Python objects across ranks."""
@@ -260,8 +358,13 @@ def _handle_request_broadcasting(self,
260
358
new_requests , "py_logits_post_processors" )
261
359
py_multimodal_data = self ._collect_py_objects_from_requests (
262
360
new_requests , "py_multimodal_data" )
361
+ py_scheduling_params = self ._collect_py_objects_from_requests (
362
+ new_requests , "py_scheduling_params" )
263
363
py_request_objects = tuple (
264
- filter (None , [py_logits_post_processors , py_multimodal_data ]))
364
+ filter (None , [
365
+ py_logits_post_processors , py_multimodal_data ,
366
+ py_scheduling_params
367
+ ]))
265
368
else :
266
369
py_request_objects = None
267
370
@@ -300,28 +403,30 @@ def _validate_and_filter_requests(
300
403
301
404
def _balance_requests_across_ranks (
302
405
self , new_requests : List [RequestQueueItem ],
406
+ all_ranks_new_requests : Dict [int , List [RequestQueueItem ]],
303
407
all_ranks_num_active_requests : List [int ]) -> List [RequestQueueItem ]:
304
408
"""Balance requests across ranks for attention DP."""
305
- new_requests_cur_rank = []
306
-
307
- if new_requests and self .expected_num_active_requests > all_ranks_num_active_requests [
308
- self .dist .tp_rank ]:
409
+ if new_requests :
309
410
# Balance context tokens across ranks using heap
310
411
HeapVal = namedtuple (
311
412
'HeapVal' ,
312
413
['num_tokens' , 'num_requests' , 'rank' , 'request_list' ])
313
414
314
415
all_ranks_new_requests_heap = [
315
- HeapVal (0 , self . expected_num_active_requests - val , tp_rank , [])
416
+ HeapVal (0 , val , tp_rank , [])
316
417
for tp_rank , val in enumerate (all_ranks_num_active_requests )
317
418
]
318
419
319
- new_requests_cur_rank = all_ranks_new_requests_heap [
320
- self .dist .tp_rank ].request_list
321
420
all_ranks_new_requests_heap = [
322
421
val for val in all_ranks_new_requests_heap
323
- if val .num_requests > 0
422
+ if val .num_requests < self . expected_num_active_requests
324
423
]
424
+
425
+ all_ranks_new_scheduled_requests = {
426
+ val .rank : val .request_list
427
+ for val in all_ranks_new_requests_heap
428
+ }
429
+
325
430
heapq .heapify (all_ranks_new_requests_heap )
326
431
327
432
# Sort by token count (descending) for better load balancing
@@ -337,17 +442,22 @@ def _balance_requests_across_ranks(
337
442
token_count = len (
338
443
getattr (req_item .request , 'input_token_ids' ,
339
444
[])) if req_item .request else 0
445
+ # Update the heap value with the new request
340
446
val = val ._replace (
341
447
num_tokens = val .num_tokens + token_count ,
342
- num_requests = val .num_requests - 1 ,
448
+ num_requests = val .num_requests + 1 ,
343
449
)
450
+
344
451
val .request_list .append (req_item )
345
- if val .num_requests > 0 :
452
+ # If rank still has room for new requests, push back into heap
453
+ if val .num_requests < self .expected_num_active_requests :
346
454
heapq .heappush (all_ranks_new_requests_heap , val )
347
- elif val .rank == self .dist .tp_rank :
348
- break
349
455
350
- return new_requests_cur_rank
456
+ # Extend all_ranks_new_requests with the new requests that have been scheduled
457
+ for rank , reqs in all_ranks_new_scheduled_requests .items ():
458
+ all_ranks_new_requests [rank ].extend (reqs )
459
+
460
+ return all_ranks_new_requests
351
461
352
462
def _collect_py_objects_from_requests (
353
463
self , requests : List [RequestQueueItem ],
0 commit comments