@@ -5,7 +5,6 @@ using Distributions
55using Plots
66using AbstractMCMC
77using Random123
8- using Libtask
98
109"""
1110 plot_update_rate(update_rate, N)
@@ -91,39 +90,35 @@ plot(x; label="x", xlabel="t")
9190plot (y; label= " y" , xlabel= " t" )
9291
9392# Each model takes an `AbstractRNG` as input and generates the logpdf of the current transition:
94- mutable struct NonLinearTimeSeries <: AdvancedPS.AbstractGenericModel
95- X:: Array
93+ mutable struct NonLinearTimeSeries <: AdvancedPS.AbstractStateSpaceModel
94+ X:: Vector{Float64}
9695 θ:: Parameters
97- NonLinearTimeSeries (θ:: Parameters ) = new (zeros ( Float64, θ . T) , θ)
96+ NonLinearTimeSeries (θ:: Parameters ) = new (Float64[] , θ)
9897end
9998
100- function (model:: NonLinearTimeSeries )(rng:: Random.AbstractRNG )
101- x₀ = rand (rng, f₀ (model. θ))
102- model. X[1 ] = x₀
103- score = logpdf (g (model. θ, x₀, 1 ), y[1 ])
104- Libtask. produce (score)
105-
106- for t in 2 : (model. θ. T)
107- state = rand (rng, f (model. θ, model. X[t - 1 ], t - 1 ))
108- model. X[t] = state
109- score = logpdf (g (model. θ, state, t), y[t])
110- Libtask. produce (score)
111- end
99+ # The dynamics of the model is defined through the `AbstractStateSpaceModel` interface:
100+ AdvancedPS. initialization (model:: NonLinearTimeSeries ) = f₀ (model. θ)
101+ AdvancedPS. transition (model:: NonLinearTimeSeries , state, step) = f (model. θ, state, step)
102+ function AdvancedPS. observation (model:: NonLinearTimeSeries , state, step)
103+ return logpdf (g (model. θ, state, step), y[step])
112104end
105+ AdvancedPS. isdone (:: NonLinearTimeSeries , step) = step > Tₘ
113106
114107# Here we use the particle gibbs kernel without adaptive resampling.
115108model = NonLinearTimeSeries (θ₀)
116109pg = AdvancedPS. PG (Nₚ, 1.0 )
117110chains = sample (rng, model, pg, Nₛ; progress= false );
118111# md nothing #hide
119112
120- particles = hcat ([chain. trajectory. X for chain in chains]. .. ) # Concat all sampled states
113+ particles = hcat ([chain. trajectory. model . X for chain in chains]. .. ) # Concat all sampled states
121114mean_trajectory = mean (particles; dims= 2 );
122115# md nothing #hide
123116
124117# We can now plot all the generated traces.
125118# Beyond the last few timesteps all the trajectories collapse into one. Using the ancestor updating step can help with the degeneracy problem, as we show below.
126- scatter (particles; label= false , opacity= 1.01 , color= :black , xlabel= " t" , ylabel= " state" )
119+ scatter (
120+ particles[:, 1 : 50 ]; label= false , opacity= 0.5 , color= :black , xlabel= " t" , ylabel= " state"
121+ )
127122plot! (x; color= :darkorange , label= " Original Trajectory" )
128123plot! (mean_trajectory; color= :dodgerblue , label= " Mean trajectory" , opacity= 0.9 )
129124
@@ -133,29 +128,15 @@ plot!(mean_trajectory; color=:dodgerblue, label="Mean trajectory", opacity=0.9)
133128plot_update_rate (update_rate (particles, Nₛ)[:, 1 ], Nₚ)
134129
135130# Let's see if ancestor sampling can help with the degeneracy problem. We use the same number of particles, but replace the sampler with PGAS.
136- # To use this sampler we need to define the transition and observation densities as well as the initial distribution in the following way:
137- mutable struct NonLinearSSM <: AdvancedPS.AbstractStateSpaceModel
138- X:: Vector{Float64}
139- θ:: Parameters
140- NonLinearSSM (θ:: Parameters ) = new (Float64[], θ)
141- end
142-
143- AdvancedPS. initialization (model:: NonLinearSSM ) = f₀ (model. θ)
144- AdvancedPS. transition (model:: NonLinearSSM , state, step) = f (model. θ, state, step)
145- function AdvancedPS. observation (model:: NonLinearSSM , state, step)
146- return logpdf (g (model. θ, state, step), y[step])
147- end
148- AdvancedPS. isdone (:: NonLinearSSM , step) = step > Tₘ
149-
150- # We can now sample from the model using the PGAS sampler and collect the trajectories.
151131pgas = AdvancedPS. PGAS (Nₚ)
152- model = NonLinearSSM (θ₀)
153132chains = sample (rng, model, pgas, Nₛ; progress= false );
154133particles = hcat ([chain. trajectory. model. X for chain in chains]. .. );
155134mean_trajectory = mean (particles; dims= 2 );
156135
157136# The ancestor sampling has helped with the degeneracy problem and we now have a much more diverse set of trajectories, also at earlier time periods.
158- scatter (particles; label= false , opacity= 0.01 , color= :black , xlabel= " t" , ylabel= " state" )
137+ scatter (
138+ particles[:, 1 : 50 ]; label= false , opacity= 0.5 , color= :black , xlabel= " t" , ylabel= " state"
139+ )
159140plot! (x; color= :darkorange , label= " Original Trajectory" )
160141plot! (mean_trajectory; color= :dodgerblue , label= " Mean trajectory" , opacity= 0.9 )
161142
0 commit comments