Skip to content

Commit c054aee

Browse files
Merge pull request #217 from SciML/staticarrayscore
Drop StaticArrays by using StaticArraysCore
2 parents 945a9d1 + 611b0a1 commit c054aee

File tree

8 files changed

+93
-32
lines changed

8 files changed

+93
-32
lines changed

.github/workflows/CI.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ jobs:
3434
${{ runner.os }}-
3535
- uses: julia-actions/julia-buildpkg@v1
3636
- uses: julia-actions/julia-runtest@v1
37+
env:
38+
GROUP: ${{ matrix.group }}
3739
- uses: julia-actions/julia-processcoverage@v1
3840
- uses: codecov/codecov-action@v1
3941
with:

Project.toml

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,27 +6,27 @@ version = "2.30.0"
66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
88
ArrayInterfaceCore = "30b0a656-2188-435a-8636-2ec0e6a096e2"
9-
ArrayInterfaceStaticArrays = "b0d46f97-bff5-4637-a19a-dd75974142cd"
9+
ArrayInterfaceStaticArraysCore = "dd5226c6-a4d4-4bc7-8575-46859f9c95b9"
1010
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
1111
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
1212
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
1313
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
1414
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1515
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
16-
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
16+
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
1717
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1818
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
1919

2020
[compat]
2121
Adapt = "3"
2222
ArrayInterfaceCore = "0.1.1"
23-
ArrayInterfaceStaticArrays = "0.1"
23+
ArrayInterfaceStaticArraysCore = "0.1"
2424
ChainRulesCore = "0.10.7, 1"
2525
DocStringExtensions = "0.8, 0.9"
2626
FillArrays = "0.11, 0.12, 0.13"
2727
GPUArraysCore = "0.1"
2828
RecipesBase = "0.7, 0.8, 1.0"
29-
StaticArrays = "0.12, 1.0"
29+
StaticArraysCore = "1"
3030
ZygoteRules = "0.2"
3131
julia = "1.6"
3232

@@ -36,10 +36,11 @@ NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
3636
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
3737
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
3838
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
39+
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
3940
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
4041
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
4142
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
4243
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
4344

4445
[targets]
45-
test = ["ForwardDiff", "NLsolve", "OrdinaryDiffEq", "Pkg", "Test", "Unitful", "Random", "StructArrays", "Zygote"]
46+
test = ["ForwardDiff", "NLsolve", "OrdinaryDiffEq", "Pkg", "Test", "Unitful", "Random", "StaticArrays", "StructArrays", "Zygote"]

src/RecursiveArrayTools.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ $(DocStringExtensions.README)
55
module RecursiveArrayTools
66

77
using DocStringExtensions
8-
using RecipesBase, StaticArrays, Statistics,
8+
using RecipesBase, StaticArraysCore, Statistics,
99
ArrayInterfaceCore, LinearAlgebra
1010

1111
import ChainRulesCore
@@ -15,7 +15,7 @@ import ZygoteRules, Adapt
1515
# Required for the downstream_events.jl test
1616
# Since `ismutable` on an ArrayPartition needs
1717
# to know static arrays are not mutable
18-
import ArrayInterfaceStaticArrays
18+
import ArrayInterfaceStaticArraysCore
1919

2020
using FillArrays
2121

src/array_partition.jl

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ Base.zero(A::ArrayPartition, dims::NTuple{N,Int}) where {N} = zero(A)
9090

9191
## Array
9292

93-
Base.Array(A::ArrayPartition) = ArrayPartition(Array.(A.x))
93+
Base.Array(A::ArrayPartition) = reduce(vcat,Array.(A.x))
9494
Base.Array(VA::AbstractVectorOfArray{T,N,A}) where {T,N,A <: AbstractVector{<:ArrayPartition}} = reduce(hcat,Array.(VA.u))
9595

