From 9ad954006c13a1a47d188d86340f344c4d7bc618 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sat, 21 Jun 2025 19:50:56 +0200 Subject: [PATCH 1/9] block_diag dot rewrite --- pytensor/tensor/rewriting/math.py | 73 +++++++++++++++++++++++++++-- tests/tensor/rewriting/test_math.py | 73 +++++++++++++++++++++++++++++ 2 files changed, 142 insertions(+), 4 deletions(-) diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index aef363655e..a0d5a9dc7b 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -29,9 +29,11 @@ cast, constant, get_underlying_scalar_constant_value, + join, moveaxis, ones_like, register_infer_shape, + split, switch, zeros_like, ) @@ -99,6 +101,7 @@ ) from pytensor.tensor.rewriting.elemwise import apply_local_dimshuffle_lift from pytensor.tensor.shape import Shape, Shape_i +from pytensor.tensor.slinalg import BlockDiagonal from pytensor.tensor.subtensor import Subtensor from pytensor.tensor.type import ( complex_dtypes, @@ -167,6 +170,72 @@ def local_0_dot_x(fgraph, node): return [constant_zero] +@register_canonicalize +@register_specialize +@register_stabilize +@node_rewriter([Dot]) +def local_block_diag_dot_to_dot_block_diag(fgraph, node): + r""" + Perform the rewrite ``dot(block_diag(A, B), C) -> block_diag(dot(A, C), dot(B, C))`` + + BlockDiag results in the creation of a matrix of shape ``(n1 * n2, m1 * m2)``. Because dot has complexity + of approximately O(n^3), it's always better to perform two dot products on the smaller matrices, rather than + a single dot on the larger matrix. + """ + x, y = node.inputs + op = node.op + + def check_for_block_diag(x): + return x.owner and ( + isinstance(x.owner.op, BlockDiagonal) + or isinstance(x.owner.op, Blockwise) + and isinstance(x.owner.op.core_op, BlockDiagonal) + ) + + if not (check_for_block_diag(x) or check_for_block_diag(y)): + return None + + # Case 1: Only one input is BlockDiagonal. In this case, multiply all components of the block-diagonal with the + # non-block diagonal, and return a new block diagonal + if check_for_block_diag(x) and not check_for_block_diag(y): + components = x.owner.inputs + y_splits = split( + y, + splits_size=[component.shape[-1] for component in components], + n_splits=len(components), + ) + new_components = [ + op(component, y_split) for component, y_split in zip(components, y_splits) + ] + new_output = join(0, *new_components) + elif not check_for_block_diag(x) and check_for_block_diag(y): + components = y.owner.inputs + new_components = [op(x, component) for component in components] + new_output = join(0, *new_components) + + # Case 2: Both inputs are BlockDiagonal. Here we can proceed only if the static shapes are known and identical. In + # that case, blockdiag(a,b) @ blockdiag(c, d) = blockdiag(a @ c, b @ d), but this is not true in the general case + elif any(shape is None for shape in (*x.type.shape, *y.type.shape)): + return None + elif x.ndim == y.ndim and all( + x_shape == y_shape for x_shape, y_shape in zip(x.type.shape, y.type.shape) + ): + x_components = x.owner.inputs + y_components = y.owner.inputs + + if len(x_components) != len(y_components): + return None + + new_output = BlockDiagonal(len(x_components))( + *[op(x_comp, y_comp) for x_comp, y_comp in zip(x_components, y_components)] + ) + else: + return None + + copy_stack_trace(node.outputs[0], new_output) + return [new_output] + + @register_canonicalize @node_rewriter([DimShuffle]) def local_lift_transpose_through_dot(fgraph, node): @@ -2496,7 +2565,6 @@ def add_calculate(num, denum, aslist=False, out_type=None): name="add_canonizer_group", ) - register_canonicalize(local_add_canonizer, "shape_unsafe", name="local_add_canonizer") @@ -3619,7 +3687,6 @@ def logmexpm1_to_log1mexp(fgraph, node): ) register_stabilize(logdiffexp_to_log1mexpdiff, name="logdiffexp_to_log1mexpdiff") - # log(sigmoid(x) / (1 - sigmoid(x))) -> x # i.e logit(sigmoid(x)) -> x local_logit_sigmoid = PatternNodeRewriter( @@ -3633,7 +3700,6 @@ def logmexpm1_to_log1mexp(fgraph, node): register_canonicalize(local_logit_sigmoid) register_specialize(local_logit_sigmoid) - # sigmoid(log(x / (1-x)) -> x # i.e., sigmoid(logit(x)) -> x local_sigmoid_logit = PatternNodeRewriter( @@ -3674,7 +3740,6 @@ def local_useless_conj(fgraph, node): register_specialize(local_polygamma_to_tri_gamma) - local_log_kv = PatternNodeRewriter( # Rewrite log(kv(v, x)) = log(kve(v, x) * exp(-x)) -> log(kve(v, x)) - x # During stabilize -x is converted to -1.0 * x diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index c4999fcd33..3be12da3e5 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -113,6 +113,7 @@ simplify_mul, ) from pytensor.tensor.shape import Reshape, Shape_i, SpecifyShape, specify_shape +from pytensor.tensor.slinalg import BlockDiagonal from pytensor.tensor.type import ( TensorType, cmatrix, @@ -4654,3 +4655,75 @@ def test_local_dot_to_mul(batched, a_shape, b_shape): out.eval({a: a_test, b: b_test}, mode=test_mode), rewritten_out.eval({a: a_test, b: b_test}, mode=test_mode), ) + + +def test_local_block_diag_dot_to_dot_block_diag(): + """ + Test that dot(block_diag(x, y,), z) is rewritten to concat(dot(x, z[:n]), dot(y, z[n:])) + """ + a = tensor("a", shape=(4, 2)) + b = tensor("b", shape=(2, 4)) + c = tensor("c", shape=(4, 4)) + d = tensor("d", shape=(10,)) + + x = pt.linalg.block_diag(a, b, c) + out = x @ d + + fn = pytensor.function([a, b, c, d], out) + assert not any( + isinstance(node, BlockDiagonal) for node in fn.maker.fgraph.toposort() + ) + + fn_expected = pytensor.function( + [a, b, c, d], + out, + mode=get_default_mode().excluding("local_block_diag_dot_to_dot_block_diag"), + ) + + rng = np.random.default_rng() + a_val = rng.normal(size=a.type.shape).astype(a.type.dtype) + b_val = rng.normal(size=b.type.shape).astype(b.type.dtype) + c_val = rng.normal(size=c.type.shape).astype(c.type.dtype) + d_val = rng.normal(size=d.type.shape).astype(d.type.dtype) + + np.testing.assert_allclose( + fn(a_val, b_val, c_val, d_val), + fn_expected(a_val, b_val, c_val, d_val), + atol=1e-6 if config.floatX == "float32" else 1e-12, + rtol=1e-6 if config.floatX == "float32" else 1e-12, + ) + + +@pytest.mark.parametrize("rewrite", [True, False], ids=["rewrite", "no_rewrite"]) +@pytest.mark.parametrize("size", [10, 100, 1000], ids=["small", "medium", "large"]) +def test_block_diag_dot_to_dot_concat_benchmark(benchmark, size, rewrite): + rng = np.random.default_rng() + a_size = int(rng.uniform(0, size)) + b_size = int(rng.uniform(0, size - a_size)) + c_size = size - a_size - b_size + + a = tensor("a", shape=(a_size, a_size)) + b = tensor("b", shape=(b_size, b_size)) + c = tensor("c", shape=(c_size, c_size)) + d = tensor("d", shape=(size,)) + + x = pt.linalg.block_diag(a, b, c) + out = x @ d + + mode = get_default_mode() + if not rewrite: + mode = mode.excluding("local_block_diag_dot_to_dot_block_diag") + fn = pytensor.function([a, b, c, d], out, mode=mode) + + a_val = rng.normal(size=a.type.shape).astype(a.type.dtype) + b_val = rng.normal(size=b.type.shape).astype(b.type.dtype) + c_val = rng.normal(size=c.type.shape).astype(c.type.dtype) + d_val = rng.normal(size=d.type.shape).astype(d.type.dtype) + + benchmark( + fn, + a_val, + b_val, + c_val, + d_val, + ) From ffb71d30c2e8516f811cdd1e88fb48f01e06bc21 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sat, 21 Jun 2025 22:02:28 +0200 Subject: [PATCH 2/9] Handle right-multiplication case --- pytensor/tensor/rewriting/math.py | 29 +++++++++++------------------ tests/tensor/rewriting/test_math.py | 11 ++++++++--- 2 files changed, 19 insertions(+), 21 deletions(-) diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index a0d5a9dc7b..a35363a170 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -176,7 +176,7 @@ def local_0_dot_x(fgraph, node): @node_rewriter([Dot]) def local_block_diag_dot_to_dot_block_diag(fgraph, node): r""" - Perform the rewrite ``dot(block_diag(A, B), C) -> block_diag(dot(A, C), dot(B, C))`` + Perform the rewrite ``dot(block_diag(A, B), C) -> concat(dot(A, C), dot(B, C))`` BlockDiag results in the creation of a matrix of shape ``(n1 * n2, m1 * m2)``. Because dot has complexity of approximately O(n^3), it's always better to perform two dot products on the smaller matrices, rather than @@ -210,25 +210,18 @@ def check_for_block_diag(x): new_output = join(0, *new_components) elif not check_for_block_diag(x) and check_for_block_diag(y): components = y.owner.inputs - new_components = [op(x, component) for component in components] - new_output = join(0, *new_components) - - # Case 2: Both inputs are BlockDiagonal. Here we can proceed only if the static shapes are known and identical. In - # that case, blockdiag(a,b) @ blockdiag(c, d) = blockdiag(a @ c, b @ d), but this is not true in the general case - elif any(shape is None for shape in (*x.type.shape, *y.type.shape)): - return None - elif x.ndim == y.ndim and all( - x_shape == y_shape for x_shape, y_shape in zip(x.type.shape, y.type.shape) - ): - x_components = x.owner.inputs - y_components = y.owner.inputs + x_splits = split( + x, + splits_size=[component.shape[0] for component in components], + n_splits=len(components), + axis=1, + ) - if len(x_components) != len(y_components): - return None + new_components = [ + op(x_split, component) for component, x_split in zip(components, x_splits) + ] + new_output = join(1, *new_components) - new_output = BlockDiagonal(len(x_components))( - *[op(x_comp, y_comp) for x_comp, y_comp in zip(x_components, y_components)] - ) else: return None diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index 3be12da3e5..b1451825ab 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -4657,17 +4657,22 @@ def test_local_dot_to_mul(batched, a_shape, b_shape): ) -def test_local_block_diag_dot_to_dot_block_diag(): +@pytest.mark.parametrize("left_multiply", [True, False], ids=["left", "right"]) +def test_local_block_diag_dot_to_dot_block_diag(left_multiply): """ Test that dot(block_diag(x, y,), z) is rewritten to concat(dot(x, z[:n]), dot(y, z[n:])) """ a = tensor("a", shape=(4, 2)) b = tensor("b", shape=(2, 4)) c = tensor("c", shape=(4, 4)) - d = tensor("d", shape=(10,)) + d = tensor("d", shape=(10, 10)) x = pt.linalg.block_diag(a, b, c) - out = x @ d + + if left_multiply: + out = x @ d + else: + out = d @ x fn = pytensor.function([a, b, c, d], out) assert not any( From c5137d7214aff6d35cd2fbcee7a30e9e2459b1d7 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sat, 21 Jun 2025 22:17:33 +0200 Subject: [PATCH 3/9] The robot was right! --- tests/tensor/rewriting/test_math.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index b1451825ab..32bcb5c471 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -4676,7 +4676,7 @@ def test_local_block_diag_dot_to_dot_block_diag(left_multiply): fn = pytensor.function([a, b, c, d], out) assert not any( - isinstance(node, BlockDiagonal) for node in fn.maker.fgraph.toposort() + isinstance(node.op, BlockDiagonal) for node in fn.maker.fgraph.toposort() ) fn_expected = pytensor.function( From 3b66eba6cd44675ad900561c6879191ceec66d7c Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sun, 22 Jun 2025 12:33:12 +0200 Subject: [PATCH 4/9] Respond to feedback --- pytensor/tensor/rewriting/math.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index a35363a170..b12d75ae35 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -170,10 +170,8 @@ def local_0_dot_x(fgraph, node): return [constant_zero] -@register_canonicalize -@register_specialize @register_stabilize -@node_rewriter([Dot]) +@node_rewriter([Blockwise]) def local_block_diag_dot_to_dot_block_diag(fgraph, node): r""" Perform the rewrite ``dot(block_diag(A, B), C) -> concat(dot(A, C), dot(B, C))`` @@ -182,8 +180,8 @@ def local_block_diag_dot_to_dot_block_diag(fgraph, node): of approximately O(n^3), it's always better to perform two dot products on the smaller matrices, rather than a single dot on the larger matrix. """ - x, y = node.inputs - op = node.op + if not isinstance(node.op.core_op, BlockDiagonal): + return def check_for_block_diag(x): return x.owner and ( @@ -192,6 +190,15 @@ def check_for_block_diag(x): and isinstance(x.owner.op.core_op, BlockDiagonal) ) + # Check that the BlockDiagonal is an input to a Dot node: + clients = list(get_clients_at_depth(fgraph, node, depth=1)) + if not clients or len(clients) > 1 or not isinstance(clients[0].op, Dot): + return + + [dot_node] = clients + op = dot_node.op + x, y = dot_node.inputs + if not (check_for_block_diag(x) or check_for_block_diag(y)): return None @@ -208,6 +215,7 @@ def check_for_block_diag(x): op(component, y_split) for component, y_split in zip(components, y_splits) ] new_output = join(0, *new_components) + elif not check_for_block_diag(x) and check_for_block_diag(y): components = y.owner.inputs x_splits = split( @@ -222,11 +230,14 @@ def check_for_block_diag(x): ] new_output = join(1, *new_components) + # Case 2: Both inputs are BlockDiagonal. Do nothing else: + # TODO: If shapes are statically known and all components have equal shapes, we could rewrite + # this case to block_diag(*[dot(comp_1, comp_2) for comp_1, comp_2 in zip(x.owner.inputs, y.owner.inputs)]) return None copy_stack_trace(node.outputs[0], new_output) - return [new_output] + return {dot_node.outputs[0]: new_output} @register_canonicalize From 09bddf1112246841c302d6ae1d821d1e79914431 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sun, 22 Jun 2025 13:00:24 +0200 Subject: [PATCH 5/9] Use `rewrite_mode` defined in `test_math.py` for testing --- tests/tensor/rewriting/test_math.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index 32bcb5c471..137b91fb34 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -4674,7 +4674,7 @@ def test_local_block_diag_dot_to_dot_block_diag(left_multiply): else: out = d @ x - fn = pytensor.function([a, b, c, d], out) + fn = pytensor.function([a, b, c, d], out, mode=rewrite_mode) assert not any( isinstance(node.op, BlockDiagonal) for node in fn.maker.fgraph.toposort() ) @@ -4682,7 +4682,7 @@ def test_local_block_diag_dot_to_dot_block_diag(left_multiply): fn_expected = pytensor.function( [a, b, c, d], out, - mode=get_default_mode().excluding("local_block_diag_dot_to_dot_block_diag"), + mode=rewrite_mode.excluding("local_block_diag_dot_to_dot_block_diag"), ) rng = np.random.default_rng() From 17fbeb33c309b4e74a445de48307a763f9d1d7ec Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Thu, 26 Jun 2025 14:25:24 +0800 Subject: [PATCH 6/9] Handle case with multiple clients --- pytensor/tensor/rewriting/math.py | 83 +++++++++++++++-------------- tests/tensor/rewriting/test_math.py | 15 +++--- 2 files changed, 51 insertions(+), 47 deletions(-) diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index b12d75ae35..e1af06f8b4 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -191,53 +191,54 @@ def check_for_block_diag(x): ) # Check that the BlockDiagonal is an input to a Dot node: - clients = list(get_clients_at_depth(fgraph, node, depth=1)) - if not clients or len(clients) > 1 or not isinstance(clients[0].op, Dot): - return + for client in get_clients_at_depth(fgraph, node, depth=1): + if not isinstance(client.op, Dot): + return - [dot_node] = clients - op = dot_node.op - x, y = dot_node.inputs + op = client.op + x, y = client.inputs - if not (check_for_block_diag(x) or check_for_block_diag(y)): - return None + if not (check_for_block_diag(x) or check_for_block_diag(y)): + return None - # Case 1: Only one input is BlockDiagonal. In this case, multiply all components of the block-diagonal with the - # non-block diagonal, and return a new block diagonal - if check_for_block_diag(x) and not check_for_block_diag(y): - components = x.owner.inputs - y_splits = split( - y, - splits_size=[component.shape[-1] for component in components], - n_splits=len(components), - ) - new_components = [ - op(component, y_split) for component, y_split in zip(components, y_splits) - ] - new_output = join(0, *new_components) - - elif not check_for_block_diag(x) and check_for_block_diag(y): - components = y.owner.inputs - x_splits = split( - x, - splits_size=[component.shape[0] for component in components], - n_splits=len(components), - axis=1, - ) + # Case 1: Only one input is BlockDiagonal. In this case, multiply all components of the block-diagonal with the + # non-block diagonal, and return a new block diagonal + if check_for_block_diag(x) and not check_for_block_diag(y): + components = x.owner.inputs + y_splits = split( + y, + splits_size=[component.shape[-1] for component in components], + n_splits=len(components), + ) + new_components = [ + op(component, y_split) + for component, y_split in zip(components, y_splits) + ] + new_output = join(0, *new_components) + + elif not check_for_block_diag(x) and check_for_block_diag(y): + components = y.owner.inputs + x_splits = split( + x, + splits_size=[component.shape[0] for component in components], + n_splits=len(components), + axis=1, + ) - new_components = [ - op(x_split, component) for component, x_split in zip(components, x_splits) - ] - new_output = join(1, *new_components) + new_components = [ + op(x_split, component) + for component, x_split in zip(components, x_splits) + ] + new_output = join(1, *new_components) - # Case 2: Both inputs are BlockDiagonal. Do nothing - else: - # TODO: If shapes are statically known and all components have equal shapes, we could rewrite - # this case to block_diag(*[dot(comp_1, comp_2) for comp_1, comp_2 in zip(x.owner.inputs, y.owner.inputs)]) - return None + # Case 2: Both inputs are BlockDiagonal. Do nothing + else: + # TODO: If shapes are statically known and all components have equal shapes, we could rewrite + # this case to block_diag(*[dot(comp_1, comp_2) for comp_1, comp_2 in zip(x.owner.inputs, y.owner.inputs)]) + return None - copy_stack_trace(node.outputs[0], new_output) - return {dot_node.outputs[0]: new_output} + copy_stack_trace(node.outputs[0], new_output) + return {client.outputs[0]: new_output} @register_canonicalize diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index 137b91fb34..d30519d36d 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -4666,21 +4666,23 @@ def test_local_block_diag_dot_to_dot_block_diag(left_multiply): b = tensor("b", shape=(2, 4)) c = tensor("c", shape=(4, 4)) d = tensor("d", shape=(10, 10)) + e = tensor("e", shape=(10, 10)) x = pt.linalg.block_diag(a, b, c) + # Test multiple clients are all rewritten if left_multiply: - out = x @ d + out = [x @ d, x @ e] else: - out = d @ x + out = [d @ x, e @ x] - fn = pytensor.function([a, b, c, d], out, mode=rewrite_mode) + fn = pytensor.function([a, b, c, d, e], out, mode=rewrite_mode) assert not any( isinstance(node.op, BlockDiagonal) for node in fn.maker.fgraph.toposort() ) fn_expected = pytensor.function( - [a, b, c, d], + [a, b, c, d, e], out, mode=rewrite_mode.excluding("local_block_diag_dot_to_dot_block_diag"), ) @@ -4690,10 +4692,11 @@ def test_local_block_diag_dot_to_dot_block_diag(left_multiply): b_val = rng.normal(size=b.type.shape).astype(b.type.dtype) c_val = rng.normal(size=c.type.shape).astype(c.type.dtype) d_val = rng.normal(size=d.type.shape).astype(d.type.dtype) + e_val = rng.normal(size=e.type.shape).astype(e.type.dtype) np.testing.assert_allclose( - fn(a_val, b_val, c_val, d_val), - fn_expected(a_val, b_val, c_val, d_val), + fn(a_val, b_val, c_val, d_val, e_val), + fn_expected(a_val, b_val, c_val, d_val, e_val), atol=1e-6 if config.floatX == "float32" else 1e-12, rtol=1e-6 if config.floatX == "float32" else 1e-12, ) From 7cef064fd602a62e2f026c6b2c329c9875beb77d Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Thu, 26 Jun 2025 14:41:52 +0800 Subject: [PATCH 7/9] use `continue` on rewrite failures when checking clients --- pytensor/tensor/rewriting/math.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index e1af06f8b4..5a296aecc5 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -193,13 +193,13 @@ def check_for_block_diag(x): # Check that the BlockDiagonal is an input to a Dot node: for client in get_clients_at_depth(fgraph, node, depth=1): if not isinstance(client.op, Dot): - return + continue op = client.op x, y = client.inputs if not (check_for_block_diag(x) or check_for_block_diag(y)): - return None + continue # Case 1: Only one input is BlockDiagonal. In this case, multiply all components of the block-diagonal with the # non-block diagonal, and return a new block diagonal @@ -235,7 +235,7 @@ def check_for_block_diag(x): else: # TODO: If shapes are statically known and all components have equal shapes, we could rewrite # this case to block_diag(*[dot(comp_1, comp_2) for comp_1, comp_2 in zip(x.owner.inputs, y.owner.inputs)]) - return None + continue copy_stack_trace(node.outputs[0], new_output) return {client.outputs[0]: new_output} From 9455b86757b7e35fe77479d12380babe937729d7 Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Tue, 8 Jul 2025 18:35:28 +0800 Subject: [PATCH 8/9] pair coding results --- pytensor/tensor/rewriting/elemwise.py | 2 + pytensor/tensor/rewriting/math.py | 69 ++++++++++----------------- tests/tensor/rewriting/test_math.py | 36 +++++++++----- 3 files changed, 52 insertions(+), 55 deletions(-) diff --git a/pytensor/tensor/rewriting/elemwise.py b/pytensor/tensor/rewriting/elemwise.py index 98fc4e074c..b83681b12b 100644 --- a/pytensor/tensor/rewriting/elemwise.py +++ b/pytensor/tensor/rewriting/elemwise.py @@ -41,6 +41,7 @@ broadcasted_by, register_canonicalize, register_specialize, + register_stabilize, ) from pytensor.tensor.shape import shape_padleft from pytensor.tensor.variable import TensorConstant @@ -395,6 +396,7 @@ def is_dimshuffle_useless(new_order, input): @register_canonicalize +@register_stabilize @register_specialize @node_rewriter([DimShuffle]) def local_dimshuffle_lift(fgraph, node): diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index 5a296aecc5..16e4081924 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -183,59 +183,40 @@ def local_block_diag_dot_to_dot_block_diag(fgraph, node): if not isinstance(node.op.core_op, BlockDiagonal): return - def check_for_block_diag(x): - return x.owner and ( - isinstance(x.owner.op, BlockDiagonal) - or isinstance(x.owner.op, Blockwise) - and isinstance(x.owner.op.core_op, BlockDiagonal) - ) - # Check that the BlockDiagonal is an input to a Dot node: for client in get_clients_at_depth(fgraph, node, depth=1): - if not isinstance(client.op, Dot): + if not ( + ( + isinstance(client.op, Dot) + and all(input.ndim == 2 for input in client.inputs) + ) + or client.op == _matrix_matrix_matmul + ): continue op = client.op - x, y = client.inputs - if not (check_for_block_diag(x) or check_for_block_diag(y)): - continue + client_idx = client.inputs.index(node.outputs[0]) - # Case 1: Only one input is BlockDiagonal. In this case, multiply all components of the block-diagonal with the - # non-block diagonal, and return a new block diagonal - if check_for_block_diag(x) and not check_for_block_diag(y): - components = x.owner.inputs - y_splits = split( - y, - splits_size=[component.shape[-1] for component in components], - n_splits=len(components), - ) - new_components = [ - op(component, y_split) - for component, y_split in zip(components, y_splits) - ] - new_output = join(0, *new_components) - - elif not check_for_block_diag(x) and check_for_block_diag(y): - components = y.owner.inputs - x_splits = split( - x, - splits_size=[component.shape[0] for component in components], - n_splits=len(components), - axis=1, - ) + other_input = client.inputs[1 - client_idx] + components = node.inputs - new_components = [ - op(x_split, component) - for component, x_split in zip(components, x_splits) - ] - new_output = join(1, *new_components) + split_axis = -2 if client_idx == 0 else -1 + shape_idx = -1 if client_idx == 0 else -2 - # Case 2: Both inputs are BlockDiagonal. Do nothing - else: - # TODO: If shapes are statically known and all components have equal shapes, we could rewrite - # this case to block_diag(*[dot(comp_1, comp_2) for comp_1, comp_2 in zip(x.owner.inputs, y.owner.inputs)]) - continue + other_dot_input_split = split( + other_input, + splits_size=[component.shape[shape_idx] for component in components], + n_splits=len(components), + axis=split_axis, + ) + new_components = [ + op(component, other_split) + if client_idx == 0 + else op(other_split, component) + for component, other_split in zip(components, other_dot_input_split) + ] + new_output = join(split_axis, *new_components) copy_stack_trace(node.outputs[0], new_output) return {client.outputs[0]: new_output} diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index d30519d36d..3b43db4fb4 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -4658,15 +4658,21 @@ def test_local_dot_to_mul(batched, a_shape, b_shape): @pytest.mark.parametrize("left_multiply", [True, False], ids=["left", "right"]) -def test_local_block_diag_dot_to_dot_block_diag(left_multiply): +@pytest.mark.parametrize( + "batch_left", [True, False], ids=["batched_left", "unbatched_left"] +) +@pytest.mark.parametrize( + "batch_right", [True, False], ids=["batched_right", "unbatched_right"] +) +def test_local_block_diag_dot_to_dot_block_diag(left_multiply, batch_left, batch_right): """ Test that dot(block_diag(x, y,), z) is rewritten to concat(dot(x, z[:n]), dot(y, z[n:])) """ a = tensor("a", shape=(4, 2)) - b = tensor("b", shape=(2, 4)) + b = tensor("b", shape=(2, 4) if not batch_left else (3, 2, 4)) c = tensor("c", shape=(4, 4)) d = tensor("d", shape=(10, 10)) - e = tensor("e", shape=(10, 10)) + e = tensor("e", shape=(10, 10) if not batch_right else (3, 1, 10, 10)) x = pt.linalg.block_diag(a, b, c) @@ -4676,7 +4682,9 @@ def test_local_block_diag_dot_to_dot_block_diag(left_multiply): else: out = [d @ x, e @ x] - fn = pytensor.function([a, b, c, d, e], out, mode=rewrite_mode) + with config.change_flags(optimizer_verbose=True): + fn = pytensor.function([a, b, c, d, e], out, mode=rewrite_mode) + assert not any( isinstance(node.op, BlockDiagonal) for node in fn.maker.fgraph.toposort() ) @@ -4684,9 +4692,11 @@ def test_local_block_diag_dot_to_dot_block_diag(left_multiply): fn_expected = pytensor.function( [a, b, c, d, e], out, - mode=rewrite_mode.excluding("local_block_diag_dot_to_dot_block_diag"), + mode=Mode(linker="py", optimizer=None), ) + # TODO: Count Dots + rng = np.random.default_rng() a_val = rng.normal(size=a.type.shape).astype(a.type.dtype) b_val = rng.normal(size=b.type.shape).astype(b.type.dtype) @@ -4694,12 +4704,16 @@ def test_local_block_diag_dot_to_dot_block_diag(left_multiply): d_val = rng.normal(size=d.type.shape).astype(d.type.dtype) e_val = rng.normal(size=e.type.shape).astype(e.type.dtype) - np.testing.assert_allclose( - fn(a_val, b_val, c_val, d_val, e_val), - fn_expected(a_val, b_val, c_val, d_val, e_val), - atol=1e-6 if config.floatX == "float32" else 1e-12, - rtol=1e-6 if config.floatX == "float32" else 1e-12, - ) + rewrite_outs = fn(a_val, b_val, c_val, d_val, e_val) + expected_outs = fn_expected(a_val, b_val, c_val, d_val, e_val) + + for out, expected in zip(rewrite_outs, expected_outs): + np.testing.assert_allclose( + out, + expected, + atol=1e-6 if config.floatX == "float32" else 1e-12, + rtol=1e-6 if config.floatX == "float32" else 1e-12, + ) @pytest.mark.parametrize("rewrite", [True, False], ids=["rewrite", "no_rewrite"]) From 68591adb72418c4a4f42cf78f3d86afa8c086136 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Tue, 8 Jul 2025 12:57:23 +0200 Subject: [PATCH 9/9] Cleanup test --- tests/tensor/rewriting/test_math.py | 66 +++++++++++++++++------------ 1 file changed, 39 insertions(+), 27 deletions(-) diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index 3b43db4fb4..b27e0de7e6 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -4659,41 +4659,56 @@ def test_local_dot_to_mul(batched, a_shape, b_shape): @pytest.mark.parametrize("left_multiply", [True, False], ids=["left", "right"]) @pytest.mark.parametrize( - "batch_left", [True, False], ids=["batched_left", "unbatched_left"] + "batch_blockdiag", [True, False], ids=["batch_blockdiag", "unbatched_blockdiag"] ) @pytest.mark.parametrize( - "batch_right", [True, False], ids=["batched_right", "unbatched_right"] + "batch_other", [True, False], ids=["batched_other", "unbatched_other"] ) -def test_local_block_diag_dot_to_dot_block_diag(left_multiply, batch_left, batch_right): +def test_local_block_diag_dot_to_dot_block_diag( + left_multiply, batch_blockdiag, batch_other +): """ Test that dot(block_diag(x, y,), z) is rewritten to concat(dot(x, z[:n]), dot(y, z[n:])) """ + + def has_blockdiag(graph): + return any( + ( + var.owner + and ( + isinstance(var.owner.op, BlockDiagonal) + or ( + isinstance(var.owner.op, Blockwise) + and isinstance(var.owner.op.core_op, BlockDiagonal) + ) + ) + ) + for var in ancestors([graph]) + ) + a = tensor("a", shape=(4, 2)) - b = tensor("b", shape=(2, 4) if not batch_left else (3, 2, 4)) + b = tensor("b", shape=(2, 4) if not batch_blockdiag else (3, 2, 4)) c = tensor("c", shape=(4, 4)) - d = tensor("d", shape=(10, 10)) - e = tensor("e", shape=(10, 10) if not batch_right else (3, 1, 10, 10)) - x = pt.linalg.block_diag(a, b, c) + d = tensor("d", shape=(10, 10) if not batch_other else (3, 1, 10, 10)) + # Test multiple clients are all rewritten if left_multiply: - out = [x @ d, x @ e] + out = x @ d else: - out = [d @ x, e @ x] + out = d @ x - with config.change_flags(optimizer_verbose=True): - fn = pytensor.function([a, b, c, d, e], out, mode=rewrite_mode) - - assert not any( - isinstance(node.op, BlockDiagonal) for node in fn.maker.fgraph.toposort() - ) + assert has_blockdiag(out) + fn = pytensor.function([a, b, c, d], out, mode=rewrite_mode) + assert not has_blockdiag(fn.maker.fgraph.outputs[0]) fn_expected = pytensor.function( - [a, b, c, d, e], + [a, b, c, d], out, mode=Mode(linker="py", optimizer=None), ) + assert has_blockdiag(fn_expected.maker.fgraph.outputs[0]) # TODO: Count Dots @@ -4702,18 +4717,15 @@ def test_local_block_diag_dot_to_dot_block_diag(left_multiply, batch_left, batch b_val = rng.normal(size=b.type.shape).astype(b.type.dtype) c_val = rng.normal(size=c.type.shape).astype(c.type.dtype) d_val = rng.normal(size=d.type.shape).astype(d.type.dtype) - e_val = rng.normal(size=e.type.shape).astype(e.type.dtype) - rewrite_outs = fn(a_val, b_val, c_val, d_val, e_val) - expected_outs = fn_expected(a_val, b_val, c_val, d_val, e_val) - - for out, expected in zip(rewrite_outs, expected_outs): - np.testing.assert_allclose( - out, - expected, - atol=1e-6 if config.floatX == "float32" else 1e-12, - rtol=1e-6 if config.floatX == "float32" else 1e-12, - ) + rewrite_out = fn(a_val, b_val, c_val, d_val) + expected_out = fn_expected(a_val, b_val, c_val, d_val) + np.testing.assert_allclose( + rewrite_out, + expected_out, + atol=1e-6 if config.floatX == "float32" else 1e-12, + rtol=1e-6 if config.floatX == "float32" else 1e-12, + ) @pytest.mark.parametrize("rewrite", [True, False], ids=["rewrite", "no_rewrite"])