diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index 0093ed01caf0..7a4d21dbc45f 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -2454,6 +2454,38 @@ def reduce( ) def broadcast(self, shape) -> FragmentedArray: + if isinstance(self.layout, WGStridedFragLayout): + src_shape, dst_shape = self.layout.shape, shape + if len(src_shape) > len(dst_shape): + raise ValueError( + f"Shape length mismatch. Expected len({src_shape}) <= len({dst_shape})" + ) + if not all(s == 1 or s == d for s, d in zip(src_shape[::-1], dst_shape[::-1])): + raise ValueError( + "Can broadcast if all source dimensions match trailing target" + " dimensions by being equal or set to 1. Broadcasting from" + f" {src_shape} to {dst_shape}" + ) + rank_diff = len(dst_shape) - len(src_shape) + src_shape = tuple([1] * rank_diff + list(src_shape)) + + assert len(src_shape) == len(dst_shape), (src_shape, dst_shape) + len_suffix = next( + (i for i in range(len(src_shape)) if src_shape[~i] != dst_shape[~i]), + len(src_shape) + ) + if len_suffix > 0 and all(x == 1 for x in src_shape[:-len_suffix]): + return FragmentedArray( + _registers=np.tile(self.registers, np.prod(dst_shape[:-len_suffix])), + _layout=WGStridedFragLayout(shape, self.layout.vec_size), + _is_signed=self.is_signed, + ) + + raise NotImplementedError( + "Only major-most broadcast for WGStridedFragLayout is implemented." + f" Broadcasting from: {src_shape}, to: {dst_shape}." + ) + if not isinstance(self.layout, WGSplatFragLayout): raise NotImplementedError(self.layout) diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index a81e6b29d043..deb8772937d0 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -3839,6 +3839,29 @@ def kernel(ctx, gmem_input, gmem_output, _): out_ref = jax.lax.broadcast_in_dim(inp, (m, n), (1,)) np.testing.assert_array_equal(result, out_ref) + @parameterized.parameters( + ((128), (4, 128)), + ((1, 128), (2, 128)), + ((1, 128), (4, 128)), + ((1, 256), (2, 256)), + ((128, ), (1, 3, 1, 2, 4, 128)), + ((1, 1, 128,), (1, 3, 1, 2, 4, 128)), + ((1, 1, 1, 1, 1, 128,), (1, 3, 1, 2, 4, 128)), + ((2, 4, 128,), (1, 3, 1, 2, 4, 128)), + ((1, 1, 1, 2, 4, 128,), (1, 3, 1, 2, 4, 128)), + ((2, 8, 8), (2, 8, 8)), + ) + def test_broadcast_major_strided(self, in_shape, out_shape): + dtype = jnp.float16 + def kernel(ctx, gmem_input, gmem_output, _): + t = mgpu.FragmentedArray.load_strided(gmem_input, vec_size=1) + t.broadcast(out_shape).store_untiled(gmem_output, optimized=False) + inp = self.prng.uniform(-1, 1, in_shape).astype(dtype) + result = mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), (inp,), jax.ShapeDtypeStruct(out_shape, dtype), inp + )(inp) + np.testing.assert_array_equal(result, jnp.broadcast_to(inp, out_shape)) + @parameterized.parameters(*mtu.RegisterLayout) def test_broadcast_splat(self, layout): out_shape = (128, 128) diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 7b5d844e7015..05c2ef48516a 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -2486,6 +2486,32 @@ def kernel(x_ref, y_ref): result = jax.random.uniform(jax.random.key(0), shape=(128,), dtype=jnp.float32) np.testing.assert_array_equal(kernel(result), jnp.broadcast_to(result[None,:], (256, 128))) + @parameterized.parameters( + ((4, 128),), + ((2, 4, 128),), + ) + def test_broadcast_wg_strided_majormost_dim(self, out_shape): + self.skip_if_wg_semantics() # Lowering not implemented. + dtype = jnp.float32 + @functools.partial( + self.pallas_call, out_shape=jax.ShapeDtypeStruct(out_shape, dtype) + ) + def kernel(x_ref, side_load_ref, y_ref): + x_strided = plgpu.load( + x_ref, (), layout=plgpu.Layout.WG_STRIDED((128,), vec_size=1) + ) + side_load_strided = plgpu.load( + side_load_ref, (), layout=plgpu.Layout.WG_STRIDED(out_shape, vec_size=1) + ) + for _ in range(len(out_shape) - 1): + x_strided = x_strided[None, ...] + y_ref[...] = x_strided + side_load_strided[...] + + inp = jax.random.uniform(jax.random.key(0), (128,), dtype) + side_load = jax.random.uniform(jax.random.key(1), out_shape, dtype) + np.testing.assert_array_equal(kernel(inp, side_load), + jnp.broadcast_to(inp, out_shape) + side_load) + def test_broadcast_in_dim_tcgen05_native_layout(self): @functools.partial( self.kernel,