Skip to content

Improve dot lift rewrites #1471

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: main
Choose a base branch
from

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Jun 13, 2025

This PR was motivated by the partial jacobian computation example in JAX discussed in jax-ml/jax#5904 (comment)

After #1228 it's actually easier to do this sort of optimization in PyTensor since there's no scan to worry about. We already have a bunch of rewrites to lift subtensor operations through elemwise and dots, but we did not have to lift it through blockwise (and blockwise dot - aka matmul). This PR addresses this.

Some notes on the major changes

  1. Do constant_folding in python mode. This is not related to this PR but I noticed a test was taking 10x longer than the others just because there was a simple constant folding operation being triggered in the rewrites, and the whole c-cache was being loaded. This incurs a one time penalty that's pretty large. For users, not interested in the C backend at all, there's no reason to involve the machinery. One single python eval should be pretty fast anyway. This was moved to Fix CheckAndRaise Op C implementation #1521 as it revealed an unrelated bug

  2. Simplified local_upcast_elemwise. This rewrite was too complex and wasteful, in that it wrapped constants in symbolic expand_dims / alloc + cast. I just do it in numpy directly. This reduces the number of rewrite iterations.

  3. Bunch of improvements to rewrites. Including lifting index operations on the batch dimensions of blockwise, and expanding the dot subtensor lift to work with the Blockwise case. This rewrite predates Blockwise. Others are self-explanatory.

  4. Canonicalize matvec, vecmat, vecdot internally to all use matmul (i.e., Blockwise of 2x2 dot operation). This makes things simpler for our rewrites, because we only need to worry about one case.

  5. The pre-existing test_local_batched_matmul_to_core_matmul rewrite was extend to better address cases of batched matvec, vecmat, and vecdot (batch dimensions are moved to the core dimension). It now moves non-ovelapping batch dimensions of both inputs to their core dimensions. It further tries to avoid reshape (needed when combining multiple batch/core dimensions), so that subtensor_lift rewrites mentioned above can work fine through them.

  6. Prioritize gemv/ger, which also makes several xfail tests pass. There was probably a misattribution mistaken for these xfails.

