Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
221 changes: 120 additions & 101 deletions tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py

Large diffs are not rendered by default.

59 changes: 47 additions & 12 deletions tensorrt_llm/_torch/pyexecutor/model_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
set_per_request_piecewise_cuda_graph_flag,
set_torch_compiling, with_model_extra_attrs)
from .config_utils import is_mla
from .cuda_graph_runner import CUDAGraphRunner
from .cuda_graph_runner import CUDAGraphRunner, CUDAGraphRunnerConfig
from .guided_decoder import CapturableGuidedDecoder
from .layerwise_nvtx_marker import LayerwiseNvtxMarker
from .llm_request import get_draft_token_length
Expand Down Expand Up @@ -370,9 +370,31 @@ def __init__(
# We look up this key in resource_manager during forward to find the
# kv cache manager. Can be changed to support multiple model engines
# with different KV cache managers.
self.kv_cache_manager_key = ResourceManagerType.KV_CACHE_MANAGER
self.kv_cache_manager_key = ResourceManagerType.DRAFT_KV_CACHE_MANAGER if is_draft_model else ResourceManagerType.KV_CACHE_MANAGER
self.lora_model_config: Optional[LoraModelConfig] = None
self.cuda_graph_runner = CUDAGraphRunner(self)

# Create config and runner
cuda_graph_runner_config = CUDAGraphRunnerConfig(
use_cuda_graph=self.cuda_graph_config is not None,
cuda_graph_padding_enabled=self._cuda_graph_padding_enabled,
cuda_graph_batch_sizes=self._cuda_graph_batch_sizes,
max_cuda_graph_batch_size=self._max_cuda_graph_batch_size,
max_beam_width=self.max_beam_width,
spec_config=self.spec_config,
cuda_graph_mem_pool=self._cuda_graph_mem_pool,
max_num_tokens=self.max_num_tokens,
use_mrope=self.use_mrope,
original_max_draft_len=self.original_max_draft_len,
original_max_total_draft_tokens=self.
original_max_total_draft_tokens,
is_draft_model=self.is_draft_model,
enable_attention_dp=self.enable_attention_dp,
batch_size=self.batch_size,
mapping=self.mapping,
dist=self.dist,
kv_cache_manager_key=self.kv_cache_manager_key,
)
self.cuda_graph_runner = CUDAGraphRunner(cuda_graph_runner_config)

# Setup the local cache indirection buffer only once and reuse it.
# This way it can also be used for CUDA graphs.
Expand Down Expand Up @@ -2319,11 +2341,21 @@ def forward(
return self._forward_step(inputs, gather_ids,
gather_context_logits)
with self.cuda_graph_runner.pad_batch(
scheduled_requests, resource_manager) as padded_requests:

maybe_graph, maybe_attn_metadata, maybe_spec_metadata, key = self.cuda_graph_runner.maybe_get_cuda_graph(
padded_requests, spec_resource_manager)
if maybe_graph:
scheduled_requests, resource_manager,
self.runtime_draft_len) as padded_requests:

maybe_attn_metadata, maybe_spec_metadata, key = self.cuda_graph_runner.maybe_get_cuda_graph(
padded_requests,
iter_counter=self.iter_counter,
enable_spec_decode=self.enable_spec_decode,
attn_metadata=attn_metadata,
spec_metadata=spec_metadata,
draft_tokens_cuda=self.draft_tokens_cuda
if self.is_spec_decode else None,
spec_resource_manager=spec_resource_manager,
)
can_run_graph = key is not None
if can_run_graph:
attn_metadata = maybe_attn_metadata
spec_metadata = maybe_spec_metadata
else:
Expand All @@ -2339,7 +2371,7 @@ def forward(

self.iter_counter += 1
with with_shared_pool(self.cuda_graph_runner.get_graph_pool()):
if not maybe_graph:
if not can_run_graph:
# Fallback to eager execution if graph was not used
with MoeLoadBalancerIterContext(moe_load_balancer):
outputs = self._forward_step(inputs, gather_ids,
Expand All @@ -2357,9 +2389,12 @@ def capture_forward_fn(inputs: Dict[str, Any]):
def capture_postprocess_fn(inputs: Dict[str, Any]):
self._postprocess_inputs(inputs)

self.cuda_graph_runner.capture(key, capture_forward_fn,
inputs,
capture_postprocess_fn)
self.cuda_graph_runner.capture(
key,
capture_forward_fn,
inputs,
enable_spec_decode=self.enable_spec_decode,
postprocess_fn=capture_postprocess_fn)

# here we don't need to use context since cuda graph capture didn't run kernel.
# maybe we need a cleaner way to do this.
Expand Down
1 change: 0 additions & 1 deletion tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,6 @@ def drafting_loop_wrapper(model):
# For DeepseekV3 MTP, we need to set the num_hidden_layers to 1 for the draft model
if spec_config.spec_dec_mode.is_mtp_eagle():
draft_model_engine.model.model_config.pretrained_config.num_hidden_layers = 1
draft_model_engine.kv_cache_manager_key = ResourceManagerType.DRAFT_KV_CACHE_MANAGER
draft_model_engine.load_weights_from_target_model(
model_engine.model)
else:
Expand Down
60 changes: 22 additions & 38 deletions tests/unittest/_torch/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
import torch
import torch.nn.functional as F

from tensorrt_llm.llmapi.llm_args import TorchLlmArgs
from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import (
CUDAGraphRunner, CUDAGraphRunnerConfig)
from tensorrt_llm._torch.pyexecutor.resource_manager import ResourceManagerType
from tensorrt_llm.mapping import Mapping


def ceil_div(x: int, y: int) -> int:
Expand Down Expand Up @@ -166,42 +169,23 @@ def block_scale_gemm(mat_a: torch.Tensor, mat_scale_a: torch.Tensor,
return results.view_as(x)


class MockPytorchBackendConfig:

def __init__(self, use_cuda_graph, cuda_graph_padding_enabled):
self.use_cuda_graph = use_cuda_graph
self.cuda_graph_padding_enabled = cuda_graph_padding_enabled


class MockEngine:
"""A replacement for SimpleNamespace that supports weak references."""

def __init__(self, **kwargs):
self.__dict__.update(kwargs)


def create_mock_engine(batch_size: int):

class MockSpecConfig:

class SpecDecMode:

def needs_kv_cache_recompute(self):
return False

spec_dec_mode = SpecDecMode()

return MockEngine(
llm_args=TorchLlmArgs(model="dummy"),
_cuda_graph_padding_enabled=True,
_cuda_graph_batch_sizes=[batch_size],
_max_cuda_graph_batch_size=batch_size,
def create_mock_cuda_graph_runner(batch_size: int, use_mrope: bool = False):
config = CUDAGraphRunnerConfig(
use_cuda_graph=True,
cuda_graph_padding_enabled=False,
cuda_graph_batch_sizes=[batch_size],
max_cuda_graph_batch_size=batch_size,
batch_size=batch_size,
max_beam_width=1,
max_num_tokens=8192,
is_spec_decode=False,
enable_spec_decode=False,
spec_config=MockSpecConfig(),
max_num_tokens=1,
use_mrope=use_mrope,
spec_config=None,
cuda_graph_mem_pool=None,
enable_attention_dp=False,
original_max_draft_len=0,
original_max_total_draft_tokens=0,
is_draft_model=False,
_cuda_graph_mem_pool=None,
use_mrope=False,
)
mapping=Mapping(),
dist=None,
kv_cache_manager_key=ResourceManagerType.KV_CACHE_MANAGER)
return CUDAGraphRunner(config)
9 changes: 3 additions & 6 deletions tests/unittest/_torch/modeling/test_modeling_exaone4.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class Exaone4Config(PretrainedConfig):
# TODO: Remove this once we have a proper config for Exaone4
SKIP_EXAONE4_HF_ACCURACY_TEST = True

from _torch.helpers import create_mock_engine
from _torch.helpers import create_mock_cuda_graph_runner
from transformers.cache_utils import HybridCache
from utils.util import getSMVersion

Expand All @@ -31,7 +31,6 @@ class Exaone4Config(PretrainedConfig):
from tensorrt_llm._torch.metadata import KVCacheParams
from tensorrt_llm._torch.model_config import ModelConfig
from tensorrt_llm._torch.models.modeling_exaone4 import Exaone4ForCausalLM
from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import CUDAGraphRunner
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
from tensorrt_llm.bindings.executor import KvCacheConfig
from tensorrt_llm.mapping import Mapping
Expand Down Expand Up @@ -338,10 +337,8 @@ def test_exaone4_allclose_to_hf(self, scenario: Scenario) -> None:
]
gen_position_ids = torch.cat(gen_position_ids).unsqueeze(0).cuda()

graph_runner = None
if scenario.use_cuda_graph:
mock_engine = create_mock_engine(1)
graph_runner = CUDAGraphRunner(mock_engine)
graph_runner = create_mock_cuda_graph_runner(
1) if scenario.use_cuda_graph else None

def run_forward(input_ids, position_ids, attn_metadata):
attn_metadata.prepare()
Expand Down
9 changes: 3 additions & 6 deletions tests/unittest/_torch/modeling/test_modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Any

import torch
from _torch.helpers import create_mock_engine
from _torch.helpers import create_mock_cuda_graph_runner
from parameterized import parameterized
from transformers import LlamaConfig
from transformers import LlamaForCausalLM as HFLlamaForCausalLM
Expand All @@ -16,7 +16,6 @@
from tensorrt_llm._torch.metadata import KVCacheParams
from tensorrt_llm._torch.model_config import ModelConfig
from tensorrt_llm._torch.models.modeling_llama import LlamaForCausalLM
from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import CUDAGraphRunner
from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequestState
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
from tensorrt_llm._torch.pyexecutor.scheduler import ScheduledRequests
Expand Down Expand Up @@ -331,10 +330,8 @@ def test_llama_allclose_to_hf(self, scenario: Scenario) -> None:
]
gen_position_ids = torch.cat(gen_position_ids).unsqueeze(0).cuda()

graph_runner = None
if scenario.use_cuda_graph:
mock_engine = create_mock_engine(1)
graph_runner = CUDAGraphRunner(mock_engine)
graph_runner = create_mock_cuda_graph_runner(
1) if scenario.use_cuda_graph else None

def run_forward(input_ids, position_ids, attn_metadata):
attn_metadata.prepare()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import torch
import transformers
from _torch.helpers import create_mock_engine
from _torch.helpers import create_mock_cuda_graph_runner
from parameterized import parameterized
from transformers import Llama4Config
from transformers import \
Expand All @@ -20,7 +20,6 @@
Llama4HfWeightMapper
from tensorrt_llm._torch.models.modeling_llama import \
Llama4ForConditionalGeneration
from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import CUDAGraphRunner
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
from tensorrt_llm.bindings.executor import KvCacheConfig
from tensorrt_llm.mapping import Mapping
Expand Down Expand Up @@ -406,10 +405,8 @@ def test_llama_allclose_to_hf(self, scenario: AllCloseScenario) -> None:
input_ids.size(-1) + gen_input_ids.size(-1))
]
gen_position_ids = torch.cat(gen_position_ids).unsqueeze(0).cuda()
graph_runner = None
if scenario.use_cuda_graph:
mock_engine = create_mock_engine(1)
graph_runner = CUDAGraphRunner(mock_engine)
graph_runner = create_mock_cuda_graph_runner(
1) if scenario.use_cuda_graph else None

