Skip to content

Support for delta #33

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Aug 13, 2025
9 changes: 6 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "KroneckerArrays"
uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.1.27"
version = "0.1.28"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand All @@ -16,22 +16,25 @@ MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4"
[weakdeps]
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
BlockSparseArrays = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
TensorProducts = "decf83d6-1968-43f4-96dc-fdb3fe15fc6d"

[extensions]
KroneckerArraysBlockSparseArraysExt = ["BlockArrays", "BlockSparseArrays"]
KroneckerArraysTensorAlgebraExt = "TensorAlgebra"
KroneckerArraysTensorProductsExt = "TensorProducts"

[compat]
Adapt = "4.3"
BlockArrays = "1.6"
BlockSparseArrays = "0.9"
DerivableInterfaces = "0.5"
DiagonalArrays = "0.3.5"
DerivableInterfaces = "0.5.3"
DiagonalArrays = "0.3.11"
FillArrays = "1.13"
GPUArraysCore = "0.2"
LinearAlgebra = "1.10"
MapBroadcast = "0.1.9"
MatrixAlgebraKit = "0.2"
TensorAlgebra = "0.3.10"
TensorProducts = "0.1.7"
julia = "1.10"
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ using KroneckerArrays:
_similar

function KroneckerArrays.arg1(r::AbstractBlockedUnitRange)
return mortar_axis(arg2.(eachblockaxis(r)))
return mortar_axis(arg1.(eachblockaxis(r)))
end
function KroneckerArrays.arg2(r::AbstractBlockedUnitRange)
return mortar_axis(arg2.(eachblockaxis(r)))
Expand All @@ -56,15 +56,14 @@ function block_axes(ax::NTuple{N,AbstractUnitRange{<:Integer}}, I::Block{N}) whe
return block_axes(ax, Tuple(I)...)
end

## TODO: Is this needed?
function Base.getindex(
a::ZeroBlocks{2,KroneckerMatrix{T,A,B}}, I::Vararg{Int,2}
) where {T,A<:AbstractMatrix{T},B<:AbstractMatrix{T}}
ax_a1 = arg1.(a.parentaxes)
ax_a1 = map(arg1, a.parentaxes)
a1 = ZeroBlocks{2,A}(ax_a1)[I...]

ax_a2 = arg2.(a.parentaxes)
ax_a2 = map(arg2, a.parentaxes)
a2 = ZeroBlocks{2,B}(ax_a2)[I...]

