5050
5151import psutil
5252import torch
53+ from torch .fx .experimental .const_fold import _inline_module
5354from torch .fx .passes .splitter_base import Subgraph , _SplitterBase
5455from torch .fx .passes .tools_common import CALLABLE_NODE_OPS
5556from torch_tensorrt .dynamo .partitioning ._atomic_subgraphs import (
@@ -77,17 +78,16 @@ class ResourcePartitioner(_SplitterBase): # type: ignore
7778 def __init__ (
7879 self ,
7980 module : torch .fx .GraphModule ,
80- partitioned_module : torch .fx .GraphModule ,
8181 cpu_memory_budget : int ,
82+ submodule_name : str ,
8283 ):
8384
8485 assert isinstance (module , torch .fx .GraphModule )
85- assert isinstance (partitioned_module , torch .fx .GraphModule )
8686
8787 self .module = module
88- self .partitioned_module = partitioned_module
8988 self .cpu_memory_budget = cpu_memory_budget
90-
89+ self .resource_split_count = 0
90+ self .submodule_name = submodule_name
9191 self .deps = self .find_deps ()
9292
9393 self .non_acc_submodule_name = "_run_on_gpu_"
@@ -119,64 +119,39 @@ def partition_graph(self) -> torch.fx.GraphModule:
119119 # Tag the accelerated nodes and split the graph accordingly
120120 self .tag (subgraphs )
121121
122- gm = self .split ()
122+ gm = self .split (remove_tag = True )
123123
124124 return gm
125125
126- def put_nodes_into_subgraphs (self ) -> list [Subgraph ]:
127- """Map original graph nodes into capability-based subgraphs.
128-
129- - Iterates `partitioned_module` submodules to establish which node names
130- belong to which subgraph (accelerated or not).
131- - Builds a fusion pattern map for each subgraph so that known fusion groups remain intact.
132- Note that since fusion map is built for each subgraph, the capability partitioning can still break the fusion groups.
133- - Put the nodes into the subgraphs based on the capability partitioning.
134- - Verifies the resulting list of subgraphs is topologically ordered.
126+ def tag (self , subgraphs : list [Subgraph ]) -> None :
127+ self . tags = []
128+ for subgraph in subgraphs :
129+ tag = f" { self . submodule_name } _resource_split_ { self . resource_split_count } "
130+ self . resource_split_count += 1
131+ self . tags . append ( tag )
132+ for node in subgraph . nodes :
133+ node . tag = tag
134+ self . _node_submodule_map [ node . name ] = tag
135135
136+ def put_nodes_into_subgraphs (self ) -> list [Subgraph ]:
137+ """
138+ Put the nodes into the subgraphs and erase the tag from previous partitioner if it exists.
136139 Returns:
137- list[Subgraph]: Ordered subgraphs consisting of nodes in `module` based on capability partitioning .
140+ list[Subgraph]: Ordered subgraphs consisting of nodes in `module`.
138141 """
139- subgraphs_map = {}
140- subgraphs = []
141- name_to_node_map = (
142- {}
143- ) # We use this map to help map the nodes in partitioned module to the nodes in original module.
144- for name , _ in self .partitioned_module .named_children ():
145- # We first iterate over the partitioned module to find the subgraphs based on capability partitioning.
146- submodule = getattr (self .partitioned_module , name )
147- if not isinstance (submodule , torch .fx .graph_module .GraphModule ):
148- continue
149- subgraph = Subgraph (is_acc = "acc" in name , nodes = [])
150- subgraphs .append (subgraph )
151- self .fusion_patterns .update (get_node_in_fusion_pattern (submodule .graph ))
152-
153- for node in submodule .graph .nodes :
154- # Erase the tag from previous partitioner if it exists
155- if hasattr (node , "tag" ):
156- delattr (node , "tag" )
157-
158- if node .op in CALLABLE_NODE_OPS :
159- # Store which subgraph the node should be put in
160- subgraphs_map [node .name ] = subgraph
161142
162- # We then iterate over the original module to put the nodes into the subgraphs.
143+ nodes = []
163144 for node in self .module .graph .nodes :
164145 if hasattr (node , "tag" ):
165- # Erase the tag from previous partitioner
166- delattr (node , "tag" )
146+ del node .tag
167147 if node .op in CALLABLE_NODE_OPS :
168- name_to_node_map [node .name ] = node
169- subgraphs_map [node .name ].nodes .append (node )
148+ nodes .append (node )
149+ subgraphs = [Subgraph (is_acc = True , nodes = nodes )]
150+ self .fusion_patterns = get_node_in_fusion_pattern (self .module .graph )
170151
171152 assert self .check_topological_order (
172153 subgraphs
173154 ), "The subgraphs are not topologically ordered"
174- self .fusion_patterns = {
175- name_to_node_map [node .name ]: [
176- name_to_node_map [n .name ] for n in fusion_nodes
177- ]
178- for node , fusion_nodes in self .fusion_patterns .items ()
179- }
180155
181156 return subgraphs
182157
@@ -240,6 +215,7 @@ def break_subgraphs(
240215 # We throw an error if the remaining memory is almost empty compared to the model size.
241216 # i.e. if the remaining memory is 4G (budget is 1G) the model size is greater than 40G, we stop the compilation.
242217 sizes = self .size_of_subgraphs (subgraphs )
218+ # subgraph_size_budget = 500*1024*1024
243219 if sum (sizes ) > subgraph_size_budget * 40 :
244220 raise ValueError (
245221 "CPU memory budget or available memory is too small to compile the model. "
@@ -255,7 +231,9 @@ def break_subgraphs(
255231 size = size_1
256232 new_subgraphs .append (broken_subgraphs [0 ])
257233 subgraph = broken_subgraphs [1 ]
258- new_subgraphs .append (subgraph )
234+
235+ if len (subgraph .nodes ) != 0 :
236+ new_subgraphs .append (subgraph )
259237
260238 self ._varify_all_fusion_nodes_in_same_subgraph (new_subgraphs )
261239
@@ -325,8 +303,6 @@ def break_subgraph_by_size(
325303 if size_0 > size_to_break :
326304 break
327305
328- if len (new_subgraphs [1 ].nodes ) == 0 :
329- new_subgraphs .pop (1 )
330306 return new_subgraphs , size_0 , size_1
331307
332308 def step_and_validate (
@@ -530,7 +506,6 @@ def validate_and_correct_subgraphs(
530506
531507def resource_partition (
532508 gm : torch .fx .GraphModule ,
533- partitioned_module : torch .fx .GraphModule ,
534509 cpu_memory_budget : int ,
535510) -> torch .fx .GraphModule :
536511 """Resource-aware partitioning entry point.
@@ -552,12 +527,29 @@ def resource_partition(
552527 """
553528
554529 # Construct
555- partitioner = ResourcePartitioner (
556- gm ,
557- partitioned_module ,
558- cpu_memory_budget = cpu_memory_budget ,
559- )
530+ for name , _ in gm .named_children ():
531+ submodule = getattr (gm , name )
532+ if (
533+ not isinstance (submodule , torch .fx .graph_module .GraphModule )
534+ or "_run_on_acc" not in name
535+ ):
536+ continue
537+ partitioner = ResourcePartitioner (
538+ submodule ,
539+ submodule_name = name ,
540+ cpu_memory_budget = cpu_memory_budget ,
541+ )
542+
543+ partitioned_graph = partitioner .partition_graph ()
544+ setattr (gm , name , partitioned_graph )
560545
561- partitioned_graph = partitioner .partition_graph ()
546+ for name , module in list (gm .named_children ()):
547+ if "_run_on_acc" in name :
548+ for subname , submodule in module .named_children ():
549+ if "resource_split" in subname :
550+ setattr (gm , subname , submodule )
551+ _inline_module (gm , name )
552+ delattr (gm , name )
562553
563- return partitioned_graph
554+ gm .recompile ()
555+ return gm
0 commit comments