Skip to content

Conversation

martiningram
Copy link
Contributor

Hi everyone,

I'm one of the authors of the paper on deterministic ADVI. There is an open feature request for this in PyMC here so I thought I'd kick things off with this PR.

In simple terms, DADVI is like ADVI but rather than using a new draw to estimate its objective at each step, it uses a fixed set of draws during the optimisation. That means that (1) it can use regular off-the-shelf optimisers rather than stochastic optimisation, making convergence more reliable, and (2) it's possible to use techniques to improve the variance estimates. This is in the paper, as well as tools to assess how big the error is from using fixed draws.

This PR covers only the first part -- optimising ADVI with fixed draws. This is because I thought I'd start simple and because I'm hoping that it already addresses a real problem with ADVI, which is the difficulty in assessing convergence.

In addition to adding the code, there is an example notebook in notebooks/deterministic_advi_example.ipynb. It fits DADVI to the PyMC basic linear regression example. I can add more examples, but I thought I'd start simple.

I mostly lifted the code from my research repository, so there are probably some style differences. Let me know what would be important to change.

Note that JAX is needed, but there shouldn't be any other dependencies.

Very keen to hear what you all think! :)

All the best,
Martin

Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@jessegrabowski jessegrabowski added the enhancements New feature or request label Aug 14, 2025
@jessegrabowski
Copy link
Member

This is super cool -- I'm very excited to look more closely over the next few days.

Since you're ultimately building a loss function and sending it to scipy.optimize, do you think we could re-use any of the machinery that exists for doing that in the laplace_approx module, for example this or this?

@ricardoV94
Copy link
Member

What exactly needs jax?

@martiningram
Copy link
Contributor Author

martiningram commented Aug 14, 2025

Thank you both very much for having a look so quickly!

@jessegrabowski Good point, yes maybe! I'll take a look.

@ricardoV94 Currently, JAX is used to compute the hvp and the jacobian of the objective. That involves computing a gradient for each of the fixed draws and then taking an average. What's quite nice in JAX is that this can be done with vmap easily: https://github.com/pymc-devs/pymc-extras/pull/564/files#diff-1d6e8b962a8c3ca803c55bea43c19863223ed50ae3814acc55424834ade1215cR44

That said, JAX isn't strictly necessary. Anything that can provide the DADVIFuns is fine: https://github.com/pymc-devs/pymc-extras/pull/564/files#diff-48ee4e85c0ff57f5b8af20dfd608bd0e37c3a2c76169a7bbe499e77ff3802d9dR13 . In fact, I have code in the original research repo that turns the regular hvp and gradient function into the DADVIFuns. But I think it'll be slower because of the for loops e.g. here .

Are you concerned about the JAX dependency? If so, maybe I could have a go at doing a JAX-free version using the code just mentioned and then only support JAX optionally. I do think it might be nice to have since it's probably more efficient and would hopefully also run fast on GPUs. But interested in your thoughts.

Also, I see one of the pre-commit checks seem to be failing. I can do the work to make the pre-commit hooks happy, sorry I haven't done that yet.

@zaxtax
Copy link
Contributor

zaxtax commented Aug 14, 2025 via email

@ricardoV94
Copy link
Member

ricardoV94 commented Aug 14, 2025

@ricardoV94 Currently, JAX is used to compute the hvp and the jacobian of the objective. That involves computing a gradient for each of the fixed draws and then taking an average. What's quite nice in JAX is that this can be done with vmap easily: h

PyTensor has the equivalent hessian_product_vector and jacobian, and vectorize_graph that does the same as vmap (or more if you have multiple batch dimensions)

The reason I ask is if you don't have anything jax specific you can still end up using jax, but also C or numba which may be better for certain users.

@martiningram
Copy link
Contributor Author

@ricardoV94 Oh cool, thanks, I didn't realise! I'll take a look if I can use those. I agree it would be nice to support as many users as possible.

@ricardoV94
Copy link
Member

Happy to assist you. If you're vectorizing the Jacobian you probably want to build jacobian(vectorize=True) which can further be vectorized more nicely.

