Skip to content

Commit 01303b3

Browse files
committed
optimize: ADP schedule optimization
Signed-off-by: yunruis <[email protected]>
1 parent 2a147c4 commit 01303b3

File tree

5 files changed

+135
-5
lines changed

5 files changed

+135
-5
lines changed

examples/llm-api/quickstart_advanced.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def add_llm_args(parser):
5050
parser.add_argument('--moe_backend',
5151
type=str,
5252
default='CUTLASS',
53-
choices=['CUTLASS', 'TRTLLM', 'VANILLA'])
53+
choices=['CUTLASS', 'TRTLLM', 'VANILLA', 'WIDEEP'])
5454
parser.add_argument('--enable_attention_dp',
5555
default=False,
5656
action='store_true')

tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ def select_alltoall_method_type(mapping: Mapping, top_k: int,
245245
if os.environ.get("TRTLLM_MOE_DISABLE_ALLTOALLV", "0") == "1":
246246
return AlltoallMethodType.NotEnabled
247247

248-
if mapping.moe_ep_size <= top_k:
248+
if mapping.moe_ep_size < top_k:
249249
return AlltoallMethodType.NotEnabled
250250

251251
if MnnvlMemory.supports_mnnvl():

tensorrt_llm/_torch/pyexecutor/config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,10 @@ class PyTorchConfig:
4444
moe_max_num_tokens: Optional[int] = None
4545
moe_load_balancer: Optional[Union[MoeLoadBalancerConfig, dict, str]] = None
4646

47+
use_attention_dp_config: bool = False
48+
attention_dp_time_out_iters: int = 500
49+
attention_dp_batching_wait_iters: int = 0
50+
4751
attn_backend: str = 'TRTLLM'
4852
moe_backend: str = 'CUTLASS'
4953

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 75 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,9 @@ def __init__(self,
241241
self.enable_iter_perf_stats = model_engine.pytorch_backend_config.enable_iter_perf_stats
242242
self.enable_iter_req_stats = model_engine.pytorch_backend_config.enable_iter_req_stats
243243
self.stream_interval = model_engine.pytorch_backend_config.stream_interval
244+
self.use_attention_dp_config = model_engine.pytorch_backend_config.use_attention_dp_config
245+
self.attention_dp_time_out_iters = model_engine.pytorch_backend_config.attention_dp_time_out_iters
246+
self.attention_dp_batching_wait_iters = model_engine.pytorch_backend_config.attention_dp_batching_wait_iters
244247
self.num_fetch_requests_cur_rank = 0
245248
self.num_fetch_requests = 0
246249
self.shutdown_event = threading.Event()
@@ -287,6 +290,9 @@ def __init__(self,
287290
self.draft_model_engine.warmup(self.resource_manager)
288291

289292
self.is_shutdown = False
293+
self.max_batch_size = max_batch_size
294+
self.adp_ctx_waiting_iters = 0
295+
self.adp_ctx_batching_wait_iters = 0
290296

291297
self.stats_lock = threading.Lock()
292298
self.stats = []
@@ -1228,7 +1234,16 @@ def _broadcast_new_requests(
12281234
def _fetch_new_requests(self) -> List[RequestQueueItem]:
12291235
if self.enable_attention_dp:
12301236
all_ranks_num_active_requests = []
1231-
responses_list = self.dist.tp_allgather(len(self.active_requests))
1237+
num_active_requests = len(self.active_requests)
1238+
responses_list = self.dist.tp_allgather(num_active_requests)
1239+
# Debug check - remove after verification
1240+
if not all(isinstance(x, int) for x in responses_list):
1241+
raise RuntimeError(
1242+
f"tp_allgather returned non-integer values: {responses_list} "
1243+
+
1244+
f"Expected all ranks to return int from {num_active_requests} and {self.active_requests}."
1245+
)
1246+
12321247
for num_active_requests in responses_list:
12331248
all_ranks_num_active_requests.append(num_active_requests)
12341249
total_num_active_requests = sum(all_ranks_num_active_requests)
@@ -1518,8 +1533,66 @@ def _schedule(self):
15181533
scheduler_output = self.scheduler.schedule_request(
15191534
self.active_requests, self.inflight_req_ids)
15201535
scheduled_requests = ScheduledRequests()
1536+
context_requests = scheduler_output.context_requests
1537+
if self.enable_attention_dp:
1538+
num_scheduled_context_requests = len(
1539+
scheduler_output.context_requests)
1540+
num_scheduled_generation_requests = len(
1541+
scheduler_output.generation_requests)
1542+
num_scheduled_tokens = sum([
1543+
len(req.get_tokens(0)) for req in context_requests
1544+
]) + num_scheduled_generation_requests
1545+
responses_list = self.dist.tp_allgather([
1546+
num_scheduled_context_requests,
1547+
num_scheduled_generation_requests, num_scheduled_tokens
1548+
])
1549+
all_ranks_num_scheduled_context_requests = [
1550+
response[0] for response in responses_list
1551+
]
1552+
all_ranks_num_scheduled_generation_requests = [
1553+
response[1] for response in responses_list
1554+
]
1555+
all_ranks_num_scheduled_tokens = [
1556+
response[2] for response in responses_list
1557+
]
1558+
1559+
all_ranks_have_free_ctx_slots = all([
1560+
num_gen < self.max_batch_size
1561+
for num_gen in all_ranks_num_scheduled_generation_requests
1562+
])
1563+
all_ranks_have_multi_gen = all([
1564+
num_gen > 1
1565+
for num_gen in all_ranks_num_scheduled_generation_requests
1566+
])
1567+
all_ranks_have_ctx_requests = all([
1568+
num_ctx > 0
1569+
for num_ctx in all_ranks_num_scheduled_context_requests
1570+
])
1571+
1572+
all_ranks_have_gen_requests = all([
1573+
num_gen > 0
1574+
for num_gen in all_ranks_num_scheduled_generation_requests
1575+
])
1576+
1577+
if self.use_attention_dp_config:
1578+
# wait for all ranks have context requests
1579+
if all_ranks_have_free_ctx_slots and all_ranks_have_ctx_requests:
1580+
self.adp_ctx_waiting_iters = 0
1581+
# balance number of context requests across ranks
1582+
if all_ranks_have_gen_requests:
1583+
if self.adp_ctx_batching_wait_iters <= self.attention_dp_batching_wait_iters:
1584+
self.adp_ctx_batching_wait_iters += 1
1585+
context_requests = []
1586+
else:
1587+
self.adp_ctx_batching_wait_iters = 0
1588+
else:
1589+
self.adp_ctx_waiting_iters += 1
1590+
context_requests = []
1591+
if self.adp_ctx_waiting_iters >= self.attention_dp_time_out_iters or not all_ranks_have_gen_requests:
1592+
self.adp_ctx_waiting_iters = 0
1593+
context_requests = scheduler_output.context_requests
15211594

1522-
scheduled_requests.context_requests = scheduler_output.context_requests
1595+
scheduled_requests.context_requests = context_requests
15231596
scheduled_requests.generation_requests = scheduler_output.generation_requests
15241597
scheduled_requests.paused_requests = scheduler_output.paused_requests
15251598
return scheduled_requests, scheduler_output.fitting_disagg_gen_init_requests, scheduler_output.num_fitting_requests

tensorrt_llm/llmapi/llm_args.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,23 @@ def validate_cuda_graph_max_batch_size(cls, v):
8888
return v
8989

9090

91+
class AttentionDpConfig(BaseModel):
92+
"""
93+
Configuration for attention DP.
94+
"""
95+
enable_balance: bool = Field(default=False,
96+
description="Whether to enable balance.")
97+
batching_wait_iters: int = Field(
98+
default=10,
99+
description="The number of iterations to wait for batching.")
100+
timeout_iters: int = Field(
101+
default=500, description="The number of iterations to timeout.")
102+
103+
@classmethod
104+
def from_dict(cls, data: dict):
105+
return cls(**data)
106+
107+
91108
class MoeConfig(BaseModel):
92109
"""
93110
Configuration for MoE.
@@ -1789,6 +1806,11 @@ class TorchLlmArgs(BaseLlmArgs):
17891806
since the input shapes are a function of the sequence lengths).\
17901807
Note that each CUDA graph can use up to 200 MB of extra memory.")
17911808

1809+
attention_dp_config: Optional[AttentionDpConfig] = Field(
1810+
default=None,
1811+
description=
1812+
"Attention DP config. If true, use attention DP optimized scheduler.")
1813+
17921814
disable_overlap_scheduler: bool = Field(
17931815
default=False, description="Disable the overlap scheduler.")
17941816

@@ -1993,6 +2015,29 @@ def validate_cuda_graph_config(self) -> 'TorchLlmArgs':
19932015

19942016
return self
19952017

2018+
@model_validator(mode='after')
2019+
def validate_attention_dp_config(self) -> 'TorchLlmArgs':
2020+
"""Validate attention DP configuration.
2021+
2022+
Ensures that:
2023+
1. If attention_dp_config.enable_balance is true, attention_dp_config.batching_wait_iters must be greater than 0
2024+
2. If attention_dp_config.enable_balance is true, attention_dp_config.timeout_iters must be greater than 0
2025+
"""
2026+
if self.attention_dp_config is None:
2027+
return self
2028+
2029+
config = self.attention_dp_config
2030+
if config.enable_balance:
2031+
if config.batching_wait_iters < 0:
2032+
raise ValueError(
2033+
"attention_dp_config.batching_wait_iters must be greater than 0 when enable_balance is true"
2034+
)
2035+
if config.timeout_iters < 0:
2036+
raise ValueError(
2037+
"attention_dp_config.timeout_iters must be greater than 0 when enable_balance is true"
2038+
)
2039+
return self
2040+
19962041
# TODO: Remove this after the PyTorch backend is fully migrated to TorchLlmArgs from ExecutorConfig
19972042
def get_pytorch_backend_config(self) -> "PyTorchConfig":
19982043
from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
@@ -2040,7 +2085,14 @@ def get_pytorch_backend_config(self) -> "PyTorchConfig":
20402085
enable_min_latency=self.enable_min_latency,
20412086
stream_interval=self.stream_interval,
20422087
force_dynamic_quantization=self.force_dynamic_quantization,
2043-
allreduce_strategy=self.allreduce_strategy)
2088+
allreduce_strategy=self.allreduce_strategy,
2089+
use_attention_dp_config=bool(self.attention_dp_config is not None),
2090+
attention_dp_time_out_iters=self.attention_dp_config.timeout_iters
2091+
if self.attention_dp_config is not None else
2092+
AttentionDpConfig.model_fields['timeout_iters'].default,
2093+
attention_dp_batching_wait_iters=self.attention_dp_config.
2094+
batching_wait_iters if self.attention_dp_config is not None else
2095+
AttentionDpConfig.model_fields['batching_wait_iters'].default)
20442096

20452097

20462098
def update_llm_args_with_extra_dict(
@@ -2057,6 +2109,7 @@ def update_llm_args_with_extra_dict(
20572109
"speculative_config": DecodingBaseConfig,
20582110
"lora_config": LoraConfig,
20592111
"moe_config": MoeConfig,
2112+
"attention_dp_config": AttentionDpConfig,
20602113
}
20612114
for field_name, field_type in field_mapping.items():
20622115
if field_name in llm_args_dict:

0 commit comments

Comments
 (0)