Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 37 additions & 11 deletions pytensor/tensor/rewriting/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from pytensor import compile, config
from pytensor.compile.ops import ViewOp
from pytensor.graph import FunctionGraph, Op
from pytensor.graph.basic import Constant
from pytensor.graph.basic import Constant, equal_computations
from pytensor.graph.rewriting.basic import (
NodeProcessingGraphRewriter,
NodeRewriter,
Expand Down Expand Up @@ -82,7 +82,7 @@
)
from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.extra_ops import broadcast_arrays, repeat
from pytensor.tensor.extra_ops import broadcast_arrays
from pytensor.tensor.math import Sum, add, eq, variadic_add
from pytensor.tensor.shape import Shape_i, shape_padleft
from pytensor.tensor.type import DenseTensorType, TensorType
Expand Down Expand Up @@ -915,26 +915,52 @@ def local_join_make_vector(fgraph, node):
def local_join_to_repeat(fgraph, node):
"""Join(axis, x, x, x, ...) -> repeat(x, n, axis)

When the same tensor is concatenated multiple times,
replace with a single repeat operation which is more efficient.
When the same tensor is concatenated multiple times along an axis
where it has size 1, replace with a repeat operation which is more efficient.

Examples
--------
concatenate([x, x, x], axis=0) -> repeat(x, 3, axis=0)
concatenate([x[None], x[None], x[None]], axis=0) -> repeat(x[None], 3, axis=0)
"""
# Extract axis and the tensors being joined
axis, *tensors = node.inputs
axis_sym, *tensors = node.inputs

# Need at least 2 tensors to consider optimization
if len(tensors) <= 1:
return
return None

# Check if all tensors are identical
if not all(t == tensors[0] for t in tensors[1:]):
return
# Extract (and normalize) axis as Python int
try:
axis_val = int(get_scalar_constant_value(axis_sym, only_process_constants=True))
except NotScalarConstantError:
return None

# Get first tensor and check if ndim is known
first = tensors[0]
ndim = first.ndim
if ndim is None:
return None

# Normalize negative axes (e.g., -1 -> ndim-1)
axis_val = axis_val % ndim

# All inputs must be structurally the same tensor
# Use equal_computations to check structural equality, not symbolic ==
for t in tensors[1:]:
if not equal_computations([t], [first]):
return None

# Only apply when size along join axis is statically 1
# (e.g., x[None] has a guaranteed 1 at that axis)
shp = first.type.shape # tuple of ints/None
if shp is None or axis_val >= len(shp) or shp[axis_val] != 1:
return None

# Replace with repeat operation
result = repeat(tensors[0], len(tensors), axis)
from pytensor.tensor.extra_ops import repeat

n = len(tensors)
result = repeat(first, n, axis=axis_val)

# Preserve debugging information
copy_stack_trace(node.outputs[0], result)
Expand Down
77 changes: 48 additions & 29 deletions tests/tensor/rewriting/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
tile,
)
from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.extra_ops import Repeat
from pytensor.tensor.math import (
add,
bitwise_and,
Expand Down Expand Up @@ -1249,83 +1248,103 @@ def test_local_join_1():


def test_local_join_to_repeat():
"""Test that Join(axis, x, x, ...) gets rewritten to repeat(x, n, axis)"""
"""Test that Join(axis, x, x, ...) gets rewritten to repeat(x, n, axis)

# Test with vector - concatenate same vector 3 times along axis 0
This optimization applies when joining the same tensor multiple times
along an axis where it has size 1 (e.g., after ExpandDims).
"""

# Test with vector expanded to (1, n) - concatenate along axis 0
x = vector("x")
s = join(0, x, x, x)
x_expanded = x[None] # Shape: (1, n)
s = join(0, x_expanded, x_expanded, x_expanded) # Shape: (3, n)
f = function([x], s, mode=rewrite_mode)

# Check numerical correctness
test_val = np.array([1.0, 2.0, 3.0], dtype=config.floatX)
result = f(test_val)
expected = np.array(
[1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0], dtype=config.floatX
[[1.0, 2.0, 3.0], [1.0, 2.0, 3.0], [1.0, 2.0, 3.0]], dtype=config.floatX
)
assert np.allclose(result, expected)

# Check that Join was replaced with Repeat
# Check that Join was replaced with Alloc (repeat with scalar repeats becomes Alloc)
ops = f.maker.fgraph.toposort()
assert len([n for n in ops if isinstance(n.op, Join)]) == 0
assert len([n for n in ops if isinstance(n.op, Repeat)]) == 1
assert len([n for n in ops if isinstance(n.op, Alloc)]) >= 1

# Test with matrix - concatenate same matrix along axis 0
a = matrix("a")
s = join(0, a, a, a, a)
# Test with matrix - add dimension and concatenate along new axis
a = matrix("a") # Shape: (m, n)
a_expanded = a[None, :, :] # Shape: (1, m, n)
s = join(0, a_expanded, a_expanded, a_expanded, a_expanded) # Shape: (4, m, n)
f = function([a], s, mode=rewrite_mode)

test_mat = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX)
result = f(test_mat)
expected = np.vstack([test_mat, test_mat, test_mat, test_mat])
expected = np.array([test_mat, test_mat, test_mat, test_mat])
assert np.allclose(result, expected)