Everything is described here although a bit scattered: https://pytensor.readthedocs.io/en/latest/tutorial/gradients.html

@martiningram
Copy link
Contributor Author

martiningram commented Aug 15, 2025

Hey @ricardoV94 (and potentially others!), I think I could use your advice with the vectorisation. I think I've read enough to do it without using the functions here but I'd really like to try to get this vectorised for speed.

To explain a bit: the code expects the definition of DADVIFuns (https://github.com/pymc-devs/pymc-extras/pull/564/files#diff-48ee4e85c0ff57f5b8af20dfd608bd0e37c3a2c76169a7bbe499e77ff3802d9dR13). The first of these expects two inputs:

  1. The variational parameter vector eta, which is all the means concatenated with the log_sds of the variational parameters. This will have length 2D, where D is the number of parameters in the model (first D is means, second D is log_sds).
  2. A matrix of draws of shape [M, D], with D as before and M the number of draws.

The function should then return the estimate of the kl divergence using these draws, as well as its gradient with respect to the variational parameters. The KL divergence is the sum of the entropy of the approximation (a simple function of the variational parameters only) and the average of the log posterior densities from the draws. That's the part that I'd like to vectorise.

Now in JAX, the way I do this is to...:

  1. Compute the log posterior density for a single draw (i.e. one row in the matrix of draws): https://github.com/pymc-devs/pymc-extras/pull/564/files#diff-1d6e8b962a8c3ca803c55bea43c19863223ed50ae3814acc55424834ade1215cR37
  2. vmap this computation across all the draws and then take the mean here: https://github.com/pymc-devs/pymc-extras/pull/564/files#diff-1d6e8b962a8c3ca803c55bea43c19863223ed50ae3814acc55424834ade1215cR42

Thanks to vectorize_graph, I am hoping I can do something like this with pytensor. My strategy idea was to...:

  1. Define the variational parameter vector eta in pytensor, and transform a single draw with this vector
  2. clone_replace to use the transformed draw as an input to the graph, rather than the current input
  3. Use vectorize_graph to vectorise with respect to the draws
  4. Compute the mean of the densities and get gradients of this mean with respect to the variational parameter vector

This makes sense in my head but the problem I see is that the pymc model's logp seems to expect a dictionary, rather than a flat vector. So as part of the step to get from the new inputs to the density, I need to turn the flat vector into the dictionary. In pymc there is DictToArrayBijection which does this, but I don't think I can use it as part of the pytensor graph.

So in essence, I think I need code to do DictToArrayBijection in pure pytensor. Is there something like that? Or is there another way I am missing? I guess it would be great if I could just have a logp function that takes a flat vector as an input already -- is there a way I can get to that?

Thanks a lot for your help :)

@jessegrabowski
Copy link
Member

If you get the logp of a pymc model using model.logp (rather than compile_logp or one of the jax helpers), it will just return the symbolic logp graph, which you can then do all your vectorization/replacements on.

The path followed by the laplace code is to freeze the model and extract the negative logp , then create a flat vector input replacing the individual value inputs, then compile the loss_and_grads/hess/hessp functions, (optionally in jax)

My hope is that you can get the correct loss function for DADVI, then you should be able to directly pass it into scipy_optimize_funcs_from_loss and just completely re-use all that.

The 4 steps you outline seem correct to me. pymc.pytensorf.join_nonshared_inputs is the function I think you're looking for to do the pack/unpack operation on different the parameters; I linked to its usage above.

@martiningram
Copy link
Contributor Author

Thanks a lot @jessegrabowski . I'll give it a go!

@martiningram
Copy link
Contributor Author

Hey all, I think I made good progress with the pytensor version. A first version is here: https://github.com/pymc-devs/pymc-extras/pull/564/files#diff-1b6e7da940ec73fce49f5e13ae1db5369ec011cb0b55974ec04d81e519e923f6R55

I think the only major thing missing is to transform the draws back into the constrained space from the unconstrained space. Is there a code snippet anyone could point me to? Thanks for your help and for all the helpful advice you've already given!

