Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,23 @@ uuid = "90137ffa-7385-5640-81b9-e52037218182"
version = "1.5.2"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
julia = "1.6"
StaticArraysCore = "1"
julia = "1.6"

[extras]
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"

[targets]
test = ["InteractiveUtils", "Test", "BenchmarkTools", "OffsetArrays"]
test = ["InteractiveUtils", "Test", "BenchmarkTools", "OffsetArrays", "Zygote", "ForwardDiff"]
Comment on lines +21 to +25

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regardless of where these rules end up living, they should be tested with ChainRulesTestUtils and not ad-hoc via Zygote/ForwardDiff/what have you. That library does far more robust testing than most of us would think to write by hand.

2 changes: 2 additions & 0 deletions src/StaticArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import Statistics: mean

using Random
import Random: rand, randn, randexp, rand!, randn!, randexp!
using ChainRulesCore
using Core.Compiler: return_type
import Base: sqrt, exp, log, float, real
using LinearAlgebra
Expand Down Expand Up @@ -129,6 +130,7 @@ include("io.jl")
include("pinv.jl")

include("precompile.jl")
include("chainrule.jl")
_precompile_()

end # module
23 changes: 23 additions & 0 deletions src/chainrule.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
### Projecting a tuple to SMatrix leads to ChainRulesCore._projection_mismatch by default, so overloaded here
function (project::ChainRulesCore.ProjectTo{<:Tangent{<:Tuple}})(dx::SArray)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the MWE that hits this?

And might it be simpler just to do something like (p::ChainRulesCore.ProjectTo{<:Tangent{<:Tuple}})(dx::SArray) = p(Tuple(dx)), correct the type & then take the same path as other tuples?

dy = reshape(dx, axes(project.elements)) # allows for dx::OffsetArray
dz = ntuple(i -> project.elements[i](dy[i]), length(project.elements))
return ChainRulesCore.project_type(project)(dz...)
end

### Project SArray to SArray
function ChainRulesCore.ProjectTo(x::SArray{S,T}) where {S, T}
return ChainRulesCore.ProjectTo{SArray}(; element=ChainRulesCore._eltype_projectto(T), axes=S)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If anyone is relying on ProjectTo(x).axes actually containing axes(x), then this will be very surprising. Maybe this Tuple{size...} thing ought to have a different name:

Suggested change
return ChainRulesCore.ProjectTo{SArray}(; element=ChainRulesCore._eltype_projectto(T), axes=S)
return ChainRulesCore.ProjectTo{SArray}(; element=ChainRulesCore._eltype_projectto(T), axes=axes(x), static_size=S)

end

function (project::ChainRulesCore.ProjectTo{SArray})(dx::AbstractArray{S,M}) where {S,M}
return SArray{project.axes}(dx)
Comment on lines +13 to +14
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To match what projection does for Array, I think this ought to use the element projector to correct eltype.

The CRC one also checks that the size differs only in trivial ways (i.e. size-1 trailing dimensions), otherwise errors. Sometimes this is helpful for finding bugs in AD rules. I think this will accept any shape with the right length, which most bugs will still hit...

end

### Adjoint for SArray constructor

function ChainRulesCore.rrule(::Type{T}, x::Tuple) where {T<:SArray}
project_x = ProjectTo(x)
Array_pullback(ȳ) = (NoTangent(), project_x(ȳ))
return T(x), Array_pullback
end
39 changes: 38 additions & 1 deletion test/abstractarray.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using StaticArrays, Test, LinearAlgebra
using StaticArrays, Test, LinearAlgebra, Zygote, ForwardDiff

@testset "AbstractArray interface" begin
@testset "size and length" begin
Expand Down Expand Up @@ -243,6 +243,43 @@ using StaticArrays, Test, LinearAlgebra
@test rs == Base.reduced_indices(axes(a), i)
end
end

@testset "AutoDiff" begin
u0 = @SVector rand(2)
p = @SVector rand(4)

function lotka(u, p, svec=true)
du1 = p[1]*u[1] - p[2]*u[1]*u[2]
du2 = -p[3]*u[2] + p[4]*u[1]*u[2]
if svec
@SVector [du1, du2]
else
@SMatrix [du1 du2 du1; du2 du1 du1]
end
end

#SVector constructor adjoint
function loss(p)
u = lotka(u0, p)
sum(1 .- u)
end
Comment on lines +251 to +265

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, this seems a rather not minimal and domain-specific test?


grad = Zygote.gradient(loss, p)
@test typeof(grad[1]) <: SArray
grad2 = ForwardDiff.gradient(loss, p)
@test grad[1] ≈ grad2 rtol=1e-12

#SMatrix constructor adjoint
function loss_mat(p)
u = lotka(u0, p, false)
sum(1 .- u)
end

grad = Zygote.gradient(loss_mat, p)
@test typeof(grad[1]) <: SArray
grad2 = ForwardDiff.gradient(loss_mat, p)
@test grad[1] ≈ grad2 rtol=1e-12
end
end

@testset "permutedims" begin
Expand Down