Skip to content

Commit a061180

Browse files
committed
Added atomic_subgraph template
1 parent 1b524f9 commit a061180

File tree

1 file changed

+25
-63
lines changed

1 file changed

+25
-63
lines changed

py/torch_tensorrt/dynamo/partitioning/_atomic_subgraphs.py

Lines changed: 25 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from functools import lru_cache
2-
from typing import Callable, Dict, List, Set
2+
from typing import Any, Callable, Dict, List, Set, Tuple
33

44
import torch
55
from torch.fx.passes.utils.matcher_utils import SubgraphMatcher
@@ -9,20 +9,25 @@
99

1010

1111
def register_atomic_subgraph(
12+
init_args: Tuple[Any, ...] = tuple(),
1213
is_core_aten: bool = False,
1314
) -> Callable[[torch.nn.Module], torch.nn.Module]:
1415

1516
def decorator(subgraph: torch.nn.Module) -> torch.nn.Module:
16-
ATOMIC_SUBGRAPHS.append((subgraph, is_core_aten))
17+
ATOMIC_SUBGRAPHS.append((subgraph, init_args, is_core_aten))
1718
return subgraph
1819

1920
return decorator
2021

2122

22-
@register_atomic_subgraph(is_core_aten=True)
23-
class ConvBNReLU(torch.nn.Module): # type: ignore[misc]
24-
def __init__(self) -> None:
23+
@register_atomic_subgraph(init_args=(aten.silu.default,), is_core_aten=True)
24+
@register_atomic_subgraph(init_args=(aten.gelu.default,), is_core_aten=True)
25+
@register_atomic_subgraph(init_args=(aten.relu.default,), is_core_aten=True)
26+
@register_atomic_subgraph(init_args=(aten.sigmoid.default,), is_core_aten=True)
27+
class ConvBNActivation(torch.nn.Module): # type: ignore[misc]
28+
def __init__(self, activation: torch._ops.OpOverload) -> None:
2529
super().__init__()
30+
self.activation = activation
2631

2732
def forward(
2833
self,
@@ -56,46 +61,18 @@ def forward(
5661
x = aten._native_batch_norm_legit_no_training.default(
5762
x, bn_weight, bn_bias, running_mean, running_var, momentum, eps
5863
)[0]
59-
x = aten.relu.default(x)
60-
return x
61-
62-
63-
@register_atomic_subgraph(is_core_aten=True)
64-
class ConvReLU(torch.nn.Module): # type: ignore[misc]
65-
def __init__(self) -> None:
66-
super().__init__()
67-
68-
def forward(
69-
self,
70-
x: torch.Tensor,
71-
weight: torch.Tensor,
72-
bias: torch.Tensor,
73-
stride: List[int],
74-
padding: List[int],
75-
dilation: List[int],
76-
transposed: bool,
77-
output_padding: List[int],
78-
groups: int,
79-
) -> torch.Tensor:
80-
x = aten.convolution.default(
81-
x,
82-
weight,
83-
bias,
84-
stride,
85-
padding,
86-
dilation,
87-
transposed,
88-
output_padding,
89-
groups,
90-
)
91-
x = aten.relu.default(x)
64+
x = self.activation(x)
9265
return x
9366

9467

95-
@register_atomic_subgraph(is_core_aten=True)
96-
class ConvGelu(torch.nn.Module): # type: ignore[misc]
97-
def __init__(self) -> None:
68+
@register_atomic_subgraph(init_args=(aten.silu.default,), is_core_aten=True)
69+
@register_atomic_subgraph(init_args=(aten.gelu.default,), is_core_aten=True)
70+
@register_atomic_subgraph(init_args=(aten.relu.default,), is_core_aten=True)
71+
@register_atomic_subgraph(init_args=(aten.sigmoid.default,), is_core_aten=True)
72+
class ConvActivation(torch.nn.Module): # type: ignore[misc]
73+
def __init__(self, activation: torch._ops.OpOverload) -> None:
9874
super().__init__()
75+
self.activation = activation
9976

10077
def forward(
10178
self,
@@ -120,26 +97,11 @@ def forward(
12097
output_padding,
12198
groups,
12299
)
123-
x = aten.gelu.default(x)
124-
return x
125-
126-
127-
@register_atomic_subgraph(is_core_aten=True)
128-
class ConvSilu(torch.nn.Module): # type: ignore[misc]
129-
def __init__(self) -> None:
130-
super().__init__()
131-
132-
def forward(
133-
self, x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor
134-
) -> torch.Tensor:
135-
x = aten.convolution.default(
136-
x, weight, bias, [1, 1], [1, 1], [1, 1], False, [0, 0], 1
137-
)
138-
x = aten.silu.default(x)
100+
x = self.activation(x)
139101
return x
140102

141103

142-
@register_atomic_subgraph(is_core_aten=True)
104+
@register_atomic_subgraph(init_args=(), is_core_aten=True)
143105
class MulAdd(torch.nn.Module): # type: ignore[misc]
144106
def __init__(self) -> None:
145107
super().__init__()
@@ -152,7 +114,7 @@ def forward(
152114
return x
153115

154116

155-
@register_atomic_subgraph(is_core_aten=True)
117+
@register_atomic_subgraph(init_args=(), is_core_aten=True)
156118
class MulMul(torch.nn.Module): # type: ignore[misc]
157119
def __init__(self) -> None:
158120
super().__init__()
@@ -198,8 +160,8 @@ def get_compiled_atomic_subgraphs() -> List[torch.fx.GraphModule]:
198160
LRU cache the result to avoid recompiling the same pattern multiple times.
199161
"""
200162
compiled_atomic_subgraphs = []
201-
for pattern, is_core_aten in ATOMIC_SUBGRAPHS:
202-
pattern_graph = trace_atomic_graph(pattern, is_core_aten)
163+
for pattern, init_args, is_core_aten in ATOMIC_SUBGRAPHS:
164+
pattern_graph = trace_atomic_graph(pattern, init_args, is_core_aten)
203165
if not is_core_aten:
204166
# TODO: Add decomposition and lowering if is_core_aten is False
205167
raise NotImplementedError(
@@ -211,10 +173,10 @@ def get_compiled_atomic_subgraphs() -> List[torch.fx.GraphModule]:
211173

212174
@lru_cache(maxsize=None)
213175
def trace_atomic_graph(
214-
graph: torch.nn.Module, is_core_aten: bool = True
176+
graph: torch.nn.Module, init_args: Any, is_core_aten: bool = True
215177
) -> torch.fx.GraphModule:
216178
if is_core_aten:
217-
return torch.fx.symbolic_trace(graph())
179+
return torch.fx.symbolic_trace(graph(*init_args))
218180
else:
219181
raise NotImplementedError(
220182
"Resource partitioner currently does not support unlowered atomic subgraphs"

0 commit comments

Comments
 (0)