Skip to content

Commit f217947

Browse files
authored
Add Prefetch support and retain discardable attributes in XeTile Canonicalization. (#853)
* Add Prefetch support and retain discardable attributes in XeTile Canonicalization * pre-commit issue
1 parent 9fd715c commit f217947

File tree

2 files changed

+83
-12
lines changed

2 files changed

+83
-12
lines changed

lib/Dialect/XeTile/Transforms/Canonicalization.cpp

Lines changed: 49 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,20 @@ struct UpdateTileOffsetOpPattern final
135135
}
136136
};
137137

138+
struct PrefetchTilePattern final
139+
: public mlir::OpConversionPattern<imex::xetile::PrefetchTileOp> {
140+
using OpConversionPattern<imex::xetile::PrefetchTileOp>::OpConversionPattern;
141+
mlir::LogicalResult
142+
matchAndRewrite(imex::xetile::PrefetchTileOp prefetchOp, OpAdaptor adaptor,
143+
mlir::ConversionPatternRewriter &rewriter) const override {
144+
// Create a new prefetch op.
145+
rewriter.replaceOpWithNewOp<imex::xetile::PrefetchTileOp>(
146+
prefetchOp, adaptor.getTile(), prefetchOp.getL1HintAttr(),
147+
prefetchOp.getL2HintAttr(), prefetchOp.getL3HintAttr());
148+
return mlir::success();
149+
}
150+
};
151+
138152
// Pattern for rewriting LoadTileOp to consume row-major tiles.
139153
struct LoadTileOpPattern final
140154
: public mlir::OpConversionPattern<imex::xetile::LoadTileOp> {
@@ -216,9 +230,13 @@ struct VectorTransposeToXetileTransposeOpPattern
216230
mlir::PatternRewriter &rewriter) const override {
217231
if (op.getVector().getType().getRank() != 2)
218232
return mlir::failure();
233+
// Retain discardable attributes if any.
234+
llvm::SmallVector<mlir::NamedAttribute> discardableAttrs(
235+
op->getDiscardableAttrs().begin(), op->getDiscardableAttrs().end());
219236
// Create an equivalent XeTileTransposeOp
220-
rewriter.replaceOpWithNewOp<imex::xetile::TransposeOp>(
237+
auto newOp = rewriter.replaceOpWithNewOp<imex::xetile::TransposeOp>(
221238
op, op.getType(), op.getVector(), op.getPermutation());
239+
newOp->setDiscardableAttrs(discardableAttrs);
222240
return mlir::success();
223241
}
224242
};
@@ -242,6 +260,9 @@ struct VectorBroadcastToXetileBroadcastOpPattern
242260
auto sourceVectorTy = llvm::cast<mlir::VectorType>(op.getSourceType());
243261
auto sourceRank = sourceVectorTy.getRank();
244262
auto sourceShape = sourceVectorTy.getShape();
263+
// Retain the discardable attributes if any.
264+
llvm::SmallVector<mlir::NamedAttribute> discardableAttrs(
265+
op->getDiscardableAttrs().begin(), op->getDiscardableAttrs().end());
245266
// If the source rank is 1 and result rank is 2, we need to create a shape
246267
// cast to convert source to 2D and then create a xetile.broadcast. In this
247268
// case, broadcast dimension is 0 according to vector.broadcast definition.
@@ -251,14 +272,17 @@ struct VectorBroadcastToXetileBroadcastOpPattern
251272
resultTy.getElementType());
252273
auto source2D = rewriter.create<mlir::vector::ShapeCastOp>(
253274
op.getLoc(), source2DTy, op.getSource());
254-
rewriter.replaceOpWithNewOp<imex::xetile::BroadcastOp>(
275+
source2D->setDiscardableAttrs(discardableAttrs);
276+
auto newOp = rewriter.replaceOpWithNewOp<imex::xetile::BroadcastOp>(
255277
op, resultTy, source2D, llvm::ArrayRef<int64_t>({0}));
278+
newOp->setDiscardableAttrs(discardableAttrs);
256279
return mlir::success();
257280
}
258281
// If ranks are same, inner dimension is stretched in vector.broadcast. So
259282
// broadcast dimension is 1 for this case.
260-
rewriter.replaceOpWithNewOp<imex::xetile::BroadcastOp>(
283+
auto newOp = rewriter.replaceOpWithNewOp<imex::xetile::BroadcastOp>(
261284
op, resultTy, op.getSource(), llvm::ArrayRef<int64_t>({1}));
285+
newOp->setDiscardableAttrs(discardableAttrs);
262286
return mlir::success();
263287
}
264288
};
@@ -281,7 +305,9 @@ struct VectorMultiReductionToXeTileReduce
281305
auto reductionDims = op.getReductionDims().getValue();
282306
if (reductionDims.size() != 1)
283307
return mlir::failure();
284-
308+
// Retain discardable attributes if any.
309+
llvm::SmallVector<mlir::NamedAttribute> discardableAttrs(
310+
op->getDiscardableAttrs().begin(), op->getDiscardableAttrs().end());
285311
// Create an equivalent XeTileReduceOp
286312
int64_t reduceDim = llvm::cast<mlir::IntegerAttr>(reductionDims[0])
287313
.getValue()
@@ -294,16 +320,21 @@ struct VectorMultiReductionToXeTileReduce
294320
auto reduceOp = rewriter.create<imex::xetile::ReductionOp>(
295321
op->getLoc(), xetileResultTy, op.getKind(), op.getSource(),
296322
mlir::ArrayRef<int64_t>({reduceDim}));
323+
reduceOp->setDiscardableAttrs(discardableAttrs);
297324
// Shape cast the result back to original shape.
298325
auto shapeCastOp = rewriter.create<mlir::vector::ShapeCastOp>(
299326
op->getLoc(), resultTy, reduceOp.getResult());
327+
shapeCastOp->setDiscardableAttrs(discardableAttrs);
300328
// Finally add the result to the accumulator.
301-
if (llvm::isa<mlir::IntegerType>(sourceTy.getElementType()))
302-
rewriter.replaceOpWithNewOp<mlir::arith::AddIOp>(op, shapeCastOp,
303-
op.getAcc());
304-
else
305-
rewriter.replaceOpWithNewOp<mlir::arith::AddFOp>(op, shapeCastOp,
306-
op.getAcc());
329+
if (llvm::isa<mlir::IntegerType>(sourceTy.getElementType())) {
330+
auto accOp = rewriter.replaceOpWithNewOp<mlir::arith::AddIOp>(
331+
op, shapeCastOp, op.getAcc());
332+
accOp->setDiscardableAttrs(discardableAttrs);
333+
} else {
334+
auto accOp = rewriter.replaceOpWithNewOp<mlir::arith::AddFOp>(
335+
op, shapeCastOp, op.getAcc());
336+
accOp->setDiscardableAttrs(discardableAttrs);
337+
}
307338
return mlir::success();
308339
}
309340
};
@@ -406,6 +437,12 @@ struct XeTileCanonicalizationPass final
406437
op) {
407438
return op.getType().getOrder().asArrayRef() != mlir::ArrayRef({0, 1});
408439
});
440+
// PrefetchTileOp is legal if it does not consume col-major tiles.
441+
target.addDynamicallyLegalOp<imex::xetile::PrefetchTileOp>(
442+
[&](imex::xetile::PrefetchTileOp op) {
443+
return op.getTile().getType().getOrder().asArrayRef() !=
444+
mlir::ArrayRef({0, 1});
445+
});
409446
// LoadTileOp is legal if it does not consume col-major tiles.
410447
target.addDynamicallyLegalOp<imex::xetile::LoadTileOp>(
411448
[&](imex::xetile::LoadTileOp op) {
@@ -437,7 +474,8 @@ struct XeTileCanonicalizationPass final
437474
});
438475
patterns
439476
.add<InitTileOpPattern, LoadTileOpPattern, UpdateTileOffsetOpPattern,
440-
ScfForOpPattern, ScfYieldOpPattern>(typeConverter, context);
477+
PrefetchTilePattern, ScfForOpPattern, ScfYieldOpPattern>(
478+
typeConverter, context);
441479

