-
Notifications
You must be signed in to change notification settings - Fork 137
Change tolerance used to decide whether a constant is one in rewrite functions #1526
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
to allow rewrites that would otherwise fail when the new and old dtype differ. Example: `np.array(1., "float64") - sigmoid(x)` cannot be rewritten as `sigmoid(-x)` (where x is an fmatrix) because the type would change. This commit allows an automatic cast to be added so the expression is rewritten as `cast(sigmoid(-x), "float64")`. Relevant tests added.
…tain dtype like MyType in the tests
…ion isclose, which uses 10 ULPs by default
pytensor/graph/rewriting/basic.py
Outdated
if self.allow_cast and ret.owner.outputs[0].type.dtype != out_dtype: | ||
ret = pytensor.tensor.basic.cast(ret, out_dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not all types have a dtype, we should check it's a TensorType before even trying to access dtype
and doing stuff with it. I would perhaps write like this:
The whole logic is weird though with the if ret.owner
, why do we care about the type of outputs we're not replacing. It's actually dangerous to try to replace only one of them without the user consent. Since this is WIP I would change to if len(node.outputs) != 1: return False
, before we try to unify.
Then here we just have to worry about the final else branch below:
[old_out] = node.outputs
if not old_out.type.is_super(ret.type):
if not (
self.allow_cast
and isinstance(old_out.type, TensorType)
and isinstance(ret.type, TensorType)
):
return False
# Try to cast
ret = ret.astype(old_out.type.dtype)
if not old_out.type.is_super(ret.type):
return False
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am happy to replace as you suggest but I am not sure how to fit it within the rest. This is the current code:
if ret.owner:
if not (
len(node.outputs) == len(ret.owner.outputs)
and all(
o.type.is_super(new_o.type)
for o, new_o in zip(node.outputs, ret.owner.outputs, strict=True)
)
):
return False
else:
# ret is just an input variable
assert len(node.outputs) == 1
if not node.outputs[0].type.is_super(ret.type):
return False
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you only need what I wrote, above, template something like this
def transform(...):
...
if node.op != self.op:
return False
if len(node.outputs) != 1:
# PatternNodeRewriter doesn't support replacing multi-output nodes
return False
...
if not self.allow_multiple_clients:
...
# New logic
[old_out] = node.outputs
if not old_out.type.is_super(ret.type):
# Type doesn't match
if not (
self.allow_cast
and isinstance(old_out.type, TensorType)
and isinstance(ret.type, TensorType)
):
return False
# Try to cast tensors
ret = ret.astype(old_out.type.dtype)
if not old_out.type.is_super(ret.type):
# Still doesn't match
return False
return [ret]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you sure PatternNodeRewriter
is supposed to only work with single inputs? I get the following error:
def test_patternsub_different_output_lengths():
# Test that PatternNodeRewriter won't replace nodes with different numbers of outputs
ps = PatternNodeRewriter(
(op1, "x"),
("x"),
name="ps",
)
rewriter = in2out(ps)
x = MyVariable("x")
e1, e2 = op_multiple_outputs(x)
o = op1(e1)
fgraph = FunctionGraph(inputs=[x], outputs=[o])
rewriter.rewrite(fgraph)
> assert fgraph.outputs[0].owner.op == op1
E assert OpMultipleOutputs == op1
E + where OpMultipleOutputs = OpMultipleOutputs(x).op
E + where OpMultipleOutputs(x) = OpMultipleOutputs.0.owner
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think that test makes sense. It's like saying you don't want to replace log(exp(x)
, if x
comes from a multi-output node. We usually don't care about the provenance of a root variable in a rewrite. Nothing in that rewrite cares about op_multiple_outputs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It was here: https://github.com/aesara-devs/aesara/pull/803/files
The problem was before the zip would be shorter if node.outputs and replacement didn't match in length. But the whole thing goes away if you just say it doesn't support replacing multiple outputs nodes, which it doesn't really.
That test can be removed in favor of one where it refuses to replace OpMultipleOutputs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. It sorts of makes sense to me but I know too little of the PyTensor internals to fully understand.
Can you propose a quick way to modify/replace the test with one where it refuses to replace OpMultipleOutputs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you push your changes (if you haven't already), I can push the new test on top of it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have pushed all my changes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I pushed a commit that changes the behavior of the test, have a look and let me know if there's anything else missing
But it's fine if they're just root inputs
Codecov ReportAttention: Patch coverage is
❌ Your patch check has failed because the patch coverage (88.23%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage. Additional details and impacted files@@ Coverage Diff @@
## main #1526 +/- ##
==========================================
- Coverage 81.99% 81.85% -0.14%
==========================================
Files 231 230 -1
Lines 52253 52532 +279
Branches 9203 9345 +142
==========================================
+ Hits 42843 42999 +156
Misses 7099 7099
- Partials 2311 2434 +123
🚀 New features to boost your workflow:
|
Description
The previous tolerance used within a rewrite to decide whether a constant is one (or minus one) is too large.$c=1 − p$ where p is 1 in 10000.
For example
c - sigmoid(x)
is rewritten assigmoid(-x)
even whenMany rewrites currently use np.isclose and np.allclose with the default tolerances (rtol=1e-05, atol=1e-08), which are unnecessarily large (and independent on the data type of the constant computed).
This PR implements a function
isclose
used within all rewrites in place ofnp.isclose
andnp.allclose
. This new function uses a much smaller tolerance by default, i.e. 10 unit in the last place (ULPs). This tolerance is dtype dependent, so it's stricter for a float64 than a float32. See #1497 for a back of the envelope justification for choosing 10 ULPs.This PR also implements
allow_cast
in PatternNodeRewriter to allow rewrites that would otherwise fail when the new and old dtype differ. For example, a rewrite attempt fornp.array(1., "float64") - sigmoid(x)
(where x isfmatrix
) currently fails because in the rewritesigmoid(-x)
the type would change. This PR allows an automatic cast to be added so the expression is rewritten ascast(sigmoid(-x), "float64")
.Relevant tests added.
Related Issue
Checklist
Type of change
📚 Documentation preview 📚: https://pytensor--1526.org.readthedocs.build/en/1526/