Skip to content
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
488bd9c
Add first version of deterministic ADVI
Aug 1, 2025
f46f1cd
Update API
Aug 12, 2025
894f62b
Add a notebook example
Aug 14, 2025
a1afaf6
Merge branch 'main' into add_basic_deterministic_advi
Aug 14, 2025
637fc3b
Add to API and add a docstring
Aug 14, 2025
3e397f7
Change import in notebook
Aug 14, 2025
d954ec7
Add jax to dependencies
Aug 14, 2025
aad9f21
Add pytensor version
Aug 16, 2025
ef3d86b
Fix handling of pymc model
Aug 16, 2025
6bf92ef
Add (probably suboptimal) handling of the two backends
Aug 16, 2025
32aff46
Add transformation
Aug 18, 2025
138f8c2
Follow Ricardo's advice to simplify the transformation step
Aug 19, 2025
7073a7d
Fix naming bug
Aug 19, 2025
609aef7
Document and clean up
Aug 19, 2025
b611d51
Merge branch 'main' into add_basic_deterministic_advi
Aug 19, 2025
f17a090
Fix example
Aug 19, 2025
9ab2e1e
Update pymc_extras/inference/deterministic_advi/dadvi.py
martiningram Aug 20, 2025
a8a53f3
Respond to comments
Aug 20, 2025
bdee446
Fix with pre commit checks
Aug 20, 2025
3fcafb6
Update pymc_extras/inference/deterministic_advi/dadvi.py
martiningram Aug 28, 2025
ad46b07
Implement suggestions
Aug 28, 2025
6cd0184
Rename parameter because it's duplicated otherwise
Aug 28, 2025
d648105
Rename to be consistent in use of dadvi
Aug 28, 2025
9d18f80
Rename to `optimizer_method` and drop jac=True
Aug 28, 2025
9f86d4f
Add jac=True back in since trust-ncg complained
Aug 28, 2025
3b090ca
Make hessp and jac optional
Aug 28, 2025
93cd831
Harmonize naming with existing code
Aug 28, 2025
7b84872
Fix example
Aug 29, 2025
7cd407e
Switch to `better_optimize`
Aug 29, 2025
cb070aa
Replace with pt.split
Aug 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
975 changes: 975 additions & 0 deletions notebooks/deterministic_advi_example.ipynb

Large diffs are not rendered by default.

9 changes: 8 additions & 1 deletion pymc_extras/inference/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from pymc_extras.inference.dadvi.dadvi import fit_dadvi
from pymc_extras.inference.fit import fit
from pymc_extras.inference.laplace_approx.find_map import find_MAP
from pymc_extras.inference.laplace_approx.laplace import fit_laplace
from pymc_extras.inference.pathfinder.pathfinder import fit_pathfinder

__all__ = ["find_MAP", "fit", "fit_laplace", "fit_pathfinder"]
__all__ = [
"find_MAP",
"fit",
"fit_laplace",
"fit_pathfinder",
"fit_dadvi",
]
Empty file.
237 changes: 237 additions & 0 deletions pymc_extras/inference/dadvi/dadvi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@
import arviz as az
import numpy as np
import pymc
import pytensor
import pytensor.tensor as pt
import xarray

from better_optimize.constants import minimize_method
from pymc import DictToArrayBijection, Model, join_nonshared_inputs
from pymc.backends.arviz import (
PointFunc,
apply_function_over_dataset,
coords_and_dims_for_inferencedata,
)
from pymc.util import RandomSeed, get_default_varnames
from pytensor.tensor.variable import TensorVariable
from scipy.optimize import minimize

from pymc_extras.inference.laplace_approx.laplace import unstack_laplace_draws
from pymc_extras.inference.laplace_approx.scipy_interface import (
_compile_functions_for_scipy_optimize,
)


def fit_dadvi(
model: Model | None = None,
n_fixed_draws: int = 30,
random_seed: RandomSeed = None,
n_draws: int = 1000,
keep_untransformed: bool = False,
optimizer_method: minimize_method = "trust-ncg",
**minimize_kwargs,
) -> az.InferenceData:
"""
Does inference using deterministic ADVI (automatic differentiation
variational inference), DADVI for short.

For full details see the paper cited in the references:
https://www.jmlr.org/papers/v25/23-1015.html

Parameters
----------
model : pm.Model
The PyMC model to be fit. If None, the current model context is used.

n_fixed_draws : int
The number of fixed draws to use for the optimisation. More
draws will result in more accurate estimates, but also
increase inference time. Usually, the default of 30 is a good
tradeoff.between speed and accuracy.

random_seed: int
The random seed to use for the fixed draws. Running the optimisation
twice with the same seed should arrive at the same result.

n_draws: int
The number of draws to return from the variational approximation.

keep_untransformed: bool
Whether or not to keep the unconstrained variables (such as
logs of positive-constrained parameters) in the output.

optimizer_method: str
Which optimization method to use. The function calls
``scipy.optimize.minimize``, so any of the methods there can
be used. The default is trust-ncg, which uses second-order
information and is generally very reliable. Other methods such
as L-BFGS-B might be faster but potentially more brittle and
may not converge exactly to the optimum.

minimize_kwargs:
Additional keyword arguments to pass to the
``scipy.optimize.minimize`` function. See the documentation of
that function for details.

Returns
-------
:class:`~arviz.InferenceData`
The inference data containing the results of the DADVI algorithm.

References
----------
Giordano, R., Ingram, M., & Broderick, T. (2024). Black Box
Variational Inference with a Deterministic Objective: Faster, More
Accurate, and Even More Black Box. Journal of Machine Learning
Research, 25(18), 1–39.
"""

