Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions doc/gallery/rewrites/graph_rewrites.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -583,7 +583,7 @@
" def tracks(self):\n",
" return [pt.log]\n",
" \n",
" def transform(self, fgraph, node):\n",
" def transform(self, fgraph, node, enforce_tracks=True):\n",
" return local_log1p(node) \n",
" \n",
" def __str__(self):\n",
Expand Down Expand Up @@ -669,8 +669,8 @@
"@node_rewriter(tracks=[pt.abs])\n",
"def local_useless_abs_exp(fgraph, node):\n",
" # Because of the tracks we don't need to check \n",
" # that `node` has a `Sign` Op.\n",
" # We still need to check whether it's input is an `Abs` Op\n",
" # that `node` has a `Abs` Op.\n",
" # We still need to check whether it's input is an `Exp` Op\n",
" exp_node = node.inputs[0].owner\n",
" if exp_node is None or exp_node.op != pt.exp:\n",
" return None\n",
Expand Down
330 changes: 254 additions & 76 deletions pytensor/graph/rewriting/basic.py

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions pytensor/graph/rewriting/kanren.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def results_filter(
self.node_filter = node_filter
super().__init__()

def transform(self, fgraph, node):
def transform(self, fgraph, node, enforce_tracks: bool = True):
if self.node_filter(node) is False:
return False

Expand All @@ -86,13 +86,13 @@ def transform(self, fgraph, node):
q = var()
kanren_results = run(None, q, self.kanren_relation(input_expr, q))

chosen_res = self.results_filter(kanren_results)
chosen_res = self.results_filter(kanren_results) # type: ignore[arg-type]

if chosen_res:
if isinstance(chosen_res, list):
new_outputs = [eval_if_etuple(v) for v in chosen_res]
else:
new_outputs = [eval_if_etuple(chosen_res)]
new_outputs = [eval_if_etuple(chosen_res)] # type: ignore[unreachable]

return new_outputs
else:
Expand Down
213 changes: 209 additions & 4 deletions pytensor/graph/rewriting/unify.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,11 @@

"""

from collections.abc import Mapping
from collections.abc import Mapping, Sequence
from dataclasses import dataclass
from numbers import Number
from types import UnionType
from typing import Any, TypeAlias

import numpy as np
from cons.core import ConsError, _car, _cdr
Expand Down Expand Up @@ -254,6 +257,200 @@ def _unify_ConstrainedVar_object(u, v, s):
_unify.add((object, ConstrainedVar, Mapping), _unify_ConstrainedVar_object)


@dataclass(frozen=True)
class LiteralString:
value: str


OpPatternOpTypeType: TypeAlias = type[Op] | tuple[type[Op], ...] | UnionType


@dataclass(unsafe_hash=True)
class OpPattern:
"""Class that can be unified with Op instances of a given type (or instance) and parameters.

Parameters that are not specified in the OpPattern are ignored during unification.

This is needed because some Ops can be complex to parametrize fully,
and not all parameters are relevant for a given pattern.


Examples
--------

OpPattern can be used in the `tracks` functionality of `node_rewriter` to more flexible filter out nodes.
For Ops that are parametrized by other Ops, it's possible to use nested OpPatterns.

.. test-code::

from pytensor.graph.rewriting.basic import node_rewriter
from pytensor.graph.rewriting.unify import OpPattern
from pytensor.tensor.elemwise import CAReduce
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.slinalg import Solve

@node_rewriter(tracks=[OpPattern(CAReduce, axis=None)])
def local_car_reduce_all_rewriter(fgraph, node):
# This will always be true!
assert isinstance(node.op, CAReduce) and node.op.axis is None
...

# Any Blockwise whose core_op is a Solve Op (or subclass) instance
@node_rewriter(tracks=[OpPattern(Blockwise, core_op=OpPattern(Solve))])
def local_blockwise_solve_triangular_rewriter(fgraph, node):
# This will always be true!
assert isinstance(node.op, Blockwise) and isinstance(node.op.core_op, Solve)
...

# Any Blockwise whose core_op is a Solve Op (or subclass) instance with b_ndim==1
@node_rewriter(tracks=[OpPattern(Blockwise, core_op=OpPattern(Solve, b_ndim=1))])
def local_blockwise_vector_solve_rewriter(fgraph, node):
# This will always be true!
assert (
isinstance(node.op, Blockwise)
and isinstance(node.op.core_op, Solve)
and node.op.core_op.b_ndim == 1
)
...


OpPattern can be used with `PatternNodeRewriter` to define graph rewrites that match Ops with specific parameters.
The example below matches two nested CAReduce Ops with the same `scalar_op`,
the outer with `axis=None` (full reduction) and fuses them into a single CAReduce.
Note, that because we didn't specify it, the axis of the inner CAReduce can be anything.
The same goes for other properties of the Op that are not specified in the OpPattern.

.. testcode::

from pytensor.graph.rewriting.basic import PatternNodeRewriter
from pytensor.graph.rewriting.unify import OpPattern
from pytensor.tensor.basic import Join
from pytensor.tensor.elemwise import CAReduce, Elemwise

def output_fn(fgraph, node, s):
reduce_op = node.op
reduced_a = reduce_op(s["a"])
reduced_b = reduce_op(s["b"])
return Elemwise(s["scalar_op"])(reduced_a, reduced_b)


PatternNodeRewriter(
in_pattern=(OpPattern(CAReduce, scalar_op="scalar_op", axis=None),
(OpPattern(CAReduce, scalar_op="scalar_op",), "x")),
out_pattern=(OpPattern(CAReduce, scalar_op="scalar_op", axis=None), "x"),
)


OpPattern can also be used with `unification.unify` to match Ops with specific parameters.
This is used by PatternNodeRewriter but can also be used directly.

.. testcode::

from unification import var, unify
from etuples import etuple

import pytensor.tensor as pt
from pytensor.graph.rewriting.unify import OpPattern
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.slinalg import Solve

A = var("A")
b = var("b")
pattern = etuple(
OpPattern(Blockwise, core_op=OpPattern(Solve, assume_a="gen")),
A,
b,
)

A_pt = pt.tensor3("A")
b_pt = pt.tensor3("b")
out1 = pt.linalg.solve(A_pt, b_pt)
out2 = pt.linalg.solve(A_pt, b_pt, assume_a="pos")

assert unify(pattern, out1) == {A: A_pt, b: b_pt}
assert unify(pattern, out2) is False

assume_a = var("assume_a")
pattern = etuple(
OpPattern(Blockwise, core_op=OpPattern(Solve, assume_a=assume_a)),
A,
b,
)
assert unify(pattern, out1) == {A: A_pt, b: b_pt, assume_a: "gen"}
assert unify(pattern, out2) == {A: A_pt, b: b_pt, assume_a: "pos"}


"""

op_type: OpPatternOpTypeType
parameters: tuple[tuple[str, Any]]

def __init__(
self,
op_type: OpPatternOpTypeType,
parameters: dict[str, Any] | Sequence[tuple[str, Any]] | None = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

parameters is not of type parameters 🤔

**kwargs,
):
if kwargs:
if parameters is not None:
raise ValueError(
"Cannot provide both parameters dict and keyword arguments"
)
parameters = kwargs
if isinstance(parameters, dict):
parameters = tuple(sorted(parameters.items()))
elif isinstance(parameters, list | tuple):
parameters = tuple(sorted(parameters))
elif parameters is None:
parameters = ()
self.op_type = op_type
self.parameters = parameters # type: ignore[assignment]

def match_op(self, op: Op):
if not isinstance(op, self.op_type):
return False
return self.match_parameters(op)

def match_parameters(self, op):
# This is used by methods that already check the op_type is satisfied
# Some methods may index on the op_type and know in advance the op is matched
# Also recursive calls to OpPattern.match_parameters do the op check outside to exit early (see below)
for key, param in self.parameters:
if isinstance(param, OpPattern):
# Parameters can itself be other OpPatterns
# We check the op_type to avoid a nested call in cases we can reject early
sub_op = getattr(op, key)
if not isinstance(sub_op, param.op_type):
return False
# Match the pattern of the inner Op
# Skip if there are no parameters
if param.parameters and not param.match_parameters(sub_op):
return False
elif getattr(op, key) != param:
return False
return True

def __str__(self):
return f"OpPattern({self.op_type}, {', '.join(f'{k}={v}' for k, v in self.parameters)})"


def _unify_parametrized_op(v: Op, u: OpPattern, s: Mapping):
if not isinstance(v, u.op_type):
yield False
return
for parameter_key, parameter_pattern in u.parameters:
parameter_value = getattr(v, parameter_key)
new_s = yield _unify(parameter_value, parameter_pattern, s)
if new_s is False:
yield False
return
s = new_s
yield s


_unify.add((Op, OpPattern, Mapping), _unify_parametrized_op)


def convert_strs_to_vars(
x: tuple | str | dict, var_map: dict[str, Var] | None = None
) -> ExpressionTuple | Var:
Expand All @@ -266,11 +463,13 @@ def convert_strs_to_vars(
if var_map is None:
var_map = {}

def _convert(y):
def _convert(y, op_prop=False):
if isinstance(y, str):
v = var_map.get(y, var(y))
var_map[y] = v
return v
if isinstance(y, LiteralString):
return y.value
elif isinstance(y, dict):
pattern = y["pattern"]
if not isinstance(pattern, str):
Expand All @@ -282,8 +481,14 @@ def _convert(y):
var_map[pattern] = v
return v
elif isinstance(y, tuple):
return etuple(*(_convert(e) for e in y))
elif isinstance(y, Number | np.ndarray):
return etuple(*(_convert(e, op_prop=op_prop) for e in y))
elif isinstance(y, OpPattern):
return OpPattern(
y.op_type,
{k: _convert(v, op_prop=True) for k, v in y.parameters},
)
elif (not op_prop) and isinstance(y, Number | np.ndarray):
# If we are converting an Op property, we don't want to convert numbers to PyTensor constants
from pytensor.tensor import as_tensor_variable

return as_tensor_variable(y)
Expand Down
8 changes: 5 additions & 3 deletions pytensor/scalar/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1228,6 +1228,8 @@ def __init__(self, output_types_preference=None, name=None):
f"(got: {output_types_preference})"
)
self.output_types_preference = output_types_preference
elif not hasattr(self, "output_types_preference"):
self.output_types_preference = None

def make_node(self, *inputs):
if self.nin >= 0:
Expand All @@ -1247,7 +1249,7 @@ def make_node(self, *inputs):
return Apply(self, inputs, outputs)

def output_types(self, types):
if hasattr(self, "output_types_preference"):
if self.output_types_preference is not None:
variables = self.output_types_preference(*types)
if not isinstance(variables, list | tuple) or any(
not isinstance(x, CType) for x in variables
Expand Down Expand Up @@ -2696,7 +2698,7 @@ class Sign(UnaryScalarOp):
nfunc_spec = ("sign", 1, 1)

@staticmethod
def output_types_preference(x):
def _output_types_preference(x):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def _output_types_preference(x):
def _get_output_types_preference(x):

Since it's a method in this case, not a member variable

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is what init of the base class will look for if I don't pass a function

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't get it

Copy link
Member Author

@ricardoV94 ricardoV94 Sep 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check for ScalarLoop.__init__, this plays a role there, and the name matters

if x == bool:
raise TypeError(x)
return same_out_nocomplex(x)
Expand Down Expand Up @@ -2737,7 +2739,7 @@ def c_code_cache_version(self):
return s


sign = Sign(name="sign")
sign = Sign(name="sign", output_types_preference=Sign._output_types_preference)


class Ceil(UnaryScalarOp):
Expand Down
3 changes: 2 additions & 1 deletion pytensor/tensor/_linalg/solve/rewriting.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.rewriting.basic import register_specialize
from pytensor.tensor.rewriting.blockwise import blockwise_of
from pytensor.tensor.rewriting.linalg import is_matrix_transpose
from pytensor.tensor.slinalg import Solve, cho_solve, cholesky, lu_factor, lu_solve
from pytensor.tensor.variable import TensorVariable
Expand Down Expand Up @@ -227,7 +228,7 @@ def _scan_split_non_sequence_decomposition_and_solve(


@register_specialize
@node_rewriter([Blockwise])
@node_rewriter([blockwise_of(Solve)])
def reuse_decomposition_multiple_solves(fgraph, node):
return _split_decomp_and_solve_steps(
fgraph, node, eager=False, allowed_assume_a={"gen", "tridiagonal", "pos"}
Expand Down
Loading