Skip to content

Commit 67a3fd8

Browse files
ShunkangzShunkangpcastonguay
authored
[None][feat] Add support of scheduling attention dp request (#6246)
Signed-off-by: Shunkang <[email protected]> Signed-off-by: Patrice Castonguay <[email protected]> Co-authored-by: Shunkang <[email protected]> Co-authored-by: Patrice Castonguay <[email protected]>
1 parent 31802de commit 67a3fd8

File tree

8 files changed

+693
-153
lines changed

8 files changed

+693
-153
lines changed

tensorrt_llm/_torch/pyexecutor/executor_request_queue.py

Lines changed: 150 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -87,27 +87,68 @@ def _get_from_waiting_queue(
8787
self,
8888
waiting_queue: deque[RequestQueueItem],
8989
max_req_count: int,
90+
enable_attention_dp: bool,
91+
all_ranks_num_active_requests: Optional[List[int]] = None,
9092
) -> List[RequestQueueItem]:
91-
"""Safely extracts up to max_req_count items from a deque.
92-
93+
"""
9394
Args:
9495
waiting_queue: The queue to pop items from.
9596
max_req_count: Maximum items to retrieve. Returns empty list if <=0.
96-
97+
enable_attention_dp: Whether to enable attention DP scheduling.
98+
all_ranks_num_active_requests: Number of active requests for each rank.
9799
Returns:
98-
List of retrieved items (may be shorter than max_req_count if queue empties first).
100+
List of requests that can be processed.
99101
"""
100-
# Edge case handling
101-
if max_req_count <= 0: # Handles negative/zero counts
102+
103+
if max_req_count <= 0:
102104
return []
103105

104-
items = []
105106
req_count = 0
107+
items = []
108+
pending_requests = []
109+
110+
# Track the request with strict requirements
111+
scheduling_all_ranks_num_active_requests = all_ranks_num_active_requests.copy(
112+
) if enable_attention_dp else None
106113
while req_count < max_req_count and waiting_queue:
107-
items.append(waiting_queue.popleft())
108-
req_count += 1
114+
req_item = waiting_queue.popleft()
115+
can_process = self._can_process_attention_dp_request(
116+
req_item, scheduling_all_ranks_num_active_requests
117+
) if enable_attention_dp else True
118+
119+
if can_process:
120+
items.append(req_item)
121+
req_count += 1
122+
else:
123+
pending_requests.append(req_item)
124+
125+
# Put the pending requests back to the waiting queue
126+
# All ranks should have the same waiting queue
127+
waiting_queue.extendleft(reversed(pending_requests))
128+
109129
return items
110130

131+
def _can_process_attention_dp_request(
132+
self, req_item: RequestQueueItem,
133+
all_ranks_num_active_requests: List[int]) -> bool:
134+
"""Return True if the request can be processed immediately, else False."""
135+
136+
scheduling_params = getattr(req_item.request, 'py_scheduling_params',
137+
None)
138+
if scheduling_params is None:
139+
return True
140+
141+
target_dp_rank = scheduling_params.attention_dp_rank
142+
if target_dp_rank is None or scheduling_params.attention_dp_relax:
143+
return True
144+
145+
if all_ranks_num_active_requests[
146+
target_dp_rank] < self.max_num_active_requests:
147+
all_ranks_num_active_requests[target_dp_rank] += 1
148+
return True
149+
150+
return False
151+
111152
def enqueue_requests(self, requests: List[ExecutorRequest]):
112153
req_ids = []
113154
try:
@@ -166,8 +207,12 @@ def can_enqueue_request(self) -> bool:
166207
return can_enqueue and self.dist.rank == 0
167208

168209
def _fetch_and_process_requests(
169-
self, total_num_active_requests: int,
170-
total_max_num_active_requests: int) -> List[RequestQueueItem]:
210+
self,
211+
total_num_active_requests: int,
212+
total_max_num_active_requests: int,
213+
enable_attention_dp: bool,
214+
all_ranks_num_active_requests: Optional[List[int]] = None
215+
) -> List[RequestQueueItem]:
171216
"""Common logic for fetching and processing requests from the queue."""
172217
# Calculate timeout
173218
timeout = None if (total_num_active_requests == 0) and len(
@@ -195,7 +240,8 @@ def _fetch_and_process_requests(
195240

196241
new_requests = self._get_from_waiting_queue(
197242
self.waiting_queue,
198-
total_max_num_active_requests - total_num_active_requests)
243+
total_max_num_active_requests - total_num_active_requests,
244+
enable_attention_dp, all_ranks_num_active_requests)
199245

200246
# Update performance metrics
201247
if self.enable_iter_perf_stats and self.dist.rank == 0:
@@ -218,9 +264,11 @@ def _fetch_new_requests_attention_tp(
218264
total_num_active_requests = num_active_requests
219265
total_max_num_active_requests = self.max_num_active_requests
220266

221-
# Use common request fetching logic
267+
# fetch and process requests into waiting queue
222268
new_requests = self._fetch_and_process_requests(
223-
total_num_active_requests, total_max_num_active_requests)
269+
total_num_active_requests,
270+
total_max_num_active_requests,
271+
enable_attention_dp=False)
224272

225273
# Merge requests and add to active list
226274
merged_requests = self._merge_requests(new_requests)
@@ -238,34 +286,84 @@ def _fetch_new_requests_attention_dp(
238286
total_num_active_requests = sum(all_ranks_num_active_requests)
239287
total_max_num_active_requests = self.dist.tp_size * self.max_num_active_requests
240288

241-
# Use common request fetching logic
289+
# fetch and process requests into waiting queue
242290
new_requests = self._fetch_and_process_requests(
243-
total_num_active_requests, total_max_num_active_requests)
291+
total_num_active_requests,
292+
total_max_num_active_requests,
293+
enable_attention_dp=True,
294+
all_ranks_num_active_requests=all_ranks_num_active_requests)
244295

245-
# Balance requests across ranks
246-
num_new_requests_all_ranks = len(new_requests)
247-
self.expected_num_active_requests = max(
248-
(total_num_active_requests + num_new_requests_all_ranks +
249-
self.dist.tp_size - 1) // self.dist.tp_size,
250-
max(all_ranks_num_active_requests),
251-
)
252-
253-
new_requests_cur_rank = self._balance_requests_across_ranks(
296+
# Schedule attention dp requests
297+
all_ranks_new_requests = self._schedule_attention_dp_requests(
254298
new_requests, all_ranks_num_active_requests)
299+
new_requests_cur_rank = all_ranks_new_requests[self.dist.tp_rank]
255300

256301
# Update performance metrics
257302
if self.enable_iter_perf_stats and self.start_times:
258303
self._update_new_active_requests_queue_latency(
259304
new_requests_cur_rank)
260305

261306
# Update counters
262-
self.num_fetch_requests += num_new_requests_all_ranks
307+
self.num_fetch_requests += len(new_requests)
263308
self.num_fetch_requests_cur_rank += len(new_requests_cur_rank)
264309

265310
# Merge requests and add to active list
266311
new_requests_cur_rank = self._merge_requests(new_requests_cur_rank)
267312
return new_requests_cur_rank
268313

314+
def _schedule_attention_dp_requests(
315+
self, new_requests: List[RequestQueueItem],
316+
all_ranks_num_active_requests: List[int]) -> List[RequestQueueItem]:
317+
"""Schedule attention dp requests."""
318+
319+
# Map from ranks to new requests
320+
all_ranks_new_requests = {
321+
tp_rank: []
322+
for tp_rank in range(self.dist.tp_size)
323+
}
324+
325+
# Prioritize the requests that are not in relax mode
326+
def get_relax_value(req_item):
327+
scheduling_params = getattr(req_item.request,
328+
'py_scheduling_params', None)
329+
if scheduling_params is None:
330+
return True
331+
return scheduling_params.attention_dp_relax
332+
333+
new_requests = sorted(new_requests, key=get_relax_value, reverse=True)
334+
335+
# Try to put the requests to the target dp rank until the max_num_active_requests is reached
336+
remaining_unscheduled = []
337+
for req_item in new_requests:
338+
scheduled = False
339+
scheduling_params = getattr(req_item.request,
340+
'py_scheduling_params', None)
341+
if scheduling_params is not None:
342+
target_dp_rank = scheduling_params.attention_dp_rank
343+
if target_dp_rank is not None and all_ranks_num_active_requests[
344+
target_dp_rank] < self.max_num_active_requests:
345+
all_ranks_num_active_requests[target_dp_rank] += 1
346+
scheduled = True
347+
all_ranks_new_requests[target_dp_rank].append(req_item)
348+
349+
if not scheduled:
350+
remaining_unscheduled.append(req_item)
351+
352+
# Balance the remaining unscheduled requests across ranks
353+
num_new_requests_all_ranks = len(remaining_unscheduled)
354+
total_num_active_requests = sum(all_ranks_num_active_requests)
355+
self.expected_num_active_requests = max(
356+
(total_num_active_requests + num_new_requests_all_ranks +
357+
self.dist.tp_size - 1) // self.dist.tp_size,
358+
max(all_ranks_num_active_requests),
359+
)
360+
361+
all_ranks_new_requests = self._balance_requests_across_ranks(
362+
remaining_unscheduled, all_ranks_new_requests,
363+
all_ranks_num_active_requests)
364+
365+
return all_ranks_new_requests
366+
269367
def _handle_request_broadcasting(self,
270368
new_requests: List[RequestQueueItem]):
271369
"""Handle broadcasting of requests and Python objects across ranks."""
@@ -274,8 +372,13 @@ def _handle_request_broadcasting(self,
274372
new_requests, "py_logits_post_processors")
275373
py_multimodal_data = self._collect_py_objects_from_requests(
276374
new_requests, "py_multimodal_data")
375+
py_scheduling_params = self._collect_py_objects_from_requests(
376+
new_requests, "py_scheduling_params")
277377
py_request_objects = tuple(
278-
filter(None, [py_logits_post_processors, py_multimodal_data]))
378+
filter(None, [
379+
py_logits_post_processors, py_multimodal_data,
380+
py_scheduling_params
381+
]))
279382
else:
280383
py_request_objects = None
281384

@@ -314,28 +417,30 @@ def _validate_and_filter_requests(
314417

315418
def _balance_requests_across_ranks(
316419
self, new_requests: List[RequestQueueItem],
420+
all_ranks_new_requests: Dict[int, List[RequestQueueItem]],
317421
all_ranks_num_active_requests: List[int]) -> List[RequestQueueItem]:
318422
"""Balance requests across ranks for attention DP."""
319-
new_requests_cur_rank = []
320-
321-
if new_requests and self.expected_num_active_requests > all_ranks_num_active_requests[
322-
self.dist.tp_rank]:
423+
if new_requests:
323424
# Balance context tokens across ranks using heap
324425
HeapVal = namedtuple(
325426
'HeapVal',
326427
['num_tokens', 'num_requests', 'rank', 'request_list'])
327428

328429
all_ranks_new_requests_heap = [
329-
HeapVal(0, self.expected_num_active_requests - val, tp_rank, [])
430+
HeapVal(0, val, tp_rank, [])
330431
for tp_rank, val in enumerate(all_ranks_num_active_requests)
331432
]
332433

333-
new_requests_cur_rank = all_ranks_new_requests_heap[
334-
self.dist.tp_rank].request_list
335434
all_ranks_new_requests_heap = [
336435
val for val in all_ranks_new_requests_heap
337-
if val.num_requests > 0
436+
if val.num_requests < self.expected_num_active_requests
338437
]
438+
439+
all_ranks_new_scheduled_requests = {
440+
val.rank: val.request_list
441+
for val in all_ranks_new_requests_heap
442+
}
443+
339444
heapq.heapify(all_ranks_new_requests_heap)
340445

341446
# Sort by token count (descending) for better load balancing
@@ -351,17 +456,22 @@ def _balance_requests_across_ranks(
351456
token_count = len(
352457
getattr(req_item.request, 'input_token_ids',
353458
[])) if req_item.request else 0
459+
# Update the heap value with the new request
354460
val = val._replace(
355461
num_tokens=val.num_tokens + token_count,
356-
num_requests=val.num_requests - 1,
462+
num_requests=val.num_requests + 1,
357463
)
464+
358465
val.request_list.append(req_item)
359-
if val.num_requests > 0:
466+
# If rank still has room for new requests, push back into heap
467+
if val.num_requests < self.expected_num_active_requests:
360468
heapq.heappush(all_ranks_new_requests_heap, val)
361-
elif val.rank == self.dist.tp_rank:
362-
break
363469

364-
return new_requests_cur_rank
470+
# Extend all_ranks_new_requests with the new requests that have been scheduled
471+
for rank, reqs in all_ranks_new_scheduled_requests.items():
472+
all_ranks_new_requests[rank].extend(reqs)
473+
474+
return all_ranks_new_requests
365475

366476
def _collect_py_objects_from_requests(
367477
self, requests: List[RequestQueueItem],

tensorrt_llm/executor/executor.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
print_colored_debug)
3030
from ..sampling_params import (BatchedLogitsProcessor, LogprobParams,
3131
SamplingParams)
32+
from ..scheduling_params import SchedulingParams
3233
from .ipc import FusedIpcQueue
3334
from .postproc_worker import PostprocParams, PostprocWorkerConfig
3435
from .request import GenerationRequest, LoRARequest, PromptAdapterRequest
@@ -120,6 +121,7 @@ def generate_async(
120121
disaggregated_params: Optional[DisaggregatedParams] = None,
121122
postproc_params: Optional[PostprocParams] = None,
122123
multimodal_params: Optional[MultimodalParams] = None,
124+
scheduling_params: Optional[SchedulingParams] = None,
123125
) -> GenerationResult:
124126
"""Generate output for the given prompt token ids in the asynchronous mode.
125127
Asynchronous generation accepts single prompt only.
@@ -142,7 +144,8 @@ def generate_async(
142144
streaming=streaming,
143145
kv_cache_retention_config=kv_cache_retention_config,
144146
disaggregated_params=disaggregated_params,
145-
multimodal_params=multimodal_params)
147+
multimodal_params=multimodal_params,
148+
scheduling_params=scheduling_params)
146149
result = self.submit(request)
147150
# release memory in time
148151
if hasattr(request, "multimodal_params"):

