-
Notifications
You must be signed in to change notification settings - Fork 153
Adjoint and Projections #1068
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adjoint and Projections #1068
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,23 @@ | ||||||
| ### Projecting a tuple to SMatrix leads to ChainRulesCore._projection_mismatch by default, so overloaded here | ||||||
| function (project::ChainRulesCore.ProjectTo{<:Tangent{<:Tuple}})(dx::SArray) | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What's the MWE that hits this? And might it be simpler just to do something like |
||||||
| dy = reshape(dx, axes(project.elements)) # allows for dx::OffsetArray | ||||||
| dz = ntuple(i -> project.elements[i](dy[i]), length(project.elements)) | ||||||
| return ChainRulesCore.project_type(project)(dz...) | ||||||
| end | ||||||
|
|
||||||
| ### Project SArray to SArray | ||||||
| function ChainRulesCore.ProjectTo(x::SArray{S,T}) where {S, T} | ||||||
| return ChainRulesCore.ProjectTo{SArray}(; element=ChainRulesCore._eltype_projectto(T), axes=S) | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If anyone is relying on
Suggested change
|
||||||
| end | ||||||
|
|
||||||
| function (project::ChainRulesCore.ProjectTo{SArray})(dx::AbstractArray{S,M}) where {S,M} | ||||||
| return SArray{project.axes}(dx) | ||||||
|
Comment on lines
+13
to
+14
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To match what projection does for The CRC one also checks that the size differs only in trivial ways (i.e. size-1 trailing dimensions), otherwise errors. Sometimes this is helpful for finding bugs in AD rules. I think this will accept any shape with the right length, which most bugs will still hit... |
||||||
| end | ||||||
|
|
||||||
| ### Adjoint for SArray constructor | ||||||
|
|
||||||
| function ChainRulesCore.rrule(::Type{T}, x::Tuple) where {T<:SArray} | ||||||
| project_x = ProjectTo(x) | ||||||
| Array_pullback(ȳ) = (NoTangent(), project_x(ȳ)) | ||||||
| return T(x), Array_pullback | ||||||
| end | ||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,4 @@ | ||
| using StaticArrays, Test, LinearAlgebra | ||
| using StaticArrays, Test, LinearAlgebra, Zygote, ForwardDiff | ||
|
|
||
| @testset "AbstractArray interface" begin | ||
| @testset "size and length" begin | ||
|
|
@@ -243,6 +243,43 @@ using StaticArrays, Test, LinearAlgebra | |
| @test rs == Base.reduced_indices(axes(a), i) | ||
| end | ||
| end | ||
|
|
||
| @testset "AutoDiff" begin | ||
| u0 = @SVector rand(2) | ||
| p = @SVector rand(4) | ||
|
|
||
| function lotka(u, p, svec=true) | ||
| du1 = p[1]*u[1] - p[2]*u[1]*u[2] | ||
| du2 = -p[3]*u[2] + p[4]*u[1]*u[2] | ||
| if svec | ||
| @SVector [du1, du2] | ||
| else | ||
| @SMatrix [du1 du2 du1; du2 du1 du1] | ||
| end | ||
| end | ||
|
|
||
| #SVector constructor adjoint | ||
| function loss(p) | ||
| u = lotka(u0, p) | ||
| sum(1 .- u) | ||
| end | ||
|
Comment on lines
+251
to
+265
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, this seems a rather not minimal and domain-specific test? |
||
|
|
||
| grad = Zygote.gradient(loss, p) | ||
| @test typeof(grad[1]) <: SArray | ||
| grad2 = ForwardDiff.gradient(loss, p) | ||
| @test grad[1] ≈ grad2 rtol=1e-12 | ||
|
|
||
| #SMatrix constructor adjoint | ||
| function loss_mat(p) | ||
| u = lotka(u0, p, false) | ||
| sum(1 .- u) | ||
| end | ||
|
|
||
| grad = Zygote.gradient(loss_mat, p) | ||
| @test typeof(grad[1]) <: SArray | ||
| grad2 = ForwardDiff.gradient(loss_mat, p) | ||
| @test grad[1] ≈ grad2 rtol=1e-12 | ||
| end | ||
| end | ||
|
|
||
| @testset "permutedims" begin | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Regardless of where these rules end up living, they should be tested with ChainRulesTestUtils and not ad-hoc via Zygote/ForwardDiff/what have you. That library does far more robust testing than most of us would think to write by hand.