From d25f214af7793e104f8f2da940eac3b4ad25293d Mon Sep 17 00:00:00 2001 From: Will Dean Date: Fri, 11 Apr 2025 11:32:13 -0400 Subject: [PATCH 01/71] mlx poc --- pytensor/compile/mode.py | 19 +++++ pytensor/link/mlx/dispatch/__init__.py | 5 ++ pytensor/link/mlx/dispatch/basic.py | 61 +++++++++++++ pytensor/link/mlx/dispatch/math.py | 12 +++ pytensor/link/mlx/linker.py | 113 +++++++++++++++++++++++++ pytensor/link/pytorch/linker.py | 16 ++-- 6 files changed, 218 insertions(+), 8 deletions(-) create mode 100644 pytensor/link/mlx/dispatch/__init__.py create mode 100644 pytensor/link/mlx/dispatch/basic.py create mode 100644 pytensor/link/mlx/dispatch/math.py create mode 100644 pytensor/link/mlx/linker.py diff --git a/pytensor/compile/mode.py b/pytensor/compile/mode.py index f80dfaaf5c..ce58561212 100644 --- a/pytensor/compile/mode.py +++ b/pytensor/compile/mode.py @@ -27,6 +27,7 @@ from pytensor.link.basic import Linker, PerformLinker from pytensor.link.c.basic import CLinker, OpWiseCLinker from pytensor.link.jax.linker import JAXLinker +from pytensor.link.mlx.linker import MLXLinker from pytensor.link.numba.linker import NumbaLinker from pytensor.link.pytorch.linker import PytorchLinker from pytensor.link.vm import VMLinker @@ -50,6 +51,7 @@ "jax": JAXLinker(), "pytorch": PytorchLinker(), "numba": NumbaLinker(), + "mlx": MLXLinker(), } @@ -494,6 +496,20 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs): ), ) +MLX = Mode( + MLXLinker(), + RewriteDatabaseQuery( + include=["fast_run"], + exclude=[ + "cxx_only", + "BlasOpt", + "fusion", + "inplace", + "scan_save_mem_prealloc", + ], + ), +) + predefined_modes = { "FAST_COMPILE": FAST_COMPILE, @@ -501,6 +517,7 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs): "JAX": JAX, "NUMBA": NUMBA, "PYTORCH": PYTORCH, + "MLX": MLX, } _CACHED_RUNTIME_MODES: dict[str, Mode] = {} @@ -585,6 +602,8 @@ def get_target_language(mode=None) -> tuple[Literal["py", "c", "numba", "jax"], return ("py",) if isinstance(linker, CLinker): return ("c",) + if isinstance(linker, MLXLinker): + return ("py",) if isinstance(linker, VMLinker | OpWiseCLinker): return ("c", "py") if config.cxx else ("py",) diff --git a/pytensor/link/mlx/dispatch/__init__.py b/pytensor/link/mlx/dispatch/__init__.py new file mode 100644 index 0000000000..7acb41e1b5 --- /dev/null +++ b/pytensor/link/mlx/dispatch/__init__.py @@ -0,0 +1,5 @@ +# isort: off +from pytensor.link.mlx.dispatch.basic import mlx_funcify, mlx_typify + +import pytensor.link.mlx.dispatch.math +# isort: on diff --git a/pytensor/link/mlx/dispatch/basic.py b/pytensor/link/mlx/dispatch/basic.py new file mode 100644 index 0000000000..9cbb92118d --- /dev/null +++ b/pytensor/link/mlx/dispatch/basic.py @@ -0,0 +1,61 @@ +from functools import singledispatch +from types import NoneType + +import mlx.core as mx +import numpy as np + +from pytensor.compile.ops import DeepCopyOp +from pytensor.graph.fg import FunctionGraph +from pytensor.link.utils import fgraph_to_python + + +@singledispatch +def mlx_typify(data, **kwargs): + raise NotImplementedError(f"mlx_typify is not implemented for {type(data)}") + + +@mlx_typify.register(np.ndarray) +@mlx_typify.register(mx.array) +def mlx_typify_tensor(data, dtype=None, **kwargs): + return mx.array(data, dtype=dtype) + + +@mlx_typify.register(slice) +@mlx_typify.register(NoneType) +@mlx_typify.register(np.number) +def mlx_typify_no_conversion_needed(data, **kwargs): + return data + + +@singledispatch +def mlx_funcify(op, node=None, storage_map=None, **kwargs): + """Create a MLX compatible function from an PyTensor `Op`.""" + raise NotImplementedError( + f"No MLX conversion for the given `Op`: {op}.\nCheck out `https://github.com/pymc-devs/pytensor/issues/1350` for progress or to request we prioritize this operation" + ) + + +@mlx_funcify.register(FunctionGraph) +def mlx_funcify_FunctionGraph( + fgraph, + node=None, + fgraph_name="mlx_funcified_fgraph", + conversion_func=mlx_funcify, + **kwargs, +): + built_kwargs = {"conversion_func": conversion_func, **kwargs} + return fgraph_to_python( + fgraph, + conversion_func, + type_conversion_fn=mlx_typify, + fgraph_name=fgraph_name, + **built_kwargs, + ) + + +@mlx_funcify.register(DeepCopyOp) +def mlx_funcify_DeepCopyOp(op, **kwargs): + def deepcopyop(x): + return x.copy() + + return deepcopyop diff --git a/pytensor/link/mlx/dispatch/math.py b/pytensor/link/mlx/dispatch/math.py new file mode 100644 index 0000000000..1ef7ec4608 --- /dev/null +++ b/pytensor/link/mlx/dispatch/math.py @@ -0,0 +1,12 @@ +import mlx.core as mx + +from pytensor.link.mlx.dispatch import mlx_funcify +from pytensor.tensor.math import Dot + + +@mlx_funcify.register(Dot) +def mlx_funcify_Dot(op, **kwargs): + def dot(x, y): + return mx.matmul(x, y) + + return dot diff --git a/pytensor/link/mlx/linker.py b/pytensor/link/mlx/linker.py new file mode 100644 index 0000000000..8cfd9a0ff5 --- /dev/null +++ b/pytensor/link/mlx/linker.py @@ -0,0 +1,113 @@ +from pytensor.link.basic import JITLinker +from pytensor.link.utils import unique_name_generator + + +class MLXLinker(JITLinker): + """A `Linker` that JIT-compiles NumPy-based operations using Apple's MLX.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.gen_functors = [] + + def fgraph_convert( + self, + fgraph, + order, + input_storage, + output_storage, + storage_map, + **kwargs, + ): + """Convert a PyTensor FunctionGraph to an MLX-compatible function. + + Parameters + ---------- + fgraph : FunctionGraph + The function graph to convert + order : list + The order in which to compute the nodes + input_storage : list + Storage for the input variables + output_storage : list + Storage for the output variables + storage_map : dict + Map from variables to their storage + + Returns + ------- + callable + An MLX-compatible function + """ + from pytensor.link.mlx.dispatch import mlx_funcify + + # We want to have globally unique names + # across the entire pytensor graph, not + # just the subgraph + generator = unique_name_generator(["mlx_linker"]) + + # Ensure that torch is aware of the generated + # code so we can compile without graph breaks + def conversion_func_register(*args, **kwargs): + functor = mlx_funcify(*args, **kwargs) + name = kwargs["unique_name"](functor) + self.gen_functors.append((f"_{name}", functor)) + return functor + + built_kwargs = { + "unique_name": generator, + "conversion_func": conversion_func_register, + **kwargs, + } + return mlx_funcify( + fgraph, + input_storage=input_storage, + storage_map=storage_map, + **built_kwargs, + ) + + def jit_compile(self, fn): + """JIT compile an MLX function. + + Parameters + ---------- + fn : callable + The function to compile + + Returns + ------- + callable + The compiled function + """ + import mlx.core as mx + + return mx.compile(fn) + + def create_thunk_inputs(self, storage_map): + """Create inputs for the MLX thunk. + + Parameters + ---------- + storage_map : dict + Map from variables to their storage + + Returns + ------- + list + The inputs for the thunk + """ + from numpy.random import Generator, RandomState + + from pytensor.link.mlx.dispatch import mlx_typify + + thunk_inputs = [] + for n in self.fgraph.inputs: + sinput = storage_map[n] + # Handle random number generators specially + if isinstance(sinput[0], RandomState | Generator): + new_value = mlx_typify( + sinput[0], dtype=getattr(sinput[0], "dtype", None) + ) + sinput[0] = new_value + thunk_inputs.append(sinput) + + return thunk_inputs diff --git a/pytensor/link/pytorch/linker.py b/pytensor/link/pytorch/linker.py index b8475e3157..0a057a9e8d 100644 --- a/pytensor/link/pytorch/linker.py +++ b/pytensor/link/pytorch/linker.py @@ -31,16 +31,16 @@ def conversion_func_register(*args, **kwargs): **kwargs, } return pytorch_funcify( - fgraph, input_storage=input_storage, storage_map=storage_map, **built_kwargs + fgraph, + input_storage=input_storage, + storage_map=storage_map, + **built_kwargs, ) def jit_compile(self, fn): - import torch + import mlx.core as mx - # flag that tend to help our graphs - torch._dynamo.config.capture_dynamic_output_shape_ops = True - - from pytensor.link.pytorch.dispatch import pytorch_typify + from pytensor.link.mlx.dispatch import mlx_typify class wrapper: """ @@ -54,7 +54,7 @@ class wrapper: """ def __init__(self, fn, gen_functors): - self.fn = torch.compile(fn) + self.fn = mx.compile(fn) self.gen_functors = gen_functors.copy() def __call__(self, *inputs, **kwargs): @@ -65,7 +65,7 @@ def __call__(self, *inputs, **kwargs): setattr(pytensor.link.utils, n[1:], fn) # Torch does not accept numpy inputs and may return GPU objects - outs = self.fn(*(pytorch_typify(inp) for inp in inputs), **kwargs) + outs = self.fn(*(mlx_typify(inp) for inp in inputs), **kwargs) # unset attrs for n, _ in self.gen_functors: From edacc0ed246c13a7bc664a5b4a5f078fccba68eb Mon Sep 17 00:00:00 2001 From: Will Dean Date: Fri, 11 Apr 2025 11:38:56 -0400 Subject: [PATCH 02/71] add test for dot --- tests/link/mlx/dispatch/test_math.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) create mode 100644 tests/link/mlx/dispatch/test_math.py diff --git a/tests/link/mlx/dispatch/test_math.py b/tests/link/mlx/dispatch/test_math.py new file mode 100644 index 0000000000..5608321a80 --- /dev/null +++ b/tests/link/mlx/dispatch/test_math.py @@ -0,0 +1,19 @@ +import numpy as np + +import pytensor +from pytensor.tensor.type import matrix + + +def test_mlx_dot(): + x = matrix("x") + y = matrix("y") + + out = x.dot(y) + fn = pytensor.function([x, y], out, mode="MLX") + + test_x = np.random.normal(size=(3, 2)) + test_y = np.random.normal(size=(2, 4)) + np.testing.assert_allclose( + fn(test_x, test_y), + np.dot(test_x, test_y), + ) From 052fdc23e80ad097373e5782eb06b65f75dead17 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Fri, 11 Apr 2025 12:26:49 -0400 Subject: [PATCH 03/71] restore pytorch --- pytensor/link/pytorch/linker.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/pytensor/link/pytorch/linker.py b/pytensor/link/pytorch/linker.py index 0a057a9e8d..18824a5b71 100644 --- a/pytensor/link/pytorch/linker.py +++ b/pytensor/link/pytorch/linker.py @@ -38,9 +38,11 @@ def conversion_func_register(*args, **kwargs): ) def jit_compile(self, fn): - import mlx.core as mx + import torch - from pytensor.link.mlx.dispatch import mlx_typify + torch._dynamo.config.capture_dynamic_output_shape_ops = True + + from pytensor.link.pytorch.dispatch import pytorch_typify class wrapper: """ @@ -54,7 +56,7 @@ class wrapper: """ def __init__(self, fn, gen_functors): - self.fn = mx.compile(fn) + self.fn = torch.compile(fn) self.gen_functors = gen_functors.copy() def __call__(self, *inputs, **kwargs): @@ -65,7 +67,7 @@ def __call__(self, *inputs, **kwargs): setattr(pytensor.link.utils, n[1:], fn) # Torch does not accept numpy inputs and may return GPU objects - outs = self.fn(*(mlx_typify(inp) for inp in inputs), **kwargs) + outs = self.fn(*(pytorch_typify(inp) for inp in inputs), **kwargs) # unset attrs for n, _ in self.gen_functors: From a9ecad0f8e41ae32e0b0148e2af15695d9c735f7 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Fri, 11 Apr 2025 12:31:07 -0400 Subject: [PATCH 04/71] wrap in mx.array --- tests/link/mlx/dispatch/test_math.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/link/mlx/dispatch/test_math.py b/tests/link/mlx/dispatch/test_math.py index 5608321a80..3b2c41167f 100644 --- a/tests/link/mlx/dispatch/test_math.py +++ b/tests/link/mlx/dispatch/test_math.py @@ -1,3 +1,4 @@ +import mlx.core as mx import numpy as np import pytensor @@ -11,8 +12,8 @@ def test_mlx_dot(): out = x.dot(y) fn = pytensor.function([x, y], out, mode="MLX") - test_x = np.random.normal(size=(3, 2)) - test_y = np.random.normal(size=(2, 4)) + test_x = mx.array(np.random.normal(size=(3, 2))) + test_y = mx.array(np.random.normal(size=(2, 4))) np.testing.assert_allclose( fn(test_x, test_y), np.dot(test_x, test_y), From e690bff174f4030d713bd51c61a3af46ad6652f6 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Fri, 11 Apr 2025 12:32:29 -0400 Subject: [PATCH 05/71] modify the pytorch jit --- pytensor/link/mlx/linker.py | 44 +++++++++++++++++++++++++++---------- 1 file changed, 32 insertions(+), 12 deletions(-) diff --git a/pytensor/link/mlx/linker.py b/pytensor/link/mlx/linker.py index 8cfd9a0ff5..c2c970aebf 100644 --- a/pytensor/link/mlx/linker.py +++ b/pytensor/link/mlx/linker.py @@ -66,21 +66,41 @@ def conversion_func_register(*args, **kwargs): ) def jit_compile(self, fn): - """JIT compile an MLX function. + import mlx.core as mx - Parameters - ---------- - fn : callable - The function to compile + from pytensor.link.mlx.dispatch import mlx_typify - Returns - ------- - callable - The compiled function - """ - import mlx.core as mx + class wrapper: + def __init__(self, fn, gen_functors): + self.fn = mx.compile(fn) + self.gen_functors = gen_functors.copy() + + def __call__(self, *inputs, **kwargs): + import pytensor.link.utils + + # set attrs + for n, fn in self.gen_functors: + setattr(pytensor.link.utils, n[1:], fn) + + # MLX doesn't support np.ndarray as input + outs = self.fn(*(mlx_typify(inp) for inp in inputs), **kwargs) + + return outs + + # unset attrs + for n, _ in self.gen_functors: + if getattr(pytensor.link.utils, n[1:], False): + delattr(pytensor.link.utils, n[1:]) + + return tuple(out.cpu().numpy() for out in outs) + + def __del__(self): + del self.gen_functors + + inner_fn = wrapper(fn, self.gen_functors) + self.gen_functors = [] - return mx.compile(fn) + return inner_fn def create_thunk_inputs(self, storage_map): """Create inputs for the MLX thunk. From ad29c1780ef05a16cea7a06a08e258f97852eeb1 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Fri, 11 Apr 2025 12:38:06 -0400 Subject: [PATCH 06/71] move file --- tests/link/mlx/{dispatch => }/test_math.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/link/mlx/{dispatch => }/test_math.py (100%) diff --git a/tests/link/mlx/dispatch/test_math.py b/tests/link/mlx/test_math.py similarity index 100% rename from tests/link/mlx/dispatch/test_math.py rename to tests/link/mlx/test_math.py From ba29b373df4fd283a1b542e0d0df50ab3fb35269 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Fri, 11 Apr 2025 12:47:38 -0400 Subject: [PATCH 07/71] dont wrap --- tests/link/mlx/test_math.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/tests/link/mlx/test_math.py b/tests/link/mlx/test_math.py index 3b2c41167f..28397d7643 100644 --- a/tests/link/mlx/test_math.py +++ b/tests/link/mlx/test_math.py @@ -1,4 +1,3 @@ -import mlx.core as mx import numpy as np import pytensor @@ -12,9 +11,12 @@ def test_mlx_dot(): out = x.dot(y) fn = pytensor.function([x, y], out, mode="MLX") - test_x = mx.array(np.random.normal(size=(3, 2))) - test_y = mx.array(np.random.normal(size=(2, 4))) - np.testing.assert_allclose( - fn(test_x, test_y), - np.dot(test_x, test_y), - ) + seed = sum(map(ord, "test_mlx_dot")) + rng = np.random.default_rng(seed) + + test_x = rng.normal(size=(3, 2)) + test_y = rng.normal(size=(2, 4)) + + actual = fn(test_x, test_y) + expected = np.dot(test_x, test_y) + np.testing.assert_allclose(actual, expected) From 87168707285534ca73b449c99ca2a3b2ecd588ed Mon Sep 17 00:00:00 2001 From: Will Dean Date: Fri, 11 Apr 2025 12:53:48 -0400 Subject: [PATCH 08/71] attempt to fix github action --- .github/workflows/test.yml | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 9677615206..46800a1e13 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -82,6 +82,7 @@ jobs: install-numba: [0] install-jax: [0] install-torch: [0] + install-mlx: [0] part: - "tests --ignore=tests/tensor --ignore=tests/scan --ignore=tests/sparse" - "tests/scan" @@ -115,6 +116,7 @@ jobs: install-numba: 0 install-jax: 0 install-torch: 0 + install-mlx: 0 - install-numba: 1 os: "ubuntu-latest" python-version: "3.10" @@ -150,6 +152,13 @@ jobs: fast-compile: 0 float32: 0 part: "tests/link/pytorch" + - install-mlx: 1 + os: "ubuntu-latest" + python-version: "3.10" + numpy-version: ">=2.0" + fast-compile: 0 + float32: 0 + part: "tests/link/mlx" - os: macos-15 python-version: "3.13" numpy-version: ">=2.0" @@ -196,6 +205,7 @@ jobs: if [[ $INSTALL_NUMBA == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" "numba>=0.57"; fi if [[ $INSTALL_JAX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" jax jaxlib numpyro && pip install tensorflow-probability; fi if [[ $INSTALL_TORCH == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" pytorch pytorch-cuda=12.1 "mkl<=2024.0" -c pytorch -c nvidia; fi + if [[ $INSTALL_MLX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" mlx; fi pip install pytest-sphinx pip install -e ./ @@ -212,6 +222,7 @@ jobs: INSTALL_NUMBA: ${{ matrix.install-numba }} INSTALL_JAX: ${{ matrix.install-jax }} INSTALL_TORCH: ${{ matrix.install-torch}} + INSTALL_MLX: ${{ matrix.install-mlx }} OS: ${{ matrix.os}} - name: Run tests From 9bf7edfb9a6e73675594f6ee1964086ff9f67d75 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Fri, 11 Apr 2025 12:55:58 -0400 Subject: [PATCH 09/71] change the rtol --- tests/link/mlx/test_math.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/link/mlx/test_math.py b/tests/link/mlx/test_math.py index 28397d7643..8a9c700a52 100644 --- a/tests/link/mlx/test_math.py +++ b/tests/link/mlx/test_math.py @@ -19,4 +19,4 @@ def test_mlx_dot(): actual = fn(test_x, test_y) expected = np.dot(test_x, test_y) - np.testing.assert_allclose(actual, expected) + np.testing.assert_allclose(actual, expected, rtol=1e-6) From 96ba1162e5e87ec635e3dfbd36f9061f6d6fbac0 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Fri, 11 Apr 2025 13:17:57 -0400 Subject: [PATCH 10/71] add init file --- pytensor/link/mlx/__init__.py | 1 + 1 file changed, 1 insertion(+) create mode 100644 pytensor/link/mlx/__init__.py diff --git a/pytensor/link/mlx/__init__.py b/pytensor/link/mlx/__init__.py new file mode 100644 index 0000000000..d5a6ab19ff --- /dev/null +++ b/pytensor/link/mlx/__init__.py @@ -0,0 +1 @@ +from pytensor.link.mlx.linker import MLXLinker From e116fa1d4c26814b663be3d88aebaeb718416b4e Mon Sep 17 00:00:00 2001 From: Will Dean Date: Fri, 11 Apr 2025 13:21:47 -0400 Subject: [PATCH 11/71] skip if not installed --- tests/link/mlx/test_basic.py | 4 ++++ tests/link/mlx/test_math.py | 1 + 2 files changed, 5 insertions(+) create mode 100644 tests/link/mlx/test_basic.py diff --git a/tests/link/mlx/test_basic.py b/tests/link/mlx/test_basic.py new file mode 100644 index 0000000000..f4e5149d67 --- /dev/null +++ b/tests/link/mlx/test_basic.py @@ -0,0 +1,4 @@ +import pytest + + +mx = pytest.importorskip("mlx.core") diff --git a/tests/link/mlx/test_math.py b/tests/link/mlx/test_math.py index 8a9c700a52..f3839a1cac 100644 --- a/tests/link/mlx/test_math.py +++ b/tests/link/mlx/test_math.py @@ -1,6 +1,7 @@ import numpy as np import pytensor +import tests.link.mlx.test_basic # noqa: F401 from pytensor.tensor.type import matrix From 5d5f7546d53cd19c3c4a0f88a57bbfd82c853f07 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Fri, 11 Apr 2025 13:28:11 -0400 Subject: [PATCH 12/71] remove torch related code / comments --- pytensor/link/mlx/linker.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/pytensor/link/mlx/linker.py b/pytensor/link/mlx/linker.py index c2c970aebf..f8159a120b 100644 --- a/pytensor/link/mlx/linker.py +++ b/pytensor/link/mlx/linker.py @@ -45,8 +45,6 @@ def fgraph_convert( # just the subgraph generator = unique_name_generator(["mlx_linker"]) - # Ensure that torch is aware of the generated - # code so we can compile without graph breaks def conversion_func_register(*args, **kwargs): functor = mlx_funcify(*args, **kwargs) name = kwargs["unique_name"](functor) @@ -85,14 +83,12 @@ def __call__(self, *inputs, **kwargs): # MLX doesn't support np.ndarray as input outs = self.fn(*(mlx_typify(inp) for inp in inputs), **kwargs) - return outs - # unset attrs for n, _ in self.gen_functors: if getattr(pytensor.link.utils, n[1:], False): delattr(pytensor.link.utils, n[1:]) - return tuple(out.cpu().numpy() for out in outs) + return outs def __del__(self): del self.gen_functors From b8cee3f779a4472431022459cc198003be671ac9 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Sat, 12 Apr 2025 16:25:22 -0400 Subject: [PATCH 13/71] simplify the fgraph_convert --- pytensor/link/mlx/linker.py | 39 ++----------------------------------- 1 file changed, 2 insertions(+), 37 deletions(-) diff --git a/pytensor/link/mlx/linker.py b/pytensor/link/mlx/linker.py index f8159a120b..f512c041d3 100644 --- a/pytensor/link/mlx/linker.py +++ b/pytensor/link/mlx/linker.py @@ -1,5 +1,4 @@ from pytensor.link.basic import JITLinker -from pytensor.link.utils import unique_name_generator class MLXLinker(JITLinker): @@ -9,29 +8,13 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.gen_functors = [] - def fgraph_convert( - self, - fgraph, - order, - input_storage, - output_storage, - storage_map, - **kwargs, - ): + def fgraph_convert(self, fgraph, **kwargs): """Convert a PyTensor FunctionGraph to an MLX-compatible function. Parameters ---------- fgraph : FunctionGraph The function graph to convert - order : list - The order in which to compute the nodes - input_storage : list - Storage for the input variables - output_storage : list - Storage for the output variables - storage_map : dict - Map from variables to their storage Returns ------- @@ -40,27 +23,9 @@ def fgraph_convert( """ from pytensor.link.mlx.dispatch import mlx_funcify - # We want to have globally unique names - # across the entire pytensor graph, not - # just the subgraph - generator = unique_name_generator(["mlx_linker"]) - - def conversion_func_register(*args, **kwargs): - functor = mlx_funcify(*args, **kwargs) - name = kwargs["unique_name"](functor) - self.gen_functors.append((f"_{name}", functor)) - return functor - - built_kwargs = { - "unique_name": generator, - "conversion_func": conversion_func_register, - **kwargs, - } return mlx_funcify( fgraph, - input_storage=input_storage, - storage_map=storage_map, - **built_kwargs, + **kwargs, ) def jit_compile(self, fn): From d057453c071d441e375b0deaa05122a774ba4e31 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Sat, 12 Apr 2025 16:25:44 -0400 Subject: [PATCH 14/71] assert type --- tests/link/mlx/test_math.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/link/mlx/test_math.py b/tests/link/mlx/test_math.py index f3839a1cac..1380e01ca4 100644 --- a/tests/link/mlx/test_math.py +++ b/tests/link/mlx/test_math.py @@ -1,8 +1,8 @@ import numpy as np import pytensor -import tests.link.mlx.test_basic # noqa: F401 from pytensor.tensor.type import matrix +from tests.link.mlx.test_basic import mx def test_mlx_dot(): @@ -19,5 +19,6 @@ def test_mlx_dot(): test_y = rng.normal(size=(2, 4)) actual = fn(test_x, test_y) + assert isinstance(actual, mx.array) expected = np.dot(test_x, test_y) np.testing.assert_allclose(actual, expected, rtol=1e-6) From ae202e669b238f285da04ca9daee43eb74f849de Mon Sep 17 00:00:00 2001 From: Will Dean Date: Fri, 18 Apr 2025 13:51:02 -0400 Subject: [PATCH 15/71] simplify the internal --- pytensor/link/mlx/linker.py | 31 ++++--------------------------- 1 file changed, 4 insertions(+), 27 deletions(-) diff --git a/pytensor/link/mlx/linker.py b/pytensor/link/mlx/linker.py index f512c041d3..e057bb942c 100644 --- a/pytensor/link/mlx/linker.py +++ b/pytensor/link/mlx/linker.py @@ -33,35 +33,12 @@ def jit_compile(self, fn): from pytensor.link.mlx.dispatch import mlx_typify - class wrapper: - def __init__(self, fn, gen_functors): - self.fn = mx.compile(fn) - self.gen_functors = gen_functors.copy() + inner_fn = mx.compile(fn) - def __call__(self, *inputs, **kwargs): - import pytensor.link.utils + def fn(*inputs, inner_fn=inner_fn): + return inner_fn(*(mlx_typify(inp) for inp in inputs)) - # set attrs - for n, fn in self.gen_functors: - setattr(pytensor.link.utils, n[1:], fn) - - # MLX doesn't support np.ndarray as input - outs = self.fn(*(mlx_typify(inp) for inp in inputs), **kwargs) - - # unset attrs - for n, _ in self.gen_functors: - if getattr(pytensor.link.utils, n[1:], False): - delattr(pytensor.link.utils, n[1:]) - - return outs - - def __del__(self): - del self.gen_functors - - inner_fn = wrapper(fn, self.gen_functors) - self.gen_functors = [] - - return inner_fn + return fn def create_thunk_inputs(self, storage_map): """Create inputs for the MLX thunk. From f1941fe1c7951d1208cbc4c82ce7d7da99e8dc75 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Fri, 18 Apr 2025 14:00:29 -0400 Subject: [PATCH 16/71] remove the language --- pytensor/compile/mode.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pytensor/compile/mode.py b/pytensor/compile/mode.py index ce58561212..8dc7c742bc 100644 --- a/pytensor/compile/mode.py +++ b/pytensor/compile/mode.py @@ -602,8 +602,6 @@ def get_target_language(mode=None) -> tuple[Literal["py", "c", "numba", "jax"], return ("py",) if isinstance(linker, CLinker): return ("c",) - if isinstance(linker, MLXLinker): - return ("py",) if isinstance(linker, VMLinker | OpWiseCLinker): return ("c", "py") if config.cxx else ("py",) From 7c8eae7aa670d914bb9ad3552f58b4e4bd5279ce Mon Sep 17 00:00:00 2001 From: Carlos Trujillo <59846724+cetagostini@users.noreply.github.com> Date: Fri, 18 Apr 2025 13:08:40 -0500 Subject: [PATCH 17/71] Adding operations in pytensor --- pytensor/link/mlx/dispatch/elemwise | 77 +++++++++++++++++++++++++++++ pytensor/link/mlx/dispatch/math.py | 43 ++++++++++++++++ pytensor/link/mlx/dispatch/shape.py | 15 ++++++ 3 files changed, 135 insertions(+) create mode 100644 pytensor/link/mlx/dispatch/elemwise create mode 100644 pytensor/link/mlx/dispatch/shape.py diff --git a/pytensor/link/mlx/dispatch/elemwise b/pytensor/link/mlx/dispatch/elemwise new file mode 100644 index 0000000000..5c938cac10 --- /dev/null +++ b/pytensor/link/mlx/dispatch/elemwise @@ -0,0 +1,77 @@ +from pytensor.link.mlx.dispatch.basic import mlx_funcify +from pytensor.tensor.elemwise import CAReduce, DimShuffle +from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad +from pytensor.scalar.basic import Add, Mul, Any, AND, OR, ScalarMaximum, ScalarMinimum + +import mlx.core as mx + +@mlx_funcify.register(DimShuffle) +def mlx_funcify_DimShuffle(op, **kwargs): + def dimshuffle(x): + res = mx.transpose(x, op.transposition) + + shape = list(res.shape[: len(op.shuffle)]) + + for augm in op.augment: + shape.insert(augm, 1) + + return mx.reshape(res, shape) + + return dimshuffle + +@mlx_funcify.register(CAReduce) +def mlx_funcify_CAReduce(op, **kwargs): + if isinstance(op.scalar_op, Add): + def sum(x): + return mx.sum(x, axis=op.axis) + + return sum + elif isinstance(op.scalar_op, Mul): + def prod(x): + return mx.prod(x, axis=op.axis) + + return prod + elif isinstance(op.scalar_op, AND): + def all(x): + return mx.all(x, axis=op.axis) + + return all + elif isinstance(op.scalar_op, OR): + def any(x): + return mx.any(x, axis=op.axis) + + return any + elif isinstance(op.scalar_op, ScalarMaximum): + def max(x): + return mx.max(x, axis=op.axis) + + return max + elif isinstance(op.scalar_op, ScalarMinimum): + def min(x): + return mx.min(x, axis=op.axis) + + return min + + else: + raise NotImplementedError(f"MLX does not support {op.scalar_op}") + + +@mlx_funcify.register(Softmax) +def mlx_funcify_Softmax(op, **kwargs): + axis = op.axis + + def softmax(x): + return mx.softmax(x, axis=axis) + + return softmax + + +@mlx_funcify.register(SoftmaxGrad) +def mlx_funcify_SoftmaxGrad(op, **kwargs): + axis = op.axis + + def softmax_grad(dy, sm): + dy_times_sm = dy * sm + return dy_times_sm - mx.sum(dy_times_sm, axis=axis, keepdims=True) * sm + + return softmax_grad \ No newline at end of file diff --git a/pytensor/link/mlx/dispatch/math.py b/pytensor/link/mlx/dispatch/math.py index 1ef7ec4608..842181b046 100644 --- a/pytensor/link/mlx/dispatch/math.py +++ b/pytensor/link/mlx/dispatch/math.py @@ -1,7 +1,10 @@ import mlx.core as mx from pytensor.link.mlx.dispatch import mlx_funcify + +from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.math import Dot +from pytensor.scalar.basic import Add, Mul, Sub, Exp, Log, Sin, Cos @mlx_funcify.register(Dot) @@ -10,3 +13,43 @@ def dot(x, y): return mx.matmul(x, y) return dot + +@mlx_funcify.register(Elemwise) +def mlx_funcify_Elemwise(op, **kwargs): + if isinstance(op.scalar_op, Add): + def add(x, y): + return mx.add(x, y) + + return add + elif isinstance(op.scalar_op, Sub): + def sub(x, y): + return mx.sub(x, y) + + return sub + elif isinstance(op.scalar_op, Mul): + def mul(x, y): + return mx.mul(x, y) + + return mul + elif isinstance(op.scalar_op, Exp): + def exp(x): + return mx.exp(x) + + return exp + elif isinstance(op.scalar_op, Log): + def log(x): + return mx.log(x) + + return log + elif isinstance(op.scalar_op, Sin): + def sin(x): + return mx.sin(x) + + return sin + elif isinstance(op.scalar_op, Cos): + def cos(x): + return mx.cos(x) + + return cos + else: + raise NotImplementedError(f"MLX does not support {op.scalar_op}") \ No newline at end of file diff --git a/pytensor/link/mlx/dispatch/shape.py b/pytensor/link/mlx/dispatch/shape.py new file mode 100644 index 0000000000..c22ecea704 --- /dev/null +++ b/pytensor/link/mlx/dispatch/shape.py @@ -0,0 +1,15 @@ +from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape +from pytensor.link.mlx.dispatch.basic import mlx_funcify + +@mlx_funcify.register(SpecifyShape) +def mlx_funcify_SpecifyShape(op, node, **kwargs): + def specifyshape(x, *shape): + assert x.ndim == len(shape) + for actual, expected in zip(x.shape, shape, strict=True): + if expected is None: + continue + if actual != expected: + raise ValueError(f"Invalid shape: Expected {shape} but got {x.shape}") + return x + + return specifyshape \ No newline at end of file From 67a74fb5e5ecb3a32c49a820e939509f0c2556d0 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Fri, 18 Apr 2025 14:15:34 -0400 Subject: [PATCH 18/71] add extension --- .../mlx/dispatch/{elemwise => elemwise.py} | 20 +++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) rename pytensor/link/mlx/dispatch/{elemwise => elemwise.py} (90%) diff --git a/pytensor/link/mlx/dispatch/elemwise b/pytensor/link/mlx/dispatch/elemwise.py similarity index 90% rename from pytensor/link/mlx/dispatch/elemwise rename to pytensor/link/mlx/dispatch/elemwise.py index 5c938cac10..7ec124623e 100644 --- a/pytensor/link/mlx/dispatch/elemwise +++ b/pytensor/link/mlx/dispatch/elemwise.py @@ -1,9 +1,10 @@ +import mlx.core as mx + from pytensor.link.mlx.dispatch.basic import mlx_funcify +from pytensor.scalar.basic import AND, OR, Add, Mul, ScalarMaximum, ScalarMinimum from pytensor.tensor.elemwise import CAReduce, DimShuffle -from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad -from pytensor.scalar.basic import Add, Mul, Any, AND, OR, ScalarMaximum, ScalarMinimum +from pytensor.tensor.special import Softmax, SoftmaxGrad -import mlx.core as mx @mlx_funcify.register(DimShuffle) def mlx_funcify_DimShuffle(op, **kwargs): @@ -19,42 +20,49 @@ def dimshuffle(x): return dimshuffle + @mlx_funcify.register(CAReduce) def mlx_funcify_CAReduce(op, **kwargs): if isinstance(op.scalar_op, Add): + def sum(x): return mx.sum(x, axis=op.axis) return sum elif isinstance(op.scalar_op, Mul): + def prod(x): return mx.prod(x, axis=op.axis) return prod elif isinstance(op.scalar_op, AND): + def all(x): return mx.all(x, axis=op.axis) return all elif isinstance(op.scalar_op, OR): + def any(x): return mx.any(x, axis=op.axis) return any elif isinstance(op.scalar_op, ScalarMaximum): + def max(x): return mx.max(x, axis=op.axis) return max elif isinstance(op.scalar_op, ScalarMinimum): + def min(x): return mx.min(x, axis=op.axis) return min - + else: raise NotImplementedError(f"MLX does not support {op.scalar_op}") - + @mlx_funcify.register(Softmax) def mlx_funcify_Softmax(op, **kwargs): @@ -74,4 +82,4 @@ def softmax_grad(dy, sm): dy_times_sm = dy * sm return dy_times_sm - mx.sum(dy_times_sm, axis=axis, keepdims=True) * sm - return softmax_grad \ No newline at end of file + return softmax_grad From fb5eb523dbae16ee572261c95ecb03ace969816b Mon Sep 17 00:00:00 2001 From: Will Dean Date: Fri, 18 Apr 2025 14:38:21 -0400 Subject: [PATCH 19/71] make compare function --- tests/link/mlx/test_basic.py | 74 ++++++++++++++++++++++++++++++++++++ 1 file changed, 74 insertions(+) diff --git a/tests/link/mlx/test_basic.py b/tests/link/mlx/test_basic.py index f4e5149d67..746aab10bc 100644 --- a/tests/link/mlx/test_basic.py +++ b/tests/link/mlx/test_basic.py @@ -1,4 +1,78 @@ +from collections.abc import Callable, Iterable +from functools import partial + +import numpy as np import pytest +from pytensor.compile.function import function +from pytensor.compile.mode import Mode +from pytensor.graph.basic import Variable +from pytensor.link.mlx import MLXLinker + mx = pytest.importorskip("mlx.core") + +mlx_mode = Mode(linker=MLXLinker()) +py_mode = Mode(linker="py", optimizer=None) + + +def compare_mlx_and_py( + graph_inputs: Iterable[Variable], + graph_outputs: Variable | Iterable[Variable], + test_inputs: Iterable, + *, + assert_fn: Callable | None = None, + must_be_device_array: bool = True, + mlx_mode=mlx_mode, + py_mode=py_mode, +): + """Function to compare python function output and mlx compiled output for testing equality + + The inputs and outputs are then passed to this function which then compiles the given function in both + mlx and python, runs the calculation in both and checks if the results are the same + + Parameters + ---------- + graph_inputs: + Symbolic inputs to the graph + outputs: + Symbolic outputs of the graph + test_inputs: iter + Numerical inputs for testing the function. + assert_fn: func, opt + Assert function used to check for equality between python and mlx. If not + provided uses np.testing.assert_allclose + must_be_device_array: Bool + Checks for instance of jax.interpreters.xla.DeviceArray. For testing purposes + if this device array is found it indicates if the result was computed by jax + + Returns + ------- + mlx_res + + """ + if assert_fn is None: + assert_fn = partial(np.testing.assert_allclose, rtol=1e-4) + + if any(inp.owner is not None for inp in graph_inputs): + raise ValueError("Inputs must be root variables") + + pytensor_mlx_fn = function(graph_inputs, graph_outputs, mode=mlx_mode) + mlx_res = pytensor_mlx_fn(*test_inputs) + + if must_be_device_array: + if isinstance(mlx_res, list): + assert all(isinstance(res, mx.array) for res in mlx_res) + else: + assert isinstance(mlx_res, mx.array) + + pytensor_py_fn = function(graph_inputs, graph_outputs, mode=py_mode) + py_res = pytensor_py_fn(*test_inputs) + + if isinstance(graph_outputs, list | tuple): + for j, p in zip(mlx_res, py_res, strict=True): + assert_fn(j, p) + else: + assert_fn(mlx_res, py_res) + + return pytensor_mlx_fn, mlx_res From 516b5958b32669f6844d301d8a8f971b7611bab5 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Fri, 18 Apr 2025 14:38:52 -0400 Subject: [PATCH 20/71] rename function --- tests/link/mlx/test_math.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/link/mlx/test_math.py b/tests/link/mlx/test_math.py index 1380e01ca4..0781ea4e22 100644 --- a/tests/link/mlx/test_math.py +++ b/tests/link/mlx/test_math.py @@ -5,7 +5,7 @@ from tests.link.mlx.test_basic import mx -def test_mlx_dot(): +def test_dot(): x = matrix("x") y = matrix("y") From 67bb8da51aadfacbd9913ee6f13308d2857aa2be Mon Sep 17 00:00:00 2001 From: Will Dean Date: Fri, 18 Apr 2025 15:00:45 -0400 Subject: [PATCH 21/71] correct the function name --- pytensor/link/mlx/dispatch/math.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/pytensor/link/mlx/dispatch/math.py b/pytensor/link/mlx/dispatch/math.py index 842181b046..42f1ec7b72 100644 --- a/pytensor/link/mlx/dispatch/math.py +++ b/pytensor/link/mlx/dispatch/math.py @@ -1,10 +1,9 @@ import mlx.core as mx from pytensor.link.mlx.dispatch import mlx_funcify - +from pytensor.scalar.basic import Add, Cos, Exp, Log, Mul, Sin, Sub from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.math import Dot -from pytensor.scalar.basic import Add, Mul, Sub, Exp, Log, Sin, Cos @mlx_funcify.register(Dot) @@ -14,42 +13,50 @@ def dot(x, y): return dot + @mlx_funcify.register(Elemwise) def mlx_funcify_Elemwise(op, **kwargs): if isinstance(op.scalar_op, Add): + def add(x, y): return mx.add(x, y) return add elif isinstance(op.scalar_op, Sub): + def sub(x, y): - return mx.sub(x, y) + return mx.subtract(x, y) return sub elif isinstance(op.scalar_op, Mul): + def mul(x, y): - return mx.mul(x, y) + return mx.multiply(x, y) return mul elif isinstance(op.scalar_op, Exp): + def exp(x): return mx.exp(x) return exp elif isinstance(op.scalar_op, Log): + def log(x): return mx.log(x) return log elif isinstance(op.scalar_op, Sin): + def sin(x): return mx.sin(x) - + return sin elif isinstance(op.scalar_op, Cos): + def cos(x): return mx.cos(x) return cos else: - raise NotImplementedError(f"MLX does not support {op.scalar_op}") \ No newline at end of file + raise NotImplementedError(f"MLX does not support {op.scalar_op}") From 82bb9642daa7514c07e233b3486bd75f85503d8c Mon Sep 17 00:00:00 2001 From: Will Dean Date: Fri, 18 Apr 2025 15:01:21 -0400 Subject: [PATCH 22/71] tests for elemwise --- tests/link/mlx/test_math.py | 35 +++++++++++++++++++++++++++++++---- 1 file changed, 31 insertions(+), 4 deletions(-) diff --git a/tests/link/mlx/test_math.py b/tests/link/mlx/test_math.py index 0781ea4e22..c4f16d2f28 100644 --- a/tests/link/mlx/test_math.py +++ b/tests/link/mlx/test_math.py @@ -1,13 +1,14 @@ import numpy as np +import pytest import pytensor -from pytensor.tensor.type import matrix -from tests.link.mlx.test_basic import mx +import pytensor.tensor as pt +from tests.link.mlx.test_basic import compare_mlx_and_py, mx def test_dot(): - x = matrix("x") - y = matrix("y") + x = pt.matrix("x") + y = pt.matrix("y") out = x.dot(y) fn = pytensor.function([x, y], out, mode="MLX") @@ -22,3 +23,29 @@ def test_dot(): assert isinstance(actual, mx.array) expected = np.dot(test_x, test_y) np.testing.assert_allclose(actual, expected, rtol=1e-6) + + +@pytest.mark.parametrize( + "op", + [pt.exp, pt.log, pt.sin, pt.cos], + ids=["exp", "log", "sin", "cos"], +) +def test_elemwise_one_input(op) -> None: + x = pt.vector("x") + out = op(x) + x_test = mx.array([1.0, 2.0, 3.0]) + compare_mlx_and_py([x], out, [x_test]) + + +@pytest.mark.parametrize( + "op", + [pt.add, pt.sub, pt.mul], + ids=["add", "sub", "mul"], +) +def test_elemwise_two_inputs(op) -> None: + x = pt.vector("x") + y = pt.vector("y") + out = op(x, y) + x_test = mx.array([1.0, 2.0, 3.0]) + y_test = mx.array([4.0, 5.0, 6.0]) + compare_mlx_and_py([x, y], out, [x_test, y_test]) From 877d79fe2981d132194e4a528601d6c1f5b6105b Mon Sep 17 00:00:00 2001 From: Carlos Trujillo <59846724+cetagostini@users.noreply.github.com> Date: Fri, 18 Apr 2025 14:05:11 -0500 Subject: [PATCH 23/71] Changes --- pytensor/link/mlx/dispatch/__init__.py | 7 +- pytensor/link/mlx/dispatch/elemwise.py | 2 +- pytensor/link/mlx/dispatch/math.py | 53 ++++++++++-- pytensor/link/mlx/dispatch/subtensor.py | 110 ++++++++++++++++++++++++ 4 files changed, 164 insertions(+), 8 deletions(-) create mode 100644 pytensor/link/mlx/dispatch/subtensor.py diff --git a/pytensor/link/mlx/dispatch/__init__.py b/pytensor/link/mlx/dispatch/__init__.py index 7acb41e1b5..2d7dd19974 100644 --- a/pytensor/link/mlx/dispatch/__init__.py +++ b/pytensor/link/mlx/dispatch/__init__.py @@ -2,4 +2,9 @@ from pytensor.link.mlx.dispatch.basic import mlx_funcify, mlx_typify import pytensor.link.mlx.dispatch.math -# isort: on +import pytensor.link.mlx.dispatch.basic +import pytensor.link.mlx.dispatch.elemwise +import pytensor.link.mlx.dispatch.shape +import pytensor.link.mlx.dispatch.subtensor +import pytensor.link.mlx.dispatch.core +# isort: on \ No newline at end of file diff --git a/pytensor/link/mlx/dispatch/elemwise.py b/pytensor/link/mlx/dispatch/elemwise.py index 7ec124623e..d4bfaeab51 100644 --- a/pytensor/link/mlx/dispatch/elemwise.py +++ b/pytensor/link/mlx/dispatch/elemwise.py @@ -1,7 +1,7 @@ import mlx.core as mx from pytensor.link.mlx.dispatch.basic import mlx_funcify -from pytensor.scalar.basic import AND, OR, Add, Mul, ScalarMaximum, ScalarMinimum +from pytensor.scalar.basic import AND, OR, Add, Mul, ScalarMaximum, ScalarMinimum, Switch from pytensor.tensor.elemwise import CAReduce, DimShuffle from pytensor.tensor.special import Softmax, SoftmaxGrad diff --git a/pytensor/link/mlx/dispatch/math.py b/pytensor/link/mlx/dispatch/math.py index 42f1ec7b72..ce5064bf92 100644 --- a/pytensor/link/mlx/dispatch/math.py +++ b/pytensor/link/mlx/dispatch/math.py @@ -4,6 +4,8 @@ from pytensor.scalar.basic import Add, Cos, Exp, Log, Mul, Sin, Sub from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.math import Dot +from pytensor.scalar.math import Sigmoid +from pytensor.scalar.basic import Add, Mul, Sub, Exp, Log, Sin, Cos, LE, LT, GE, GT, EQ, NEQ @mlx_funcify.register(Dot) @@ -17,9 +19,11 @@ def dot(x, y): @mlx_funcify.register(Elemwise) def mlx_funcify_Elemwise(op, **kwargs): if isinstance(op.scalar_op, Add): - - def add(x, y): - return mx.add(x, y) + def add(*args): + result = args[0] + for arg in args[1:]: + result = mx.add(result, arg) + return result return add elif isinstance(op.scalar_op, Sub): @@ -29,9 +33,11 @@ def sub(x, y): return sub elif isinstance(op.scalar_op, Mul): - - def mul(x, y): - return mx.multiply(x, y) + def mul(*args): + result = args[0] + for arg in args[1:]: + result = mx.multiply(result, arg) + return result return mul elif isinstance(op.scalar_op, Exp): @@ -58,5 +64,40 @@ def cos(x): return mx.cos(x) return cos + elif isinstance(op.scalar_op, Sigmoid): + def sigmoid(x): + return mx.sigmoid(x) + + return sigmoid + elif isinstance(op.scalar_op, LE): + def le(x, y): + return mx.less_equal(x, y) + + return le + elif isinstance(op.scalar_op, LT): + def lt(x, y): + return mx.less(x, y) + + return lt + elif isinstance(op.scalar_op, GE): + def ge(x, y): + return mx.greater_equal(x, y) + + return ge + elif isinstance(op.scalar_op, GT): + def gt(x, y): + return mx.greater(x, y) + + return gt + elif isinstance(op.scalar_op, EQ): + def eq(x, y): + return mx.equal(x, y) + + return eq + elif isinstance(op.scalar_op, NEQ): + def neq(x, y): + return mx.not_equal(x, y) + + return neq else: raise NotImplementedError(f"MLX does not support {op.scalar_op}") diff --git a/pytensor/link/mlx/dispatch/subtensor.py b/pytensor/link/mlx/dispatch/subtensor.py new file mode 100644 index 0000000000..7f8b55f18e --- /dev/null +++ b/pytensor/link/mlx/dispatch/subtensor.py @@ -0,0 +1,110 @@ +from pytensor.link.mlx.dispatch.basic import mlx_funcify +from pytensor.tensor.subtensor import ( + AdvancedIncSubtensor, + AdvancedIncSubtensor1, + AdvancedSubtensor, + AdvancedSubtensor1, + IncSubtensor, + Subtensor, + indices_from_subtensor, +) +from pytensor.tensor.type_other import MakeSlice + + +BOOLEAN_MASK_ERROR = """MLX does not support resizing arrays with boolean +masks. In some cases, however, it is possible to re-express your model +in a form that MLX can compile: + +>>> import pytensor.tensor as pt +>>> x_pt = pt.vector('x') +>>> y_pt = x_pt[x_pt > 0].sum() + +can be re-expressed as: + +>>> import pytensor.tensor as pt +>>> x_pt = pt.vector('x') +>>> y_pt = pt.where(x_pt > 0, x_pt, 0).sum() +""" + +DYNAMIC_SLICE_LENGTH_ERROR = """MLX does not support slicing arrays with a dynamic +slice length. +""" + + +@mlx_funcify.register(Subtensor) +@mlx_funcify.register(AdvancedSubtensor) +@mlx_funcify.register(AdvancedSubtensor1) +def mlx_funcify_Subtensor(op, node, **kwargs): + idx_list = getattr(op, "idx_list", None) + + def subtensor(x, *ilists): + indices = indices_from_subtensor(ilists, idx_list) + if len(indices) == 1: + indices = indices[0] + + return x.__getitem__(indices) + + return subtensor + + +@mlx_funcify.register(IncSubtensor) +@mlx_funcify.register(AdvancedIncSubtensor1) +def mlx_funcify_IncSubtensor(op, node, **kwargs): + idx_list = getattr(op, "idx_list", None) + + if getattr(op, "set_instead_of_inc", False): + + def mlx_fn(x, indices, y): + if not op.inplace: + x = x.copy() + x[indices] = y + return x + + else: + + def mlx_fn(x, indices, y): + if not op.inplace: + x = x.copy() + x[indices] += y + return x + + def incsubtensor(x, y, *ilist, mlx_fn=mlx_fn, idx_list=idx_list): + indices = indices_from_subtensor(ilist, idx_list) + if len(indices) == 1: + indices = indices[0] + + return mlx_fn(x, indices, y) + + return incsubtensor + + +@mlx_funcify.register(AdvancedIncSubtensor) +def mlx_funcify_AdvancedIncSubtensor(op, node, **kwargs): + if getattr(op, "set_instead_of_inc", False): + + def mlx_fn(x, indices, y): + if not op.inplace: + x = x.copy() + x[indices] = y + return x + + else: + + def mlx_fn(x, indices, y): + if not op.inplace: + x = x.copy() + x[indices] += y + return x + + def advancedincsubtensor(x, y, *ilist, mlx_fn=mlx_fn): + return mlx_fn(x, ilist, y) + + return advancedincsubtensor + + +@mlx_funcify.register(MakeSlice) +def mlx_funcify_MakeSlice(op, **kwargs): + def makeslice(*x): + return slice(*x) + + return makeslice From fafedd66d579636eb3292245a1d5f7c0e874c2da Mon Sep 17 00:00:00 2001 From: Carlos Trujillo <59846724+cetagostini@users.noreply.github.com> Date: Fri, 18 Apr 2025 14:09:35 -0500 Subject: [PATCH 24/71] Toma tu tomate William --- pytensor/link/mlx/dispatch/__init__.py | 2 +- pytensor/link/mlx/dispatch/elemwise.py | 2 +- pytensor/link/mlx/dispatch/math.py | 49 ++++++++++++++++++++++++-- pytensor/link/mlx/dispatch/shape.py | 5 +-- 4 files changed, 51 insertions(+), 7 deletions(-) diff --git a/pytensor/link/mlx/dispatch/__init__.py b/pytensor/link/mlx/dispatch/__init__.py index 2d7dd19974..7e835d238b 100644 --- a/pytensor/link/mlx/dispatch/__init__.py +++ b/pytensor/link/mlx/dispatch/__init__.py @@ -7,4 +7,4 @@ import pytensor.link.mlx.dispatch.shape import pytensor.link.mlx.dispatch.subtensor import pytensor.link.mlx.dispatch.core -# isort: on \ No newline at end of file +# isort: on diff --git a/pytensor/link/mlx/dispatch/elemwise.py b/pytensor/link/mlx/dispatch/elemwise.py index d4bfaeab51..7ec124623e 100644 --- a/pytensor/link/mlx/dispatch/elemwise.py +++ b/pytensor/link/mlx/dispatch/elemwise.py @@ -1,7 +1,7 @@ import mlx.core as mx from pytensor.link.mlx.dispatch.basic import mlx_funcify -from pytensor.scalar.basic import AND, OR, Add, Mul, ScalarMaximum, ScalarMinimum, Switch +from pytensor.scalar.basic import AND, OR, Add, Mul, ScalarMaximum, ScalarMinimum from pytensor.tensor.elemwise import CAReduce, DimShuffle from pytensor.tensor.special import Softmax, SoftmaxGrad diff --git a/pytensor/link/mlx/dispatch/math.py b/pytensor/link/mlx/dispatch/math.py index ce5064bf92..a0f68324d9 100644 --- a/pytensor/link/mlx/dispatch/math.py +++ b/pytensor/link/mlx/dispatch/math.py @@ -1,11 +1,27 @@ import mlx.core as mx from pytensor.link.mlx.dispatch import mlx_funcify -from pytensor.scalar.basic import Add, Cos, Exp, Log, Mul, Sin, Sub +from pytensor.scalar.basic import ( + EQ, + GE, + GT, + LE, + LT, + NEQ, + Add, + Cos, + Exp, + Log, + Mul, + Pow, + Sin, + Sub, + Switch, + TrueDiv, +) +from pytensor.scalar.math import Sigmoid from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.math import Dot -from pytensor.scalar.math import Sigmoid -from pytensor.scalar.basic import Add, Mul, Sub, Exp, Log, Sin, Cos, LE, LT, GE, GT, EQ, NEQ @mlx_funcify.register(Dot) @@ -19,6 +35,7 @@ def dot(x, y): @mlx_funcify.register(Elemwise) def mlx_funcify_Elemwise(op, **kwargs): if isinstance(op.scalar_op, Add): + def add(*args): result = args[0] for arg in args[1:]: @@ -33,6 +50,7 @@ def sub(x, y): return sub elif isinstance(op.scalar_op, Mul): + def mul(*args): result = args[0] for arg in args[1:]: @@ -65,39 +83,64 @@ def cos(x): return cos elif isinstance(op.scalar_op, Sigmoid): + def sigmoid(x): return mx.sigmoid(x) return sigmoid elif isinstance(op.scalar_op, LE): + def le(x, y): return mx.less_equal(x, y) return le elif isinstance(op.scalar_op, LT): + def lt(x, y): return mx.less(x, y) return lt elif isinstance(op.scalar_op, GE): + def ge(x, y): return mx.greater_equal(x, y) return ge elif isinstance(op.scalar_op, GT): + def gt(x, y): return mx.greater(x, y) return gt elif isinstance(op.scalar_op, EQ): + def eq(x, y): return mx.equal(x, y) return eq elif isinstance(op.scalar_op, NEQ): + def neq(x, y): return mx.not_equal(x, y) return neq + elif isinstance(op.scalar_op, Switch): + + def switch(cond, x, y): + return mx.where(cond, x, y) + + return switch + elif isinstance(op.scalar_op, Pow): + + def pow(x, y): + return mx.power(x, y) + + return pow + elif isinstance(op.scalar_op, TrueDiv): + + def true_div(x, y): + return mx.divide(x, y) + + return true_div else: raise NotImplementedError(f"MLX does not support {op.scalar_op}") diff --git a/pytensor/link/mlx/dispatch/shape.py b/pytensor/link/mlx/dispatch/shape.py index c22ecea704..1d48eae1f5 100644 --- a/pytensor/link/mlx/dispatch/shape.py +++ b/pytensor/link/mlx/dispatch/shape.py @@ -1,5 +1,6 @@ -from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape from pytensor.link.mlx.dispatch.basic import mlx_funcify +from pytensor.tensor.shape import SpecifyShape + @mlx_funcify.register(SpecifyShape) def mlx_funcify_SpecifyShape(op, node, **kwargs): @@ -12,4 +13,4 @@ def specifyshape(x, *shape): raise ValueError(f"Invalid shape: Expected {shape} but got {x.shape}") return x - return specifyshape \ No newline at end of file + return specifyshape From 60acb8d7ac60d44e99a7c2d680702134cee446dd Mon Sep 17 00:00:00 2001 From: Carlos Trujillo <59846724+cetagostini@users.noreply.github.com> Date: Fri, 18 Apr 2025 14:15:32 -0500 Subject: [PATCH 25/71] Pushing changes with the core shit. --- .gitignore | 1 - pytensor/link/mlx/dispatch/core.py | 174 +++++++++++++++++++++++++++++ 2 files changed, 174 insertions(+), 1 deletion(-) create mode 100644 pytensor/link/mlx/dispatch/core.py diff --git a/.gitignore b/.gitignore index dfe862b868..ebe8e61bd0 100644 --- a/.gitignore +++ b/.gitignore @@ -27,7 +27,6 @@ __pycache__ \#*\# build compiled/*.cpp -core.* cutils_ext.cpp dist doc/.build/ diff --git a/pytensor/link/mlx/dispatch/core.py b/pytensor/link/mlx/dispatch/core.py new file mode 100644 index 0000000000..84ae042682 --- /dev/null +++ b/pytensor/link/mlx/dispatch/core.py @@ -0,0 +1,174 @@ +""" +pytensor/link/mlx/dispatch/basic.py +----------------------------------- + +First‑cut MLX translations for the most common tensor Ops. + +The structure intentionally follows pytensor's JAX dispatcher so that +once these kernels stabilise they can be optimised further (e.g. fusing +element‑wise graphs, adding in‑place updates, RNG thinning, etc.). +""" +from __future__ import annotations + +import warnings +import numpy as np + +import mlx.core as mx # MLX +from pytensor.link.mlx.dispatch.basic import mlx_funcify # MLX + +from pytensor.tensor import get_vector_length +from pytensor.tensor.basic import ( + Join, Split, ExtractDiag, Eye, MakeVector, + ScalarFromTensor, TensorFromScalar, Tri, + get_scalar_constant_value, +) +from pytensor.tensor.exceptions import NotScalarConstantError + + +# ------------------------------------------------------------------ +# Join +# ------------------------------------------------------------------ +@mlx_funcify.register(Join) # MLX +def mlx_funcify_Join(op, **kwargs): + def join(axis, *tensors): + view = op.view + if (view != -1) and all( + tensors[i].shape[axis] == 0 # MLX + for i in list(range(view)) + list(range(view + 1, len(tensors))) + ): + return tensors[view] + + return mx.concatenate(tensors, axis=axis) # MLX + + return join + + +# ------------------------------------------------------------------ +# Split +# ------------------------------------------------------------------ +@mlx_funcify.register(Split) # MLX +def mlx_funcify_Split(op: Split, node, **kwargs): + _, axis_sym, splits_sym = node.inputs + + try: + constant_axis = get_scalar_constant_value(axis_sym) + except NotScalarConstantError: + constant_axis = None + warnings.warn( + "Split node does not have a constant axis. MLX implementation may fail." + ) + + try: + constant_splits = np.array( + [get_scalar_constant_value(splits_sym[i]) + for i in range(get_vector_length(splits_sym))] + ) + except (ValueError, NotScalarConstantError): + constant_splits = None + warnings.warn( + "Split node does not have constant split positions. MLX implementation may fail." + ) + + def split(x, axis, splits): + # Resolve constants (avoids tracing extra ops) + if constant_axis is not None: + axis = int(constant_axis) + + if constant_splits is not None: + splits = constant_splits + cumsum_splits = np.cumsum(splits[:-1]) + else: + # dynamic ‑– keep in graph + splits_arr = mx.array(splits) # MLX + cumsum_splits = mx.cumsum(splits_arr[:-1]).tolist() # python list for mx.split + + if len(splits) != op.len_splits: + raise ValueError("Length of 'splits' is not equal to n_splits") + if np.sum(np.asarray(splits)) != x.shape[axis]: + raise ValueError("Split sizes do not sum to the input length on the chosen axis.") + if np.any(np.asarray(splits) < 0): + raise ValueError("Split sizes cannot be negative.") + + return mx.split(x, cumsum_splits, axis=axis) # MLX + + return split + + +# ------------------------------------------------------------------ +# ExtractDiag +# ------------------------------------------------------------------ +@mlx_funcify.register(ExtractDiag) # MLX +def mlx_funcify_ExtractDiag(op, **kwargs): + offset, axis1, axis2 = op.offset, op.axis1, op.axis2 + + def extract_diag(x, offset=offset, axis1=axis1, axis2=axis2): + return mx.diagonal(x, offset=offset, axis1=axis1, axis2=axis2) # MLX + + return extract_diag + + +# ------------------------------------------------------------------ +# Eye +# ------------------------------------------------------------------ +@mlx_funcify.register(Eye) # MLX +def mlx_funcify_Eye(op, **kwargs): + dtype = op.dtype + + def eye(N, M, k): + return mx.eye(int(N), int(M), int(k), dtype=dtype) # MLX + + return eye + + +# ------------------------------------------------------------------ +# MakeVector +# ------------------------------------------------------------------ +@mlx_funcify.register(MakeVector) # MLX +def mlx_funcify_MakeVector(op, **kwargs): + def makevector(*x): + return mx.array(x, dtype=op.dtype) # MLX + + return makevector + + +# ------------------------------------------------------------------ +# TensorFromScalar (identity for MLX) +# ------------------------------------------------------------------ +@mlx_funcify.register(TensorFromScalar) # MLX +def mlx_funcify_TensorFromScalar(op, **kwargs): + def tensor_from_scalar(x): + return x # already an MLX array / scalar + + return tensor_from_scalar + + +# ------------------------------------------------------------------ +# ScalarFromTensor +# ------------------------------------------------------------------ +@mlx_funcify.register(ScalarFromTensor) # MLX +def mlx_funcify_ScalarFromTensor(op, **kwargs): + def scalar_from_tensor(x): + return mx.array(x).reshape(-1)[0] # MLX + + return scalar_from_tensor + + +# ------------------------------------------------------------------ +# Tri +# ------------------------------------------------------------------ +@mlx_funcify.register(Tri) # MLX +def mlx_funcify_Tri(op, node, **kwargs): + # node.inputs -> N, M, k + const_args = [getattr(inp, "data", None) for inp in node.inputs] + + def tri(*args): + # Replace args with compile‑time constants when available + args = [ + arg if const_a is None else const_a + for arg, const_a in zip(args, const_args, strict=True) + ] + return mx.tri(*args, dtype=op.dtype) # MLX + + return tri + +## Change the code to use the mlx functions \ No newline at end of file From 242aba75e3f212673faaa9b21d77029d76550029 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Fri, 18 Apr 2025 15:34:02 -0400 Subject: [PATCH 26/71] add more tests --- tests/link/mlx/test_math.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/tests/link/mlx/test_math.py b/tests/link/mlx/test_math.py index c4f16d2f28..57b717360d 100644 --- a/tests/link/mlx/test_math.py +++ b/tests/link/mlx/test_math.py @@ -27,8 +27,13 @@ def test_dot(): @pytest.mark.parametrize( "op", - [pt.exp, pt.log, pt.sin, pt.cos], - ids=["exp", "log", "sin", "cos"], + [ + pytest.param(pt.exp, id="exp"), + pytest.param(pt.log, id="log"), + pytest.param(pt.sin, id="sin"), + pytest.param(pt.cos, id="cos"), + pytest.param(pt.sigmoid, id="sigmoid"), + ], ) def test_elemwise_one_input(op) -> None: x = pt.vector("x") @@ -39,8 +44,16 @@ def test_elemwise_one_input(op) -> None: @pytest.mark.parametrize( "op", - [pt.add, pt.sub, pt.mul], - ids=["add", "sub", "mul"], + [ + pytest.param(pt.add, id="add"), + pytest.param(pt.sub, id="sub"), + pytest.param(pt.mul, id="mul"), + pytest.param(pt.power, id="power"), + pytest.param(pt.le, id="le"), + pytest.param(pt.lt, id="lt"), + pytest.param(pt.ge, id="ge"), + pytest.param(pt.gt, id="gt"), + ], ) def test_elemwise_two_inputs(op) -> None: x = pt.vector("x") From 6cb47fc61b316fe5d22c1a56bc408d2a8e6c6366 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Fri, 18 Apr 2025 15:43:39 -0400 Subject: [PATCH 27/71] additional tests --- tests/link/mlx/test_math.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/link/mlx/test_math.py b/tests/link/mlx/test_math.py index 57b717360d..25a0198cc1 100644 --- a/tests/link/mlx/test_math.py +++ b/tests/link/mlx/test_math.py @@ -53,6 +53,9 @@ def test_elemwise_one_input(op) -> None: pytest.param(pt.lt, id="lt"), pytest.param(pt.ge, id="ge"), pytest.param(pt.gt, id="gt"), + pytest.param(pt.eq, id="eq"), + pytest.param(pt.neq, id="neq"), + pytest.param(pt.true_div, id="true_div"), ], ) def test_elemwise_two_inputs(op) -> None: From bc98e09caa1053becbeab2017d3958b6e9bd77ca Mon Sep 17 00:00:00 2001 From: Will Dean Date: Fri, 18 Apr 2025 15:50:17 -0400 Subject: [PATCH 28/71] test for switch with mlx --- tests/link/mlx/test_basic.py | 6 ++++-- tests/link/mlx/test_math.py | 12 ++++++++++++ 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/tests/link/mlx/test_basic.py b/tests/link/mlx/test_basic.py index 746aab10bc..4a6e67f406 100644 --- a/tests/link/mlx/test_basic.py +++ b/tests/link/mlx/test_basic.py @@ -5,14 +5,16 @@ import pytest from pytensor.compile.function import function -from pytensor.compile.mode import Mode +from pytensor.compile.mode import MLX, Mode +from pytensor.graph import RewriteDatabaseQuery from pytensor.graph.basic import Variable from pytensor.link.mlx import MLXLinker mx = pytest.importorskip("mlx.core") -mlx_mode = Mode(linker=MLXLinker()) +optimizer = RewriteDatabaseQuery(include=["mlx"], exclude=MLX._optimizer.exclude) +mlx_mode = Mode(linker=MLXLinker(), optimizer=optimizer) py_mode = Mode(linker="py", optimizer=None) diff --git a/tests/link/mlx/test_math.py b/tests/link/mlx/test_math.py index 25a0198cc1..86fa999451 100644 --- a/tests/link/mlx/test_math.py +++ b/tests/link/mlx/test_math.py @@ -42,6 +42,18 @@ def test_elemwise_one_input(op) -> None: compare_mlx_and_py([x], out, [x_test]) +def test_switch() -> None: + x = pt.vector("x") + y = pt.vector("y") + + out = pt.switch(x > 0, y, x) + + x_test = mx.array([-1.0, 2.0, 3.0]) + y_test = mx.array([4.0, 5.0, 6.0]) + + compare_mlx_and_py([x, y], out, [x_test, y_test]) + + @pytest.mark.parametrize( "op", [ From 4d5b34b4b324c9f78da83e8d5a94c1f8d692738b Mon Sep 17 00:00:00 2001 From: Carlos Trujillo <59846724+cetagostini@users.noreply.github.com> Date: Fri, 18 Apr 2025 14:52:19 -0500 Subject: [PATCH 29/71] Pushing code --- pytensor/link/mlx/dispatch/__init__.py | 2 + pytensor/link/mlx/dispatch/core.py | 95 ++++++++++++------- pytensor/link/mlx/dispatch/shape.py | 10 +- pytensor/link/mlx/dispatch/signal/__init__.py | 0 pytensor/link/mlx/dispatch/signal/conv.py | 14 +++ 5 files changed, 88 insertions(+), 33 deletions(-) create mode 100644 pytensor/link/mlx/dispatch/signal/__init__.py create mode 100644 pytensor/link/mlx/dispatch/signal/conv.py diff --git a/pytensor/link/mlx/dispatch/__init__.py b/pytensor/link/mlx/dispatch/__init__.py index 7e835d238b..2dd4e8a02d 100644 --- a/pytensor/link/mlx/dispatch/__init__.py +++ b/pytensor/link/mlx/dispatch/__init__.py @@ -7,4 +7,6 @@ import pytensor.link.mlx.dispatch.shape import pytensor.link.mlx.dispatch.subtensor import pytensor.link.mlx.dispatch.core +import pytensor.link.mlx.dispatch.signal +import pytensor.link.mlx.dispatch.signal.conv # isort: on diff --git a/pytensor/link/mlx/dispatch/core.py b/pytensor/link/mlx/dispatch/core.py index 84ae042682..6985c2b656 100644 --- a/pytensor/link/mlx/dispatch/core.py +++ b/pytensor/link/mlx/dispatch/core.py @@ -2,24 +2,33 @@ pytensor/link/mlx/dispatch/basic.py ----------------------------------- -First‑cut MLX translations for the most common tensor Ops. +First-cut MLX translations for the most common tensor Ops. The structure intentionally follows pytensor's JAX dispatcher so that once these kernels stabilise they can be optimised further (e.g. fusing -element‑wise graphs, adding in‑place updates, RNG thinning, etc.). +element-wise graphs, adding in-place updates, RNG thinning, etc.). """ + from __future__ import annotations import warnings -import numpy as np -import mlx.core as mx # MLX -from pytensor.link.mlx.dispatch.basic import mlx_funcify # MLX +import mlx.core as mx # MLX +import numpy as np +from pytensor.link.mlx.dispatch.basic import mlx_funcify # MLX from pytensor.tensor import get_vector_length from pytensor.tensor.basic import ( - Join, Split, ExtractDiag, Eye, MakeVector, - ScalarFromTensor, TensorFromScalar, Tri, + Alloc, + AllocEmpty, + ExtractDiag, + Eye, + Join, + MakeVector, + ScalarFromTensor, + Split, + TensorFromScalar, + Tri, get_scalar_constant_value, ) from pytensor.tensor.exceptions import NotScalarConstantError @@ -28,17 +37,17 @@ # ------------------------------------------------------------------ # Join # ------------------------------------------------------------------ -@mlx_funcify.register(Join) # MLX +@mlx_funcify.register(Join) # MLX def mlx_funcify_Join(op, **kwargs): def join(axis, *tensors): view = op.view if (view != -1) and all( - tensors[i].shape[axis] == 0 # MLX + tensors[i].shape[axis] == 0 # MLX for i in list(range(view)) + list(range(view + 1, len(tensors))) ): return tensors[view] - return mx.concatenate(tensors, axis=axis) # MLX + return mx.concatenate(tensors, axis=axis) # MLX return join @@ -46,7 +55,7 @@ def join(axis, *tensors): # ------------------------------------------------------------------ # Split # ------------------------------------------------------------------ -@mlx_funcify.register(Split) # MLX +@mlx_funcify.register(Split) # MLX def mlx_funcify_Split(op: Split, node, **kwargs): _, axis_sym, splits_sym = node.inputs @@ -60,8 +69,10 @@ def mlx_funcify_Split(op: Split, node, **kwargs): try: constant_splits = np.array( - [get_scalar_constant_value(splits_sym[i]) - for i in range(get_vector_length(splits_sym))] + [ + get_scalar_constant_value(splits_sym[i]) + for i in range(get_vector_length(splits_sym)) + ] ) except (ValueError, NotScalarConstantError): constant_splits = None @@ -78,18 +89,22 @@ def split(x, axis, splits): splits = constant_splits cumsum_splits = np.cumsum(splits[:-1]) else: - # dynamic ‑– keep in graph - splits_arr = mx.array(splits) # MLX - cumsum_splits = mx.cumsum(splits_arr[:-1]).tolist() # python list for mx.split + # dynamic - keep in graph + splits_arr = mx.array(splits) # MLX + cumsum_splits = mx.cumsum( + splits_arr[:-1] + ).tolist() # python list for mx.split if len(splits) != op.len_splits: raise ValueError("Length of 'splits' is not equal to n_splits") if np.sum(np.asarray(splits)) != x.shape[axis]: - raise ValueError("Split sizes do not sum to the input length on the chosen axis.") + raise ValueError( + "Split sizes do not sum to the input length on the chosen axis." + ) if np.any(np.asarray(splits) < 0): raise ValueError("Split sizes cannot be negative.") - return mx.split(x, cumsum_splits, axis=axis) # MLX + return mx.split(x, cumsum_splits, axis=axis) # MLX return split @@ -97,12 +112,12 @@ def split(x, axis, splits): # ------------------------------------------------------------------ # ExtractDiag # ------------------------------------------------------------------ -@mlx_funcify.register(ExtractDiag) # MLX +@mlx_funcify.register(ExtractDiag) # MLX def mlx_funcify_ExtractDiag(op, **kwargs): offset, axis1, axis2 = op.offset, op.axis1, op.axis2 def extract_diag(x, offset=offset, axis1=axis1, axis2=axis2): - return mx.diagonal(x, offset=offset, axis1=axis1, axis2=axis2) # MLX + return mx.diagonal(x, offset=offset, axis1=axis1, axis2=axis2) # MLX return extract_diag @@ -110,12 +125,12 @@ def extract_diag(x, offset=offset, axis1=axis1, axis2=axis2): # ------------------------------------------------------------------ # Eye # ------------------------------------------------------------------ -@mlx_funcify.register(Eye) # MLX +@mlx_funcify.register(Eye) # MLX def mlx_funcify_Eye(op, **kwargs): dtype = op.dtype def eye(N, M, k): - return mx.eye(int(N), int(M), int(k), dtype=dtype) # MLX + return mx.eye(int(N), int(M), int(k), dtype=dtype) # MLX return eye @@ -123,10 +138,10 @@ def eye(N, M, k): # ------------------------------------------------------------------ # MakeVector # ------------------------------------------------------------------ -@mlx_funcify.register(MakeVector) # MLX +@mlx_funcify.register(MakeVector) # MLX def mlx_funcify_MakeVector(op, **kwargs): def makevector(*x): - return mx.array(x, dtype=op.dtype) # MLX + return mx.array(x, dtype=op.dtype) # MLX return makevector @@ -134,10 +149,10 @@ def makevector(*x): # ------------------------------------------------------------------ # TensorFromScalar (identity for MLX) # ------------------------------------------------------------------ -@mlx_funcify.register(TensorFromScalar) # MLX +@mlx_funcify.register(TensorFromScalar) # MLX def mlx_funcify_TensorFromScalar(op, **kwargs): def tensor_from_scalar(x): - return x # already an MLX array / scalar + return x # already an MLX array / scalar return tensor_from_scalar @@ -145,10 +160,10 @@ def tensor_from_scalar(x): # ------------------------------------------------------------------ # ScalarFromTensor # ------------------------------------------------------------------ -@mlx_funcify.register(ScalarFromTensor) # MLX +@mlx_funcify.register(ScalarFromTensor) # MLX def mlx_funcify_ScalarFromTensor(op, **kwargs): def scalar_from_tensor(x): - return mx.array(x).reshape(-1)[0] # MLX + return mx.array(x).reshape(-1)[0] # MLX return scalar_from_tensor @@ -156,19 +171,35 @@ def scalar_from_tensor(x): # ------------------------------------------------------------------ # Tri # ------------------------------------------------------------------ -@mlx_funcify.register(Tri) # MLX +@mlx_funcify.register(Tri) # MLX def mlx_funcify_Tri(op, node, **kwargs): # node.inputs -> N, M, k const_args = [getattr(inp, "data", None) for inp in node.inputs] def tri(*args): - # Replace args with compile‑time constants when available + # Replace args with compile-time constants when available args = [ arg if const_a is None else const_a for arg, const_a in zip(args, const_args, strict=True) ] - return mx.tri(*args, dtype=op.dtype) # MLX + return mx.tri(*args, dtype=op.dtype) # MLX return tri -## Change the code to use the mlx functions \ No newline at end of file + +@mlx_funcify.register(AllocEmpty) +def mlx_funcify_AllocEmpty(op, **kwargs): + def allocempty(*shape): + return mx.zeros(shape, dtype=op.dtype) + + return allocempty + + +@mlx_funcify.register(Alloc) +def mlx_funcify_Alloc(op, node, **kwargs): + def alloc(x, *shape): + res = mx.broadcast_to(x, shape) + Alloc._check_runtime_broadcast(node, mx.array(x), res.shape) + return res + + return alloc diff --git a/pytensor/link/mlx/dispatch/shape.py b/pytensor/link/mlx/dispatch/shape.py index 1d48eae1f5..a0b8193b42 100644 --- a/pytensor/link/mlx/dispatch/shape.py +++ b/pytensor/link/mlx/dispatch/shape.py @@ -1,5 +1,5 @@ from pytensor.link.mlx.dispatch.basic import mlx_funcify -from pytensor.tensor.shape import SpecifyShape +from pytensor.tensor.shape import Shape_i, SpecifyShape @mlx_funcify.register(SpecifyShape) @@ -14,3 +14,11 @@ def specifyshape(x, *shape): return x return specifyshape + + +@mlx_funcify.register(Shape_i) +def mlx_funcify_Shape_i(op, node, **kwargs): + def shape_i(x, i): + return x.shape[op.i] + + return shape_i diff --git a/pytensor/link/mlx/dispatch/signal/__init__.py b/pytensor/link/mlx/dispatch/signal/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/pytensor/link/mlx/dispatch/signal/conv.py b/pytensor/link/mlx/dispatch/signal/conv.py new file mode 100644 index 0000000000..a383695437 --- /dev/null +++ b/pytensor/link/mlx/dispatch/signal/conv.py @@ -0,0 +1,14 @@ +from pytensor.link.mlx.dispatch import mlx_funcify +from pytensor.tensor.signal.conv import Conv1d + +import mlx.core as mx + + +@mlx_funcify.register(Conv1d) +def mlx_funcify_Conv1d(op, node, **kwargs): + mode = op.mode + + def conv1d(data, kernel): + return mx.convolve(data, kernel, mode=mode) + + return conv1d From 5abd32d6b6e67a45f1a3f7f50f08131bb23c98eb Mon Sep 17 00:00:00 2001 From: Carlos Trujillo <59846724+cetagostini@users.noreply.github.com> Date: Fri, 18 Apr 2025 15:00:08 -0500 Subject: [PATCH 30/71] Changes --- pytensor/link/mlx/dispatch/blockwise.py | 16 ++++++++++++++++ pytensor/link/mlx/dispatch/signal/conv.py | 4 ++-- 2 files changed, 18 insertions(+), 2 deletions(-) create mode 100644 pytensor/link/mlx/dispatch/blockwise.py diff --git a/pytensor/link/mlx/dispatch/blockwise.py b/pytensor/link/mlx/dispatch/blockwise.py new file mode 100644 index 0000000000..240ee1ad21 --- /dev/null +++ b/pytensor/link/mlx/dispatch/blockwise.py @@ -0,0 +1,16 @@ +import mlx.core as mx + +from pytensor.link.mlx.dispatch import mlx_funcify +from pytensor.tensor.blockwise import Blockwise + +@mlx_funcify.register(Blockwise) +def funcify_Blockwise(op: Blockwise, node, *args, **kwargs): + core_f = mlx_funcify(op.core_op) + batched_f = core_f + for _ in range(op.batch_ndim(node)): + batched_f = mx.vmap(batched_f) + + def wrapped_blockwise_f(*inputs): + return batched_f(*inputs) + + return wrapped_blockwise_f diff --git a/pytensor/link/mlx/dispatch/signal/conv.py b/pytensor/link/mlx/dispatch/signal/conv.py index a383695437..d3725b7f3e 100644 --- a/pytensor/link/mlx/dispatch/signal/conv.py +++ b/pytensor/link/mlx/dispatch/signal/conv.py @@ -1,8 +1,8 @@ +import mlx.core as mx + from pytensor.link.mlx.dispatch import mlx_funcify from pytensor.tensor.signal.conv import Conv1d -import mlx.core as mx - @mlx_funcify.register(Conv1d) def mlx_funcify_Conv1d(op, node, **kwargs): From 12daeacfd356807fef2c1a0c9e4902bf4f49fa56 Mon Sep 17 00:00:00 2001 From: Carlos Trujillo <59846724+cetagostini@users.noreply.github.com> Date: Fri, 18 Apr 2025 15:23:13 -0500 Subject: [PATCH 31/71] A lot of new code --- pytensor/link/mlx/dispatch/__init__.py | 1 + pytensor/link/mlx/dispatch/blockwise.py | 68 +++++++++++++++++--- pytensor/link/mlx/dispatch/elemwise.py | 24 +++++++- pytensor/link/mlx/dispatch/math.py | 75 ++++++++++++++++++++++- pytensor/link/mlx/dispatch/signal/conv.py | 2 +- 5 files changed, 157 insertions(+), 13 deletions(-) diff --git a/pytensor/link/mlx/dispatch/__init__.py b/pytensor/link/mlx/dispatch/__init__.py index 2dd4e8a02d..f039263a37 100644 --- a/pytensor/link/mlx/dispatch/__init__.py +++ b/pytensor/link/mlx/dispatch/__init__.py @@ -9,4 +9,5 @@ import pytensor.link.mlx.dispatch.core import pytensor.link.mlx.dispatch.signal import pytensor.link.mlx.dispatch.signal.conv +import pytensor.link.mlx.dispatch.blockwise # isort: on diff --git a/pytensor/link/mlx/dispatch/blockwise.py b/pytensor/link/mlx/dispatch/blockwise.py index 240ee1ad21..5a5ed8584a 100644 --- a/pytensor/link/mlx/dispatch/blockwise.py +++ b/pytensor/link/mlx/dispatch/blockwise.py @@ -1,16 +1,66 @@ import mlx.core as mx +from pytensor.graph import FunctionGraph from pytensor.link.mlx.dispatch import mlx_funcify from pytensor.tensor.blockwise import Blockwise + @mlx_funcify.register(Blockwise) def funcify_Blockwise(op: Blockwise, node, *args, **kwargs): - core_f = mlx_funcify(op.core_op) - batched_f = core_f - for _ in range(op.batch_ndim(node)): - batched_f = mx.vmap(batched_f) - - def wrapped_blockwise_f(*inputs): - return batched_f(*inputs) - - return wrapped_blockwise_f + # Create a function graph for the core operation + core_node = op._create_dummy_core_node(node.inputs) + core_fgraph = FunctionGraph(inputs=core_node.inputs, outputs=core_node.outputs) + + # Convert the core function graph to an MLX function + tuple_core_fn = mlx_funcify(core_fgraph, **kwargs) + + # If there's only one output, unwrap it from the tuple + if len(node.outputs) == 1: + + def core_fn(*inputs): + return tuple_core_fn(*inputs)[0] + else: + core_fn = tuple_core_fn + + # Apply vmap for each batch dimension + batch_ndims = op.batch_ndim(node) + vmap_fn = core_fn + for _ in range(batch_ndims): + vmap_fn = mx.vmap(vmap_fn) + + def blockwise_fn(*inputs): + # Check for runtime broadcasting compatibility + op._check_runtime_broadcast(node, inputs) + + # Handle broadcasting for batched dimensions + if batch_ndims > 0: + # Get batch shapes for broadcasting + batch_shapes = [inp.shape[:batch_ndims] for inp in inputs] + + # Calculate the broadcasted batch shape + from functools import reduce + + def broadcast_shapes(shape1, shape2): + return tuple(max(s1, s2) for s1, s2 in zip(shape1, shape2, strict=True)) + + if batch_shapes: + broadcasted_shape = reduce(broadcast_shapes, batch_shapes) + + # Broadcast inputs to the common batch shape + broadcasted_inputs = [] + for inp in inputs: + if inp.shape[:batch_ndims] != broadcasted_shape: + # Create the full target shape + target_shape = broadcasted_shape + inp.shape[batch_ndims:] + # Broadcast the input + broadcasted_inputs.append(mx.broadcast_to(inp, target_shape)) + else: + broadcasted_inputs.append(inp) + + # Apply the vectorized function to the broadcasted inputs + return vmap_fn(*broadcasted_inputs) + + # No broadcasting needed + return vmap_fn(*inputs) + + return blockwise_fn diff --git a/pytensor/link/mlx/dispatch/elemwise.py b/pytensor/link/mlx/dispatch/elemwise.py index 7ec124623e..a5374b8cb4 100644 --- a/pytensor/link/mlx/dispatch/elemwise.py +++ b/pytensor/link/mlx/dispatch/elemwise.py @@ -1,6 +1,7 @@ import mlx.core as mx from pytensor.link.mlx.dispatch.basic import mlx_funcify +from pytensor.scalar import Softplus from pytensor.scalar.basic import AND, OR, Add, Mul, ScalarMaximum, ScalarMinimum from pytensor.tensor.elemwise import CAReduce, DimShuffle from pytensor.tensor.special import Softmax, SoftmaxGrad @@ -59,9 +60,8 @@ def min(x): return mx.min(x, axis=op.axis) return min - else: - raise NotImplementedError(f"MLX does not support {op.scalar_op}") + raise NotImplementedError(f"MLX does not support Elemwise {op.scalar_op}") @mlx_funcify.register(Softmax) @@ -83,3 +83,23 @@ def softmax_grad(dy, sm): return dy_times_sm - mx.sum(dy_times_sm, axis=axis, keepdims=True) * sm return softmax_grad + + +@mlx_funcify.register(Softplus) +def mlx_funcify_Softplus(op, **kwargs): + def softplus(x): + return mx.where( + x < -37.0, + mx.exp(x), + mx.where( + x < 18.0, + mx.log1p(mx.exp(x)), + mx.where( + x < 33.3, + x + mx.exp(-x), + x, + ), + ), + ) + + return softplus diff --git a/pytensor/link/mlx/dispatch/math.py b/pytensor/link/mlx/dispatch/math.py index a0f68324d9..1adf547e3f 100644 --- a/pytensor/link/mlx/dispatch/math.py +++ b/pytensor/link/mlx/dispatch/math.py @@ -8,6 +8,7 @@ LE, LT, NEQ, + Abs, Add, Cos, Exp, @@ -15,14 +16,21 @@ Mul, Pow, Sin, + Sqr, + Sqrt, Sub, Switch, TrueDiv, + Neg, + AND, + OR, + ScalarMaximum, + ScalarMinimum, ) from pytensor.scalar.math import Sigmoid from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.math import Dot - +from pytensor.scalar import Softplus @mlx_funcify.register(Dot) def mlx_funcify_Dot(op, **kwargs): @@ -142,5 +150,70 @@ def true_div(x, y): return mx.divide(x, y) return true_div + elif isinstance(op.scalar_op, Sqr): + + def sqr(x): + return mx.square(x) + + return sqr + elif isinstance(op.scalar_op, Sqrt): + + def sqrt(x): + return mx.sqrt(x) + + return sqrt + elif isinstance(op.scalar_op, Abs): + + def abs(x): + return mx.abs(x) + + return abs + elif isinstance(op.scalar_op, Softplus): + def softplus(x): + return mx.where( + x < -37.0, + mx.exp(x), + mx.where( + x < 18.0, + mx.log1p(mx.exp(x)), + mx.where( + x < 33.3, + x + mx.exp(-x), + x, + ), + ), + ) + + return softplus + elif isinstance(op.scalar_op, Neg): + + def neg(x): + return mx.negative(x) + + return neg + elif isinstance(op.scalar_op, AND): + + def all(x): + return mx.all(x, axis=op.axis) + + return all + elif isinstance(op.scalar_op, OR): + + def any(x): + return mx.any(x, axis=op.axis) + + return any + elif isinstance(op.scalar_op, ScalarMaximum): + + def max(x): + return mx.max(x, axis=op.axis) + + return max + elif isinstance(op.scalar_op, ScalarMinimum): + + def min(x): + return mx.min(x, axis=op.axis) + + return min else: raise NotImplementedError(f"MLX does not support {op.scalar_op}") diff --git a/pytensor/link/mlx/dispatch/signal/conv.py b/pytensor/link/mlx/dispatch/signal/conv.py index d3725b7f3e..8f84ebb42f 100644 --- a/pytensor/link/mlx/dispatch/signal/conv.py +++ b/pytensor/link/mlx/dispatch/signal/conv.py @@ -5,7 +5,7 @@ @mlx_funcify.register(Conv1d) -def mlx_funcify_Conv1d(op, node, **kwargs): +def mlx_funcify_Conv1d(op, node=None, **kwargs): mode = op.mode def conv1d(data, kernel): From ac93949d4b50f6eff5c165781b8fa9a444265e3b Mon Sep 17 00:00:00 2001 From: Carlos Trujillo <59846724+cetagostini@users.noreply.github.com> Date: Fri, 18 Apr 2025 18:20:17 -0500 Subject: [PATCH 32/71] almost there baby william --- pytensor/link/mlx/dispatch/basic.py | 16 +++++++ pytensor/link/mlx/dispatch/blockwise.py | 62 +++---------------------- pytensor/link/mlx/dispatch/math.py | 37 ++++++++++++--- pytensor/link/mlx/dispatch/shape.py | 2 +- pytensor/link/mlx/dispatch/subtensor.py | 32 +++++-------- 5 files changed, 66 insertions(+), 83 deletions(-) diff --git a/pytensor/link/mlx/dispatch/basic.py b/pytensor/link/mlx/dispatch/basic.py index 9cbb92118d..a99772dba3 100644 --- a/pytensor/link/mlx/dispatch/basic.py +++ b/pytensor/link/mlx/dispatch/basic.py @@ -1,3 +1,4 @@ +import warnings from functools import singledispatch from types import NoneType @@ -7,6 +8,7 @@ from pytensor.compile.ops import DeepCopyOp from pytensor.graph.fg import FunctionGraph from pytensor.link.utils import fgraph_to_python +from pytensor.raise_op import Assert, CheckAndRaise @singledispatch @@ -59,3 +61,17 @@ def deepcopyop(x): return x.copy() return deepcopyop + + +@mlx_funcify.register(Assert) +@mlx_funcify.register(CheckAndRaise) +def mlx_funcify_CheckAndRaise(op, **kwargs): + warnings.warn( + f"""Skipping `CheckAndRaise` Op (assertion: {op.msg}) as MLX tracing would remove it.""", + stacklevel=2, + ) + + def assert_fn(x, *inputs): + return x + + return assert_fn diff --git a/pytensor/link/mlx/dispatch/blockwise.py b/pytensor/link/mlx/dispatch/blockwise.py index 5a5ed8584a..550a1c9616 100644 --- a/pytensor/link/mlx/dispatch/blockwise.py +++ b/pytensor/link/mlx/dispatch/blockwise.py @@ -1,66 +1,18 @@ import mlx.core as mx -from pytensor.graph import FunctionGraph from pytensor.link.mlx.dispatch import mlx_funcify from pytensor.tensor.blockwise import Blockwise @mlx_funcify.register(Blockwise) def funcify_Blockwise(op: Blockwise, node, *args, **kwargs): - # Create a function graph for the core operation core_node = op._create_dummy_core_node(node.inputs) - core_fgraph = FunctionGraph(inputs=core_node.inputs, outputs=core_node.outputs) + core_f = mlx_funcify(op.core_op, core_node) + blockwise_f = core_f + for i in range(op.batch_ndim(node)): + blockwise_f = mx.vmap(blockwise_f) - # Convert the core function graph to an MLX function - tuple_core_fn = mlx_funcify(core_fgraph, **kwargs) + def blockwise_fun(*inputs): + return blockwise_f(*inputs) - # If there's only one output, unwrap it from the tuple - if len(node.outputs) == 1: - - def core_fn(*inputs): - return tuple_core_fn(*inputs)[0] - else: - core_fn = tuple_core_fn - - # Apply vmap for each batch dimension - batch_ndims = op.batch_ndim(node) - vmap_fn = core_fn - for _ in range(batch_ndims): - vmap_fn = mx.vmap(vmap_fn) - - def blockwise_fn(*inputs): - # Check for runtime broadcasting compatibility - op._check_runtime_broadcast(node, inputs) - - # Handle broadcasting for batched dimensions - if batch_ndims > 0: - # Get batch shapes for broadcasting - batch_shapes = [inp.shape[:batch_ndims] for inp in inputs] - - # Calculate the broadcasted batch shape - from functools import reduce - - def broadcast_shapes(shape1, shape2): - return tuple(max(s1, s2) for s1, s2 in zip(shape1, shape2, strict=True)) - - if batch_shapes: - broadcasted_shape = reduce(broadcast_shapes, batch_shapes) - - # Broadcast inputs to the common batch shape - broadcasted_inputs = [] - for inp in inputs: - if inp.shape[:batch_ndims] != broadcasted_shape: - # Create the full target shape - target_shape = broadcasted_shape + inp.shape[batch_ndims:] - # Broadcast the input - broadcasted_inputs.append(mx.broadcast_to(inp, target_shape)) - else: - broadcasted_inputs.append(inp) - - # Apply the vectorized function to the broadcasted inputs - return vmap_fn(*broadcasted_inputs) - - # No broadcasting needed - return vmap_fn(*inputs) - - return blockwise_fn + return blockwise_fun diff --git a/pytensor/link/mlx/dispatch/math.py b/pytensor/link/mlx/dispatch/math.py index 1adf547e3f..6696635ed5 100644 --- a/pytensor/link/mlx/dispatch/math.py +++ b/pytensor/link/mlx/dispatch/math.py @@ -1,36 +1,40 @@ import mlx.core as mx from pytensor.link.mlx.dispatch import mlx_funcify +from pytensor.scalar import Softplus from pytensor.scalar.basic import ( + AND, EQ, GE, GT, LE, LT, NEQ, + OR, Abs, Add, + Cast, Cos, Exp, Log, Mul, + Neg, Pow, + ScalarMaximum, + ScalarMinimum, + Sign, Sin, Sqr, Sqrt, Sub, Switch, TrueDiv, - Neg, - AND, - OR, - ScalarMaximum, - ScalarMinimum, + Log1p ) from pytensor.scalar.math import Sigmoid from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.math import Dot -from pytensor.scalar import Softplus + @mlx_funcify.register(Dot) def mlx_funcify_Dot(op, **kwargs): @@ -169,6 +173,7 @@ def abs(x): return abs elif isinstance(op.scalar_op, Softplus): + def softplus(x): return mx.where( x < -37.0, @@ -194,7 +199,7 @@ def neg(x): elif isinstance(op.scalar_op, AND): def all(x): - return mx.all(x, axis=op.axis) + return mx.all(x) return all elif isinstance(op.scalar_op, OR): @@ -215,5 +220,23 @@ def min(x): return mx.min(x, axis=op.axis) return min + elif isinstance(op.scalar_op, Cast): + + def cast(x): + return mx.cast(x, op.dtype) + + return cast + elif isinstance(op.scalar_op, Sign): + + def sign(x): + return mx.sign(x) + + return sign + elif isinstance(op.scalar_op, Log1p): + + def log1p(x): + return mx.log1p(x) + + return log1p else: raise NotImplementedError(f"MLX does not support {op.scalar_op}") diff --git a/pytensor/link/mlx/dispatch/shape.py b/pytensor/link/mlx/dispatch/shape.py index a0b8193b42..bd5b5941d9 100644 --- a/pytensor/link/mlx/dispatch/shape.py +++ b/pytensor/link/mlx/dispatch/shape.py @@ -18,7 +18,7 @@ def specifyshape(x, *shape): @mlx_funcify.register(Shape_i) def mlx_funcify_Shape_i(op, node, **kwargs): - def shape_i(x, i): + def shape_i(x): return x.shape[op.i] return shape_i diff --git a/pytensor/link/mlx/dispatch/subtensor.py b/pytensor/link/mlx/dispatch/subtensor.py index 7f8b55f18e..b45a10519c 100644 --- a/pytensor/link/mlx/dispatch/subtensor.py +++ b/pytensor/link/mlx/dispatch/subtensor.py @@ -11,40 +11,32 @@ from pytensor.tensor.type_other import MakeSlice -BOOLEAN_MASK_ERROR = """MLX does not support resizing arrays with boolean -masks. In some cases, however, it is possible to re-express your model -in a form that MLX can compile: - ->>> import pytensor.tensor as pt ->>> x_pt = pt.vector('x') ->>> y_pt = x_pt[x_pt > 0].sum() - -can be re-expressed as: +@mlx_funcify.register(Subtensor) +def mlx_funcify_Subtensor(op, node, **kwargs): + idx_list = getattr(op, "idx_list", None) ->>> import pytensor.tensor as pt ->>> x_pt = pt.vector('x') ->>> y_pt = pt.where(x_pt > 0, x_pt, 0).sum() -""" + def subtensor(x, *ilists): + indices = indices_from_subtensor([int(element) for element in ilists], idx_list) + if len(indices) == 1: + indices = indices[0] -DYNAMIC_SLICE_LENGTH_ERROR = """MLX does not support slicing arrays with a dynamic -slice length. -""" + return x.__getitem__(indices) + return subtensor -@mlx_funcify.register(Subtensor) @mlx_funcify.register(AdvancedSubtensor) @mlx_funcify.register(AdvancedSubtensor1) -def mlx_funcify_Subtensor(op, node, **kwargs): +def mlx_funcify_AdvancedSubtensor(op, node, **kwargs): idx_list = getattr(op, "idx_list", None) - def subtensor(x, *ilists): + def advanced_subtensor(x, *ilists): indices = indices_from_subtensor(ilists, idx_list) if len(indices) == 1: indices = indices[0] return x.__getitem__(indices) - return subtensor + return advanced_subtensor @mlx_funcify.register(IncSubtensor) From a19cbc87180adf5bc96cad8c91be9ad19ab0856d Mon Sep 17 00:00:00 2001 From: Carlos Trujillo <59846724+cetagostini@users.noreply.github.com> Date: Fri, 18 Apr 2025 18:28:04 -0500 Subject: [PATCH 33/71] Another push small --- pytensor/link/mlx/dispatch/elemwise.py | 2 +- pytensor/link/mlx/dispatch/math.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pytensor/link/mlx/dispatch/elemwise.py b/pytensor/link/mlx/dispatch/elemwise.py index a5374b8cb4..81cdf2b2ca 100644 --- a/pytensor/link/mlx/dispatch/elemwise.py +++ b/pytensor/link/mlx/dispatch/elemwise.py @@ -39,7 +39,7 @@ def prod(x): elif isinstance(op.scalar_op, AND): def all(x): - return mx.all(x, axis=op.axis) + return mx.all(a=x, axis=op.axis) return all elif isinstance(op.scalar_op, OR): diff --git a/pytensor/link/mlx/dispatch/math.py b/pytensor/link/mlx/dispatch/math.py index 6696635ed5..5398183f29 100644 --- a/pytensor/link/mlx/dispatch/math.py +++ b/pytensor/link/mlx/dispatch/math.py @@ -199,7 +199,7 @@ def neg(x): elif isinstance(op.scalar_op, AND): def all(x): - return mx.all(x) + return mx.all(a=x, axis=op.axis) return all elif isinstance(op.scalar_op, OR): From 5c97bc8d31185d82b7dd97c962cc60664873e301 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Fri, 18 Apr 2025 19:34:11 -0400 Subject: [PATCH 34/71] fix for all --- pytensor/link/mlx/dispatch/elemwise.py | 2 +- pytensor/link/mlx/dispatch/math.py | 4 ++-- tests/link/mlx/test_elemwise.py | 12 ++++++++++++ 3 files changed, 15 insertions(+), 3 deletions(-) create mode 100644 tests/link/mlx/test_elemwise.py diff --git a/pytensor/link/mlx/dispatch/elemwise.py b/pytensor/link/mlx/dispatch/elemwise.py index 81cdf2b2ca..57103c12ff 100644 --- a/pytensor/link/mlx/dispatch/elemwise.py +++ b/pytensor/link/mlx/dispatch/elemwise.py @@ -39,7 +39,7 @@ def prod(x): elif isinstance(op.scalar_op, AND): def all(x): - return mx.all(a=x, axis=op.axis) + return x.all(axis=op.axis) return all elif isinstance(op.scalar_op, OR): diff --git a/pytensor/link/mlx/dispatch/math.py b/pytensor/link/mlx/dispatch/math.py index 5398183f29..305f86c90b 100644 --- a/pytensor/link/mlx/dispatch/math.py +++ b/pytensor/link/mlx/dispatch/math.py @@ -17,6 +17,7 @@ Cos, Exp, Log, + Log1p, Mul, Neg, Pow, @@ -29,7 +30,6 @@ Sub, Switch, TrueDiv, - Log1p ) from pytensor.scalar.math import Sigmoid from pytensor.tensor.elemwise import Elemwise @@ -199,7 +199,7 @@ def neg(x): elif isinstance(op.scalar_op, AND): def all(x): - return mx.all(a=x, axis=op.axis) + return x.all(axis=op.axis) return all elif isinstance(op.scalar_op, OR): diff --git a/tests/link/mlx/test_elemwise.py b/tests/link/mlx/test_elemwise.py new file mode 100644 index 0000000000..d7e17b6654 --- /dev/null +++ b/tests/link/mlx/test_elemwise.py @@ -0,0 +1,12 @@ +import pytensor.tensor as pt +from tests.link.mlx.test_basic import compare_mlx_and_py, mx + + +def test_all() -> None: + x = pt.vector("x") + + out = pt.all(x > 0) + + x_test = mx.array([-1.0, 2.0, 3.0]) + + compare_mlx_and_py([x], out, [x_test]) From 2fc81bc8701949f555d2c85e8b3c668313d745c4 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Fri, 18 Apr 2025 19:45:54 -0400 Subject: [PATCH 35/71] fix for carlos --- tests/link/mlx/test_elemwise.py | 11 ++++++----- tests/link/mlx/test_math.py | 10 ++++++++++ 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/tests/link/mlx/test_elemwise.py b/tests/link/mlx/test_elemwise.py index d7e17b6654..7819df06be 100644 --- a/tests/link/mlx/test_elemwise.py +++ b/tests/link/mlx/test_elemwise.py @@ -1,12 +1,13 @@ +import pytest + import pytensor.tensor as pt from tests.link.mlx.test_basic import compare_mlx_and_py, mx -def test_all() -> None: +@pytest.mark.parametrize("op", [pt.any, pt.all, pt.max, pt.min]) +def test_input(op) -> None: x = pt.vector("x") - - out = pt.all(x > 0) - - x_test = mx.array([-1.0, 2.0, 3.0]) + out = op(x > 0) + x_test = mx.array([1.0, 2.0, 3.0]) compare_mlx_and_py([x], out, [x_test]) diff --git a/tests/link/mlx/test_math.py b/tests/link/mlx/test_math.py index 86fa999451..850d9e754d 100644 --- a/tests/link/mlx/test_math.py +++ b/tests/link/mlx/test_math.py @@ -54,6 +54,16 @@ def test_switch() -> None: compare_mlx_and_py([x, y], out, [x_test, y_test]) +@pytest.mark.parametrize("op", [pt.sum, pt.prod]) +def test_input(op) -> None: + x = pt.vector("x") + y = pt.vector("y") + out = op([x, y, x + y]) + x_test = mx.array([1.0, 2.0, 3.0]) + y_test = mx.array([4.0, 5.0, 6.0]) + compare_mlx_and_py([x, y], out, [x_test, y_test]) + + @pytest.mark.parametrize( "op", [ From e6437cc33d4ea165adec28a9900086e38f6e2bd1 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Fri, 18 Apr 2025 20:09:21 -0400 Subject: [PATCH 36/71] just return the compiled func --- pytensor/link/mlx/linker.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pytensor/link/mlx/linker.py b/pytensor/link/mlx/linker.py index e057bb942c..1dc06aefc6 100644 --- a/pytensor/link/mlx/linker.py +++ b/pytensor/link/mlx/linker.py @@ -35,6 +35,8 @@ def jit_compile(self, fn): inner_fn = mx.compile(fn) + return inner_fn + def fn(*inputs, inner_fn=inner_fn): return inner_fn(*(mlx_typify(inp) for inp in inputs)) From c3a3e1a81d1b436282eabf390c1f3f7c66c20e47 Mon Sep 17 00:00:00 2001 From: Carlos Trujillo <59846724+cetagostini@users.noreply.github.com> Date: Fri, 18 Apr 2025 19:19:39 -0500 Subject: [PATCH 37/71] A change for willy may! --- pytensor/link/mlx/dispatch/math.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pytensor/link/mlx/dispatch/math.py b/pytensor/link/mlx/dispatch/math.py index 305f86c90b..293a7cfa0a 100644 --- a/pytensor/link/mlx/dispatch/math.py +++ b/pytensor/link/mlx/dispatch/math.py @@ -198,14 +198,14 @@ def neg(x): return neg elif isinstance(op.scalar_op, AND): - def all(x): - return x.all(axis=op.axis) + def all(x, y): + return mx.bitwise_and(x, y) return all elif isinstance(op.scalar_op, OR): - def any(x): - return mx.any(x, axis=op.axis) + def any(x, y): + return mx.bitwise_or(x, y) return any elif isinstance(op.scalar_op, ScalarMaximum): From e7cf10ea0ed889c1ed91e80243f825087c30d1dd Mon Sep 17 00:00:00 2001 From: Carlos Trujillo <59846724+cetagostini@users.noreply.github.com> Date: Fri, 18 Apr 2025 19:24:10 -0500 Subject: [PATCH 38/71] FINALLY BABY LETS PARTY! (IF YOU ARE READING THIS MAKE MORE PRs) --- pytensor/link/mlx/linker.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pytensor/link/mlx/linker.py b/pytensor/link/mlx/linker.py index 1dc06aefc6..e057bb942c 100644 --- a/pytensor/link/mlx/linker.py +++ b/pytensor/link/mlx/linker.py @@ -35,8 +35,6 @@ def jit_compile(self, fn): inner_fn = mx.compile(fn) - return inner_fn - def fn(*inputs, inner_fn=inner_fn): return inner_fn(*(mlx_typify(inp) for inp in inputs)) From 880dd5cf3beb875f955b3c5b7d6bbe39560322f6 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Fri, 18 Apr 2025 20:51:53 -0400 Subject: [PATCH 39/71] refactor to use getattr --- pytensor/link/mlx/dispatch/elemwise.py | 86 ++++++++++++++------------ 1 file changed, 47 insertions(+), 39 deletions(-) diff --git a/pytensor/link/mlx/dispatch/elemwise.py b/pytensor/link/mlx/dispatch/elemwise.py index 57103c12ff..c71de48b12 100644 --- a/pytensor/link/mlx/dispatch/elemwise.py +++ b/pytensor/link/mlx/dispatch/elemwise.py @@ -2,7 +2,6 @@ from pytensor.link.mlx.dispatch.basic import mlx_funcify from pytensor.scalar import Softplus -from pytensor.scalar.basic import AND, OR, Add, Mul, ScalarMaximum, ScalarMinimum from pytensor.tensor.elemwise import CAReduce, DimShuffle from pytensor.tensor.special import Softmax, SoftmaxGrad @@ -24,44 +23,53 @@ def dimshuffle(x): @mlx_funcify.register(CAReduce) def mlx_funcify_CAReduce(op, **kwargs): - if isinstance(op.scalar_op, Add): - - def sum(x): - return mx.sum(x, axis=op.axis) - - return sum - elif isinstance(op.scalar_op, Mul): - - def prod(x): - return mx.prod(x, axis=op.axis) - - return prod - elif isinstance(op.scalar_op, AND): - - def all(x): - return x.all(axis=op.axis) - - return all - elif isinstance(op.scalar_op, OR): - - def any(x): - return mx.any(x, axis=op.axis) - - return any - elif isinstance(op.scalar_op, ScalarMaximum): - - def max(x): - return mx.max(x, axis=op.axis) - - return max - elif isinstance(op.scalar_op, ScalarMinimum): - - def min(x): - return mx.min(x, axis=op.axis) - - return min - else: - raise NotImplementedError(f"MLX does not support Elemwise {op.scalar_op}") + axis = op.axis + op_nfunc_spec = getattr(op, "nfunc_spec", None) + scalar_nfunc_spec = getattr(op.scalar_op, "nfunc_spec", None) + scalar_op_name = getattr(op.scalar_op, "name", None) + scalar_op_identity = getattr(op.scalar_op, "identity", None) + acc_dtype = getattr(op, "acc_dtype", None) + + def careduce(x): + nonlocal \ + axis, \ + op_nfunc_spec, \ + scalar_nfunc_spec, \ + scalar_op_name, \ + scalar_op_identity, \ + acc_dtype + + if axis is None: + axis = list(range(x.ndim)) + + if acc_dtype is None: + acc_dtype = x.dtype.type + + if op_nfunc_spec: + mlx_op = getattr(mx, op_nfunc_spec[0]) + return mlx_op(x, axis=axis) + return mlx_op(x, axis=axis).astype(acc_dtype) + + # The PyTensor `Op` didn't tell us which NumPy equivalent to use (or + # there isn't one), so we use this fallback approach + if scalar_nfunc_spec: + scalar_fn_name = scalar_nfunc_spec[0] + elif scalar_op_name: + scalar_fn_name = scalar_op_name + + to_reduce = sorted(axis, reverse=True) + + if to_reduce: + raise NotImplementedError("Not implemented yet") + # In this case, we need to use the `jax.lax` function (if there + # is one), and not the `jnp` version. + mlx_op = getattr(mx, scalar_fn_name) + init_value = mx.array(scalar_op_identity, dtype=acc_dtype) + return mx.reduce(x, init_value, mlx_op, to_reduce).astype(acc_dtype) + else: + return x + + return careduce @mlx_funcify.register(Softmax) From 1e6addd79c2c839bb2a3f42a309e10d9a2585b2c Mon Sep 17 00:00:00 2001 From: Will Dean Date: Fri, 18 Apr 2025 20:52:17 -0400 Subject: [PATCH 40/71] bring argmax test --- tests/link/mlx/test_math.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/link/mlx/test_math.py b/tests/link/mlx/test_math.py index 850d9e754d..2c08d986c9 100644 --- a/tests/link/mlx/test_math.py +++ b/tests/link/mlx/test_math.py @@ -3,6 +3,7 @@ import pytensor import pytensor.tensor as pt +from pytensor.tensor.math import Argmax, Max from tests.link.mlx.test_basic import compare_mlx_and_py, mx @@ -87,3 +88,14 @@ def test_elemwise_two_inputs(op) -> None: x_test = mx.array([1.0, 2.0, 3.0]) y_test = mx.array([4.0, 5.0, 6.0]) compare_mlx_and_py([x, y], out, [x_test, y_test]) + + +@pytest.mark.xfail(reason="Argmax not implemented yet") +def test_mlx_max_and_argmax(): + # Test that a single output of a multi-output `Op` can be used as input to + # another `Op` + x = pt.dvector() + mx = Max([0])(x) + amx = Argmax([0])(x) + out = mx * amx + compare_mlx_and_py([x], [out], [np.r_[1, 2]]) From aabbb788ff70668d37ea8523e72be4032e5430e0 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Fri, 18 Apr 2025 21:10:21 -0400 Subject: [PATCH 41/71] use deepcopy --- pytensor/link/mlx/dispatch/basic.py | 3 ++- pytensor/link/mlx/dispatch/subtensor.py | 11 +++++++---- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/pytensor/link/mlx/dispatch/basic.py b/pytensor/link/mlx/dispatch/basic.py index a99772dba3..d0b3d451f5 100644 --- a/pytensor/link/mlx/dispatch/basic.py +++ b/pytensor/link/mlx/dispatch/basic.py @@ -1,4 +1,5 @@ import warnings +from copy import deepcopy from functools import singledispatch from types import NoneType @@ -58,7 +59,7 @@ def mlx_funcify_FunctionGraph( @mlx_funcify.register(DeepCopyOp) def mlx_funcify_DeepCopyOp(op, **kwargs): def deepcopyop(x): - return x.copy() + return deepcopy(x) return deepcopyop diff --git a/pytensor/link/mlx/dispatch/subtensor.py b/pytensor/link/mlx/dispatch/subtensor.py index b45a10519c..ce14d08246 100644 --- a/pytensor/link/mlx/dispatch/subtensor.py +++ b/pytensor/link/mlx/dispatch/subtensor.py @@ -1,3 +1,5 @@ +from copy import deepcopy + from pytensor.link.mlx.dispatch.basic import mlx_funcify from pytensor.tensor.subtensor import ( AdvancedIncSubtensor, @@ -24,6 +26,7 @@ def subtensor(x, *ilists): return subtensor + @mlx_funcify.register(AdvancedSubtensor) @mlx_funcify.register(AdvancedSubtensor1) def mlx_funcify_AdvancedSubtensor(op, node, **kwargs): @@ -48,7 +51,7 @@ def mlx_funcify_IncSubtensor(op, node, **kwargs): def mlx_fn(x, indices, y): if not op.inplace: - x = x.copy() + x = deepcopy(x) x[indices] = y return x @@ -56,7 +59,7 @@ def mlx_fn(x, indices, y): def mlx_fn(x, indices, y): if not op.inplace: - x = x.copy() + x = deepcopy(x) x[indices] += y return x @@ -76,7 +79,7 @@ def mlx_funcify_AdvancedIncSubtensor(op, node, **kwargs): def mlx_fn(x, indices, y): if not op.inplace: - x = x.copy() + x = deepcopy(x) x[indices] = y return x @@ -84,7 +87,7 @@ def mlx_fn(x, indices, y): def mlx_fn(x, indices, y): if not op.inplace: - x = x.copy() + x = deepcopy(x) x[indices] += y return x From 0812c55398f19c4ed95a3d7cca4891ff75f68b27 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Fri, 18 Apr 2025 21:11:30 -0400 Subject: [PATCH 42/71] move some tests --- tests/link/mlx/test_shape.py | 78 ++++++++++++++++++++++++++++++++++++ 1 file changed, 78 insertions(+) create mode 100644 tests/link/mlx/test_shape.py diff --git a/tests/link/mlx/test_shape.py b/tests/link/mlx/test_shape.py new file mode 100644 index 0000000000..7a548df8f8 --- /dev/null +++ b/tests/link/mlx/test_shape.py @@ -0,0 +1,78 @@ +import numpy as np +import pytest + +import pytensor.tensor as pt +from pytensor.compile.ops import DeepCopyOp, ViewOp +from pytensor.configdefaults import config +from pytensor.tensor.shape import Shape, Shape_i, reshape +from pytensor.tensor.type import iscalar, vector +from tests.link.mlx.test_basic import compare_mlx_and_py + + +@pytest.mark.xfail(reason="Shape Op is not supported yet") +def test_mlx_shape_ops(): + x_np = np.zeros((20, 3)) + x = Shape()(pt.as_tensor_variable(x_np)) + + compare_mlx_and_py([], [x], [], must_be_device_array=False) + + x = Shape_i(1)(pt.as_tensor_variable(x_np)) + + compare_mlx_and_py([], [x], [], must_be_device_array=False) + + +@pytest.mark.xfail(reason="Shape Op is not supported yet") +def test_mlx_specify_shape(): + in_pt = pt.matrix("in") + x = pt.specify_shape(in_pt, (4, None)) + compare_mlx_and_py([in_pt], [x], [np.ones((4, 5)).astype(config.floatX)]) + + # When used to assert two arrays have similar shapes + in_pt = pt.matrix("in") + shape_pt = pt.matrix("shape") + x = pt.specify_shape(in_pt, shape_pt.shape) + + compare_mlx_and_py( + [in_pt, shape_pt], + [x], + [np.ones((4, 5)).astype(config.floatX), np.ones((4, 5)).astype(config.floatX)], + ) + + +@pytest.mark.xfail(reason="Reshape Op is not supported yet") +def test_mlx_Reshape_constant(): + a = vector("a") + x = reshape(a, (2, 2)) + compare_mlx_and_py([a], [x], [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX)]) + + +@pytest.mark.xfail(reason="Reshape Op is not supported yet") +def test_mlx_Reshape_concrete_shape(): + """MLX should compile when a concrete value is passed for the `shape` parameter.""" + a = vector("a") + x = reshape(a, a.shape) + compare_mlx_and_py([a], [x], [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX)]) + + x = reshape(a, (a.shape[0] // 2, a.shape[0] // 2)) + compare_mlx_and_py([a], [x], [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX)]) + + +@pytest.mark.xfail(reason="`shape_pt` should be specified as a static argument") +def test_mlx_Reshape_shape_graph_input(): + a = vector("a") + shape_pt = iscalar("b") + x = reshape(a, (shape_pt, shape_pt)) + compare_mlx_and_py( + [a, shape_pt], [x], [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX), 2] + ) + + +@pytest.mark.xfail(reason="ViewOp Op is not supported yet") +def test_mlx_compile_ops(): + x = DeepCopyOp()(pt.as_tensor_variable(1.1)) + compare_mlx_and_py([], [x], []) + + x_np = np.zeros((20, 1, 1)) + x = ViewOp()(pt.as_tensor_variable(x_np)) + + compare_mlx_and_py([], [x], []) From 294c271ca2258186fa8d3ce68c9fabc1a9a0c261 Mon Sep 17 00:00:00 2001 From: Carlos Trujillo <59846724+cetagostini@users.noreply.github.com> Date: Fri, 18 Apr 2025 20:13:34 -0500 Subject: [PATCH 43/71] THE SUPER BLOCKWISEE YA YA YA YA JUUUUU --- pytensor/link/mlx/dispatch/blockwise.py | 32 ++++++++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/pytensor/link/mlx/dispatch/blockwise.py b/pytensor/link/mlx/dispatch/blockwise.py index 550a1c9616..378fe861a8 100644 --- a/pytensor/link/mlx/dispatch/blockwise.py +++ b/pytensor/link/mlx/dispatch/blockwise.py @@ -2,11 +2,41 @@ from pytensor.link.mlx.dispatch import mlx_funcify from pytensor.tensor.blockwise import Blockwise +from pytensor.tensor.signal.conv import Conv1d +def blockwise_conv1d(op, node): + if op.core_op.mode != "valid": + raise NotImplementedError("Only 'valid' mode is supported for conv1d") + batches_ndim = op.batch_ndim(node) + if batches_ndim != 1: + raise NotImplementedError("Only 1D batches are supported for conv1d") + + _, kernel = node.inputs + if not all(kernel.type.broadcastable[:batches_ndim]): + raise NotImplementedError("Only 1D batches are supported for conv1d") + + def inner_f(x, kernel): + x_reshaped = x.reshape(-1, x.shape[-1]).T # shape equals to (N, B) -> N Time as batches all together + b = x_reshaped.shape[1] # + kernel_squeeze = kernel.reshape(-1) + f = kernel_squeeze.shape[0] # Number of filters + kernel_reshaped = mx.broadcast_to(a=kernel_squeeze[None, :, None], shape=(b, f, b)) + conv_result = mx.conv1d(x_reshaped[None, :, :], kernel_reshaped, stride=1, padding=0, dilation=1) + _, conv_shape, _ = conv_result.shape + return mx.moveaxis(a=conv_result, source=-1, destination=0).reshape(x.shape[:-1] + (conv_shape,)) + return inner_f @mlx_funcify.register(Blockwise) -def funcify_Blockwise(op: Blockwise, node, *args, **kwargs): +def funcify_Blockwise(op: Blockwise, node, **kwargs): + if isinstance(op.core_op, Conv1d): + return blockwise_conv1d(op, node, **kwargs) + + core_f = mlx_funcify(op.core_op) + + def blockwise_f(*inputs): + return blockwise_f(*inputs) core_node = op._create_dummy_core_node(node.inputs) + core_f = mlx_funcify(op.core_op, core_node) blockwise_f = core_f for i in range(op.batch_ndim(node)): From 9f31ab109c870a05a0207c9e4f9019e301129601 Mon Sep 17 00:00:00 2001 From: Carlos Trujillo <59846724+cetagostini@users.noreply.github.com> Date: Fri, 18 Apr 2025 20:45:51 -0500 Subject: [PATCH 44/71] Guys, I'm getting sad. We need help yisus!!!!! --- pytensor/link/mlx/dispatch/blockwise.py | 36 ++++++++++++++++--------- 1 file changed, 23 insertions(+), 13 deletions(-) diff --git a/pytensor/link/mlx/dispatch/blockwise.py b/pytensor/link/mlx/dispatch/blockwise.py index 378fe861a8..9c8d67b69a 100644 --- a/pytensor/link/mlx/dispatch/blockwise.py +++ b/pytensor/link/mlx/dispatch/blockwise.py @@ -4,26 +4,36 @@ from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.signal.conv import Conv1d -def blockwise_conv1d(op, node): +import numpy as np + +def blockwise_conv1d(op, node, **kwargs): if op.core_op.mode != "valid": raise NotImplementedError("Only 'valid' mode is supported for conv1d") - batches_ndim = op.batch_ndim(node) - if batches_ndim != 1: - raise NotImplementedError("Only 1D batches are supported for conv1d") + # batches_ndim = op.batch_ndim(node) + # if batches_ndim != 1: + # raise NotImplementedError("Only 1D batches are supported for conv1d") - _, kernel = node.inputs - if not all(kernel.type.broadcastable[:batches_ndim]): - raise NotImplementedError("Only 1D batches are supported for conv1d") + # _, kernel = node.inputs + # if not all(kernel.type.broadcastable[:batches_ndim]): + # raise NotImplementedError("Only 1D batches are supported for conv1d") def inner_f(x, kernel): - x_reshaped = x.reshape(-1, x.shape[-1]).T # shape equals to (N, B) -> N Time as batches all together - b = x_reshaped.shape[1] # - kernel_squeeze = kernel.reshape(-1) - f = kernel_squeeze.shape[0] # Number of filters - kernel_reshaped = mx.broadcast_to(a=kernel_squeeze[None, :, None], shape=(b, f, b)) + *bx, t = x.shape + *bk, h = kernel.shape + + b = np.broadcast_shapes(bx, bk) + + x = x.reshape(b + (t,)) + kernel = kernel.reshape(b + (h,)) + + x_reshaped = x.reshape(-1, t).T # shape equals to (N, B) -> N Time as batches all together + kernel_squeeze = kernel.reshape(-1, h) + b_prod = kernel_squeeze.shape[0] + + kernel_reshaped = mx.broadcast_to(a=kernel_squeeze[None, :, None], shape=(b_prod, h, b_prod)) conv_result = mx.conv1d(x_reshaped[None, :, :], kernel_reshaped, stride=1, padding=0, dilation=1) _, conv_shape, _ = conv_result.shape - return mx.moveaxis(a=conv_result, source=-1, destination=0).reshape(x.shape[:-1] + (conv_shape,)) + return mx.moveaxis(a=conv_result, source=-1, destination=0).reshape(b + (conv_shape,)) return inner_f @mlx_funcify.register(Blockwise) From 37440ff1a8a959ae8f1edbb573e665334239d8e8 Mon Sep 17 00:00:00 2001 From: Carlos Trujillo <59846724+cetagostini@users.noreply.github.com> Date: Fri, 18 Apr 2025 21:00:15 -0500 Subject: [PATCH 45/71] WILLIAM YOU NEED TO GO ANOTHER MILE! GO ON MY MATEEEEEEE, GO PHILLIES! --- pytensor/link/mlx/dispatch/blockwise.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/pytensor/link/mlx/dispatch/blockwise.py b/pytensor/link/mlx/dispatch/blockwise.py index 9c8d67b69a..cdabaf8315 100644 --- a/pytensor/link/mlx/dispatch/blockwise.py +++ b/pytensor/link/mlx/dispatch/blockwise.py @@ -23,17 +23,20 @@ def inner_f(x, kernel): b = np.broadcast_shapes(bx, bk) - x = x.reshape(b + (t,)) - kernel = kernel.reshape(b + (h,)) + x = mx.broadcast_to(x, b + (t,)) + kernel = mx.broadcast_to(kernel, b + (h,)) x_reshaped = x.reshape(-1, t).T # shape equals to (N, B) -> N Time as batches all together kernel_squeeze = kernel.reshape(-1, h) b_prod = kernel_squeeze.shape[0] - kernel_reshaped = mx.broadcast_to(a=kernel_squeeze[None, :, None], shape=(b_prod, h, b_prod)) + print(kernel_squeeze.shape) + + print(b_prod, h, b_prod) + kernel_reshaped = mx.broadcast_to(kernel_squeeze[:, :, None], shape=(b_prod, h, b_prod)) conv_result = mx.conv1d(x_reshaped[None, :, :], kernel_reshaped, stride=1, padding=0, dilation=1) _, conv_shape, _ = conv_result.shape - return mx.moveaxis(a=conv_result, source=-1, destination=0).reshape(b + (conv_shape,)) + mx.moveaxis(conv_result, source=-1, destination=0).reshape(b + (conv_shape,)) return inner_f @mlx_funcify.register(Blockwise) From 4e4923fa6d02a24b25f473690bee5aaf839abfaf Mon Sep 17 00:00:00 2001 From: Carlos Trujillo <59846724+cetagostini@users.noreply.github.com> Date: Fri, 18 Apr 2025 21:03:02 -0500 Subject: [PATCH 46/71] RETURN, WHAT A SHAME! Sad times are coming. --- pytensor/link/mlx/dispatch/blockwise.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytensor/link/mlx/dispatch/blockwise.py b/pytensor/link/mlx/dispatch/blockwise.py index cdabaf8315..5393483f20 100644 --- a/pytensor/link/mlx/dispatch/blockwise.py +++ b/pytensor/link/mlx/dispatch/blockwise.py @@ -36,7 +36,7 @@ def inner_f(x, kernel): kernel_reshaped = mx.broadcast_to(kernel_squeeze[:, :, None], shape=(b_prod, h, b_prod)) conv_result = mx.conv1d(x_reshaped[None, :, :], kernel_reshaped, stride=1, padding=0, dilation=1) _, conv_shape, _ = conv_result.shape - mx.moveaxis(conv_result, source=-1, destination=0).reshape(b + (conv_shape,)) + return mx.moveaxis(conv_result, source=-1, destination=0).reshape(b + (conv_shape,)) return inner_f @mlx_funcify.register(Blockwise) From 6b27dc4bbebd787b102b1f28b3a6af8d7013cdd9 Mon Sep 17 00:00:00 2001 From: Carlos Trujillo <59846724+cetagostini@users.noreply.github.com> Date: Fri, 18 Apr 2025 22:15:05 -0500 Subject: [PATCH 47/71] AI COULD BE COOL? OR WE ARE JUST FUCKING AROUND? --- pytensor/link/mlx/dispatch/blockwise.py | 34 ++++++++++++++----------- 1 file changed, 19 insertions(+), 15 deletions(-) diff --git a/pytensor/link/mlx/dispatch/blockwise.py b/pytensor/link/mlx/dispatch/blockwise.py index 5393483f20..cf9b0fa830 100644 --- a/pytensor/link/mlx/dispatch/blockwise.py +++ b/pytensor/link/mlx/dispatch/blockwise.py @@ -18,25 +18,29 @@ def blockwise_conv1d(op, node, **kwargs): # raise NotImplementedError("Only 1D batches are supported for conv1d") def inner_f(x, kernel): - *bx, t = x.shape - *bk, h = kernel.shape + # 1) Validate shapes + B, T = x.shape + Bk, K = kernel.shape + if B != Bk: + raise ValueError(f"Batch mismatch: x has {B}, kernels has {Bk}") - b = np.broadcast_shapes(bx, bk) + # 2) Reshape x so that 'channels' = B, batch size = 1 + # → input shape (N=1, H=T, C_in=B) + x_in = x.T[None, :, :] # shape (1, T, B) - x = mx.broadcast_to(x, b + (t,)) - kernel = mx.broadcast_to(kernel, b + (h,)) + # 3) Build weight array of shape (C_out=B, H_f=K, C_in=1) + # groups = B will slice C_in into B single-channel groups + w = kernel[:, :, None] # shape (B, K, 1) - x_reshaped = x.reshape(-1, t).T # shape equals to (N, B) -> N Time as batches all together - kernel_squeeze = kernel.reshape(-1, h) - b_prod = kernel_squeeze.shape[0] + # 4) Convolve with one group per sequence + y = mx.conv1d(x_in, w, + stride=1, + padding=0, + dilation=1, + groups=B) - print(kernel_squeeze.shape) - - print(b_prod, h, b_prod) - kernel_reshaped = mx.broadcast_to(kernel_squeeze[:, :, None], shape=(b_prod, h, b_prod)) - conv_result = mx.conv1d(x_reshaped[None, :, :], kernel_reshaped, stride=1, padding=0, dilation=1) - _, conv_shape, _ = conv_result.shape - return mx.moveaxis(conv_result, source=-1, destination=0).reshape(b + (conv_shape,)) + # 5) y has shape (1, T - K + 1, B); drop the batch axis and transpose + return y[0].T # final shape (B, T - K + 1) return inner_f @mlx_funcify.register(Blockwise) From e308f838086042ef8e6551617962699c9cbfcdd3 Mon Sep 17 00:00:00 2001 From: Carlos Trujillo <59846724+cetagostini@users.noreply.github.com> Date: Fri, 18 Apr 2025 22:28:31 -0500 Subject: [PATCH 48/71] AI RULES BABY MY MATE --- pytensor/link/mlx/dispatch/blockwise.py | 31 +++++++++++++------------ 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/pytensor/link/mlx/dispatch/blockwise.py b/pytensor/link/mlx/dispatch/blockwise.py index cf9b0fa830..a30203f17f 100644 --- a/pytensor/link/mlx/dispatch/blockwise.py +++ b/pytensor/link/mlx/dispatch/blockwise.py @@ -18,29 +18,30 @@ def blockwise_conv1d(op, node, **kwargs): # raise NotImplementedError("Only 1D batches are supported for conv1d") def inner_f(x, kernel): - # 1) Validate shapes B, T = x.shape Bk, K = kernel.shape if B != Bk: raise ValueError(f"Batch mismatch: x has {B}, kernels has {Bk}") - # 2) Reshape x so that 'channels' = B, batch size = 1 - # → input shape (N=1, H=T, C_in=B) - x_in = x.T[None, :, :] # shape (1, T, B) + # 1) Flip each kernel for true convolution + kernels_flipped = kernel[:, ::-1] # shape (B, K) - # 3) Build weight array of shape (C_out=B, H_f=K, C_in=1) - # groups = B will slice C_in into B single-channel groups - w = kernel[:, :, None] # shape (B, K, 1) + # 2) Reshape input into (N=1, H=T, C_in=B) + x_in = x.T[None, :, :] - # 4) Convolve with one group per sequence - y = mx.conv1d(x_in, w, - stride=1, - padding=0, - dilation=1, - groups=B) + # 3) Build weight tensor of shape (C_out=B, H_f=K, C_in=1) + w = kernels_flipped[:, :, None] - # 5) y has shape (1, T - K + 1, B); drop the batch axis and transpose - return y[0].T # final shape (B, T - K + 1) + # 4) Convolve with one group per channel → valid mode + y = mx.conv1d( + x_in, w, + stride=1, + padding=0, + dilation=1, + groups=B + ) + # y: (1, T-K+1, B) → drop batch and transpose to (B, T-K+1) + return y[0].T return inner_f @mlx_funcify.register(Blockwise) From 3744a180db6c84474168fe3819f9ddb8054d2099 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Fri, 18 Apr 2025 23:31:18 -0400 Subject: [PATCH 49/71] test conv1d case --- tests/link/mlx/test_blockwise.py | 64 ++++++++++++++++++++++++++++++++ 1 file changed, 64 insertions(+) create mode 100644 tests/link/mlx/test_blockwise.py diff --git a/tests/link/mlx/test_blockwise.py b/tests/link/mlx/test_blockwise.py new file mode 100644 index 0000000000..9b271186c9 --- /dev/null +++ b/tests/link/mlx/test_blockwise.py @@ -0,0 +1,64 @@ +import numpy as np + +import pytensor.tensor as pt +from pytensor.tensor import tensor +from pytensor.tensor.blockwise import Blockwise +from pytensor.tensor.math import Dot +from tests.link.mlx.test_basic import compare_mlx_and_py + + +# Equivalent blockwise to matmul but with dumb signature +odd_matmul = Blockwise(Dot(), signature="(i00,i01),(i10,i11)->(o00,o01)") + + +# @pytest.mark.parametrize("matmul_op", (matmul, odd_matmul)) +# def test_matmul(matmul_op): +# rng = np.random.default_rng(14) +# a = tensor("a", shape=(2, 3, 5)) +# b = tensor("b", shape=(2, 5, 3)) +# test_values = [ +# rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (a, b) +# ] +# +# out = matmul_op(a, b) +# assert isinstance(out.owner.op, Blockwise) +# fn, _ = compare_mlx_and_py([a, b], [out], test_values) +# +## Check we are not adding any unnecessary stuff +# jaxpr = str(jax.make_jaxpr(fn.vm.jit_fn)(*test_values)) +# jaxpr = jaxpr.replace("name=jax_funcified_fgraph", "name=matmul") +# expected_jaxpr = str(jax.make_jaxpr(jax.jit(jax.numpy.matmul))(*test_values)) +# assert jaxpr == expected_jaxpr + + +# conv1d +# (2, 100) +# (8, 100) +# mode = valid + + +def test_blockwise_conv1d(): + rng = np.random.default_rng(14) + a = tensor("a", shape=(2, 100)) + b = tensor("b", shape=(2, 8)) + + # a_test = np.broadcast_to(np.arange(100), (2, 100)) + a_test = rng.normal(size=(2, 100)) + b_test = rng.normal(size=(2, 8)) + # b_test = np.concatenate( + # [ + # np.ones((1, 8)), + # np.zeros((1, 8)), + # np.zeros((1, 8)), + # np.array([1, 0, 0, 0, 0, 0, 0, 0]).reshape(1, 8), + # np.array([1, 0, 0, 0, 0, 0, 0, 0]).reshape(1, 8), + # ], + # axis=0, + # ) + + test_values = [a_test, b_test] + + out = pt.signal.convolve1d(a, b, mode="valid") + + # assert isinstance(out.owner.op, Blockwise) + compare_mlx_and_py([a, b], [out], test_values, must_be_device_array=True) From b41cab00f96ce8af6de4a642daef5ce3b7590c52 Mon Sep 17 00:00:00 2001 From: Carlos Trujillo <59846724+cetagostini@users.noreply.github.com> Date: Fri, 18 Apr 2025 22:40:40 -0500 Subject: [PATCH 50/71] I'm going for pizzas, it was an incredible day! --- pytensor/link/mlx/dispatch/blockwise.py | 105 ++++++++++++++++++------ 1 file changed, 82 insertions(+), 23 deletions(-) diff --git a/pytensor/link/mlx/dispatch/blockwise.py b/pytensor/link/mlx/dispatch/blockwise.py index a30203f17f..95fc9d0f9a 100644 --- a/pytensor/link/mlx/dispatch/blockwise.py +++ b/pytensor/link/mlx/dispatch/blockwise.py @@ -7,49 +7,108 @@ import numpy as np def blockwise_conv1d(op, node, **kwargs): - if op.core_op.mode != "valid": - raise NotImplementedError("Only 'valid' mode is supported for conv1d") - # batches_ndim = op.batch_ndim(node) - # if batches_ndim != 1: - # raise NotImplementedError("Only 1D batches are supported for conv1d") + # if op.core_op.mode != "valid": + # raise NotImplementedError("Only 'valid' mode is supported for conv1d") - # _, kernel = node.inputs - # if not all(kernel.type.broadcastable[:batches_ndim]): - # raise NotImplementedError("Only 1D batches are supported for conv1d") + # def inner_f(x, kernel): + # B, T = x.shape + # Bk, K = kernel.shape + # if B != Bk: + # raise ValueError(f"Batch mismatch: x has {B}, kernels has {Bk}") + + # # 1) Flip each kernel for true convolution + # kernels_flipped = kernel[:, ::-1] # shape (B, K) + + # # 2) Reshape input into (N=1, H=T, C_in=B) + # x_in = x.T[None, :, :] + + # # 3) Build weight tensor of shape (C_out=B, H_f=K, C_in=1) + # w = kernels_flipped[:, :, None] + + # # 4) Convolve with one group per channel → valid mode + # y = mx.conv1d( + # x_in, w, + # stride=1, + # padding=0, + # dilation=1, + # groups=B + # ) + # # y: (1, T-K+1, B) → drop batch and transpose to (B, T-K+1) + # return y[0].T - def inner_f(x, kernel): + def batched_conv1d( + x: mx.array, + kernels: mx.array, + mode: str = op.core_op.mode, + stride: int = 1, + dilation: int = 1) -> mx.array: + """ + Apply B separate 1D convolutions (full or valid) to B sequences in parallel. + + Parameters + ---------- + x : array of shape (B, T) + B sequences of length T. + kernels : array of shape (B, K) + B kernels of length K. + mode : {"valid", "full"} + "valid" → no padding, output length = T - K + 1 + "full" → zero‑pad so output length = T + K - 1 + stride : int, convolution stride (default=1) + dilation : int, convolution dilation (default=1) + + Returns + ------- + out : array of shape (B, L) + where L = + - T - K + 1 if mode="valid" + - T + K - 1 if mode="full" + """ + # --- 1) shape checks --- B, T = x.shape - Bk, K = kernel.shape + Bk, K = kernels.shape if B != Bk: raise ValueError(f"Batch mismatch: x has {B}, kernels has {Bk}") - # 1) Flip each kernel for true convolution - kernels_flipped = kernel[:, ::-1] # shape (B, K) + # --- 2) flip kernels for convolution --- + kernels_flipped = kernels[:, ::-1] # shape (B, K) + + # --- 3) decide padding --- + if mode == "valid": + pad = 0 + elif mode == "full": + pad = (K - 1) * dilation + else: + raise ValueError(f"Unsupported mode {mode!r}: choose 'valid' or 'full'") - # 2) Reshape input into (N=1, H=T, C_in=B) - x_in = x.T[None, :, :] + # --- 4) reshape into MLX conv1d form --- + # input: (N=1, H=T, C_in=B) + x_in = x.T[None, :, :] - # 3) Build weight tensor of shape (C_out=B, H_f=K, C_in=1) - w = kernels_flipped[:, :, None] + # weight: (C_out=B, H_f=K, C_in=1) + w = kernels_flipped[:, :, None] - # 4) Convolve with one group per channel → valid mode + # --- 5) run grouped conv1d --- y = mx.conv1d( x_in, w, - stride=1, - padding=0, - dilation=1, + stride=stride, + padding=pad, + dilation=dilation, groups=B ) - # y: (1, T-K+1, B) → drop batch and transpose to (B, T-K+1) + # y shape: (1, H_out, B) + + # --- 6) return shape (B, H_out) --- return y[0].T - return inner_f + + return batched_conv1d @mlx_funcify.register(Blockwise) def funcify_Blockwise(op: Blockwise, node, **kwargs): if isinstance(op.core_op, Conv1d): return blockwise_conv1d(op, node, **kwargs) - core_f = mlx_funcify(op.core_op) + core_f = mlx_funcify(op.core_op, node) def blockwise_f(*inputs): return blockwise_f(*inputs) From 9766975453a5d43518b9b63567d1eac722e04d71 Mon Sep 17 00:00:00 2001 From: Carlos Trujillo <59846724+cetagostini@users.noreply.github.com> Date: Fri, 18 Apr 2025 23:32:07 -0500 Subject: [PATCH 51/71] SUUUUUUUUU!!!!!! LIFE IS GOING WELL. MLX FOR MEDIA MIX MODELS BAY A shout out for the fathers of the day! Co-Authored-By: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com> Co-Authored-By: Jesse Grabowski <48652735+jessegrabowski@users.noreply.github.com> --- pytensor/link/mlx/dispatch/blockwise.py | 20 ++++++++++++-------- pytensor/link/mlx/dispatch/elemwise.py | 4 ++-- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/pytensor/link/mlx/dispatch/blockwise.py b/pytensor/link/mlx/dispatch/blockwise.py index 95fc9d0f9a..00f774fe08 100644 --- a/pytensor/link/mlx/dispatch/blockwise.py +++ b/pytensor/link/mlx/dispatch/blockwise.py @@ -105,20 +105,24 @@ def batched_conv1d( @mlx_funcify.register(Blockwise) def funcify_Blockwise(op: Blockwise, node, **kwargs): + # 1) If it's a Conv1d Blockwise, use the custom implementation if isinstance(op.core_op, Conv1d): return blockwise_conv1d(op, node, **kwargs) - - core_f = mlx_funcify(op.core_op, node) - def blockwise_f(*inputs): - return blockwise_f(*inputs) + # 2) Otherwise, get the core python function for this Blockwise core_node = op._create_dummy_core_node(node.inputs) - core_f = mlx_funcify(op.core_op, core_node) - blockwise_f = core_f - for i in range(op.batch_ndim(node)): - blockwise_f = mx.vmap(blockwise_f) + # 3) Determine how many inputs correspond to batch dimensions + n_batch = op.batch_ndim(node) + + # 4) Build in_axes: map only the first n_batch args, keep the rest static + in_axes = tuple(0 if i < n_batch else None for i in range(len(node.inputs))) + + # 5) Vectorize (vmap) with in_axes + blockwise_f = mx.vmap(core_f, in_axes=in_axes) + + # 6) Return the mapped function def blockwise_fun(*inputs): return blockwise_f(*inputs) diff --git a/pytensor/link/mlx/dispatch/elemwise.py b/pytensor/link/mlx/dispatch/elemwise.py index c71de48b12..926da572e0 100644 --- a/pytensor/link/mlx/dispatch/elemwise.py +++ b/pytensor/link/mlx/dispatch/elemwise.py @@ -43,12 +43,12 @@ def careduce(x): axis = list(range(x.ndim)) if acc_dtype is None: - acc_dtype = x.dtype.type + acc_dtype = x.dtype if op_nfunc_spec: mlx_op = getattr(mx, op_nfunc_spec[0]) return mlx_op(x, axis=axis) - return mlx_op(x, axis=axis).astype(acc_dtype) + # return mlx_op(x, axis=axis).astype(acc_dtype) # The PyTensor `Op` didn't tell us which NumPy equivalent to use (or # there isn't one), so we use this fallback approach From 5ffc5ef8fc6a933ff55038ce1f84c59be7876ba0 Mon Sep 17 00:00:00 2001 From: Carlos Trujillo <59846724+cetagostini@users.noreply.github.com> Date: Fri, 18 Apr 2025 23:33:44 -0500 Subject: [PATCH 52/71] pre-commit --- pytensor/link/mlx/dispatch/blockwise.py | 58 ++++++------------------- 1 file changed, 14 insertions(+), 44 deletions(-) diff --git a/pytensor/link/mlx/dispatch/blockwise.py b/pytensor/link/mlx/dispatch/blockwise.py index 00f774fe08..74bb018a68 100644 --- a/pytensor/link/mlx/dispatch/blockwise.py +++ b/pytensor/link/mlx/dispatch/blockwise.py @@ -4,44 +4,19 @@ from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.signal.conv import Conv1d -import numpy as np def blockwise_conv1d(op, node, **kwargs): - # if op.core_op.mode != "valid": - # raise NotImplementedError("Only 'valid' mode is supported for conv1d") - - # def inner_f(x, kernel): - # B, T = x.shape - # Bk, K = kernel.shape - # if B != Bk: - # raise ValueError(f"Batch mismatch: x has {B}, kernels has {Bk}") - - # # 1) Flip each kernel for true convolution - # kernels_flipped = kernel[:, ::-1] # shape (B, K) - - # # 2) Reshape input into (N=1, H=T, C_in=B) - # x_in = x.T[None, :, :] - - # # 3) Build weight tensor of shape (C_out=B, H_f=K, C_in=1) - # w = kernels_flipped[:, :, None] - - # # 4) Convolve with one group per channel → valid mode - # y = mx.conv1d( - # x_in, w, - # stride=1, - # padding=0, - # dilation=1, - # groups=B - # ) - # # y: (1, T-K+1, B) → drop batch and transpose to (B, T-K+1) - # return y[0].T - + """ + Custom implementation of Blockwise.conv1d for MLX. + """ + def batched_conv1d( - x: mx.array, - kernels: mx.array, - mode: str = op.core_op.mode, - stride: int = 1, - dilation: int = 1) -> mx.array: + x: mx.array, + kernels: mx.array, + mode: str = op.core_op.mode, + stride: int = 1, + dilation: int = 1, + ) -> mx.array: """ Apply B separate 1D convolutions (full or valid) to B sequences in parallel. @@ -53,14 +28,14 @@ def batched_conv1d( B kernels of length K. mode : {"valid", "full"} "valid" → no padding, output length = T - K + 1 - "full" → zero‑pad so output length = T + K - 1 + "full" → zero-pad so output length = T + K - 1 stride : int, convolution stride (default=1) dilation : int, convolution dilation (default=1) Returns ------- out : array of shape (B, L) - where L = + where L = - T - K + 1 if mode="valid" - T + K - 1 if mode="full" """ @@ -89,13 +64,7 @@ def batched_conv1d( w = kernels_flipped[:, :, None] # --- 5) run grouped conv1d --- - y = mx.conv1d( - x_in, w, - stride=stride, - padding=pad, - dilation=dilation, - groups=B - ) + y = mx.conv1d(x_in, w, stride=stride, padding=pad, dilation=dilation, groups=B) # y shape: (1, H_out, B) # --- 6) return shape (B, H_out) --- @@ -103,6 +72,7 @@ def batched_conv1d( return batched_conv1d + @mlx_funcify.register(Blockwise) def funcify_Blockwise(op: Blockwise, node, **kwargs): # 1) If it's a Conv1d Blockwise, use the custom implementation From 597f84ef2140b9e818c67aed6520796d9a25b2b7 Mon Sep 17 00:00:00 2001 From: Carlos Trujillo <59846724+cetagostini@users.noreply.github.com> Date: Sat, 19 Apr 2025 11:32:15 -0500 Subject: [PATCH 53/71] Almost working --- pytensor/link/mlx/dispatch/elemwise.py | 125 +++++++++++++++---------- pytensor/link/mlx/dispatch/math.py | 4 +- 2 files changed, 80 insertions(+), 49 deletions(-) diff --git a/pytensor/link/mlx/dispatch/elemwise.py b/pytensor/link/mlx/dispatch/elemwise.py index 926da572e0..7e7a27c5ab 100644 --- a/pytensor/link/mlx/dispatch/elemwise.py +++ b/pytensor/link/mlx/dispatch/elemwise.py @@ -5,6 +5,35 @@ from pytensor.tensor.elemwise import CAReduce, DimShuffle from pytensor.tensor.special import Softmax, SoftmaxGrad +from pytensor.scalar.basic import ( + AND, + EQ, + GE, + GT, + LE, + LT, + NEQ, + OR, + Abs, + Add, + Cast, + Cos, + Exp, + Log, + Log1p, + Mul, + Neg, + Pow, + ScalarMaximum, + ScalarMinimum, + Sign, + Sin, + Sqr, + Sqrt, + Sub, + Switch, + TrueDiv, +) @mlx_funcify.register(DimShuffle) def mlx_funcify_DimShuffle(op, **kwargs): @@ -21,55 +50,57 @@ def dimshuffle(x): return dimshuffle +@mlx_funcify.register(DimShuffle) +def mlx_funcify_DimShuffle(op, **kwargs): + def dimshuffle(x): + res = mx.transpose(x, op.transposition) + shape = list(res.shape[: len(op.shuffle)]) + for augm in op.augment: + shape.insert(augm, 1) + return mx.reshape(res, shape) + return dimshuffle + @mlx_funcify.register(CAReduce) def mlx_funcify_CAReduce(op, **kwargs): - axis = op.axis - op_nfunc_spec = getattr(op, "nfunc_spec", None) - scalar_nfunc_spec = getattr(op.scalar_op, "nfunc_spec", None) - scalar_op_name = getattr(op.scalar_op, "name", None) - scalar_op_identity = getattr(op.scalar_op, "identity", None) - acc_dtype = getattr(op, "acc_dtype", None) - - def careduce(x): - nonlocal \ - axis, \ - op_nfunc_spec, \ - scalar_nfunc_spec, \ - scalar_op_name, \ - scalar_op_identity, \ - acc_dtype - - if axis is None: - axis = list(range(x.ndim)) - - if acc_dtype is None: - acc_dtype = x.dtype - - if op_nfunc_spec: - mlx_op = getattr(mx, op_nfunc_spec[0]) - return mlx_op(x, axis=axis) - # return mlx_op(x, axis=axis).astype(acc_dtype) - - # The PyTensor `Op` didn't tell us which NumPy equivalent to use (or - # there isn't one), so we use this fallback approach - if scalar_nfunc_spec: - scalar_fn_name = scalar_nfunc_spec[0] - elif scalar_op_name: - scalar_fn_name = scalar_op_name - - to_reduce = sorted(axis, reverse=True) - - if to_reduce: - raise NotImplementedError("Not implemented yet") - # In this case, we need to use the `jax.lax` function (if there - # is one), and not the `jnp` version. - mlx_op = getattr(mx, scalar_fn_name) - init_value = mx.array(scalar_op_identity, dtype=acc_dtype) - return mx.reduce(x, init_value, mlx_op, to_reduce).astype(acc_dtype) - else: - return x - - return careduce + if isinstance(op.scalar_op, Add): + + def sum(x): + return mx.sum(x, axis=op.axis) + + return sum + elif isinstance(op.scalar_op, Mul): + + def prod(x): + return mx.prod(x, axis=op.axis) + + return prod + elif isinstance(op.scalar_op, AND): + + def all(x): + return x.all(axis=op.axis) + + return all + elif isinstance(op.scalar_op, OR): + + def any(x): + return mx.any(x, axis=op.axis) + + return any + elif isinstance(op.scalar_op, ScalarMaximum): + + def max(x): + return x.max(axis=op.axis) + + return max + elif isinstance(op.scalar_op, ScalarMinimum): + + def min(x): + return x.min(axis=op.axis) + + return min + else: + raise NotImplementedError(f"MLX does not support Elemwise {op.scalar_op}") + @mlx_funcify.register(Softmax) diff --git a/pytensor/link/mlx/dispatch/math.py b/pytensor/link/mlx/dispatch/math.py index 293a7cfa0a..890a8db601 100644 --- a/pytensor/link/mlx/dispatch/math.py +++ b/pytensor/link/mlx/dispatch/math.py @@ -211,13 +211,13 @@ def any(x, y): elif isinstance(op.scalar_op, ScalarMaximum): def max(x): - return mx.max(x, axis=op.axis) + return x.max(axis=op.axis) return max elif isinstance(op.scalar_op, ScalarMinimum): def min(x): - return mx.min(x, axis=op.axis) + return x.min(axis=op.axis) return min elif isinstance(op.scalar_op, Cast): From fb8fd2f12ca6061d5b9002bdab155a382b5a842c Mon Sep 17 00:00:00 2001 From: Carlos Trujillo <59846724+cetagostini@users.noreply.github.com> Date: Wed, 23 Apr 2025 11:15:11 +0300 Subject: [PATCH 54/71] Last PR sampling working Working --- pytensor/link/mlx/dispatch/core.py | 60 +++++++++++++++++++--- pytensor/link/mlx/dispatch/elemwise.py | 70 +++++++------------------- pytensor/link/mlx/dispatch/math.py | 20 +++++--- 3 files changed, 87 insertions(+), 63 deletions(-) diff --git a/pytensor/link/mlx/dispatch/core.py b/pytensor/link/mlx/dispatch/core.py index 6985c2b656..3a0b279cd3 100644 --- a/pytensor/link/mlx/dispatch/core.py +++ b/pytensor/link/mlx/dispatch/core.py @@ -127,7 +127,7 @@ def extract_diag(x, offset=offset, axis1=axis1, axis2=axis2): # ------------------------------------------------------------------ @mlx_funcify.register(Eye) # MLX def mlx_funcify_Eye(op, **kwargs): - dtype = op.dtype + dtype = convert_dtype_to_mlx(op.dtype) def eye(N, M, k): return mx.eye(int(N), int(M), int(k), dtype=dtype) # MLX @@ -135,13 +135,56 @@ def eye(N, M, k): return eye +def convert_dtype_to_mlx(dtype_str): + """Convert PyTensor dtype strings to MLX dtype objects. + + MLX expects dtype objects rather than string literals for type conversion. + This function maps common dtype strings to their MLX equivalents. + """ + if isinstance(dtype_str, str): + if dtype_str == "bool": + return mx.bool_ + elif dtype_str == "int8": + return mx.int8 + elif dtype_str == "int16": + return mx.int16 + elif dtype_str == "int32": + return mx.int32 + elif dtype_str == "int64": + return mx.int64 + elif dtype_str == "uint8": + return mx.uint8 + elif dtype_str == "uint16": + return mx.uint16 + elif dtype_str == "uint32": + return mx.uint32 + elif dtype_str == "uint64": + return mx.uint64 + elif dtype_str == "float16": + return mx.float16 + elif dtype_str == "float32": + return mx.float32 + elif dtype_str == "float64": + return mx.float64 + elif dtype_str == "bfloat16": + return mx.bfloat16 + elif dtype_str == "complex64": + return mx.complex64 + elif dtype_str == "complex128": + return mx.complex128 + # Return as is if it's already an MLX dtype or not a recognized string + return dtype_str + + # ------------------------------------------------------------------ # MakeVector # ------------------------------------------------------------------ @mlx_funcify.register(MakeVector) # MLX def mlx_funcify_MakeVector(op, **kwargs): + dtype = convert_dtype_to_mlx(op.dtype) + def makevector(*x): - return mx.array(x, dtype=op.dtype) # MLX + return mx.array(x, dtype=dtype) # MLX return makevector @@ -175,6 +218,7 @@ def scalar_from_tensor(x): def mlx_funcify_Tri(op, node, **kwargs): # node.inputs -> N, M, k const_args = [getattr(inp, "data", None) for inp in node.inputs] + dtype = convert_dtype_to_mlx(op.dtype) def tri(*args): # Replace args with compile-time constants when available @@ -182,15 +226,17 @@ def tri(*args): arg if const_a is None else const_a for arg, const_a in zip(args, const_args, strict=True) ] - return mx.tri(*args, dtype=op.dtype) # MLX + return mx.tri(*args, dtype=dtype) # MLX return tri @mlx_funcify.register(AllocEmpty) def mlx_funcify_AllocEmpty(op, **kwargs): + dtype = convert_dtype_to_mlx(op.dtype) + def allocempty(*shape): - return mx.zeros(shape, dtype=op.dtype) + return mx.zeros(shape, dtype=dtype) return allocempty @@ -198,8 +244,10 @@ def allocempty(*shape): @mlx_funcify.register(Alloc) def mlx_funcify_Alloc(op, node, **kwargs): def alloc(x, *shape): - res = mx.broadcast_to(x, shape) - Alloc._check_runtime_broadcast(node, mx.array(x), res.shape) + # Convert x to an MLX array with the correct dtype if it's a scalar + x_array = mx.array(x) + res = mx.broadcast_to(x_array, shape) + Alloc._check_runtime_broadcast(node, x_array, res.shape) return res return alloc diff --git a/pytensor/link/mlx/dispatch/elemwise.py b/pytensor/link/mlx/dispatch/elemwise.py index 7e7a27c5ab..aaf04968de 100644 --- a/pytensor/link/mlx/dispatch/elemwise.py +++ b/pytensor/link/mlx/dispatch/elemwise.py @@ -1,65 +1,37 @@ import mlx.core as mx +import numpy as np from pytensor.link.mlx.dispatch.basic import mlx_funcify +from pytensor.link.mlx.dispatch.core import convert_dtype_to_mlx from pytensor.scalar import Softplus -from pytensor.tensor.elemwise import CAReduce, DimShuffle -from pytensor.tensor.special import Softmax, SoftmaxGrad - from pytensor.scalar.basic import ( AND, - EQ, - GE, - GT, - LE, - LT, - NEQ, OR, - Abs, Add, Cast, - Cos, - Exp, - Log, - Log1p, Mul, - Neg, - Pow, - ScalarMaximum, - ScalarMinimum, - Sign, - Sin, - Sqr, - Sqrt, - Sub, - Switch, - TrueDiv, ) +from pytensor.tensor.elemwise import CAReduce, DimShuffle +from pytensor.tensor.special import Softmax, SoftmaxGrad + @mlx_funcify.register(DimShuffle) def mlx_funcify_DimShuffle(op, **kwargs): def dimshuffle(x): + # Convert scalar to array if needed + if isinstance(x, int | float) or ( + isinstance(x, np.number) and not isinstance(x, np.ndarray) + ): + x = mx.array(x) res = mx.transpose(x, op.transposition) - shape = list(res.shape[: len(op.shuffle)]) - for augm in op.augment: shape.insert(augm, 1) - return mx.reshape(res, shape) return dimshuffle -@mlx_funcify.register(DimShuffle) -def mlx_funcify_DimShuffle(op, **kwargs): - def dimshuffle(x): - res = mx.transpose(x, op.transposition) - shape = list(res.shape[: len(op.shuffle)]) - for augm in op.augment: - shape.insert(augm, 1) - return mx.reshape(res, shape) - return dimshuffle - @mlx_funcify.register(CAReduce) def mlx_funcify_CAReduce(op, **kwargs): if isinstance(op.scalar_op, Add): @@ -86,23 +58,10 @@ def any(x): return mx.any(x, axis=op.axis) return any - elif isinstance(op.scalar_op, ScalarMaximum): - - def max(x): - return x.max(axis=op.axis) - - return max - elif isinstance(op.scalar_op, ScalarMinimum): - - def min(x): - return x.min(axis=op.axis) - - return min else: raise NotImplementedError(f"MLX does not support Elemwise {op.scalar_op}") - @mlx_funcify.register(Softmax) def mlx_funcify_Softmax(op, **kwargs): axis = op.axis @@ -142,3 +101,12 @@ def softplus(x): ) return softplus + + +@mlx_funcify.register(Cast) +def mlx_funcify_Cast(op, **kwargs): + def cast(x): + dtype = convert_dtype_to_mlx(op.scalar_op.o_type.dtype) + return x.astype(dtype) + + return cast diff --git a/pytensor/link/mlx/dispatch/math.py b/pytensor/link/mlx/dispatch/math.py index 890a8db601..153f049b0e 100644 --- a/pytensor/link/mlx/dispatch/math.py +++ b/pytensor/link/mlx/dispatch/math.py @@ -1,6 +1,7 @@ import mlx.core as mx -from pytensor.link.mlx.dispatch import mlx_funcify +from pytensor.link.mlx.dispatch import mlx_funcify, mlx_typify +from pytensor.link.mlx.dispatch.core import convert_dtype_to_mlx from pytensor.scalar import Softplus from pytensor.scalar.basic import ( AND, @@ -36,6 +37,12 @@ from pytensor.tensor.math import Dot +@mlx_typify.register(int) +@mlx_typify.register(float) +def mlx_typify_python_scalar(data, **kwargs): + return mx.array(data) + + @mlx_funcify.register(Dot) def mlx_funcify_Dot(op, **kwargs): def dot(x, y): @@ -210,20 +217,21 @@ def any(x, y): return any elif isinstance(op.scalar_op, ScalarMaximum): - def max(x): - return x.max(axis=op.axis) + def max(x, y): + return mx.maximum(x, y) return max elif isinstance(op.scalar_op, ScalarMinimum): - def min(x): - return x.min(axis=op.axis) + def min(x, y): + return mx.minimum(x, y) return min elif isinstance(op.scalar_op, Cast): def cast(x): - return mx.cast(x, op.dtype) + dtype = convert_dtype_to_mlx(op.scalar_op.o_type.dtype) + return x.astype(dtype) return cast elif isinstance(op.scalar_op, Sign): From 6a2b7743dff60229d58a7794c87c38d96aa2a6b9 Mon Sep 17 00:00:00 2001 From: Carlos Trujillo <59846724+cetagostini@users.noreply.github.com> Date: Mon, 2 Jun 2025 20:45:57 +0200 Subject: [PATCH 55/71] Requested changes by Ricardo --- pytensor/link/mlx/dispatch/basic.py | 2 +- pytensor/link/mlx/dispatch/core.py | 68 ++--- pytensor/link/mlx/dispatch/elemwise.py | 73 +++-- pytensor/link/mlx/dispatch/math.py | 391 +++++++++++++++---------- tests/link/mlx/test_subtensor.py | 276 +++++++++++++++++ 5 files changed, 593 insertions(+), 217 deletions(-) create mode 100644 tests/link/mlx/test_subtensor.py diff --git a/pytensor/link/mlx/dispatch/basic.py b/pytensor/link/mlx/dispatch/basic.py index d0b3d451f5..c9f7f2fa3d 100644 --- a/pytensor/link/mlx/dispatch/basic.py +++ b/pytensor/link/mlx/dispatch/basic.py @@ -18,7 +18,6 @@ def mlx_typify(data, **kwargs): @mlx_typify.register(np.ndarray) -@mlx_typify.register(mx.array) def mlx_typify_tensor(data, dtype=None, **kwargs): return mx.array(data, dtype=dtype) @@ -26,6 +25,7 @@ def mlx_typify_tensor(data, dtype=None, **kwargs): @mlx_typify.register(slice) @mlx_typify.register(NoneType) @mlx_typify.register(np.number) +@mlx_typify.register(mx.array) def mlx_typify_no_conversion_needed(data, **kwargs): return data diff --git a/pytensor/link/mlx/dispatch/core.py b/pytensor/link/mlx/dispatch/core.py index 3a0b279cd3..a1f128218b 100644 --- a/pytensor/link/mlx/dispatch/core.py +++ b/pytensor/link/mlx/dispatch/core.py @@ -13,10 +13,10 @@ import warnings -import mlx.core as mx # MLX +import mlx.core as mx import numpy as np -from pytensor.link.mlx.dispatch.basic import mlx_funcify # MLX +from pytensor.link.mlx.dispatch.basic import mlx_funcify from pytensor.tensor import get_vector_length from pytensor.tensor.basic import ( Alloc, @@ -34,28 +34,22 @@ from pytensor.tensor.exceptions import NotScalarConstantError -# ------------------------------------------------------------------ -# Join -# ------------------------------------------------------------------ -@mlx_funcify.register(Join) # MLX +@mlx_funcify.register(Join) def mlx_funcify_Join(op, **kwargs): def join(axis, *tensors): view = op.view if (view != -1) and all( - tensors[i].shape[axis] == 0 # MLX + tensors[i].shape[axis] == 0 for i in list(range(view)) + list(range(view + 1, len(tensors))) ): return tensors[view] - return mx.concatenate(tensors, axis=axis) # MLX + return mx.concatenate(tensors, axis=axis) return join -# ------------------------------------------------------------------ -# Split -# ------------------------------------------------------------------ -@mlx_funcify.register(Split) # MLX +@mlx_funcify.register(Split) def mlx_funcify_Split(op: Split, node, **kwargs): _, axis_sym, splits_sym = node.inputs @@ -90,7 +84,7 @@ def split(x, axis, splits): cumsum_splits = np.cumsum(splits[:-1]) else: # dynamic - keep in graph - splits_arr = mx.array(splits) # MLX + splits_arr = mx.array(splits) cumsum_splits = mx.cumsum( splits_arr[:-1] ).tolist() # python list for mx.split @@ -104,33 +98,29 @@ def split(x, axis, splits): if np.any(np.asarray(splits) < 0): raise ValueError("Split sizes cannot be negative.") - return mx.split(x, cumsum_splits, axis=axis) # MLX + return mx.split(x, cumsum_splits, axis=axis) return split -# ------------------------------------------------------------------ -# ExtractDiag -# ------------------------------------------------------------------ -@mlx_funcify.register(ExtractDiag) # MLX + +@mlx_funcify.register(ExtractDiag) def mlx_funcify_ExtractDiag(op, **kwargs): offset, axis1, axis2 = op.offset, op.axis1, op.axis2 def extract_diag(x, offset=offset, axis1=axis1, axis2=axis2): - return mx.diagonal(x, offset=offset, axis1=axis1, axis2=axis2) # MLX + return mx.diagonal(x, offset=offset, axis1=axis1, axis2=axis2) return extract_diag -# ------------------------------------------------------------------ -# Eye -# ------------------------------------------------------------------ -@mlx_funcify.register(Eye) # MLX + +@mlx_funcify.register(Eye) def mlx_funcify_Eye(op, **kwargs): dtype = convert_dtype_to_mlx(op.dtype) def eye(N, M, k): - return mx.eye(int(N), int(M), int(k), dtype=dtype) # MLX + return mx.eye(int(N), int(M), int(k), dtype=dtype) return eye @@ -176,23 +166,19 @@ def convert_dtype_to_mlx(dtype_str): return dtype_str -# ------------------------------------------------------------------ -# MakeVector -# ------------------------------------------------------------------ -@mlx_funcify.register(MakeVector) # MLX + +@mlx_funcify.register(MakeVector) def mlx_funcify_MakeVector(op, **kwargs): dtype = convert_dtype_to_mlx(op.dtype) def makevector(*x): - return mx.array(x, dtype=dtype) # MLX + return mx.array(x, dtype=dtype) return makevector -# ------------------------------------------------------------------ -# TensorFromScalar (identity for MLX) -# ------------------------------------------------------------------ -@mlx_funcify.register(TensorFromScalar) # MLX + +@mlx_funcify.register(TensorFromScalar) def mlx_funcify_TensorFromScalar(op, **kwargs): def tensor_from_scalar(x): return x # already an MLX array / scalar @@ -200,21 +186,17 @@ def tensor_from_scalar(x): return tensor_from_scalar -# ------------------------------------------------------------------ -# ScalarFromTensor -# ------------------------------------------------------------------ -@mlx_funcify.register(ScalarFromTensor) # MLX + +@mlx_funcify.register(ScalarFromTensor) def mlx_funcify_ScalarFromTensor(op, **kwargs): def scalar_from_tensor(x): - return mx.array(x).reshape(-1)[0] # MLX + return mx.array(x).reshape(-1)[0] return scalar_from_tensor -# ------------------------------------------------------------------ -# Tri -# ------------------------------------------------------------------ -@mlx_funcify.register(Tri) # MLX + +@mlx_funcify.register(Tri) def mlx_funcify_Tri(op, node, **kwargs): # node.inputs -> N, M, k const_args = [getattr(inp, "data", None) for inp in node.inputs] @@ -226,7 +208,7 @@ def tri(*args): arg if const_a is None else const_a for arg, const_a in zip(args, const_args, strict=True) ] - return mx.tri(*args, dtype=dtype) # MLX + return mx.tri(*args, dtype=dtype) return tri diff --git a/pytensor/link/mlx/dispatch/elemwise.py b/pytensor/link/mlx/dispatch/elemwise.py index aaf04968de..d2e8274c68 100644 --- a/pytensor/link/mlx/dispatch/elemwise.py +++ b/pytensor/link/mlx/dispatch/elemwise.py @@ -1,5 +1,6 @@ import mlx.core as mx import numpy as np +from functools import singledispatch from pytensor.link.mlx.dispatch.basic import mlx_funcify from pytensor.link.mlx.dispatch.core import convert_dtype_to_mlx @@ -10,6 +11,8 @@ Add, Cast, Mul, + ScalarMaximum, + ScalarMinimum, ) from pytensor.tensor.elemwise import CAReduce, DimShuffle from pytensor.tensor.special import Softmax, SoftmaxGrad @@ -32,34 +35,64 @@ def dimshuffle(x): return dimshuffle -@mlx_funcify.register(CAReduce) -def mlx_funcify_CAReduce(op, **kwargs): - if isinstance(op.scalar_op, Add): +# Second-level dispatch for scalar operations in CAReduce +@singledispatch +def mlx_funcify_CAReduce_scalar_op(scalar_op): + raise NotImplementedError(f"MLX does not support CAReduce with scalar op {scalar_op}") + + +@mlx_funcify_CAReduce_scalar_op.register(Add) +def _(scalar_op): + def sum_reduce(x, axis): + return mx.sum(x, axis=axis) + return sum_reduce + + +@mlx_funcify_CAReduce_scalar_op.register(Mul) +def _(scalar_op): + def prod_reduce(x, axis): + return mx.prod(x, axis=axis) + return prod_reduce - def sum(x): - return mx.sum(x, axis=op.axis) - return sum - elif isinstance(op.scalar_op, Mul): +@mlx_funcify_CAReduce_scalar_op.register(AND) +def _(scalar_op): + def all_reduce(x, axis): + return x.all(axis=axis) + return all_reduce - def prod(x): - return mx.prod(x, axis=op.axis) - return prod - elif isinstance(op.scalar_op, AND): +@mlx_funcify_CAReduce_scalar_op.register(OR) +def _(scalar_op): + def any_reduce(x, axis): + return mx.any(x, axis=axis) + return any_reduce - def all(x): - return x.all(axis=op.axis) - return all - elif isinstance(op.scalar_op, OR): +@mlx_funcify_CAReduce_scalar_op.register(ScalarMaximum) +def _(scalar_op): + def max_reduce(x, axis): + return mx.max(x, axis=axis) + return max_reduce - def any(x): - return mx.any(x, axis=op.axis) - return any - else: - raise NotImplementedError(f"MLX does not support Elemwise {op.scalar_op}") +@mlx_funcify_CAReduce_scalar_op.register(ScalarMinimum) +def _(scalar_op): + def min_reduce(x, axis): + return mx.min(x, axis=axis) + return min_reduce + + +@mlx_funcify.register(CAReduce) +def mlx_funcify_CAReduce(op, **kwargs): + # Dispatch to the appropriate scalar op handler + scalar_reduce_fn = mlx_funcify_CAReduce_scalar_op(op.scalar_op) + axis = op.axis + + def reduce(x): + return scalar_reduce_fn(x, axis) + + return reduce @mlx_funcify.register(Softmax) diff --git a/pytensor/link/mlx/dispatch/math.py b/pytensor/link/mlx/dispatch/math.py index 153f049b0e..840344bfbc 100644 --- a/pytensor/link/mlx/dispatch/math.py +++ b/pytensor/link/mlx/dispatch/math.py @@ -1,4 +1,5 @@ import mlx.core as mx +from functools import singledispatch from pytensor.link.mlx.dispatch import mlx_funcify, mlx_typify from pytensor.link.mlx.dispatch.core import convert_dtype_to_mlx @@ -17,6 +18,7 @@ Cast, Cos, Exp, + Invert, Log, Log1p, Mul, @@ -51,200 +53,283 @@ def dot(x, y): return dot -@mlx_funcify.register(Elemwise) -def mlx_funcify_Elemwise(op, **kwargs): - if isinstance(op.scalar_op, Add): - - def add(*args): - result = args[0] - for arg in args[1:]: - result = mx.add(result, arg) - return result - - return add - elif isinstance(op.scalar_op, Sub): +# Second-level dispatch for scalar operations in Elemwise +@singledispatch +def mlx_funcify_Elemwise_scalar_op(scalar_op): + """Default implementation that tries to use getattr(mx, func_name) similar to JAX.""" + + # Try to get the function name from nfunc_spec (like JAX does) + nfunc_spec = getattr(scalar_op, "nfunc_spec", None) + if nfunc_spec is not None: + func_name = nfunc_spec[0] + try: + mlx_func = getattr(mx, func_name) + # Handle variadic functions + if len(scalar_op.inputs) > nfunc_spec[1]: + # For operations like Add that can take multiple inputs + def variadic_func(*args): + result = args[0] + for arg in args[1:]: + result = mlx_func(result, arg) + return result + return variadic_func + else: + return mlx_func + except AttributeError: + pass + + # Try using the operation name directly + op_name = getattr(scalar_op, "name", None) + if op_name is not None: + try: + return getattr(mx, op_name) + except AttributeError: + pass + + raise NotImplementedError(f"MLX does not support Elemwise scalar op {scalar_op}") + + +@mlx_funcify_Elemwise_scalar_op.register(Add) +def _(scalar_op): + def add(*args): + result = args[0] + for arg in args[1:]: + result = mx.add(result, arg) + return result + return add + + +@mlx_funcify_Elemwise_scalar_op.register(Sub) +def _(scalar_op): + def sub(x, y): + return mx.subtract(x, y) + return sub + + +@mlx_funcify_Elemwise_scalar_op.register(Mul) +def _(scalar_op): + def mul(*args): + result = args[0] + for arg in args[1:]: + result = mx.multiply(result, arg) + return result + return mul + + +@mlx_funcify_Elemwise_scalar_op.register(TrueDiv) +def _(scalar_op): + def true_div(x, y): + return mx.divide(x, y) + return true_div + + +@mlx_funcify_Elemwise_scalar_op.register(Pow) +def _(scalar_op): + def pow(x, y): + return mx.power(x, y) + return pow + + +@mlx_funcify_Elemwise_scalar_op.register(Exp) +def _(scalar_op): + def exp(x): + return mx.exp(x) + return exp + + +@mlx_funcify_Elemwise_scalar_op.register(Log) +def _(scalar_op): + def log(x): + return mx.log(x) + return log + + +@mlx_funcify_Elemwise_scalar_op.register(Log1p) +def _(scalar_op): + def log1p(x): + return mx.log1p(x) + return log1p + + +@mlx_funcify_Elemwise_scalar_op.register(Sin) +def _(scalar_op): + def sin(x): + return mx.sin(x) + return sin + + +@mlx_funcify_Elemwise_scalar_op.register(Cos) +def _(scalar_op): + def cos(x): + return mx.cos(x) + return cos - def sub(x, y): - return mx.subtract(x, y) - return sub - elif isinstance(op.scalar_op, Mul): +@mlx_funcify_Elemwise_scalar_op.register(Sqrt) +def _(scalar_op): + def sqrt(x): + return mx.sqrt(x) + return sqrt - def mul(*args): - result = args[0] - for arg in args[1:]: - result = mx.multiply(result, arg) - return result - return mul - elif isinstance(op.scalar_op, Exp): +@mlx_funcify_Elemwise_scalar_op.register(Sqr) +def _(scalar_op): + def sqr(x): + return mx.square(x) + return sqr - def exp(x): - return mx.exp(x) - return exp - elif isinstance(op.scalar_op, Log): +@mlx_funcify_Elemwise_scalar_op.register(Abs) +def _(scalar_op): + def abs(x): + return mx.abs(x) + return abs - def log(x): - return mx.log(x) - return log - elif isinstance(op.scalar_op, Sin): +@mlx_funcify_Elemwise_scalar_op.register(Neg) +def _(scalar_op): + def neg(x): + return mx.negative(x) + return neg - def sin(x): - return mx.sin(x) - return sin - elif isinstance(op.scalar_op, Cos): +@mlx_funcify_Elemwise_scalar_op.register(Sign) +def _(scalar_op): + def sign(x): + return mx.sign(x) + return sign - def cos(x): - return mx.cos(x) - return cos - elif isinstance(op.scalar_op, Sigmoid): +@mlx_funcify_Elemwise_scalar_op.register(LE) +def _(scalar_op): + def le(x, y): + return mx.less_equal(x, y) + return le - def sigmoid(x): - return mx.sigmoid(x) - return sigmoid - elif isinstance(op.scalar_op, LE): +@mlx_funcify_Elemwise_scalar_op.register(LT) +def _(scalar_op): + def lt(x, y): + return mx.less(x, y) + return lt - def le(x, y): - return mx.less_equal(x, y) - return le - elif isinstance(op.scalar_op, LT): +@mlx_funcify_Elemwise_scalar_op.register(GE) +def _(scalar_op): + def ge(x, y): + return mx.greater_equal(x, y) + return ge - def lt(x, y): - return mx.less(x, y) - return lt - elif isinstance(op.scalar_op, GE): +@mlx_funcify_Elemwise_scalar_op.register(GT) +def _(scalar_op): + def gt(x, y): + return mx.greater(x, y) + return gt - def ge(x, y): - return mx.greater_equal(x, y) - return ge - elif isinstance(op.scalar_op, GT): +@mlx_funcify_Elemwise_scalar_op.register(EQ) +def _(scalar_op): + def eq(x, y): + return mx.equal(x, y) + return eq - def gt(x, y): - return mx.greater(x, y) - return gt - elif isinstance(op.scalar_op, EQ): +@mlx_funcify_Elemwise_scalar_op.register(NEQ) +def _(scalar_op): + def neq(x, y): + return mx.not_equal(x, y) + return neq - def eq(x, y): - return mx.equal(x, y) - return eq - elif isinstance(op.scalar_op, NEQ): +@mlx_funcify_Elemwise_scalar_op.register(Switch) +def _(scalar_op): + def switch(cond, x, y): + return mx.where(cond, x, y) + return switch - def neq(x, y): - return mx.not_equal(x, y) - return neq - elif isinstance(op.scalar_op, Switch): +@mlx_funcify_Elemwise_scalar_op.register(AND) +def _(scalar_op): + def bitwise_and(x, y): + return mx.bitwise_and(x, y) + return bitwise_and - def switch(cond, x, y): - return mx.where(cond, x, y) - return switch - elif isinstance(op.scalar_op, Pow): +@mlx_funcify_Elemwise_scalar_op.register(OR) +def _(scalar_op): + def bitwise_or(x, y): + return mx.bitwise_or(x, y) + return bitwise_or - def pow(x, y): - return mx.power(x, y) - return pow - elif isinstance(op.scalar_op, TrueDiv): +@mlx_funcify_Elemwise_scalar_op.register(ScalarMaximum) +def _(scalar_op): + def maximum(x, y): + return mx.maximum(x, y) + return maximum - def true_div(x, y): - return mx.divide(x, y) - return true_div - elif isinstance(op.scalar_op, Sqr): +@mlx_funcify_Elemwise_scalar_op.register(ScalarMinimum) +def _(scalar_op): + def minimum(x, y): + return mx.minimum(x, y) + return minimum - def sqr(x): - return mx.square(x) - return sqr - elif isinstance(op.scalar_op, Sqrt): +@mlx_funcify_Elemwise_scalar_op.register(Cast) +def _(scalar_op): + def cast(x): + dtype = convert_dtype_to_mlx(scalar_op.o_type.dtype) + return x.astype(dtype) + return cast - def sqrt(x): - return mx.sqrt(x) - return sqrt - elif isinstance(op.scalar_op, Abs): +@mlx_funcify_Elemwise_scalar_op.register(Sigmoid) +def _(scalar_op): + def sigmoid(x): + return mx.sigmoid(x) + return sigmoid - def abs(x): - return mx.abs(x) - return abs - elif isinstance(op.scalar_op, Softplus): - - def softplus(x): - return mx.where( - x < -37.0, - mx.exp(x), +@mlx_funcify_Elemwise_scalar_op.register(Softplus) +def _(scalar_op): + def softplus(x): + return mx.where( + x < -37.0, + mx.exp(x), + mx.where( + x < 18.0, + mx.log1p(mx.exp(x)), mx.where( - x < 18.0, - mx.log1p(mx.exp(x)), - mx.where( - x < 33.3, - x + mx.exp(-x), - x, - ), + x < 33.3, + x + mx.exp(-x), + x, ), - ) - - return softplus - elif isinstance(op.scalar_op, Neg): - - def neg(x): - return mx.negative(x) - - return neg - elif isinstance(op.scalar_op, AND): - - def all(x, y): - return mx.bitwise_and(x, y) - - return all - elif isinstance(op.scalar_op, OR): + ), + ) + return softplus - def any(x, y): - return mx.bitwise_or(x, y) - return any - elif isinstance(op.scalar_op, ScalarMaximum): +@mlx_funcify_Elemwise_scalar_op.register(Invert) +def _(scalar_op): + def invert(x): + return ~x + return invert - def max(x, y): - return mx.maximum(x, y) - return max - elif isinstance(op.scalar_op, ScalarMinimum): - - def min(x, y): - return mx.minimum(x, y) - - return min - elif isinstance(op.scalar_op, Cast): - - def cast(x): - dtype = convert_dtype_to_mlx(op.scalar_op.o_type.dtype) - return x.astype(dtype) - - return cast - elif isinstance(op.scalar_op, Sign): - - def sign(x): - return mx.sign(x) - - return sign - elif isinstance(op.scalar_op, Log1p): - - def log1p(x): - return mx.log1p(x) - - return log1p - else: - raise NotImplementedError(f"MLX does not support {op.scalar_op}") +@mlx_funcify.register(Elemwise) +def mlx_funcify_Elemwise(op, node=None, **kwargs): + # Dispatch to the appropriate scalar op handler + scalar_func = mlx_funcify_Elemwise_scalar_op(op.scalar_op) + + def elemwise(*inputs): + # Enforce runtime broadcast checks (same as JAX and PyTorch implementations) + if node is not None: + # Convert inputs to MLX arrays for broadcast checking + mlx_inputs = tuple(mx.array(inp) if not hasattr(inp, 'shape') else inp for inp in inputs) + Elemwise._check_runtime_broadcast(node, mlx_inputs) + + return scalar_func(*inputs) + + return elemwise diff --git a/tests/link/mlx/test_subtensor.py b/tests/link/mlx/test_subtensor.py new file mode 100644 index 0000000000..1ecd441f55 --- /dev/null +++ b/tests/link/mlx/test_subtensor.py @@ -0,0 +1,276 @@ +import numpy as np +import pytest + +import pytensor.tensor as pt +from pytensor.configdefaults import config +from pytensor.tensor import subtensor as pt_subtensor +from pytensor.tensor import tensor +from test_basic import compare_mlx_and_py + +mx = pytest.importorskip("mlx.core") + + +def test_mlx_Subtensor_basic(): + """Test basic subtensor operations with constant indices.""" + shape = (3, 4, 5) + x_pt = tensor("x", shape=shape, dtype="float32") + x_np = np.arange(np.prod(shape), dtype=np.float32).reshape(shape) + + # Basic indexing with single elements + out_pt = x_pt[1, 2, 0] + assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor) + compare_mlx_and_py([x_pt], [out_pt], [x_np]) + + # Basic indexing with slices + out_pt = x_pt[1:, 1, :] + assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor) + compare_mlx_and_py([x_pt], [out_pt], [x_np]) + + out_pt = x_pt[:2, 1, :] + assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor) + compare_mlx_and_py([x_pt], [out_pt], [x_np]) + + out_pt = x_pt[1:2, 1, :] + assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor) + compare_mlx_and_py([x_pt], [out_pt], [x_np]) + + # Negative indexing + out_pt = x_pt[-1, -1, -1] + assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor) + compare_mlx_and_py([x_pt], [out_pt], [x_np]) + + # Step slicing + out_pt = x_pt[::2, ::2, ::2] + assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor) + compare_mlx_and_py([x_pt], [out_pt], [x_np]) + + # Reverse indexing + out_pt = x_pt[::-1, ::-1, ::-1] + assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor) + compare_mlx_and_py([x_pt], [out_pt], [x_np]) + + +def test_mlx_AdvancedSubtensor(): + """Test advanced subtensor operations.""" + shape = (3, 4, 5) + x_pt = tensor("x", shape=shape, dtype="float32") + x_np = np.arange(np.prod(shape), dtype=np.float32).reshape(shape) + + # Advanced indexing with array indices + out_pt = pt_subtensor.advanced_subtensor1(x_pt, [1, 2]) + assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor1) + compare_mlx_and_py([x_pt], [out_pt], [x_np]) + + # Multi-dimensional advanced indexing + out_pt = x_pt[[1, 2], [2, 3]] + assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor) + compare_mlx_and_py([x_pt], [out_pt], [x_np]) + + # Mixed advanced and basic indexing + out_pt = x_pt[[1, 2], :] + assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor) + compare_mlx_and_py([x_pt], [out_pt], [x_np]) + + out_pt = x_pt[[1, 2], :, [3, 4]] + assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor) + compare_mlx_and_py([x_pt], [out_pt], [x_np]) + + +@pytest.mark.xfail(reason="MLX does not support boolean indexing yet") +def test_mlx_AdvancedSubtensor_boolean(): + """Test advanced subtensor operations with boolean indexing.""" + shape = (3, 4, 5) + x_pt = tensor("x", shape=shape, dtype="float32") + x_np = np.arange(np.prod(shape), dtype=np.float32).reshape(shape) + + # Boolean indexing with constant mask + bool_mask = np.array([True, False, True]) + out_pt = x_pt[bool_mask] + assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor) + compare_mlx_and_py([x_pt], [out_pt], [x_np]) + + +@pytest.mark.xfail(reason="MLX indexing with tuples not yet supported") +def test_mlx_IncSubtensor_set(): + """Test set operations using IncSubtensor (set_instead_of_inc=True).""" + rng = np.random.default_rng(213234) + + # Test data + x_np = np.arange(3 * 4 * 5, dtype=np.float32).reshape((3, 4, 5)) + x_pt = pt.constant(x_np) + + # Set single element + st_pt = pt.as_tensor_variable(np.array(-10.0, dtype=np.float32)) + out_pt = pt_subtensor.set_subtensor(x_pt[1, 2, 3], st_pt) + assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor) + assert out_pt.owner.op.set_instead_of_inc == True + compare_mlx_and_py([], [out_pt], []) + + +@pytest.mark.xfail(reason="MLX indexing with tuples not yet supported") +def test_mlx_IncSubtensor_increment(): + """Test increment operations using IncSubtensor (set_instead_of_inc=False).""" + rng = np.random.default_rng(213234) + + # Test data + x_np = np.arange(3 * 4 * 5, dtype=np.float32).reshape((3, 4, 5)) + x_pt = pt.constant(x_np) + + # Increment single element + st_pt = pt.as_tensor_variable(np.array(-10.0, dtype=np.float32)) + out_pt = pt_subtensor.inc_subtensor(x_pt[1, 2, 3], st_pt) + assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor) + assert out_pt.owner.op.set_instead_of_inc == False + compare_mlx_and_py([], [out_pt], []) + + +def test_mlx_AdvancedIncSubtensor_set(): + """Test advanced set operations using AdvancedIncSubtensor.""" + rng = np.random.default_rng(213234) + + # Test data + x_np = np.arange(3 * 4 * 5, dtype=np.float32).reshape((3, 4, 5)) + x_pt = pt.constant(x_np) + + # Set with advanced indexing - this actually works in MLX! + st_pt = pt.as_tensor_variable( + rng.uniform(-1, 1, size=(2, 4, 5)).astype(np.float32) + ) + out_pt = pt_subtensor.set_subtensor(x_pt[np.r_[0, 2]], st_pt) + assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor) + assert out_pt.owner.op.set_instead_of_inc == True + compare_mlx_and_py([], [out_pt], []) + + +def test_mlx_AdvancedIncSubtensor_increment(): + """Test advanced increment operations using AdvancedIncSubtensor.""" + rng = np.random.default_rng(213234) + + # Test data + x_np = np.arange(3 * 4 * 5, dtype=np.float32).reshape((3, 4, 5)) + x_pt = pt.constant(x_np) + + # Increment with advanced indexing - this actually works in MLX! + st_pt = pt.as_tensor_variable( + rng.uniform(-1, 1, size=(2, 4, 5)).astype(np.float32) + ) + out_pt = pt_subtensor.inc_subtensor(x_pt[np.r_[0, 2]], st_pt) + assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor) + assert out_pt.owner.op.set_instead_of_inc == False + compare_mlx_and_py([], [out_pt], []) + + +def test_mlx_AdvancedIncSubtensor1_operations(): + """Test AdvancedIncSubtensor1 operations (handled by IncSubtensor dispatcher).""" + rng = np.random.default_rng(213234) + + # Test data + x_np = np.arange(3 * 4 * 5, dtype=np.float32).reshape((3, 4, 5)) + x_pt = pt.constant(x_np) + + # Test set operation - this actually works in MLX! + st_pt = pt.as_tensor_variable( + rng.uniform(-1, 1, size=(2, 4, 5)).astype(np.float32) + ) + indices = [1, 2] + + # Create AdvancedIncSubtensor1 manually for set operation + out_pt = pt_subtensor.advanced_set_subtensor1(x_pt, st_pt, indices) + assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor1) + assert out_pt.owner.op.set_instead_of_inc == True + compare_mlx_and_py([], [out_pt], []) + + +@pytest.mark.xfail(reason="Inplace operations not yet supported in MLX mode") +def test_mlx_inplace_variants(): + """Test inplace variants of all subtensor operations.""" + rng = np.random.default_rng(213234) + + # Test data + x_np = np.arange(12, dtype=np.float32).reshape((3, 4)) + x_pt = pt.constant(x_np) + + # Test inplace IncSubtensor (set) + st_pt = pt.as_tensor_variable(np.array([-1.0, -2.0], dtype=np.float32)) + out_pt = pt_subtensor.set_subtensor(x_pt[0, :2], st_pt, inplace=True) + assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor) + assert out_pt.owner.op.inplace == True + assert out_pt.owner.op.set_instead_of_inc == True + compare_mlx_and_py([], [out_pt], []) + + +@pytest.mark.xfail(reason="MLX slice indices must be integers or None, dynamic slices not supported") +def test_mlx_MakeSlice(): + """Test MakeSlice operation.""" + # Test slice creation + start = pt.iscalar("start") + stop = pt.iscalar("stop") + step = pt.iscalar("step") + + # Create a slice using MakeSlice + slice_op = pt_subtensor.MakeSlice() + slice_pt = slice_op(start, stop, step) + + # Use simple constant array instead of arange + x_pt = pt.constant(np.arange(10, dtype=np.float32)) + out_pt = x_pt[slice_pt] + + compare_mlx_and_py([start, stop, step], [out_pt], [1, 8, 2]) + + +def test_mlx_subtensor_edge_cases(): + """Test edge cases and boundary conditions.""" + # Empty slices - use constant array + x_pt = pt.constant(np.arange(10, dtype=np.float32)) + out_pt = x_pt[5:5] # Empty slice + compare_mlx_and_py([], [out_pt], []) + + # Single element arrays + x_pt = pt.tensor(shape=(1,), dtype="float32", name="x") + x_np = np.array([42.0], dtype=np.float32) + out_pt = x_pt[0] + compare_mlx_and_py([x_pt], [out_pt], [x_np]) + + # Large step sizes - use constant array + x_pt = pt.constant(np.arange(20, dtype=np.float32)) + out_pt = x_pt[::5] + compare_mlx_and_py([], [out_pt], []) + + # Negative steps - use constant array + x_pt = pt.constant(np.arange(10, dtype=np.float32)) + out_pt = x_pt[::-2] + compare_mlx_and_py([], [out_pt], []) + + +@pytest.mark.xfail(reason="MLX indexing with tuples not yet supported") +def test_mlx_subtensor_with_variables(): + """Test subtensor operations with PyTensor variables as inputs.""" + # Test with variable arrays (not constants) + x_pt = pt.matrix("x", dtype="float32") + y_pt = pt.vector("y", dtype="float32") + + x_np = np.arange(12, dtype=np.float32).reshape((3, 4)) + y_np = np.array([-1.0, -2.0], dtype=np.float32) + + # Set operation with variables + out_pt = pt_subtensor.set_subtensor(x_pt[0, :2], y_pt) + compare_mlx_and_py([x_pt, y_pt], [out_pt], [x_np, y_np]) + + +def test_mlx_subtensor_working_operations_summary(): + """Summary test showing which subtensor operations currently work in MLX.""" + + # Operations that work: + # 1. Basic Subtensor with constant indices ✅ + # 2. Advanced Subtensor with array indices ✅ + # 3. MakeSlice ✅ + # 4. Edge cases with constant arrays ✅ + + # Operations that don't work yet: + # 1. Boolean indexing ❌ ("boolean indices are not yet supported") + # 2. IncSubtensor/AdvancedIncSubtensor ❌ ("Cannot index mlx array using the given type yet") + # 3. Inplace operations ❌ (require special MLX handling) + # 4. Variable indexing ❌ (tuples not supported in MLX indexing) + + # This test documents the current state + assert True # Just a documentation test \ No newline at end of file From 602f0ede876f2362038bb02619807f08e72c8ec4 Mon Sep 17 00:00:00 2001 From: Carlos Trujillo <59846724+cetagostini@users.noreply.github.com> Date: Mon, 2 Jun 2025 20:51:41 +0200 Subject: [PATCH 56/71] Pre commit changes --- pytensor/link/mlx/dispatch/core.py | 44 +++++++-------- pytensor/link/mlx/dispatch/elemwise.py | 17 ++++-- pytensor/link/mlx/dispatch/math.py | 50 ++++++++++++++--- tests/link/mlx/test_subtensor.py | 75 ++++++++------------------ 4 files changed, 97 insertions(+), 89 deletions(-) diff --git a/pytensor/link/mlx/dispatch/core.py b/pytensor/link/mlx/dispatch/core.py index a1f128218b..785aca2811 100644 --- a/pytensor/link/mlx/dispatch/core.py +++ b/pytensor/link/mlx/dispatch/core.py @@ -13,10 +13,10 @@ import warnings -import mlx.core as mx +import mlx.core as mx import numpy as np -from pytensor.link.mlx.dispatch.basic import mlx_funcify +from pytensor.link.mlx.dispatch.basic import mlx_funcify from pytensor.tensor import get_vector_length from pytensor.tensor.basic import ( Alloc, @@ -34,22 +34,22 @@ from pytensor.tensor.exceptions import NotScalarConstantError -@mlx_funcify.register(Join) +@mlx_funcify.register(Join) def mlx_funcify_Join(op, **kwargs): def join(axis, *tensors): view = op.view if (view != -1) and all( - tensors[i].shape[axis] == 0 + tensors[i].shape[axis] == 0 for i in list(range(view)) + list(range(view + 1, len(tensors))) ): return tensors[view] - return mx.concatenate(tensors, axis=axis) + return mx.concatenate(tensors, axis=axis) return join -@mlx_funcify.register(Split) +@mlx_funcify.register(Split) def mlx_funcify_Split(op: Split, node, **kwargs): _, axis_sym, splits_sym = node.inputs @@ -84,7 +84,7 @@ def split(x, axis, splits): cumsum_splits = np.cumsum(splits[:-1]) else: # dynamic - keep in graph - splits_arr = mx.array(splits) + splits_arr = mx.array(splits) cumsum_splits = mx.cumsum( splits_arr[:-1] ).tolist() # python list for mx.split @@ -98,29 +98,27 @@ def split(x, axis, splits): if np.any(np.asarray(splits) < 0): raise ValueError("Split sizes cannot be negative.") - return mx.split(x, cumsum_splits, axis=axis) + return mx.split(x, cumsum_splits, axis=axis) return split - -@mlx_funcify.register(ExtractDiag) +@mlx_funcify.register(ExtractDiag) def mlx_funcify_ExtractDiag(op, **kwargs): offset, axis1, axis2 = op.offset, op.axis1, op.axis2 def extract_diag(x, offset=offset, axis1=axis1, axis2=axis2): - return mx.diagonal(x, offset=offset, axis1=axis1, axis2=axis2) + return mx.diagonal(x, offset=offset, axis1=axis1, axis2=axis2) return extract_diag - -@mlx_funcify.register(Eye) +@mlx_funcify.register(Eye) def mlx_funcify_Eye(op, **kwargs): dtype = convert_dtype_to_mlx(op.dtype) def eye(N, M, k): - return mx.eye(int(N), int(M), int(k), dtype=dtype) + return mx.eye(int(N), int(M), int(k), dtype=dtype) return eye @@ -166,19 +164,17 @@ def convert_dtype_to_mlx(dtype_str): return dtype_str - -@mlx_funcify.register(MakeVector) +@mlx_funcify.register(MakeVector) def mlx_funcify_MakeVector(op, **kwargs): dtype = convert_dtype_to_mlx(op.dtype) def makevector(*x): - return mx.array(x, dtype=dtype) + return mx.array(x, dtype=dtype) return makevector - -@mlx_funcify.register(TensorFromScalar) +@mlx_funcify.register(TensorFromScalar) def mlx_funcify_TensorFromScalar(op, **kwargs): def tensor_from_scalar(x): return x # already an MLX array / scalar @@ -186,17 +182,15 @@ def tensor_from_scalar(x): return tensor_from_scalar - -@mlx_funcify.register(ScalarFromTensor) +@mlx_funcify.register(ScalarFromTensor) def mlx_funcify_ScalarFromTensor(op, **kwargs): def scalar_from_tensor(x): - return mx.array(x).reshape(-1)[0] + return mx.array(x).reshape(-1)[0] return scalar_from_tensor - -@mlx_funcify.register(Tri) +@mlx_funcify.register(Tri) def mlx_funcify_Tri(op, node, **kwargs): # node.inputs -> N, M, k const_args = [getattr(inp, "data", None) for inp in node.inputs] @@ -208,7 +202,7 @@ def tri(*args): arg if const_a is None else const_a for arg, const_a in zip(args, const_args, strict=True) ] - return mx.tri(*args, dtype=dtype) + return mx.tri(*args, dtype=dtype) return tri diff --git a/pytensor/link/mlx/dispatch/elemwise.py b/pytensor/link/mlx/dispatch/elemwise.py index d2e8274c68..476d32f805 100644 --- a/pytensor/link/mlx/dispatch/elemwise.py +++ b/pytensor/link/mlx/dispatch/elemwise.py @@ -1,6 +1,7 @@ +from functools import singledispatch + import mlx.core as mx import numpy as np -from functools import singledispatch from pytensor.link.mlx.dispatch.basic import mlx_funcify from pytensor.link.mlx.dispatch.core import convert_dtype_to_mlx @@ -38,13 +39,16 @@ def dimshuffle(x): # Second-level dispatch for scalar operations in CAReduce @singledispatch def mlx_funcify_CAReduce_scalar_op(scalar_op): - raise NotImplementedError(f"MLX does not support CAReduce with scalar op {scalar_op}") + raise NotImplementedError( + f"MLX does not support CAReduce with scalar op {scalar_op}" + ) @mlx_funcify_CAReduce_scalar_op.register(Add) def _(scalar_op): def sum_reduce(x, axis): return mx.sum(x, axis=axis) + return sum_reduce @@ -52,6 +56,7 @@ def sum_reduce(x, axis): def _(scalar_op): def prod_reduce(x, axis): return mx.prod(x, axis=axis) + return prod_reduce @@ -59,6 +64,7 @@ def prod_reduce(x, axis): def _(scalar_op): def all_reduce(x, axis): return x.all(axis=axis) + return all_reduce @@ -66,6 +72,7 @@ def all_reduce(x, axis): def _(scalar_op): def any_reduce(x, axis): return mx.any(x, axis=axis) + return any_reduce @@ -73,6 +80,7 @@ def any_reduce(x, axis): def _(scalar_op): def max_reduce(x, axis): return mx.max(x, axis=axis) + return max_reduce @@ -80,6 +88,7 @@ def max_reduce(x, axis): def _(scalar_op): def min_reduce(x, axis): return mx.min(x, axis=axis) + return min_reduce @@ -88,10 +97,10 @@ def mlx_funcify_CAReduce(op, **kwargs): # Dispatch to the appropriate scalar op handler scalar_reduce_fn = mlx_funcify_CAReduce_scalar_op(op.scalar_op) axis = op.axis - + def reduce(x): return scalar_reduce_fn(x, axis) - + return reduce diff --git a/pytensor/link/mlx/dispatch/math.py b/pytensor/link/mlx/dispatch/math.py index 840344bfbc..f876a58671 100644 --- a/pytensor/link/mlx/dispatch/math.py +++ b/pytensor/link/mlx/dispatch/math.py @@ -1,6 +1,7 @@ -import mlx.core as mx from functools import singledispatch +import mlx.core as mx + from pytensor.link.mlx.dispatch import mlx_funcify, mlx_typify from pytensor.link.mlx.dispatch.core import convert_dtype_to_mlx from pytensor.scalar import Softplus @@ -57,7 +58,7 @@ def dot(x, y): @singledispatch def mlx_funcify_Elemwise_scalar_op(scalar_op): """Default implementation that tries to use getattr(mx, func_name) similar to JAX.""" - + # Try to get the function name from nfunc_spec (like JAX does) nfunc_spec = getattr(scalar_op, "nfunc_spec", None) if nfunc_spec is not None: @@ -72,12 +73,13 @@ def variadic_func(*args): for arg in args[1:]: result = mlx_func(result, arg) return result + return variadic_func else: return mlx_func except AttributeError: pass - + # Try using the operation name directly op_name = getattr(scalar_op, "name", None) if op_name is not None: @@ -85,7 +87,7 @@ def variadic_func(*args): return getattr(mx, op_name) except AttributeError: pass - + raise NotImplementedError(f"MLX does not support Elemwise scalar op {scalar_op}") @@ -96,6 +98,7 @@ def add(*args): for arg in args[1:]: result = mx.add(result, arg) return result + return add @@ -103,6 +106,7 @@ def add(*args): def _(scalar_op): def sub(x, y): return mx.subtract(x, y) + return sub @@ -113,6 +117,7 @@ def mul(*args): for arg in args[1:]: result = mx.multiply(result, arg) return result + return mul @@ -120,6 +125,7 @@ def mul(*args): def _(scalar_op): def true_div(x, y): return mx.divide(x, y) + return true_div @@ -127,6 +133,7 @@ def true_div(x, y): def _(scalar_op): def pow(x, y): return mx.power(x, y) + return pow @@ -134,6 +141,7 @@ def pow(x, y): def _(scalar_op): def exp(x): return mx.exp(x) + return exp @@ -141,6 +149,7 @@ def exp(x): def _(scalar_op): def log(x): return mx.log(x) + return log @@ -148,6 +157,7 @@ def log(x): def _(scalar_op): def log1p(x): return mx.log1p(x) + return log1p @@ -155,6 +165,7 @@ def log1p(x): def _(scalar_op): def sin(x): return mx.sin(x) + return sin @@ -162,6 +173,7 @@ def sin(x): def _(scalar_op): def cos(x): return mx.cos(x) + return cos @@ -169,6 +181,7 @@ def cos(x): def _(scalar_op): def sqrt(x): return mx.sqrt(x) + return sqrt @@ -176,6 +189,7 @@ def sqrt(x): def _(scalar_op): def sqr(x): return mx.square(x) + return sqr @@ -183,6 +197,7 @@ def sqr(x): def _(scalar_op): def abs(x): return mx.abs(x) + return abs @@ -190,6 +205,7 @@ def abs(x): def _(scalar_op): def neg(x): return mx.negative(x) + return neg @@ -197,6 +213,7 @@ def neg(x): def _(scalar_op): def sign(x): return mx.sign(x) + return sign @@ -204,6 +221,7 @@ def sign(x): def _(scalar_op): def le(x, y): return mx.less_equal(x, y) + return le @@ -211,6 +229,7 @@ def le(x, y): def _(scalar_op): def lt(x, y): return mx.less(x, y) + return lt @@ -218,6 +237,7 @@ def lt(x, y): def _(scalar_op): def ge(x, y): return mx.greater_equal(x, y) + return ge @@ -225,6 +245,7 @@ def ge(x, y): def _(scalar_op): def gt(x, y): return mx.greater(x, y) + return gt @@ -232,6 +253,7 @@ def gt(x, y): def _(scalar_op): def eq(x, y): return mx.equal(x, y) + return eq @@ -239,6 +261,7 @@ def eq(x, y): def _(scalar_op): def neq(x, y): return mx.not_equal(x, y) + return neq @@ -246,6 +269,7 @@ def neq(x, y): def _(scalar_op): def switch(cond, x, y): return mx.where(cond, x, y) + return switch @@ -253,6 +277,7 @@ def switch(cond, x, y): def _(scalar_op): def bitwise_and(x, y): return mx.bitwise_and(x, y) + return bitwise_and @@ -260,6 +285,7 @@ def bitwise_and(x, y): def _(scalar_op): def bitwise_or(x, y): return mx.bitwise_or(x, y) + return bitwise_or @@ -267,6 +293,7 @@ def bitwise_or(x, y): def _(scalar_op): def maximum(x, y): return mx.maximum(x, y) + return maximum @@ -274,6 +301,7 @@ def maximum(x, y): def _(scalar_op): def minimum(x, y): return mx.minimum(x, y) + return minimum @@ -282,6 +310,7 @@ def _(scalar_op): def cast(x): dtype = convert_dtype_to_mlx(scalar_op.o_type.dtype) return x.astype(dtype) + return cast @@ -289,6 +318,7 @@ def cast(x): def _(scalar_op): def sigmoid(x): return mx.sigmoid(x) + return sigmoid @@ -308,6 +338,7 @@ def softplus(x): ), ), ) + return softplus @@ -315,6 +346,7 @@ def softplus(x): def _(scalar_op): def invert(x): return ~x + return invert @@ -322,14 +354,16 @@ def invert(x): def mlx_funcify_Elemwise(op, node=None, **kwargs): # Dispatch to the appropriate scalar op handler scalar_func = mlx_funcify_Elemwise_scalar_op(op.scalar_op) - + def elemwise(*inputs): # Enforce runtime broadcast checks (same as JAX and PyTorch implementations) if node is not None: # Convert inputs to MLX arrays for broadcast checking - mlx_inputs = tuple(mx.array(inp) if not hasattr(inp, 'shape') else inp for inp in inputs) + mlx_inputs = tuple( + mx.array(inp) if not hasattr(inp, "shape") else inp for inp in inputs + ) Elemwise._check_runtime_broadcast(node, mlx_inputs) - + return scalar_func(*inputs) - + return elemwise diff --git a/tests/link/mlx/test_subtensor.py b/tests/link/mlx/test_subtensor.py index 1ecd441f55..cfc5a07baa 100644 --- a/tests/link/mlx/test_subtensor.py +++ b/tests/link/mlx/test_subtensor.py @@ -1,11 +1,11 @@ import numpy as np import pytest +from test_basic import compare_mlx_and_py import pytensor.tensor as pt -from pytensor.configdefaults import config from pytensor.tensor import subtensor as pt_subtensor from pytensor.tensor import tensor -from test_basic import compare_mlx_and_py + mx = pytest.importorskip("mlx.core") @@ -93,8 +93,6 @@ def test_mlx_AdvancedSubtensor_boolean(): @pytest.mark.xfail(reason="MLX indexing with tuples not yet supported") def test_mlx_IncSubtensor_set(): """Test set operations using IncSubtensor (set_instead_of_inc=True).""" - rng = np.random.default_rng(213234) - # Test data x_np = np.arange(3 * 4 * 5, dtype=np.float32).reshape((3, 4, 5)) x_pt = pt.constant(x_np) @@ -103,15 +101,13 @@ def test_mlx_IncSubtensor_set(): st_pt = pt.as_tensor_variable(np.array(-10.0, dtype=np.float32)) out_pt = pt_subtensor.set_subtensor(x_pt[1, 2, 3], st_pt) assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor) - assert out_pt.owner.op.set_instead_of_inc == True + assert out_pt.owner.op.set_instead_of_inc compare_mlx_and_py([], [out_pt], []) @pytest.mark.xfail(reason="MLX indexing with tuples not yet supported") def test_mlx_IncSubtensor_increment(): """Test increment operations using IncSubtensor (set_instead_of_inc=False).""" - rng = np.random.default_rng(213234) - # Test data x_np = np.arange(3 * 4 * 5, dtype=np.float32).reshape((3, 4, 5)) x_pt = pt.constant(x_np) @@ -120,72 +116,64 @@ def test_mlx_IncSubtensor_increment(): st_pt = pt.as_tensor_variable(np.array(-10.0, dtype=np.float32)) out_pt = pt_subtensor.inc_subtensor(x_pt[1, 2, 3], st_pt) assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor) - assert out_pt.owner.op.set_instead_of_inc == False + assert not out_pt.owner.op.set_instead_of_inc compare_mlx_and_py([], [out_pt], []) def test_mlx_AdvancedIncSubtensor_set(): """Test advanced set operations using AdvancedIncSubtensor.""" rng = np.random.default_rng(213234) - + # Test data x_np = np.arange(3 * 4 * 5, dtype=np.float32).reshape((3, 4, 5)) x_pt = pt.constant(x_np) # Set with advanced indexing - this actually works in MLX! - st_pt = pt.as_tensor_variable( - rng.uniform(-1, 1, size=(2, 4, 5)).astype(np.float32) - ) + st_pt = pt.as_tensor_variable(rng.uniform(-1, 1, size=(2, 4, 5)).astype(np.float32)) out_pt = pt_subtensor.set_subtensor(x_pt[np.r_[0, 2]], st_pt) assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor) - assert out_pt.owner.op.set_instead_of_inc == True + assert out_pt.owner.op.set_instead_of_inc compare_mlx_and_py([], [out_pt], []) def test_mlx_AdvancedIncSubtensor_increment(): """Test advanced increment operations using AdvancedIncSubtensor.""" rng = np.random.default_rng(213234) - + # Test data x_np = np.arange(3 * 4 * 5, dtype=np.float32).reshape((3, 4, 5)) x_pt = pt.constant(x_np) # Increment with advanced indexing - this actually works in MLX! - st_pt = pt.as_tensor_variable( - rng.uniform(-1, 1, size=(2, 4, 5)).astype(np.float32) - ) + st_pt = pt.as_tensor_variable(rng.uniform(-1, 1, size=(2, 4, 5)).astype(np.float32)) out_pt = pt_subtensor.inc_subtensor(x_pt[np.r_[0, 2]], st_pt) assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor) - assert out_pt.owner.op.set_instead_of_inc == False + assert not out_pt.owner.op.set_instead_of_inc compare_mlx_and_py([], [out_pt], []) def test_mlx_AdvancedIncSubtensor1_operations(): """Test AdvancedIncSubtensor1 operations (handled by IncSubtensor dispatcher).""" rng = np.random.default_rng(213234) - + # Test data x_np = np.arange(3 * 4 * 5, dtype=np.float32).reshape((3, 4, 5)) x_pt = pt.constant(x_np) # Test set operation - this actually works in MLX! - st_pt = pt.as_tensor_variable( - rng.uniform(-1, 1, size=(2, 4, 5)).astype(np.float32) - ) + st_pt = pt.as_tensor_variable(rng.uniform(-1, 1, size=(2, 4, 5)).astype(np.float32)) indices = [1, 2] - + # Create AdvancedIncSubtensor1 manually for set operation out_pt = pt_subtensor.advanced_set_subtensor1(x_pt, st_pt, indices) assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor1) - assert out_pt.owner.op.set_instead_of_inc == True + assert out_pt.owner.op.set_instead_of_inc compare_mlx_and_py([], [out_pt], []) @pytest.mark.xfail(reason="Inplace operations not yet supported in MLX mode") def test_mlx_inplace_variants(): """Test inplace variants of all subtensor operations.""" - rng = np.random.default_rng(213234) - # Test data x_np = np.arange(12, dtype=np.float32).reshape((3, 4)) x_pt = pt.constant(x_np) @@ -194,27 +182,29 @@ def test_mlx_inplace_variants(): st_pt = pt.as_tensor_variable(np.array([-1.0, -2.0], dtype=np.float32)) out_pt = pt_subtensor.set_subtensor(x_pt[0, :2], st_pt, inplace=True) assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor) - assert out_pt.owner.op.inplace == True - assert out_pt.owner.op.set_instead_of_inc == True + assert out_pt.owner.op.inplace + assert out_pt.owner.op.set_instead_of_inc compare_mlx_and_py([], [out_pt], []) -@pytest.mark.xfail(reason="MLX slice indices must be integers or None, dynamic slices not supported") +@pytest.mark.xfail( + reason="MLX slice indices must be integers or None, dynamic slices not supported" +) def test_mlx_MakeSlice(): """Test MakeSlice operation.""" # Test slice creation start = pt.iscalar("start") stop = pt.iscalar("stop") step = pt.iscalar("step") - + # Create a slice using MakeSlice slice_op = pt_subtensor.MakeSlice() slice_pt = slice_op(start, stop, step) - + # Use simple constant array instead of arange x_pt = pt.constant(np.arange(10, dtype=np.float32)) out_pt = x_pt[slice_pt] - + compare_mlx_and_py([start, stop, step], [out_pt], [1, 8, 2]) @@ -248,29 +238,10 @@ def test_mlx_subtensor_with_variables(): # Test with variable arrays (not constants) x_pt = pt.matrix("x", dtype="float32") y_pt = pt.vector("y", dtype="float32") - + x_np = np.arange(12, dtype=np.float32).reshape((3, 4)) y_np = np.array([-1.0, -2.0], dtype=np.float32) # Set operation with variables out_pt = pt_subtensor.set_subtensor(x_pt[0, :2], y_pt) compare_mlx_and_py([x_pt, y_pt], [out_pt], [x_np, y_np]) - - -def test_mlx_subtensor_working_operations_summary(): - """Summary test showing which subtensor operations currently work in MLX.""" - - # Operations that work: - # 1. Basic Subtensor with constant indices ✅ - # 2. Advanced Subtensor with array indices ✅ - # 3. MakeSlice ✅ - # 4. Edge cases with constant arrays ✅ - - # Operations that don't work yet: - # 1. Boolean indexing ❌ ("boolean indices are not yet supported") - # 2. IncSubtensor/AdvancedIncSubtensor ❌ ("Cannot index mlx array using the given type yet") - # 3. Inplace operations ❌ (require special MLX handling) - # 4. Variable indexing ❌ (tuples not supported in MLX indexing) - - # This test documents the current state - assert True # Just a documentation test \ No newline at end of file From 8a2aea96b13e89ccd5afec81e91e230085ec2b35 Mon Sep 17 00:00:00 2001 From: Carlos Trujillo <59846724+cetagostini@users.noreply.github.com> Date: Sun, 8 Jun 2025 23:08:41 +0300 Subject: [PATCH 57/71] More changes from Ricardo --- pytensor/link/mlx/dispatch/basic.py | 6 +++ pytensor/link/mlx/dispatch/core.py | 41 +++++++---------- pytensor/link/mlx/dispatch/elemwise.py | 24 +++++----- pytensor/link/mlx/dispatch/math.py | 63 ++++---------------------- pytensor/link/mlx/dispatch/shape.py | 10 +++- tests/link/mlx/test_shape.py | 2 - 6 files changed, 52 insertions(+), 94 deletions(-) diff --git a/pytensor/link/mlx/dispatch/basic.py b/pytensor/link/mlx/dispatch/basic.py index c9f7f2fa3d..7b9a7e68f9 100644 --- a/pytensor/link/mlx/dispatch/basic.py +++ b/pytensor/link/mlx/dispatch/basic.py @@ -30,6 +30,12 @@ def mlx_typify_no_conversion_needed(data, **kwargs): return data +@mlx_typify.register(int) +@mlx_typify.register(float) +def mlx_typify_python_scalar(data, **kwargs): + return mx.array(data) + + @singledispatch def mlx_funcify(op, node=None, storage_map=None, **kwargs): """Create a MLX compatible function from an PyTensor `Op`.""" diff --git a/pytensor/link/mlx/dispatch/core.py b/pytensor/link/mlx/dispatch/core.py index 785aca2811..c0130ce1cc 100644 --- a/pytensor/link/mlx/dispatch/core.py +++ b/pytensor/link/mlx/dispatch/core.py @@ -37,13 +37,6 @@ @mlx_funcify.register(Join) def mlx_funcify_Join(op, **kwargs): def join(axis, *tensors): - view = op.view - if (view != -1) and all( - tensors[i].shape[axis] == 0 - for i in list(range(view)) + list(range(view + 1, len(tensors))) - ): - return tensors[view] - return mx.concatenate(tensors, axis=axis) return join @@ -57,9 +50,6 @@ def mlx_funcify_Split(op: Split, node, **kwargs): constant_axis = get_scalar_constant_value(axis_sym) except NotScalarConstantError: constant_axis = None - warnings.warn( - "Split node does not have a constant axis. MLX implementation may fail." - ) try: constant_splits = np.array( @@ -70,12 +60,9 @@ def mlx_funcify_Split(op: Split, node, **kwargs): ) except (ValueError, NotScalarConstantError): constant_splits = None - warnings.warn( - "Split node does not have constant split positions. MLX implementation may fail." - ) def split(x, axis, splits): - # Resolve constants (avoids tracing extra ops) + # Resolve constants for significant performance improvement (14x speedup) if constant_axis is not None: axis = int(constant_axis) @@ -83,12 +70,11 @@ def split(x, axis, splits): splits = constant_splits cumsum_splits = np.cumsum(splits[:-1]) else: - # dynamic - keep in graph + # Dynamic case - use MLX operations splits_arr = mx.array(splits) - cumsum_splits = mx.cumsum( - splits_arr[:-1] - ).tolist() # python list for mx.split + cumsum_splits = mx.cumsum(splits_arr[:-1]).tolist() + # Validation checks if len(splits) != op.len_splits: raise ValueError("Length of 'splits' is not equal to n_splits") if np.sum(np.asarray(splits)) != x.shape[axis]: @@ -114,10 +100,18 @@ def extract_diag(x, offset=offset, axis1=axis1, axis2=axis2): @mlx_funcify.register(Eye) -def mlx_funcify_Eye(op, **kwargs): +def mlx_funcify_Eye(op, node, **kwargs): + # Extract constants for performance optimization + const_args = [getattr(inp, "data", None) for inp in node.inputs] dtype = convert_dtype_to_mlx(op.dtype) - def eye(N, M, k): + def eye(*args): + # Replace args with compile-time constants when available for better performance + args = [ + arg if const_a is None else const_a + for arg, const_a in zip(args, const_args, strict=True) + ] + N, M, k = args return mx.eye(int(N), int(M), int(k), dtype=dtype) return eye @@ -185,7 +179,7 @@ def tensor_from_scalar(x): @mlx_funcify.register(ScalarFromTensor) def mlx_funcify_ScalarFromTensor(op, **kwargs): def scalar_from_tensor(x): - return mx.array(x).reshape(-1)[0] + return x.reshape(-1)[0] return scalar_from_tensor @@ -220,10 +214,7 @@ def allocempty(*shape): @mlx_funcify.register(Alloc) def mlx_funcify_Alloc(op, node, **kwargs): def alloc(x, *shape): - # Convert x to an MLX array with the correct dtype if it's a scalar - x_array = mx.array(x) - res = mx.broadcast_to(x_array, shape) - Alloc._check_runtime_broadcast(node, x_array, res.shape) + res = mx.broadcast_to(x, shape) return res return alloc diff --git a/pytensor/link/mlx/dispatch/elemwise.py b/pytensor/link/mlx/dispatch/elemwise.py index 476d32f805..ab92358862 100644 --- a/pytensor/link/mlx/dispatch/elemwise.py +++ b/pytensor/link/mlx/dispatch/elemwise.py @@ -44,6 +44,18 @@ def mlx_funcify_CAReduce_scalar_op(scalar_op): ) +@mlx_funcify.register(CAReduce) +def mlx_funcify_CAReduce(op, **kwargs): + # Dispatch to the appropriate scalar op handler + scalar_reduce_fn = mlx_funcify_CAReduce_scalar_op(op.scalar_op) + axis = op.axis + + def reduce(x): + return scalar_reduce_fn(x, axis) + + return reduce + + @mlx_funcify_CAReduce_scalar_op.register(Add) def _(scalar_op): def sum_reduce(x, axis): @@ -92,18 +104,6 @@ def min_reduce(x, axis): return min_reduce -@mlx_funcify.register(CAReduce) -def mlx_funcify_CAReduce(op, **kwargs): - # Dispatch to the appropriate scalar op handler - scalar_reduce_fn = mlx_funcify_CAReduce_scalar_op(op.scalar_op) - axis = op.axis - - def reduce(x): - return scalar_reduce_fn(x, axis) - - return reduce - - @mlx_funcify.register(Softmax) def mlx_funcify_Softmax(op, **kwargs): axis = op.axis diff --git a/pytensor/link/mlx/dispatch/math.py b/pytensor/link/mlx/dispatch/math.py index f876a58671..8324ee2b64 100644 --- a/pytensor/link/mlx/dispatch/math.py +++ b/pytensor/link/mlx/dispatch/math.py @@ -40,12 +40,6 @@ from pytensor.tensor.math import Dot -@mlx_typify.register(int) -@mlx_typify.register(float) -def mlx_typify_python_scalar(data, **kwargs): - return mx.array(data) - - @mlx_funcify.register(Dot) def mlx_funcify_Dot(op, **kwargs): def dot(x, y): @@ -57,37 +51,26 @@ def dot(x, y): # Second-level dispatch for scalar operations in Elemwise @singledispatch def mlx_funcify_Elemwise_scalar_op(scalar_op): - """Default implementation that tries to use getattr(mx, func_name) similar to JAX.""" - - # Try to get the function name from nfunc_spec (like JAX does) - nfunc_spec = getattr(scalar_op, "nfunc_spec", None) - if nfunc_spec is not None: - func_name = nfunc_spec[0] + """Simplified implementation for MLX scalar operations.""" + + # Try using the operation name directly (most common case) + op_name = getattr(scalar_op, "name", None) + if op_name is not None: try: - mlx_func = getattr(mx, func_name) - # Handle variadic functions - if len(scalar_op.inputs) > nfunc_spec[1]: - # For operations like Add that can take multiple inputs + mlx_func = getattr(mx, op_name) + # Handle variadic functions like Add + if hasattr(scalar_op, "inputs") and len(scalar_op.inputs) > 2: def variadic_func(*args): result = args[0] for arg in args[1:]: result = mlx_func(result, arg) return result - return variadic_func else: return mlx_func except AttributeError: pass - # Try using the operation name directly - op_name = getattr(scalar_op, "name", None) - if op_name is not None: - try: - return getattr(mx, op_name) - except AttributeError: - pass - raise NotImplementedError(f"MLX does not support Elemwise scalar op {scalar_op}") @@ -322,26 +305,6 @@ def sigmoid(x): return sigmoid -@mlx_funcify_Elemwise_scalar_op.register(Softplus) -def _(scalar_op): - def softplus(x): - return mx.where( - x < -37.0, - mx.exp(x), - mx.where( - x < 18.0, - mx.log1p(mx.exp(x)), - mx.where( - x < 33.3, - x + mx.exp(-x), - x, - ), - ), - ) - - return softplus - - @mlx_funcify_Elemwise_scalar_op.register(Invert) def _(scalar_op): def invert(x): @@ -351,19 +314,11 @@ def invert(x): @mlx_funcify.register(Elemwise) -def mlx_funcify_Elemwise(op, node=None, **kwargs): +def mlx_funcify_Elemwise(op, node, **kwargs): # Dispatch to the appropriate scalar op handler scalar_func = mlx_funcify_Elemwise_scalar_op(op.scalar_op) def elemwise(*inputs): - # Enforce runtime broadcast checks (same as JAX and PyTorch implementations) - if node is not None: - # Convert inputs to MLX arrays for broadcast checking - mlx_inputs = tuple( - mx.array(inp) if not hasattr(inp, "shape") else inp for inp in inputs - ) - Elemwise._check_runtime_broadcast(node, mlx_inputs) - return scalar_func(*inputs) return elemwise diff --git a/pytensor/link/mlx/dispatch/shape.py b/pytensor/link/mlx/dispatch/shape.py index bd5b5941d9..8c5973f543 100644 --- a/pytensor/link/mlx/dispatch/shape.py +++ b/pytensor/link/mlx/dispatch/shape.py @@ -1,5 +1,13 @@ from pytensor.link.mlx.dispatch.basic import mlx_funcify -from pytensor.tensor.shape import Shape_i, SpecifyShape +from pytensor.tensor.shape import Shape, Shape_i, SpecifyShape + + +@mlx_funcify.register(Shape) +def mlx_funcify_Shape(op, **kwargs): + def shape(x): + return x.shape + + return shape @mlx_funcify.register(SpecifyShape) diff --git a/tests/link/mlx/test_shape.py b/tests/link/mlx/test_shape.py index 7a548df8f8..999d12e6c7 100644 --- a/tests/link/mlx/test_shape.py +++ b/tests/link/mlx/test_shape.py @@ -9,7 +9,6 @@ from tests.link.mlx.test_basic import compare_mlx_and_py -@pytest.mark.xfail(reason="Shape Op is not supported yet") def test_mlx_shape_ops(): x_np = np.zeros((20, 3)) x = Shape()(pt.as_tensor_variable(x_np)) @@ -21,7 +20,6 @@ def test_mlx_shape_ops(): compare_mlx_and_py([], [x], [], must_be_device_array=False) -@pytest.mark.xfail(reason="Shape Op is not supported yet") def test_mlx_specify_shape(): in_pt = pt.matrix("in") x = pt.specify_shape(in_pt, (4, None)) From 845561c28865060e532a57a278d3b4074a2d7da6 Mon Sep 17 00:00:00 2001 From: Carlos Trujillo <59846724+cetagostini@users.noreply.github.com> Date: Sun, 8 Jun 2025 23:10:08 +0300 Subject: [PATCH 58/71] Pre Commit RUN --- pytensor/link/mlx/dispatch/core.py | 2 -- pytensor/link/mlx/dispatch/math.py | 7 ++++--- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/pytensor/link/mlx/dispatch/core.py b/pytensor/link/mlx/dispatch/core.py index c0130ce1cc..10af6aca80 100644 --- a/pytensor/link/mlx/dispatch/core.py +++ b/pytensor/link/mlx/dispatch/core.py @@ -11,8 +11,6 @@ from __future__ import annotations -import warnings - import mlx.core as mx import numpy as np diff --git a/pytensor/link/mlx/dispatch/math.py b/pytensor/link/mlx/dispatch/math.py index 8324ee2b64..c21dda8d22 100644 --- a/pytensor/link/mlx/dispatch/math.py +++ b/pytensor/link/mlx/dispatch/math.py @@ -2,9 +2,8 @@ import mlx.core as mx -from pytensor.link.mlx.dispatch import mlx_funcify, mlx_typify +from pytensor.link.mlx.dispatch import mlx_funcify from pytensor.link.mlx.dispatch.core import convert_dtype_to_mlx -from pytensor.scalar import Softplus from pytensor.scalar.basic import ( AND, EQ, @@ -52,7 +51,7 @@ def dot(x, y): @singledispatch def mlx_funcify_Elemwise_scalar_op(scalar_op): """Simplified implementation for MLX scalar operations.""" - + # Try using the operation name directly (most common case) op_name = getattr(scalar_op, "name", None) if op_name is not None: @@ -60,11 +59,13 @@ def mlx_funcify_Elemwise_scalar_op(scalar_op): mlx_func = getattr(mx, op_name) # Handle variadic functions like Add if hasattr(scalar_op, "inputs") and len(scalar_op.inputs) > 2: + def variadic_func(*args): result = args[0] for arg in args[1:]: result = mlx_func(result, arg) return result + return variadic_func else: return mlx_func From 6ab74281b877142d35865c92d968b8a8d7703779 Mon Sep 17 00:00:00 2001 From: Carlos Trujillo <59846724+cetagostini@users.noreply.github.com> Date: Mon, 9 Jun 2025 00:38:46 +0300 Subject: [PATCH 59/71] Adding more operations for complex model --- pytensor/link/mlx/dispatch/core.py | 2 +- pytensor/link/mlx/dispatch/math.py | 54 ++++++++++++- pytensor/link/mlx/dispatch/shape.py | 12 ++- tests/link/mlx/test_basic.py | 56 +++++++++++++- tests/link/mlx/test_elemwise.py | 37 +++++++++ tests/link/mlx/test_math.py | 114 ++++++++++++++++++++++++++++ tests/link/mlx/test_shape.py | 37 ++++++++- 7 files changed, 303 insertions(+), 9 deletions(-) diff --git a/pytensor/link/mlx/dispatch/core.py b/pytensor/link/mlx/dispatch/core.py index 10af6aca80..d4f5059da5 100644 --- a/pytensor/link/mlx/dispatch/core.py +++ b/pytensor/link/mlx/dispatch/core.py @@ -177,7 +177,7 @@ def tensor_from_scalar(x): @mlx_funcify.register(ScalarFromTensor) def mlx_funcify_ScalarFromTensor(op, **kwargs): def scalar_from_tensor(x): - return x.reshape(-1)[0] + return mx.array(x).reshape(-1)[0] return scalar_from_tensor diff --git a/pytensor/link/mlx/dispatch/math.py b/pytensor/link/mlx/dispatch/math.py index c21dda8d22..e765bea6ea 100644 --- a/pytensor/link/mlx/dispatch/math.py +++ b/pytensor/link/mlx/dispatch/math.py @@ -18,7 +18,9 @@ Cast, Cos, Exp, + IntDiv, Invert, + IsNan, Log, Log1p, Mul, @@ -34,7 +36,7 @@ Switch, TrueDiv, ) -from pytensor.scalar.math import Sigmoid +from pytensor.scalar.math import Erfc, Erfcx, Sigmoid, Softplus from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.math import Dot @@ -113,6 +115,14 @@ def true_div(x, y): return true_div +@mlx_funcify_Elemwise_scalar_op.register(IntDiv) +def _(scalar_op): + def int_div(x, y): + return mx.floor_divide(x, y) + + return int_div + + @mlx_funcify_Elemwise_scalar_op.register(Pow) def _(scalar_op): def pow(x, y): @@ -309,11 +319,51 @@ def sigmoid(x): @mlx_funcify_Elemwise_scalar_op.register(Invert) def _(scalar_op): def invert(x): - return ~x + return mx.bitwise_invert(x) return invert +@mlx_funcify_Elemwise_scalar_op.register(IsNan) +def _(scalar_op): + def isnan(x): + return mx.isnan(x) + + return isnan + + +@mlx_funcify_Elemwise_scalar_op.register(Erfc) +def _(scalar_op): + def erfc(x): + return 1.0 - mx.erf(x) + + return erfc + + +@mlx_funcify_Elemwise_scalar_op.register(Erfcx) +def _(scalar_op): + def erfcx(x): + return mx.exp(x * x) * (1.0 - mx.erf(x)) + + return erfcx + + +@mlx_funcify_Elemwise_scalar_op.register(Softplus) +def _(scalar_op): + def softplus(x): + # Numerically stable implementation of log(1 + exp(x)) + # Following the same logic as the original PyTensor implementation + return mx.where( + x < -37.0, + mx.exp(x), + mx.where( + x < 18.0, mx.log1p(mx.exp(x)), mx.where(x < 33.3, x + mx.exp(-x), x) + ), + ) + + return softplus + + @mlx_funcify.register(Elemwise) def mlx_funcify_Elemwise(op, node, **kwargs): # Dispatch to the appropriate scalar op handler diff --git a/pytensor/link/mlx/dispatch/shape.py b/pytensor/link/mlx/dispatch/shape.py index 8c5973f543..5441c8a363 100644 --- a/pytensor/link/mlx/dispatch/shape.py +++ b/pytensor/link/mlx/dispatch/shape.py @@ -1,5 +1,7 @@ +import mlx.core as mx + from pytensor.link.mlx.dispatch.basic import mlx_funcify -from pytensor.tensor.shape import Shape, Shape_i, SpecifyShape +from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape @mlx_funcify.register(Shape) @@ -30,3 +32,11 @@ def shape_i(x): return x.shape[op.i] return shape_i + + +@mlx_funcify.register(Reshape) +def mlx_funcify_Reshape(op, **kwargs): + def reshape(x, shp): + return mx.reshape(x, shp) + + return reshape diff --git a/tests/link/mlx/test_basic.py b/tests/link/mlx/test_basic.py index 4a6e67f406..f1e20892e7 100644 --- a/tests/link/mlx/test_basic.py +++ b/tests/link/mlx/test_basic.py @@ -1,18 +1,19 @@ from collections.abc import Callable, Iterable from functools import partial +import mlx.core as mx import numpy as np -import pytest +import pytensor +from pytensor import tensor as pt from pytensor.compile.function import function from pytensor.compile.mode import MLX, Mode from pytensor.graph import RewriteDatabaseQuery from pytensor.graph.basic import Variable from pytensor.link.mlx import MLXLinker +from pytensor.link.mlx.dispatch.core import mlx_funcify_ScalarFromTensor -mx = pytest.importorskip("mlx.core") - optimizer = RewriteDatabaseQuery(include=["mlx"], exclude=MLX._optimizer.exclude) mlx_mode = Mode(linker=MLXLinker(), optimizer=optimizer) py_mode = Mode(linker="py", optimizer=None) @@ -78,3 +79,52 @@ def compare_mlx_and_py( assert_fn(mlx_res, py_res) return pytensor_mlx_fn, mlx_res + + +def test_scalar_from_tensor_with_scalars(): + """Test ScalarFromTensor works with both MLX arrays and Python/NumPy scalars. + + This addresses the AttributeError that occurred when Python integers were + passed to ScalarFromTensor instead of MLX arrays. + """ + scalar_from_tensor_func = mlx_funcify_ScalarFromTensor(None) + + # Test with MLX array + mlx_array = mx.array([42]) + result = scalar_from_tensor_func(mlx_array) + assert result == 42 + + # Test with Python int (this used to fail) + python_int = 42 + result = scalar_from_tensor_func(python_int) + assert result == 42 + + # Test with Python float + python_float = 3.14 + result = scalar_from_tensor_func(python_float) + assert abs(result - 3.14) < 1e-6 + + # Test with NumPy scalar + numpy_scalar = np.int32(123) + result = scalar_from_tensor_func(numpy_scalar) + assert result == 123 + + # Test with NumPy float scalar + numpy_float = np.float32(2.71) + result = scalar_from_tensor_func(numpy_float) + assert abs(result - 2.71) < 1e-6 + + +def test_scalar_from_tensor_pytensor_integration(): + """Test ScalarFromTensor in a PyTensor graph context.""" + # Create a 0-d tensor (scalar tensor) + x = pt.as_tensor_variable(42) + + # Apply ScalarFromTensor + scalar_result = pt.scalar_from_tensor(x) + + # Create function and test + f = pytensor.function([], scalar_result, mode="MLX") + result = f() + + assert result == 42 diff --git a/tests/link/mlx/test_elemwise.py b/tests/link/mlx/test_elemwise.py index 7819df06be..ab6e64e0cb 100644 --- a/tests/link/mlx/test_elemwise.py +++ b/tests/link/mlx/test_elemwise.py @@ -1,3 +1,4 @@ +import numpy as np import pytest import pytensor.tensor as pt @@ -11,3 +12,39 @@ def test_input(op) -> None: x_test = mx.array([1.0, 2.0, 3.0]) compare_mlx_and_py([x], out, [x_test]) + + +def test_new_elemwise_operations() -> None: + """Test new elemwise operations (IntDiv, IsNan, Erfc, Erfcx, Softplus) in elemwise context""" + x = pt.vector("x") + y = pt.vector("y") + + # Test int_div in an elemwise expression + out_int_div = pt.int_div(x, y) + 1 + x_test = mx.array([10.0, 15.0, 20.0]) + y_test = mx.array([3.0, 4.0, 6.0]) + compare_mlx_and_py([x, y], out_int_div, [x_test, y_test]) + + # Test isnan in an elemwise expression + z = pt.vector("z") + out_isnan = pt.isnan(z).astype("float32") * 10 + z_test = mx.array([1.0, np.nan, 3.0]) + compare_mlx_and_py([z], out_isnan, [z_test]) + + # Test erfc in an elemwise expression + w = pt.vector("w") + out_erfc = pt.erfc(w) * 2.0 + w_test = mx.array([0.0, 0.5, 1.0]) + compare_mlx_and_py([w], out_erfc, [w_test]) + + # Test erfcx in an elemwise expression + v = pt.vector("v") + out_erfcx = pt.erfcx(v) + 0.1 + v_test = mx.array([0.0, 1.0, 2.0]) + compare_mlx_and_py([v], out_erfcx, [v_test]) + + # Test softplus in an elemwise expression + u = pt.vector("u") + out_softplus = pt.softplus(u) - 0.5 + u_test = mx.array([0.0, 1.0, -1.0]) + compare_mlx_and_py([u], out_softplus, [u_test]) diff --git a/tests/link/mlx/test_math.py b/tests/link/mlx/test_math.py index 2c08d986c9..d35cb27654 100644 --- a/tests/link/mlx/test_math.py +++ b/tests/link/mlx/test_math.py @@ -79,6 +79,7 @@ def test_input(op) -> None: pytest.param(pt.eq, id="eq"), pytest.param(pt.neq, id="neq"), pytest.param(pt.true_div, id="true_div"), + pytest.param(pt.int_div, id="int_div"), ], ) def test_elemwise_two_inputs(op) -> None: @@ -90,6 +91,119 @@ def test_elemwise_two_inputs(op) -> None: compare_mlx_and_py([x, y], out, [x_test, y_test]) +def test_int_div_specific() -> None: + """Test integer division with specific test cases""" + x = pt.vector("x") + y = pt.vector("y") + out = pt.int_div(x, y) + + # Test with integers that demonstrate floor division behavior + x_test = mx.array([7.0, 8.0, 9.0, -7.0, -8.0]) + y_test = mx.array([3.0, 3.0, 3.0, 3.0, 3.0]) + + compare_mlx_and_py([x, y], out, [x_test, y_test]) + + +def test_isnan() -> None: + """Test IsNan operation with various inputs including NaN values""" + x = pt.vector("x") + out = pt.isnan(x) + + # Test with mix of normal values, NaN, and infinity + x_test = mx.array([1.0, np.nan, 3.0, np.inf, -np.nan, 0.0, -np.inf]) + + compare_mlx_and_py([x], out, [x_test]) + + +def test_isnan_edge_cases() -> None: + """Test IsNan with edge cases""" + x = pt.scalar("x") + out = pt.isnan(x) + + # Test individual cases + test_cases = [0.0, np.nan, np.inf, -np.inf, 1e-10, 1e10] + + for test_val in test_cases: + x_test = test_val + compare_mlx_and_py([x], out, [x_test]) + + +def test_erfc() -> None: + """Test complementary error function""" + x = pt.vector("x") + out = pt.erfc(x) + + # Test with various values including negative, positive, and zero + x_test = mx.array([0.0, 0.5, 1.0, -0.5, -1.0, 2.0, -2.0, 0.1]) + + compare_mlx_and_py([x], out, [x_test]) + + +def test_erfc_extreme_values() -> None: + """Test erfc with extreme values""" + x = pt.vector("x") + out = pt.erfc(x) + + # Test with larger values where erfc approaches 0 or 2 + x_test = mx.array([-3.0, -2.5, 2.5, 3.0]) + + # Use relaxed tolerance for extreme values due to numerical precision differences + from functools import partial + + relaxed_assert = partial(np.testing.assert_allclose, rtol=1e-3, atol=1e-6) + + compare_mlx_and_py([x], out, [x_test], assert_fn=relaxed_assert) + + +def test_erfcx() -> None: + """Test scaled complementary error function""" + x = pt.vector("x") + out = pt.erfcx(x) + + # Test with positive values where erfcx is most numerically stable + x_test = mx.array([0.0, 0.5, 1.0, 1.5, 2.0, 2.5]) + + compare_mlx_and_py([x], out, [x_test]) + + +def test_erfcx_small_values() -> None: + """Test erfcx with small values""" + x = pt.vector("x") + out = pt.erfcx(x) + + # Test with small values + x_test = mx.array([0.001, 0.01, 0.1, 0.2]) + + compare_mlx_and_py([x], out, [x_test]) + + +def test_softplus() -> None: + """Test softplus (log(1 + exp(x))) function""" + x = pt.vector("x") + out = pt.softplus(x) + + # Test with normal range values + x_test = mx.array([0.0, 1.0, 2.0, -1.0, -2.0, 10.0]) + + compare_mlx_and_py([x], out, [x_test]) + + +def test_softplus_extreme_values() -> None: + """Test softplus with extreme values to verify numerical stability""" + x = pt.vector("x") + out = pt.softplus(x) + + # Test with extreme values where different branches of the implementation are used + x_test = mx.array([-40.0, -50.0, 20.0, 30.0, 35.0, 50.0]) + + # Use relaxed tolerance for extreme values due to numerical precision differences + from functools import partial + + relaxed_assert = partial(np.testing.assert_allclose, rtol=1e-4, atol=1e-8) + + compare_mlx_and_py([x], out, [x_test], assert_fn=relaxed_assert) + + @pytest.mark.xfail(reason="Argmax not implemented yet") def test_mlx_max_and_argmax(): # Test that a single output of a multi-output `Op` can be used as input to diff --git a/tests/link/mlx/test_shape.py b/tests/link/mlx/test_shape.py index 999d12e6c7..19c3dd220b 100644 --- a/tests/link/mlx/test_shape.py +++ b/tests/link/mlx/test_shape.py @@ -37,14 +37,47 @@ def test_mlx_specify_shape(): ) -@pytest.mark.xfail(reason="Reshape Op is not supported yet") def test_mlx_Reshape_constant(): a = vector("a") x = reshape(a, (2, 2)) compare_mlx_and_py([a], [x], [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX)]) -@pytest.mark.xfail(reason="Reshape Op is not supported yet") +def test_mlx_Reshape_various_shapes(): + """Test reshape with various different shapes to ensure robustness.""" + # 1D to 2D + a = vector("a") + x = reshape(a, (2, 3)) + compare_mlx_and_py([a], [x], [np.arange(6, dtype=config.floatX)]) + + # 2D to 1D + b = pt.matrix("b") + y = reshape(b, (6,)) + compare_mlx_and_py([b], [y], [np.arange(6, dtype=config.floatX).reshape(2, 3)]) + + # 2D to 3D + c = pt.matrix("c") + z = reshape(c, (2, 2, 3)) + compare_mlx_and_py([c], [z], [np.arange(12, dtype=config.floatX).reshape(4, 3)]) + + # 3D to 2D + d = pt.tensor3("d") + w = reshape(d, (3, 4)) + compare_mlx_and_py([d], [w], [np.arange(12, dtype=config.floatX).reshape(2, 2, 3)]) + + +def test_mlx_Reshape_negative_one(): + """Test reshape with -1 dimension (infer dimension).""" + a = vector("a") + # Use -1 to infer the second dimension + x = reshape(a, (2, -1)) + compare_mlx_and_py([a], [x], [np.arange(8, dtype=config.floatX)]) + + # Use -1 to infer the first dimension + y = reshape(a, (-1, 4)) + compare_mlx_and_py([a], [y], [np.arange(8, dtype=config.floatX)]) + + def test_mlx_Reshape_concrete_shape(): """MLX should compile when a concrete value is passed for the `shape` parameter.""" a = vector("a") From b6292f1bc0cd68c87303ea496ee08830ff3ffeb7 Mon Sep 17 00:00:00 2001 From: Carlos Trujillo <59846724+cetagostini@users.noreply.github.com> Date: Mon, 9 Jun 2025 12:10:25 +0300 Subject: [PATCH 60/71] Working with simple model --- pytensor/link/mlx/dispatch/core.py | 69 +++++++- pytensor/link/mlx/dispatch/elemwise.py | 21 ++- pytensor/link/mlx/dispatch/math.py | 21 ++- pytensor/link/mlx/linker.py | 10 +- tests/link/mlx/test_basic.py | 217 ++++++++++++++++++++++++- 5 files changed, 329 insertions(+), 9 deletions(-) diff --git a/pytensor/link/mlx/dispatch/core.py b/pytensor/link/mlx/dispatch/core.py index d4f5059da5..bdddfa034d 100644 --- a/pytensor/link/mlx/dispatch/core.py +++ b/pytensor/link/mlx/dispatch/core.py @@ -115,12 +115,25 @@ def eye(*args): return eye -def convert_dtype_to_mlx(dtype_str): +def convert_dtype_to_mlx(dtype_str, auto_cast_unsupported=True): """Convert PyTensor dtype strings to MLX dtype objects. MLX expects dtype objects rather than string literals for type conversion. This function maps common dtype strings to their MLX equivalents. + + Parameters + ---------- + dtype_str : str or MLX dtype + The dtype to convert + auto_cast_unsupported : bool + If True, automatically cast unsupported dtypes to supported ones with warnings + + Returns + ------- + MLX dtype object """ + import warnings + if isinstance(dtype_str, str): if dtype_str == "bool": return mx.bool_ @@ -145,13 +158,35 @@ def convert_dtype_to_mlx(dtype_str): elif dtype_str == "float32": return mx.float32 elif dtype_str == "float64": - return mx.float64 + if auto_cast_unsupported: + warnings.warn( + "MLX does not support float64 on GPU. Automatically casting to float32. " + "This may result in reduced precision. To avoid this warning, " + "explicitly use float32 in your code or set floatX='float32' in PyTensor config.", + UserWarning, + stacklevel=3, + ) + return mx.float32 + else: + return mx.float64 elif dtype_str == "bfloat16": return mx.bfloat16 elif dtype_str == "complex64": return mx.complex64 elif dtype_str == "complex128": - return mx.complex128 + if auto_cast_unsupported: + warnings.warn( + "MLX does not support complex128. Automatically casting to complex64. " + "This may result in reduced precision. To avoid this warning, " + "explicitly use complex64 in your code.", + UserWarning, + stacklevel=3, + ) + return mx.complex64 + else: + # Return the original even though it might fail + # This allows users to opt out of auto-casting if needed + return mx.complex64 # MLX doesn't have complex128, so fallback # Return as is if it's already an MLX dtype or not a recognized string return dtype_str @@ -212,7 +247,31 @@ def allocempty(*shape): @mlx_funcify.register(Alloc) def mlx_funcify_Alloc(op, node, **kwargs): def alloc(x, *shape): - res = mx.broadcast_to(x, shape) - return res + try: + # Convert shape elements to Python ints for MLX compatibility + # MLX requires shape dimensions to be Python integers, not MLX arrays + shape_ints = tuple( + int(s.item()) if hasattr(s, "item") else int(s) for s in shape + ) + return mx.broadcast_to(x, shape_ints) + except ValueError as e: + if ( + "[eval] Attempting to eval an array during function transformations" + in str(e) + ): + # This is the MLX compilation limitation - provide helpful error + raise ValueError( + "MLX compilation limitation: Alloc operations with dynamic shapes " + "cannot be used inside compiled functions. This is because MLX " + "compilation forbids evaluating arrays to extract shape values. " + "\n\nWorkarounds:" + "\n1. Avoid using Alloc with dynamic shapes in compiled contexts" + "\n2. Use static shapes when possible" + "\n3. Move Alloc operations outside compiled functions" + "\n\nOriginal error: " + str(e) + ) from e + else: + # Re-raise other ValueError exceptions + raise return alloc diff --git a/pytensor/link/mlx/dispatch/elemwise.py b/pytensor/link/mlx/dispatch/elemwise.py index ab92358862..0bbc98cf81 100644 --- a/pytensor/link/mlx/dispatch/elemwise.py +++ b/pytensor/link/mlx/dispatch/elemwise.py @@ -149,6 +149,25 @@ def softplus(x): def mlx_funcify_Cast(op, **kwargs): def cast(x): dtype = convert_dtype_to_mlx(op.scalar_op.o_type.dtype) - return x.astype(dtype) + try: + return x.astype(dtype) + except ValueError as e: + if "is not supported on the GPU" in str(e): + # MLX GPU limitation - try auto-casting with warning + import warnings + + warnings.warn( + f"MLX GPU limitation: {e}. Attempting automatic fallback casting.", + UserWarning, + stacklevel=2, + ) + # Get the auto-cast version + fallback_dtype = convert_dtype_to_mlx( + op.scalar_op.o_type.dtype, auto_cast_unsupported=True + ) + return x.astype(fallback_dtype) + else: + # Re-raise other ValueError exceptions + raise return cast diff --git a/pytensor/link/mlx/dispatch/math.py b/pytensor/link/mlx/dispatch/math.py index e765bea6ea..0d8e1ee7b1 100644 --- a/pytensor/link/mlx/dispatch/math.py +++ b/pytensor/link/mlx/dispatch/math.py @@ -303,7 +303,26 @@ def minimum(x, y): def _(scalar_op): def cast(x): dtype = convert_dtype_to_mlx(scalar_op.o_type.dtype) - return x.astype(dtype) + try: + return x.astype(dtype) + except ValueError as e: + if "is not supported on the GPU" in str(e): + # MLX GPU limitation - try auto-casting with warning + import warnings + + warnings.warn( + f"MLX GPU limitation: {e}. Attempting automatic fallback casting.", + UserWarning, + stacklevel=2, + ) + # Get the auto-cast version + fallback_dtype = convert_dtype_to_mlx( + scalar_op.o_type.dtype, auto_cast_unsupported=True + ) + return x.astype(fallback_dtype) + else: + # Re-raise other ValueError exceptions + raise return cast diff --git a/pytensor/link/mlx/linker.py b/pytensor/link/mlx/linker.py index e057bb942c..9a4d1ac2c1 100644 --- a/pytensor/link/mlx/linker.py +++ b/pytensor/link/mlx/linker.py @@ -4,9 +4,10 @@ class MLXLinker(JITLinker): """A `Linker` that JIT-compiles NumPy-based operations using Apple's MLX.""" - def __init__(self, *args, **kwargs): + def __init__(self, use_compile=True, *args, **kwargs): super().__init__(*args, **kwargs) self.gen_functors = [] + self.use_compile = use_compile def fgraph_convert(self, fgraph, **kwargs): """Convert a PyTensor FunctionGraph to an MLX-compatible function. @@ -33,6 +34,13 @@ def jit_compile(self, fn): from pytensor.link.mlx.dispatch import mlx_typify + if not self.use_compile: + # Skip compilation and just return the function with MLX typification + def fn_no_compile(*inputs): + return fn(*(mlx_typify(inp) for inp in inputs)) + + return fn_no_compile + inner_fn = mx.compile(fn) def fn(*inputs, inner_fn=inner_fn): diff --git a/tests/link/mlx/test_basic.py b/tests/link/mlx/test_basic.py index f1e20892e7..36fcadd7a7 100644 --- a/tests/link/mlx/test_basic.py +++ b/tests/link/mlx/test_basic.py @@ -11,11 +11,17 @@ from pytensor.graph import RewriteDatabaseQuery from pytensor.graph.basic import Variable from pytensor.link.mlx import MLXLinker -from pytensor.link.mlx.dispatch.core import mlx_funcify_ScalarFromTensor +from pytensor.link.mlx.dispatch.core import ( + mlx_funcify_Alloc, + mlx_funcify_ScalarFromTensor, +) +from pytensor.tensor.basic import Alloc optimizer = RewriteDatabaseQuery(include=["mlx"], exclude=MLX._optimizer.exclude) mlx_mode = Mode(linker=MLXLinker(), optimizer=optimizer) +mlx_mode_no_compile = Mode(linker=MLXLinker(use_compile=False), optimizer=optimizer) +compile_mode = Mode(linker=MLXLinker(use_compile=True), optimizer=optimizer) py_mode = Mode(linker="py", optimizer=None) @@ -128,3 +134,212 @@ def test_scalar_from_tensor_pytensor_integration(): result = f() assert result == 42 + + +def test_alloc_with_different_shape_types(): + """Test Alloc works with different types of shape parameters. + + This addresses the TypeError that occurred when shape parameters + contained MLX arrays instead of Python integers. + """ + + # Create a mock node (we don't need a real node for this test) + class MockNode: + pass + + alloc_func = mlx_funcify_Alloc(Alloc(), MockNode()) + x = mx.array(5.0) + + # Test with Python ints + result = alloc_func(x, 3, 4) + assert result.shape == (3, 4) + assert float(result[0, 0]) == 5.0 + + # Test with MLX arrays (this used to fail) + result = alloc_func(x, mx.array(3), mx.array(4)) + assert result.shape == (3, 4) + assert float(result[0, 0]) == 5.0 + + # Test with mixed types + result = alloc_func(x, 3, mx.array(4)) + assert result.shape == (3, 4) + assert float(result[0, 0]) == 5.0 + + +def test_alloc_pytensor_integration(): + """Test Alloc in a PyTensor graph context.""" + # Test basic constant shape allocation + x = pt.scalar("x", dtype="float32") + result = pt.alloc(x, 3, 4) + + # Use MLX mode + from pytensor.compile import mode + + mlx_mode = mode.get_mode("MLX") + + f = pytensor.function([x], result, mode=mlx_mode) + output = f(5.0) + + assert output.shape == (3, 4) + assert float(output[0, 0]) == 5.0 + + +def test_alloc_compilation_limitation(): + """Test that Alloc operations with dynamic shapes provide helpful error in compiled contexts.""" + import pytest + + # Create variables + x = pt.scalar("x", dtype="float32") + s1 = pt.scalar("s1", dtype="int64") + s2 = pt.scalar("s2", dtype="int64") + + # Create Alloc operation with dynamic shapes + result = pt.alloc(x, s1, s2) + + # Create function with non-compiled MLX mode + f = pytensor.function([x, s1, s2], result, mode=mlx_mode_no_compile) + + # Test that it works with concrete values (non-compiled context) + output = f(5.0, 3, 4) + assert output.shape == (3, 4) + assert np.allclose(output, 5.0) + + # Test that compilation fails with helpful error + compiled_f = pytensor.function([x, s1, s2], result, mode=compile_mode) + + with pytest.raises(ValueError) as exc_info: + compiled_f(5.0, 3, 4) + + error_msg = str(exc_info.value) + assert "MLX compilation limitation" in error_msg + assert "Alloc operations with dynamic shapes" in error_msg + assert "cannot be used inside compiled functions" in error_msg + assert "Workarounds:" in error_msg + assert "Avoid using Alloc with dynamic shapes in compiled contexts" in error_msg + assert "Use static shapes when possible" in error_msg + assert "Move Alloc operations outside compiled functions" in error_msg + + +def test_alloc_static_shapes_compilation(): + """Test that Alloc operations with static shapes work fine in compiled contexts.""" + # Create a scenario with static shapes that should work + x = pt.scalar("x", dtype="float32") + + # Use constant shape - this should work even in compilation + result = pt.alloc(x, 3, 4) # Static shapes + + # Test both compiled and non-compiled modes + f_normal = pytensor.function([x], result, mode=mlx_mode_no_compile) + f_compiled = pytensor.function([x], result, mode=compile_mode) + + # Both should work + output_normal = f_normal(5.0) + output_compiled = f_compiled(5.0) + + assert output_normal.shape == (3, 4) + assert output_compiled.shape == (3, 4) + assert np.allclose(output_normal, 5.0) + assert np.allclose(output_compiled, 5.0) + assert np.allclose(output_normal, output_compiled) + + +def test_mlx_float64_auto_casting(): + """Test MLX automatic casting of float64 to float32 with warnings.""" + import warnings + + # Test 1: Direct Cast operation with warning + x = pt.scalar("x", dtype="float32") + y = pt.cast(x, "float64") + + # Capture warnings + with warnings.catch_warnings(record=True) as warning_list: + warnings.simplefilter("always") + + f = pytensor.function([x], y, mode=mlx_mode, allow_input_downcast=True) + result = f(3.14) + + # Check that the operation succeeded + assert result.dtype == mx.float32 # Should be auto-cast to float32 + assert abs(float(result) - 3.14) < 1e-6 + + # Check that a warning was issued + warning_messages = [str(w.message) for w in warning_list] + dtype_warnings = [ + msg for msg in warning_messages if "float64" in msg and "float32" in msg + ] + assert ( + len(dtype_warnings) > 0 + ), f"Expected dtype warning, got warnings: {warning_messages}" + + +def test_mlx_float64_complex_operations(): + """Test float64 casting in more complex operations.""" + import warnings + + # Test with vector operations + x = pt.vector("x", dtype="float32") + y = pt.cast(x, "float64") + z = pt.exp(y) + pt.sin(y) # Multiple operations on float64 + + with warnings.catch_warnings(record=True) as warning_list: + warnings.simplefilter("always") + + f = pytensor.function([x], z, mode=mlx_mode, allow_input_downcast=True) + result = f([1.0, 2.0, 3.0]) + + # Should work and return float32 results + assert result.dtype == mx.float32 + assert result.shape == (3,) + + # Should have issued warnings + warning_messages = [str(w.message) for w in warning_list] + dtype_warnings = [ + msg + for msg in warning_messages + if "float64" in msg or "MLX GPU limitation" in msg + ] + assert len(dtype_warnings) > 0 + + +def test_mlx_float64_no_warning_when_disabled(): + """Test that auto-casting can be controlled.""" + import warnings + + from pytensor.link.mlx.dispatch.core import convert_dtype_to_mlx + + # Test that we can disable auto-casting + with warnings.catch_warnings(record=True) as warning_list: + warnings.simplefilter("always") + + # This should not issue warnings when auto_cast_unsupported=False + dtype = convert_dtype_to_mlx("float64", auto_cast_unsupported=False) + assert dtype == mx.float64 # Should return the original dtype + + # No warnings should be issued for proactive conversion when disabled + dtype_warnings = [ + str(w.message) for w in warning_list if "float64" in str(w.message) + ] + assert len(dtype_warnings) == 0 + + +def test_mlx_complex128_auto_casting(): + """Test automatic casting of complex128 to complex64.""" + import warnings + + from pytensor.link.mlx.dispatch.core import convert_dtype_to_mlx + + with warnings.catch_warnings(record=True) as warning_list: + warnings.simplefilter("always") + + # This should trigger a warning and return complex64 + dtype = convert_dtype_to_mlx("complex128", auto_cast_unsupported=True) + assert dtype == mx.complex64 + + # Should have issued a warning + warning_messages = [str(w.message) for w in warning_list] + complex_warnings = [ + msg + for msg in warning_messages + if "complex128" in msg and "complex64" in msg + ] + assert len(complex_warnings) > 0 From f2d9d1b09b8e70e4a7ca1432f513b10ab792003f Mon Sep 17 00:00:00 2001 From: Carlos Trujillo <59846724+cetagostini@users.noreply.github.com> Date: Mon, 9 Jun 2025 12:19:46 +0300 Subject: [PATCH 61/71] Change bad name --- tests/link/mlx/test_elemwise.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/link/mlx/test_elemwise.py b/tests/link/mlx/test_elemwise.py index ab6e64e0cb..44d0d58dcb 100644 --- a/tests/link/mlx/test_elemwise.py +++ b/tests/link/mlx/test_elemwise.py @@ -14,8 +14,8 @@ def test_input(op) -> None: compare_mlx_and_py([x], out, [x_test]) -def test_new_elemwise_operations() -> None: - """Test new elemwise operations (IntDiv, IsNan, Erfc, Erfcx, Softplus) in elemwise context""" +def test_elemwise_operations() -> None: + """Test elemwise operations (IntDiv, IsNan, Erfc, Erfcx, Softplus) in elemwise context""" x = pt.vector("x") y = pt.vector("y") From 5c759b9e6485ab608725bb1b5251ebd1aeb5c007 Mon Sep 17 00:00:00 2001 From: Carlos Trujillo <59846724+cetagostini@users.noreply.github.com> Date: Mon, 9 Jun 2025 12:21:28 +0300 Subject: [PATCH 62/71] Correcting test by Ricardo --- tests/link/mlx/test_basic.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/link/mlx/test_basic.py b/tests/link/mlx/test_basic.py index 36fcadd7a7..4ae665fdb9 100644 --- a/tests/link/mlx/test_basic.py +++ b/tests/link/mlx/test_basic.py @@ -123,15 +123,15 @@ def test_scalar_from_tensor_with_scalars(): def test_scalar_from_tensor_pytensor_integration(): """Test ScalarFromTensor in a PyTensor graph context.""" - # Create a 0-d tensor (scalar tensor) - x = pt.as_tensor_variable(42) + # Create a symbolic scalar input to actually test MLX execution + x = pt.scalar("x", dtype="int64") # Apply ScalarFromTensor scalar_result = pt.scalar_from_tensor(x) # Create function and test - f = pytensor.function([], scalar_result, mode="MLX") - result = f() + f = pytensor.function([x], scalar_result, mode="MLX") + result = f(42) assert result == 42 From 97b2e310a01d60d74bd02ec4c3f24e16bb9ec6af Mon Sep 17 00:00:00 2001 From: Carlos Trujillo <59846724+cetagostini@users.noreply.github.com> Date: Mon, 9 Jun 2025 12:46:55 +0300 Subject: [PATCH 63/71] Changing synth test --- tests/link/mlx/test_basic.py | 111 +++++++++++++++++++++++++---------- 1 file changed, 81 insertions(+), 30 deletions(-) diff --git a/tests/link/mlx/test_basic.py b/tests/link/mlx/test_basic.py index 4ae665fdb9..b09b822187 100644 --- a/tests/link/mlx/test_basic.py +++ b/tests/link/mlx/test_basic.py @@ -1,3 +1,6 @@ +""" +Basic tests for the MLX backend. +""" from collections.abc import Callable, Iterable from functools import partial @@ -13,7 +16,6 @@ from pytensor.link.mlx import MLXLinker from pytensor.link.mlx.dispatch.core import ( mlx_funcify_Alloc, - mlx_funcify_ScalarFromTensor, ) from pytensor.tensor.basic import Alloc @@ -87,53 +89,102 @@ def compare_mlx_and_py( return pytensor_mlx_fn, mlx_res -def test_scalar_from_tensor_with_scalars(): - """Test ScalarFromTensor works with both MLX arrays and Python/NumPy scalars. +def test_scalar_from_tensor_matrix_indexing(): + """Test ScalarFromTensor with matrix element extraction.""" + # Matrix element extraction is a common real-world scenario + matrix = pt.matrix("matrix", dtype="float32") + element = matrix[0, 0] # Creates 0-d tensor - This addresses the AttributeError that occurred when Python integers were - passed to ScalarFromTensor instead of MLX arrays. - """ - scalar_from_tensor_func = mlx_funcify_ScalarFromTensor(None) + f = pytensor.function([matrix], element, mode="MLX") - # Test with MLX array - mlx_array = mx.array([42]) - result = scalar_from_tensor_func(mlx_array) - assert result == 42 + test_matrix = np.array([[42.0, 1.0], [2.0, 3.0]], dtype=np.float32) + result = f(test_matrix) + + assert float(result) == 42.0 + assert isinstance(result, mx.array) + + +def test_scalar_from_tensor_reduction_operations(): + """Test ScalarFromTensor with reduction operations that produce scalars.""" + # Test vector sum reduction + vector = pt.vector("vector", dtype="float32") + sum_result = pt.sum(vector) + + f = pytensor.function([vector], sum_result, mode="MLX") + test_vector = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32) + result = f(test_vector) + + assert float(result) == 10.0 + + # Test matrix mean reduction + matrix = pt.matrix("matrix", dtype="float32") + mean_result = pt.mean(matrix) + + f2 = pytensor.function([matrix], mean_result, mode="MLX") + test_matrix = np.array([[2.0, 4.0], [6.0, 8.0]], dtype=np.float32) + result = f2(test_matrix) + + assert float(result) == 5.0 - # Test with Python int (this used to fail) - python_int = 42 - result = scalar_from_tensor_func(python_int) - assert result == 42 - # Test with Python float - python_float = 3.14 - result = scalar_from_tensor_func(python_float) - assert abs(result - 3.14) < 1e-6 +def test_scalar_from_tensor_conditional_operations(): + """Test ScalarFromTensor with conditional operations.""" + x = pt.scalar("x", dtype="float32") + y = pt.scalar("y", dtype="float32") + + # Switch operation may create 0-d tensors + max_val = pt.switch(x > y, x, y) + + f = pytensor.function([x, y], max_val, mode="MLX") - # Test with NumPy scalar - numpy_scalar = np.int32(123) - result = scalar_from_tensor_func(numpy_scalar) - assert result == 123 + # Test both branches + result1 = f(5.0, 3.0) + assert float(result1) == 5.0 - # Test with NumPy float scalar - numpy_float = np.float32(2.71) - result = scalar_from_tensor_func(numpy_float) - assert abs(result - 2.71) < 1e-6 + result2 = f(2.0, 7.0) + assert float(result2) == 7.0 + + +def test_scalar_from_tensor_multiple_dtypes(): + """Test ScalarFromTensor with different data types.""" + # Test different dtypes that might require scalar extraction + for dtype in ["float32", "int32", "int64"]: + x = pt.vector("x", dtype=dtype) + # Use max reduction to create 0-d tensor + max_val = pt.max(x) + + f = pytensor.function([x], max_val, mode="MLX", allow_input_downcast=True) + + if dtype.startswith("float"): + test_data = np.array([1.5, 3.7, 2.1], dtype=dtype) + expected = 3.7 + else: + test_data = np.array([10, 30, 20], dtype=dtype) + expected = 30 + + result = f(test_data) + assert abs(float(result) - expected) < 1e-5 def test_scalar_from_tensor_pytensor_integration(): - """Test ScalarFromTensor in a PyTensor graph context.""" + """Test ScalarFromTensor in a complete PyTensor graph context. + + This test uses symbolic variables (not constants) to ensure the MLX backend + actually executes the ScalarFromTensor operation rather than having it + optimized away during compilation. + """ # Create a symbolic scalar input to actually test MLX execution x = pt.scalar("x", dtype="int64") - # Apply ScalarFromTensor + # Apply ScalarFromTensor - this creates a graph that forces execution scalar_result = pt.scalar_from_tensor(x) - # Create function and test + # Create function and test with actual MLX backend execution f = pytensor.function([x], scalar_result, mode="MLX") result = f(42) assert result == 42 + assert isinstance(result, mx.array) def test_alloc_with_different_shape_types(): From dd83e0fb8c4438af870877e49bbf2570b54f1eea Mon Sep 17 00:00:00 2001 From: Carlos Trujillo <59846724+cetagostini@users.noreply.github.com> Date: Mon, 9 Jun 2025 13:40:13 +0300 Subject: [PATCH 64/71] Optimizing reshape --- pytensor/link/mlx/dispatch/core.py | 12 +++++++++++- tests/link/mlx/test_basic.py | 1 + 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/pytensor/link/mlx/dispatch/core.py b/pytensor/link/mlx/dispatch/core.py index bdddfa034d..b8270e146c 100644 --- a/pytensor/link/mlx/dispatch/core.py +++ b/pytensor/link/mlx/dispatch/core.py @@ -212,7 +212,17 @@ def tensor_from_scalar(x): @mlx_funcify.register(ScalarFromTensor) def mlx_funcify_ScalarFromTensor(op, **kwargs): def scalar_from_tensor(x): - return mx.array(x).reshape(-1)[0] + arr = mx.array(x) + try: + # Try .item() first (cleaner and faster when possible) + return arr.item() + except ValueError as e: + if "eval" in str(e): + # Fall back to reshape approach for compiled contexts + return arr.reshape(-1)[0] + else: + # Re-raise if it's a different error + raise return scalar_from_tensor diff --git a/tests/link/mlx/test_basic.py b/tests/link/mlx/test_basic.py index b09b822187..871f694c49 100644 --- a/tests/link/mlx/test_basic.py +++ b/tests/link/mlx/test_basic.py @@ -1,6 +1,7 @@ """ Basic tests for the MLX backend. """ + from collections.abc import Callable, Iterable from functools import partial From 662b4f2704987561f8d5bc9b188a1e19faa66197 Mon Sep 17 00:00:00 2001 From: Carlos Trujillo <59846724+cetagostini@users.noreply.github.com> Date: Mon, 9 Jun 2025 13:58:49 +0300 Subject: [PATCH 65/71] Comment --- pytensor/link/mlx/dispatch/core.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytensor/link/mlx/dispatch/core.py b/pytensor/link/mlx/dispatch/core.py index b8270e146c..82b8f0a5de 100644 --- a/pytensor/link/mlx/dispatch/core.py +++ b/pytensor/link/mlx/dispatch/core.py @@ -274,6 +274,7 @@ def alloc(x, *shape): "MLX compilation limitation: Alloc operations with dynamic shapes " "cannot be used inside compiled functions. This is because MLX " "compilation forbids evaluating arrays to extract shape values. " + # Just a note! TODO: remove this once we have a better solution "\n\nWorkarounds:" "\n1. Avoid using Alloc with dynamic shapes in compiled contexts" "\n2. Use static shapes when possible" From bcf7f8d0cfbe0a9846f10784e251b8dc8da7ad6a Mon Sep 17 00:00:00 2001 From: Carlos Trujillo <59846724+cetagostini@users.noreply.github.com> Date: Mon, 9 Jun 2025 19:44:12 +0300 Subject: [PATCH 66/71] Small changes and adding small benchmark --- .../benchmark_mlx_v_jax_corrected.ipynb | 443 ++++++++++++++++++ pytensor/link/mlx/dispatch/blockwise.py | 12 +- pytensor/link/mlx/dispatch/math.py | 2 +- 3 files changed, 454 insertions(+), 3 deletions(-) create mode 100644 doc/_drafts/benchmark_mlx_v_jax_corrected.ipynb diff --git a/doc/_drafts/benchmark_mlx_v_jax_corrected.ipynb b/doc/_drafts/benchmark_mlx_v_jax_corrected.ipynb new file mode 100644 index 0000000000..84f78b5caf --- /dev/null +++ b/doc/_drafts/benchmark_mlx_v_jax_corrected.ipynb @@ -0,0 +1,443 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "import time\n", + "import numpy as np\n", + "import jax\n", + "import jax.numpy as jnp\n", + "\n", + "import pytensor\n", + "import pytensor.tensor as pt\n", + "from pytensor.compile.function import function\n", + "from pytensor.compile.mode import Mode\n", + "from pytensor.graph import RewriteDatabaseQuery\n", + "from pytensor.link.jax import JAXLinker\n" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [], + "source": [ + "# Configure JAX to use float32 for consistency with MLX\n", + "jax.config.update(\"jax_enable_x64\", False)\n", + "\n", + "# Set up PyTensor JAX mode\n", + "jax_optimizer = RewriteDatabaseQuery(include=[\"jax\"], exclude=[])\n", + "pytensor_jax_mode = Mode(linker=JAXLinker(), optimizer=jax_optimizer)\n", + "\n", + "# Try to set up MLX mode\n", + "try:\n", + " from pytensor.link.mlx import MLXLinker\n", + " import mlx.core as mx\n", + " mlx_optimizer = RewriteDatabaseQuery(include=[\"mlx\"], exclude=[])\n", + " pytensor_mlx_mode = Mode(linker=MLXLinker(), optimizer=mlx_optimizer)\n", + " MLX_AVAILABLE = True\n", + "except ImportError:\n", + " MLX_AVAILABLE = False\n", + "\n", + "def timer_jax(func, N=1000):\n", + " \"\"\"Time function execution with proper JAX synchronization, repeated N times\"\"\"\n", + " def wrapper(*args, **kwargs):\n", + " times = []\n", + " for _ in range(N):\n", + " start = time.perf_counter()\n", + " result = func(*args, **kwargs)\n", + " if hasattr(result, 'block_until_ready'):\n", + " result.block_until_ready()\n", + " elif isinstance(result, (list, tuple)):\n", + " for r in result:\n", + " if hasattr(r, 'block_until_ready'):\n", + " r.block_until_ready()\n", + " end = time.perf_counter()\n", + " times.append(end - start)\n", + " \n", + " mean_time = np.mean(times)\n", + " std_time = np.std(times)\n", + " return result, mean_time, std_time\n", + " return wrapper\n", + "\n", + "def timer_mlx(func, N=1000):\n", + " \"\"\"Time function execution with proper MLX synchronization, repeated N times\"\"\"\n", + " def wrapper(*args, **kwargs):\n", + " times = []\n", + " for _ in range(N):\n", + " start = time.perf_counter()\n", + " result = func(*args, **kwargs)\n", + " # For MLX, we need to use mx.eval() to force computation\n", + " if MLX_AVAILABLE:\n", + " if isinstance(result, (list, tuple)):\n", + " mx.eval(*result)\n", + " else:\n", + " mx.eval(result)\n", + " end = time.perf_counter()\n", + " times.append(end - start)\n", + " \n", + " mean_time = np.mean(times)\n", + " std_time = np.std(times)\n", + " return result, mean_time, std_time\n", + " return wrapper\n", + "\n", + "def run_benchmark(N=1000):\n", + " \"\"\"Run comprehensive benchmark comparing PyTensor JAX vs MLX backends\"\"\"\n", + " import pandas as pd\n", + " \n", + " sizes = [128, 256, 512, 1024]\n", + " results = []\n", + " \n", + " print(f\"Running benchmarks with N={N} repetitions per test...\")\n", + " \n", + " for size in sizes:\n", + " print(f\"Testing {size}x{size} matrices...\")\n", + " \n", + " # Generate test matrices with fixed seed for reproducibility\n", + " np.random.seed(42)\n", + " A = np.random.randn(size, size).astype(np.float32)\n", + " B = np.random.randn(size, size).astype(np.float32)\n", + " C = np.random.randn(size, size).astype(np.float32)\n", + " \n", + " # === TEST 1: Matrix Multiplication Chain ===\n", + " # PyTensor + JAX backend\n", + " @timer_jax\n", + " def pytensor_jax_matmul():\n", + " pt_A = pt.matrix('A', dtype='float32')\n", + " pt_B = pt.matrix('B', dtype='float32') \n", + " pt_C = pt.matrix('C', dtype='float32')\n", + " result = pt.dot(pt.dot(pt_A, pt_B), pt_C)\n", + " f = function([pt_A, pt_B, pt_C], result, mode=pytensor_jax_mode)\n", + " return f(A, B, C)\n", + " \n", + " # PyTensor + MLX backend\n", + " @timer_mlx\n", + " def pytensor_mlx_matmul():\n", + " if not MLX_AVAILABLE:\n", + " return None, float('inf'), 0\n", + " pt_A = pt.matrix('A', dtype='float32')\n", + " pt_B = pt.matrix('B', dtype='float32')\n", + " pt_C = pt.matrix('C', dtype='float32')\n", + " result = pt_A @ pt_B @ pt_C\n", + " f = function([pt_A, pt_B, pt_C], result, mode=pytensor_mlx_mode)\n", + " return f(A, B, C)\n", + " \n", + " # Run matrix multiplication test\n", + " _, jax_mean, jax_std = pytensor_jax_matmul()\n", + " try:\n", + " _, mlx_mean, mlx_std = pytensor_mlx_matmul()\n", + " except Exception as e:\n", + " print(f\"MLX matmul error: {e}\")\n", + " mlx_mean, mlx_std = float('inf'), 0\n", + " \n", + " results.append({\n", + " 'Size': f'{size}x{size}',\n", + " 'Operation': 'Matrix Chain (A @ B @ C)',\n", + " 'PyTensor+JAX Mean (s)': f'{jax_mean:.6f}',\n", + " 'PyTensor+JAX Std (s)': f'{jax_std:.6f}',\n", + " 'PyTensor+MLX Mean (s)': f'{mlx_mean:.6f}' if mlx_mean != float('inf') else 'Error',\n", + " 'PyTensor+MLX Std (s)': f'{mlx_std:.6f}' if mlx_mean != float('inf') else 'N/A',\n", + " 'MLX Speedup': f'{jax_mean/mlx_mean:.2f}x' if mlx_mean != float('inf') and mlx_mean > 0 else 'N/A'\n", + " })\n", + " \n", + " # === TEST 2: Element-wise Operations ===\n", + " # PyTensor + JAX\n", + " @timer_jax\n", + " def pytensor_jax_elemwise():\n", + " pt_A = pt.matrix('A', dtype='float32')\n", + " pt_B = pt.matrix('B', dtype='float32')\n", + " result = pt.sin(pt_A) + pt.cos(pt_B)\n", + " f = function([pt_A, pt_B], result, mode=pytensor_jax_mode)\n", + " return f(A, B)\n", + " \n", + " # PyTensor + MLX\n", + " @timer_mlx\n", + " def pytensor_mlx_elemwise():\n", + " if not MLX_AVAILABLE:\n", + " return None, float('inf'), 0\n", + " pt_A = pt.matrix('A', dtype='float32')\n", + " pt_B = pt.matrix('B', dtype='float32')\n", + " result = pt.sin(pt_A) + pt.cos(pt_B)\n", + " f = function([pt_A, pt_B], result, mode=pytensor_mlx_mode)\n", + " return f(A, B)\n", + " \n", + " # Run element-wise test\n", + " _, jax_mean, jax_std = pytensor_jax_elemwise()\n", + " try:\n", + " _, mlx_mean, mlx_std = pytensor_mlx_elemwise()\n", + " except Exception as e:\n", + " print(f\"MLX elemwise error: {e}\")\n", + " mlx_mean, mlx_std = float('inf'), 0\n", + " \n", + " results.append({\n", + " 'Size': f'{size}x{size}',\n", + " 'Operation': 'Element-wise (sin(A) + cos(B))',\n", + " 'PyTensor+JAX Mean (s)': f'{jax_mean:.6f}',\n", + " 'PyTensor+JAX Std (s)': f'{jax_std:.6f}',\n", + " 'PyTensor+MLX Mean (s)': f'{mlx_mean:.6f}' if mlx_mean != float('inf') else 'Error',\n", + " 'PyTensor+MLX Std (s)': f'{mlx_std:.6f}' if mlx_mean != float('inf') else 'N/A',\n", + " 'MLX Speedup': f'{jax_mean/mlx_mean:.2f}x' if mlx_mean != float('inf') and mlx_mean > 0 else 'N/A'\n", + " })\n", + " \n", + " # === TEST 3: Matrix Addition with Broadcasting ===\n", + " # PyTensor + JAX\n", + " @timer_jax\n", + " def pytensor_jax_broadcast():\n", + " pt_A = pt.matrix('A', dtype='float32')\n", + " pt_B = pt.matrix('B', dtype='float32')\n", + " result = pt_A + pt_B.T\n", + " f = function([pt_A, pt_B], result, mode=pytensor_jax_mode)\n", + " return f(A, B)\n", + " \n", + " # PyTensor + MLX\n", + " @timer_mlx\n", + " def pytensor_mlx_broadcast():\n", + " if not MLX_AVAILABLE:\n", + " return None, float('inf'), 0\n", + " pt_A = pt.matrix('A', dtype='float32')\n", + " pt_B = pt.matrix('B', dtype='float32')\n", + " result = pt_A + pt_B.T\n", + " f = function([pt_A, pt_B], result, mode=pytensor_mlx_mode)\n", + " return f(A, B)\n", + " \n", + " # Run broadcasting test\n", + " _, jax_mean, jax_std = pytensor_jax_broadcast()\n", + " try:\n", + " _, mlx_mean, mlx_std = pytensor_mlx_broadcast()\n", + " except Exception as e:\n", + " print(f\"MLX broadcast error: {e}\")\n", + " mlx_mean, mlx_std = float('inf'), 0\n", + " \n", + " results.append({\n", + " 'Size': f'{size}x{size}',\n", + " 'Operation': 'Broadcasting (A + B.T)',\n", + " 'PyTensor+JAX Mean (s)': f'{jax_mean:.6f}',\n", + " 'PyTensor+JAX Std (s)': f'{jax_std:.6f}',\n", + " 'PyTensor+MLX Mean (s)': f'{mlx_mean:.6f}' if mlx_mean != float('inf') else 'Error',\n", + " 'PyTensor+MLX Std (s)': f'{mlx_std:.6f}' if mlx_mean != float('inf') else 'N/A',\n", + " 'MLX Speedup': f'{jax_mean/mlx_mean:.2f}x' if mlx_mean != float('inf') and mlx_mean > 0 else 'N/A'\n", + " })\n", + " \n", + " # Create and display results table\n", + " df = pd.DataFrame(results)\n", + " return df\n", + "\n", + "def verify_computation_correctness():\n", + " \"\"\"Verify that JAX and MLX backends produce the same results\"\"\"\n", + " if not MLX_AVAILABLE:\n", + " print(\"MLX not available, skipping correctness check\")\n", + " return\n", + " \n", + " print(\"Verifying computational correctness...\")\n", + " \n", + " # Test with small matrices\n", + " np.random.seed(42)\n", + " A = np.random.randn(4, 4).astype(np.float32)\n", + " B = np.random.randn(4, 4).astype(np.float32)\n", + " C = np.random.randn(4, 4).astype(np.float32)\n", + " \n", + " # Test matrix multiplication\n", + " pt_A = pt.matrix('A', dtype='float32')\n", + " pt_B = pt.matrix('B', dtype='float32')\n", + " pt_C = pt.matrix('C', dtype='float32')\n", + " result_expr = pt_A @ pt_B @ pt_C\n", + " \n", + " f_jax = function([pt_A, pt_B, pt_C], result_expr, mode=pytensor_jax_mode)\n", + " f_mlx = function([pt_A, pt_B, pt_C], result_expr, mode=pytensor_mlx_mode)\n", + " \n", + " result_jax = f_jax(A, B, C)\n", + " result_mlx = f_mlx(A, B, C)\n", + " \n", + " # Force MLX evaluation\n", + " mx.eval(result_mlx)\n", + " \n", + " # Convert to numpy for comparison\n", + " if hasattr(result_jax, 'block_until_ready'):\n", + " result_jax.block_until_ready()\n", + " \n", + " diff = np.abs(np.array(result_jax) - np.array(result_mlx)).max()\n", + " print(f\"Max difference between JAX and MLX results: {diff:.2e}\")\n", + " \n", + " if diff < 1e-5:\n", + " print(\"✅ Results match within tolerance\")\n", + " else:\n", + " print(\"❌ Results differ significantly\")\n", + " \n", + " return diff\n", + "\n", + "def main(N=1000):\n", + " \"\"\"Main benchmark execution\"\"\"\n", + " # Display system info\n", + " system_info = {\n", + " 'JAX version': jax.__version__,\n", + " 'PyTensor version': pytensor.__version__,\n", + " 'MLX Available': 'Yes' if MLX_AVAILABLE else 'No',\n", + " 'Platform': 'Apple Silicon' if MLX_AVAILABLE else 'Generic',\n", + " 'Repetitions (N)': N\n", + " }\n", + " \n", + " if MLX_AVAILABLE:\n", + " system_info['MLX version'] = mx.__version__\n", + " \n", + " import pandas as pd\n", + " info_df = pd.DataFrame([system_info])\n", + " \n", + " # First verify correctness\n", + " verify_computation_correctness()\n", + " \n", + " # Then run benchmarks\n", + " results_df = run_benchmark(N=N)\n", + " \n", + " return info_df, results_df\n" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Verifying computational correctness...\n", + "Max difference between JAX and MLX results: 0.00e+00\n", + "✅ Results match within tolerance\n", + "Running benchmarks with N=1000 repetitions per test...\n", + "Testing 128x128 matrices...\n", + "Testing 256x256 matrices...\n", + "Testing 512x512 matrices...\n", + "Testing 1024x1024 matrices...\n" + ] + } + ], + "source": [ + "_, results = main()" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Benchmark Results over 1000 repetitions:\n", + " Size Operation PyTensor+JAX Mean (s) PyTensor+JAX Std (s) PyTensor+MLX Mean (s) PyTensor+MLX Std (s) MLX Speedup\n", + " 128x128 Matrix Chain (A @ B @ C) 0.005700 0.002127 0.001215 0.000497 4.69x\n", + " 128x128 Element-wise (sin(A) + cos(B)) 0.008280 0.002158 0.000876 0.000451 9.45x\n", + " 128x128 Broadcasting (A + B.T) 0.008083 0.002485 0.000861 0.000207 9.39x\n", + " 256x256 Matrix Chain (A @ B @ C) 0.005705 0.002307 0.001085 0.000210 5.26x\n", + " 256x256 Element-wise (sin(A) + cos(B)) 0.009794 0.001994 0.000998 0.001895 9.82x\n", + " 256x256 Broadcasting (A + B.T) 0.010467 0.002573 0.001056 0.000578 9.91x\n", + " 512x512 Matrix Chain (A @ B @ C) 0.006898 0.002576 0.001300 0.000391 5.31x\n", + " 512x512 Element-wise (sin(A) + cos(B)) 0.010997 0.002435 0.000976 0.000584 11.27x\n", + " 512x512 Broadcasting (A + B.T) 0.009730 0.002690 0.000968 0.000315 10.05x\n", + "1024x1024 Matrix Chain (A @ B @ C) 0.010941 0.002035 0.001735 0.000302 6.31x\n", + "1024x1024 Element-wise (sin(A) + cos(B)) 0.013936 0.003774 0.001103 0.000253 12.64x\n", + "1024x1024 Broadcasting (A + B.T) 0.011153 0.002297 0.001084 0.000242 10.29x\n" + ] + } + ], + "source": [ + "print(\"\\nBenchmark Results over 1000 repetitions:\")\n", + "print(results.to_string(index=False))" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "=== Detailed MLX Timing Analysis ===\n", + "Compilation time: 0.0020s\n", + "First execution: 0.0061s\n", + "Average execution (5 runs): 0.0004s ± 0.0001s\n", + "Individual execution times: ['0.0007', '0.0005', '0.0006', '0.0008', '0.0004', '0.0004', '0.0004', '0.0005', '0.0004', '0.0005', '0.0005', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0005', '0.0004', '0.0005', '0.0004', '0.0004', '0.0005', '0.0005', '0.0005', '0.0004', '0.0004', '0.0005', '0.0006', '0.0006', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0005', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0005', '0.0004', '0.0004', '0.0005', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0005', '0.0005', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0005', '0.0004', '0.0004', '0.0004', '0.0004', '0.0006', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0005', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0005', '0.0004', '0.0005', '0.0004', '0.0004', '0.0005', '0.0004', '0.0005', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0005', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0006', '0.0005', '0.0004', '0.0004', '0.0005', '0.0005', '0.0005', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0006', '0.0004', '0.0004', '0.0004', '0.0005', '0.0004', '0.0005', '0.0005', '0.0007', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0005', '0.0005', '0.0004', '0.0005', '0.0005', '0.0005', '0.0005', '0.0005', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0005', '0.0005', '0.0004', '0.0005', '0.0005', '0.0005', '0.0005', '0.0005', '0.0005', '0.0005', '0.0005', '0.0005', '0.0005', '0.0005', '0.0012', '0.0010', '0.0004', '0.0004', '0.0005', '0.0005', '0.0004', '0.0004', '0.0005', '0.0004', '0.0005', '0.0005', '0.0005', '0.0004', '0.0005', '0.0004', '0.0004', '0.0004', '0.0004', '0.0005', '0.0005', '0.0005', '0.0005', '0.0006', '0.0005', '0.0005', '0.0005', '0.0005', '0.0005', '0.0005', '0.0005', '0.0004', '0.0005', '0.0004', '0.0005', '0.0005', '0.0005', '0.0005', '0.0005', '0.0005', '0.0008', '0.0006', '0.0006', '0.0005', '0.0004', '0.0005', '0.0005', '0.0004', '0.0004', '0.0005', '0.0005', '0.0004', '0.0005', '0.0005', '0.0004', '0.0004', '0.0006', '0.0006', '0.0006', '0.0006', '0.0007', '0.0006', '0.0005', '0.0005', '0.0005', '0.0004', '0.0004', '0.0006', '0.0005', '0.0004', '0.0005', '0.0005', '0.0006', '0.0006', '0.0006', '0.0006', '0.0006', '0.0006', '0.0006', '0.0006', '0.0005', '0.0005', '0.0006', '0.0005', '0.0005', '0.0004', '0.0005', '0.0005', '0.0008', '0.0006', '0.0006', '0.0007', '0.0006', '0.0006', '0.0006', '0.0006', '0.0007', '0.0005', '0.0005', '0.0004', '0.0004', '0.0004', '0.0009', '0.0006', '0.0006', '0.0006', '0.0006', '0.0006', '0.0006', '0.0006', '0.0006', '0.0008', '0.0007', '0.0005', '0.0005', '0.0006', '0.0006', '0.0006', '0.0006', '0.0006', '0.0006', '0.0006', '0.0006', '0.0006', '0.0006', '0.0007', '0.0007', '0.0006', '0.0008', '0.0007', '0.0006', '0.0006', '0.0007', '0.0007', '0.0006', '0.0006', '0.0007', '0.0007', '0.0007', '0.0007', '0.0006', '0.0007', '0.0007', '0.0007', '0.0007', '0.0008', '0.0007', '0.0007', '0.0007', '0.0006', '0.0007', '0.0008', '0.0007', '0.0008', '0.0007', '0.0007', '0.0008', '0.0007', '0.0007', '0.0007', '0.0007', '0.0008', '0.0006', '0.0007', '0.0007', '0.0008', '0.0007', '0.0007', '0.0007', '0.0007', '0.0007', '0.0007', '0.0006', '0.0007', '0.0007', '0.0007', '0.0008', '0.0007', '0.0007', '0.0007', '0.0007', '0.0007', '0.0005', '0.0005', '0.0005', '0.0005', '0.0005', '0.0005', '0.0005', '0.0005', '0.0005', '0.0009', '0.0007', '0.0006', '0.0006', '0.0006', '0.0009', '0.0005', '0.0005', '0.0005', '0.0005', '0.0005', '0.0005', '0.0005', '0.0005', '0.0006', '0.0005', '0.0005', '0.0005', '0.0005', '0.0005', '0.0005', '0.0005', '0.0005', '0.0005', '0.0005', '0.0005', '0.0005', '0.0005', '0.0005', '0.0004', '0.0005', '0.0007', '0.0005', '0.0005', '0.0005', '0.0005', '0.0005', '0.0005', '0.0005', '0.0005', '0.0005', '0.0005', '0.0005', '0.0005', '0.0005', '0.0005', '0.0005', '0.0005', '0.0005', '0.0004', '0.0004', '0.0004', '0.0005', '0.0004', '0.0005', '0.0005', '0.0005', '0.0004', '0.0004', '0.0004', '0.0005', '0.0004', '0.0005', '0.0004', '0.0004', '0.0004', '0.0006', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0005', '0.0004', '0.0004', '0.0004', '0.0004', '0.0006', '0.0004', '0.0004', '0.0004', '0.0004', '0.0006', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0006', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0005', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0005', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0003', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0007', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0005', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0005', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0005', '0.0004', '0.0004', '0.0004', '0.0003', '0.0004', '0.0004', '0.0004', '0.0005', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0003', '0.0004', '0.0004', '0.0004', '0.0004', '0.0003', '0.0004', '0.0004', '0.0004', '0.0004', '0.0003', '0.0003', '0.0003', '0.0004', '0.0004', '0.0005', '0.0003', '0.0003', '0.0004', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0004', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0004', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0004', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0004', '0.0004', '0.0003', '0.0003', '0.0004', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0004', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0004', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0004', '0.0003', '0.0003', '0.0003', '0.0003', '0.0004', '0.0003', '0.0004', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0004', '0.0004', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0004', '0.0003', '0.0003', '0.0003', '0.0004', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0004', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0004', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0005', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0004', '0.0003', '0.0003', '0.0004', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0004', '0.0004', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0004', '0.0004', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0004', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0005', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0004', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0004', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0004', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0004', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0005', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0004', '0.0004', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0004', '0.0003', '0.0003', '0.0004', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0004', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0004', '0.0004', '0.0004', '0.0004', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0004', '0.0003', '0.0003', '0.0003', '0.0003', '0.0004', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0005', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0004', '0.0003', '0.0003', '0.0004', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003']\n" + ] + } + ], + "source": [ + "# Additional timing analysis - separate compilation vs execution time\n", + "if MLX_AVAILABLE:\n", + " print(\"\\n=== Detailed MLX Timing Analysis ===\")\n", + " \n", + " # Test with medium-sized matrix\n", + " np.random.seed(42)\n", + " A = np.random.randn(512, 512).astype(np.float32)\n", + " B = np.random.randn(512, 512).astype(np.float32)\n", + " C = np.random.randn(512, 512).astype(np.float32)\n", + " \n", + " # Create PyTensor function (compilation time)\n", + " start = time.perf_counter()\n", + " pt_A = pt.matrix('A', dtype='float32')\n", + " pt_B = pt.matrix('B', dtype='float32')\n", + " pt_C = pt.matrix('C', dtype='float32')\n", + " result_expr = pt_A @ pt_B @ pt_C\n", + " f_mlx = function([pt_A, pt_B, pt_C], result_expr, mode=pytensor_mlx_mode)\n", + " compilation_time = time.perf_counter() - start\n", + " \n", + " # First execution (may include additional compilation/optimization)\n", + " start = time.perf_counter()\n", + " result = f_mlx(A, B, C)\n", + " mx.eval(result) # Force evaluation\n", + " first_exec_time = time.perf_counter() - start\n", + " \n", + " # Subsequent executions (should be faster)\n", + " exec_times = []\n", + " for _ in range(1000):\n", + " start = time.perf_counter()\n", + " result = f_mlx(A, B, C)\n", + " mx.eval(result)\n", + " exec_times.append(time.perf_counter() - start)\n", + " \n", + " avg_exec_time = np.mean(exec_times)\n", + " std_exec_time = np.std(exec_times)\n", + " \n", + " print(f\"Compilation time: {compilation_time:.4f}s\")\n", + " print(f\"First execution: {first_exec_time:.4f}s\")\n", + " print(f\"Average execution (5 runs): {avg_exec_time:.4f}s ± {std_exec_time:.4f}s\")\n", + " print(f\"Individual execution times: {[f'{t:.4f}' for t in exec_times]}\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "mlx_env", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.2" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/pytensor/link/mlx/dispatch/blockwise.py b/pytensor/link/mlx/dispatch/blockwise.py index 74bb018a68..7a6bda8a66 100644 --- a/pytensor/link/mlx/dispatch/blockwise.py +++ b/pytensor/link/mlx/dispatch/blockwise.py @@ -89,10 +89,18 @@ def funcify_Blockwise(op: Blockwise, node, **kwargs): # 4) Build in_axes: map only the first n_batch args, keep the rest static in_axes = tuple(0 if i < n_batch else None for i in range(len(node.inputs))) - # 5) Vectorize (vmap) with in_axes + # 5) Handle case where no vectorization is needed + if n_batch == 0 or all(axis is None for axis in in_axes): + # No batch dimensions, just return the core function + def blockwise_fun(*inputs): + return core_f(*inputs) + + return blockwise_fun + + # 6) Vectorize (vmap) with in_axes blockwise_f = mx.vmap(core_f, in_axes=in_axes) - # 6) Return the mapped function + # 7) Return the mapped function def blockwise_fun(*inputs): return blockwise_f(*inputs) diff --git a/pytensor/link/mlx/dispatch/math.py b/pytensor/link/mlx/dispatch/math.py index 0d8e1ee7b1..9c7f94d491 100644 --- a/pytensor/link/mlx/dispatch/math.py +++ b/pytensor/link/mlx/dispatch/math.py @@ -42,7 +42,7 @@ @mlx_funcify.register(Dot) -def mlx_funcify_Dot(op, **kwargs): +def mlx_funcify_Dot(op, node=None, **kwargs): def dot(x, y): return mx.matmul(x, y) From e70617167bd6ff1d252835392e9a0a3ee5165fa7 Mon Sep 17 00:00:00 2001 From: Carlos Trujillo <59846724+cetagostini@users.noreply.github.com> Date: Tue, 10 Jun 2025 16:49:59 +0300 Subject: [PATCH 67/71] Changes with Ricardo --- .../benchmark_mlx_v_jax_corrected.ipynb | 18 ++++++++---------- pytensor/link/mlx/dispatch/core.py | 13 ++----------- pytensor/link/mlx/dispatch/shape.py | 2 +- tests/link/mlx/test_basic.py | 9 ++------- 4 files changed, 13 insertions(+), 29 deletions(-) diff --git a/doc/_drafts/benchmark_mlx_v_jax_corrected.ipynb b/doc/_drafts/benchmark_mlx_v_jax_corrected.ipynb index 84f78b5caf..876e982567 100644 --- a/doc/_drafts/benchmark_mlx_v_jax_corrected.ipynb +++ b/doc/_drafts/benchmark_mlx_v_jax_corrected.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 15, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -21,7 +21,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -296,7 +296,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -306,16 +306,14 @@ "Verifying computational correctness...\n", "Max difference between JAX and MLX results: 0.00e+00\n", "✅ Results match within tolerance\n", - "Running benchmarks with N=1000 repetitions per test...\n", - "Testing 128x128 matrices...\n", - "Testing 256x256 matrices...\n", - "Testing 512x512 matrices...\n", - "Testing 1024x1024 matrices...\n" + "Running benchmarks with N=20 repetitions per test...\n", + "Testing 128x128 matrices...\n" ] } ], "source": [ - "_, results = main()" + "iteration=20\n", + "_, results = main(N=iteration)" ] }, { @@ -346,7 +344,7 @@ } ], "source": [ - "print(\"\\nBenchmark Results over 1000 repetitions:\")\n", + "print(f\"\\nBenchmark Results over {iteration} repetitions:\")\n", "print(results.to_string(index=False))" ] }, diff --git a/pytensor/link/mlx/dispatch/core.py b/pytensor/link/mlx/dispatch/core.py index 82b8f0a5de..f1d960e760 100644 --- a/pytensor/link/mlx/dispatch/core.py +++ b/pytensor/link/mlx/dispatch/core.py @@ -212,17 +212,8 @@ def tensor_from_scalar(x): @mlx_funcify.register(ScalarFromTensor) def mlx_funcify_ScalarFromTensor(op, **kwargs): def scalar_from_tensor(x): - arr = mx.array(x) - try: - # Try .item() first (cleaner and faster when possible) - return arr.item() - except ValueError as e: - if "eval" in str(e): - # Fall back to reshape approach for compiled contexts - return arr.reshape(-1)[0] - else: - # Re-raise if it's a different error - raise + "We can't not return a scalar in MLX without trigger evaluation" + return x return scalar_from_tensor diff --git a/pytensor/link/mlx/dispatch/shape.py b/pytensor/link/mlx/dispatch/shape.py index 5441c8a363..8e530d468d 100644 --- a/pytensor/link/mlx/dispatch/shape.py +++ b/pytensor/link/mlx/dispatch/shape.py @@ -7,7 +7,7 @@ @mlx_funcify.register(Shape) def mlx_funcify_Shape(op, **kwargs): def shape(x): - return x.shape + return mx.array(x.shape, dtype=mx.int64) return shape diff --git a/tests/link/mlx/test_basic.py b/tests/link/mlx/test_basic.py index 871f694c49..e52f42b425 100644 --- a/tests/link/mlx/test_basic.py +++ b/tests/link/mlx/test_basic.py @@ -1,6 +1,7 @@ """ Basic tests for the MLX backend. """ +import pytest from collections.abc import Callable, Iterable from functools import partial @@ -224,12 +225,7 @@ def test_alloc_pytensor_integration(): x = pt.scalar("x", dtype="float32") result = pt.alloc(x, 3, 4) - # Use MLX mode - from pytensor.compile import mode - - mlx_mode = mode.get_mode("MLX") - - f = pytensor.function([x], result, mode=mlx_mode) + f = pytensor.function([x], result, mode="MLX") output = f(5.0) assert output.shape == (3, 4) @@ -238,7 +234,6 @@ def test_alloc_pytensor_integration(): def test_alloc_compilation_limitation(): """Test that Alloc operations with dynamic shapes provide helpful error in compiled contexts.""" - import pytest # Create variables x = pt.scalar("x", dtype="float32") From 929630bf5341e2f3b25392f7e66562e210b28ac6 Mon Sep 17 00:00:00 2001 From: Carlos Trujillo <59846724+cetagostini@users.noreply.github.com> Date: Tue, 10 Jun 2025 17:04:17 +0300 Subject: [PATCH 68/71] improving benchmark --- .../benchmark_mlx_v_jax_corrected.ipynb | 218 +++++++----------- 1 file changed, 81 insertions(+), 137 deletions(-) diff --git a/doc/_drafts/benchmark_mlx_v_jax_corrected.ipynb b/doc/_drafts/benchmark_mlx_v_jax_corrected.ipynb index 876e982567..680be8b455 100644 --- a/doc/_drafts/benchmark_mlx_v_jax_corrected.ipynb +++ b/doc/_drafts/benchmark_mlx_v_jax_corrected.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -21,7 +21,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -30,14 +30,14 @@ "\n", "# Set up PyTensor JAX mode\n", "jax_optimizer = RewriteDatabaseQuery(include=[\"jax\"], exclude=[])\n", - "pytensor_jax_mode = Mode(linker=JAXLinker(), optimizer=jax_optimizer)\n", + "pytensor_jax_mode = \"JAX\"\n", "\n", "# Try to set up MLX mode\n", "try:\n", " from pytensor.link.mlx import MLXLinker\n", " import mlx.core as mx\n", " mlx_optimizer = RewriteDatabaseQuery(include=[\"mlx\"], exclude=[])\n", - " pytensor_mlx_mode = Mode(linker=MLXLinker(), optimizer=mlx_optimizer)\n", + " pytensor_mlx_mode = \"MLX\"\n", " MLX_AVAILABLE = True\n", "except ImportError:\n", " MLX_AVAILABLE = False\n", @@ -101,29 +101,28 @@ " A = np.random.randn(size, size).astype(np.float32)\n", " B = np.random.randn(size, size).astype(np.float32)\n", " C = np.random.randn(size, size).astype(np.float32)\n", + "\n", + " pt_A = pt.matrix('A', dtype='float32')\n", + " pt_B = pt.matrix('B', dtype='float32') \n", + " pt_C = pt.matrix('C', dtype='float32')\n", + " result = pt.dot(pt.dot(pt_A, pt_B), pt_C)\n", + "\n", + "\n", + " f_jax = function([pt_A, pt_B, pt_C], result, mode=pytensor_jax_mode, trust_input=True)\n", + " f_mlx = function([pt_A, pt_B, pt_C], result, mode=pytensor_mlx_mode, trust_input=True)\n", " \n", " # === TEST 1: Matrix Multiplication Chain ===\n", " # PyTensor + JAX backend\n", " @timer_jax\n", " def pytensor_jax_matmul():\n", - " pt_A = pt.matrix('A', dtype='float32')\n", - " pt_B = pt.matrix('B', dtype='float32') \n", - " pt_C = pt.matrix('C', dtype='float32')\n", - " result = pt.dot(pt.dot(pt_A, pt_B), pt_C)\n", - " f = function([pt_A, pt_B, pt_C], result, mode=pytensor_jax_mode)\n", - " return f(A, B, C)\n", + " return f_jax(A, B, C)\n", " \n", " # PyTensor + MLX backend\n", " @timer_mlx\n", " def pytensor_mlx_matmul():\n", " if not MLX_AVAILABLE:\n", " return None, float('inf'), 0\n", - " pt_A = pt.matrix('A', dtype='float32')\n", - " pt_B = pt.matrix('B', dtype='float32')\n", - " pt_C = pt.matrix('C', dtype='float32')\n", - " result = pt_A @ pt_B @ pt_C\n", - " f = function([pt_A, pt_B, pt_C], result, mode=pytensor_mlx_mode)\n", - " return f(A, B, C)\n", + " return f_mlx(A, B, C)\n", " \n", " # Run matrix multiplication test\n", " _, jax_mean, jax_std = pytensor_jax_matmul()\n", @@ -145,24 +144,20 @@ " \n", " # === TEST 2: Element-wise Operations ===\n", " # PyTensor + JAX\n", + " result = pt.sin(pt_A) + pt.cos(pt_B)\n", + " f_jax = function([pt_A, pt_B], result, mode=pytensor_jax_mode, trust_input=True)\n", + " f_mlx = function([pt_A, pt_B], result, mode=pytensor_mlx_mode, trust_input=True)\n", + "\n", " @timer_jax\n", " def pytensor_jax_elemwise():\n", - " pt_A = pt.matrix('A', dtype='float32')\n", - " pt_B = pt.matrix('B', dtype='float32')\n", - " result = pt.sin(pt_A) + pt.cos(pt_B)\n", - " f = function([pt_A, pt_B], result, mode=pytensor_jax_mode)\n", - " return f(A, B)\n", + " return f_jax(A, B)\n", " \n", " # PyTensor + MLX\n", " @timer_mlx\n", " def pytensor_mlx_elemwise():\n", " if not MLX_AVAILABLE:\n", " return None, float('inf'), 0\n", - " pt_A = pt.matrix('A', dtype='float32')\n", - " pt_B = pt.matrix('B', dtype='float32')\n", - " result = pt.sin(pt_A) + pt.cos(pt_B)\n", - " f = function([pt_A, pt_B], result, mode=pytensor_mlx_mode)\n", - " return f(A, B)\n", + " return f_mlx(A, B)\n", " \n", " # Run element-wise test\n", " _, jax_mean, jax_std = pytensor_jax_elemwise()\n", @@ -184,24 +179,19 @@ " \n", " # === TEST 3: Matrix Addition with Broadcasting ===\n", " # PyTensor + JAX\n", + " result = pt_A + pt_B.T\n", + " f_jax = function([pt_A, pt_B], result, mode=pytensor_jax_mode, trust_input=True)\n", + " f_mlx = function([pt_A, pt_B], result, mode=pytensor_mlx_mode, trust_input=True)\n", " @timer_jax\n", " def pytensor_jax_broadcast():\n", - " pt_A = pt.matrix('A', dtype='float32')\n", - " pt_B = pt.matrix('B', dtype='float32')\n", - " result = pt_A + pt_B.T\n", - " f = function([pt_A, pt_B], result, mode=pytensor_jax_mode)\n", - " return f(A, B)\n", + " return f_jax(A, B)\n", " \n", " # PyTensor + MLX\n", " @timer_mlx\n", " def pytensor_mlx_broadcast():\n", " if not MLX_AVAILABLE:\n", " return None, float('inf'), 0\n", - " pt_A = pt.matrix('A', dtype='float32')\n", - " pt_B = pt.matrix('B', dtype='float32')\n", - " result = pt_A + pt_B.T\n", - " f = function([pt_A, pt_B], result, mode=pytensor_mlx_mode)\n", - " return f(A, B)\n", + " return f_mlx(A, B)\n", " \n", " # Run broadcasting test\n", " _, jax_mean, jax_std = pytensor_jax_broadcast()\n", @@ -225,49 +215,6 @@ " df = pd.DataFrame(results)\n", " return df\n", "\n", - "def verify_computation_correctness():\n", - " \"\"\"Verify that JAX and MLX backends produce the same results\"\"\"\n", - " if not MLX_AVAILABLE:\n", - " print(\"MLX not available, skipping correctness check\")\n", - " return\n", - " \n", - " print(\"Verifying computational correctness...\")\n", - " \n", - " # Test with small matrices\n", - " np.random.seed(42)\n", - " A = np.random.randn(4, 4).astype(np.float32)\n", - " B = np.random.randn(4, 4).astype(np.float32)\n", - " C = np.random.randn(4, 4).astype(np.float32)\n", - " \n", - " # Test matrix multiplication\n", - " pt_A = pt.matrix('A', dtype='float32')\n", - " pt_B = pt.matrix('B', dtype='float32')\n", - " pt_C = pt.matrix('C', dtype='float32')\n", - " result_expr = pt_A @ pt_B @ pt_C\n", - " \n", - " f_jax = function([pt_A, pt_B, pt_C], result_expr, mode=pytensor_jax_mode)\n", - " f_mlx = function([pt_A, pt_B, pt_C], result_expr, mode=pytensor_mlx_mode)\n", - " \n", - " result_jax = f_jax(A, B, C)\n", - " result_mlx = f_mlx(A, B, C)\n", - " \n", - " # Force MLX evaluation\n", - " mx.eval(result_mlx)\n", - " \n", - " # Convert to numpy for comparison\n", - " if hasattr(result_jax, 'block_until_ready'):\n", - " result_jax.block_until_ready()\n", - " \n", - " diff = np.abs(np.array(result_jax) - np.array(result_mlx)).max()\n", - " print(f\"Max difference between JAX and MLX results: {diff:.2e}\")\n", - " \n", - " if diff < 1e-5:\n", - " print(\"✅ Results match within tolerance\")\n", - " else:\n", - " print(\"❌ Results differ significantly\")\n", - " \n", - " return diff\n", - "\n", "def main(N=1000):\n", " \"\"\"Main benchmark execution\"\"\"\n", " # Display system info\n", @@ -285,9 +232,6 @@ " import pandas as pd\n", " info_df = pd.DataFrame([system_info])\n", " \n", - " # First verify correctness\n", - " verify_computation_correctness()\n", - " \n", " # Then run benchmarks\n", " results_df = run_benchmark(N=N)\n", " \n", @@ -296,29 +240,29 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Verifying computational correctness...\n", - "Max difference between JAX and MLX results: 0.00e+00\n", - "✅ Results match within tolerance\n", - "Running benchmarks with N=20 repetitions per test...\n", - "Testing 128x128 matrices...\n" + "Running benchmarks with N=100 repetitions per test...\n", + "Testing 128x128 matrices...\n", + "Testing 256x256 matrices...\n", + "Testing 512x512 matrices...\n", + "Testing 1024x1024 matrices...\n" ] } ], "source": [ - "iteration=20\n", + "iteration=100\n", "_, results = main(N=iteration)" ] }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 11, "metadata": {}, "outputs": [ { @@ -326,20 +270,20 @@ "output_type": "stream", "text": [ "\n", - "Benchmark Results over 1000 repetitions:\n", + "Benchmark Results over 100 repetitions:\n", " Size Operation PyTensor+JAX Mean (s) PyTensor+JAX Std (s) PyTensor+MLX Mean (s) PyTensor+MLX Std (s) MLX Speedup\n", - " 128x128 Matrix Chain (A @ B @ C) 0.005700 0.002127 0.001215 0.000497 4.69x\n", - " 128x128 Element-wise (sin(A) + cos(B)) 0.008280 0.002158 0.000876 0.000451 9.45x\n", - " 128x128 Broadcasting (A + B.T) 0.008083 0.002485 0.000861 0.000207 9.39x\n", - " 256x256 Matrix Chain (A @ B @ C) 0.005705 0.002307 0.001085 0.000210 5.26x\n", - " 256x256 Element-wise (sin(A) + cos(B)) 0.009794 0.001994 0.000998 0.001895 9.82x\n", - " 256x256 Broadcasting (A + B.T) 0.010467 0.002573 0.001056 0.000578 9.91x\n", - " 512x512 Matrix Chain (A @ B @ C) 0.006898 0.002576 0.001300 0.000391 5.31x\n", - " 512x512 Element-wise (sin(A) + cos(B)) 0.010997 0.002435 0.000976 0.000584 11.27x\n", - " 512x512 Broadcasting (A + B.T) 0.009730 0.002690 0.000968 0.000315 10.05x\n", - "1024x1024 Matrix Chain (A @ B @ C) 0.010941 0.002035 0.001735 0.000302 6.31x\n", - "1024x1024 Element-wise (sin(A) + cos(B)) 0.013936 0.003774 0.001103 0.000253 12.64x\n", - "1024x1024 Broadcasting (A + B.T) 0.011153 0.002297 0.001084 0.000242 10.29x\n" + " 128x128 Matrix Chain (A @ B @ C) 0.000131 0.000300 0.000283 0.000216 0.46x\n", + " 128x128 Element-wise (sin(A) + cos(B)) 0.000104 0.000304 0.000209 0.000145 0.50x\n", + " 128x128 Broadcasting (A + B.T) 0.000037 0.000296 0.000215 0.000153 0.17x\n", + " 256x256 Matrix Chain (A @ B @ C) 0.000394 0.000372 0.000441 0.000239 0.89x\n", + " 256x256 Element-wise (sin(A) + cos(B)) 0.000247 0.000389 0.000255 0.000168 0.97x\n", + " 256x256 Broadcasting (A + B.T) 0.000063 0.000329 0.000217 0.000153 0.29x\n", + " 512x512 Matrix Chain (A @ B @ C) 0.001004 0.000255 0.000399 0.000188 2.51x\n", + " 512x512 Element-wise (sin(A) + cos(B)) 0.000664 0.000328 0.000263 0.000163 2.53x\n", + " 512x512 Broadcasting (A + B.T) 0.000115 0.000339 0.000254 0.000156 0.45x\n", + "1024x1024 Matrix Chain (A @ B @ C) 0.005281 0.000359 0.000993 0.000342 5.32x\n", + "1024x1024 Element-wise (sin(A) + cos(B)) 0.002595 0.000359 0.000408 0.000220 6.36x\n", + "1024x1024 Broadcasting (A + B.T) 0.000501 0.000346 0.000385 0.000155 1.30x\n" ] } ], @@ -367,46 +311,46 @@ } ], "source": [ - "# Additional timing analysis - separate compilation vs execution time\n", - "if MLX_AVAILABLE:\n", - " print(\"\\n=== Detailed MLX Timing Analysis ===\")\n", + "# # Additional timing analysis - separate compilation vs execution time\n", + "# if MLX_AVAILABLE:\n", + "# print(\"\\n=== Detailed MLX Timing Analysis ===\")\n", " \n", - " # Test with medium-sized matrix\n", - " np.random.seed(42)\n", - " A = np.random.randn(512, 512).astype(np.float32)\n", - " B = np.random.randn(512, 512).astype(np.float32)\n", - " C = np.random.randn(512, 512).astype(np.float32)\n", + "# # Test with medium-sized matrix\n", + "# np.random.seed(42)\n", + "# A = np.random.randn(512, 512).astype(np.float32)\n", + "# B = np.random.randn(512, 512).astype(np.float32)\n", + "# C = np.random.randn(512, 512).astype(np.float32)\n", " \n", - " # Create PyTensor function (compilation time)\n", - " start = time.perf_counter()\n", - " pt_A = pt.matrix('A', dtype='float32')\n", - " pt_B = pt.matrix('B', dtype='float32')\n", - " pt_C = pt.matrix('C', dtype='float32')\n", - " result_expr = pt_A @ pt_B @ pt_C\n", - " f_mlx = function([pt_A, pt_B, pt_C], result_expr, mode=pytensor_mlx_mode)\n", - " compilation_time = time.perf_counter() - start\n", + "# # Create PyTensor function (compilation time)\n", + "# start = time.perf_counter()\n", + "# pt_A = pt.matrix('A', dtype='float32')\n", + "# pt_B = pt.matrix('B', dtype='float32')\n", + "# pt_C = pt.matrix('C', dtype='float32')\n", + "# result_expr = pt_A @ pt_B @ pt_C\n", + "# f_mlx = function([pt_A, pt_B, pt_C], result_expr, mode=pytensor_mlx_mode)\n", + "# compilation_time = time.perf_counter() - start\n", " \n", - " # First execution (may include additional compilation/optimization)\n", - " start = time.perf_counter()\n", - " result = f_mlx(A, B, C)\n", - " mx.eval(result) # Force evaluation\n", - " first_exec_time = time.perf_counter() - start\n", + "# # First execution (may include additional compilation/optimization)\n", + "# start = time.perf_counter()\n", + "# result = f_mlx(A, B, C)\n", + "# mx.eval(result) # Force evaluation\n", + "# first_exec_time = time.perf_counter() - start\n", " \n", - " # Subsequent executions (should be faster)\n", - " exec_times = []\n", - " for _ in range(1000):\n", - " start = time.perf_counter()\n", - " result = f_mlx(A, B, C)\n", - " mx.eval(result)\n", - " exec_times.append(time.perf_counter() - start)\n", + "# # Subsequent executions (should be faster)\n", + "# exec_times = []\n", + "# for _ in range(1000):\n", + "# start = time.perf_counter()\n", + "# result = f_mlx(A, B, C)\n", + "# mx.eval(result)\n", + "# exec_times.append(time.perf_counter() - start)\n", " \n", - " avg_exec_time = np.mean(exec_times)\n", - " std_exec_time = np.std(exec_times)\n", + "# avg_exec_time = np.mean(exec_times)\n", + "# std_exec_time = np.std(exec_times)\n", " \n", - " print(f\"Compilation time: {compilation_time:.4f}s\")\n", - " print(f\"First execution: {first_exec_time:.4f}s\")\n", - " print(f\"Average execution (5 runs): {avg_exec_time:.4f}s ± {std_exec_time:.4f}s\")\n", - " print(f\"Individual execution times: {[f'{t:.4f}' for t in exec_times]}\")\n" + "# print(f\"Compilation time: {compilation_time:.4f}s\")\n", + "# print(f\"First execution: {first_exec_time:.4f}s\")\n", + "# print(f\"Average execution (5 runs): {avg_exec_time:.4f}s ± {std_exec_time:.4f}s\")\n", + "# print(f\"Individual execution times: {[f'{t:.4f}' for t in exec_times]}\")\n" ] }, { From 8f2982d3dc6fa9115ce9b4cdee76cfbf97fdeefe Mon Sep 17 00:00:00 2001 From: Carlos Trujillo <59846724+cetagostini@users.noreply.github.com> Date: Tue, 10 Jun 2025 17:30:37 +0300 Subject: [PATCH 69/71] pre commit --- tests/link/mlx/test_basic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/link/mlx/test_basic.py b/tests/link/mlx/test_basic.py index e52f42b425..8d6999e55f 100644 --- a/tests/link/mlx/test_basic.py +++ b/tests/link/mlx/test_basic.py @@ -1,13 +1,13 @@ """ Basic tests for the MLX backend. """ -import pytest from collections.abc import Callable, Iterable from functools import partial import mlx.core as mx import numpy as np +import pytest import pytensor from pytensor import tensor as pt From 02ed254221816bc464372e512f6dd8da9292ba38 Mon Sep 17 00:00:00 2001 From: Carlos Trujillo <59846724+cetagostini@users.noreply.github.com> Date: Tue, 10 Jun 2025 22:01:14 +0300 Subject: [PATCH 70/71] benchs --- .../benchmark_mlx_v_jax_corrected.ipynb | 100 ++++++++++-------- 1 file changed, 57 insertions(+), 43 deletions(-) diff --git a/doc/_drafts/benchmark_mlx_v_jax_corrected.ipynb b/doc/_drafts/benchmark_mlx_v_jax_corrected.ipynb index 680be8b455..a5e1b2dfa4 100644 --- a/doc/_drafts/benchmark_mlx_v_jax_corrected.ipynb +++ b/doc/_drafts/benchmark_mlx_v_jax_corrected.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 5, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -21,7 +21,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -88,7 +88,7 @@ " \"\"\"Run comprehensive benchmark comparing PyTensor JAX vs MLX backends\"\"\"\n", " import pandas as pd\n", " \n", - " sizes = [128, 256, 512, 1024]\n", + " sizes = [2, 4, 2000, 4000]\n", " results = []\n", " \n", " print(f\"Running benchmarks with N={N} repetitions per test...\")\n", @@ -110,6 +110,8 @@ "\n", " f_jax = function([pt_A, pt_B, pt_C], result, mode=pytensor_jax_mode, trust_input=True)\n", " f_mlx = function([pt_A, pt_B, pt_C], result, mode=pytensor_mlx_mode, trust_input=True)\n", + " f_jax(A, B, C)\n", + " f_mlx(A, B, C)\n", " \n", " # === TEST 1: Matrix Multiplication Chain ===\n", " # PyTensor + JAX backend\n", @@ -132,6 +134,13 @@ " print(f\"MLX matmul error: {e}\")\n", " mlx_mean, mlx_std = float('inf'), 0\n", " \n", + " # Calculate percentage improvement (positive = MLX is faster, negative = MLX is slower)\n", + " if mlx_mean != float('inf') and mlx_mean > 0:\n", + " speedup_percentage = ((jax_mean - mlx_mean) / jax_mean) * 100\n", + " speedup_str = f'{speedup_percentage:+.1f}%'\n", + " else:\n", + " speedup_str = 'N/A'\n", + " \n", " results.append({\n", " 'Size': f'{size}x{size}',\n", " 'Operation': 'Matrix Chain (A @ B @ C)',\n", @@ -139,7 +148,7 @@ " 'PyTensor+JAX Std (s)': f'{jax_std:.6f}',\n", " 'PyTensor+MLX Mean (s)': f'{mlx_mean:.6f}' if mlx_mean != float('inf') else 'Error',\n", " 'PyTensor+MLX Std (s)': f'{mlx_std:.6f}' if mlx_mean != float('inf') else 'N/A',\n", - " 'MLX Speedup': f'{jax_mean/mlx_mean:.2f}x' if mlx_mean != float('inf') and mlx_mean > 0 else 'N/A'\n", + " 'MLX Performance': speedup_str\n", " })\n", " \n", " # === TEST 2: Element-wise Operations ===\n", @@ -147,6 +156,8 @@ " result = pt.sin(pt_A) + pt.cos(pt_B)\n", " f_jax = function([pt_A, pt_B], result, mode=pytensor_jax_mode, trust_input=True)\n", " f_mlx = function([pt_A, pt_B], result, mode=pytensor_mlx_mode, trust_input=True)\n", + " f_jax(A, B)\n", + " f_mlx(A, B)\n", "\n", " @timer_jax\n", " def pytensor_jax_elemwise():\n", @@ -167,6 +178,13 @@ " print(f\"MLX elemwise error: {e}\")\n", " mlx_mean, mlx_std = float('inf'), 0\n", " \n", + " # Calculate percentage improvement\n", + " if mlx_mean != float('inf') and mlx_mean > 0:\n", + " speedup_percentage = ((jax_mean - mlx_mean) / jax_mean) * 100\n", + " speedup_str = f'{speedup_percentage:+.1f}%'\n", + " else:\n", + " speedup_str = 'N/A'\n", + " \n", " results.append({\n", " 'Size': f'{size}x{size}',\n", " 'Operation': 'Element-wise (sin(A) + cos(B))',\n", @@ -174,7 +192,7 @@ " 'PyTensor+JAX Std (s)': f'{jax_std:.6f}',\n", " 'PyTensor+MLX Mean (s)': f'{mlx_mean:.6f}' if mlx_mean != float('inf') else 'Error',\n", " 'PyTensor+MLX Std (s)': f'{mlx_std:.6f}' if mlx_mean != float('inf') else 'N/A',\n", - " 'MLX Speedup': f'{jax_mean/mlx_mean:.2f}x' if mlx_mean != float('inf') and mlx_mean > 0 else 'N/A'\n", + " 'MLX Performance': speedup_str\n", " })\n", " \n", " # === TEST 3: Matrix Addition with Broadcasting ===\n", @@ -182,6 +200,8 @@ " result = pt_A + pt_B.T\n", " f_jax = function([pt_A, pt_B], result, mode=pytensor_jax_mode, trust_input=True)\n", " f_mlx = function([pt_A, pt_B], result, mode=pytensor_mlx_mode, trust_input=True)\n", + " f_jax(A, B)\n", + " f_mlx(A, B)\n", " @timer_jax\n", " def pytensor_jax_broadcast():\n", " return f_jax(A, B)\n", @@ -201,6 +221,13 @@ " print(f\"MLX broadcast error: {e}\")\n", " mlx_mean, mlx_std = float('inf'), 0\n", " \n", + " # Calculate percentage improvement\n", + " if mlx_mean != float('inf') and mlx_mean > 0:\n", + " speedup_percentage = ((jax_mean - mlx_mean) / jax_mean) * 100\n", + " speedup_str = f'{speedup_percentage:+.1f}%'\n", + " else:\n", + " speedup_str = 'N/A'\n", + " \n", " results.append({\n", " 'Size': f'{size}x{size}',\n", " 'Operation': 'Broadcasting (A + B.T)',\n", @@ -208,7 +235,7 @@ " 'PyTensor+JAX Std (s)': f'{jax_std:.6f}',\n", " 'PyTensor+MLX Mean (s)': f'{mlx_mean:.6f}' if mlx_mean != float('inf') else 'Error',\n", " 'PyTensor+MLX Std (s)': f'{mlx_std:.6f}' if mlx_mean != float('inf') else 'N/A',\n", - " 'MLX Speedup': f'{jax_mean/mlx_mean:.2f}x' if mlx_mean != float('inf') and mlx_mean > 0 else 'N/A'\n", + " 'MLX Performance': speedup_str\n", " })\n", " \n", " # Create and display results table\n", @@ -240,29 +267,29 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Running benchmarks with N=100 repetitions per test...\n", - "Testing 128x128 matrices...\n", - "Testing 256x256 matrices...\n", - "Testing 512x512 matrices...\n", - "Testing 1024x1024 matrices...\n" + "Running benchmarks with N=150 repetitions per test...\n", + "Testing 2x2 matrices...\n", + "Testing 4x4 matrices...\n", + "Testing 2000x2000 matrices...\n", + "Testing 4000x4000 matrices...\n" ] } ], "source": [ - "iteration=100\n", + "iteration=150\n", "_, results = main(N=iteration)" ] }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 4, "metadata": {}, "outputs": [ { @@ -270,20 +297,20 @@ "output_type": "stream", "text": [ "\n", - "Benchmark Results over 100 repetitions:\n", - " Size Operation PyTensor+JAX Mean (s) PyTensor+JAX Std (s) PyTensor+MLX Mean (s) PyTensor+MLX Std (s) MLX Speedup\n", - " 128x128 Matrix Chain (A @ B @ C) 0.000131 0.000300 0.000283 0.000216 0.46x\n", - " 128x128 Element-wise (sin(A) + cos(B)) 0.000104 0.000304 0.000209 0.000145 0.50x\n", - " 128x128 Broadcasting (A + B.T) 0.000037 0.000296 0.000215 0.000153 0.17x\n", - " 256x256 Matrix Chain (A @ B @ C) 0.000394 0.000372 0.000441 0.000239 0.89x\n", - " 256x256 Element-wise (sin(A) + cos(B)) 0.000247 0.000389 0.000255 0.000168 0.97x\n", - " 256x256 Broadcasting (A + B.T) 0.000063 0.000329 0.000217 0.000153 0.29x\n", - " 512x512 Matrix Chain (A @ B @ C) 0.001004 0.000255 0.000399 0.000188 2.51x\n", - " 512x512 Element-wise (sin(A) + cos(B)) 0.000664 0.000328 0.000263 0.000163 2.53x\n", - " 512x512 Broadcasting (A + B.T) 0.000115 0.000339 0.000254 0.000156 0.45x\n", - "1024x1024 Matrix Chain (A @ B @ C) 0.005281 0.000359 0.000993 0.000342 5.32x\n", - "1024x1024 Element-wise (sin(A) + cos(B)) 0.002595 0.000359 0.000408 0.000220 6.36x\n", - "1024x1024 Broadcasting (A + B.T) 0.000501 0.000346 0.000385 0.000155 1.30x\n" + "Benchmark Results over 150 repetitions:\n", + " Size Operation PyTensor+JAX Mean (s) PyTensor+JAX Std (s) PyTensor+MLX Mean (s) PyTensor+MLX Std (s) MLX Performance\n", + " 2x2 Matrix Chain (A @ B @ C) 0.000009 0.000002 0.000311 0.000266 -3277.7%\n", + " 2x2 Element-wise (sin(A) + cos(B)) 0.000008 0.000003 0.000233 0.000105 -2830.3%\n", + " 2x2 Broadcasting (A + B.T) 0.000007 0.000003 0.000253 0.000151 -3429.1%\n", + " 4x4 Matrix Chain (A @ B @ C) 0.000011 0.000008 0.000285 0.000111 -2537.7%\n", + " 4x4 Element-wise (sin(A) + cos(B)) 0.000007 0.000001 0.000235 0.000124 -3217.0%\n", + " 4x4 Broadcasting (A + B.T) 0.000007 0.000002 0.000202 0.000077 -2755.8%\n", + "2000x2000 Matrix Chain (A @ B @ C) 0.024714 0.000919 0.004166 0.003531 +83.1%\n", + "2000x2000 Element-wise (sin(A) + cos(B)) 0.009464 0.000417 0.000844 0.000284 +91.1%\n", + "2000x2000 Broadcasting (A + B.T) 0.000690 0.000022 0.000821 0.000093 -19.0%\n", + "4000x4000 Matrix Chain (A @ B @ C) 0.196587 0.008780 0.027411 0.001132 +86.1%\n", + "4000x4000 Element-wise (sin(A) + cos(B)) 0.037744 0.001247 0.003355 0.000467 +91.1%\n", + "4000x4000 Broadcasting (A + B.T) 0.012233 0.000421 0.003323 0.000370 +72.8%\n" ] } ], @@ -294,22 +321,9 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 5, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "=== Detailed MLX Timing Analysis ===\n", - "Compilation time: 0.0020s\n", - "First execution: 0.0061s\n", - "Average execution (5 runs): 0.0004s ± 0.0001s\n", - "Individual execution times: ['0.0007', '0.0005', '0.0006', '0.0008', '0.0004', '0.0004', '0.0004', '0.0005', '0.0004', '0.0005', '0.0005', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0005', '0.0004', '0.0005', '0.0004', '0.0004', '0.0005', '0.0005', '0.0005', '0.0004', '0.0004', '0.0005', '0.0006', '0.0006', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0005', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0005', '0.0004', '0.0004', '0.0005', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0005', '0.0005', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0005', '0.0004', '0.0004', '0.0004', '0.0004', '0.0006', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0005', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0005', '0.0004', '0.0005', '0.0004', '0.0004', '0.0005', '0.0004', '0.0005', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0005', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0006', '0.0005', '0.0004', '0.0004', '0.0005', '0.0005', '0.0005', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0006', '0.0004', '0.0004', '0.0004', '0.0005', '0.0004', '0.0005', '0.0005', '0.0007', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0005', '0.0005', '0.0004', '0.0005', '0.0005', '0.0005', '0.0005', '0.0005', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0005', '0.0005', '0.0004', '0.0005', '0.0005', '0.0005', '0.0005', '0.0005', '0.0005', '0.0005', '0.0005', '0.0005', '0.0005', '0.0005', '0.0012', '0.0010', '0.0004', '0.0004', '0.0005', '0.0005', '0.0004', '0.0004', '0.0005', '0.0004', '0.0005', '0.0005', '0.0005', '0.0004', '0.0005', '0.0004', '0.0004', '0.0004', '0.0004', '0.0005', '0.0005', '0.0005', '0.0005', '0.0006', '0.0005', '0.0005', '0.0005', '0.0005', '0.0005', '0.0005', '0.0005', '0.0004', '0.0005', '0.0004', '0.0005', '0.0005', '0.0005', '0.0005', '0.0005', '0.0005', '0.0008', '0.0006', '0.0006', '0.0005', '0.0004', '0.0005', '0.0005', '0.0004', '0.0004', '0.0005', '0.0005', '0.0004', '0.0005', '0.0005', '0.0004', '0.0004', '0.0006', '0.0006', '0.0006', '0.0006', '0.0007', '0.0006', '0.0005', '0.0005', '0.0005', '0.0004', '0.0004', '0.0006', '0.0005', '0.0004', '0.0005', '0.0005', '0.0006', '0.0006', '0.0006', '0.0006', '0.0006', '0.0006', '0.0006', '0.0006', '0.0005', '0.0005', '0.0006', '0.0005', '0.0005', '0.0004', '0.0005', '0.0005', '0.0008', '0.0006', '0.0006', '0.0007', '0.0006', '0.0006', '0.0006', '0.0006', '0.0007', '0.0005', '0.0005', '0.0004', '0.0004', '0.0004', '0.0009', '0.0006', '0.0006', '0.0006', '0.0006', '0.0006', '0.0006', '0.0006', '0.0006', '0.0008', '0.0007', '0.0005', '0.0005', '0.0006', '0.0006', '0.0006', '0.0006', '0.0006', '0.0006', '0.0006', '0.0006', '0.0006', '0.0006', '0.0007', '0.0007', '0.0006', '0.0008', '0.0007', '0.0006', '0.0006', '0.0007', '0.0007', '0.0006', '0.0006', '0.0007', '0.0007', '0.0007', '0.0007', '0.0006', '0.0007', '0.0007', '0.0007', '0.0007', '0.0008', '0.0007', '0.0007', '0.0007', '0.0006', '0.0007', '0.0008', '0.0007', '0.0008', '0.0007', '0.0007', '0.0008', '0.0007', '0.0007', '0.0007', '0.0007', '0.0008', '0.0006', '0.0007', '0.0007', '0.0008', '0.0007', '0.0007', '0.0007', '0.0007', '0.0007', '0.0007', '0.0006', '0.0007', '0.0007', '0.0007', '0.0008', '0.0007', '0.0007', '0.0007', '0.0007', '0.0007', '0.0005', '0.0005', '0.0005', '0.0005', '0.0005', '0.0005', '0.0005', '0.0005', '0.0005', '0.0009', '0.0007', '0.0006', '0.0006', '0.0006', '0.0009', '0.0005', '0.0005', '0.0005', '0.0005', '0.0005', '0.0005', '0.0005', '0.0005', '0.0006', '0.0005', '0.0005', '0.0005', '0.0005', '0.0005', '0.0005', '0.0005', '0.0005', '0.0005', '0.0005', '0.0005', '0.0005', '0.0005', '0.0005', '0.0004', '0.0005', '0.0007', '0.0005', '0.0005', '0.0005', '0.0005', '0.0005', '0.0005', '0.0005', '0.0005', '0.0005', '0.0005', '0.0005', '0.0005', '0.0005', '0.0005', '0.0005', '0.0005', '0.0005', '0.0004', '0.0004', '0.0004', '0.0005', '0.0004', '0.0005', '0.0005', '0.0005', '0.0004', '0.0004', '0.0004', '0.0005', '0.0004', '0.0005', '0.0004', '0.0004', '0.0004', '0.0006', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0005', '0.0004', '0.0004', '0.0004', '0.0004', '0.0006', '0.0004', '0.0004', '0.0004', '0.0004', '0.0006', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0006', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0005', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0005', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0003', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0007', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0005', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0005', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0005', '0.0004', '0.0004', '0.0004', '0.0003', '0.0004', '0.0004', '0.0004', '0.0005', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0004', '0.0003', '0.0004', '0.0004', '0.0004', '0.0004', '0.0003', '0.0004', '0.0004', '0.0004', '0.0004', '0.0003', '0.0003', '0.0003', '0.0004', '0.0004', '0.0005', '0.0003', '0.0003', '0.0004', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0004', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0004', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0004', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0004', '0.0004', '0.0003', '0.0003', '0.0004', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0004', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0004', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0004', '0.0003', '0.0003', '0.0003', '0.0003', '0.0004', '0.0003', '0.0004', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0004', '0.0004', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0004', '0.0003', '0.0003', '0.0003', '0.0004', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0004', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0004', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0005', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0004', '0.0003', '0.0003', '0.0004', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0004', '0.0004', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0004', '0.0004', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0004', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0005', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0004', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0004', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0004', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0004', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0005', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0004', '0.0004', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0004', '0.0003', '0.0003', '0.0004', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0004', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0004', '0.0004', '0.0004', '0.0004', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0004', '0.0003', '0.0003', '0.0003', '0.0003', '0.0004', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0005', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0004', '0.0003', '0.0003', '0.0004', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003', '0.0003']\n" - ] - } - ], + "outputs": [], "source": [ "# # Additional timing analysis - separate compilation vs execution time\n", "# if MLX_AVAILABLE:\n", From 03a2094a5367743a5971b773d4e1e0fef5cd09ba Mon Sep 17 00:00:00 2001 From: Carlos Trujillo <59846724+cetagostini@users.noreply.github.com> Date: Fri, 11 Jul 2025 16:14:38 +0300 Subject: [PATCH 71/71] Changes on the branch --- .../benchmark_mlx_v_jax_corrected.ipynb | 69 ++++++++++++++----- pytensor/link/mlx/dispatch/blockwise.py | 2 +- pytensor/link/mlx/dispatch/signal/conv.py | 2 +- 3 files changed, 55 insertions(+), 18 deletions(-) diff --git a/doc/_drafts/benchmark_mlx_v_jax_corrected.ipynb b/doc/_drafts/benchmark_mlx_v_jax_corrected.ipynb index a5e1b2dfa4..4d5b9f296b 100644 --- a/doc/_drafts/benchmark_mlx_v_jax_corrected.ipynb +++ b/doc/_drafts/benchmark_mlx_v_jax_corrected.ipynb @@ -1,5 +1,38 @@ { "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Obtaining file:///Users/carlostrujillo/Documents/GitHub/pytensor\n", + " Installing build dependencies ... \u001b[?25ldone\n", + "\u001b[?25h Checking if build backend supports build_editable ... \u001b[?25ldone\n", + "\u001b[?25h Getting requirements to build editable ... \u001b[?25ldone\n", + "\u001b[?25h Preparing editable metadata (pyproject.toml) ... \u001b[?25ldone\n", + "\u001b[?25hBuilding wheels for collected packages: pytensor\n", + " Building editable for pytensor (pyproject.toml) ... \u001b[?25ldone\n", + "\u001b[?25h Created wheel for pytensor: filename=pytensor-2.31.7+80.g06ccf91ba.dirty-0.editable-cp312-cp312-macosx_11_0_arm64.whl size=7323 sha256=c09587a5f3141d49000666d2817c5a01436f13ff5a19aa3deda20f647660afee\n", + " Stored in directory: /private/var/folders/f0/rbz8xs8s17n3k3f_ccp31bvh0000gn/T/pip-ephem-wheel-cache-i00nb67k/wheels/52/f6/4c/e6784e2203d5405c94db1d544248730e598e4397674416af05\n", + "Successfully built pytensor\n", + "Installing collected packages: pytensor\n", + " Attempting uninstall: pytensor\n", + " Found existing installation: pytensor 2.31.7+80.g06ccf91ba.dirty\n", + " Uninstalling pytensor-2.31.7+80.g06ccf91ba.dirty:\n", + " Successfully uninstalled pytensor-2.31.7+80.g06ccf91ba.dirty\n", + "Successfully installed pytensor-2.31.7+80.g06ccf91ba.dirty\n", + "Note: you may need to restart the kernel to use updated packages.\n" + ] + } + ], + "source": [ + "%pip install -e ../.. --no-deps" + ] + }, { "cell_type": "code", "execution_count": 1, @@ -88,7 +121,7 @@ " \"\"\"Run comprehensive benchmark comparing PyTensor JAX vs MLX backends\"\"\"\n", " import pandas as pd\n", " \n", - " sizes = [2, 4, 2000, 4000]\n", + " sizes = [2, 4, 1080, 2080, 3080]\n", " results = []\n", " \n", " print(f\"Running benchmarks with N={N} repetitions per test...\")\n", @@ -277,8 +310,9 @@ "Running benchmarks with N=150 repetitions per test...\n", "Testing 2x2 matrices...\n", "Testing 4x4 matrices...\n", - "Testing 2000x2000 matrices...\n", - "Testing 4000x4000 matrices...\n" + "Testing 1080x1080 matrices...\n", + "Testing 2080x2080 matrices...\n", + "Testing 3080x3080 matrices...\n" ] } ], @@ -299,18 +333,21 @@ "\n", "Benchmark Results over 150 repetitions:\n", " Size Operation PyTensor+JAX Mean (s) PyTensor+JAX Std (s) PyTensor+MLX Mean (s) PyTensor+MLX Std (s) MLX Performance\n", - " 2x2 Matrix Chain (A @ B @ C) 0.000009 0.000002 0.000311 0.000266 -3277.7%\n", - " 2x2 Element-wise (sin(A) + cos(B)) 0.000008 0.000003 0.000233 0.000105 -2830.3%\n", - " 2x2 Broadcasting (A + B.T) 0.000007 0.000003 0.000253 0.000151 -3429.1%\n", - " 4x4 Matrix Chain (A @ B @ C) 0.000011 0.000008 0.000285 0.000111 -2537.7%\n", - " 4x4 Element-wise (sin(A) + cos(B)) 0.000007 0.000001 0.000235 0.000124 -3217.0%\n", - " 4x4 Broadcasting (A + B.T) 0.000007 0.000002 0.000202 0.000077 -2755.8%\n", - "2000x2000 Matrix Chain (A @ B @ C) 0.024714 0.000919 0.004166 0.003531 +83.1%\n", - "2000x2000 Element-wise (sin(A) + cos(B)) 0.009464 0.000417 0.000844 0.000284 +91.1%\n", - "2000x2000 Broadcasting (A + B.T) 0.000690 0.000022 0.000821 0.000093 -19.0%\n", - "4000x4000 Matrix Chain (A @ B @ C) 0.196587 0.008780 0.027411 0.001132 +86.1%\n", - "4000x4000 Element-wise (sin(A) + cos(B)) 0.037744 0.001247 0.003355 0.000467 +91.1%\n", - "4000x4000 Broadcasting (A + B.T) 0.012233 0.000421 0.003323 0.000370 +72.8%\n" + " 2x2 Matrix Chain (A @ B @ C) 0.000009 0.000002 0.000305 0.000299 -3213.5%\n", + " 2x2 Element-wise (sin(A) + cos(B)) 0.000007 0.000002 0.000352 0.003757 -5078.0%\n", + " 2x2 Broadcasting (A + B.T) 0.000007 0.000001 0.000188 0.000153 -2721.1%\n", + " 4x4 Matrix Chain (A @ B @ C) 0.000009 0.000001 0.000209 0.000063 -2126.2%\n", + " 4x4 Element-wise (sin(A) + cos(B)) 0.000007 0.000001 0.000180 0.000066 -2449.5%\n", + " 4x4 Broadcasting (A + B.T) 0.000007 0.000003 0.000181 0.000065 -2564.1%\n", + "1080x1080 Matrix Chain (A @ B @ C) 0.005951 0.000356 0.001355 0.000392 +77.2%\n", + "1080x1080 Element-wise (sin(A) + cos(B)) 0.002820 0.000107 0.000432 0.000207 +84.7%\n", + "1080x1080 Broadcasting (A + B.T) 0.000212 0.000035 0.000428 0.000206 -102.0%\n", + "2080x2080 Matrix Chain (A @ B @ C) 0.027609 0.001255 0.004550 0.002528 +83.5%\n", + "2080x2080 Element-wise (sin(A) + cos(B)) 0.010086 0.000417 0.001175 0.000350 +88.3%\n", + "2080x2080 Broadcasting (A + B.T) 0.000856 0.000068 0.001124 0.000241 -31.2%\n", + "3080x3080 Matrix Chain (A @ B @ C) 0.093115 0.003823 0.013649 0.000513 +85.3%\n", + "3080x3080 Element-wise (sin(A) + cos(B)) 0.022586 0.000756 0.001930 0.000287 +91.5%\n", + "3080x3080 Broadcasting (A + B.T) 0.002580 0.000161 0.001937 0.000257 +24.9%\n" ] } ], @@ -321,7 +358,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ diff --git a/pytensor/link/mlx/dispatch/blockwise.py b/pytensor/link/mlx/dispatch/blockwise.py index 7a6bda8a66..ef68dca1fa 100644 --- a/pytensor/link/mlx/dispatch/blockwise.py +++ b/pytensor/link/mlx/dispatch/blockwise.py @@ -2,7 +2,7 @@ from pytensor.link.mlx.dispatch import mlx_funcify from pytensor.tensor.blockwise import Blockwise -from pytensor.tensor.signal.conv import Conv1d +from pytensor.tensor.signal.conv import Convolve1d as Conv1d def blockwise_conv1d(op, node, **kwargs): diff --git a/pytensor/link/mlx/dispatch/signal/conv.py b/pytensor/link/mlx/dispatch/signal/conv.py index 8f84ebb42f..b2adf009ab 100644 --- a/pytensor/link/mlx/dispatch/signal/conv.py +++ b/pytensor/link/mlx/dispatch/signal/conv.py @@ -1,7 +1,7 @@ import mlx.core as mx from pytensor.link.mlx.dispatch import mlx_funcify -from pytensor.tensor.signal.conv import Conv1d +from pytensor.tensor.signal.conv import Convolve1d as Conv1d @mlx_funcify.register(Conv1d)