Skip to content

jaxopt's L-BFGS-B with custom gradient not matching with scipy implementation #620

@jithendaraa

Description

@jithendaraa

Context:

I am migrating code (causalnex's dynotears) from a numpy/scipy implementation to a jax implementation. This essentially involves moving from scipy's LBFGS-B to jaxopt's implementation so I can jit this function and run it faster.

Apart from the _func(..) to minimize, the code has a custom _grad(..) function defined for the optimization. I converted both _func() and _grad() to their jax counterparts, and am using jaxopt.LBFGSB with the custom grad function like.

Original numpy/scipy implementation

# initialise matrix, weights and constraints
wa_est = np.zeros(2 * (p_orders + 1) * d_vars**2)
wa_new = np.zeros(2 * (p_orders + 1) * d_vars**2)
rho, alpha, h_value, h_new = 1.0, 0.0, np.inf, np.inf

for n_iter in range(max_iter):
    while (rho < 1e20) and (h_new > 0.25 * h_value or h_new == np.inf):
        wa_new = sopt.minimize(
            _func, 
            wa_est, 
            method="L-BFGS-B", 
            jac=_grad, 
            bounds=bnds
        ).x
        h_new = _h(wa_new, d_vars, p_orders)
        if h_new > 0.25 * h_value:
            rho *= 10

    wa_est = wa_new
    h_value = h_new
    alpha += rho * h_value
    if h_value <= h_tol:
        break
    if h_value > h_tol and n_iter == max_iter - 1:
        warnings.warn("Failed to converge. Consider increasing max_iter.")

My current jaxopt implementation

# bnds is a list of (lower, upper) tuples, where upper might have None values. 
# Make it compatible with what jaxopt.LBFGSB expects
np_bnds = np.array(bnds)
lowers = jnp.array(np_bnds[:, 0].astype(float))
cleaned_uppers = np.where(np_bnds[:, 1] == None, jnp.inf, np_bnds[:, 1])
uppers = jnp.array(cleaned_uppers.astype(float))
jnp_lbfgs_bounds = (lowers, uppers)

lbfgsb_solver = LBFGSB(fun=_func_jax, value_and_grad=True)

for n_iter in range(max_iter):
    while (rho < 1e20) and (h_new > 0.25 * h_value or h_new == jnp.inf):
        wa_new = lbfgsb_solver.run(
            wa_est, 
            bounds=jnp_lbfgs_bounds
        ).params

        h_new = _h_jax(wa_new, d_vars, p_orders)
        if h_new > 0.25 * h_value:
            rho *= 10

    wa_est = wa_new
    h_value = h_new
    alpha += rho * h_value
    if h_value <= h_tol:
        break
    if h_value > h_tol and n_iter == max_iter - 1:
        warnings.warn("Failed to converge. Consider increasing max_iter.")

I have ensured that _func_jax returns (loss, _grad_jax(params)) compared to _func() which returns just the scalar. I'm not expecting exact answers between scipy/jaxopt implementations since I understand there will be numerical issues, even if seeds are set. But there seems to be a large mismatch between the scipy and jaxopt versions.

I do get some warnings during my run like:

WARNING: jaxopt.ZoomLineSearch: No interval satisfying curvature condition.Consider increasing maximal possible stepsize of the linesearch.
WARNING: jaxopt.ZoomLineSearch: Returning stepsize with sufficient decrease but curvature condition not satisfied.

Would really help to understand what is causing these differences (and if they are expected or not)?

Versions:

jax: 0.4.31
jaxopt: 0.8.3
numpy: 1.23.5
scipy:  1.13.1

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions