Skip to content

Commit dbcd7d2

Browse files
committed
optimize: ADP schedule optimization
Signed-off-by: yunruis <[email protected]>
1 parent 03632a6 commit dbcd7d2

File tree

3 files changed

+125
-2
lines changed

3 files changed

+125
-2
lines changed

tensorrt_llm/_torch/pyexecutor/config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,10 @@ class PyTorchConfig:
4646
moe_max_num_tokens: Optional[int] = None
4747
moe_load_balancer: Optional[Union[MoeLoadBalancerConfig, dict, str]] = None
4848

49+
use_attention_dp_config: bool = False
50+
attention_dp_time_out_iters: int = 0
51+
attention_dp_batching_wait_iters: int = 0
52+
4953
attn_backend: str = 'TRTLLM'
5054
moe_backend: str = 'CUTLASS'
5155

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,9 @@ def __init__(self,
176176
self.enable_iter_perf_stats = model_engine.pytorch_backend_config.enable_iter_perf_stats
177177
self.enable_iter_req_stats = model_engine.pytorch_backend_config.enable_iter_req_stats
178178
self.stream_interval = model_engine.pytorch_backend_config.stream_interval
179+
self.use_attention_dp_config = model_engine.pytorch_backend_config.use_attention_dp_config
180+
self.attention_dp_time_out_iters = model_engine.pytorch_backend_config.attention_dp_time_out_iters
181+
self.attention_dp_batching_wait_iters = model_engine.pytorch_backend_config.attention_dp_batching_wait_iters
179182
self.num_fetch_requests_cur_rank = 0
180183
self.num_fetch_requests = 0
181184
self.shutdown_event = threading.Event()
@@ -214,6 +217,9 @@ def __init__(self,
214217
self.draft_model_engine.warmup(self.resource_manager)
215218

216219
self.is_shutdown = False
220+
self.max_batch_size = max_batch_size
221+
self.self.adp_ctx_waiting_iters_count = 0
222+
self.adp_ctx_batching_wait_iters_count = 0
217223

218224
# request fetcher initialization
219225
self.executor_request_queue = ExecutorRequestQueue(
@@ -1088,8 +1094,66 @@ def _schedule(self):
10881094
scheduler_output = self.scheduler.schedule_request(
10891095
self.active_requests, self.inflight_req_ids)
10901096
scheduled_requests = ScheduledRequests()
1097+
context_requests = scheduler_output.context_requests
1098+
if self.enable_attention_dp:
1099+
num_scheduled_context_requests = len(
1100+
scheduler_output.context_requests)
1101+
num_scheduled_generation_requests = len(
1102+
scheduler_output.generation_requests)
1103+
num_scheduled_tokens = sum([
1104+
len(req.get_tokens(0)) for req in context_requests
1105+
]) + num_scheduled_generation_requests
1106+
responses_list = self.dist.tp_allgather([
1107+
num_scheduled_context_requests,
1108+
num_scheduled_generation_requests, num_scheduled_tokens
1109+
])
1110+
all_ranks_num_scheduled_context_requests = [
1111+
response[0] for response in responses_list
1112+
]
1113+
all_ranks_num_scheduled_generation_requests = [
1114+
response[1] for response in responses_list
1115+
]
1116+
all_ranks_num_scheduled_tokens = [
1117+
response[2] for response in responses_list
1118+
]
1119+
1120+
all_ranks_have_free_ctx_slots = all([
1121+
num_gen < self.max_batch_size
1122+
for num_gen in all_ranks_num_scheduled_generation_requests
1123+
])
1124+
all_ranks_have_multi_gen = all([
1125+
num_gen > 1
1126+
for num_gen in all_ranks_num_scheduled_generation_requests
1127+
])
1128+
all_ranks_have_ctx_requests = all([
1129+
num_ctx > 0
1130+
for num_ctx in all_ranks_num_scheduled_context_requests
1131+
])
1132+
1133+
all_ranks_have_gen_requests = all([
1134+
num_gen > 0
1135+
for num_gen in all_ranks_num_scheduled_generation_requests
1136+
])
1137+
1138+
if self.use_attention_dp_config:
1139+
# wait for all ranks have context requests
1140+
if all_ranks_have_free_ctx_slots and all_ranks_have_ctx_requests:
1141+
self.self.adp_ctx_waiting_iters_count = 0
1142+
# balance number of context requests across ranks
1143+
if all_ranks_have_gen_requests:
1144+
if self.adp_ctx_batching_wait_iters_count < self.attention_dp_batching_wait_iters:
1145+
self.adp_ctx_batching_wait_iters_count += 1
1146+
context_requests = []
1147+
else:
1148+
self.adp_ctx_batching_wait_iters_count = 0
1149+
else:
1150+
self.self.adp_ctx_waiting_iters_count += 1
1151+
context_requests = []
1152+
if self.self.adp_ctx_waiting_iters_count >= self.attention_dp_time_out_iters or not all_ranks_have_gen_requests:
1153+
self.self.adp_ctx_waiting_iters_count = 0
1154+
context_requests = scheduler_output.context_requests
10911155

1092-
scheduled_requests.context_requests = scheduler_output.context_requests
1156+
scheduled_requests.context_requests = context_requests
10931157
scheduled_requests.generation_requests = scheduler_output.generation_requests
10941158
scheduled_requests.paused_requests = scheduler_output.paused_requests
10951159
return scheduled_requests, scheduler_output.fitting_disagg_gen_init_requests, scheduler_output.num_fitting_requests

tensorrt_llm/llmapi/llm_args.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,23 @@ def validate_cuda_graph_max_batch_size(cls, v):
120120
return v
121121

122122

123+
class AttentionDpConfig(BaseModel):
124+
"""
125+
Configuration for attention DP.
126+
"""
127+
enable_balance: bool = Field(default=False,
128+
description="Whether to enable balance.")
129+
batching_wait_iters: int = Field(
130+
default=0,
131+
description="The number of iterations to wait for batching.")
132+
timeout_iters: int = Field(
133+
default=0, description="The number of iterations to timeout.")
134+
135+
@classmethod
136+
def from_dict(cls, data: dict):
137+
return cls(**data)
138+
139+
123140
class MoeConfig(BaseModel):
124141
"""
125142
Configuration for MoE.
@@ -1876,6 +1893,11 @@ class TorchLlmArgs(BaseLlmArgs):
18761893
Note that each CUDA graph can use up to 200 MB of extra memory.",
18771894
status="beta")
18781895

1896+
attention_dp_config: Optional[AttentionDpConfig] = Field(
1897+
default=None,
1898+
description=
1899+
"Attention DP config. If true, use attention DP optimized scheduler.")
1900+
18791901
disable_overlap_scheduler: bool = Field(
18801902
default=False,
18811903
description="Disable the overlap scheduler.",
@@ -2173,6 +2195,31 @@ def warn_on_unstable_feature_usage(self) -> 'TorchLlmArgs':
21732195

21742196
return self
21752197

2198+
@model_validator(mode='after')
2199+
def validate_attention_dp_config(self) -> 'TorchLlmArgs':
2200+
"""Validate attention DP configuration.
2201+
2202+
Ensures that:
2203+
1. If attention_dp_config.enable_balance is true, attention_dp_config.batching_wait_iters must be greater than 0
2204+
2. If attention_dp_config.enable_balance is true, attention_dp_config.timeout_iters must be greater than 0
2205+
"""
2206+
if self.attention_dp_config is None:
2207+
return self
2208+
2209+
config = self.attention_dp_config
2210+
if config.enable_balance:
2211+
if config.batching_wait_iters < 0:
2212+
raise ValueError(
2213+
"attention_dp_config.batching_wait_iters must be greater than 0 when enable_balance is true"
2214+
)
2215+
if config.timeout_iters < 0:
2216+
raise ValueError(
2217+
"attention_dp_config.timeout_iters must be greater than 0 when enable_balance is true"
2218+
)
2219+
return self
2220+
2221+
2222+
21762223
# TODO: Remove this after the PyTorch backend is fully migrated to TorchLlmArgs from ExecutorConfig
21772224
def get_pytorch_backend_config(self) -> "PyTorchConfig":
21782225
from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
@@ -2223,7 +2270,14 @@ def get_pytorch_backend_config(self) -> "PyTorchConfig":
22232270
enable_min_latency=self.enable_min_latency,
22242271
stream_interval=self.stream_interval,
22252272
force_dynamic_quantization=self.force_dynamic_quantization,
2226-
allreduce_strategy=self.allreduce_strategy)
2273+
allreduce_strategy=self.allreduce_strategy,
2274+
use_attention_dp_config=bool(self.attention_dp_config is not None),
2275+
attention_dp_time_out_iters=self.attention_dp_config.timeout_iters
2276+
if self.attention_dp_config is not None else
2277+
AttentionDpConfig.model_fields['timeout_iters'].default,
2278+
attention_dp_batching_wait_iters=self.attention_dp_config.
2279+
batching_wait_iters if self.attention_dp_config is not None else
2280+
AttentionDpConfig.model_fields['batching_wait_iters'].default)
22272281

22282282

22292283
def update_llm_args_with_extra_dict(
@@ -2240,6 +2294,7 @@ def update_llm_args_with_extra_dict(
22402294
"speculative_config": DecodingBaseConfig,
22412295
"lora_config": LoraConfig,
22422296
"moe_config": MoeConfig,
2297+
"attention_dp_config": AttentionDpConfig,
22432298
}
22442299
for field_name, field_type in field_mapping.items():
22452300
if field_name in llm_args_dict:

0 commit comments

Comments
 (0)