Skip to content

Commit 67d3496

Browse files
FredericWantiezyebaigithub-actions[bot]
authored
Update the update_ref! API (#85)
* Try to fix the API, `update_ref!` is broken * Introduce a `ParticleSampler` type (#86) * Update smc.jl * Update smc.jl * Apply suggestions from code review * Update smc.jl * Update AdvancedPS.jl * Update src/container.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/AdvancedPS.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Apply suggestions from code review * fix typo * Changing the name --------- Co-authored-by: Hong Ge <[email protected]> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent e545c86 commit 67d3496

File tree

8 files changed

+46
-54
lines changed

8 files changed

+46
-54
lines changed

examples/particle-gibbs/script.jl

Lines changed: 16 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ using Distributions
55
using Plots
66
using AbstractMCMC
77
using Random123
8-
using Libtask
98

109
"""
1110
plot_update_rate(update_rate, N)
@@ -91,39 +90,35 @@ plot(x; label="x", xlabel="t")
9190
plot(y; label="y", xlabel="t")
9291

9392
# Each model takes an `AbstractRNG` as input and generates the logpdf of the current transition:
94-
mutable struct NonLinearTimeSeries <: AdvancedPS.AbstractGenericModel
95-
X::Array
93+
mutable struct NonLinearTimeSeries <: AdvancedPS.AbstractStateSpaceModel
94+
X::Vector{Float64}
9695
θ::Parameters
97-
NonLinearTimeSeries::Parameters) = new(zeros(Float64, θ.T), θ)
96+
NonLinearTimeSeries::Parameters) = new(Float64[], θ)
9897
end
9998

100-
function (model::NonLinearTimeSeries)(rng::Random.AbstractRNG)
101-
x₀ = rand(rng, f₀(model.θ))
102-
model.X[1] = x₀
103-
score = logpdf(g(model.θ, x₀, 1), y[1])
104-
Libtask.produce(score)
105-
106-
for t in 2:(model.θ.T)
107-
state = rand(rng, f(model.θ, model.X[t - 1], t - 1))
108-
model.X[t] = state
109-
score = logpdf(g(model.θ, state, t), y[t])
110-
Libtask.produce(score)
111-
end
99+
# The dynamics of the model is defined through the `AbstractStateSpaceModel` interface:
100+
AdvancedPS.initialization(model::NonLinearTimeSeries) = f₀(model.θ)
101+
AdvancedPS.transition(model::NonLinearTimeSeries, state, step) = f(model.θ, state, step)
102+
function AdvancedPS.observation(model::NonLinearTimeSeries, state, step)
103+
return logpdf(g(model.θ, state, step), y[step])
112104
end
105+
AdvancedPS.isdone(::NonLinearTimeSeries, step) = step > Tₘ
113106

114107
# Here we use the particle gibbs kernel without adaptive resampling.
115108
model = NonLinearTimeSeries(θ₀)
116109
pg = AdvancedPS.PG(Nₚ, 1.0)
117110
chains = sample(rng, model, pg, Nₛ; progress=false);
118111
#md nothing #hide
119112

120-
particles = hcat([chain.trajectory.X for chain in chains]...) # Concat all sampled states
113+
particles = hcat([chain.trajectory.model.X for chain in chains]...) # Concat all sampled states
121114
mean_trajectory = mean(particles; dims=2);
122115
#md nothing #hide
123116

124117
# We can now plot all the generated traces.
125118
# Beyond the last few timesteps all the trajectories collapse into one. Using the ancestor updating step can help with the degeneracy problem, as we show below.
126-
scatter(particles; label=false, opacity=1.01, color=:black, xlabel="t", ylabel="state")
119+
scatter(
120+
particles[:, 1:50]; label=false, opacity=0.5, color=:black, xlabel="t", ylabel="state"
121+
)
127122
plot!(x; color=:darkorange, label="Original Trajectory")
128123
plot!(mean_trajectory; color=:dodgerblue, label="Mean trajectory", opacity=0.9)
129124

@@ -133,29 +128,15 @@ plot!(mean_trajectory; color=:dodgerblue, label="Mean trajectory", opacity=0.9)
133128
plot_update_rate(update_rate(particles, Nₛ)[:, 1], Nₚ)
134129

135130
# Let's see if ancestor sampling can help with the degeneracy problem. We use the same number of particles, but replace the sampler with PGAS.
136-
# To use this sampler we need to define the transition and observation densities as well as the initial distribution in the following way:
137-
mutable struct NonLinearSSM <: AdvancedPS.AbstractStateSpaceModel
138-
X::Vector{Float64}
139-
θ::Parameters
140-
NonLinearSSM::Parameters) = new(Float64[], θ)
141-
end
142-
143-
AdvancedPS.initialization(model::NonLinearSSM) = f₀(model.θ)
144-
AdvancedPS.transition(model::NonLinearSSM, state, step) = f(model.θ, state, step)
145-
function AdvancedPS.observation(model::NonLinearSSM, state, step)
146-
return logpdf(g(model.θ, state, step), y[step])
147-
end
148-
AdvancedPS.isdone(::NonLinearSSM, step) = step > Tₘ
149-
150-
# We can now sample from the model using the PGAS sampler and collect the trajectories.
151131
pgas = AdvancedPS.PGAS(Nₚ)
152-
model = NonLinearSSM(θ₀)
153132
chains = sample(rng, model, pgas, Nₛ; progress=false);
154133
particles = hcat([chain.trajectory.model.X for chain in chains]...);
155134
mean_trajectory = mean(particles; dims=2);
156135

157136
# The ancestor sampling has helped with the degeneracy problem and we now have a much more diverse set of trajectories, also at earlier time periods.
158-
scatter(particles; label=false, opacity=0.01, color=:black, xlabel="t", ylabel="state")
137+
scatter(
138+
particles[:, 1:50]; label=false, opacity=0.5, color=:black, xlabel="t", ylabel="state"
139+
)
159140
plot!(x; color=:darkorange, label="Original Trajectory")
160141
plot!(mean_trajectory; color=:dodgerblue, label="Mean trajectory", opacity=0.9)
161142

ext/AdvancedPSLibtaskExt.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ function AbstractMCMC.step(
146146

147147
# Perform a particle sweep.
148148
reference = isref ? particles.vals[nparticles] : nothing
149-
logevidence = AdvancedPS.sweep!(rng, particles, sampler.resampler, reference)
149+
logevidence = AdvancedPS.sweep!(rng, particles, sampler.resampler, sampler, reference)
150150

151151
# Pick a particle to be retained.
152152
newtrajectory = rand(rng, particles)
@@ -184,7 +184,7 @@ function AbstractMCMC.sample(
184184
particles = AdvancedPS.ParticleContainer(traces, AdvancedPS.TracedRNG(), rng)
185185

186186
# Perform particle sweep.
187-
logevidence = AdvancedPS.sweep!(rng, particles, sampler.resampler)
187+
logevidence = AdvancedPS.sweep!(rng, particles, sampler.resampler, sampler)
188188

189189
replayed = map(particle -> AdvancedPS.replay(particle).model.f, particles.vals)
190190

src/AdvancedPS.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ using Random123: Random123
88

99
abstract type AbstractParticleModel <: AbstractMCMC.AbstractModel end
1010

11+
abstract type AbstractParticleSampler <: AbstractMCMC.AbstractSampler end
12+
1113
""" Abstract type for an abstract model formulated in the state space form
1214
"""
1315
abstract type AbstractStateSpaceModel <: AbstractParticleModel end
@@ -17,8 +19,8 @@ include("resampling.jl")
1719
include("rng.jl")
1820
include("model.jl")
1921
include("container.jl")
20-
include("pgas.jl")
2122
include("smc.jl")
23+
include("pgas.jl")
2224

2325
if !isdefined(Base, :get_extension)
2426
using Requires

src/container.jl

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,11 @@ end
6161
6262
Update reference trajectory. Defaults to `nothing`
6363
"""
64-
update_ref!(particle::Trace, pc::ParticleContainer) = nothing
64+
function update_ref!(
65+
particle::Trace, pc::ParticleContainer, sampler::AbstractParticleSampler
66+
)
67+
return nothing
68+
end
6569

6670
"""
6771
reset_logweights!(pc::ParticleContainer)
@@ -167,6 +171,7 @@ of the particle `weights`. For Particle Gibbs sampling, one can provide a refere
167171
function resample_propagate!(
168172
::Random.AbstractRNG,
169173
pc::ParticleContainer,
174+
sampler::AbstractParticleSampler,
170175
randcat=DEFAULT_RESAMPLER,
171176
ref::Union{Particle,Nothing}=nothing;
172177
weights=getweights(pc),
@@ -214,7 +219,7 @@ function resample_propagate!(
214219
if ref !== nothing
215220
# Insert the retained particle. This is based on the replaying trick for efficiency
216221
# reasons. If we implement PG using task copying, we need to store Nx * T particles!
217-
update_ref!(ref, pc)
222+
update_ref!(ref, pc, sampler)
218223
@inbounds children[n] = ref
219224
end
220225

@@ -228,6 +233,7 @@ end
228233
function resample_propagate!(
229234
rng::Random.AbstractRNG,
230235
pc::ParticleContainer,
236+
sampler::AbstractParticleSampler,
231237
resampler::ResampleWithESSThreshold,
232238
ref::Union{Particle,Nothing}=nothing;
233239
weights=getweights(pc),
@@ -236,7 +242,7 @@ function resample_propagate!(
236242
ess = inv(sum(abs2, weights))
237243

238244
if ess resampler.threshold * length(pc)
239-
resample_propagate!(rng, pc, resampler.resampler, ref; weights=weights)
245+
resample_propagate!(rng, pc, sampler, resampler.resampler, ref; weights=weights)
240246
else
241247
update_keys!(pc, ref)
242248
end
@@ -311,11 +317,12 @@ function sweep!(
311317
rng::Random.AbstractRNG,
312318
pc::ParticleContainer,
313319
resampler,
320+
sampler::AbstractMCMC.AbstractSampler,
314321
ref::Union{Particle,Nothing}=nothing,
315322
)
316323
# Initial step:
317324
# Resample and propagate particles.
318-
resample_propagate!(rng, pc, resampler, ref)
325+
resample_propagate!(rng, pc, sampler, resampler, ref)
319326

320327
# Compute the current normalizing constant ``Z₀`` of the unnormalized logarithmic
321328
# weights.
@@ -336,7 +343,7 @@ function sweep!(
336343
# For observations ``y₂, …, yₜ``:
337344
while !isdone
338345
# Resample and propagate particles.
339-
resample_propagate!(rng, pc, resampler, ref)
346+
resample_propagate!(rng, pc, sampler, resampler, ref)
340347

341348
# Compute the current normalizing constant ``Z₀`` of the unnormalized logarithmic
342349
# weights.

src/pgas.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ function forkr(particle::SSMTrace)
133133
return newtrace
134134
end
135135

136-
function update_ref!(ref::SSMTrace, pc::ParticleContainer{<:SSMTrace})
136+
function update_ref!(ref::SSMTrace, pc::ParticleContainer{<:SSMTrace}, sampler::PGAS)
137137
current_step(ref) <= 2 && return nothing # At the beginning of step + 1 since we start at 1
138138
isdone(ref.model, current_step(ref)) && return nothing
139139

src/smc.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
struct SMC{R} <: AbstractMCMC.AbstractSampler
1+
struct SMC{R} <: AbstractParticleSampler
22
nparticles::Int
33
resampler::R
44
end
@@ -46,12 +46,12 @@ function AbstractMCMC.sample(
4646
particles = ParticleContainer(traces, TracedRNG(), rng)
4747

4848
# Perform particle sweep.
49-
logevidence = sweep!(rng, particles, sampler.resampler)
49+
logevidence = sweep!(rng, particles, sampler.resampler, sampler)
5050

5151
return SMCSample(collect(particles), getweights(particles), logevidence)
5252
end
5353

54-
struct PG{R} <: AbstractMCMC.AbstractSampler
54+
struct PG{R} <: AbstractParticleSampler
5555
"""Number of particles."""
5656
nparticles::Int
5757
"""Resampling algorithm."""
@@ -84,7 +84,7 @@ struct PGSample{T,L}
8484
logevidence::L
8585
end
8686

87-
struct PGAS{R} <: AbstractMCMC.AbstractSampler
87+
struct PGAS{R} <: AbstractParticleSampler
8888
"""Number of particles."""
8989
nparticles::Int
9090
"""Resampling algorithm."""
@@ -96,7 +96,7 @@ PGAS(nparticles::Int) = PGAS(nparticles, ResampleWithESSThreshold(1.0))
9696
function AbstractMCMC.step(
9797
rng::Random.AbstractRNG,
9898
model::AbstractStateSpaceModel,
99-
sampler::PGAS,
99+
sampler::Union{PGAS,PG},
100100
state::Union{PGState,Nothing}=nothing;
101101
kwargs...,
102102
)
@@ -116,7 +116,7 @@ function AbstractMCMC.step(
116116

117117
# Perform a particle sweep.
118118
reference = isref ? particles.vals[nparticles] : nothing
119-
logevidence = sweep!(rng, particles, sampler.resampler, reference)
119+
logevidence = sweep!(rng, particles, sampler.resampler, sampler, reference)
120120

121121
# Pick a particle to be retained.
122122
newtrajectory = rand(particles.rng, particles)

test/container.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,9 @@
7373
ref = AdvancedPS.forkr(selected)
7474
pc_ref.vals[end] = ref
7575

76+
sampler = AdvancedPS.PG(length(logps))
7677
AdvancedPS.resample_propagate!(
77-
Random.GLOBAL_RNG, pc_ref, AdvancedPS.resample_systematic, ref
78+
Random.GLOBAL_RNG, pc_ref, sampler, AdvancedPS.resample_systematic, ref
7879
)
7980
@test pc_ref.logWs == zeros(3)
8081
@test AdvancedPS.getweights(pc_ref) == fill(1 / 3, 3)
@@ -84,7 +85,7 @@
8485
@test pc_ref.vals[end] === particles_ref[end]
8586

8687
# Resample and propagate particles.
87-
AdvancedPS.resample_propagate!(Random.GLOBAL_RNG, pc)
88+
AdvancedPS.resample_propagate!(Random.GLOBAL_RNG, pc, sampler)
8889
@test pc.logWs == zeros(3)
8990
@test AdvancedPS.getweights(pc) == fill(1 / 3, 3)
9091
@test all(AdvancedPS.getweight(pc, i) == 1 / 3 for i in 1:3)

test/pgas.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
AdvancedPS.Trace(BaseModel(Params(0.9, 0.31, 1)), AdvancedPS.TracedRNG()) for
4747
_ in 1:3
4848
]
49+
sampler = AdvancedPS.PGAS(3)
4950
resampler = AdvancedPS.ResampleWithESSThreshold(1.0)
5051

5152
part = particles[3]
@@ -58,11 +59,11 @@
5859
pc = AdvancedPS.ParticleContainer(particles, AdvancedPS.TracedRNG(), base_rng)
5960

6061
AdvancedPS.reweight!(pc, ref)
61-
AdvancedPS.resample_propagate!(base_rng, pc, resampler, ref)
62+
AdvancedPS.resample_propagate!(base_rng, pc, sampler, resampler, ref)
6263

6364
AdvancedPS.reweight!(pc, ref)
6465
pc.logWs = [-Inf, 0, -Inf] # Force ancestor update to second particle
65-
AdvancedPS.resample_propagate!(base_rng, pc, resampler, ref)
66+
AdvancedPS.resample_propagate!(base_rng, pc, sampler, resampler, ref)
6667

6768
AdvancedPS.reweight!(pc, ref)
6869
@test all(pc.vals[2].model.X[1:2] .≈ ref.model.X[1:2])

0 commit comments

Comments
 (0)