@zaxtax
Copy link
Contributor

zaxtax commented Aug 16, 2025 via email

@martiningram
Copy link
Contributor Author

Thanks a lot @zaxtax . I had a go here. I think it works but I'm doing a for-loop over the draws and I'm a bit worried it might get slow (though hopefully the transformations are fairly cheap).

I guess I could try to vectorise the graph. What do you guys think, would that be worth it? Or is there a better way to approach this?

Thanks again for your help :)

return var_params, objective, n_params


def transform_draws(unstacked_draws, model, n_draws, keep_untransformed=False):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@martiningram
Copy link
Contributor Author

Thanks a lot for the suggestions and comments @jessegrabowski ! I think I implemented them, please let me know.

The only one I didn't go for in the end was your pt.split suggestion -- it looked great but the suggestion didn't work as-is, and it looked like the signature of pt.split is actually quite complicated (see here) and after some messing around I thought the original version might be simpler. But please let me know if I'm missing something.

Copy link
Member

@jessegrabowski jessegrabowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great!

Bummer about pt.split, maybe @ricardoV94 can chime in on usage. I agree it looks more complex than expected.

random_seed: RandomSeed = None,
n_draws: int = 1000,
keep_untransformed: bool = False,
opt_method: minimize_method = "trust-ncg",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should agree on a unified name for this, because find_MAP is method, fit_laplace is optimizer_method and now this is opt_method.

