From 49460d61082c541d684de8825741af5c345e18b6 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Wed, 30 Jul 2025 17:09:36 -0400 Subject: [PATCH 01/14] Support for delta --- Project.toml | 9 +- .../KroneckerArraysBlockSparseArraysExt.jl | 23 ++-- src/cartesianproduct.jl | 6 + src/fillarrays/kroneckerarray.jl | 111 +++++++++++++++++- src/kroneckerarray.jl | 88 +++++++++++++- test/test_aqua.jl | 2 +- 6 files changed, 214 insertions(+), 25 deletions(-) diff --git a/Project.toml b/Project.toml index 2aa9b11..f65e24f 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "KroneckerArrays" uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc" authors = ["ITensor developers and contributors"] -version = "0.1.26" +version = "0.1.27" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" @@ -12,23 +12,28 @@ GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MapBroadcast = "ebd9b9da-f48d-417c-9660-449667d60261" MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4" +TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" [weakdeps] BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" BlockSparseArrays = "2c9a651f-6452-4ace-a6ac-809f4280fbb4" +TensorProducts = "decf83d6-1968-43f4-96dc-fdb3fe15fc6d" [extensions] KroneckerArraysBlockSparseArraysExt = ["BlockArrays", "BlockSparseArrays"] +KroneckerArraysTensorProductsExt = "TensorProducts" [compat] Adapt = "4.3" BlockArrays = "1.6" BlockSparseArrays = "0.8.1" DerivableInterfaces = "0.5" -DiagonalArrays = "0.3.5" +DiagonalArrays = "0.3.11" FillArrays = "1.13" GPUArraysCore = "0.2" LinearAlgebra = "1.10" MapBroadcast = "0.1.9" MatrixAlgebraKit = "0.2" +TensorAlgebra = "0.3.10" +TensorProducts = "0.1.7" julia = "1.10" diff --git a/ext/KroneckerArraysBlockSparseArraysExt/KroneckerArraysBlockSparseArraysExt.jl b/ext/KroneckerArraysBlockSparseArraysExt/KroneckerArraysBlockSparseArraysExt.jl index e67d3aa..0a5c1df 100644 --- a/ext/KroneckerArraysBlockSparseArraysExt/KroneckerArraysBlockSparseArraysExt.jl +++ b/ext/KroneckerArraysBlockSparseArraysExt/KroneckerArraysBlockSparseArraysExt.jl @@ -39,7 +39,7 @@ using KroneckerArrays: _similar function KroneckerArrays.arg1(r::AbstractBlockedUnitRange) - return mortar_axis(arg2.(eachblockaxis(r))) + return mortar_axis(arg1.(eachblockaxis(r))) end function KroneckerArrays.arg2(r::AbstractBlockedUnitRange) return mortar_axis(arg2.(eachblockaxis(r))) @@ -56,17 +56,16 @@ function block_axes(ax::NTuple{N,AbstractUnitRange{<:Integer}}, I::Block{N}) whe return block_axes(ax, Tuple(I)...) end -function Base.getindex( - a::ZeroBlocks{2,KroneckerMatrix{T,A,B}}, I::Vararg{Int,2} -) where {T,A<:AbstractMatrix{T},B<:AbstractMatrix{T}} - ax_a1 = arg1.(a.parentaxes) - a1 = ZeroBlocks{2,A}(ax_a1)[I...] - - ax_a2 = arg2.(a.parentaxes) - a2 = ZeroBlocks{2,B}(ax_a2)[I...] - - return a1 ⊗ a2 -end +## TODO: Is this needed? +## function Base.getindex( +## a::ZeroBlocks{2,KroneckerMatrix{T,A,B}}, I::Vararg{Int,2} +## ) where {T,A<:AbstractMatrix{T},B<:AbstractMatrix{T}} +## ax_a1 = map(arg1, a.parentaxes) +## a1 = ZeroBlocks{2,A}(ax_a1)[I...] +## ax_a2 = map(arg2, a.parentaxes) +## a2 = ZeroBlocks{2,B}(ax_a2)[I...] +## return a1 ⊗ a2 +## end function Base.getindex( a::ZeroBlocks{2,EyeKronecker{T,A,B}}, I::Vararg{Int,2} ) where {T,A<:Eye{T},B<:AbstractMatrix{T}} diff --git a/src/cartesianproduct.jl b/src/cartesianproduct.jl index be1c4fa..e8247e8 100644 --- a/src/cartesianproduct.jl +++ b/src/cartesianproduct.jl @@ -96,6 +96,12 @@ unproduct(r::CartesianProductUnitRange) = getfield(r, :range) arg1(a::CartesianProductUnitRange) = arg1(cartesianproduct(a)) arg2(a::CartesianProductUnitRange) = arg2(cartesianproduct(a)) +function Base.getindex(a::CartesianProductUnitRange, i::CartesianProductUnitRange) + prod = cartesianproduct(a)[cartesianproduct(i)] + range = unproduct(a)[unproduct(i)] + return cartesianrange(prod, range) +end + function Base.show(io::IO, a::CartesianProductUnitRange) show(io, unproduct(a)) return nothing diff --git a/src/fillarrays/kroneckerarray.jl b/src/fillarrays/kroneckerarray.jl index f943ce0..772e251 100644 --- a/src/fillarrays/kroneckerarray.jl +++ b/src/fillarrays/kroneckerarray.jl @@ -1,4 +1,4 @@ -using FillArrays: FillArrays, Zeros +using FillArrays: FillArrays, Ones, Zeros function FillArrays.fillsimilar( a::Zeros{T}, ax::Tuple{ @@ -21,6 +21,11 @@ const SquareEyeKronecker{T,A<:SquareEye{T},B<:AbstractMatrix{T}} = KroneckerMatr const KroneckerSquareEye{T,A<:AbstractMatrix{T},B<:SquareEye{T}} = KroneckerMatrix{T,A,B} const SquareEyeSquareEye{T,A<:SquareEye{T},B<:SquareEye{T}} = KroneckerMatrix{T,A,B} +using DiagonalArrays: Delta +const DeltaKronecker{T,N,A<:Delta{T,N},B<:AbstractArray{T,N}} = KroneckerArray{T,N,A,B} +const KroneckerDelta{T,N,A<:AbstractArray{T,N},B<:Delta{T,N}} = KroneckerArray{T,N,A,B} +const DeltaDelta{T,N,A<:Delta{T,N},B<:Delta{T,N}} = KroneckerArray{T,N,A,B} + _getindex(a::Eye, I1::Colon, I2::Colon) = a _getindex(a::Eye, I1::Base.Slice, I2::Base.Slice) = a _getindex(a::Eye, I1::Base.Slice, I2::Colon) = a @@ -30,15 +35,23 @@ _view(a::Eye, I1::Base.Slice, I2::Base.Slice) = a _view(a::Eye, I1::Base.Slice, I2::Colon) = a _view(a::Eye, I1::Colon, I2::Base.Slice) = a +function _getindex(a::Delta, I1::Union{Colon,Base.Slice}, Irest::Union{Colon,Base.Slice}...) + return a +end +function _view(a::Delta, I1::Union{Colon,Base.Slice}, Irest::Union{Colon,Base.Slice}...) + return a +end + # Like `adapt` but preserves `Eye`. _adapt(to, a::Eye) = a +_adapt(to, a::Delta) = a # Allows customizing for `FillArrays.Eye`. function _convert(::Type{AbstractArray{T}}, a::RectDiagonal) where {T} - _convert(AbstractMatrix{T}, a) + return _convert(AbstractMatrix{T}, a) end function _convert(::Type{AbstractMatrix{T}}, a::RectDiagonal) where {T} - RectDiagonal(convert(AbstractVector{T}, _diagview(a)), axes(a)) + return RectDiagonal(convert(AbstractVector{T}, _diagview(a)), axes(a)) end # Like `similar` but preserves `Eye`. @@ -74,8 +87,39 @@ function _similar(arrayt::Type{<:SquareEye}, axs::NTuple{2,AbstractUnitRange}) return Eye{eltype(arrayt)}((only(unique(axs)),)) end -# Like `copy` but preserves `Eye`. +function _similar(a::Delta, elt::Type, axs::Tuple{Vararg{AbstractUnitRange}}) + return Delta{elt}(axs) +end +function _similar(arrayt::Type{<:Delta}, axs::Tuple{Vararg{AbstractUnitRange}}) + return Delta{eltype(arrayt)}(axs) +end + +# Like `copy` but preserves `Eye`/`Delta`. _copy(a::Eye) = a +_copy(a::Delta) = a + +function _copyto!!(dest::Eye{<:Any,N}, src::Eye{<:Any,N}) where {N} + size(dest) == size(src) || + throw(ArgumentError("Sizes do not match: $(size(dest)) != $(size(src)).")) + return dest +end +function _copyto!!(dest::Delta{<:Any,N}, src::Delta{<:Any,N}) where {N} + size(dest) == size(src) || + throw(ArgumentError("Sizes do not match: $(size(dest)) != $(size(src)).")) + return dest +end + +# TODO: Define `DerivableInterfaces.permuteddims` and overload that instead. +function Base.PermutedDimsArray(a::Delta, perm) + ax_perm = Base.PermutedDimsArrays.genperm(axes(a), perm) + return Delta{eltype(a)}(ax_perm) +end + +function _permutedims!!(dest::Delta, src::Delta, perm) + Base.PermutedDimsArrays.genperm(axes(src), perm) == axes(dest) || + throw(ArgumentError("Permuted axes do not match.")) + return dest +end using DerivableInterfaces: DerivableInterfaces, zero! function DerivableInterfaces.zero!(a::EyeKronecker) @@ -90,6 +134,18 @@ function DerivableInterfaces.zero!(a::EyeEye) return throw(ArgumentError("Can't zero out `Eye ⊗ Eye`.")) end +function DerivableInterfaces.zero!(a::DeltaKronecker) + zero!(a.b) + return a +end +function DerivableInterfaces.zero!(a::KroneckerDelta) + zero!(a.a) + return a +end +function DerivableInterfaces.zero!(a::DeltaDelta) + return throw(ArgumentError("Can't zero out `Delta ⊗ Delta`.")) +end + using Base.Broadcast: AbstractArrayStyle, AbstractArrayStyle, BroadcastStyle, Broadcasted, broadcasted @@ -101,10 +157,16 @@ end Base.BroadcastStyle(style1::EyeStyle, style2::EyeStyle) = EyeStyle() Base.BroadcastStyle(style1::EyeStyle, style2::DefaultArrayStyle) = style2 +function _copyto!!(dest::Eye, src::Broadcasted{<:EyeStyle,<:Any,typeof(identity)}) + axes(dest) == axes(src) || error("Dimension mismatch.") + return dest +end + function Base.similar(bc::Broadcasted{EyeStyle}, elt::Type) return Eye{elt}(axes(bc)) end +# TODO: Define in terms of `_copyto!!` that is called on each argument. function Base.copyto!(dest::EyeKronecker, a::Sum{<:KroneckerStyle{<:Any,EyeStyle()}}) dest2 = arg2(dest) f = LinearCombination(a) @@ -125,6 +187,47 @@ function Base.copyto!(dest::EyeEye, a::Sum{<:KroneckerStyle{<:Any,EyeStyle(),Eye return error("Can't write in-place to `Eye ⊗ Eye`.") end +struct DeltaStyle{N} <: AbstractArrayStyle{N} end +DeltaStyle(::Val{N}) where {N} = DeltaStyle{N}() +DeltaStyle{M}(::Val{N}) where {M,N} = DeltaStyle{N}() +function _BroadcastStyle(A::Type{<:Delta}) + return DeltaStyle{ndims(A)}() +end +Base.BroadcastStyle(style1::DeltaStyle, style2::DeltaStyle) = DeltaStyle() +Base.BroadcastStyle(style1::DeltaStyle, style2::DefaultArrayStyle) = style2 + +function _copyto!!(dest::Delta, src::Broadcasted{<:DeltaStyle,<:Any,typeof(identity)}) + axes(dest) == axes(src) || error("Dimension mismatch.") + return dest +end + +function Base.similar(bc::Broadcasted{<:DeltaStyle}, elt::Type) + return Delta{elt}(axes(bc)) +end + +# TODO: Dispatch on `DeltaStyle`. +function Base.copyto!(dest::DeltaKronecker, a::Sum{<:KroneckerStyle}) + dest2 = arg2(dest) + f = LinearCombination(a) + args = arguments(a) + arg2s = arg2.(args) + dest2 .= f.(arg2s...) + return dest +end +# TODO: Dispatch on `DeltaStyle`. +function Base.copyto!(dest::KroneckerDelta, a::Sum{<:KroneckerStyle}) + dest1 = arg1(dest) + f = LinearCombination(a) + args = arguments(a) + arg1s = arg1.(args) + dest1 .= f.(arg1s...) + return dest +end +# TODO: Dispatch on `DeltaStyle`. +function Base.copyto!(dest::DeltaDelta, a::Sum{<:KroneckerStyle}) + return error("Can't write in-place to `Delta ⊗ Delta`.") +end + # Simplification rules similar to those for FillArrays.jl: # https://github.com/JuliaArrays/FillArrays.jl/blob/v1.13.0/src/fillbroadcast.jl using FillArrays: Zeros diff --git a/src/kroneckerarray.jl b/src/kroneckerarray.jl index 0473b1a..886986e 100644 --- a/src/kroneckerarray.jl +++ b/src/kroneckerarray.jl @@ -43,9 +43,26 @@ _copy(a::AbstractArray) = copy(a) function Base.copy(a::KroneckerArray) return _copy(arg1(a)) ⊗ _copy(arg2(a)) end -function Base.copyto!(dest::KroneckerArray, src::KroneckerArray) - copyto!(arg1(dest), arg1(src)) - copyto!(arg2(dest), arg2(src)) + +# Allows extra customization, like for `FillArrays.Eye`. +function _copyto!!(dest::AbstractArray{<:Any,N}, src::AbstractArray{<:Any,N}) where {N} + copyto!(dest, src) + return dest +end +function _copyto!!(dest::AbstractArray, src::Broadcasted) + copyto!(dest, src) + return dest +end + +function Base.copyto!(dest::KroneckerArray{<:Any,N}, src::KroneckerArray{<:Any,N}) where {N} + return copyto!_kronecker(dest, src) +end +function copyto!_kronecker( + dest::KroneckerArray{<:Any,N}, src::KroneckerArray{<:Any,N} +) where {N} + # TODO: Check if neither argument is mutated and if so error. + _copyto!!(arg1(dest), arg1(src)) + _copyto!!(arg2(dest), arg2(src)) return dest end @@ -101,6 +118,23 @@ function Base.similar( return similar(promote_type(A, B), sz) end +function _permutedims!!(dest::AbstractArray, src::AbstractArray, perm) + permutedims!(dest, src, perm) + return dest +end + +# TODO: Define `DerivableInterfaces.permuteddims` and overload that instead. +function Base.PermutedDimsArray(a::KroneckerArray, perm) + return PermutedDimsArray(arg1(a), perm) ⊗ PermutedDimsArray(arg2(a), perm) +end + +function Base.permutedims!(dest::KroneckerArray, src::KroneckerArray, perm) + # TODO: Error if neither argument is mutable. + _permutedims!!(arg1(dest), arg1(src), perm) + _permutedims!!(arg2(dest), arg2(src), perm) + return dest +end + function flatten(t::Tuple{Tuple,Tuple,Vararg{Tuple}}) return (t[1]..., flatten(Base.tail(t))...) end @@ -119,7 +153,7 @@ function kron_nd(a::AbstractArray{<:Any,N}, b::AbstractArray{<:Any,N}) where {N} a′ = reshape(a, interleave(size(a), ntuple(one, N))) b′ = reshape(b, interleave(ntuple(one, N), size(b))) c′ = permutedims(a′ .* b′, reverse(ntuple(identity, 2N))) - sz = ntuple(i -> size(a, i) * size(b, i), N) + sz = reverse(ntuple(i -> size(a, i) * size(b, i), N)) return permutedims(reshape(c′, sz), reverse(ntuple(identity, N))) end kron_nd(a::AbstractMatrix, b::AbstractMatrix) = kron(a, b) @@ -265,6 +299,12 @@ for f in [:transpose, :adjoint, :inv] end end +function Base.reshape( + a::KroneckerArray, ax::Tuple{CartesianProductUnitRange,Vararg{CartesianProductUnitRange}} +) + return reshape(arg1(a), map(arg1, ax)) ⊗ reshape(arg2(a), map(arg2, ax)) +end + # Allows for customizations for FillArrays. _BroadcastStyle(x) = BroadcastStyle(x) @@ -384,8 +424,8 @@ Broadcast.materialize!(dest, a::KroneckerBroadcasted) = copyto!(dest, a) Broadcast.broadcastable(a::KroneckerBroadcasted) = a Base.copy(a::KroneckerBroadcasted) = copy(arg1(a)) ⊗ copy(arg2(a)) function Base.copyto!(dest::KroneckerArray, a::KroneckerBroadcasted) - copyto!(arg1(dest), copy(arg1(a))) - copyto!(arg2(dest), copy(arg2(a))) + _copyto!!(arg1(dest), arg1(a)) + _copyto!!(arg2(dest), arg2(a)) return dest end function Base.eltype(a::KroneckerBroadcasted) @@ -433,3 +473,39 @@ function Base.broadcasted( ) return broadcasted(style, /, a, f.args[2]) end + +using TensorAlgebra: TensorAlgebra, AbstractBlockPermutation, FusionStyle, matricize +struct KroneckerFusion{A<:FusionStyle,B<:FusionStyle} <: FusionStyle + a::A + b::B +end +arg1(style::KroneckerFusion) = style.a +arg2(style::KroneckerFusion) = style.b +function TensorAlgebra.FusionStyle(a::KroneckerArray) + return KroneckerFusion(FusionStyle(arg1(a)), FusionStyle(arg2(a))) +end +function matricize_kronecker( + style::KroneckerFusion, a::AbstractArray, biperm::AbstractBlockPermutation{2} +) + return matricize(arg1(style), arg1(a), biperm) ⊗ matricize(arg2(style), arg2(a), biperm) +end +function TensorAlgebra.matricize( + style::KroneckerFusion, a::AbstractArray, biperm::AbstractBlockPermutation{2} +) + return matricize_kronecker(style, a, biperm) +end +# Fix ambiguity error. +# TODO: Investigate rewriting the logic in `TensorAlgebra.jl` to avoid this. +using TensorAlgebra: BlockedTrivialPermutation, unmatricize +function TensorAlgebra.matricize( + style::KroneckerFusion, a::AbstractArray, biperm::BlockedTrivialPermutation{2} +) + return matricize_kronecker(style, a, biperm) +end +function unmatricize_kronecker(style::KroneckerFusion, a::AbstractArray, ax) + return unmatricize(arg1(style), arg1(a), arg1.(ax)) ⊗ + unmatricize(arg2(style), arg2(a), arg2.(ax)) +end +function TensorAlgebra.unmatricize(style::KroneckerFusion, a::AbstractArray, ax) + return unmatricize_kronecker(style, a, ax) +end diff --git a/test/test_aqua.jl b/test/test_aqua.jl index 5727e26..ef5144e 100644 --- a/test/test_aqua.jl +++ b/test/test_aqua.jl @@ -3,5 +3,5 @@ using Aqua: Aqua using Test: @testset @testset "Code quality (Aqua.jl)" begin - Aqua.test_all(KroneckerArrays) + # Aqua.test_all(KroneckerArrays) end From aede4cdcc378912651dd8f906dcb67d5dd8bd04d Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 12 Aug 2025 10:05:59 -0400 Subject: [PATCH 02/14] Simplify, fix tests --- .../KroneckerArraysBlockSparseArraysExt.jl | 18 ++++++------- src/fillarrays/kroneckerarray.jl | 25 ------------------- 2 files changed, 9 insertions(+), 34 deletions(-) diff --git a/ext/KroneckerArraysBlockSparseArraysExt/KroneckerArraysBlockSparseArraysExt.jl b/ext/KroneckerArraysBlockSparseArraysExt/KroneckerArraysBlockSparseArraysExt.jl index 0a5c1df..624b269 100644 --- a/ext/KroneckerArraysBlockSparseArraysExt/KroneckerArraysBlockSparseArraysExt.jl +++ b/ext/KroneckerArraysBlockSparseArraysExt/KroneckerArraysBlockSparseArraysExt.jl @@ -57,15 +57,15 @@ function block_axes(ax::NTuple{N,AbstractUnitRange{<:Integer}}, I::Block{N}) whe end ## TODO: Is this needed? -## function Base.getindex( -## a::ZeroBlocks{2,KroneckerMatrix{T,A,B}}, I::Vararg{Int,2} -## ) where {T,A<:AbstractMatrix{T},B<:AbstractMatrix{T}} -## ax_a1 = map(arg1, a.parentaxes) -## a1 = ZeroBlocks{2,A}(ax_a1)[I...] -## ax_a2 = map(arg2, a.parentaxes) -## a2 = ZeroBlocks{2,B}(ax_a2)[I...] -## return a1 ⊗ a2 -## end +function Base.getindex( + a::ZeroBlocks{2,KroneckerMatrix{T,A,B}}, I::Vararg{Int,2} +) where {T,A<:AbstractMatrix{T},B<:AbstractMatrix{T}} + ax_a1 = map(arg1, a.parentaxes) + a1 = ZeroBlocks{2,A}(ax_a1)[I...] + ax_a2 = map(arg2, a.parentaxes) + a2 = ZeroBlocks{2,B}(ax_a2)[I...] + return a1 ⊗ a2 +end function Base.getindex( a::ZeroBlocks{2,EyeKronecker{T,A,B}}, I::Vararg{Int,2} ) where {T,A<:Eye{T},B<:AbstractMatrix{T}} diff --git a/src/fillarrays/kroneckerarray.jl b/src/fillarrays/kroneckerarray.jl index d164a2f..a69db51 100644 --- a/src/fillarrays/kroneckerarray.jl +++ b/src/fillarrays/kroneckerarray.jl @@ -108,31 +108,6 @@ function _permutedims!!(dest::Delta, src::Delta, perm) return dest end -using DerivableInterfaces: DerivableInterfaces, zero! -function DerivableInterfaces.zero!(a::EyeKronecker) - zero!(a.b) - return a -end -function DerivableInterfaces.zero!(a::KroneckerEye) - zero!(a.a) - return a -end -function DerivableInterfaces.zero!(a::EyeEye) - return throw(ArgumentError("Can't zero out `Eye ⊗ Eye`.")) -end - -function DerivableInterfaces.zero!(a::DeltaKronecker) - zero!(a.b) - return a -end -function DerivableInterfaces.zero!(a::KroneckerDelta) - zero!(a.a) - return a -end -function DerivableInterfaces.zero!(a::DeltaDelta) - return throw(ArgumentError("Can't zero out `Delta ⊗ Delta`.")) -end - using Base.Broadcast: AbstractArrayStyle, AbstractArrayStyle, BroadcastStyle, Broadcasted, broadcasted From e1a89ac5faaa5e89a345f54f2389332780ab58e8 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 12 Aug 2025 10:13:15 -0400 Subject: [PATCH 03/14] Use permuteddims --- Project.toml | 2 +- src/fillarrays/kroneckerarray.jl | 4 ++-- src/kroneckerarray.jl | 6 +++--- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/Project.toml b/Project.toml index 53e9f3d..9f0a927 100644 --- a/Project.toml +++ b/Project.toml @@ -27,7 +27,7 @@ KroneckerArraysTensorProductsExt = "TensorProducts" Adapt = "4.3" BlockArrays = "1.6" BlockSparseArrays = "0.9" -DerivableInterfaces = "0.5" +DerivableInterfaces = "0.5.3" DiagonalArrays = "0.3.11" FillArrays = "1.13" GPUArraysCore = "0.2" diff --git a/src/fillarrays/kroneckerarray.jl b/src/fillarrays/kroneckerarray.jl index a69db51..1e47a3c 100644 --- a/src/fillarrays/kroneckerarray.jl +++ b/src/fillarrays/kroneckerarray.jl @@ -96,8 +96,8 @@ function _copyto!!(dest::Delta{<:Any,N}, src::Delta{<:Any,N}) where {N} return dest end -# TODO: Define `DerivableInterfaces.permuteddims` and overload that instead. -function Base.PermutedDimsArray(a::Delta, perm) +using DerivableInterfaces: DerivableInterfaces, permuteddims +function DerivableInterfaces.permuteddims(a::Delta, perm) ax_perm = Base.PermutedDimsArrays.genperm(axes(a), perm) return Delta{eltype(a)}(ax_perm) end diff --git a/src/kroneckerarray.jl b/src/kroneckerarray.jl index d815e29..bfb2e5b 100644 --- a/src/kroneckerarray.jl +++ b/src/kroneckerarray.jl @@ -132,9 +132,9 @@ function _permutedims!!(dest::AbstractArray, src::AbstractArray, perm) return dest end -# TODO: Define `DerivableInterfaces.permuteddims` and overload that instead. -function Base.PermutedDimsArray(a::KroneckerArray, perm) - return PermutedDimsArray(arg1(a), perm) ⊗ PermutedDimsArray(arg2(a), perm) +using DerivableInterfaces: DerivableInterfaces, permuteddims +function DerivableInterfaces.permuteddims(a::KroneckerArray, perm) + return permuteddims(arg1(a), perm) ⊗ permuteddims(arg2(a), perm) end function Base.permutedims!(dest::KroneckerArray, src::KroneckerArray, perm) From 84c66af9d0285b8f2033b9827ff4cfcbc0bad8e5 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 12 Aug 2025 15:33:30 -0400 Subject: [PATCH 04/14] More tests --- src/fillarrays/kroneckerarray.jl | 6 ------ test/test_aqua.jl | 2 +- test/test_basics.jl | 7 +++++++ 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/src/fillarrays/kroneckerarray.jl b/src/fillarrays/kroneckerarray.jl index 1e47a3c..0faeb48 100644 --- a/src/fillarrays/kroneckerarray.jl +++ b/src/fillarrays/kroneckerarray.jl @@ -96,12 +96,6 @@ function _copyto!!(dest::Delta{<:Any,N}, src::Delta{<:Any,N}) where {N} return dest end -using DerivableInterfaces: DerivableInterfaces, permuteddims -function DerivableInterfaces.permuteddims(a::Delta, perm) - ax_perm = Base.PermutedDimsArrays.genperm(axes(a), perm) - return Delta{eltype(a)}(ax_perm) -end - function _permutedims!!(dest::Delta, src::Delta, perm) Base.PermutedDimsArrays.genperm(axes(src), perm) == axes(dest) || throw(ArgumentError("Permuted axes do not match.")) diff --git a/test/test_aqua.jl b/test/test_aqua.jl index ef5144e..5727e26 100644 --- a/test/test_aqua.jl +++ b/test/test_aqua.jl @@ -3,5 +3,5 @@ using Aqua: Aqua using Test: @testset @testset "Code quality (Aqua.jl)" begin - # Aqua.test_all(KroneckerArrays) + Aqua.test_all(KroneckerArrays) end diff --git a/test/test_basics.jl b/test/test_basics.jl index e17f8c1..9f89585 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -50,6 +50,13 @@ elts = (Float32, Float64, ComplexF32, ComplexF64) r = cartesianrange((2:3) × (3:4), 2:5) @test axes(r) ≡ (CartesianProductUnitRange(Base.OneTo(2) × Base.OneTo(2), Base.OneTo(4)),) + # CartesianProductUnitRange getindex + r1 = cartesianrange((2:4) × (3:5), 2:10) + r2 = cartesianrange((2:3) × (2:3), 2:5) + @test r1[r2] ≡ cartesianrange((3:4) × (4:5), 3:6) + + @test axes(r) ≡ (CartesianProductUnitRange(Base.OneTo(2) × Base.OneTo(2), Base.OneTo(4)),) + # CartesianProductVector axes r = CartesianProductVector(([2, 4]) × ([3, 5]), [3, 5, 7, 9]) @test axes(r) ≡ (CartesianProductUnitRange(Base.OneTo(2) × Base.OneTo(2), Base.OneTo(4)),) From 1a056568d6faa470557d00da9c34868d66a42926 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 12 Aug 2025 15:43:20 -0400 Subject: [PATCH 05/14] KroneckerArraysTensorAlgebraExt extension --- Project.toml | 3 +- .../KroneckerArraysTensorAlgebraExt.jl | 41 +++++++++++++++++++ src/kroneckerarray.jl | 36 ---------------- 3 files changed, 43 insertions(+), 37 deletions(-) create mode 100644 ext/KroneckerArraysTensorAlgebraExt/KroneckerArraysTensorAlgebraExt.jl diff --git a/Project.toml b/Project.toml index 9f0a927..c13cd4e 100644 --- a/Project.toml +++ b/Project.toml @@ -12,15 +12,16 @@ GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MapBroadcast = "ebd9b9da-f48d-417c-9660-449667d60261" MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4" -TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" [weakdeps] BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" BlockSparseArrays = "2c9a651f-6452-4ace-a6ac-809f4280fbb4" +TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" TensorProducts = "decf83d6-1968-43f4-96dc-fdb3fe15fc6d" [extensions] KroneckerArraysBlockSparseArraysExt = ["BlockArrays", "BlockSparseArrays"] +KroneckerArraysTensorAlgebraExt = "TensorAlgebra" KroneckerArraysTensorProductsExt = "TensorProducts" [compat] diff --git a/ext/KroneckerArraysTensorAlgebraExt/KroneckerArraysTensorAlgebraExt.jl b/ext/KroneckerArraysTensorAlgebraExt/KroneckerArraysTensorAlgebraExt.jl new file mode 100644 index 0000000..c248ec2 --- /dev/null +++ b/ext/KroneckerArraysTensorAlgebraExt/KroneckerArraysTensorAlgebraExt.jl @@ -0,0 +1,41 @@ +module KroneckerArraysTensorAlgebraExt + +using KroneckerArrays: KroneckerArrays, KroneckerArray +using TensorAlgebra: TensorAlgebra, AbstractBlockPermutation, FusionStyle, matricize + +struct KroneckerFusion{A<:FusionStyle,B<:FusionStyle} <: FusionStyle + a::A + b::B +end +KroneckerArrays.arg1(style::KroneckerFusion) = style.a +KroneckerArrays.arg2(style::KroneckerFusion) = style.b +function TensorAlgebra.FusionStyle(a::KroneckerArray) + return KroneckerFusion(FusionStyle(arg1(a)), FusionStyle(arg2(a))) +end +function matricize_kronecker( + style::KroneckerFusion, a::AbstractArray, biperm::AbstractBlockPermutation{2} +) + return matricize(arg1(style), arg1(a), biperm) ⊗ matricize(arg2(style), arg2(a), biperm) +end +function TensorAlgebra.matricize( + style::KroneckerFusion, a::AbstractArray, biperm::AbstractBlockPermutation{2} +) + return matricize_kronecker(style, a, biperm) +end +# Fix ambiguity error. +# TODO: Investigate rewriting the logic in `TensorAlgebra.jl` to avoid this. +using TensorAlgebra: BlockedTrivialPermutation, unmatricize +function TensorAlgebra.matricize( + style::KroneckerFusion, a::AbstractArray, biperm::BlockedTrivialPermutation{2} +) + return matricize_kronecker(style, a, biperm) +end +function unmatricize_kronecker(style::KroneckerFusion, a::AbstractArray, ax) + return unmatricize(arg1(style), arg1(a), arg1.(ax)) ⊗ + unmatricize(arg2(style), arg2(a), arg2.(ax)) +end +function TensorAlgebra.unmatricize(style::KroneckerFusion, a::AbstractArray, ax) + return unmatricize_kronecker(style, a, ax) +end + +end diff --git a/src/kroneckerarray.jl b/src/kroneckerarray.jl index bfb2e5b..9d30a08 100644 --- a/src/kroneckerarray.jl +++ b/src/kroneckerarray.jl @@ -494,39 +494,3 @@ function Base.broadcasted( ) return broadcasted(style, /, a, f.args[2]) end - -using TensorAlgebra: TensorAlgebra, AbstractBlockPermutation, FusionStyle, matricize -struct KroneckerFusion{A<:FusionStyle,B<:FusionStyle} <: FusionStyle - a::A - b::B -end -arg1(style::KroneckerFusion) = style.a -arg2(style::KroneckerFusion) = style.b -function TensorAlgebra.FusionStyle(a::KroneckerArray) - return KroneckerFusion(FusionStyle(arg1(a)), FusionStyle(arg2(a))) -end -function matricize_kronecker( - style::KroneckerFusion, a::AbstractArray, biperm::AbstractBlockPermutation{2} -) - return matricize(arg1(style), arg1(a), biperm) ⊗ matricize(arg2(style), arg2(a), biperm) -end -function TensorAlgebra.matricize( - style::KroneckerFusion, a::AbstractArray, biperm::AbstractBlockPermutation{2} -) - return matricize_kronecker(style, a, biperm) -end -# Fix ambiguity error. -# TODO: Investigate rewriting the logic in `TensorAlgebra.jl` to avoid this. -using TensorAlgebra: BlockedTrivialPermutation, unmatricize -function TensorAlgebra.matricize( - style::KroneckerFusion, a::AbstractArray, biperm::BlockedTrivialPermutation{2} -) - return matricize_kronecker(style, a, biperm) -end -function unmatricize_kronecker(style::KroneckerFusion, a::AbstractArray, ax) - return unmatricize(arg1(style), arg1(a), arg1.(ax)) ⊗ - unmatricize(arg2(style), arg2(a), arg2.(ax)) -end -function TensorAlgebra.unmatricize(style::KroneckerFusion, a::AbstractArray, ax) - return unmatricize_kronecker(style, a, ax) -end From 9a105ee7fa13229296d9e094d4d65d0e49a88b79 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 12 Aug 2025 15:52:13 -0400 Subject: [PATCH 06/14] More tests --- test/test_blocksparsearrays.jl | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/test/test_blocksparsearrays.jl b/test/test_blocksparsearrays.jl index 1dcf58a..db492db 100644 --- a/test/test_blocksparsearrays.jl +++ b/test/test_blocksparsearrays.jl @@ -1,10 +1,16 @@ using Adapt: adapt -using BlockArrays: Block, BlockRange, mortar +using BlockArrays: Block, BlockRange, blockedrange, blockisequal, mortar using BlockSparseArrays: - BlockIndexVector, BlockSparseArray, BlockSparseMatrix, blockrange, blocksparse, blocktype + BlockIndexVector, + BlockSparseArray, + BlockSparseMatrix, + blockrange, + blocksparse, + blocktype, + eachblockaxis using FillArrays: Eye, SquareEye using JLArrays: JLArray -using KroneckerArrays: KroneckerArray, ⊗, ×, arg1, arg2 +using KroneckerArrays: KroneckerArray, ⊗, ×, arg1, arg2, cartesianrange using LinearAlgebra: norm using MatrixAlgebraKit: svd_compact, svd_trunc using StableRNGs: StableRNG @@ -17,6 +23,15 @@ arrayts = (Array, JLArray) arrayts, elt in elts + # BlockUnitRange with CartesianProduct blocks + r = blockrange([2 × 3, 3 × 4]) + @test r[Block(1)] ≡ cartesianrange(2 × 3, 1:6) + @test r[Block(2)] ≡ cartesianrange(3 × 4, 7:18) + @test eachblockaxis(r)[1] ≡ cartesianrange(2, 3) + @test eachblockaxis(r)[2] ≡ cartesianrange(3, 4) + @test blockisequal(arg1(r), blockedrange([2, 3])) + @test blockisequal(arg2(r), blockedrange([3, 4])) + dev = adapt(arrayt) r = blockrange([2 × 2, 3 × 3]) d = Dict( From 82dce155a4b4579f9c8141d7abeb2c606f37f650 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 12 Aug 2025 15:57:25 -0400 Subject: [PATCH 07/14] More tests --- test/test_basics.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/test/test_basics.jl b/test/test_basics.jl index 9f89585..00a8a4c 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -46,6 +46,10 @@ elts = (Float32, Float64, ComplexF32, ComplexF64) @test r[2 × 2] == 5 @test r[2 × 3] == 6 + @test sprint(show, "text/plain", cartesianrange(2 × 3)) == + "Base.OneTo(2) × Base.OneTo(3)\nBase.OneTo(6)" + @test sprint(show, cartesianrange(2 × 3)) == "Base.OneTo(6)" + # CartesianProductUnitRange axes r = cartesianrange((2:3) × (3:4), 2:5) @test axes(r) ≡ (CartesianProductUnitRange(Base.OneTo(2) × Base.OneTo(2), Base.OneTo(4)),) From 39ce09352250da2901c2a6e0971cd01b6c349050 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 12 Aug 2025 16:48:44 -0400 Subject: [PATCH 08/14] More tests --- test/test_basics.jl | 11 +++++++++++ test/test_fillarrays.jl | 18 ++++++++++++++++++ 2 files changed, 29 insertions(+) diff --git a/test/test_basics.jl b/test/test_basics.jl index 00a8a4c..b621730 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -189,6 +189,17 @@ elts = (Float32, Float64, ComplexF32, ComplexF64) @test_throws ErrorException imag(a) end + # permutedims + a = randn(elt, 2, 2, 2) ⊗ randn(elt, 3, 3, 3) + @test permutedims(a, (2, 1, 3)) == + permutedims(arg1(a), (2, 1, 3)) ⊗ permutedims(arg2(a), (2, 1, 3)) + + # permutedims! + a = randn(elt, 2, 2, 2) ⊗ randn(elt, 3, 3, 3) + b = similar(a) + permutedims!(b, a, (2, 1, 3)) + @test b == permutedims(arg1(a), (2, 1, 3)) ⊗ permutedims(arg2(a), (2, 1, 3)) + # Adapt a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) a′ = adapt(JLArray, a) diff --git a/test/test_fillarrays.jl b/test/test_fillarrays.jl index 62e0234..9ad2c36 100644 --- a/test/test_fillarrays.jl +++ b/test/test_fillarrays.jl @@ -204,6 +204,24 @@ using TestExtras: @constinferred @test fa.b isa Eye @test det(a) ≈ det(collect(a)) ≈ 1 + + # permutedims + a = Eye(2, 2) ⊗ randn(3, 3) + @test permutedims(a, (2, 1)) == Eye(2, 2) ⊗ permutedims(arg2(a), (2, 1)) + + a = randn(2, 2) ⊗ Eye(3, 3) + @test permutedims(a, (2, 1)) == permutedims(arg1(a), (2, 1)) ⊗ Eye(3, 3) + + # permutedims! + a = Eye(2, 2) ⊗ randn(3, 3) + b = similar(a) + permutedims!(b, a, (2, 1)) + @test b == Eye(2, 2) ⊗ permutedims(arg2(a), (2, 1)) + + a = randn(3, 3) ⊗ Eye(2, 2) + b = similar(a) + permutedims!(b, a, (2, 1)) + @test b == permutedims(arg1(a), (2, 1)) ⊗ Eye(2, 2) end @testset "FillArrays.Zeros" begin From f5553bde34fa9a2464d690de650351b6bc6f502d Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 12 Aug 2025 18:10:38 -0400 Subject: [PATCH 09/14] More tests --- test/test_fillarrays.jl | 48 +++++++++++++++++++++++++++++++++++++---- 1 file changed, 44 insertions(+), 4 deletions(-) diff --git a/test/test_fillarrays.jl b/test/test_fillarrays.jl index 9ad2c36..0702459 100644 --- a/test/test_fillarrays.jl +++ b/test/test_fillarrays.jl @@ -1,12 +1,15 @@ +using Adapt: adapt using DerivableInterfaces: zero! +using DiagonalArrays: δ using FillArrays: Eye, Zeros +using JLArrays: JLArray, jl using KroneckerArrays: KroneckerArrays, KroneckerArray, ⊗, ×, arg1, arg2 using LinearAlgebra: det, norm, pinv using StableRNGs: StableRNG using Test: @test, @test_throws, @testset using TestExtras: @constinferred -@testset "FillArrays.Eye" begin +@testset "FillArrays.Eye, DiagonalArrays.Delta" begin MATRIX_FUNCTIONS = KroneckerArrays.MATRIX_FUNCTIONS if VERSION < v"1.11-" # `cbrt(::AbstractMatrix{<:Real})` was implemented in Julia 1.11. @@ -15,9 +18,46 @@ using TestExtras: @constinferred a = Eye(2) ⊗ randn(3, 3) @test size(a) == (6, 6) - @test a + a == Eye(2) ⊗ (2a.b) - @test 2a == Eye(2) ⊗ (2a.b) - @test a * a == Eye(2) ⊗ (a.b * a.b) + @test a + a == Eye(2) ⊗ (2 * arg2(a)) + @test 2a == Eye(2) ⊗ (2 * arg2(a)) + @test a * a == Eye(2) ⊗ (arg2(a) * arg2(a)) + @test arg1(a[(:) × (:), (:) × (:)]) ≡ Eye(2) + @test arg1(view(a, (:) × (:), (:) × (:))) ≡ Eye(2) + @test arg1(a[Base.Slice(Base.OneTo(2)) × (:), (:) × (:)]) ≡ Eye(2) + @test arg1(view(a, Base.Slice(Base.OneTo(2)) × (:), (:) × (:))) ≡ Eye(2) + @test arg1(a[(:) × (:), Base.Slice(Base.OneTo(2)) × (:)]) ≡ Eye(2) + @test arg1(view(a, (:) × (:), Base.Slice(Base.OneTo(2)) × (:))) ≡ Eye(2) + @test arg1(a[Base.Slice(Base.OneTo(2)) × (:), Base.Slice(Base.OneTo(2)) × (:)]) ≡ Eye(2) + @test arg1(view(a, Base.Slice(Base.OneTo(2)) × (:), Base.Slice(Base.OneTo(2)) × (:))) ≡ + Eye(2) + @test arg1(adapt(JLArray, a)) ≡ Eye(2) + @test arg2(adapt(JLArray, a)) == jl(arg2(a)) + @test arg2(adapt(JLArray, a)) isa JLArray + @test arg1(similar(a, (cartesianrange(3 × 2), cartesianrange(3 × 2)))) ≡ Eye(3) + @test arg1(similar(typeof(a), (cartesianrange(3 × 2), cartesianrange(3 × 2)))) ≡ Eye(3) + @test arg1(similar(a, Float32, (cartesianrange(3 × 2), cartesianrange(3 × 2)))) ≡ + Eye{Float32}(3) + + a = δ(2, 2) ⊗ randn(3, 3) + @test size(a) == (6, 6) + @test a + a == δ(2, 2) ⊗ (2 * arg2(a)) + @test 2a == δ(2, 2) ⊗ (2 * arg2(a)) + @test a * a == δ(2, 2) ⊗ (arg2(a) * arg2(a)) + @test arg1(a[(:) × (:), (:) × (:)]) ≡ δ(2, 2) + @test arg1(a[Base.Slice(Base.OneTo(2)) × (:), (:) × (:)]) ≡ δ(2, 2) + @test arg1(view(a, Base.Slice(Base.OneTo(2)) × (:), (:) × (:))) ≡ δ(2, 2) + @test arg1(a[(:) × (:), Base.Slice(Base.OneTo(2)) × (:)]) ≡ δ(2, 2) + @test arg1(view(a, (:) × (:), Base.Slice(Base.OneTo(2)) × (:))) ≡ δ(2, 2) + @test arg1(a[Base.Slice(Base.OneTo(2)) × (:), Base.Slice(Base.OneTo(2)) × (:)]) ≡ δ(2, 2) + @test arg1(view(a, Base.Slice(Base.OneTo(2)) × (:), Base.Slice(Base.OneTo(2)) × (:))) ≡ + δ(2, 2) + @test arg1(adapt(JLArray, a)) ≡ δ(2, 2) + @test arg2(adapt(JLArray, a)) == jl(arg2(a)) + @test arg2(adapt(JLArray, a)) isa JLArray + @test arg1(similar(a, (cartesianrange(3 × 2), cartesianrange(3 × 2)))) ≡ δ(3, 3) + @test arg1(similar(typeof(a), (cartesianrange(3 × 2), cartesianrange(3 × 2)))) ≡ δ(3, 3) + @test arg1(similar(a, Float32, (cartesianrange(3 × 2), cartesianrange(3 × 2)))) ≡ + δ(Float32, 3, 3) # Views a = @constinferred(Eye(2) ⊗ randn(3, 3)) From 40b5e4db39d0e1787d222ebb7e1f45ba471d8e04 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 12 Aug 2025 18:18:12 -0400 Subject: [PATCH 10/14] More tests --- test/test_fillarrays.jl | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/test/test_fillarrays.jl b/test/test_fillarrays.jl index 0702459..b49f7d0 100644 --- a/test/test_fillarrays.jl +++ b/test/test_fillarrays.jl @@ -37,6 +37,16 @@ using TestExtras: @constinferred @test arg1(similar(typeof(a), (cartesianrange(3 × 2), cartesianrange(3 × 2)))) ≡ Eye(3) @test arg1(similar(a, Float32, (cartesianrange(3 × 2), cartesianrange(3 × 2)))) ≡ Eye{Float32}(3) + @test arg1(copy(a)) ≡ Eye(2) + @test arg2(copy(a)) == arg2(a) + b = similar(a) + @test arg1(copyto!(b, a)) ≡ Eye(2) + @test arg2(copyto!(b, a)) == arg2(a) + @test arg1(permutedims(a, (2, 1))) ≡ Eye(2) + @test arg2(permutedims(a, (2, 1))) == permutedims(arg2(a), (2, 1)) + b = similar(a) + @test arg1(permutedims!(b, a, (2, 1))) ≡ Eye(2) + @test arg2(permutedims!(b, a, (2, 1))) == permutedims(arg2(a), (2, 1)) a = δ(2, 2) ⊗ randn(3, 3) @test size(a) == (6, 6) @@ -58,6 +68,16 @@ using TestExtras: @constinferred @test arg1(similar(typeof(a), (cartesianrange(3 × 2), cartesianrange(3 × 2)))) ≡ δ(3, 3) @test arg1(similar(a, Float32, (cartesianrange(3 × 2), cartesianrange(3 × 2)))) ≡ δ(Float32, 3, 3) + @test arg1(copy(a)) ≡ δ(2, 2) + @test arg2(copy(a)) == arg2(a) + b = similar(a) + @test arg1(copyto!(b, a)) ≡ δ(2, 2) + @test arg2(copyto!(b, a)) == arg2(a) + @test arg1(permutedims(a, (2, 1))) ≡ δ(2, 2) + @test arg2(permutedims(a, (2, 1))) == permutedims(arg2(a), (2, 1)) + b = similar(a) + @test arg1(permutedims!(b, a, (2, 1))) ≡ δ(2, 2) + @test arg2(permutedims!(b, a, (2, 1))) == permutedims(arg2(a), (2, 1)) # Views a = @constinferred(Eye(2) ⊗ randn(3, 3)) From f9545da4031e77c6b76fe8e2e06c4b08621a4930 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Wed, 13 Aug 2025 10:07:02 -0400 Subject: [PATCH 11/14] Namespace fix --- test/test_fillarrays.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_fillarrays.jl b/test/test_fillarrays.jl index b49f7d0..1159380 100644 --- a/test/test_fillarrays.jl +++ b/test/test_fillarrays.jl @@ -3,7 +3,7 @@ using DerivableInterfaces: zero! using DiagonalArrays: δ using FillArrays: Eye, Zeros using JLArrays: JLArray, jl -using KroneckerArrays: KroneckerArrays, KroneckerArray, ⊗, ×, arg1, arg2 +using KroneckerArrays: KroneckerArrays, KroneckerArray, ⊗, ×, arg1, arg2, cartesianrange using LinearAlgebra: det, norm, pinv using StableRNGs: StableRNG using Test: @test, @test_throws, @testset From c1e073ecf6d9bc84c6cf39159b4c756b242446ab Mon Sep 17 00:00:00 2001 From: mtfishman Date: Wed, 13 Aug 2025 12:00:10 -0400 Subject: [PATCH 12/14] More tests --- test/test_fillarrays.jl | 64 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 64 insertions(+) diff --git a/test/test_fillarrays.jl b/test/test_fillarrays.jl index 1159380..04ef6ad 100644 --- a/test/test_fillarrays.jl +++ b/test/test_fillarrays.jl @@ -48,6 +48,38 @@ using TestExtras: @constinferred @test arg1(permutedims!(b, a, (2, 1))) ≡ Eye(2) @test arg2(permutedims!(b, a, (2, 1))) == permutedims(arg2(a), (2, 1)) + a = randn(3, 3) ⊗ Eye(2) + @test size(a) == (6, 6) + @test a + a == (2 * arg1(a)) ⊗ Eye(2) + @test 2a == (2 * arg1(a)) ⊗ Eye(2) + @test a * a == (arg1(a) * arg1(a)) ⊗ Eye(2) + @test arg2(a[(:) × (:), (:) × (:)]) ≡ Eye(2) + @test arg2(view(a, (:) × (:), (:) × (:))) ≡ Eye(2) + @test arg2(a[Base.Slice(Base.OneTo(2)) × (:), (:) × (:)]) ≡ Eye(2) + @test arg2(view(a, Base.Slice(Base.OneTo(2)) × (:), (:) × (:))) ≡ Eye(2) + @test arg2(a[(:) × (:), Base.Slice(Base.OneTo(2)) × (:)]) ≡ Eye(2) + @test arg2(view(a, (:) × (:), Base.Slice(Base.OneTo(2)) × (:))) ≡ Eye(2) + @test arg2(a[Base.Slice(Base.OneTo(2)) × (:), Base.Slice(Base.OneTo(2)) × (:)]) ≡ Eye(2) + @test arg2(view(a, Base.Slice(Base.OneTo(2)) × (:), Base.Slice(Base.OneTo(2)) × (:))) ≡ + Eye(2) + @test arg2(adapt(JLArray, a)) ≡ Eye(2) + @test arg1(adapt(JLArray, a)) == jl(arg1(a)) + @test arg1(adapt(JLArray, a)) isa JLArray + @test arg2(similar(a, (cartesianrange(2 × 3), cartesianrange(2 × 3)))) ≡ Eye(3) + @test arg2(similar(typeof(a), (cartesianrange(2 × 3), cartesianrange(2 × 3)))) ≡ Eye(3) + @test arg2(similar(a, Float32, (cartesianrange(2 × 3), cartesianrange(2 × 3)))) ≡ + Eye{Float32}(3) + @test arg2(copy(a)) ≡ Eye(2) + @test arg2(copy(a)) == arg2(a) + b = similar(a) + @test arg2(copyto!(b, a)) ≡ Eye(2) + @test arg2(copyto!(b, a)) == arg2(a) + @test arg2(permutedims(a, (2, 1))) ≡ Eye(2) + @test arg1(permutedims(a, (2, 1))) == permutedims(arg1(a), (2, 1)) + b = similar(a) + @test arg2(permutedims!(b, a, (2, 1))) ≡ Eye(2) + @test arg1(permutedims!(b, a, (2, 1))) == permutedims(arg1(a), (2, 1)) + a = δ(2, 2) ⊗ randn(3, 3) @test size(a) == (6, 6) @test a + a == δ(2, 2) ⊗ (2 * arg2(a)) @@ -79,6 +111,38 @@ using TestExtras: @constinferred @test arg1(permutedims!(b, a, (2, 1))) ≡ δ(2, 2) @test arg2(permutedims!(b, a, (2, 1))) == permutedims(arg2(a), (2, 1)) + a = randn(3, 3) ⊗ δ(2, 2) + @test size(a) == (6, 6) + @test a + a == (2 * arg1(a)) ⊗ δ(2, 2) + @test 2a == (2 * arg1(a)) ⊗ δ(2, 2) + @test a * a == (arg1(a) * arg1(a)) ⊗ δ(2, 2) + @test arg2(a[(:) × (:), (:) × (:)]) ≡ δ(2, 2) + @test arg2(view(a, (:) × (:), (:) × (:))) ≡ δ(2, 2) + @test arg2(a[Base.Slice(Base.OneTo(2)) × (:), (:) × (:)]) ≡ δ(2, 2) + @test arg2(view(a, Base.Slice(Base.OneTo(2)) × (:), (:) × (:))) ≡ δ(2, 2) + @test arg2(a[(:) × (:), Base.Slice(Base.OneTo(2)) × (:)]) ≡ δ(2, 2) + @test arg2(view(a, (:) × (:), Base.Slice(Base.OneTo(2)) × (:))) ≡ δ(2, 2) + @test arg2(a[Base.Slice(Base.OneTo(2)) × (:), Base.Slice(Base.OneTo(2)) × (:)]) ≡ δ(2, 2) + @test arg2(view(a, Base.Slice(Base.OneTo(2)) × (:), Base.Slice(Base.OneTo(2)) × (:))) ≡ + δ(2, 2) + @test arg2(adapt(JLArray, a)) ≡ δ(2, 2) + @test arg1(adapt(JLArray, a)) == jl(arg1(a)) + @test arg1(adapt(JLArray, a)) isa JLArray + @test arg2(similar(a, (cartesianrange(2 × 3), cartesianrange(2 × 3)))) ≡ δ(3, 3) + @test arg2(similar(typeof(a), (cartesianrange(2 × 3), cartesianrange(2 × 3)))) ≡ δ(3, 3) + @test arg2(similar(a, Float32, (cartesianrange(2 × 3), cartesianrange(2 × 3)))) ≡ + δ(Float32, (3, 3)) + @test arg2(copy(a)) ≡ δ(2, 2) + @test arg2(copy(a)) == arg2(a) + b = similar(a) + @test arg2(copyto!(b, a)) ≡ δ(2, 2) + @test arg2(copyto!(b, a)) == arg2(a) + @test arg2(permutedims(a, (2, 1))) ≡ δ(2, 2) + @test arg1(permutedims(a, (2, 1))) == permutedims(arg1(a), (2, 1)) + b = similar(a) + @test arg2(permutedims!(b, a, (2, 1))) ≡ δ(2, 2) + @test arg1(permutedims!(b, a, (2, 1))) == permutedims(arg1(a), (2, 1)) + # Views a = @constinferred(Eye(2) ⊗ randn(3, 3)) b = @constinferred(view(a, (:) × (2:3), (:) × (2:3))) From a621896e6155e6c7792ff630032b4988d9487652 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Wed, 13 Aug 2025 16:35:35 -0400 Subject: [PATCH 13/14] Test Kronecker matricize --- .../KroneckerArraysTensorAlgebraExt.jl | 5 +++-- test/test_tensoralgebra.jl | 10 ++++++++++ 2 files changed, 13 insertions(+), 2 deletions(-) create mode 100644 test/test_tensoralgebra.jl diff --git a/ext/KroneckerArraysTensorAlgebraExt/KroneckerArraysTensorAlgebraExt.jl b/ext/KroneckerArraysTensorAlgebraExt/KroneckerArraysTensorAlgebraExt.jl index c248ec2..2969ea5 100644 --- a/ext/KroneckerArraysTensorAlgebraExt/KroneckerArraysTensorAlgebraExt.jl +++ b/ext/KroneckerArraysTensorAlgebraExt/KroneckerArraysTensorAlgebraExt.jl @@ -1,7 +1,8 @@ module KroneckerArraysTensorAlgebraExt -using KroneckerArrays: KroneckerArrays, KroneckerArray -using TensorAlgebra: TensorAlgebra, AbstractBlockPermutation, FusionStyle, matricize +using KroneckerArrays: KroneckerArrays, KroneckerArray, ⊗, arg1, arg2 +using TensorAlgebra: + TensorAlgebra, AbstractBlockPermutation, FusionStyle, matricize, unmatricize struct KroneckerFusion{A<:FusionStyle,B<:FusionStyle} <: FusionStyle a::A diff --git a/test/test_tensoralgebra.jl b/test/test_tensoralgebra.jl new file mode 100644 index 0000000..35cac2a --- /dev/null +++ b/test/test_tensoralgebra.jl @@ -0,0 +1,10 @@ +using TensorAlgebra: matricize, unmatricize +using KroneckerArrays: ⊗, arg1, arg2 +using Test: @test, @testset + +@testset "TensorAlgebraExt" begin + a = randn(2, 2, 2) ⊗ randn(3, 3, 3) + m = matricize(a, (1, 2), (3,)) + @test m == matricize(arg1(a), (1, 2), (3,)) ⊗ matricize(arg2(a), (1, 2), (3,)) + @test unmatricize(m, (axes(a, 1), axes(a, 2)), (axes(a, 3),)) == a +end From 234cd0b591dc710394d5f4f23e424e4c194f7c6c Mon Sep 17 00:00:00 2001 From: mtfishman Date: Wed, 13 Aug 2025 16:47:13 -0400 Subject: [PATCH 14/14] Missing test dep --- test/Project.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/Project.toml b/test/Project.toml index f649d4b..f37978a 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -14,6 +14,7 @@ MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb" +TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" TensorProducts = "decf83d6-1968-43f4-96dc-fdb3fe15fc6d" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a" @@ -34,6 +35,7 @@ MatrixAlgebraKit = "0.2" SafeTestsets = "0.1" StableRNGs = "1.0" Suppressor = "0.2" +TensorAlgebra = "0.3.10" TensorProducts = "0.1.7" Test = "1.10" TestExtras = "0.3"