From 848c81abdefcf25bef91973870d9e69356290d1c Mon Sep 17 00:00:00 2001 From: Jesper Stemann Andersen Date: Sat, 16 Aug 2025 13:40:45 +0200 Subject: [PATCH 1/2] Added workaround for MLX issue where strides may be zero for dims of size 1 Cf. https://github.com/ml-explore/mlx/issues/2501 --- src/array.jl | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/array.jl b/src/array.jl index 0e09b46..0c26f70 100644 --- a/src/array.jl +++ b/src/array.jl @@ -92,7 +92,7 @@ end # Strided array interface, cf. https://docs.julialang.org/en/v1/manual/interfaces/#man-interface-strided-arrays function Base.strides(array::MLXArray) - return Tuple( + array_strides = Tuple( Int.( unsafe_wrap( Vector{Csize_t}, @@ -101,6 +101,13 @@ function Base.strides(array::MLXArray) ), ), ) + if any(iszero, array_strides) # Workaround for MLX issue where strides may be zero for dims of size 1: https://github.com/ml-explore/mlx/issues/2501 + non_zero_strides = map(s -> iszero(s) ? 1 : s, array_strides) + @debug "Some strides are zero in $array_strides - returning strides $non_zero_strides for array of size $(size(array))" + return non_zero_strides + end + + return array_strides end function Base.unsafe_convert(::Type{Ptr{T}}, array::MLXArray{T, N}) where {T, N} From f455bb3ff57fc575b4d0d8629e44ce7e27ab513a Mon Sep 17 00:00:00 2001 From: Jesper Stemann Andersen Date: Wed, 26 Mar 2025 10:11:47 +0100 Subject: [PATCH 2/2] Implemented broadcasting for MLXArray Also: * Added necessary eval of MLX array data in Base.unsafe_convert(::Type{Ptr{T}}, array::MLXArray{T, N}) * Implemented Base.similar for MLXArray. * Added supported_number_types(::DeviceType = DeviceTypeCPU) --- Project.toml | 3 ++- src/array.jl | 25 ++++++++++++++++++++++ test/array_tests.jl | 51 ++++++++++++++++++++++++++++++++++++++++++++ test/device_tests.jl | 1 + test/stream_tests.jl | 1 + 5 files changed, 80 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 507b6b0..4b0e56f 100644 --- a/Project.toml +++ b/Project.toml @@ -15,7 +15,8 @@ ScopedValues = "1" julia = "1" [extras] +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test"] +test = ["Random", "Test"] diff --git a/src/array.jl b/src/array.jl index 0c26f70..bb4deef 100644 --- a/src/array.jl +++ b/src/array.jl @@ -89,6 +89,13 @@ function Base.setindex!(array::MLXArray{T, N}, v::T, i::Int) where {T, N} return array end +function Base.similar(array::MLXArray{T, N}, ::Type{T}, ::Dims{N}) where {T, N} + stream = get_stream() + result_ref = Ref(Wrapper.mlx_array_new()) + Wrapper.mlx_zeros_like(result_ref, array.mlx_array, stream.mlx_stream) + return MLXArray{T, N}(result_ref[]) +end + # Strided array interface, cf. https://docs.julialang.org/en/v1/manual/interfaces/#man-interface-strided-arrays function Base.strides(array::MLXArray) @@ -164,3 +171,21 @@ function Base.unsafe_wrap(array::MLXArray{T, N}) where {T, N} return PermutedDimsArray(wrapped_array, reverse(1:ndims(array))) end end + +# Broadcasting interface, cf. https://docs.julialang.org/en/v1/manual/interfaces/#man-interfaces-broadcasting + +Base.BroadcastStyle(::Type{<:MLXArray}) = Broadcast.ArrayStyle{MLXArray}() + +function Base.similar( + bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MLXArray}}, ::Type{TElement} +) where {TElement} + first_mlx_array(bc::Broadcast.Broadcasted) = first_mlx_array(bc.args) + function first_mlx_array(args::Tuple) + return first_mlx_array(first_mlx_array(args[1]), Base.tail(args)) + end + first_mlx_array(x) = x + first_mlx_array(::Tuple{}) = nothing + first_mlx_array(a::MLXArray, _) = a + first_mlx_array(::Any, rest) = first_mlx_array(rest) + return similar(first_mlx_array(bc)) +end diff --git a/test/array_tests.jl b/test/array_tests.jl index 025688a..4b2843b 100644 --- a/test/array_tests.jl +++ b/test/array_tests.jl @@ -1,7 +1,21 @@ +@static if VERSION < v"1.11" + using ScopedValues +else + using Base.ScopedValues +end + using MLX +using Random using Test @testset "MLXArray" begin + Random.seed!(42) + + device_types = [MLX.DeviceTypeCPU] + if MLX.metal_is_available() + push!(device_types, MLX.DeviceTypeGPU) + end + @test IndexStyle(MLXArray) == IndexLinear() array_sizes = [(), (1,), (2,), (1, 1), (2, 1), (3, 2), (4, 3, 2)] @@ -31,6 +45,22 @@ using Test array[1] = T(1) @test setindex!(mlx_array, T(1), 1) == array end + + @testset "similar(::$MLXArray{$T, $N}), array_size=$array_size" begin + for device_type in device_types + if T ∉ MLX.supported_number_types(device_type) + continue + end + @testset "similar(::$MLXArray{$T, $N}), with array_size=$array_size, $device_type" begin + with(MLX.device => MLX.Device(; device_type)) do + similar_mlx_array = similar(mlx_array) + @test typeof(similar_mlx_array) == typeof(mlx_array) + @test size(similar_mlx_array) == size(mlx_array) + @test similar_mlx_array !== mlx_array + end + end + end + end end end end @@ -62,7 +92,28 @@ using Test @test Base.elsize(MLXArray{T, 0}) == Base.elsize(Array{T, 0}) end end + @testset "Unsupported Number types" begin @test_throws ArgumentError convert(MLX.Wrapper.mlx_dtype, Rational{Int}) end + + @testset "Broadcasting interface" begin + for device_type in device_types, + T in MLX.supported_number_types(device_type), + array_size in array_sizes + + N = length(array_size) + @testset "broadcast(identity, ::$MLXArray{$T, $N}), array_size=$array_size, $device_type" begin + array = rand(T, array_size) + mlx_array = MLXArray(array) + + with(MLX.device => MLX.Device(; device_type)) do + result = identity.(mlx_array) + @test result isa MLXArray + @test result == mlx_array + @test result !== mlx_array + end + end + end + end end diff --git a/test/device_tests.jl b/test/device_tests.jl index 4051cab..3a7dbab 100644 --- a/test/device_tests.jl +++ b/test/device_tests.jl @@ -3,6 +3,7 @@ else using Base.ScopedValues end + using MLX using Test diff --git a/test/stream_tests.jl b/test/stream_tests.jl index 28b106f..b5a1429 100644 --- a/test/stream_tests.jl +++ b/test/stream_tests.jl @@ -3,6 +3,7 @@ else using Base.ScopedValues end + using MLX using Test