Skip to content

[None][feat] Add support of scheduling attention dp request #6246

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Aug 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 150 additions & 40 deletions tensorrt_llm/_torch/pyexecutor/executor_request_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,27 +87,68 @@ def _get_from_waiting_queue(
self,
waiting_queue: deque[RequestQueueItem],
max_req_count: int,
enable_attention_dp: bool,
all_ranks_num_active_requests: Optional[List[int]] = None,
) -> List[RequestQueueItem]:
"""Safely extracts up to max_req_count items from a deque.

"""
Args:
waiting_queue: The queue to pop items from.
max_req_count: Maximum items to retrieve. Returns empty list if <=0.

enable_attention_dp: Whether to enable attention DP scheduling.
all_ranks_num_active_requests: Number of active requests for each rank.
Returns:
List of retrieved items (may be shorter than max_req_count if queue empties first).
List of requests that can be processed.
"""
# Edge case handling
if max_req_count <= 0: # Handles negative/zero counts

if max_req_count <= 0:
return []

items = []
req_count = 0
items = []
pending_requests = []

# Track the request with strict requirements
scheduling_all_ranks_num_active_requests = all_ranks_num_active_requests.copy(
) if enable_attention_dp else None
while req_count < max_req_count and waiting_queue:
items.append(waiting_queue.popleft())
req_count += 1
req_item = waiting_queue.popleft()
can_process = self._can_process_attention_dp_request(
req_item, scheduling_all_ranks_num_active_requests
) if enable_attention_dp else True

if can_process:
items.append(req_item)
req_count += 1
else:
pending_requests.append(req_item)

# Put the pending requests back to the waiting queue
# All ranks should have the same waiting queue
waiting_queue.extendleft(reversed(pending_requests))

return items

def _can_process_attention_dp_request(
self, req_item: RequestQueueItem,
all_ranks_num_active_requests: List[int]) -> bool:
"""Return True if the request can be processed immediately, else False."""

scheduling_params = getattr(req_item.request, 'py_scheduling_params',
None)
if scheduling_params is None:
return True

target_dp_rank = scheduling_params.attention_dp_rank
if target_dp_rank is None or scheduling_params.attention_dp_relax:
return True

if all_ranks_num_active_requests[
target_dp_rank] < self.max_num_active_requests:
all_ranks_num_active_requests[target_dp_rank] += 1
return True

return False

def enqueue_requests(self, requests: List[ExecutorRequest]):
req_ids = []
try:
Expand Down Expand Up @@ -166,8 +207,12 @@ def can_enqueue_request(self) -> bool:
return can_enqueue and self.dist.rank == 0

def _fetch_and_process_requests(
self, total_num_active_requests: int,
total_max_num_active_requests: int) -> List[RequestQueueItem]:
self,
total_num_active_requests: int,
total_max_num_active_requests: int,
enable_attention_dp: bool,
all_ranks_num_active_requests: Optional[List[int]] = None
) -> List[RequestQueueItem]:
"""Common logic for fetching and processing requests from the queue."""
# Calculate timeout
timeout = None if (total_num_active_requests == 0) and len(
Expand Down Expand Up @@ -195,7 +240,8 @@ def _fetch_and_process_requests(

new_requests = self._get_from_waiting_queue(
self.waiting_queue,
total_max_num_active_requests - total_num_active_requests)
total_max_num_active_requests - total_num_active_requests,
enable_attention_dp, all_ranks_num_active_requests)

# Update performance metrics
if self.enable_iter_perf_stats and self.dist.rank == 0:
Expand All @@ -218,9 +264,11 @@ def _fetch_new_requests_attention_tp(
total_num_active_requests = num_active_requests
total_max_num_active_requests = self.max_num_active_requests

# Use common request fetching logic
# fetch and process requests into waiting queue
new_requests = self._fetch_and_process_requests(
total_num_active_requests, total_max_num_active_requests)
total_num_active_requests,
total_max_num_active_requests,
enable_attention_dp=False)

# Merge requests and add to active list
merged_requests = self._merge_requests(new_requests)
Expand All @@ -238,34 +286,84 @@ def _fetch_new_requests_attention_dp(
total_num_active_requests = sum(all_ranks_num_active_requests)
total_max_num_active_requests = self.dist.tp_size * self.max_num_active_requests

# Use common request fetching logic
# fetch and process requests into waiting queue
new_requests = self._fetch_and_process_requests(
total_num_active_requests, total_max_num_active_requests)
total_num_active_requests,
total_max_num_active_requests,
enable_attention_dp=True,
all_ranks_num_active_requests=all_ranks_num_active_requests)

# Balance requests across ranks
num_new_requests_all_ranks = len(new_requests)
self.expected_num_active_requests = max(
(total_num_active_requests + num_new_requests_all_ranks +
self.dist.tp_size - 1) // self.dist.tp_size,
max(all_ranks_num_active_requests),
)

new_requests_cur_rank = self._balance_requests_across_ranks(
# Schedule attention dp requests
all_ranks_new_requests = self._schedule_attention_dp_requests(
new_requests, all_ranks_num_active_requests)
new_requests_cur_rank = all_ranks_new_requests[self.dist.tp_rank]

# Update performance metrics
if self.enable_iter_perf_stats and self.start_times:
self._update_new_active_requests_queue_latency(
new_requests_cur_rank)

# Update counters
self.num_fetch_requests += num_new_requests_all_ranks
self.num_fetch_requests += len(new_requests)
self.num_fetch_requests_cur_rank += len(new_requests_cur_rank)

# Merge requests and add to active list
new_requests_cur_rank = self._merge_requests(new_requests_cur_rank)
return new_requests_cur_rank

def _schedule_attention_dp_requests(
self, new_requests: List[RequestQueueItem],
all_ranks_num_active_requests: List[int]) -> List[RequestQueueItem]:
"""Schedule attention dp requests."""

# Map from ranks to new requests
all_ranks_new_requests = {
tp_rank: []
for tp_rank in range(self.dist.tp_size)
}

# Prioritize the requests that are not in relax mode
def get_relax_value(req_item):
scheduling_params = getattr(req_item.request,
'py_scheduling_params', None)
if scheduling_params is None:
return True
return scheduling_params.attention_dp_relax

new_requests = sorted(new_requests, key=get_relax_value, reverse=True)

# Try to put the requests to the target dp rank until the max_num_active_requests is reached
remaining_unscheduled = []
for req_item in new_requests:
scheduled = False
scheduling_params = getattr(req_item.request,
'py_scheduling_params', None)
if scheduling_params is not None:
target_dp_rank = scheduling_params.attention_dp_rank
if target_dp_rank is not None and all_ranks_num_active_requests[
target_dp_rank] < self.max_num_active_requests:
all_ranks_num_active_requests[target_dp_rank] += 1
scheduled = True
all_ranks_new_requests[target_dp_rank].append(req_item)

if not scheduled:
remaining_unscheduled.append(req_item)

# Balance the remaining unscheduled requests across ranks
num_new_requests_all_ranks = len(remaining_unscheduled)
total_num_active_requests = sum(all_ranks_num_active_requests)
self.expected_num_active_requests = max(
(total_num_active_requests + num_new_requests_all_ranks +
self.dist.tp_size - 1) // self.dist.tp_size,
max(all_ranks_num_active_requests),
)

all_ranks_new_requests = self._balance_requests_across_ranks(
remaining_unscheduled, all_ranks_new_requests,
all_ranks_num_active_requests)

return all_ranks_new_requests

def _handle_request_broadcasting(self,
new_requests: List[RequestQueueItem]):
"""Handle broadcasting of requests and Python objects across ranks."""
Expand All @@ -274,8 +372,13 @@ def _handle_request_broadcasting(self,
new_requests, "py_logits_post_processors")
py_multimodal_data = self._collect_py_objects_from_requests(
new_requests, "py_multimodal_data")
py_scheduling_params = self._collect_py_objects_from_requests(
new_requests, "py_scheduling_params")
py_request_objects = tuple(
filter(None, [py_logits_post_processors, py_multimodal_data]))
filter(None, [
py_logits_post_processors, py_multimodal_data,
py_scheduling_params
]))
else:
py_request_objects = None

Expand Down Expand Up @@ -314,28 +417,30 @@ def _validate_and_filter_requests(

def _balance_requests_across_ranks(
self, new_requests: List[RequestQueueItem],
all_ranks_new_requests: Dict[int, List[RequestQueueItem]],
all_ranks_num_active_requests: List[int]) -> List[RequestQueueItem]:
"""Balance requests across ranks for attention DP."""
new_requests_cur_rank = []

if new_requests and self.expected_num_active_requests > all_ranks_num_active_requests[
self.dist.tp_rank]:
if new_requests:
# Balance context tokens across ranks using heap
HeapVal = namedtuple(
'HeapVal',
['num_tokens', 'num_requests', 'rank', 'request_list'])

all_ranks_new_requests_heap = [
HeapVal(0, self.expected_num_active_requests - val, tp_rank, [])
HeapVal(0, val, tp_rank, [])
for tp_rank, val in enumerate(all_ranks_num_active_requests)
]

new_requests_cur_rank = all_ranks_new_requests_heap[
self.dist.tp_rank].request_list
all_ranks_new_requests_heap = [
val for val in all_ranks_new_requests_heap
if val.num_requests > 0
if val.num_requests < self.expected_num_active_requests
]

all_ranks_new_scheduled_requests = {
val.rank: val.request_list
for val in all_ranks_new_requests_heap
}

heapq.heapify(all_ranks_new_requests_heap)

# Sort by token count (descending) for better load balancing
Expand All @@ -351,17 +456,22 @@ def _balance_requests_across_ranks(
token_count = len(
getattr(req_item.request, 'input_token_ids',
[])) if req_item.request else 0
# Update the heap value with the new request
val = val._replace(
num_tokens=val.num_tokens + token_count,
num_requests=val.num_requests - 1,
num_requests=val.num_requests + 1,
)

val.request_list.append(req_item)
if val.num_requests > 0:
# If rank still has room for new requests, push back into heap
if val.num_requests < self.expected_num_active_requests:
heapq.heappush(all_ranks_new_requests_heap, val)
elif val.rank == self.dist.tp_rank:
break

return new_requests_cur_rank
# Extend all_ranks_new_requests with the new requests that have been scheduled
for rank, reqs in all_ranks_new_scheduled_requests.items():
all_ranks_new_requests[rank].extend(reqs)

return all_ranks_new_requests

def _collect_py_objects_from_requests(
self, requests: List[RequestQueueItem],
Expand Down
5 changes: 4 additions & 1 deletion tensorrt_llm/executor/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
print_colored_debug)
from ..sampling_params import (BatchedLogitsProcessor, LogprobParams,
SamplingParams)
from ..scheduling_params import SchedulingParams
from .ipc import FusedIpcQueue
from .postproc_worker import PostprocParams, PostprocWorkerConfig
from .request import GenerationRequest, LoRARequest, PromptAdapterRequest
Expand Down Expand Up @@ -120,6 +121,7 @@ def generate_async(
disaggregated_params: Optional[DisaggregatedParams] = None,
postproc_params: Optional[PostprocParams] = None,
multimodal_params: Optional[MultimodalParams] = None,
scheduling_params: Optional[SchedulingParams] = None,
) -> GenerationResult:
"""Generate output for the given prompt token ids in the asynchronous mode.
Asynchronous generation accepts single prompt only.
Expand All @@ -142,7 +144,8 @@ def generate_async(
streaming=streaming,
kv_cache_retention_config=kv_cache_retention_config,
disaggregated_params=disaggregated_params,
multimodal_params=multimodal_params)
multimodal_params=multimodal_params,
scheduling_params=scheduling_params)
result = self.submit(request)
# release memory in time
if hasattr(request, "multimodal_params"):
Expand Down
3 changes: 3 additions & 0 deletions tensorrt_llm/executor/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from ..disaggregated_params import DisaggregatedParams
from ..llmapi.llm_utils import KvCacheRetentionConfig
from ..sampling_params import SamplingParams
from ..scheduling_params import SchedulingParams
from .postproc_worker import PostprocParams

