-
Notifications
You must be signed in to change notification settings - Fork 4
Add frule and rrule for bsplinebasisall
#267
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Codecov Report
@@ Coverage Diff @@
## master #267 +/- ##
==========================================
+ Coverage 98.87% 98.88% +0.01%
==========================================
Files 17 17
Lines 1507 1522 +15
==========================================
+ Hits 1490 1505 +15
Misses 17 17
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. |
|
Before this PR julia> using Zygote, BasicBSpline, BenchmarkTools, Random
julia> Random.seed!(42)
TaskLocalRNG()
julia> P = BSplineSpace{3}(KnotVector(rand(30)))
BSplineSpace{3, Float64}(KnotVector([0.02698475249996979, 0.030185550039208087, 0.05529822631508752, 0.07933388944734077, 0.09443144857141617, 0.16643864408566544, 0.1735745757945074, 0.18342915487003242, 0.2112283719951349, 0.25858546995315457, 0.32166161915780656, 0.3475083889757653, 0.3723466293304377, 0.3906633890864917, 0.4339142195043123, 0.48302213696845187, 0.49205817172341115, 0.5270150071089016, 0.5447580773835046, 0.6226472050312687, 0.6349805161807821, 0.6599599239949302, 0.6637579955379869, 0.7219827350406843, 0.802762551279973, 0.8120185311491867, 0.837335454222923, 0.9786385378290976, 0.9807576556964709, 0.9893116874624179]))
julia> n = dim(P)
26
julia> a = rand(n,n);
julia> M = BSplineManifold(a,(P,P));
julia> gradient(M, 0.5, 0.4)
ERROR: Need an adjoint for constructor StaticArraysCore.SVector{4, Float64}. Gradient is of type StaticArraysCore.SizedVector{4, Float64, Vector{Float64}}
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:35
[2] (::Zygote.Jnew{StaticArraysCore.SVector{4, Float64}, Nothing, false})(Δ::StaticArraysCore.SizedVector{4, Float64, Vector{Float64}})
@ Zygote ~/.julia/packages/Zygote/DRjAT/src/lib/lib.jl:327
[3] (::Zygote.var"#1948#back#224"{Zygote.Jnew{StaticArraysCore.SVector{4, Float64}, Nothing, false}})(Δ::StaticArraysCore.SizedVector{4, Float64, Vector{Float64}})
@ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
[4] Pullback
@ ~/.julia/packages/StaticArraysCore/HoT7O/src/StaticArraysCore.jl:107 [inlined]
[5] (::typeof(∂(StaticArraysCore.SVector{4, Float64})))(Δ::StaticArraysCore.SizedVector{4, Float64, Vector{Float64}})
@ Zygote ~/.julia/packages/Zygote/DRjAT/src/compiler/interface2.jl:0
[6] Pullback
@ ~/.julia/dev/StaticArrays/src/convert.jl:163 [inlined]
[7] macro expansion
@ ~/.julia/dev/BasicBSpline/src/_BSplineBasis.jl:225 [inlined]
[8] Pullback
@ ~/.julia/dev/BasicBSpline/src/_BSplineBasis.jl:225 [inlined]
[9] macro expansion
@ ~/.julia/dev/BasicBSpline/src/_BSplineManifold.jl:77 [inlined]
[10] Pullback
@ ~/.julia/dev/BasicBSpline/src/_BSplineManifold.jl:77 [inlined]
[11] (::typeof(∂(unsafe_mapping)))(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/DRjAT/src/compiler/interface2.jl:0
[12] macro expansion
@ ~/.julia/dev/BasicBSpline/src/_BSplineManifold.jl:117 [inlined]
[13] Pullback
@ ~/.julia/dev/BasicBSpline/src/_BSplineManifold.jl:117 [inlined]
[14] (::Zygote.var"#60#61"{typeof(∂(λ))})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/DRjAT/src/compiler/interface.jl:45
[15] gradient(::BSplineManifold{2, (3, 3), Float64, Tuple{BSplineSpace{3, Float64}, BSplineSpace{3, Float64}}}, ::Float64, ::Vararg{Float64})
@ Zygote ~/.julia/packages/Zygote/DRjAT/src/compiler/interface.jl:97
[16] top-level scope
@ REPL[7]:1After this PR julia> using Zygote, BasicBSpline, BenchmarkTools, Random
julia> Random.seed!(42)
TaskLocalRNG()
julia> P = BSplineSpace{3}(KnotVector(rand(30)))
BSplineSpace{3, Float64}(KnotVector([0.02698475249996979, 0.030185550039208087, 0.05529822631508752, 0.07933388944734077, 0.09443144857141617, 0.16643864408566544, 0.1735745757945074, 0.18342915487003242, 0.2112283719951349, 0.25858546995315457, 0.32166161915780656, 0.3475083889757653, 0.3723466293304377, 0.3906633890864917, 0.4339142195043123, 0.48302213696845187, 0.49205817172341115, 0.5270150071089016, 0.5447580773835046, 0.6226472050312687, 0.6349805161807821, 0.6599599239949302, 0.6637579955379869, 0.7219827350406843, 0.802762551279973, 0.8120185311491867, 0.837335454222923, 0.9786385378290976, 0.9807576556964709, 0.9893116874624179]))
julia> n = dim(P)
26
julia> a = rand(n,n);
julia> M = BSplineManifold(a,(P,P));
julia> gradient(M, 0.5, 0.4)
(-3.7331250515839387, 2.813433752317371)
julia> @benchmark gradient(M, 0.5, 0.4)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
Range (min … max): 129.736 μs … 6.399 ms ┊ GC (min … max): 0.00% … 92.22%
Time (median): 134.560 μs ┊ GC (median): 0.00%
Time (mean ± σ): 144.093 μs ± 210.939 μs ┊ GC (mean ± σ): 5.03% ± 3.36%
▁▄▆█▅▂
▁▂▂▄▆██████▇▅▄▃▃▃▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▂
130 μs Histogram: frequency by time 162 μs <
Memory estimate: 62.45 KiB, allocs estimate: 1073.
julia> @benchmark gradient($M, 0.5, 0.4)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
Range (min … max): 128.693 μs … 6.351 ms ┊ GC (min … max): 0.00% … 92.97%
Time (median): 134.986 μs ┊ GC (median): 0.00%
Time (mean ± σ): 144.185 μs ± 223.955 μs ┊ GC (mean ± σ): 5.57% ± 3.50%
▁▄▆█▇▇▆▃
▂▂▂▂▃▄▅██████████▇▅▅▅▄▄▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▂▂▂▂▂ ▃
129 μs Histogram: frequency by time 158 μs <
Memory estimate: 62.48 KiB, allocs estimate: 1074.The performance on speed seems able to be improved, but avoiding error is enough in this PR. |
|
With JuliaArrays/StaticArrays.jl#1068, the julia> using Zygote, BasicBSpline, BenchmarkTools, Random
julia> Random.seed!(42)
TaskLocalRNG()
julia> P = BSplineSpace{3}(KnotVector(rand(30)))
BSplineSpace{3, Float64}(KnotVector([0.02698475249996979, 0.030185550039208087, 0.05529822631508752, 0.07933388944734077, 0.09443144857141617, 0.16643864408566544, 0.1735745757945074, 0.18342915487003242, 0.2112283719951349, 0.25858546995315457, 0.32166161915780656, 0.3475083889757653, 0.3723466293304377, 0.3906633890864917, 0.4339142195043123, 0.48302213696845187, 0.49205817172341115, 0.5270150071089016, 0.5447580773835046, 0.6226472050312687, 0.6349805161807821, 0.6599599239949302, 0.6637579955379869, 0.7219827350406843, 0.802762551279973, 0.8120185311491867, 0.837335454222923, 0.9786385378290976, 0.9807576556964709, 0.9893116874624179]))
julia> n = dim(P)
26
julia> a = rand(n,n);
julia> M = BSplineManifold(a,(P,P));
julia> gradient(M, 0.5, 0.4)
(-3.7331250515839374, 2.8134337523173722)
julia> @benchmark gradient(M, 0.5, 0.4)
BenchmarkTools.Trial: 7001 samples with 1 evaluation.
Range (min … max): 653.094 μs … 8.426 ms ┊ GC (min … max): 0.00% … 86.14%
Time (median): 669.645 μs ┊ GC (median): 0.00%
Time (mean ± σ): 710.238 μs ± 526.281 μs ┊ GC (mean ± σ): 4.98% ± 6.13%
▂▇█▅▁
▂▃▆█████▆▄▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▂▂▂▂▂▂▂▁▂▂▂▁▂▁▁▁▂▁▂▁▂▂▁▂▁▂▂▂ ▃
653 μs Histogram: frequency by time 832 μs <
Memory estimate: 299.77 KiB, allocs estimate: 4304.
julia> @benchmark gradient($M, 0.5, 0.4)
BenchmarkTools.Trial: 6722 samples with 1 evaluation.
Range (min … max): 658.474 μs … 8.945 ms ┊ GC (min … max): 0.00% … 85.83%
Time (median): 680.636 μs ┊ GC (median): 0.00%
Time (mean ± σ): 739.736 μs ± 539.914 μs ┊ GC (mean ± σ): 4.85% ± 6.08%
▂▇█▇▅▂▁
▁▂▃▆███████▆▄▃▃▂▂▂▂▂▂▂▂▁▁▁▁▂▁▁▁▁▁▂▂▄▆███▇▆▅▄▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁ ▃
658 μs Histogram: frequency by time 787 μs <
Memory estimate: 299.80 KiB, allocs estimate: 4305. |
|
Before this PR julia> using Zygote, BasicBSpline, BenchmarkTools, Random
julia> Random.seed!(42)
TaskLocalRNG()
julia> P = BSplineSpace{3}(KnotVector(rand(30)))
BSplineSpace{3, Float64}(KnotVector([0.02698475249996979, 0.030185550039208087, 0.05529822631508752, 0.07933388944734077, 0.09443144857141617, 0.16643864408566544, 0.1735745757945074, 0.18342915487003242, 0.2112283719951349, 0.25858546995315457, 0.32166161915780656, 0.3475083889757653, 0.3723466293304377, 0.3906633890864917, 0.4339142195043123, 0.48302213696845187, 0.49205817172341115, 0.5270150071089016, 0.5447580773835046, 0.6226472050312687, 0.6349805161807821, 0.6599599239949302, 0.6637579955379869, 0.7219827350406843, 0.802762551279973, 0.8120185311491867, 0.837335454222923, 0.9786385378290976, 0.9807576556964709, 0.9893116874624179]))
julia> jacobian(t->bsplinebasisall(P,1,t), 0.2)
([-880.8310662693823, 1661.8743208018375, -1133.499200955729, 352.4559464232739],)
julia> @benchmark jacobian(t->bsplinebasisall(P,1,t), 0.2)
BenchmarkTools.Trial: 6775 samples with 1 evaluation.
Range (min … max): 679.325 μs … 6.652 ms ┊ GC (min … max): 0.00% … 83.42%
Time (median): 701.707 μs ┊ GC (median): 0.00%
Time (mean ± σ): 734.366 μs ± 406.467 μs ┊ GC (mean ± σ): 3.75% ± 6.01%
▁▃▅▆▇███▇▆▅▄▃ ▁ ▁ ▂
▅███████████████▇███████▇█▆▆▆▆▇▇██▇█▇█▆▆▇▆▃▆▅▃▅▄▃▁▄▅▃▁▁▁▁▁▁▁▃ █
679 μs Histogram: log(frequency) by time 833 μs <
Memory estimate: 226.14 KiB, allocs estimate: 4168.After this PR julia> using Zygote, BasicBSpline, BenchmarkTools, Random
julia> Random.seed!(42)
TaskLocalRNG()
julia> P = BSplineSpace{3}(KnotVector(rand(30)))
BSplineSpace{3, Float64}(KnotVector([0.02698475249996979, 0.030185550039208087, 0.05529822631508752, 0.07933388944734077, 0.09443144857141617, 0.16643864408566544, 0.1735745757945074, 0.18342915487003242, 0.2112283719951349, 0.25858546995315457, 0.32166161915780656, 0.3475083889757653, 0.3723466293304377, 0.3906633890864917, 0.4339142195043123, 0.48302213696845187, 0.49205817172341115, 0.5270150071089016, 0.5447580773835046, 0.6226472050312687, 0.6349805161807821, 0.6599599239949302, 0.6637579955379869, 0.7219827350406843, 0.802762551279973, 0.8120185311491867, 0.837335454222923, 0.9786385378290976, 0.9807576556964709, 0.9893116874624179]))
julia> jacobian(t->bsplinebasisall(P,1,t), 0.2)
([-880.8310662693823, 1661.8743208018377, -1133.4992009557293, 352.45594642327393],)
julia> @benchmark jacobian(t->bsplinebasisall(P,1,t), 0.2)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
Range (min … max): 40.166 μs … 6.136 ms ┊ GC (min … max): 0.00% … 98.45%
Time (median): 42.010 μs ┊ GC (median): 0.00%
Time (mean ± σ): 43.630 μs ± 85.562 μs ┊ GC (mean ± σ): 2.76% ± 1.40%
▃▅▇█▇▆▅▃
▁▁▂▃▄▄▇███████████▆▆▆▅▅▄▃▃▃▃▂▂▂▂▂▂▁▂▁▁▁▁▁▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▃
40.2 μs Histogram: frequency by time 48.4 μs <
Memory estimate: 8.61 KiB, allocs estimate: 277. |
This PR fixes #265.