@@ -526,6 +526,43 @@ def kernel(x_ref, o_ref):
526
526
527
527
np .testing .assert_allclose (result [0 , 0 ], reduction_op (x ), atol = 1e-5 )
528
528
529
+ @parameterized .named_parameters (
530
+ ("sum" , jnp .sum , (32 , 256 )),
531
+ ("max" , jnp .max , (32 , 256 )),
532
+ ("min" , jnp .min , (32 , 256 )),
533
+ ("sum_irregular" , jnp .sum , (31 , 300 )),
534
+ ("max_irregular" , jnp .max , (31 , 300 )),
535
+ ("min_irregular" , jnp .min , (31 , 300 )),
536
+ )
537
+ def test_reduce_int32 (self , reduction_op , input_shape ):
538
+ if jtu .test_device_matches (["gpu" ]):
539
+ self .skipTest ("TODO: error on GPU" )
540
+ # TODO(b/395579834): Remove this skip later.
541
+ if not jtu .if_cloud_tpu_at_least (2025 , 9 , 1 ):
542
+ self .skipTest ("Requires libtpu built after 2025-09-01" )
543
+
544
+ def kernel (x_ref , o_ref ):
545
+ o_ref [0 , 0 ] = reduction_op (x_ref [...])
546
+
547
+ x = jax .random .randint (
548
+ jax .random .key (0 ),
549
+ shape = input_shape ,
550
+ minval = - 100 ,
551
+ maxval = 100 ,
552
+ dtype = jnp .int32 ,
553
+ )
554
+ result = self .pallas_call (
555
+ kernel ,
556
+ in_specs = [
557
+ pl .BlockSpec (input_shape , lambda * _ : (0 , 0 )),
558
+ ],
559
+ out_specs = pl .BlockSpec ((1 , 1 ), memory_space = smem_on_tpu ()),
560
+ out_shape = jax .ShapeDtypeStruct ([1 , 1 ], intx ),
561
+ grid = (1 ,),
562
+ )(x )
563
+
564
+ np .testing .assert_allclose (result [0 , 0 ], reduction_op (x ), atol = 1e-5 )
565
+
529
566
# TODO(sharadmv): test rank < 2, size < 2
530
567
@hp .given (select_n_strategy (max_cases = 2 , min_rank = 2 , max_rank = 4 ,
531
568
min_size_exp = 1 ))
0 commit comments