Skip to content

Commit d0ea3a1

Browse files
ShunkangShunkang
authored andcommitted
Refactor code
Signed-off-by: Shunkang <[email protected]>
1 parent a7e6ebd commit d0ea3a1

File tree

7 files changed

+65
-99
lines changed

7 files changed

+65
-99
lines changed

tensorrt_llm/_torch/pyexecutor/executor_request_queue.py

Lines changed: 44 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -83,45 +83,13 @@ def _get_from_request_queue(
8383
pass
8484
return items
8585

86-
def _get_from_waiting_queue_attention_tp(
86+
def _get_from_waiting_queue(
8787
self,
8888
waiting_queue: deque[RequestQueueItem],
8989
max_req_count: int,
90+
enable_attention_dp: bool,
9091
) -> List[RequestQueueItem]:
91-
"""Safely extracts up to max_req_count items from a deque.
9292

93-
Args:
94-
waiting_queue: The queue to pop items from.
95-
max_req_count: Maximum items to retrieve. Returns empty list if <=0.
96-
97-
Returns:
98-
List of retrieved items (may be shorter than max_req_count if queue empties first).
99-
"""
100-
# Edge case handling
101-
if max_req_count <= 0: # Handles negative/zero counts
102-
return []
103-
104-
items = []
105-
req_count = 0
106-
while req_count < max_req_count and waiting_queue:
107-
items.append(waiting_queue.popleft())
108-
req_count += 1
109-
return items
110-
111-
def _get_from_waiting_queue_attention_dp(
112-
self,
113-
waiting_queue: deque[RequestQueueItem],
114-
max_req_count: int,
115-
) -> List[RequestQueueItem]:
116-
"""Extract requests from waiting queue with attention DP load balancing.
117-
118-
Args:
119-
waiting_queue: Queue of pending requests
120-
max_req_count: Maximum number of requests to extract
121-
122-
Returns:
123-
List of requests that can be processed immediately
124-
"""
12593
if max_req_count <= 0:
12694
return []
12795

@@ -130,55 +98,45 @@ def _get_from_waiting_queue_attention_dp(
13098
pending_requests = []
13199

132100
# Track the request with strict requirements
133-
all_ranks_num_active_requests = self.all_ranks_num_active_requests.copy(
134-
)
101+
scheduling_all_ranks_num_active_requests = self.all_ranks_num_active_requests.copy(
102+
) if enable_attention_dp else None
135103
while req_count < max_req_count and waiting_queue:
136104
req_item = waiting_queue.popleft()
137-
can_process_now = self._can_process_attention_dp_request(
138-
req_item, all_ranks_num_active_requests)
105+
can_process = self._can_process_attention_dp_request(
106+
req_item, scheduling_all_ranks_num_active_requests
107+
) if enable_attention_dp else True
139108

140-
if can_process_now:
109+
if can_process:
141110
items.append(req_item)
142111
req_count += 1
143112
else:
144113
pending_requests.append(req_item)
145114

146115
# Put the pending requests back to the waiting queue
147116
# All ranks should have the same waiting queue
148-
self.waiting_queue.extendleft(pending_requests)
117+
self.waiting_queue.extendleft(reversed(pending_requests))
149118

150119
return items
151120

152121
def _can_process_attention_dp_request(
153122
self, req_item: RequestQueueItem,
154123
all_ranks_num_active_requests: List[int]) -> bool:
155-
"""Check if a request can be processed immediately.
124+
"""Return True if the request can be processed immediately, else False."""
156125

157-
Returns:
158-
True if the request can be processed now, False if it should be deferred.
159-
"""
160-
# Handle requests without schedule parameters
161-
if req_item.request.py_schedule_params is None:
126+
scheduling_params = req_item.request.py_scheduling_params
127+
if scheduling_params is None:
162128
return True
163129

164-
schedule_params = req_item.request.py_schedule_params
165-
target_dp_rank = schedule_params.attention_dp_rank
166-
is_relax = schedule_params.attention_dp_relax
167-
168-
# Handle requests without target rank or in relax mode
169-
if target_dp_rank is None or is_relax:
130+
target_dp_rank = scheduling_params.attention_dp_rank
131+
if target_dp_rank is None or scheduling_params.attention_dp_relax:
170132
return True
171133

172-
# Handle strict mode requests - check target rank capacity
173-
target_rank_has_capacity = (
174-
all_ranks_num_active_requests[target_dp_rank]
175-
< self.max_num_active_requests)
176-
177-
if target_rank_has_capacity:
134+
if all_ranks_num_active_requests[
135+
target_dp_rank] < self.max_num_active_requests:
178136
all_ranks_num_active_requests[target_dp_rank] += 1
179137
return True
180-
else:
181-
return False
138+
139+
return False
182140

183141
def enqueue_requests(self, requests: List[ExecutorRequest]):
184142
req_ids = []
@@ -238,7 +196,9 @@ def can_enqueue_request(self) -> bool:
238196
return can_enqueue and self.dist.rank == 0
239197

240198
def _fetch_and_process_requests(
241-
self, total_num_active_requests: int) -> List[RequestQueueItem]:
199+
self, total_num_active_requests: int,
200+
total_max_num_active_requests: int,
201+
enable_attention_dp: bool) -> List[RequestQueueItem]:
242202
"""Common logic for fetching and processing requests from the queue."""
243203
# Calculate timeout
244204
timeout = None if (total_num_active_requests == 0) and len(
@@ -264,6 +224,17 @@ def _fetch_and_process_requests(
264224

265225
self.waiting_queue.extend(new_requests)
266226

227+
new_requests = self._get_from_waiting_queue(
228+
self.waiting_queue,
229+
total_max_num_active_requests - total_num_active_requests,
230+
enable_attention_dp)
231+
232+
# Update performance metrics
233+
if self.enable_iter_perf_stats and self.dist.rank == 0:
234+
self._update_new_active_requests_queue_latency(new_requests)
235+
236+
return new_requests
237+
267238
@nvtx_range("_fetch_new_requests")
268239
def fetch_new_requests(self,
269240
num_active_requests: int) -> List[RequestQueueItem]:
@@ -280,15 +251,10 @@ def _fetch_new_requests_attention_tp(
280251
total_max_num_active_requests = self.max_num_active_requests
281252

282253
# fetch and process requests into waiting queue
283-
self._fetch_and_process_requests(total_num_active_requests)
284-
285-
new_requests = self._get_from_waiting_queue_attention_tp(
286-
self.waiting_queue,
287-
total_max_num_active_requests - total_num_active_requests)
288-
289-
# Update performance metrics
290-
if self.enable_iter_perf_stats and self.dist.rank == 0:
291-
self._update_new_active_requests_queue_latency(new_requests)
254+
new_requests = self._fetch_and_process_requests(
255+
total_num_active_requests,
256+
total_max_num_active_requests,
257+
enable_attention_dp=False)
292258

293259
# Merge requests and add to active list
294260
merged_requests = self._merge_requests(new_requests)
@@ -307,16 +273,10 @@ def _fetch_new_requests_attention_dp(
307273
total_max_num_active_requests = self.dist.tp_size * self.max_num_active_requests
308274

309275
# fetch and process requests into waiting queue
310-
self._fetch_and_process_requests(total_num_active_requests)
311-
312-
new_requests = self._get_from_waiting_queue_attention_dp(
313-
self.waiting_queue,
314-
total_max_num_active_requests - total_num_active_requests)
315-
316-
# Update performance metrics
317-
# TODO: Check whether we should update the performance metrics for all ranks
318-
if self.enable_iter_perf_stats and self.dist.rank == 0:
319-
self._update_new_active_requests_queue_latency(new_requests)
276+
new_requests = self._fetch_and_process_requests(
277+
total_num_active_requests,
278+
total_max_num_active_requests,
279+
enable_attention_dp=True)
320280

321281
# Schedule attention dp requests
322282
new_requests_cur_rank = self._schedule_attention_dp_requests(
@@ -342,9 +302,9 @@ def _schedule_attention_dp_requests(
342302

343303
# Prioritize the requests that are not in relax mode
344304
def get_relax_value(req_item):
345-
if req_item.request.py_schedule_params is None:
305+
if req_item.request.py_scheduling_params is None:
346306
return True
347-
return req_item.request.py_schedule_params.attention_dp_relax
307+
return req_item.request.py_scheduling_params.attention_dp_relax
348308

349309
new_requests = sorted(new_requests, key=get_relax_value, reverse=True)
350310

@@ -353,8 +313,8 @@ def get_relax_value(req_item):
353313
new_requests_cur_rank = []
354314
for req_item in new_requests:
355315
scheduled = False
356-
if req_item.request.py_schedule_params is not None:
357-
target_dp_rank = req_item.request.py_schedule_params.attention_dp_rank
316+
if req_item.request.py_scheduling_params is not None:
317+
target_dp_rank = req_item.request.py_scheduling_params.attention_dp_rank
358318
if target_dp_rank is not None and self.all_ranks_num_active_requests[
359319
target_dp_rank] < self.max_num_active_requests:
360320
self.all_ranks_num_active_requests[target_dp_rank] += 1

tensorrt_llm/executor/executor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
print_colored_debug)
3030
from ..sampling_params import (BatchedLogitsProcessor, LogprobParams,
3131
SamplingParams)
32-
from ..schedule_params import ScheduleParams
32+
from ..scheduling_params import SchedulingParams
3333
from .ipc import FusedIpcQueue
3434
from .postproc_worker import PostprocParams, PostprocWorkerConfig
3535
from .request import GenerationRequest, LoRARequest, PromptAdapterRequest
@@ -121,7 +121,7 @@ def generate_async(
121121
disaggregated_params: Optional[DisaggregatedParams] = None,
122122
postproc_params: Optional[PostprocParams] = None,
123123
multimodal_params: Optional[MultimodalParams] = None,
124-
schedule_params: Optional[ScheduleParams] = None,
124+
scheduling_params: Optional[SchedulingParams] = None,
125125
) -> GenerationResult:
126126
"""Generate output for the given prompt token ids in the asynchronous mode.
127127
Asynchronous generation accepts single prompt only.
@@ -145,7 +145,7 @@ def generate_async(
145145
kv_cache_retention_config=kv_cache_retention_config,
146146
disaggregated_params=disaggregated_params,
147147
multimodal_params=multimodal_params,
148-
schedule_params=schedule_params)
148+
scheduling_params=scheduling_params)
149149
result = self.submit(request)
150150
# release memory in time
151151
if hasattr(request, "multimodal_params"):

tensorrt_llm/executor/request.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from ..disaggregated_params import DisaggregatedParams
1111
from ..llmapi.llm_utils import KvCacheRetentionConfig
1212
from ..sampling_params import SamplingParams
13-
from ..schedule_params import ScheduleParams
13+
from ..scheduling_params import SchedulingParams
1414
from .postproc_worker import PostprocParams
1515

1616
__all__ = [
@@ -96,7 +96,7 @@ def __init__(
9696
disaggregated_params: Optional[DisaggregatedParams] = None,
9797
postproc_params: Optional[PostprocParams] = None,
9898
multimodal_params: Optional[MultimodalParams] = None,
99-
schedule_params: Optional[ScheduleParams] = None,
99+
scheduling_params: Optional[SchedulingParams] = None,
100100
):
101101
if isinstance(prompt_token_ids, list):
102102
self.prompt_token_ids = prompt_token_ids
@@ -121,7 +121,7 @@ def __init__(
121121
self.kv_cache_retention_config = kv_cache_retention_config
122122
self.id: Optional[int] = None
123123
self.disaggregated_params = disaggregated_params
124-
self.schedule_params = schedule_params
124+
self.scheduling_params = scheduling_params
125125

126126
def set_id(self, id):
127127
assert self.id is None, f"Request ID is already set: {self.id}"

tensorrt_llm/executor/worker.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -510,8 +510,8 @@ def _deduce_max_tokens(request: GenerationRequest,
510510
lp, list) else [lp]
511511

512512
executor_request.py_schedule_params = None
513-
if self._is_pytorch_backend and request.schedule_params is not None:
514-
executor_request.py_schedule_params = request.schedule_params
513+
if self._is_pytorch_backend and request.scheduling_params is not None:
514+
executor_request.py_scheduling_params = request.scheduling_params
515515

516516
if request.query_token_ids is not None:
517517
# pytorch star attention workflow

tensorrt_llm/llmapi/llm.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
create_input_processor_with_hash, prompt_inputs)
3131
from ..logger import logger
3232
from ..sampling_params import SamplingParams
33-
from ..schedule_params import ScheduleParams
33+
from ..scheduling_params import SchedulingParams
3434
from .llm_args import (TORCH_LLMARGS_EXPLICIT_DOCSTRING,
3535
TRT_LLMARGS_EXPLICIT_DOCSTRING, PybindMirror,
3636
TorchLlmArgs, TrtLlmArgs)
@@ -236,8 +236,8 @@ def generate(
236236
KvCacheRetentionConfig, Sequence[KvCacheRetentionConfig]]] = None,
237237
disaggregated_params: Optional[Union[
238238
DisaggregatedParams, Sequence[DisaggregatedParams]]] = None,
239-
schedule_params: Optional[Union[ScheduleParams,
240-
List[ScheduleParams]]] = None,
239+
scheduling_params: Optional[Union[SchedulingParams,
240+
List[SchedulingParams]]] = None,
241241
) -> Union[RequestOutput, List[RequestOutput]]:
242242
"""Generate output for the given prompts in the synchronous mode.
243243
Synchronous generation accepts either single prompt or batched prompts.
@@ -285,7 +285,7 @@ def _item_at(maybe_batched: Union[Any, Sequence[Any]], pos: int) -> Any:
285285
kv_cache_retention_config=_item_at(kv_cache_retention_config,
286286
i),
287287
disaggregated_params=_item_at(disaggregated_params, i),
288-
schedule_params=_item_at(schedule_params, i),
288+
scheduling_params=_item_at(scheduling_params, i),
289289
streaming=False)
290290
futures.append(future)
291291

@@ -311,7 +311,7 @@ def generate_async(
311311
kv_cache_retention_config: Optional[KvCacheRetentionConfig] = None,
312312
disaggregated_params: Optional[DisaggregatedParams] = None,
313313
_postproc_params: Optional[PostprocParams] = None,
314-
schedule_params: Optional[ScheduleParams] = None,
314+
scheduling_params: Optional[SchedulingParams] = None,
315315
) -> RequestOutput:
316316
"""Generate output for the given prompt in the asynchronous mode.
317317
Asynchronous generation accepts single prompt only.
@@ -422,7 +422,7 @@ def generate_async(
422422
disaggregated_params=disaggregated_params,
423423
postproc_params=_postproc_params,
424424
multimodal_params=multimodal_params,
425-
schedule_params=schedule_params,
425+
scheduling_params=scheduling_params,
426426
)
427427

428428
return RequestOutput._from_generation_result(result, prompt,

tensorrt_llm/schedule_params.py renamed to tensorrt_llm/scheduling_params.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44

55
@dataclass(slots=True, kw_only=True)
6-
class ScheduleParams:
6+
class SchedulingParams:
77
"""Schedule parameters.
88
99
Args:

tests/unittest/api_stability/references/llm.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,9 @@ methods:
126126
kv_cache_retention_config:
127127
annotation: Union[tensorrt_llm.bindings.executor.KvCacheRetentionConfig, Sequence[tensorrt_llm.bindings.executor.KvCacheRetentionConfig], NoneType]
128128
default: null
129+
scheduling_params:
130+
annotation: Optional[tensorrt_llm.scheduling_params.SchedulingParams]
131+
default: null
129132
return_annotation: Union[tensorrt_llm.llmapi.llm.RequestOutput, List[tensorrt_llm.llmapi.llm.RequestOutput]]
130133
generate_async:
131134
parameters:
@@ -135,6 +138,9 @@ methods:
135138
kv_cache_retention_config:
136139
annotation: Optional[tensorrt_llm.bindings.executor.KvCacheRetentionConfig]
137140
default: null
141+
scheduling_params:
142+
annotation: Optional[tensorrt_llm.scheduling_params.SchedulingParams]
143+
default: null
138144
return_annotation: tensorrt_llm.llmapi.llm.RequestOutput
139145
get_kv_cache_events:
140146
parameters:

0 commit comments

Comments
 (0)