Skip to content

Commit 59789e6

Browse files
[Mosaic GPU] Fix a bug with transform inference of subview in the presence of dynamic offsets.
The old code incorrectly started counting the dynamic offsets from 0 while processing the tiled dimensions. This only works if there are no dynamic offsets in the non-tiled dimensions. Otherwise, the logic will index into the wrong offset. PiperOrigin-RevId: 802047327
1 parent d3ff2a6 commit 59789e6

File tree

2 files changed

+64
-3
lines changed

2 files changed

+64
-3
lines changed

jax/experimental/mosaic/gpu/transform_inference.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -334,21 +334,26 @@ def _infer_memref_subview_transforms(
334334
"unchanged."
335335
)
336336

337+
is_dynamic = lambda x: ir.ShapedType.is_dynamic_size(x)
338+
337339
# Check tile transform propagation.
338340
old_tiling = mgpu.TileTransformAttr(tile_transform).tiling
339341
num_tiled_axes = len(old_tiling)
340342
last_n_dims = op.source.type.shape[-num_tiled_axes:]
341343
last_n_sizes = list(op.static_sizes)[-num_tiled_axes:]
342344
last_n_offsets = list(op.static_offsets)[-num_tiled_axes:]
343345

344-
if any(ir.ShapedType.is_dynamic_size(x) for x in last_n_sizes):
346+
if any(is_dynamic(x) for x in last_n_sizes):
345347
raise NotImplementedError(
346348
"Subview transforms with dynamic sizes are not supported."
347349
)
348350

349-
dynamic_index = 0
351+
num_non_tiled_axes = len(op.source.type.shape) - num_tiled_axes
352+
non_tiled_offsets = list(op.static_offsets)[:num_non_tiled_axes]
353+
dynamic_index = sum(1 for x in non_tiled_offsets if is_dynamic(x))
354+
350355
for i in range(len(last_n_offsets)):
351-
if ir.ShapedType.is_dynamic_size(last_n_offsets[i]):
356+
if is_dynamic(last_n_offsets[i]):
352357
if utils.is_known_divisible(
353358
op.offsets[dynamic_index], last_n_sizes[i]
354359
):

tests/mosaic/gpu_transform_inference_test.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -657,6 +657,62 @@ def test_custom_primitive_op_retains_transforms(self):
657657
mgpu.infer_transforms(self.module)
658658
self.assertSequenceEqual(inference_utils.in_transforms(op), [transforms])
659659

660+
@parameterized.parameters([False, True])
661+
def test_infer_transforms_for_subview_handles_dynamic_offsets(
662+
self, annotate_input
663+
):
664+
subview_op = user_op = None
665+
shape = (32, 32, 32)
666+
elt_ty = ir.BF16Type.get()
667+
668+
in_ref_ty = ir.MemRefType.get(shape, elt_ty, memory_space=mgpu.utils.smem())
669+
out_ref_ty = ir.MemRefType.get((16, 16, 32), elt_ty, memory_space=mgpu.utils.smem())
670+
671+
def body(in_ref):
672+
nonlocal subview_op, user_op
673+
c = lambda x: arith.constant(ir.IntegerType.get_signless(32), x)
674+
subview_op = memref.SubViewOp(
675+
out_ref_ty,
676+
in_ref,
677+
[c(16), c(4)],
678+
[],
679+
[],
680+
static_offsets=[
681+
ir.ShapedType.get_dynamic_size(),
682+
ir.ShapedType.get_dynamic_size(),
683+
0,
684+
],
685+
static_sizes=[16, 16, 32],
686+
static_strides=[1, 1, 1],
687+
)
688+
user_op = memref.CastOp(out_ref_ty, subview_op.result)
689+
690+
with ir.InsertionPoint(self.module.body):
691+
f = func.FuncOp.from_py_func(in_ref_ty)(body).func_op
692+
693+
transforms = ir.ArrayAttr.get([
694+
mgpu.dialect.TileTransformAttr.get((16, 16)),
695+
mgpu.dialect.SwizzleTransformAttr.get(32),
696+
])
697+
698+
if annotate_input:
699+
f.attributes["in_transforms"] = ir.ArrayAttr.get([transforms])
700+
else:
701+
user_op.attributes["in_transforms"] = ir.ArrayAttr.get([transforms])
702+
703+
mgpu.infer_transforms(self.module)
704+
705+
expected_transforms = ir.ArrayAttr.get([
706+
mgpu.dialect.TileTransformAttr.get((1, 16)),
707+
mgpu.dialect.SwizzleTransformAttr.get(32),
708+
])
709+
710+
self.assertSequenceEqual(
711+
inference_utils.in_transforms(subview_op), [expected_transforms]
712+
)
713+
self.assertSequenceEqual(
714+
inference_utils.out_transforms(subview_op), [expected_transforms]
715+
)
660716

661717
if __name__ == "__main__":
662718
parameterized.absltest.main(testLoader=jtu.JaxTestLoader())

0 commit comments

Comments
 (0)