Skip to content

Commit 52f6273

Browse files
apaszkeGoogle-ML-Automation
authored andcommitted
Don't run Pallas:SC debug_print tests in multithreaded pytest invocations
They mess with stdout capturing, so we run them separately. PiperOrigin-RevId: 802042120
1 parent 2b43197 commit 52f6273

File tree

2 files changed

+79
-73
lines changed

2 files changed

+79
-73
lines changed

ci/run_pytest_tpu.sh

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ if [[ "$JAXCI_RUN_FULL_TPU_TEST_SUITE" == "1" ]]; then
7171
# Run single-accelerator tests in parallel
7272
JAX_ENABLE_TPU_XDIST=true "$JAXCI_PYTHON" -m pytest -n="$JAXCI_TPU_CORES" --tb=short \
7373
--deselect=tests/pallas/tpu_pallas_call_print_test.py::PallasCallPrintTest \
74+
--deselect=tests/pallas/tpu_sparsecore_pallas_test.py::DebugPrintTest \
7475
--deselect=tests/pallas/tpu_pallas_interpret_thread_map_test.py::InterpretThreadMapTest::test_thread_map \
7576
--maxfail=20 -m "not multiaccelerator" $IGNORE_FLAGS tests examples
7677

@@ -108,7 +109,9 @@ else
108109
fi
109110

110111
# Run Pallas printing tests, which need to run with I/O capturing disabled.
111-
TPU_STDERR_LOG_LEVEL=0 "$JAXCI_PYTHON" -m pytest -s tests/pallas/tpu_pallas_call_print_test.py::PallasCallPrintTest
112+
TPU_STDERR_LOG_LEVEL=0 "$JAXCI_PYTHON" -m pytest \
113+
-s tests/pallas/tpu_pallas_call_print_test.py::PallasCallPrintTest \
114+
-s tests/pallas/tpu_sparsecore_pallas_test.py::DebugPrintTest
112115

113116
# Store the return value of the third command.
114117
third_cmd_retval=$?

tests/pallas/tpu_sparsecore_pallas_test.py

Lines changed: 75 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -43,26 +43,10 @@ def setUp(self):
4343
super().setUp()
4444

4545

46-
class VectorSubcoreTest(PallasSCTest):
47-
48-
# Used for testing masked loads and stores below
49-
MASK_FNS = [lambda x: x < 4, lambda x: x >= 4, lambda x: x % 2 == 0]
50-
51-
@parameterized.product(
52-
dtype=[jnp.int32, jnp.float32], op=[jnp.add, jnp.subtract]
53-
)
54-
def test_add_sub_one(self, dtype, op):
55-
x = jnp.arange(8, dtype=dtype)
56-
57-
@plsc.vector_subcore_kernel(out_shape=x)
58-
def kernel(x_ref, o_ref):
59-
x = x_ref[...]
60-
o_ref[...] = op(x, 1)
61-
62-
np.testing.assert_array_equal(kernel(x), op(x, 1))
46+
class DebugPrintTest(PallasSCTest):
6347

6448
@parameterized.product(dtype=[jnp.int32, jnp.float32])
65-
def test_debug_print(self, dtype):
49+
def test_vector_subcore(self, dtype):
6650
x = jnp.arange(16, dtype=dtype)
6751
debug_int = 1234552
6852
debug_float = 12344.625
@@ -99,6 +83,79 @@ def kernel(x_hbm_ref, _):
9983
self.assertIn(str(debug_float), get_output())
10084
self.assertIn("No values", get_output())
10185

