Skip to content

Commit ac0d750

Browse files
ShunkangShunkang
authored andcommitted
Add attention dp scheduling logic
Signed-off-by: Shunkang <[email protected]>
1 parent 84bb95c commit ac0d750

File tree

1 file changed

+86
-11
lines changed

1 file changed

+86
-11
lines changed

tensorrt_llm/_torch/pyexecutor/executor_request_queue.py

Lines changed: 86 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -244,30 +244,105 @@ def _fetch_new_requests_attention_dp(
244244
new_requests = self._fetch_and_process_requests(
245245
total_num_active_requests, total_max_num_active_requests)
246246

247-
# Balance requests across ranks
248-
num_new_requests_all_ranks = len(new_requests)
249-
self.expected_num_active_requests = max(
250-
(total_num_active_requests + num_new_requests_all_ranks +
251-
self.dist.tp_size - 1) // self.dist.tp_size,
252-
max(all_ranks_num_active_requests),
253-
)
254-
255-
new_requests_cur_rank = self._balance_requests_across_ranks(
256-
new_requests, all_ranks_num_active_requests)
247+
# Schedule attention dp requests
248+
new_requests_cur_rank = self._schedule_attention_dp_requests(
249+
num_active_requests, new_requests, all_ranks_num_active_requests)
257250

258251
# Update performance metrics
259252
if self.enable_iter_perf_stats and self.start_times:
260253
self._update_new_active_requests_queue_latency(
261254
new_requests_cur_rank)
262255

263256
# Update counters
264-
self.num_fetch_requests += num_new_requests_all_ranks
257+
self.num_fetch_requests += len(new_requests)
265258
self.num_fetch_requests_cur_rank += len(new_requests_cur_rank)
266259

267260
# Merge requests and add to active list
268261
new_requests_cur_rank = self._merge_requests(new_requests_cur_rank)
269262
return new_requests_cur_rank
270263

264+
def _schedule_attention_dp_requests(
265+
self, num_active_requests: int,
266+
new_requests: List[RequestQueueItem],
267+
all_ranks_num_active_requests: List[int]) -> List[RequestQueueItem]:
268+
"""Schedule attention dp requests."""
269+
# Separate the requests into two groups
270+
# 1. requests without schedule params or with schedule params that don't specify attention dp rank
271+
# 2. requests with schedule params that specify attention dp rank
272+
requests_specified_attention_dp_rank = []
273+
for req_item in new_requests:
274+
if req_item.request.schedule_params is not None and \
275+
req_item.request.schedule_params.attention_dp_rank == self.dist.tp_rank:
276+
requests_specified_attention_dp_rank.append(req_item)
277+
278+
# Routing requests to the corresponding attention dp without exceeding the max_num_active_requests
279+
new_requests_cur_rank = []
280+
new_requests_cur_rank_waiting = []
281+
new_requests_cur_rank_relax = []
282+
283+
available_slots = self.max_num_active_requests - num_active_requests
284+
285+
for req_item in requests_specified_attention_dp_rank:
286+
is_relax = req_item.request.schedule_params.attention_dp_relax
287+
288+
if len(new_requests_cur_rank) < available_slots:
289+
# Prioritize the non-relax requests
290+
target_list = new_requests_cur_rank_relax if is_relax else new_requests_cur_rank
291+
target_list.append(req_item)
292+
else:
293+
# Add to waiting queue
294+
target_list = new_requests_cur_rank_relax if is_relax else new_requests_cur_rank_waiting
295+
target_list.append(req_item)
296+
297+
items_to_move = available_slots - len(new_requests_cur_rank)
298+
if items_to_move > 0:
299+
new_requests_cur_rank.extend(
300+
new_requests_cur_rank_relax[:items_to_move])
301+
new_requests_cur_rank_relax = new_requests_cur_rank_relax[
302+
items_to_move:]
303+
304+
# Allgather the non-scheduled requests across ranks
305+
# TODO: Remove the padding overhead
306+
new_requests_cur_rank_relax_ids = [
307+
req_item.id for req_item in new_requests_cur_rank_relax
308+
]
309+
padding_num = self.max_num_active_requests - len(
310+
new_requests_cur_rank_relax)
311+
for _ in range(padding_num):
312+
new_requests_cur_rank_relax_ids.append(None)
313+
non_scheduled_requests_id = self.dist.tp_allgather(
314+
new_requests_cur_rank_relax_ids)
315+
non_scheduled_requests_id = [
316+
req_id for req_id in non_scheduled_requests_id if req_id is not None
317+
]
318+
319+
# Non-scheduled requests should be same across ranks
320+
non_scheduled_requests = []
321+
for req_item in new_requests:
322+
if req_item.id in non_scheduled_requests_id:
323+
non_scheduled_requests.append(req_item)
324+
elif req_item.request.schedule_params is None or \
325+
req_item.request.schedule_params is not None and \
326+
req_item.request.schedule_params.attention_dp_rank is None:
327+
non_scheduled_requests.append(req_item)
328+
329+
# Put the request back to the waiting queue
330+
self.waiting_queue.extendleft(new_requests_cur_rank_waiting)
331+
332+
# TODO: Balance the no attention dp rank requests and relax requests across ranks
333+
num_new_requests_all_ranks = len(new_requests)
334+
total_num_active_requests = sum(all_ranks_num_active_requests)
335+
self.expected_num_active_requests = max(
336+
(total_num_active_requests + num_new_requests_all_ranks +
337+
self.dist.tp_size - 1) // self.dist.tp_size,
338+
max(all_ranks_num_active_requests),
339+
)
340+
341+
new_requests_cur_rank = self._balance_requests_across_ranks(
342+
non_scheduled_requests, )
343+
344+
return new_requests_cur_rank
345+
271346
def _handle_request_broadcasting(self,
272347
new_requests: List[RequestQueueItem]):
273348
"""Handle broadcasting of requests and Python objects across ranks."""

0 commit comments

Comments
 (0)