def run_forward(input_ids, position_ids, attn_metadata):
attn_metadata.prepare()
Expand Down
8 changes: 2 additions & 6 deletions tests/unittest/_torch/modeling/test_modeling_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch
import transformers
import transformers.models.mistral3
from _torch.helpers import create_mock_engine
from _torch.helpers import create_mock_cuda_graph_runner
from PIL import Image
from utils.util import getSMVersion

Expand All @@ -19,7 +19,6 @@
from tensorrt_llm._torch.attention_backend import utils as attention_utils
from tensorrt_llm._torch.models import modeling_mistral
from tensorrt_llm._torch.pyexecutor import resource_manager
from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import CUDAGraphRunner
from tensorrt_llm.bindings import executor as executor_lib
from tensorrt_llm.models import modeling_utils

Expand Down Expand Up @@ -404,10 +403,7 @@ def test_mistral_3_vlm_allclose_to_hf(mistral_small_3_1_24b_config, backend, use
]
gen_position_ids = torch.cat(gen_position_ids).unsqueeze(0).cuda()

graph_runner = None
if use_cuda_graph:
mock_engine = create_mock_engine(1)
graph_runner = CUDAGraphRunner(mock_engine)
graph_runner = create_mock_cuda_graph_runner(1) if use_cuda_graph else None

