Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions pytensor/link/mlx/dispatch/blockwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,30 +6,37 @@

@mlx_funcify.register(Blockwise)
def funcify_Blockwise(op: Blockwise, node, **kwargs):
# 2) Otherwise, get the core python function for this Blockwise
# Get the core python function for this Blockwise operation
core_node = op._create_dummy_core_node(node.inputs)
core_f = mlx_funcify(op.core_op, core_node)

# 3) Determine how many inputs correspond to batch dimensions
# Determine how many batch dimensions are present in the output
n_batch = op.batch_ndim(node)

# 4) Handle case where no vectorization is needed
# If there are no batch dimensions, just return the core function
if n_batch == 0:
return core_f

# 5) Vectorize using mx.vmap over any batched inputs
# Build in_axes specification for mx.vmap
# Each input can be vectorized (axis=0) or static (axis=None)
in_axes: list[int | None] = []
for inp, sig in zip(node.inputs, op.inputs_sig):
batch_ndim = inp.type.ndim - len(sig)
if batch_ndim == 0:
# Input has no batch dimensions - treat as static
in_axes.append(None)
continue

batch_bcast = inp.type.broadcastable[:batch_ndim]
# If all batch dims are broadcastable (size 1), treat input as static
# Otherwise, vectorize over the first dimension (axis=0)
in_axes.append(0 if not all(batch_bcast) else None)

# If all inputs are static (no actual vectorization needed), return core function
# This prevents calling mx.vmap with all-None in_axes, which would raise:
# "ValueError: At least one of in_axes must be non-None"
if not any(axis == 0 for axis in in_axes):
return core_f

# Apply mx.vmap to vectorize the core function over batch dimensions
return mx.vmap(core_f, in_axes=tuple(in_axes))
42 changes: 42 additions & 0 deletions tests/link/mlx/test_blockwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,45 @@ def test_blockwise_conv1d():

# assert isinstance(out.owner.op, Blockwise)
compare_mlx_and_py([a, b], [out], test_values, must_be_device_array=True)


def test_blockwise_no_batch_dimensions():
"""Test that Blockwise returns the core function when there are no batch dimensions.

This verifies the fix for the vmap dispatcher issue where mx.vmap should not
be called when there are no batch dimensions to vectorize over.
"""
rng = np.random.default_rng(42)

# Create a blockwise matmul with no batch dimensions (core operation only)
x = pt.matrix("x")
y = pt.matrix("y")

blockwise_matmul = Blockwise(Dot(), signature="(i,j),(j,k)->(i,k)")
z = blockwise_matmul(x, y)

x_test = rng.normal(size=(2, 3))
y_test = rng.normal(size=(3, 4))

compare_mlx_and_py([x, y], [z], [x_test, y_test], must_be_device_array=True)


def test_blockwise_all_broadcastable_batch_dims():
"""Test that Blockwise returns the core function when all batch dims are broadcastable.

When all batch dimensions are size-1 (broadcastable), vmap should not be called
since there's no actual vectorization needed.
"""
rng = np.random.default_rng(43)

# Create inputs with size-1 batch dimensions
x = tensor("x", shape=(1, 2, 3))
y = tensor("y", shape=(1, 3, 4))

blockwise_matmul = Blockwise(Dot(), signature="(i,j),(j,k)->(i,k)")
z = blockwise_matmul(x, y)

x_test = rng.normal(size=(1, 2, 3))
y_test = rng.normal(size=(1, 3, 4))

compare_mlx_and_py([x, y], [z], [x_test, y_test], must_be_device_array=True)