-
Notifications
You must be signed in to change notification settings - Fork 19
Batch-and-Match algorithm for minimizing the covariance-weighted Fisher divergence #218
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
+405
−1
Merged
Changes from all commits
Commits
Show all changes
20 commits
Select commit
Hold shift + click to select a range
6816292
add batch and match
Red-Portal 09317e2
update HISTORY
Red-Portal 2b72d51
fun formatter
Red-Portal b746ab7
fix add missing test file
Red-Portal bb70d02
add documentation and update docstring for batch-and-match
Red-Portal 0103702
run formatter
Red-Portal 867fc27
run formatter
Red-Portal c5e725b
run formatter
Red-Portal e3e4e3b
run formatter
Red-Portal 188489e
run formatter
Red-Portal b4da148
fix docs
Red-Portal 3032c66
fix docs
Red-Portal 5773df6
fix docs
Red-Portal 1f1ff5a
fix remove dead code
Red-Portal e19fb58
fix compute average outside of loop for batch-and-match
Red-Portal 11b461e
Merge branch 'batchmatch' of github.com:TuringLang/AdvancedVI.jl into…
Red-Portal 1ca0f00
fix remove reference in docstring
Red-Portal 445f6c0
fix capitalization in dosctring
Red-Portal 40ad7e9
refactor move duplicate code in batch match to a common function
Red-Portal 9b8ee3b
Merge branch 'main' into batchmatch
Red-Portal File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,61 @@ | ||
| # [`FisherMinBatchMatch`](@id fisherminbatchmatch) | ||
|
|
||
| ## Description | ||
|
|
||
| This algorithm, known as batch-and-match (BaM) aims to minimize the covariance-weighted 2nd-order Fisher divergence by running a proximal point-type method[^CMPMGBS24]. | ||
| On certain low-dimensional problems, BaM can converge very quickly without any tuning. | ||
| Since `FisherMinBatchMatch` is a measure-space algorithm, its use is restricted to full-rank Gaussian variational families (`FullRankGaussian`) that make the measure-valued operations tractable. | ||
|
|
||
| ```@docs | ||
| FisherMinBatchMatch | ||
| ``` | ||
|
|
||
| The associated objective value can be estimated through the following: | ||
|
|
||
| ```@docs; canonical=false | ||
| estimate_objective( | ||
| ::Random.AbstractRNG, | ||
| ::KLMinWassFwdBwd, | ||
| ::MvLocationScale, | ||
| ::Any; | ||
| ::Int, | ||
| ) | ||
| ``` | ||
|
|
||
| [^CMPMGBS24]: Cai, D., Modi, C., Pillaud-Vivien, L., Margossian, C. C., Gower, R. M., Blei, D. M., & Saul, L. K. (2024). Batch and match: black-box variational inference with a score-based divergence. In *Proceedings of the International Conference on Machine Learning*. | ||
| ## [Methodology](@id fisherminbatchmatch_method) | ||
|
|
||
| This algorithm aims to solve the problem | ||
|
|
||
| ```math | ||
| \mathrm{minimize}_{q \in \mathcal{Q}}\quad \mathrm{F}_{\mathrm{cov}}(q, \pi), | ||
| ``` | ||
|
|
||
| where $\mathcal{Q}$ is some family of distributions, often called the variational family, and $\mathrm{F}_{\mathrm{cov}}$ is a divergence defined as | ||
|
|
||
| ```math | ||
| \mathrm{F}_{\mathrm{cov}}(q, \pi) = \mathbb{E}_{z \sim q} {\left\lVert \nabla \log \frac{q}{\pi} (z) \right\rVert}_{\mathrm{Cov}(q)}^2 , | ||
| ``` | ||
|
|
||
| where ${\lVert x \rVert}_{A}^2 = x^{\top} A x $ is a weighted norm. | ||
| $\mathrm{F}_{\mathrm{cov}}$ can be viewed as a variant of the canonical 2nd-order Fisher divergence defined as | ||
|
|
||
| ```math | ||
| \mathrm{F}_{2}(q, \pi) = \sqrt{ \mathbb{E}_{z \sim q} {\left\lVert \nabla \log \frac{q}{\pi} (z) \right\rVert}^2 }. | ||
| ``` | ||
|
|
||
| The use of the weighted norm ${\lVert \cdot \rVert}_{\mathrm{Cov}(q)}^2$ facilitates the use of a proximal point-type method for minimizing $\mathrm{F}_{2}(q, \pi)$. | ||
| In particular, BaM iterates the update | ||
|
|
||
| ```math | ||
| q_{t+1} = \argmin_{q \in \mathcal{Q}} \left\{ \mathrm{F}_{\mathrm{cov}}(q, \pi) + \frac{2}{\lambda_t} \mathrm{KL}\left(q_t, q\right) \right\} . | ||
| ``` | ||
|
|
||
| Since $\mathrm{F}(q, \pi)$ is intractable, it is replaced with a Monte Carlo approximation with a number of samples `n_samples`. | ||
| Furthermore, by restricting $\mathcal{Q}$ to a Gaussian variational family, the update rule admits a closed form solution[^CMPMGBS24]. | ||
| Notice that the update does not involve the parameterization of $q_t$, which makes `FisherMinBatchMatch` a measure-space algorithm. | ||
|
|
||
| Historically, the idea of using a proximal point-type update for minimizing a Fisher divergence-like objective was initially coined as Gaussian score matching[^MGMYBS23]. | ||
| BaM can be viewed as a successor to this algorithm. | ||
|
|
||
| [^MGMYBS23]: Modi, C., Gower, R., Margossian, C., Yao, Y., Blei, D., & Saul, L. (2023). Variational inference with Gaussian score matching. In *Advances in Neural Information Processing Systems*, 36. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,195 @@ | ||
|
|
||
| """ | ||
| FisherMinBatchMatch(n_samples, subsampling) | ||
| FisherMinBatchMatch(; n_samples, subsampling) | ||
|
|
||
| Covariance-weighted Fisher divergence minimization via the batch-and-match algorithm, which is a proximal point-type optimization scheme. | ||
|
|
||
| # (Keyword) Arguments | ||
| - `n_samples::Int`: Number of samples (batchsize) used to compute the moments required for the batch-and-match update. (default: `32`) | ||
| - `subsampling::Union{Nothing,<:AbstractSubsampling}`: Optional subsampling strategy. (default: `nothing`) | ||
|
|
||
| !!! warning | ||
| `FisherMinBatchMatch` with subsampling enabled results in a biased algorithm and may not properly optimize the covariance-weighted Fisher divergence. | ||
|
|
||
| !!! note | ||
| `FisherMinBatchMatch` requires a sufficiently large `n_samples` to converge quickly. | ||
|
|
||
| !!! note | ||
| The `subsampling` strategy is only applied to the target `LogDensityProblem` but not to the variational approximation `q`. That is, `FisherMinBatchMatch` does not support amortization or structured variational families. | ||
|
|
||
| # Output | ||
| - `q`: The last iterate of the algorithm. | ||
|
|
||
| # Callback Signature | ||
| The `callback` function supplied to `optimize` needs to have the following signature: | ||
|
|
||
| callback(; rng, iteration, q, info) | ||
|
|
||
| The keyword arguments are as follows: | ||
| - `rng`: Random number generator internally used by the algorithm. | ||
| - `iteration`: The index of the current iteration. | ||
| - `q`: Current variational approximation. | ||
| - `info`: `NamedTuple` containing the information generated during the current iteration. | ||
|
|
||
| # Requirements | ||
| - The variational family is [`FullRankGaussian`](@ref FullRankGaussian). | ||
| - The target distribution has unconstrained support. | ||
| - The target `LogDensityProblems.logdensity(prob, x)` has at least first-order differentiation capability. | ||
| """ | ||
| @kwdef struct FisherMinBatchMatch{Sub<:Union{Nothing,<:AbstractSubsampling}} <: | ||
| AbstractVariationalAlgorithm | ||
| n_samples::Int = 32 | ||
| subsampling::Sub = nothing | ||
| end | ||
|
|
||
| struct BatchMatchState{Q,P,Sigma,Sub,UBuf,GradBuf} | ||
| q::Q | ||
| prob::P | ||
| sigma::Sigma | ||
| iteration::Int | ||
| sub_st::Sub | ||
| u_buf::UBuf | ||
| grad_buf::GradBuf | ||
| end | ||
|
|
||
| function init( | ||
| rng::Random.AbstractRNG, | ||
| alg::FisherMinBatchMatch, | ||
| q::MvLocationScale{<:LowerTriangular,<:Normal,L}, | ||
| prob, | ||
| ) where {L} | ||
| (; n_samples, subsampling) = alg | ||
| capability = LogDensityProblems.capabilities(typeof(prob)) | ||
| if capability < LogDensityProblems.LogDensityOrder{1}() | ||
| throw( | ||
| ArgumentError( | ||
| "`FisherMinBatchMatch` requires at least first-order differentiation capability. The capability of the supplied `LogDensityProblem` is $(capability).", | ||
| ), | ||
| ) | ||
| end | ||
| sub_st = isnothing(subsampling) ? nothing : init(rng, subsampling) | ||
| params, _ = Optimisers.destructure(q) | ||
| n_dims = LogDensityProblems.dimension(prob) | ||
| u_buf = Matrix{eltype(params)}(undef, n_dims, n_samples) | ||
| grad_buf = Matrix{eltype(params)}(undef, n_dims, n_samples) | ||
| return BatchMatchState(q, prob, cov(q), 0, sub_st, u_buf, grad_buf) | ||
| end | ||
|
|
||
| output(::FisherMinBatchMatch, state) = state.q | ||
|
|
||
| function rand_batch_match_samples_with_objective!( | ||
| rng::Random.AbstractRNG, | ||
| q::MvLocationScale, | ||
| n_samples::Int, | ||
| prob, | ||
| u_buf=Matrix{eltype(q)}(undef, LogDensityProblems.dimension(prob), n_samples), | ||
| grad_buf=Matrix{eltype(q)}(undef, LogDensityProblems.dimension(prob), n_samples), | ||
| ) | ||
| μ = q.location | ||
| C = q.scale | ||
| u = Random.randn!(rng, u_buf) | ||
| z = C*u .+ μ | ||
| logπ_sum = zero(eltype(μ)) | ||
| for b in 1:n_samples | ||
| logπb, gb = LogDensityProblems.logdensity_and_gradient(prob, view(z, :, b)) | ||
| grad_buf[:, b] = gb | ||
| logπ_sum += logπb | ||
| end | ||
| logπ_avg = logπ_sum/n_samples | ||
|
|
||
| # Estimate objective values | ||
| # | ||
| # F = E[| ∇log(q/π) (z) |_{CC'}^2] (definition) | ||
| # = E[| C' (∇logq(z) - ∇logπ(z)) |^2] (Σ = CC') | ||
| # = E[| C' ( -(CC')\((Cu + μ) - μ) - ∇logπ(z)) |^2] (z = Cu + μ) | ||
| # = E[| C' ( -(CC')\(Cu) - ∇logπ(z)) |^2] | ||
| # = E[| -u - C'∇logπ(z)) |^2] | ||
| fisher = sum(abs2, -u_buf - (C'*grad_buf))/n_samples | ||
|
|
||
| return u_buf, z, grad_buf, fisher, logπ_avg | ||
| end | ||
|
|
||
| function step( | ||
| rng::Random.AbstractRNG, | ||
| alg::FisherMinBatchMatch, | ||
| state, | ||
| callback, | ||
| objargs...; | ||
| kwargs..., | ||
| ) | ||
| (; n_samples, subsampling) = alg | ||
| (; q, prob, sigma, iteration, sub_st, u_buf, grad_buf) = state | ||
|
|
||
| d = LogDensityProblems.dimension(prob) | ||
| μ = q.location | ||
| C = q.scale | ||
| Σ = sigma | ||
| iteration += 1 | ||
|
|
||
| # Maybe apply subsampling | ||
| prob_sub, sub_st′, sub_inf = if isnothing(subsampling) | ||
| prob, sub_st, NamedTuple() | ||
| else | ||
| batch, sub_st′, sub_inf = step(rng, subsampling, sub_st) | ||
| prob_sub = subsample(prob, batch) | ||
| prob_sub, sub_st′, sub_inf | ||
| end | ||
|
|
||
| u_buf, z, grad_buf, fisher, logπ_avg = rand_batch_match_samples_with_objective!( | ||
| rng, q, n_samples, prob_sub, u_buf, grad_buf | ||
| ) | ||
|
|
||
| # BaM updates | ||
| zbar, C = mean_and_cov(z, 2) | ||
| gbar, Γ = mean_and_cov(grad_buf, 2) | ||
|
|
||
| μmz = μ - zbar | ||
| λ = convert(eltype(μ), d*n_samples / iteration) | ||
|
|
||
| U = Symmetric(λ*Γ + (λ/(1 + λ)*gbar)*gbar') | ||
| V = Symmetric(Σ + λ*C + (λ/(1 + λ)*μmz)*μmz') | ||
|
|
||
| Σ′ = Hermitian(2*V/(I + real(sqrt(I + 4*U*V)))) | ||
| μ′ = 1/(1 + λ)*μ + λ/(1 + λ)*(Σ′*gbar + zbar) | ||
| q′ = MvLocationScale(μ′[:, 1], cholesky(Σ′).L, q.dist) | ||
|
|
||
| elbo = logπ_avg + entropy(q) | ||
| info = (iteration=iteration, covweighted_fisher=fisher, elbo=elbo) | ||
|
|
||
| state = BatchMatchState(q′, prob, Σ′, iteration, sub_st′, u_buf, grad_buf) | ||
|
|
||
| if !isnothing(callback) | ||
| info′ = callback(; rng, iteration, q, state) | ||
| info = !isnothing(info′) ? merge(info′, info) : info | ||
| end | ||
| state, false, info | ||
| end | ||
|
|
||
| """ | ||
| estimate_objective([rng,] alg, q, prob; n_samples) | ||
|
|
||
| Estimate the covariance-weighted Fisher divergence of the variational approximation `q` against the target log-density `prob`. | ||
|
|
||
| # Arguments | ||
| - `rng::Random.AbstractRNG`: Random number generator. | ||
| - `alg::FisherMinBatchMatch`: Variational inference algorithm. | ||
| - `q::MvLocationScale{<:Any,<:Normal,<:Any}`: Gaussian variational approximation. | ||
| - `prob`: The target log-joint likelihood implementing the `LogDensityProblem` interface. | ||
|
|
||
| # Keyword Arguments | ||
| - `n_samples::Int`: Number of Monte Carlo samples for estimating the objective. (default: Same as the the number of samples used for estimating the gradient during optimization.) | ||
|
|
||
| # Returns | ||
| - `obj_est`: Estimate of the objective value. | ||
| """ | ||
| function estimate_objective( | ||
| rng::Random.AbstractRNG, | ||
| alg::FisherMinBatchMatch, | ||
| q::MvLocationScale{S,<:Normal,L}, | ||
| prob; | ||
| n_samples::Int=alg.n_samples, | ||
Red-Portal marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| ) where {S,L} | ||
| _, _, _, fisher, _ = rand_batch_match_samples_with_objective!(rng, q, n_samples, prob) | ||
| return fisher | ||
| end | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.