Skip to content

Commit ac6e2ee

Browse files
authored
[mlir][vector] Support direct broadcast conversion (LLVM & SPIRV) (#148027)
Add conversion for broadcast from scalar for LLVM and SPIRV. Also some miscellaneous replacements of vector.splat with vector.broadcast in VectorToGPU and ArithToAMDGPU. Part of deprecation of vector.splat RFC: https://discourse.llvm.org/t/rfc-mlir-vector-deprecate-then-remove-vector-splat/87143/4
1 parent abce4e9 commit ac6e2ee

File tree

10 files changed

+195
-116
lines changed

10 files changed

+195
-116
lines changed

mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ ExtFOnFloat8RewritePattern::matchAndRewrite(arith::ExtFOp op,
153153

154154
if (inVecType.getShape().empty()) {
155155
Value zerodSplat =
156-
rewriter.createOrFold<vector::SplatOp>(loc, outType, zero);
156+
rewriter.createOrFold<vector::BroadcastOp>(loc, outType, zero);
157157
Value scalarIn =
158158
rewriter.create<vector::ExtractOp>(loc, in, ArrayRef<int64_t>{});
159159
Value scalarExt =
@@ -166,7 +166,7 @@ ExtFOnFloat8RewritePattern::matchAndRewrite(arith::ExtFOp op,
166166

167167
VectorType flatTy = VectorType::get(SmallVector<int64_t>{numElements},
168168
outType.getElementType());
169-
Value result = rewriter.createOrFold<vector::SplatOp>(loc, flatTy, zero);
169+
Value result = rewriter.createOrFold<vector::BroadcastOp>(loc, flatTy, zero);
170170

171171
if (inVecType.getRank() > 1) {
172172
inVecType = VectorType::get(SmallVector<int64_t>{numElements},
@@ -315,7 +315,7 @@ TruncFToFloat8RewritePattern::matchAndRewrite(arith::TruncFOp op,
315315

316316
VectorType flatTy = VectorType::get(SmallVector<int64_t>{numElements},
317317
outVecType.getElementType());
318-
Value result = rewriter.createOrFold<vector::SplatOp>(loc, flatTy, zero);
318+
Value result = rewriter.createOrFold<vector::BroadcastOp>(loc, flatTy, zero);
319319

320320
if (inVectorTy.getRank() > 1) {
321321
inVectorTy = VectorType::get(SmallVector<int64_t>{numElements},
@@ -383,7 +383,8 @@ LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite(
383383
int64_t numElements = outVecType.getNumElements();
384384
Value zero = rewriter.createOrFold<arith::ConstantOp>(
385385
loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
386-
Value result = rewriter.createOrFold<vector::SplatOp>(loc, outVecType, zero);
386+
Value result =
387+
rewriter.createOrFold<vector::BroadcastOp>(loc, outVecType, zero);
387388

388389
if (inVectorTy.getRank() > 1) {
389390
inVectorTy = VectorType::get(SmallVector<int64_t>{numElements},
@@ -478,8 +479,8 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
478479
VectorType extScaleResultType = VectorType::get(opWidth, outType);
479480

480481
if (!outVecType) {
481-
Value inCast =
482-
rewriter.create<vector::SplatOp>(loc, VectorType::get(1, inType), in);
482+
Value inCast = rewriter.create<vector::BroadcastOp>(
483+
loc, VectorType::get(1, inType), in);
483484
// TODO: replace this with non-packed ScaledExtOp
484485
Value scaleExt = rewriter.create<amdgpu::ScaledExtPackedOp>(
485486
loc, extScaleResultType, inCast, scale, 0);
@@ -509,7 +510,8 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
509510

510511
Value zero = rewriter.create<arith::ConstantOp>(
511512
loc, outType, rewriter.getFloatAttr(outType, 0.0));
512-
Value result = rewriter.createOrFold<vector::SplatOp>(loc, outVecType, zero);
513+
Value result =
514+
rewriter.createOrFold<vector::BroadcastOp>(loc, outVecType, zero);
513515

514516
for (SmallVector<int64_t> offsets : StaticTileOffsetRange(inShape, ratio)) {
515517
SmallVector<int64_t> strides(offsets.size(), 1);
@@ -523,7 +525,7 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
523525

524526
VectorType blockResultType = VectorType::get(blockSize, outType);
525527
Value blockResult =
526-
rewriter.createOrFold<vector::SplatOp>(loc, blockResultType, zero);
528+
rewriter.createOrFold<vector::BroadcastOp>(loc, blockResultType, zero);
527529

528530
for (int64_t i = 0, sliceWidth = std::min(opWidth, blockSize - i);
529531
i < blockSize;
@@ -587,7 +589,7 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
587589

588590
if (!outVecType) {
589591
Type inVecType = VectorType::get(1, inType);
590-
Value inCast = rewriter.create<vector::SplatOp>(loc, inVecType, in);
592+
Value inCast = rewriter.create<vector::BroadcastOp>(loc, inVecType, in);
591593
// TODO: replace this with non-packed ScaledTruncOp
592594
Value scaleTrunc = rewriter.create<amdgpu::PackedScaledTruncOp>(
593595
loc, truncScaleResultType, inCast, scale, 0, /*existing=*/nullptr);
@@ -616,7 +618,8 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
616618

617619
int64_t blockSize = computeProduct(ratio);
618620

619-
Value result = rewriter.createOrFold<vector::SplatOp>(loc, outVecType, zero);
621+
Value result =
622+
rewriter.createOrFold<vector::BroadcastOp>(loc, outVecType, zero);
620623

621624
for (SmallVector<int64_t> offsets : StaticTileOffsetRange(inShape, ratio)) {
622625
SmallVector<int64_t> strides(offsets.size(), 1);
@@ -630,7 +633,7 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
630633

631634
VectorType blockResultType = VectorType::get(blockSize, outType);
632635
Value blockResult =
633-
rewriter.createOrFold<vector::SplatOp>(loc, blockResultType, zero);
636+
rewriter.createOrFold<vector::BroadcastOp>(loc, blockResultType, zero);
634637

635638
for (int64_t i = 0, sliceWidth = std::min(opWidth, blockSize - i);
636639
i < blockSize;

mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -792,7 +792,7 @@ createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op,
792792
op.getLoc(), vectorType.getElementType(),
793793
rewriter.getZeroAttr(vectorType.getElementType()));
794794
Value result =
795-
rewriter.create<vector::SplatOp>(op.getLoc(), fill, vectorType);
795+
rewriter.create<vector::BroadcastOp>(op.getLoc(), vectorType, fill);
796796

797797
bool isTransposeLoad = !op.getPermutationMap().isMinorIdentity();
798798

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

Lines changed: 61 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1373,7 +1373,7 @@ struct VectorScalableExtractOpLowering
13731373
/// ```
13741374
/// is rewritten into:
13751375
/// ```
1376-
/// %r = splat %f0: vector<2x4xf32>
1376+
/// %r = vector.broadcast %f0 : f32 to vector<2x4xf32>
13771377
/// %va = vector.extractvalue %a[0] : vector<2x4xf32>
13781378
/// %vb = vector.extractvalue %b[0] : vector<2x4xf32>
13791379
/// %vc = vector.extractvalue %c[0] : vector<2x4xf32>
@@ -1406,7 +1406,7 @@ class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
14061406
auto elemType = vType.getElementType();
14071407
Value zero = rewriter.create<arith::ConstantOp>(
14081408
loc, elemType, rewriter.getZeroAttr(elemType));
1409-
Value desc = rewriter.create<vector::SplatOp>(loc, vType, zero);
1409+
Value desc = rewriter.create<vector::BroadcastOp>(loc, vType, zero);
14101410
for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
14111411
Value extrLHS = rewriter.create<ExtractOp>(loc, op.getLhs(), i);
14121412
Value extrRHS = rewriter.create<ExtractOp>(loc, op.getRhs(), i);
@@ -1548,7 +1548,7 @@ class VectorCreateMaskOpConversion
15481548
/*isScalable=*/true));
15491549
auto bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType,
15501550
adaptor.getOperands()[0]);
1551-
Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound);
1551+
Value bounds = rewriter.create<BroadcastOp>(loc, indices.getType(), bound);
15521552
Value comp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
15531553
indices, bounds);
15541554
rewriter.replaceOp(op, comp);
@@ -1732,63 +1732,77 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
17321732
}
17331733
};
17341734

1735-
/// The Splat operation is lowered to an insertelement + a shufflevector
1736-
/// operation. Splat to only 0-d and 1-d vector result types are lowered.
1737-
struct VectorSplatOpLowering : public ConvertOpToLLVMPattern<vector::SplatOp> {
1738-
using ConvertOpToLLVMPattern<vector::SplatOp>::ConvertOpToLLVMPattern;
1735+
/// A broadcast of a scalar is lowered to an insertelement + a shufflevector
1736+
/// operation. Only broadcasts to 0-d and 1-d vectors are lowered by this
1737+
/// pattern, the higher rank cases are handled by another pattern.
1738+
struct VectorBroadcastScalarToLowRankLowering
1739+
: public ConvertOpToLLVMPattern<vector::BroadcastOp> {
1740+
using ConvertOpToLLVMPattern<vector::BroadcastOp>::ConvertOpToLLVMPattern;
17391741

17401742
LogicalResult
1741-
matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor,
1743+
matchAndRewrite(vector::BroadcastOp broadcast, OpAdaptor adaptor,
17421744
ConversionPatternRewriter &rewriter) const override {
1743-
VectorType resultType = cast<VectorType>(splatOp.getType());
1745+
if (isa<VectorType>(broadcast.getSourceType()))
1746+
return rewriter.notifyMatchFailure(
1747+
broadcast, "broadcast from vector type not handled");
1748+
1749+
VectorType resultType = broadcast.getType();
17441750
if (resultType.getRank() > 1)
1745-
return failure();
1751+
return rewriter.notifyMatchFailure(broadcast,
1752+
"broadcast to 2+-d handled elsewhere");
17461753

17471754
// First insert it into a poison vector so we can shuffle it.
1748-
auto vectorType = typeConverter->convertType(splatOp.getType());
1755+
auto vectorType = typeConverter->convertType(broadcast.getType());
17491756
Value poison =
1750-
rewriter.create<LLVM::PoisonOp>(splatOp.getLoc(), vectorType);
1757+
rewriter.create<LLVM::PoisonOp>(broadcast.getLoc(), vectorType);
17511758
auto zero = rewriter.create<LLVM::ConstantOp>(
1752-
splatOp.getLoc(),
1759+
broadcast.getLoc(),
17531760
typeConverter->convertType(rewriter.getIntegerType(32)),
17541761
rewriter.getZeroAttr(rewriter.getIntegerType(32)));
17551762

17561763
// For 0-d vector, we simply do `insertelement`.
17571764
if (resultType.getRank() == 0) {
17581765
rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
1759-
splatOp, vectorType, poison, adaptor.getInput(), zero);
1766+
broadcast, vectorType, poison, adaptor.getSource(), zero);
17601767
return success();
17611768
}
17621769

17631770
// For 1-d vector, we additionally do a `vectorshuffle`.
17641771
auto v = rewriter.create<LLVM::InsertElementOp>(
1765-
splatOp.getLoc(), vectorType, poison, adaptor.getInput(), zero);
1772+
broadcast.getLoc(), vectorType, poison, adaptor.getSource(), zero);
17661773

1767-
int64_t width = cast<VectorType>(splatOp.getType()).getDimSize(0);
1774+
int64_t width = cast<VectorType>(broadcast.getType()).getDimSize(0);
17681775
SmallVector<int32_t> zeroValues(width, 0);
17691776

17701777
// Shuffle the value across the desired number of elements.
1771-
rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(splatOp, v, poison,
1778+
rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(broadcast, v, poison,
17721779
zeroValues);
17731780
return success();
17741781
}
17751782
};
17761783

1777-
/// The Splat operation is lowered to an insertelement + a shufflevector
1778-
/// operation. Splat to only 2+-d vector result types are lowered by the
1779-
/// SplatNdOpLowering, the 1-d case is handled by SplatOpLowering.
1780-
struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
1781-
using ConvertOpToLLVMPattern<SplatOp>::ConvertOpToLLVMPattern;
1784+
/// The broadcast of a scalar is lowered to an insertelement + a shufflevector
1785+
/// operation. Only broadcasts to 2+-d vector result types are lowered by this
1786+
/// pattern, the 1-d case is handled by another pattern. Broadcasts from vectors
1787+
/// are not converted to LLVM, only broadcasts from scalars are.
1788+
struct VectorBroadcastScalarToNdLowering
1789+
: public ConvertOpToLLVMPattern<BroadcastOp> {
1790+
using ConvertOpToLLVMPattern<BroadcastOp>::ConvertOpToLLVMPattern;
17821791

17831792
LogicalResult
1784-
matchAndRewrite(SplatOp splatOp, OpAdaptor adaptor,
1793+
matchAndRewrite(BroadcastOp broadcast, OpAdaptor adaptor,
17851794
ConversionPatternRewriter &rewriter) const override {
1786-
VectorType resultType = splatOp.getType();
1795+
if (isa<VectorType>(broadcast.getSourceType()))
1796+
return rewriter.notifyMatchFailure(
1797+
broadcast, "broadcast from vector type not handled");
1798+
1799+
VectorType resultType = broadcast.getType();
17871800
if (resultType.getRank() <= 1)
1788-
return failure();
1801+
return rewriter.notifyMatchFailure(
1802+
broadcast, "broadcast to 1-d or 0-d handled elsewhere");
17891803

17901804
// First insert it into an undef vector so we can shuffle it.
1791-
auto loc = splatOp.getLoc();
1805+
auto loc = broadcast.getLoc();
17921806
auto vectorTypeInfo =
17931807
LLVM::detail::extractNDVectorTypeInfo(resultType, *getTypeConverter());
17941808
auto llvmNDVectorTy = vectorTypeInfo.llvmNDVectorTy;
@@ -1799,26 +1813,26 @@ struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
17991813
// Construct returned value.
18001814
Value desc = rewriter.create<LLVM::PoisonOp>(loc, llvmNDVectorTy);
18011815

1802-
// Construct a 1-D vector with the splatted value that we insert in all the
1803-
// places within the returned descriptor.
1816+
// Construct a 1-D vector with the broadcasted value that we insert in all
1817+
// the places within the returned descriptor.
18041818
Value vdesc = rewriter.create<LLVM::PoisonOp>(loc, llvm1DVectorTy);
18051819
auto zero = rewriter.create<LLVM::ConstantOp>(
18061820
loc, typeConverter->convertType(rewriter.getIntegerType(32)),
18071821
rewriter.getZeroAttr(rewriter.getIntegerType(32)));
18081822
Value v = rewriter.create<LLVM::InsertElementOp>(loc, llvm1DVectorTy, vdesc,
1809-
adaptor.getInput(), zero);
1823+
adaptor.getSource(), zero);
18101824

18111825
// Shuffle the value across the desired number of elements.
18121826
int64_t width = resultType.getDimSize(resultType.getRank() - 1);
18131827
SmallVector<int32_t> zeroValues(width, 0);
18141828
v = rewriter.create<LLVM::ShuffleVectorOp>(loc, v, v, zeroValues);
18151829

1816-
// Iterate of linear index, convert to coords space and insert splatted 1-D
1817-
// vector in each position.
1830+
// Iterate of linear index, convert to coords space and insert broadcasted
1831+
// 1-D vector in each position.
18181832
nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayRef<int64_t> position) {
18191833
desc = rewriter.create<LLVM::InsertValueOp>(loc, desc, v, position);
18201834
});
1821-
rewriter.replaceOp(splatOp, desc);
1835+
rewriter.replaceOp(broadcast, desc);
18221836
return success();
18231837
}
18241838
};
@@ -2177,6 +2191,19 @@ class TransposeOpToMatrixTransposeOpLowering
21772191
}
21782192
};
21792193

2194+
/// Convert `vector.splat` to `vector.broadcast`. There is a path to LLVM from
2195+
/// `vector.broadcast` through other patterns.
2196+
struct VectorSplatToBroadcast : public ConvertOpToLLVMPattern<vector::SplatOp> {
2197+
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
2198+
LogicalResult
2199+
matchAndRewrite(vector::SplatOp splat, OpAdaptor adaptor,
2200+
ConversionPatternRewriter &rewriter) const override {
2201+
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splat, splat.getType(),
2202+
adaptor.getInput());
2203+
return success();
2204+
}
2205+
};
2206+
21802207
} // namespace
21812208

21822209
void mlir::vector::populateVectorRankReducingFMAPattern(
@@ -2216,7 +2243,8 @@ void mlir::populateVectorToLLVMConversionPatterns(
22162243
VectorInsertOpConversion, VectorPrintOpConversion,
22172244
VectorTypeCastOpConversion, VectorScaleOpConversion,
22182245
VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
2219-
VectorSplatOpLowering, VectorSplatNdOpLowering,
2246+
VectorSplatToBroadcast, VectorBroadcastScalarToLowRankLowering,
2247+
VectorBroadcastScalarToNdLowering,
22202248
VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
22212249
MaskedReductionOpConversion, VectorInterleaveOpLowering,
22222250
VectorDeinterleaveOpLowering, VectorFromElementsLowering,

mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -444,7 +444,7 @@ struct Strategy<TransferReadOp> {
444444
Location loc = xferOp.getLoc();
445445
auto bufferType = dyn_cast<ShapedType>(buffer.getType());
446446
auto vecType = dyn_cast<VectorType>(bufferType.getElementType());
447-
auto vec = b.create<vector::SplatOp>(loc, vecType, xferOp.getPadding());
447+
auto vec = b.create<vector::BroadcastOp>(loc, vecType, xferOp.getPadding());
448448
b.create<memref::StoreOp>(loc, vec, buffer, storeIndices);
449449

450450
return Value();
@@ -1261,8 +1261,8 @@ struct UnrollTransferReadConversion
12611261
if (auto insertOp = getInsertOp(xferOp))
12621262
return insertOp.getDest();
12631263
Location loc = xferOp.getLoc();
1264-
return rewriter.create<vector::SplatOp>(loc, xferOp.getVectorType(),
1265-
xferOp.getPadding());
1264+
return rewriter.create<vector::BroadcastOp>(loc, xferOp.getVectorType(),
1265+
xferOp.getPadding());
12661266
}
12671267

12681268
/// If the result of the TransferReadOp has exactly one user, which is a
@@ -1583,8 +1583,8 @@ struct Strategy1d<TransferReadOp> {
15831583
static Value initialLoopState(OpBuilder &b, TransferReadOp xferOp) {
15841584
// Inititalize vector with padding value.
15851585
Location loc = xferOp.getLoc();
1586-
return b.create<vector::SplatOp>(loc, xferOp.getVectorType(),
1587-
xferOp.getPadding());
1586+
return b.create<vector::BroadcastOp>(loc, xferOp.getVectorType(),
1587+
xferOp.getPadding());
15881588
}
15891589
};
15901590

mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,20 @@ struct VectorShapeCast final : public OpConversionPattern<vector::ShapeCastOp> {
7979
}
8080
};
8181

82+
// Convert `vector.splat` to `vector.broadcast`. There is a path from
83+
// `vector.broadcast` to SPIRV via other patterns.
84+
struct VectorSplatToBroadcast final
85+
: public OpConversionPattern<vector::SplatOp> {
86+
using OpConversionPattern::OpConversionPattern;
87+
LogicalResult
88+
matchAndRewrite(vector::SplatOp splat, OpAdaptor adaptor,
89+
ConversionPatternRewriter &rewriter) const override {
90+
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splat, splat.getType(),
91+
adaptor.getInput());
92+
return success();
93+
}
94+
};
95+
8296
struct VectorBitcastConvert final
8397
: public OpConversionPattern<vector::BitCastOp> {
8498
using OpConversionPattern::OpConversionPattern;
@@ -556,22 +570,27 @@ struct VectorReductionFloatMinMax final
556570
}
557571
};
558572

559-
class VectorSplatPattern final : public OpConversionPattern<vector::SplatOp> {
573+
class VectorScalarBroadcastPattern final
574+
: public OpConversionPattern<vector::BroadcastOp> {
560575
public:
561-
using OpConversionPattern<vector::SplatOp>::OpConversionPattern;
576+
using OpConversionPattern<vector::BroadcastOp>::OpConversionPattern;
562577

563578
LogicalResult
564-
matchAndRewrite(vector::SplatOp op, OpAdaptor adaptor,
579+
matchAndRewrite(vector::BroadcastOp op, OpAdaptor adaptor,
565580
ConversionPatternRewriter &rewriter) const override {
581+
if (isa<VectorType>(op.getSourceType())) {
582+
return rewriter.notifyMatchFailure(
583+
op, "only conversion of 'broadcast from scalar' is supported");
584+
}
566585
Type dstType = getTypeConverter()->convertType(op.getType());
567586
if (!dstType)
568587
return failure();
569588
if (isa<spirv::ScalarType>(dstType)) {
570-
rewriter.replaceOp(op, adaptor.getInput());
589+
rewriter.replaceOp(op, adaptor.getSource());
571590
} else {
572591
auto dstVecType = cast<VectorType>(dstType);
573592
SmallVector<Value, 4> source(dstVecType.getNumElements(),
574-
adaptor.getInput());
593+
adaptor.getSource());
575594
rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, dstType,
576595
source);
577596
}
@@ -1089,11 +1108,11 @@ void mlir::populateVectorToSPIRVPatterns(
10891108
VectorReductionPattern<CL_INT_MAX_MIN_OPS>,
10901109
VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
10911110
VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
1092-
VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
1093-
VectorInterleaveOpConvert, VectorDeinterleaveOpConvert,
1094-
VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter,
1095-
VectorStepOpConvert>(typeConverter, patterns.getContext(),
1096-
PatternBenefit(1));
1111+
VectorSplatToBroadcast, VectorInsertStridedSliceOpConvert,
1112+
VectorShuffleOpConvert, VectorInterleaveOpConvert,
1113+
VectorDeinterleaveOpConvert, VectorScalarBroadcastPattern,
1114+
VectorLoadOpConverter, VectorStoreOpConverter, VectorStepOpConvert>(
1115+
typeConverter, patterns.getContext(), PatternBenefit(1));
10971116

10981117
// Make sure that the more specialized dot product pattern has higher benefit
10991118
// than the generic one that extracts all elements.

0 commit comments

Comments
 (0)