model = pymc.modelcontext(model) if model is None else model

initial_point_dict = model.initial_point()
n_params = DictToArrayBijection.map(initial_point_dict).data.shape[0]

var_params, objective = create_dadvi_graph(
model,
n_fixed_draws=n_fixed_draws,
random_seed=random_seed,
n_params=n_params,
)

f_fused, f_hessp = _compile_functions_for_scipy_optimize(
objective,
[var_params],
compute_grad=True,
compute_hessp=True,
compute_hess=False,
)

result = minimize(
f_fused,
np.zeros(2 * n_params),
jac=True,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't need to hard-code this to true, minimize will automatically figure it out if f_fused returns 2 values (and the method uses gradients).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I didn't know that! Will change it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah hmm, I tried it but trust-ncg failed with the following message:
image
So maybe it only works the way you describe with L-BFGS-B, or perhaps trust-ncg is just extra-finicky?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bug. I'll fix it, but in the meantime you can do something like: use_jac = minimize_kwargs.get('use_jac', True) so the user can override it if he wants, but it still works fine with the default?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm, I'm not sure I totally like that to be honest, because the minimize_kwargs ideally should match what's in scipy.optimize.minimize, so that they can just be passed on as-is and we can keep the docstring:
image

But if you think it's important to have the ability to set jac=False, I could add it as an argument to fit_dadvi?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've made using the Jacobian and hessp optional now, let me know what you think!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
jac=True,
jac=minimize_kwargs.get('use_jac', True),

method=optimizer_method,
hessp=f_hessp,
**minimize_kwargs,
)

opt_var_params = result.x
opt_means, opt_log_sds = np.split(opt_var_params, 2)

# Make the draws:
generator = np.random.default_rng(seed=random_seed)
draws_raw = generator.standard_normal(size=(n_draws, n_params))

draws = opt_means + draws_raw * np.exp(opt_log_sds)
draws_arviz = unstack_laplace_draws(draws, model, chains=1, draws=n_draws)

transformed_draws = transform_draws(draws_arviz, model, keep_untransformed=keep_untransformed)

return transformed_draws


def create_dadvi_graph(
model: Model,
n_params: int,
n_fixed_draws: int = 30,
random_seed: RandomSeed = None,
) -> tuple[TensorVariable, TensorVariable]:
"""
Sets up the DADVI graph in pytensor and returns it.

Parameters
----------
model : pm.Model
The PyMC model to be fit.

n_params: int
The total number of parameters in the model.

n_fixed_draws : int
The number of fixed draws to use.

random_seed: int
The random seed to use for the fixed draws.

Returns
-------
Tuple[TensorVariable, TensorVariable]
A tuple whose first element contains the variational parameters,
and whose second contains the DADVI objective.
"""

# Make the fixed draws
generator = np.random.default_rng(seed=random_seed)
draws = generator.standard_normal(size=(n_fixed_draws, n_params))

inputs = model.continuous_value_vars + model.discrete_value_vars
initial_point_dict = model.initial_point()
logp = model.logp()

# Graph in terms of a flat input
[logp], flat_input = join_nonshared_inputs(
point=initial_point_dict, outputs=[logp], inputs=inputs
)

var_params = pt.vector(name="eta", shape=(2 * n_params,))
means, log_sds = var_params[:n_params], var_params[n_params:]

draw_matrix = pt.constant(draws)
samples = means + pt.exp(log_sds) * draw_matrix

logp_vectorized_draws = pytensor.graph.vectorize_graph(logp, replace={flat_input: samples})

mean_log_density = pt.mean(logp_vectorized_draws)
entropy = pt.sum(log_sds)

objective = -mean_log_density - entropy

return var_params, objective


def transform_draws(
unstacked_draws: xarray.Dataset,
model: Model,
keep_untransformed: bool = False,
):
"""
Transforms the unconstrained draws back into the constrained space.

Parameters
----------
unstacked_draws : xarray.Dataset
The draws to constrain back into the original space.

model : Model
The PyMC model the variables were derived from.

n_draws: int
The number of draws to return from the variational approximation.

keep_untransformed: bool
Whether or not to keep the unconstrained variables in the output.

Returns
-------
:class:`~arviz.InferenceData`
Draws from the original constrained parameters.
"""

filtered_var_names = model.unobserved_value_vars
vars_to_sample = list(
get_default_varnames(filtered_var_names, include_transformed=keep_untransformed)
)
fn = pytensor.function(model.value_vars, vars_to_sample)
point_func = PointFunc(fn)

coords, dims = coords_and_dims_for_inferencedata(model)

transformed_result = apply_function_over_dataset(
point_func,
unstacked_draws,
output_var_names=[x.name for x in vars_to_sample],
coords=coords,
dims=dims,
)

return transformed_result
5 changes: 5 additions & 0 deletions pymc_extras/inference/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,8 @@ def fit(method: str, **kwargs) -> az.InferenceData:
from pymc_extras.inference import fit_laplace

return fit_laplace(**kwargs)

if method == "dadvi":
from pymc_extras.inference import fit_dadvi

return fit_dadvi(**kwargs)
Loading