# Check optimization applied
ops = f.maker.fgraph.toposort()
assert len([n for n in ops if isinstance(n.op, Join)]) == 0
assert len([n for n in ops if isinstance(n.op, Repeat)]) == 1
assert len([n for n in ops if isinstance(n.op, Alloc)]) >= 1

# Test with matrix - concatenate along axis 1
s = join(1, a, a)
# Test with matrix - expand along axis 1 and concatenate
a_expanded_ax1 = a[:, None, :] # Shape: (m, 1, n)
s = join(1, a_expanded_ax1, a_expanded_ax1) # Shape: (m, 2, n)
f = function([a], s, mode=rewrite_mode)

result = f(test_mat)
expected = np.hstack([test_mat, test_mat])
expected = np.array([[[1.0, 2.0], [1.0, 2.0]], [[3.0, 4.0], [3.0, 4.0]]])
assert np.allclose(result, expected)

# Check optimization applied
ops = f.maker.fgraph.toposort()
assert len([n for n in ops if isinstance(n.op, Join)]) == 0
assert len([n for n in ops if isinstance(n.op, Repeat)]) == 1
assert len([n for n in ops if isinstance(n.op, Alloc)]) >= 1

# Test that it does NOT apply when tensors are different
b = matrix("b")
s = join(0, a, b)
f = function([a, b], s, mode=rewrite_mode)

test_mat1 = np.array([[1.0, 2.0]], dtype=config.floatX)
test_mat2 = np.array([[3.0, 4.0]], dtype=config.floatX)
result = f(test_mat1, test_mat2)
expected = np.vstack([test_mat1, test_mat2])
y = vector("y")
s = join(0, x[None], y[None])
f = function([x, y], s, mode=rewrite_mode)

test_vec1 = np.array([1.0, 2.0], dtype=config.floatX)
test_vec2 = np.array([3.0, 4.0], dtype=config.floatX)
result = f(test_vec1, test_vec2)
expected = np.array([[1.0, 2.0], [3.0, 4.0]])
assert np.allclose(result, expected)

# Join should still be present (not optimized)
ops = f.maker.fgraph.toposort()
assert len([n for n in ops if isinstance(n.op, Join)]) == 1

# Test that it does NOT apply when tensor doesn't have size 1 along join axis
# (regular concatenation without ExpandDims)
s = join(0, x, x, x) # Shape: (3n,) not using ExpandDims
f = function([x], s, mode=rewrite_mode)

test_val = np.array([1.0, 2.0], dtype=config.floatX)
result = f(test_val)
expected = np.array([1.0, 2.0, 1.0, 2.0, 1.0, 2.0], dtype=config.floatX)
assert np.allclose(result, expected)

# Join should still be present (not optimized to Repeat)
# Join should still be present (optimization doesn't apply)
ops = f.maker.fgraph.toposort()
assert len([n for n in ops if isinstance(n.op, Join)]) == 1
assert len([n for n in ops if isinstance(n.op, Repeat)]) == 0

# Test with 5 repetitions to ensure it works with larger counts
s = join(0, x, x, x, x, x)
s = join(0, x[None], x[None], x[None], x[None], x[None])
f = function([x], s, mode=rewrite_mode)

test_val = np.array([1.0, 2.0], dtype=config.floatX)
result = f(test_val)
expected = np.tile(test_val, 5)
expected = np.array([[1.0, 2.0]] * 5, dtype=config.floatX)
assert np.allclose(result, expected)

# Check optimization applied
ops = f.maker.fgraph.toposort()
assert len([n for n in ops if isinstance(n.op, Join)]) == 0
assert len([n for n in ops if isinstance(n.op, Repeat)]) == 1
assert len([n for n in ops if isinstance(n.op, Alloc)]) >= 1


def test_local_join_empty():
Expand Down