Skip to content

Commit 16dd578

Browse files
Merge pull request #80 from JuliaDiffEq/concrete
fix GPU concretizations
2 parents c0c6731 + 6340a2f commit 16dd578

File tree

6 files changed

+32
-10
lines changed

6 files changed

+32
-10
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "RecursiveArrayTools"
22
uuid = "731186ca-8d62-57ce-b412-fbd966d074cd"
33
authors = ["Chris Rackauckas <[email protected]>"]
4-
version = "1.2.1"
4+
version = "2.0.0"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"

src/RecursiveArrayTools.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ module RecursiveArrayTools
55
using Requires, RecipesBase, StaticArrays, Statistics,
66
ArrayInterface
77

8-
abstract type AbstractVectorOfArray{T, N} <: AbstractArray{T, N} end
9-
abstract type AbstractDiffEqArray{T, N} <: AbstractVectorOfArray{T, N} end
8+
abstract type AbstractVectorOfArray{T, N, A} <: AbstractArray{T, N} end
9+
abstract type AbstractDiffEqArray{T, N, A} <: AbstractVectorOfArray{T, N, A} end
1010

1111
include("utils.jl")
1212
include("vector_of_array.jl")

src/init.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,13 @@ function __init__()
1010
RecursiveArrayTools.recursive_unitless_bottom_eltype(a::ApproxFun.Fun) = recursive_unitless_bottom_eltype(ApproxFun.coefficients(a))
1111
RecursiveArrayTools.recursive_bottom_eltype(a::ApproxFun.Fun) = recursive_bottom_eltype(ApproxFun.coefficients(a))
1212
end
13+
14+
@require CuArrays="3a865a2d-5b23-5a0f-bc46-62713ec82fae" begin
15+
function CuArrays.CuArray(VA::AbstractVectorOfArray)
16+
vecs = vec.(VA.u)
17+
return CuArrays.CuArray(reshape(reduce(hcat,vecs),size(VA.u[1])...,length(VA.u)))
18+
end
19+
20+
Base.convert(::Type{<:CuArrays.CuArray},VA::AbstractVectorOfArray) = CuArrays.CuArray(VA)
21+
end
1322
end

src/vector_of_array.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,19 @@
11
# Based on code from M. Bauman Stackexchange answer + Gitter discussion
2-
mutable struct VectorOfArray{T, N, A} <: AbstractVectorOfArray{T, N}
2+
mutable struct VectorOfArray{T, N, A} <: AbstractVectorOfArray{T, N, A}
33
u::A # A <: AbstractVector{<: AbstractArray{T, N - 1}}
44
end
55
# VectorOfArray with an added series for time
6-
mutable struct DiffEqArray{T, N, A, B} <: AbstractDiffEqArray{T, N}
6+
mutable struct DiffEqArray{T, N, A, B} <: AbstractDiffEqArray{T, N, A}
77
u::A # A <: AbstractVector{<: AbstractArray{T, N - 1}}
88
t::B
99
end
1010

11+
Base.Array(VA::AbstractVectorOfArray{T,N,A}) where {T,N,A <: AbstractVector{<:AbstractVector}} = reduce(hcat,VA.u)
12+
function Base.Array(VA::AbstractVectorOfArray)
13+
vecs = vec.(VA.u)
14+
Array(reshape(reduce(hcat,vecs),size(VA.u[1])...,length(VA.u)))
15+
end
16+
1117
VectorOfArray(vec::AbstractVector{T}, dims::NTuple{N}) where {T, N} = VectorOfArray{eltype(T), N, typeof(vec)}(vec)
1218
# Assume that the first element is representative all all other elements
1319
VectorOfArray(vec::AbstractVector) = VectorOfArray(vec, (size(vec[1])..., length(vec)))

test/basic_indexing.jl

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,13 @@ testa = cat(recs..., dims=2)
66
testva = VectorOfArray(recs)
77
t = [1,2,3]
88
diffeq = DiffEqArray(recs,t)
9+
@test Array(testva) == [1 4 7
10+
2 5 8
11+
3 6 9]
912

10-
testa[1:2, 1:2] == [1 4; 2 5]
11-
testva[1:2, 1:2] == [1 4; 2 5]
12-
testa[1:2, 1:2] == [1 4; 2 5]
13+
@test testa[1:2, 1:2] == [1 4; 2 5]
14+
@test testva[1:2, 1:2] == [1 4; 2 5]
15+
@test testa[1:2, 1:2] == [1 4; 2 5]
1316

1417
# # ndims == 2
1518
recs = [rand(8) for i in 1:10]
@@ -73,3 +76,7 @@ testva[1:2, 1:2]
7376
# Test broadcast
7477
a = testva .+ rand(3,3)
7578
@test_broken a.= testva
79+
80+
recs = [rand(2,2) for i in 1:5]
81+
testva = VectorOfArray(recs)
82+
@test Array(testva) isa Array{Float64,3}

test/interface_tests.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,9 @@ testva = VectorOfArray(recs)
3939
testa = cat(recs...,dims=3)
4040
@test convert(Array,testva) == testa
4141

42-
recs = [[1, 2, 3], [3 5; 6 7], [8, 9, 10, 11]]
42+
recs = [[1 2; 3 4], [3 5; 6 7], [8 9; 10 11]]
4343
testva = VectorOfArray(recs)
44-
@test size(convert(Array,testva)) == (3,3)
44+
@test size(convert(Array,testva)) == (2,2,3)
4545

4646
# create similar VectorOfArray
4747
recs = [rand(6) for i = 1:4]

0 commit comments

Comments
 (0)