|
38 | 38 | from jax._src.interpreters import mlir
|
39 | 39 | from jax._src.interpreters import partial_eval as pe
|
40 | 40 | from jax._src.lax import lax
|
| 41 | +from jax._src.lax import utils as lax_utils |
41 | 42 | from jax._src.lax.utils import (
|
42 | 43 | _argnum_weak_type,
|
43 | 44 | input_dtype,
|
@@ -2107,16 +2108,19 @@ def _gather_fill(operand, indices, *, dimension_numbers, slice_sizes,
|
2107 | 2108 | output_shape):
|
2108 | 2109 | """Lowers a FILL_OR_DROP gather as a PROMISE_IN_BOUNDS gather with masking."""
|
2109 | 2110 | dnums = dimension_numbers
|
2110 |
| - intarray = partial(np.array, dtype=np.int64) |
2111 |
| - operand_dims = lax.shape_as_value(operand.shape) |
2112 |
| - indices = lax.convert_element_type(indices, np.int64) |
| 2111 | + index_dtype = lax_utils.int_dtype_for_shape(operand.shape, signed=True) |
| 2112 | + intarray = partial(np.array, dtype=index_dtype) |
| 2113 | + operand_dims = lax.shape_as_value(operand.shape).astype(index_dtype) |
| 2114 | + indices = lax.convert_element_type(indices, index_dtype) |
2113 | 2115 | num_batch_dims = len(indices.shape) - 1
|
2114 | 2116 |
|
2115 |
| - upper_bound = ( |
2116 |
| - operand_dims[intarray(dnums.start_index_map)] - |
2117 |
| - lax.shape_as_value(slice_sizes)[intarray(dnums.start_index_map)]) |
| 2117 | + upper_bound = operand_dims[ |
| 2118 | + intarray(dnums.start_index_map) |
| 2119 | + ] - lax.shape_as_value(slice_sizes)[intarray(dnums.start_index_map)].astype( |
| 2120 | + index_dtype |
| 2121 | + ) |
2118 | 2122 | mask = lax.bitwise_and(
|
2119 |
| - lax.ge(indices, np.int64(0)), |
| 2123 | + lax.ge(indices, index_dtype.type(0)), |
2120 | 2124 | lax.le(indices, lax.expand_dims(upper_bound, tuple(range(num_batch_dims)))))
|
2121 | 2125 | mask = lax.reduce_and(mask, [num_batch_dims])
|
2122 | 2126 |
|
@@ -2727,18 +2731,20 @@ def _clamp_scatter_indices(operand, indices, updates, *, dnums):
|
2727 | 2731 |
|
2728 | 2732 | upper_bounds: core.Shape = tuple(operand.shape[i] - slice_sizes[i]
|
2729 | 2733 | for i in dnums.scatter_dims_to_operand_dims)
|
| 2734 | + |
2730 | 2735 | # Stack upper_bounds into a Array[n]
|
2731 | 2736 | upper_bound = lax.shape_as_value(upper_bounds)
|
2732 | 2737 | # This fix fails lax_test_no_jax_array
|
2733 |
| - upper_bound = lax.min(upper_bound, |
2734 |
| - lax.convert_element_type(np.uint64(np.iinfo(indices.dtype).max), |
2735 |
| - np.int64)) |
2736 |
| - |
| 2738 | + upper_bound = lax.min( |
| 2739 | + upper_bound, |
| 2740 | + upper_bound.dtype.type( |
| 2741 | + min(np.iinfo(upper_bound.dtype).max, np.iinfo(indices.dtype).max) |
| 2742 | + ), |
| 2743 | + ) |
| 2744 | + upper_bound = lax.convert_element_type(upper_bound, indices.dtype) |
2737 | 2745 | upper_bound = lax.broadcast_in_dim(upper_bound, indices.shape,
|
2738 | 2746 | (len(indices.shape) - 1,))
|
2739 |
| - return lax.clamp(np.int64(0), lax.convert_element_type(indices, np.int64), |
2740 |
| - upper_bound) |
2741 |
| - |
| 2747 | + return lax.clamp(indices.dtype.type(0), indices, upper_bound) |
2742 | 2748 |
|
2743 | 2749 | def _scatter_addsub_jvp(
|
2744 | 2750 | prim, primals, tangents, *, update_jaxpr, update_consts, dimension_numbers,
|
@@ -3132,10 +3138,7 @@ def _scatter_jvp(primals, tangents, *, update_jaxpr, update_consts,
|
3132 | 3138 | for update_dim in dnums.update_window_dims:
|
3133 | 3139 | ids_shape[update_dim] = 1
|
3134 | 3140 | num_ids = math.prod(ids_shape)
|
3135 |
| - if core.is_constant_dim(num_ids): |
3136 |
| - id_dtype = np.uint32 if (num_ids + 1) < np.iinfo(np.uint32).max else np.uint64 |
3137 |
| - else: |
3138 |
| - id_dtype = dtypes.canonicalize_dtype(np.uint64) |
| 3141 | + id_dtype = lax_utils.int_dtype_for_dim(num_ids, signed=False) |
3139 | 3142 | update_ids = lax.add(lax.reshape(lax.iota(id_dtype, num_ids), ids_shape),
|
3140 | 3143 | lax._ones(updates, dtype=id_dtype))
|
3141 | 3144 |
|
|
0 commit comments