Skip to content

[mlir][Vector] Remove vector.extractelement and vector.insertelement ops #149603

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mlir/docs/Dialects/Vector.md
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ LLVM instructions are prefixed by the `llvm.` dialect prefix (e.g.
`llvm.insertvalue`). Such ops operate exclusively on 1-D vectors and aggregates
following the [LLVM LangRef](https://llvm.org/docs/LangRef.html). MLIR
operations are prefixed by the `vector.` dialect prefix (e.g.
`vector.insertelement`). Such ops operate exclusively on MLIR `n-D` `vector`
`vector.insert`). Such ops operate exclusively on MLIR `n-D` `vector`
types.

### Alternatives For Lowering an n-D Vector Type to LLVM
Expand Down
20 changes: 10 additions & 10 deletions mlir/docs/Tutorials/transform/Ch0.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ When no support is available, such an operation can be transformed into a loop:
%c8 = arith.constant 8 : index
%init = arith.constant 0.0 : f32
%result = scf.for %i = %c0 to %c8 step %c1 iter_args(%partial = %init) -> (f32) {
%element = vector.extractelement %0[%i : index] : vector<8xf32>
%element = vector.extract %0[%i] : f32 into vector<8xf32>
%updated = arith.addf %partial, %element : f32
scf.yield %updated : f32
}
Expand Down Expand Up @@ -145,7 +145,7 @@ linalg.generic {
%c0 = arith.constant 0.0 : f32
%0 = arith.cmpf ogt %in_one, %c0 : f32
%1 = arith.select %0, %in_one, %c0 : f32
linalg.yield %1 : f32
linalg.yield %1 : f32
}
```

Expand Down Expand Up @@ -185,7 +185,7 @@ In the case of `linalg.generic` operations, the iteration space is implicit and
For example, tiling the matrix multiplication presented above with tile sizes `(2, 8)`, we obtain a loop nest around a `linalg.generic` expressing the same operation on a `2x8` tensor.

```mlir
// A special "multi-for" loop that supports tensor-insertion semantics
// A special "multi-for" loop that supports tensor-insertion semantics
// as opposed to implicit updates. The resulting 8x16 tensor will be produced
// by this loop.
// The trip count of iterators is computed dividing the original tensor size,
Expand All @@ -202,9 +202,9 @@ For example, tiling the matrix multiplication presented above with tile sizes `(
// Take slices of inputs and outputs. Only the "i" and "j" dimensions are sliced.
%lhs_slice = tensor.extract_slice %lhs[%3, 0] [2, 10] [1, 1]
: tensor<8x10xf32> to tensor<2x10xf32>
%rhs_slice = tensor.extract_slice %rhs[0, %4] [10, 8] [1, 1]
%rhs_slice = tensor.extract_slice %rhs[0, %4] [10, 8] [1, 1]
: tensor<10x16xf32> to tensor<10x8xf32>
%result_slice = tensor.extract_slice %shared[%3, %4] [2, 8] [1, 1]
%result_slice = tensor.extract_slice %shared[%3, %4] [2, 8] [1, 1]
: tensor<8x16xf32> to tensor<2x8xf32>

// This is exactly the same operation as before, but now operating on smaller
Expand All @@ -214,7 +214,7 @@ For example, tiling the matrix multiplication presented above with tile sizes `(
affine_map<(i, j, k) -> (k, j)>,
affine_map<(i, j, k) -> (i, j)>],
iterator_types = ["parallel", "parallel", "reduction"]
} ins(%lhs_slice, %rhs_slice : tensor<2x10xf32>, tensor<10x8xf32>)
} ins(%lhs_slice, %rhs_slice : tensor<2x10xf32>, tensor<10x8xf32>)
outs(%result_slice : tensor<2x8xf32>) -> tensor<2x8xf32> {
^bb0(%lhs_one: f32, %rhs_one: f32, %init_one: f32):
%0 = arith.mulf %lhs_one, %rhs_one : f32
Expand All @@ -238,15 +238,15 @@ After materializing loops with tiling, another key code generation transformatio
1. the subset (slice) of the operand that is used by the tile, and
2. the tensor-level structured operation producing the whole tensor that is being sliced.

By inverting the `indexing_map` and applying it to the set of elements accessed through the slice, we can compute the part of the iteration space of the operation defining the full tensor necessary to compute the tile. Thus fusion boils down to replacing the `tensor.extract_slice` operation with the tile of the `linalg.generic` producing the original operand.
By inverting the `indexing_map` and applying it to the set of elements accessed through the slice, we can compute the part of the iteration space of the operation defining the full tensor necessary to compute the tile. Thus fusion boils down to replacing the `tensor.extract_slice` operation with the tile of the `linalg.generic` producing the original operand.

Let us assume that the matrix multiplication operation is followed by another operation that multiplies each element of the resulting matrix with itself. This trailing elementwise operation has a 2D iteration space, unlike the 3D one in matrix multiplication. Nevertheless, it is possible to tile the trailing operation and then fuse the producer of its operand, the matmul, into the loop generated by tiling. The untiled dimension will be used in its entirety.


```mlir
// Same loop as before.
%0 = scf.forall (%i, %j) in (4, 2)
shared_outs(%shared = %init)
%0 = scf.forall (%i, %j) in (4, 2)
shared_outs(%shared = %init)
-> (tensor<8x16xf32>, tensor<8x16xf32>) {
// Scale the loop induction variables by the tile sizes.
%1 = affine.apply affine_map<(d0) -> (d0 * 2)>(%i)
Expand Down Expand Up @@ -286,7 +286,7 @@ Let us assume that the matrix multiplication operation is followed by another op
indexing_maps = [affine_map<(i, j) -> (i, j)>,
affine_map<(i, j) -> (i, j)>],
iterator_types = ["parallel", "parallel"]
} ins(%partial : tensor<2x8xf32>)
} ins(%partial : tensor<2x8xf32>)
outs(%shared_slice : tensor<2x8xf32>) {
^bb0(%in: f32, %out: f32):
%5 = arith.mulf %in, %in : f32
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ def SparseVectorization : Pass<"sparse-vectorization", "ModuleOp"> {

After:
%3 = memref.load %2[] : memref<f32>
%4 = vector.insertelement %3, %cst[%c0 : index] : vector<32xf32>
%4 = vector.insert %3, %cst [0] : f32 into vector<32xf32>
%5 = scf.for %arg3 = %c0 to %c1024 step %c32 iter_args(%arg4 = %4) -> (vector<32xf32>) {
%8 = vector.load %0[%arg3] : memref<?xf32>, vector<32xf32>
%9 = vector.load %1[%arg3] : memref<1024xf32>, vector<32xf32>
Expand Down
100 changes: 0 additions & 100 deletions mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -646,55 +646,6 @@ def Vector_DeinterleaveOp :
}];
}

def Vector_ExtractElementOp :
Vector_Op<"extractelement", [Pure,
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
TypesMatchWith<"result type matches element type of vector operand",
"vector", "result",
"::llvm::cast<VectorType>($_self).getElementType()">]>,
Arguments<(ins AnyVectorOfAnyRank:$vector,
Optional<AnySignlessIntegerOrIndex>:$position)>,
Results<(outs AnyType:$result)> {
let summary = "extractelement operation";
let description = [{
Note: This operation is deprecated. Please use vector.extract insert.

Takes a 0-D or 1-D vector and a optional dynamic index position and
extracts the scalar at that position.

Note that this instruction resembles vector.extract, but is restricted to
0-D and 1-D vectors.
If the vector is 0-D, the position must be std::nullopt.


It is meant to be closer to LLVM's version:
https://llvm.org/docs/LangRef.html#extractelement-instruction

Example:

```mlir
%c = arith.constant 15 : i32
%1 = vector.extractelement %0[%c : i32]: vector<16xf32>
%2 = vector.extractelement %z[]: vector<f32>
```
}];
let assemblyFormat = [{
$vector `[` ($position^ `:` type($position))? `]` attr-dict `:` type($vector)
}];

let builders = [
// 0-D builder.
OpBuilder<(ins "Value":$source)>,
];
let extraClassDeclaration = [{
VectorType getSourceVectorType() {
return ::llvm::cast<VectorType>(getVector().getType());
}
}];
let hasVerifier = 1;
let hasFolder = 1;
}

def Vector_ExtractOp :
Vector_Op<"extract", [Pure,
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
Expand Down Expand Up @@ -890,57 +841,6 @@ def Vector_FromElementsOp : Vector_Op<"from_elements", [
let hasCanonicalizer = 1;
}

def Vector_InsertElementOp :
Vector_Op<"insertelement", [Pure,
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
TypesMatchWith<"source operand type matches element type of result",
"result", "source",
"::llvm::cast<VectorType>($_self).getElementType()">,
AllTypesMatch<["dest", "result"]>]>,
Arguments<(ins AnyType:$source, AnyVectorOfAnyRank:$dest,
Optional<AnySignlessIntegerOrIndex>:$position)>,
Results<(outs AnyVectorOfAnyRank:$result)> {
let summary = "insertelement operation";
let description = [{
Note: This operation is deprecated. Please use vector.insert instead.

Takes a scalar source, a 0-D or 1-D destination vector and a dynamic index
position and inserts the source into the destination at the proper position.

Note that this instruction resembles vector.insert, but is restricted to 0-D
and 1-D vectors.

It is meant to be closer to LLVM's version:
https://llvm.org/docs/LangRef.html#insertelement-instruction

Example:

```mlir
%c = arith.constant 15 : i32
%f = arith.constant 0.0f : f32
%1 = vector.insertelement %f, %0[%c : i32]: vector<16xf32>
%2 = vector.insertelement %f, %z[]: vector<f32>
```
}];
let assemblyFormat = [{
$source `,` $dest `[` ($position^ `:` type($position))? `]` attr-dict `:`
type($result)
}];

let builders = [
// 0-D builder.
OpBuilder<(ins "Value":$source, "Value":$dest)>,
];
let extraClassDeclaration = [{
Type getSourceType() { return getSource().getType(); }
VectorType getDestVectorType() {
return ::llvm::cast<VectorType>(getDest().getType());
}
}];
let hasVerifier = 1;
let hasFolder = 1;
}

def Vector_InsertOp :
Vector_Op<"insert", [Pure,
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
Expand Down
4 changes: 2 additions & 2 deletions mlir/include/mlir/Dialect/X86Vector/X86Vector.td
Original file line number Diff line number Diff line change
Expand Up @@ -397,8 +397,8 @@ def DotOp : AVX_LowOp<"dot", [Pure,

```mlir
%0 = x86vector.avx.intr.dot %a, %b : vector<8xf32>
%1 = vector.extractelement %0[%i0 : i32]: vector<8xf32>
%2 = vector.extractelement %0[%i4 : i32]: vector<8xf32>
%1 = vector.extract %0[%i0] : f32 from vector<8xf32>
%2 = vector.extract %0[%i4] : f32 from vector<8xf32>
%d = arith.addf %1, %2 : f32
```
}];
Expand Down
69 changes: 1 addition & 68 deletions mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1103,39 +1103,6 @@ class VectorShuffleOpConversion
}
};

class VectorExtractElementOpConversion
: public ConvertOpToLLVMPattern<vector::ExtractElementOp> {
public:
using ConvertOpToLLVMPattern<
vector::ExtractElementOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(vector::ExtractElementOp extractEltOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto vectorType = extractEltOp.getSourceVectorType();
auto llvmType = typeConverter->convertType(vectorType.getElementType());

// Bail if result type cannot be lowered.
if (!llvmType)
return failure();

if (vectorType.getRank() == 0) {
Location loc = extractEltOp.getLoc();
auto idxType = rewriter.getIndexType();
auto zero = rewriter.create<LLVM::ConstantOp>(
loc, typeConverter->convertType(idxType),
rewriter.getIntegerAttr(idxType, 0));
rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
extractEltOp, llvmType, adaptor.getVector(), zero);
return success();
}

rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
extractEltOp, llvmType, adaptor.getVector(), adaptor.getPosition());
return success();
}
};

class VectorExtractOpConversion
: public ConvertOpToLLVMPattern<vector::ExtractOp> {
public:
Expand Down Expand Up @@ -1238,39 +1205,6 @@ class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> {
}
};

class VectorInsertElementOpConversion
: public ConvertOpToLLVMPattern<vector::InsertElementOp> {
public:
using ConvertOpToLLVMPattern<vector::InsertElementOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(vector::InsertElementOp insertEltOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto vectorType = insertEltOp.getDestVectorType();
auto llvmType = typeConverter->convertType(vectorType);

// Bail if result type cannot be lowered.
if (!llvmType)
return failure();

if (vectorType.getRank() == 0) {
Location loc = insertEltOp.getLoc();
auto idxType = rewriter.getIndexType();
auto zero = rewriter.create<LLVM::ConstantOp>(
loc, typeConverter->convertType(idxType),
rewriter.getIntegerAttr(idxType, 0));
rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), zero);
return success();
}

rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(),
adaptor.getPosition());
return success();
}
};

class VectorInsertOpConversion
: public ConvertOpToLLVMPattern<vector::InsertOp> {
public:
Expand Down Expand Up @@ -2058,8 +1992,7 @@ void mlir::populateVectorToLLVMConversionPatterns(
VectorGatherOpConversion, VectorScatterOpConversion>(
converter, useVectorAlignment);
patterns.add<VectorBitCastOpConversion, VectorShuffleOpConversion,
VectorExtractElementOpConversion, VectorExtractOpConversion,
VectorFMAOp1DConversion, VectorInsertElementOpConversion,
VectorExtractOpConversion, VectorFMAOp1DConversion,
VectorInsertOpConversion, VectorPrintOpConversion,
VectorTypeCastOpConversion, VectorScaleOpConversion,
VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -688,7 +688,7 @@ struct PrepareTransferWriteConversion
/// %lastIndex = arith.subi %length, %c1 : index
/// vector.print punctuation <open>
/// scf.for %i = %c0 to %length step %c1 {
/// %el = vector.extractelement %v[%i : index] : vector<[4]xi32>
/// %el = vector.extract %v[%i] : i32 from vector<[4]xi32>
/// vector.print %el : i32 punctuation <no_punctuation>
/// %notLastIndex = arith.cmpi ult, %i, %lastIndex : index
/// scf.if %notLastIndex {
Expand Down Expand Up @@ -1639,7 +1639,7 @@ struct Strategy1d<TransferWriteOp> {
/// Is rewritten to approximately the following pseudo-IR:
/// ```
/// for i = 0 to 9 {
/// %t = vector.extractelement %vec[i] : vector<9xf32>
/// %t = vector.extract %vec[i] : f32 from vector<9xf32>
/// memref.store %t, %arg0[%a + i, %b] : memref<?x?xf32>
/// }
/// ```
Expand Down
Loading
Loading