Skip to content

Commit a931ac0

Browse files
authored
Merge pull request #98 from JuliaGaussianProcesses/tgf/flexible-kernel
Add ApproximatePeriodicKernel
2 parents 7a6dc8f + 718b1bd commit a931ac0

File tree

7 files changed

+182
-28
lines changed

7 files changed

+182
-28
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
name = "TemporalGPs"
22
uuid = "e155a3c4-0841-43e1-8b83-a0e4f03cc18f"
33
authors = ["willtebbutt <[email protected]> and contributors"]
4-
version = "0.6.4"
4+
version = "0.6.5"
55

66
[deps]
77
AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918"
8+
Bessels = "0e736298-9ec6-45e8-9647-e4fc86a2fe38"
89
BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0"
910
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
1011
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"

src/TemporalGPs.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module TemporalGPs
22

33
using AbstractGPs
4+
using Bessels: besseli
45
using BlockDiagonals
56
using ChainRulesCore
67
import ChainRulesCore: rrule
@@ -31,7 +32,8 @@ module TemporalGPs
3132
checkpointed,
3233
posterior,
3334
logpdf_and_rand,
34-
Separable
35+
Separable,
36+
ApproxPeriodicKernel
3537

3638
# Various bits-and-bobs. Often commiting some type piracy.
3739
include(joinpath("util", "harmonise.jl"))

