|
7 | 7 | from graph_net.torch.extractor import GraphExtractor as BuiltinGraphExtractor |
8 | 8 | import graph_net.imp_util as imp_util |
9 | 9 |
|
| 10 | +NODE_COUNT_FILE = "./output/node_count.txt" |
| 11 | + |
| 12 | +def get_graph_node_count(gm: torch.fx.GraphModule): |
| 13 | + """返回任意 GraphModule 的节点数量(包括子图)""" |
| 14 | + return len(list(gm.graph.nodes)) |
| 15 | + |
| 16 | +def write_global_node_count(n: int): |
| 17 | + os.makedirs(os.path.dirname(NODE_COUNT_FILE), exist_ok=True) |
| 18 | + with open(NODE_COUNT_FILE, "w") as f: |
| 19 | + f.write(str(n)) |
| 20 | + |
| 21 | +def write_subgraph_node_count(path: str, n: int): |
| 22 | + """写入 path/node_count.txt""" |
| 23 | + os.makedirs(path, exist_ok=True) |
| 24 | + file = os.path.join(path, "node_count.txt") |
| 25 | + with open(file, "w") as f: |
| 26 | + f.write(str(n)) |
10 | 27 |
|
11 | 28 | class GraphExtractor: |
12 | 29 | def __init__( |
@@ -47,6 +64,10 @@ def make_config( |
47 | 64 | } |
48 | 65 |
|
49 | 66 | def __call__(self, gm: torch.fx.GraphModule, sample_inputs): |
| 67 | + node_count = len(list(gm.graph.nodes)) |
| 68 | + write_global_node_count(node_count) |
| 69 | + print(f"[naive_graph_decomposer] 总节点数: {node_count} -> {NODE_COUNT_FILE}") |
| 70 | + |
50 | 71 | config = { |
51 | 72 | k: v |
52 | 73 | for k, v in self.config.items() |
@@ -83,6 +104,17 @@ def __init__(self, parent_graph_extractor, submodule, seq_no): |
83 | 104 | def forward(self, *args): |
84 | 105 | if not self.extracted: |
85 | 106 | if self.need_extract(self.submodule, args): |
| 107 | + out_dir = os.path.join( |
| 108 | + self.parent_graph_extractor.config["output_dir"], |
| 109 | + f"{self.parent_graph_extractor.name}_{self.seq_no}" |
| 110 | + ) |
| 111 | + |
| 112 | + # 子图节点数 |
| 113 | + node_count = len(list(self.submodule.graph.nodes)) |
| 114 | + write_subgraph_node_count(out_dir, node_count) |
| 115 | + |
| 116 | + print(f"[naive_graph_decomposer] 子图 {self.seq_no} 节点数 = {node_count} -> {out_dir}/node_count.txt") |
| 117 | + |
86 | 118 | self.builtin_extractor(self.submodule, args) |
87 | 119 | self.extracted = True |
88 | 120 | return self.submodule(*args) |
|
0 commit comments