Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions tests/pallas/ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,43 @@ def kernel(x_ref, o_ref):

np.testing.assert_allclose(result[0, 0], reduction_op(x), atol=1e-5)

@parameterized.named_parameters(
("sum", jnp.sum, (32, 256)),
("max", jnp.max, (32, 256)),
("min", jnp.min, (32, 256)),
("sum_irregular", jnp.sum, (31, 300)),
("max_irregular", jnp.max, (31, 300)),
("min_irregular", jnp.min, (31, 300)),
)
def test_reduce_int32(self, reduction_op, input_shape):
if jtu.test_device_matches(["gpu"]):
self.skipTest("TODO: error on GPU")
# TODO(b/395579834): Remove this skip later.
if not jtu.if_cloud_tpu_at_least(2025, 9, 1):
self.skipTest("Requires libtpu built after 2025-09-01")

def kernel(x_ref, o_ref):
o_ref[0, 0] = reduction_op(x_ref[...])

x = jax.random.randint(
jax.random.key(0),
shape=input_shape,
minval=-100,
maxval=100,
dtype=jnp.int32,
)
result = self.pallas_call(
kernel,
in_specs=[
pl.BlockSpec(input_shape, lambda *_: (0, 0)),
],
out_specs=pl.BlockSpec((1, 1), memory_space=smem_on_tpu()),
out_shape=jax.ShapeDtypeStruct([1, 1], intx),
grid=(1,),
)(x)

np.testing.assert_allclose(result[0, 0], reduction_op(x), atol=1e-5)

# TODO(sharadmv): test rank < 2, size < 2
@hp.given(select_n_strategy(max_cases=2, min_rank=2, max_rank=4,
min_size_exp=1))
Expand Down
13 changes: 7 additions & 6 deletions tests/pallas/tpu_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,12 +305,13 @@ def kernel(x_ref, mask_ref, o_ref):
reduce_func = [jnp.sum, jnp.max, jnp.min]
)
def test_reduction(self, dtype, axis, reduce_func):
if dtype == jnp.int32:
if axis == 2:
self.skipTest("Int32 reduction on minor is not supported.")
# TODO(b/384127570): fix bfloat16 reduction.
if dtype == jnp.bfloat16 and reduce_func != jnp.sum:
self.skipTest("b/384127570")
# TODO(b/395579834): Remove this skip later.
if (
dtype == jnp.int32
and axis == 2
and not jtu.if_cloud_tpu_at_least(2025, 9, 1)
):
self.skipTest("Requires libtpu built after 2025-09-01")
in_shape = (2, 16, 128)
out_shape = list(in_shape)
out_shape[axis] = 1
Expand Down
10 changes: 7 additions & 3 deletions tests/pallas/tpu_pallas_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1967,9 +1967,13 @@ def kernel(x, out):
def test_replicated_broadcast_reduction(
self, m, replicated, reduced_dims, dty, reduce_func
):
if dty == jnp.int32 and 1 in reduced_dims:
# TODO(b/395579834): Remove this skip once we implement this.
self.skipTest('int32 reduction on last dimension not supported')
# TODO(b/395579834): Remove this skip later.
if (
dty == jnp.int32
and 1 in reduced_dims
and is_cloud_tpu_older_than(2025, 9, 1)
):
self.skipTest('Requires libtpu built after 2025-09-01')
if not jtu.is_device_tpu_at_least(4) and len(replicated) == 2:
self.skipTest(
'Brodcast in both sublanes and lanes not supported on this hardware'
Expand Down
Loading