Skip to content

Commit 75aa629

Browse files
authored
[mlir][vector] Add a check to ensure input vector rank equals target shape rank (#149239)
The crash is caused because, during IR transformation, the vector-unrolling pass (using ExtractStridedSliceOp) attempts to slice an input vector of higher rank using a target vector of lower rank, which is not supported. Fixes #148368.
1 parent 6d8d6f6 commit 75aa629

File tree

2 files changed

+27
-1
lines changed

2 files changed

+27
-1
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ struct UnrollTransferReadPattern
169169
auto sourceVectorType = readOp.getVectorType();
170170
SmallVector<int64_t> strides(targetShape->size(), 1);
171171
Location loc = readOp.getLoc();
172-
ArrayRef<int64_t> originalSize = readOp.getVectorType().getShape();
172+
ArrayRef<int64_t> originalSize = sourceVectorType.getShape();
173173

174174
// Prepare the result vector;
175175
Value result =
@@ -225,6 +225,14 @@ struct UnrollTransferWritePattern
225225
SmallVector<int64_t> strides(targetShape->size(), 1);
226226
Location loc = writeOp.getLoc();
227227
ArrayRef<int64_t> originalSize = sourceVectorType.getShape();
228+
// Bail-out if rank(source) != rank(target). The main limitation here is the
229+
// fact that `ExtractStridedSlice` requires the rank for the input and
230+
// output to match. If needed, we can relax this later.
231+
if (originalSize.size() != targetShape->size())
232+
return rewriter.notifyMatchFailure(
233+
writeOp,
234+
"expected source input vector rank to match target shape rank");
235+
228236
SmallVector<Value> originalIndices(writeOp.getIndices().begin(),
229237
writeOp.getIndices().end());
230238
SmallVector<int64_t> loopOrder =

mlir/test/Dialect/Vector/vector-transfer-unroll.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,24 @@ func.func @transfer_write_unroll(%mem : memref<4x4xf32>, %vec : vector<4x4xf32>)
6868

6969
// -----
7070

71+
// Ensure that cases with mismatched target and source shape ranks
72+
// do not lead to a crash.
73+
// Note: The vector unrolling target shape in `test-vector-transfer-unrolling-patterns`
74+
// is currently hard-coded to [2, 2].
75+
76+
// CHECK-LABEL: func @negative_transfer_write
77+
// CHECK-NOT: vector.extract_strided_slice
78+
// CHECK: vector.transfer_write
79+
// CHECK: return
80+
func.func @negative_transfer_write(%vec: vector<6x34x62xi8>) {
81+
%c0 = arith.constant 0 : index
82+
%alloc = memref.alloc() : memref<6x34x62xi8>
83+
vector.transfer_write %vec, %alloc[%c0, %c0, %c0]: vector<6x34x62xi8>, memref<6x34x62xi8>
84+
return
85+
}
86+
87+
// -----
88+
7189
// CHECK-LABEL: func @transfer_readwrite_unroll
7290
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
7391
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index

0 commit comments

Comments
 (0)