Skip to content

Commit 04e5e64

Browse files
authored
[VectorCombine] Generalize foldBitOpOfBitcasts to support more cast operations (#148350)
This patch generalizes the existing foldBitOpOfBitcasts optimization in the VectorCombine pass to handle additional cast operations beyond just bitcast. Fixes: [#146037](#146037) Summary The optimization now supports folding bitwise operations (AND/OR/XOR) with the following cast operations: - bitcast (original functionality) - trunc (truncate) - sext (sign extend) - zext (zero extend) The transformation pattern is: bitop(castop(x), castop(y)) -> castop(bitop(x, y)) This reduces the number of cast instructions from 2 to 1, improving performance on targets where cast operations are expensive or where performing bitwise operations on narrower types is beneficial. Implementation Details - Renamed foldBitOpOfBitcasts to foldBitOpOfCastops to reflect broader functionality - Extended pattern matching to handle any CastInst operation - Added validation for each cast type's constraints (e.g., trunc requires source > dest) - Updated cost model to use the actual cast opcode - Preserves IR flags from original instructions - Handles multi-use scenarios appropriately Testing - Added comprehensive tests in test/Transforms/VectorCombine/bitop-of-castops.ll - Tests cover all supported cast types with all bitwise operations - Includes negative tests for unsupported patterns - All existing VectorCombine tests pass
1 parent a270fdf commit 04e5e64

File tree

2 files changed

+339
-30
lines changed

2 files changed

+339
-30
lines changed

llvm/lib/Transforms/Vectorize/VectorCombine.cpp

Lines changed: 77 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ class VectorCombine {
115115
bool foldInsExtFNeg(Instruction &I);
116116
bool foldInsExtBinop(Instruction &I);
117117
bool foldInsExtVectorToShuffle(Instruction &I);
118-
bool foldBitOpOfBitcasts(Instruction &I);
118+
bool foldBitOpOfCastops(Instruction &I);
119119
bool foldBitcastShuffle(Instruction &I);
120120
bool scalarizeOpOrCmp(Instruction &I);
121121
bool scalarizeVPIntrinsic(Instruction &I);
@@ -808,48 +808,87 @@ bool VectorCombine::foldInsExtBinop(Instruction &I) {
808808
return true;
809809
}
810810

811-
bool VectorCombine::foldBitOpOfBitcasts(Instruction &I) {
812-
// Match: bitop(bitcast(x), bitcast(y)) -> bitcast(bitop(x, y))
813-
Value *LHSSrc, *RHSSrc;
814-
if (!match(&I, m_BitwiseLogic(m_BitCast(m_Value(LHSSrc)),
815-
m_BitCast(m_Value(RHSSrc)))))
811+
/// Match: bitop(castop(x), castop(y)) -> castop(bitop(x, y))
812+
/// Supports: bitcast, trunc, sext, zext
813+
bool VectorCombine::foldBitOpOfCastops(Instruction &I) {
814+
// Check if this is a bitwise logic operation
815+
auto *BinOp = dyn_cast<BinaryOperator>(&I);
816+
if (!BinOp || !BinOp->isBitwiseLogicOp())
816817
return false;
817818

819+
// Get the cast instructions
820+
auto *LHSCast = dyn_cast<CastInst>(BinOp->getOperand(0));
821+
auto *RHSCast = dyn_cast<CastInst>(BinOp->getOperand(1));
822+
if (!LHSCast || !RHSCast) {
823+
LLVM_DEBUG(dbgs() << " One or both operands are not cast instructions\n");
824+
return false;
825+
}
826+
827+
// Both casts must be the same type
828+
Instruction::CastOps CastOpcode = LHSCast->getOpcode();
829+
if (CastOpcode != RHSCast->getOpcode())
830+
return false;
831+
832+
// Only handle supported cast operations
833+
switch (CastOpcode) {
834+
case Instruction::BitCast:
835+
case Instruction::Trunc:
836+
case Instruction::SExt:
837+
case Instruction::ZExt:
838+
break;
839+
default:
840+
return false;
841+
}
842+
843+
Value *LHSSrc = LHSCast->getOperand(0);
844+
Value *RHSSrc = RHSCast->getOperand(0);
845+
818846
// Source types must match
819847
if (LHSSrc->getType() != RHSSrc->getType())
820848
return false;
821-
if (!LHSSrc->getType()->getScalarType()->isIntegerTy())
822-
return false;
823849

824-
// Only handle vector types
850+
// Only handle vector types with integer elements
825851
auto *SrcVecTy = dyn_cast<FixedVectorType>(LHSSrc->getType());
826852
auto *DstVecTy = dyn_cast<FixedVectorType>(I.getType());
827853
if (!SrcVecTy || !DstVecTy)
828854
return false;
829855

830-
// Same total bit width
831-
assert(SrcVecTy->getPrimitiveSizeInBits() ==
832-
DstVecTy->getPrimitiveSizeInBits() &&
833-
"Bitcast should preserve total bit width");
856+
if (!SrcVecTy->getScalarType()->isIntegerTy() ||
857+
!DstVecTy->getScalarType()->isIntegerTy())
858+
return false;
834859

835860
// Cost Check :
836-
// OldCost = bitlogic + 2*bitcasts
837-
// NewCost = bitlogic + bitcast
838-
auto *BinOp = cast<BinaryOperator>(&I);
861+
// OldCost = bitlogic + 2*casts
862+
// NewCost = bitlogic + cast
863+
864+
// Calculate specific costs for each cast with instruction context
865+
InstructionCost LHSCastCost =
866+
TTI.getCastInstrCost(CastOpcode, DstVecTy, SrcVecTy,
867+
TTI::CastContextHint::None, CostKind, LHSCast);
868+
InstructionCost RHSCastCost =
869+
TTI.getCastInstrCost(CastOpcode, DstVecTy, SrcVecTy,
870+
TTI::CastContextHint::None, CostKind, RHSCast);
871+
839872
InstructionCost OldCost =
840-
TTI.getArithmeticInstrCost(BinOp->getOpcode(), DstVecTy) +
841-
TTI.getCastInstrCost(Instruction::BitCast, DstVecTy, LHSSrc->getType(),
842-
TTI::CastContextHint::None) +
843-
TTI.getCastInstrCost(Instruction::BitCast, DstVecTy, RHSSrc->getType(),
844-
TTI::CastContextHint::None);
873+
TTI.getArithmeticInstrCost(BinOp->getOpcode(), DstVecTy, CostKind) +
874+
LHSCastCost + RHSCastCost;
875+
876+
// For new cost, we can't provide an instruction (it doesn't exist yet)
877+
InstructionCost GenericCastCost = TTI.getCastInstrCost(
878+
CastOpcode, DstVecTy, SrcVecTy, TTI::CastContextHint::None, CostKind);
879+
845880
InstructionCost NewCost =
846-
TTI.getArithmeticInstrCost(BinOp->getOpcode(), SrcVecTy) +
847-
TTI.getCastInstrCost(Instruction::BitCast, DstVecTy, SrcVecTy,
848-
TTI::CastContextHint::None);
881+
TTI.getArithmeticInstrCost(BinOp->getOpcode(), SrcVecTy, CostKind) +
882+
GenericCastCost;
849883

850-
LLVM_DEBUG(dbgs() << "Found a bitwise logic op of bitcasted values: " << I
851-
<< "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
852-
<< "\n");
884+
// Account for multi-use casts using specific costs
885+
if (!LHSCast->hasOneUse())
886+
NewCost += LHSCastCost;
887+
if (!RHSCast->hasOneUse())
888+
NewCost += RHSCastCost;
889+
890+
LLVM_DEBUG(dbgs() << "foldBitOpOfCastops: OldCost=" << OldCost
891+
<< " NewCost=" << NewCost << "\n");
853892

854893
if (NewCost > OldCost)
855894
return false;
@@ -862,8 +901,16 @@ bool VectorCombine::foldBitOpOfBitcasts(Instruction &I) {
862901

863902
Worklist.pushValue(NewOp);
864903

865-
// Bitcast the result back
866-
Value *Result = Builder.CreateBitCast(NewOp, I.getType());
904+
// Create the cast operation directly to ensure we get a new instruction
905+
Instruction *NewCast = CastInst::Create(CastOpcode, NewOp, I.getType());
906+
907+
// Preserve cast instruction flags
908+
NewCast->copyIRFlags(LHSCast);
909+
NewCast->andIRFlags(RHSCast);
910+
911+
// Insert the new instruction
912+
Value *Result = Builder.Insert(NewCast);
913+
867914
replaceValue(I, *Result);
868915
return true;
869916
}
@@ -3773,7 +3820,7 @@ bool VectorCombine::run() {
37733820
case Instruction::And:
37743821
case Instruction::Or:
37753822
case Instruction::Xor:
3776-
MadeChange |= foldBitOpOfBitcasts(I);
3823+
MadeChange |= foldBitOpOfCastops(I);
37773824
break;
37783825
default:
37793826
MadeChange |= shrinkType(I);

0 commit comments

Comments
 (0)