From 8db633164cb30081c3701f4c2d87e649904375ab Mon Sep 17 00:00:00 2001 From: hyrodium Date: Thu, 8 Sep 2022 17:30:17 +0900 Subject: [PATCH 01/11] update comments --- src/_ChainRules.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/_ChainRules.jl b/src/_ChainRules.jl index 81501148e..5128bc0df 100644 --- a/src/_ChainRules.jl +++ b/src/_ChainRules.jl @@ -2,6 +2,7 @@ const BSPLINESPACE_INFO = """ derivatives of B-spline basis functions with respect to BSplineSpace not implemented currently. """ +# bsplinebasis function ChainRulesCore.frule((_, ΔP, Δi, Δt), ::typeof(bsplinebasis), P::BSplineSpace, i::Integer, t::Real) B = bsplinebasis(P,i,t) ∂B_∂P = @not_implemented BSPLINESPACE_INFO @@ -9,7 +10,6 @@ function ChainRulesCore.frule((_, ΔP, Δi, Δt), ::typeof(bsplinebasis), P::BSp ∂B_∂t = bsplinebasis′(P,i,t) return (B, ∂B_∂P*ΔP + ∂B_∂t*Δt) end - function ChainRulesCore.rrule(::typeof(bsplinebasis), P::BSplineSpace, i::Integer, t::Real) B = bsplinebasis(P,i,t) # project_t = ProjectTo(t) # Not sure we need this ProjectTo. @@ -22,6 +22,7 @@ function ChainRulesCore.rrule(::typeof(bsplinebasis), P::BSplineSpace, i::Intege return (B, bsplinebasis_pullback) end +# bsplinebasis₊₀ function ChainRulesCore.frule((_, ΔP, Δi, Δt), ::typeof(bsplinebasis₊₀), P::BSplineSpace, i::Integer, t::Real) B = bsplinebasis₊₀(P,i,t) ∂B_∂P = @not_implemented BSPLINESPACE_INFO @@ -29,7 +30,6 @@ function ChainRulesCore.frule((_, ΔP, Δi, Δt), ::typeof(bsplinebasis₊₀), ∂B_∂t = bsplinebasis′₊₀(P,i,t) return (B, ∂B_∂P*ΔP + ∂B_∂t*Δt) end - function ChainRulesCore.rrule(::typeof(bsplinebasis₊₀), P::BSplineSpace, i::Integer, t::Real) B = bsplinebasis₊₀(P,i,t) # project_t = ProjectTo(t) # Not sure we need this ProjectTo. @@ -42,6 +42,7 @@ function ChainRulesCore.rrule(::typeof(bsplinebasis₊₀), P::BSplineSpace, i:: return (B, bsplinebasis_pullback) end +# bsplinebasis₋₀ function ChainRulesCore.frule((_, ΔP, Δi, Δt), ::typeof(bsplinebasis₋₀), P::BSplineSpace, i::Integer, t::Real) B = bsplinebasis₋₀(P,i,t) ∂B_∂P = @not_implemented BSPLINESPACE_INFO @@ -49,7 +50,6 @@ function ChainRulesCore.frule((_, ΔP, Δi, Δt), ::typeof(bsplinebasis₋₀), ∂B_∂t = bsplinebasis′₋₀(P,i,t) return (B, ∂B_∂P*ΔP + ∂B_∂t*Δt) end - function ChainRulesCore.rrule(::typeof(bsplinebasis₋₀), P::BSplineSpace, i::Integer, t::Real) B = bsplinebasis₋₀(P,i,t) # project_t = ProjectTo(t) # Not sure we need this ProjectTo. From e6de6e5a1f8131cf3a2f48a65fc20ac080925a72 Mon Sep 17 00:00:00 2001 From: hyrodium Date: Thu, 8 Sep 2022 17:31:14 +0900 Subject: [PATCH 02/11] add test_frule and test_rrule for bsplinebasisall --- test/test_ChainRules.jl | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/test/test_ChainRules.jl b/test/test_ChainRules.jl index e50d3f81a..1272a3163 100644 --- a/test/test_ChainRules.jl +++ b/test/test_ChainRules.jl @@ -1,6 +1,7 @@ @testset "ChainRules" begin k = KnotVector(rand(12)) - P = BSplineSpace{2}(k) + p = 2 + P = BSplineSpace{p}(k) @testset "bsplinebasis" begin for _ in 1:10 t = rand(domain(P)) @@ -12,6 +13,10 @@ test_frule(bsplinebasis₋₀, P, i, t) test_rrule(bsplinebasis₋₀, P, i, t) end + for i in 1:length(k)-2p-1 + test_frule(bsplinebasisall, P, i, t) + test_rrule(bsplinebasisall, P, i, t) + end end end end From c8bbccfb0c193b47d15de3121a95e4c0ca8366dc Mon Sep 17 00:00:00 2001 From: hyrodium Date: Thu, 8 Sep 2022 18:04:55 +0900 Subject: [PATCH 03/11] add derivative function --- src/_DerivativeSpace.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/_DerivativeSpace.jl b/src/_DerivativeSpace.jl index bcdd1c01e..058fbf605 100644 --- a/src/_DerivativeSpace.jl +++ b/src/_DerivativeSpace.jl @@ -40,6 +40,8 @@ exactdim(dP::BSplineDerivativeSpace{r,<:AbstractBSplineSpace{p}}) where {r,p} = intervalindex(dP::BSplineDerivativeSpace,t::Real) = intervalindex(bsplinespace(dP),t) domain(dP::BSplineDerivativeSpace) = domain(bsplinespace(dP)) _lower(dP::BSplineDerivativeSpace{r}) where r = BSplineDerivativeSpace{r-1}(_lower(bsplinespace(dP))) +derivative(P::BSplineSpace) = BSplineDerivativeSpace{1}(P) +derivative(dP::BSplineDerivativeSpace{r}) where r = BSplineDerivativeSpace{r-1}(bsplinespace(dP)) function Base.issubset(dP::BSplineDerivativeSpace{r,<:AbstractBSplineSpace{p}}, P′::AbstractBSplineSpace) where {r,p} k = knotvector(dP) From 245ebe3922afd88173ca7e1279e1f0923b191ed6 Mon Sep 17 00:00:00 2001 From: hyrodium Date: Thu, 8 Sep 2022 18:05:15 +0900 Subject: [PATCH 04/11] add frule and rrule for bsplinebasisall --- src/_ChainRules.jl | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/src/_ChainRules.jl b/src/_ChainRules.jl index 5128bc0df..bdfa2bae4 100644 --- a/src/_ChainRules.jl +++ b/src/_ChainRules.jl @@ -61,3 +61,23 @@ function ChainRulesCore.rrule(::typeof(bsplinebasis₋₀), P::BSplineSpace, i:: end return (B, bsplinebasis_pullback) end + +# bsplinebasisall +function ChainRulesCore.frule((_, ΔP, Δi, Δt), ::typeof(bsplinebasisall), P::BSplineSpace, i::Integer, t::Real) + B = bsplinebasisall(P,i,t) + ∂B_∂P = @not_implemented BSPLINESPACE_INFO + # ∂B_∂i = NoTangent() + ∂B_∂t = bsplinebasisall(derivative(P),i,t) + return (B, ∂B_∂P*ΔP + ∂B_∂t*Δt) +end +function ChainRulesCore.rrule(::typeof(bsplinebasisall), P::BSplineSpace, i::Integer, t::Real) + B = bsplinebasisall(P,i,t) + # project_t = ProjectTo(t) # Not sure we need this ProjectTo. + function bsplinebasis_pullback(ΔB) + P̄ = @not_implemented BSPLINESPACE_INFO + ī = NoTangent() + t̄ = bsplinebasisall(derivative(P),i,t)' * ΔB + return (NoTangent(), P̄, ī, t̄) + end + return (B, bsplinebasis_pullback) +end From 81ea6d5451a0c807cd9e227c7d515ab9e547e9a7 Mon Sep 17 00:00:00 2001 From: hyrodium Date: Thu, 8 Sep 2022 18:19:10 +0900 Subject: [PATCH 05/11] add tests for derivative --- test/test_Derivative.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/test_Derivative.jl b/test/test_Derivative.jl index ddc08d2f2..5336d5d1c 100644 --- a/test/test_Derivative.jl +++ b/test/test_Derivative.jl @@ -88,8 +88,11 @@ ts = rand(10) for p in 0:5 P = BSplineSpace{p}(k) + Q = P for r in 0:5 + Q = BasicBSpline.derivative(Q) dP = BSplineDerivativeSpace{r}(P) + dP == Q for t in ts j = intervalindex(dP,t) B = collect(bsplinebasisall(dP,j,t)) From bbfb7f7fb92545ba8eaf9e03cc17c11ad14dc66e Mon Sep 17 00:00:00 2001 From: hyrodium Date: Thu, 8 Sep 2022 18:46:18 +0900 Subject: [PATCH 06/11] update chain rule tests for BSplineDerivativeSpace --- test/test_ChainRules.jl | 32 ++++++++++++++++++++------------ 1 file changed, 20 insertions(+), 12 deletions(-) diff --git a/test/test_ChainRules.jl b/test/test_ChainRules.jl index 1272a3163..246c9132e 100644 --- a/test/test_ChainRules.jl +++ b/test/test_ChainRules.jl @@ -1,21 +1,29 @@ @testset "ChainRules" begin - k = KnotVector(rand(12)) - p = 2 + k = KnotVector(rand(20)) + p = 3 P = BSplineSpace{p}(k) + dP0 = BSplineDerivativeSpace{0}(P) + dP1 = BSplineDerivativeSpace{1}(P) + dP2 = BSplineDerivativeSpace{2}(P) @testset "bsplinebasis" begin for _ in 1:10 t = rand(domain(P)) - for i in 1:dim(P) - test_frule(bsplinebasis, P, i, t) - test_rrule(bsplinebasis, P, i, t) - test_frule(bsplinebasis₊₀, P, i, t) - test_rrule(bsplinebasis₊₀, P, i, t) - test_frule(bsplinebasis₋₀, P, i, t) - test_rrule(bsplinebasis₋₀, P, i, t) + for _P in (P, dP0, dP1, dP2), i in 1:dim(_P) + test_frule(bsplinebasis, _P, i, t) + test_rrule(bsplinebasis, _P, i, t) + test_frule(bsplinebasis₊₀, _P, i, t) + test_rrule(bsplinebasis₊₀, _P, i, t) + test_frule(bsplinebasis₋₀, _P, i, t) + test_rrule(bsplinebasis₋₀, _P, i, t) end - for i in 1:length(k)-2p-1 - test_frule(bsplinebasisall, P, i, t) - test_rrule(bsplinebasisall, P, i, t) + end + end + @testset "bsplinebasisall" begin + for _ in 1:10 + t = rand(domain(_P)) + for _P in (P, dP0, dP1, dP2), i in 1:length(k)-2p-1 + test_frule(bsplinebasisall, _P, i, t) + test_rrule(bsplinebasisall, _P, i, t) end end end From a60fb5694127e3ccd79bcf00f59304d709cdae1f Mon Sep 17 00:00:00 2001 From: hyrodium Date: Thu, 8 Sep 2022 19:00:10 +0900 Subject: [PATCH 07/11] update tests for chain rules --- test/test_ChainRules.jl | 26 +++++++++++--------------- 1 file changed, 11 insertions(+), 15 deletions(-) diff --git a/test/test_ChainRules.jl b/test/test_ChainRules.jl index 246c9132e..8956a5c2d 100644 --- a/test/test_ChainRules.jl +++ b/test/test_ChainRules.jl @@ -6,25 +6,21 @@ dP1 = BSplineDerivativeSpace{1}(P) dP2 = BSplineDerivativeSpace{2}(P) @testset "bsplinebasis" begin - for _ in 1:10 + for _P in (P, dP0, dP1, dP2), i in 1:dim(_P) t = rand(domain(P)) - for _P in (P, dP0, dP1, dP2), i in 1:dim(_P) - test_frule(bsplinebasis, _P, i, t) - test_rrule(bsplinebasis, _P, i, t) - test_frule(bsplinebasis₊₀, _P, i, t) - test_rrule(bsplinebasis₊₀, _P, i, t) - test_frule(bsplinebasis₋₀, _P, i, t) - test_rrule(bsplinebasis₋₀, _P, i, t) - end + test_frule(bsplinebasis, _P, i, t) + test_rrule(bsplinebasis, _P, i, t) + test_frule(bsplinebasis₊₀, _P, i, t) + test_rrule(bsplinebasis₊₀, _P, i, t) + test_frule(bsplinebasis₋₀, _P, i, t) + test_rrule(bsplinebasis₋₀, _P, i, t) end end @testset "bsplinebasisall" begin - for _ in 1:10 - t = rand(domain(_P)) - for _P in (P, dP0, dP1, dP2), i in 1:length(k)-2p-1 - test_frule(bsplinebasisall, _P, i, t) - test_rrule(bsplinebasisall, _P, i, t) - end + for _P in (P, dP0, dP1, dP2), i in 1:length(k)-2p-1 + t = rand(domain(P)) + test_frule(bsplinebasisall, _P, i, t) + test_rrule(bsplinebasisall, _P, i, t) end end end From 049d50ae20d1dfa52a31f8bee94605ddb305fb2f Mon Sep 17 00:00:00 2001 From: hyrodium Date: Thu, 8 Sep 2022 19:00:44 +0900 Subject: [PATCH 08/11] fix bug in derivative --- src/_DerivativeSpace.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/_DerivativeSpace.jl b/src/_DerivativeSpace.jl index 058fbf605..5fc77f292 100644 --- a/src/_DerivativeSpace.jl +++ b/src/_DerivativeSpace.jl @@ -41,7 +41,7 @@ intervalindex(dP::BSplineDerivativeSpace,t::Real) = intervalindex(bsplinespace(d domain(dP::BSplineDerivativeSpace) = domain(bsplinespace(dP)) _lower(dP::BSplineDerivativeSpace{r}) where r = BSplineDerivativeSpace{r-1}(_lower(bsplinespace(dP))) derivative(P::BSplineSpace) = BSplineDerivativeSpace{1}(P) -derivative(dP::BSplineDerivativeSpace{r}) where r = BSplineDerivativeSpace{r-1}(bsplinespace(dP)) +derivative(dP::BSplineDerivativeSpace{r}) where r = BSplineDerivativeSpace{r+1}(bsplinespace(dP)) function Base.issubset(dP::BSplineDerivativeSpace{r,<:AbstractBSplineSpace{p}}, P′::AbstractBSplineSpace) where {r,p} k = knotvector(dP) From 958266822fe8e9ab1cc88d5d595d6206ec292bf7 Mon Sep 17 00:00:00 2001 From: hyrodium Date: Thu, 8 Sep 2022 19:01:27 +0900 Subject: [PATCH 09/11] add support for chain rules with BSplineDerivativeSpace --- src/_ChainRules.jl | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/_ChainRules.jl b/src/_ChainRules.jl index bdfa2bae4..0470dbab6 100644 --- a/src/_ChainRules.jl +++ b/src/_ChainRules.jl @@ -3,14 +3,14 @@ derivatives of B-spline basis functions with respect to BSplineSpace not impleme """ # bsplinebasis -function ChainRulesCore.frule((_, ΔP, Δi, Δt), ::typeof(bsplinebasis), P::BSplineSpace, i::Integer, t::Real) +function ChainRulesCore.frule((_, ΔP, Δi, Δt), ::typeof(bsplinebasis), P::AbstractFunctionSpace, i::Integer, t::Real) B = bsplinebasis(P,i,t) ∂B_∂P = @not_implemented BSPLINESPACE_INFO # ∂B_∂i = NoTangent() ∂B_∂t = bsplinebasis′(P,i,t) return (B, ∂B_∂P*ΔP + ∂B_∂t*Δt) end -function ChainRulesCore.rrule(::typeof(bsplinebasis), P::BSplineSpace, i::Integer, t::Real) +function ChainRulesCore.rrule(::typeof(bsplinebasis), P::AbstractFunctionSpace, i::Integer, t::Real) B = bsplinebasis(P,i,t) # project_t = ProjectTo(t) # Not sure we need this ProjectTo. function bsplinebasis_pullback(ΔB) @@ -23,14 +23,14 @@ function ChainRulesCore.rrule(::typeof(bsplinebasis), P::BSplineSpace, i::Intege end # bsplinebasis₊₀ -function ChainRulesCore.frule((_, ΔP, Δi, Δt), ::typeof(bsplinebasis₊₀), P::BSplineSpace, i::Integer, t::Real) +function ChainRulesCore.frule((_, ΔP, Δi, Δt), ::typeof(bsplinebasis₊₀), P::AbstractFunctionSpace, i::Integer, t::Real) B = bsplinebasis₊₀(P,i,t) ∂B_∂P = @not_implemented BSPLINESPACE_INFO # ∂B_∂i = NoTangent() ∂B_∂t = bsplinebasis′₊₀(P,i,t) return (B, ∂B_∂P*ΔP + ∂B_∂t*Δt) end -function ChainRulesCore.rrule(::typeof(bsplinebasis₊₀), P::BSplineSpace, i::Integer, t::Real) +function ChainRulesCore.rrule(::typeof(bsplinebasis₊₀), P::AbstractFunctionSpace, i::Integer, t::Real) B = bsplinebasis₊₀(P,i,t) # project_t = ProjectTo(t) # Not sure we need this ProjectTo. function bsplinebasis_pullback(ΔB) @@ -43,14 +43,14 @@ function ChainRulesCore.rrule(::typeof(bsplinebasis₊₀), P::BSplineSpace, i:: end # bsplinebasis₋₀ -function ChainRulesCore.frule((_, ΔP, Δi, Δt), ::typeof(bsplinebasis₋₀), P::BSplineSpace, i::Integer, t::Real) +function ChainRulesCore.frule((_, ΔP, Δi, Δt), ::typeof(bsplinebasis₋₀), P::AbstractFunctionSpace, i::Integer, t::Real) B = bsplinebasis₋₀(P,i,t) ∂B_∂P = @not_implemented BSPLINESPACE_INFO # ∂B_∂i = NoTangent() ∂B_∂t = bsplinebasis′₋₀(P,i,t) return (B, ∂B_∂P*ΔP + ∂B_∂t*Δt) end -function ChainRulesCore.rrule(::typeof(bsplinebasis₋₀), P::BSplineSpace, i::Integer, t::Real) +function ChainRulesCore.rrule(::typeof(bsplinebasis₋₀), P::AbstractFunctionSpace, i::Integer, t::Real) B = bsplinebasis₋₀(P,i,t) # project_t = ProjectTo(t) # Not sure we need this ProjectTo. function bsplinebasis_pullback(ΔB) @@ -63,14 +63,14 @@ function ChainRulesCore.rrule(::typeof(bsplinebasis₋₀), P::BSplineSpace, i:: end # bsplinebasisall -function ChainRulesCore.frule((_, ΔP, Δi, Δt), ::typeof(bsplinebasisall), P::BSplineSpace, i::Integer, t::Real) +function ChainRulesCore.frule((_, ΔP, Δi, Δt), ::typeof(bsplinebasisall), P::AbstractFunctionSpace, i::Integer, t::Real) B = bsplinebasisall(P,i,t) ∂B_∂P = @not_implemented BSPLINESPACE_INFO # ∂B_∂i = NoTangent() ∂B_∂t = bsplinebasisall(derivative(P),i,t) return (B, ∂B_∂P*ΔP + ∂B_∂t*Δt) end -function ChainRulesCore.rrule(::typeof(bsplinebasisall), P::BSplineSpace, i::Integer, t::Real) +function ChainRulesCore.rrule(::typeof(bsplinebasisall), P::AbstractFunctionSpace, i::Integer, t::Real) B = bsplinebasisall(P,i,t) # project_t = ProjectTo(t) # Not sure we need this ProjectTo. function bsplinebasis_pullback(ΔB) From 9f16b7a5a0816f54cfcc8fa33db507ebee89c706 Mon Sep 17 00:00:00 2001 From: hyrodium Date: Thu, 8 Sep 2022 19:04:21 +0900 Subject: [PATCH 10/11] fix test for derivative --- test/test_Derivative.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_Derivative.jl b/test/test_Derivative.jl index 5336d5d1c..7eecbe792 100644 --- a/test/test_Derivative.jl +++ b/test/test_Derivative.jl @@ -92,7 +92,7 @@ for r in 0:5 Q = BasicBSpline.derivative(Q) dP = BSplineDerivativeSpace{r}(P) - dP == Q + @test dP == Q for t in ts j = intervalindex(dP,t) B = collect(bsplinebasisall(dP,j,t)) From d10cd8607355ef3e661c5391d8151c467dc3963e Mon Sep 17 00:00:00 2001 From: hyrodium Date: Thu, 8 Sep 2022 19:08:55 +0900 Subject: [PATCH 11/11] fix test for derivative --- test/test_Derivative.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_Derivative.jl b/test/test_Derivative.jl index 7eecbe792..2e64bfa4e 100644 --- a/test/test_Derivative.jl +++ b/test/test_Derivative.jl @@ -90,9 +90,9 @@ P = BSplineSpace{p}(k) Q = P for r in 0:5 - Q = BasicBSpline.derivative(Q) dP = BSplineDerivativeSpace{r}(P) - @test dP == Q + @test (dP == Q) || (r == 0) + Q = BasicBSpline.derivative(Q) for t in ts j = intervalindex(dP,t) B = collect(bsplinebasisall(dP,j,t))