Skip to content

Commit 8c0a5fb

Browse files
committed
optimize: ADP schedule optimization
Signed-off-by: yunruis <[email protected]>
1 parent 2a147c4 commit 8c0a5fb

File tree

5 files changed

+136
-5
lines changed

5 files changed

+136
-5
lines changed

examples/llm-api/quickstart_advanced.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def add_llm_args(parser):
5050
parser.add_argument('--moe_backend',
5151
type=str,
5252
default='CUTLASS',
53-
choices=['CUTLASS', 'TRTLLM', 'VANILLA'])
53+
choices=['CUTLASS', 'TRTLLM', 'VANILLA', 'WIDEEP'])
5454
parser.add_argument('--enable_attention_dp',
5555
default=False,
5656
action='store_true')

tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ def select_alltoall_method_type(mapping: Mapping, top_k: int,
245245
if os.environ.get("TRTLLM_MOE_DISABLE_ALLTOALLV", "0") == "1":
246246
return AlltoallMethodType.NotEnabled
247247

248-
if mapping.moe_ep_size <= top_k:
248+
if mapping.moe_ep_size < top_k:
249249
return AlltoallMethodType.NotEnabled
250250

251251
if MnnvlMemory.supports_mnnvl():

tensorrt_llm/_torch/pyexecutor/config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,10 @@ class PyTorchConfig:
4444
moe_max_num_tokens: Optional[int] = None
4545
moe_load_balancer: Optional[Union[MoeLoadBalancerConfig, dict, str]] = None
4646

47+
use_attention_dp_config: bool = False
48+
attention_dp_time_out_iters: int = 500
49+
attention_dp_batching_wait_iters: int = 0
50+
4751
attn_backend: str = 'TRTLLM'
4852
moe_backend: str = 'CUTLASS'
4953

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 76 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,10 @@ def __init__(self,
241241
self.enable_iter_perf_stats = model_engine.pytorch_backend_config.enable_iter_perf_stats
242242
self.enable_iter_req_stats = model_engine.pytorch_backend_config.enable_iter_req_stats
243243
self.stream_interval = model_engine.pytorch_backend_config.stream_interval
244+
self.use_attention_dp_config = model_engine.pytorch_backend_config.use_attention_dp_config
245+
self.attention_dp_time_out_iters = model_engine.pytorch_backend_config.attention_dp_time_out_iters
246+
self.attention_dp_batching_wait_iters = model_engine.pytorch_backend_config.attention_dp_batching_wait_iters
247+
244248
self.num_fetch_requests_cur_rank = 0
245249
self.num_fetch_requests = 0
246250
self.shutdown_event = threading.Event()
@@ -287,6 +291,9 @@ def __init__(self,
287291
self.draft_model_engine.warmup(self.resource_manager)
288292

289293
self.is_shutdown = False
294+
self.max_batch_size = max_batch_size
295+
self.adp_ctx_waiting_iters = 0
296+
self.adp_ctx_batching_wait_iters = 0
290297

291298
self.stats_lock = threading.Lock()
292299
self.stats = []
@@ -1228,7 +1235,16 @@ def _broadcast_new_requests(
12281235
def _fetch_new_requests(self) -> List[RequestQueueItem]:
12291236
if self.enable_attention_dp:
12301237
all_ranks_num_active_requests = []
1231-
responses_list = self.dist.tp_allgather(len(self.active_requests))
1238+
num_active_requests = len(self.active_requests)
1239+
responses_list = self.dist.tp_allgather(num_active_requests)
1240+
# Debug check - remove after verification
1241+
if not all(isinstance(x, int) for x in responses_list):
1242+
raise RuntimeError(
1243+
f"tp_allgather returned non-integer values: {responses_list} "
1244+
+
1245+
f"Expected all ranks to return int from {num_active_requests} and {self.active_requests}."
1246+
)
1247+
12321248
for num_active_requests in responses_list:
12331249
all_ranks_num_active_requests.append(num_active_requests)
12341250
total_num_active_requests = sum(all_ranks_num_active_requests)
@@ -1518,8 +1534,66 @@ def _schedule(self):
15181534
scheduler_output = self.scheduler.schedule_request(
15191535
self.active_requests, self.inflight_req_ids)
15201536
scheduled_requests = ScheduledRequests()
1537+
context_requests = scheduler_output.context_requests
1538+
if self.enable_attention_dp:
1539+
num_scheduled_context_requests = len(
1540+
scheduler_output.context_requests)
1541+
num_scheduled_generation_requests = len(
1542+
scheduler_output.generation_requests)
1543+
num_scheduled_tokens = sum([
1544+
len(req.get_tokens(0)) for req in context_requests
1545+
]) + num_scheduled_generation_requests
1546+
responses_list = self.dist.tp_allgather([
1547+
num_scheduled_context_requests,
1548+
num_scheduled_generation_requests, num_scheduled_tokens
1549+
])
1550+
all_ranks_num_scheduled_context_requests = [
1551+
response[0] for response in responses_list
1552+
]
1553+
all_ranks_num_scheduled_generation_requests = [
1554+
response[1] for response in responses_list
1555+
]
1556+
all_ranks_num_scheduled_tokens = [
1557+
response[2] for response in responses_list
1558+
]
1559+
1560+
all_ranks_have_free_ctx_slots = all([
1561+
num_gen < self.max_batch_size
1562+
for num_gen in all_ranks_num_scheduled_generation_requests
1563+
])
1564+
all_ranks_have_multi_gen = all([
1565+
num_gen > 1
1566+
for num_gen in all_ranks_num_scheduled_generation_requests
1567+
])
1568+
all_ranks_have_ctx_requests = all([
1569+
num_ctx > 0
1570+
for num_ctx in all_ranks_num_scheduled_context_requests
1571+
])
1572+
1573+
all_ranks_have_gen_requests = all([
1574+
num_gen > 0
1575+
for num_gen in all_ranks_num_scheduled_generation_requests
1576+
])
1577+
if self.use_attention_dp_config:
1578+
# wait for all ranks have context requests
1579+
if all_ranks_have_multi_gen:
1580+
if all_ranks_have_free_ctx_slots and all_ranks_have_ctx_requests:
1581+
self.adp_ctx_waiting_iters = 0
1582+
else:
1583+
self.adp_ctx_waiting_iters += 1
1584+
context_requests = []
1585+
if self.adp_ctx_waiting_iters >= self.attention_dp_time_out_iters:
1586+
self.adp_ctx_waiting_iters = 0
1587+
context_requests = scheduler_output.context_requests
1588+
# balance number of context requests across ranks
1589+
if all_ranks_have_free_ctx_slots and all_ranks_have_ctx_requests and all_ranks_have_gen_requests:
1590+
if self.adp_ctx_batching_wait_iters <= self.attention_dp_batching_wait_iters:
1591+
self.adp_ctx_batching_wait_iters += 1
1592+
context_requests = []
1593+
else:
1594+
self.adp_ctx_batching_wait_iters = 0
15211595

1522-
scheduled_requests.context_requests = scheduler_output.context_requests
1596+
scheduled_requests.context_requests = context_requests
15231597
scheduled_requests.generation_requests = scheduler_output.generation_requests
15241598
scheduled_requests.paused_requests = scheduler_output.paused_requests
15251599
return scheduled_requests, scheduler_output.fitting_disagg_gen_init_requests, scheduler_output.num_fitting_requests

tensorrt_llm/llmapi/llm_args.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,23 @@ def validate_cuda_graph_max_batch_size(cls, v):
8888
return v
8989

9090

91+
class AttentionDpConfig(BaseModel):
92+
"""
93+
Configuration for attention DP.
94+
"""
95+
enable_balance: bool = Field(default=False,
96+
description="Whether to enable balance.")
97+
batching_wait_iters: int = Field(
98+
default=10,
99+
description="The number of iterations to wait for batching.")
100+
timeout_iters: int = Field(
101+
default=500, description="The number of iterations to timeout.")
102+
103+
@classmethod
104+
def from_dict(cls, data: dict):
105+
return cls(**data)
106+
107+
91108
class MoeConfig(BaseModel):
92109
"""
93110
Configuration for MoE.
@@ -1789,6 +1806,11 @@ class TorchLlmArgs(BaseLlmArgs):
17891806
since the input shapes are a function of the sequence lengths).\
17901807
Note that each CUDA graph can use up to 200 MB of extra memory.")
17911808

1809+
attention_dp_config: Optional[AttentionDpConfig] = Field(
1810+
default=None,
1811+
description=
1812+
"Attention DP config. If true, use attention DP optimized scheduler.")
1813+
17921814
disable_overlap_scheduler: bool = Field(
17931815
default=False, description="Disable the overlap scheduler.")
17941816

@@ -1993,6 +2015,29 @@ def validate_cuda_graph_config(self) -> 'TorchLlmArgs':
19932015

19942016
return self
19952017

2018+
@model_validator(mode='after')
2019+
def validate_attention_dp_config(self) -> 'TorchLlmArgs':
2020+
"""Validate attention DP configuration.
2021+
2022+
Ensures that:
2023+
1. If attention_dp_config.enable_balance is true, attention_dp_config.batching_wait_iters must be greater than 0
2024+
2. If attention_dp_config.enable_balance is true, attention_dp_config.timeout_iters must be greater than 0
2025+
"""
2026+
if self.attention_dp_config is None:
2027+
return self
2028+
2029+
config = self.attention_dp_config
2030+
if config.enable_balance:
2031+
if config.batching_wait_iters < 0:
2032+
raise ValueError(
2033+
"attention_dp_config.batching_wait_iters must be greater than 0 when enable_balance is true"
2034+
)
2035+
if config.timeout_iters < 0:
2036+
raise ValueError(
2037+
"attention_dp_config.timeout_iters must be greater than 0 when enable_balance is true"
2038+
)
2039+
return self
2040+
19962041
# TODO: Remove this after the PyTorch backend is fully migrated to TorchLlmArgs from ExecutorConfig
19972042
def get_pytorch_backend_config(self) -> "PyTorchConfig":
19982043
from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
@@ -2040,7 +2085,14 @@ def get_pytorch_backend_config(self) -> "PyTorchConfig":
20402085
enable_min_latency=self.enable_min_latency,
20412086
stream_interval=self.stream_interval,
20422087
force_dynamic_quantization=self.force_dynamic_quantization,
2043-
allreduce_strategy=self.allreduce_strategy)
2088+
allreduce_strategy=self.allreduce_strategy,
2089+
use_attention_dp_config=bool(self.attention_dp_config is not None),
2090+
attention_dp_time_out_iters=self.attention_dp_config.timeout_iters
2091+
if self.attention_dp_config is not None else
2092+
AttentionDpConfig.model_fields['timeout_iters'].default,
2093+
attention_dp_batching_wait_iters=self.attention_dp_config.
2094+
batching_wait_iters if self.attention_dp_config is not None else
2095+
AttentionDpConfig.model_fields['batching_wait_iters'].default)
20442096

20452097

20462098
def update_llm_args_with_extra_dict(
@@ -2057,6 +2109,7 @@ def update_llm_args_with_extra_dict(
20572109
"speculative_config": DecodingBaseConfig,
20582110
"lora_config": LoraConfig,
20592111
"moe_config": MoeConfig,
2112+
"attention_dp_config": AttentionDpConfig,
20602113
}
20612114
for field_name, field_type in field_mapping.items():
20622115
if field_name in llm_args_dict:

0 commit comments

Comments
 (0)