Skip to content

Commit 2880bd3

Browse files
SSMProblems integration (#97)
* SSMProblems - Draft * Tests * Format, GP-SSM * Add levy-ssm * Align timesteps * Update Project.toml --------- Co-authored-by: Hong Ge <[email protected]>
1 parent 1e5dfdd commit 2880bd3

File tree

18 files changed

+395
-90
lines changed

18 files changed

+395
-90
lines changed

Project.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
name = "AdvancedPS"
22
uuid = "576499cb-2369-40b2-a588-c64705576edc"
33
authors = ["TuringLang"]
4-
version = "0.5.4"
4+
version = "0.6"
55

66
[deps]
77
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
88
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
99
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1010
Random123 = "74087812-796a-5b5d-8853-05524746bad3"
1111
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
12+
SSMProblems = "26aad666-b158-4e64-9d35-0e672562fa48"
1213
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
1314

1415
[weakdeps]
@@ -21,10 +22,10 @@ AdvancedPSLibtaskExt = "Libtask"
2122
AbstractMCMC = "2, 3, 4, 5"
2223
Distributions = "0.23, 0.24, 0.25"
2324
Libtask = "0.8"
25+
Random = "1.6"
2426
Random123 = "1.3"
2527
Requires = "1.0"
2628
StatsFuns = "0.9, 1"
27-
Random = "1.6"
2829
julia = "1.6"
2930

3031
[extras]

examples/gaussian-process/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
55
Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f"
66
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
77
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
8+
SSMProblems = "26aad666-b158-4e64-9d35-0e672562fa48"

examples/gaussian-process/script.jl

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,21 @@ using AbstractGPs
66
using Plots
77
using Distributions
88
using Libtask
9+
using SSMProblems
910

1011
Parameters = @NamedTuple begin
1112
a::Float64
1213
q::Float64
1314
kernel
1415
end
1516

16-
mutable struct GPSSM <: AdvancedPS.AbstractStateSpaceModel
17+
mutable struct GPSSM <: SSMProblems.AbstractStateSpaceModel
1718
X::Vector{Float64}
19+
observations::Vector{Float64}
1820
θ::Parameters
1921

2022
GPSSM(params::Parameters) = new(Vector{Float64}(), params)
23+
GPSSM(y::Vector{Float64}, params::Parameters) = new(Vector{Float64}(), y, params)
2124
end
2225

2326
seed = 1
@@ -29,21 +32,20 @@ q = 0.5
2932

3033
params = Parameters((a, q, SqExponentialKernel()))
3134

32-
f(model::GPSSM, x, t) = Normal(model.θ.a * x, model.θ.q)
33-
h(model::GPSSM) = Normal(0, model.θ.q)
34-
g(model::GPSSM, x, t) = Normal(0, exp(0.5 * x)^2)
35+
f(θ::Parameters, x, t) = Normal.a * x, θ.q)
36+
h(θ::Parameters) = Normal(0, θ.q)
37+
g(θ::Parameters, x, t) = Normal(0, exp(0.5 * x)^2)
3538

3639
rng = Random.MersenneTwister(seed)
37-
ref_model = GPSSM(params)
3840

3941
x = zeros(T)
4042
y = similar(x)
41-
x[1] = rand(rng, h(ref_model))
43+
x[1] = rand(rng, h(params))
4244
for t in 1:T
4345
if t < T
44-
x[t + 1] = rand(rng, f(ref_model, x[t], t))
46+
x[t + 1] = rand(rng, f(params, x[t], t))
4547
end
46-
y[t] = rand(rng, g(ref_model, x[t], t))
48+
y[t] = rand(rng, g(params, x[t], t))
4749
end
4850

4951
function gp_update(model::GPSSM, state, step)
@@ -54,12 +56,21 @@ function gp_update(model::GPSSM, state, step)
5456
return Normal(μ[1], σ[1])
5557
end
5658

57-
AdvancedPS.initialization(::GPSSM) = h(model)
58-
AdvancedPS.transition(model::GPSSM, state, step) = gp_update(model, state, step)
59-
AdvancedPS.observation(model::GPSSM, state, step) = logpdf(g(model, state, step), y[step])
59+
SSMProblems.transition!!(rng::AbstractRNG, model::GPSSM) = rand(rng, h(model.θ))
60+
function SSMProblems.transition!!(rng::AbstractRNG, model::GPSSM, state, step)
61+
return rand(rng, gp_update(model, state, step))
62+
end
63+
64+
function SSMProblems.emission_logdensity(model::GPSSM, state, step)
65+
return logpdf(g(model.θ, state, step), model.observations[step])
66+
end
67+
function SSMProblems.transition_logdensity(model::GPSSM, prev_state, current_state, step)
68+
return logpdf(gp_update(model, prev_state, step), current_state)
69+
end
70+
6071
AdvancedPS.isdone(::GPSSM, step) = step > T
6172

62-
model = GPSSM(params)
73+
model = GPSSM(y, params)
6374
pg = AdvancedPS.PGAS(Nₚ)
6475
chains = sample(rng, model, pg, Nₛ)
6576

