Skip to content
6 changes: 0 additions & 6 deletions core/runtime/execute_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,6 @@ void setup_input_tensors(
TORCHTRT_CHECK(
inputs[i].is_cuda(), "Expected input tensors to have device cuda, found device " << inputs[i].device());

auto expected_type =
util::TRTDataTypeToScalarType(compiled_engine->exec_ctx->getEngine().getTensorDataType(name.c_str()));
TORCHTRT_CHECK(
inputs[i].dtype() == expected_type,
"Expected input tensors to have type " << expected_type << ", found type " << inputs[i].dtype());

auto dims = core::util::toDims(inputs[i].sizes());
auto shape = core::util::toVec(dims);
LOG_DEBUG("Input Name: " << name << " Shape: " << dims);
Expand Down
103 changes: 103 additions & 0 deletions examples/dynamo/autocast_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import torch
import torch.nn as nn
import torch_tensorrt
import torchvision


class MyModule(torch.nn.Module):
def forward(self, a_float32, b_float32, c_float32, d_float32):
with torch.autocast(device_type="cuda"):
e_float16 = torch.mm(a_float32, b_float32)
with torch.autocast(device_type="cuda", enabled=False):
# Calls e_float16.float() to ensure float32 execution
# (necessary because e_float16 was created in an autocasted region)
f_float32 = torch.mm(c_float32, e_float16.float())

# No manual casts are required when re-entering the autocast-enabled region.
# torch.mm again runs in float16 and produces float16 output, regardless of input types.
g_float16 = torch.mm(d_float32, f_float32)
return g_float16


class AutocastExample(nn.Module):
def __init__(self):
super(AutocastExample, self).__init__()
self.conv1 = nn.Conv2d(
in_channels=3, out_channels=8, kernel_size=3, stride=1, padding=1
)
self.relu1 = nn.ReLU()
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv2 = nn.Conv2d(
in_channels=8, out_channels=16, kernel_size=3, stride=1, padding=1
)
self.relu2 = nn.ReLU()
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = nn.Flatten()
self.fc1 = nn.Linear(16 * 8 * 8, 10)

def forward(self, x, y):
out = self.pool1(self.relu1(self.conv1(x))) # fp16
x = self.pool2(self.relu2(self.conv2(out))) # fp16
x = self.flatten(x)
with torch.autocast(x.device.type, enabled=True, dtype=torch.float32):
x = self.fc1(x) # fp32
with torch.autocast(x.device.type, enabled=False):
x = torch.sub(x.half(), y) # fp16
out2 = torch.add(x, x) # fp16
with torch.autocast(x.device.type, enabled=True, dtype=torch.float16):
out2 = torch.log(out2) # fp32
return x, out, out2


class MyResNet18Wrapper(torch.nn.Module):
def __init__(self, num_classes=1000, pretrained=True):
super(MyResNet18Wrapper, self).__init__()
self.resnet = torchvision.models.resnet18(
num_classes=num_classes, weights="IMAGENET1K_V1" if pretrained else None
)

def forward(self, x):
x = self.resnet(x)
return x


if __name__ == "__main__":
# model = MyModule().cuda().eval()
# inputs = (torch.randn((8, 8), device="cuda"),
# torch.randn((8, 8), device="cuda"),
# torch.randn((8, 8), device="cuda"),
# torch.randn((8, 8), device="cuda"),)

# model = AutocastExample().cuda().eval()
# inputs = (torch.randn((1, 3, 32, 32), dtype=torch.float32, device="cuda"),
# torch.randn((1,), dtype=torch.float16, device="cuda"),)

model = MyResNet18Wrapper().cuda().eval()
inputs = (torch.randn((1, 3, 224, 224), dtype=torch.float32, device="cuda"),)

ep = torch.export.export(model, inputs)

with torch_tensorrt.dynamo.Debugger(
"graphs",
logging_dir=".",
engine_builder_monitor=False,
):
trt_mod = torch_tensorrt.compile(
ep.module(),
arg_inputs=inputs,
min_block_size=1,
use_python_runtime=True,
##### weak typing #####
# use_explicit_typing=False,
# enabled_precisions={torch.float16},
##### strong typing + autocast #####
use_explicit_typing=True,
enable_autocast=True,
low_precision_type=torch.float16,
# nodes_to_exclude={"^conv2d$"},
targets_to_exclude={},
data_max=512,
max_depth_of_reduction=None,
)

trt_out = trt_mod(*inputs)
72 changes: 71 additions & 1 deletion py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def cross_compile_for_windows(
disable_tf32 (bool): Force FP32 layers to use traditional as FP32 format vs the default behavior of rounding the inputs to 10-bit mantissas before multiplying, but accumulates the sum using 23-bit mantissas
assume_dynamic_shape_support (bool): Setting this to true enables the converters work for both dynamic and static shapes. Default: False
sparse_weights (bool): Enable sparsity for convolution and fully connected layers.
enabled_precision (Set(Union(torch.dtype, torch_tensorrt.dtype))): The set of datatypes that TensorRT can use when selecting kernels
enabled_precisions (Set(Union(torch.dtype, torch_tensorrt.dtype))): The set of datatypes that TensorRT can use when selecting kernels
capability (torch_tensorrt.EngineCapability): Restrict kernel selection to safe gpu kernels or safe dla kernels
num_avg_timing_iters (int): Number of averaging timing iterations used to select kernels
workspace_size (int): Maximum size of workspace given to TensorRT
Expand Down Expand Up @@ -434,6 +434,14 @@ def compile(
l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING,
offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU,
use_distributed_mode_trace: bool = _defaults.USE_DISTRIBUTED_MODE_TRACE,
enable_autocast: bool = _defaults.ENABLE_AUTOCAST,
low_precision_type: Optional[
Union[torch.dtype, dtype]
] = _defaults.LOW_PRECISION_TYPE,
nodes_to_exclude: Collection[str] = _defaults.NODES_TO_EXCLUDE,
targets_to_exclude: Collection[Target] = _defaults.TARGETS_TO_EXCLUDE,
data_max: float = _defaults.DATA_MAX,
max_depth_of_reduction: Optional[int] = _defaults.MAX_DEPTH_OF_REDUCTION,
**kwargs: Any,
) -> torch.fx.GraphModule:
"""Compile an ExportedProgram module for NVIDIA GPUs using TensorRT
Expand Down Expand Up @@ -511,6 +519,12 @@ def compile(
l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit).
offload_module_to_cpu (bool): Offload the module to CPU. This is useful when we need to minimize GPU memory usage.
use_distributed_mode_trace (bool): Using aot_autograd to trace the graph. This is enabled when DTensors or distributed tensors are present in distributed model
enable_autocast (bool): Whether to enable autocast. If enabled, use_explicit_typing will be set to True.
low_precision_type (Optional[Union[torch.dtype, dtype]]): The precision to reduce to. We currently support torch.float16 and torch.bfloat16. Default is None, which means no low precision is used.
nodes_to_exclude (Collection[str]): The set of regex patterns to match node names that should remain in FP32. Default is [].
targets_to_exclude (Collection[Target]): The set of targets (ATen ops) that should remain in FP32. Default is [].
data_max (float): Maximum absolute value for node outputs, nodes with outputs greater than this value will remain in FP32. Default is 512.
max_depth_of_reduction (Optional[int]): Maximum depth of reduction allowed in low precision. Nodes with higher reduction depths will remain in FP32. If not provided, infinity will be used. Default is None.
**kwargs: Any,
Returns:
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
Expand Down Expand Up @@ -584,6 +598,10 @@ def compile(
"\nThis feature is unimplemented in Torch-TRT Dynamo currently."
)

if enable_autocast:
use_explicit_typing = True
logger.debug("Autocast is enabled, setting use_explicit_typing to True.")

if use_explicit_typing:
if len(enabled_precisions) != 1 or not any(
x in enabled_precisions
Expand All @@ -593,6 +611,19 @@ def compile(
f"use_explicit_typing was set to True, however found that enabled_precisions was also specified (saw: {enabled_precisions}, expected: dtype.f32, dtype.f4). enabled_precisions should not be used when use_explicit_typing=True"
)

if low_precision_type is not None:
if not isinstance(low_precision_type, (torch.dtype, dtype)):
raise ValueError(
f"low_precision_type must be a torch.dtype or torch_tensorrt._enums.dtype, got {type(low_precision_type)}"
)
if low_precision_type not in {
torch.float16,
torch.bfloat16,
} and low_precision_type not in {dtype.f16, dtype.bf16}:
raise ValueError(
f"low_precision_type must be one of torch.float16, torch.bfloat16, dtype.f16, dtype.bf16, got {low_precision_type}"
)

if use_fp32_acc:
logger.debug(
"FP32 accumulation for matmul layers is enabled. This option should only be enabled if the model already has FP16 weights and has no effect if it has FP32 weights. \
Expand Down Expand Up @@ -622,6 +653,38 @@ def compile(
if not isinstance(arg_inputs, collections.abc.Sequence):
arg_inputs = [arg_inputs] # type: ignore

# save intermediate outputs of each node for Autocast
intermediate_node_outputs = {}
if not use_explicit_typing:

class DumpInterpreter(torch.fx.Interpreter): # type: ignore[misc]
"""Dump intermediate outputs of each node"""

def run_node(self, n: torch.fx.Node) -> Any:
if (
n.op == "call_function"
and n.target != torch.ops.higher_order.wrap_with_autocast
):
out = super().run_node(n)
if not isinstance(out, torch.Tensor):
raise ValueError(
f"Please file a bug with Torch-TensorRT because it expects a torch.Tensor but got {type(out)} for node {n.name}."
)
intermediate_node_outputs[n.name] = out
return out
return super().run_node(n)

def _materialize(x: Input | torch.Tensor) -> torch.Tensor:
"""Materialize an Input object to a tensor"""
if isinstance(x, Input):
return x.torch_tensor
return x

with torch.no_grad():
mat_args = tuple(_materialize(a) for a in arg_inputs)
mat_kwargs = {k: _materialize(v) for k, v in kwarg_inputs.items()}
DumpInterpreter(exported_program.module()).run(*mat_args, **mat_kwargs)

# Prepare torch_trt inputs
trt_arg_inputs: Sequence[Input] = prepare_inputs(arg_inputs)
trt_kwarg_inputs: Optional[dict[Any, Any]] = prepare_inputs(kwarg_inputs)
Expand Down Expand Up @@ -680,6 +743,13 @@ def compile(
"l2_limit_for_tiling": l2_limit_for_tiling,
"offload_module_to_cpu": offload_module_to_cpu,
"use_distributed_mode_trace": use_distributed_mode_trace,
"enable_autocast": enable_autocast,
"low_precision_type": low_precision_type,
"nodes_to_exclude": nodes_to_exclude,
"targets_to_exclude": targets_to_exclude,
"data_max": data_max,
"max_depth_of_reduction": max_depth_of_reduction,
"intermediate_node_outputs": intermediate_node_outputs,
}

settings = CompilationSettings(**compilation_options)
Expand Down
6 changes: 6 additions & 0 deletions py/torch_tensorrt/dynamo/_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,12 @@
L2_LIMIT_FOR_TILING = -1
USE_DISTRIBUTED_MODE_TRACE = False
OFFLOAD_MODULE_TO_CPU = False
ENABLE_AUTOCAST = False
LOW_PRECISION_TYPE = None
NODES_TO_EXCLUDE = set[str]()
TARGETS_TO_EXCLUDE = set[torch.fx.node.Target]()
DATA_MAX = 512
MAX_DEPTH_OF_REDUCTION = None

if platform.system() == "Linux":
import pwd
Expand Down
26 changes: 26 additions & 0 deletions py/torch_tensorrt/dynamo/_settings.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
from dataclasses import dataclass, field
from typing import Any, Collection, Optional, Set, Tuple, Union

import torch
from torch.fx.node import Target
from torch_tensorrt._Device import Device
from torch_tensorrt._enums import EngineCapability, dtype
from torch_tensorrt.dynamo._defaults import (
ASSUME_DYNAMIC_SHAPE_SUPPORT,
CACHE_BUILT_ENGINES,
DATA_MAX,
DISABLE_TF32,
DLA_GLOBAL_DRAM_SIZE,
DLA_LOCAL_DRAM_SIZE,
DLA_SRAM_SIZE,
DRYRUN,
ENABLE_AUTOCAST,
ENABLE_CROSS_COMPILE_FOR_WINDOWS,
ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
ENABLE_WEIGHT_STREAMING,
Expand All @@ -21,8 +24,11 @@
IMMUTABLE_WEIGHTS,
L2_LIMIT_FOR_TILING,
LAZY_ENGINE_INIT,
LOW_PRECISION_TYPE,
MAX_AUX_STREAMS,
MAX_DEPTH_OF_REDUCTION,
MIN_BLOCK_SIZE,
NODES_TO_EXCLUDE,
NUM_AVG_TIMING_ITERS,
OFFLOAD_MODULE_TO_CPU,
OPTIMIZATION_LEVEL,
Expand All @@ -32,6 +38,7 @@
REUSE_CACHED_ENGINES,
SPARSE_WEIGHTS,
STRIP_ENGINE_WEIGHTS,
TARGETS_TO_EXCLUDE,
TILING_OPTIMIZATION_LEVEL,
TIMING_CACHE_PATH,
TRUNCATE_DOUBLE,
Expand Down Expand Up @@ -97,6 +104,13 @@ class CompilationSettings:
tiling_optimization_level (str): The optimization level of tiling strategies. A higher level allows TensorRT to spend more time searching for better tiling strategy. We currently support ["none", "fast", "moderate", "full"].
l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit).
use_distributed_mode_trace (bool): Using aot_autograd to trace the graph. This is enabled when DTensors or distributed tensors are present in distributed model
enable_autocast (bool): Whether to enable autocast. If enabled, use_explicit_typing will be set to True.
low_precision_type (Optional[Union[torch.dtype, dtype]]): The precision to reduce to. We currently support torch.float16 and torch.bfloat16. Default is None, which means no low precision is used.
nodes_to_exclude (Collection[str]): The set of regex patterns to match node names that should remain in FP32. Default is [].
targets_to_exclude (Collection[Target]): The set of targets (ATen ops) that should remain in FP32. Default is [].
data_max (float): Maximum absolute value for node outputs, nodes with outputs greater than this value will remain in FP32. Default is 512.
max_depth_of_reduction (Optional[int]): Maximum depth of reduction allowed in low precision. Nodes with higher reduction depths will remain in FP32. If not provided, infinity will be used. Default is None.
intermediate_node_outputs (dict[str, torch.Tensor]): The intermediate node outputs of the graph. Default is {}.
"""

enabled_precisions: Set[dtype] = field(default_factory=lambda: ENABLED_PRECISIONS)
Expand Down Expand Up @@ -140,6 +154,17 @@ class CompilationSettings:
l2_limit_for_tiling: int = L2_LIMIT_FOR_TILING
use_distributed_mode_trace: bool = USE_DISTRIBUTED_MODE_TRACE
offload_module_to_cpu: bool = OFFLOAD_MODULE_TO_CPU
enable_autocast: bool = ENABLE_AUTOCAST
low_precision_type: Optional[dtype] = LOW_PRECISION_TYPE
nodes_to_exclude: Collection[str] = field(default_factory=lambda: NODES_TO_EXCLUDE)
targets_to_exclude: Collection[Target] = field(
default_factory=lambda: TARGETS_TO_EXCLUDE
)
data_max: float = DATA_MAX
max_depth_of_reduction: Optional[int] = MAX_DEPTH_OF_REDUCTION
intermediate_node_outputs: dict[str, torch.Tensor] = field(
default_factory=lambda: {}
)

def __getstate__(self) -> dict[str, Any]:
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
Expand All @@ -157,6 +182,7 @@ def __setstate__(self, state: dict[str, Any]) -> None:
self.__dict__.update(state)


# If any of the following setting is changed, the engine should be rebuilt.
_SETTINGS_TO_BE_ENGINE_INVARIANT = (
"enabled_precisions",
"max_aux_streams",
Expand Down
11 changes: 7 additions & 4 deletions py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,13 @@
from .remove_num_users_is_0_nodes import remove_num_users_is_0_nodes
from .repair_input_as_output import repair_input_as_output
from .replace_max_pool_with_indices import replace_max_pool_with_indices
from .rule_based_autocast import rule_based_autocast

pre_lowering_pass_list = [
remove_detach,
rule_based_autocast,
remove_assert_nodes, # rule_based_autocast might insert assert nodes
]

post_lowering_pass_list = [
remove_input_alias_fixing_clones,
Expand All @@ -27,10 +34,6 @@
complex_graph_detection,
]

pre_lowering_pass_list = [
remove_detach,
]

if not is_tegra_platform():
from .fuse_distributed_ops import fuse_distributed_ops

Expand Down
Loading
Loading