Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ steps:
julia:
- "1.10"
- "1.11"
- "1.12"
- "nightly"
adjustments:
- with:
Expand Down
8 changes: 4 additions & 4 deletions lib/level-zero/pointer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ Base.eltype(::Type{<:ZePtr{T}}) where {T} = T
Base.convert(::Type{T}, x::ZePtr) where {T<:Integer} = T(UInt(x))
## integer to pointer
Base.convert(::Type{ZePtr{T}}, x::Union{Int,UInt}) where {T} = ZePtr{T}(x)
Int(x::ZePtr) = Base.bitcast(Int, x)
UInt(x::ZePtr) = Base.bitcast(UInt, x)
Base.Int(x::ZePtr) = Base.bitcast(Int, x)
Base.UInt(x::ZePtr) = Base.bitcast(UInt, x)

# between regular and oneAPI pointers
Base.convert(::Type{<:Ptr}, p::ZePtr) =
Expand Down Expand Up @@ -71,8 +71,8 @@ Base.:(==)(x::ZePtr, y::ZePtr) = UInt(x) == UInt(y)
Base.:(<)(x::ZePtr, y::ZePtr) = UInt(x) < UInt(y)
Base.:(-)(x::ZePtr, y::ZePtr) = UInt(x) - UInt(y)

Base.:(+)(x::ZePtr, y::Integer) = oftype(x, Base.add_ptr(UInt(x), (y % UInt) % UInt))
Base.:(-)(x::ZePtr, y::Integer) = oftype(x, Base.sub_ptr(UInt(x), (y % UInt) % UInt))
Base.:(+)(x::ZePtr, y::Integer) = oftype(x, UInt(x) + (y % UInt) % UInt)
Base.:(-)(x::ZePtr, y::Integer) = oftype(x, UInt(x) - (y % UInt) % UInt)
Base.:(+)(x::Integer, y::ZePtr) = y + x


Expand Down
34 changes: 22 additions & 12 deletions lib/mkl/interfaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,42 @@

using LinearAlgebra: BlasComplex, BlasFloat, BlasReal, MulAddMul

function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::oneSparseMatrixCSR{T}, B::oneVector{T}, _add::MulAddMul) where T <: BlasFloat
# legacy methods with final MulAddMul argument
LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::oneSparseMatrixCSR{T}, B::oneVector{T}, _add::MulAddMul) where {T <: Union{Float16, ComplexF16, BlasFloat}} =
LinearAlgebra.generic_matvecmul!(C, tA, A, B, _add.alpha, _add.beta)
LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::oneSparseMatrixCSC{T}, B::oneVector{T}, _add::MulAddMul) where {T <: Union{Float16, ComplexF16, BlasFloat}} =
LinearAlgebra.generic_matvecmul!(C, tA, A, B, _add.alpha, _add.beta)
LinearAlgebra.generic_matmatmul!(C::oneMatrix{T}, tA, tB, A::oneSparseMatrixCSR{T}, B::oneMatrix{T}, _add::MulAddMul) where {T <: Union{Float16, ComplexF16, BlasFloat}} =
LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, _add.alpha, _add.beta)
LinearAlgebra.generic_matmatmul!(C::oneMatrix{T}, tA, tB, A::oneSparseMatrixCSC{T}, B::oneMatrix{T}, _add::MulAddMul) where {T <: Union{Float16, ComplexF16, BlasFloat}} =
LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, _add.alpha, _add.beta)

function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::oneSparseMatrixCSR{T}, B::oneVector{T}, alpha::Number, beta::Number) where {T <: BlasFloat}
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
sparse_gemv!(tA, _add.alpha, A, B, _add.beta, C)
return sparse_gemv!(tA, alpha, A, B, beta, C)
end

function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::oneSparseMatrixCSC{T}, B::oneVector{T}, _add::MulAddMul) where T <: BlasReal
function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::oneSparseMatrixCSC{T}, B::oneVector{T}, alpha::Number, beta::Number) where {T <: BlasReal}
tA = tA in ('S', 's', 'H', 'h') ? 'T' : flip_trans(tA)
sparse_gemv!(tA, _add.alpha, A, B, _add.beta, C)
return sparse_gemv!(tA, alpha, A, B, beta, C)
end

function LinearAlgebra.generic_matmatmul!(C::oneMatrix{T}, tA, tB, A::oneSparseMatrixCSR{T}, B::oneMatrix{T}, _add::MulAddMul) where T <: BlasFloat
function LinearAlgebra.generic_matmatmul!(C::oneMatrix{T}, tA, tB, A::oneSparseMatrixCSR{T}, B::oneMatrix{T}, alpha::Number, beta::Number) where {T <: BlasFloat}
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
tB = tB in ('S', 's', 'H', 'h') ? 'N' : tB
sparse_gemm!(tA, tB, _add.alpha, A, B, _add.beta, C)
return sparse_gemm!(tA, tB, alpha, A, B, beta, C)
end

