@@ -83,45 +83,13 @@ def _get_from_request_queue(
83
83
pass
84
84
return items
85
85
86
- def _get_from_waiting_queue_attention_tp (
86
+ def _get_from_waiting_queue (
87
87
self ,
88
88
waiting_queue : deque [RequestQueueItem ],
89
89
max_req_count : int ,
90
+ enable_attention_dp : bool ,
90
91
) -> List [RequestQueueItem ]:
91
- """Safely extracts up to max_req_count items from a deque.
92
92
93
- Args:
94
- waiting_queue: The queue to pop items from.
95
- max_req_count: Maximum items to retrieve. Returns empty list if <=0.
96
-
97
- Returns:
98
- List of retrieved items (may be shorter than max_req_count if queue empties first).
99
- """
100
- # Edge case handling
101
- if max_req_count <= 0 : # Handles negative/zero counts
102
- return []
103
-
104
- items = []
105
- req_count = 0
106
- while req_count < max_req_count and waiting_queue :
107
- items .append (waiting_queue .popleft ())
108
- req_count += 1
109
- return items
110
-
111
- def _get_from_waiting_queue_attention_dp (
112
- self ,
113
- waiting_queue : deque [RequestQueueItem ],
114
- max_req_count : int ,
115
- ) -> List [RequestQueueItem ]:
116
- """Extract requests from waiting queue with attention DP load balancing.
117
-
118
- Args:
119
- waiting_queue: Queue of pending requests
120
- max_req_count: Maximum number of requests to extract
121
-
122
- Returns:
123
- List of requests that can be processed immediately
124
- """
125
93
if max_req_count <= 0 :
126
94
return []
127
95
@@ -130,55 +98,45 @@ def _get_from_waiting_queue_attention_dp(
130
98
pending_requests = []
131
99
132
100
# Track the request with strict requirements
133
- all_ranks_num_active_requests = self .all_ranks_num_active_requests .copy (
134
- )
101
+ scheduling_all_ranks_num_active_requests = self .all_ranks_num_active_requests .copy (
102
+ ) if enable_attention_dp else None
135
103
while req_count < max_req_count and waiting_queue :
136
104
req_item = waiting_queue .popleft ()
137
- can_process_now = self ._can_process_attention_dp_request (
138
- req_item , all_ranks_num_active_requests )
105
+ can_process = self ._can_process_attention_dp_request (
106
+ req_item , scheduling_all_ranks_num_active_requests
107
+ ) if enable_attention_dp else True
139
108
140
- if can_process_now :
109
+ if can_process :
141
110
items .append (req_item )
142
111
req_count += 1
143
112
else :
144
113
pending_requests .append (req_item )
145
114
146
115
# Put the pending requests back to the waiting queue
147
116
# All ranks should have the same waiting queue
148
- self .waiting_queue .extendleft (pending_requests )
117
+ self .waiting_queue .extendleft (reversed ( pending_requests ) )
149
118
150
119
return items
151
120
152
121
def _can_process_attention_dp_request (
153
122
self , req_item : RequestQueueItem ,
154
123
all_ranks_num_active_requests : List [int ]) -> bool :
155
- """Check if a request can be processed immediately.
124
+ """Return True if the request can be processed immediately, else False."""
156
125
157
- Returns:
158
- True if the request can be processed now, False if it should be deferred.
159
- """
160
- # Handle requests without schedule parameters
161
- if req_item .request .py_schedule_params is None :
126
+ scheduling_params = req_item .request .py_scheduling_params
127
+ if scheduling_params is None :
162
128
return True
163
129
164
- schedule_params = req_item .request .py_schedule_params
165
- target_dp_rank = schedule_params .attention_dp_rank
166
- is_relax = schedule_params .attention_dp_relax
167
-
168
- # Handle requests without target rank or in relax mode
169
- if target_dp_rank is None or is_relax :
130
+ target_dp_rank = scheduling_params .attention_dp_rank
131
+ if target_dp_rank is None or scheduling_params .attention_dp_relax :
170
132
return True
171
133
172
- # Handle strict mode requests - check target rank capacity
173
- target_rank_has_capacity = (
174
- all_ranks_num_active_requests [target_dp_rank ]
175
- < self .max_num_active_requests )
176
-
177
- if target_rank_has_capacity :
134
+ if all_ranks_num_active_requests [
135
+ target_dp_rank ] < self .max_num_active_requests :
178
136
all_ranks_num_active_requests [target_dp_rank ] += 1
179
137
return True
180
- else :
181
- return False
138
+
139
+ return False
182
140
183
141
def enqueue_requests (self , requests : List [ExecutorRequest ]):
184
142
req_ids = []
@@ -238,7 +196,9 @@ def can_enqueue_request(self) -> bool:
238
196
return can_enqueue and self .dist .rank == 0
239
197
240
198
def _fetch_and_process_requests (
241
- self , total_num_active_requests : int ) -> List [RequestQueueItem ]:
199
+ self , total_num_active_requests : int ,
200
+ total_max_num_active_requests : int ,
201
+ enable_attention_dp : bool ) -> List [RequestQueueItem ]:
242
202
"""Common logic for fetching and processing requests from the queue."""
243
203
# Calculate timeout
244
204
timeout = None if (total_num_active_requests == 0 ) and len (
@@ -264,6 +224,17 @@ def _fetch_and_process_requests(
264
224
265
225
self .waiting_queue .extend (new_requests )
266
226
227
+ new_requests = self ._get_from_waiting_queue (
228
+ self .waiting_queue ,
229
+ total_max_num_active_requests - total_num_active_requests ,
230
+ enable_attention_dp )
231
+
232
+ # Update performance metrics
233
+ if self .enable_iter_perf_stats and self .dist .rank == 0 :
234
+ self ._update_new_active_requests_queue_latency (new_requests )
235
+
236
+ return new_requests
237
+
267
238
@nvtx_range ("_fetch_new_requests" )
268
239
def fetch_new_requests (self ,
269
240
num_active_requests : int ) -> List [RequestQueueItem ]:
@@ -280,15 +251,10 @@ def _fetch_new_requests_attention_tp(
280
251
total_max_num_active_requests = self .max_num_active_requests
281
252
282
253
# fetch and process requests into waiting queue
283
- self ._fetch_and_process_requests (total_num_active_requests )
284
-
285
- new_requests = self ._get_from_waiting_queue_attention_tp (
286
- self .waiting_queue ,
287
- total_max_num_active_requests - total_num_active_requests )
288
-
289
- # Update performance metrics
290
- if self .enable_iter_perf_stats and self .dist .rank == 0 :
291
- self ._update_new_active_requests_queue_latency (new_requests )
254
+ new_requests = self ._fetch_and_process_requests (
255
+ total_num_active_requests ,
256
+ total_max_num_active_requests ,
257
+ enable_attention_dp = False )
292
258
293
259
# Merge requests and add to active list
294
260
merged_requests = self ._merge_requests (new_requests )
@@ -307,16 +273,10 @@ def _fetch_new_requests_attention_dp(
307
273
total_max_num_active_requests = self .dist .tp_size * self .max_num_active_requests
308
274
309
275
# fetch and process requests into waiting queue
310
- self ._fetch_and_process_requests (total_num_active_requests )
311
-
312
- new_requests = self ._get_from_waiting_queue_attention_dp (
313
- self .waiting_queue ,
314
- total_max_num_active_requests - total_num_active_requests )
315
-
316
- # Update performance metrics
317
- # TODO: Check whether we should update the performance metrics for all ranks
318
- if self .enable_iter_perf_stats and self .dist .rank == 0 :
319
- self ._update_new_active_requests_queue_latency (new_requests )
276
+ new_requests = self ._fetch_and_process_requests (
277
+ total_num_active_requests ,
278
+ total_max_num_active_requests ,
279
+ enable_attention_dp = True )
320
280
321
281
# Schedule attention dp requests
322
282
new_requests_cur_rank = self ._schedule_attention_dp_requests (
@@ -342,9 +302,9 @@ def _schedule_attention_dp_requests(
342
302
343
303
# Prioritize the requests that are not in relax mode
344
304
def get_relax_value (req_item ):
345
- if req_item .request .py_schedule_params is None :
305
+ if req_item .request .py_scheduling_params is None :
346
306
return True
347
- return req_item .request .py_schedule_params .attention_dp_relax
307
+ return req_item .request .py_scheduling_params .attention_dp_relax
348
308
349
309
new_requests = sorted (new_requests , key = get_relax_value , reverse = True )
350
310
@@ -353,8 +313,8 @@ def get_relax_value(req_item):
353
313
new_requests_cur_rank = []
354
314
for req_item in new_requests :
355
315
scheduled = False
356
- if req_item .request .py_schedule_params is not None :
357
- target_dp_rank = req_item .request .py_schedule_params .attention_dp_rank
316
+ if req_item .request .py_scheduling_params is not None :
317
+ target_dp_rank = req_item .request .py_scheduling_params .attention_dp_rank
358
318
if target_dp_rank is not None and self .all_ranks_num_active_requests [
359
319
target_dp_rank ] < self .max_num_active_requests :
360
320
self .all_ranks_num_active_requests [target_dp_rank ] += 1
0 commit comments