diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 6339cef1b25f..ef20ee735035 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -2408,8 +2408,14 @@ def _reshape_lowering_rule( if sharding is not None: raise NotImplementedError("Not implemented: sharding") [x_aval] = ctx.avals_in - x = _ensure_fa(x, x_aval.dtype) - return x.reshape(new_sizes) + return _ensure_fa(x, x_aval.dtype).reshape(new_sizes) + + +@register_lowering_rule(lax.squeeze_p, mgpu.LoweringSemantics.Lane) +def _squeeze_lowering_rule(ctx: LoweringRuleContext, x, dimensions): + [x_aval] = ctx.avals_in + [y_aval] = ctx.avals_out + return _ensure_fa(x, x_aval.dtype).reshape(y_aval.shape) def _reduce_lowering_rule(op, ctx: LoweringRuleContext, x, *, axes, **kwargs): diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index bb8a5576558e..a81784f1e6d5 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -1729,8 +1729,8 @@ def __getitem__(self, idx) -> FragmentedArray: if any(is_squeezed): raise NotImplementedError("Integer indexing not implemented (only slicing allowed)") base_tile_shape = self.layout.base_tile_shape - if len(base_tile_shape) != len(self.shape): - raise NotImplementedError("Tiling has different rank than array") + if untiled_rank := len(self.shape) - len(base_tile_shape): + base_tile_shape = (1,) * untiled_rank + base_tile_shape if any(b % t for b, t in zip(base_idx, base_tile_shape, strict=True)): raise ValueError( "Base indices of array slices must be aligned to the beginning of a" diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 49e9bea41837..158754fc697a 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -337,6 +337,21 @@ def kernel(x_ref, out_ref): x = jnp.arange(math.prod(shape1)).reshape(shape1).astype(jnp.float32) np.testing.assert_array_equal(kernel(x), x.reshape(shape2)) + def test_slice_untiled_dim(self): + self.skip_if_wg_semantics() + shape = (2, 3, 64, 8) + + @functools.partial( + self.kernel, + out_shape=jax.ShapeDtypeStruct(shape[2:], jnp.float32), + ) + def kernel(x_ref, out_ref): + y = plgpu.load(x_ref, (), layout=plgpu.Layout.WGMMA, optimized=False)[1, 1] + out_ref[...] = y + + x = jnp.arange(math.prod(shape)).reshape(shape).astype(jnp.float32) + np.testing.assert_array_equal(kernel(x), x[1, 1]) + def test_add_xy_indexed(self): @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct([128], jnp.float32) @@ -2835,6 +2850,7 @@ def test_missing_primitive_lowerings_are_tracked(self): pallas_primitives.delay_p, checkify.check_p, lax.reshape_p, + lax.squeeze_p, } self.assertSetEqual(actual_missing_primitives, expected_missing_primitives)