Skip to content

Commit dcbfa7e

Browse files
authored
[https://nvbugs/5252313][fix] Fix torch compile + MTP (#6554)
Signed-off-by: Jin Li <[email protected]>
1 parent 61da2da commit dcbfa7e

File tree

5 files changed

+40
-53
lines changed

5 files changed

+40
-53
lines changed

tensorrt_llm/_torch/compilation/piecewise_optimizer.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from torch.fx.passes.split_module import split_module
1111

1212
from tensorrt_llm.llmapi.utils import enable_llm_debug
13-
from tensorrt_llm.logger import logger
1413

1514
from ..utils import (get_model_extra_attrs, get_piecewise_cuda_graph_flag,
1615
make_weak_ref)
@@ -169,14 +168,11 @@ def __call__(self, *args):
169168
if entry.cuda_graph is None:
170169

171170
if not get_enable_piecewise_cuda_graph_capture_flag():
172-
logger.warning(
173-
f"Unexpectedly capture cuda graph for {self.name} with runtime_num_of_token {runtime_num_of_token}. Will fallback to non-CUDA graph execution."
174-
)
175171
return entry.callable(*args)
176172

177-
if entry.warmup_count < 2:
173+
if entry.warmup_count < 3:
178174
entry.warmup_count += 1
179-
return self.default_callable(*args)
175+
return entry.callable(*args)
180176

181177
entry.input_addresses = [
182178
i.data_ptr() for i in args if isinstance(i, torch.Tensor)
@@ -204,6 +200,8 @@ def __call__(self, *args):
204200
i.data_ptr() for i in output if isinstance(i, torch.Tensor)
205201
]
206202

203+
entry.cuda_graph.replay()
204+
207205
return output
208206

209207
if enable_llm_debug():

tensorrt_llm/_torch/compilation/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def is_call_function(node: Node, target: Union[List[Callable], Callable]):
3030
return node.op == "call_function" and node.target == target
3131

3232

33-
_enable_piecewise_cuda_graph_capture = True
33+
_enable_piecewise_cuda_graph_capture = False
3434

3535

3636
def set_enable_piecewise_cuda_graph_capture_flag(enable: bool):

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -660,7 +660,6 @@ def disable_optimization(backend: Backend):
660660
self._torch_compile_backend)
661661

662662
self._torch_compile_backend.enable_optimization()
663-
set_enable_piecewise_cuda_graph_capture_flag(True)
664663

665664
# Disable cuda graph capture here so that we can properly capture it later
666665
with self.no_cuda_graph():
@@ -748,26 +747,28 @@ def disable_optimization(backend: Backend):
748747
resource_manager=resource_manager)
749748
torch.cuda.synchronize()
750749

751-
if self._torch_compile_piecewise_cuda_graph and self._torch_compile_enabled:
752-
with self.no_cuda_graph():
753-
with release_batch(
754-
get_torch_compile_warmup_request(
755-
1, bs)) as batch:
756-
logger.info(
757-
f"Run piecewise CUDA graph warmup for batch size={bs}"
758-
)
759-
760-
for _ in range(3):
761-
self.forward(
762-
batch,
763-
new_tensors_device=None,
764-
resource_manager=resource_manager)
750+
if self._torch_compile_piecewise_cuda_graph and self._torch_compile_enabled:
751+
for seq_lens in cuda_graph_batch_sizes:
752+
set_enable_piecewise_cuda_graph_capture_flag(True)
753+
with self.no_cuda_graph():
754+
with release_batch(
755+
get_torch_compile_warmup_request(
756+
1, seq_lens)) as batch:
757+
logger.info(
758+
f"Run piecewise CUDA graph warmup for seq_lens={seq_lens}"
759+
)
760+
# self.model.mtp_worker.stored_input_ids = []
761+
for _ in range(3):
765762
self.forward(batch,
766763
new_tensors_device=None,
767764
resource_manager=resource_manager)
768-
torch.cuda.synchronize()
769-
gc.collect()
770-
torch.cuda.empty_cache()
765+
self.forward(batch,
766+
new_tensors_device=None,
767+
resource_manager=resource_manager)
768+
torch.cuda.synchronize()
769+
gc.collect()
770+
torch.cuda.empty_cache()
771+
set_enable_piecewise_cuda_graph_capture_flag(False)
771772

772773
# Set the value back to the original value
773774
self.enable_spec_decode = self.is_spec_decode

tensorrt_llm/_torch/speculative/mtp.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1153,17 +1153,20 @@ def prepare_position_ids_and_last_tokens(position_ids, attn_metadata):
11531153
position_ids = position_ids.squeeze(0)
11541154
last_tokens_idx = torch.cumsum(
11551155
attn_metadata.seq_lens_cuda, dim=0, dtype=torch.long) - 1
1156-
return position_ids, last_tokens_idx
1156+
last_tokens_idx_host = torch.cumsum(
1157+
attn_metadata.seq_lens, dim=0, dtype=torch.long) - 1
1158+
return position_ids, last_tokens_idx, last_tokens_idx_host
11571159

