Skip to content

Commit 8019407

Browse files
author
TelGome
committed
Binary search to locate the faulty subgraph
1 parent d7cf909 commit 8019407

File tree

2 files changed

+299
-3
lines changed

2 files changed

+299
-3
lines changed

graph_net/torch/decompose_util.py

100644100755
Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -215,13 +215,22 @@ def _get_submodule_inputs_and_outputs(
215215
)
216216
node_list = list(gm.graph.nodes)
217217

218+
def _hashable(obj):
219+
if isinstance(obj, slice):
220+
return ("__slice__", obj.start, obj.stop, obj.step)
221+
elif isinstance(obj, (list, tuple)):
222+
return tuple(_hashable(x) for x in obj)
223+
else:
224+
return obj
225+
218226
def get_related_node(node):
219227
for arg in node.args:
220228
if isinstance(arg, tuple):
221-
yield from arg
229+
for x in arg:
230+
yield _hashable(x)
222231
else:
223-
yield arg
224-
yield node
232+
yield _hashable(arg)
233+
yield _hashable(node)
225234

226235
for node in node_list[0:start_node_idx]:
227236
for related_node in get_related_node(node):
Lines changed: 287 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,287 @@
1+
import os
2+
import sys
3+
import re
4+
import json
5+
import base64
6+
import shutil
7+
import argparse
8+
import subprocess
9+
from pathlib import Path
10+
import torch
11+
from graph_net.torch import utils as gn_utils
12+
from graph_net.torch.decompose_util import convert_to_submodules_graph
13+
from graph_net.torch.extractor import GraphExtractor as BuiltinGraphExtractor
14+
import graph_net.imp_util as imp_util
15+
16+
# ----------------------------
17+
# Helpers
18+
# ----------------------------
19+
def contains_nan_or_inf_in_file(path: str) -> bool:
20+
"""Check if the file contains NaN or INF."""
21+
if not os.path.exists(path):
22+
return False
23+
with open(path, "r", encoding="utf-8", errors="ignore") as f:
24+
text = f.read().lower()
25+
return ("nan" in text) or ("inf" in text)
26+
27+
28+
def parse_correctness_line_for_nan(log_path: str, log_prompt: str = "graph-net-test-compiler-log"):
29+
"""Parse log for correctness and check if it contains NaN."""
30+
if not os.path.exists(log_path):
31+
return False
32+
pattern = re.compile(re.escape(log_prompt) + r".*\[Correctness\]\[max_diff\].*", re.IGNORECASE)
33+
with open(log_path, "r", encoding="utf-8", errors="ignore") as f:
34+
for line in f:
35+
if pattern.search(line):
36+
return "nan" in line.lower() or "inf" in line.lower()
37+
return contains_nan_or_inf_in_file(log_path)
38+
39+
40+
def get_graph_net_root() -> str:
41+
"""Return graph_net package root directory."""
42+
result = subprocess.run(
43+
[sys.executable, "-c", "import graph_net, os; print(os.path.dirname(graph_net.__file__))"],
44+
capture_output=True,
45+
text=True,
46+
)
47+
if result.returncode != 0:
48+
raise RuntimeError("Cannot locate graph_net package root: " + result.stderr)
49+
return result.stdout.strip()
50+
51+
52+
# ----------------------------
53+
# Extraction (reuse run_model decorator approach)
54+
# ----------------------------
55+
def run_naive_extractor(model_path: str, output_dir: str, split_positions: list, group_head_and_tail: bool = True, filter_path: str = None):
56+
"""Run naive graph decomposition to extract subgraphs."""
57+
GRAPH_NET_ROOT = get_graph_net_root()
58+
59+
decorator_config = {
60+
"decorator_path": f"{GRAPH_NET_ROOT}/torch/extractor.py",
61+
"decorator_config": {
62+
"name": os.path.basename(model_path.rstrip("/")),
63+
"custom_extractor_path": f"{GRAPH_NET_ROOT}/torch/naive_graph_decomposer.py",
64+
"custom_extractor_config": {
65+
"output_dir": output_dir,
66+
"split_positions": split_positions,
67+
"group_head_and_tail": group_head_and_tail,
68+
"filter_path": filter_path if filter_path else f"{GRAPH_NET_ROOT}/torch/naive_subgraph_filter.py",
69+
"filter_config": {}
70+
}
71+
}
72+
}
73+
74+
deco_json = json.dumps(decorator_config)
75+
deco_b64 = base64.b64encode(deco_json.encode()).decode()
76+
77+
cmd = [
78+
sys.executable,
79+
"-m",
80+
"graph_net.torch.run_model",
81+
"--model-path",
82+
model_path,
83+
"--decorator-config",
84+
deco_b64,
85+
]
86+
87+
print("[RUN] extracting subgraphs with naive_graph_decomposer:")
88+
print(" ".join(cmd))
89+
proc = subprocess.run(cmd)
90+
if proc.returncode != 0:
91+
raise RuntimeError(f"Extractor failed (rc={proc.returncode})")
92+
93+
94+
# ----------------------------
95+
# Find subgraph directories under output_dir
96+
# ----------------------------
97+
def find_subgraphs(output_dir: str):
98+
"""Find all subgraph directories in the output dir."""
99+
out = []
100+
for root, dirs, files in os.walk(output_dir):
101+
if any(fname in files for fname in ("graph_code.json", "graph_net.json", "model.py")):
102+
out.append(os.path.abspath(root))
103+
out = sorted(set(out))
104+
return out
105+
106+
107+
# ----------------------------
108+
# Run test_compiler on a subgraph dir and write log
109+
# ----------------------------
110+
def run_test_compiler_on_subgraph(subgraph_dir: str, log_path: str, compiler: str, device: str, warmup: int = 1, trials: int = 1, log_prompt: str = "graph-net-test-compiler-log"):
111+
"""Run test_compiler on subgraph and log result."""
112+
cmd = [
113+
sys.executable,
114+
"-m",
115+
"graph_net.torch.test_compiler",
116+
"--model-path",
117+
subgraph_dir,
118+
"--compiler",
119+
compiler,
120+
"--device",
121+
device,
122+
"--warmup",
123+
str(warmup),
124+
"--trials",
125+
str(trials),
126+
"--log-prompt",
127+
log_prompt,
128+
]
129+
130+
print(f"[RUN] test_compiler on {subgraph_dir}")
131+
with open(log_path, "wb") as logf:
132+
proc = subprocess.run(cmd, stdout=logf, stderr=subprocess.STDOUT)
133+
134+
has_nan = parse_correctness_line_for_nan(log_path, log_prompt=log_prompt)
135+
print(f"[LOG] {subgraph_dir} -> nan={has_nan} (rc={proc.returncode})")
136+
return has_nan
137+
138+
139+
# ----------------------------
140+
# Count nodes in the FX graph
141+
# ----------------------------
142+
def count_graph_nodes(gm: torch.fx.GraphModule):
143+
"""Count the number of nodes in the FX graph."""
144+
return len(list(gm.graph.nodes))
145+
146+
147+
# ----------------------------
148+
# Recursive binary classification
149+
# ----------------------------
150+
def binary_classify_subgraphs(subgraphs: list, tmp_log_dir: str, compiler: str, device: str, warmup: int, trials: int, log_prompt: str):
151+
"""
152+
Recursive binary classification of subgraphs based on node count.
153+
"""
154+
good = []
155+
bad = []
156+
157+
def solve(lst):
158+
if not lst:
159+
return
160+
if len(lst) == 1:
161+
g = lst[0]
162+
log_path = os.path.join(g, "compiler_test.log")
163+
os.makedirs(os.path.dirname(log_path), exist_ok=True)
164+
has_nan = run_test_compiler_on_subgraph(g, log_path, compiler, device, warmup, trials, log_prompt)
165+
if has_nan:
166+
bad.append(g)
167+
else:
168+
good.append(g)
169+
return
170+
171+
mid = len(lst) // 2
172+
left = lst[:mid]
173+
right = lst[mid:]
174+
175+
# Test left side
176+
left_has_nan = False
177+
for g in left:
178+
tmp_log = os.path.join(tmp_log_dir, "batch_left.log")
179+
os.makedirs(os.path.dirname(tmp_log), exist_ok=True)
180+
if run_test_compiler_on_subgraph(g, tmp_log, compiler, device, warmup, trials, log_prompt):
181+
left_has_nan = True
182+
break
183+
184+
if left_has_nan:
185+
solve(left)
186+
else:
187+
good.extend(left)
188+
189+
# Test right side
190+
right_has_nan = False
191+
for g in right:
192+
tmp_log = os.path.join(tmp_log_dir, "batch_right.log")
193+
os.makedirs(os.path.dirname(tmp_log), exist_ok=True)
194+
if run_test_compiler_on_subgraph(g, tmp_log, compiler, device, warmup, trials, log_prompt):
195+
right_has_nan = True
196+
break
197+
198+
if right_has_nan:
199+
solve(right)
200+
else:
201+
good.extend(right)
202+
203+
solve(subgraphs)
204+
return good, bad
205+
206+
207+
# ----------------------------
208+
# Main
209+
# ----------------------------
210+
def main():
211+
parser = argparse.ArgumentParser(description="GraphNet: check log -> if nan -> extract subgraphs -> binary classify via test_compiler")
212+
parser.add_argument("--log-file", type=str, required=True, help="Path to main run log to check for nan")
213+
parser.add_argument("--model-path", type=str, required=True, help="GraphNet model dir (contains model.py, graph_net.json, inputs...)")
214+
parser.add_argument("--output-dir", type=str, default="/tmp/naive_decompose_workspace", help="workspace to dump extracted subgraphs")
215+
parser.add_argument("--split-positions", type=int, nargs="*", default=[], help="split positions to pass to extractor")
216+
parser.add_argument("--compiler", type=str, default="inductor", help="compiler backend to use when running test_compiler")
217+
parser.add_argument("--device", type=str, default="cuda", help="device for test_compiler")
218+
parser.add_argument("--warmup", type=int, default=1, help="warmup for test_compiler runs")
219+
parser.add_argument("--trials", type=int, default=1, help="trials for test_compiler runs")
220+
parser.add_argument("--log-prompt", type=str, default="graph-net-test-compiler-log", help="log prompt used by test_compiler")
221+
parser.add_argument("--force-extract", action="store_true", help="always run extractor even if no nan in main log")
222+
223+
args = parser.parse_args()
224+
225+
# 1) check main log
226+
print(f"[INFO] Checking main log: {args.log_file}")
227+
if not os.path.exists(args.log_file):
228+
print(f"[WARN] main log not found: {args.log_file}")
229+
# we allow forcing extractor or abort
230+
if not args.force_extract:
231+
print("[ERROR] main log missing and not forcing extraction. Exiting.")
232+
sys.exit(2)
233+
234+
main_log_has_nan = False
235+
if os.path.exists(args.log_file):
236+
main_log_has_nan = contains_nan_or_inf_in_file(args.log_file)
237+
print(f"[INFO] main log contains nan/inf? {main_log_has_nan}")
238+
239+
if not main_log_has_nan and not args.force_extract:
240+
print("[INFO] No NaN found in main log. Exiting without extraction.")
241+
sys.exit(0)
242+
243+
# 2) run extractor to produce subgraphs
244+
print("[STEP] Running naive_graph_decomposer to extract subgraphs...")
245+
# ensure fresh output dir
246+
if os.path.exists(args.output_dir):
247+
print(f"[INFO] clearing existing output dir: {args.output_dir}")
248+
shutil.rmtree(args.output_dir)
249+
os.makedirs(args.output_dir, exist_ok=True)
250+
251+
run_naive_extractor(args.model_path, args.output_dir, args.split_positions)
252+
253+
# 3) find subgraphs
254+
print("[STEP] Searching for subgraphs in output dir...")
255+
subgraphs = find_subgraphs(args.output_dir)
256+
print(f"[INFO] Found {len(subgraphs)} candidate subgraph dirs")
257+
258+
if not subgraphs:
259+
print("[ERROR] No subgraphs found; make sure extractor produced files (graph_code.json or graph_net.json or model.py)")
260+
sys.exit(3)
261+
262+
# 4) binary classify using test_compiler
263+
print("[STEP] Running binary classification on extracted subgraphs...")
264+
tmp_log_dir = os.path.join(args.output_dir, "_tmp_logs")
265+
os.makedirs(tmp_log_dir, exist_ok=True)
266+
267+
good, bad = binary_classify_subgraphs(subgraphs, tmp_log_dir, args.compiler, args.device, args.warmup, args.trials, args.log_prompt)
268+
269+
# 5) output result
270+
print("\n===== RESULT =====")
271+
print(f"Good subgraphs ({len(good)}):")
272+
for g in good:
273+
print(" [GOOD]", g)
274+
print(f"\nBad subgraphs ({len(bad)}):")
275+
for g in bad:
276+
print(" [BAD]", g)
277+
278+
if bad:
279+
print("\nDetected bad subgraphs -> exit code 4")
280+
sys.exit(4)
281+
else:
282+
print("\nAll subgraphs OK -> exit code 0")
283+
sys.exit(0)
284+
285+
286+
if __name__ == "__main__":
287+
main()

0 commit comments

Comments
 (0)