__all__ = [
Expand Down Expand Up @@ -95,6 +96,7 @@ def __init__(
disaggregated_params: Optional[DisaggregatedParams] = None,
postproc_params: Optional[PostprocParams] = None,
multimodal_params: Optional[MultimodalParams] = None,
scheduling_params: Optional[SchedulingParams] = None,
):
if isinstance(prompt_token_ids, list):
self.prompt_token_ids = prompt_token_ids
Expand All @@ -119,6 +121,7 @@ def __init__(
self.kv_cache_retention_config = kv_cache_retention_config
self.id: Optional[int] = None
self.disaggregated_params = disaggregated_params
self.scheduling_params = scheduling_params

def set_id(self, id):
assert self.id is None, f"Request ID is already set: {self.id}"
Expand Down
4 changes: 4 additions & 0 deletions tensorrt_llm/executor/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,10 @@ def _deduce_max_tokens(request: GenerationRequest,
executor_request.py_logits_post_processors = lp if isinstance(
lp, list) else [lp]

executor_request.py_scheduling_params = None
if self._is_pytorch_backend and request.scheduling_params is not None:
executor_request.py_scheduling_params = request.scheduling_params

if request.query_token_ids is not None:
# pytorch star attention workflow
# a workaround to avoid public interface update
Expand Down
Loading