55import pytensor .tensor as pt
66import xarray
77
8+ from better_optimize import minimize
89from better_optimize .constants import minimize_method
910from pymc import DictToArrayBijection , Model , join_nonshared_inputs
1011from pymc .backends .arviz import (
1415)
1516from pymc .util import RandomSeed , get_default_varnames
1617from pytensor .tensor .variable import TensorVariable
17- from scipy .optimize import minimize
1818
1919from pymc_extras .inference .laplace_approx .laplace import unstack_laplace_draws
2020from pymc_extras .inference .laplace_approx .scipy_interface import (
@@ -31,6 +31,7 @@ def fit_dadvi(
3131 optimizer_method : minimize_method = "trust-ncg" ,
3232 use_grad : bool = True ,
3333 use_hessp : bool = True ,
34+ use_hess : bool = False ,
3435 ** minimize_kwargs ,
3536) -> az .InferenceData :
3637 """
@@ -82,6 +83,11 @@ def fit_dadvi(
8283 use_hessp:
8384 If True, pass the hessian vector product to `scipy.optimize.minimize`.
8485
86+ use_hess:
87+ If True, pass the hessian to `scipy.optimize.minimize`. Note that
88+ this is generally not recommended since its computation can be slow
89+ and memory-intensive if there are many parameters.
90+
8591 Returns
8692 -------
8793 :class:`~arviz.InferenceData`
@@ -110,9 +116,9 @@ def fit_dadvi(
110116 f_fused , f_hessp = _compile_functions_for_scipy_optimize (
111117 objective ,
112118 [var_params ],
113- compute_grad = True ,
114- compute_hessp = True ,
115- compute_hess = False ,
119+ compute_grad = use_grad ,
120+ compute_hessp = use_hessp ,
121+ compute_hess = use_hess ,
116122 )
117123
118124 derivative_kwargs = {}
@@ -121,6 +127,8 @@ def fit_dadvi(
121127 derivative_kwargs ["jac" ] = True
122128 if use_hessp :
123129 derivative_kwargs ["hessp" ] = f_hessp
130+ if use_hess :
131+ derivative_kwargs ["hess" ] = True
124132
125133 result = minimize (
126134 f_fused ,
0 commit comments