Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ ScopedValues = "1"
julia = "1"

[extras]
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test"]
test = ["Random", "Test"]
34 changes: 33 additions & 1 deletion src/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,17 @@ function Base.setindex!(array::MLXArray{T, N}, v::T, i::Int) where {T, N}
return array
end

function Base.similar(array::MLXArray{T, N}, ::Type{T}, ::Dims{N}) where {T, N}
stream = get_stream()
result_ref = Ref(Wrapper.mlx_array_new())
Wrapper.mlx_zeros_like(result_ref, array.mlx_array, stream.mlx_stream)
return MLXArray{T, N}(result_ref[])
end

# Strided array interface, cf. https://docs.julialang.org/en/v1/manual/interfaces/#man-interface-strided-arrays

function Base.strides(array::MLXArray)
return Tuple(
array_strides = Tuple(
Int.(
unsafe_wrap(
Vector{Csize_t},
Expand All @@ -101,6 +108,13 @@ function Base.strides(array::MLXArray)
),
),
)
if any(iszero, array_strides) # Workaround for MLX issue where strides may be zero for dims of size 1: https://github.com/ml-explore/mlx/issues/2501
non_zero_strides = map(s -> iszero(s) ? 1 : s, array_strides)
@debug "Some strides are zero in $array_strides - returning strides $non_zero_strides for array of size $(size(array))"
return non_zero_strides
end

return array_strides
end

function Base.unsafe_convert(::Type{Ptr{T}}, array::MLXArray{T, N}) where {T, N}
Expand Down Expand Up @@ -157,3 +171,21 @@ function Base.unsafe_wrap(array::MLXArray{T, N}) where {T, N}
return PermutedDimsArray(wrapped_array, reverse(1:ndims(array)))
end
end

# Broadcasting interface, cf. https://docs.julialang.org/en/v1/manual/interfaces/#man-interfaces-broadcasting

Base.BroadcastStyle(::Type{<:MLXArray}) = Broadcast.ArrayStyle{MLXArray}()

function Base.similar(
bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MLXArray}}, ::Type{TElement}
) where {TElement}
first_mlx_array(bc::Broadcast.Broadcasted) = first_mlx_array(bc.args)
function first_mlx_array(args::Tuple)
return first_mlx_array(first_mlx_array(args[1]), Base.tail(args))
end
first_mlx_array(x) = x
first_mlx_array(::Tuple{}) = nothing
first_mlx_array(a::MLXArray, _) = a
first_mlx_array(::Any, rest) = first_mlx_array(rest)
return similar(first_mlx_array(bc))
end
51 changes: 51 additions & 0 deletions test/array_tests.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,21 @@
@static if VERSION < v"1.11"
using ScopedValues
else
using Base.ScopedValues
end

using MLX
using Random
using Test

@testset "MLXArray" begin
Random.seed!(42)

device_types = [MLX.DeviceTypeCPU]
if MLX.metal_is_available()
push!(device_types, MLX.DeviceTypeGPU)
end

@test IndexStyle(MLXArray) == IndexLinear()

array_sizes = [(), (1,), (2,), (1, 1), (2, 1), (3, 2), (4, 3, 2)]
Expand Down Expand Up @@ -31,6 +45,22 @@ using Test
array[1] = T(1)
@test setindex!(mlx_array, T(1), 1) == array
end

@testset "similar(::$MLXArray{$T, $N}), array_size=$array_size" begin
for device_type in device_types
if T ∉ MLX.supported_number_types(device_type)
continue
end
@testset "similar(::$MLXArray{$T, $N}), with array_size=$array_size, $device_type" begin
with(MLX.device => MLX.Device(; device_type)) do
similar_mlx_array = similar(mlx_array)
@test typeof(similar_mlx_array) == typeof(mlx_array)
@test size(similar_mlx_array) == size(mlx_array)
@test similar_mlx_array !== mlx_array
end
end
end
end
end
end
end
Expand Down Expand Up @@ -62,7 +92,28 @@ using Test
@test Base.elsize(MLXArray{T, 0}) == Base.elsize(Array{T, 0})
end
end

@testset "Unsupported Number types" begin
@test_throws ArgumentError convert(MLX.Wrapper.mlx_dtype, Rational{Int})
end

@testset "Broadcasting interface" begin
for device_type in device_types,
T in MLX.supported_number_types(device_type),
array_size in array_sizes

N = length(array_size)
@testset "broadcast(identity, ::$MLXArray{$T, $N}), array_size=$array_size, $device_type" begin
array = rand(T, array_size)
mlx_array = MLXArray(array)

with(MLX.device => MLX.Device(; device_type)) do
result = identity.(mlx_array)
@test result isa MLXArray
@test result == mlx_array
@test result !== mlx_array
end
end
end
end
end
1 change: 1 addition & 0 deletions test/device_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
else
using Base.ScopedValues
end

using MLX
using Test

Expand Down
1 change: 1 addition & 0 deletions test/stream_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
else
using Base.ScopedValues
end

using MLX
using Test

Expand Down
Loading