function LinearAlgebra.generic_matmatmul!(C::oneMatrix{T}, tA, tB, A::oneSparseMatrixCSC{T}, B::oneMatrix{T}, _add::MulAddMul) where T <: BlasReal
function LinearAlgebra.generic_matmatmul!(C::oneMatrix{T}, tA, tB, A::oneSparseMatrixCSC{T}, B::oneMatrix{T}, alpha::Number, beta::Number) where {T <: BlasReal}
tA = tA in ('S', 's', 'H', 'h') ? 'T' : flip_trans(tA)
tB = tB in ('S', 's', 'H', 'h') ? 'N' : tB
sparse_gemm!(tA, tB, _add.alpha, A, B, _add.beta, C)
return sparse_gemm!(tA, tB, alpha, A, B, beta, C)
end

function LinearAlgebra.generic_trimatdiv!(C::oneVector{T}, uploc, isunitc, tfun::Function, A::oneSparseMatrixCSR{T}, B::oneVector{T}) where T <: BlasFloat
sparse_trsv!(uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), A, B, C)
function LinearAlgebra.generic_trimatdiv!(C::oneVector{T}, uploc, isunitc, tfun::Function, A::oneSparseMatrixCSR{T}, B::oneVector{T}) where {T <: BlasFloat}
return sparse_trsv!(uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), A, B, C)
end

function LinearAlgebra.generic_trimatdiv!(C::oneMatrix{T}, uploc, isunitc, tfun::Function, A::oneSparseMatrixCSR{T}, B::oneMatrix{T}) where T <: BlasFloat
sparse_trsm!(uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', 'N', isunitc, one(T), A, B, C)
function LinearAlgebra.generic_trimatdiv!(C::oneMatrix{T}, uploc, isunitc, tfun::Function, A::oneSparseMatrixCSR{T}, B::oneMatrix{T}) where {T <: BlasFloat}
return sparse_trsm!(uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', 'N', isunitc, one(T), A, B, C)
end
95 changes: 70 additions & 25 deletions lib/mkl/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ function LinearAlgebra.generic_matvecmul!(Y::oneVector, tA::AbstractChar, A::one
end
end
end
LinearAlgebra.generic_matmatmul!(Y, tA, 'N', A, B, MulAddMul(alpha, beta))
return LinearAlgebra.generic_matmatmul!(Y, tA, 'N', A, B, alpha, beta)
end

# triangular
Expand All @@ -120,46 +120,71 @@ LinearAlgebra.generic_trimatdiv!(C::oneStridedVector{T}, uploc, isunitc, tfun::F
# BLAS 3
#

LinearAlgebra.generic_matmatmul!(C::oneStridedMatrix, tA, tB, A::oneStridedVecOrMat, B::oneStridedVecOrMat, _add::MulAddMul=MulAddMul()) =
LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, _add.alpha, _add.beta)
function LinearAlgebra.generic_matmatmul!(C::oneStridedMatrix, tA, tB, A::oneStridedVecOrMat, B::oneStridedVecOrMat, a::Number, b::Number)
if VERSION >= v"1.12-"
# Otherwise dispatches onto:
# https://github.com/JuliaLang/LinearAlgebra.jl/blob/4e7c3f40316a956119ac419a97c4b8aad7a17e6c/src/matmul.jl#L490
for blas_flag in (LinearAlgebra.BlasFlag.SyrkHerkGemm, LinearAlgebra.BlasFlag.SymmHemmGeneric)
@eval LinearAlgebra.generic_matmatmul_wrapper!(
C::oneStridedMatrix, tA::AbstractChar, tB::AbstractChar, A::oneStridedVecOrMat, B::oneStridedVecOrMat,
alpha::Number, beta::Number, ::$blas_flag
) =
LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, alpha, beta)
end
end

LinearAlgebra.generic_matmatmul!(
C::oneStridedVecOrMat, tA, tB, A::oneStridedVecOrMat,
B::oneStridedVecOrMat, _add::MulAddMul,
) = LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, _add.alpha, _add.beta)
function LinearAlgebra.generic_matmatmul!(
C::oneStridedVecOrMat, tA, tB, A::oneStridedVecOrMat,
B::oneStridedVecOrMat, alpha::Number, beta::Number,
)
T = eltype(C)
alpha, beta = promote(a, b, zero(T))
mA, nA = size(A, tA == 'N' ? 1 : 2), size(A, tA == 'N' ? 2 : 1)
mB, nB = size(B, tB == 'N' ? 1 : 2), size(B, tB == 'N' ? 2 : 1)

if nA != mB
throw(DimensionMismatch("A has dimensions ($mA,$nA) but B has dimensions ($mB,$nB)"))
end

if C === A || B === C
throw(ArgumentError("output matrix must not be aliased with input matrix"))
end
nA != mB && throw(
DimensionMismatch(
"A has dimensions ($mA,$nA) but B has dimensions ($mB,$nB)"
)
)
(C === A || B === C) && throw(
ArgumentError(
"output matrix must not be aliased with input matrix"
)
)

if mA == 0 || nA == 0 || nB == 0
if size(C) != (mA, nB)
throw(DimensionMismatch("C has dimensions $(size(C)), should have ($mA,$nB)"))
end
size(C) != (mA, nB) && throw(
DimensionMismatch(
"C has dimensions $(size(C)), should have ($mA,$nB)"
)
)
return LinearAlgebra.rmul!(C, 0)
end

if all(in(('N', 'T', 'C')), (tA, tB))
if T <: Union{onemklFloat, onemklComplex, onemklHalf} && eltype(A) == eltype(B) == T
return gemm!(tA, tB, alpha, A, B, beta, C)
end
end
T = eltype(C)

if alpha isa Union{Bool,T} && beta isa Union{Bool,T}
# TODO: should the gemm part above be included in this branch?
if (tA == 'S' || tA == 's') && tB == 'N'
return symm!('L', tA == 'S' ? 'U' : 'L', alpha, A, B, beta, C)
α, β = T(alpha), T(beta)
if (
all(in(('N', 'T', 'C')), (tA, tB)) && T <: Union{onemklFloat, onemklComplex, onemklHalf} &&
A isa oneStridedArray{T} && B isa oneStridedArray{T}
)
return gemm!(tA, tB, α, A, B, β, C)
elseif (tA == 'S' || tA == 's') && tB == 'N'
return symm!('L', tA == 'S' ? 'U' : 'L', α, A, B, β, C)
elseif (tB == 'S' || tB == 's') && tA == 'N'
return symm!('R', tB == 'S' ? 'U' : 'L', alpha, B, A, beta, C)
return symm!('R', tB == 'S' ? 'U' : 'L', α, B, A, β, C)
elseif (tA == 'H' || tA == 'h') && tB == 'N'
return hemm!('L', tA == 'H' ? 'U' : 'L', alpha, A, B, beta, C)
return hemm!('L', tA == 'H' ? 'U' : 'L', α, A, B, β, C)
elseif (tB == 'H' || tB == 'h') && tA == 'N'
return hemm!('R', tB == 'H' ? 'U' : 'L', alpha, B, A, beta, C)
return hemm!('R', tB == 'H' ? 'U' : 'L', α, B, A, β, C)
end
end

GPUArrays.generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), alpha, beta)
end

