Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 27 additions & 33 deletions tests/mosaic/gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3796,46 +3796,40 @@ def kernel(ctx, *args):
)(inp)
np.testing.assert_array_equal(result, inp)

@parameterized.product(
mns=((128, 128), (128, 64), (64, 128)),
layout=(mtu.RegisterLayout.WG_STRIDED, mtu.RegisterLayout.WGMMA),
@parameterized.parameters(
(128, 128), (64, 128), (64, 256)
)
def test_broadcast_major(self, mns, layout):
m, n = mns
def test_broadcast_in_dim_major_strided(self, m, n):
dtype = jnp.float16
def kernel(ctx, gmem_input, gmem_output, _):
t = mgpu.FragmentedArray.load_strided(
gmem_input, vec_size=1
)
t.broadcast_in_dim((m, n), (1,),
mgpu.WGStridedFragLayout(shape=(m, n), vec_size=1),
).store_untiled(gmem_output, optimized=False)

if n < 128 and layout == mtu.RegisterLayout.WG_STRIDED:
self.skipTest(f"{n=} < 128 not supported for {layout=}")
inp = self.prng.uniform(-1, 1, (n,)).astype(dtype)
out_shape = jax.ShapeDtypeStruct((m, n), dtype)
result = mgpu.as_gpu_kernel(
kernel, (1, 1, 1), (128, 1, 1), (inp,), out_shape, inp
)(inp)
out_ref = jax.lax.broadcast_in_dim(inp, (m, n), (1,))
np.testing.assert_array_equal(result, out_ref)

@parameterized.parameters(
(128, 128), (128, 64), (64, 128)
)
def test_broadcast_in_dim_major_wgmma(self, m, n):
dtype = jnp.float16
load_layout = (
layout.to_mgpu((n,), dtype)
if layout == mtu.RegisterLayout.WG_STRIDED
else mgpu.WGMMA_COL_LAYOUT
)
broadcast_layout = (
mgpu.WGStridedFragLayout((m, n), load_layout.vec_size)
if layout == mtu.RegisterLayout.WG_STRIDED
else layout.to_mgpu((m, n), dtype)
)

def load(gmem_input):
match layout:
case mtu.RegisterLayout.WG_STRIDED:
return mgpu.FragmentedArray.load_strided(
gmem_input, vec_size=load_layout.vec_size
)
case mtu.RegisterLayout.WGMMA:
return mgpu.FragmentedArray.load_untiled(
gmem_input, layout=mgpu.WGMMA_COL_LAYOUT, optimized=False
)
case _:
raise NotImplementedError(f"Unsupported layout: {layout}")

def kernel(ctx, gmem_input, gmem_output, _):
t = load(gmem_input)
t.broadcast_in_dim((m, n), (1,), broadcast_layout).store_untiled(
gmem_output, optimized=False
t = mgpu.FragmentedArray.load_untiled(
gmem_input, layout=mgpu.WGMMA_COL_LAYOUT, optimized=False
)
t.broadcast_in_dim(
(m, n), (1,), mgpu.WGMMA_LAYOUT
).store_untiled(gmem_output, optimized=False)

inp = self.prng.uniform(-1, 1, (n,)).astype(dtype)
out_shape = jax.ShapeDtypeStruct((m, n), dtype)
Expand Down
Loading