Skip to content

Commit 02bd9ec

Browse files
yueshengysGoogle-ML-Automation
authored andcommitted
[Mosaic] Add support for S32 cross-lane reduction.
PiperOrigin-RevId: 797979994
1 parent 82da6e1 commit 02bd9ec

File tree

3 files changed

+51
-9
lines changed

3 files changed

+51
-9
lines changed

tests/pallas/ops_test.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -526,6 +526,43 @@ def kernel(x_ref, o_ref):
526526

527527
np.testing.assert_allclose(result[0, 0], reduction_op(x), atol=1e-5)
528528

529+
@parameterized.named_parameters(
530+
("sum", jnp.sum, (32, 256)),
531+
("max", jnp.max, (32, 256)),
532+
("min", jnp.min, (32, 256)),
533+
("sum_irregular", jnp.sum, (31, 300)),
534+
("max_irregular", jnp.max, (31, 300)),
535+
("min_irregular", jnp.min, (31, 300)),
536+
)
537+
def test_reduce_int32(self, reduction_op, input_shape):
538+
if jtu.test_device_matches(["gpu"]):
539+
self.skipTest("TODO: error on GPU")
540+
# TODO(b/395579834): Remove this skip later.
541+
if not jtu.if_cloud_tpu_at_least(2025, 9, 1):
542+
self.skipTest("Requires libtpu built after 2025-09-01")
543+
544+
def kernel(x_ref, o_ref):
545+
o_ref[0, 0] = reduction_op(x_ref[...])
546+
547+
x = jax.random.randint(
548+
jax.random.key(0),
549+
shape=input_shape,
550+
minval=-100,
551+
maxval=100,
552+
dtype=jnp.int32,
553+
)
554+
result = self.pallas_call(
555+
kernel,
556+
in_specs=[
557+
pl.BlockSpec(input_shape, lambda *_: (0, 0)),
558+
],
559+
out_specs=pl.BlockSpec((1, 1), memory_space=smem_on_tpu()),
560+
out_shape=jax.ShapeDtypeStruct([1, 1], intx),
561+
grid=(1,),
562+
)(x)
563+
564+
np.testing.assert_allclose(result[0, 0], reduction_op(x), atol=1e-5)
565+
529566
# TODO(sharadmv): test rank < 2, size < 2
530567
@hp.given(select_n_strategy(max_cases=2, min_rank=2, max_rank=4,
531568
min_size_exp=1))

tests/pallas/tpu_ops_test.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -305,12 +305,13 @@ def kernel(x_ref, mask_ref, o_ref):
305305
reduce_func = [jnp.sum, jnp.max, jnp.min]
306306
)
307307
def test_reduction(self, dtype, axis, reduce_func):
308-
if dtype == jnp.int32:
309-
if axis == 2:
310-
self.skipTest("Int32 reduction on minor is not supported.")
311-
# TODO(b/384127570): fix bfloat16 reduction.
312-
if dtype == jnp.bfloat16 and reduce_func != jnp.sum:
313-
self.skipTest("b/384127570")
308+
# TODO(b/395579834): Remove this skip later.
309+
if (
310+
dtype == jnp.int32
311+
and axis == 2
312+
and not jtu.if_cloud_tpu_at_least(2025, 9, 1)
313+
):
314+
self.skipTest("Requires libtpu built after 2025-09-01")
314315
in_shape = (2, 16, 128)
315316
out_shape = list(in_shape)
316317
out_shape[axis] = 1

tests/pallas/tpu_pallas_test.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1967,9 +1967,13 @@ def kernel(x, out):
19671967
def test_replicated_broadcast_reduction(
19681968
self, m, replicated, reduced_dims, dty, reduce_func
19691969
):
1970-
if dty == jnp.int32 and 1 in reduced_dims:
1971-
# TODO(b/395579834): Remove this skip once we implement this.
1972-
self.skipTest('int32 reduction on last dimension not supported')
1970+
# TODO(b/395579834): Remove this skip later.
1971+
if (
1972+
dty == jnp.int32
1973+
and 1 in reduced_dims
1974+
and is_cloud_tpu_older_than(2025, 9, 1)
1975+
):
1976+
self.skipTest('Requires libtpu built after 2025-09-01')
19731977
if not jtu.is_device_tpu_at_least(4) and len(replicated) == 2:
19741978
self.skipTest(
19751979
'Brodcast in both sublanes and lanes not supported on this hardware'

0 commit comments

Comments
 (0)