1158-
position_ids, last_tokens_idx = prepare_position_ids_and_last_tokens(
1160+
position_ids, last_tokens_idx, last_tokens_idx_host = prepare_position_ids_and_last_tokens(
11591161
position_ids, attn_metadata)
1160-
inputs = self.prepare_drafter_inputs(input_ids=input_ids,
1161-
position_ids=position_ids,
1162-
last_tokens_idx=last_tokens_idx,
1163-
hidden_states=hidden_states,
1164-
accepted_tokens=accepted_tokens,
1165-
attn_metadata=attn_metadata,
1166-
spec_metadata=spec_metadata)
1162+
inputs = self.prepare_drafter_inputs(
1163+
input_ids=input_ids,
1164+
position_ids=position_ids,
1165+
last_tokens_idx_host=last_tokens_idx_host,
1166+
hidden_states=hidden_states,
1167+
accepted_tokens=accepted_tokens,
1168+
attn_metadata=attn_metadata,
1169+
spec_metadata=spec_metadata)
11671170

11681171
# Predict draft tokens
11691172
next_draft_tokens = []
@@ -1277,7 +1280,7 @@ def prepare_drafter_inputs(
12771280
self,
12781281
input_ids: torch.IntTensor,
12791282
position_ids: torch.IntTensor,
1280-
last_tokens_idx: torch.LongTensor,
1283+
last_tokens_idx_host: torch.LongTensor,
12811284
hidden_states: torch.Tensor,
12821285
accepted_tokens: torch.Tensor,
12831286
attn_metadata: AttentionMetadata,
@@ -1292,7 +1295,9 @@ def prepare_drafter_inputs(
12921295
device="cuda")
12931296
input_ids_ctx[:-1].copy_(input_prompt_ids[1:])
12941297
input_ids_ctx[
1295-
last_tokens_idx[:num_contexts]] = accepted_tokens[:num_contexts, 0]
1298+
last_tokens_idx_host[:
1299+
num_contexts]] = accepted_tokens[:num_contexts,
1300+
0]
12961301

12971302
# generation
12981303
input_ids_gen = accepted_tokens[num_contexts:, :].flatten()

tests/integration/defs/accuracy/test_llm_api_pytorch.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -874,9 +874,6 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
874874
[0, pytest.param(2, marks=skip_pre_hopper)])
875875
def test_bfloat16(self, mtp_nextn, attention_dp, cuda_graph,
876876
overlap_scheduler, torch_compile):
877-
if torch_compile and mtp_nextn > 0:
878-
pytest.skip("https://nvbugs/5252313")
879-
880877
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.75)
881878
torch_compile_config = TorchCompileConfig(
882879
enable_fullgraph=True,
@@ -913,8 +910,6 @@ def test_bfloat16(self, mtp_nextn, attention_dp, cuda_graph,
913910
def test_bfloat16_4gpus(self, tp_size, pp_size, ep_size, mtp_nextn,
914911
attention_dp, cuda_graph, overlap_scheduler,
915912
torch_compile):
916-
if torch_compile and mtp_nextn > 0:
917-
pytest.skip("https://nvbugs/5252313")
918913
if torch_compile and pp_size > 1:
919914
pytest.skip("PP with torch.compile is not supported yet.")
920915
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.75)
@@ -1004,8 +999,6 @@ def test_cute_dsl_fp8_block_scales(
1004999
overlap_scheduler,
10051000
torch_compile,
10061001
):
1007-
if torch_compile and mtp_nextn > 0:
1008-
pytest.skip("https://nvbugs/5252313")
10091002
if torch_compile and attention_dp:
10101003
pytest.skip("https://nvbugs/5252559")
10111004
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.9)
@@ -1105,8 +1098,6 @@ def test_fp8_block_scales_cuda_graph_padding_4gpus(self, mtp_nextn,
11051098
def test_fp8_block_scales_4gpus(self, tp_size, pp_size, ep_size, mtp_nextn,
11061099
fp8kv, attention_dp, cuda_graph,
11071100
overlap_scheduler, torch_compile):
1108-
if torch_compile and mtp_nextn > 0:
1109-
pytest.skip("https://nvbugs/5252313")
11101101
if torch_compile and pp_size > 1:
11111102
pytest.skip("PP with torch.compile is not supported yet.")
11121103
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.75)
@@ -1166,10 +1157,6 @@ def test_cute_dsl_fp8_block_scales_4gpus(
11661157
overlap_scheduler,
11671158
torch_compile,
11681159
):
1169-
if torch_compile and mtp_nextn > 0:
1170-
pytest.skip("https://nvbugs/5252313")
1171-
if torch_compile and attention_dp:
1172-
pytest.skip("https://nvbugs/5252559")
11731160
if torch_compile and pp_size > 1:
11741161
pytest.skip("PP with torch.compile is not supported yet.")
11751162
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.9)
@@ -1298,8 +1285,6 @@ def test_nvfp4_4gpus_online_eplb(self, fp8kv):
12981285
@parametrize_with_ids("moe_backend", ["CUTLASS", "TRTLLM"])
12991286
def test_nvfp4(self, fp8kv, attention_dp, cuda_graph, overlap_scheduler,
13001287
torch_compile, mtp_nextn, moe_backend):
1301-
if torch_compile and mtp_nextn > 0:
1302-
pytest.skip("https://nvbugs/5252313")
13031288
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.75)
13041289
torch_compile_config = TorchCompileConfig(
13051290
enable_fullgraph=True,
@@ -1345,8 +1330,6 @@ def test_nvfp4(self, fp8kv, attention_dp, cuda_graph, overlap_scheduler,
13451330
def test_nvfp4_4gpus(self, fp8kv, attention_dp, cuda_graph,
13461331
overlap_scheduler, tp_size, pp_size, ep_size,
13471332
torch_compile, mtp_nextn, moe_backend):
1348-
if torch_compile and mtp_nextn > 0:
1349-
pytest.skip("https://nvbugs/5252313")
13501333
if torch_compile and pp_size > 1:
13511334
pytest.skip("PP with torch.compile is not supported yet.")
13521335
if moe_backend == "TRTLLM" and get_sm_version() == 120:

0 commit comments

Comments
 (0)