Skip to content

Commit 76d0c2b

Browse files
yweng0828Ransiki
authored andcommitted
[nvbug/5322354] fix PD + MTP + overlap scheduler accuracy issue (NVIDIA#6136)
Signed-off-by: Yue Weng <[email protected]> Signed-off-by: Ransiki Zhang <[email protected]>
1 parent 3ba1532 commit 76d0c2b

File tree

3 files changed

+30
-20
lines changed

3 files changed

+30
-20
lines changed

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1323,7 +1323,6 @@ def previous_seq_slots_device():
13231323

13241324
num_tokens = len(input_ids)
13251325
num_draft_tokens = len(draft_tokens)
1326-
num_requests = len(request_ids)
13271326
total_num_tokens = len(position_ids)
13281327
assert total_num_tokens <= self.max_num_tokens, (
13291328
"total_num_tokens should be less than or equal to max_num_tokens")
@@ -1340,6 +1339,10 @@ def previous_seq_slots_device():
13401339
self.draft_tokens_cuda[:len(draft_tokens)].copy_(draft_tokens,
13411340
non_blocking=True)
13421341
if next_draft_tokens_device is not None:
1342+
# Initialize these two values to zeros
1343+
self.previous_pos_id_offsets_cuda *= 0
1344+
self.previous_kv_lens_offsets_cuda *= 0
1345+
13431346
if previous_batch_len > 0:
13441347
previous_slots = previous_seq_slots_device()
13451348
# previous input ids
@@ -1364,24 +1367,37 @@ def previous_seq_slots_device():
13641367
pin_memory=True)
13651368
self.previous_pos_indices_cuda[0:previous_batch_tokens].copy_(
13661369
previous_pos_indices_host, non_blocking=True)
1370+
1371+
# The order of requests in a batch: [context requests, generation requests]
1372+
# generation requests: ['requests that do not have previous batch', 'requests that already have previous batch', 'dummy requests']
1373+
# 1) 'requests that do not have previous batch': disable overlap scheduler or the first step in the generation server of disaggregated serving.
1374+
# 2) 'requests that already have previous batch': previous iteration's requests.
1375+
# 3) 'dummy requests': pad dummy requests for CUDA graph or attention dp.
1376+
# Therefore, both of self.previous_pos_id_offsets_cuda and self.previous_kv_lens_offsets_cuda are also 3 segments.
1377+
# For 1) 'requests that do not have previous batch': disable overlap scheduler or the first step in the generation server of disaggregated serving.
1378+
# Set these requests' previous_pos_id_offsets and previous_kv_lens_offsets to '0' to skip the value changes in _preprocess_inputs.
1379+
# Already set to '0' during initialization.
1380+
# For 2) 'requests that already have previous batch': enable overlap scheduler.
1381+
# Set their previous_pos_id_offsets and previous_kv_lens_offsets according to new_tokens_lens_device and kv_len_offsets_device.
1382+
# For 3) 'dummy requests': pad dummy requests for CUDA graph or attention dp.
1383+
# Already set to '0' during initialization.
1384+
1385+
num_extend_reqeust_wo_dummy = len(extend_requests) - len(
1386+
extend_dummy_requests)
13671387
self.previous_pos_id_offsets_cuda[
1368-
0:previous_batch_tokens].copy_(
1388+
(num_extend_reqeust_wo_dummy - previous_batch_len) *
1389+
(1 + self.max_draft_len):num_extend_reqeust_wo_dummy *
1390+
(1 + self.max_draft_len)].copy_(
13691391
new_tokens_lens_device[self.previous_pos_indices_cuda[
13701392
0:previous_batch_tokens]],
13711393
non_blocking=True)
1372-
self.previous_kv_lens_offsets_cuda[0:previous_batch_len].copy_(
1373-
kv_len_offsets_device[previous_slots], non_blocking=True)
1374-
# for the requests that do not have previous batch, set the previous_pos_id_offsets and
1375-
# previous_kv_lens_offsets to zeros to skip the value changes in _preprocess_inputs
1376-
self.previous_pos_id_offsets_cuda[
1377-
previous_batch_tokens:num_requests *
1378-
(1 + self.max_draft_len)] *= 0
1394+
13791395
self.previous_kv_lens_offsets_cuda[
1380-
previous_batch_len:num_requests] *= 0
1381-
else:
1382-
# change the data to zeros to skip the value changes in _preprocess_inputs
1383-
self.previous_pos_id_offsets_cuda *= 0
1384-
self.previous_kv_lens_offsets_cuda *= 0
1396+
num_extend_reqeust_wo_dummy -
1397+
previous_batch_len:num_extend_reqeust_wo_dummy].copy_(
1398+
kv_len_offsets_device[previous_slots],
1399+
non_blocking=True)
1400+
13851401
elif new_tokens_device is not None:
13861402
seq_slots_device = previous_seq_slots_device()
13871403
max_draft_len = max(draft_lens)

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1022,10 +1022,6 @@ def _executor_loop_overlap(self):
10221022
)
10231023

10241024
if self.kv_cache_transceiver:
1025-
# For generation requests which have completed KV cache transfer
1026-
self._prepare_disagg_gen_transmission_complete(
1027-
scheduled_batch)
1028-
10291025
# Return the first token to the client
10301026
self._handle_first_token_response(scheduled_batch)
10311027

tests/integration/test_lists/waives.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -371,8 +371,6 @@ perf/test_perf.py::test_perf[bert_large-bench-float16-maxbs:32-input_len:128+512
371371
perf/test_perf.py::test_perf[roberta_base-bench-float16-maxbs:32-input_len:128+512] SKIP (https://nvbugspro.nvidia.com/bug/5295411)
372372
disaggregated/test_disaggregated.py::test_disaggregated_single_gpu_with_mpirun[TinyLlama-1.1B-Chat-v1.0] SKIP (https://nvbugs/5328160)
373373
stress_test/stress_test.py::test_run_stress_test[llama-v3-8b-instruct-hf_tp1-stress_time_300s_timeout_450s-MAX_UTILIZATION-pytorch-stress-test] SKIP (https://nvbugs/5328495)
374-
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype[mtp_nextn=0-overlap_scheduler=True] SKIP (https://nvbugs/5322354)
375-
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype[mtp_nextn=2-overlap_scheduler=True] SKIP (https://nvbugs/5322354)
376374
full:B200/examples/test_gemma.py::test_llm_gemma_1gpu_summary_vswa[gemma-3-1b-it-other-bfloat16-8] SKIP (https://nvbugs/5292737)
377375
full:B200/accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_auto_dtype SKIP (https://nvbugs/5295470)
378376
examples/test_mistral.py::test_llm_mistral_v1_1gpu[mistral-7b-v0.1-float16-max_attention_window_size_4096-summarization_long] SKIP (https://nvbugs/5324976)

0 commit comments

Comments
 (0)