@@ -1408,7 +1408,7 @@ struct VectorScalableExtractOpLowering
1408
1408
// / ```
1409
1409
// / is rewritten into:
1410
1410
// / ```
1411
- // / %r = splat %f0: vector<2x4xf32>
1411
+ // / %r = vector.broadcast %f0 : f32 to vector<2x4xf32>
1412
1412
// / %va = vector.extractvalue %a[0] : vector<2x4xf32>
1413
1413
// / %vb = vector.extractvalue %b[0] : vector<2x4xf32>
1414
1414
// / %vc = vector.extractvalue %c[0] : vector<2x4xf32>
@@ -1441,7 +1441,7 @@ class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
1441
1441
auto elemType = vType.getElementType ();
1442
1442
Value zero = rewriter.create <arith::ConstantOp>(
1443
1443
loc, elemType, rewriter.getZeroAttr (elemType));
1444
- Value desc = rewriter.create <vector::SplatOp >(loc, vType, zero);
1444
+ Value desc = rewriter.create <vector::BroadcastOp >(loc, vType, zero);
1445
1445
for (int64_t i = 0 , e = vType.getShape ().front (); i != e; ++i) {
1446
1446
Value extrLHS = rewriter.create <ExtractOp>(loc, op.getLhs (), i);
1447
1447
Value extrRHS = rewriter.create <ExtractOp>(loc, op.getRhs (), i);
@@ -1583,7 +1583,7 @@ class VectorCreateMaskOpConversion
1583
1583
/* isScalable=*/ true ));
1584
1584
auto bound = getValueOrCreateCastToIndexLike (rewriter, loc, idxType,
1585
1585
adaptor.getOperands ()[0 ]);
1586
- Value bounds = rewriter.create <SplatOp >(loc, indices.getType (), bound);
1586
+ Value bounds = rewriter.create <BroadcastOp >(loc, indices.getType (), bound);
1587
1587
Value comp = rewriter.create <arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
1588
1588
indices, bounds);
1589
1589
rewriter.replaceOp (op, comp);
@@ -1767,63 +1767,79 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
1767
1767
}
1768
1768
};
1769
1769
1770
- // / The Splat operation is lowered to an insertelement + a shufflevector
1771
- // / operation. Splat to only 0-d and 1-d vector result types are lowered.
1772
- struct VectorSplatOpLowering : public ConvertOpToLLVMPattern <vector::SplatOp> {
1773
- using ConvertOpToLLVMPattern<vector::SplatOp>::ConvertOpToLLVMPattern;
1770
+ // / A broadcast of a scalar is lowered to an insertelement + a shufflevector
1771
+ // / operation. Only broadcasts to 0-d and 1-d vectors are lowered by this
1772
+ // / pattern, the higher rank cases are handled by another pattern.
1773
+ struct VectorBroadcastScalarToLowRankLowering
1774
+ : public ConvertOpToLLVMPattern<vector::BroadcastOp> {
1775
+ using ConvertOpToLLVMPattern<vector::BroadcastOp>::ConvertOpToLLVMPattern;
1774
1776
1775
1777
LogicalResult
1776
- matchAndRewrite (vector::SplatOp splatOp , OpAdaptor adaptor,
1778
+ matchAndRewrite (vector::BroadcastOp broadcast , OpAdaptor adaptor,
1777
1779
ConversionPatternRewriter &rewriter) const override {
1778
- VectorType resultType = cast<VectorType>(splatOp.getType ());
1780
+
1781
+ if (isa<VectorType>(broadcast.getSourceType ()))
1782
+ return rewriter.notifyMatchFailure (
1783
+ broadcast, " broadcast from vector type not handled" );
1784
+
1785
+ VectorType resultType = broadcast.getType ();
1779
1786
if (resultType.getRank () > 1 )
1780
- return failure ();
1787
+ return rewriter.notifyMatchFailure (broadcast,
1788
+ " broadcast to 2+-d handled elsewhere" );
1781
1789
1782
1790
// First insert it into a poison vector so we can shuffle it.
1783
- auto vectorType = typeConverter->convertType (splatOp .getType ());
1791
+ auto vectorType = typeConverter->convertType (broadcast .getType ());
1784
1792
Value poison =
1785
- rewriter.create <LLVM::PoisonOp>(splatOp .getLoc (), vectorType);
1793
+ rewriter.create <LLVM::PoisonOp>(broadcast .getLoc (), vectorType);
1786
1794
auto zero = rewriter.create <LLVM::ConstantOp>(
1787
- splatOp .getLoc (),
1795
+ broadcast .getLoc (),
1788
1796
typeConverter->convertType (rewriter.getIntegerType (32 )),
1789
1797
rewriter.getZeroAttr (rewriter.getIntegerType (32 )));
1790
1798
1791
1799
// For 0-d vector, we simply do `insertelement`.
1792
1800
if (resultType.getRank () == 0 ) {
1793
1801
rewriter.replaceOpWithNewOp <LLVM::InsertElementOp>(
1794
- splatOp , vectorType, poison, adaptor.getInput (), zero);
1802
+ broadcast , vectorType, poison, adaptor.getSource (), zero);
1795
1803
return success ();
1796
1804
}
1797
1805
1798
1806
// For 1-d vector, we additionally do a `vectorshuffle`.
1799
1807
auto v = rewriter.create <LLVM::InsertElementOp>(
1800
- splatOp .getLoc (), vectorType, poison, adaptor.getInput (), zero);
1808
+ broadcast .getLoc (), vectorType, poison, adaptor.getSource (), zero);
1801
1809
1802
- int64_t width = cast<VectorType>(splatOp .getType ()).getDimSize (0 );
1810
+ int64_t width = cast<VectorType>(broadcast .getType ()).getDimSize (0 );
1803
1811
SmallVector<int32_t > zeroValues (width, 0 );
1804
1812
1805
1813
// Shuffle the value across the desired number of elements.
1806
- rewriter.replaceOpWithNewOp <LLVM::ShuffleVectorOp>(splatOp , v, poison,
1814
+ rewriter.replaceOpWithNewOp <LLVM::ShuffleVectorOp>(broadcast , v, poison,
1807
1815
zeroValues);
1808
1816
return success ();
1809
1817
}
1810
1818
};
1811
1819
1812
- // / The Splat operation is lowered to an insertelement + a shufflevector
1813
- // / operation. Splat to only 2+-d vector result types are lowered by the
1814
- // / SplatNdOpLowering, the 1-d case is handled by SplatOpLowering.
1815
- struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern <SplatOp> {
1816
- using ConvertOpToLLVMPattern<SplatOp>::ConvertOpToLLVMPattern;
1820
+ // / The broadcast of a scalar is lowered to an insertelement + a shufflevector
1821
+ // / operation. Only broadcasts to 2+-d vector result types are lowered by this
1822
+ // / pattern, the 1-d case is handled by another pattern. Broadcasts from vectors
1823
+ // / are not converted to LLVM, only broadcasts from scalars are.
1824
+ struct VectorBroadcastScalarToNdLowering
1825
+ : public ConvertOpToLLVMPattern<BroadcastOp> {
1826
+ using ConvertOpToLLVMPattern<BroadcastOp>::ConvertOpToLLVMPattern;
1817
1827
1818
1828
LogicalResult
1819
- matchAndRewrite (SplatOp splatOp , OpAdaptor adaptor,
1829
+ matchAndRewrite (BroadcastOp broadcast , OpAdaptor adaptor,
1820
1830
ConversionPatternRewriter &rewriter) const override {
1821
- VectorType resultType = splatOp.getType ();
1831
+
1832
+ if (isa<VectorType>(broadcast.getSourceType ()))
1833
+ return rewriter.notifyMatchFailure (
1834
+ broadcast, " broadcast from vector type not handled" );
1835
+
1836
+ VectorType resultType = broadcast.getType ();
1822
1837
if (resultType.getRank () <= 1 )
1823
- return failure ();
1838
+ return rewriter.notifyMatchFailure (
1839
+ broadcast, " broadcast to 1-d or 0-d handled elsewhere" );
1824
1840
1825
1841
// First insert it into an undef vector so we can shuffle it.
1826
- auto loc = splatOp .getLoc ();
1842
+ auto loc = broadcast .getLoc ();
1827
1843
auto vectorTypeInfo =
1828
1844
LLVM::detail::extractNDVectorTypeInfo (resultType, *getTypeConverter ());
1829
1845
auto llvmNDVectorTy = vectorTypeInfo.llvmNDVectorTy ;
@@ -1834,26 +1850,26 @@ struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
1834
1850
// Construct returned value.
1835
1851
Value desc = rewriter.create <LLVM::PoisonOp>(loc, llvmNDVectorTy);
1836
1852
1837
- // Construct a 1-D vector with the splatted value that we insert in all the
1838
- // places within the returned descriptor.
1853
+ // Construct a 1-D vector with the broadcasted value that we insert in all
1854
+ // the places within the returned descriptor.
1839
1855
Value vdesc = rewriter.create <LLVM::PoisonOp>(loc, llvm1DVectorTy);
1840
1856
auto zero = rewriter.create <LLVM::ConstantOp>(
1841
1857
loc, typeConverter->convertType (rewriter.getIntegerType (32 )),
1842
1858
rewriter.getZeroAttr (rewriter.getIntegerType (32 )));
1843
1859
Value v = rewriter.create <LLVM::InsertElementOp>(loc, llvm1DVectorTy, vdesc,
1844
- adaptor.getInput (), zero);
1860
+ adaptor.getSource (), zero);
1845
1861
1846
1862
// Shuffle the value across the desired number of elements.
1847
1863
int64_t width = resultType.getDimSize (resultType.getRank () - 1 );
1848
1864
SmallVector<int32_t > zeroValues (width, 0 );
1849
1865
v = rewriter.create <LLVM::ShuffleVectorOp>(loc, v, v, zeroValues);
1850
1866
1851
- // Iterate of linear index, convert to coords space and insert splatted 1-D
1852
- // vector in each position.
1867
+ // Iterate of linear index, convert to coords space and insert broadcasted
1868
+ // 1-D vector in each position.
1853
1869
nDVectorIterate (vectorTypeInfo, rewriter, [&](ArrayRef<int64_t > position) {
1854
1870
desc = rewriter.create <LLVM::InsertValueOp>(loc, desc, v, position);
1855
1871
});
1856
- rewriter.replaceOp (splatOp , desc);
1872
+ rewriter.replaceOp (broadcast , desc);
1857
1873
return success ();
1858
1874
}
1859
1875
};
@@ -2035,6 +2051,19 @@ struct VectorScalableStepOpLowering
2035
2051
}
2036
2052
};
2037
2053
2054
+ // / Convert `vector.splat` to `vector.broadcast`. There is a path to LLVM from
2055
+ // / `vector.broadcast` through other patterns.
2056
+ struct VectorSplatToBroadcast : public ConvertOpToLLVMPattern <vector::SplatOp> {
2057
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
2058
+ LogicalResult
2059
+ matchAndRewrite (vector::SplatOp splat, OpAdaptor adaptor,
2060
+ ConversionPatternRewriter &rewriter) const override {
2061
+ rewriter.replaceOpWithNewOp <vector::BroadcastOp>(splat, splat.getType (),
2062
+ adaptor.getInput ());
2063
+ return success ();
2064
+ }
2065
+ };
2066
+
2038
2067
} // namespace
2039
2068
2040
2069
void mlir::vector::populateVectorRankReducingFMAPattern (
@@ -2063,7 +2092,8 @@ void mlir::populateVectorToLLVMConversionPatterns(
2063
2092
VectorInsertOpConversion, VectorPrintOpConversion,
2064
2093
VectorTypeCastOpConversion, VectorScaleOpConversion,
2065
2094
VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
2066
- VectorSplatOpLowering, VectorSplatNdOpLowering,
2095
+ VectorSplatToBroadcast, VectorBroadcastScalarToLowRankLowering,
2096
+ VectorBroadcastScalarToNdLowering,
2067
2097
VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
2068
2098
MaskedReductionOpConversion, VectorInterleaveOpLowering,
2069
2099
VectorDeinterleaveOpLowering, VectorFromElementsLowering,
0 commit comments