def run_forward(input_ids, position_ids, attn_metadata):
attn_metadata.prepare()
Expand Down
9 changes: 3 additions & 6 deletions tests/unittest/_torch/modeling/test_modeling_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from dataclasses import dataclass

import torch
from _torch.helpers import create_mock_engine
from _torch.helpers import create_mock_cuda_graph_runner
from parameterized import parameterized
from transformers import MixtralConfig
from transformers import MixtralForCausalLM as HFMixtralForCausalLM
Expand All @@ -16,7 +16,6 @@
from tensorrt_llm._torch.models.checkpoints.hf.mixtral_weight_mapper import \
MixtralHfWeightMapper
from tensorrt_llm._torch.models.modeling_mixtral import MixtralForCausalLM
from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import CUDAGraphRunner
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
from tensorrt_llm.bindings.executor import KvCacheConfig
from tensorrt_llm.mapping import Mapping
Expand Down Expand Up @@ -310,10 +309,8 @@ def test_mixtral_allclose_to_hf(self, scenario: Scenario):
]
gen_position_ids = torch.cat(gen_position_ids).unsqueeze(0).cuda()

graph_runner = None
if scenario.use_cuda_graph:
mock_engine = create_mock_engine(1)
graph_runner = CUDAGraphRunner(mock_engine)
graph_runner = create_mock_cuda_graph_runner(
1) if scenario.use_cuda_graph else None

