diff --git a/Project.toml b/Project.toml index 507b6b0..4b0e56f 100644 --- a/Project.toml +++ b/Project.toml @@ -15,7 +15,8 @@ ScopedValues = "1" julia = "1" [extras] +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test"] +test = ["Random", "Test"] diff --git a/src/MLX.jl b/src/MLX.jl index 0d4b07a..4daf147 100644 --- a/src/MLX.jl +++ b/src/MLX.jl @@ -10,6 +10,8 @@ export MLXArray, MLXException, MLXMatrix, MLXVecOrMat, MLXVector include(joinpath(@__DIR__, "Wrapper.jl")) +include(joinpath(@__DIR__, "Private.jl")) + include(joinpath(@__DIR__, "utils.jl")) include(joinpath(@__DIR__, "array.jl")) @@ -18,6 +20,8 @@ include(joinpath(@__DIR__, "error_handling.jl")) include(joinpath(@__DIR__, "metal.jl")) include(joinpath(@__DIR__, "stream.jl")) +include(joinpath(@__DIR__, "ops.jl")) + function __init__() register_error_handler() return nothing diff --git a/src/Private.jl b/src/Private.jl new file mode 100644 index 0000000..ef34fec --- /dev/null +++ b/src/Private.jl @@ -0,0 +1,286 @@ +module Private + +using ..Wrapper + +function return_input_type(::Type{TIn}) where {TIn} + return TIn +end + +function return_float_type(::Type{TIn}) where {TIn} + return TIn <: Complex{<:AbstractFloat} ? TIn : Float32 # TODO: Float64 unsupported by MLX C 0.1.1 +end + +function get_unary_scalar_ops() + RealExceptBool = Union{AbstractFloat, Signed, Unsigned} + return Dict( + :abs => ( + mlx_fn = Wrapper.mlx_abs, + TIn = Real, # in testing, abs differs from mlx_abs wrt. Complex{<:AbstractFloat} + output_type = return_input_type, + preserves_type = true, + normalize = (a, TIn) -> a, + ), + :acos => ( + mlx_fn = Wrapper.mlx_arccos, + TIn = RealExceptBool, # in testing, acos differs from mlx_arccos wrt. Bool, Complex{<:AbstractFloat} + output_type = return_float_type, + preserves_type = false, + normalize = (a, TIn) -> TIn.(floor.(a ./ maximum(a))), + ), + :acosh => ( + mlx_fn = Wrapper.mlx_arccosh, + TIn = Union{AbstractFloat, Complex}, # in testing, acosh differs from mlx_arccosh wrt. Integer + output_type = return_input_type, + preserves_type = true, + normalize = (a, TIn) -> a .+ 1, + ), + :asin => ( + mlx_fn = Wrapper.mlx_arcsin, + TIn = AbstractFloat, # in testing, asin differs from mlx_arcsin wrt. Integer, normalize fails for Complex{<:AbstractFloat} + output_type = return_input_type, + preserves_type = true, + normalize = (a, TIn) -> TIn.(floor.(a ./ maximum(a))), + ), + :asinh => ( + mlx_fn = Wrapper.mlx_arcsinh, + TIn = Union{AbstractFloat, Complex}, # in testing, asinh differs from mlx_arcsinh wrt. Integer + output_type = return_input_type, + preserves_type = true, + normalize = (a, TIn) -> a, + ), + :atan => ( + mlx_fn = Wrapper.mlx_arctan, + TIn = Real, # testing fails for atan wrt. Complex{<:AbstractFloat} + output_type = return_float_type, + preserves_type = false, + normalize = (a, TIn) -> a, + ), + :atanh => ( + mlx_fn = Wrapper.mlx_arctanh, + TIn = AbstractFloat, # in testing, atanh differs from mlx_arctanh wrt. Integer, normalize fails for Complex{<:AbstractFloat} + output_type = return_float_type, + preserves_type = false, + normalize = (a, TIn) -> TIn.(floor.(a ./ maximum(a))), + ), + # mlx_atleast_1d + # mlx_atleast_2d + # mlx_atleast_3d + :~ => ( + mlx_fn = Wrapper.mlx_bitwise_invert, + TIn = Integer, + output_type = return_input_type, + preserves_type = true, + normalize = (a, TIn) -> a, + ), + :ceil => ( + mlx_fn = Wrapper.mlx_ceil, + TIn = Real, # MLX: [floor] Not supported for complex64 + output_type = return_input_type, + preserves_type = true, + normalize = (a, TIn) -> a, + ), + :conj => ( # TODO: conj is also defined for AbstractArray + mlx_fn = Wrapper.mlx_conjugate, + TIn = Number, + output_type = return_input_type, + preserves_type = true, + normalize = (a, TIn) -> a, + ), + :cos => ( + mlx_fn = Wrapper.mlx_cos, + TIn = AbstractFloat, # in testing, cos differs from mlx_cos wrt. Signed, Unsigned, Complex{<:AbstractFloat}, Bool fails: conversion to pointer not defined for BitArray + output_type = return_float_type, + preserves_type = false, + normalize = (a, TIn) -> + TIn.(round.(map(x -> iszero(x % π) ? x + eps(Float32) : x, a))), + ), + :cosh => ( + mlx_fn = Wrapper.mlx_cosh, + TIn = Real, # testing fails for cosh wrt. Complex{<:AbstractFloat} + output_type = return_float_type, + preserves_type = false, + normalize = (a, TIn) -> a, + ), + :rad2deg => ( + mlx_fn = Wrapper.mlx_degrees, + TIn = Real, # testing fails for rad2deg wrt. Complex{<:AbstractFloat} + output_type = return_float_type, + preserves_type = false, + normalize = (a, TIn) -> a, + ), + # mlx_erf + # mlx_erfinv + :exp => ( + mlx_fn = Wrapper.mlx_exp, + TIn = Real, # testing fails for exp wrt. Complex{<:AbstractFloat}. TODO: Needs broadcast across Float32 and ComplexF32: `copyto!(dest::MLXArray{Float32, 3}, bc::Base.Broadcast.Broadcasted{Nothing, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}, Base.OneTo{Int64}}, typeof(-), Tuple{MLXArray{Float32, 3}, Array{ComplexF32, 3}}})` + output_type = return_float_type, + preserves_type = false, + normalize = (a, TIn) -> a, + ), + :expm1 => ( + mlx_fn = Wrapper.mlx_expm1, + TIn = Real, # testing fails for expm1 wrt. Complex{<:AbstractFloat}. TODO: Needs broadcast across Float32 and ComplexF32 + output_type = return_float_type, + preserves_type = false, + normalize = (a, TIn) -> a, + ), + :floor => ( + mlx_fn = Wrapper.mlx_floor, + TIn = Real, # MLX: [floor] Not supported for complex64 + output_type = return_input_type, + preserves_type = true, + normalize = (a, TIn) -> a, + ), + :imag => ( + mlx_fn = Wrapper.mlx_imag, + TIn = Real, # testing segfaults wrt. Complex{<:AbstractFloat} + output_type = return_input_type, + preserves_type = true, + normalize = (a, TIn) -> a, + ), + :isfinite => ( + mlx_fn = Wrapper.mlx_isfinite, + TIn = Number, + output_type = (::Type) -> Bool, + preserves_type = false, + normalize = (a, TIn) -> a, # TODO should provide some finite/Infs + ), + :isinf => ( + mlx_fn = Wrapper.mlx_isinf, + TIn = Number, + output_type = (::Type) -> Bool, + preserves_type = false, + normalize = (a, TIn) -> a, # TODO should provide some Infs + ), + :isnan => ( + mlx_fn = Wrapper.mlx_isnan, + TIn = Number, + output_type = (::Type) -> Bool, + preserves_type = false, + normalize = (a, TIn) -> a, # TODO should provide some NaNs + ), + # mlx_isneginf + # mlx_isposinf + :log => ( + mlx_fn = Wrapper.mlx_log, + TIn = RealExceptBool, # Bool fails: conversion to pointer not defined for BitArray. Complex{<:AbstractFloat} fails: MethodError: no method matching isless(::ComplexF32, ::Float32) + output_type = return_float_type, + preserves_type = false, + normalize = (a, TIn) -> TIn.(ceil.(max.(eps(Float32), a))), + ), + :log10 => ( + mlx_fn = Wrapper.mlx_log10, + TIn = RealExceptBool, # Bool fails: conversion to pointer not defined for BitArray. Complex{<:AbstractFloat} fails: MethodError: no method matching isless(::ComplexF32, ::Float32) + output_type = return_float_type, + preserves_type = false, + normalize = (a, TIn) -> TIn.(ceil.(max.(eps(Float32), a))), + ), + :log1p => ( + mlx_fn = Wrapper.mlx_log1p, + TIn = RealExceptBool, # Bool fails: conversion to pointer not defined for BitArray. Complex{<:AbstractFloat} fails: MethodError: no method matching isless(::ComplexF32, ::Float32) + output_type = return_float_type, + preserves_type = false, + normalize = (a, TIn) -> TIn.(ceil.(max.(eps(Float32), a))), + ), + :log2 => ( + mlx_fn = Wrapper.mlx_log2, + TIn = RealExceptBool, # Bool fails: conversion to pointer not defined for BitArray. Complex{<:AbstractFloat} fails: MethodError: no method matching isless(::ComplexF32, ::Float32) + output_type = return_float_type, + preserves_type = false, + normalize = (a, TIn) -> TIn.(ceil.(max.(eps(Float32), a))), + ), + :! => ( + mlx_fn = Wrapper.mlx_logical_not, + TIn = Bool, + output_type = return_input_type, + preserves_type = true, + normalize = (a, TIn) -> a, + ), + :- => ( + mlx_fn = Wrapper.mlx_negative, + TIn = Union{RealExceptBool, Complex{<:AbstractFloat}}, # MLX: [negative] Not supported for bool, use logical_not instead. + output_type = return_input_type, + preserves_type = true, + normalize = (a, TIn) -> a, + ), + # mlx_ones_like + :deg2rad => ( + mlx_fn = Wrapper.mlx_radians, + TIn = Real, # testing fails for deg2rad wrt. Complex{<:AbstractFloat}. TODO: Needs broadcast across Float32 and ComplexF32 + output_type = return_float_type, + preserves_type = false, + normalize = (a, TIn) -> a, + ), + :real => ( + mlx_fn = Wrapper.mlx_real, + TIn = Real, # testing fails for real wrt. Complex{<:AbstractFloat} likely due to array storage order. + output_type = return_input_type, + preserves_type = true, + normalize = (a, TIn) -> a, + ), + :inv => ( + mlx_fn = Wrapper.mlx_reciprocal, # TODO check if this is correct, notably wrt. mlx_linalg_inv + TIn = Real, # testing fails for inv wrt. Complex{<:AbstractFloat}. TODO: Needs broadcast across Float32 and ComplexF32 + output_type = return_float_type, + preserves_type = false, + normalize = (a, TIn) -> a, + ), + # mlx_rsqrt + # mlx_sigmoid + :sign => ( + mlx_fn = Wrapper.mlx_sign, + TIn = Union{AbstractFloat, Bool, Signed, Complex}, # TIn = Number \ Unsigned: sign broken on CPU for Unsigned on MLX <= 0.24.1, cf. https://github.com/ml-explore/mlx/issues/2023 + output_type = return_input_type, + preserves_type = true, + normalize = (a, TIn) -> a, + ), + :sin => ( + mlx_fn = Wrapper.mlx_sin, + TIn = AbstractFloat, # in testing, sin differs from mlx_sin wrt. Signed, Unsigned, Complex{<:AbstractFloat}, Bool fails: conversion to pointer not defined for BitArray + output_type = return_float_type, + preserves_type = false, + normalize = (a, TIn) -> + TIn.(round.(map(x -> iszero(x % π / 2) ? x + eps(Float32) : x, a))), + ), + :sinh => ( + mlx_fn = Wrapper.mlx_sinh, + TIn = Real, # testing fails for cosh wrt. Complex{<:AbstractFloat}. TODO: Needs broadcast across Float32 and ComplexF32 + output_type = return_float_type, + preserves_type = false, + normalize = (a, TIn) -> a, + ), + :sqrt => ( + mlx_fn = Wrapper.mlx_sqrt, + TIn = RealExceptBool, # Bool fails: conversion to pointer not defined for BitArray. Complex{<:AbstractFloat} fails: MethodError: no method matching isless(::ComplexF32, ::Float32) + output_type = return_float_type, + preserves_type = false, + normalize = (a, TIn) -> TIn.(ceil.(max.(eps(Float32), a))), + ), + # mlx_square + # mlx_stop_gradient + :tan => ( + mlx_fn = Wrapper.mlx_tan, + TIn = Union{AbstractFloat, Bool}, # in testing, tan differs from mlx_tan wrt. Signed, Unsigned, Complex{<:AbstractFloat} + output_type = return_float_type, + preserves_type = false, + normalize = (a, TIn) -> a, + ), + :tanh => ( + mlx_fn = Wrapper.mlx_tanh, + TIn = Real, # testing fails for tanh wrt. Complex{<:AbstractFloat}. TODO: Needs broadcast across Float32 and ComplexF32 + output_type = return_float_type, + preserves_type = false, + normalize = (a, TIn) -> a, + ), + # mlx_linalg_inv + # :pinv => ( # TODO using LinearAlgebra + # mlx_fn = Wrapper.mlx_linalg_pinv, + # TIn = Number, + # output_type = return_input_type, + # preserves_type = true, + # normalize = (a, TIn) -> a, + # ), + ) +end + +end diff --git a/src/array.jl b/src/array.jl index 0e09b46..bb4deef 100644 --- a/src/array.jl +++ b/src/array.jl @@ -89,10 +89,17 @@ function Base.setindex!(array::MLXArray{T, N}, v::T, i::Int) where {T, N} return array end +function Base.similar(array::MLXArray{T, N}, ::Type{T}, ::Dims{N}) where {T, N} + stream = get_stream() + result_ref = Ref(Wrapper.mlx_array_new()) + Wrapper.mlx_zeros_like(result_ref, array.mlx_array, stream.mlx_stream) + return MLXArray{T, N}(result_ref[]) +end + # Strided array interface, cf. https://docs.julialang.org/en/v1/manual/interfaces/#man-interface-strided-arrays function Base.strides(array::MLXArray) - return Tuple( + array_strides = Tuple( Int.( unsafe_wrap( Vector{Csize_t}, @@ -101,6 +108,13 @@ function Base.strides(array::MLXArray) ), ), ) + if any(iszero, array_strides) # Workaround for MLX issue where strides may be zero for dims of size 1: https://github.com/ml-explore/mlx/issues/2501 + non_zero_strides = map(s -> iszero(s) ? 1 : s, array_strides) + @debug "Some strides are zero in $array_strides - returning strides $non_zero_strides for array of size $(size(array))" + return non_zero_strides + end + + return array_strides end function Base.unsafe_convert(::Type{Ptr{T}}, array::MLXArray{T, N}) where {T, N} @@ -157,3 +171,21 @@ function Base.unsafe_wrap(array::MLXArray{T, N}) where {T, N} return PermutedDimsArray(wrapped_array, reverse(1:ndims(array))) end end + +# Broadcasting interface, cf. https://docs.julialang.org/en/v1/manual/interfaces/#man-interfaces-broadcasting + +Base.BroadcastStyle(::Type{<:MLXArray}) = Broadcast.ArrayStyle{MLXArray}() + +function Base.similar( + bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MLXArray}}, ::Type{TElement} +) where {TElement} + first_mlx_array(bc::Broadcast.Broadcasted) = first_mlx_array(bc.args) + function first_mlx_array(args::Tuple) + return first_mlx_array(first_mlx_array(args[1]), Base.tail(args)) + end + first_mlx_array(x) = x + first_mlx_array(::Tuple{}) = nothing + first_mlx_array(a::MLXArray, _) = a + first_mlx_array(::Any, rest) = first_mlx_array(rest) + return similar(first_mlx_array(bc)) +end diff --git a/src/ops.jl b/src/ops.jl new file mode 100644 index 0000000..a8b2ee0 --- /dev/null +++ b/src/ops.jl @@ -0,0 +1,95 @@ +function Base.copy(a::MLXArray{T, N}) where {T, N} + s = get_stream() + result = Ref(Wrapper.mlx_array_new()) + Wrapper.mlx_copy(result, a.mlx_array, s.mlx_stream) + return MLXArray{T, N}(result[]) +end + +""" + dropdims(a::MLXArray; dims::Union{Dims, Integer, Nothing} = nothing) + +Return an array with singleton dimensions removed. If `dims` is not specified, +all singleton dimensions are removed. +""" +function Base.dropdims( + a::MLXArray{T, N}; dims::Union{Dims, Integer, Nothing} = nothing +) where {T, N} + s = get_stream() + result = Ref(Wrapper.mlx_array_new()) + if isnothing(dims) + Wrapper.mlx_squeeze_all(result, a.mlx_array, s.mlx_stream) + else + if dims isa Integer + dims = Dims(dims) + end + axes = collect(Cint.(dims) .- one(Cint)) + Wrapper.mlx_squeeze(result, a.mlx_array, axes, length(axes), s.mlx_stream) + end + remaining_dims = Int(Wrapper.mlx_array_ndim(result[])) + return MLXArray{T, remaining_dims}(result[]) +end + +function Base.sort(v::MLXVector{T}) where {T} + s = get_stream() + result = Ref(Wrapper.mlx_array_new()) + Wrapper.mlx_sort_all(result, v.mlx_array, s.mlx_stream) + return MLXVector{T}(result[]) +end + +function Base.sort(a::MLXArray{T, N}; dims::Integer) where {T, N} + s = get_stream() + result = Ref(Wrapper.mlx_array_new()) + axis = Cint(dims) - one(Cint) + Wrapper.mlx_sort(result, a.mlx_array, axis, s.mlx_stream) + return MLXArray{T, N}(result[]) +end + +function Base.sortperm(v::MLXVector{T}) where {T} + s = get_stream() + result = Ref(Wrapper.mlx_array_new()) + Wrapper.mlx_argsort_all(result, v.mlx_array, s.mlx_stream) + return MLXVector{T}(result[]) +end + +function Base.sortperm(a::MLXArray{T, N}; dims::Integer) where {T, N} + s = get_stream() + result = Ref(Wrapper.mlx_array_new()) + axis = Cint(dims) - one(Cint) + Wrapper.mlx_argsort(result, a.mlx_array, axis, s.mlx_stream) + Wrapper.mlx_add(result, result[], Wrapper.mlx_array_new_int(1), s.mlx_stream) + return MLXArray{T, N}(result[]) +end + +function Base.permutedims( + a::MLXArray{T, N}; dims::Union{Dims, Integer, Nothing} = nothing +) where {T, N} + s = get_stream() + result = Ref(Wrapper.mlx_array_new()) + if isnothing(dims) + Wrapper.mlx_transpose_all(result, a.mlx_array, s.mlx_stream) + else + if dims isa Integer + dims = Dims(dims) + end + axes = collect(Cint.(dims) .- one(Cint)) + Wrapper.mlx_transpose(result, a.mlx_array, axes, length(axes), s.mlx_stream) + end + return MLXArray{T, N}(result[]) +end + +for (fn, fn_def) in Private.get_unary_scalar_ops() + TOut = fn_def.output_type(fn_def.TIn) + + @eval function Broadcast.broadcasted( + ::Broadcast.ArrayStyle{MLXArray}, ::typeof($fn), a::MLXArray{T, N} + ) where {T <: $(fn_def.TIn), N} + s = get_stream() + result_ref = Ref(Wrapper.mlx_array_new()) + $(fn_def.mlx_fn)(result_ref, a.mlx_array, s.mlx_stream) + @static if $(fn_def.preserves_type) + return MLXArray{T, N}(result_ref[]) + else + return MLXArray{$TOut, N}(result_ref[]) + end + end +end diff --git a/test/array_tests.jl b/test/array_tests.jl index 025688a..4b2843b 100644 --- a/test/array_tests.jl +++ b/test/array_tests.jl @@ -1,7 +1,21 @@ +@static if VERSION < v"1.11" + using ScopedValues +else + using Base.ScopedValues +end + using MLX +using Random using Test @testset "MLXArray" begin + Random.seed!(42) + + device_types = [MLX.DeviceTypeCPU] + if MLX.metal_is_available() + push!(device_types, MLX.DeviceTypeGPU) + end + @test IndexStyle(MLXArray) == IndexLinear() array_sizes = [(), (1,), (2,), (1, 1), (2, 1), (3, 2), (4, 3, 2)] @@ -31,6 +45,22 @@ using Test array[1] = T(1) @test setindex!(mlx_array, T(1), 1) == array end + + @testset "similar(::$MLXArray{$T, $N}), array_size=$array_size" begin + for device_type in device_types + if T ∉ MLX.supported_number_types(device_type) + continue + end + @testset "similar(::$MLXArray{$T, $N}), with array_size=$array_size, $device_type" begin + with(MLX.device => MLX.Device(; device_type)) do + similar_mlx_array = similar(mlx_array) + @test typeof(similar_mlx_array) == typeof(mlx_array) + @test size(similar_mlx_array) == size(mlx_array) + @test similar_mlx_array !== mlx_array + end + end + end + end end end end @@ -62,7 +92,28 @@ using Test @test Base.elsize(MLXArray{T, 0}) == Base.elsize(Array{T, 0}) end end + @testset "Unsupported Number types" begin @test_throws ArgumentError convert(MLX.Wrapper.mlx_dtype, Rational{Int}) end + + @testset "Broadcasting interface" begin + for device_type in device_types, + T in MLX.supported_number_types(device_type), + array_size in array_sizes + + N = length(array_size) + @testset "broadcast(identity, ::$MLXArray{$T, $N}), array_size=$array_size, $device_type" begin + array = rand(T, array_size) + mlx_array = MLXArray(array) + + with(MLX.device => MLX.Device(; device_type)) do + result = identity.(mlx_array) + @test result isa MLXArray + @test result == mlx_array + @test result !== mlx_array + end + end + end + end end diff --git a/test/device_tests.jl b/test/device_tests.jl index 4051cab..3a7dbab 100644 --- a/test/device_tests.jl +++ b/test/device_tests.jl @@ -3,6 +3,7 @@ else using Base.ScopedValues end + using MLX using Test diff --git a/test/ops_tests.jl b/test/ops_tests.jl new file mode 100644 index 0000000..cf42bd4 --- /dev/null +++ b/test/ops_tests.jl @@ -0,0 +1,143 @@ +using MLX +using Random +using Test + +@testset "ops" begin + Random.seed!(42) + + element_types = MLX.supported_number_types(MLX.DeviceTypeGPU) # TODO Excluding Float64 + + array_sizes = [ + # (), # TODO: Excluded broadcasting over 0-dimensional Array: Yields scalar result + (1,), + (2,), + (1, 1), + (2, 1), + (2, 2), + (1, 1, 1), + ] + + @testset "copy" begin + for T in element_types, array_size in array_sizes + N = length(array_size) + @testset "copy(::$MLXArray{$T, $N}), $array_size" begin + array = rand(T, array_size) + if N > 2 || N == 0 + mlx_array = MLXArray(array) + elseif N > 1 + mlx_array = MLXMatrix(array) + else + mlx_array = MLXVector(array) + end + actual = copy(mlx_array) + expected = copy(array) + @test actual == expected + end + end + end + + @testset "dropdims" begin + for T in element_types, array_size in array_sizes + N = length(array_size) + dims = Dims(unique(rand(1:N, rand(1:N)))) + if !all([array_size[d] == 1 for d in dims]) + continue + end + @testset "dropdims(::$MLXArray{$T, $N}; dims = $dims), $array_size" begin + array = rand(T, array_size) + if N > 2 || N == 0 + mlx_array = MLXArray(array) + elseif N > 1 + mlx_array = MLXMatrix(array) + else + mlx_array = MLXVector(array) + end + actual = dropdims(mlx_array; dims) + expected = dropdims(array; dims) + @test actual == expected + end + end + end + + for fn in [:sort] # TODO sortperm is broken + @testset "$fn" begin + for T in filter(T -> T != ComplexF32, element_types), # isless is not defined for ComplexF32 + array_size in array_sizes + + N = length(array_size) + dims = rand(1:N) + @testset "$fn(::$MLXArray{$T, $N}), $array_size" begin + array = rand(T, array_size) + if N > 2 || N == 0 + mlx_array = MLXArray(array) + elseif N > 1 + mlx_array = MLXMatrix(array) + else + mlx_array = MLXVector(array) + end + if N == 1 + actual = @eval $fn($mlx_array) + expected = @eval $fn($array) + else + actual = @eval $fn($mlx_array; dims = $dims) + expected = @eval $fn($array; dims = $dims) + end + @test actual == expected + end + end + end + end + + @testset "permutedims" begin + for T in element_types, array_size in array_sizes + N = length(array_size) + @testset "permutedims(::$MLXArray{$T, $N}), $array_size" begin + array = rand(T, array_size) + if N > 2 || N == 0 + mlx_array = MLXArray(array) + elseif N > 1 + mlx_array = MLXMatrix(array) + else + mlx_array = MLXVector(array) + end + actual = permutedims(mlx_array) + expected = permutedims(array, reverse(1:ndims(array))) + @test actual == expected + end + end + end + + for (fn, fn_def) in MLX.Private.get_unary_scalar_ops() + @testset "$fn" begin + for T in element_types, array_size in array_sizes + N = length(array_size) + if !(T <: fn_def.TIn) + continue + end + @testset "$fn.(::$MLXArray{$T, $N}), $array_size" begin + array = rand(T, array_size) + array = fn_def.normalize(array, T) + if N > 2 || N == 0 + mlx_array = MLXArray(array) + elseif N > 1 + mlx_array = MLXMatrix(array) + else + mlx_array = MLXVector(array) + end + TOut = fn_def.output_type(T) + if TOut == Float32 # TODO fn.(array) may return a Float64 array for a Float32 array + expected = @eval $TOut.($fn.($array)) + else + expected = @eval $fn.($array) + end + actual = @eval $fn.($mlx_array) + if TOut <: Integer + @test actual == expected + else + @test actual ≈ expected + end + end + end + end + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 1af8702..b9d7a0d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,5 +4,6 @@ using Test include(joinpath(@__DIR__, "array_tests.jl")) include(joinpath(@__DIR__, "device_tests.jl")) include(joinpath(@__DIR__, "error_handling_tests.jl")) + include(joinpath(@__DIR__, "ops_tests.jl")) include(joinpath(@__DIR__, "stream_tests.jl")) end diff --git a/test/stream_tests.jl b/test/stream_tests.jl index 28b106f..b5a1429 100644 --- a/test/stream_tests.jl +++ b/test/stream_tests.jl @@ -3,6 +3,7 @@ else using Base.ScopedValues end + using MLX using Test