Benchmark result added in the last commit:
(Note that vectorize=True goes from underperforming (28ms) to overperforming (.37 ms).

Before
------------------------------------------------------------------------------------------------- benchmark: 2 tests ------------------------------------------------------------------------------------------------
Name (time in ms)                                        Min                Max               Mean            StdDev             Median               IQR            Outliers       OPS            Rounds  Iterations
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_benchmark_partial_jacobian[vectorize=False]      1.9453 (1.0)       2.8201 (1.0)       2.2296 (1.0)      0.0963 (1.0)       2.2031 (1.0)      0.0855 (1.0)         52;25  448.5095 (1.0)         421           1
test_benchmark_partial_jacobian[vectorize=True]      28.8122 (14.81)    36.9261 (13.09)    34.1470 (15.32)    2.3973 (24.90)    34.8889 (15.84)    2.6797 (31.35)         8;1   29.2851 (0.07)         21           1
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

After
--------------------------------------------------------------------------------------------------------- benchmark: 2 tests --------------------------------------------------------------------------------------------------------
Name (time in us)                                           Min                   Max                  Mean             StdDev                Median                IQR            Outliers         OPS            Rounds  Iterations
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_benchmark_partial_jacobian[vectorize=True]        345.7980 (1.0)        658.8850 (1.0)        370.9925 (1.0)      41.1362 (1.0)        357.2400 (1.0)      16.9117 (1.0)         24;34  2,695.4724 (1.0)         287           1
test_benchmark_partial_jacobian[vectorize=False]     2,148.9270 (6.21)     3,062.8910 (4.65)     2,215.2234 (5.97)     77.6787 (1.89)     2,194.7940 (6.14)     44.7890 (2.65)        33;34    451.4217 (0.17)        496           1
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

vectorized jacobian code before:

Subtensor{:stop, :stop} [id A] shape=(5, 5) 9
 ├─ DimShuffle{order=[1,0]} [id B] shape=(1000, 1000) 8
 │  └─ Reshape{3} [id C] shape=(1000, 1000, 1) 7
 │     ├─ Dot22 [id D] shape=(1000, 1000) 6
 │     │  ├─ [[0.903246 ... 74841955]] [id E] shape=(1000, 1000)
 │     │  └─ Reshape{2} [id F] shape=(1000, 1000) 5
 │     │     ├─ True_div [id G] shape=(1000, 1000, 1) 4
 │     │     │  ├─ [[[0.0005] ... [0.0005]]] [id H] shape=(1000, 1000, 1)
 │     │     │  └─ Composite{sqrt((0.001 * i0))} [id I] shape=(1000, 1, 1) 3
 │     │     │     └─ ExpandDims{axes=[1, 2]} [id J] shape=(1000, 1, 1) 2
 │     │     │        └─ CGemv{inplace} [id K] shape=(1000,) 1
 │     │     │           ├─ AllocEmpty{dtype='float64'} [id L] shape=(1000,) 0
 │     │     │           │  └─ 1000 [id M] shape=()
 │     │     │           ├─ 1.0 [id N] shape=()
 │     │     │           ├─ [[0.903246 ... 74841955]] [id O] shape=(1000, 1000)
 │     │     │           ├─ x [id P] shape=(?,)
 │     │     │           └─ 0.0 [id Q] shape=()
 │     │     └─ [1000   -1] [id R] shape=(2,)
 │     └─ [1000 1000    1] [id S] shape=(3,)
 ├─ 5 [id T] shape=()
 └─ 5 [id T] shape=()

and after:

Dot22 [id A] shape=(5, 5) 5
 ├─ True_div [id B] shape=(5, 1000) 4
 │  ├─ [[0.0005 0 ... 0.    ]] [id C] shape=(5, 1000)
 │  └─ Composite{sqrt((0.001 * i0))} [id D] shape=(1, 1000) 3
 │     └─ ExpandDims{axis=0} [id E] shape=(1, 1000) 2
 │        └─ CGemv{inplace} [id F] shape=(1000,) 1
 │           ├─ AllocEmpty{dtype='float64'} [id G] shape=(1000,) 0
 │           │  └─ 1000 [id H] shape=()
 │           ├─ 1.0 [id I] shape=()
 │           ├─ [[0.903246 ... 74841955]] [id J] shape=(1000, 1000)
 │           ├─ x [id K] shape=(?,)
 │           └─ 0.0 [id L] shape=()
 └─ [[0.903246 ... 45926986]] [id M] shape=(1000, 5)

📚 Documentation preview 📚: https://pytensor--1471.org.readthedocs.build/en/1471/

@ricardoV94 ricardoV94 force-pushed the dot_lift_rewrite branch 2 times, most recently from 2dbf4f0 to bbd12a7 Compare July 9, 2025 09:24
@ricardoV94 ricardoV94 marked this pull request as ready for review July 9, 2025 09:31
Copy link

codecov bot commented Jul 9, 2025

Codecov Report

Attention: Patch coverage is 93.48837% with 14 lines in your changes missing coverage. Please review.

Project coverage is 81.83%. Comparing base (b9fc4f8) to head (425859b).
Report is 4 commits behind head on main.

Files with missing lines Patch % Lines
pytensor/tensor/rewriting/subtensor_lift.py 85.24% 7 Missing and 2 partials ⚠️
pytensor/tensor/math.py 85.71% 1 Missing and 1 partial ⚠️
pytensor/tensor/rewriting/blas.py 50.00% 1 Missing and 1 partial ⚠️
pytensor/tensor/rewriting/math.py 98.68% 0 Missing and 1 partial ⚠️

❌ Your patch check has failed because the patch coverage (93.48%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #1471      +/-   ##
==========================================
- Coverage   81.85%   81.83%   -0.02%     
==========================================
  Files         230      230              
  Lines       52522    52609      +87     
  Branches     9345     9364      +19     
==========================================
+ Hits        42992    43055      +63     
- Misses       7095     7114      +19     
- Partials     2435     2440       +5     
Files with missing lines Coverage Δ
pytensor/tensor/rewriting/elemwise.py 93.37% <100.00%> (+0.65%) ⬆️
pytensor/tensor/rewriting/linalg.py 92.08% <100.00%> (ø)
pytensor/tensor/rewriting/subtensor.py 90.00% <100.00%> (+0.18%) ⬆️
pytensor/tensor/rewriting/math.py 89.27% <98.68%> (-0.36%) ⬇️
pytensor/tensor/math.py 92.78% <85.71%> (-0.20%) ⬇️
pytensor/tensor/rewriting/blas.py 89.28% <50.00%> (-0.36%) ⬇️
pytensor/tensor/rewriting/subtensor_lift.py 91.46% <85.24%> (-0.83%) ⬇️

... and 2 files with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR extends and simplifies subtensor lifting and matmul-related rewrites to support Blockwise ops, unifies all matmul variants under _matmul, and adds tests and performance benchmarks for partial Jacobian computations.

  • Extend local_subtensor_of_dot and local_subtensor_of_elemwise to handle batched/blockwise cases and add a new squeeze-based subtensor lift.
  • Unify all matmul-like ops (matvec, vecmat, vecdot, and matrix–matrix) to use a single _matmul core and implement batch‐to‐core‐matmul rewrites with optional reshape.
  • Add new tests for blockwise subtensor lifts, batched matvec rewrites, and partial Jacobian benchmarks; adjust tolerances and seeds for existing tests.

Reviewed Changes

Copilot reviewed 13 out of 13 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
tests/test_gradient.py Import sqrt and add test_benchmark_partial_jacobian
tests/tensor/test_math.py Fix RNG seed and set atol for vector/matrix operation tests
tests/tensor/test_blas.py Remove xfail markers, add skipif and rename parameters
tests/tensor/rewriting/test_subtensor_lift.py Rename subtensor‐of‐elemwise tests, import Op, add blockwise tests
tests/tensor/rewriting/test_math.py Add test_batch_matvec_to_matmul parameterized test
tests/tensor/rewriting/test_blas.py Update imports, skip fast compile mode, adjust rewrite assertions
pytensor/tensor/rewriting/subtensor_lift.py Enhance local_subtensor_of_dot and local_subtensor_of_batch_dims, add squeeze lift
pytensor/tensor/rewriting/subtensor.py Minor cleanup in slice merging and useless‐slice rewrites
pytensor/tensor/rewriting/math.py Replace DimShuffle‐through‐dot rewrite with unified _matmul, reposition specializations
pytensor/tensor/rewriting/linalg.py Update import of _matmul and use in transpose/blockwise rewrites
pytensor/tensor/rewriting/elemwise.py Simplify upcast‐constant rewrite, add register_stabilize
pytensor/tensor/rewriting/blas.py Adjust rewrite positions and batched‐dot reshaping logic
pytensor/tensor/math.py Add dimension check to Dot22.make_node, unify matmul variants
Comments suppressed due to low confidence (1)

tests/tensor/rewriting/test_subtensor_lift.py:194

  • The test references tensor3 but it is not imported; add from pytensor.tensor import tensor3 to the file's imports to avoid a NameError.
        x = tensor3("x", shape=(7, 5, 11), dtype="float64")

if not (node.inputs[0].owner and isinstance(node.inputs[0].owner.op, Dot)):

if not (
is_matrix_transpose(node.out)
Copy link
Preview

Copilot AI Jul 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In local_lift_transpose_through_dot, node.out is not a valid attribute; it should be node.outputs[0] when checking is_matrix_transpose.

Suggested change
is_matrix_transpose(node.out)
is_matrix_transpose(node.outputs[0])

Copilot uses AI. Check for mistakes.

if not allow_reshape:
# TODO: We could allow the x rewrite to go on
# Or just move one axis (the largest) if y is col
return False
Copy link
Preview

Copilot AI Jul 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A node rewriter should return None to indicate no rewrite rather than False, which may be treated as a valid (but incorrect) rewrite result.

Suggested change
return False
return None

Copilot uses AI. Check for mistakes.


# Apply indices directly on x
# Add empty slices on the axis that squeeze would have removed
new_idxs = np.insert(np.array(idxs, dtype=object), dropped_dims, slice(None))
Copy link
Preview

Copilot AI Jul 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Building new_idxs as a NumPy array of objects can make subsequent tuple indexing confusing; consider constructing a Python list and inserting slices, then converting directly to a tuple for clarity.

Suggested change
new_idxs = np.insert(np.array(idxs, dtype=object), dropped_dims, slice(None))
new_idxs = list(idxs)
for dim in dropped_dims:
new_idxs.insert(dim, slice(None))

Copilot uses AI. Check for mistakes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant