Skip to content

Commit 621743b

Browse files
author
TelGome
committed
Binary search to locate the faulty subgraph
1 parent d7cf909 commit 621743b

File tree

3 files changed

+380
-3
lines changed

3 files changed

+380
-3
lines changed

graph_net/torch/decompose_util.py

100644100755
Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -215,13 +215,22 @@ def _get_submodule_inputs_and_outputs(
215215
)
216216
node_list = list(gm.graph.nodes)
217217

218+
def _hashable(obj):
219+
if isinstance(obj, slice):
220+
return ("__slice__", obj.start, obj.stop, obj.step)
221+
elif isinstance(obj, (list, tuple)):
222+
return tuple(_hashable(x) for x in obj)
223+
else:
224+
return obj
225+
218226
def get_related_node(node):
219227
for arg in node.args:
220228
if isinstance(arg, tuple):
221-
yield from arg
229+
for x in arg:
230+
yield _hashable(x)
222231
else:
223-
yield arg
224-
yield node
232+
yield _hashable(arg)
233+
yield _hashable(node)
225234

226235
for node in node_list[0:start_node_idx]:
227236
for related_node in get_related_node(node):

graph_net/torch/naive_graph_decomposer.py

100644100755
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,23 @@
77
from graph_net.torch.extractor import GraphExtractor as BuiltinGraphExtractor
88
import graph_net.imp_util as imp_util
99

10+
NODE_COUNT_FILE = "./output/node_count.txt"
11+
12+
def get_graph_node_count(gm: torch.fx.GraphModule):
13+
"""返回任意 GraphModule 的节点数量(包括子图)"""
14+
return len(list(gm.graph.nodes))
15+
16+
def write_global_node_count(n: int):
17+
os.makedirs(os.path.dirname(NODE_COUNT_FILE), exist_ok=True)
18+
with open(NODE_COUNT_FILE, "w") as f:
19+
f.write(str(n))
20+
21+
def write_subgraph_node_count(path: str, n: int):
22+
"""写入 path/node_count.txt"""
23+
os.makedirs(path, exist_ok=True)
24+
file = os.path.join(path, "node_count.txt")
25+
with open(file, "w") as f:
26+
f.write(str(n))
1027

1128
class GraphExtractor:
1229
def __init__(
@@ -47,6 +64,10 @@ def make_config(
4764
}
4865

4966
def __call__(self, gm: torch.fx.GraphModule, sample_inputs):
67+
node_count = len(list(gm.graph.nodes))
68+
write_global_node_count(node_count)
69+
print(f"[naive_graph_decomposer] 总节点数: {node_count} -> {NODE_COUNT_FILE}")
70+
5071
config = {
5172
k: v
5273
for k, v in self.config.items()
@@ -83,6 +104,17 @@ def __init__(self, parent_graph_extractor, submodule, seq_no):
83104
def forward(self, *args):
84105
if not self.extracted:
85106
if self.need_extract(self.submodule, args):
107+
out_dir = os.path.join(
108+
self.parent_graph_extractor.config["output_dir"],
109+
f"{self.parent_graph_extractor.name}_{self.seq_no}"
110+
)
111+
112+
# 子图节点数
113+
node_count = len(list(self.submodule.graph.nodes))
114+
write_subgraph_node_count(out_dir, node_count)
115+
116+
print(f"[naive_graph_decomposer] 子图 {self.seq_no} 节点数 = {node_count} -> {out_dir}/node_count.txt")
117+
86118
self.builtin_extractor(self.submodule, args)
87119
self.extracted = True
88120
return self.submodule(*args)

0 commit comments

Comments
 (0)