Skip to content

Commit f6778f2

Browse files
committed
opt: add tp attn waiting
Signed-off-by: yunruis <[email protected]>
1 parent abdb273 commit f6778f2

File tree

5 files changed

+121
-1
lines changed

5 files changed

+121
-1
lines changed

tensorrt_llm/_torch/pyexecutor/config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,11 @@ class PyTorchConfig:
5050
attention_dp_time_out_iters: int = 50
5151
attention_dp_batching_wait_iters: int = 10
5252

53+
max_num_tokens: int = 8192
54+
5355
batch_wait_timeout_ms: float = 0
56+
batch_wait_timeout_iters: int = 0
57+
batch_wait_max_tokens_ratio: float = 0
5458

5559
attn_backend: str = 'TRTLLM'
5660
moe_backend: str = 'CUTLASS'

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ def __init__(self,
180180
self.active = True
181181
self.max_beam_width = max_beam_width
182182
self.max_draft_len = max_draft_len
183+
self.max_num_tokens = model_engine.pytorch_backend_config.max_num_tokens
183184
self.print_log = model_engine.pytorch_backend_config.print_iter_log
184185
self.enable_iter_perf_stats = model_engine.pytorch_backend_config.enable_iter_perf_stats
185186
self.enable_iter_req_stats = model_engine.pytorch_backend_config.enable_iter_req_stats
@@ -188,6 +189,10 @@ def __init__(self,
188189
self.attention_dp_time_out_iters = model_engine.pytorch_backend_config.attention_dp_time_out_iters
189190
self.attention_dp_batching_wait_iters = model_engine.pytorch_backend_config.attention_dp_batching_wait_iters
190191
self.batch_wait_timeout_ms = model_engine.pytorch_backend_config.batch_wait_timeout_ms
192+
self.batch_wait_timeout_iters = model_engine.pytorch_backend_config.batch_wait_timeout_iters
193+
self.batch_wait_max_tokens_ratio = model_engine.pytorch_backend_config.batch_wait_max_tokens_ratio
194+
self.enable_batch_waiting = self.batch_wait_timeout_iters > 0 or self.batch_wait_max_tokens_ratio > 0
195+
191196
self.num_fetch_requests_cur_rank = 0
192197
self.num_fetch_requests = 0
193198
self.shutdown_event = threading.Event()
@@ -232,6 +237,7 @@ def __init__(self,
232237
self.max_batch_size = max_batch_size
233238
self.adp_ctx_waiting_iters_count = 0
234239
self.adp_ctx_batching_wait_iters_count = 0
240+
self.batch_wait_iters_count = 0
235241

236242
# request fetcher initialization
237243
self.executor_request_queue = ExecutorRequestQueue(
@@ -1257,6 +1263,27 @@ def _balance_adp_requests(self, context_requests: list[LlmRequest],
12571263
balanced_context_requests = context_requests
12581264
return balanced_context_requests
12591265

1266+
def _waiting_requests(self, context_requests: list[LlmRequest],
1267+
generation_requests: list[LlmRequest]):
1268+
if not self.enable_batch_waiting:
1269+
return context_requests
1270+
1271+
waited_context_requests = []
1272+
stop_waiting = False
1273+
num_scheduled_ctx_tokens = sum(
1274+
len(ctx_req.get_tokens(0)) for ctx_req in context_requests)
1275+
num_scheduled_gen_tokens = sum(
1276+
len(gen_req.get_tokens(0)) for gen_req in generation_requests)
1277+
num_scheduled_tokens = num_scheduled_ctx_tokens + num_scheduled_gen_tokens
1278+
1279+
stop_waiting = self.batch_wait_iters_count >= self.batch_wait_timeout_iters or num_scheduled_tokens >= self.batch_wait_max_tokens_ratio * self.max_num_tokens
1280+
if stop_waiting:
1281+
waited_context_requests = context_requests
1282+
self.batch_wait_iters_count = 0
1283+
else:
1284+
self.batch_wait_iters_count += 1
1285+
return waited_context_requests
1286+
12601287
@nvtx_range("_schedule")
12611288
def _schedule(self):
12621289
scheduler_output = self.scheduler.schedule_request(
@@ -1267,6 +1294,14 @@ def _schedule(self):
12671294
scheduler_output.context_requests,
12681295
scheduler_output.generation_requests)
12691296

1297+
# if no generation requests, no need to wait, to avoid dead waiting
1298+
if not self.enable_attention_dp and self.enable_batch_waiting and len(
1299+
scheduler_output.context_requests) > 0 and len(
1300+
scheduler_output.generation_requests) > 0:
1301+
scheduled_context_requests = self._waiting_requests(
1302+
scheduler_output.context_requests,
1303+
scheduler_output.generation_requests)
1304+
12701305
scheduled_requests = ScheduledRequests()
12711306
scheduled_requests.context_requests = scheduled_context_requests
12721307
scheduled_requests.generation_requests = scheduler_output.generation_requests

tensorrt_llm/llmapi/llm_args.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2240,6 +2240,18 @@ class TorchLlmArgs(BaseLlmArgs):
22402240
"If greater than 0, the request queue might wait up to batch_wait_timeout_ms to receive max_batch_size requests, if fewer than max_batch_size requests are currently available. If 0, no waiting occurs.",
22412241
status="prototype")
22422242

2243+
batch_wait_timeout_iters: float = Field(
2244+
default=0,
2245+
description=
2246+
"Maximum number of iterations the scheduler will wait to accumulate new coming requests for improved GPU utilization efficiency. If greater than 0, the scheduler will delay batch processing to gather more requests up to the specified iteration limit. If 0, disables timeout-iters-based batching delays.",
2247+
status="prototype")
2248+
2249+
batch_wait_max_tokens_ratio: float = Field(
2250+
default=0,
2251+
description=
2252+
"Token accumulation threshold ratio for batch scheduling optimization. If greater than 0, the scheduler will accumulate requests locally until the total token count reaches batch_wait_max_tokens_ratio * max_num_tokens. This mechanism enhances GPU utilization efficiency by ensuring adequate batch sizes.If 0 disables token-based batching delays.",
2253+
status="prototype")
2254+
22432255
torch_compile_config: Optional[TorchCompileConfig] = Field(
22442256
default=None, description="Torch compile config.", status="prototype")
22452257

@@ -2508,6 +2520,20 @@ def validate_batch_wait_timeout_ms(self) -> 'TorchLlmArgs':
25082520
raise ValueError("batch_wait_timeout_ms must be greater than 0")
25092521
return self
25102522

2523+
@model_validator(mode='after')
2524+
def validate_batch_wait_timeout_iters(self) -> 'TorchLlmArgs':
2525+
if self.batch_wait_timeout_iters < 0:
2526+
raise ValueError("batch_wait_timeout_iters must be greater than 0")
2527+
return self
2528+
2529+
@model_validator(mode='after')
2530+
def validate_batch_wait_max_tokens_ratio(self) -> 'TorchLlmArgs':
2531+
if self.batch_wait_max_tokens_ratio < 0 or self.batch_wait_max_tokens_ratio > 1:
2532+
raise ValueError(
2533+
"batch_wait_max_tokens_ratio must be greater than or equal to 0 and less than or equal to 1"
2534+
)
2535+
return self
2536+
25112537
def get_executor_config(
25122538
self,
25132539
_hf_model_dir: Optional[Path] = None,
@@ -2583,7 +2609,10 @@ def get_pytorch_backend_config(self) -> "PyTorchConfig":
25832609
attention_dp_batching_wait_iters=self.attention_dp_config.
25842610
batching_wait_iters if self.attention_dp_config is not None else
25852611
AttentionDpConfig.model_fields['batching_wait_iters'].default,
2586-
batch_wait_timeout_ms=self.batch_wait_timeout_ms)
2612+
batch_wait_timeout_ms=self.batch_wait_timeout_ms,
2613+
batch_wait_timeout_iters=self.batch_wait_timeout_iters,
2614+
batch_wait_max_tokens_ratio=self.batch_wait_max_tokens_ratio,
2615+
)
25872616

25882617

25892618
def update_llm_args_with_extra_dict(

tests/integration/defs/accuracy/test_llm_api_pytorch.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1483,6 +1483,50 @@ def test_nvfp4(self, fp8kv, attention_dp, cuda_graph, overlap_scheduler,
14831483
task = GSM8K(self.MODEL_NAME)
14841484
task.evaluate(llm)
14851485

1486+
@skip_pre_blackwell
1487+
@parametrize_with_ids("torch_compile", [False, True])
1488+
@parametrize_with_ids("fp8kv,cuda_graph,overlap_scheduler",
1489+
[(False, False, False), (True, True, True)])
1490+
@parametrize_with_ids("mtp_nextn", [0, 2])
1491+
@parametrize_with_ids(
1492+
"batch_wait_timeout_iters, batch_wait_max_tokens_ratio", [(0, 0),
1493+
(10, 0.75),
1494+
(10, 0),
1495+
(0, 0.75)])
1496+
def test_nvfp4_batch_waiting(self, cuda_graph, overlap_scheduler,
1497+
torch_compile, mtp_nextn,
1498+
batch_wait_timeout_iters,
1499+
batch_wait_max_tokens_ratio):
1500+
fp8kv = True
1501+
moe_backend = "CUTLASS"
1502+
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.75)
1503+
torch_compile_config = TorchCompileConfig(
1504+
enable_fullgraph=True,
1505+
enable_piecewise_cuda_graph=cuda_graph,
1506+
capture_num_tokens=[2048, 8192],
1507+
max_num_streams=3) if torch_compile else None
1508+
pytorch_config = dict(
1509+
disable_overlap_scheduler=not overlap_scheduler,
1510+
cuda_graph_config=CudaGraphConfig() if cuda_graph else None,
1511+
torch_compile_config=torch_compile_config,
1512+
batch_wait_timeout_iters=batch_wait_timeout_iters,
1513+
batch_wait_max_tokens_ratio=batch_wait_max_tokens_ratio,
1514+
moe_config=MoeConfig(backend=moe_backend))
1515+
mtp_config = None
1516+
if mtp_nextn > 0:
1517+
mtp_config = MTPDecodingConfig(num_nextn_predict_layers=mtp_nextn)
1518+
if fp8kv:
1519+
kv_cache_config.dtype = "fp8"
1520+
with LLM(f"{llm_models_root()}/DeepSeek-V3-Lite/nvfp4_moe_only_mtp",
1521+
kv_cache_config=kv_cache_config,
1522+
**pytorch_config,
1523+
enable_attention_dp=False,
1524+
speculative_config=mtp_config) as llm:
1525+
assert llm.args.quant_config.quant_algo == QuantAlgo.NVFP4
1526+
1527+
task = GSM8K(self.MODEL_NAME)
1528+
task.evaluate(llm)
1529+
14861530
@pytest.mark.skip_less_device(4)
14871531
@skip_pre_blackwell
14881532
@parametrize_with_ids("torch_compile", [False, True])

tests/unittest/api_stability/references/llm.yaml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,14 @@ methods:
131131
annotation: float
132132
default: 0
133133
status: prototype
134+
batch_wait_timeout_iters:
135+
annotation: int
136+
default: 0
137+
status: prototype
138+
batch_wait_max_tokens_ratio:
139+
annotation: float
140+
default: 0
141+
status: prototype
134142
print_iter_log:
135143
annotation: bool
136144
default: False

0 commit comments

Comments
 (0)