Skip to content

Commit d3ff2a6

Browse files
apaszkeGoogle-ML-Automation
authored andcommitted
Enable the bounds checker in the Pallas:SC bounds checker test
PiperOrigin-RevId: 802046061
1 parent 52f6273 commit d3ff2a6

File tree

2 files changed

+8
-0
lines changed

2 files changed

+8
-0
lines changed

tests/pallas/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1198,6 +1198,7 @@ jax_multiplatform_test(
11981198
] + py_deps([
11991199
"numpy",
12001200
"absl/testing",
1201+
"absl/flags",
12011202
]),
12021203
)
12031204

tests/pallas/tpu_pallas_sparsecore_debug_check_test.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import unittest
2929

3030
from absl.testing import absltest
31+
from absl import flags
3132
import jax
3233
from jax._src import test_util as jtu
3334
from jax.experimental import pallas as pl
@@ -118,6 +119,12 @@ def kernel(_):
118119
)
119120

120121
def test_trigger_bounds_checker(self):
122+
if "xla_sc_assert_level" in flags.FLAGS:
123+
# The test crashes the process anyway, so no need to be clean.
124+
flags.FLAGS.xla_sc_assert_level = "bounds"
125+
else:
126+
self.skipTest("TODO: Find another way to enable bounds checking.")
127+
121128
x = jnp.arange(8, dtype=jnp.int32)
122129
# Index 8 is out-of-bounds.
123130
indices = jnp.array([0, 1, 2, 3, 4, 5, 6, 8], dtype=jnp.int32)

0 commit comments

Comments
 (0)