return a1 ⊗ a2
end
function Base.getindex(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
module KroneckerArraysTensorAlgebraExt

using KroneckerArrays: KroneckerArrays, KroneckerArray, ⊗, arg1, arg2
using TensorAlgebra:
TensorAlgebra, AbstractBlockPermutation, FusionStyle, matricize, unmatricize

struct KroneckerFusion{A<:FusionStyle,B<:FusionStyle} <: FusionStyle
a::A
b::B
end
KroneckerArrays.arg1(style::KroneckerFusion) = style.a
KroneckerArrays.arg2(style::KroneckerFusion) = style.b
function TensorAlgebra.FusionStyle(a::KroneckerArray)
return KroneckerFusion(FusionStyle(arg1(a)), FusionStyle(arg2(a)))
end
function matricize_kronecker(
style::KroneckerFusion, a::AbstractArray, biperm::AbstractBlockPermutation{2}
)
return matricize(arg1(style), arg1(a), biperm) ⊗ matricize(arg2(style), arg2(a), biperm)
end
function TensorAlgebra.matricize(
style::KroneckerFusion, a::AbstractArray, biperm::AbstractBlockPermutation{2}
)
return matricize_kronecker(style, a, biperm)
end
# Fix ambiguity error.
# TODO: Investigate rewriting the logic in `TensorAlgebra.jl` to avoid this.
using TensorAlgebra: BlockedTrivialPermutation, unmatricize
function TensorAlgebra.matricize(
style::KroneckerFusion, a::AbstractArray, biperm::BlockedTrivialPermutation{2}
)
return matricize_kronecker(style, a, biperm)
end
function unmatricize_kronecker(style::KroneckerFusion, a::AbstractArray, ax)
return unmatricize(arg1(style), arg1(a), arg1.(ax)) ⊗
unmatricize(arg2(style), arg2(a), arg2.(ax))
end
function TensorAlgebra.unmatricize(style::KroneckerFusion, a::AbstractArray, ax)
return unmatricize_kronecker(style, a, ax)
end

end
6 changes: 6 additions & 0 deletions src/cartesianproduct.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,12 @@ unproduct(r::CartesianProductUnitRange) = getfield(r, :range)
arg1(a::CartesianProductUnitRange) = arg1(cartesianproduct(a))
arg2(a::CartesianProductUnitRange) = arg2(cartesianproduct(a))

function Base.getindex(a::CartesianProductUnitRange, i::CartesianProductUnitRange)
prod = cartesianproduct(a)[cartesianproduct(i)]
range = unproduct(a)[unproduct(i)]
return cartesianrange(prod, range)
end

function Base.show(io::IO, a::CartesianProductUnitRange)
show(io, unproduct(a))
return nothing
Expand Down
93 changes: 89 additions & 4 deletions src/fillarrays/kroneckerarray.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using FillArrays: FillArrays, Zeros
using FillArrays: FillArrays, Ones, Zeros
function FillArrays.fillsimilar(
a::Zeros{T},
ax::Tuple{
Expand All @@ -21,6 +21,11 @@ const SquareEyeKronecker{T,A<:SquareEye{T},B<:AbstractMatrix{T}} = KroneckerMatr
const KroneckerSquareEye{T,A<:AbstractMatrix{T},B<:SquareEye{T}} = KroneckerMatrix{T,A,B}
const SquareEyeSquareEye{T,A<:SquareEye{T},B<:SquareEye{T}} = KroneckerMatrix{T,A,B}

using DiagonalArrays: Delta
const DeltaKronecker{T,N,A<:Delta{T,N},B<:AbstractArray{T,N}} = KroneckerArray{T,N,A,B}
const KroneckerDelta{T,N,A<:AbstractArray{T,N},B<:Delta{T,N}} = KroneckerArray{T,N,A,B}
const DeltaDelta{T,N,A<:Delta{T,N},B<:Delta{T,N}} = KroneckerArray{T,N,A,B}

_getindex(a::Eye, I1::Colon, I2::Colon) = a
_getindex(a::Eye, I1::Base.Slice, I2::Base.Slice) = a
_getindex(a::Eye, I1::Base.Slice, I2::Colon) = a
Expand All @@ -30,15 +35,23 @@ _view(a::Eye, I1::Base.Slice, I2::Base.Slice) = a
_view(a::Eye, I1::Base.Slice, I2::Colon) = a
_view(a::Eye, I1::Colon, I2::Base.Slice) = a

function _getindex(a::Delta, I1::Union{Colon,Base.Slice}, Irest::Union{Colon,Base.Slice}...)
return a
end
function _view(a::Delta, I1::Union{Colon,Base.Slice}, Irest::Union{Colon,Base.Slice}...)
return a
end

# Like `adapt` but preserves `Eye`.
_adapt(to, a::Eye) = a
_adapt(to, a::Delta) = a

# Allows customizing for `FillArrays.Eye`.
function _convert(::Type{AbstractArray{T}}, a::RectDiagonal) where {T}
_convert(AbstractMatrix{T}, a)
return _convert(AbstractMatrix{T}, a)
end
function _convert(::Type{AbstractMatrix{T}}, a::RectDiagonal) where {T}
RectDiagonal(convert(AbstractVector{T}, _diagview(a)), axes(a))
return RectDiagonal(convert(AbstractVector{T}, _diagview(a)), axes(a))
end

# Like `similar` but preserves `Eye`, `Ones`, etc.
Expand All @@ -61,8 +74,33 @@ function _similar(arrayt::Type{<:SquareEye}, axs::NTuple{2,AbstractUnitRange})
return Eye{eltype(arrayt)}((only(unique(axs)),))
end

# Like `copy` but preserves `Eye`.
function _similar(a::Delta, elt::Type, axs::Tuple{Vararg{AbstractUnitRange}})
return Delta{elt}(axs)
end
function _similar(arrayt::Type{<:Delta}, axs::Tuple{Vararg{AbstractUnitRange}})
return Delta{eltype(arrayt)}(axs)
end

# Like `copy` but preserves `Eye`/`Delta`.
_copy(a::Eye) = a
_copy(a::Delta) = a

function _copyto!!(dest::Eye{<:Any,N}, src::Eye{<:Any,N}) where {N}
size(dest) == size(src) ||
throw(ArgumentError("Sizes do not match: $(size(dest)) != $(size(src))."))
return dest
end
function _copyto!!(dest::Delta{<:Any,N}, src::Delta{<:Any,N}) where {N}
size(dest) == size(src) ||
throw(ArgumentError("Sizes do not match: $(size(dest)) != $(size(src))."))
return dest
end

function _permutedims!!(dest::Delta, src::Delta, perm)
Base.PermutedDimsArrays.genperm(axes(src), perm) == axes(dest) ||
throw(ArgumentError("Permuted axes do not match."))
return dest
end

using Base.Broadcast:
AbstractArrayStyle, AbstractArrayStyle, BroadcastStyle, Broadcasted, broadcasted
Expand All @@ -75,10 +113,16 @@ end
Base.BroadcastStyle(style1::EyeStyle, style2::EyeStyle) = EyeStyle()
Base.BroadcastStyle(style1::EyeStyle, style2::DefaultArrayStyle) = style2

function _copyto!!(dest::Eye, src::Broadcasted{<:EyeStyle,<:Any,typeof(identity)})
axes(dest) == axes(src) || error("Dimension mismatch.")
return dest
end

function Base.similar(bc::Broadcasted{EyeStyle}, elt::Type)
return Eye{elt}(axes(bc))
end

# TODO: Define in terms of `_copyto!!` that is called on each argument.
function Base.copyto!(dest::EyeKronecker, a::Sum{<:KroneckerStyle{<:Any,EyeStyle()}})
dest2 = arg2(dest)
f = LinearCombination(a)
Expand All @@ -99,6 +143,47 @@ function Base.copyto!(dest::EyeEye, a::Sum{<:KroneckerStyle{<:Any,EyeStyle(),Eye
return error("Can't write in-place to `Eye ⊗ Eye`.")
end

struct DeltaStyle{N} <: AbstractArrayStyle{N} end
DeltaStyle(::Val{N}) where {N} = DeltaStyle{N}()
DeltaStyle{M}(::Val{N}) where {M,N} = DeltaStyle{N}()
function _BroadcastStyle(A::Type{<:Delta})
return DeltaStyle{ndims(A)}()
end
Base.BroadcastStyle(style1::DeltaStyle, style2::DeltaStyle) = DeltaStyle()
Base.BroadcastStyle(style1::DeltaStyle, style2::DefaultArrayStyle) = style2

function _copyto!!(dest::Delta, src::Broadcasted{<:DeltaStyle,<:Any,typeof(identity)})
axes(dest) == axes(src) || error("Dimension mismatch.")
return dest
end

function Base.similar(bc::Broadcasted{<:DeltaStyle}, elt::Type)
return Delta{elt}(axes(bc))
end

# TODO: Dispatch on `DeltaStyle`.
function Base.copyto!(dest::DeltaKronecker, a::Sum{<:KroneckerStyle})
dest2 = arg2(dest)
f = LinearCombination(a)
args = arguments(a)
arg2s = arg2.(args)
dest2 .= f.(arg2s...)
return dest
end
# TODO: Dispatch on `DeltaStyle`.
function Base.copyto!(dest::KroneckerDelta, a::Sum{<:KroneckerStyle})
dest1 = arg1(dest)
f = LinearCombination(a)
args = arguments(a)
arg1s = arg1.(args)
dest1 .= f.(arg1s...)
return dest
end
# TODO: Dispatch on `DeltaStyle`.
function Base.copyto!(dest::DeltaDelta, a::Sum{<:KroneckerStyle})
return error("Can't write in-place to `Delta ⊗ Delta`.")
end

# Simplification rules similar to those for FillArrays.jl:
# https://github.com/JuliaArrays/FillArrays.jl/blob/v1.13.0/src/fillbroadcast.jl
using FillArrays: Zeros
Expand Down
52 changes: 46 additions & 6 deletions src/kroneckerarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,26 @@ _copy(a::AbstractArray) = copy(a)
function Base.copy(a::KroneckerArray)
return _copy(arg1(a)) ⊗ _copy(arg2(a))
end
function Base.copyto!(dest::KroneckerArray, src::KroneckerArray)
copyto!(arg1(dest), arg1(src))
copyto!(arg2(dest), arg2(src))

# Allows extra customization, like for `FillArrays.Eye`.
function _copyto!!(dest::AbstractArray{<:Any,N}, src::AbstractArray{<:Any,N}) where {N}
copyto!(dest, src)
return dest
end
function _copyto!!(dest::AbstractArray, src::Broadcasted)
copyto!(dest, src)
return dest
end

function Base.copyto!(dest::KroneckerArray{<:Any,N}, src::KroneckerArray{<:Any,N}) where {N}
return copyto!_kronecker(dest, src)
end
function copyto!_kronecker(
dest::KroneckerArray{<:Any,N}, src::KroneckerArray{<:Any,N}
) where {N}
# TODO: Check if neither argument is mutated and if so error.
_copyto!!(arg1(dest), arg1(src))
_copyto!!(arg2(dest), arg2(src))
return dest
end

Expand Down Expand Up @@ -110,6 +127,23 @@ function Base.similar(
return similar(promote_type(A, B), sz)
end

function _permutedims!!(dest::AbstractArray, src::AbstractArray, perm)
permutedims!(dest, src, perm)
return dest
end

using DerivableInterfaces: DerivableInterfaces, permuteddims
function DerivableInterfaces.permuteddims(a::KroneckerArray, perm)
return permuteddims(arg1(a), perm) ⊗ permuteddims(arg2(a), perm)
end

function Base.permutedims!(dest::KroneckerArray, src::KroneckerArray, perm)
# TODO: Error if neither argument is mutable.
_permutedims!!(arg1(dest), arg1(src), perm)
_permutedims!!(arg2(dest), arg2(src), perm)
return dest
end

function flatten(t::Tuple{Tuple,Tuple,Vararg{Tuple}})
return (t[1]..., flatten(Base.tail(t))...)
end
Expand All @@ -128,7 +162,7 @@ function kron_nd(a::AbstractArray{<:Any,N}, b::AbstractArray{<:Any,N}) where {N}
a′ = reshape(a, interleave(size(a), ntuple(one, N)))
b′ = reshape(b, interleave(ntuple(one, N), size(b)))
c′ = permutedims(a′ .* b′, reverse(ntuple(identity, 2N)))
sz = ntuple(i -> size(a, i) * size(b, i), N)
sz = reverse(ntuple(i -> size(a, i) * size(b, i), N))
return permutedims(reshape(c′, sz), reverse(ntuple(identity, N)))
end
kron_nd(a::AbstractMatrix, b::AbstractMatrix) = kron(a, b)
Expand Down Expand Up @@ -284,6 +318,12 @@ for f in [:transpose, :adjoint, :inv]
end
end

function Base.reshape(
a::KroneckerArray, ax::Tuple{CartesianProductUnitRange,Vararg{CartesianProductUnitRange}}
)
return reshape(arg1(a), map(arg1, ax)) ⊗ reshape(arg2(a), map(arg2, ax))
end

# Allows for customizations for FillArrays.
_BroadcastStyle(x) = BroadcastStyle(x)

Expand Down Expand Up @@ -405,8 +445,8 @@ Broadcast.materialize!(dest, a::KroneckerBroadcasted) = copyto!(dest, a)
Broadcast.broadcastable(a::KroneckerBroadcasted) = a
Base.copy(a::KroneckerBroadcasted) = copy(arg1(a)) ⊗ copy(arg2(a))
function Base.copyto!(dest::KroneckerArray, a::KroneckerBroadcasted)
copyto!(arg1(dest), copy(arg1(a)))
copyto!(arg2(dest), copy(arg2(a)))
_copyto!!(arg1(dest), arg1(a))
_copyto!!(arg2(dest), arg2(a))
return dest
end
function Base.eltype(a::KroneckerBroadcasted)
Expand Down
2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
TensorProducts = "decf83d6-1968-43f4-96dc-fdb3fe15fc6d"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"
Expand All @@ -34,6 +35,7 @@ MatrixAlgebraKit = "0.2"
SafeTestsets = "0.1"
StableRNGs = "1.0"
Suppressor = "0.2"
TensorAlgebra = "0.3.10"
TensorProducts = "0.1.7"
Test = "1.10"
TestExtras = "0.3"
Loading
Loading