-
Notifications
You must be signed in to change notification settings - Fork 70
Description
Context:
I am migrating code (causalnex's dynotears) from a numpy/scipy implementation to a jax implementation. This essentially involves moving from scipy's LBFGS-B to jaxopt's implementation so I can jit this function and run it faster.
Apart from the _func(..)
to minimize, the code has a custom _grad(..)
function defined for the optimization. I converted both _func()
and _grad()
to their jax counterparts, and am using jaxopt.LBFGSB
with the custom grad function like.
Original numpy/scipy implementation
# initialise matrix, weights and constraints
wa_est = np.zeros(2 * (p_orders + 1) * d_vars**2)
wa_new = np.zeros(2 * (p_orders + 1) * d_vars**2)
rho, alpha, h_value, h_new = 1.0, 0.0, np.inf, np.inf
for n_iter in range(max_iter):
while (rho < 1e20) and (h_new > 0.25 * h_value or h_new == np.inf):
wa_new = sopt.minimize(
_func,
wa_est,
method="L-BFGS-B",
jac=_grad,
bounds=bnds
).x
h_new = _h(wa_new, d_vars, p_orders)
if h_new > 0.25 * h_value:
rho *= 10
wa_est = wa_new
h_value = h_new
alpha += rho * h_value
if h_value <= h_tol:
break
if h_value > h_tol and n_iter == max_iter - 1:
warnings.warn("Failed to converge. Consider increasing max_iter.")
My current jaxopt implementation
# bnds is a list of (lower, upper) tuples, where upper might have None values.
# Make it compatible with what jaxopt.LBFGSB expects
np_bnds = np.array(bnds)
lowers = jnp.array(np_bnds[:, 0].astype(float))
cleaned_uppers = np.where(np_bnds[:, 1] == None, jnp.inf, np_bnds[:, 1])
uppers = jnp.array(cleaned_uppers.astype(float))
jnp_lbfgs_bounds = (lowers, uppers)
lbfgsb_solver = LBFGSB(fun=_func_jax, value_and_grad=True)
for n_iter in range(max_iter):
while (rho < 1e20) and (h_new > 0.25 * h_value or h_new == jnp.inf):
wa_new = lbfgsb_solver.run(
wa_est,
bounds=jnp_lbfgs_bounds
).params
h_new = _h_jax(wa_new, d_vars, p_orders)
if h_new > 0.25 * h_value:
rho *= 10
wa_est = wa_new
h_value = h_new
alpha += rho * h_value
if h_value <= h_tol:
break
if h_value > h_tol and n_iter == max_iter - 1:
warnings.warn("Failed to converge. Consider increasing max_iter.")
I have ensured that _func_jax
returns (loss, _grad_jax(params))
compared to _func()
which returns just the scalar. I'm not expecting exact answers between scipy/jaxopt implementations since I understand there will be numerical issues, even if seeds are set. But there seems to be a large mismatch between the scipy and jaxopt versions.
I do get some warnings during my run like:
WARNING: jaxopt.ZoomLineSearch: No interval satisfying curvature condition.Consider increasing maximal possible stepsize of the linesearch.
WARNING: jaxopt.ZoomLineSearch: Returning stepsize with sufficient decrease but curvature condition not satisfied.
Would really help to understand what is causing these differences (and if they are expected or not)?
Versions:
jax: 0.4.31
jaxopt: 0.8.3
numpy: 1.23.5
scipy: 1.13.1