1- using TemporalGPs: build_lgssm, StorageType, is_of_storage_type
21using KernelFunctions
2+ using KernelFunctions: kappa
3+ using ChainRulesTestUtils
4+ using TemporalGPs: build_lgssm, StorageType, is_of_storage_type, lgssm_components
5+ using Test
36include (" ../test_util.jl" )
47include (" ../models/model_test_utils.jl" )
58_logistic (x) = 1 / (1 + exp (- x))
@@ -12,6 +15,34 @@ function _construction_tester(f_naive::GP, storage::StorageType, σ², t::Abstra
1215 return build_lgssm (fx)
1316end
1417
18+ @testset " ApproxPeriodicKernel" begin
19+ k = ApproxPeriodicKernel ()
20+ @test k isa ApproxPeriodicKernel{7 }
21+ # Test that it behaves like a normal PeriodicKernel
22+ k_base = PeriodicKernel ()
23+ x = rand ()
24+ @test kappa (k, x) == kappa (k_base, x)
25+ x = rand (3 )
26+ @test kernelmatrix (k, x) ≈ kernelmatrix (k_base, x)
27+ # Test dimensionality of LGSSM components
28+ Nt = 10
29+ @testset " $(typeof (t)) , $storage , $N " for t in (
30+ sort (rand (Nt)), RegularSpacing (0.0 , 0.1 , Nt)
31+ ),
32+ storage in (SArrayStorage {Float64} (), ArrayStorage {Float64} ()),
33+ N in (5 , 8 )
34+
35+ k = ApproxPeriodicKernel {N} ()
36+ As, as, Qs, emission_projections, x0 = lgssm_components (k, t, storage)
37+ @test length (As) == Nt
38+ @test all (x -> size (x) == (N * 2 , N * 2 ), As)
39+ @test length (as) == Nt
40+ @test all (x -> size (x) == (N * 2 ,), as)
41+ @test length (Qs) == Nt
42+ @test all (x -> size (x) == (N * 2 , N * 2 ), Qs)
43+ end
44+ end
45+
1546println (" lti_sde:" )
1647@testset " lti_sde" begin
1748 @testset " block_diagonal" begin
@@ -37,7 +68,11 @@ println("lti_sde:")
3768 )
3869
3970 kernels = [
40- Matern12Kernel (), Matern32Kernel (), Matern52Kernel (), ConstantKernel (; c= 1.5 )
71+ Matern12Kernel (),
72+ Matern32Kernel (),
73+ Matern52Kernel (),
74+ ConstantKernel (; c= 1.5 ),
75+ CosineKernel (),
4176 ]
4277
4378 @testset " $kernel , $(storage. name) " for kernel in kernels, storage in storages
@@ -56,53 +91,60 @@ println("lti_sde:")
5691 N = 13
5792 kernels = vcat (
5893 # Base kernels.
59- (name= " base-Matern12Kernel" , val= Matern12Kernel ()),
94+ (name= " base-Matern12Kernel" , val= Matern12Kernel (), to_vec_grad = false ),
6095 map ([Matern32Kernel, Matern52Kernel]) do k
61- (; name= " base-$k " , val= k ())
96+ (; name= " base-$k " , val= k (), to_vec_grad = false )
6297 end ,
6398
6499 # Scaled kernels.
65100 map ([1e-1 , 1.0 , 10.0 , 100.0 ]) do σ²
66- (; name= " scaled-σ²=$σ² " , val= σ² * Matern32Kernel ())
101+ (; name= " scaled-σ²=$σ² " , val= σ² * Matern32Kernel (), to_vec_grad = false )
67102 end ,
68103
69104 # Stretched kernels.
70105 map ([1e-2 , 0.1 , 1.0 , 10.0 , 100.0 ]) do λ
71- (; name= " stretched-λ=$λ " , val= Matern32Kernel () ∘ ScaleTransform (λ))
106+ (; name= " stretched-λ=$λ " , val= Matern32Kernel () ∘ ScaleTransform (λ), to_vec_grad = false )
72107 end ,
73108
109+ # Approx periodic kernels
110+ map ([7 , 11 ]) do N
111+ (
112+ name= " approx-periodic-N=$N " ,
113+ val= ApproxPeriodicKernel {N} (; r= 1.0 ),
114+ to_vec_grad= true ,
115+ )
116+ end ,
117+ # TEST_TOFIX
74118 # Gradients should be fixed on those composites.
75119 # Error is mostly due do an incompatibility of Tangents
76120 # between Zygote and FiniteDifferences.
77121
78122 # Product kernels
79123 (
80124 name= " prod-Matern12Kernel-Matern32Kernel" ,
81- val= 1.5 * Matern12Kernel () ∘ ScaleTransform (0.1 ) *
82- Matern32Kernel () ∘ ScaleTransform (1.1 ),
83- skip_grad = true ,
84- ),
85- (
125+ val= 1.5 * Matern12Kernel () ∘ ScaleTransform (0.1 ) * Matern32Kernel () ∘
126+ ScaleTransform (1.1 ),
127+ to_vec_grad = nothing ,
128+ ),
129+ (
86130 name= " prod-Matern32Kernel-Matern52Kernel-ConstantKernel" ,
87- val = 3.0 * Matern32Kernel () *
88- Matern52Kernel () *
89- ConstantKernel (),
90- skip_grad= true ,
131+ val= 3.0 * Matern32Kernel () * Matern52Kernel () * ConstantKernel (),
132+ to_vec_grad= nothing ,
91133 ),
92134
93135 # Summed kernels.
94136 (
95137 name= " sum-Matern12Kernel-Matern32Kernel" ,
96138 val= 1.5 * Matern12Kernel () ∘ ScaleTransform (0.1 ) +
97139 0.3 * Matern32Kernel () ∘ ScaleTransform (1.1 ),
98- skip_grad = true ,
99- ),
140+ to_vec_grad = nothing ,
141+ ),
100142 (
101143 name= " sum-Matern32Kernel-Matern52Kernel-ConstantKernel" ,
102- val = 2.0 * Matern32Kernel () +
144+ val= 2.0 * Matern32Kernel () +
103145 0.5 * Matern52Kernel () +
104146 1.0 * ConstantKernel (),
105- skip_grad = true ,
147+ to_vec_grad = nothing ,
106148 ),
107149 )
108150
@@ -126,14 +168,14 @@ println("lti_sde:")
126168 (name= " Custom Mean" , val= CustomMean (x -> 2 x)),
127169 )
128170
129- @testset " $(kernel. name) , $(m. name) , $(storage. name) , $(t. name) , $(σ². name) " for
130- kernel in kernels,
171+ @testset " $(kernel. name) , $(m. name) , $(storage. name) , $(t. name) , $(σ². name) " for kernel in
172+ kernels,
131173 m in means,
132174 storage in storages,
133175 t in ts,
134176 σ² in σ²s
135177
136- println (" $(kernel. name) , $(storage. name) , $(t. name) , $(σ². name) " )
178+ println (" $(kernel. name) , $(storage. name) , $(m . name) , $( t. name) , $(σ². name) " )
137179
138180 # Construct Gauss-Markov model.
139181 f_naive = GP (m. val, kernel. val)
@@ -174,7 +216,21 @@ println("lti_sde:")
174216 end
175217
176218 # Just need to ensure we can differentiate through construction properly.
177- if ! (hasfield (typeof (kernel), :skip_grad ) && kernel. skip_grad)
219+ if isnothing (kernel. to_vec_grad)
220+ @test_broken " Gradient tests are not passing"
221+ continue
222+ elseif kernel. to_vec_grad
223+ test_zygote_grad_finite_differences_compatible (
224+ _construction_tester,
225+ f_naive,
226+ storage. val,
227+ σ². val,
228+ t. val;
229+ check_inferred= false ,
230+ rtol= 1e-6 ,
231+ atol= 1e-6 ,
232+ )
233+ else
178234 test_zygote_grad (
179235 _construction_tester,
180236 f_naive,
0 commit comments