Skip to content

Commit 56d66e3

Browse files
committed
opt: add batch waiting when scheduling, to accumulate batch
Signed-off-by: yunruis <[email protected]>
1 parent c1e7fb9 commit 56d66e3

File tree

7 files changed

+122
-1
lines changed

7 files changed

+122
-1
lines changed

tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,8 @@ def __init__(
137137
self.pytorch_backend_config.attention_dp_time_out_iters = 50
138138
self.pytorch_backend_config.attention_dp_batching_wait_iters = 10
139139
self.pytorch_backend_config.batch_wait_timeout_ms = 0
140+
self.pytorch_backend_config.batch_wait_timeout_iters = 0
141+
self.pytorch_backend_config.batch_wait_max_tokens_ratio = 0.0
140142
self.iter_counter = 0
141143

142144
# NOTE (lucaslie): not a declared base member in the base class; required by PyExecutor...

tensorrt_llm/_torch/pyexecutor/config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,11 @@ class PyTorchConfig:
5050
attention_dp_time_out_iters: int = 50
5151
attention_dp_batching_wait_iters: int = 10
5252

53+
max_num_tokens: int = 8192
54+
5355
batch_wait_timeout_ms: float = 0
56+
batch_wait_timeout_iters: int = 0
57+
batch_wait_max_tokens_ratio: float = 0
5458

5559
attn_backend: str = 'TRTLLM'
5660
moe_backend: str = 'CUTLASS'

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ def __init__(self,
180180
self.active = True
181181
self.max_beam_width = max_beam_width
182182
self.max_draft_len = max_draft_len
183+
self.max_num_tokens = model_engine.pytorch_backend_config.max_num_tokens
183184
self.print_log = model_engine.pytorch_backend_config.print_iter_log
184185
self.enable_iter_perf_stats = model_engine.pytorch_backend_config.enable_iter_perf_stats
185186
self.enable_iter_req_stats = model_engine.pytorch_backend_config.enable_iter_req_stats
@@ -188,6 +189,10 @@ def __init__(self,
188189
self.attention_dp_time_out_iters = model_engine.pytorch_backend_config.attention_dp_time_out_iters
189190
self.attention_dp_batching_wait_iters = model_engine.pytorch_backend_config.attention_dp_batching_wait_iters
190191
self.batch_wait_timeout_ms = model_engine.pytorch_backend_config.batch_wait_timeout_ms
192+
self.batch_wait_timeout_iters = model_engine.pytorch_backend_config.batch_wait_timeout_iters
193+
self.batch_wait_max_tokens_ratio = model_engine.pytorch_backend_config.batch_wait_max_tokens_ratio
194+
self.enable_batch_waiting = self.batch_wait_timeout_iters > 0 or self.batch_wait_max_tokens_ratio > 0
195+
191196
self.num_fetch_requests_cur_rank = 0
192197
self.num_fetch_requests = 0
193198
self.shutdown_event = threading.Event()
@@ -232,6 +237,7 @@ def __init__(self,
232237
self.max_batch_size = max_batch_size
233238
self.adp_ctx_waiting_iters_count = 0
234239
self.adp_ctx_batching_wait_iters_count = 0
240+
self.batch_wait_iters_count = 0
235241

236242
# request fetcher initialization
237243
self.executor_request_queue = ExecutorRequestQueue(
@@ -1266,6 +1272,27 @@ def _balance_adp_requests(self, context_requests: list[LlmRequest],
12661272
balanced_context_requests = context_requests
12671273
return balanced_context_requests
12681274

1275+
def _waiting_requests(self, context_requests: list[LlmRequest],
1276+
generation_requests: list[LlmRequest]):
1277+
if not self.enable_batch_waiting:
1278+
return context_requests
1279+
1280+
waited_context_requests = []
1281+
stop_waiting = False
1282+
num_scheduled_ctx_tokens = sum(
1283+
len(ctx_req.get_tokens(0)) for ctx_req in context_requests)
1284+
num_scheduled_gen_tokens = sum(
1285+
len(gen_req.get_tokens(0)) for gen_req in generation_requests)
1286+
num_scheduled_tokens = num_scheduled_ctx_tokens + num_scheduled_gen_tokens
1287+
1288+
stop_waiting = self.batch_wait_iters_count >= self.batch_wait_timeout_iters or num_scheduled_tokens >= self.batch_wait_max_tokens_ratio * self.max_num_tokens
1289+
if stop_waiting:
1290+
waited_context_requests = context_requests
1291+
self.batch_wait_iters_count = 0
1292+
else:
1293+
self.batch_wait_iters_count += 1
1294+
return waited_context_requests
1295+
12691296
@nvtx_range("_schedule")
12701297
def _schedule(self):
12711298
scheduler_output = self.scheduler.schedule_request(
@@ -1276,6 +1303,14 @@ def _schedule(self):
12761303
scheduler_output.context_requests,
12771304
scheduler_output.generation_requests)
12781305

1306+
# if no generation requests, no need to wait, to avoid dead waiting
1307+
if not self.enable_attention_dp and self.enable_batch_waiting and len(
1308+
scheduler_output.context_requests) > 0 and len(
1309+
scheduler_output.generation_requests) > 0:
1310+
scheduled_context_requests = self._waiting_requests(
1311+
scheduler_output.context_requests,
1312+
scheduler_output.generation_requests)
1313+
12791314
scheduled_requests = ScheduledRequests()
12801315
scheduled_requests.context_requests = scheduled_context_requests
12811316
scheduled_requests.generation_requests = scheduler_output.generation_requests

tensorrt_llm/llmapi/llm_args.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2240,6 +2240,18 @@ class TorchLlmArgs(BaseLlmArgs):
22402240
"If greater than 0, the request queue might wait up to batch_wait_timeout_ms to receive max_batch_size requests, if fewer than max_batch_size requests are currently available. If 0, no waiting occurs.",
22412241
status="prototype")
22422242

2243+
batch_wait_timeout_iters: int = Field(
2244+
default=0,
2245+
description=
2246+
"Maximum number of iterations the scheduler will wait to accumulate new coming requests for improved GPU utilization efficiency. If greater than 0, the scheduler will delay batch processing to gather more requests up to the specified iteration limit. If 0, disables timeout-iters-based batching delays.",
2247+
status="prototype")
2248+
2249+
batch_wait_max_tokens_ratio: float = Field(
2250+
default=0,
2251+
description=
2252+
"Token accumulation threshold ratio for batch scheduling optimization. If greater than 0, the scheduler will accumulate requests locally until the total token count reaches batch_wait_max_tokens_ratio * max_num_tokens. This mechanism enhances GPU utilization efficiency by ensuring adequate batch sizes.If 0 disables token-based batching delays.",
2253+
status="prototype")
2254+
22432255
torch_compile_config: Optional[TorchCompileConfig] = Field(
22442256
default=None, description="Torch compile config.", status="prototype")
22452257

@@ -2508,6 +2520,19 @@ def validate_batch_wait_timeout_ms(self) -> 'TorchLlmArgs':
25082520
raise ValueError("batch_wait_timeout_ms must be greater than 0")
25092521
return self
25102522

2523+
@model_validator(mode='after')
2524+
def validate_batch_wait_timeout_iters(self) -> 'TorchLlmArgs':
2525+
if self.batch_wait_timeout_iters < 0:
2526+
raise ValueError("batch_wait_timeout_iters must be greater than 0")
2527+
return self
2528+
2529+
@model_validator(mode='after')
2530+
def validate_batch_wait_max_tokens_ratio(self) -> 'TorchLlmArgs':
2531+
if self.batch_wait_max_tokens_ratio < 0 or self.batch_wait_max_tokens_ratio > 1:
2532+
raise ValueError(
2533+
"batch_wait_max_tokens_ratio must be between 0 and 1")
2534+
return self
2535+
25112536
def get_executor_config(
25122537
self,
25132538
_hf_model_dir: Optional[Path] = None,
@@ -2583,7 +2608,10 @@ def get_pytorch_backend_config(self) -> "PyTorchConfig":
25832608
attention_dp_batching_wait_iters=self.attention_dp_config.
25842609
batching_wait_iters if self.attention_dp_config is not None else
25852610
AttentionDpConfig.model_fields['batching_wait_iters'].default,
2586-
batch_wait_timeout_ms=self.batch_wait_timeout_ms)
2611+
batch_wait_timeout_ms=self.batch_wait_timeout_ms,
2612+
batch_wait_timeout_iters=self.batch_wait_timeout_iters,
2613+
batch_wait_max_tokens_ratio=self.batch_wait_max_tokens_ratio,
2614+
)
25872615

25882616

25892617
def update_llm_args_with_extra_dict(

tests/integration/defs/accuracy/test_llm_api_pytorch.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1483,6 +1483,49 @@ def test_nvfp4(self, fp8kv, attention_dp, cuda_graph, overlap_scheduler,
14831483
task = GSM8K(self.MODEL_NAME)
14841484
task.evaluate(llm)
14851485

1486+
@skip_pre_blackwell
1487+
@parametrize_with_ids("torch_compile", [False, True])
1488+
@parametrize_with_ids("fp8kv,cuda_graph,overlap_scheduler",
1489+
[(False, False, False), (True, True, True)])
1490+
@parametrize_with_ids("mtp_nextn", [0, 2])
1491+
@parametrize_with_ids(
1492+
"batch_wait_timeout_iters,batch_wait_max_tokens_ratio", [(0, 0),
1493+
(10, 0.75),
1494+
(10, 0),
1495+
(0, 0.75)])
1496+
def test_nvfp4_batch_waiting(self, torch_compile, fp8kv, cuda_graph,
1497+
overlap_scheduler, mtp_nextn,
1498+
batch_wait_timeout_iters,
1499+
batch_wait_max_tokens_ratio):
1500+
moe_backend = "CUTLASS"
1501+
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.75)
1502+
torch_compile_config = TorchCompileConfig(
1503+
enable_fullgraph=True,
1504+
enable_piecewise_cuda_graph=cuda_graph,
1505+
capture_num_tokens=[2048, 8192],
1506+
max_num_streams=3) if torch_compile else None
1507+
pytorch_config = dict(
1508+
disable_overlap_scheduler=not overlap_scheduler,
1509+
cuda_graph_config=CudaGraphConfig() if cuda_graph else None,
1510+
torch_compile_config=torch_compile_config,
1511+
batch_wait_timeout_iters=batch_wait_timeout_iters,
1512+
batch_wait_max_tokens_ratio=batch_wait_max_tokens_ratio,
1513+
moe_config=MoeConfig(backend=moe_backend))
1514+
mtp_config = None
1515+
if mtp_nextn > 0:
1516+
mtp_config = MTPDecodingConfig(num_nextn_predict_layers=mtp_nextn)
1517+
if fp8kv:
1518+
kv_cache_config.dtype = "fp8"
1519+
with LLM(f"{llm_models_root()}/DeepSeek-V3-Lite/nvfp4_moe_only_mtp",
1520+
kv_cache_config=kv_cache_config,
1521+
**pytorch_config,
1522+
enable_attention_dp=False,
1523+
speculative_config=mtp_config) as llm:
1524+
assert llm.args.quant_config.quant_algo == QuantAlgo.NVFP4
1525+
1526+
task = GSM8K(self.MODEL_NAME)
1527+
task.evaluate(llm)
1528+
14861529
@pytest.mark.skip_less_device(4)
14871530
@skip_pre_blackwell
14881531
@parametrize_with_ids("torch_compile", [False, True])

tests/integration/test_lists/qa/llm_function_full.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -501,6 +501,7 @@ accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=0-
501501
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp=disable-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
502502
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[moe_backend=CUTLASS-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
503503
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[moe_backend=CUTLASS-mtp_nextn=2-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
504+
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_batch_waiting[batch_wait_timeout_iters=10-batch_wait_max_tokens_ratio=0.75-mtp_nextn=0-fp8kv=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
504505
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus_static_eplb
505506
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_online_eplb[mtp_nextn=0]
506507
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_online_eplb[mtp_nextn=2]

tests/unittest/api_stability/references/llm.yaml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,14 @@ methods:
131131
annotation: float
132132
default: 0
133133
status: prototype
134+
batch_wait_timeout_iters:
135+
annotation: int
136+
default: 0
137+
status: prototype
138+
batch_wait_max_tokens_ratio:
139+
annotation: float
140+
default: 0
141+
status: prototype
134142
print_iter_log:
135143
annotation: bool
136144
default: False

0 commit comments

Comments
 (0)