1+
12"""
23 ScoreGradELBO(n_samples; kwargs...)
34
4- Evidence lower-bound objective computed with score function gradients.
5- ```math
6- \\ begin{aligned}
7- \\ nabla_{\\ lambda} \\ mathrm{ELBO}\\ left(\\ lambda\\ right)
8- &\\ =
9- \\ mathbb{E}_{z \\ sim q_{\\ lambda}}\\ left[
10- \\ log \\ pi\\ left(z\\ right) \\ nabla_{\\ lambda} \\ log q_{\\ lambda}(z)
11- \\ right]
12- + \\ mathbb{H}\\ left(q_{\\ lambda}\\ right),
13- \\ end{aligned}
14- ```
15-
16- To reduce the variance of the gradient estimator, we use a baseline computed from a running average of the previous ELBO values and subtract it from the objective.
17-
18- ```math
19- \\ mathbb{E}_{z \\ sim q_{\\ lambda}}\\ left[
20- \\ nabla_{\\ lambda} \\ log q_{\\ lambda}(z) \\ left(\\ pi\\ left(z\\ right) - \\ beta\\ right)
21- \\ right]
22- ```
5+ Evidence lower-bound objective computed with score function gradient with the VarGrad objective, also known as the leave-one-out control variate.
236
247# Arguments
25- - `n_samples::Int`: Number of Monte Carlo samples used to estimate the ELBO.
26-
27- # Keyword Arguments
28- - `entropy`: The estimator for the entropy term. (Type `<: AbstractEntropyEstimator`; Default: `ClosedFormEntropy()`)
29- - `baseline_window_size::Int`: The window size to use to compute the baseline. (Default: `10`)
30- - `baseline_history::Vector{Float64}`: The history of the baseline. (Default: `Float64[]`)
8+ - `n_samples::Int`: Number of Monte Carlo samples used to estimate the VarGrad objective.
319
3210# Requirements
3311- The variational approximation ``q_{\\ lambda}`` implements `rand` and `logpdf`.
3412- `logpdf(q, x)` must be differentiable with respect to `q` by the selected AD backend.
3513- The target distribution and the variational approximation have the same support.
36-
37- Depending on the options, additional requirements on ``q_{\\ lambda}`` may apply.
3814"""
39- struct ScoreGradELBO{EntropyEst<: AbstractEntropyEstimator } < :
40- AdvancedVI. AbstractVariationalObjective
41- entropy:: EntropyEst
15+ struct ScoreGradELBO <: AbstractVariationalObjective
4216 n_samples:: Int
43- baseline_window_size:: Int
44- baseline_history:: Vector{Float64}
45- end
46-
47- function ScoreGradELBO (
48- n_samples:: Int ;
49- entropy:: AbstractEntropyEstimator = ClosedFormEntropy (),
50- baseline_window_size:: Int = 10 ,
51- baseline_history:: Vector{Float64} = Float64[],
52- )
53- return ScoreGradELBO (entropy, n_samples, baseline_window_size, baseline_history)
5417end
5518
5619function Base. show (io:: IO , obj:: ScoreGradELBO )
57- print (io, " ScoreGradELBO(entropy=" )
58- print (io, obj. entropy)
59- print (io, " , n_samples=" )
20+ print (io, " ScoreGradELBO(n_samples=" )
6021 print (io, obj. n_samples)
61- print (io, " , baseline_window_size=" )
62- print (io, obj. baseline_window_size)
6322 return print (io, " )" )
6423end
6524
66- function compute_control_variate_baseline (history, window_size)
67- if length (history) == 0
68- return 1.0
69- end
70- min_index = max (1 , length (history) - window_size)
71- return mean (history[min_index: end ])
72- end
73-
74- function estimate_energy_with_samples (
75- prob, samples_stop, samples_logprob, samples_logprob_stop, baseline
76- )
77- fv = Base. Fix1 (LogDensityProblems. logdensity, prob).(eachsample (samples_stop))
78- fv_mean = mean (fv)
79- score_grad = mean (@. samples_logprob * (fv - baseline))
80- score_grad_stop = mean (@. samples_logprob_stop * (fv - baseline))
81- return fv_mean + (score_grad - score_grad_stop)
82- end
83-
8425function estimate_objective (
8526 rng:: Random.AbstractRNG , obj:: ScoreGradELBO , q, prob; n_samples:: Int = obj. n_samples
8627)
87- samples, entropy = reparam_with_entropy (rng, q, q, obj. n_samples, obj. entropy)
88- energy = map (Base. Fix1 (LogDensityProblems. logdensity, prob), eachsample (samples))
89- return mean (energy) + entropy
28+ samples = rand (rng, q, n_samples)
29+ ℓπ = map (Base. Fix1 (LogDensityProblems. logdensity, prob), eachsample (samples))
30+ ℓq = logpdf .(Ref (q), AdvancedVI. eachsample (samples))
31+ return mean (ℓπ - ℓq)
9032end
9133
9234function estimate_objective (obj:: ScoreGradELBO , q, prob; n_samples:: Int = obj. n_samples)
9335 return estimate_objective (Random. default_rng (), obj, q, prob; n_samples)
9436end
9537
96- function estimate_scoregradelbo_ad_forward (params′, aux)
97- (; rng, obj, problem, adtype, restructure, q_stop) = aux
98- baseline = compute_control_variate_baseline (
99- obj. baseline_history, obj. baseline_window_size
100- )
101- q = restructure_ad_forward (adtype, restructure, params′)
102- samples_stop = rand (rng, q_stop, obj. n_samples)
103- entropy = estimate_entropy_maybe_stl (obj. entropy, samples_stop, q, q_stop)
104- samples_logprob = logpdf .(Ref (q), AdvancedVI. eachsample (samples_stop))
105- samples_logprob_stop = logpdf .(Ref (q_stop), AdvancedVI. eachsample (samples_stop))
106- energy = estimate_energy_with_samples (
107- problem, samples_stop, samples_logprob, samples_logprob_stop, baseline
108- )
109- elbo = energy + entropy
110- return - elbo
38+ """
39+ estimate_scoregradelbo_ad_forward(params, aux)
40+
41+ AD-guaranteed forward path of the score gradient objective.
42+
43+ # Arguments
44+ - `params`: Variational parameters.
45+ - `aux`: Auxiliary information excluded from the AD path.
46+
47+ # Auxiliary Information
48+ `aux` should containt the following entries:
49+ - `samples_stop`: Samples drawn from `q = restructure(params)` but with their gradients stopped (excluded from the AD path).
50+ - `logprob_stop`: Log-densities of the target `LogDensityProblem` evaluated over `samples_stop`.
51+ - `adtype`: The `ADType` used for differentiating the forward path.
52+ - `restructure`: Callable for restructuring the varitional distribution from `params`.
53+ """
54+ function estimate_scoregradelbo_ad_forward (params, aux)
55+ (; samples_stop, logprob_stop, adtype, restructure) = aux
56+ q = restructure_ad_forward (adtype, restructure, params)
57+ ℓπ = logprob_stop
58+ ℓq = logpdf .(Ref (q), AdvancedVI. eachsample (samples_stop))
59+ f = ℓq - ℓπ
60+ return (mean (abs2, f) - mean (f)^ 2 ) / 2
11161end
11262
11363function AdvancedVI. estimate_gradient! (
@@ -120,20 +70,15 @@ function AdvancedVI.estimate_gradient!(
12070 restructure,
12171 state,
12272)
123- q_stop = restructure (params)
124- aux = (
125- rng= rng,
126- adtype= adtype,
127- obj= obj,
128- problem= prob,
129- restructure= restructure,
130- q_stop= q_stop,
131- )
73+ q = restructure (params)
74+ samples = rand (rng, q, obj. n_samples)
75+ ℓπ = map (Base. Fix1 (LogDensityProblems. logdensity, prob), eachsample (samples))
76+ aux = (adtype= adtype, logprob_stop= ℓπ, samples_stop= samples, restructure= restructure)
13277 AdvancedVI. value_and_gradient! (
13378 adtype, estimate_scoregradelbo_ad_forward, params, aux, out
13479 )
135- nelbo = DiffResults . value (out )
136- stat = (elbo = - nelbo, )
137- push! (obj . baseline_history, - nelbo )
80+ ℓq = logpdf .( Ref (q), AdvancedVI . eachsample (samples) )
81+ elbo = mean (ℓπ - ℓq )
82+ stat = (elbo = elbo, )
13883 return out, nothing , stat
13984end
0 commit comments