-
Notifications
You must be signed in to change notification settings - Fork 69
Add deterministic advi #564
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add deterministic advi #564
Conversation
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
What exactly needs jax? |
Thank you both very much for having a look so quickly! @jessegrabowski Good point, yes maybe! I'll take a look. @ricardoV94 Currently, JAX is used to compute the hvp and the jacobian of the objective. That involves computing a gradient for each of the fixed draws and then taking an average. What's quite nice in JAX is that this can be done with That said, JAX isn't strictly necessary. Anything that can provide the Are you concerned about the JAX dependency? If so, maybe I could have a go at doing a JAX-free version using the code just mentioned and then only support JAX optionally. I do think it might be nice to have since it's probably more efficient and would hopefully also run fast on GPUs. But interested in your thoughts. Also, I see one of the pre-commit checks seem to be failing. I can do the work to make the pre-commit hooks happy, sorry I haven't done that yet. |
I think a jax dependency is fine. But if it's optional that's obviously
even better!
…On Thu, 14 Aug 2025, 12:14 Martin Ingram, ***@***.***> wrote:
*martiningram* left a comment (pymc-devs/pymc-extras#564)
<#564 (comment)>
Thank you both very much for having a look so quickly!
@jessegrabowski <https://github.com/jessegrabowski> Good point, yes
maybe! I'll take a look.
@ricardoV94 <https://github.com/ricardoV94> Currently, JAX is used to
compute the hvp and the jacobian of the objective. That involves computing
a gradient for each of the fixed draws and then taking an average. What's
quite nice in JAX is that this can be done with vmap easily:
https://github.com/pymc-devs/pymc-extras/pull/564/files#diff-1d6e8b962a8c3ca803c55bea43c19863223ed50ae3814acc55424834ade1215cR44
That said, JAX isn't strictly necessary. Anything that can provide the
DADVIFuns is fine:
https://github.com/pymc-devs/pymc-extras/pull/564/files#diff-48ee4e85c0ff57f5b8af20dfd608bd0e37c3a2c76169a7bbe499e77ff3802d9dR13
. In fact, I have code in the original research repo
<https://github.com/martiningram/dadvi/blob/main/dadvi/objective_from_model.py#L5>
that turns the regular hvp and gradient function into the DADVIFuns. But
I think it'll be slower because of the for loops e.g. here
<https://github.com/martiningram/dadvi/blob/main/dadvi/objective_from_model.py#L56>
.
Are you concerned about the JAX dependency? If so, maybe I could have a go
at doing a JAX-free version using the code just mentioned and then only
support JAX optionally. I do think it might be nice to have since it's
probably more efficient and would hopefully also run fast on GPUs. But
interested in your thoughts.
—
Reply to this email directly, view it on GitHub
<#564 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAACCUO4HELS5DTXEAZOJRT3NSYW7AVCNFSM6AAAAACD4JGBS6VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZTCOBZGAZTIMZRGQ>
.
You are receiving this because you are subscribed to this thread.Message
ID: ***@***.***>
|
PyTensor has the equivalent The reason I ask is if you don't have anything jax specific you can still end up using jax, but also C or numba which may be better for certain users. |
@ricardoV94 Oh cool, thanks, I didn't realise! I'll take a look if I can use those. I agree it would be nice to support as many users as possible. |
Happy to assist you. If you're vectorizing the Jacobian you probably want to build Everything is described here although a bit scattered: https://pytensor.readthedocs.io/en/latest/tutorial/gradients.html |
Hey @ricardoV94 (and potentially others!), I think I could use your advice with the vectorisation. I think I've read enough to do it without using the functions here but I'd really like to try to get this vectorised for speed. To explain a bit: the code expects the definition of
The function should then return the estimate of the kl divergence using these draws, as well as its gradient with respect to the variational parameters. The KL divergence is the sum of the entropy of the approximation (a simple function of the variational parameters only) and the average of the log posterior densities from the draws. That's the part that I'd like to vectorise. Now in JAX, the way I do this is to...:
Thanks to
This makes sense in my head but the problem I see is that the pymc model's So in essence, I think I need code to do Thanks a lot for your help :) |
If you get the logp of a pymc model using The path followed by the laplace code is to freeze the model and extract the negative logp , then create a flat vector input replacing the individual value inputs, then compile the loss_and_grads/hess/hessp functions, (optionally in jax) My hope is that you can get the correct loss function for DADVI, then you should be able to directly pass it into The 4 steps you outline seem correct to me. |
Thanks a lot @jessegrabowski . I'll give it a go! |
Hey all, I think I made good progress with the pytensor version. A first version is here: https://github.com/pymc-devs/pymc-extras/pull/564/files#diff-1b6e7da940ec73fce49f5e13ae1db5369ec011cb0b55974ec04d81e519e923f6R55 I think the only major thing missing is to transform the draws back into the constrained space from the unconstrained space. Is there a code snippet anyone could point me to? Thanks for your help and for all the helpful advice you've already given! |
You can make a pytensor function from the model value variables to the
output variables. An example of that is how get_jaxified_graph is used in
the jax based samplers
https://github.com/pymc-devs/pymc/blob/main/pymc/sampling/jax.py#L682
If you look in the source of get_jaxified_graph you can see how it's done
…On Sat, 16 Aug 2025, 11:31 Martin Ingram, ***@***.***> wrote:
*martiningram* left a comment (pymc-devs/pymc-extras#564)
<#564 (comment)>
Hey all, I think I made good progress with the pytensor version. A first
version is here:
https://github.com/pymc-devs/pymc-extras/pull/564/files#diff-1b6e7da940ec73fce49f5e13ae1db5369ec011cb0b55974ec04d81e519e923f6R55
I think the only major thing missing is to transform the draws back into
the constrained space from the unconstrained space. Is there a code snippet
anyone could point me to? Thanks for your help and for all the helpful
advice you've already given!
—
Reply to this email directly, view it on GitHub
<#564 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAACCUP5BFLVE5ZDJ4XL6433N5FD5AVCNFSM6AAAAACD4JGBS6VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZTCOJTG4ZTSMZXG4>
.
You are receiving this because you commented.Message ID:
***@***.***>
|
Thanks a lot @zaxtax . I had a go here. I think it works but I'm doing a for-loop over the draws and I'm a bit worried it might get slow (though hopefully the transformations are fairly cheap). I guess I could try to vectorise the graph. What do you guys think, would that be worth it? Or is there a better way to approach this? Thanks again for your help :) |
return var_params, objective, n_params | ||
|
||
|
||
def transform_draws(unstacked_draws, model, n_draws, keep_untransformed=False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You may want to make use of: https://github.com/pymc-devs/pymc/blob/09603234f416866c9c9e32d1954fb691bd1580cd/pymc/backends/arviz.py#L669
Thanks a lot for the suggestions and comments @jessegrabowski ! I think I implemented them, please let me know. The only one I didn't go for in the end was your |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great!
Bummer about pt.split
, maybe @ricardoV94 can chime in on usage. I agree it looks more complex than expected.
pymc_extras/inference/dadvi/dadvi.py
Outdated
random_seed: RandomSeed = None, | ||
n_draws: int = 1000, | ||
keep_untransformed: bool = False, | ||
opt_method: minimize_method = "trust-ncg", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should agree on a unified name for this, because find_MAP
is method
, fit_laplace
is optimizer_method
and now this is opt_method
.
My first choice is just method
but it clashes with the pmx.fit
API (which I don't really like anyway). Thoughts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @jessegrabowski , I agree. I first matched the syntax in find_map
but then noticed the clash with the fit
method. I'm also not a huge fan of the shared API tbh.
On the other hand, I personally think method
is maybe a bit vague anyhow. I'd be OK to go with optimizer_method
, then at least it's consistent with fit_laplace
. What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sounds good to me
pymc_extras/inference/dadvi/dadvi.py
Outdated
f_fused, | ||
np.zeros(2 * n_params), | ||
method=opt_method, | ||
jac=True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You don't need to hard-code this to true, minimize
will automatically figure it out if f_fused
returns 2 values (and the method uses gradients).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, I didn't know that! Will change it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a bug. I'll fix it, but in the meantime you can do something like: use_jac = minimize_kwargs.get('use_jac', True)
so the user can override it if he wants, but it still works fine with the default?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmmm, I'm not sure I totally like that to be honest, because the minimize_kwargs
ideally should match what's in scipy.optimize.minimize
, so that they can just be passed on as-is and we can keep the docstring:
But if you think it's important to have the ability to set jac=False
, I could add it as an argument to fit_dadvi
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've made using the Jacobian and hessp optional now, let me know what you think!
pymc_extras/inference/dadvi/dadvi.py
Outdated
result = minimize( | ||
f_fused, | ||
np.zeros(2 * n_params), | ||
jac=True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
jac=True, | |
jac=minimize_kwargs.get('use_jac', True), |
Thanks again @jessegrabowski for all your helpful comments. What do you think, do you like the new optional Jacobian and hessp? And is there anything else I should do before this could be merged? |
Yes looks nice. Could you add a use_hess option as well, in case some psycho wants to use Otherwise I'm satisfied and already approved :) |
Thanks @jessegrabowski ! I started, but it's slightly complex right? Because if I understand I could split |
Oh, I didn't notice you were using This is also why I'm sorry for not looking more closely before, that was a source of confusion. |
Oh cool @jessegrabowski , no worries, I also missed this, sorry! I've switched to |
Hey all, thanks again to all of you for all your helpful advice! Anything still in the way of merging this? |
@martiningram LGTM! If you are all set, I can merge it |
@zaxtax fantastic! Yeah I think from my side all ready to go :) |
@martiningram what's required now to get LR support into DADVI? |
@zaxtax thanks a lot for merging! I think the main thing to get LR to work is to think about the interface a bit. The thing is, it's usually too expensive to do the LR correction for all parameters, since you have to compute the full Hessian to do that (though this should be pretty easy to implement at least). But it's often enough to get the marginal variance for a subset of the parameters or for a scalar-valued function of the parameters. That can be done reasonably quickly using conjugate gradients and the hvp, basically. Maybe to keep it simple, I'd think about how to allow a user to do the LR correction for a single scalar parameter. Then, the main complication is how to include the transformations from unconstrained --> constrained where needed. What do you think? If it sounds good, I'll do some thinking and come up with a proposal. However, as a very first step, I'd like to write a blog post or two to promote the current iteration we just merged. Because I think it basically solves the convergence issues with ADVI, which is a pretty cool thing in itself and is hopefully already helpful for people. |
If you need full hessians, having support for sparse jacobians (issue here) might end up being extremely useful. I really want to push on it, but I'm spread too thin :( |
Looking forward to it! |
Thanks a lot guys! @jessegrabowski the sparse Jacobians look pretty awesome, thanks! I've also been keen to get sparse matrix algebra going for a long time. But I'll also have to read a bit to understand the topic, and I'm also spread quite thin right now unfortunately. But I'll have a read and will try to understand what you've done so far! @ricardoV94 Awesome, thanks! I have a little blog myself (https://martiningram.github.io/), but if you'd like and there's interest, I'd also love to post it on some PyMC blog. Is there one where it would fit? Only if you like the article of course (I can share it for comments before) and if there's one where it would make sense, let me know. |
We have a general VI notebook here, to which we could add a DADVI example, but might be nice to have a new standalone one that compares ADVI/DADVI or ADVI/DADVI/Pathfinder. |
* Add first version of deterministic ADVI * Update API * Add a notebook example * Add to API and add a docstring * Change import in notebook * Add jax to dependencies * Add pytensor version * Fix handling of pymc model * Add (probably suboptimal) handling of the two backends * Add transformation * Follow Ricardo's advice to simplify the transformation step * Fix naming bug * Document and clean up * Fix example * Update pymc_extras/inference/deterministic_advi/dadvi.py Co-authored-by: Ricardo Vieira <[email protected]> * Respond to comments * Fix with pre commit checks * Update pymc_extras/inference/deterministic_advi/dadvi.py Co-authored-by: Jesse Grabowski <[email protected]> * Implement suggestions * Rename parameter because it's duplicated otherwise * Rename to be consistent in use of dadvi * Rename to `optimizer_method` and drop jac=True * Add jac=True back in since trust-ncg complained * Make hessp and jac optional * Harmonize naming with existing code * Fix example * Switch to `better_optimize` * Replace with pt.split --------- Co-authored-by: Martin Ingram <[email protected]> Co-authored-by: Ricardo Vieira <[email protected]> Co-authored-by: Jesse Grabowski <[email protected]>
* Added new file DFM.py for GSOC 2025 Dynamical Factor Models * Add initial notebook on custom DFM implementation * Update of DFM draft implementation In the notebook a comparison between the custom DFM and the implemented DFM (which has an hardcoded version of make_symbolic_graph, that work just in this case) * Aligning the order of vector state with statsmodel and updating the test * Added test_DFM_update_matches_statsmodels and small corrections to DFM.py * Updating test following test_ETS.py and small adjustment for exog variables in DFM.py * Added support for joint VAR modelling (error_var=True) * Adding a first implemntation of exogeneous variable support based on pymc_extras/statespace/models/structural/components/regression.py * Completing the implementation of exogeneous varibales support * Small adjustments and improvements in DFM.py * Small adjustments and improvements in DFM.py * Adjustments after Jesse review * Adjustments following Jesse suggestions and added tests for exog support * Added new DFM example notebook and deleted an old version of custom DFM example * Add tests for names/dims/coords * De-duplicate exogenous dim between DFM and SARIMAX * Small adjustments and refactoring after code review * Allow exogenous regressors in `BayesianVARMAX` (#567) * First pass on exogenous variables in VARMA * Adjust state names for API consistency * Allow exogenous variables in BayesianVARMAX * Eagerly simplify model where possible * Typo fix * Small adjustments in the tests after review * Harmonizing names for EXOG dimension between DFM and VARMAX * Corrections in the notebook and add a small comment in DFM.py * Add deterministic advi (#564) * Add first version of deterministic ADVI * Update API * Add a notebook example * Add to API and add a docstring * Change import in notebook * Add jax to dependencies * Add pytensor version * Fix handling of pymc model * Add (probably suboptimal) handling of the two backends * Add transformation * Follow Ricardo's advice to simplify the transformation step * Fix naming bug * Document and clean up * Fix example * Update pymc_extras/inference/deterministic_advi/dadvi.py Co-authored-by: Ricardo Vieira <[email protected]> * Respond to comments * Fix with pre commit checks * Update pymc_extras/inference/deterministic_advi/dadvi.py Co-authored-by: Jesse Grabowski <[email protected]> * Implement suggestions * Rename parameter because it's duplicated otherwise * Rename to be consistent in use of dadvi * Rename to `optimizer_method` and drop jac=True * Add jac=True back in since trust-ncg complained * Make hessp and jac optional * Harmonize naming with existing code * Fix example * Switch to `better_optimize` * Replace with pt.split --------- Co-authored-by: Martin Ingram <[email protected]> Co-authored-by: Ricardo Vieira <[email protected]> Co-authored-by: Jesse Grabowski <[email protected]> * Small adjustments in the notebook --------- Co-authored-by: jessegrabowski <[email protected]> Co-authored-by: Jesse Grabowski <[email protected]> Co-authored-by: Martin Ingram <[email protected]> Co-authored-by: Martin Ingram <[email protected]> Co-authored-by: Ricardo Vieira <[email protected]>
Hi everyone,
I'm one of the authors of the paper on deterministic ADVI. There is an open feature request for this in PyMC here so I thought I'd kick things off with this PR.
In simple terms, DADVI is like ADVI but rather than using a new draw to estimate its objective at each step, it uses a fixed set of draws during the optimisation. That means that (1) it can use regular off-the-shelf optimisers rather than stochastic optimisation, making convergence more reliable, and (2) it's possible to use techniques to improve the variance estimates. This is in the paper, as well as tools to assess how big the error is from using fixed draws.
This PR covers only the first part -- optimising ADVI with fixed draws. This is because I thought I'd start simple and because I'm hoping that it already addresses a real problem with ADVI, which is the difficulty in assessing convergence.
In addition to adding the code, there is an example notebook in
notebooks/deterministic_advi_example.ipynb
. It fits DADVI to the PyMC basic linear regression example. I can add more examples, but I thought I'd start simple.I mostly lifted the code from my research repository, so there are probably some style differences. Let me know what would be important to change.
Note that JAX is needed, but there shouldn't be any other dependencies.
Very keen to hear what you all think! :)
All the best,
Martin