Skip to content

Commit 652cbfd

Browse files
ShunkangzShunkangpcastonguay
authored andcommitted
[None][feat] Add support of scheduling attention dp request (NVIDIA#6246)
Signed-off-by: Shunkang <[email protected]> Signed-off-by: Patrice Castonguay <[email protected]> Co-authored-by: Shunkang <[email protected]> Co-authored-by: Patrice Castonguay <[email protected]> Signed-off-by: Lanyu Liao <[email protected]>
1 parent c17c423 commit 652cbfd

File tree

8 files changed

+693
-153
lines changed

8 files changed

+693
-153
lines changed

tensorrt_llm/_torch/pyexecutor/executor_request_queue.py

Lines changed: 150 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -87,27 +87,68 @@ def _get_from_waiting_queue(
8787
self,
8888
waiting_queue: deque[RequestQueueItem],
8989
max_req_count: int,
90+
enable_attention_dp: bool,
91+
all_ranks_num_active_requests: Optional[List[int]] = None,
9092
) -> List[RequestQueueItem]:
91-
"""Safely extracts up to max_req_count items from a deque.
92-
93+
"""
9394
Args:
9495
waiting_queue: The queue to pop items from.
9596
max_req_count: Maximum items to retrieve. Returns empty list if <=0.
96-
97+
enable_attention_dp: Whether to enable attention DP scheduling.
98+
all_ranks_num_active_requests: Number of active requests for each rank.
9799
Returns:
98-
List of retrieved items (may be shorter than max_req_count if queue empties first).
100+
List of requests that can be processed.
99101
"""
100-
# Edge case handling
101-
if max_req_count <= 0: # Handles negative/zero counts
102+
103+
if max_req_count <= 0:
102104
return []
103105

104-
items = []
105106
req_count = 0
107+
items = []
108+
pending_requests = []
109+
110+
# Track the request with strict requirements
111+
scheduling_all_ranks_num_active_requests = all_ranks_num_active_requests.copy(
112+
) if enable_attention_dp else None
106113
while req_count < max_req_count and waiting_queue:
107-
items.append(waiting_queue.popleft())
108-
req_count += 1
114+
req_item = waiting_queue.popleft()
115+
can_process = self._can_process_attention_dp_request(
116+
req_item, scheduling_all_ranks_num_active_requests
117+
) if enable_attention_dp else True
118+
119+
if can_process:
120+
items.append(req_item)
121+
req_count += 1
122+
else:
123+
pending_requests.append(req_item)
124+
125+
# Put the pending requests back to the waiting queue
126+
# All ranks should have the same waiting queue
127+
waiting_queue.extendleft(reversed(pending_requests))
128+
109129
return items
110130

131+
def _can_process_attention_dp_request(
132+
self, req_item: RequestQueueItem,
133+
all_ranks_num_active_requests: List[int]) -> bool:
134+
"""Return True if the request can be processed immediately, else False."""
135+
136+
scheduling_params = getattr(req_item.request, 'py_scheduling_params',
137+
None)
138+
if scheduling_params is None:
139+
return True
140+
141+
target_dp_rank = scheduling_params.attention_dp_rank
142+
if target_dp_rank is None or scheduling_params.attention_dp_relax:
143+
return True
144+
145+
if all_ranks_num_active_requests[
146+
target_dp_rank] < self.max_num_active_requests:
147+
all_ranks_num_active_requests[target_dp_rank] += 1
148+
return True
149+
150+
return False
151+
111152
def enqueue_requests(self, requests: List[ExecutorRequest]):
112153
req_ids = []
113154
with self.enqueue_lock:
@@ -152,8 +193,12 @@ def can_enqueue_request(self) -> bool:
152193
return self.active and self.dist.rank == 0
153194

154195
def _fetch_and_process_requests(
155-
self, total_num_active_requests: int,
156-
total_max_num_active_requests: int) -> List[RequestQueueItem]:
196+
self,
197+
total_num_active_requests: int,
198+
total_max_num_active_requests: int,
199+
enable_attention_dp: bool,
200+
all_ranks_num_active_requests: Optional[List[int]] = None
201+
) -> List[RequestQueueItem]:
157202
"""Common logic for fetching and processing requests from the queue."""
158203
# Calculate timeout
159204
timeout = None if (total_num_active_requests == 0) and len(
@@ -181,7 +226,8 @@ def _fetch_and_process_requests(
181226

182227
new_requests = self._get_from_waiting_queue(
183228
self.waiting_queue,
184-
total_max_num_active_requests - total_num_active_requests)
229+
total_max_num_active_requests - total_num_active_requests,
230+
enable_attention_dp, all_ranks_num_active_requests)
185231

186232
# Update performance metrics
187233
if self.enable_iter_perf_stats and self.dist.rank == 0:
@@ -204,9 +250,11 @@ def _fetch_new_requests_attention_tp(
204250
total_num_active_requests = num_active_requests
205251
total_max_num_active_requests = self.max_num_active_requests
206252

207-
# Use common request fetching logic
253+
# fetch and process requests into waiting queue
208254
new_requests = self._fetch_and_process_requests(
209-
total_num_active_requests, total_max_num_active_requests)
255+
total_num_active_requests,
256+
total_max_num_active_requests,
257+
enable_attention_dp=False)
210258

211259
# Merge requests and add to active list
212260
merged_requests = self._merge_requests(new_requests)
@@ -224,34 +272,84 @@ def _fetch_new_requests_attention_dp(
224272
total_num_active_requests = sum(all_ranks_num_active_requests)
225273
total_max_num_active_requests = self.dist.tp_size * self.max_num_active_requests
226274

227-
# Use common request fetching logic
275+
# fetch and process requests into waiting queue
228276
new_requests = self._fetch_and_process_requests(
229-
total_num_active_requests, total_max_num_active_requests)
277+
total_num_active_requests,
278+
total_max_num_active_requests,
279+
enable_attention_dp=True,
280+
all_ranks_num_active_requests=all_ranks_num_active_requests)
230281

231-
# Balance requests across ranks
232-
num_new_requests_all_ranks = len(new_requests)
233-
self.expected_num_active_requests = max(
234-
(total_num_active_requests + num_new_requests_all_ranks +
235-
self.dist.tp_size - 1) // self.dist.tp_size,
236-
max(all_ranks_num_active_requests),
237-
)
238-
239-
new_requests_cur_rank = self._balance_requests_across_ranks(
282+
# Schedule attention dp requests
283+
all_ranks_new_requests = self._schedule_attention_dp_requests(
240284
new_requests, all_ranks_num_active_requests)
285+
new_requests_cur_rank = all_ranks_new_requests[self.dist.tp_rank]
241286

242287
# Update performance metrics
243288
if self.enable_iter_perf_stats and self.start_times:
244289
self._update_new_active_requests_queue_latency(
245290
new_requests_cur_rank)
246291

247292
# Update counters
248-
self.num_fetch_requests += num_new_requests_all_ranks
293+
self.num_fetch_requests += len(new_requests)
249294
self.num_fetch_requests_cur_rank += len(new_requests_cur_rank)
250295

251296
# Merge requests and add to active list
252297
new_requests_cur_rank = self._merge_requests(new_requests_cur_rank)
253298
return new_requests_cur_rank
254299

300+
def _schedule_attention_dp_requests(
301+
self, new_requests: List[RequestQueueItem],
302+
all_ranks_num_active_requests: List[int]) -> List[RequestQueueItem]:
303+
"""Schedule attention dp requests."""
304+
305+
# Map from ranks to new requests
306+
all_ranks_new_requests = {
307+
tp_rank: []
308+
for tp_rank in range(self.dist.tp_size)
309+
}
310+
311+
# Prioritize the requests that are not in relax mode
312+
def get_relax_value(req_item):
313+
scheduling_params = getattr(req_item.request,
314+
'py_scheduling_params', None)
315+
if scheduling_params is None:
316+
return True
317+
return scheduling_params.attention_dp_relax
318+
319+
new_requests = sorted(new_requests, key=get_relax_value, reverse=True)
320+
321+
# Try to put the requests to the target dp rank until the max_num_active_requests is reached
322+
remaining_unscheduled = []
323+
for req_item in new_requests:
324+
scheduled = False
325+
scheduling_params = getattr(req_item.request,
326+
'py_scheduling_params', None)
327+
if scheduling_params is not None:
328+
target_dp_rank = scheduling_params.attention_dp_rank
329+
if target_dp_rank is not None and all_ranks_num_active_requests[
330+
target_dp_rank] < self.max_num_active_requests:
331+
all_ranks_num_active_requests[target_dp_rank] += 1
332+
scheduled = True
333+
all_ranks_new_requests[target_dp_rank].append(req_item)
334+
335+
if not scheduled:
336+
remaining_unscheduled.append(req_item)
337+
338+
# Balance the remaining unscheduled requests across ranks
339+
num_new_requests_all_ranks = len(remaining_unscheduled)
340+
total_num_active_requests = sum(all_ranks_num_active_requests)
341+
self.expected_num_active_requests = max(
342+
(total_num_active_requests + num_new_requests_all_ranks +
343+
self.dist.tp_size - 1) // self.dist.tp_size,
344+
max(all_ranks_num_active_requests),
345+
)
346+
347+
all_ranks_new_requests = self._balance_requests_across_ranks(
348+
remaining_unscheduled, all_ranks_new_requests,
349+
all_ranks_num_active_requests)
350+
351+
return all_ranks_new_requests
352+
255353
def _handle_request_broadcasting(self,
256354
new_requests: List[RequestQueueItem]):
257355
"""Handle broadcasting of requests and Python objects across ranks."""
@@ -260,8 +358,13 @@ def _handle_request_broadcasting(self,
260358
new_requests, "py_logits_post_processors")
261359
py_multimodal_data = self._collect_py_objects_from_requests(
262360
new_requests, "py_multimodal_data")
361+
py_scheduling_params = self._collect_py_objects_from_requests(
362+
new_requests, "py_scheduling_params")
263363
py_request_objects = tuple(
264-
filter(None, [py_logits_post_processors, py_multimodal_data]))
364+
filter(None, [
365+
py_logits_post_processors, py_multimodal_data,
366+
py_scheduling_params
367+
]))
265368
else:
266369
py_request_objects = None
267370

@@ -300,28 +403,30 @@ def _validate_and_filter_requests(
300403

301404
def _balance_requests_across_ranks(
302405
self, new_requests: List[RequestQueueItem],
406+
all_ranks_new_requests: Dict[int, List[RequestQueueItem]],
303407
all_ranks_num_active_requests: List[int]) -> List[RequestQueueItem]:
304408
"""Balance requests across ranks for attention DP."""
305-
new_requests_cur_rank = []
306-
307-
if new_requests and self.expected_num_active_requests > all_ranks_num_active_requests[
308-
self.dist.tp_rank]:
409+
if new_requests:
309410
# Balance context tokens across ranks using heap
310411
HeapVal = namedtuple(
311412
'HeapVal',
312413
['num_tokens', 'num_requests', 'rank', 'request_list'])
313414

314415
all_ranks_new_requests_heap = [
315-
HeapVal(0, self.expected_num_active_requests - val, tp_rank, [])
416+
HeapVal(0, val, tp_rank, [])
316417
for tp_rank, val in enumerate(all_ranks_num_active_requests)
317418
]
318419

319-
new_requests_cur_rank = all_ranks_new_requests_heap[
320-
self.dist.tp_rank].request_list
321420
all_ranks_new_requests_heap = [
322421
val for val in all_ranks_new_requests_heap
323-
if val.num_requests > 0
422+
if val.num_requests < self.expected_num_active_requests
324423
]
424+
425+
all_ranks_new_scheduled_requests = {
426+
val.rank: val.request_list
427+
for val in all_ranks_new_requests_heap
428+
}
429+
325430
heapq.heapify(all_ranks_new_requests_heap)
326431

327432
# Sort by token count (descending) for better load balancing
@@ -337,17 +442,22 @@ def _balance_requests_across_ranks(
337442
token_count = len(
338443
getattr(req_item.request, 'input_token_ids',
339444
[])) if req_item.request else 0
445+
# Update the heap value with the new request
340446
val = val._replace(
341447
num_tokens=val.num_tokens + token_count,
342-
num_requests=val.num_requests - 1,
448+
num_requests=val.num_requests + 1,
343449
)
450+
344451
val.request_list.append(req_item)
345-
if val.num_requests > 0:
452+
# If rank still has room for new requests, push back into heap
453+
if val.num_requests < self.expected_num_active_requests:
346454
heapq.heappush(all_ranks_new_requests_heap, val)
347-
elif val.rank == self.dist.tp_rank:
348-
break
349455

350-
return new_requests_cur_rank
456+
# Extend all_ranks_new_requests with the new requests that have been scheduled
457+
for rank, reqs in all_ranks_new_scheduled_requests.items():
458+
all_ranks_new_requests[rank].extend(reqs)
459+
460+
return all_ranks_new_requests
351461

352462
def _collect_py_objects_from_requests(
353463
self, requests: List[RequestQueueItem],

tensorrt_llm/executor/executor.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
print_colored_debug)
3030
from ..sampling_params import (BatchedLogitsProcessor, LogprobParams,
3131
SamplingParams)
32+
from ..scheduling_params import SchedulingParams
3233
from .ipc import FusedIpcQueue
3334
from .postproc_worker import PostprocParams, PostprocWorkerConfig
3435
from .request import GenerationRequest, LoRARequest, PromptAdapterRequest
@@ -120,6 +121,7 @@ def generate_async(
120121
disaggregated_params: Optional[DisaggregatedParams] = None,
121122
postproc_params: Optional[PostprocParams] = None,
122123
multimodal_params: Optional[MultimodalParams] = None,
124+
scheduling_params: Optional[SchedulingParams] = None,
123125
) -> GenerationResult:
124126
"""Generate output for the given prompt token ids in the asynchronous mode.
125127
Asynchronous generation accepts single prompt only.
@@ -142,7 +144,8 @@ def generate_async(
142144
streaming=streaming,
143145
kv_cache_retention_config=kv_cache_retention_config,
144146
disaggregated_params=disaggregated_params,
145-
multimodal_params=multimodal_params)
147+
multimodal_params=multimodal_params,
148+
scheduling_params=scheduling_params)
146149
result = self.submit(request)
147150
# release memory in time
148151
if hasattr(request, "multimodal_params"):