src/gp/lti_sde.jl

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,93 @@ function stationary_distribution(::Matern52Kernel, ::SArrayStorage{T}) where {T<
241241
return Gaussian(m, P)
242242
end
243243

244+
# Cosine
245+
246+
function to_sde(::CosineKernel, ::SArrayStorage{T}) where {T}
247+
F = SMatrix{2, 2, T}(0, 1, -1, 0)
248+
q = zero(T)
249+
H = SVector{2, T}(1, 0)
250+
return F, q, H
251+
end
252+
253+
function stationary_distribution(::CosineKernel, ::SArrayStorage{T}) where {T<:Real}
254+
m = SVector{2, T}(0, 0)
255+
P = SMatrix{2, 2, T}(1, 0, 0, 1)
256+
return Gaussian(m, P)
257+
end
258+
259+
# Approximate Periodic Kernel
260+
# The periodic kernel is approximated by a sum of cosine kernels with different frequencies.
261+
struct ApproxPeriodicKernel{N,K<:PeriodicKernel} <: KernelFunctions.SimpleKernel
262+
kernel::K
263+
function ApproxPeriodicKernel{N,K}(kernel::K) where {N,K<:PeriodicKernel}
264+
length(kernel.r) == 1 || error("ApproxPeriodicKernel only supports a single lengthscale")
265+
return new{N,K}(kernel)
266+
end
267+
end
268+
# We follow "State Space approximation of Gaussian Processes for time series forecasting"
269+
# by Alessio Benavoli and Giorgio Corani and use a default of 7 Cosine Kernel terms
270+
ApproxPeriodicKernel(;r::Real=1.0) = ApproxPeriodicKernel{7}(PeriodicKernel(;r=[r]))
271+
ApproxPeriodicKernel{N}(;r::Real=1.0) where {N} = ApproxPeriodicKernel{N}(PeriodicKernel(;r=[r]))
272+
ApproxPeriodicKernel(kernel::PeriodicKernel) = ApproxPeriodicKernel{7}(kernel)
273+
ApproxPeriodicKernel{N}(kernel::K) where {N,K<:PeriodicKernel} = ApproxPeriodicKernel{N,K}(kernel)
274+
275+
KernelFunctions.kappa(k::ApproxPeriodicKernel, x) = KernelFunctions.kappa(k.kernel, x)
276+
KernelFunctions.metric(k::ApproxPeriodicKernel) = KernelFunctions.metric(k.kernel)
277+
278+
function Base.show(io::IO, κ::ApproxPeriodicKernel{N}) where {N}
279+
return print(io, "Approximate Periodic Kernel, (r = $(only.kernel.r))) approximated with $N cosine kernels")
280+
end
281+
282+
function lgssm_components(approx::ApproxPeriodicKernel{N}, t::Union{StepRangeLen, RegularSpacing}, storage::StorageType{T}) where {N,T<:Real}
283+
Fs, Hs, ms, Ps = _init_periodic_kernel_lgssm(approx.kernel, storage, N)
284+
nt = length(t)
285+
As = map(F -> Fill(time_exp(F, T(step(t))), nt), Fs)
286+
return _reduce_sum_cosine_kernel_lgssm(As, Hs, ms, Ps, N, nt, T)
287+
end
288+
function lgssm_components(approx::ApproxPeriodicKernel{N}, t::AbstractVector{<:Real}, storage::StorageType{T}) where {N,T<:Real}
289+
Fs, Hs, ms, Ps = _init_periodic_kernel_lgssm(approx.kernel, storage, N)
290+
t = vcat([first(t) - 1], t)
291+
nt = length(diff(t))
292+
As = _map(F -> _map(Δt -> time_exp(F, T(Δt)), diff(t)), Fs)
293+
return _reduce_sum_cosine_kernel_lgssm(As, Hs, ms, Ps, N, nt, T)
294+
end
295+
296+
function _init_periodic_kernel_lgssm(kernel::PeriodicKernel, storage, N::Int=7)
297+
r = kernel.r
298+
l⁻² = inv(4 * only(r)^2)
299+
300+
F, _, H = to_sde(CosineKernel(), storage)
301+
Fs = ntuple(N) do i
302+
2π * (i - 1) * F
303+
end
304+
Hs = Fill(H, N)
305+
306+
x0 = stationary_distribution(CosineKernel(), storage)
307+
ms = Fill(x0.m, N)
308+
P = x0.P
309+
Ps = ntuple(N) do j
310+
qⱼ = (1 + (j !== 1) ) * besseli(j - 1, l⁻²) / exp(l⁻²)
311+
qⱼ * P
312+
end
313+
314+
Fs, Hs, ms, Ps
315+
end
316+
317+
function _reduce_sum_cosine_kernel_lgssm(As, Hs, ms, Ps, N, nt, T)
318+
as = Fill(Fill(Zeros{T}(size(first(first(As)), 1)), nt), N)
319+
Qs = _map((P, A) -> _map(A -> Symmetric(P) - A * Symmetric(P) * A', A), Ps, As)
320+
Hs = Fill(vcat(Hs...), nt)
321+
h = Fill(zero(T), nt)
322+
As = _map(block_diagonal, As...)
323+
as = -map(vcat, as...)
324+
Qs = _map(block_diagonal, Qs...)
325+
m = reduce(vcat, ms)
326+
P = block_diagonal(Ps...)
327+
x0 = Gaussian(m, P)
328+
return As, as, Qs, (Hs, h), x0
329+
end
330+
244331
# Constant
245332

246333
function TemporalGPs.to_sde(::ConstantKernel, ::SArrayStorage{T}) where {T<:Real}

src/models/lgssm.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -267,8 +267,7 @@ ChainRulesCore.@non_differentiable ident_eps(args...)
267267

268268
_collect(U::Adjoint{<:Any, <:Matrix}) = collect(U)
269269
_collect(U::SMatrix) = U
270-
271-
270+
_collect(U::BlockDiagonal) = U
272271

273272
# AD stuff. No need to understand this unless you're really plumbing the depths...
274273

src/util/chainrules.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ function rrule(::Type{<:Fill}, x, sz)
147147
Fill_rrule::Union{Fill,Thunk}) = NoTangent(), FillArrays.getindex_value(unthunk(Δ)), NoTangent()
148148
Fill_rrule::Tangent{T,<:NamedTuple{(:value, :axes)}}) where {T} = NoTangent(), Δ.value, NoTangent()
149149
Fill_rrule(::AbstractZero) = NoTangent(), NoTangent(), NoTangent()
150+
Fill_rrule::Tangent{T,<:NTuple}) where {T} = NoTangent(), sum(Δ), NoTangent()
150151
function Fill_rrule::AbstractArray)
151152
# all(==(first(Δ)), Δ) || error("Δ should be a vector of the same value")
152153
# sum(Δ)

test/gp/lti_sde.jl

