Skip to content

Commit 4dfc924

Browse files
apaszkeGoogle-ML-Automation
authored andcommitted
Use clear dtype annotations to avoid dtype promotion issues
The default dtype changes depending on the x64 mode so it's better to be explicit in code internal to JAX. PiperOrigin-RevId: 801797558
1 parent 7c6378d commit 4dfc924

File tree

4 files changed

+14
-6
lines changed

4 files changed

+14
-6
lines changed

jax/_src/lax_reference.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,8 @@ def ragged_dot(
285285

286286
out = np.zeros((m, n), dtype=lhs.dtype)
287287
result_iota = np.expand_dims(np.arange(out.shape[0]), list(range(1, out.ndim)))
288-
start = 0
288+
result_iota = result_iota.astype(group_sizes.dtype)
289+
start = np.asarray(0, dtype=group_sizes.dtype)
289290
for i, size in enumerate(group_sizes):
290291
out += np.where(
291292
np.logical_and(start <= result_iota, result_iota < (start + size)),

jax/_src/numpy/lax_numpy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7631,7 +7631,7 @@ def delete(
76317631
obj = asarray(obj).ravel()
76327632
obj = clip(where(obj < 0, obj + a.shape[axis], obj), 0, a.shape[axis])
76337633
obj = sort(obj)
7634-
obj -= arange(len(obj)) # type: ignore[arg-type,operator]
7634+
obj -= arange(len(obj), dtype=obj.dtype) # type: ignore
76357635
i = arange(a.shape[axis] - obj.size)
76367636
i += (i[None, :] >= obj[:, None]).sum(0)
76377637
return a[(slice(None),) * axis + (i,)]

tests/lax_numpy_test.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1655,7 +1655,10 @@ def testIntegerPower(self, ptype):
16551655
def testIntegerPowerOverflow(self, x, y):
16561656
# Regression test for https://github.com/jax-ml/jax/issues/5987
16571657
args_maker = lambda: [x, y]
1658-
self._CheckAgainstNumpy(np.power, jnp.power, args_maker)
1658+
check_dtypes = platform.system() != 'Windows'
1659+
self._CheckAgainstNumpy(
1660+
np.power, jnp.power, args_maker, check_dtypes=check_dtypes
1661+
)
16591662
self._CompileAndCheck(jnp.power, args_maker)
16601663

16611664
@jtu.sample_product(
@@ -3852,7 +3855,11 @@ def testArrayFromList(self):
38523855

38533856
# out of bounds leads to an OverflowError
38543857
val = jnp.iinfo(jnp.int64).min - 1
3855-
with self.assertRaisesRegex(OverflowError, "Python int too large.*"):
3858+
if platform.system() == 'Windows':
3859+
expected_regex = 'int too big to convert'
3860+
else:
3861+
expected_regex = 'Python int too large.*'
3862+
with self.assertRaisesRegex(OverflowError, expected_regex):
38563863
jnp.array([0, val])
38573864

38583865
def testArrayNone(self):

tests/state_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -506,9 +506,9 @@ def f(x_ref, *idxs):
506506
return op(x_ref, indexer)
507507

508508
rng = self.rng()
509-
a = rng.randn(*bat_ref_aval.shape)
509+
a = rng.randn(*bat_ref_aval.shape).astype(floatx)
510510
his = [d for d, b in zip(ref_aval.shape, indexed_dims) if b]
511-
idxs = [rng.randint(low=0, high=hi, size=i.shape)
511+
idxs = [rng.randint(low=0, high=hi, size=i.shape, dtype=intx)
512512
for i, hi in zip(bat_idx_avals, his)]
513513

514514
# discharge-of-vmap

0 commit comments

Comments
 (0)