11from functools import lru_cache
2- from typing import Callable , Dict , List , Set
2+ from typing import Any , Callable , Dict , List , Set , Tuple
33
44import torch
55from torch .fx .passes .utils .matcher_utils import SubgraphMatcher
99
1010
1111def register_atomic_subgraph (
12+ init_args : Tuple [Any , ...] = tuple (),
1213 is_core_aten : bool = False ,
1314) -> Callable [[torch .nn .Module ], torch .nn .Module ]:
1415
1516 def decorator (subgraph : torch .nn .Module ) -> torch .nn .Module :
16- ATOMIC_SUBGRAPHS .append ((subgraph , is_core_aten ))
17+ ATOMIC_SUBGRAPHS .append ((subgraph , init_args , is_core_aten ))
1718 return subgraph
1819
1920 return decorator
2021
2122
22- @register_atomic_subgraph (is_core_aten = True )
23- class ConvBNReLU (torch .nn .Module ): # type: ignore[misc]
24- def __init__ (self ) -> None :
23+ @register_atomic_subgraph (init_args = (aten .silu .default ,), is_core_aten = True )
24+ @register_atomic_subgraph (init_args = (aten .gelu .default ,), is_core_aten = True )
25+ @register_atomic_subgraph (init_args = (aten .relu .default ,), is_core_aten = True )
26+ @register_atomic_subgraph (init_args = (aten .sigmoid .default ,), is_core_aten = True )
27+ class ConvBNActivation (torch .nn .Module ): # type: ignore[misc]
28+ def __init__ (self , activation : torch ._ops .OpOverload ) -> None :
2529 super ().__init__ ()
30+ self .activation = activation
2631
2732 def forward (
2833 self ,
@@ -56,46 +61,18 @@ def forward(
5661 x = aten ._native_batch_norm_legit_no_training .default (
5762 x , bn_weight , bn_bias , running_mean , running_var , momentum , eps
5863 )[0 ]
59- x = aten .relu .default (x )
60- return x
61-
62-
63- @register_atomic_subgraph (is_core_aten = True )
64- class ConvReLU (torch .nn .Module ): # type: ignore[misc]
65- def __init__ (self ) -> None :
66- super ().__init__ ()
67-
68- def forward (
69- self ,
70- x : torch .Tensor ,
71- weight : torch .Tensor ,
72- bias : torch .Tensor ,
73- stride : List [int ],
74- padding : List [int ],
75- dilation : List [int ],
76- transposed : bool ,
77- output_padding : List [int ],
78- groups : int ,
79- ) -> torch .Tensor :
80- x = aten .convolution .default (
81- x ,
82- weight ,
83- bias ,
84- stride ,
85- padding ,
86- dilation ,
87- transposed ,
88- output_padding ,
89- groups ,
90- )
91- x = aten .relu .default (x )
64+ x = self .activation (x )
9265 return x
9366
9467
95- @register_atomic_subgraph (is_core_aten = True )
96- class ConvGelu (torch .nn .Module ): # type: ignore[misc]
97- def __init__ (self ) -> None :
68+ @register_atomic_subgraph (init_args = (aten .silu .default ,), is_core_aten = True )
69+ @register_atomic_subgraph (init_args = (aten .gelu .default ,), is_core_aten = True )
70+ @register_atomic_subgraph (init_args = (aten .relu .default ,), is_core_aten = True )
71+ @register_atomic_subgraph (init_args = (aten .sigmoid .default ,), is_core_aten = True )
72+ class ConvActivation (torch .nn .Module ): # type: ignore[misc]
73+ def __init__ (self , activation : torch ._ops .OpOverload ) -> None :
9874 super ().__init__ ()
75+ self .activation = activation
9976
10077 def forward (
10178 self ,
@@ -120,26 +97,11 @@ def forward(
12097 output_padding ,
12198 groups ,
12299 )
123- x = aten .gelu .default (x )
124- return x
125-
126-
127- @register_atomic_subgraph (is_core_aten = True )
128- class ConvSilu (torch .nn .Module ): # type: ignore[misc]
129- def __init__ (self ) -> None :
130- super ().__init__ ()
131-
132- def forward (
133- self , x : torch .Tensor , weight : torch .Tensor , bias : torch .Tensor
134- ) -> torch .Tensor :
135- x = aten .convolution .default (
136- x , weight , bias , [1 , 1 ], [1 , 1 ], [1 , 1 ], False , [0 , 0 ], 1
137- )
138- x = aten .silu .default (x )
100+ x = self .activation (x )
139101 return x
140102
141103
142- @register_atomic_subgraph (is_core_aten = True )
104+ @register_atomic_subgraph (init_args = (), is_core_aten = True )
143105class MulAdd (torch .nn .Module ): # type: ignore[misc]
144106 def __init__ (self ) -> None :
145107 super ().__init__ ()
@@ -152,7 +114,7 @@ def forward(
152114 return x
153115
154116
155- @register_atomic_subgraph (is_core_aten = True )
117+ @register_atomic_subgraph (init_args = (), is_core_aten = True )
156118class MulMul (torch .nn .Module ): # type: ignore[misc]
157119 def __init__ (self ) -> None :
158120 super ().__init__ ()
@@ -198,8 +160,8 @@ def get_compiled_atomic_subgraphs() -> List[torch.fx.GraphModule]:
198160 LRU cache the result to avoid recompiling the same pattern multiple times.
199161 """
200162 compiled_atomic_subgraphs = []
201- for pattern , is_core_aten in ATOMIC_SUBGRAPHS :
202- pattern_graph = trace_atomic_graph (pattern , is_core_aten )
163+ for pattern , init_args , is_core_aten in ATOMIC_SUBGRAPHS :
164+ pattern_graph = trace_atomic_graph (pattern , init_args , is_core_aten )
203165 if not is_core_aten :
204166 # TODO: Add decomposition and lowering if is_core_aten is False
205167 raise NotImplementedError (
@@ -211,10 +173,10 @@ def get_compiled_atomic_subgraphs() -> List[torch.fx.GraphModule]:
211173
212174@lru_cache (maxsize = None )
213175def trace_atomic_graph (
214- graph : torch .nn .Module , is_core_aten : bool = True
176+ graph : torch .nn .Module , init_args : Any , is_core_aten : bool = True
215177) -> torch .fx .GraphModule :
216178 if is_core_aten :
217- return torch .fx .symbolic_trace (graph ())
179+ return torch .fx .symbolic_trace (graph (* init_args ))
218180 else :
219181 raise NotImplementedError (
220182 "Resource partitioner currently does not support unlowered atomic subgraphs"
0 commit comments