Skip to content

Commit e74167f

Browse files
committed
Add interactive optimization mode
1 parent b8831aa commit e74167f

File tree

3 files changed

+55
-3
lines changed

3 files changed

+55
-3
lines changed

pytensor/configdefaults.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -515,6 +515,22 @@ def add_compile_configvars():
515515
in_c_key=False,
516516
)
517517

518+
config.add(
519+
"optimizer_interactive",
520+
"If True, we interrupt after every optimization being applied and display how the graph changed",
521+
BoolParam(False),
522+
in_c_key=False,
523+
)
524+
525+
config.add(
526+
"optimizer_interactive_skip_rewrites",
527+
(
528+
"Do not interrupt after changes from optimizers with these names. Separate names with ',"
529+
),
530+
StrParam(""),
531+
in_c_key=False,
532+
)
533+
518534
config.add(
519535
"on_opt_error",
520536
(

pytensor/graph/features.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import time
44
import warnings
55
from collections import OrderedDict
6+
from difflib import Differ
67
from functools import partial
78
from io import StringIO
89

@@ -564,7 +565,17 @@ def replace_all_validate(
564565
chk = fgraph.checkpoint()
565566

566567
if verbose is None:
567-
verbose = config.optimizer_verbose
568+
interactive = config.optimizer_interactive
569+
verbose = config.optimizer_verbose or interactive
570+
571+
if interactive:
572+
differ = Differ()
573+
bef = pytensor.dprint(
574+
fgraph, file="str", print_type=True, id_type="", print_topo_order=False
575+
)
576+
skip_rewrites = config.optimizer_interactive_skip_rewrites.replace(
577+
" ", ""
578+
).split(",")
568579

569580
for r, new_r in replacements:
570581
try:
@@ -611,6 +622,22 @@ def replace_all_validate(
611622
print(
612623
f"rewriting: rewrite {reason} replaces {r} of {r.owner} with {new_r} of {new_r.owner}"
613624
)
625+
if interactive and str(reason) not in skip_rewrites:
626+
aft = pytensor.dprint(
627+
fgraph,
628+
file="str",
629+
print_type=True,
630+
id_type="",
631+
print_topo_order=False,
632+
)
633+
if bef != aft:
634+
diff = list(
635+
differ.compare(
636+
bef.splitlines(keepends=True), aft.splitlines(keepends=True)
637+
)
638+
)
639+
sys.stdout.writelines(diff)
640+
input("Press any key to continue")
614641

615642
# The return is needed by replace_all_validate_remove
616643
return chk

pytensor/printing.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ def debugprint(
119119
print_destroy_map: bool = False,
120120
print_view_map: bool = False,
121121
print_fgraph_inputs: bool = False,
122+
print_topo_order: bool = True,
122123
) -> Union[str, TextIO]:
123124
r"""Print a graph as text.
124125
@@ -175,6 +176,8 @@ def debugprint(
175176
Whether to print the `view_map`\s of printed objects
176177
print_fgraph_inputs
177178
Print the inputs of `FunctionGraph`\s.
179+
print_topo_order:
180+
Whether to print the toposort ordering of nodes
178181
179182
Returns
180183
-------
@@ -231,7 +234,10 @@ def debugprint(
231234
else:
232235
storage_maps.extend([None for item in obj.maker.fgraph.outputs])
233236
topo = obj.maker.fgraph.toposort()
234-
topo_orders.extend([topo for item in obj.maker.fgraph.outputs])
237+
if print_topo_order:
238+
topo_orders.extend([topo for item in obj.maker.fgraph.outputs])
239+
else:
240+
topo_orders.extend([None for item in obj.maker.fgraph.outputs])
235241
elif isinstance(obj, FunctionGraph):
236242
if print_fgraph_inputs:
237243
inputs_to_print.extend(obj.inputs)
@@ -241,7 +247,10 @@ def debugprint(
241247
[getattr(obj, "storage_map", None) for item in obj.outputs]
242248
)
243249
topo = obj.toposort()
244-
topo_orders.extend([topo for item in obj.outputs])
250+
if print_topo_order:
251+
topo_orders.extend([topo for item in obj.outputs])
252+
else:
253+
topo_orders.extend([None for item in obj.outputs])
245254
elif isinstance(obj, (int, float, np.ndarray)):
246255
print(obj, file=_file)
247256
elif isinstance(obj, (In, Out)):

0 commit comments

Comments
 (0)