diff --git a/src/_ChainRules.jl b/src/_ChainRules.jl index 81501148e..0470dbab6 100644 --- a/src/_ChainRules.jl +++ b/src/_ChainRules.jl @@ -2,15 +2,15 @@ const BSPLINESPACE_INFO = """ derivatives of B-spline basis functions with respect to BSplineSpace not implemented currently. """ -function ChainRulesCore.frule((_, ΔP, Δi, Δt), ::typeof(bsplinebasis), P::BSplineSpace, i::Integer, t::Real) +# bsplinebasis +function ChainRulesCore.frule((_, ΔP, Δi, Δt), ::typeof(bsplinebasis), P::AbstractFunctionSpace, i::Integer, t::Real) B = bsplinebasis(P,i,t) ∂B_∂P = @not_implemented BSPLINESPACE_INFO # ∂B_∂i = NoTangent() ∂B_∂t = bsplinebasis′(P,i,t) return (B, ∂B_∂P*ΔP + ∂B_∂t*Δt) end - -function ChainRulesCore.rrule(::typeof(bsplinebasis), P::BSplineSpace, i::Integer, t::Real) +function ChainRulesCore.rrule(::typeof(bsplinebasis), P::AbstractFunctionSpace, i::Integer, t::Real) B = bsplinebasis(P,i,t) # project_t = ProjectTo(t) # Not sure we need this ProjectTo. function bsplinebasis_pullback(ΔB) @@ -22,15 +22,15 @@ function ChainRulesCore.rrule(::typeof(bsplinebasis), P::BSplineSpace, i::Intege return (B, bsplinebasis_pullback) end -function ChainRulesCore.frule((_, ΔP, Δi, Δt), ::typeof(bsplinebasis₊₀), P::BSplineSpace, i::Integer, t::Real) +# bsplinebasis₊₀ +function ChainRulesCore.frule((_, ΔP, Δi, Δt), ::typeof(bsplinebasis₊₀), P::AbstractFunctionSpace, i::Integer, t::Real) B = bsplinebasis₊₀(P,i,t) ∂B_∂P = @not_implemented BSPLINESPACE_INFO # ∂B_∂i = NoTangent() ∂B_∂t = bsplinebasis′₊₀(P,i,t) return (B, ∂B_∂P*ΔP + ∂B_∂t*Δt) end - -function ChainRulesCore.rrule(::typeof(bsplinebasis₊₀), P::BSplineSpace, i::Integer, t::Real) +function ChainRulesCore.rrule(::typeof(bsplinebasis₊₀), P::AbstractFunctionSpace, i::Integer, t::Real) B = bsplinebasis₊₀(P,i,t) # project_t = ProjectTo(t) # Not sure we need this ProjectTo. function bsplinebasis_pullback(ΔB) @@ -42,15 +42,15 @@ function ChainRulesCore.rrule(::typeof(bsplinebasis₊₀), P::BSplineSpace, i:: return (B, bsplinebasis_pullback) end -function ChainRulesCore.frule((_, ΔP, Δi, Δt), ::typeof(bsplinebasis₋₀), P::BSplineSpace, i::Integer, t::Real) +# bsplinebasis₋₀ +function ChainRulesCore.frule((_, ΔP, Δi, Δt), ::typeof(bsplinebasis₋₀), P::AbstractFunctionSpace, i::Integer, t::Real) B = bsplinebasis₋₀(P,i,t) ∂B_∂P = @not_implemented BSPLINESPACE_INFO # ∂B_∂i = NoTangent() ∂B_∂t = bsplinebasis′₋₀(P,i,t) return (B, ∂B_∂P*ΔP + ∂B_∂t*Δt) end - -function ChainRulesCore.rrule(::typeof(bsplinebasis₋₀), P::BSplineSpace, i::Integer, t::Real) +function ChainRulesCore.rrule(::typeof(bsplinebasis₋₀), P::AbstractFunctionSpace, i::Integer, t::Real) B = bsplinebasis₋₀(P,i,t) # project_t = ProjectTo(t) # Not sure we need this ProjectTo. function bsplinebasis_pullback(ΔB) @@ -61,3 +61,23 @@ function ChainRulesCore.rrule(::typeof(bsplinebasis₋₀), P::BSplineSpace, i:: end return (B, bsplinebasis_pullback) end + +# bsplinebasisall +function ChainRulesCore.frule((_, ΔP, Δi, Δt), ::typeof(bsplinebasisall), P::AbstractFunctionSpace, i::Integer, t::Real) + B = bsplinebasisall(P,i,t) + ∂B_∂P = @not_implemented BSPLINESPACE_INFO + # ∂B_∂i = NoTangent() + ∂B_∂t = bsplinebasisall(derivative(P),i,t) + return (B, ∂B_∂P*ΔP + ∂B_∂t*Δt) +end +function ChainRulesCore.rrule(::typeof(bsplinebasisall), P::AbstractFunctionSpace, i::Integer, t::Real) + B = bsplinebasisall(P,i,t) + # project_t = ProjectTo(t) # Not sure we need this ProjectTo. + function bsplinebasis_pullback(ΔB) + P̄ = @not_implemented BSPLINESPACE_INFO + ī = NoTangent() + t̄ = bsplinebasisall(derivative(P),i,t)' * ΔB + return (NoTangent(), P̄, ī, t̄) + end + return (B, bsplinebasis_pullback) +end diff --git a/src/_DerivativeSpace.jl b/src/_DerivativeSpace.jl index bcdd1c01e..5fc77f292 100644 --- a/src/_DerivativeSpace.jl +++ b/src/_DerivativeSpace.jl @@ -40,6 +40,8 @@ exactdim(dP::BSplineDerivativeSpace{r,<:AbstractBSplineSpace{p}}) where {r,p} = intervalindex(dP::BSplineDerivativeSpace,t::Real) = intervalindex(bsplinespace(dP),t) domain(dP::BSplineDerivativeSpace) = domain(bsplinespace(dP)) _lower(dP::BSplineDerivativeSpace{r}) where r = BSplineDerivativeSpace{r-1}(_lower(bsplinespace(dP))) +derivative(P::BSplineSpace) = BSplineDerivativeSpace{1}(P) +derivative(dP::BSplineDerivativeSpace{r}) where r = BSplineDerivativeSpace{r+1}(bsplinespace(dP)) function Base.issubset(dP::BSplineDerivativeSpace{r,<:AbstractBSplineSpace{p}}, P′::AbstractBSplineSpace) where {r,p} k = knotvector(dP) diff --git a/test/test_ChainRules.jl b/test/test_ChainRules.jl index e50d3f81a..8956a5c2d 100644 --- a/test/test_ChainRules.jl +++ b/test/test_ChainRules.jl @@ -1,17 +1,26 @@ @testset "ChainRules" begin - k = KnotVector(rand(12)) - P = BSplineSpace{2}(k) + k = KnotVector(rand(20)) + p = 3 + P = BSplineSpace{p}(k) + dP0 = BSplineDerivativeSpace{0}(P) + dP1 = BSplineDerivativeSpace{1}(P) + dP2 = BSplineDerivativeSpace{2}(P) @testset "bsplinebasis" begin - for _ in 1:10 + for _P in (P, dP0, dP1, dP2), i in 1:dim(_P) t = rand(domain(P)) - for i in 1:dim(P) - test_frule(bsplinebasis, P, i, t) - test_rrule(bsplinebasis, P, i, t) - test_frule(bsplinebasis₊₀, P, i, t) - test_rrule(bsplinebasis₊₀, P, i, t) - test_frule(bsplinebasis₋₀, P, i, t) - test_rrule(bsplinebasis₋₀, P, i, t) - end + test_frule(bsplinebasis, _P, i, t) + test_rrule(bsplinebasis, _P, i, t) + test_frule(bsplinebasis₊₀, _P, i, t) + test_rrule(bsplinebasis₊₀, _P, i, t) + test_frule(bsplinebasis₋₀, _P, i, t) + test_rrule(bsplinebasis₋₀, _P, i, t) + end + end + @testset "bsplinebasisall" begin + for _P in (P, dP0, dP1, dP2), i in 1:length(k)-2p-1 + t = rand(domain(P)) + test_frule(bsplinebasisall, _P, i, t) + test_rrule(bsplinebasisall, _P, i, t) end end end diff --git a/test/test_Derivative.jl b/test/test_Derivative.jl index ddc08d2f2..2e64bfa4e 100644 --- a/test/test_Derivative.jl +++ b/test/test_Derivative.jl @@ -88,8 +88,11 @@ ts = rand(10) for p in 0:5 P = BSplineSpace{p}(k) + Q = P for r in 0:5 dP = BSplineDerivativeSpace{r}(P) + @test (dP == Q) || (r == 0) + Q = BasicBSpline.derivative(Q) for t in ts j = intervalindex(dP,t) B = collect(bsplinebasisall(dP,j,t))