@@ -1373,7 +1373,7 @@ struct VectorScalableExtractOpLowering
1373
1373
// / ```
1374
1374
// / is rewritten into:
1375
1375
// / ```
1376
- // / %r = splat %f0: vector<2x4xf32>
1376
+ // / %r = vector.broadcast %f0 : f32 to vector<2x4xf32>
1377
1377
// / %va = vector.extractvalue %a[0] : vector<2x4xf32>
1378
1378
// / %vb = vector.extractvalue %b[0] : vector<2x4xf32>
1379
1379
// / %vc = vector.extractvalue %c[0] : vector<2x4xf32>
@@ -1406,7 +1406,7 @@ class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
1406
1406
auto elemType = vType.getElementType ();
1407
1407
Value zero = rewriter.create <arith::ConstantOp>(
1408
1408
loc, elemType, rewriter.getZeroAttr (elemType));
1409
- Value desc = rewriter.create <vector::SplatOp >(loc, vType, zero);
1409
+ Value desc = rewriter.create <vector::BroadcastOp >(loc, vType, zero);
1410
1410
for (int64_t i = 0 , e = vType.getShape ().front (); i != e; ++i) {
1411
1411
Value extrLHS = rewriter.create <ExtractOp>(loc, op.getLhs (), i);
1412
1412
Value extrRHS = rewriter.create <ExtractOp>(loc, op.getRhs (), i);
@@ -1548,7 +1548,7 @@ class VectorCreateMaskOpConversion
1548
1548
/* isScalable=*/ true ));
1549
1549
auto bound = getValueOrCreateCastToIndexLike (rewriter, loc, idxType,
1550
1550
adaptor.getOperands ()[0 ]);
1551
- Value bounds = rewriter.create <SplatOp >(loc, indices.getType (), bound);
1551
+ Value bounds = rewriter.create <BroadcastOp >(loc, indices.getType (), bound);
1552
1552
Value comp = rewriter.create <arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
1553
1553
indices, bounds);
1554
1554
rewriter.replaceOp (op, comp);
@@ -1732,63 +1732,77 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
1732
1732
}
1733
1733
};
1734
1734
1735
- // / The Splat operation is lowered to an insertelement + a shufflevector
1736
- // / operation. Splat to only 0-d and 1-d vector result types are lowered.
1737
- struct VectorSplatOpLowering : public ConvertOpToLLVMPattern <vector::SplatOp> {
1738
- using ConvertOpToLLVMPattern<vector::SplatOp>::ConvertOpToLLVMPattern;
1735
+ // / A broadcast of a scalar is lowered to an insertelement + a shufflevector
1736
+ // / operation. Only broadcasts to 0-d and 1-d vectors are lowered by this
1737
+ // / pattern, the higher rank cases are handled by another pattern.
1738
+ struct VectorBroadcastScalarToLowRankLowering
1739
+ : public ConvertOpToLLVMPattern<vector::BroadcastOp> {
1740
+ using ConvertOpToLLVMPattern<vector::BroadcastOp>::ConvertOpToLLVMPattern;
1739
1741
1740
1742
LogicalResult
1741
- matchAndRewrite (vector::SplatOp splatOp , OpAdaptor adaptor,
1743
+ matchAndRewrite (vector::BroadcastOp broadcast , OpAdaptor adaptor,
1742
1744
ConversionPatternRewriter &rewriter) const override {
1743
- VectorType resultType = cast<VectorType>(splatOp.getType ());
1745
+ if (isa<VectorType>(broadcast.getSourceType ()))
1746
+ return rewriter.notifyMatchFailure (
1747
+ broadcast, " broadcast from vector type not handled" );
1748
+
1749
+ VectorType resultType = broadcast.getType ();
1744
1750
if (resultType.getRank () > 1 )
1745
- return failure ();
1751
+ return rewriter.notifyMatchFailure (broadcast,
1752
+ " broadcast to 2+-d handled elsewhere" );
1746
1753
1747
1754
// First insert it into a poison vector so we can shuffle it.
1748
- auto vectorType = typeConverter->convertType (splatOp .getType ());
1755
+ auto vectorType = typeConverter->convertType (broadcast .getType ());
1749
1756
Value poison =
1750
- rewriter.create <LLVM::PoisonOp>(splatOp .getLoc (), vectorType);
1757
+ rewriter.create <LLVM::PoisonOp>(broadcast .getLoc (), vectorType);
1751
1758
auto zero = rewriter.create <LLVM::ConstantOp>(
1752
- splatOp .getLoc (),
1759
+ broadcast .getLoc (),
1753
1760
typeConverter->convertType (rewriter.getIntegerType (32 )),
1754
1761
rewriter.getZeroAttr (rewriter.getIntegerType (32 )));
1755
1762
1756
1763
// For 0-d vector, we simply do `insertelement`.
1757
1764
if (resultType.getRank () == 0 ) {
1758
1765
rewriter.replaceOpWithNewOp <LLVM::InsertElementOp>(
1759
- splatOp , vectorType, poison, adaptor.getInput (), zero);
1766
+ broadcast , vectorType, poison, adaptor.getSource (), zero);
1760
1767
return success ();
1761
1768
}
1762
1769
1763
1770
// For 1-d vector, we additionally do a `vectorshuffle`.
1764
1771
auto v = rewriter.create <LLVM::InsertElementOp>(
1765
- splatOp .getLoc (), vectorType, poison, adaptor.getInput (), zero);
1772
+ broadcast .getLoc (), vectorType, poison, adaptor.getSource (), zero);
1766
1773
1767
- int64_t width = cast<VectorType>(splatOp .getType ()).getDimSize (0 );
1774
+ int64_t width = cast<VectorType>(broadcast .getType ()).getDimSize (0 );
1768
1775
SmallVector<int32_t > zeroValues (width, 0 );
1769
1776
1770
1777
// Shuffle the value across the desired number of elements.
1771
- rewriter.replaceOpWithNewOp <LLVM::ShuffleVectorOp>(splatOp , v, poison,
1778
+ rewriter.replaceOpWithNewOp <LLVM::ShuffleVectorOp>(broadcast , v, poison,
1772
1779
zeroValues);
1773
1780
return success ();
1774
1781
}
1775
1782
};
1776
1783
1777
- // / The Splat operation is lowered to an insertelement + a shufflevector
1778
- // / operation. Splat to only 2+-d vector result types are lowered by the
1779
- // / SplatNdOpLowering, the 1-d case is handled by SplatOpLowering.
1780
- struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern <SplatOp> {
1781
- using ConvertOpToLLVMPattern<SplatOp>::ConvertOpToLLVMPattern;
1784
+ // / The broadcast of a scalar is lowered to an insertelement + a shufflevector
1785
+ // / operation. Only broadcasts to 2+-d vector result types are lowered by this
1786
+ // / pattern, the 1-d case is handled by another pattern. Broadcasts from vectors
1787
+ // / are not converted to LLVM, only broadcasts from scalars are.
1788
+ struct VectorBroadcastScalarToNdLowering
1789
+ : public ConvertOpToLLVMPattern<BroadcastOp> {
1790
+ using ConvertOpToLLVMPattern<BroadcastOp>::ConvertOpToLLVMPattern;
1782
1791
1783
1792
LogicalResult
1784
- matchAndRewrite (SplatOp splatOp , OpAdaptor adaptor,
1793
+ matchAndRewrite (BroadcastOp broadcast , OpAdaptor adaptor,
1785
1794
ConversionPatternRewriter &rewriter) const override {
1786
- VectorType resultType = splatOp.getType ();
1795
+ if (isa<VectorType>(broadcast.getSourceType ()))
1796
+ return rewriter.notifyMatchFailure (
1797
+ broadcast, " broadcast from vector type not handled" );
1798
+
1799
+ VectorType resultType = broadcast.getType ();
1787
1800
if (resultType.getRank () <= 1 )
1788
- return failure ();
1801
+ return rewriter.notifyMatchFailure (
1802
+ broadcast, " broadcast to 1-d or 0-d handled elsewhere" );
1789
1803
1790
1804
// First insert it into an undef vector so we can shuffle it.
1791
- auto loc = splatOp .getLoc ();
1805
+ auto loc = broadcast .getLoc ();
1792
1806
auto vectorTypeInfo =
1793
1807
LLVM::detail::extractNDVectorTypeInfo (resultType, *getTypeConverter ());
1794
1808
auto llvmNDVectorTy = vectorTypeInfo.llvmNDVectorTy ;
@@ -1799,26 +1813,26 @@ struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
1799
1813
// Construct returned value.
1800
1814
Value desc = rewriter.create <LLVM::PoisonOp>(loc, llvmNDVectorTy);
1801
1815
1802
- // Construct a 1-D vector with the splatted value that we insert in all the
1803
- // places within the returned descriptor.
1816
+ // Construct a 1-D vector with the broadcasted value that we insert in all
1817
+ // the places within the returned descriptor.
1804
1818
Value vdesc = rewriter.create <LLVM::PoisonOp>(loc, llvm1DVectorTy);
1805
1819
auto zero = rewriter.create <LLVM::ConstantOp>(
1806
1820
loc, typeConverter->convertType (rewriter.getIntegerType (32 )),
1807
1821
rewriter.getZeroAttr (rewriter.getIntegerType (32 )));
1808
1822
Value v = rewriter.create <LLVM::InsertElementOp>(loc, llvm1DVectorTy, vdesc,
1809
- adaptor.getInput (), zero);
1823
+ adaptor.getSource (), zero);
1810
1824
1811
1825
// Shuffle the value across the desired number of elements.
1812
1826
int64_t width = resultType.getDimSize (resultType.getRank () - 1 );
1813
1827
SmallVector<int32_t > zeroValues (width, 0 );
1814
1828
v = rewriter.create <LLVM::ShuffleVectorOp>(loc, v, v, zeroValues);
1815
1829
1816
- // Iterate of linear index, convert to coords space and insert splatted 1-D
1817
- // vector in each position.
1830
+ // Iterate of linear index, convert to coords space and insert broadcasted
1831
+ // 1-D vector in each position.
1818
1832
nDVectorIterate (vectorTypeInfo, rewriter, [&](ArrayRef<int64_t > position) {
1819
1833
desc = rewriter.create <LLVM::InsertValueOp>(loc, desc, v, position);
1820
1834
});
1821
- rewriter.replaceOp (splatOp , desc);
1835
+ rewriter.replaceOp (broadcast , desc);
1822
1836
return success ();
1823
1837
}
1824
1838
};
@@ -2177,6 +2191,19 @@ class TransposeOpToMatrixTransposeOpLowering
2177
2191
}
2178
2192
};
2179
2193
2194
+ // / Convert `vector.splat` to `vector.broadcast`. There is a path to LLVM from
2195
+ // / `vector.broadcast` through other patterns.
2196
+ struct VectorSplatToBroadcast : public ConvertOpToLLVMPattern <vector::SplatOp> {
2197
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
2198
+ LogicalResult
2199
+ matchAndRewrite (vector::SplatOp splat, OpAdaptor adaptor,
2200
+ ConversionPatternRewriter &rewriter) const override {
2201
+ rewriter.replaceOpWithNewOp <vector::BroadcastOp>(splat, splat.getType (),
2202
+ adaptor.getInput ());
2203
+ return success ();
2204
+ }
2205
+ };
2206
+
2180
2207
} // namespace
2181
2208
2182
2209
void mlir::vector::populateVectorRankReducingFMAPattern (
@@ -2216,7 +2243,8 @@ void mlir::populateVectorToLLVMConversionPatterns(
2216
2243
VectorInsertOpConversion, VectorPrintOpConversion,
2217
2244
VectorTypeCastOpConversion, VectorScaleOpConversion,
2218
2245
VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
2219
- VectorSplatOpLowering, VectorSplatNdOpLowering,
2246
+ VectorSplatToBroadcast, VectorBroadcastScalarToLowRankLowering,
2247
+ VectorBroadcastScalarToNdLowering,
2220
2248
VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
2221
2249
MaskedReductionOpConversion, VectorInterleaveOpLowering,
2222
2250
VectorDeinterleaveOpLowering, VectorFromElementsLowering,
0 commit comments