examples/gaussian-ssm/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,5 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
55
Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f"
66
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
77
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
8+
SSMProblems = "26aad666-b158-4e64-9d35-0e672562fa48"
89
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"

examples/gaussian-ssm/script.jl

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ using AdvancedPS
33
using Random
44
using Distributions
55
using Plots
6+
using SSMProblems
67

78
# We consider the following linear state-space model with Gaussian innovations. The latent state is a simple gaussian random walk
89
# and the observation is linear in the latent states, namely:
@@ -33,32 +34,45 @@ Parameters = @NamedTuple begin
3334
r::Float64
3435
end
3536

36-
mutable struct LinearSSM <: AdvancedPS.AbstractStateSpaceModel
37+
mutable struct LinearSSM <: SSMProblems.AbstractStateSpaceModel
3738
X::Vector{Float64}
39+
observations::Vector{Float64}
3840
θ::Parameters
3941
LinearSSM::Parameters) = new(Vector{Float64}(), θ)
42+
LinearSSM(y::Vector, θ::Parameters) = new(Vector{Float64}(), y, θ)
4043
end
4144

4245
# and the densities defined above.
43-
f(m::LinearSSM, state, t) = Normal(m.θ.a * state, m.θ.q) # Transition density
44-
g(m::LinearSSM, state, t) = Normal(state, m.θ.r) # Observation density
45-
f₀(m::LinearSSM) = Normal(0, m.θ.q^2 / (1 - m.θ.a^2)) # Initial state density
46+
f(θ::Parameters, state, t) = Normal.a * state, θ.q) # Transition density
47+
g(θ::Parameters, state, t) = Normal(state, θ.r) # Observation density
48+
f₀(θ::Parameters) = Normal(0, θ.q^2 / (1 - θ.a^2)) # Initial state density
4649
#md nothing #hide
4750

4851
# We also need to specify the dynamics of the system through the transition equations:
4952
# - `AdvancedPS.initialization`: the initial state density
5053
# - `AdvancedPS.transition`: the state transition density
5154
# - `AdvancedPS.observation`: the observation score given the observed data
5255
# - `AdvancedPS.isdone`: signals the end of the execution for the model
53-
AdvancedPS.initialization(model::LinearSSM) = f₀(model)
54-
AdvancedPS.transition(model::LinearSSM, state, step) = f(model, state, step)
55-
function AdvancedPS.observation(model::LinearSSM, state, step)
56-
return logpdf(g(model, state, step), y[step])
56+
SSMProblems.transition!!(rng::AbstractRNG, model::LinearSSM) = rand(rng, f₀(model.θ))
57+
function SSMProblems.transition!!(
58+
rng::AbstractRNG, model::LinearSSM, state::Float64, step::Int
59+
)
60+
return rand(rng, f(model.θ, state, step))
61+
end
62+
63+
function SSMProblems.emission_logdensity(modeL::LinearSSM, state::Float64, step::Int)
64+
return logpdf(g(model.θ, state, step), model.observations[step])
65+
end
66+
function SSMProblems.transition_logdensity(
67+
model::LinearSSM, prev_state, current_state, step
68+
)
69+
return logpdf(f(model.θ, prev_state, step), current_state)
5770
end
71+
72+
# We need to think seriously about how the data is handled
5873
AdvancedPS.isdone(::LinearSSM, step) = step > Tₘ
5974

6075
# Everything is now ready to simulate some data.
61-
6276
a = 0.9 # Scale
6377
q = 0.32 # State variance
6478
r = 1 # Observation variance
@@ -72,14 +86,12 @@ rng = Random.MersenneTwister(seed)
7286

7387
x = zeros(Tₘ)
7488
y = zeros(Tₘ)
75-
76-
reference = LinearSSM(θ₀)
77-
x[1] = rand(rng, f₀(reference))
89+
x[1] = rand(rng, f₀(θ₀))
7890
for t in 1:Tₘ
7991
if t < Tₘ
80-
x[t + 1] = rand(rng, f(reference, x[t], t))
92+
x[t + 1] = rand(rng, f(θ₀, x[t], t))
8193
end
82-
y[t] = rand(rng, g(reference, x[t], t))
94+
y[t] = rand(rng, g(θ₀, x[t], t))
8395
end
8496

8597
# Here are the latent and obseravation timeseries
@@ -88,7 +100,7 @@ plot!(y; seriestype=:scatter, label="y", xlabel="t", mc=:red, ms=2, ma=0.5)
88100

89101
# `AdvancedPS` subscribes to the `AbstractMCMC` API. To sample we just need to define a Particle Gibbs kernel
90102
# and a model interface.
91-
model = LinearSSM(θ₀)
103+
model = LinearSSM(y, θ₀)
92104
pgas = AdvancedPS.PGAS(Nₚ)
93105
chains = sample(rng, model, pgas, Nₛ; progress=false);
94106
#md nothing #hide

examples/levy-ssm/Project.toml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
[deps]
2+
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
3+
AdvancedPS = "576499cb-2369-40b2-a588-c64705576edc"
4+
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
5+
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
6+
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
7+
SSMProblems = "26aad666-b158-4e64-9d35-0e672562fa48"

0 commit comments

Comments
 (0)