Skip to content

Commit cdf234a

Browse files
committed
optimize: ADP schedule optimization
Signed-off-by: yunruis <[email protected]>
1 parent 2a147c4 commit cdf234a

File tree

5 files changed

+126
-6
lines changed

5 files changed

+126
-6
lines changed

examples/llm-api/quickstart_advanced.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def add_llm_args(parser):
5050
parser.add_argument('--moe_backend',
5151
type=str,
5252
default='CUTLASS',
53-
choices=['CUTLASS', 'TRTLLM', 'VANILLA'])
53+
choices=['CUTLASS', 'TRTLLM', 'VANILLA', 'WIDEEP'])
5454
parser.add_argument('--enable_attention_dp',
5555
default=False,
5656
action='store_true')

tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ def select_alltoall_method_type(mapping: Mapping, top_k: int,
245245
if os.environ.get("TRTLLM_MOE_DISABLE_ALLTOALLV", "0") == "1":
246246
return AlltoallMethodType.NotEnabled
247247

248-
if mapping.moe_ep_size <= top_k:
248+
if mapping.moe_ep_size < top_k:
249249
return AlltoallMethodType.NotEnabled
250250

251251
if MnnvlMemory.supports_mnnvl():

tensorrt_llm/_torch/pyexecutor/config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,10 @@ class PyTorchConfig:
4444
moe_max_num_tokens: Optional[int] = None
4545
moe_load_balancer: Optional[Union[MoeLoadBalancerConfig, dict, str]] = None
4646

47+
use_attention_dp_config: bool = False
48+
attention_dp_time_out_iters: int = 500
49+
attention_dp_batching_wait_iters: int = 0
50+
4751
attn_backend: str = 'TRTLLM'
4852
moe_backend: str = 'CUTLASS'
4953

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 75 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,10 @@ def __init__(self,
241241
self.enable_iter_perf_stats = model_engine.pytorch_backend_config.enable_iter_perf_stats
242242
self.enable_iter_req_stats = model_engine.pytorch_backend_config.enable_iter_req_stats
243243
self.stream_interval = model_engine.pytorch_backend_config.stream_interval
244+
self.use_attention_dp_config = model_engine.pytorch_backend_config.use_attention_dp_config
245+
self.attention_dp_time_out_iters = model_engine.pytorch_backend_config.attention_dp_time_out_iters
246+
self.attention_dp_batching_wait_iters = model_engine.pytorch_backend_config.attention_dp_batching_wait_iters
247+
244248
self.num_fetch_requests_cur_rank = 0
245249
self.num_fetch_requests = 0
246250
self.shutdown_event = threading.Event()
@@ -287,6 +291,9 @@ def __init__(self,
287291
self.draft_model_engine.warmup(self.resource_manager)
288292

289293
self.is_shutdown = False
294+
self.max_batch_size = max_batch_size
295+
self.adp_ctx_waiting_iters = 0
296+
self.adp_ctx_batching_wait_iters = 0
290297

291298
self.stats_lock = threading.Lock()
292299
self.stats = []
@@ -1228,7 +1235,15 @@ def _broadcast_new_requests(
12281235
def _fetch_new_requests(self) -> List[RequestQueueItem]:
12291236
if self.enable_attention_dp:
12301237
all_ranks_num_active_requests = []
1231-
responses_list = self.dist.tp_allgather(len(self.active_requests))
1238+
num_active_requests = len(self.active_requests)
1239+
responses_list = self.dist.tp_allgather(num_active_requests)
1240+
# Debug check - remove after verification
1241+
if not all(isinstance(x, int) for x in responses_list):
1242+
raise RuntimeError(
1243+
f"tp_allgather returned non-integer values: {responses_list} " +
1244+
f"Expected all ranks to return int from {num_active_requests} and {self.active_requests}."
1245+
)
1246+
12321247
for num_active_requests in responses_list:
12331248
all_ranks_num_active_requests.append(num_active_requests)
12341249
total_num_active_requests = sum(all_ranks_num_active_requests)
@@ -1518,8 +1533,66 @@ def _schedule(self):
15181533
scheduler_output = self.scheduler.schedule_request(
15191534
self.active_requests, self.inflight_req_ids)
15201535
scheduled_requests = ScheduledRequests()
1536+
context_requests = scheduler_output.context_requests
1537+
if self.enable_attention_dp:
1538+
num_scheduled_context_requests = len(
1539+
scheduler_output.context_requests)
1540+
num_scheduled_generation_requests = len(
1541+
scheduler_output.generation_requests)
1542+
num_scheduled_tokens = sum([
1543+
len(req.get_tokens(0)) for req in context_requests
1544+
]) + num_scheduled_generation_requests
1545+
responses_list = self.dist.tp_allgather([
1546+
num_scheduled_context_requests,
1547+
num_scheduled_generation_requests, num_scheduled_tokens
1548+
])
1549+
all_ranks_num_scheduled_context_requests = [
1550+
response[0] for response in responses_list
1551+
]
1552+
all_ranks_num_scheduled_generation_requests = [
1553+
response[1] for response in responses_list
1554+
]
1555+
all_ranks_num_scheduled_tokens = [
1556+
response[2] for response in responses_list
1557+
]
1558+
1559+
all_ranks_have_free_ctx_slots = all([
1560+
num_gen < self.max_batch_size
1561+
for num_gen in all_ranks_num_scheduled_generation_requests
1562+
])
1563+
all_ranks_have_multi_gen = all([
1564+
num_gen > 1
1565+
for num_gen in all_ranks_num_scheduled_generation_requests
1566+
])
1567+
all_ranks_have_ctx_requests = all([
1568+
num_ctx > 0
1569+
for num_ctx in all_ranks_num_scheduled_context_requests
1570+
])
1571+
1572+
all_ranks_have_gen_requests = all([
1573+
num_gen > 0
1574+
for num_gen in all_ranks_num_scheduled_generation_requests
1575+
])
1576+
if self.use_attention_dp_config:
1577+
# wait for all ranks have context requests
1578+
if all_ranks_have_multi_gen:
1579+
if all_ranks_have_free_ctx_slots and all_ranks_have_ctx_requests:
1580+
self.adp_ctx_waiting_iters = 0
1581+
else:
1582+
self.adp_ctx_waiting_iters += 1
1583+
context_requests = []
1584+
if self.adp_ctx_waiting_iters >= self.attention_dp_time_out_iters:
1585+
self.adp_ctx_waiting_iters = 0
1586+
context_requests = scheduler_output.context_requests
1587+
# balance number of context requests across ranks
1588+
if all_ranks_have_free_ctx_slots and all_ranks_have_ctx_requests and all_ranks_have_gen_requests:
1589+
if self.adp_ctx_batching_wait_iters <= self.attention_dp_batching_wait_iters:
1590+
self.adp_ctx_batching_wait_iters += 1
1591+
context_requests = []
1592+
else:
1593+
self.adp_ctx_batching_wait_iters = 0
15211594

1522-
scheduled_requests.context_requests = scheduler_output.context_requests
1595+
scheduled_requests.context_requests = context_requests
15231596
scheduled_requests.generation_requests = scheduler_output.generation_requests
15241597
scheduled_requests.paused_requests = scheduler_output.paused_requests
15251598
return scheduled_requests, scheduler_output.fitting_disagg_gen_init_requests, scheduler_output.num_fitting_requests

tensorrt_llm/llmapi/llm_args.py

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,18 @@ def validate_cuda_graph_max_batch_size(cls, v):
8787
"cuda_graph_config.max_batch_size must be non-negative")
8888
return v
8989

90+
class AttentionDpConfig(BaseModel):
91+
"""
92+
Configuration for attention DP.
93+
"""
94+
enable_balance: bool = Field(default=False, description="Whether to enable balance.")
95+
batching_wait_iters: int = Field(default=10, description="The number of iterations to wait for batching.")
96+
timeout_iters: int = Field(default=500, description="The number of iterations to timeout.")
97+
98+
@classmethod
99+
def from_dict(cls, data: dict):
100+
return cls(**data)
101+
90102

91103
class MoeConfig(BaseModel):
92104
"""
@@ -1789,6 +1801,9 @@ class TorchLlmArgs(BaseLlmArgs):
17891801
since the input shapes are a function of the sequence lengths).\
17901802
Note that each CUDA graph can use up to 200 MB of extra memory.")
17911803

1804+
attention_dp_config: Optional[AttentionDpConfig] = Field(
1805+
default=None, description="Attention DP config. If true, use attention DP optimized scheduler.")
1806+
17921807
disable_overlap_scheduler: bool = Field(
17931808
default=False, description="Disable the overlap scheduler.")
17941809

@@ -1993,6 +2008,29 @@ def validate_cuda_graph_config(self) -> 'TorchLlmArgs':
19932008

19942009
return self
19952010

2011+
@model_validator(mode='after')
2012+
def validate_attention_dp_config(self) -> 'TorchLlmArgs':
2013+
"""Validate attention DP configuration.
2014+
2015+
Ensures that:
2016+
1. If attention_dp_config.enable_balance is true, attention_dp_config.batching_wait_iters must be greater than 0
2017+
2. If attention_dp_config.enable_balance is true, attention_dp_config.timeout_iters must be greater than 0
2018+
"""
2019+
if self.attention_dp_config is None:
2020+
return self
2021+
2022+
config = self.attention_dp_config
2023+
if config.enable_balance:
2024+
if config.batching_wait_iters < 0:
2025+
raise ValueError(
2026+
"attention_dp_config.batching_wait_iters must be greater than 0 when enable_balance is true"
2027+
)
2028+
if config.timeout_iters < 0:
2029+
raise ValueError(
2030+
"attention_dp_config.timeout_iters must be greater than 0 when enable_balance is true"
2031+
)
2032+
return self
2033+
19962034
# TODO: Remove this after the PyTorch backend is fully migrated to TorchLlmArgs from ExecutorConfig
19972035
def get_pytorch_backend_config(self) -> "PyTorchConfig":
19982036
from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
@@ -2039,8 +2077,12 @@ def get_pytorch_backend_config(self) -> "PyTorchConfig":
20392077
load_format=self.load_format,
20402078
enable_min_latency=self.enable_min_latency,
20412079
stream_interval=self.stream_interval,
2042-
force_dynamic_quantization=self.force_dynamic_quantization,
2043-
allreduce_strategy=self.allreduce_strategy)
2080+
force_dynamic_quantization=self.
2081+
force_dynamic_quantization,
2082+
allreduce_strategy=self.allreduce_strategy,
2083+
use_attention_dp_config = bool(self.attention_dp_config is not None),
2084+
attention_dp_time_out_iters = self.attention_dp_config.timeout_iters if self.attention_dp_config is not None else AttentionDpConfig.model_fields['timeout_iters'].default,
2085+
attention_dp_batching_wait_iters = self.attention_dp_config.batching_wait_iters if self.attention_dp_config is not None else AttentionDpConfig.model_fields['batching_wait_iters'].default)
20442086

20452087

20462088
def update_llm_args_with_extra_dict(
@@ -2057,6 +2099,7 @@ def update_llm_args_with_extra_dict(
20572099
"speculative_config": DecodingBaseConfig,
20582100
"lora_config": LoraConfig,
20592101
"moe_config": MoeConfig,
2102+
"attention_dp_config": AttentionDpConfig,
20602103
}
20612104
for field_name, field_type in field_mapping.items():
20622105
if field_name in llm_args_dict:

0 commit comments

Comments
 (0)