Lines changed: 79 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1-
using TemporalGPs: build_lgssm, StorageType, is_of_storage_type
21
using KernelFunctions
2+
using KernelFunctions: kappa
3+
using ChainRulesTestUtils
4+
using TemporalGPs: build_lgssm, StorageType, is_of_storage_type, lgssm_components
5+
using Test
36
include("../test_util.jl")
47
include("../models/model_test_utils.jl")
58
_logistic(x) = 1 / (1 + exp(-x))
@@ -12,6 +15,34 @@ function _construction_tester(f_naive::GP, storage::StorageType, σ², t::Abstra
1215
return build_lgssm(fx)
1316
end
1417

18+
@testset "ApproxPeriodicKernel" begin
19+
k = ApproxPeriodicKernel()
20+
@test k isa ApproxPeriodicKernel{7}
21+
# Test that it behaves like a normal PeriodicKernel
22+
k_base = PeriodicKernel()
23+
x = rand()
24+
@test kappa(k, x) == kappa(k_base, x)
25+
x = rand(3)
26+
@test kernelmatrix(k, x) kernelmatrix(k_base, x)
27+
# Test dimensionality of LGSSM components
28+
Nt = 10
29+
@testset "$(typeof(t)), $storage, $N" for t in (
30+
sort(rand(Nt)), RegularSpacing(0.0, 0.1, Nt)
31+
),
32+
storage in (SArrayStorage{Float64}(), ArrayStorage{Float64}()),
33+
N in (5, 8)
34+
35+
k = ApproxPeriodicKernel{N}()
36+
As, as, Qs, emission_projections, x0 = lgssm_components(k, t, storage)
37+
@test length(As) == Nt
38+
@test all(x -> size(x) == (N * 2, N * 2), As)
39+
@test length(as) == Nt
40+
@test all(x -> size(x) == (N * 2,), as)
41+
@test length(Qs) == Nt
42+
@test all(x -> size(x) == (N * 2, N * 2), Qs)
43+
end
44+
end
45+
1546
println("lti_sde:")
1647
@testset "lti_sde" begin
1748
@testset "block_diagonal" begin
@@ -37,7 +68,11 @@ println("lti_sde:")
3768
)
3869

3970
kernels = [
40-
Matern12Kernel(), Matern32Kernel(), Matern52Kernel(), ConstantKernel(; c=1.5)
71+
Matern12Kernel(),
72+
Matern32Kernel(),
73+
Matern52Kernel(),
74+
ConstantKernel(; c=1.5),
75+
CosineKernel(),
4176
]
4277

4378
@testset "$kernel, $(storage.name)" for kernel in kernels, storage in storages
@@ -56,53 +91,60 @@ println("lti_sde:")
5691
N = 13
5792
kernels = vcat(
5893
# Base kernels.
59-
(name="base-Matern12Kernel", val=Matern12Kernel()),
94+
(name="base-Matern12Kernel", val=Matern12Kernel(), to_vec_grad=false),
6095
map([Matern32Kernel, Matern52Kernel]) do k
61-
(; name="base-$k", val=k())
96+
(; name="base-$k", val=k(), to_vec_grad=false)
6297
end,
6398

6499
# Scaled kernels.
65100
map([1e-1, 1.0, 10.0, 100.0]) do σ²
66-
(; name="scaled-σ²=$σ²", val=σ² * Matern32Kernel())
101+
(; name="scaled-σ²=$σ²", val=σ² * Matern32Kernel(), to_vec_grad=false)
67102
end,
68103

69104
# Stretched kernels.
70105
map([1e-2, 0.1, 1.0, 10.0, 100.0]) do λ
71-
(; name="stretched-λ=", val=Matern32Kernel() ScaleTransform(λ))
106+
(; name="stretched-λ=", val=Matern32Kernel() ScaleTransform(λ), to_vec_grad=false)
72107
end,
73108

109+
# Approx periodic kernels
110+
map([7, 11]) do N
111+
(
112+
name="approx-periodic-N=$N",
113+
val=ApproxPeriodicKernel{N}(; r=1.0),
114+
to_vec_grad=true,
115+
)
116+
end,
117+
# TEST_TOFIX
74118
# Gradients should be fixed on those composites.
75119
# Error is mostly due do an incompatibility of Tangents
76120
# between Zygote and FiniteDifferences.
77121

78122
# Product kernels
79123
(
80124
name="prod-Matern12Kernel-Matern32Kernel",
81-
val=1.5 * Matern12Kernel() ScaleTransform(0.1) *
82-
Matern32Kernel() ScaleTransform(1.1),
83-
skip_grad=true,
84-
),
85-
(
125+
val=1.5 * Matern12Kernel() ScaleTransform(0.1) * Matern32Kernel()
126+
ScaleTransform(1.1),
127+
to_vec_grad=nothing,
128+
),
129+
(
86130
name="prod-Matern32Kernel-Matern52Kernel-ConstantKernel",
87-
val = 3.0 * Matern32Kernel() *
88-
Matern52Kernel() *
89-
ConstantKernel(),
90-
skip_grad=true,
131+
val=3.0 * Matern32Kernel() * Matern52Kernel() * ConstantKernel(),
132+
to_vec_grad=nothing,
91133
),
92134

93135
# Summed kernels.
94136
(
95137
name="sum-Matern12Kernel-Matern32Kernel",
96138
val=1.5 * Matern12Kernel() ScaleTransform(0.1) +
97139
0.3 * Matern32Kernel() ScaleTransform(1.1),
98-
skip_grad=true,
99-
),
140+
to_vec_grad=nothing,
141+
),
100142
(
101143
name="sum-Matern32Kernel-Matern52Kernel-ConstantKernel",
102-
val = 2.0 * Matern32Kernel() +
144+
val=2.0 * Matern32Kernel() +
103145
0.5 * Matern52Kernel() +
104146
1.0 * ConstantKernel(),
105-
skip_grad=true,
147+
to_vec_grad=nothing,
106148
),
107149
)
108150

