From 8e65f8f1dc91ab771dd3220640cb8d8547444ea4 Mon Sep 17 00:00:00 2001 From: Stephen Skeirik Date: Wed, 12 Nov 2025 12:55:22 -0500 Subject: [PATCH] teach kmir to reduce SMIR K definitions over a set of CFG roots --- kmir/src/kmir/__main__.py | 4 ++++ kmir/src/kmir/kmir.py | 2 +- kmir/src/kmir/options.py | 7 +++++++ kmir/src/kmir/smir.py | 21 ++++++++++++++------- 4 files changed, 26 insertions(+), 8 deletions(-) diff --git a/kmir/src/kmir/__main__.py b/kmir/src/kmir/__main__.py index 846cf91ef..2653d796c 100644 --- a/kmir/src/kmir/__main__.py +++ b/kmir/src/kmir/__main__.py @@ -429,6 +429,9 @@ def _arg_parser() -> ArgumentParser: prove_rs_parser.add_argument( '--start-symbol', type=str, metavar='SYMBOL', default='main', help='Symbol name to begin execution from' ) + prove_rs_parser.add_argument( + '--cfg-roots', type=Path, metavar='CFG_ROOTS', help='Path to file containing newline-separated possible control flow graph roots (used to prune `rustc` generated MIR symbol table)' + ) link_parser = command_parser.add_parser( 'link', help='Link together 2 or more SMIR JSON files', parents=[kcli_args.logging_args] @@ -499,6 +502,7 @@ def _parse_args(ns: Namespace) -> KMirOpts: save_smir=ns.save_smir, smir=ns.smir, start_symbol=ns.start_symbol, + cfg_roots=ns.cfg_roots, break_on_calls=ns.break_on_calls, break_on_function_calls=ns.break_on_function_calls, break_on_intrinsic_calls=ns.break_on_intrinsic_calls, diff --git a/kmir/src/kmir/kmir.py b/kmir/src/kmir/kmir.py index c5d629c62..5508eeb52 100644 --- a/kmir/src/kmir/kmir.py +++ b/kmir/src/kmir/kmir.py @@ -217,7 +217,7 @@ def prove_rs(opts: ProveRSOpts) -> APRProof: else: smir_info = SMIRInfo(cargo_get_smir_json(opts.rs_file, save_smir=opts.save_smir)) - smir_info = smir_info.reduce_to(opts.start_symbol) + smir_info = smir_info.reduce_to(opts.cfg_roots) # Report whether the reduced call graph includes any functions without MIR bodies missing_body_syms = [ sym diff --git a/kmir/src/kmir/options.py b/kmir/src/kmir/options.py index ff6371fdd..533b6a645 100644 --- a/kmir/src/kmir/options.py +++ b/kmir/src/kmir/options.py @@ -108,6 +108,7 @@ class ProveRSOpts(ProveOpts): save_smir: bool smir: bool start_symbol: str + cfg_roots: list[str] def __init__( self, @@ -120,6 +121,7 @@ def __init__( save_smir: bool = False, smir: bool = False, start_symbol: str = 'main', + cfg_roots: Path | None = None, break_on_calls: bool = False, break_on_function_calls: bool = False, break_on_intrinsic_calls: bool = False, @@ -136,6 +138,10 @@ def __init__( break_every_step: bool = False, terminate_on_thunk: bool = False, ) -> None: + # store each non-empty line in the cfg roots file + start symbol + cfg_roots = list(filter(None, [root.strip() for root in cfg_roots.read_text().splitlines()])) if cfg_roots is not None else [] + cfg_roots.append(start_symbol) + self.rs_file = rs_file self.proof_dir = Path(proof_dir).resolve() if proof_dir is not None else None self.bug_report = bug_report @@ -145,6 +151,7 @@ def __init__( self.save_smir = save_smir self.smir = smir self.start_symbol = start_symbol + self.cfg_roots = cfg_roots self.break_on_calls = break_on_calls self.break_on_function_calls = break_on_function_calls self.break_on_intrinsic_calls = break_on_intrinsic_calls diff --git a/kmir/src/kmir/smir.py b/kmir/src/kmir/smir.py index e40157e50..211d6d078 100644 --- a/kmir/src/kmir/smir.py +++ b/kmir/src/kmir/smir.py @@ -10,6 +10,7 @@ from .ty import EnumT, RefT, StructT, Ty, TypeMetadata, UnionT if TYPE_CHECKING: + from collections.abc import Sequence from pathlib import Path from typing import Final @@ -180,13 +181,19 @@ def spans(self) -> dict[int, tuple[Path, int, int, int, int]]: def _is_func(item: dict[str, dict]) -> bool: return 'MonoItemFn' in item['mono_item_kind'] - def reduce_to(self, start_name: str) -> SMIRInfo: - # returns a new SMIRInfo with all _items_ removed that are not reachable from the named function - start_ty = self.function_tys[start_name] + def reduce_to(self, start_symbols: str | Sequence[str]) -> SMIRInfo: + # returns a new SMIRInfo with all _items_ removed that are not reachable from the named function(s) + match start_symbols: + case str(symbol): + start_tys = [Ty(self.function_tys[symbol])] + case [*symbols] if symbols and all(isinstance(sym, str) for sym in symbols): + start_tys = [Ty(self.function_tys[sym]) for sym in symbols] + case _: + raise ValueError("SMIRInfo.reduce_to() received an invalid start_symbol") - _LOGGER.debug(f'Reducing items, starting at {start_ty}. Call Edges {self.call_edges}') + _LOGGER.debug(f'Reducing items, starting at {start_tys}. Call Edges {self.call_edges}') - reachable = compute_closure(Ty(start_ty), self.call_edges) + reachable = compute_closure(start_tys, self.call_edges) _LOGGER.debug(f'Reducing to reachable Tys {reachable}') @@ -226,8 +233,8 @@ def call_edges(self) -> dict[Ty, set[Ty]]: return result -def compute_closure(start: Ty, edges: dict[Ty, set[Ty]]) -> set[Ty]: - work = deque([start]) +def compute_closure(start_nodes: Sequence[Ty], edges: dict[Ty, set[Ty]]) -> set[Ty]: + work = deque(start_nodes) reached = set() finished = False while not finished: