Skip to content

Commit 1b524f9

Browse files
committed
supported global partitioner
1 parent 07d8786 commit 1b524f9

File tree

4 files changed

+251
-87
lines changed

4 files changed

+251
-87
lines changed

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -867,7 +867,6 @@ def preserve_module_specs(
867867
)
868868

869869
partitioned_module = resource_partition(
870-
gm,
871870
partitioned_module,
872871
cpu_memory_budget=settings.cpu_memory_budget,
873872
)
@@ -895,6 +894,7 @@ def preserve_module_specs(
895894
for attr in dir(gm):
896895
if attr.startswith("_frozen_param"):
897896
delattr(gm, attr)
897+
898898
for name, _ in partitioned_module.named_children():
899899
submodule = getattr(partitioned_module, name)
900900
# filter on the GraphModule
@@ -1357,7 +1357,7 @@ def convert_exported_program_to_serialized_trt_engine(
13571357
)
13581358

13591359
flattened_input_list = get_flat_args_with_check(
1360-
exported_program, list(trt_arg_inputs), trt_kwarg_inputs # type: ignore
1360+
exported_program, list(trt_arg_inputs), trt_kwarg_inputs
13611361
)[0]
13621362

13631363
try:

py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ def partition_graph(self) -> torch.fx.GraphModule:
230230

231231
# Tag the accelerated nodes and split the graph accordingly
232232
self.tag(subgraphs)
233-
return self.split()
233+
return self.split(remove_tag=True)
234234

235235
def starter_nodes(self) -> Tuple[NodeSet, NodeSet]:
236236
"""Generates starter nodes for partitioning + segmentation"""

py/torch_tensorrt/dynamo/partitioning/_resource_partitioner.py

Lines changed: 51 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050

5151
import psutil
5252
import torch
53+
from torch.fx.experimental.const_fold import _inline_module
5354
from torch.fx.passes.splitter_base import Subgraph, _SplitterBase
5455
from torch.fx.passes.tools_common import CALLABLE_NODE_OPS
5556
from torch_tensorrt.dynamo.partitioning._atomic_subgraphs import (
@@ -77,17 +78,16 @@ class ResourcePartitioner(_SplitterBase): # type: ignore
7778
def __init__(
7879
self,
7980
module: torch.fx.GraphModule,
80-
partitioned_module: torch.fx.GraphModule,
8181
cpu_memory_budget: int,
82+
submodule_name: str,
8283
):
8384

8485
assert isinstance(module, torch.fx.GraphModule)
85-
assert isinstance(partitioned_module, torch.fx.GraphModule)
8686

8787
self.module = module
88-
self.partitioned_module = partitioned_module
8988
self.cpu_memory_budget = cpu_memory_budget
90-
89+
self.resource_split_count = 0
90+
self.submodule_name = submodule_name
9191
self.deps = self.find_deps()
9292

9393
self.non_acc_submodule_name = "_run_on_gpu_"
@@ -119,64 +119,39 @@ def partition_graph(self) -> torch.fx.GraphModule:
119119
# Tag the accelerated nodes and split the graph accordingly
120120
self.tag(subgraphs)
121121

122-
gm = self.split()
122+
gm = self.split(remove_tag=True)
123123

124124
return gm
125125

126-
def put_nodes_into_subgraphs(self) -> list[Subgraph]:
127-
"""Map original graph nodes into capability-based subgraphs.
128-
129-
- Iterates `partitioned_module` submodules to establish which node names
130-
belong to which subgraph (accelerated or not).
131-
- Builds a fusion pattern map for each subgraph so that known fusion groups remain intact.
132-
Note that since fusion map is built for each subgraph, the capability partitioning can still break the fusion groups.
133-
- Put the nodes into the subgraphs based on the capability partitioning.
134-
- Verifies the resulting list of subgraphs is topologically ordered.
126+
def tag(self, subgraphs: list[Subgraph]) -> None:
127+
self.tags = []
128+
for subgraph in subgraphs:
129+
tag = f"{self.submodule_name}_resource_split_{self.resource_split_count}"
130+
self.resource_split_count += 1
131+
self.tags.append(tag)
132+
for node in subgraph.nodes:
133+
node.tag = tag
134+
self._node_submodule_map[node.name] = tag
135135

136+
def put_nodes_into_subgraphs(self) -> list[Subgraph]:
137+
"""
138+
Put the nodes into the subgraphs and erase the tag from previous partitioner if it exists.
136139
Returns:
137-
list[Subgraph]: Ordered subgraphs consisting of nodes in `module` based on capability partitioning.
140+
list[Subgraph]: Ordered subgraphs consisting of nodes in `module`.
138141
"""
139-
subgraphs_map = {}
140-
subgraphs = []
141-
name_to_node_map = (
142-
{}
143-
) # We use this map to help map the nodes in partitioned module to the nodes in original module.
144-
for name, _ in self.partitioned_module.named_children():
145-
# We first iterate over the partitioned module to find the subgraphs based on capability partitioning.
146-
submodule = getattr(self.partitioned_module, name)
147-
if not isinstance(submodule, torch.fx.graph_module.GraphModule):
148-
continue
149-
subgraph = Subgraph(is_acc="acc" in name, nodes=[])
150-
subgraphs.append(subgraph)
151-
self.fusion_patterns.update(get_node_in_fusion_pattern(submodule.graph))
152-
153-
for node in submodule.graph.nodes:
154-
# Erase the tag from previous partitioner if it exists
155-
if hasattr(node, "tag"):
156-
delattr(node, "tag")
157-
158-
if node.op in CALLABLE_NODE_OPS:
159-
# Store which subgraph the node should be put in
160-
subgraphs_map[node.name] = subgraph
161142

162-
# We then iterate over the original module to put the nodes into the subgraphs.
143+
nodes = []
163144
for node in self.module.graph.nodes:
164145
if hasattr(node, "tag"):
165-
# Erase the tag from previous partitioner
166-
delattr(node, "tag")
146+
del node.tag
167147
if node.op in CALLABLE_NODE_OPS:
168-
name_to_node_map[node.name] = node
169-
subgraphs_map[node.name].nodes.append(node)
148+
nodes.append(node)
149+
subgraphs = [Subgraph(is_acc=True, nodes=nodes)]
150+
self.fusion_patterns = get_node_in_fusion_pattern(self.module.graph)
170151

171152
assert self.check_topological_order(
172153
subgraphs
173154
), "The subgraphs are not topologically ordered"
174-
self.fusion_patterns = {
175-
name_to_node_map[node.name]: [
176-
name_to_node_map[n.name] for n in fusion_nodes
177-
]
178-
for node, fusion_nodes in self.fusion_patterns.items()
179-
}
180155

181156
return subgraphs
182157

@@ -240,6 +215,7 @@ def break_subgraphs(
240215
# We throw an error if the remaining memory is almost empty compared to the model size.
241216
# i.e. if the remaining memory is 4G (budget is 1G) the model size is greater than 40G, we stop the compilation.
242217
sizes = self.size_of_subgraphs(subgraphs)
218+
# subgraph_size_budget = 500*1024*1024
243219
if sum(sizes) > subgraph_size_budget * 40:
244220
raise ValueError(
245221
"CPU memory budget or available memory is too small to compile the model. "
@@ -255,7 +231,9 @@ def break_subgraphs(
255231
size = size_1
256232
new_subgraphs.append(broken_subgraphs[0])
257233
subgraph = broken_subgraphs[1]
258-
new_subgraphs.append(subgraph)
234+
235+
if len(subgraph.nodes) != 0:
236+
new_subgraphs.append(subgraph)
259237

260238
self._varify_all_fusion_nodes_in_same_subgraph(new_subgraphs)
261239

@@ -325,8 +303,6 @@ def break_subgraph_by_size(
325303
if size_0 > size_to_break:
326304
break
327305

328-
if len(new_subgraphs[1].nodes) == 0:
329-
new_subgraphs.pop(1)
330306
return new_subgraphs, size_0, size_1
331307

332308
def step_and_validate(
@@ -530,7 +506,6 @@ def validate_and_correct_subgraphs(
530506

531507
def resource_partition(
532508
gm: torch.fx.GraphModule,
533-
partitioned_module: torch.fx.GraphModule,
534509
cpu_memory_budget: int,
535510
) -> torch.fx.GraphModule:
536511
"""Resource-aware partitioning entry point.
@@ -552,12 +527,29 @@ def resource_partition(
552527
"""
553528

554529
# Construct
555-
partitioner = ResourcePartitioner(
556-
gm,
557-
partitioned_module,
558-
cpu_memory_budget=cpu_memory_budget,
559-
)
530+
for name, _ in gm.named_children():
531+
submodule = getattr(gm, name)
532+
if (
533+
not isinstance(submodule, torch.fx.graph_module.GraphModule)
534+
or "_run_on_acc" not in name
535+
):
536+
continue
537+
partitioner = ResourcePartitioner(
538+
submodule,
539+
submodule_name=name,
540+
cpu_memory_budget=cpu_memory_budget,
541+
)
542+
543+
partitioned_graph = partitioner.partition_graph()
544+
setattr(gm, name, partitioned_graph)
560545

561-
partitioned_graph = partitioner.partition_graph()
546+
for name, module in list(gm.named_children()):
547+
if "_run_on_acc" in name:
548+
for subname, submodule in module.named_children():
549+
if "resource_split" in subname:
550+
setattr(gm, subname, submodule)
551+
_inline_module(gm, name)
552+
delattr(gm, name)
562553

563-
return partitioned_graph
554+
gm.recompile()
555+
return gm

0 commit comments

Comments
 (0)