@@ -43,26 +43,10 @@ def setUp(self):
43
43
super ().setUp ()
44
44
45
45
46
- class VectorSubcoreTest (PallasSCTest ):
47
-
48
- # Used for testing masked loads and stores below
49
- MASK_FNS = [lambda x : x < 4 , lambda x : x >= 4 , lambda x : x % 2 == 0 ]
50
-
51
- @parameterized .product (
52
- dtype = [jnp .int32 , jnp .float32 ], op = [jnp .add , jnp .subtract ]
53
- )
54
- def test_add_sub_one (self , dtype , op ):
55
- x = jnp .arange (8 , dtype = dtype )
56
-
57
- @plsc .vector_subcore_kernel (out_shape = x )
58
- def kernel (x_ref , o_ref ):
59
- x = x_ref [...]
60
- o_ref [...] = op (x , 1 )
61
-
62
- np .testing .assert_array_equal (kernel (x ), op (x , 1 ))
46
+ class DebugPrintTest (PallasSCTest ):
63
47
64
48
@parameterized .product (dtype = [jnp .int32 , jnp .float32 ])
65
- def test_debug_print (self , dtype ):
49
+ def test_vector_subcore (self , dtype ):
66
50
x = jnp .arange (16 , dtype = dtype )
67
51
debug_int = 1234552
68
52
debug_float = 12344.625
@@ -99,6 +83,79 @@ def kernel(x_hbm_ref, _):
99
83
self .assertIn (str (debug_float ), get_output ())
100
84
self .assertIn ("No values" , get_output ())
101
85
86
+ def test_scalar_subcore (self ):
87
+ int32s = jnp .arange (512 , dtype = jnp .int32 ).reshape (64 , 8 )
88
+ int16s = jnp .arange (512 , dtype = jnp .int16 ).reshape (32 , 16 )
89
+ int8s = jnp .arange (512 , dtype = jnp .int8 ).reshape (16 , 32 )
90
+ debug_int = 1234552
91
+ debug_float = 12344.625
92
+
93
+ @plsc .scalar_subcore_kernel (
94
+ out_shape = int32s ,
95
+ mesh = plsc .ScalarSubcoreMesh (axis_name = "core" , num_cores = self .num_cores ),
96
+ )
97
+ def kernel (int32s_hbm_ref , int16s_hbm_ref , int8s_hbm_ref , o_hbm_ref ):
98
+ @functools .partial (
99
+ pl .run_scoped ,
100
+ tmp_ref = pltpu .VMEM_SHARED (int32s .shape , int32s .dtype ),
101
+ sem = pltpu .SemaphoreType .DMA ,
102
+ )
103
+ def _ (tmp_ref , sem ):
104
+ @pl .when (lax .axis_index ("core" ) == 0 )
105
+ def _ ():
106
+ pltpu .async_copy (int32s_hbm_ref , tmp_ref , sem ).wait ()
107
+ pltpu .async_copy (tmp_ref , o_hbm_ref , sem ).wait ()
108
+ pl .debug_print ("s32 array" , tmp_ref )
109
+ pl .debug_print ("s16 array" , int16s_hbm_ref )
110
+ pl .debug_print ("s8 array" , int8s_hbm_ref )
111
+ pl .debug_print ("Single int" , debug_int )
112
+ pl .debug_print ("Single float" , debug_float )
113
+ pl .debug_print ("No values" )
114
+
115
+ compiled_kernel = jax .jit (
116
+ kernel , compiler_options = {"xla_tpu_enable_sc_log_recorder" : "true" }
117
+ )
118
+ with jtu .capture_stderr () as get_output :
119
+ jax .block_until_ready (compiled_kernel (int32s , int16s , int8s ))
120
+ print (get_output ())
121
+ self .assertIn ("s32 array, data: s32" , get_output ())
122
+ self .assertIn ("{ 8, 9, 10, 11, 12, 13, 14, 15 }" , get_output ())
123
+ self .assertIn ("s16 array, data: s16" , get_output ())
124
+ self .assertIn (
125
+ "{ 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31 }" ,
126
+ get_output (),
127
+ )
128
+ self .assertIn ("s8 array, data: s8" , get_output ())
129
+ self .assertIn (
130
+ "{ 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47"
131
+ ", 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63 }" ,
132
+ get_output (),
133
+ )
134
+ self .assertIn ("Single int" , get_output ())
135
+ self .assertIn (str (debug_int ), get_output ())
136
+ self .assertIn ("Single float" , get_output ())
137
+ self .assertIn (str (debug_float ), get_output ())
138
+ self .assertIn ("No values" , get_output ())
139
+
140
+
141
+ class VectorSubcoreTest (PallasSCTest ):
142
+
143
+ # Used for testing masked loads and stores below
144
+ MASK_FNS = [lambda x : x < 4 , lambda x : x >= 4 , lambda x : x % 2 == 0 ]
145
+
146
+ @parameterized .product (
147
+ dtype = [jnp .int32 , jnp .float32 ], op = [jnp .add , jnp .subtract ]
148
+ )
149
+ def test_add_sub_one (self , dtype , op ):
150
+ x = jnp .arange (8 , dtype = dtype )
151
+
152
+ @plsc .vector_subcore_kernel (out_shape = x )
153
+ def kernel (x_ref , o_ref ):
154
+ x = x_ref [...]
155
+ o_ref [...] = op (x , 1 )
156
+
157
+ np .testing .assert_array_equal (kernel (x ), op (x , 1 ))
158
+
102
159
def test_add_one_block_specs (self ):
103
160
x = jnp .arange (32 , dtype = jnp .int32 )
104
161
@@ -624,60 +681,6 @@ class ScalarSubcoreTest(PallasSCTest):
624
681
def num_cores (self ):
625
682
return sc_core ._num_available_cores ()
626
683
627
- def test_debug_print (self ):
628
- int32s = jnp .arange (512 , dtype = jnp .int32 ).reshape (64 , 8 )
629
- int16s = jnp .arange (512 , dtype = jnp .int16 ).reshape (32 , 16 )
630
- int8s = jnp .arange (512 , dtype = jnp .int8 ).reshape (16 , 32 )
631
- debug_int = 1234552
632
- debug_float = 12344.625
633
-
634
- @plsc .scalar_subcore_kernel (
635
- out_shape = int32s ,
636
- mesh = plsc .ScalarSubcoreMesh (axis_name = "core" , num_cores = self .num_cores ),
637
- )
638
- def kernel (int32s_hbm_ref , int16s_hbm_ref , int8s_hbm_ref , o_hbm_ref ):
639
- @functools .partial (
640
- pl .run_scoped ,
641
- tmp_ref = pltpu .VMEM_SHARED (int32s .shape , int32s .dtype ),
642
- sem = pltpu .SemaphoreType .DMA ,
643
- )
644
- def _ (tmp_ref , sem ):
645
- @pl .when (lax .axis_index ("core" ) == 0 )
646
- def _ ():
647
- pltpu .async_copy (int32s_hbm_ref , tmp_ref , sem ).wait ()
648
- pltpu .async_copy (tmp_ref , o_hbm_ref , sem ).wait ()
649
- pl .debug_print ("s32 array" , tmp_ref )
650
- pl .debug_print ("s16 array" , int16s_hbm_ref )
651
- pl .debug_print ("s8 array" , int8s_hbm_ref )
652
- pl .debug_print ("Single int" , debug_int )
653
- pl .debug_print ("Single float" , debug_float )
654
- pl .debug_print ("No values" )
655
-
656
- compiled_kernel = jax .jit (
657
- kernel , compiler_options = {"xla_tpu_enable_sc_log_recorder" : "true" }
658
- )
659
- with jtu .capture_stderr () as get_output :
660
- jax .block_until_ready (compiled_kernel (int32s , int16s , int8s ))
661
- print (get_output ())
662
- self .assertIn ("s32 array, data: s32" , get_output ())
663
- self .assertIn ("{ 8, 9, 10, 11, 12, 13, 14, 15 }" , get_output ())
664
- self .assertIn ("s16 array, data: s16" , get_output ())
665
- self .assertIn (
666
- "{ 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31 }" ,
667
- get_output (),
668
- )
669
- self .assertIn ("s8 array, data: s8" , get_output ())
670
- self .assertIn (
671
- "{ 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47"
672
- ", 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63 }" ,
673
- get_output (),
674
- )
675
- self .assertIn ("Single int" , get_output ())
676
- self .assertIn (str (debug_int ), get_output ())
677
- self .assertIn ("Single float" , get_output ())
678
- self .assertIn (str (debug_float ), get_output ())
679
- self .assertIn ("No values" , get_output ())
680
-
681
684
def test_copy (self ):
682
685
x = jnp .arange (16 )
683
686
0 commit comments