My first choice is just method but it clashes with the pmx.fit API (which I don't really like anyway). Thoughts?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @jessegrabowski , I agree. I first matched the syntax in find_map but then noticed the clash with the fit method. I'm also not a huge fan of the shared API tbh.

On the other hand, I personally think method is maybe a bit vague anyhow. I'd be OK to go with optimizer_method, then at least it's consistent with fit_laplace. What do you think?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good to me

f_fused,
np.zeros(2 * n_params),
method=opt_method,
jac=True,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't need to hard-code this to true, minimize will automatically figure it out if f_fused returns 2 values (and the method uses gradients).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I didn't know that! Will change it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah hmm, I tried it but trust-ncg failed with the following message:
image
So maybe it only works the way you describe with L-BFGS-B, or perhaps trust-ncg is just extra-finicky?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bug. I'll fix it, but in the meantime you can do something like: use_jac = minimize_kwargs.get('use_jac', True) so the user can override it if he wants, but it still works fine with the default?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm, I'm not sure I totally like that to be honest, because the minimize_kwargs ideally should match what's in scipy.optimize.minimize, so that they can just be passed on as-is and we can keep the docstring:
image

But if you think it's important to have the ability to set jac=False, I could add it as an argument to fit_dadvi?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've made using the Jacobian and hessp optional now, let me know what you think!

result = minimize(
f_fused,
np.zeros(2 * n_params),
jac=True,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
jac=True,
jac=minimize_kwargs.get('use_jac', True),

@martiningram
Copy link
Contributor Author

Thanks again @jessegrabowski for all your helpful comments. What do you think, do you like the new optional Jacobian and hessp? And is there anything else I should do before this could be merged?

@jessegrabowski
Copy link
Member

Yes looks nice. Could you add a use_hess option as well, in case some psycho wants to use trust-exact or dogleg?

Otherwise I'm satisfied and already approved :)

@martiningram
Copy link
Contributor Author

martiningram commented Aug 29, 2025

Thanks @jessegrabowski ! I started, but it's slightly complex right? Because if I understand _compile_functions_for_scipy_optimize correctly, if I pass compute_hess=True, it will add the hessian to the output of f_fused. But unlike with the Jacobian, where I can then just pass jac=True to scipy.minimize, I can't just pass hess=True -- it looks like scipy.optimize.minimize expects a Hessian function and doesn't allow a bool there like with the Jacobian.

I could split f_fused back into two functions -- one returning value and gradient, and the other returning the Hessian only. Or is there a better option, what do you think?

@jessegrabowski
Copy link
Member

Oh, I didn't notice you were using scipy.optimize.minimize. You can use better_optimize.minimize, which allows a triple-fused function. Here is where it's actually being called. Notice that you only pass f_fused (which is returning either loss, (loss, grad), or (loss, grad, hess) ) and f_hessp (which can be None).

This is also why use_jac didn't work automatically in our previous discussion thread. better_optimize.minimize automatically sets the use_jac, use_hess, and use_hessp flags based on the method and the fused function.

I'm sorry for not looking more closely before, that was a source of confusion.

@martiningram
Copy link
Contributor Author

martiningram commented Aug 29, 2025

Oh cool @jessegrabowski , no worries, I also missed this, sorry! I've switched to better_optimize now and as you promised, the Hessian seems to work nicely :)

@martiningram
Copy link
Contributor Author

Hey all, thanks again to all of you for all your helpful advice! Anything still in the way of merging this?

@zaxtax
Copy link
Contributor

zaxtax commented Sep 3, 2025

@martiningram LGTM! If you are all set, I can merge it

@martiningram
Copy link
Contributor Author

@zaxtax fantastic! Yeah I think from my side all ready to go :)

@zaxtax zaxtax merged commit 24e18e8 into pymc-devs:main Sep 3, 2025
17 checks passed
@zaxtax
Copy link
Contributor

zaxtax commented Sep 4, 2025

@martiningram what's required now to get LR support into DADVI?

@martiningram
Copy link
Contributor Author

martiningram commented Sep 4, 2025

@zaxtax thanks a lot for merging!

I think the main thing to get LR to work is to think about the interface a bit. The thing is, it's usually too expensive to do the LR correction for all parameters, since you have to compute the full Hessian to do that (though this should be pretty easy to implement at least). But it's often enough to get the marginal variance for a subset of the parameters or for a scalar-valued function of the parameters. That can be done reasonably quickly using conjugate gradients and the hvp, basically.

Maybe to keep it simple, I'd think about how to allow a user to do the LR correction for a single scalar parameter. Then, the main complication is how to include the transformations from unconstrained --> constrained where needed. What do you think? If it sounds good, I'll do some thinking and come up with a proposal.

However, as a very first step, I'd like to write a blog post or two to promote the current iteration we just merged. Because I think it basically solves the convergence issues with ADVI, which is a pretty cool thing in itself and is hopefully already helpful for people.

@jessegrabowski
Copy link
Member

If you need full hessians, having support for sparse jacobians (issue here) might end up being extremely useful. I really want to push on it, but I'm spread too thin :(

@ricardoV94
Copy link
Member

I'd like to write a blog post or two to promote the current iteration we just merged

Looking forward to it!

@martiningram
Copy link
Contributor Author

martiningram commented Sep 5, 2025

Thanks a lot guys!

@jessegrabowski the sparse Jacobians look pretty awesome, thanks! I've also been keen to get sparse matrix algebra going for a long time. But I'll also have to read a bit to understand the topic, and I'm also spread quite thin right now unfortunately. But I'll have a read and will try to understand what you've done so far!

@ricardoV94 Awesome, thanks!

I have a little blog myself (https://martiningram.github.io/), but if you'd like and there's interest, I'd also love to post it on some PyMC blog. Is there one where it would fit? Only if you like the article of course (I can share it for comments before) and if there's one where it would make sense, let me know.

@fonnesbeck
Copy link
Member

We have a general VI notebook here, to which we could add a DADVI example, but might be nice to have a new standalone one that compares ADVI/DADVI or ADVI/DADVI/Pathfinder.

andreacate pushed a commit to andreacate/pymc-extras that referenced this pull request Sep 6, 2025
* Add first version of deterministic ADVI

* Update API

* Add a notebook example

* Add to API and add a docstring

* Change import in notebook

* Add jax to dependencies

* Add pytensor version

* Fix handling of pymc model

* Add (probably suboptimal) handling of the two backends

* Add transformation

* Follow Ricardo's advice to simplify the transformation step

* Fix naming bug

* Document and clean up

* Fix example

* Update pymc_extras/inference/deterministic_advi/dadvi.py

Co-authored-by: Ricardo Vieira <[email protected]>

* Respond to comments

* Fix with pre commit checks

* Update pymc_extras/inference/deterministic_advi/dadvi.py

Co-authored-by: Jesse Grabowski <[email protected]>

* Implement suggestions

* Rename parameter because it's duplicated otherwise

* Rename to be consistent in use of dadvi

* Rename to `optimizer_method` and drop jac=True

* Add jac=True back in since trust-ncg complained

* Make hessp and jac optional

* Harmonize naming with existing code

* Fix example

* Switch to `better_optimize`

* Replace with pt.split

---------

Co-authored-by: Martin Ingram <[email protected]>
Co-authored-by: Ricardo Vieira <[email protected]>
Co-authored-by: Jesse Grabowski <[email protected]>
jessegrabowski added a commit that referenced this pull request Sep 6, 2025
* Added new file DFM.py for GSOC 2025 Dynamical Factor Models

* Add initial notebook on custom DFM implementation

* Update of DFM draft implementation

In the notebook a comparison between the custom DFM and the implemented DFM (which has an hardcoded version of make_symbolic_graph, that work just in this case)

* Aligning the order of vector state with statsmodel and updating the test

* Added test_DFM_update_matches_statsmodels and small corrections to DFM.py

* Updating test following test_ETS.py and small adjustment for exog variables in DFM.py

* Added support for joint VAR modelling (error_var=True)

* Adding a first implemntation of exogeneous variable support based on pymc_extras/statespace/models/structural/components/regression.py

* Completing the implementation of exogeneous varibales support

* Small adjustments and improvements in DFM.py

* Small adjustments and improvements in DFM.py

* Adjustments after Jesse review

* Adjustments following Jesse suggestions and added tests for exog support

* Added new DFM example notebook and deleted an old version of custom DFM example

* Add tests for names/dims/coords

* De-duplicate exogenous dim between DFM and SARIMAX

* Small adjustments and refactoring after code review

* Allow exogenous regressors in `BayesianVARMAX` (#567)

* First pass on exogenous variables in VARMA

* Adjust state names for API consistency

* Allow exogenous variables in BayesianVARMAX

* Eagerly simplify model where possible

* Typo fix

* Small adjustments in the tests after review

* Harmonizing names for EXOG dimension between DFM and VARMAX

* Corrections in the notebook and add a small comment in DFM.py

* Add deterministic advi (#564)

* Add first version of deterministic ADVI

* Update API

* Add a notebook example

* Add to API and add a docstring

* Change import in notebook

* Add jax to dependencies

* Add pytensor version

* Fix handling of pymc model

* Add (probably suboptimal) handling of the two backends

* Add transformation

* Follow Ricardo's advice to simplify the transformation step

* Fix naming bug

* Document and clean up

* Fix example

* Update pymc_extras/inference/deterministic_advi/dadvi.py

Co-authored-by: Ricardo Vieira <[email protected]>

* Respond to comments

* Fix with pre commit checks

* Update pymc_extras/inference/deterministic_advi/dadvi.py

Co-authored-by: Jesse Grabowski <[email protected]>

* Implement suggestions

* Rename parameter because it's duplicated otherwise

* Rename to be consistent in use of dadvi

* Rename to `optimizer_method` and drop jac=True

* Add jac=True back in since trust-ncg complained

* Make hessp and jac optional

* Harmonize naming with existing code

* Fix example

* Switch to `better_optimize`

* Replace with pt.split

---------

Co-authored-by: Martin Ingram <[email protected]>
Co-authored-by: Ricardo Vieira <[email protected]>
Co-authored-by: Jesse Grabowski <[email protected]>

* Small adjustments in the notebook

---------

Co-authored-by: jessegrabowski <[email protected]>
Co-authored-by: Jesse Grabowski <[email protected]>
Co-authored-by: Martin Ingram <[email protected]>
Co-authored-by: Martin Ingram <[email protected]>
Co-authored-by: Ricardo Vieira <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancements New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants