Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions jax/experimental/mosaic/gpu/fragmented_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,6 +681,11 @@ def registers_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]:
raise ValueError(f"Shape {shape} is not compatible with {self}")
return (math.prod(self.shape) // (WARPGROUP_SIZE * self.vec_size),)

def can_broadcast_to(self, shape) -> bool:
"""Check that the shape can be broadcast."""
ones = [i for i, elem in enumerate(shape) if elem == 1]
return list(range(len(ones))) == ones

def shape_from_registers_shape(
self, shape: tuple[int, ...]
) -> tuple[int, ...]:
Expand Down Expand Up @@ -2447,6 +2452,20 @@ def reduce(
)

def broadcast(self, shape) -> FragmentedArray:
if isinstance(self.layout, WGStridedFragLayout):
if not self.layout.can_broadcast_to(shape):
raise NotImplementedError(
f"Only major-most broadcast is implemented. Layout: {self.layout},"
f" to shape: {shape}."
)

one_dims = [i for i, elem in enumerate(self.shape) if elem == 1]
assert list(range(len(one_dims))) == one_dims, (one_dims,)
return FragmentedArray(
_registers=np.tile(self.registers, np.prod(shape[:len(one_dims)])),
_layout=WGStridedFragLayout(shape, self.layout.vec_size),
_is_signed=self.is_signed,
)
if not isinstance(self.layout, WGSplatFragLayout):
raise NotImplementedError(self.layout)

Expand Down
16 changes: 16 additions & 0 deletions tests/mosaic/gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3845,6 +3845,22 @@ def kernel(ctx, gmem_input, gmem_output, _):
out_ref = jax.lax.broadcast_in_dim(inp, (m, n), (1,))
np.testing.assert_array_equal(result, out_ref)

@parameterized.parameters(
(128, 128), (64, 128), (64, 256),
)
def test_broadcast_major_strided(self, m, n):
dtype = jnp.float16
def kernel(ctx, gmem_input, gmem_output, _):
t = mgpu.FragmentedArray.load_strided(gmem_input, vec_size=1)
t.broadcast((m, n)).store_untiled(gmem_output, optimized=False)
inp = self.prng.uniform(-1, 1, (1, n)).astype(dtype)
out_shape = jax.ShapeDtypeStruct((m, n), dtype)
result = mgpu.as_gpu_kernel(
kernel, (1, 1, 1), (128, 1, 1), (inp,), out_shape, inp
)(inp)
out_ref = jax.lax.broadcast(inp.reshape((n,)), (m,))
np.testing.assert_array_equal(result, out_ref)

@parameterized.parameters(*mtu.RegisterLayout)
def test_broadcast_splat(self, layout):
out_shape = (128, 128)
Expand Down
32 changes: 32 additions & 0 deletions tests/pallas/mosaic_gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2416,6 +2416,38 @@ def kernel(x_ref, y_ref):
result = jax.random.uniform(jax.random.key(0), shape=(128,), dtype=jnp.float32)
np.testing.assert_array_equal(kernel(result), jnp.broadcast_to(result[None,:], (256, 128)))

@parameterized.parameters(
((64, 128),),
((2, 32, 128),),
)
def test_broadcast_wg_strided_majormost_dim(self, out_shape):
self.skip_if_wg_semantics()
num_major_dims = len(out_shape) - 1
@functools.partial(
self.pallas_call,
out_shape=jax.ShapeDtypeStruct(out_shape, jnp.float32),
)
def kernel(x_ref, side_load_ref, y_ref):
x_strided = plgpu.load(
x_ref, (), layout=plgpu.Layout.WG_STRIDED((128,), vec_size=1)
)
side_load_strided = plgpu.load(
side_load_ref, (), layout=plgpu.Layout.WG_STRIDED(out_shape, vec_size=1)
)
if num_major_dims == 1:
x_expanded = x_strided[None, :]
else:
x_expanded = x_strided[None, None, :]
y_ref[...] = x_expanded + side_load_strided[...]

inp = jax.random.uniform(jax.random.key(0), shape=(128,), dtype=jnp.float32)
side_load = jax.random.uniform(jax.random.key(1), shape=out_shape, dtype=jnp.float32)
if num_major_dims == 1:
expected = jnp.broadcast_to(inp[None, ...], out_shape)
else:
expected = jnp.broadcast_to(inp[None, None, ...], out_shape)
np.testing.assert_array_equal(kernel(inp, side_load), jnp.broadcast_to(expected + side_load, out_shape))

def test_broadcast_in_dim_tcgen05_native_layout(self):
@functools.partial(
self.kernel,
Expand Down
Loading