Skip to content

Commit 6ed921f

Browse files
authored
Reland "[mlir][vector] Use vector.broadcast in place of vector.splat" (#150138)
This reverts commit 228c45f (PR #148937) . Now that #148027 is landed, I think it is safe to "reland" the original PR: #148028
1 parent 01e23c3 commit 6ed921f

File tree

4 files changed

+6
-6
lines changed

4 files changed

+6
-6
lines changed

mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ void mlir::arith::populateEmulateUnsupportedFloatsLegality(
123123
vector::OuterProductOp, vector::ScanOp>(
124124
[&](Operation *op) { return converter.isLegal(op); });
125125
target.addLegalOp<arith::BitcastOp, arith::ExtFOp, arith::TruncFOp,
126-
arith::ConstantOp, vector::SplatOp>();
126+
arith::ConstantOp, vector::SplatOp, vector::BroadcastOp>();
127127
}
128128

129129
void EmulateUnsupportedFloatsPass::runOnOperation() {

mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -688,7 +688,7 @@ Value MmaSyncBuilder::buildMmaSyncMemRefLoadOperand(
688688

689689
Type elementType = getElementTypeOrSelf(memref.getType());
690690
auto vt = VectorType::get(vectorShape, elementType);
691-
Value res = vector::SplatOp::create(b, loc, vt, loads[0]);
691+
Value res = vector::BroadcastOp::create(b, loc, vt, loads[0]);
692692
foreachIndividualVectorElement(
693693
res,
694694
/*applyFn=*/

mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -438,7 +438,7 @@ struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
438438
Value inc = arith::ConstantIndexOp::create(rewriter, loc,
439439
i * blockedChunkSize);
440440
Value incVec =
441-
vector::SplatOp::create(rewriter, loc, indiceType, inc);
441+
vector::BroadcastOp::create(rewriter, loc, indiceType, inc);
442442
Value offsetIndice =
443443
arith::AddIOp::create(rewriter, loc, indice, incVec);
444444

mlir/test/Dialect/NVGPU/transform-matmul-to-nvvm.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,14 @@ func.func @matmul_16x8x4xf32_global(
2020
// CHECK: %[[VAL_7:.*]] = affine.apply #[[$div4p8]]()[%[[TIDX]]]
2121
// CHECK: %[[VAL_8:.*]] = affine.apply #[[$mod4]]()[%[[TIDX]]]
2222
// CHECK: %[[VAL_9:.*]] = memref.load %[[VAL_0]][%[[VAL_7]], %[[VAL_8]]] : memref<16x4xf32>
23-
// CHECK: %[[VAL_10:.*]] = vector.splat %[[VAL_6]] : vector<2x1xf32>
23+
// CHECK: %[[VAL_10:.*]] = vector.broadcast %[[VAL_6]] : f32 to vector<2x1xf32>
2424
// CHECK: %[[VAL_11:.*]] = vector.insert %[[VAL_6]], %[[VAL_10]] [0, 0] : f32 into vector<2x1xf32>
2525
// CHECK: %[[LHS:.*]] = vector.insert %[[VAL_9]], %[[VAL_11]] [1, 0] : f32 into vector<2x1xf32>
2626
//
2727
// CHECK: %[[VAL_13:.*]] = affine.apply #[[$mod4]]()[%[[TIDX]]]
2828
// CHECK: %[[VAL_14:.*]] = affine.apply #[[$div4]]()[%[[TIDX]]]
2929
// CHECK: %[[VAL_15:.*]] = memref.load %[[VAL_1]][%[[VAL_13]], %[[VAL_14]]] : memref<4x8xf32>
30-
// CHECK: %[[VAL_16:.*]] = vector.splat %[[VAL_15]] : vector<1x1xf32>
30+
// CHECK: %[[VAL_16:.*]] = vector.broadcast %[[VAL_15]] : f32 to vector<1x1xf32>
3131
// CHECK: %[[RHS:.*]] = vector.insert %[[VAL_15]], %[[VAL_16]] [0, 0] : f32 into vector<1x1xf32>
3232
//
3333
// CHECK: %[[VAL_18:.*]] = affine.apply #[[$div4]]()[%[[TIDX]]]
@@ -42,7 +42,7 @@ func.func @matmul_16x8x4xf32_global(
4242
// CHECK: %[[VAL_27:.*]] = affine.apply #[[$div4p8]]()[%[[TIDX]]]
4343
// CHECK: %[[VAL_28:.*]] = affine.apply #[[$map4]]()[%[[TIDX]]]
4444
// CHECK: %[[VAL_29:.*]] = memref.load %[[VAL_2]][%[[VAL_27]], %[[VAL_28]]] : memref<16x8xf32>
45-
// CHECK: %[[VAL_30:.*]] = vector.splat %[[VAL_20]] : vector<2x2xf32>
45+
// CHECK: %[[VAL_30:.*]] = vector.broadcast %[[VAL_20]] : f32 to vector<2x2xf32>
4646
// CHECK: %[[VAL_31:.*]] = vector.insert %[[VAL_20]], %[[VAL_30]] [0, 0] : f32 into vector<2x2xf32>
4747
// CHECK: %[[VAL_32:.*]] = vector.insert %[[VAL_23]], %[[VAL_31]] [0, 1] : f32 into vector<2x2xf32>
4848
// CHECK: %[[VAL_33:.*]] = vector.insert %[[VAL_26]], %[[VAL_32]] [1, 0] : f32 into vector<2x2xf32>

0 commit comments

Comments
 (0)