@@ -126,14 +168,14 @@ println("lti_sde:")
126168
(name="Custom Mean", val=CustomMean(x -> 2x)),
127169
)
128170

129-
@testset "$(kernel.name), $(m.name), $(storage.name), $(t.name), $(σ².name)" for
130-
kernel in kernels,
171+
@testset "$(kernel.name), $(m.name), $(storage.name), $(t.name), $(σ².name)" for kernel in
172+
kernels,
131173
m in means,
132174
storage in storages,
133175
t in ts,
134176
σ² in σ²s
135177

136-
println("$(kernel.name), $(storage.name), $(t.name), $(σ².name)")
178+
println("$(kernel.name), $(storage.name), $(m.name), $(t.name), $(σ².name)")
137179

138180
# Construct Gauss-Markov model.
139181
f_naive = GP(m.val, kernel.val)
@@ -174,7 +216,21 @@ println("lti_sde:")
174216
end
175217

176218
# Just need to ensure we can differentiate through construction properly.
177-
if !(hasfield(typeof(kernel), :skip_grad) && kernel.skip_grad)
219+
if isnothing(kernel.to_vec_grad)
220+
@test_broken "Gradient tests are not passing"
221+
continue
222+
elseif kernel.to_vec_grad
223+
test_zygote_grad_finite_differences_compatible(
224+
_construction_tester,
225+
f_naive,
226+
storage.val,
227+
σ².val,
228+
t.val;
229+
check_inferred=false,
230+
rtol=1e-6,
231+
atol=1e-6,
232+
)
233+
else
178234
test_zygote_grad(
179235
_construction_tester,
180236
f_naive,

test/test_util.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ test_zygote_grad(f, args...; check_inferred=false, kwargs...) = test_rrule(Zygot
4646
function test_zygote_grad_finite_differences_compatible(f, args...; kwargs...)
4747
x_vec, from_vec = to_vec(args)
4848
function finite_diff_compatible_f(x::AbstractVector)
49-
return @ignore_derivatives(f)(from_vec(x)...)
49+
return @ignore_derivatives(f)(@ignore_derivatives(from_vec)(x)...)
5050
end
5151
test_zygote_grad(finite_diff_compatible_f NoTangent(), x_vec; testset_name="test_rrule: $(f) on $(typeof.(args))", kwargs...)
5252
end
@@ -134,6 +134,14 @@ function ChainRulesTestUtils.test_approx(actual::Tangent{<:Fill}, expected, msg=
134134
test_approx(actual.value, expected.value, msg; kwargs...)
135135
end
136136

137+
function to_vec(x::PeriodicKernel)
138+
x, to_r = to_vec(x.r)
139+
function PeriodicKernel_from_vec(x)
140+
return PeriodicKernel(;r=exp.(to_r(x)))
141+
end
142+
log.(x), PeriodicKernel_from_vec
143+
end
144+
137145
to_vec(x::T) where {T} = generic_struct_to_vec(x)
138146

139147
# This is a copy from FiniteDifferences.jl without the try catch

0 commit comments

Comments
 (0)