Skip to content

Commit e100352

Browse files
charlesknippyebai
andauthored
Update SSMProblems interface (#119)
* Integrate new SSMProblems interface * Fix unit tests * Removed type piracy * Fix examples and streamline Levy SSM * Formatter * Fix Literate Errors * Remove compat entry for Documenter * Increase HTML size threshold * Suppress output in Levy model * Fix update rate calculation * Update Project.toml --------- Co-authored-by: Hong Ge <[email protected]>
1 parent 3cb7ec8 commit e100352

File tree

12 files changed

+170
-169
lines changed

12 files changed

+170
-169
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "AdvancedPS"
22
uuid = "576499cb-2369-40b2-a588-c64705576edc"
33
authors = ["TuringLang"]
4-
version = "0.7"
4+
version = "0.7.1"
55

66
[deps]
77
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
@@ -26,7 +26,7 @@ Random = "<0.0.1, 1"
2626
Random123 = "1.3"
2727
Requires = "1.0"
2828
StatsFuns = "0.9, 1"
29-
SSMProblems = "0.5"
29+
SSMProblems = "0.6"
3030
julia = "1.10.8"
3131

3232
[extras]

docs/Project.toml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,3 @@
22
AdvancedPS = "576499cb-2369-40b2-a588-c64705576edc"
33
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
44
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
5-
6-
[compat]
7-
Documenter = "0.27"

docs/make.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ DocMeta.setdocmeta!(AdvancedPS, :DocTestSetup, :(using AdvancedPS); recursive=tr
4848

4949
makedocs(;
5050
sitename="AdvancedPS",
51-
format=Documenter.HTML(),
51+
format=Documenter.HTML(; size_threshold=1000 * 2^11), # 1Mb per page
5252
modules=[AdvancedPS],
5353
pages=[
5454
"Home" => "index.md",

examples/gaussian-process/script.jl

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -8,59 +8,59 @@ using Distributions
88
using Libtask
99
using SSMProblems
1010

11-
struct GaussianProcessDynamics{T<:Real,KT<:Kernel} <: LatentDynamics{T,T}
11+
struct GaussianProcessDynamics{T<:Real,KT<:Kernel} <: SSMProblems.LatentDynamics
1212
proc::GP{ZeroMean{T},KT}
1313
function GaussianProcessDynamics(::Type{T}, kernel::KT) where {T<:Real,KT<:Kernel}
1414
return new{T,KT}(GP(ZeroMean{T}(), kernel))
1515
end
1616
end
1717

18-
struct LinearGaussianDynamics{T<:Real} <: LatentDynamics{T,T}
19-
a::T
20-
b::T
21-
q::T
18+
struct GaussianPrior{ΣT<:Real} <: SSMProblems.StatePrior
19+
σ::ΣT
2220
end
2321

24-
function SSMProblems.distribution(proc::LinearGaussianDynamics{T}) where {T<:Real}
25-
return Normal(zero(T), proc.q)
22+
SSMProblems.distribution(proc::GaussianPrior) = Normal(0, proc.σ)
23+
24+
struct LinearGaussianDynamics{AT<:Real,BT<:Real,QT<:Real} <: SSMProblems.LatentDynamics
25+
a::AT
26+
b::BT
27+
q::QT
2628
end
2729

2830
function SSMProblems.distribution(proc::LinearGaussianDynamics, ::Int, state)
2931
return Normal(proc.a * state + proc.b, proc.q)
3032
end
3133

32-
struct StochasticVolatility{T<:Real} <: ObservationProcess{T,T} end
34+
struct StochasticVolatility <: SSMProblems.ObservationProcess end
3335

34-
function SSMProblems.distribution(::StochasticVolatility{T}, ::Int, state) where {T<:Real}
35-
return Normal(zero(T), exp((1 / 2) * state))
36+
function SSMProblems.distribution(::StochasticVolatility, ::Int, state)
37+
return Normal(0, exp(state / 2))
3638
end
3739

38-
function LinearGaussianStochasticVolatilityModel(a::T, q::T) where {T<:Real}
39-
dyn = LinearGaussianDynamics(a, zero(T), q)
40-
obs = StochasticVolatility{T}()
41-
return SSMProblems.StateSpaceModel(dyn, obs)
40+
function LinearGaussianStochasticVolatilityModel(a, q)
41+
prior = GaussianPrior(q)
42+
dyn = LinearGaussianDynamics(a, 0, q)
43+
obs = StochasticVolatility()
44+
return SSMProblems.StateSpaceModel(prior, dyn, obs)
4245
end
4346

4447
function GaussianProcessStateSpaceModel(::Type{T}, kernel::KT) where {T<:Real,KT<:Kernel}
48+
prior = GaussianPrior(one(T))
4549
dyn = GaussianProcessDynamics(T, kernel)
46-
obs = StochasticVolatility{T}()
47-
return SSMProblems.StateSpaceModel(dyn, obs)
50+
obs = StochasticVolatility()
51+
return SSMProblems.StateSpaceModel(prior, dyn, obs)
4852
end
4953

5054
const GPSSM{T,KT<:Kernel} = SSMProblems.StateSpaceModel{
51-
T,
52-
GaussianProcessDynamics{T,KT},
53-
StochasticVolatility{T}
55+
<:GaussianPrior,<:GaussianProcessDynamics{T,KT},StochasticVolatility
5456
};
5557

5658
# for non-markovian models, we can redefine dynamics to reference the trajectory
57-
function AdvancedPS.dynamics(
58-
ssm::AdvancedPS.TracedSSM{<:GPSSM{T},T,T}, step::Int
59-
) where {T<:Real}
59+
function AdvancedPS.dynamics(ssm::AdvancedPS.TracedSSM{<:GPSSM}, step::Int)
6060
prior = ssm.model.dyn.proc(1:(step - 1))
61-
post = posterior(prior, ssm.X[1:(step - 1)])
62-
μ, σ = mean_and_cov(post, [step])
63-
return LinearGaussianDynamics(zero(T), μ[1], sqrt(σ[1]))
61+
post = posterior(prior, ssm.X[1:(step - 1)])
62+
μ, σ = mean_and_cov(post, [step])
63+
return LinearGaussianDynamics(0, μ[1], sqrt(σ[1]))
6464
end
6565

6666
# Everything is now ready to simulate some data.
@@ -70,9 +70,9 @@ _, x, y = sample(rng, true_model, 100);
7070

7171
# Create the model and run the sampler
7272
gpssm = GaussianProcessStateSpaceModel(Float64, SqExponentialKernel());
73-
model = gpssm(y);
73+
model = AdvancedPS.TracedSSM(gpssm, y);
7474
pg = AdvancedPS.PGAS(20);
75-
chains = sample(rng, model, pg, 250; progress=false);
75+
chains = sample(rng, model, pg, 250);
7676
#md nothing #hide
7777

7878
particles = hcat([chain.trajectory.model.X for chain in chains]...);

examples/gaussian-ssm/script.jl

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -28,27 +28,31 @@ using SSMProblems
2828
# as well as the initial distribution $f_0(x) = \mathcal{N}(0, q^2/(1-a^2))$.
2929

3030
# To use `AdvancedPS` we first need to define a model type that subtypes `AdvancedPS.AbstractStateSpaceModel`.
31-
mutable struct Parameters{T<:Real}
32-
a::T
33-
q::T
34-
r::T
31+
mutable struct Parameters{AT<:Real,QT<:Real,RT<:Real}
32+
a::AT
33+
q::QT
34+
r::RT
3535
end
3636

37-
struct LinearGaussianDynamics{T<:Real} <: SSMProblems.LatentDynamics{T,T}
38-
a::T
39-
q::T
37+
struct GaussianPrior{ΣT<:Real} <: SSMProblems.StatePrior
38+
σ::ΣT
4039
end
4140

42-
function SSMProblems.distribution(dyn::LinearGaussianDynamics{T}; kwargs...) where {T<:Real}
43-
return Normal(zero(T), sqrt(dyn.q^2 / (1 - dyn.a^2)))
41+
struct LinearGaussianDynamics{AT<:Real,QT<:Real} <: SSMProblems.LatentDynamics
42+
a::AT
43+
q::QT
44+
end
45+
46+
function SSMProblems.distribution(prior::GaussianPrior; kwargs...)
47+
return Normal(0, prior.σ)
4448
end
4549

4650
function SSMProblems.distribution(dyn::LinearGaussianDynamics, step::Int, state; kwargs...)
4751
return Normal(dyn.a * state, dyn.q)
4852
end
4953

50-
struct LinearGaussianObservation{T<:Real} <: SSMProblems.ObservationProcess{T,T}
51-
r::T
54+
struct LinearGaussianObservation{RT<:Real} <: SSMProblems.ObservationProcess
55+
r::RT
5256
end
5357

5458
function SSMProblems.distribution(
@@ -58,9 +62,10 @@ function SSMProblems.distribution(
5862
end
5963

6064
function LinearGaussianStateSpaceModel::Parameters)
65+
prior = GaussianPrior(sqrt.q^2 / (1 - θ.a^2)))
6166
dyn = LinearGaussianDynamics.a, θ.q)
6267
obs = LinearGaussianObservation.r)
63-
return SSMProblems.StateSpaceModel(dyn, obs)
68+
return SSMProblems.StateSpaceModel(prior, dyn, obs)
6469
end
6570

6671
# Everything is now ready to simulate some data.
@@ -75,8 +80,9 @@ plot!(y; seriestype=:scatter, label="y", xlabel="t", mc=:red, ms=2, ma=0.5)
7580

7681
# `AdvancedPS` subscribes to the `AbstractMCMC` API. To sample we just need to define a Particle Gibbs kernel
7782
# and a model interface.
78-
pgas = AdvancedPS.PGAS(20)
79-
chains = sample(rng, true_model(y), pgas, 500; progress=false);
83+
N = 20
84+
pgas = AdvancedPS.PGAS(N)
85+
chains = sample(rng, AdvancedPS.TracedSSM(true_model, y), pgas, 500; progress=false);
8086
#md nothing #hide
8187

8288
#
@@ -104,4 +110,4 @@ plot(
104110
xlabel="Iteration",
105111
ylabel="Update rate",
106112
)
107-
hline!([1 - 1 / length(chains)]; label="N: $(length(chains))")
113+
hline!([1 - 1 / N]; label="N: $(N)")

examples/levy-ssm/script.jl

Lines changed: 43 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -27,17 +27,16 @@ function simulate(
2727
t = t0
2828
truncated = last_jump < tolerance
2929
while !truncated
30-
t += rand(rng, Exponential(one(T) / rate))
31-
xi = one(T) /* (exp(t / C) - one(T)))
32-
prob = (one(T) + β * xi) * exp(-β * xi)
30+
t += rand(rng, Exponential(1 / rate))
31+
xi = 1 /* (exp(t / C) - 1))
32+
prob = (1 + β * xi) * exp(-β * xi)
3333
if rand(rng) < prob
3434
push!(jumps, xi)
3535
last_jump = xi
3636
end
3737
truncated = last_jump < tolerance
3838
end
39-
times = rand(rng, Uniform(start, finish), length(jumps))
40-
return GammaPath(jumps, times)
39+
return GammaPath(jumps, rand(rng, Uniform(start, finish), length(jumps)))
4140
end
4241
end
4342

@@ -47,85 +46,66 @@ function integral(times::Array{<:Real}, path::GammaPath)
4746
end
4847
end
4948

50-
struct LangevinDynamics{T}
51-
A::Matrix{T}
52-
L::Vector{T}
53-
θ::T
54-
H::Vector{T}
55-
σe::T
49+
struct LangevinDynamics{AT<:AbstractMatrix,LT<:AbstractVector,θT<:Real}
50+
A::AT
51+
L::LT
52+
θ::θT
5653
end
5754

58-
struct NormalMeanVariance{T}
59-
μ::T
60-
σ::T
55+
function Base.exp(dyn::LangevinDynamics, dt)
56+
f_val = exp(dyn.θ * dt)
57+
return [1 (f_val - 1)/dyn.θ; 0 f_val]
6158
end
6259

63-
f(dt, θ) = exp* dt)
64-
function Base.exp(dyn::LangevinDynamics{T}, dt::T) where {T<:Real}
65-
let θ = dyn.θ
66-
f_val = f(dt, θ)
67-
return [one(T) (f_val - 1)/θ; zero(T) f_val]
68-
end
60+
function meancov(t, dyn::LangevinDynamics, path::GammaPath, dist::Normal)
61+
fts = exp.(Ref(dyn), (t .- path.times)) .* Ref(dyn.L)
62+
μ = sum(@. fts * mean(dist) * path.jumps)
63+
Σ = sum(@. fts * transpose(fts) * var(dist) * path.jumps)
64+
return μ, Σ + eltype(Σ)(1e-6) * I
6965
end
7066

71-
function meancov(
72-
t::T, dyn::LangevinDynamics, path::GammaPath, nvm::NormalMeanVariance
73-
) where {T<:Real}
74-
μ = zeros(T, 2)
75-
Σ = zeros(T, (2, 2))
76-
let times = path.times, jumps = path.jumps, μw = nvm.μ, σw = nvm.σ
77-
for (v, z) in zip(times, jumps)
78-
ft = exp(dyn, (t - v)) * dyn.L
79-
μ += ft .* μw .* z
80-
Σ += ft * transpose(ft) .* σw^2 .* z
81-
end
82-
83-
# Guarantees positive semi-definiteness
84-
return μ, Σ + T(1e-6) * I
85-
end
67+
struct LevyPrior{XT<:AbstractVector,ΣT<:AbstractMatrix} <: StatePrior
68+
μ::XT
69+
Σ::ΣT
8670
end
8771

88-
struct LevyLangevin{T} <: LatentDynamics{T,Vector{T}}
89-
dt::T
90-
dyn::LangevinDynamics{T}
91-
process::GammaProcess{T}
92-
nvm::NormalMeanVariance{T}
93-
end
72+
SSMProblems.distribution(proc::LevyPrior) = MvNormal(proc.μ, proc.Σ)
9473

95-
function SSMProblems.distribution(proc::LevyLangevin{T}) where {T<:Real}
96-
return MultivariateNormal(zeros(T, 2), I)
74+
struct LevyLangevin{T<:Real,LT<:LangevinDynamics,ΓT<:GammaProcess,DT<:Normal} <:
75+
SSMProblems.LatentDynamics
76+
dt::T
77+
dyn::LT
78+
process::ΓT
79+
dist::DT
9780
end
9881

99-
function SSMProblems.distribution(proc::LevyLangevin{T}, step::Int, state) where {T<:Real}
82+
function SSMProblems.distribution(proc::LevyLangevin, step::Int, state)
10083
dt = proc.dt
10184
path = simulate(rng, proc.process, dt, (step - 1) * dt, step * dt)
102-
μ, Σ = meancov(step * dt, proc.dyn, path, proc.nvm)
103-
return MultivariateNormal(exp(proc.dyn, dt) * state + μ, Σ)
85+
μ, Σ = meancov(step * dt, proc.dyn, path, proc.dist)
86+
return MvNormal(exp(proc.dyn, dt) * state + μ, Σ)
10487
end
10588

106-
struct LinearGaussianObservation{T<:Real} <: ObservationProcess{T,T}
107-
H::Vector{T}
108-
R::T
89+
struct LinearGaussianObservation{HT<:AbstractVector,RT<:Real} <:
90+
SSMProblems.ObservationProcess
91+
H::HT
92+
R::RT
10993
end
11094

111-
function SSMProblems.distribution(proc::LinearGaussianObservation, step::Int, state)
95+
function SSMProblems.distribution(proc::LinearGaussianObservation, ::Int, state)
11296
return Normal(transpose(proc.H) * state, proc.R)
11397
end
11498

115-
function LevyModel(dt, θ, σe, C, β, μw, σw; ϵ=1e-10)
116-
A = [0.0 1.0; 0.0 θ]
117-
L = [0.0; 1.0]
118-
H = [1.0, 0]
119-
99+
function LevyModel(dt, θ, σe, C, β, μw, σw; kwargs...)
120100
dyn = LevyLangevin(
121101
dt,
122-
LangevinDynamics(A, L, θ, H, σe),
123-
GammaProcess(C, β; ϵ),
124-
NormalMeanVariance(μw, σw),
102+
LangevinDynamics([0 1; 0 θ], [0; 1], θ),
103+
GammaProcess(C, β; kwargs...),
104+
Normal(μw, σw),
125105
)
126106

127-
obs = LinearGaussianObservation(H, σe)
128-
return StateSpaceModel(dyn, obs)
107+
obs = LinearGaussianObservation([1; 0], σe)
108+
return SSMProblems.StateSpaceModel(LevyPrior(zeros(Bool, 2), I(2)), dyn, obs)
129109
end
130110

131111
# Levy SSM with Langevin dynamics
@@ -139,18 +119,18 @@ end
139119
# Simulation parameters
140120
N = 200
141121
ts = range(0, 100; length=N)
142-
levyssm = LevyModel(step(ts), θ, 1.0, 1.0, 1.0, 0.0, 1.0);
122+
levyssm = LevyModel(step(ts), -0.5, 1, 1.0, 1.0, 0, 1);
143123

144124
# Simulate data
145125
rng = Random.MersenneTwister(1234);
146126
_, X, Y = sample(rng, levyssm, N);
147127

148128
# Run sampler
149129
pg = AdvancedPS.PGAS(50);
150-
chains = sample(rng, levyssm(Y), pg, 100);
130+
chains = sample(rng, AdvancedPS.TracedSSM(levyssm, Y), pg, 100; progress=false);
151131

152132
# Concat all sampled states
153-
marginal_states = hcat([chain.trajectory.model.X for chain in chains]...)
133+
marginal_states = hcat([chain.trajectory.model.X for chain in chains]...);
154134

155135
# Plot marginal state and jump intensities for one trajectory
156136
p1 = plot(
@@ -166,7 +146,6 @@ plot!(
166146
label="Marginal State (x2)",
167147
)
168148

169-
# TODO: collect jumps from the model
170149
p2 = scatter([], []; color=:darkorange, label="Jumps")
171150

172151
plot(

0 commit comments

Comments
 (0)