Skip to content
38 changes: 29 additions & 9 deletions src/_ChainRules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@ const BSPLINESPACE_INFO = """
derivatives of B-spline basis functions with respect to BSplineSpace not implemented currently.
"""

function ChainRulesCore.frule((_, ΔP, Δi, Δt), ::typeof(bsplinebasis), P::BSplineSpace, i::Integer, t::Real)
# bsplinebasis
function ChainRulesCore.frule((_, ΔP, Δi, Δt), ::typeof(bsplinebasis), P::AbstractFunctionSpace, i::Integer, t::Real)
B = bsplinebasis(P,i,t)
∂B_∂P = @not_implemented BSPLINESPACE_INFO
# ∂B_∂i = NoTangent()
∂B_∂t = bsplinebasis′(P,i,t)
return (B, ∂B_∂P*ΔP + ∂B_∂t*Δt)
end

function ChainRulesCore.rrule(::typeof(bsplinebasis), P::BSplineSpace, i::Integer, t::Real)
function ChainRulesCore.rrule(::typeof(bsplinebasis), P::AbstractFunctionSpace, i::Integer, t::Real)
B = bsplinebasis(P,i,t)
# project_t = ProjectTo(t) # Not sure we need this ProjectTo.
function bsplinebasis_pullback(ΔB)
Expand All @@ -22,15 +22,15 @@ function ChainRulesCore.rrule(::typeof(bsplinebasis), P::BSplineSpace, i::Intege
return (B, bsplinebasis_pullback)
end

function ChainRulesCore.frule((_, ΔP, Δi, Δt), ::typeof(bsplinebasis₊₀), P::BSplineSpace, i::Integer, t::Real)
# bsplinebasis₊₀
function ChainRulesCore.frule((_, ΔP, Δi, Δt), ::typeof(bsplinebasis₊₀), P::AbstractFunctionSpace, i::Integer, t::Real)
B = bsplinebasis₊₀(P,i,t)
∂B_∂P = @not_implemented BSPLINESPACE_INFO
# ∂B_∂i = NoTangent()
∂B_∂t = bsplinebasis′₊₀(P,i,t)
return (B, ∂B_∂P*ΔP + ∂B_∂t*Δt)
end

function ChainRulesCore.rrule(::typeof(bsplinebasis₊₀), P::BSplineSpace, i::Integer, t::Real)
function ChainRulesCore.rrule(::typeof(bsplinebasis₊₀), P::AbstractFunctionSpace, i::Integer, t::Real)
B = bsplinebasis₊₀(P,i,t)
# project_t = ProjectTo(t) # Not sure we need this ProjectTo.
function bsplinebasis_pullback(ΔB)
Expand All @@ -42,15 +42,15 @@ function ChainRulesCore.rrule(::typeof(bsplinebasis₊₀), P::BSplineSpace, i::
return (B, bsplinebasis_pullback)
end

function ChainRulesCore.frule((_, ΔP, Δi, Δt), ::typeof(bsplinebasis₋₀), P::BSplineSpace, i::Integer, t::Real)
# bsplinebasis₋₀
function ChainRulesCore.frule((_, ΔP, Δi, Δt), ::typeof(bsplinebasis₋₀), P::AbstractFunctionSpace, i::Integer, t::Real)
B = bsplinebasis₋₀(P,i,t)
∂B_∂P = @not_implemented BSPLINESPACE_INFO
# ∂B_∂i = NoTangent()
∂B_∂t = bsplinebasis′₋₀(P,i,t)
return (B, ∂B_∂P*ΔP + ∂B_∂t*Δt)
end

function ChainRulesCore.rrule(::typeof(bsplinebasis₋₀), P::BSplineSpace, i::Integer, t::Real)
function ChainRulesCore.rrule(::typeof(bsplinebasis₋₀), P::AbstractFunctionSpace, i::Integer, t::Real)
B = bsplinebasis₋₀(P,i,t)
# project_t = ProjectTo(t) # Not sure we need this ProjectTo.
function bsplinebasis_pullback(ΔB)
Expand All @@ -61,3 +61,23 @@ function ChainRulesCore.rrule(::typeof(bsplinebasis₋₀), P::BSplineSpace, i::
end
return (B, bsplinebasis_pullback)
end

# bsplinebasisall
function ChainRulesCore.frule((_, ΔP, Δi, Δt), ::typeof(bsplinebasisall), P::AbstractFunctionSpace, i::Integer, t::Real)
B = bsplinebasisall(P,i,t)
∂B_∂P = @not_implemented BSPLINESPACE_INFO
# ∂B_∂i = NoTangent()
∂B_∂t = bsplinebasisall(derivative(P),i,t)
return (B, ∂B_∂P*ΔP + ∂B_∂t*Δt)
end
function ChainRulesCore.rrule(::typeof(bsplinebasisall), P::AbstractFunctionSpace, i::Integer, t::Real)
B = bsplinebasisall(P,i,t)
# project_t = ProjectTo(t) # Not sure we need this ProjectTo.
function bsplinebasis_pullback(ΔB)
P̄ = @not_implemented BSPLINESPACE_INFO
ī = NoTangent()
t̄ = bsplinebasisall(derivative(P),i,t)' * ΔB
return (NoTangent(), P̄, ī, t̄)
end
return (B, bsplinebasis_pullback)
end
2 changes: 2 additions & 0 deletions src/_DerivativeSpace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ exactdim(dP::BSplineDerivativeSpace{r,<:AbstractBSplineSpace{p}}) where {r,p} =
intervalindex(dP::BSplineDerivativeSpace,t::Real) = intervalindex(bsplinespace(dP),t)
domain(dP::BSplineDerivativeSpace) = domain(bsplinespace(dP))
_lower(dP::BSplineDerivativeSpace{r}) where r = BSplineDerivativeSpace{r-1}(_lower(bsplinespace(dP)))
derivative(P::BSplineSpace) = BSplineDerivativeSpace{1}(P)
derivative(dP::BSplineDerivativeSpace{r}) where r = BSplineDerivativeSpace{r+1}(bsplinespace(dP))

function Base.issubset(dP::BSplineDerivativeSpace{r,<:AbstractBSplineSpace{p}}, P′::AbstractBSplineSpace) where {r,p}
k = knotvector(dP)
Expand Down
31 changes: 20 additions & 11 deletions test/test_ChainRules.jl
Original file line number Diff line number Diff line change
@@ -1,17 +1,26 @@
@testset "ChainRules" begin
k = KnotVector(rand(12))
P = BSplineSpace{2}(k)
k = KnotVector(rand(20))
p = 3
P = BSplineSpace{p}(k)
dP0 = BSplineDerivativeSpace{0}(P)
dP1 = BSplineDerivativeSpace{1}(P)
dP2 = BSplineDerivativeSpace{2}(P)
@testset "bsplinebasis" begin
for _ in 1:10
for _P in (P, dP0, dP1, dP2), i in 1:dim(_P)
t = rand(domain(P))
for i in 1:dim(P)
test_frule(bsplinebasis, P, i, t)
test_rrule(bsplinebasis, P, i, t)
test_frule(bsplinebasis₊₀, P, i, t)
test_rrule(bsplinebasis₊₀, P, i, t)
test_frule(bsplinebasis₋₀, P, i, t)
test_rrule(bsplinebasis₋₀, P, i, t)
end
test_frule(bsplinebasis, _P, i, t)
test_rrule(bsplinebasis, _P, i, t)
test_frule(bsplinebasis₊₀, _P, i, t)
test_rrule(bsplinebasis₊₀, _P, i, t)
test_frule(bsplinebasis₋₀, _P, i, t)
test_rrule(bsplinebasis₋₀, _P, i, t)
end
end
@testset "bsplinebasisall" begin
for _P in (P, dP0, dP1, dP2), i in 1:length(k)-2p-1
t = rand(domain(P))
test_frule(bsplinebasisall, _P, i, t)
test_rrule(bsplinebasisall, _P, i, t)
end
end
end
3 changes: 3 additions & 0 deletions test/test_Derivative.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,11 @@
ts = rand(10)
for p in 0:5
P = BSplineSpace{p}(k)
Q = P
for r in 0:5
dP = BSplineDerivativeSpace{r}(P)
@test (dP == Q) || (r == 0)
Q = BasicBSpline.derivative(Q)
for t in ts
j = intervalindex(dP,t)
B = collect(bsplinebasisall(dP,j,t))
Expand Down