-
Notifications
You must be signed in to change notification settings - Fork 70
Open
Description
jax
version: 0.4.38
.
Logs:
> self.assertArraysAllClose(scipy_val, jaxopt_val)
tests/lbfgs_test.py:422:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxopt/_src/test_util.py:292: in assertArraysAllClose
_assert_numpy_allclose(x, y, atol=atol, rtol=rtol, err_msg=err_msg)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
a = np.float64(695.4757505059242), b = Array(636.7615, dtype=float32)
atol = 1e-06, rtol = 1e-06, err_msg = ''
def _assert_numpy_allclose(a, b, atol=None, rtol=None, err_msg=''):
if a.dtype == b.dtype == float0:
np.testing.assert_array_equal(a, b, err_msg=err_msg)
return
#a = a.astype(np.float32) if a.dtype == _dtypes.bfloat16 else a
#b = b.astype(np.float32) if b.dtype == _dtypes.bfloat16 else b
kw = {}
if atol: kw["atol"] = atol
if rtol: kw["rtol"] = rtol
with onp.errstate(invalid='ignore'):
# TODO(phawkins): surprisingly, assert_allclose sometimes reports invalid
# value errors. It should not do that.
> onp.testing.assert_allclose(a, b, **kw, err_msg=err_msg)
E AssertionError:
E Not equal to tolerance rtol=1e-06, atol=1e-06
E
E Mismatched elements: 1 / 1 (100%)
E Max absolute difference among violations: 58.7142759
E Max relative difference among violations: 0.09220764
E ACTUAL: array(695.475751)
E DESIRED: array(636.7615, dtype=float32)
jaxopt/_src/test_util.py:262: AssertionError
Metadata
Metadata
Assignees
Labels
No labels