Skip to content

Commit 1dbf2ac

Browse files
Red-Portalgithub-actions[bot]yebai
authored
Minor Touches for ScoreGradELBO (#99)
* fix move log density computation of ScoreGradELBO out of the AD path * update change the `ScoreGradELBO` objective to be VarGrad underneath * fix remove unnecessary import * add basic tests for interface tests of variational objectives * tweak stepsize for inference test of ScoreGradELBO * add docstrings to elbo objective forward ad paths * remove `n_montecarlo` option in the inference tests and just fix it --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Hong Ge <[email protected]>
1 parent 227d58d commit 1dbf2ac

11 files changed

+137
-151
lines changed

src/objectives/elbo/entropy.jl

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,3 @@ function estimate_entropy(
3737
-logpdf(q, mc_sample)
3838
end
3939
end
40-
41-
function estimate_entropy_maybe_stl(
42-
entropy_estimator::AbstractEntropyEstimator, samples, q, q_stop
43-
)
44-
q_maybe_stop = maybe_stop_entropy_score(entropy_estimator, q, q_stop)
45-
return estimate_entropy(entropy_estimator, samples, q_maybe_stop)
46-
end

src/objectives/elbo/repgradelbo.jl

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,13 @@ function Base.show(io::IO, obj::RepGradELBO)
4545
return print(io, ")")
4646
end
4747

48+
function estimate_entropy_maybe_stl(
49+
entropy_estimator::AbstractEntropyEstimator, samples, q, q_stop
50+
)
51+
q_maybe_stop = maybe_stop_entropy_score(entropy_estimator, q, q_stop)
52+
return estimate_entropy(entropy_estimator, samples, q_maybe_stop)
53+
end
54+
4855
function estimate_energy_with_samples(prob, samples)
4956
return mean(Base.Fix1(LogDensityProblems.logdensity, prob), eachsample(samples))
5057
end
@@ -85,9 +92,27 @@ function estimate_objective(obj::RepGradELBO, q, prob; n_samples::Int=obj.n_samp
8592
return estimate_objective(Random.default_rng(), obj, q, prob; n_samples)
8693
end
8794

88-
function estimate_repgradelbo_ad_forward(params′, aux)
95+
"""
96+
estimate_repgradelbo_ad_forward(params, aux)
97+
98+
AD-guaranteed forward path of the reparameterization gradient objective.
99+
100+
# Arguments
101+
- `params`: Variational parameters.
102+
- `aux`: Auxiliary information excluded from the AD path.
103+
104+
# Auxiliary Information
105+
`aux` should containt the following entries:
106+
- `rng`: Random number generator.
107+
- `obj`: The `RepGradELBO` objective.
108+
- `problem`: The target `LogDensityProblem`.
109+
- `adtype`: The `ADType` used for differentiating the forward path.
110+
- `restructure`: Callable for restructuring the varitional distribution from `params`.
111+
- `q_stop`: A copy of `restructure(params)` with its gradient "stopped" (excluded from the AD path).
112+
"""
113+
function estimate_repgradelbo_ad_forward(params, aux)
89114
(; rng, obj, problem, adtype, restructure, q_stop) = aux
90-
q = restructure_ad_forward(adtype, restructure, params)
115+
q = restructure_ad_forward(adtype, restructure, params)
91116
samples, entropy = reparam_with_entropy(rng, q, q_stop, obj.n_samples, obj.entropy)
92117
energy = estimate_energy_with_samples(problem, samples)
93118
elbo = energy + entropy
Lines changed: 39 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -1,113 +1,63 @@
1+
12
"""
23
ScoreGradELBO(n_samples; kwargs...)
34
4-
Evidence lower-bound objective computed with score function gradients.
5-
```math
6-
\\begin{aligned}
7-
\\nabla_{\\lambda} \\mathrm{ELBO}\\left(\\lambda\\right)
8-
&\\=
9-
\\mathbb{E}_{z \\sim q_{\\lambda}}\\left[
10-
\\log \\pi\\left(z\\right) \\nabla_{\\lambda} \\log q_{\\lambda}(z)
11-
\\right]
12-
+ \\mathbb{H}\\left(q_{\\lambda}\\right),
13-
\\end{aligned}
14-
```
15-
16-
To reduce the variance of the gradient estimator, we use a baseline computed from a running average of the previous ELBO values and subtract it from the objective.
17-
18-
```math
19-
\\mathbb{E}_{z \\sim q_{\\lambda}}\\left[
20-
\\nabla_{\\lambda} \\log q_{\\lambda}(z) \\left(\\pi\\left(z\\right) - \\beta\\right)
21-
\\right]
22-
```
5+
Evidence lower-bound objective computed with score function gradient with the VarGrad objective, also known as the leave-one-out control variate.
236
247
# Arguments
25-
- `n_samples::Int`: Number of Monte Carlo samples used to estimate the ELBO.
26-
27-
# Keyword Arguments
28-
- `entropy`: The estimator for the entropy term. (Type `<: AbstractEntropyEstimator`; Default: `ClosedFormEntropy()`)
29-
- `baseline_window_size::Int`: The window size to use to compute the baseline. (Default: `10`)
30-
- `baseline_history::Vector{Float64}`: The history of the baseline. (Default: `Float64[]`)
8+
- `n_samples::Int`: Number of Monte Carlo samples used to estimate the VarGrad objective.
319
3210
# Requirements
3311
- The variational approximation ``q_{\\lambda}`` implements `rand` and `logpdf`.
3412
- `logpdf(q, x)` must be differentiable with respect to `q` by the selected AD backend.
3513
- The target distribution and the variational approximation have the same support.
36-
37-
Depending on the options, additional requirements on ``q_{\\lambda}`` may apply.
3814
"""
39-
struct ScoreGradELBO{EntropyEst<:AbstractEntropyEstimator} <:
40-
AdvancedVI.AbstractVariationalObjective
41-
entropy::EntropyEst
15+
struct ScoreGradELBO <: AbstractVariationalObjective
4216
n_samples::Int
43-
baseline_window_size::Int
44-
baseline_history::Vector{Float64}
45-
end
46-
47-
function ScoreGradELBO(
48-
n_samples::Int;
49-
entropy::AbstractEntropyEstimator=ClosedFormEntropy(),
50-
baseline_window_size::Int=10,
51-
baseline_history::Vector{Float64}=Float64[],
52-
)
53-
return ScoreGradELBO(entropy, n_samples, baseline_window_size, baseline_history)
5417
end
5518

5619
function Base.show(io::IO, obj::ScoreGradELBO)
57-
print(io, "ScoreGradELBO(entropy=")
58-
print(io, obj.entropy)
59-
print(io, ", n_samples=")
20+
print(io, "ScoreGradELBO(n_samples=")
6021
print(io, obj.n_samples)
61-
print(io, ", baseline_window_size=")
62-
print(io, obj.baseline_window_size)
6322
return print(io, ")")
6423
end
6524

66-
function compute_control_variate_baseline(history, window_size)
67-
if length(history) == 0
68-
return 1.0
69-
end
70-
min_index = max(1, length(history) - window_size)
71-
return mean(history[min_index:end])
72-
end
73-
74-
function estimate_energy_with_samples(
75-
prob, samples_stop, samples_logprob, samples_logprob_stop, baseline
76-
)
77-
fv = Base.Fix1(LogDensityProblems.logdensity, prob).(eachsample(samples_stop))
78-
fv_mean = mean(fv)
79-
score_grad = mean(@. samples_logprob * (fv - baseline))
80-
score_grad_stop = mean(@. samples_logprob_stop * (fv - baseline))
81-
return fv_mean + (score_grad - score_grad_stop)
82-
end
83-
8425
function estimate_objective(
8526
rng::Random.AbstractRNG, obj::ScoreGradELBO, q, prob; n_samples::Int=obj.n_samples
8627
)
87-
samples, entropy = reparam_with_entropy(rng, q, q, obj.n_samples, obj.entropy)
88-
energy = map(Base.Fix1(LogDensityProblems.logdensity, prob), eachsample(samples))
89-
return mean(energy) + entropy
28+
samples = rand(rng, q, n_samples)
29+
ℓπ = map(Base.Fix1(LogDensityProblems.logdensity, prob), eachsample(samples))
30+
ℓq = logpdf.(Ref(q), AdvancedVI.eachsample(samples))
31+
return mean(ℓπ - ℓq)
9032
end
9133

9234
function estimate_objective(obj::ScoreGradELBO, q, prob; n_samples::Int=obj.n_samples)
9335
return estimate_objective(Random.default_rng(), obj, q, prob; n_samples)
9436
end
9537

96-
function estimate_scoregradelbo_ad_forward(params′, aux)
97-
(; rng, obj, problem, adtype, restructure, q_stop) = aux
98-
baseline = compute_control_variate_baseline(
99-
obj.baseline_history, obj.baseline_window_size
100-
)
101-
q = restructure_ad_forward(adtype, restructure, params′)
102-
samples_stop = rand(rng, q_stop, obj.n_samples)
103-
entropy = estimate_entropy_maybe_stl(obj.entropy, samples_stop, q, q_stop)
104-
samples_logprob = logpdf.(Ref(q), AdvancedVI.eachsample(samples_stop))
105-
samples_logprob_stop = logpdf.(Ref(q_stop), AdvancedVI.eachsample(samples_stop))
106-
energy = estimate_energy_with_samples(
107-
problem, samples_stop, samples_logprob, samples_logprob_stop, baseline
108-
)
109-
elbo = energy + entropy
110-
return -elbo
38+
"""
39+
estimate_scoregradelbo_ad_forward(params, aux)
40+
41+
AD-guaranteed forward path of the score gradient objective.
42+
43+
# Arguments
44+
- `params`: Variational parameters.
45+
- `aux`: Auxiliary information excluded from the AD path.
46+
47+
# Auxiliary Information
48+
`aux` should containt the following entries:
49+
- `samples_stop`: Samples drawn from `q = restructure(params)` but with their gradients stopped (excluded from the AD path).
50+
- `logprob_stop`: Log-densities of the target `LogDensityProblem` evaluated over `samples_stop`.
51+
- `adtype`: The `ADType` used for differentiating the forward path.
52+
- `restructure`: Callable for restructuring the varitional distribution from `params`.
53+
"""
54+
function estimate_scoregradelbo_ad_forward(params, aux)
55+
(; samples_stop, logprob_stop, adtype, restructure) = aux
56+
q = restructure_ad_forward(adtype, restructure, params)
57+
ℓπ = logprob_stop
58+
ℓq = logpdf.(Ref(q), AdvancedVI.eachsample(samples_stop))
59+
f = ℓq - ℓπ
60+
return (mean(abs2, f) - mean(f)^2) / 2
11161
end
11262

11363
function AdvancedVI.estimate_gradient!(
@@ -120,20 +70,15 @@ function AdvancedVI.estimate_gradient!(
12070
restructure,
12171
state,
12272
)
123-
q_stop = restructure(params)
124-
aux = (
125-
rng=rng,
126-
adtype=adtype,
127-
obj=obj,
128-
problem=prob,
129-
restructure=restructure,
130-
q_stop=q_stop,
131-
)
73+
q = restructure(params)
74+
samples = rand(rng, q, obj.n_samples)
75+
ℓπ = map(Base.Fix1(LogDensityProblems.logdensity, prob), eachsample(samples))
76+
aux = (adtype=adtype, logprob_stop=ℓπ, samples_stop=samples, restructure=restructure)
13277
AdvancedVI.value_and_gradient!(
13378
adtype, estimate_scoregradelbo_ad_forward, params, aux, out
13479
)
135-
nelbo = DiffResults.value(out)
136-
stat = (elbo=-nelbo,)
137-
push!(obj.baseline_history, -nelbo)
80+
ℓq = logpdf.(Ref(q), AdvancedVI.eachsample(samples))
81+
elbo = mean(ℓπ - ℓq)
82+
stat = (elbo=elbo,)
13883
return out, nothing, stat
13984
end

test/inference/repgradelbo_distributionsad.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,10 @@ end
1414
@testset "$(modelname) $(objname) $(realtype) $(adbackname)" for realtype in
1515
[Float64, Float32],
1616
(modelname, modelconstr) in Dict(:Normal => normal_meanfield),
17-
n_montecarlo in [1, 10],
1817
(objname, objective) in Dict(
19-
:RepGradELBOClosedFormEntropy => RepGradELBO(n_montecarlo),
18+
:RepGradELBOClosedFormEntropy => RepGradELBO(10),
2019
:RepGradELBOStickingTheLanding =>
21-
RepGradELBO(n_montecarlo; entropy=StickingTheLandingEntropy()),
20+
RepGradELBO(10; entropy=StickingTheLandingEntropy()),
2221
),
2322
(adbackname, adtype) in AD_repgradelbo_distributionsad
2423

test/inference/repgradelbo_locationscale.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,15 @@ else
1010
)
1111
end
1212

13-
@testset "inference ScoreGradELBO VILocationScale" begin
13+
@testset "inference RepGradELBO VILocationScale" begin
1414
@testset "$(modelname) $(objname) $(realtype) $(adbackname)" for realtype in
1515
[Float64, Float32],
1616
(modelname, modelconstr) in
1717
Dict(:Normal => normal_meanfield, :Normal => normal_fullrank),
18-
n_montecarlo in [1, 10],
1918
(objname, objective) in Dict(
20-
:RepGradELBOClosedFormEntropy => RepGradELBO(n_montecarlo),
19+
:RepGradELBOClosedFormEntropy => RepGradELBO(10),
2120
:RepGradELBOStickingTheLanding =>
22-
RepGradELBO(n_montecarlo; entropy=StickingTheLandingEntropy()),
21+
RepGradELBO(10; entropy=StickingTheLandingEntropy()),
2322
),
2423
(adbackname, adtype) in AD_repgradelbo_locationscale
2524

test/inference/repgradelbo_locationscale_bijectors.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,10 @@ end
1515
[Float64, Float32],
1616
(modelname, modelconstr) in
1717
Dict(:NormalLogNormalMeanField => normallognormal_meanfield),
18-
n_montecarlo in [1, 10],
1918
(objname, objective) in Dict(
20-
:RepGradELBOClosedFormEntropy => RepGradELBO(n_montecarlo),
19+
:RepGradELBOClosedFormEntropy => RepGradELBO(10),
2120
:RepGradELBOStickingTheLanding =>
22-
RepGradELBO(n_montecarlo; entropy=StickingTheLandingEntropy()),
21+
RepGradELBO(10; entropy=StickingTheLandingEntropy()),
2322
),
2423
(adbackname, adtype) in AD_repgradelbo_locationscale_bijectors
2524

test/inference/scoregradelbo_distributionsad.jl

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,7 @@ end
1414
@testset "$(modelname) $(objname) $(realtype) $(adbackname)" for realtype in
1515
[Float64, Float32],
1616
(modelname, modelconstr) in Dict(:Normal => normal_meanfield),
17-
n_montecarlo in [1, 10],
18-
(objname, objective) in Dict(
19-
:ScoreGradELBOClosedFormEntropy => ScoreGradELBO(n_montecarlo),
20-
:ScoreGradELBOStickingTheLanding =>
21-
ScoreGradELBO(n_montecarlo; entropy=StickingTheLandingEntropy()),
22-
),
17+
(objname, objective) in Dict(:ScoreGradELBO => ScoreGradELBO(10)),
2318
(adbackname, adtype) in AD_scoregradelbo_distributionsad
2419

2520
seed = (0x38bef07cf9cc549d)
@@ -29,7 +24,7 @@ end
2924
(; model, μ_true, L_true, n_dims, strong_convexity, is_meanfield) = modelstats
3025

3126
T = 1000
32-
η = 1e-5
27+
η = 1e-4
3328
opt = Optimisers.Descent(realtype(η))
3429

3530
# For small enough η, the error of SGD, Δλ, is bounded as

test/inference/scoregradelbo_locationscale.jl

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,7 @@ end
1515
[Float64, Float32],
1616
(modelname, modelconstr) in
1717
Dict(:Normal => normal_meanfield, :Normal => normal_fullrank),
18-
n_montecarlo in [1, 10],
19-
(objname, objective) in Dict(
20-
:ScoreGradELBOClosedFormEntropy => ScoreGradELBO(n_montecarlo),
21-
:ScoreGradELBOStickingTheLanding =>
22-
ScoreGradELBO(n_montecarlo; entropy=StickingTheLandingEntropy()),
23-
),
18+
(objname, objective) in Dict(:ScoreGradELBO => ScoreGradELBO(10)),
2419
(adbackname, adtype) in AD_scoregradelbo_locationscale
2520

2621
seed = (0x38bef07cf9cc549d)
@@ -30,7 +25,7 @@ end
3025
(; model, μ_true, L_true, n_dims, strong_convexity, is_meanfield) = modelstats
3126

3227
T = 1000
33-
η = 1e-5
28+
η = 1e-4
3429
opt = Optimisers.Descent(realtype(η))
3530

3631
# For small enough η, the error of SGD, Δλ, is bounded as

test/inference/scoregradelbo_locationscale_bijectors.jl

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,7 @@ end
1515
[Float64, Float32],
1616
(modelname, modelconstr) in
1717
Dict(:NormalLogNormalMeanField => normallognormal_meanfield),
18-
n_montecarlo in [1, 10],
19-
(objname, objective) in Dict(
20-
#:ScoreGradELBOClosedFormEntropy => ScoreGradELBO(n_montecarlo), # not supported yet.
21-
:ScoreGradELBOStickingTheLanding =>
22-
ScoreGradELBO(n_montecarlo; entropy=StickingTheLandingEntropy()),
23-
),
18+
(objname, objective) in Dict(:ScoreGradELBO => ScoreGradELBO(10)),
2419
(adbackname, adtype) in AD_scoregradelbo_locationscale_bijectors
2520

2621
seed = (0x38bef07cf9cc549d)
@@ -30,7 +25,7 @@ end
3025
(; model, μ_true, L_true, n_dims, strong_convexity, is_meanfield) = modelstats
3126

3227
T = 1000
33-
η = 1e-5
28+
η = 1e-4
3429
opt = Optimisers.Descent(realtype(η))
3530

3631
b = Bijectors.bijector(model)

0 commit comments

Comments
 (0)