tensorrt_llm/executor/request.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from ..disaggregated_params import DisaggregatedParams
1111
from ..llmapi.llm_utils import KvCacheRetentionConfig
1212
from ..sampling_params import SamplingParams
13+
from ..scheduling_params import SchedulingParams
1314
from .postproc_worker import PostprocParams
1415

1516
__all__ = [
@@ -95,6 +96,7 @@ def __init__(
9596
disaggregated_params: Optional[DisaggregatedParams] = None,
9697
postproc_params: Optional[PostprocParams] = None,
9798
multimodal_params: Optional[MultimodalParams] = None,
99+
scheduling_params: Optional[SchedulingParams] = None,
98100
):
99101
if isinstance(prompt_token_ids, list):
100102
self.prompt_token_ids = prompt_token_ids
@@ -119,6 +121,7 @@ def __init__(
119121
self.kv_cache_retention_config = kv_cache_retention_config
120122
self.id: Optional[int] = None
121123
self.disaggregated_params = disaggregated_params
124+
self.scheduling_params = scheduling_params
122125

123126
def set_id(self, id):
124127
assert self.id is None, f"Request ID is already set: {self.id}"

tensorrt_llm/executor/worker.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -520,6 +520,10 @@ def _deduce_max_tokens(request: GenerationRequest,
520520
executor_request.py_logits_post_processors = lp if isinstance(
521521
lp, list) else [lp]
522522

523+
executor_request.py_scheduling_params = None
524+
if self._is_pytorch_backend and request.scheduling_params is not None:
525+
executor_request.py_scheduling_params = request.scheduling_params
526+
523527
if request.query_token_ids is not None:
524528
# pytorch star attention workflow
525529
# a workaround to avoid public interface update

0 commit comments

Comments
 (0)