Skip to content

Commit 84bb95c

Browse files
ShunkangShunkang
authored andcommitted
Add schedulParams
Signed-off-by: Shunkang <[email protected]>
1 parent ee45e0c commit 84bb95c

File tree

5 files changed

+31
-1
lines changed

5 files changed

+31
-1
lines changed

tensorrt_llm/executor/executor.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
print_colored_debug)
3030
from ..sampling_params import (BatchedLogitsProcessor, LogprobParams,
3131
SamplingParams)
32+
from ..schedule_params import ScheduleParams
3233
from .ipc import FusedIpcQueue
3334
from .postproc_worker import PostprocParams, PostprocWorkerConfig
3435
from .request import GenerationRequest, LoRARequest, PromptAdapterRequest
@@ -120,6 +121,7 @@ def generate_async(
120121
disaggregated_params: Optional[DisaggregatedParams] = None,
121122
postproc_params: Optional[PostprocParams] = None,
122123
multimodal_params: Optional[MultimodalParams] = None,
124+
schedule_params: Optional[ScheduleParams] = None,
123125
) -> GenerationResult:
124126
"""Generate output for the given prompt token ids in the asynchronous mode.
125127
Asynchronous generation accepts single prompt only.
@@ -142,7 +144,8 @@ def generate_async(
142144
streaming=streaming,
143145
kv_cache_retention_config=kv_cache_retention_config,
144146
disaggregated_params=disaggregated_params,
145-
multimodal_params=multimodal_params)
147+
multimodal_params=multimodal_params,
148+
schedule_params=schedule_params)
146149
result = self.submit(request)
147150
# release memory in time
148151
if hasattr(request, "multimodal_params"):

tensorrt_llm/executor/request.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from ..disaggregated_params import DisaggregatedParams
1111
from ..llmapi.llm_utils import KvCacheRetentionConfig
1212
from ..sampling_params import SamplingParams
13+
from ..schedule_params import ScheduleParams
1314
from .postproc_worker import PostprocParams
1415

1516
__all__ = [
@@ -86,6 +87,7 @@ def __init__(
8687
disaggregated_params: Optional[DisaggregatedParams] = None,
8788
postproc_params: Optional[PostprocParams] = None,
8889
multimodal_params: Optional[MultimodalParams] = None,
90+
schedule_params: Optional[ScheduleParams] = None,
8991
):
9092
if isinstance(prompt_token_ids, list):
9193
self.prompt_token_ids = prompt_token_ids
@@ -110,6 +112,7 @@ def __init__(
110112
self.kv_cache_retention_config = kv_cache_retention_config
111113
self.id: Optional[int] = None
112114
self.disaggregated_params = disaggregated_params
115+
self.schedule_params = schedule_params
113116

114117
def set_id(self, id):
115118
assert self.id is None, f"Request ID is already set: {self.id}"

tensorrt_llm/executor/worker.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -508,6 +508,9 @@ def _deduce_max_tokens(request: GenerationRequest,
508508
executor_request.py_logits_post_processors = lp if isinstance(
509509
lp, list) else [lp]
510510

511+
if self._is_pytorch_backend and request.schedule_params is not None:
512+
executor_request.py_schedule_params = request.schedule_params
513+
511514
if request.query_token_ids is not None:
512515
# pytorch star attention workflow
513516
# a workaround to avoid public interface update

tensorrt_llm/llmapi/llm.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
create_input_processor_with_hash, prompt_inputs)
3131
from ..logger import logger
3232
from ..sampling_params import SamplingParams
33+
from ..schedule_params import ScheduleParams
3334
from .llm_args import (TORCH_LLMARGS_EXPLICIT_DOCSTRING,
3435
TRT_LLMARGS_EXPLICIT_DOCSTRING, PybindMirror,
3536
TorchLlmArgs, TrtLlmArgs)
@@ -235,6 +236,8 @@ def generate(
235236
KvCacheRetentionConfig, Sequence[KvCacheRetentionConfig]]] = None,
236237
disaggregated_params: Optional[Union[
237238
DisaggregatedParams, Sequence[DisaggregatedParams]]] = None,
239+
schedule_params: Optional[Union[ScheduleParams,
240+
List[ScheduleParams]]] = None,
238241
) -> Union[RequestOutput, List[RequestOutput]]:
239242
"""Generate output for the given prompts in the synchronous mode.
240243
Synchronous generation accepts either single prompt or batched prompts.
@@ -282,6 +285,7 @@ def _item_at(maybe_batched: Union[Any, Sequence[Any]], pos: int) -> Any:
282285
kv_cache_retention_config=_item_at(kv_cache_retention_config,
283286
i),
284287
disaggregated_params=_item_at(disaggregated_params, i),
288+
schedule_params=_item_at(schedule_params, i),
285289
streaming=False)
286290
futures.append(future)
287291

@@ -307,6 +311,7 @@ def generate_async(
307311
kv_cache_retention_config: Optional[KvCacheRetentionConfig] = None,
308312
disaggregated_params: Optional[DisaggregatedParams] = None,
309313
_postproc_params: Optional[PostprocParams] = None,
314+
schedule_params: Optional[ScheduleParams] = None,
310315
) -> RequestOutput:
311316
"""Generate output for the given prompt in the asynchronous mode.
312317
Asynchronous generation accepts single prompt only.
@@ -417,6 +422,7 @@ def generate_async(
417422
disaggregated_params=disaggregated_params,
418423
postproc_params=_postproc_params,
419424
multimodal_params=multimodal_params,
425+
schedule_params=schedule_params,
420426
)
421427

422428
return RequestOutput._from_generation_result(result, prompt,

tensorrt_llm/schedule_params.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
from dataclasses import dataclass
2+
from typing import Optional
3+
4+
5+
@dataclass(slots=True, kw_only=True)
6+
class ScheduleParams:
7+
"""Schedule parameters.
8+
9+
Args:
10+
attention_dp_rank (int): The rank of target attention dp
11+
attention_dp_relax (bool): Whether to allow the request to be scheduled to other attention dp for better throughput
12+
"""
13+
14+
attention_dp_rank: Optional[int] = None
15+
attention_dp_relax: Optional[bool] = None

0 commit comments

Comments
 (0)