-
Notifications
You must be signed in to change notification settings - Fork 70
Add deterministic advi #564
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
zaxtax
merged 30 commits into
pymc-devs:main
from
martiningram:add_basic_deterministic_advi
Sep 3, 2025
Merged
Changes from 25 commits
Commits
Show all changes
30 commits
Select commit
Hold shift + click to select a range
488bd9c
Add first version of deterministic ADVI
f46f1cd
Update API
894f62b
Add a notebook example
a1afaf6
Merge branch 'main' into add_basic_deterministic_advi
637fc3b
Add to API and add a docstring
3e397f7
Change import in notebook
d954ec7
Add jax to dependencies
aad9f21
Add pytensor version
ef3d86b
Fix handling of pymc model
6bf92ef
Add (probably suboptimal) handling of the two backends
32aff46
Add transformation
138f8c2
Follow Ricardo's advice to simplify the transformation step
7073a7d
Fix naming bug
609aef7
Document and clean up
b611d51
Merge branch 'main' into add_basic_deterministic_advi
f17a090
Fix example
9ab2e1e
Update pymc_extras/inference/deterministic_advi/dadvi.py
martiningram a8a53f3
Respond to comments
bdee446
Fix with pre commit checks
3fcafb6
Update pymc_extras/inference/deterministic_advi/dadvi.py
martiningram ad46b07
Implement suggestions
6cd0184
Rename parameter because it's duplicated otherwise
d648105
Rename to be consistent in use of dadvi
9d18f80
Rename to `optimizer_method` and drop jac=True
9f86d4f
Add jac=True back in since trust-ncg complained
3b090ca
Make hessp and jac optional
93cd831
Harmonize naming with existing code
7b84872
Fix example
7cd407e
Switch to `better_optimize`
cb070aa
Replace with pt.split
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,237 @@ | ||||||
import arviz as az | ||||||
import numpy as np | ||||||
import pymc | ||||||
import pytensor | ||||||
import pytensor.tensor as pt | ||||||
import xarray | ||||||
|
||||||
from better_optimize.constants import minimize_method | ||||||
from pymc import DictToArrayBijection, Model, join_nonshared_inputs | ||||||
from pymc.backends.arviz import ( | ||||||
PointFunc, | ||||||
apply_function_over_dataset, | ||||||
coords_and_dims_for_inferencedata, | ||||||
) | ||||||
from pymc.util import RandomSeed, get_default_varnames | ||||||
from pytensor.tensor.variable import TensorVariable | ||||||
from scipy.optimize import minimize | ||||||
|
||||||
from pymc_extras.inference.laplace_approx.laplace import unstack_laplace_draws | ||||||
from pymc_extras.inference.laplace_approx.scipy_interface import ( | ||||||
_compile_functions_for_scipy_optimize, | ||||||
) | ||||||
|
||||||
|
||||||
def fit_dadvi( | ||||||
model: Model | None = None, | ||||||
n_fixed_draws: int = 30, | ||||||
random_seed: RandomSeed = None, | ||||||
n_draws: int = 1000, | ||||||
keep_untransformed: bool = False, | ||||||
optimizer_method: minimize_method = "trust-ncg", | ||||||
**minimize_kwargs, | ||||||
) -> az.InferenceData: | ||||||
""" | ||||||
Does inference using deterministic ADVI (automatic differentiation | ||||||
variational inference), DADVI for short. | ||||||
|
||||||
For full details see the paper cited in the references: | ||||||
https://www.jmlr.org/papers/v25/23-1015.html | ||||||
|
||||||
Parameters | ||||||
---------- | ||||||
model : pm.Model | ||||||
The PyMC model to be fit. If None, the current model context is used. | ||||||
|
||||||
n_fixed_draws : int | ||||||
The number of fixed draws to use for the optimisation. More | ||||||
draws will result in more accurate estimates, but also | ||||||
increase inference time. Usually, the default of 30 is a good | ||||||
tradeoff.between speed and accuracy. | ||||||
|
||||||
random_seed: int | ||||||
The random seed to use for the fixed draws. Running the optimisation | ||||||
twice with the same seed should arrive at the same result. | ||||||
|
||||||
n_draws: int | ||||||
The number of draws to return from the variational approximation. | ||||||
|
||||||
keep_untransformed: bool | ||||||
Whether or not to keep the unconstrained variables (such as | ||||||
logs of positive-constrained parameters) in the output. | ||||||
|
||||||
optimizer_method: str | ||||||
Which optimization method to use. The function calls | ||||||
``scipy.optimize.minimize``, so any of the methods there can | ||||||
be used. The default is trust-ncg, which uses second-order | ||||||
information and is generally very reliable. Other methods such | ||||||
as L-BFGS-B might be faster but potentially more brittle and | ||||||
may not converge exactly to the optimum. | ||||||
|
||||||
minimize_kwargs: | ||||||
Additional keyword arguments to pass to the | ||||||
``scipy.optimize.minimize`` function. See the documentation of | ||||||
that function for details. | ||||||
|
||||||
Returns | ||||||
------- | ||||||
:class:`~arviz.InferenceData` | ||||||
The inference data containing the results of the DADVI algorithm. | ||||||
|
||||||
References | ||||||
---------- | ||||||
Giordano, R., Ingram, M., & Broderick, T. (2024). Black Box | ||||||
Variational Inference with a Deterministic Objective: Faster, More | ||||||
Accurate, and Even More Black Box. Journal of Machine Learning | ||||||
Research, 25(18), 1–39. | ||||||
""" | ||||||
|
||||||
model = pymc.modelcontext(model) if model is None else model | ||||||
|
||||||
initial_point_dict = model.initial_point() | ||||||
n_params = DictToArrayBijection.map(initial_point_dict).data.shape[0] | ||||||
|
||||||
var_params, objective = create_dadvi_graph( | ||||||
model, | ||||||
n_fixed_draws=n_fixed_draws, | ||||||
random_seed=random_seed, | ||||||
n_params=n_params, | ||||||
) | ||||||
|
||||||
f_fused, f_hessp = _compile_functions_for_scipy_optimize( | ||||||
objective, | ||||||
[var_params], | ||||||
compute_grad=True, | ||||||
compute_hessp=True, | ||||||
compute_hess=False, | ||||||
) | ||||||
|
||||||
result = minimize( | ||||||
f_fused, | ||||||
np.zeros(2 * n_params), | ||||||
jac=True, | ||||||
|
jac=True, | |
jac=minimize_kwargs.get('use_jac', True), |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You don't need to hard-code this to true,
minimize
will automatically figure it out iff_fused
returns 2 values (and the method uses gradients).There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, I didn't know that! Will change it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah hmm, I tried it but

trust-ncg
failed with the following message:So maybe it only works the way you describe with
L-BFGS-B
, or perhapstrust-ncg
is just extra-finicky?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a bug. I'll fix it, but in the meantime you can do something like:
use_jac = minimize_kwargs.get('use_jac', True)
so the user can override it if he wants, but it still works fine with the default?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmmm, I'm not sure I totally like that to be honest, because the

minimize_kwargs
ideally should match what's inscipy.optimize.minimize
, so that they can just be passed on as-is and we can keep the docstring:But if you think it's important to have the ability to set
jac=False
, I could add it as an argument tofit_dadvi
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've made using the Jacobian and hessp optional now, let me know what you think!