AdvancedVI provides implementations of variational inference (VI) algorithms, which is a family of algorithms aiming for scalable approximate Bayesian inference by leveraging optimization.
AdvancedVI
is part of the Turing probabilistic programming ecosystem.
The purpose of this package is to provide a common accessible interface for various VI algorithms and utilities so that other packages, e.g. Turing
, only need to write a light wrapper for integration.
For example, integrating Turing
with AdvancedVI.ADVI
only involves converting a Turing.Model
into a LogDensityProblem
and extracting a corresponding Bijectors.bijector
.
We will describe a simple example to demonstrate the basic usage of AdvancedVI
.
AdvancedVI
works with differentiable models specified through the LogDensityProblem
interface.
Let's look at a basic logistic regression example with a hierarchical prior.
For a dataset
The LogDensityProblem
corresponding to this model can be constructed as
import LogDensityProblems
using Distributions
using FillArrays
struct LogReg{XType,YType}
X::XType
y::YType
end
function LogDensityProblems.logdensity(model::LogReg, θ)
(; X, y) = model
d = size(X, 2)
β, σ = θ[1:size(X, 2)], θ[end]
logprior_β = logpdf(MvNormal(Zeros(d), σ*I), β)
logprior_σ = logpdf(truncated(TDist(3.0); lower=0), σ)
logit = X*β
loglike_y = sum(@. logpdf(BernoulliLogit(logit), y))
return loglike_y + logprior_β + logprior_σ
end
function LogDensityProblems.dimension(model::LogReg)
return size(model.X, 2) + 1
end
function LogDensityProblems.capabilities(::Type{<:LogReg})
return LogDensityProblems.LogDensityOrder{0}()
end
Since the support of σ
is constrained to be positive and most VI algorithms assume an unconstrained Euclidean support, we need to use a bijector to transform θ
.
We will use Bijectors
for this purpose.
This corresponds to the automatic differentiation variational inference (ADVI) formulation1.
import Bijectors
function Bijectors.bijector(model::LogReg)
d = size(model.X, 2)
return Bijectors.Stacked(
Bijectors.bijector.([MvNormal(Zeros(d), 1.0), truncated(TDist(3.0); lower=0)]),
[1:d, (d + 1):(d + 1)],
)
end
A simpler approach would be to use Turing
, where a Turing.Model
can be automatically be converted into a LogDensityProblem
and a corresponding bijector
is automatically generated.
For the dataset, we will use the popular sonar classification dataset from the UCI repository.
This can be automatically downloaded using OpenML
.
The sonar dataset corresponds to the dataset id 40.
import OpenML
import DataFrames
data = Array(DataFrames.DataFrame(OpenML.load(40)))
X = Matrix{Float64}(data[:, 1:(end - 1)])
y = Vector{Bool}(data[:, end] .== "Mine")
Let's apply some basic pre-processing and add an intercept column:
X = (X .- mean(X; dims=2)) ./ std(X; dims=2)
X = hcat(X, ones(size(X, 1)))
The model can now be instantiated as follows:
model = LogReg(X, y)
For the VI algorithm, we will use the following:
using ADTypes, ReverseDiff
using AdvancedVI
alg = KLMinRepGradDescent(ADTypes.AutoReverseDiff())
This algorithm minimizes the exclusive/reverse KL divergence via stochastic gradient descent in the (Euclidean) space of the parameters of the variational approximation with the reparametrization gradient234. This is also commonly referred as automatic differentiation VI, black-box VI, stochastic gradient VI, and so on.
This KLMinRepGradDescent
, in particular, assumes that the target LogDensityProblem
has gradients.
For this, it is straightforward to use LogDensityProblemsAD
:
import DifferentiationInterface
import LogDensityProblemsAD
model_ad = LogDensityProblemsAD.ADgradient(ADTypes.AutoReverseDiff(), model)
For the variational family, we will consider a FullRankGaussian
approximation:
using LinearAlgebra
d = LogDensityProblems.dimension(model_ad)
q = MeanFieldGaussian(zeros(d), Diagonal(ones(d)))
The bijector can now be applied to q
to match the support of the target problem.
b = Bijectors.bijector(model)
binv = Bijectors.inverse(b)
q_transformed = Bijectors.TransformedDistribution(q, binv)
We can now run VI:
max_iter = 10^3
q_avg, info, _ = AdvancedVI.optimize(
alg,
max_iter,
model_ad,
q_transformed;
)
For more examples and details, please refer to the documentation.
Footnotes
-
Kucukelbir, A., Tran, D., Ranganath, R., Gelman, A., & Blei, D. M. (2017). Automatic differentiation variational inference. Journal of machine learning research. ↩
-
Titsias, M., & Lázaro-Gredilla, M. (2014, June). Doubly stochastic variational Bayes for non-conjugate inference. In International Conference on Machine Learning. PMLR. ↩
-
Rezende, D. J., Mohamed, S., & Wierstra, D. (2014, June). Stochastic backpropagation and approximate inference in deep generative models. In International Conference on Machine Learning. PMLR. ↩
-
Kingma, D. P., & Welling, M. (2014). Auto-encoding variational bayes. In International Conference on Learning Representations. ↩