tensorrt_llm/executor/request.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from ..disaggregated_params import DisaggregatedParams
1111
from ..llmapi.llm_utils import KvCacheRetentionConfig
1212
from ..sampling_params import SamplingParams
13+
from ..scheduling_params import SchedulingParams
1314
from .postproc_worker import PostprocParams
1415

1516
__all__ = [
@@ -95,6 +96,7 @@ def __init__(
9596
disaggregated_params: Optional[DisaggregatedParams] = None,
9697
postproc_params: Optional[PostprocParams] = None,
9798
multimodal_params: Optional[MultimodalParams] = None,
99+
scheduling_params: Optional[SchedulingParams] = None,
98100
):
99101
if isinstance(prompt_token_ids, list):
100102
self.prompt_token_ids = prompt_token_ids
@@ -119,6 +121,7 @@ def __init__(
119121
self.kv_cache_retention_config = kv_cache_retention_config
120122
self.id: Optional[int] = None
121123
self.disaggregated_params = disaggregated_params
124+
self.scheduling_params = scheduling_params
122125

123126
def set_id(self, id):
124127
assert self.id is None, f"Request ID is already set: {self.id}"

tensorrt_llm/executor/worker.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -520,6 +520,10 @@ def _deduce_max_tokens(request: GenerationRequest,
520520
executor_request.py_logits_post_processors = lp if isinstance(
521521
lp, list) else [lp]
522522

523+
executor_request.py_scheduling_params = None
524+
if self._is_pytorch_backend and request.scheduling_params is not None:
525+
executor_request.py_scheduling_params = request.scheduling_params
526+
523527
if request.query_token_ids is not None:
524528
# pytorch star attention workflow
525529
# a workaround to avoid public interface update

0 commit comments

Comments
 (0)