Expand All @@ -172,3 +197,23 @@ LinearAlgebra.generic_trimatdiv!(C::oneStridedMatrix{T}, uploc, isunitc, tfun::F
trsm!('L', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), A, C === B ? C : copyto!(C, B))
LinearAlgebra.generic_mattridiv!(C::oneStridedMatrix{T}, uploc, isunitc, tfun::Function, A::oneStridedMatrix{T}, B::oneStridedMatrix{T}) where {T<:onemklFloat} =
trsm!('R', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), B, C === A ? C : copyto!(C, A))

#
# BLAS extensions
#

# Extend LinearAlgebra.BLAS.herk! to dispatch to oneAPI implementation
for (elty) in ([Float32, ComplexF32], [Float64, ComplexF64])
@eval begin
LinearAlgebra.BLAS.herk!(uplo::Char, trans::Char, alpha::$elty[1], A::oneStridedVecOrMat{$elty[2]}, beta::$elty[1], C::oneStridedMatrix{$elty[2]}) =
herk!(uplo, trans, alpha, A, beta, C)
end
end

# Extend LinearAlgebra.BLAS.syrk! to dispatch to oneAPI implementation
for (elty) in (Float32, Float64, ComplexF32, ComplexF64)
@eval begin
LinearAlgebra.BLAS.syrk!(uplo::Char, trans::Char, alpha::$elty, A::oneStridedVecOrMat{$elty}, beta::$elty, C::oneStridedMatrix{$elty}) =
syrk!(uplo, trans, alpha, A, beta, C)
end
end
6 changes: 3 additions & 3 deletions lib/mkl/wrappers_sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ for (fname, elty, intty) in ((:onemklSsparse_set_csr_data , :Float32 , :Int3
return dA
end

function SparseMatrixCSC(A::oneSparseMatrixCSR{$elty, $intty})
function SparseArrays.SparseMatrixCSC(A::oneSparseMatrixCSR{$elty, $intty})
handle_ptr = Ref{matrix_handle_t}()
At = SparseMatrixCSC(reverse(A.dims)..., Vector(A.rowPtr), Vector(A.colVal), Vector(A.nzVal))
A_csc = SparseMatrixCSC(At |> transpose)
Expand All @@ -51,7 +51,7 @@ for (fname, elty, intty) in ((:onemklSsparse_set_csr_data , :Float32 , :Int3
return dA
end

function SparseMatrixCSC(A::oneSparseMatrixCSC{$elty, $intty})
function SparseArrays.SparseMatrixCSC(A::oneSparseMatrixCSC{$elty, $intty})
handle_ptr = Ref{matrix_handle_t}()
A_csc = SparseMatrixCSC(A.dims..., Vector(A.colPtr), Vector(A.rowVal), Vector(A.nzVal))
return A_csc
Expand Down Expand Up @@ -84,7 +84,7 @@ for (fname, elty, intty) in ((:onemklSsparse_set_coo_data , :Float32 , :Int3
return dA
end

function SparseMatrixCSC(A::oneSparseMatrixCOO{$elty, $intty})
function SparseArrays.SparseMatrixCSC(A::oneSparseMatrixCOO{$elty, $intty})
handle_ptr = Ref{matrix_handle_t}()
A = sparse(Vector(A.rowInd), Vector(A.colInd), Vector(A.nzVal), A.dims...)
return A
Expand Down
14 changes: 7 additions & 7 deletions src/broadcast.jl
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
# broadcasting

using Base.Broadcast: BroadcastStyle, Broadcasted
import Base.Broadcast: BroadcastStyle, Broadcasted

struct oneArrayStyle{N,B} <: AbstractGPUArrayStyle{N} end
oneArrayStyle{M,B}(::Val{N}) where {N,M,B} = oneArrayStyle{N,B}()

# identify the broadcast style of a (wrapped) oneArray
BroadcastStyle(::Type{<:oneArray{T,N,B}}) where {T,N,B} = oneArrayStyle{N,B}()
BroadcastStyle(W::Type{<:oneWrappedArray{T,N}}) where {T,N} =
BroadcastStyle(::Type{<:oneArray{T, N, B}}) where {T, N, B} = oneArrayStyle{N, B}()
BroadcastStyle(W::Type{<:oneWrappedArray{T, N}}) where {T, N} =
oneArrayStyle{N, buftype(Adapt.unwrap_type(W))}()

# when we are dealing with different buffer styles, we cannot know
# which one is better, so use shared memory
BroadcastStyle(::oneArrayStyle{N, B1},
::oneArrayStyle{N, B2}) where {N,B1,B2} =
BroadcastStyle(
::oneArrayStyle{N, B1},
::oneArrayStyle{N, B2},
) where {N,B1,B2} =
oneArrayStyle{N, oneL0.SharedBuffer}()

# allocation of output arrays
Expand Down
4 changes: 2 additions & 2 deletions src/compiler/reflection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ export @device_code_lowered, @device_code_typed, @device_code_warntype,
#

"""
Metal.return_type(f, tt) -> r::Type
return_type(f, tt) -> r::Type

Return a type `r` such that `f(args...)::r` where `args::tt`.
"""
Expand All @@ -75,5 +75,5 @@ function return_type(@nospecialize(func), @nospecialize(tt))
job = CompilerJob(source, config)
interp = GPUCompiler.get_interpreter(job)
sig = Base.signature_type(func, tt)
Core.Compiler.return_type(interp, sig)
return Core.Compiler._return_type(interp, sig)
end
10 changes: 10 additions & 0 deletions src/device/quirks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,13 @@ end
@print_and_throw "Out-of-bounds access of scalar value"
x
end

# From Metal.jl to avoid widemul and Int128
@static if VERSION >= v"1.12.0-DEV.1736" # Partially reverts JuliaLang/julia PR #56750
let BitInteger64 = Union{Int64, UInt64}
@device_override function Base.checkbounds(::Type{Bool}, v::StepRange{<:BitInteger64, <:BitInteger64}, i::BitInteger64)
@inline
return checkindex(Bool, eachindex(IndexLinear(), v), i)
end
end
end
14 changes: 9 additions & 5 deletions test/execution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -307,12 +307,12 @@ end
@oneapi kernel(arr)
@test Array(arr)[] == 1

function kernel(ptr)
function kernel2(ptr)
ptr[] = 2
return
end

@oneapi kernel(arr)
@oneapi kernel2(arr)
@test Array(arr)[] == 2
end

Expand Down Expand Up @@ -611,9 +611,13 @@ end
return
end

a = oneArray(Float32[0])
@oneapi kernel(pointer(a))
@test Array(a) == [42]
if VERSION < v"1.12"
a = oneArray(Float32[0])
@oneapi kernel(pointer(a))
@test Array(a) == [42]
else
@test_broken false
end
end

############################################################################################
Loading