Skip to content

Change tolerance used to decide whether a constant is one in rewrite functions #1526

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 21 additions & 10 deletions pytensor/graph/rewriting/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1550,6 +1550,7 @@ def __init__(
tracks=(),
get_nodes=None,
values_eq_approx=None,
allow_cast=True,
):
"""

Expand All @@ -1572,6 +1573,10 @@ def __init__(
If you provide `tracks`, you must provide this parameter. It must be a
function that takes the tracked node and returns a list of nodes on
which we will try this rewrite.
values_eq_approx
TODO
allow_cast
Automatically cast the output of the rewrite whenever new and old types differ

Notes
-----
Expand All @@ -1586,6 +1591,7 @@ def __init__(
self.in_pattern = convert_strs_to_vars(in_pattern, var_map=var_map)
self.out_pattern = convert_strs_to_vars(out_pattern, var_map=var_map)
self.values_eq_approx = values_eq_approx
self.allow_cast = allow_cast
if isinstance(in_pattern, list | tuple):
self.op = self.in_pattern[0]
elif isinstance(in_pattern, dict):
Expand Down Expand Up @@ -1630,6 +1636,10 @@ def transform(self, fgraph, node, get_nodes=True):
if node.op != self.op:
return False

if len(node.outputs) != 1:
# PatternNodeRewriter doesn't support replacing multi-output nodes
return False

s = unify(self.in_pattern, node.out)

if s is False:
Expand All @@ -1652,19 +1662,20 @@ def transform(self, fgraph, node, get_nodes=True):
):
return False

if ret.owner:
[old_out] = node.outputs
if not old_out.type.is_super(ret.type):
# Type doesn't match
if not (
len(node.outputs) == len(ret.owner.outputs)
and all(
o.type.is_super(new_o.type)
for o, new_o in zip(node.outputs, ret.owner.outputs, strict=True)
)
self.allow_cast
and isinstance(old_out.type, pytensor.tensor.TensorType)
and isinstance(ret.type, pytensor.tensor.TensorType)
):
return False
else:
# ret is just an input variable
assert len(node.outputs) == 1
if not node.outputs[0].type.is_super(ret.type):

# Try to cast tensors
ret = ret.astype(old_out.type.dtype)
if not old_out.type.is_super(ret.type):
# Still doesn't match
return False

return [ret]
Expand Down
25 changes: 20 additions & 5 deletions pytensor/tensor/rewriting/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -2385,7 +2385,7 @@ def local_log1p(fgraph, node):
log_arg.owner.inputs, only_process_constants=True
)
# scalar_inputs are potentially dimshuffled and fill'd scalars
if scalars and np.allclose(np.sum(scalars), 1):
if scalars and isclose(np.sum(scalars), 1):
if nonconsts:
ninp = variadic_add(*nonconsts)
if ninp.dtype != log_arg.type.dtype:
Expand Down Expand Up @@ -2990,6 +2990,21 @@ def check_input(inputs):
return [ret]


def isclose(x, ref, rtol=0, atol=0, num_ulps=10):
"""

Returns
-------
bool
True iff x is a constant close to ref (by default 10 ULPs).

"""
x = np.asarray(x)
if np.issubdtype(x.dtype, np.floating):
atol = atol + num_ulps * np.abs(np.spacing(x.dtype.type(ref)))
return np.allclose(x, ref, rtol=rtol, atol=atol)


def _skip_mul_1(r):
if r.owner and r.owner.op == mul:
not_is_1 = [i for i in r.owner.inputs if not _is_1(i)]
Expand All @@ -3008,7 +3023,7 @@ def _is_1(expr):
"""
try:
v = get_underlying_scalar_constant_value(expr)
return np.isclose(v, 1)
return isclose(v, 1)
except NotScalarConstantError:
return False

Expand Down Expand Up @@ -3069,7 +3084,7 @@ def is_1pexp(t, only_process_constants=True):
scal_sum = scalars[0]
for s in scalars[1:]:
scal_sum = scal_sum + s
if np.allclose(scal_sum, 1):
if isclose(scal_sum, 1):
return False, maybe_exp.owner.inputs[0]
return None

