From 80ad9ccbde09b8e4b34b5013c1ebc34c548a71e1 Mon Sep 17 00:00:00 2001 From: Abhishek Bhatt Date: Mon, 1 Aug 2022 18:17:57 +0530 Subject: [PATCH 1/4] initial changes --- Project.toml | 3 ++- src/StaticArrays.jl | 1 + src/chainrule.jl | 39 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 42 insertions(+), 1 deletion(-) create mode 100644 src/chainrule.jl diff --git a/Project.toml b/Project.toml index 673df031..6200fd4b 100644 --- a/Project.toml +++ b/Project.toml @@ -3,14 +3,15 @@ uuid = "90137ffa-7385-5640-81b9-e52037218182" version = "1.5.2" [deps] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] -julia = "1.6" StaticArraysCore = "1" +julia = "1.6" [extras] BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" diff --git a/src/StaticArrays.jl b/src/StaticArrays.jl index e2e3fd83..196026d7 100644 --- a/src/StaticArrays.jl +++ b/src/StaticArrays.jl @@ -12,6 +12,7 @@ import Statistics: mean using Random import Random: rand, randn, randexp, rand!, randn!, randexp! +using ChainRulesCore using Core.Compiler: return_type import Base: sqrt, exp, log, float, real using LinearAlgebra diff --git a/src/chainrule.jl b/src/chainrule.jl new file mode 100644 index 00000000..31107ca3 --- /dev/null +++ b/src/chainrule.jl @@ -0,0 +1,39 @@ +##### +##### constructors +##### + +ChainRulesCore.@non_differentiable (::Type{T} where {T<:Union{SArray, SizedArray}})(::UndefInitializer, args...) + +function ChainRulesCore.frule((_, ẋ), ::Type{T}, x::Tuple) where {T<:Union{SArray, SizedArray}} + return T(x), T(ẋ) +end + +function ChainRulesCore.rrule(::Type{T}, x::Tuple) where {T<:Union{SArray, SizedArray}} + project_x = ProjectTo(x) + Array_pullback(ȳ) = (NoTangent(), project_x(ȳ)) + return T(x), Array_pullback +end + +function (project::ChainRulesCore.ProjectTo{AbstractArray})(dx::AbstractArray{S,M}) where {S,M} + # First deal with shape. The rule is that we reshape to add or remove trivial dimensions + # like dx = ones(4,1), where x = ones(4), but throw an error on dx = ones(1,4) etc. + dy = if axes(dx) === project.axes + dx + else + for d in 1:max(M, length(project.axes)) + if size(dx, d) != length(get(project.axes, d, 1)) + throw(_projection_mismatch(project.axes, size(dx))) + end + end + reshape(dx, project.axes) + end + # Then deal with the elements. One projector if AbstractArray{<:Number}, + # or one per element for arrays of anything else, including arrays of arrays: + dz = if hasproperty(project, :element) + T = project_type(project.element) + S <: T ? dy : map(project.element, dy) + else + map((f, y) -> f(y), project.elements, dy) + end + return dz +end \ No newline at end of file From bf2ae51bebe981395f41234bb5f676826b6ff3cb Mon Sep 17 00:00:00 2001 From: Abhishek Bhatt Date: Tue, 2 Aug 2022 17:04:14 +0530 Subject: [PATCH 2/4] custom projection methods --- src/StaticArrays.jl | 1 + src/chainrule.jl | 54 ++++++++++++++++++++------------------------- 2 files changed, 25 insertions(+), 30 deletions(-) diff --git a/src/StaticArrays.jl b/src/StaticArrays.jl index 196026d7..7b353f84 100644 --- a/src/StaticArrays.jl +++ b/src/StaticArrays.jl @@ -130,6 +130,7 @@ include("io.jl") include("pinv.jl") include("precompile.jl") +include("chainrule.jl") _precompile_() end # module diff --git a/src/chainrule.jl b/src/chainrule.jl index 31107ca3..d546e592 100644 --- a/src/chainrule.jl +++ b/src/chainrule.jl @@ -1,39 +1,33 @@ -##### -##### constructors -##### +### Projecting a tuple to SMatrix leads to ChainRulesCore._projection_mismatch by default, so overloaded here +function (project::ChainRulesCore.ProjectTo{<:Tangent{<:Tuple}})(dx::SArray) + # for d in 1:ndims(dx) + # if size(dx, d) != get(length(project.elements), d, 1) + # throw(ChainRulesCore._projection_mismatch(axes(project.elements), size(dx))) + # end + # end + dy = reshape(dx, axes(project.elements)) # allows for dx::OffsetArray + dz = ntuple(i -> project.elements[i](dy[i]), length(project.elements)) + return ChainRulesCore.project_type(project)(dz...) +end +### Project SArray to SArray +function ChainRulesCore.ProjectTo(x::SArray{S,T}) where {S, T} + return ChainRulesCore.ProjectTo{SArray}(; element=ChainRulesCore._eltype_projectto(T), axes=S) +end + +function (project::ChainRulesCore.ProjectTo{SArray})(dx::AbstractArray{S,M}) where {S,M} + return SArray{project.axes}(dx) +end -ChainRulesCore.@non_differentiable (::Type{T} where {T<:Union{SArray, SizedArray}})(::UndefInitializer, args...) +### Adjoint for SArray constructor -function ChainRulesCore.frule((_, ẋ), ::Type{T}, x::Tuple) where {T<:Union{SArray, SizedArray}} +ChainRulesCore.@non_differentiable (::Type{T} where {T<:SArray})(::UndefInitializer, args...) + +function ChainRulesCore.frule((_, ẋ), ::Type{T}, x::Tuple) where {T<:SArray} return T(x), T(ẋ) end -function ChainRulesCore.rrule(::Type{T}, x::Tuple) where {T<:Union{SArray, SizedArray}} +function ChainRulesCore.rrule(::Type{T}, x::Tuple) where {T<:SArray} project_x = ProjectTo(x) Array_pullback(ȳ) = (NoTangent(), project_x(ȳ)) return T(x), Array_pullback -end - -function (project::ChainRulesCore.ProjectTo{AbstractArray})(dx::AbstractArray{S,M}) where {S,M} - # First deal with shape. The rule is that we reshape to add or remove trivial dimensions - # like dx = ones(4,1), where x = ones(4), but throw an error on dx = ones(1,4) etc. - dy = if axes(dx) === project.axes - dx - else - for d in 1:max(M, length(project.axes)) - if size(dx, d) != length(get(project.axes, d, 1)) - throw(_projection_mismatch(project.axes, size(dx))) - end - end - reshape(dx, project.axes) - end - # Then deal with the elements. One projector if AbstractArray{<:Number}, - # or one per element for arrays of anything else, including arrays of arrays: - dz = if hasproperty(project, :element) - T = project_type(project.element) - S <: T ? dy : map(project.element, dy) - else - map((f, y) -> f(y), project.elements, dy) - end - return dz end \ No newline at end of file From 49f41844038a010c34ce8d47f21025f4f9c40a7d Mon Sep 17 00:00:00 2001 From: Abhishek Bhatt Date: Wed, 3 Aug 2022 16:50:04 +0530 Subject: [PATCH 3/4] Tests --- Project.toml | 4 +++- src/chainrule.jl | 6 +----- test/abstractarray.jl | 39 ++++++++++++++++++++++++++++++++++++++- 3 files changed, 42 insertions(+), 7 deletions(-) diff --git a/Project.toml b/Project.toml index 6200fd4b..eb8e049f 100644 --- a/Project.toml +++ b/Project.toml @@ -18,6 +18,8 @@ BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" [targets] -test = ["InteractiveUtils", "Test", "BenchmarkTools", "OffsetArrays"] +test = ["InteractiveUtils", "Test", "BenchmarkTools", "OffsetArrays", "Zygote", "ForwardDiff"] diff --git a/src/chainrule.jl b/src/chainrule.jl index d546e592..68bf1226 100644 --- a/src/chainrule.jl +++ b/src/chainrule.jl @@ -1,14 +1,10 @@ ### Projecting a tuple to SMatrix leads to ChainRulesCore._projection_mismatch by default, so overloaded here function (project::ChainRulesCore.ProjectTo{<:Tangent{<:Tuple}})(dx::SArray) - # for d in 1:ndims(dx) - # if size(dx, d) != get(length(project.elements), d, 1) - # throw(ChainRulesCore._projection_mismatch(axes(project.elements), size(dx))) - # end - # end dy = reshape(dx, axes(project.elements)) # allows for dx::OffsetArray dz = ntuple(i -> project.elements[i](dy[i]), length(project.elements)) return ChainRulesCore.project_type(project)(dz...) end + ### Project SArray to SArray function ChainRulesCore.ProjectTo(x::SArray{S,T}) where {S, T} return ChainRulesCore.ProjectTo{SArray}(; element=ChainRulesCore._eltype_projectto(T), axes=S) diff --git a/test/abstractarray.jl b/test/abstractarray.jl index d8aff256..8fc6c917 100644 --- a/test/abstractarray.jl +++ b/test/abstractarray.jl @@ -1,4 +1,4 @@ -using StaticArrays, Test, LinearAlgebra +using StaticArrays, Test, LinearAlgebra, Zygote, ForwardDiff @testset "AbstractArray interface" begin @testset "size and length" begin @@ -243,6 +243,43 @@ using StaticArrays, Test, LinearAlgebra @test rs == Base.reduced_indices(axes(a), i) end end + + @testset "AutoDiff" begin + u0 = @SVector rand(2) + p = @SVector rand(4) + + function lotka(u, p, svec=true) + du1 = p[1]*u[1] - p[2]*u[1]*u[2] + du2 = -p[3]*u[2] + p[4]*u[1]*u[2] + if svec + @SVector [du1, du2] + else + @SMatrix [du1 du2 du1; du2 du1 du1] + end + end + + #SVector constructor adjoint + function loss(p) + u = lotka(u0, p) + sum(1 .- u) + end + + grad = Zygote.gradient(loss, p) + @test typeof(grad[1]) <: SArray + grad2 = ForwardDiff.gradient(loss, p) + @test grad[1] ≈ grad2 rtol=1e-12 + + #SMatrix constructor adjoint + function loss_mat(p) + u = lotka(u0, p, false) + sum(1 .- u) + end + + grad = Zygote.gradient(loss_mat, p) + @test typeof(grad[1]) <: SArray + grad2 = ForwardDiff.gradient(loss_mat, p) + @test grad[1] ≈ grad2 rtol=1e-12 + end end @testset "permutedims" begin From 92ca72fefce97b08ac98da5879b94ce70744b4a6 Mon Sep 17 00:00:00 2001 From: Abhishek Bhatt Date: Mon, 8 Aug 2022 11:48:40 +0530 Subject: [PATCH 4/4] Removed ambiguous methods --- src/chainrule.jl | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/chainrule.jl b/src/chainrule.jl index 68bf1226..f177081c 100644 --- a/src/chainrule.jl +++ b/src/chainrule.jl @@ -16,12 +16,6 @@ end ### Adjoint for SArray constructor -ChainRulesCore.@non_differentiable (::Type{T} where {T<:SArray})(::UndefInitializer, args...) - -function ChainRulesCore.frule((_, ẋ), ::Type{T}, x::Tuple) where {T<:SArray} - return T(x), T(ẋ) -end - function ChainRulesCore.rrule(::Type{T}, x::Tuple) where {T<:SArray} project_x = ProjectTo(x) Array_pullback(ȳ) = (NoTangent(), project_x(ȳ))