Skip to content

Commit 35da779

Browse files
committed
optimize: ADP schedule optimization
Signed-off-by: yunruis <[email protected]>
1 parent 93a0fd0 commit 35da779

File tree

3 files changed

+125
-2
lines changed

3 files changed

+125
-2
lines changed

tensorrt_llm/_torch/pyexecutor/config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,10 @@ class PyTorchConfig:
4646
moe_max_num_tokens: Optional[int] = None
4747
moe_load_balancer: Optional[Union[MoeLoadBalancerConfig, dict, str]] = None
4848

49+
use_attention_dp_config: bool = False
50+
attention_dp_time_out_iters: int = 0
51+
attention_dp_batching_wait_iters: int = 0
52+
4953
attn_backend: str = 'TRTLLM'
5054
moe_backend: str = 'CUTLASS'
5155

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,9 @@ def __init__(self,
176176
self.enable_iter_perf_stats = model_engine.pytorch_backend_config.enable_iter_perf_stats
177177
self.enable_iter_req_stats = model_engine.pytorch_backend_config.enable_iter_req_stats
178178
self.stream_interval = model_engine.pytorch_backend_config.stream_interval
179+
self.use_attention_dp_config = model_engine.pytorch_backend_config.use_attention_dp_config
180+
self.attention_dp_time_out_iters = model_engine.pytorch_backend_config.attention_dp_time_out_iters
181+
self.attention_dp_batching_wait_iters = model_engine.pytorch_backend_config.attention_dp_batching_wait_iters
179182
self.num_fetch_requests_cur_rank = 0
180183
self.num_fetch_requests = 0
181184
self.shutdown_event = threading.Event()
@@ -214,6 +217,9 @@ def __init__(self,
214217
self.draft_model_engine.warmup(self.resource_manager)
215218

216219
self.is_shutdown = False
220+
self.max_batch_size = max_batch_size
221+
self.self.adp_ctx_waiting_iters_count = 0
222+
self.adp_ctx_batching_wait_iters_count = 0
217223

218224
# request fetcher initialization
219225
self.executor_request_queue = ExecutorRequestQueue(
@@ -1119,8 +1125,66 @@ def _schedule(self):
11191125
scheduler_output = self.scheduler.schedule_request(
11201126
self.active_requests, self.inflight_req_ids)
11211127
scheduled_requests = ScheduledRequests()
1128+
context_requests = scheduler_output.context_requests
1129+
if self.enable_attention_dp:
1130+
num_scheduled_context_requests = len(
1131+
scheduler_output.context_requests)
1132+
num_scheduled_generation_requests = len(
1133+
scheduler_output.generation_requests)
1134+
num_scheduled_tokens = sum([
1135+
len(req.get_tokens(0)) for req in context_requests
1136+
]) + num_scheduled_generation_requests
1137+
responses_list = self.dist.tp_allgather([
1138+
num_scheduled_context_requests,
1139+
num_scheduled_generation_requests, num_scheduled_tokens
1140+
])
1141+
all_ranks_num_scheduled_context_requests = [
1142+
response[0] for response in responses_list
1143+
]
1144+
all_ranks_num_scheduled_generation_requests = [
1145+
response[1] for response in responses_list
1146+
]
1147+
all_ranks_num_scheduled_tokens = [
1148+
response[2] for response in responses_list
1149+
]
1150+
1151+
all_ranks_have_free_ctx_slots = all([
1152+
num_gen < self.max_batch_size
1153+
for num_gen in all_ranks_num_scheduled_generation_requests
1154+
])
1155+
all_ranks_have_multi_gen = all([
1156+
num_gen > 1
1157+
for num_gen in all_ranks_num_scheduled_generation_requests
1158+
])
1159+
all_ranks_have_ctx_requests = all([
1160+
num_ctx > 0
1161+
for num_ctx in all_ranks_num_scheduled_context_requests
1162+
])
1163+
1164+
all_ranks_have_gen_requests = all([
1165+
num_gen > 0
1166+
for num_gen in all_ranks_num_scheduled_generation_requests
1167+
])
1168+
1169+
if self.use_attention_dp_config:
1170+
# wait for all ranks have context requests
1171+
if all_ranks_have_free_ctx_slots and all_ranks_have_ctx_requests:
1172+
self.self.adp_ctx_waiting_iters_count = 0
1173+
# balance number of context requests across ranks
1174+
if all_ranks_have_gen_requests:
1175+
if self.adp_ctx_batching_wait_iters_count < self.attention_dp_batching_wait_iters:
1176+
self.adp_ctx_batching_wait_iters_count += 1
1177+
context_requests = []
1178+
else:
1179+
self.adp_ctx_batching_wait_iters_count = 0
1180+
else:
1181+
self.self.adp_ctx_waiting_iters_count += 1
1182+
context_requests = []
1183+
if self.self.adp_ctx_waiting_iters_count >= self.attention_dp_time_out_iters or not all_ranks_have_gen_requests:
1184+
self.self.adp_ctx_waiting_iters_count = 0
1185+
context_requests = scheduler_output.context_requests
11221186

1123-
scheduled_requests.context_requests = scheduler_output.context_requests
1187+
scheduled_requests.context_requests = context_requests
11241188
scheduled_requests.generation_requests = scheduler_output.generation_requests
11251189
scheduled_requests.paused_requests = scheduler_output.paused_requests
11261190
return scheduled_requests, scheduler_output.fitting_disagg_gen_init_requests, scheduler_output.num_fitting_requests

tensorrt_llm/llmapi/llm_args.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,23 @@ def validate_cuda_graph_max_batch_size(cls, v):
8989
return v
9090

9191

92+
class AttentionDpConfig(BaseModel):
93+
"""
94+
Configuration for attention DP.
95+
"""
96+
enable_balance: bool = Field(default=False,
97+
description="Whether to enable balance.")
98+
batching_wait_iters: int = Field(
99+
default=0,
100+
description="The number of iterations to wait for batching.")
101+
timeout_iters: int = Field(
102+
default=0, description="The number of iterations to timeout.")
103+
104+
@classmethod
105+
def from_dict(cls, data: dict):
106+
return cls(**data)
107+
108+
92109
class MoeConfig(BaseModel):
93110
"""
94111
Configuration for MoE.
@@ -1817,6 +1834,11 @@ class TorchLlmArgs(BaseLlmArgs):
18171834
since the input shapes are a function of the sequence lengths).\
18181835
Note that each CUDA graph can use up to 200 MB of extra memory.")
18191836

1837+
attention_dp_config: Optional[AttentionDpConfig] = Field(
1838+
default=None,
1839+
description=
1840+
"Attention DP config. If true, use attention DP optimized scheduler.")
1841+
18201842
disable_overlap_scheduler: bool = Field(
18211843
default=False, description="Disable the overlap scheduler.")
18221844

@@ -2075,6 +2097,31 @@ def sync_quant_config_with_kv_cache_config_dtype(self) -> 'TorchLlmArgs':
20752097
"please update the validator")
20762098
return self
20772099

2100+
@model_validator(mode='after')
2101+
def validate_attention_dp_config(self) -> 'TorchLlmArgs':
2102+
"""Validate attention DP configuration.
2103+
2104+
Ensures that:
2105+
1. If attention_dp_config.enable_balance is true, attention_dp_config.batching_wait_iters must be greater than 0
2106+
2. If attention_dp_config.enable_balance is true, attention_dp_config.timeout_iters must be greater than 0
2107+
"""
2108+
if self.attention_dp_config is None:
2109+
return self
2110+
2111+
config = self.attention_dp_config
2112+
if config.enable_balance:
2113+
if config.batching_wait_iters < 0:
2114+
raise ValueError(
2115+
"attention_dp_config.batching_wait_iters must be greater than 0 when enable_balance is true"
2116+
)
2117+
if config.timeout_iters < 0:
2118+
raise ValueError(
2119+
"attention_dp_config.timeout_iters must be greater than 0 when enable_balance is true"
2120+
)
2121+
return self
2122+
2123+
2124+
20782125
# TODO: Remove this after the PyTorch backend is fully migrated to TorchLlmArgs from ExecutorConfig
20792126
def get_pytorch_backend_config(self) -> "PyTorchConfig":
20802127
from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
@@ -2125,7 +2172,14 @@ def get_pytorch_backend_config(self) -> "PyTorchConfig":
21252172
enable_min_latency=self.enable_min_latency,
21262173
stream_interval=self.stream_interval,
21272174
force_dynamic_quantization=self.force_dynamic_quantization,
2128-
allreduce_strategy=self.allreduce_strategy)
2175+
allreduce_strategy=self.allreduce_strategy,
2176+
use_attention_dp_config=bool(self.attention_dp_config is not None),
2177+
attention_dp_time_out_iters=self.attention_dp_config.timeout_iters
2178+
if self.attention_dp_config is not None else
2179+
AttentionDpConfig.model_fields['timeout_iters'].default,
2180+
attention_dp_batching_wait_iters=self.attention_dp_config.
2181+
batching_wait_iters if self.attention_dp_config is not None else
2182+
AttentionDpConfig.model_fields['batching_wait_iters'].default)
21292183

21302184

21312185
def update_llm_args_with_extra_dict(
@@ -2142,6 +2196,7 @@ def update_llm_args_with_extra_dict(
21422196
"speculative_config": DecodingBaseConfig,
21432197
"lora_config": LoraConfig,
21442198
"moe_config": MoeConfig,
2199+
"attention_dp_config": AttentionDpConfig,
21452200
}
21462201
for field_name, field_type in field_mapping.items():
21472202
if field_name in llm_args_dict:

0 commit comments

Comments
 (0)