442480
if (mlir::failed(mlir::applyPartialConversion(getOperation(), target,
443481
std::move(patterns))))

test/Dialect/XeTile/Transforms/canonicalization.mlir

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,15 @@ gpu.module @test_module {
33
gpu.func @test_static_memref(%arg0 : memref<512x128xf16, strided<[1, 512], offset:0>>, %arg1 : index, %arg2 : index) {
44
%0 = xetile.init_tile %arg0 [%arg1, %arg2] : memref<512x128xf16, strided<[1, 512], offset:0>> -> !xetile.tile<16x32xf16, #xetile.tile_attr<order=[0,1]>>
55
%3 = xetile.load_tile %0 : !xetile.tile<16x32xf16, #xetile.tile_attr<order=[0,1]>> -> vector<16x32xf16>
6+
xetile.prefetch_tile %0 : !xetile.tile<16x32xf16, #xetile.tile_attr<order=[0,1]>>
67
// With static offsets
78
%1 = xetile.init_tile %arg0 [12, %arg1] : memref<512x128xf16, strided<[1, 512], offset:0>> -> !xetile.tile<16x32xf16, #xetile.tile_attr<order=[0,1]>>
89
// Update offsets
910
%c16 = arith.constant 16 : index
1011
%c32 = arith.constant 32 : index
1112
%2 = xetile.update_tile_offset %1, [%c32, %c16] : !xetile.tile<16x32xf16, #xetile.tile_attr<order=[0,1]>>, index, index -> !xetile.tile<16x32xf16, #xetile.tile_attr<order=[0,1]>>
1213
%4 = xetile.load_tile %2 : !xetile.tile<16x32xf16, #xetile.tile_attr<order=[0,1]>> -> vector<16x32xf16>
14+
xetile.prefetch_tile %1 : !xetile.tile<16x32xf16, #xetile.tile_attr<order=[0,1]>>
1315
gpu.return
1416
}
1517
}
@@ -23,22 +25,25 @@ gpu.module @test_module {
2325
// CHECK: %[[T0:.*]] = xetile.init_tile %[[RCAST]][%[[ARG2]], %[[ARG1]]] : memref<128x512xf16, strided<[512, 1]>> -> !xetile.tile<32x16xf16, #xetile.tile_attr<>>
2426
// CHECK: %[[T1:.*]] = xetile.load_tile %[[T0]] { padding = 0.000000e+00 : f32 } : !xetile.tile<32x16xf16, #xetile.tile_attr<>> -> vector<32x16xf16>
2527
// CHECK: %[[T2:.*]] = xetile.transpose %[[T1]], [1, 0] : vector<32x16xf16> -> vector<16x32xf16>
28+
// CHECK: xetile.prefetch_tile %[[T0]] : !xetile.tile<32x16xf16, #xetile.tile_attr<>>
2629
// CHECK: %[[RCAST0:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [0], sizes: [128, 512], strides: [512, 1] : memref<512x128xf16, strided<[1, 512]>> to memref<128x512xf16, strided<[512, 1]>>
2730
// CHECK: %[[T3:.*]] = xetile.init_tile %[[RCAST0]][%[[ARG1]], 12] : memref<128x512xf16, strided<[512, 1]>> -> !xetile.tile<32x16xf16, #xetile.tile_attr<>>
2831
// CHECK: %[[T4:.*]] = xetile.update_tile_offset %[[T3]], [%[[C16]], %[[C32]]] : !xetile.tile<32x16xf16, #xetile.tile_attr<>>, index, index -> !xetile.tile<32x16xf16, #xetile.tile_attr<>>
2932
// CHECK: %[[T5:.*]] = xetile.load_tile %[[T4]] { padding = 0.000000e+00 : f32 } : !xetile.tile<32x16xf16, #xetile.tile_attr<>> -> vector<32x16xf16>
3033
// CHECK: %[[T6:.*]] = xetile.transpose %[[T5]], [1, 0] : vector<32x16xf16> -> vector<16x32xf16>
34+
// CHECK: xetile.prefetch_tile %[[T3]] : !xetile.tile<32x16xf16, #xetile.tile_attr<>>
3135

3236
// -----
3337
gpu.module @test_module {
3438
gpu.func @test_dynamic_memref(%arg0 : memref<?x?xf16>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : index, %arg6 : index) {
3539
%0 = xetile.init_tile %arg0 [%arg1, %arg2], [%arg3, %arg4], [%arg5, %arg6] : memref<?x?xf16> -> !xetile.tile<16x32xf16, #xetile.tile_attr<order=[0,1]>>
36-
%1 = xetile.load_tile %0 : !xetile.tile<16x32xf16, #xetile.tile_attr<order=[0,1]>> -> vector<16x32xf16>
40+
xetile.load_tile %0 : !xetile.tile<16x32xf16, #xetile.tile_attr<order=[0,1]>> -> vector<16x32xf16>
3741
// Update offsets
3842
%c16 = arith.constant 16 : index
3943
%c32 = arith.constant 32 : index
4044
%2 = xetile.update_tile_offset %0, [%c32, %c16] : !xetile.tile<16x32xf16, #xetile.tile_attr<order=[0,1]>>, index, index -> !xetile.tile<16x32xf16, #xetile.tile_attr<order=[0,1]>>
4145
%3 = xetile.load_tile %2 : !xetile.tile<16x32xf16, #xetile.tile_attr<order=[0,1]>> -> vector<16x32xf16>
46+
xetile.prefetch_tile %2 : !xetile.tile<16x32xf16, #xetile.tile_attr<order=[0,1]>>
4247
gpu.return
4348
}
4449
}
@@ -59,6 +64,7 @@ gpu.module @test_module {
5964
// CHECK: %[[T3:.*]] = xetile.update_tile_offset %[[T0]], [%[[C16]], %[[C32]]] : !xetile.tile<32x16xf16, #xetile.tile_attr<>>, index, index -> !xetile.tile<32x16xf16, #xetile.tile_attr<>>
6065
// CHECK: %[[T4:.*]] = xetile.load_tile %[[T3]] { padding = 0.000000e+00 : f32 } : !xetile.tile<32x16xf16, #xetile.tile_attr<>> -> vector<32x16xf16>
6166
// CHECK: %[[T5:.*]] = xetile.transpose %[[T4]], [1, 0] : vector<32x16xf16> -> vector<16x32xf16>
67+
// CHECK: xetile.prefetch_tile %[[T3]] : !xetile.tile<32x16xf16, #xetile.tile_attr<>>
6268

6369
// -----
6470
gpu.module @test_module {
@@ -272,10 +278,37 @@ gpu.module @test_module {
272278
}
273279
}
274280

281+
// CHECK-LABEL: @test_multireduction_1
282+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: vector<64x256xf32>, %[[ARG1:[a-zA-Z0-9]+]]: vector<256xf32>) -> vector<256xf32>
283+
// CHECK: %[[T0:.*]] = xetile.reduction <add>, %[[ARG0]] [0] : vector<64x256xf32> -> vector<1x256xf32>
284+
// CHECK: %[[T1:.*]] = vector.shape_cast %[[T0]] : vector<1x256xf32> to vector<256xf32>
285+
// CHECK: %[[T2:.*]] = arith.addf %[[T1]], %[[ARG1]] : vector<256xf32>
286+
// CHECK: gpu.return %[[T2]] : vector<256xf32>
287+
275288
// -----
276289
gpu.module @test_module {
277290
gpu.func @test_multireduction_2(%arg0 : vector<64x256xi8>, %arg1 : vector<256xi8>) -> vector<256xi8> {
278291
%0 = vector.multi_reduction <add>, %arg0, %arg1 [0] : vector<64x256xi8> to vector<256xi8>
279292
gpu.return %0 : vector<256xi8>
280293
}
281294
}
295+
296+
// CHECK-LABEL: @test_multireduction_2
297+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: vector<64x256xi8>, %[[ARG1:[a-zA-Z0-9]+]]: vector<256xi8>) -> vector<256xi8>
298+
// CHECK: %[[T0:.*]] = xetile.reduction <add>, %[[ARG0]] [0] : vector<64x256xi8> -> vector<1x256xi8>
299+
// CHECK: %[[T1:.*]] = vector.shape_cast %[[T0]] : vector<1x256xi8> to vector<256xi8>
300+
// CHECK: %[[T2:.*]] = arith.addi %[[T1]], %[[ARG1]] : vector<256xi8>
301+
// CHECK: gpu.return %[[T2]] : vector<256xi8>
302+
303+
// -----
304+
gpu.module @test_module {
305+
gpu.func @test_transpose_1(%arg0 : vector<16x32xf32>) -> vector<32x16xf32> {
306+
%0 = vector.transpose %arg0, [1, 0] : vector<16x32xf32> to vector<32x16xf32>
307+
gpu.return %0 : vector<32x16xf32>
308+
}
309+
}
310+
311+
// CHECK-LABEL: @test_transpose_1
312+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: vector<16x32xf32>) -> vector<32x16xf32>
313+
// CHECK: %[[T0:.*]] = xetile.transpose %arg0, [1, 0] : vector<16x32xf32> -> vector<32x16xf32>
314+
// CHECK: gpu.return %[[T0]] : vector<32x16xf32>

0 commit comments

Comments
 (0)