def run_forward(input_ids, position_ids, attn_metadata):
attn_metadata.prepare()
Expand Down
9 changes: 3 additions & 6 deletions tests/unittest/_torch/modeling/test_modeling_mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import pytest
import torch
from _torch.helpers import create_mock_engine
from _torch.helpers import create_mock_cuda_graph_runner
from parameterized import parameterized
from test_modeling_llama import Scenario, reduce_llama_config
from transformers import MllamaConfig
Expand All @@ -17,7 +17,6 @@
from tensorrt_llm._torch.model_config import ModelConfig
from tensorrt_llm._torch.models.modeling_mllama import \
MllamaForConditionalGeneration
from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import CUDAGraphRunner
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
from tensorrt_llm.bindings.executor import KvCacheConfig
from tensorrt_llm.mapping import Mapping
Expand Down Expand Up @@ -420,10 +419,8 @@ def test_mllama_allclose_to_hf_text_only(self, scenario: Scenario) -> None:
]
gen_position_ids = torch.cat(gen_position_ids).unsqueeze(0).cuda()

graph_runner = None
if scenario.use_cuda_graph:
mock_engine = create_mock_engine(1)
graph_runner = CUDAGraphRunner(mock_engine)
graph_runner = create_mock_cuda_graph_runner(
1) if scenario.use_cuda_graph else None

def run_forward(input_ids, position_ids, attn_metadata):
attn_metadata.prepare()
Expand Down
6 changes: 2 additions & 4 deletions tests/unittest/_torch/modeling/test_modeling_multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from typing import Dict, List, Optional, Tuple, Type

import torch
from _torch.helpers import create_mock_engine
from _torch.helpers import create_mock_cuda_graph_runner
from transformers import AutoProcessor, AutoTokenizer, PretrainedConfig, PreTrainedModel
from utils.llm_data import llm_models_root

Expand All @@ -17,7 +17,6 @@
from tensorrt_llm._torch.attention_backend.utils import get_attention_backend
from tensorrt_llm._torch.metadata import KVCacheParams
from tensorrt_llm._torch.model_config import ModelConfig
from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import CUDAGraphRunner
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
from tensorrt_llm._utils import str_dtype_to_torch
from tensorrt_llm.bindings.executor import KvCacheConfig
Expand Down Expand Up @@ -425,8 +424,7 @@ def run_trtllm_forward(self, trtllm_inputs, use_cuda_graph: bool = False):
trtllm_inputs["attn_metadata"].prepare()
return self.trtllm_model.forward(**trtllm_inputs)
else:
mock_engine = create_mock_engine(1)
graph_runner = CUDAGraphRunner(mock_engine)
graph_runner = create_mock_cuda_graph_runner(1)
trtllm_inputs["attn_metadata"] = trtllm_inputs[
"attn_metadata"
].create_cuda_graph_metadata(1)
Expand Down
Loading
Loading