|
| 1 | +import os |
| 2 | +import sys |
| 3 | +import re |
| 4 | +import json |
| 5 | +import base64 |
| 6 | +import shutil |
| 7 | +import argparse |
| 8 | +import subprocess |
| 9 | +from pathlib import Path |
| 10 | +import torch |
| 11 | +from graph_net.torch import utils as gn_utils |
| 12 | +from graph_net.torch.decompose_util import convert_to_submodules_graph |
| 13 | +from graph_net.torch.extractor import GraphExtractor as BuiltinGraphExtractor |
| 14 | +import graph_net.imp_util as imp_util |
| 15 | + |
| 16 | +# ---------------------------- |
| 17 | +# Helpers |
| 18 | +# ---------------------------- |
| 19 | +def contains_nan_or_inf_in_file(path: str) -> bool: |
| 20 | + """Check if the file contains NaN or INF.""" |
| 21 | + if not os.path.exists(path): |
| 22 | + return False |
| 23 | + with open(path, "r", encoding="utf-8", errors="ignore") as f: |
| 24 | + text = f.read().lower() |
| 25 | + return ("nan" in text) or ("inf" in text) |
| 26 | + |
| 27 | + |
| 28 | +def parse_correctness_line_for_nan(log_path: str, log_prompt: str = "graph-net-test-compiler-log"): |
| 29 | + """Parse log for correctness and check if it contains NaN.""" |
| 30 | + if not os.path.exists(log_path): |
| 31 | + return False |
| 32 | + pattern = re.compile(re.escape(log_prompt) + r".*\[Correctness\]\[max_diff\].*", re.IGNORECASE) |
| 33 | + with open(log_path, "r", encoding="utf-8", errors="ignore") as f: |
| 34 | + for line in f: |
| 35 | + if pattern.search(line): |
| 36 | + return "nan" in line.lower() or "inf" in line.lower() |
| 37 | + return contains_nan_or_inf_in_file(log_path) |
| 38 | + |
| 39 | + |
| 40 | +def get_graph_net_root() -> str: |
| 41 | + """Return graph_net package root directory.""" |
| 42 | + result = subprocess.run( |
| 43 | + [sys.executable, "-c", "import graph_net, os; print(os.path.dirname(graph_net.__file__))"], |
| 44 | + capture_output=True, |
| 45 | + text=True, |
| 46 | + ) |
| 47 | + if result.returncode != 0: |
| 48 | + raise RuntimeError("Cannot locate graph_net package root: " + result.stderr) |
| 49 | + return result.stdout.strip() |
| 50 | + |
| 51 | + |
| 52 | +# ---------------------------- |
| 53 | +# Extraction (reuse run_model decorator approach) |
| 54 | +# ---------------------------- |
| 55 | +def run_naive_extractor(model_path: str, output_dir: str, split_positions: list, group_head_and_tail: bool = True, filter_path: str = None): |
| 56 | + """Run naive graph decomposition to extract subgraphs.""" |
| 57 | + GRAPH_NET_ROOT = get_graph_net_root() |
| 58 | + |
| 59 | + decorator_config = { |
| 60 | + "decorator_path": f"{GRAPH_NET_ROOT}/torch/extractor.py", |
| 61 | + "decorator_config": { |
| 62 | + "name": os.path.basename(model_path.rstrip("/")), |
| 63 | + "custom_extractor_path": f"{GRAPH_NET_ROOT}/torch/naive_graph_decomposer.py", |
| 64 | + "custom_extractor_config": { |
| 65 | + "output_dir": output_dir, |
| 66 | + "split_positions": split_positions, |
| 67 | + "group_head_and_tail": group_head_and_tail, |
| 68 | + "filter_path": filter_path if filter_path else f"{GRAPH_NET_ROOT}/torch/naive_subgraph_filter.py", |
| 69 | + "filter_config": {} |
| 70 | + } |
| 71 | + } |
| 72 | + } |
| 73 | + |
| 74 | + deco_json = json.dumps(decorator_config) |
| 75 | + deco_b64 = base64.b64encode(deco_json.encode()).decode() |
| 76 | + |
| 77 | + cmd = [ |
| 78 | + sys.executable, |
| 79 | + "-m", |
| 80 | + "graph_net.torch.run_model", |
| 81 | + "--model-path", |
| 82 | + model_path, |
| 83 | + "--decorator-config", |
| 84 | + deco_b64, |
| 85 | + ] |
| 86 | + |
| 87 | + print("[RUN] extracting subgraphs with naive_graph_decomposer:") |
| 88 | + print(" ".join(cmd)) |
| 89 | + proc = subprocess.run(cmd) |
| 90 | + if proc.returncode != 0: |
| 91 | + raise RuntimeError(f"Extractor failed (rc={proc.returncode})") |
| 92 | + |
| 93 | + |
| 94 | +# ---------------------------- |
| 95 | +# Find subgraph directories under output_dir |
| 96 | +# ---------------------------- |
| 97 | +def find_subgraphs(output_dir: str): |
| 98 | + """Find all subgraph directories in the output dir.""" |
| 99 | + out = [] |
| 100 | + for root, dirs, files in os.walk(output_dir): |
| 101 | + if any(fname in files for fname in ("graph_code.json", "graph_net.json", "model.py")): |
| 102 | + out.append(os.path.abspath(root)) |
| 103 | + out = sorted(set(out)) |
| 104 | + return out |
| 105 | + |
| 106 | + |
| 107 | +# ---------------------------- |
| 108 | +# Run test_compiler on a subgraph dir and write log |
| 109 | +# ---------------------------- |
| 110 | +def run_test_compiler_on_subgraph(subgraph_dir: str, log_path: str, compiler: str, device: str, warmup: int = 1, trials: int = 1, log_prompt: str = "graph-net-test-compiler-log"): |
| 111 | + """Run test_compiler on subgraph and log result.""" |
| 112 | + cmd = [ |
| 113 | + sys.executable, |
| 114 | + "-m", |
| 115 | + "graph_net.torch.test_compiler", |
| 116 | + "--model-path", |
| 117 | + subgraph_dir, |
| 118 | + "--compiler", |
| 119 | + compiler, |
| 120 | + "--device", |
| 121 | + device, |
| 122 | + "--warmup", |
| 123 | + str(warmup), |
| 124 | + "--trials", |
| 125 | + str(trials), |
| 126 | + "--log-prompt", |
| 127 | + log_prompt, |
| 128 | + ] |
| 129 | + |
| 130 | + print(f"[RUN] test_compiler on {subgraph_dir}") |
| 131 | + with open(log_path, "wb") as logf: |
| 132 | + proc = subprocess.run(cmd, stdout=logf, stderr=subprocess.STDOUT) |
| 133 | + |
| 134 | + has_nan = parse_correctness_line_for_nan(log_path, log_prompt=log_prompt) |
| 135 | + print(f"[LOG] {subgraph_dir} -> nan={has_nan} (rc={proc.returncode})") |
| 136 | + return has_nan |
| 137 | + |
| 138 | + |
| 139 | +# ---------------------------- |
| 140 | +# Count nodes in the FX graph |
| 141 | +# ---------------------------- |
| 142 | +def count_graph_nodes(gm: torch.fx.GraphModule): |
| 143 | + """Count the number of nodes in the FX graph.""" |
| 144 | + return len(list(gm.graph.nodes)) |
| 145 | + |
| 146 | + |
| 147 | +# ---------------------------- |
| 148 | +# Recursive binary classification |
| 149 | +# ---------------------------- |
| 150 | +def binary_classify_subgraphs(subgraphs: list, tmp_log_dir: str, compiler: str, device: str, warmup: int, trials: int, log_prompt: str): |
| 151 | + """ |
| 152 | + Recursive binary classification of subgraphs based on node count. |
| 153 | + """ |
| 154 | + good = [] |
| 155 | + bad = [] |
| 156 | + |
| 157 | + def solve(lst): |
| 158 | + if not lst: |
| 159 | + return |
| 160 | + if len(lst) == 1: |
| 161 | + g = lst[0] |
| 162 | + log_path = os.path.join(g, "compiler_test.log") |
| 163 | + os.makedirs(os.path.dirname(log_path), exist_ok=True) |
| 164 | + has_nan = run_test_compiler_on_subgraph(g, log_path, compiler, device, warmup, trials, log_prompt) |
| 165 | + if has_nan: |
| 166 | + bad.append(g) |
| 167 | + else: |
| 168 | + good.append(g) |
| 169 | + return |
| 170 | + |
| 171 | + mid = len(lst) // 2 |
| 172 | + left = lst[:mid] |
| 173 | + right = lst[mid:] |
| 174 | + |
| 175 | + # Test left side |
| 176 | + left_has_nan = False |
| 177 | + for g in left: |
| 178 | + tmp_log = os.path.join(tmp_log_dir, "batch_left.log") |
| 179 | + os.makedirs(os.path.dirname(tmp_log), exist_ok=True) |
| 180 | + if run_test_compiler_on_subgraph(g, tmp_log, compiler, device, warmup, trials, log_prompt): |
| 181 | + left_has_nan = True |
| 182 | + break |
| 183 | + |
| 184 | + if left_has_nan: |
| 185 | + solve(left) |
| 186 | + else: |
| 187 | + good.extend(left) |
| 188 | + |
| 189 | + # Test right side |
| 190 | + right_has_nan = False |
| 191 | + for g in right: |
| 192 | + tmp_log = os.path.join(tmp_log_dir, "batch_right.log") |
| 193 | + os.makedirs(os.path.dirname(tmp_log), exist_ok=True) |
| 194 | + if run_test_compiler_on_subgraph(g, tmp_log, compiler, device, warmup, trials, log_prompt): |
| 195 | + right_has_nan = True |
| 196 | + break |
| 197 | + |
| 198 | + if right_has_nan: |
| 199 | + solve(right) |
| 200 | + else: |
| 201 | + good.extend(right) |
| 202 | + |
| 203 | + solve(subgraphs) |
| 204 | + return good, bad |
| 205 | + |
| 206 | + |
| 207 | +# ---------------------------- |
| 208 | +# Main |
| 209 | +# ---------------------------- |
| 210 | +def main(): |
| 211 | + parser = argparse.ArgumentParser(description="GraphNet: check log -> if nan -> extract subgraphs -> binary classify via test_compiler") |
| 212 | + parser.add_argument("--log-file", type=str, required=True, help="Path to main run log to check for nan") |
| 213 | + parser.add_argument("--model-path", type=str, required=True, help="GraphNet model dir (contains model.py, graph_net.json, inputs...)") |
| 214 | + parser.add_argument("--output-dir", type=str, default="/tmp/naive_decompose_workspace", help="workspace to dump extracted subgraphs") |
| 215 | + parser.add_argument("--split-positions", type=int, nargs="*", default=[], help="split positions to pass to extractor") |
| 216 | + parser.add_argument("--compiler", type=str, default="inductor", help="compiler backend to use when running test_compiler") |
| 217 | + parser.add_argument("--device", type=str, default="cuda", help="device for test_compiler") |
| 218 | + parser.add_argument("--warmup", type=int, default=1, help="warmup for test_compiler runs") |
| 219 | + parser.add_argument("--trials", type=int, default=1, help="trials for test_compiler runs") |
| 220 | + parser.add_argument("--log-prompt", type=str, default="graph-net-test-compiler-log", help="log prompt used by test_compiler") |
| 221 | + parser.add_argument("--force-extract", action="store_true", help="always run extractor even if no nan in main log") |
| 222 | + |
| 223 | + args = parser.parse_args() |
| 224 | + |
| 225 | + # 1) check main log |
| 226 | + print(f"[INFO] Checking main log: {args.log_file}") |
| 227 | + if not os.path.exists(args.log_file): |
| 228 | + print(f"[WARN] main log not found: {args.log_file}") |
| 229 | + # we allow forcing extractor or abort |
| 230 | + if not args.force_extract: |
| 231 | + print("[ERROR] main log missing and not forcing extraction. Exiting.") |
| 232 | + sys.exit(2) |
| 233 | + |
| 234 | + main_log_has_nan = False |
| 235 | + if os.path.exists(args.log_file): |
| 236 | + main_log_has_nan = contains_nan_or_inf_in_file(args.log_file) |
| 237 | + print(f"[INFO] main log contains nan/inf? {main_log_has_nan}") |
| 238 | + |
| 239 | + if not main_log_has_nan and not args.force_extract: |
| 240 | + print("[INFO] No NaN found in main log. Exiting without extraction.") |
| 241 | + sys.exit(0) |
| 242 | + |
| 243 | + # 2) run extractor to produce subgraphs |
| 244 | + print("[STEP] Running naive_graph_decomposer to extract subgraphs...") |
| 245 | + # ensure fresh output dir |
| 246 | + if os.path.exists(args.output_dir): |
| 247 | + print(f"[INFO] clearing existing output dir: {args.output_dir}") |
| 248 | + shutil.rmtree(args.output_dir) |
| 249 | + os.makedirs(args.output_dir, exist_ok=True) |
| 250 | + |
| 251 | + run_naive_extractor(args.model_path, args.output_dir, args.split_positions) |
| 252 | + |
| 253 | + # 3) find subgraphs |
| 254 | + print("[STEP] Searching for subgraphs in output dir...") |
| 255 | + subgraphs = find_subgraphs(args.output_dir) |
| 256 | + print(f"[INFO] Found {len(subgraphs)} candidate subgraph dirs") |
| 257 | + |
| 258 | + if not subgraphs: |
| 259 | + print("[ERROR] No subgraphs found; make sure extractor produced files (graph_code.json or graph_net.json or model.py)") |
| 260 | + sys.exit(3) |
| 261 | + |
| 262 | + # 4) binary classify using test_compiler |
| 263 | + print("[STEP] Running binary classification on extracted subgraphs...") |
| 264 | + tmp_log_dir = os.path.join(args.output_dir, "_tmp_logs") |
| 265 | + os.makedirs(tmp_log_dir, exist_ok=True) |
| 266 | + |
| 267 | + good, bad = binary_classify_subgraphs(subgraphs, tmp_log_dir, args.compiler, args.device, args.warmup, args.trials, args.log_prompt) |
| 268 | + |
| 269 | + # 5) output result |
| 270 | + print("\n===== RESULT =====") |
| 271 | + print(f"Good subgraphs ({len(good)}):") |
| 272 | + for g in good: |
| 273 | + print(" [GOOD]", g) |
| 274 | + print(f"\nBad subgraphs ({len(bad)}):") |
| 275 | + for g in bad: |
| 276 | + print(" [BAD]", g) |
| 277 | + |
| 278 | + if bad: |
| 279 | + print("\nDetected bad subgraphs -> exit code 4") |
| 280 | + sys.exit(4) |
| 281 | + else: |
| 282 | + print("\nAll subgraphs OK -> exit code 0") |
| 283 | + sys.exit(0) |
| 284 | + |
| 285 | + |
| 286 | +if __name__ == "__main__": |
| 287 | + main() |
0 commit comments