diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 5218160a6a77..33f4fed610b8 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -3796,46 +3796,40 @@ def kernel(ctx, *args): )(inp) np.testing.assert_array_equal(result, inp) - @parameterized.product( - mns=((128, 128), (128, 64), (64, 128)), - layout=(mtu.RegisterLayout.WG_STRIDED, mtu.RegisterLayout.WGMMA), + @parameterized.parameters( + (128, 128), (64, 128), (64, 256) ) - def test_broadcast_major(self, mns, layout): - m, n = mns + def test_broadcast_in_dim_major_strided(self, m, n): + dtype = jnp.float16 + def kernel(ctx, gmem_input, gmem_output, _): + t = mgpu.FragmentedArray.load_strided( + gmem_input, vec_size=1 + ) + t.broadcast_in_dim((m, n), (1,), + mgpu.WGStridedFragLayout(shape=(m, n), vec_size=1), + ).store_untiled(gmem_output, optimized=False) - if n < 128 and layout == mtu.RegisterLayout.WG_STRIDED: - self.skipTest(f"{n=} < 128 not supported for {layout=}") + inp = self.prng.uniform(-1, 1, (n,)).astype(dtype) + out_shape = jax.ShapeDtypeStruct((m, n), dtype) + result = mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), (inp,), out_shape, inp + )(inp) + out_ref = jax.lax.broadcast_in_dim(inp, (m, n), (1,)) + np.testing.assert_array_equal(result, out_ref) + @parameterized.parameters( + (128, 128), (128, 64), (64, 128) + ) + def test_broadcast_in_dim_major_wgmma(self, m, n): dtype = jnp.float16 - load_layout = ( - layout.to_mgpu((n,), dtype) - if layout == mtu.RegisterLayout.WG_STRIDED - else mgpu.WGMMA_COL_LAYOUT - ) - broadcast_layout = ( - mgpu.WGStridedFragLayout((m, n), load_layout.vec_size) - if layout == mtu.RegisterLayout.WG_STRIDED - else layout.to_mgpu((m, n), dtype) - ) - - def load(gmem_input): - match layout: - case mtu.RegisterLayout.WG_STRIDED: - return mgpu.FragmentedArray.load_strided( - gmem_input, vec_size=load_layout.vec_size - ) - case mtu.RegisterLayout.WGMMA: - return mgpu.FragmentedArray.load_untiled( - gmem_input, layout=mgpu.WGMMA_COL_LAYOUT, optimized=False - ) - case _: - raise NotImplementedError(f"Unsupported layout: {layout}") def kernel(ctx, gmem_input, gmem_output, _): - t = load(gmem_input) - t.broadcast_in_dim((m, n), (1,), broadcast_layout).store_untiled( - gmem_output, optimized=False + t = mgpu.FragmentedArray.load_untiled( + gmem_input, layout=mgpu.WGMMA_COL_LAYOUT, optimized=False ) + t.broadcast_in_dim( + (m, n), (1,), mgpu.WGMMA_LAYOUT + ).store_untiled(gmem_output, optimized=False) inp = self.prng.uniform(-1, 1, (n,)).astype(dtype) out_shape = jax.ShapeDtypeStruct((m, n), dtype)