-
Notifications
You must be signed in to change notification settings - Fork 137
Improve dot lift rewrites #1471
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
5b63faf
to
24d7a98
Compare
096615a
to
cac3cd5
Compare
This reduces the number of rewrite passes, by avoiding constant fold of cast/expand_dims/alloc
…ched_dot` rewrite
2dbf4f0
to
bbd12a7
Compare
…in `local_subtensor_merge`
New rewrite is added to convert unpaired batched row/column matvec or vec products as equivalent matmul products.
The marked xfail test was failing because Ger wasn't introduced, not because of the complex dtype.
bbd12a7
to
425859b
Compare
Codecov ReportAttention: Patch coverage is
❌ Your patch check has failed because the patch coverage (93.48%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage. Additional details and impacted files@@ Coverage Diff @@
## main #1471 +/- ##
==========================================
- Coverage 81.85% 81.83% -0.02%
==========================================
Files 230 230
Lines 52522 52609 +87
Branches 9345 9364 +19
==========================================
+ Hits 42992 43055 +63
- Misses 7095 7114 +19
- Partials 2435 2440 +5
🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR extends and simplifies subtensor lifting and matmul-related rewrites to support Blockwise
ops, unifies all matmul variants under _matmul
, and adds tests and performance benchmarks for partial Jacobian computations.
- Extend
local_subtensor_of_dot
andlocal_subtensor_of_elemwise
to handle batched/blockwise cases and add a newsqueeze
-based subtensor lift. - Unify all matmul-like ops (
matvec
,vecmat
,vecdot
, and matrix–matrix) to use a single_matmul
core and implement batch‐to‐core‐matmul rewrites with optional reshape. - Add new tests for blockwise subtensor lifts, batched matvec rewrites, and partial Jacobian benchmarks; adjust tolerances and seeds for existing tests.
Reviewed Changes
Copilot reviewed 13 out of 13 changed files in this pull request and generated 3 comments.
Show a summary per file
File | Description |
---|---|
tests/test_gradient.py | Import sqrt and add test_benchmark_partial_jacobian |
tests/tensor/test_math.py | Fix RNG seed and set atol for vector/matrix operation tests |
tests/tensor/test_blas.py | Remove xfail markers, add skipif and rename parameters |
tests/tensor/rewriting/test_subtensor_lift.py | Rename subtensor‐of‐elemwise tests, import Op , add blockwise tests |
tests/tensor/rewriting/test_math.py | Add test_batch_matvec_to_matmul parameterized test |
tests/tensor/rewriting/test_blas.py | Update imports, skip fast compile mode, adjust rewrite assertions |
pytensor/tensor/rewriting/subtensor_lift.py | Enhance local_subtensor_of_dot and local_subtensor_of_batch_dims , add squeeze lift |
pytensor/tensor/rewriting/subtensor.py | Minor cleanup in slice merging and useless‐slice rewrites |
pytensor/tensor/rewriting/math.py | Replace DimShuffle ‐through‐dot rewrite with unified _matmul , reposition specializations |
pytensor/tensor/rewriting/linalg.py | Update import of _matmul and use in transpose/blockwise rewrites |
pytensor/tensor/rewriting/elemwise.py | Simplify upcast‐constant rewrite, add register_stabilize |
pytensor/tensor/rewriting/blas.py | Adjust rewrite positions and batched‐dot reshaping logic |
pytensor/tensor/math.py | Add dimension check to Dot22.make_node , unify matmul variants |
Comments suppressed due to low confidence (1)
tests/tensor/rewriting/test_subtensor_lift.py:194
- The test references
tensor3
but it is not imported; addfrom pytensor.tensor import tensor3
to the file's imports to avoid aNameError
.
x = tensor3("x", shape=(7, 5, 11), dtype="float64")
if not (node.inputs[0].owner and isinstance(node.inputs[0].owner.op, Dot)): | ||
|
||
if not ( | ||
is_matrix_transpose(node.out) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In local_lift_transpose_through_dot
, node.out
is not a valid attribute; it should be node.outputs[0]
when checking is_matrix_transpose
.
is_matrix_transpose(node.out) | |
is_matrix_transpose(node.outputs[0]) |
Copilot uses AI. Check for mistakes.
if not allow_reshape: | ||
# TODO: We could allow the x rewrite to go on | ||
# Or just move one axis (the largest) if y is col | ||
return False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A node rewriter should return None
to indicate no rewrite rather than False
, which may be treated as a valid (but incorrect) rewrite result.
return False | |
return None |
Copilot uses AI. Check for mistakes.
|
||
# Apply indices directly on x | ||
# Add empty slices on the axis that squeeze would have removed | ||
new_idxs = np.insert(np.array(idxs, dtype=object), dropped_dims, slice(None)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] Building new_idxs
as a NumPy array of objects can make subsequent tuple indexing confusing; consider constructing a Python list and inserting slices, then converting directly to a tuple for clarity.
new_idxs = np.insert(np.array(idxs, dtype=object), dropped_dims, slice(None)) | |
new_idxs = list(idxs) | |
for dim in dropped_dims: | |
new_idxs.insert(dim, slice(None)) |
Copilot uses AI. Check for mistakes.
This PR was motivated by the partial jacobian computation example in JAX discussed in jax-ml/jax#5904 (comment)
After #1228 it's actually easier to do this sort of optimization in PyTensor since there's no scan to worry about. We already have a bunch of rewrites to lift subtensor operations through elemwise and dots, but we did not have to lift it through blockwise (and blockwise dot - aka matmul). This PR addresses this.
Some notes on the major changes
Do constant_folding in python mode. This is not related to this PR but I noticed a test was taking 10x longer than the others just because there was a simple constant folding operation being triggered in the rewrites, and the whole c-cache was being loaded. This incurs a one time penalty that's pretty large. For users, not interested in the C backend at all, there's no reason to involve the machinery. One single python eval should be pretty fast anyway.This was moved to FixCheckAndRaise
Op C implementation #1521 as it revealed an unrelated bugSimplified
local_upcast_elemwise
. This rewrite was too complex and wasteful, in that it wrapped constants in symbolic expand_dims / alloc + cast. I just do it in numpy directly. This reduces the number of rewrite iterations.Bunch of improvements to rewrites. Including lifting index operations on the batch dimensions of blockwise, and expanding the dot subtensor lift to work with the Blockwise case. This rewrite predates Blockwise. Others are self-explanatory.
Canonicalize matvec, vecmat, vecdot internally to all use
matmul
(i.e., Blockwise of 2x2 dot operation). This makes things simpler for our rewrites, because we only need to worry about one case.The pre-existing
test_local_batched_matmul_to_core_matmul
rewrite was extend to better address cases of batched matvec, vecmat, and vecdot (batch dimensions are moved to the core dimension). It now moves non-ovelapping batch dimensions of both inputs to their core dimensions. It further tries to avoid reshape (needed when combining multiple batch/core dimensions), so that subtensor_lift rewrites mentioned above can work fine through them.Prioritize gemv/ger, which also makes several xfail tests pass. There was probably a misattribution mistaken for these xfails.
Benchmark result added in the last commit:
(Note that vectorize=True goes from underperforming (28ms) to overperforming (.37 ms).
vectorized jacobian code before:
and after:
📚 Documentation preview 📚: https://pytensor--1471.org.readthedocs.build/en/1471/