9696
## ones
@@ -390,13 +390,13 @@ end
390390
# [U11 U12 U13] [ b1 ]
391391
# [ 0 U22 U23] \ [ b2 ]
392392
# [ 0 0 U33] [ b3 ]
393-
function LinearAlgebra.ldiv!(A::T, bb::ArrayPartition) where T<:Union{UnitUpperTriangular,UpperTriangular}
393+
function LinearAlgebra.ldiv!(A::UnitUpperTriangular, bb::ArrayPartition)
394394
A = A.data
395395
n = npartitions(bb)
396396
b = bb.x
397397
lens = map(length, b)
398398
@inbounds for j in n:-1:1
399-
Ajj = T(getblock(A, lens, j, j))
399+
Ajj = UnitUpperTriangular(getblock(A, lens, j, j))
400400
xj = ldiv!(Ajj, vec(b[j]))
401401
for i in j-1:-1:1
402402
Aij = getblock(A, lens, i, j)
@@ -407,13 +407,30 @@ function LinearAlgebra.ldiv!(A::T, bb::ArrayPartition) where T<:Union{UnitUpperT
407407
return bb
408408
end
409409

410-
function LinearAlgebra.ldiv!(A::T, bb::ArrayPartition) where T<:Union{UnitLowerTriangular,LowerTriangular}
410+
function LinearAlgebra.ldiv!(A::UpperTriangular, bb::ArrayPartition)
411+
A = A.data
412+
n = npartitions(bb)
413+
b = bb.x
414+
lens = map(length, b)
415+
@inbounds for j in n:-1:1
416+
Ajj = UpperTriangular(getblock(A, lens, j, j))
417+
xj = ldiv!(Ajj, vec(b[j]))
418+
for i in j-1:-1:1
419+
Aij = getblock(A, lens, i, j)
420+
# bi = -Aij * xj + bi
421+
mul!(vec(b[i]), Aij, xj, -1, true)
422+
end
423+
end
424+
return bb
425+
end
426+
427+
function LinearAlgebra.ldiv!(A::UnitLowerTriangular, bb::ArrayPartition)
411428
A = A.data
412429
n = npartitions(bb)
413430
b = bb.x
414431
lens = map(length, b)
415432
@inbounds for j in 1:n
416-
Ajj = T(getblock(A, lens, j, j))
433+
Ajj = UnitLowerTriangular(getblock(A, lens, j, j))
417434
xj = ldiv!(Ajj, vec(b[j]))
418435
for i in j+1:n
419436
Aij = getblock(A, lens, i, j)
@@ -423,6 +440,24 @@ function LinearAlgebra.ldiv!(A::T, bb::ArrayPartition) where T<:Union{UnitLowerT
423440
end
424441
return bb
425442
end
443+
444+
function LinearAlgebra.ldiv!(A::LowerTriangular, bb::ArrayPartition)
445+
A = A.data
446+
n = npartitions(bb)
447+
b = bb.x
448+
lens = map(length, b)
449+
@inbounds for j in 1:n
450+
Ajj = LowerTriangular(getblock(A, lens, j, j))
451+
xj = ldiv!(Ajj, vec(b[j]))
452+
for i in j+1:n
453+
Aij = getblock(A, lens, i, j)
454+
# bi = -Aij * xj + b[i]
455+
mul!(vec(b[i]), Aij, xj, -1, true)
456+
end
457+
end
458+
return bb
459+
end
460+
426461
# TODO: optimize
427462
function LinearAlgebra._ipiv_rows!(A::LU, order::OrdinalRange, B::ArrayPartition)
428463
for i = order

src/utils.jl

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@ like `copy` on arrays of scalars.
99
function recursivecopy(a)
1010
deepcopy(a)
1111
end
12-
recursivecopy(a::Union{SVector,SMatrix,SArray,Number}) = copy(a)
12+
recursivecopy(a::Union{StaticArraysCore.SVector,StaticArraysCore.SMatrix,
13+
StaticArraysCore.SArray,Number}) = copy(a)
1314
function recursivecopy(a::AbstractArray{T,N}) where {T<:Number,N}
1415
copy(a)
1516
end
@@ -33,7 +34,7 @@ like `copy!` on arrays of scalars.
3334
"""
3435
function recursivecopy! end
3536

36-
function recursivecopy!(b::AbstractArray{T,N},a::AbstractArray{T2,N}) where {T<:StaticArray,T2<:StaticArray,N}
37+
function recursivecopy!(b::AbstractArray{T,N},a::AbstractArray{T2,N}) where {T<:StaticArraysCore.StaticArray,T2<:StaticArraysCore.StaticArray,N}
3738
@inbounds for i in eachindex(a)
3839
# TODO: Check for `setindex!`` and use `copy!(b[i],a[i])` or `b[i] = a[i]`, see #19
3940
b[i] = copy(a[i])
@@ -68,13 +69,13 @@ A recursive `fill!` function.
6869
"""
6970
function recursivefill! end
7071

71-
function recursivefill!(b::AbstractArray{T,N},a::T2) where {T<:StaticArray,T2<:StaticArray,N}
72+
function recursivefill!(b::AbstractArray{T,N},a::T2) where {T<:StaticArraysCore.StaticArray,T2<:StaticArraysCore.StaticArray,N}
7273
@inbounds for i in eachindex(b)
7374
b[i] = copy(a)
7475
end
7576
end
7677

77-
function recursivefill!(b::AbstractArray{T,N},a::T2) where {T<:SArray,T2<:Union{Number,Bool},N}
78+
function recursivefill!(b::AbstractArray{T,N},a::T2) where {T<:StaticArraysCore.SArray,T2<:Union{Number,Bool},N}
7879
@inbounds for i in eachindex(b)
7980
b[i] = fill(a, typeof(b[i]))
8081
end
@@ -88,7 +89,7 @@ function recursivefill!(b::AbstractArray{T,N},a::T2) where {T<:Union{Number,Bool
8889
fill!(b, a)
8990
end
9091

91-
function recursivefill!(b::AbstractArray{T,N},a) where {T<:MArray,N}
92+
function recursivefill!(b::AbstractArray{T,N},a) where {T<:StaticArraysCore.MArray,N}
9293
@inbounds for i in eachindex(b)
9394
if isassigned(b,i)
9495
recursivefill!(b[i],a)
@@ -151,7 +152,7 @@ If `i<length(x)`, it's simply a `recursivecopy!` to the `i`th element. Otherwise
151152
function copyat_or_push!(a::AbstractVector{T},i::Int,x,nc::Type{Val{perform_copy}}=Val{true}) where {T,perform_copy}
152153
@inbounds if length(a) >= i
153154
if !ArrayInterfaceCore.ismutable(T) || !perform_copy
154-
# TODO: Check for `setindex!`` if T <: StaticArray and use `copy!(b[i],a[i])`
155+
# TODO: Check for `setindex!`` if T <: StaticArraysCore.StaticArray and use `copy!(b[i],a[i])`
155156
# or `b[i] = a[i]`, see https://github.com/JuliaDiffEq/RecursiveArrayTools.jl/issues/19
156157
a[i] = x
157158
else
@@ -208,7 +209,15 @@ ones has a `Array{Array{Float64,N},N}`, this will return `Array{Float64,N}`.
208209
"""
209210
recursive_unitless_eltype(a) = recursive_unitless_eltype(eltype(a))
210211
recursive_unitless_eltype(a::Type{Any}) = Any
211-
recursive_unitless_eltype(a::Type{T}) where {T<:StaticArray} = similar_type(a,recursive_unitless_eltype(eltype(a)))
212+
213+
# Should be:
214+
# recursive_unitless_eltype(a::Type{T}) where {T<:StaticArray} = similar_type(a,recursive_unitless_eltype(eltype(a)))
215+
# But missing from StaticArraysCore
216+
recursive_unitless_eltype(a::Type{StaticArraysCore.SArray{S, T, N, L}}) where {S, T, N, L} = StaticArraysCore.SArray{S, typeof(one(T)), N, L}
217+
recursive_unitless_eltype(a::Type{StaticArraysCore.MArray{S, T, N, L}}) where {S, T, N, L} = StaticArraysCore.MArray{S, typeof(one(T)), N, L}
218+
recursive_unitless_eltype(a::Type{StaticArraysCore.SizedArray{S, T, N, M, TData}}) where {
219+
S, T, N, M, TData} = StaticArraysCore.SizedArray{S, typeof(one(T)), N, M, TData}
220+
212221
recursive_unitless_eltype(a::Type{T}) where {T<:Array} = Array{recursive_unitless_eltype(eltype(a)),ndims(a)}
213222
recursive_unitless_eltype(a::Type{T}) where {T<:Number} = typeof(one(eltype(a)))
214223
recursive_unitless_eltype(::Type{<:Enum{T}}) where T = T

test/linalg.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ using LinearAlgebra
44
n, m = 5, 6
55
bb = rand(n), rand(m)
66
b = ArrayPartition(bb)
7+
@test Array(b) isa Array
78
@test Array(b) == collect(b) == vcat(bb...)
89
A = randn(MersenneTwister(123), n+m, n+m)
910

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ end
1919

2020
@time begin
2121

22-
if !is_APPVEYOR && GROUP == "Core"
22+
if GROUP == "Core" || GROUP == "All"
2323
@time @testset "Utils Tests" begin include("utils_test.jl") end
2424
@time @testset "Partitions Tests" begin include("partitions_test.jl") end
2525
@time @testset "VecOfArr Indexing Tests" begin include("basic_indexing.jl") end

test/upstream.jl

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ dyn(u, p, t) = ArrayPartition(
3636
ArrayPartition(zeros(1), [0.0])
3737
)
3838

39-
solve(
39+
@test solve(
4040
ODEProblem(
4141
dyn,
4242
ArrayPartition(
@@ -45,15 +45,28 @@ solve(
4545
),
4646
(0.0, 1.0)
4747
),AutoTsit5(Rodas5())
48-
)
49-
50-
@test_broken solve(
51-
ODEProblem(
52-
dyn,
53-
ArrayPartition(
54-
ArrayPartition(zeros(1), [-1.0]),
55-
ArrayPartition(zeros(1), [0.75])
56-
),
57-
(0.0, 1.0)
58-
),Rodas5()
5948
).retcode == :Success
49+
50+
if VERSION < v"1.7"
51+
@test solve(
52+
ODEProblem(
53+
dyn,
54+
ArrayPartition(
55+
ArrayPartition(zeros(1), [-1.0]),
56+
ArrayPartition(zeros(1), [0.75])
57+
),
58+
(0.0, 1.0)
59+
),Rodas5()
60+
).retcode == :Success
61+
else
62+
@test_broken solve(
63+
ODEProblem(
64+
dyn,
65+
ArrayPartition(
66+
ArrayPartition(zeros(1), [-1.0]),
67+
ArrayPartition(zeros(1), [0.75])
68+
),
69+
(0.0, 1.0)
70+
),Rodas5()
71+
).retcode == :Success
72+
end

0 commit comments

Comments
 (0)