@@ -27,17 +27,16 @@ function simulate(
2727 t = t0
2828 truncated = last_jump < tolerance
2929 while ! truncated
30- t += rand (rng, Exponential (one (T) / rate))
31- xi = one (T) / (β * (exp (t / C) - one (T) ))
32- prob = (one (T) + β * xi) * exp (- β * xi)
30+ t += rand (rng, Exponential (1 / rate))
31+ xi = 1 / (β * (exp (t / C) - 1 ))
32+ prob = (1 + β * xi) * exp (- β * xi)
3333 if rand (rng) < prob
3434 push! (jumps, xi)
3535 last_jump = xi
3636 end
3737 truncated = last_jump < tolerance
3838 end
39- times = rand (rng, Uniform (start, finish), length (jumps))
40- return GammaPath (jumps, times)
39+ return GammaPath (jumps, rand (rng, Uniform (start, finish), length (jumps)))
4140 end
4241end
4342
@@ -47,85 +46,66 @@ function integral(times::Array{<:Real}, path::GammaPath)
4746 end
4847end
4948
50- struct LangevinDynamics{T}
51- A:: Matrix{T}
52- L:: Vector{T}
53- θ:: T
54- H:: Vector{T}
55- σe:: T
49+ struct LangevinDynamics{AT<: AbstractMatrix ,LT<: AbstractVector ,θT<: Real }
50+ A:: AT
51+ L:: LT
52+ θ:: θT
5653end
5754
58- struct NormalMeanVariance{T}
59- μ :: T
60- σ :: T
55+ function Base . exp (dyn :: LangevinDynamics , dt)
56+ f_val = exp (dyn . θ * dt)
57+ return [ 1 (f_val - 1 ) / dyn . θ; 0 f_val]
6158end
6259
63- f (dt, θ) = exp (θ * dt)
64- function Base. exp (dyn:: LangevinDynamics{T} , dt:: T ) where {T<: Real }
65- let θ = dyn. θ
66- f_val = f (dt, θ)
67- return [one (T) (f_val - 1 )/ θ; zero (T) f_val]
68- end
60+ function meancov (t, dyn:: LangevinDynamics , path:: GammaPath , dist:: Normal )
61+ fts = exp .(Ref (dyn), (t .- path. times)) .* Ref (dyn. L)
62+ μ = sum (@. fts * mean (dist) * path. jumps)
63+ Σ = sum (@. fts * transpose (fts) * var (dist) * path. jumps)
64+ return μ, Σ + eltype (Σ)(1e-6 ) * I
6965end
7066
71- function meancov (
72- t:: T , dyn:: LangevinDynamics , path:: GammaPath , nvm:: NormalMeanVariance
73- ) where {T<: Real }
74- μ = zeros (T, 2 )
75- Σ = zeros (T, (2 , 2 ))
76- let times = path. times, jumps = path. jumps, μw = nvm. μ, σw = nvm. σ
77- for (v, z) in zip (times, jumps)
78- ft = exp (dyn, (t - v)) * dyn. L
79- μ += ft .* μw .* z
80- Σ += ft * transpose (ft) .* σw^ 2 .* z
81- end
82-
83- # Guarantees positive semi-definiteness
84- return μ, Σ + T (1e-6 ) * I
85- end
67+ struct LevyPrior{XT<: AbstractVector ,ΣT<: AbstractMatrix } <: StatePrior
68+ μ:: XT
69+ Σ:: ΣT
8670end
8771
88- struct LevyLangevin{T} <: LatentDynamics{T,Vector{T}}
89- dt:: T
90- dyn:: LangevinDynamics{T}
91- process:: GammaProcess{T}
92- nvm:: NormalMeanVariance{T}
93- end
72+ SSMProblems. distribution (proc:: LevyPrior ) = MvNormal (proc. μ, proc. Σ)
9473
95- function SSMProblems. distribution (proc:: LevyLangevin{T} ) where {T<: Real }
96- return MultivariateNormal (zeros (T, 2 ), I)
74+ struct LevyLangevin{T<: Real ,LT<: LangevinDynamics ,ΓT<: GammaProcess ,DT<: Normal } < :
75+ SSMProblems. LatentDynamics
76+ dt:: T
77+ dyn:: LT
78+ process: :ΓT
79+ dist:: DT
9780end
9881
99- function SSMProblems. distribution (proc:: LevyLangevin{T} , step:: Int , state) where {T <: Real }
82+ function SSMProblems. distribution (proc:: LevyLangevin , step:: Int , state)
10083 dt = proc. dt
10184 path = simulate (rng, proc. process, dt, (step - 1 ) * dt, step * dt)
102- μ, Σ = meancov (step * dt, proc. dyn, path, proc. nvm )
103- return MultivariateNormal (exp (proc. dyn, dt) * state + μ, Σ)
85+ μ, Σ = meancov (step * dt, proc. dyn, path, proc. dist )
86+ return MvNormal (exp (proc. dyn, dt) * state + μ, Σ)
10487end
10588
106- struct LinearGaussianObservation{T<: Real } <: ObservationProcess{T,T}
107- H:: Vector{T}
108- R:: T
89+ struct LinearGaussianObservation{HT<: AbstractVector ,RT<: Real } < :
90+ SSMProblems. ObservationProcess
91+ H:: HT
92+ R:: RT
10993end
11094
111- function SSMProblems. distribution (proc:: LinearGaussianObservation , step :: Int , state)
95+ function SSMProblems. distribution (proc:: LinearGaussianObservation , :: Int , state)
11296 return Normal (transpose (proc. H) * state, proc. R)
11397end
11498
115- function LevyModel (dt, θ, σe, C, β, μw, σw; ϵ= 1e-10 )
116- A = [0.0 1.0 ; 0.0 θ]
117- L = [0.0 ; 1.0 ]
118- H = [1.0 , 0 ]
119-
99+ function LevyModel (dt, θ, σe, C, β, μw, σw; kwargs... )
120100 dyn = LevyLangevin (
121101 dt,
122- LangevinDynamics (A, L, θ, H, σe ),
123- GammaProcess (C, β; ϵ ),
124- NormalMeanVariance (μw, σw),
102+ LangevinDynamics ([ 0 1 ; 0 θ], [ 0 ; 1 ], θ ),
103+ GammaProcess (C, β; kwargs ... ),
104+ Normal (μw, σw),
125105 )
126106
127- obs = LinearGaussianObservation (H , σe)
128- return StateSpaceModel (dyn, obs)
107+ obs = LinearGaussianObservation ([ 1 ; 0 ] , σe)
108+ return SSMProblems . StateSpaceModel (LevyPrior ( zeros (Bool, 2 ), I ( 2 )), dyn, obs)
129109end
130110
131111# Levy SSM with Langevin dynamics
@@ -139,18 +119,18 @@ end
139119# Simulation parameters
140120N = 200
141121ts = range (0 , 100 ; length= N)
142- levyssm = LevyModel (step (ts), θ , 1.0 , 1.0 , 1.0 , 0.0 , 1.0 );
122+ levyssm = LevyModel (step (ts), - 0.5 , 1 , 1.0 , 1.0 , 0 , 1 );
143123
144124# Simulate data
145125rng = Random. MersenneTwister (1234 );
146126_, X, Y = sample (rng, levyssm, N);
147127
148128# Run sampler
149129pg = AdvancedPS. PGAS (50 );
150- chains = sample (rng, levyssm ( Y), pg, 100 );
130+ chains = sample (rng, AdvancedPS . TracedSSM (levyssm, Y), pg, 100 ; progress = false );
151131
152132# Concat all sampled states
153- marginal_states = hcat ([chain. trajectory. model. X for chain in chains]. .. )
133+ marginal_states = hcat ([chain. trajectory. model. X for chain in chains]. .. );
154134
155135# Plot marginal state and jump intensities for one trajectory
156136p1 = plot (
@@ -166,7 +146,6 @@ plot!(
166146 label= " Marginal State (x2)" ,
167147)
168148
169- # TODO : collect jumps from the model
170149p2 = scatter ([], []; color= :darkorange , label= " Jumps" )
171150
172151plot (
0 commit comments