86+
def test_scalar_subcore(self):
87+
int32s = jnp.arange(512, dtype=jnp.int32).reshape(64, 8)
88+
int16s = jnp.arange(512, dtype=jnp.int16).reshape(32, 16)
89+
int8s = jnp.arange(512, dtype=jnp.int8).reshape(16, 32)
90+
debug_int = 1234552
91+
debug_float = 12344.625
92+
93+
@plsc.scalar_subcore_kernel(
94+
out_shape=int32s,
95+
mesh=plsc.ScalarSubcoreMesh(axis_name="core", num_cores=self.num_cores),
96+
)
97+
def kernel(int32s_hbm_ref, int16s_hbm_ref, int8s_hbm_ref, o_hbm_ref):
98+
@functools.partial(
99+
pl.run_scoped,
100+
tmp_ref=pltpu.VMEM_SHARED(int32s.shape, int32s.dtype),
101+
sem=pltpu.SemaphoreType.DMA,
102+
)
103+
def _(tmp_ref, sem):
104+
@pl.when(lax.axis_index("core") == 0)
105+
def _():
106+
pltpu.async_copy(int32s_hbm_ref, tmp_ref, sem).wait()
107+
pltpu.async_copy(tmp_ref, o_hbm_ref, sem).wait()
108+
pl.debug_print("s32 array", tmp_ref)
109+
pl.debug_print("s16 array", int16s_hbm_ref)
110+
pl.debug_print("s8 array", int8s_hbm_ref)
111+
pl.debug_print("Single int", debug_int)
112+
pl.debug_print("Single float", debug_float)
113+
pl.debug_print("No values")
114+
115+
compiled_kernel = jax.jit(
116+
kernel, compiler_options={"xla_tpu_enable_sc_log_recorder": "true"}
117+
)
118+
with jtu.capture_stderr() as get_output:
119+
jax.block_until_ready(compiled_kernel(int32s, int16s, int8s))
120+
print(get_output())
121+
self.assertIn("s32 array, data: s32", get_output())
122+
self.assertIn("{ 8, 9, 10, 11, 12, 13, 14, 15 }", get_output())
123+
self.assertIn("s16 array, data: s16", get_output())
124+
self.assertIn(
125+
"{ 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31 }",
126+
get_output(),
127+
)
128+
self.assertIn("s8 array, data: s8", get_output())
129+
self.assertIn(
130+
"{ 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47"
131+
", 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63 }",
132+
get_output(),
133+
)
134+
self.assertIn("Single int", get_output())
135+
self.assertIn(str(debug_int), get_output())
136+
self.assertIn("Single float", get_output())
137+
self.assertIn(str(debug_float), get_output())
138+
self.assertIn("No values", get_output())
139+
140+
141+
class VectorSubcoreTest(PallasSCTest):
142+
143+
# Used for testing masked loads and stores below
144+
MASK_FNS = [lambda x: x < 4, lambda x: x >= 4, lambda x: x % 2 == 0]
145+
146+
@parameterized.product(
147+
dtype=[jnp.int32, jnp.float32], op=[jnp.add, jnp.subtract]
148+
)
149+
def test_add_sub_one(self, dtype, op):
150+
x = jnp.arange(8, dtype=dtype)
151+
152+
@plsc.vector_subcore_kernel(out_shape=x)
153+
def kernel(x_ref, o_ref):
154+
x = x_ref[...]
155+
o_ref[...] = op(x, 1)
156+
157+
np.testing.assert_array_equal(kernel(x), op(x, 1))
158+
102159
def test_add_one_block_specs(self):
103160
x = jnp.arange(32, dtype=jnp.int32)
104161

@@ -624,60 +681,6 @@ class ScalarSubcoreTest(PallasSCTest):
624681
def num_cores(self):
625682
return sc_core._num_available_cores()
626683

627-
def test_debug_print(self):
628-
int32s = jnp.arange(512, dtype=jnp.int32).reshape(64, 8)
629-
int16s = jnp.arange(512, dtype=jnp.int16).reshape(32, 16)
630-
int8s = jnp.arange(512, dtype=jnp.int8).reshape(16, 32)
631-
debug_int = 1234552
632-
debug_float = 12344.625
633-
634-
@plsc.scalar_subcore_kernel(
635-
out_shape=int32s,
636-
mesh=plsc.ScalarSubcoreMesh(axis_name="core", num_cores=self.num_cores),
637-
)
638-
def kernel(int32s_hbm_ref, int16s_hbm_ref, int8s_hbm_ref, o_hbm_ref):
639-
@functools.partial(
640-
pl.run_scoped,
641-
tmp_ref=pltpu.VMEM_SHARED(int32s.shape, int32s.dtype),
642-
sem=pltpu.SemaphoreType.DMA,
643-
)
644-
def _(tmp_ref, sem):
645-
@pl.when(lax.axis_index("core") == 0)
646-
def _():
647-
pltpu.async_copy(int32s_hbm_ref, tmp_ref, sem).wait()
648-
pltpu.async_copy(tmp_ref, o_hbm_ref, sem).wait()
649-
pl.debug_print("s32 array", tmp_ref)
650-
pl.debug_print("s16 array", int16s_hbm_ref)
651-
pl.debug_print("s8 array", int8s_hbm_ref)
652-
pl.debug_print("Single int", debug_int)
653-
pl.debug_print("Single float", debug_float)
654-
pl.debug_print("No values")
655-
656-
compiled_kernel = jax.jit(
657-
kernel, compiler_options={"xla_tpu_enable_sc_log_recorder": "true"}
658-
)
659-
with jtu.capture_stderr() as get_output:
660-
jax.block_until_ready(compiled_kernel(int32s, int16s, int8s))
661-
print(get_output())
662-
self.assertIn("s32 array, data: s32", get_output())
663-
self.assertIn("{ 8, 9, 10, 11, 12, 13, 14, 15 }", get_output())
664-
self.assertIn("s16 array, data: s16", get_output())
665-
self.assertIn(
666-
"{ 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31 }",
667-
get_output(),
668-
)
669-
self.assertIn("s8 array, data: s8", get_output())
670-
self.assertIn(
671-
"{ 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47"
672-
", 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63 }",
673-
get_output(),
674-
)
675-
self.assertIn("Single int", get_output())
676-
self.assertIn(str(debug_int), get_output())
677-
self.assertIn("Single float", get_output())
678-
self.assertIn(str(debug_float), get_output())
679-
self.assertIn("No values", get_output())
680-
681684
def test_copy(self):
682685
x = jnp.arange(16)
683686

0 commit comments

Comments
 (0)