Skip to content

TuringLang/AdvancedVI.jl

Repository files navigation

Stable Dev ForwardDiff ReverseDiff Mooncake Enzyme Coverage

AdvancedVI.jl

AdvancedVI provides implementations of variational inference (VI) algorithms, which is a family of algorithms aiming for scalable approximate Bayesian inference by leveraging optimization. AdvancedVI is part of the Turing probabilistic programming ecosystem. The purpose of this package is to provide a common accessible interface for various VI algorithms and utilities so that other packages, e.g. Turing, only need to write a light wrapper for integration. For example, integrating Turing with AdvancedVI.ADVI only involves converting a Turing.Model into a LogDensityProblem and extracting a corresponding Bijectors.bijector.

Basic Example

We will describe a simple example to demonstrate the basic usage of AdvancedVI. AdvancedVI works with differentiable models specified through the LogDensityProblem interface. Let's look at a basic logistic regression example with a hierarchical prior. For a dataset $(X, y)$ with the design matrix $X \in \mathbb{R}^{n \times d}$ and the response variables $y \in {0, 1}^n$, we assume the following data generating process:

$$ \begin{aligned} \sigma &\sim \text{Student-t}_{3}(0, 1) \\ \beta &\sim \text{Normal}\left(0_d, \sigma \mathrm{I}_d\right) \\ y &\sim \mathrm{BernoulliLogit}\left(X \beta\right) \end{aligned} $$

The LogDensityProblem corresponding to this model can be constructed as

import LogDensityProblems
using Distributions
using FillArrays

struct LogReg{XType,YType}
    X::XType
    y::YType
end

function LogDensityProblems.logdensity(model::LogReg, θ)
    (; X, y) = model
    d = size(X, 2)
    β, σ = θ[1:size(X, 2)], θ[end]

    logprior_β = logpdf(MvNormal(Zeros(d), σ*I), β)
    logprior_σ = logpdf(truncated(TDist(3.0); lower=0), σ)

    logit = X*β
    loglike_y = sum(@. logpdf(BernoulliLogit(logit), y))
    return loglike_y + logprior_β + logprior_σ
end

function LogDensityProblems.dimension(model::LogReg)
    return size(model.X, 2) + 1
end

function LogDensityProblems.capabilities(::Type{<:LogReg})
    return LogDensityProblems.LogDensityOrder{0}()
end

Since the support of σ is constrained to be positive and most VI algorithms assume an unconstrained Euclidean support, we need to use a bijector to transform θ. We will use Bijectors for this purpose. This corresponds to the automatic differentiation variational inference (ADVI) formulation1.

import Bijectors

function Bijectors.bijector(model::LogReg)
    d = size(model.X, 2)
    return Bijectors.Stacked(
        Bijectors.bijector.([MvNormal(Zeros(d), 1.0), truncated(TDist(3.0); lower=0)]),
        [1:d, (d + 1):(d + 1)],
    )
end

A simpler approach would be to use Turing, where a Turing.Model can be automatically be converted into a LogDensityProblem and a corresponding bijector is automatically generated.

For the dataset, we will use the popular sonar classification dataset from the UCI repository. This can be automatically downloaded using OpenML. The sonar dataset corresponds to the dataset id 40.

import OpenML
import DataFrames
data = Array(DataFrames.DataFrame(OpenML.load(40)))
X = Matrix{Float64}(data[:, 1:(end - 1)])
y = Vector{Bool}(data[:, end] .== "Mine")

Let's apply some basic pre-processing and add an intercept column:

X = (X .- mean(X; dims=2)) ./ std(X; dims=2)
X = hcat(X, ones(size(X, 1)))

The model can now be instantiated as follows:

model = LogReg(X, y)

For the VI algorithm, we will use the following:

using ADTypes, ReverseDiff
using AdvancedVI

alg = KLMinRepGradDescent(ADTypes.AutoReverseDiff())

This algorithm minimizes the exclusive/reverse KL divergence via stochastic gradient descent in the (Euclidean) space of the parameters of the variational approximation with the reparametrization gradient234. This is also commonly referred as automatic differentiation VI, black-box VI, stochastic gradient VI, and so on.

This KLMinRepGradDescent, in particular, assumes that the target LogDensityProblem has gradients. For this, it is straightforward to use LogDensityProblemsAD:

import DifferentiationInterface
import LogDensityProblemsAD

model_ad = LogDensityProblemsAD.ADgradient(ADTypes.AutoReverseDiff(), model)

For the variational family, we will consider a FullRankGaussian approximation:

using LinearAlgebra

d = LogDensityProblems.dimension(model_ad)
q = MeanFieldGaussian(zeros(d), Diagonal(ones(d)))

The bijector can now be applied to q to match the support of the target problem.

b = Bijectors.bijector(model)
binv = Bijectors.inverse(b)
q_transformed = Bijectors.TransformedDistribution(q, binv)

We can now run VI:

max_iter = 10^3
q_avg, info, _ = AdvancedVI.optimize(
    alg,
    max_iter,
    model_ad,
    q_transformed;
)

For more examples and details, please refer to the documentation.

Footnotes

  1. Kucukelbir, A., Tran, D., Ranganath, R., Gelman, A., & Blei, D. M. (2017). Automatic differentiation variational inference. Journal of machine learning research.

  2. Titsias, M., & Lázaro-Gredilla, M. (2014, June). Doubly stochastic variational Bayes for non-conjugate inference. In International Conference on Machine Learning. PMLR.

  3. Rezende, D. J., Mohamed, S., & Wierstra, D. (2014, June). Stochastic backpropagation and approximate inference in deep generative models. In International Conference on Machine Learning. PMLR.

  4. Kingma, D. P., & Welling, M. (2014). Auto-encoding variational bayes. In International Conference on Learning Representations.

About

Implementation of variational Bayes inference algorithms

Resources

License

Stars

Watchers

Forks

Packages

No packages published

Contributors 16

Languages