Skip to content

Commit 6251272

Browse files
hawkinspGoogle-ML-Automation
authored andcommitted
Add common helpers index_dtype_for_shape and index_dtype_for_dim.
Use them consistently when choosing a dtype for an array index. PiperOrigin-RevId: 800579624
1 parent 9495be7 commit 6251272

File tree

10 files changed

+178
-59
lines changed

10 files changed

+178
-59
lines changed

jax/_src/lax/lax.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
from jax._src.interpreters import pxla
5858
from jax._src.interpreters.batching import RaggedAxis
5959
from jax._src.lax import slicing
60+
from jax._src.lax import utils as lax_utils
6061
from jax._src.mesh import get_abstract_mesh, get_concrete_mesh
6162
from jax._src.lax.utils import (
6263
input_dtype, dtype_to_string, standard_abstract_eval,
@@ -7100,12 +7101,13 @@ def _squeeze_lower(ctx, operand, *, dimensions):
71007101

71017102
def shape_as_value(shape: core.Shape):
71027103
"""Converts a shape that may contain Poly values into a JAX value."""
7104+
dtype = lax_utils.int_dtype_for_shape(shape, signed=True)
71037105
if len(shape) == 0:
7104-
return full((0,), np.array(0, np.int64))
7106+
return full((0,), np.array(0, dtype=dtype))
71057107
if core.is_constant_shape(shape):
7106-
return np.asarray(shape, dtype=np.int64)
7108+
return np.asarray(shape, dtype=dtype)
71077109
dims = [
7108-
expand_dims(convert_element_type(core.dimension_as_value(d), np.int64),
7110+
expand_dims(convert_element_type(core.dimension_as_value(d), dtype),
71097111
(0,))
71107112
for d in shape
71117113
]
@@ -8113,9 +8115,10 @@ def _operands_to_keys(*operands, num_keys=1):
81138115

81148116
def _sort_jvp(primals, tangents, *, dimension, is_stable, num_keys):
81158117
shape = primals[0].shape
8118+
index_dtype = lax_utils.int_dtype_for_shape(shape, signed=False)
81168119
sorted_primals_and_idx = sort_p.bind(
81178120
*primals,
8118-
broadcasted_iota(dtypes.canonicalize_dtype(np.uint64), shape, dimension),
8121+
broadcasted_iota(index_dtype, shape, dimension),
81198122
dimension=dimension, is_stable=is_stable, num_keys=num_keys)
81208123
batch_dims = tuple(np.delete(np.arange(len(shape), dtype=np.int64),
81218124
dimension))

jax/_src/lax/slicing.py

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from jax._src.interpreters import mlir
3939
from jax._src.interpreters import partial_eval as pe
4040
from jax._src.lax import lax
41+
from jax._src.lax import utils as lax_utils
4142
from jax._src.lax.utils import (
4243
_argnum_weak_type,
4344
input_dtype,
@@ -2107,16 +2108,19 @@ def _gather_fill(operand, indices, *, dimension_numbers, slice_sizes,
21072108
output_shape):
21082109
"""Lowers a FILL_OR_DROP gather as a PROMISE_IN_BOUNDS gather with masking."""
21092110
dnums = dimension_numbers
2110-
intarray = partial(np.array, dtype=np.int64)
2111-
operand_dims = lax.shape_as_value(operand.shape)
2112-
indices = lax.convert_element_type(indices, np.int64)
2111+
index_dtype = lax_utils.int_dtype_for_shape(operand.shape, signed=True)
2112+
intarray = partial(np.array, dtype=index_dtype)
2113+
operand_dims = lax.shape_as_value(operand.shape).astype(index_dtype)
2114+
indices = lax.convert_element_type(indices, index_dtype)
21132115
num_batch_dims = len(indices.shape) - 1
21142116

2115-
upper_bound = (
2116-
operand_dims[intarray(dnums.start_index_map)] -
2117-
lax.shape_as_value(slice_sizes)[intarray(dnums.start_index_map)])
2117+
upper_bound = operand_dims[
2118+
intarray(dnums.start_index_map)
2119+
] - lax.shape_as_value(slice_sizes)[intarray(dnums.start_index_map)].astype(
2120+
index_dtype
2121+
)
21182122
mask = lax.bitwise_and(
2119-
lax.ge(indices, np.int64(0)),
2123+
lax.ge(indices, index_dtype.type(0)),
21202124
lax.le(indices, lax.expand_dims(upper_bound, tuple(range(num_batch_dims)))))
21212125
mask = lax.reduce_and(mask, [num_batch_dims])
21222126

@@ -2727,18 +2731,20 @@ def _clamp_scatter_indices(operand, indices, updates, *, dnums):
27272731

27282732
upper_bounds: core.Shape = tuple(operand.shape[i] - slice_sizes[i]
27292733
for i in dnums.scatter_dims_to_operand_dims)
2734+
27302735
# Stack upper_bounds into a Array[n]
27312736
upper_bound = lax.shape_as_value(upper_bounds)
27322737
# This fix fails lax_test_no_jax_array
2733-
upper_bound = lax.min(upper_bound,
2734-
lax.convert_element_type(np.uint64(np.iinfo(indices.dtype).max),
2735-
np.int64))
2736-
2738+
upper_bound = lax.min(
2739+
upper_bound,
2740+
upper_bound.dtype.type(
2741+
min(np.iinfo(upper_bound.dtype).max, np.iinfo(indices.dtype).max)
2742+
),
2743+
)
2744+
upper_bound = lax.convert_element_type(upper_bound, indices.dtype)
27372745
upper_bound = lax.broadcast_in_dim(upper_bound, indices.shape,
27382746
(len(indices.shape) - 1,))
2739-
return lax.clamp(np.int64(0), lax.convert_element_type(indices, np.int64),
2740-
upper_bound)
2741-
2747+
return lax.clamp(indices.dtype.type(0), indices, upper_bound)
27422748

27432749
def _scatter_addsub_jvp(
27442750
prim, primals, tangents, *, update_jaxpr, update_consts, dimension_numbers,
@@ -3132,10 +3138,7 @@ def _scatter_jvp(primals, tangents, *, update_jaxpr, update_consts,
31323138
for update_dim in dnums.update_window_dims:
31333139
ids_shape[update_dim] = 1
31343140
num_ids = math.prod(ids_shape)
3135-
if core.is_constant_dim(num_ids):
3136-
id_dtype = np.uint32 if (num_ids + 1) < np.iinfo(np.uint32).max else np.uint64
3137-
else:
3138-
id_dtype = dtypes.canonicalize_dtype(np.uint64)
3141+
id_dtype = lax_utils.int_dtype_for_dim(num_ids, signed=False)
31393142
update_ids = lax.add(lax.reshape(lax.iota(id_dtype, num_ids), ids_shape),
31403143
lax._ones(updates, dtype=id_dtype))
31413144

jax/_src/lax/utils.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,20 @@
1818

1919
from functools import partial
2020

21+
import numpy as np
22+
2123
from jax._src import core
2224
from jax._src import dispatch
25+
from jax._src import dtypes
2326
from jax._src import mesh as mesh_lib
2427
from jax._src import state
2528
from jax._src.named_sharding import DuplicateSpecError, NamedSharding
2629
from jax._src.partition_spec import PartitionSpec as P
2730
from jax._src.util import safe_zip
31+
from jax._src.typing import DimSize, DType, Shape
2832

2933
zip, unsafe_zip = safe_zip, zip
3034

31-
import numpy as np
3235

3336
def input_dtype(x, *_, **__):
3437
return x.dtype
@@ -223,3 +226,36 @@ def dtype_to_string(dtype):
223226
except AttributeError:
224227
pass
225228
return str(dtype)
229+
230+
_int32_max = np.iinfo(np.int32).max
231+
_uint32_max = np.iinfo(np.uint32).max
232+
233+
def int_dtype_for_dim(d: DimSize, *, signed: bool) -> DType:
234+
"""Returns a integer dtype large enough to contain indices in dimension d."""
235+
if signed:
236+
if not core.is_constant_dim(d):
237+
return dtypes.default_int_dtype()
238+
return np.dtype(np.int64) if d > _int32_max else np.dtype(np.int32)
239+
else:
240+
if not core.is_constant_dim(d):
241+
return dtypes.default_uint_dtype()
242+
return np.dtype(np.uint64) if d > _uint32_max else np.dtype(np.uint32)
243+
244+
def int_dtype_for_shape(shape: Shape, *, signed: bool) -> DType:
245+
"""Returns a integer dtype large enough to contain indices in `shape`."""
246+
if signed:
247+
for d in shape:
248+
if core.is_constant_dim(d):
249+
if d > _int32_max:
250+
return np.dtype(np.int64)
251+
else:
252+
return dtypes.default_int_dtype()
253+
return np.dtype(np.int32)
254+
else:
255+
for d in shape:
256+
if core.is_constant_dim(d):
257+
if d > _uint32_max:
258+
return np.dtype(np.uint64)
259+
else:
260+
return dtypes.default_uint_dtype()
261+
return np.dtype(np.uint32)

jax/_src/numpy/indexing.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from jax._src import errors
3333
from jax._src.lax import lax
3434
from jax._src.lax import slicing
35+
from jax._src.lax import utils as lax_utils
3536
from jax._src.numpy import einsum
3637
from jax._src.numpy import error as jnp_error
3738
from jax._src.numpy import lax_numpy
@@ -305,8 +306,7 @@ def replace(tup, val):
305306
lst[axis_int] = val
306307
return tuple(lst)
307308

308-
use_64bit_index = any(not core.is_constant_dim(d) or d >= (1 << 31) for d in a.shape)
309-
index_dtype = np.dtype('int64' if use_64bit_index else 'int32')
309+
index_dtype = lax_utils.int_dtype_for_dim(a.shape, signed=True)
310310
indices = lax.convert_element_type(indices, index_dtype)
311311

312312
axis_size = a.shape[axis_int]
@@ -850,10 +850,7 @@ def index_to_gather(x_shape: Sequence[int], idx: Sequence[Any],
850850
collapsed_slice_dims: list[int] = []
851851
start_index_map: list[int] = []
852852

853-
use_64bit_index = (
854-
any(not core.is_constant_dim(d) or d >= (1 << 31) for d in x_shape) and
855-
config.enable_x64.value)
856-
index_dtype = np.dtype('int64') if use_64bit_index else np.dtype('int32')
853+
index_dtype = lax_utils.int_dtype_for_shape(x_shape, signed=True)
857854

858855
# Gather indices.
859856
# Pairs of (array, start_dim) values. These will be broadcast into

jax/_src/numpy/lax_numpy.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
from jax._src.lax import lax
4949
from jax._src.lax import slicing as lax_slicing
5050
from jax._src.lax import special as lax_special
51+
from jax._src.lax import utils as lax_utils
5152
from jax._src.lib import xla_client as xc
5253
from jax._src.numpy.array_constructors import array, asarray
5354
from jax._src.numpy import array_creation
@@ -7166,8 +7167,12 @@ def diag_indices(n: int, ndim: int = 2) -> tuple[Array, ...]:
71667167
if ndim < 0:
71677168
raise ValueError("ndim argument to diag_indices must be nonnegative, got {}"
71687169
.format(ndim))
7169-
# TODO(phawkins): Use an int64 index if n >= 2**31.
7170-
return (lax.iota(int, n),) * ndim
7170+
index_dtype = lax_utils.int_dtype_for_dim(n, signed=True)
7171+
# We'd give the correct output values with int32, but use the default dtype to
7172+
# match NumPy type semantics if x64 mode is enabled for now.
7173+
if index_dtype == np.dtype(np.int32):
7174+
index_dtype = dtypes.default_int_dtype()
7175+
return (lax.iota(index_dtype, n),) * ndim
71717176

71727177

71737178
@export
@@ -9258,7 +9263,8 @@ def body_fun(state, _):
92589263

92599264

92609265
def _searchsorted_via_sort(sorted_arr: Array, query: Array, side: str, dtype: type) -> Array:
9261-
working_dtype = np.dtype('int32') if sorted_arr.size + query.size < np.iinfo(np.int32).max else np.dtype('int64')
9266+
working_dtype = lax_utils.int_dtype_for_dim(sorted_arr.size + query.size,
9267+
signed=False)
92629268
def _rank(x):
92639269
idx = lax.iota(working_dtype, x.shape[0])
92649270
return array_creation.zeros_like(idx).at[argsort(x)].set(idx)
@@ -9354,7 +9360,7 @@ def searchsorted(a: ArrayLike, v: ArrayLike, side: str = 'left',
93549360
a, v = util.promote_dtypes(a, v)
93559361
if sorter is not None:
93569362
a = a[sorter]
9357-
dtype = np.dtype('int32') if a.shape[0] <= np.iinfo(np.int32).max else np.dtype('int64')
9363+
dtype = lax_utils.int_dtype_for_dim(a.shape[0], signed=True)
93589364
if a.shape[0] == 0:
93599365
return array_creation.zeros_like(v, dtype=dtype)
93609366
impl = {

jax/_src/numpy/linalg.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from jax._src.custom_derivatives import custom_jvp
3131
from jax._src.lax import lax
3232
from jax._src.lax import linalg as lax_linalg
33+
from jax._src.lax import utils as lax_utils
3334
from jax._src.numpy import array_creation
3435
from jax._src.numpy import einsum
3536
from jax._src.numpy import indexing
@@ -294,11 +295,8 @@ def svd(
294295
s = lax.abs(v)
295296
if compute_uv:
296297
sign = lax.sign(v)
297-
idx_dtype = (
298-
np.int64
299-
if int(s.shape[s.ndim - 1]) > np.iinfo(np.int32).max
300-
else np.int32
301-
)
298+
idx_dtype = lax_utils.int_dtype_for_dim(
299+
s.shape[s.ndim - 1], signed=False)
302300
idxs = lax.broadcasted_iota(idx_dtype, s.shape, dimension=s.ndim - 1)
303301
s, idxs, sign = lax.sort((s, idxs, sign), dimension=-1, num_keys=1)
304302
s = lax.rev(s, dimensions=[s.ndim - 1])

jax/_src/numpy/setops.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from jax._src import dtypes
2727
from jax._src.lax import lax
2828
from jax._src.lax import slicing as lax_slicing
29+
from jax._src.lax import utils as lax_utils
2930
from jax._src.numpy.array_creation import empty, full, full_like, ones, zeros
3031
from jax._src.numpy.lax_numpy import (
3132
append, arange, concatenate, diff,
@@ -344,11 +345,7 @@ def _intersect1d_sorted_mask(arr1: Array, arr2: Array,
344345
assert arr1.ndim == arr2.ndim == 1
345346
arr = concatenate((arr1, arr2))
346347
if return_indices:
347-
use_64bit_index = (
348-
not core.is_constant_dim(arr.shape[0])
349-
or arr.shape[0] >= np.iinfo(np.int32).max
350-
)
351-
idx_dtype = np.int64 if use_64bit_index else np.int32
348+
idx_dtype = lax_utils.int_dtype_for_dim(arr.shape[0], signed=True)
352349
iota = lax.broadcasted_iota(idx_dtype, np.shape(arr), dimension=0)
353350
aux, indices = lax.sort_key_val(arr, iota)
354351
else:

jax/_src/numpy/sorting.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@
1818
import numpy as np
1919

2020
from jax._src import api
21-
from jax._src import core
2221
from jax._src import dtypes
2322
from jax._src.lax import lax
23+
from jax._src.lax import utils as lax_utils
2424
from jax._src.numpy import util
2525
from jax._src.util import canonicalize_axis, set_module
2626
from jax._src.typing import Array, ArrayLike
@@ -154,11 +154,12 @@ def argsort(
154154
arr = arr.ravel()
155155
axis = 0
156156
dimension = canonicalize_axis(axis, arr.ndim)
157-
use_64bit_index = core.is_constant_dim(arr.shape[dimension]) and arr.shape[dimension] >= (1 << 31)
158-
# TODO(phawkins): we should probably use int64 indices for unknown dimensions,
159-
# but that requires that we first support using int64 in a non-x64
160-
# computation.
161-
iota = lax.broadcasted_iota(np.dtype('int64') if use_64bit_index else int, arr.shape, dimension)
157+
idx_dtype = lax_utils.int_dtype_for_dim(arr.shape[dimension], signed=True)
158+
# We'd give the correct output values with int32, but use the default dtype to
159+
# match NumPy type semantics if x64 mode is enabled for now.
160+
if idx_dtype == np.dtype(np.int32):
161+
idx_dtype = dtypes.default_int_dtype()
162+
iota = lax.broadcasted_iota(idx_dtype, arr.shape, dimension)
162163
# For stable descending sort, we reverse the array and indices to ensure that
163164
# duplicates remain in their original order when the final indices are reversed.
164165
# For non-stable descending sort, we can avoid these extra operations.
@@ -425,7 +426,11 @@ def lexsort(keys: Array | np.ndarray | Sequence[ArrayLike], axis: int = -1) -> A
425426
if np.ndim(key_arrays[0]) == 0:
426427
return lax.full((), 0, dtypes.default_int_dtype())
427428
axis = canonicalize_axis(axis, np.ndim(key_arrays[0]))
428-
use_64bit_index = key_arrays[0].shape[axis] >= (1 << 31)
429-
iota = lax.broadcasted_iota(np.dtype('int64') if use_64bit_index else int,
430-
np.shape(key_arrays[0]), axis)
429+
idx_dtype = lax_utils.int_dtype_for_dim(key_arrays[0].shape[axis],
430+
signed=True)
431+
# We'd give the correct output values with int32, but use the default dtype to
432+
# match NumPy type semantics if x64 mode is enabled for now.
433+
if idx_dtype == np.dtype(np.int32):
434+
idx_dtype = dtypes.default_int_dtype()
435+
iota = lax.broadcasted_iota(idx_dtype, np.shape(key_arrays[0]), axis)
431436
return lax.sort((*key_arrays[::-1], iota), dimension=axis, num_keys=len(key_arrays))[-1]

0 commit comments

Comments
 (0)