Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions jax/_src/pallas/mosaic_gpu/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -2408,8 +2408,14 @@ def _reshape_lowering_rule(
if sharding is not None:
raise NotImplementedError("Not implemented: sharding")
[x_aval] = ctx.avals_in
x = _ensure_fa(x, x_aval.dtype)
return x.reshape(new_sizes)
return _ensure_fa(x, x_aval.dtype).reshape(new_sizes)


@register_lowering_rule(lax.squeeze_p, mgpu.LoweringSemantics.Lane)
def _squeeze_lowering_rule(ctx: LoweringRuleContext, x, dimensions):
[x_aval] = ctx.avals_in
[y_aval] = ctx.avals_out
return _ensure_fa(x, x_aval.dtype).reshape(y_aval.shape)


def _reduce_lowering_rule(op, ctx: LoweringRuleContext, x, *, axes, **kwargs):
Expand Down
4 changes: 2 additions & 2 deletions jax/experimental/mosaic/gpu/fragmented_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1729,8 +1729,8 @@ def __getitem__(self, idx) -> FragmentedArray:
if any(is_squeezed):
raise NotImplementedError("Integer indexing not implemented (only slicing allowed)")
base_tile_shape = self.layout.base_tile_shape
if len(base_tile_shape) != len(self.shape):
raise NotImplementedError("Tiling has different rank than array")
if untiled_rank := len(self.shape) - len(base_tile_shape):
base_tile_shape = (1,) * untiled_rank + base_tile_shape
if any(b % t for b, t in zip(base_idx, base_tile_shape, strict=True)):
raise ValueError(
"Base indices of array slices must be aligned to the beginning of a"
Expand Down
16 changes: 16 additions & 0 deletions tests/pallas/mosaic_gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,21 @@ def kernel(x_ref, out_ref):
x = jnp.arange(math.prod(shape1)).reshape(shape1).astype(jnp.float32)
np.testing.assert_array_equal(kernel(x), x.reshape(shape2))

def test_slice_untiled_dim(self):
self.skip_if_wg_semantics()
shape = (2, 3, 64, 8)

@functools.partial(
self.kernel,
out_shape=jax.ShapeDtypeStruct(shape[2:], jnp.float32),
)
def kernel(x_ref, out_ref):
y = plgpu.load(x_ref, (), layout=plgpu.Layout.WGMMA, optimized=False)[1, 1]
out_ref[...] = y

x = jnp.arange(math.prod(shape)).reshape(shape).astype(jnp.float32)
np.testing.assert_array_equal(kernel(x), x[1, 1])

def test_add_xy_indexed(self):
@functools.partial(
self.pallas_call, out_shape=jax.ShapeDtypeStruct([128], jnp.float32)
Expand Down Expand Up @@ -2835,6 +2850,7 @@ def test_missing_primitive_lowerings_are_tracked(self):
pallas_primitives.delay_p,
checkify.check_p,
lax.reshape_p,
lax.squeeze_p,
}

self.assertSetEqual(actual_missing_primitives, expected_missing_primitives)
Expand Down
Loading