Expand Down Expand Up @@ -3169,7 +3184,7 @@ def is_neg(var):
for idx, mul_input in enumerate(var_node.inputs):
try:
constant = get_underlying_scalar_constant_value(mul_input)
is_minus_1 = np.isclose(constant, -1)
is_minus_1 = isclose(constant, -1)
except NotScalarConstantError:
is_minus_1 = False
if is_minus_1:
Expand Down Expand Up @@ -3577,7 +3592,7 @@ def local_reciprocal_1_plus_exp(fgraph, node):
# scalar_inputs are potentially dimshuffled and fill'd scalars
if len(nonconsts) == 1:
if nonconsts[0].owner and nonconsts[0].owner.op == exp:
if scalars_ and np.allclose(np.sum(scalars_), 1):
if scalars_ and isclose(np.sum(scalars_), 1):
out = [
alloc_like(
sigmoid(neg(nonconsts[0].owner.inputs[0])),
Expand Down
32 changes: 23 additions & 9 deletions tests/graph/rewriting/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
op_y,
op_z,
)
from tests.unittest_tools import assert_equal_computations


class AssertNoChanges(Feature):
Expand Down Expand Up @@ -725,22 +726,35 @@ def test_patternsub_invalid_dtype(out_pattern):
assert e.type.is_super(fg.outputs[0].type)


def test_patternsub_different_output_lengths():
# Test that PatternNodeRewriter won't replace nodes with different numbers of outputs
ps = PatternNodeRewriter(
(op1, "x"),
("x"),
def test_patternsub_multi_output_nodes():
# Test that PatternNodeRewriter won't attempt to replace multi-output nodes
multiple_op_ps = PatternNodeRewriter(
(op_multiple_outputs, "x"),
"x",
name="ps",
)
rewriter = in2out(ps)

single_op_ps = PatternNodeRewriter(
(op_y, "x"),
"x",
name="ps",
)

rewriter = in2out(multiple_op_ps, single_op_ps)

x = MyVariable("x")
e1, e2 = op_multiple_outputs(x)
o = op1(e1)
o1, o2 = op_y(e1), op_y(e2)

fgraph = FunctionGraph(inputs=[x], outputs=[e2, e1], copy_inputs=False)
rewriter.rewrite(fgraph)
# This shouldn't rewrite because PatternNodeRewriter has no way of specifying which output(s) are being matched
assert_equal_computations(fgraph.outputs, [e2, e1])

fgraph = FunctionGraph(inputs=[x], outputs=[o])
fgraph = FunctionGraph(inputs=[x], outputs=[o2, o1], copy_inputs=False)
rewriter.rewrite(fgraph)
assert fgraph.outputs[0].owner.op == op1
# Having a variable that comes out of a multi-output node should be fine
assert_equal_computations(fgraph.outputs, [e2, e1])


class TestSequentialNodeRewriter:
Expand Down
3 changes: 3 additions & 0 deletions tests/graph/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,9 @@ def make_node(self, *inputs):


class MyOpMultipleOutputs(MyOp):
def __init__(self, name, dmap=None, x=None):
super().__init__(name=name, dmap=dmap, x=x, n_outs=2)

def make_node(self, input):
outputs = [input.type(), input.type()]
return Apply(self, [input], outputs)
Expand Down
30 changes: 22 additions & 8 deletions tests/tensor/rewriting/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
bitwise_and,
bitwise_or,
bitwise_xor,
cast,
conj,
cosh,
deg2rad,
Expand Down Expand Up @@ -123,6 +124,7 @@
dvector,
fmatrices,
fmatrix,
fscalar,
ftensor4,
fvector,
imatrices,
Expand Down Expand Up @@ -4114,25 +4116,36 @@ def test_exp_over_1_plus_exp(self):

def test_local_1msigmoid(self):
m = self.get_mode(excluding=["fusion", "inplace"])
x = fmatrix()
x = fscalar()
xd = dscalar()

# Test `exp_over_1_plus_exp`
f = pytensor.function([x], 1 - exp(x) / (1 + exp(x)), mode=m)
# FIXME: PatternNodeRewriter does not copy stack trace
# (see https://github.com/Theano/Theano/issues/4581)
# assert check_stack_trace(f, ops_to_check=[neg, sigmoid])
assert [node.op for node in f.maker.fgraph.toposort()] == [neg, sigmoid]
assert equal_computations(f.maker.fgraph.outputs, [sigmoid(-x)])

# Test `inv_1_plus_exp`
f = pytensor.function([x], 1 - pt.fill(x, 1.0) / (1 + exp(-x)), mode=m)
# assert check_stack_trace(f, ops_to_check=[neg, sigmoid])
assert [node.op for node in f.maker.fgraph.toposort()] == [neg, sigmoid]
assert equal_computations(f.maker.fgraph.outputs, [sigmoid(-x)])

# Test float constant
f = pytensor.function(
[x], np.array(1.000001, dtype="float32") - sigmoid(x), mode=m
)
assert [node.op for node in f.maker.fgraph.toposort()] == [neg, sigmoid]
for out, expected in [
(np.array(1.0, "float32") - sigmoid(x), sigmoid(-x)),
(np.array(1.0, "float64") - pt.sigmoid(x), cast(sigmoid(-x), "float64")),
(np.array(1.0, "float32") - sigmoid(xd), sigmoid(-xd)),
(np.array(1.0, "float64") - sigmoid(xd), sigmoid(-xd)),
(np.sum(1 / np.array([2, 3, 6], "float32")) - sigmoid(x), sigmoid(-x)),
(np.sum(1 / np.array([2, 3, 6], "float64")) - sigmoid(xd), sigmoid(-xd)),
(np.float32(1 - 9e-6) - sigmoid(x), np.float32(1 - 9e-6) - sigmoid(x)),
(np.float64(1 - 1e-9) - sigmoid(xd), np.float64(1 - 1e-9) - sigmoid(xd)),
]:
rewritten = rewrite_graph(
out, include=["canonicalize", "specialize", "stabilize"]
)
utt.assert_equal_computations([rewritten], [expected], original=out)

def test_local_sigm_times_exp(self):
"""
Expand Down Expand Up @@ -4280,7 +4293,8 @@ def test_log1msigm_to_softplus(self):
f(np.random.random((54, 11)).astype(config.floatX))

# Test close to 1
out = log(1.000001 - sigmoid(x))
x_dtype = np.dtype(x.dtype).type
out = log(np.nextafter(x_dtype(1), x_dtype(2)) - sigmoid(x))
f = pytensor.function([x], out, mode=self.m)
topo = f.maker.fgraph.toposort()
assert len(topo) == 2
Expand Down
36 changes: 36 additions & 0 deletions tests/unittest_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from pytensor.compile.debugmode import str_diagnostic
from pytensor.configdefaults import config
from pytensor.gradient import verify_grad as orig_verify_grad
from pytensor.graph.basic import equal_computations
from pytensor.tensor.basic import as_tensor_variable
from pytensor.tensor.math import _allclose
from pytensor.tensor.math import add as pt_add
Expand Down Expand Up @@ -279,6 +280,41 @@ def assert_allclose(expected, value, rtol=None, atol=None):
raise WrongValue(expected, value, rtol, atol)


def assert_equal_computations(rewritten, expected, *args, original=None, **kwargs):
"""
Assert that `rewritten` computes the same as `expected`.

Parameters
----------
rewritten
The expression after the rewrite pass.
expected
The reference expression to compare against.
*args, **kwargs
Extra arguments forwarded to equal_computations.
original : optional
If given, will be printed in the error message.
"""
__tracebackhide__ = True # Hide traceback for py.test

ok = equal_computations(rewritten, expected, *args, **kwargs)

if not ok:
parts = []

def _dprint(expr):
return pytensor.dprint(expr, print_type=True, file="str")

if original is not None:
parts.append(f"\nOriginal:\n{_dprint(original)}")
parts.append(f"\nRewritten:\n{_dprint(rewritten)}")
parts.append(f"\nExpected:\n{_dprint(expected)}")

raise AssertionError("equal_computations failed\n" + "".join(parts))

return True


class AttemptManyTimes:
"""Decorator for unit tests that forces a unit test to be attempted
multiple times. The test needs to pass a certain number of times for it to
Expand Down