diff options
Diffstat (limited to 'contrib/llvm/lib/Analysis/InstructionSimplify.cpp')
-rw-r--r-- | contrib/llvm/lib/Analysis/InstructionSimplify.cpp | 1994 |
1 files changed, 1067 insertions, 927 deletions
diff --git a/contrib/llvm/lib/Analysis/InstructionSimplify.cpp b/contrib/llvm/lib/Analysis/InstructionSimplify.cpp index 796e6e4..b4f3b87 100644 --- a/contrib/llvm/lib/Analysis/InstructionSimplify.cpp +++ b/contrib/llvm/lib/Analysis/InstructionSimplify.cpp @@ -21,9 +21,12 @@ #include "llvm/ADT/SetVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/CaptureTracking.h" #include "llvm/Analysis/ConstantFolding.h" +#include "llvm/Analysis/LoopAnalysisManager.h" #include "llvm/Analysis/MemoryBuiltins.h" +#include "llvm/Analysis/OptimizationDiagnosticInfo.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/Analysis/VectorUtils.h" #include "llvm/IR/ConstantRange.h" @@ -34,6 +37,7 @@ #include "llvm/IR/Operator.h" #include "llvm/IR/PatternMatch.h" #include "llvm/IR/ValueHandle.h" +#include "llvm/Support/KnownBits.h" #include <algorithm> using namespace llvm; using namespace llvm::PatternMatch; @@ -45,49 +49,30 @@ enum { RecursionLimit = 3 }; STATISTIC(NumExpand, "Number of expansions"); STATISTIC(NumReassoc, "Number of reassociations"); -namespace { -struct Query { - const DataLayout &DL; - const TargetLibraryInfo *TLI; - const DominatorTree *DT; - AssumptionCache *AC; - const Instruction *CxtI; - - Query(const DataLayout &DL, const TargetLibraryInfo *tli, - const DominatorTree *dt, AssumptionCache *ac = nullptr, - const Instruction *cxti = nullptr) - : DL(DL), TLI(tli), DT(dt), AC(ac), CxtI(cxti) {} -}; -} // end anonymous namespace - -static Value *SimplifyAndInst(Value *, Value *, const Query &, unsigned); -static Value *SimplifyBinOp(unsigned, Value *, Value *, const Query &, +static Value *SimplifyAndInst(Value *, Value *, const SimplifyQuery &, unsigned); +static Value *SimplifyBinOp(unsigned, Value *, Value *, const SimplifyQuery &, unsigned); static Value *SimplifyFPBinOp(unsigned, Value *, Value *, const FastMathFlags &, - const Query &, unsigned); -static Value *SimplifyCmpInst(unsigned, Value *, Value *, const Query &, + const SimplifyQuery &, unsigned); +static Value *SimplifyCmpInst(unsigned, Value *, Value *, const SimplifyQuery &, unsigned); static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, - const Query &Q, unsigned MaxRecurse); -static Value *SimplifyOrInst(Value *, Value *, const Query &, unsigned); -static Value *SimplifyXorInst(Value *, Value *, const Query &, unsigned); + const SimplifyQuery &Q, unsigned MaxRecurse); +static Value *SimplifyOrInst(Value *, Value *, const SimplifyQuery &, unsigned); +static Value *SimplifyXorInst(Value *, Value *, const SimplifyQuery &, unsigned); static Value *SimplifyCastInst(unsigned, Value *, Type *, - const Query &, unsigned); + const SimplifyQuery &, unsigned); -/// For a boolean type, or a vector of boolean type, return false, or -/// a vector with every element false, as appropriate for the type. +/// For a boolean type or a vector of boolean type, return false or a vector +/// with every element false. static Constant *getFalse(Type *Ty) { - assert(Ty->getScalarType()->isIntegerTy(1) && - "Expected i1 type or a vector of i1!"); - return Constant::getNullValue(Ty); + return ConstantInt::getFalse(Ty); } -/// For a boolean type, or a vector of boolean type, return true, or -/// a vector with every element true, as appropriate for the type. +/// For a boolean type or a vector of boolean type, return true or a vector +/// with every element true. static Constant *getTrue(Type *Ty) { - assert(Ty->getScalarType()->isIntegerTy(1) && - "Expected i1 type or a vector of i1!"); - return Constant::getAllOnesValue(Ty); + return ConstantInt::getTrue(Ty); } /// isSameCompare - Is V equivalent to the comparison "LHS Pred RHS"? @@ -118,13 +103,8 @@ static bool ValueDominatesPHI(Value *V, PHINode *P, const DominatorTree *DT) { return false; // If we have a DominatorTree then do a precise test. - if (DT) { - if (!DT->isReachableFromEntry(P->getParent())) - return true; - if (!DT->isReachableFromEntry(I->getParent())) - return false; + if (DT) return DT->dominates(I, P); - } // Otherwise, if the instruction is in the entry block and is not an invoke, // then it obviously dominates all phi nodes. @@ -140,10 +120,9 @@ static bool ValueDominatesPHI(Value *V, PHINode *P, const DominatorTree *DT) { /// given by OpcodeToExpand, while "A" corresponds to LHS and "B op' C" to RHS. /// Also performs the transform "(A op' B) op C" -> "(A op C) op' (B op C)". /// Returns the simplified value, or null if no simplification was performed. -static Value *ExpandBinOp(unsigned Opcode, Value *LHS, Value *RHS, - unsigned OpcToExpand, const Query &Q, - unsigned MaxRecurse) { - Instruction::BinaryOps OpcodeToExpand = (Instruction::BinaryOps)OpcToExpand; +static Value *ExpandBinOp(Instruction::BinaryOps Opcode, Value *LHS, Value *RHS, + Instruction::BinaryOps OpcodeToExpand, + const SimplifyQuery &Q, unsigned MaxRecurse) { // Recursion is always used, so bail out at once if we already hit the limit. if (!MaxRecurse--) return nullptr; @@ -199,9 +178,10 @@ static Value *ExpandBinOp(unsigned Opcode, Value *LHS, Value *RHS, /// Generic simplifications for associative binary operations. /// Returns the simpler value, or null if none was found. -static Value *SimplifyAssociativeBinOp(unsigned Opc, Value *LHS, Value *RHS, - const Query &Q, unsigned MaxRecurse) { - Instruction::BinaryOps Opcode = (Instruction::BinaryOps)Opc; +static Value *SimplifyAssociativeBinOp(Instruction::BinaryOps Opcode, + Value *LHS, Value *RHS, + const SimplifyQuery &Q, + unsigned MaxRecurse) { assert(Instruction::isAssociative(Opcode) && "Not an associative operation!"); // Recursion is always used, so bail out at once if we already hit the limit. @@ -298,8 +278,9 @@ static Value *SimplifyAssociativeBinOp(unsigned Opc, Value *LHS, Value *RHS, /// try to simplify the binop by seeing whether evaluating it on both branches /// of the select results in the same value. Returns the common value if so, /// otherwise returns null. -static Value *ThreadBinOpOverSelect(unsigned Opcode, Value *LHS, Value *RHS, - const Query &Q, unsigned MaxRecurse) { +static Value *ThreadBinOpOverSelect(Instruction::BinaryOps Opcode, Value *LHS, + Value *RHS, const SimplifyQuery &Q, + unsigned MaxRecurse) { // Recursion is always used, so bail out at once if we already hit the limit. if (!MaxRecurse--) return nullptr; @@ -370,7 +351,7 @@ static Value *ThreadBinOpOverSelect(unsigned Opcode, Value *LHS, Value *RHS, /// comparison by seeing whether both branches of the select result in the same /// value. Returns the common value if so, otherwise returns null. static Value *ThreadCmpOverSelect(CmpInst::Predicate Pred, Value *LHS, - Value *RHS, const Query &Q, + Value *RHS, const SimplifyQuery &Q, unsigned MaxRecurse) { // Recursion is always used, so bail out at once if we already hit the limit. if (!MaxRecurse--) @@ -451,8 +432,9 @@ static Value *ThreadCmpOverSelect(CmpInst::Predicate Pred, Value *LHS, /// try to simplify the binop by seeing whether evaluating it on the incoming /// phi values yields the same result for every value. If so returns the common /// value, otherwise returns null. -static Value *ThreadBinOpOverPHI(unsigned Opcode, Value *LHS, Value *RHS, - const Query &Q, unsigned MaxRecurse) { +static Value *ThreadBinOpOverPHI(Instruction::BinaryOps Opcode, Value *LHS, + Value *RHS, const SimplifyQuery &Q, + unsigned MaxRecurse) { // Recursion is always used, so bail out at once if we already hit the limit. if (!MaxRecurse--) return nullptr; @@ -494,7 +476,7 @@ static Value *ThreadBinOpOverPHI(unsigned Opcode, Value *LHS, Value *RHS, /// yields the same result every time. If so returns the common result, /// otherwise returns null. static Value *ThreadCmpOverPHI(CmpInst::Predicate Pred, Value *LHS, Value *RHS, - const Query &Q, unsigned MaxRecurse) { + const SimplifyQuery &Q, unsigned MaxRecurse) { // Recursion is always used, so bail out at once if we already hit the limit. if (!MaxRecurse--) return nullptr; @@ -527,17 +509,26 @@ static Value *ThreadCmpOverPHI(CmpInst::Predicate Pred, Value *LHS, Value *RHS, return CommonValue; } +static Constant *foldOrCommuteConstant(Instruction::BinaryOps Opcode, + Value *&Op0, Value *&Op1, + const SimplifyQuery &Q) { + if (auto *CLHS = dyn_cast<Constant>(Op0)) { + if (auto *CRHS = dyn_cast<Constant>(Op1)) + return ConstantFoldBinaryOpOperands(Opcode, CLHS, CRHS, Q.DL); + + // Canonicalize the constant to the RHS if this is a commutative operation. + if (Instruction::isCommutative(Opcode)) + std::swap(Op0, Op1); + } + return nullptr; +} + /// Given operands for an Add, see if we can fold the result. /// If not, this returns null. static Value *SimplifyAddInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, - const Query &Q, unsigned MaxRecurse) { - if (Constant *CLHS = dyn_cast<Constant>(Op0)) { - if (Constant *CRHS = dyn_cast<Constant>(Op1)) - return ConstantFoldBinaryOpOperands(Instruction::Add, CLHS, CRHS, Q.DL); - - // Canonicalize the constant to the RHS. - std::swap(Op0, Op1); - } + const SimplifyQuery &Q, unsigned MaxRecurse) { + if (Constant *C = foldOrCommuteConstant(Instruction::Add, Op0, Op1, Q)) + return C; // X + undef -> undef if (match(Op1, m_Undef())) @@ -556,12 +547,20 @@ static Value *SimplifyAddInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, return Y; // X + ~X -> -1 since ~X = -X-1 + Type *Ty = Op0->getType(); if (match(Op0, m_Not(m_Specific(Op1))) || match(Op1, m_Not(m_Specific(Op0)))) - return Constant::getAllOnesValue(Op0->getType()); + return Constant::getAllOnesValue(Ty); + + // add nsw/nuw (xor Y, signmask), signmask --> Y + // The no-wrapping add guarantees that the top bit will be set by the add. + // Therefore, the xor must be clearing the already set sign bit of Y. + if ((isNSW || isNUW) && match(Op1, m_SignMask()) && + match(Op0, m_Xor(m_Value(Y), m_SignMask()))) + return Y; /// i1 add -> xor. - if (MaxRecurse && Op0->getType()->isIntegerTy(1)) + if (MaxRecurse && Op0->getType()->isIntOrIntVectorTy(1)) if (Value *V = SimplifyXorInst(Op0, Op1, Q, MaxRecurse-1)) return V; @@ -583,11 +582,8 @@ static Value *SimplifyAddInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, } Value *llvm::SimplifyAddInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, - const DataLayout &DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT, AssumptionCache *AC, - const Instruction *CxtI) { - return ::SimplifyAddInst(Op0, Op1, isNSW, isNUW, Query(DL, TLI, DT, AC, CxtI), - RecursionLimit); + const SimplifyQuery &Query) { + return ::SimplifyAddInst(Op0, Op1, isNSW, isNUW, Query, RecursionLimit); } /// \brief Compute the base pointer and cumulative constant offsets for V. @@ -602,7 +598,7 @@ Value *llvm::SimplifyAddInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, /// folding. static Constant *stripAndComputeConstantOffsets(const DataLayout &DL, Value *&V, bool AllowNonInbounds = false) { - assert(V->getType()->getScalarType()->isPointerTy()); + assert(V->getType()->isPtrOrPtrVectorTy()); Type *IntPtrTy = DL.getIntPtrType(V->getType())->getScalarType(); APInt Offset = APInt::getNullValue(IntPtrTy->getIntegerBitWidth()); @@ -631,8 +627,7 @@ static Constant *stripAndComputeConstantOffsets(const DataLayout &DL, Value *&V, } break; } - assert(V->getType()->getScalarType()->isPointerTy() && - "Unexpected operand type!"); + assert(V->getType()->isPtrOrPtrVectorTy() && "Unexpected operand type!"); } while (Visited.insert(V).second); Constant *OffsetIntPtr = ConstantInt::get(IntPtrTy, Offset); @@ -664,10 +659,9 @@ static Constant *computePointerDifference(const DataLayout &DL, Value *LHS, /// Given operands for a Sub, see if we can fold the result. /// If not, this returns null. static Value *SimplifySubInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, - const Query &Q, unsigned MaxRecurse) { - if (Constant *CLHS = dyn_cast<Constant>(Op0)) - if (Constant *CRHS = dyn_cast<Constant>(Op1)) - return ConstantFoldBinaryOpOperands(Instruction::Sub, CLHS, CRHS, Q.DL); + const SimplifyQuery &Q, unsigned MaxRecurse) { + if (Constant *C = foldOrCommuteConstant(Instruction::Sub, Op0, Op1, Q)) + return C; // X - undef -> undef // undef - X -> undef @@ -688,11 +682,8 @@ static Value *SimplifySubInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, if (isNUW) return Op0; - unsigned BitWidth = Op1->getType()->getScalarSizeInBits(); - APInt KnownZero(BitWidth, 0); - APInt KnownOne(BitWidth, 0); - computeKnownBits(Op1, KnownZero, KnownOne, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); - if (KnownZero == ~APInt::getSignBit(BitWidth)) { + KnownBits Known = computeKnownBits(Op1, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); + if (Known.Zero.isMaxSignedValue()) { // Op1 is either 0 or the minimum signed value. If the sub is NSW, then // Op1 must be 0 because negating the minimum signed value is undefined. if (isNSW) @@ -779,7 +770,7 @@ static Value *SimplifySubInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, return ConstantExpr::getIntegerCast(Result, Op0->getType(), true); // i1 sub -> xor. - if (MaxRecurse && Op0->getType()->isIntegerTy(1)) + if (MaxRecurse && Op0->getType()->isIntOrIntVectorTy(1)) if (Value *V = SimplifyXorInst(Op0, Op1, Q, MaxRecurse-1)) return V; @@ -796,24 +787,16 @@ static Value *SimplifySubInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, } Value *llvm::SimplifySubInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, - const DataLayout &DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT, AssumptionCache *AC, - const Instruction *CxtI) { - return ::SimplifySubInst(Op0, Op1, isNSW, isNUW, Query(DL, TLI, DT, AC, CxtI), - RecursionLimit); + const SimplifyQuery &Q) { + return ::SimplifySubInst(Op0, Op1, isNSW, isNUW, Q, RecursionLimit); } /// Given operands for an FAdd, see if we can fold the result. If not, this /// returns null. static Value *SimplifyFAddInst(Value *Op0, Value *Op1, FastMathFlags FMF, - const Query &Q, unsigned MaxRecurse) { - if (Constant *CLHS = dyn_cast<Constant>(Op0)) { - if (Constant *CRHS = dyn_cast<Constant>(Op1)) - return ConstantFoldBinaryOpOperands(Instruction::FAdd, CLHS, CRHS, Q.DL); - - // Canonicalize the constant to the RHS. - std::swap(Op0, Op1); - } + const SimplifyQuery &Q, unsigned MaxRecurse) { + if (Constant *C = foldOrCommuteConstant(Instruction::FAdd, Op0, Op1, Q)) + return C; // fadd X, -0 ==> X if (match(Op1, m_NegZero())) @@ -845,11 +828,9 @@ static Value *SimplifyFAddInst(Value *Op0, Value *Op1, FastMathFlags FMF, /// Given operands for an FSub, see if we can fold the result. If not, this /// returns null. static Value *SimplifyFSubInst(Value *Op0, Value *Op1, FastMathFlags FMF, - const Query &Q, unsigned MaxRecurse) { - if (Constant *CLHS = dyn_cast<Constant>(Op0)) { - if (Constant *CRHS = dyn_cast<Constant>(Op1)) - return ConstantFoldBinaryOpOperands(Instruction::FSub, CLHS, CRHS, Q.DL); - } + const SimplifyQuery &Q, unsigned MaxRecurse) { + if (Constant *C = foldOrCommuteConstant(Instruction::FSub, Op0, Op1, Q)) + return C; // fsub X, 0 ==> X if (match(Op1, m_Zero())) @@ -878,40 +859,28 @@ static Value *SimplifyFSubInst(Value *Op0, Value *Op1, FastMathFlags FMF, } /// Given the operands for an FMul, see if we can fold the result -static Value *SimplifyFMulInst(Value *Op0, Value *Op1, - FastMathFlags FMF, - const Query &Q, - unsigned MaxRecurse) { - if (Constant *CLHS = dyn_cast<Constant>(Op0)) { - if (Constant *CRHS = dyn_cast<Constant>(Op1)) - return ConstantFoldBinaryOpOperands(Instruction::FMul, CLHS, CRHS, Q.DL); +static Value *SimplifyFMulInst(Value *Op0, Value *Op1, FastMathFlags FMF, + const SimplifyQuery &Q, unsigned MaxRecurse) { + if (Constant *C = foldOrCommuteConstant(Instruction::FMul, Op0, Op1, Q)) + return C; - // Canonicalize the constant to the RHS. - std::swap(Op0, Op1); - } - - // fmul X, 1.0 ==> X - if (match(Op1, m_FPOne())) - return Op0; + // fmul X, 1.0 ==> X + if (match(Op1, m_FPOne())) + return Op0; - // fmul nnan nsz X, 0 ==> 0 - if (FMF.noNaNs() && FMF.noSignedZeros() && match(Op1, m_AnyZero())) - return Op1; + // fmul nnan nsz X, 0 ==> 0 + if (FMF.noNaNs() && FMF.noSignedZeros() && match(Op1, m_AnyZero())) + return Op1; - return nullptr; + return nullptr; } /// Given operands for a Mul, see if we can fold the result. /// If not, this returns null. -static Value *SimplifyMulInst(Value *Op0, Value *Op1, const Query &Q, +static Value *SimplifyMulInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, unsigned MaxRecurse) { - if (Constant *CLHS = dyn_cast<Constant>(Op0)) { - if (Constant *CRHS = dyn_cast<Constant>(Op1)) - return ConstantFoldBinaryOpOperands(Instruction::Mul, CLHS, CRHS, Q.DL); - - // Canonicalize the constant to the RHS. - std::swap(Op0, Op1); - } + if (Constant *C = foldOrCommuteConstant(Instruction::Mul, Op0, Op1, Q)) + return C; // X * undef -> 0 if (match(Op1, m_Undef())) @@ -932,7 +901,7 @@ static Value *SimplifyMulInst(Value *Op0, Value *Op1, const Query &Q, return X; // i1 mul -> and. - if (MaxRecurse && Op0->getType()->isIntegerTy(1)) + if (MaxRecurse && Op0->getType()->isIntOrIntVectorTy(1)) if (Value *V = SimplifyAndInst(Op0, Op1, Q, MaxRecurse-1)) return V; @@ -964,77 +933,87 @@ static Value *SimplifyMulInst(Value *Op0, Value *Op1, const Query &Q, } Value *llvm::SimplifyFAddInst(Value *Op0, Value *Op1, FastMathFlags FMF, - const DataLayout &DL, - const TargetLibraryInfo *TLI, - const DominatorTree *DT, AssumptionCache *AC, - const Instruction *CxtI) { - return ::SimplifyFAddInst(Op0, Op1, FMF, Query(DL, TLI, DT, AC, CxtI), - RecursionLimit); + const SimplifyQuery &Q) { + return ::SimplifyFAddInst(Op0, Op1, FMF, Q, RecursionLimit); } + Value *llvm::SimplifyFSubInst(Value *Op0, Value *Op1, FastMathFlags FMF, - const DataLayout &DL, - const TargetLibraryInfo *TLI, - const DominatorTree *DT, AssumptionCache *AC, - const Instruction *CxtI) { - return ::SimplifyFSubInst(Op0, Op1, FMF, Query(DL, TLI, DT, AC, CxtI), - RecursionLimit); + const SimplifyQuery &Q) { + return ::SimplifyFSubInst(Op0, Op1, FMF, Q, RecursionLimit); } Value *llvm::SimplifyFMulInst(Value *Op0, Value *Op1, FastMathFlags FMF, - const DataLayout &DL, - const TargetLibraryInfo *TLI, - const DominatorTree *DT, AssumptionCache *AC, - const Instruction *CxtI) { - return ::SimplifyFMulInst(Op0, Op1, FMF, Query(DL, TLI, DT, AC, CxtI), - RecursionLimit); + const SimplifyQuery &Q) { + return ::SimplifyFMulInst(Op0, Op1, FMF, Q, RecursionLimit); } -Value *llvm::SimplifyMulInst(Value *Op0, Value *Op1, const DataLayout &DL, - const TargetLibraryInfo *TLI, - const DominatorTree *DT, AssumptionCache *AC, - const Instruction *CxtI) { - return ::SimplifyMulInst(Op0, Op1, Query(DL, TLI, DT, AC, CxtI), - RecursionLimit); +Value *llvm::SimplifyMulInst(Value *Op0, Value *Op1, const SimplifyQuery &Q) { + return ::SimplifyMulInst(Op0, Op1, Q, RecursionLimit); } -/// Given operands for an SDiv or UDiv, see if we can fold the result. -/// If not, this returns null. -static Value *SimplifyDiv(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1, - const Query &Q, unsigned MaxRecurse) { - if (Constant *C0 = dyn_cast<Constant>(Op0)) - if (Constant *C1 = dyn_cast<Constant>(Op1)) - return ConstantFoldBinaryOpOperands(Opcode, C0, C1, Q.DL); - - bool isSigned = Opcode == Instruction::SDiv; +/// Check for common or similar folds of integer division or integer remainder. +static Value *simplifyDivRem(Value *Op0, Value *Op1, bool IsDiv) { + Type *Ty = Op0->getType(); // X / undef -> undef + // X % undef -> undef if (match(Op1, m_Undef())) return Op1; - // X / 0 -> undef, we don't need to preserve faults! + // X / 0 -> undef + // X % 0 -> undef + // We don't need to preserve faults! if (match(Op1, m_Zero())) - return UndefValue::get(Op1->getType()); + return UndefValue::get(Ty); + + // If any element of a constant divisor vector is zero, the whole op is undef. + auto *Op1C = dyn_cast<Constant>(Op1); + if (Op1C && Ty->isVectorTy()) { + unsigned NumElts = Ty->getVectorNumElements(); + for (unsigned i = 0; i != NumElts; ++i) { + Constant *Elt = Op1C->getAggregateElement(i); + if (Elt && Elt->isNullValue()) + return UndefValue::get(Ty); + } + } // undef / X -> 0 + // undef % X -> 0 if (match(Op0, m_Undef())) - return Constant::getNullValue(Op0->getType()); + return Constant::getNullValue(Ty); - // 0 / X -> 0, we don't need to preserve faults! + // 0 / X -> 0 + // 0 % X -> 0 if (match(Op0, m_Zero())) return Op0; + // X / X -> 1 + // X % X -> 0 + if (Op0 == Op1) + return IsDiv ? ConstantInt::get(Ty, 1) : Constant::getNullValue(Ty); + // X / 1 -> X - if (match(Op1, m_One())) - return Op0; + // X % 1 -> 0 + // If this is a boolean op (single-bit element type), we can't have + // division-by-zero or remainder-by-zero, so assume the divisor is 1. + if (match(Op1, m_One()) || Ty->isIntOrIntVectorTy(1)) + return IsDiv ? Op0 : Constant::getNullValue(Ty); - if (Op0->getType()->isIntegerTy(1)) - // It can't be division by zero, hence it must be division by one. - return Op0; + return nullptr; +} - // X / X -> 1 - if (Op0 == Op1) - return ConstantInt::get(Op0->getType(), 1); +/// Given operands for an SDiv or UDiv, see if we can fold the result. +/// If not, this returns null. +static Value *SimplifyDiv(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1, + const SimplifyQuery &Q, unsigned MaxRecurse) { + if (Constant *C = foldOrCommuteConstant(Opcode, Op0, Op1, Q)) + return C; + + if (Value *V = simplifyDivRem(Op0, Op1, true)) + return V; + + bool isSigned = Opcode == Instruction::SDiv; // (X * Y) / Y -> X if the multiplication does not overflow. Value *X = nullptr, *Y = nullptr; @@ -1061,7 +1040,7 @@ static Value *SimplifyDiv(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1, if (!isSigned && match(Op0, m_UDiv(m_Value(X), m_ConstantInt(C1))) && match(Op1, m_ConstantInt(C2))) { bool Overflow; - C1->getValue().umul_ov(C2->getValue(), Overflow); + (void)C1->getValue().umul_ov(C2->getValue(), Overflow); if (Overflow) return Constant::getNullValue(Op0->getType()); } @@ -1083,7 +1062,7 @@ static Value *SimplifyDiv(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1, /// Given operands for an SDiv, see if we can fold the result. /// If not, this returns null. -static Value *SimplifySDivInst(Value *Op0, Value *Op1, const Query &Q, +static Value *SimplifySDivInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, unsigned MaxRecurse) { if (Value *V = SimplifyDiv(Instruction::SDiv, Op0, Op1, Q, MaxRecurse)) return V; @@ -1091,17 +1070,13 @@ static Value *SimplifySDivInst(Value *Op0, Value *Op1, const Query &Q, return nullptr; } -Value *llvm::SimplifySDivInst(Value *Op0, Value *Op1, const DataLayout &DL, - const TargetLibraryInfo *TLI, - const DominatorTree *DT, AssumptionCache *AC, - const Instruction *CxtI) { - return ::SimplifySDivInst(Op0, Op1, Query(DL, TLI, DT, AC, CxtI), - RecursionLimit); +Value *llvm::SimplifySDivInst(Value *Op0, Value *Op1, const SimplifyQuery &Q) { + return ::SimplifySDivInst(Op0, Op1, Q, RecursionLimit); } /// Given operands for a UDiv, see if we can fold the result. /// If not, this returns null. -static Value *SimplifyUDivInst(Value *Op0, Value *Op1, const Query &Q, +static Value *SimplifyUDivInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, unsigned MaxRecurse) { if (Value *V = SimplifyDiv(Instruction::UDiv, Op0, Op1, Q, MaxRecurse)) return V; @@ -1119,16 +1094,15 @@ static Value *SimplifyUDivInst(Value *Op0, Value *Op1, const Query &Q, return nullptr; } -Value *llvm::SimplifyUDivInst(Value *Op0, Value *Op1, const DataLayout &DL, - const TargetLibraryInfo *TLI, - const DominatorTree *DT, AssumptionCache *AC, - const Instruction *CxtI) { - return ::SimplifyUDivInst(Op0, Op1, Query(DL, TLI, DT, AC, CxtI), - RecursionLimit); +Value *llvm::SimplifyUDivInst(Value *Op0, Value *Op1, const SimplifyQuery &Q) { + return ::SimplifyUDivInst(Op0, Op1, Q, RecursionLimit); } static Value *SimplifyFDivInst(Value *Op0, Value *Op1, FastMathFlags FMF, - const Query &Q, unsigned) { + const SimplifyQuery &Q, unsigned) { + if (Constant *C = foldOrCommuteConstant(Instruction::FDiv, Op0, Op1, Q)) + return C; + // undef / X -> undef (the undef could be a snan). if (match(Op0, m_Undef())) return Op0; @@ -1166,49 +1140,19 @@ static Value *SimplifyFDivInst(Value *Op0, Value *Op1, FastMathFlags FMF, } Value *llvm::SimplifyFDivInst(Value *Op0, Value *Op1, FastMathFlags FMF, - const DataLayout &DL, - const TargetLibraryInfo *TLI, - const DominatorTree *DT, AssumptionCache *AC, - const Instruction *CxtI) { - return ::SimplifyFDivInst(Op0, Op1, FMF, Query(DL, TLI, DT, AC, CxtI), - RecursionLimit); + const SimplifyQuery &Q) { + return ::SimplifyFDivInst(Op0, Op1, FMF, Q, RecursionLimit); } /// Given operands for an SRem or URem, see if we can fold the result. /// If not, this returns null. static Value *SimplifyRem(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1, - const Query &Q, unsigned MaxRecurse) { - if (Constant *C0 = dyn_cast<Constant>(Op0)) - if (Constant *C1 = dyn_cast<Constant>(Op1)) - return ConstantFoldBinaryOpOperands(Opcode, C0, C1, Q.DL); - - // X % undef -> undef - if (match(Op1, m_Undef())) - return Op1; - - // undef % X -> 0 - if (match(Op0, m_Undef())) - return Constant::getNullValue(Op0->getType()); - - // 0 % X -> 0, we don't need to preserve faults! - if (match(Op0, m_Zero())) - return Op0; - - // X % 0 -> undef, we don't need to preserve faults! - if (match(Op1, m_Zero())) - return UndefValue::get(Op0->getType()); - - // X % 1 -> 0 - if (match(Op1, m_One())) - return Constant::getNullValue(Op0->getType()); + const SimplifyQuery &Q, unsigned MaxRecurse) { + if (Constant *C = foldOrCommuteConstant(Opcode, Op0, Op1, Q)) + return C; - if (Op0->getType()->isIntegerTy(1)) - // It can't be remainder by zero, hence it must be remainder by one. - return Constant::getNullValue(Op0->getType()); - - // X % X -> 0 - if (Op0 == Op1) - return Constant::getNullValue(Op0->getType()); + if (Value *V = simplifyDivRem(Op0, Op1, false)) + return V; // (X % Y) % Y -> X % Y if ((Opcode == Instruction::SRem && @@ -1234,7 +1178,7 @@ static Value *SimplifyRem(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1, /// Given operands for an SRem, see if we can fold the result. /// If not, this returns null. -static Value *SimplifySRemInst(Value *Op0, Value *Op1, const Query &Q, +static Value *SimplifySRemInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, unsigned MaxRecurse) { if (Value *V = SimplifyRem(Instruction::SRem, Op0, Op1, Q, MaxRecurse)) return V; @@ -1242,17 +1186,13 @@ static Value *SimplifySRemInst(Value *Op0, Value *Op1, const Query &Q, return nullptr; } -Value *llvm::SimplifySRemInst(Value *Op0, Value *Op1, const DataLayout &DL, - const TargetLibraryInfo *TLI, - const DominatorTree *DT, AssumptionCache *AC, - const Instruction *CxtI) { - return ::SimplifySRemInst(Op0, Op1, Query(DL, TLI, DT, AC, CxtI), - RecursionLimit); +Value *llvm::SimplifySRemInst(Value *Op0, Value *Op1, const SimplifyQuery &Q) { + return ::SimplifySRemInst(Op0, Op1, Q, RecursionLimit); } /// Given operands for a URem, see if we can fold the result. /// If not, this returns null. -static Value *SimplifyURemInst(Value *Op0, Value *Op1, const Query &Q, +static Value *SimplifyURemInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, unsigned MaxRecurse) { if (Value *V = SimplifyRem(Instruction::URem, Op0, Op1, Q, MaxRecurse)) return V; @@ -1270,16 +1210,15 @@ static Value *SimplifyURemInst(Value *Op0, Value *Op1, const Query &Q, return nullptr; } -Value *llvm::SimplifyURemInst(Value *Op0, Value *Op1, const DataLayout &DL, - const TargetLibraryInfo *TLI, - const DominatorTree *DT, AssumptionCache *AC, - const Instruction *CxtI) { - return ::SimplifyURemInst(Op0, Op1, Query(DL, TLI, DT, AC, CxtI), - RecursionLimit); +Value *llvm::SimplifyURemInst(Value *Op0, Value *Op1, const SimplifyQuery &Q) { + return ::SimplifyURemInst(Op0, Op1, Q, RecursionLimit); } static Value *SimplifyFRemInst(Value *Op0, Value *Op1, FastMathFlags FMF, - const Query &, unsigned) { + const SimplifyQuery &Q, unsigned) { + if (Constant *C = foldOrCommuteConstant(Instruction::FRem, Op0, Op1, Q)) + return C; + // undef % X -> undef (the undef could be a snan). if (match(Op0, m_Undef())) return Op0; @@ -1298,12 +1237,8 @@ static Value *SimplifyFRemInst(Value *Op0, Value *Op1, FastMathFlags FMF, } Value *llvm::SimplifyFRemInst(Value *Op0, Value *Op1, FastMathFlags FMF, - const DataLayout &DL, - const TargetLibraryInfo *TLI, - const DominatorTree *DT, AssumptionCache *AC, - const Instruction *CxtI) { - return ::SimplifyFRemInst(Op0, Op1, FMF, Query(DL, TLI, DT, AC, CxtI), - RecursionLimit); + const SimplifyQuery &Q) { + return ::SimplifyFRemInst(Op0, Op1, FMF, Q, RecursionLimit); } /// Returns true if a shift by \c Amount always yields undef. @@ -1335,11 +1270,10 @@ static bool isUndefShift(Value *Amount) { /// Given operands for an Shl, LShr or AShr, see if we can fold the result. /// If not, this returns null. -static Value *SimplifyShift(unsigned Opcode, Value *Op0, Value *Op1, - const Query &Q, unsigned MaxRecurse) { - if (Constant *C0 = dyn_cast<Constant>(Op0)) - if (Constant *C1 = dyn_cast<Constant>(Op1)) - return ConstantFoldBinaryOpOperands(Opcode, C0, C1, Q.DL); +static Value *SimplifyShift(Instruction::BinaryOps Opcode, Value *Op0, + Value *Op1, const SimplifyQuery &Q, unsigned MaxRecurse) { + if (Constant *C = foldOrCommuteConstant(Opcode, Op0, Op1, Q)) + return C; // 0 shift by X -> 0 if (match(Op0, m_Zero())) @@ -1367,18 +1301,14 @@ static Value *SimplifyShift(unsigned Opcode, Value *Op0, Value *Op1, // If any bits in the shift amount make that value greater than or equal to // the number of bits in the type, the shift is undefined. - unsigned BitWidth = Op1->getType()->getScalarSizeInBits(); - APInt KnownZero(BitWidth, 0); - APInt KnownOne(BitWidth, 0); - computeKnownBits(Op1, KnownZero, KnownOne, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); - if (KnownOne.getLimitedValue() >= BitWidth) + KnownBits Known = computeKnownBits(Op1, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); + if (Known.One.getLimitedValue() >= Known.getBitWidth()) return UndefValue::get(Op0->getType()); // If all valid bits in the shift amount are known zero, the first operand is // unchanged. - unsigned NumValidShiftBits = Log2_32_Ceil(BitWidth); - APInt ShiftAmountMask = APInt::getLowBitsSet(BitWidth, NumValidShiftBits); - if ((KnownZero & ShiftAmountMask) == ShiftAmountMask) + unsigned NumValidShiftBits = Log2_32_Ceil(Known.getBitWidth()); + if (Known.countMinTrailingZeros() >= NumValidShiftBits) return Op0; return nullptr; @@ -1386,8 +1316,8 @@ static Value *SimplifyShift(unsigned Opcode, Value *Op0, Value *Op1, /// \brief Given operands for an Shl, LShr or AShr, see if we can /// fold the result. If not, this returns null. -static Value *SimplifyRightShift(unsigned Opcode, Value *Op0, Value *Op1, - bool isExact, const Query &Q, +static Value *SimplifyRightShift(Instruction::BinaryOps Opcode, Value *Op0, + Value *Op1, bool isExact, const SimplifyQuery &Q, unsigned MaxRecurse) { if (Value *V = SimplifyShift(Opcode, Op0, Op1, Q, MaxRecurse)) return V; @@ -1403,12 +1333,8 @@ static Value *SimplifyRightShift(unsigned Opcode, Value *Op0, Value *Op1, // The low bit cannot be shifted out of an exact shift if it is set. if (isExact) { - unsigned BitWidth = Op0->getType()->getScalarSizeInBits(); - APInt Op0KnownZero(BitWidth, 0); - APInt Op0KnownOne(BitWidth, 0); - computeKnownBits(Op0, Op0KnownZero, Op0KnownOne, Q.DL, /*Depth=*/0, Q.AC, - Q.CxtI, Q.DT); - if (Op0KnownOne[0]) + KnownBits Op0Known = computeKnownBits(Op0, Q.DL, /*Depth=*/0, Q.AC, Q.CxtI, Q.DT); + if (Op0Known.One[0]) return Op0; } @@ -1418,7 +1344,7 @@ static Value *SimplifyRightShift(unsigned Opcode, Value *Op0, Value *Op1, /// Given operands for an Shl, see if we can fold the result. /// If not, this returns null. static Value *SimplifyShlInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, - const Query &Q, unsigned MaxRecurse) { + const SimplifyQuery &Q, unsigned MaxRecurse) { if (Value *V = SimplifyShift(Instruction::Shl, Op0, Op1, Q, MaxRecurse)) return V; @@ -1435,17 +1361,14 @@ static Value *SimplifyShlInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, } Value *llvm::SimplifyShlInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, - const DataLayout &DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT, AssumptionCache *AC, - const Instruction *CxtI) { - return ::SimplifyShlInst(Op0, Op1, isNSW, isNUW, Query(DL, TLI, DT, AC, CxtI), - RecursionLimit); + const SimplifyQuery &Q) { + return ::SimplifyShlInst(Op0, Op1, isNSW, isNUW, Q, RecursionLimit); } /// Given operands for an LShr, see if we can fold the result. /// If not, this returns null. static Value *SimplifyLShrInst(Value *Op0, Value *Op1, bool isExact, - const Query &Q, unsigned MaxRecurse) { + const SimplifyQuery &Q, unsigned MaxRecurse) { if (Value *V = SimplifyRightShift(Instruction::LShr, Op0, Op1, isExact, Q, MaxRecurse)) return V; @@ -1459,18 +1382,14 @@ static Value *SimplifyLShrInst(Value *Op0, Value *Op1, bool isExact, } Value *llvm::SimplifyLShrInst(Value *Op0, Value *Op1, bool isExact, - const DataLayout &DL, - const TargetLibraryInfo *TLI, - const DominatorTree *DT, AssumptionCache *AC, - const Instruction *CxtI) { - return ::SimplifyLShrInst(Op0, Op1, isExact, Query(DL, TLI, DT, AC, CxtI), - RecursionLimit); + const SimplifyQuery &Q) { + return ::SimplifyLShrInst(Op0, Op1, isExact, Q, RecursionLimit); } /// Given operands for an AShr, see if we can fold the result. /// If not, this returns null. static Value *SimplifyAShrInst(Value *Op0, Value *Op1, bool isExact, - const Query &Q, unsigned MaxRecurse) { + const SimplifyQuery &Q, unsigned MaxRecurse) { if (Value *V = SimplifyRightShift(Instruction::AShr, Op0, Op1, isExact, Q, MaxRecurse)) return V; @@ -1493,14 +1412,12 @@ static Value *SimplifyAShrInst(Value *Op0, Value *Op1, bool isExact, } Value *llvm::SimplifyAShrInst(Value *Op0, Value *Op1, bool isExact, - const DataLayout &DL, - const TargetLibraryInfo *TLI, - const DominatorTree *DT, AssumptionCache *AC, - const Instruction *CxtI) { - return ::SimplifyAShrInst(Op0, Op1, isExact, Query(DL, TLI, DT, AC, CxtI), - RecursionLimit); + const SimplifyQuery &Q) { + return ::SimplifyAShrInst(Op0, Op1, isExact, Q, RecursionLimit); } +/// Commuted variants are assumed to be handled by calling this function again +/// with the parameters swapped. static Value *simplifyUnsignedRangeCheck(ICmpInst *ZeroICmp, ICmpInst *UnsignedICmp, bool IsAnd) { Value *X, *Y; @@ -1569,29 +1486,75 @@ static Value *simplifyAndOfICmpsWithSameOperands(ICmpInst *Op0, ICmpInst *Op1) { /// Commuted variants are assumed to be handled by calling this function again /// with the parameters swapped. -static Value *SimplifyAndOfICmps(ICmpInst *Op0, ICmpInst *Op1) { - if (Value *X = simplifyUnsignedRangeCheck(Op0, Op1, /*IsAnd=*/true)) - return X; +static Value *simplifyOrOfICmpsWithSameOperands(ICmpInst *Op0, ICmpInst *Op1) { + ICmpInst::Predicate Pred0, Pred1; + Value *A ,*B; + if (!match(Op0, m_ICmp(Pred0, m_Value(A), m_Value(B))) || + !match(Op1, m_ICmp(Pred1, m_Specific(A), m_Specific(B)))) + return nullptr; - if (Value *X = simplifyAndOfICmpsWithSameOperands(Op0, Op1)) - return X; + // We have (icmp Pred0, A, B) | (icmp Pred1, A, B). + // If Op1 is always implied true by Op0, then Op0 is a subset of Op1, and we + // can eliminate Op0 from this 'or'. + if (ICmpInst::isImpliedTrueByMatchingCmp(Pred0, Pred1)) + return Op1; + + // Check for any combination of predicates that cover the entire range of + // possibilities. + if ((Pred0 == ICmpInst::getInversePredicate(Pred1)) || + (Pred0 == ICmpInst::ICMP_NE && ICmpInst::isTrueWhenEqual(Pred1)) || + (Pred0 == ICmpInst::ICMP_SLE && Pred1 == ICmpInst::ICMP_SGE) || + (Pred0 == ICmpInst::ICMP_ULE && Pred1 == ICmpInst::ICMP_UGE)) + return getTrue(Op0->getType()); + + return nullptr; +} + +/// Test if a pair of compares with a shared operand and 2 constants has an +/// empty set intersection, full set union, or if one compare is a superset of +/// the other. +static Value *simplifyAndOrOfICmpsWithConstants(ICmpInst *Cmp0, ICmpInst *Cmp1, + bool IsAnd) { + // Look for this pattern: {and/or} (icmp X, C0), (icmp X, C1)). + if (Cmp0->getOperand(0) != Cmp1->getOperand(0)) + return nullptr; - // Look for this pattern: (icmp V, C0) & (icmp V, C1)). - Type *ITy = Op0->getType(); - ICmpInst::Predicate Pred0, Pred1; const APInt *C0, *C1; - Value *V; - if (match(Op0, m_ICmp(Pred0, m_Value(V), m_APInt(C0))) && - match(Op1, m_ICmp(Pred1, m_Specific(V), m_APInt(C1)))) { - // Make a constant range that's the intersection of the two icmp ranges. - // If the intersection is empty, we know that the result is false. - auto Range0 = ConstantRange::makeAllowedICmpRegion(Pred0, *C0); - auto Range1 = ConstantRange::makeAllowedICmpRegion(Pred1, *C1); - if (Range0.intersectWith(Range1).isEmptySet()) - return getFalse(ITy); - } + if (!match(Cmp0->getOperand(1), m_APInt(C0)) || + !match(Cmp1->getOperand(1), m_APInt(C1))) + return nullptr; + + auto Range0 = ConstantRange::makeExactICmpRegion(Cmp0->getPredicate(), *C0); + auto Range1 = ConstantRange::makeExactICmpRegion(Cmp1->getPredicate(), *C1); + + // For and-of-compares, check if the intersection is empty: + // (icmp X, C0) && (icmp X, C1) --> empty set --> false + if (IsAnd && Range0.intersectWith(Range1).isEmptySet()) + return getFalse(Cmp0->getType()); + + // For or-of-compares, check if the union is full: + // (icmp X, C0) || (icmp X, C1) --> full set --> true + if (!IsAnd && Range0.unionWith(Range1).isFullSet()) + return getTrue(Cmp0->getType()); + + // Is one range a superset of the other? + // If this is and-of-compares, take the smaller set: + // (icmp sgt X, 4) && (icmp sgt X, 42) --> icmp sgt X, 42 + // If this is or-of-compares, take the larger set: + // (icmp sgt X, 4) || (icmp sgt X, 42) --> icmp sgt X, 4 + if (Range0.contains(Range1)) + return IsAnd ? Cmp1 : Cmp0; + if (Range1.contains(Range0)) + return IsAnd ? Cmp0 : Cmp1; + return nullptr; +} + +static Value *simplifyAndOfICmpsWithAdd(ICmpInst *Op0, ICmpInst *Op1) { // (icmp (add V, C0), C1) & (icmp V, C0) + ICmpInst::Predicate Pred0, Pred1; + const APInt *C0, *C1; + Value *V; if (!match(Op0, m_ICmp(Pred0, m_Add(m_Value(V), m_APInt(C0)), m_APInt(C1)))) return nullptr; @@ -1602,6 +1565,7 @@ static Value *SimplifyAndOfICmps(ICmpInst *Op0, ICmpInst *Op1) { if (AddInst->getOperand(1) != Op1->getOperand(1)) return nullptr; + Type *ITy = Op0->getType(); bool isNSW = AddInst->hasNoSignedWrap(); bool isNUW = AddInst->hasNoUnsignedWrap(); @@ -1632,17 +1596,132 @@ static Value *SimplifyAndOfICmps(ICmpInst *Op0, ICmpInst *Op1) { return nullptr; } +static Value *simplifyAndOfICmps(ICmpInst *Op0, ICmpInst *Op1) { + if (Value *X = simplifyUnsignedRangeCheck(Op0, Op1, /*IsAnd=*/true)) + return X; + if (Value *X = simplifyUnsignedRangeCheck(Op1, Op0, /*IsAnd=*/true)) + return X; + + if (Value *X = simplifyAndOfICmpsWithSameOperands(Op0, Op1)) + return X; + if (Value *X = simplifyAndOfICmpsWithSameOperands(Op1, Op0)) + return X; + + if (Value *X = simplifyAndOrOfICmpsWithConstants(Op0, Op1, true)) + return X; + + if (Value *X = simplifyAndOfICmpsWithAdd(Op0, Op1)) + return X; + if (Value *X = simplifyAndOfICmpsWithAdd(Op1, Op0)) + return X; + + return nullptr; +} + +static Value *simplifyOrOfICmpsWithAdd(ICmpInst *Op0, ICmpInst *Op1) { + // (icmp (add V, C0), C1) | (icmp V, C0) + ICmpInst::Predicate Pred0, Pred1; + const APInt *C0, *C1; + Value *V; + if (!match(Op0, m_ICmp(Pred0, m_Add(m_Value(V), m_APInt(C0)), m_APInt(C1)))) + return nullptr; + + if (!match(Op1, m_ICmp(Pred1, m_Specific(V), m_Value()))) + return nullptr; + + auto *AddInst = cast<BinaryOperator>(Op0->getOperand(0)); + if (AddInst->getOperand(1) != Op1->getOperand(1)) + return nullptr; + + Type *ITy = Op0->getType(); + bool isNSW = AddInst->hasNoSignedWrap(); + bool isNUW = AddInst->hasNoUnsignedWrap(); + + const APInt Delta = *C1 - *C0; + if (C0->isStrictlyPositive()) { + if (Delta == 2) { + if (Pred0 == ICmpInst::ICMP_UGE && Pred1 == ICmpInst::ICMP_SLE) + return getTrue(ITy); + if (Pred0 == ICmpInst::ICMP_SGE && Pred1 == ICmpInst::ICMP_SLE && isNSW) + return getTrue(ITy); + } + if (Delta == 1) { + if (Pred0 == ICmpInst::ICMP_UGT && Pred1 == ICmpInst::ICMP_SLE) + return getTrue(ITy); + if (Pred0 == ICmpInst::ICMP_SGT && Pred1 == ICmpInst::ICMP_SLE && isNSW) + return getTrue(ITy); + } + } + if (C0->getBoolValue() && isNUW) { + if (Delta == 2) + if (Pred0 == ICmpInst::ICMP_UGE && Pred1 == ICmpInst::ICMP_ULE) + return getTrue(ITy); + if (Delta == 1) + if (Pred0 == ICmpInst::ICMP_UGT && Pred1 == ICmpInst::ICMP_ULE) + return getTrue(ITy); + } + + return nullptr; +} + +static Value *simplifyOrOfICmps(ICmpInst *Op0, ICmpInst *Op1) { + if (Value *X = simplifyUnsignedRangeCheck(Op0, Op1, /*IsAnd=*/false)) + return X; + if (Value *X = simplifyUnsignedRangeCheck(Op1, Op0, /*IsAnd=*/false)) + return X; + + if (Value *X = simplifyOrOfICmpsWithSameOperands(Op0, Op1)) + return X; + if (Value *X = simplifyOrOfICmpsWithSameOperands(Op1, Op0)) + return X; + + if (Value *X = simplifyAndOrOfICmpsWithConstants(Op0, Op1, false)) + return X; + + if (Value *X = simplifyOrOfICmpsWithAdd(Op0, Op1)) + return X; + if (Value *X = simplifyOrOfICmpsWithAdd(Op1, Op0)) + return X; + + return nullptr; +} + +static Value *simplifyAndOrOfICmps(Value *Op0, Value *Op1, bool IsAnd) { + // Look through casts of the 'and' operands to find compares. + auto *Cast0 = dyn_cast<CastInst>(Op0); + auto *Cast1 = dyn_cast<CastInst>(Op1); + if (Cast0 && Cast1 && Cast0->getOpcode() == Cast1->getOpcode() && + Cast0->getSrcTy() == Cast1->getSrcTy()) { + Op0 = Cast0->getOperand(0); + Op1 = Cast1->getOperand(0); + } + + auto *Cmp0 = dyn_cast<ICmpInst>(Op0); + auto *Cmp1 = dyn_cast<ICmpInst>(Op1); + if (!Cmp0 || !Cmp1) + return nullptr; + + Value *V = + IsAnd ? simplifyAndOfICmps(Cmp0, Cmp1) : simplifyOrOfICmps(Cmp0, Cmp1); + if (!V) + return nullptr; + if (!Cast0) + return V; + + // If we looked through casts, we can only handle a constant simplification + // because we are not allowed to create a cast instruction here. + if (auto *C = dyn_cast<Constant>(V)) + return ConstantExpr::getCast(Cast0->getOpcode(), C, Cast0->getType()); + + return nullptr; +} + /// Given operands for an And, see if we can fold the result. /// If not, this returns null. -static Value *SimplifyAndInst(Value *Op0, Value *Op1, const Query &Q, +static Value *SimplifyAndInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, unsigned MaxRecurse) { - if (Constant *CLHS = dyn_cast<Constant>(Op0)) { - if (Constant *CRHS = dyn_cast<Constant>(Op1)) - return ConstantFoldBinaryOpOperands(Instruction::And, CLHS, CRHS, Q.DL); - - // Canonicalize the constant to the RHS. - std::swap(Op0, Op1); - } + if (Constant *C = foldOrCommuteConstant(Instruction::And, Op0, Op1, Q)) + return C; // X & undef -> 0 if (match(Op1, m_Undef())) @@ -1666,16 +1745,31 @@ static Value *SimplifyAndInst(Value *Op0, Value *Op1, const Query &Q, return Constant::getNullValue(Op0->getType()); // (A | ?) & A = A - Value *A = nullptr, *B = nullptr; - if (match(Op0, m_Or(m_Value(A), m_Value(B))) && - (A == Op1 || B == Op1)) + if (match(Op0, m_c_Or(m_Specific(Op1), m_Value()))) return Op1; // A & (A | ?) = A - if (match(Op1, m_Or(m_Value(A), m_Value(B))) && - (A == Op0 || B == Op0)) + if (match(Op1, m_c_Or(m_Specific(Op0), m_Value()))) return Op0; + // A mask that only clears known zeros of a shifted value is a no-op. + Value *X; + const APInt *Mask; + const APInt *ShAmt; + if (match(Op1, m_APInt(Mask))) { + // If all bits in the inverted and shifted mask are clear: + // and (shl X, ShAmt), Mask --> shl X, ShAmt + if (match(Op0, m_Shl(m_Value(X), m_APInt(ShAmt))) && + (~(*Mask)).lshr(*ShAmt).isNullValue()) + return Op0; + + // If all bits in the inverted and shifted mask are clear: + // and (lshr X, ShAmt), Mask --> lshr X, ShAmt + if (match(Op0, m_LShr(m_Value(X), m_APInt(ShAmt))) && + (~(*Mask)).shl(*ShAmt).isNullValue()) + return Op0; + } + // A & (-A) = A if A is a power of two or zero. if (match(Op0, m_Neg(m_Specific(Op1))) || match(Op1, m_Neg(m_Specific(Op0)))) { @@ -1687,32 +1781,8 @@ static Value *SimplifyAndInst(Value *Op0, Value *Op1, const Query &Q, return Op1; } - if (auto *ICILHS = dyn_cast<ICmpInst>(Op0)) { - if (auto *ICIRHS = dyn_cast<ICmpInst>(Op1)) { - if (Value *V = SimplifyAndOfICmps(ICILHS, ICIRHS)) - return V; - if (Value *V = SimplifyAndOfICmps(ICIRHS, ICILHS)) - return V; - } - } - - // The compares may be hidden behind casts. Look through those and try the - // same folds as above. - auto *Cast0 = dyn_cast<CastInst>(Op0); - auto *Cast1 = dyn_cast<CastInst>(Op1); - if (Cast0 && Cast1 && Cast0->getOpcode() == Cast1->getOpcode() && - Cast0->getSrcTy() == Cast1->getSrcTy()) { - auto *Cmp0 = dyn_cast<ICmpInst>(Cast0->getOperand(0)); - auto *Cmp1 = dyn_cast<ICmpInst>(Cast1->getOperand(0)); - if (Cmp0 && Cmp1) { - Instruction::CastOps CastOpc = Cast0->getOpcode(); - Type *ResultType = Cast0->getType(); - if (auto *V = dyn_cast_or_null<Constant>(SimplifyAndOfICmps(Cmp0, Cmp1))) - return ConstantExpr::getCast(CastOpc, V, ResultType); - if (auto *V = dyn_cast_or_null<Constant>(SimplifyAndOfICmps(Cmp1, Cmp0))) - return ConstantExpr::getCast(CastOpc, V, ResultType); - } - } + if (Value *V = simplifyAndOrOfICmps(Op0, Op1, true)) + return V; // Try some generic simplifications for associative operations. if (Value *V = SimplifyAssociativeBinOp(Instruction::And, Op0, Op1, Q, @@ -1746,105 +1816,16 @@ static Value *SimplifyAndInst(Value *Op0, Value *Op1, const Query &Q, return nullptr; } -Value *llvm::SimplifyAndInst(Value *Op0, Value *Op1, const DataLayout &DL, - const TargetLibraryInfo *TLI, - const DominatorTree *DT, AssumptionCache *AC, - const Instruction *CxtI) { - return ::SimplifyAndInst(Op0, Op1, Query(DL, TLI, DT, AC, CxtI), - RecursionLimit); -} - -/// Commuted variants are assumed to be handled by calling this function again -/// with the parameters swapped. -static Value *simplifyOrOfICmpsWithSameOperands(ICmpInst *Op0, ICmpInst *Op1) { - ICmpInst::Predicate Pred0, Pred1; - Value *A ,*B; - if (!match(Op0, m_ICmp(Pred0, m_Value(A), m_Value(B))) || - !match(Op1, m_ICmp(Pred1, m_Specific(A), m_Specific(B)))) - return nullptr; - - // We have (icmp Pred0, A, B) | (icmp Pred1, A, B). - // If Op1 is always implied true by Op0, then Op0 is a subset of Op1, and we - // can eliminate Op0 from this 'or'. - if (ICmpInst::isImpliedTrueByMatchingCmp(Pred0, Pred1)) - return Op1; - - // Check for any combination of predicates that cover the entire range of - // possibilities. - if ((Pred0 == ICmpInst::getInversePredicate(Pred1)) || - (Pred0 == ICmpInst::ICMP_NE && ICmpInst::isTrueWhenEqual(Pred1)) || - (Pred0 == ICmpInst::ICMP_SLE && Pred1 == ICmpInst::ICMP_SGE) || - (Pred0 == ICmpInst::ICMP_ULE && Pred1 == ICmpInst::ICMP_UGE)) - return getTrue(Op0->getType()); - - return nullptr; -} - -/// Commuted variants are assumed to be handled by calling this function again -/// with the parameters swapped. -static Value *SimplifyOrOfICmps(ICmpInst *Op0, ICmpInst *Op1) { - if (Value *X = simplifyUnsignedRangeCheck(Op0, Op1, /*IsAnd=*/false)) - return X; - - if (Value *X = simplifyOrOfICmpsWithSameOperands(Op0, Op1)) - return X; - - // (icmp (add V, C0), C1) | (icmp V, C0) - ICmpInst::Predicate Pred0, Pred1; - const APInt *C0, *C1; - Value *V; - if (!match(Op0, m_ICmp(Pred0, m_Add(m_Value(V), m_APInt(C0)), m_APInt(C1)))) - return nullptr; - - if (!match(Op1, m_ICmp(Pred1, m_Specific(V), m_Value()))) - return nullptr; - - auto *AddInst = cast<BinaryOperator>(Op0->getOperand(0)); - if (AddInst->getOperand(1) != Op1->getOperand(1)) - return nullptr; - - Type *ITy = Op0->getType(); - bool isNSW = AddInst->hasNoSignedWrap(); - bool isNUW = AddInst->hasNoUnsignedWrap(); - - const APInt Delta = *C1 - *C0; - if (C0->isStrictlyPositive()) { - if (Delta == 2) { - if (Pred0 == ICmpInst::ICMP_UGE && Pred1 == ICmpInst::ICMP_SLE) - return getTrue(ITy); - if (Pred0 == ICmpInst::ICMP_SGE && Pred1 == ICmpInst::ICMP_SLE && isNSW) - return getTrue(ITy); - } - if (Delta == 1) { - if (Pred0 == ICmpInst::ICMP_UGT && Pred1 == ICmpInst::ICMP_SLE) - return getTrue(ITy); - if (Pred0 == ICmpInst::ICMP_SGT && Pred1 == ICmpInst::ICMP_SLE && isNSW) - return getTrue(ITy); - } - } - if (C0->getBoolValue() && isNUW) { - if (Delta == 2) - if (Pred0 == ICmpInst::ICMP_UGE && Pred1 == ICmpInst::ICMP_ULE) - return getTrue(ITy); - if (Delta == 1) - if (Pred0 == ICmpInst::ICMP_UGT && Pred1 == ICmpInst::ICMP_ULE) - return getTrue(ITy); - } - - return nullptr; +Value *llvm::SimplifyAndInst(Value *Op0, Value *Op1, const SimplifyQuery &Q) { + return ::SimplifyAndInst(Op0, Op1, Q, RecursionLimit); } /// Given operands for an Or, see if we can fold the result. /// If not, this returns null. -static Value *SimplifyOrInst(Value *Op0, Value *Op1, const Query &Q, +static Value *SimplifyOrInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, unsigned MaxRecurse) { - if (Constant *CLHS = dyn_cast<Constant>(Op0)) { - if (Constant *CRHS = dyn_cast<Constant>(Op1)) - return ConstantFoldBinaryOpOperands(Instruction::Or, CLHS, CRHS, Q.DL); - - // Canonicalize the constant to the RHS. - std::swap(Op0, Op1); - } + if (Constant *C = foldOrCommuteConstant(Instruction::Or, Op0, Op1, Q)) + return C; // X | undef -> -1 if (match(Op1, m_Undef())) @@ -1868,34 +1849,61 @@ static Value *SimplifyOrInst(Value *Op0, Value *Op1, const Query &Q, return Constant::getAllOnesValue(Op0->getType()); // (A & ?) | A = A - Value *A = nullptr, *B = nullptr; - if (match(Op0, m_And(m_Value(A), m_Value(B))) && - (A == Op1 || B == Op1)) + if (match(Op0, m_c_And(m_Specific(Op1), m_Value()))) return Op1; // A | (A & ?) = A - if (match(Op1, m_And(m_Value(A), m_Value(B))) && - (A == Op0 || B == Op0)) + if (match(Op1, m_c_And(m_Specific(Op0), m_Value()))) return Op0; // ~(A & ?) | A = -1 - if (match(Op0, m_Not(m_And(m_Value(A), m_Value(B)))) && - (A == Op1 || B == Op1)) + if (match(Op0, m_Not(m_c_And(m_Specific(Op1), m_Value())))) return Constant::getAllOnesValue(Op1->getType()); // A | ~(A & ?) = -1 - if (match(Op1, m_Not(m_And(m_Value(A), m_Value(B)))) && - (A == Op0 || B == Op0)) + if (match(Op1, m_Not(m_c_And(m_Specific(Op1), m_Value())))) return Constant::getAllOnesValue(Op0->getType()); - if (auto *ICILHS = dyn_cast<ICmpInst>(Op0)) { - if (auto *ICIRHS = dyn_cast<ICmpInst>(Op1)) { - if (Value *V = SimplifyOrOfICmps(ICILHS, ICIRHS)) - return V; - if (Value *V = SimplifyOrOfICmps(ICIRHS, ICILHS)) - return V; - } - } + Value *A, *B; + // (A & ~B) | (A ^ B) -> (A ^ B) + // (~B & A) | (A ^ B) -> (A ^ B) + // (A & ~B) | (B ^ A) -> (B ^ A) + // (~B & A) | (B ^ A) -> (B ^ A) + if (match(Op1, m_Xor(m_Value(A), m_Value(B))) && + (match(Op0, m_c_And(m_Specific(A), m_Not(m_Specific(B)))) || + match(Op0, m_c_And(m_Not(m_Specific(A)), m_Specific(B))))) + return Op1; + + // Commute the 'or' operands. + // (A ^ B) | (A & ~B) -> (A ^ B) + // (A ^ B) | (~B & A) -> (A ^ B) + // (B ^ A) | (A & ~B) -> (B ^ A) + // (B ^ A) | (~B & A) -> (B ^ A) + if (match(Op0, m_Xor(m_Value(A), m_Value(B))) && + (match(Op1, m_c_And(m_Specific(A), m_Not(m_Specific(B)))) || + match(Op1, m_c_And(m_Not(m_Specific(A)), m_Specific(B))))) + return Op0; + + // (A & B) | (~A ^ B) -> (~A ^ B) + // (B & A) | (~A ^ B) -> (~A ^ B) + // (A & B) | (B ^ ~A) -> (B ^ ~A) + // (B & A) | (B ^ ~A) -> (B ^ ~A) + if (match(Op0, m_And(m_Value(A), m_Value(B))) && + (match(Op1, m_c_Xor(m_Specific(A), m_Not(m_Specific(B)))) || + match(Op1, m_c_Xor(m_Not(m_Specific(A)), m_Specific(B))))) + return Op1; + + // (~A ^ B) | (A & B) -> (~A ^ B) + // (~A ^ B) | (B & A) -> (~A ^ B) + // (B ^ ~A) | (A & B) -> (B ^ ~A) + // (B ^ ~A) | (B & A) -> (B ^ ~A) + if (match(Op1, m_And(m_Value(A), m_Value(B))) && + (match(Op0, m_c_Xor(m_Specific(A), m_Not(m_Specific(B)))) || + match(Op0, m_c_Xor(m_Not(m_Specific(A)), m_Specific(B))))) + return Op0; + + if (Value *V = simplifyAndOrOfICmps(Op0, Op1, false)) + return V; // Try some generic simplifications for associative operations. if (Value *V = SimplifyAssociativeBinOp(Instruction::Or, Op0, Op1, Q, @@ -1914,37 +1922,27 @@ static Value *SimplifyOrInst(Value *Op0, Value *Op1, const Query &Q, MaxRecurse)) return V; - // (A & C)|(B & D) - Value *C = nullptr, *D = nullptr; - if (match(Op0, m_And(m_Value(A), m_Value(C))) && - match(Op1, m_And(m_Value(B), m_Value(D)))) { - ConstantInt *C1 = dyn_cast<ConstantInt>(C); - ConstantInt *C2 = dyn_cast<ConstantInt>(D); - if (C1 && C2 && (C1->getValue() == ~C2->getValue())) { + // (A & C1)|(B & C2) + const APInt *C1, *C2; + if (match(Op0, m_And(m_Value(A), m_APInt(C1))) && + match(Op1, m_And(m_Value(B), m_APInt(C2)))) { + if (*C1 == ~*C2) { // (A & C1)|(B & C2) // If we have: ((V + N) & C1) | (V & C2) // .. and C2 = ~C1 and C2 is 0+1+ and (N & C2) == 0 // replace with V+N. - Value *V1, *V2; - if ((C2->getValue() & (C2->getValue() + 1)) == 0 && // C2 == 0+1+ - match(A, m_Add(m_Value(V1), m_Value(V2)))) { + Value *N; + if (C2->isMask() && // C2 == 0+1+ + match(A, m_c_Add(m_Specific(B), m_Value(N)))) { // Add commutes, try both ways. - if (V1 == B && - MaskedValueIsZero(V2, C2->getValue(), Q.DL, 0, Q.AC, Q.CxtI, Q.DT)) - return A; - if (V2 == B && - MaskedValueIsZero(V1, C2->getValue(), Q.DL, 0, Q.AC, Q.CxtI, Q.DT)) + if (MaskedValueIsZero(N, *C2, Q.DL, 0, Q.AC, Q.CxtI, Q.DT)) return A; } // Or commutes, try both ways. - if ((C1->getValue() & (C1->getValue() + 1)) == 0 && - match(B, m_Add(m_Value(V1), m_Value(V2)))) { + if (C1->isMask() && + match(B, m_c_Add(m_Specific(A), m_Value(N)))) { // Add commutes, try both ways. - if (V1 == A && - MaskedValueIsZero(V2, C1->getValue(), Q.DL, 0, Q.AC, Q.CxtI, Q.DT)) - return B; - if (V2 == A && - MaskedValueIsZero(V1, C1->getValue(), Q.DL, 0, Q.AC, Q.CxtI, Q.DT)) + if (MaskedValueIsZero(N, *C1, Q.DL, 0, Q.AC, Q.CxtI, Q.DT)) return B; } } @@ -1959,25 +1957,16 @@ static Value *SimplifyOrInst(Value *Op0, Value *Op1, const Query &Q, return nullptr; } -Value *llvm::SimplifyOrInst(Value *Op0, Value *Op1, const DataLayout &DL, - const TargetLibraryInfo *TLI, - const DominatorTree *DT, AssumptionCache *AC, - const Instruction *CxtI) { - return ::SimplifyOrInst(Op0, Op1, Query(DL, TLI, DT, AC, CxtI), - RecursionLimit); +Value *llvm::SimplifyOrInst(Value *Op0, Value *Op1, const SimplifyQuery &Q) { + return ::SimplifyOrInst(Op0, Op1, Q, RecursionLimit); } /// Given operands for a Xor, see if we can fold the result. /// If not, this returns null. -static Value *SimplifyXorInst(Value *Op0, Value *Op1, const Query &Q, +static Value *SimplifyXorInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, unsigned MaxRecurse) { - if (Constant *CLHS = dyn_cast<Constant>(Op0)) { - if (Constant *CRHS = dyn_cast<Constant>(Op1)) - return ConstantFoldBinaryOpOperands(Instruction::Xor, CLHS, CRHS, Q.DL); - - // Canonicalize the constant to the RHS. - std::swap(Op0, Op1); - } + if (Constant *C = foldOrCommuteConstant(Instruction::Xor, Op0, Op1, Q)) + return C; // A ^ undef -> undef if (match(Op1, m_Undef())) @@ -2013,14 +2002,11 @@ static Value *SimplifyXorInst(Value *Op0, Value *Op1, const Query &Q, return nullptr; } -Value *llvm::SimplifyXorInst(Value *Op0, Value *Op1, const DataLayout &DL, - const TargetLibraryInfo *TLI, - const DominatorTree *DT, AssumptionCache *AC, - const Instruction *CxtI) { - return ::SimplifyXorInst(Op0, Op1, Query(DL, TLI, DT, AC, CxtI), - RecursionLimit); +Value *llvm::SimplifyXorInst(Value *Op0, Value *Op1, const SimplifyQuery &Q) { + return ::SimplifyXorInst(Op0, Op1, Q, RecursionLimit); } + static Type *GetCompareTy(Value *Op) { return CmpInst::makeCmpResultType(Op->getType()); } @@ -2254,34 +2240,55 @@ computePointerICmp(const DataLayout &DL, const TargetLibraryInfo *TLI, /// Fold an icmp when its operands have i1 scalar type. static Value *simplifyICmpOfBools(CmpInst::Predicate Pred, Value *LHS, - Value *RHS, const Query &Q) { + Value *RHS, const SimplifyQuery &Q) { Type *ITy = GetCompareTy(LHS); // The return type. Type *OpTy = LHS->getType(); // The operand type. - if (!OpTy->getScalarType()->isIntegerTy(1)) + if (!OpTy->isIntOrIntVectorTy(1)) return nullptr; - switch (Pred) { - default: - break; - case ICmpInst::ICMP_EQ: - // X == 1 -> X - if (match(RHS, m_One())) - return LHS; - break; - case ICmpInst::ICMP_NE: - // X != 0 -> X - if (match(RHS, m_Zero())) + // A boolean compared to true/false can be simplified in 14 out of the 20 + // (10 predicates * 2 constants) possible combinations. Cases not handled here + // require a 'not' of the LHS, so those must be transformed in InstCombine. + if (match(RHS, m_Zero())) { + switch (Pred) { + case CmpInst::ICMP_NE: // X != 0 -> X + case CmpInst::ICMP_UGT: // X >u 0 -> X + case CmpInst::ICMP_SLT: // X <s 0 -> X return LHS; - break; - case ICmpInst::ICMP_UGT: - // X >u 0 -> X - if (match(RHS, m_Zero())) + + case CmpInst::ICMP_ULT: // X <u 0 -> false + case CmpInst::ICMP_SGT: // X >s 0 -> false + return getFalse(ITy); + + case CmpInst::ICMP_UGE: // X >=u 0 -> true + case CmpInst::ICMP_SLE: // X <=s 0 -> true + return getTrue(ITy); + + default: break; + } + } else if (match(RHS, m_One())) { + switch (Pred) { + case CmpInst::ICMP_EQ: // X == 1 -> X + case CmpInst::ICMP_UGE: // X >=u 1 -> X + case CmpInst::ICMP_SLE: // X <=s -1 -> X return LHS; + + case CmpInst::ICMP_UGT: // X >u 1 -> false + case CmpInst::ICMP_SLT: // X <s -1 -> false + return getFalse(ITy); + + case CmpInst::ICMP_ULE: // X <=u 1 -> true + case CmpInst::ICMP_SGE: // X >=s -1 -> true + return getTrue(ITy); + + default: break; + } + } + + switch (Pred) { + default: break; case ICmpInst::ICMP_UGE: - // X >=u 1 -> X - if (match(RHS, m_One())) - return LHS; if (isImpliedCondition(RHS, LHS, Q.DL).getValueOr(false)) return getTrue(ITy); break; @@ -2296,16 +2303,6 @@ static Value *simplifyICmpOfBools(CmpInst::Predicate Pred, Value *LHS, if (isImpliedCondition(LHS, RHS, Q.DL).getValueOr(false)) return getTrue(ITy); break; - case ICmpInst::ICMP_SLT: - // X <s 0 -> X - if (match(RHS, m_Zero())) - return LHS; - break; - case ICmpInst::ICMP_SLE: - // X <=s -1 -> X - if (match(RHS, m_One())) - return LHS; - break; case ICmpInst::ICMP_ULE: if (isImpliedCondition(LHS, RHS, Q.DL).getValueOr(false)) return getTrue(ITy); @@ -2317,12 +2314,11 @@ static Value *simplifyICmpOfBools(CmpInst::Predicate Pred, Value *LHS, /// Try hard to fold icmp with zero RHS because this is a common case. static Value *simplifyICmpWithZero(CmpInst::Predicate Pred, Value *LHS, - Value *RHS, const Query &Q) { + Value *RHS, const SimplifyQuery &Q) { if (!match(RHS, m_Zero())) return nullptr; Type *ITy = GetCompareTy(LHS); // The return type. - bool LHSKnownNonNegative, LHSKnownNegative; switch (Pred) { default: llvm_unreachable("Unknown ICmp predicate!"); @@ -2340,43 +2336,202 @@ static Value *simplifyICmpWithZero(CmpInst::Predicate Pred, Value *LHS, if (isKnownNonZero(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT)) return getTrue(ITy); break; - case ICmpInst::ICMP_SLT: - ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, Q.DL, 0, Q.AC, - Q.CxtI, Q.DT); - if (LHSKnownNegative) + case ICmpInst::ICMP_SLT: { + KnownBits LHSKnown = computeKnownBits(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); + if (LHSKnown.isNegative()) return getTrue(ITy); - if (LHSKnownNonNegative) + if (LHSKnown.isNonNegative()) return getFalse(ITy); break; - case ICmpInst::ICMP_SLE: - ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, Q.DL, 0, Q.AC, - Q.CxtI, Q.DT); - if (LHSKnownNegative) + } + case ICmpInst::ICMP_SLE: { + KnownBits LHSKnown = computeKnownBits(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); + if (LHSKnown.isNegative()) return getTrue(ITy); - if (LHSKnownNonNegative && isKnownNonZero(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT)) + if (LHSKnown.isNonNegative() && + isKnownNonZero(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT)) return getFalse(ITy); break; - case ICmpInst::ICMP_SGE: - ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, Q.DL, 0, Q.AC, - Q.CxtI, Q.DT); - if (LHSKnownNegative) + } + case ICmpInst::ICMP_SGE: { + KnownBits LHSKnown = computeKnownBits(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); + if (LHSKnown.isNegative()) return getFalse(ITy); - if (LHSKnownNonNegative) + if (LHSKnown.isNonNegative()) return getTrue(ITy); break; - case ICmpInst::ICMP_SGT: - ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, Q.DL, 0, Q.AC, - Q.CxtI, Q.DT); - if (LHSKnownNegative) + } + case ICmpInst::ICMP_SGT: { + KnownBits LHSKnown = computeKnownBits(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); + if (LHSKnown.isNegative()) return getFalse(ITy); - if (LHSKnownNonNegative && isKnownNonZero(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT)) + if (LHSKnown.isNonNegative() && + isKnownNonZero(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT)) return getTrue(ITy); break; } + } return nullptr; } +/// Many binary operators with a constant operand have an easy-to-compute +/// range of outputs. This can be used to fold a comparison to always true or +/// always false. +static void setLimitsForBinOp(BinaryOperator &BO, APInt &Lower, APInt &Upper) { + unsigned Width = Lower.getBitWidth(); + const APInt *C; + switch (BO.getOpcode()) { + case Instruction::Add: + if (match(BO.getOperand(1), m_APInt(C)) && !C->isNullValue()) { + // FIXME: If we have both nuw and nsw, we should reduce the range further. + if (BO.hasNoUnsignedWrap()) { + // 'add nuw x, C' produces [C, UINT_MAX]. + Lower = *C; + } else if (BO.hasNoSignedWrap()) { + if (C->isNegative()) { + // 'add nsw x, -C' produces [SINT_MIN, SINT_MAX - C]. + Lower = APInt::getSignedMinValue(Width); + Upper = APInt::getSignedMaxValue(Width) + *C + 1; + } else { + // 'add nsw x, +C' produces [SINT_MIN + C, SINT_MAX]. + Lower = APInt::getSignedMinValue(Width) + *C; + Upper = APInt::getSignedMaxValue(Width) + 1; + } + } + } + break; + + case Instruction::And: + if (match(BO.getOperand(1), m_APInt(C))) + // 'and x, C' produces [0, C]. + Upper = *C + 1; + break; + + case Instruction::Or: + if (match(BO.getOperand(1), m_APInt(C))) + // 'or x, C' produces [C, UINT_MAX]. + Lower = *C; + break; + + case Instruction::AShr: + if (match(BO.getOperand(1), m_APInt(C)) && C->ult(Width)) { + // 'ashr x, C' produces [INT_MIN >> C, INT_MAX >> C]. + Lower = APInt::getSignedMinValue(Width).ashr(*C); + Upper = APInt::getSignedMaxValue(Width).ashr(*C) + 1; + } else if (match(BO.getOperand(0), m_APInt(C))) { + unsigned ShiftAmount = Width - 1; + if (!C->isNullValue() && BO.isExact()) + ShiftAmount = C->countTrailingZeros(); + if (C->isNegative()) { + // 'ashr C, x' produces [C, C >> (Width-1)] + Lower = *C; + Upper = C->ashr(ShiftAmount) + 1; + } else { + // 'ashr C, x' produces [C >> (Width-1), C] + Lower = C->ashr(ShiftAmount); + Upper = *C + 1; + } + } + break; + + case Instruction::LShr: + if (match(BO.getOperand(1), m_APInt(C)) && C->ult(Width)) { + // 'lshr x, C' produces [0, UINT_MAX >> C]. + Upper = APInt::getAllOnesValue(Width).lshr(*C) + 1; + } else if (match(BO.getOperand(0), m_APInt(C))) { + // 'lshr C, x' produces [C >> (Width-1), C]. + unsigned ShiftAmount = Width - 1; + if (!C->isNullValue() && BO.isExact()) + ShiftAmount = C->countTrailingZeros(); + Lower = C->lshr(ShiftAmount); + Upper = *C + 1; + } + break; + + case Instruction::Shl: + if (match(BO.getOperand(0), m_APInt(C))) { + if (BO.hasNoUnsignedWrap()) { + // 'shl nuw C, x' produces [C, C << CLZ(C)] + Lower = *C; + Upper = Lower.shl(Lower.countLeadingZeros()) + 1; + } else if (BO.hasNoSignedWrap()) { // TODO: What if both nuw+nsw? + if (C->isNegative()) { + // 'shl nsw C, x' produces [C << CLO(C)-1, C] + unsigned ShiftAmount = C->countLeadingOnes() - 1; + Lower = C->shl(ShiftAmount); + Upper = *C + 1; + } else { + // 'shl nsw C, x' produces [C, C << CLZ(C)-1] + unsigned ShiftAmount = C->countLeadingZeros() - 1; + Lower = *C; + Upper = C->shl(ShiftAmount) + 1; + } + } + } + break; + + case Instruction::SDiv: + if (match(BO.getOperand(1), m_APInt(C))) { + APInt IntMin = APInt::getSignedMinValue(Width); + APInt IntMax = APInt::getSignedMaxValue(Width); + if (C->isAllOnesValue()) { + // 'sdiv x, -1' produces [INT_MIN + 1, INT_MAX] + // where C != -1 and C != 0 and C != 1 + Lower = IntMin + 1; + Upper = IntMax + 1; + } else if (C->countLeadingZeros() < Width - 1) { + // 'sdiv x, C' produces [INT_MIN / C, INT_MAX / C] + // where C != -1 and C != 0 and C != 1 + Lower = IntMin.sdiv(*C); + Upper = IntMax.sdiv(*C); + if (Lower.sgt(Upper)) + std::swap(Lower, Upper); + Upper = Upper + 1; + assert(Upper != Lower && "Upper part of range has wrapped!"); + } + } else if (match(BO.getOperand(0), m_APInt(C))) { + if (C->isMinSignedValue()) { + // 'sdiv INT_MIN, x' produces [INT_MIN, INT_MIN / -2]. + Lower = *C; + Upper = Lower.lshr(1) + 1; + } else { + // 'sdiv C, x' produces [-|C|, |C|]. + Upper = C->abs() + 1; + Lower = (-Upper) + 1; + } + } + break; + + case Instruction::UDiv: + if (match(BO.getOperand(1), m_APInt(C)) && !C->isNullValue()) { + // 'udiv x, C' produces [0, UINT_MAX / C]. + Upper = APInt::getMaxValue(Width).udiv(*C) + 1; + } else if (match(BO.getOperand(0), m_APInt(C))) { + // 'udiv C, x' produces [0, C]. + Upper = *C + 1; + } + break; + + case Instruction::SRem: + if (match(BO.getOperand(1), m_APInt(C))) { + // 'srem x, C' produces (-|C|, |C|). + Upper = C->abs(); + Lower = (-Upper) + 1; + } + break; + + case Instruction::URem: + if (match(BO.getOperand(1), m_APInt(C))) + // 'urem x, C' produces [0, C). + Upper = *C; + break; + + default: + break; + } +} + static Value *simplifyICmpWithConstant(CmpInst::Predicate Pred, Value *LHS, Value *RHS) { const APInt *C; @@ -2390,114 +2545,12 @@ static Value *simplifyICmpWithConstant(CmpInst::Predicate Pred, Value *LHS, if (RHS_CR.isFullSet()) return ConstantInt::getTrue(GetCompareTy(RHS)); - // Many binary operators with constant RHS have easy to compute constant - // range. Use them to check whether the comparison is a tautology. + // Find the range of possible values for binary operators. unsigned Width = C->getBitWidth(); APInt Lower = APInt(Width, 0); APInt Upper = APInt(Width, 0); - const APInt *C2; - if (match(LHS, m_URem(m_Value(), m_APInt(C2)))) { - // 'urem x, C2' produces [0, C2). - Upper = *C2; - } else if (match(LHS, m_SRem(m_Value(), m_APInt(C2)))) { - // 'srem x, C2' produces (-|C2|, |C2|). - Upper = C2->abs(); - Lower = (-Upper) + 1; - } else if (match(LHS, m_UDiv(m_APInt(C2), m_Value()))) { - // 'udiv C2, x' produces [0, C2]. - Upper = *C2 + 1; - } else if (match(LHS, m_UDiv(m_Value(), m_APInt(C2)))) { - // 'udiv x, C2' produces [0, UINT_MAX / C2]. - APInt NegOne = APInt::getAllOnesValue(Width); - if (*C2 != 0) - Upper = NegOne.udiv(*C2) + 1; - } else if (match(LHS, m_SDiv(m_APInt(C2), m_Value()))) { - if (C2->isMinSignedValue()) { - // 'sdiv INT_MIN, x' produces [INT_MIN, INT_MIN / -2]. - Lower = *C2; - Upper = Lower.lshr(1) + 1; - } else { - // 'sdiv C2, x' produces [-|C2|, |C2|]. - Upper = C2->abs() + 1; - Lower = (-Upper) + 1; - } - } else if (match(LHS, m_SDiv(m_Value(), m_APInt(C2)))) { - APInt IntMin = APInt::getSignedMinValue(Width); - APInt IntMax = APInt::getSignedMaxValue(Width); - if (C2->isAllOnesValue()) { - // 'sdiv x, -1' produces [INT_MIN + 1, INT_MAX] - // where C2 != -1 and C2 != 0 and C2 != 1 - Lower = IntMin + 1; - Upper = IntMax + 1; - } else if (C2->countLeadingZeros() < Width - 1) { - // 'sdiv x, C2' produces [INT_MIN / C2, INT_MAX / C2] - // where C2 != -1 and C2 != 0 and C2 != 1 - Lower = IntMin.sdiv(*C2); - Upper = IntMax.sdiv(*C2); - if (Lower.sgt(Upper)) - std::swap(Lower, Upper); - Upper = Upper + 1; - assert(Upper != Lower && "Upper part of range has wrapped!"); - } - } else if (match(LHS, m_NUWShl(m_APInt(C2), m_Value()))) { - // 'shl nuw C2, x' produces [C2, C2 << CLZ(C2)] - Lower = *C2; - Upper = Lower.shl(Lower.countLeadingZeros()) + 1; - } else if (match(LHS, m_NSWShl(m_APInt(C2), m_Value()))) { - if (C2->isNegative()) { - // 'shl nsw C2, x' produces [C2 << CLO(C2)-1, C2] - unsigned ShiftAmount = C2->countLeadingOnes() - 1; - Lower = C2->shl(ShiftAmount); - Upper = *C2 + 1; - } else { - // 'shl nsw C2, x' produces [C2, C2 << CLZ(C2)-1] - unsigned ShiftAmount = C2->countLeadingZeros() - 1; - Lower = *C2; - Upper = C2->shl(ShiftAmount) + 1; - } - } else if (match(LHS, m_LShr(m_Value(), m_APInt(C2)))) { - // 'lshr x, C2' produces [0, UINT_MAX >> C2]. - APInt NegOne = APInt::getAllOnesValue(Width); - if (C2->ult(Width)) - Upper = NegOne.lshr(*C2) + 1; - } else if (match(LHS, m_LShr(m_APInt(C2), m_Value()))) { - // 'lshr C2, x' produces [C2 >> (Width-1), C2]. - unsigned ShiftAmount = Width - 1; - if (*C2 != 0 && cast<BinaryOperator>(LHS)->isExact()) - ShiftAmount = C2->countTrailingZeros(); - Lower = C2->lshr(ShiftAmount); - Upper = *C2 + 1; - } else if (match(LHS, m_AShr(m_Value(), m_APInt(C2)))) { - // 'ashr x, C2' produces [INT_MIN >> C2, INT_MAX >> C2]. - APInt IntMin = APInt::getSignedMinValue(Width); - APInt IntMax = APInt::getSignedMaxValue(Width); - if (C2->ult(Width)) { - Lower = IntMin.ashr(*C2); - Upper = IntMax.ashr(*C2) + 1; - } - } else if (match(LHS, m_AShr(m_APInt(C2), m_Value()))) { - unsigned ShiftAmount = Width - 1; - if (*C2 != 0 && cast<BinaryOperator>(LHS)->isExact()) - ShiftAmount = C2->countTrailingZeros(); - if (C2->isNegative()) { - // 'ashr C2, x' produces [C2, C2 >> (Width-1)] - Lower = *C2; - Upper = C2->ashr(ShiftAmount) + 1; - } else { - // 'ashr C2, x' produces [C2 >> (Width-1), C2] - Lower = C2->ashr(ShiftAmount); - Upper = *C2 + 1; - } - } else if (match(LHS, m_Or(m_Value(), m_APInt(C2)))) { - // 'or x, C2' produces [C2, UINT_MAX]. - Lower = *C2; - } else if (match(LHS, m_And(m_Value(), m_APInt(C2)))) { - // 'and x, C2' produces [0, C2]. - Upper = *C2 + 1; - } else if (match(LHS, m_NUWAdd(m_Value(), m_APInt(C2)))) { - // 'add nuw x, C2' produces [C2, UINT_MAX]. - Lower = *C2; - } + if (auto *BO = dyn_cast<BinaryOperator>(LHS)) + setLimitsForBinOp(*BO, Lower, Upper); ConstantRange LHS_CR = Lower != Upper ? ConstantRange(Lower, Upper) : ConstantRange(Width, true); @@ -2516,8 +2569,11 @@ static Value *simplifyICmpWithConstant(CmpInst::Predicate Pred, Value *LHS, return nullptr; } +/// TODO: A large part of this logic is duplicated in InstCombine's +/// foldICmpBinOp(). We should be able to share that and avoid the code +/// duplication. static Value *simplifyICmpWithBinOp(CmpInst::Predicate Pred, Value *LHS, - Value *RHS, const Query &Q, + Value *RHS, const SimplifyQuery &Q, unsigned MaxRecurse) { Type *ITy = GetCompareTy(LHS); // The return type. @@ -2597,15 +2653,11 @@ static Value *simplifyICmpWithBinOp(CmpInst::Predicate Pred, Value *LHS, return getTrue(ITy); if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SGE) { - bool RHSKnownNonNegative, RHSKnownNegative; - bool YKnownNonNegative, YKnownNegative; - ComputeSignBit(RHS, RHSKnownNonNegative, RHSKnownNegative, Q.DL, 0, - Q.AC, Q.CxtI, Q.DT); - ComputeSignBit(Y, YKnownNonNegative, YKnownNegative, Q.DL, 0, Q.AC, - Q.CxtI, Q.DT); - if (RHSKnownNonNegative && YKnownNegative) + KnownBits RHSKnown = computeKnownBits(RHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); + KnownBits YKnown = computeKnownBits(Y, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); + if (RHSKnown.isNonNegative() && YKnown.isNegative()) return Pred == ICmpInst::ICMP_SLT ? getTrue(ITy) : getFalse(ITy); - if (RHSKnownNegative || YKnownNonNegative) + if (RHSKnown.isNegative() || YKnown.isNonNegative()) return Pred == ICmpInst::ICMP_SLT ? getFalse(ITy) : getTrue(ITy); } } @@ -2617,31 +2669,25 @@ static Value *simplifyICmpWithBinOp(CmpInst::Predicate Pred, Value *LHS, return getFalse(ITy); if (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SLE) { - bool LHSKnownNonNegative, LHSKnownNegative; - bool YKnownNonNegative, YKnownNegative; - ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, Q.DL, 0, - Q.AC, Q.CxtI, Q.DT); - ComputeSignBit(Y, YKnownNonNegative, YKnownNegative, Q.DL, 0, Q.AC, - Q.CxtI, Q.DT); - if (LHSKnownNonNegative && YKnownNegative) + KnownBits LHSKnown = computeKnownBits(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); + KnownBits YKnown = computeKnownBits(Y, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); + if (LHSKnown.isNonNegative() && YKnown.isNegative()) return Pred == ICmpInst::ICMP_SGT ? getTrue(ITy) : getFalse(ITy); - if (LHSKnownNegative || YKnownNonNegative) + if (LHSKnown.isNegative() || YKnown.isNonNegative()) return Pred == ICmpInst::ICMP_SGT ? getFalse(ITy) : getTrue(ITy); } } } // icmp pred (and X, Y), X - if (LBO && match(LBO, m_CombineOr(m_And(m_Value(), m_Specific(RHS)), - m_And(m_Specific(RHS), m_Value())))) { + if (LBO && match(LBO, m_c_And(m_Value(), m_Specific(RHS)))) { if (Pred == ICmpInst::ICMP_UGT) return getFalse(ITy); if (Pred == ICmpInst::ICMP_ULE) return getTrue(ITy); } // icmp pred X, (and X, Y) - if (RBO && match(RBO, m_CombineOr(m_And(m_Value(), m_Specific(LHS)), - m_And(m_Specific(LHS), m_Value())))) { + if (RBO && match(RBO, m_c_And(m_Value(), m_Specific(LHS)))) { if (Pred == ICmpInst::ICMP_UGE) return getTrue(ITy); if (Pred == ICmpInst::ICMP_ULT) @@ -2672,28 +2718,27 @@ static Value *simplifyICmpWithBinOp(CmpInst::Predicate Pred, Value *LHS, // icmp pred (urem X, Y), Y if (LBO && match(LBO, m_URem(m_Value(), m_Specific(RHS)))) { - bool KnownNonNegative, KnownNegative; switch (Pred) { default: break; case ICmpInst::ICMP_SGT: - case ICmpInst::ICMP_SGE: - ComputeSignBit(RHS, KnownNonNegative, KnownNegative, Q.DL, 0, Q.AC, - Q.CxtI, Q.DT); - if (!KnownNonNegative) + case ICmpInst::ICMP_SGE: { + KnownBits Known = computeKnownBits(RHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); + if (!Known.isNonNegative()) break; LLVM_FALLTHROUGH; + } case ICmpInst::ICMP_EQ: case ICmpInst::ICMP_UGT: case ICmpInst::ICMP_UGE: return getFalse(ITy); case ICmpInst::ICMP_SLT: - case ICmpInst::ICMP_SLE: - ComputeSignBit(RHS, KnownNonNegative, KnownNegative, Q.DL, 0, Q.AC, - Q.CxtI, Q.DT); - if (!KnownNonNegative) + case ICmpInst::ICMP_SLE: { + KnownBits Known = computeKnownBits(RHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); + if (!Known.isNonNegative()) break; LLVM_FALLTHROUGH; + } case ICmpInst::ICMP_NE: case ICmpInst::ICMP_ULT: case ICmpInst::ICMP_ULE: @@ -2703,28 +2748,27 @@ static Value *simplifyICmpWithBinOp(CmpInst::Predicate Pred, Value *LHS, // icmp pred X, (urem Y, X) if (RBO && match(RBO, m_URem(m_Value(), m_Specific(LHS)))) { - bool KnownNonNegative, KnownNegative; switch (Pred) { default: break; case ICmpInst::ICMP_SGT: - case ICmpInst::ICMP_SGE: - ComputeSignBit(LHS, KnownNonNegative, KnownNegative, Q.DL, 0, Q.AC, - Q.CxtI, Q.DT); - if (!KnownNonNegative) + case ICmpInst::ICMP_SGE: { + KnownBits Known = computeKnownBits(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); + if (!Known.isNonNegative()) break; LLVM_FALLTHROUGH; + } case ICmpInst::ICMP_NE: case ICmpInst::ICMP_UGT: case ICmpInst::ICMP_UGE: return getTrue(ITy); case ICmpInst::ICMP_SLT: - case ICmpInst::ICMP_SLE: - ComputeSignBit(LHS, KnownNonNegative, KnownNegative, Q.DL, 0, Q.AC, - Q.CxtI, Q.DT); - if (!KnownNonNegative) + case ICmpInst::ICMP_SLE: { + KnownBits Known = computeKnownBits(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); + if (!Known.isNonNegative()) break; LLVM_FALLTHROUGH; + } case ICmpInst::ICMP_EQ: case ICmpInst::ICMP_ULT: case ICmpInst::ICMP_ULE: @@ -2773,14 +2817,14 @@ static Value *simplifyICmpWithBinOp(CmpInst::Predicate Pred, Value *LHS, // - CI2 is one // - CI isn't zero if (LBO->hasNoSignedWrap() || LBO->hasNoUnsignedWrap() || - *CI2Val == 1 || !CI->isZero()) { + CI2Val->isOneValue() || !CI->isZero()) { if (Pred == ICmpInst::ICMP_EQ) return ConstantInt::getFalse(RHS->getContext()); if (Pred == ICmpInst::ICMP_NE) return ConstantInt::getTrue(RHS->getContext()); } } - if (CIVal->isSignBit() && *CI2Val == 1) { + if (CIVal->isSignMask() && CI2Val->isOneValue()) { if (Pred == ICmpInst::ICMP_UGT) return ConstantInt::getFalse(RHS->getContext()); if (Pred == ICmpInst::ICMP_ULE) @@ -2796,10 +2840,19 @@ static Value *simplifyICmpWithBinOp(CmpInst::Predicate Pred, Value *LHS, break; case Instruction::UDiv: case Instruction::LShr: - if (ICmpInst::isSigned(Pred)) + if (ICmpInst::isSigned(Pred) || !LBO->isExact() || !RBO->isExact()) break; - LLVM_FALLTHROUGH; + if (Value *V = SimplifyICmpInst(Pred, LBO->getOperand(0), + RBO->getOperand(0), Q, MaxRecurse - 1)) + return V; + break; case Instruction::SDiv: + if (!ICmpInst::isEquality(Pred) || !LBO->isExact() || !RBO->isExact()) + break; + if (Value *V = SimplifyICmpInst(Pred, LBO->getOperand(0), + RBO->getOperand(0), Q, MaxRecurse - 1)) + return V; + break; case Instruction::AShr: if (!LBO->isExact() || !RBO->isExact()) break; @@ -2827,7 +2880,7 @@ static Value *simplifyICmpWithBinOp(CmpInst::Predicate Pred, Value *LHS, /// Simplify integer comparisons where at least one operand of the compare /// matches an integer min/max idiom. static Value *simplifyICmpWithMinMax(CmpInst::Predicate Pred, Value *LHS, - Value *RHS, const Query &Q, + Value *RHS, const SimplifyQuery &Q, unsigned MaxRecurse) { Type *ITy = GetCompareTy(LHS); // The return type. Value *A, *B; @@ -3031,7 +3084,7 @@ static Value *simplifyICmpWithMinMax(CmpInst::Predicate Pred, Value *LHS, /// Given operands for an ICmpInst, see if we can fold the result. /// If not, this returns null. static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, - const Query &Q, unsigned MaxRecurse) { + const SimplifyQuery &Q, unsigned MaxRecurse) { CmpInst::Predicate Pred = (CmpInst::Predicate)Predicate; assert(CmpInst::isIntPredicate(Pred) && "Not an integer compare!"); @@ -3064,8 +3117,8 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, // If both operands have range metadata, use the metadata // to simplify the comparison. if (isa<Instruction>(RHS) && isa<Instruction>(LHS)) { - auto RHS_Instr = dyn_cast<Instruction>(RHS); - auto LHS_Instr = dyn_cast<Instruction>(LHS); + auto RHS_Instr = cast<Instruction>(RHS); + auto LHS_Instr = cast<Instruction>(LHS); if (RHS_Instr->getMetadata(LLVMContext::MD_range) && LHS_Instr->getMetadata(LLVMContext::MD_range)) { @@ -3245,11 +3298,9 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, } // icmp eq|ne X, Y -> false|true if X != Y - if ((Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_NE) && + if (ICmpInst::isEquality(Pred) && isKnownNonEqual(LHS, RHS, Q.DL, Q.AC, Q.CxtI, Q.DT)) { - LLVMContext &Ctx = LHS->getType()->getContext(); - return Pred == ICmpInst::ICMP_NE ? - ConstantInt::getTrue(Ctx) : ConstantInt::getFalse(Ctx); + return Pred == ICmpInst::ICMP_NE ? getTrue(ITy) : getFalse(ITy); } if (Value *V = simplifyICmpWithBinOp(Pred, LHS, RHS, Q, MaxRecurse)) @@ -3297,22 +3348,6 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, } } - // If a bit is known to be zero for A and known to be one for B, - // then A and B cannot be equal. - if (ICmpInst::isEquality(Pred)) { - const APInt *RHSVal; - if (match(RHS, m_APInt(RHSVal))) { - unsigned BitWidth = RHSVal->getBitWidth(); - APInt LHSKnownZero(BitWidth, 0); - APInt LHSKnownOne(BitWidth, 0); - computeKnownBits(LHS, LHSKnownZero, LHSKnownOne, Q.DL, /*Depth=*/0, Q.AC, - Q.CxtI, Q.DT); - if (((LHSKnownZero & *RHSVal) != 0) || ((LHSKnownOne & ~(*RHSVal)) != 0)) - return Pred == ICmpInst::ICMP_EQ ? ConstantInt::getFalse(ITy) - : ConstantInt::getTrue(ITy); - } - } - // If the comparison is with the result of a select instruction, check whether // comparing with either branch of the select always yields the same value. if (isa<SelectInst>(LHS) || isa<SelectInst>(RHS)) @@ -3329,18 +3364,14 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, } Value *llvm::SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, - const DataLayout &DL, - const TargetLibraryInfo *TLI, - const DominatorTree *DT, AssumptionCache *AC, - const Instruction *CxtI) { - return ::SimplifyICmpInst(Predicate, LHS, RHS, Query(DL, TLI, DT, AC, CxtI), - RecursionLimit); + const SimplifyQuery &Q) { + return ::SimplifyICmpInst(Predicate, LHS, RHS, Q, RecursionLimit); } /// Given operands for an FCmpInst, see if we can fold the result. /// If not, this returns null. static Value *SimplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS, - FastMathFlags FMF, const Query &Q, + FastMathFlags FMF, const SimplifyQuery &Q, unsigned MaxRecurse) { CmpInst::Predicate Pred = (CmpInst::Predicate)Predicate; assert(CmpInst::isFPPredicate(Pred) && "Not an FP compare!"); @@ -3462,22 +3493,22 @@ static Value *SimplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS, } Value *llvm::SimplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS, - FastMathFlags FMF, const DataLayout &DL, - const TargetLibraryInfo *TLI, - const DominatorTree *DT, AssumptionCache *AC, - const Instruction *CxtI) { - return ::SimplifyFCmpInst(Predicate, LHS, RHS, FMF, - Query(DL, TLI, DT, AC, CxtI), RecursionLimit); + FastMathFlags FMF, const SimplifyQuery &Q) { + return ::SimplifyFCmpInst(Predicate, LHS, RHS, FMF, Q, RecursionLimit); } /// See if V simplifies when its operand Op is replaced with RepOp. static const Value *SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp, - const Query &Q, + const SimplifyQuery &Q, unsigned MaxRecurse) { // Trivial replacement. if (V == Op) return RepOp; + // We cannot replace a constant, and shouldn't even try. + if (isa<Constant>(Op)) + return nullptr; + auto *I = dyn_cast<Instruction>(V); if (!I) return nullptr; @@ -3620,7 +3651,7 @@ static Value *simplifySelectWithFakeICmpEq(Value *CmpLHS, Value *TrueVal, /// Try to simplify a select instruction when its condition operand is an /// integer comparison. static Value *simplifySelectWithICmpCond(Value *CondVal, Value *TrueVal, - Value *FalseVal, const Query &Q, + Value *FalseVal, const SimplifyQuery &Q, unsigned MaxRecurse) { ICmpInst::Predicate Pred; Value *CmpLHS, *CmpRHS; @@ -3699,7 +3730,7 @@ static Value *simplifySelectWithICmpCond(Value *CondVal, Value *TrueVal, /// Given operands for a SelectInst, see if we can fold the result. /// If not, this returns null. static Value *SimplifySelectInst(Value *CondVal, Value *TrueVal, - Value *FalseVal, const Query &Q, + Value *FalseVal, const SimplifyQuery &Q, unsigned MaxRecurse) { // select true, X, Y -> X // select false, X, Y -> Y @@ -3715,9 +3746,9 @@ static Value *SimplifySelectInst(Value *CondVal, Value *TrueVal, return TrueVal; if (isa<UndefValue>(CondVal)) { // select undef, X, Y -> X or Y - if (isa<Constant>(TrueVal)) - return TrueVal; - return FalseVal; + if (isa<Constant>(FalseVal)) + return FalseVal; + return TrueVal; } if (isa<UndefValue>(TrueVal)) // select C, undef, X -> X return FalseVal; @@ -3732,18 +3763,14 @@ static Value *SimplifySelectInst(Value *CondVal, Value *TrueVal, } Value *llvm::SimplifySelectInst(Value *Cond, Value *TrueVal, Value *FalseVal, - const DataLayout &DL, - const TargetLibraryInfo *TLI, - const DominatorTree *DT, AssumptionCache *AC, - const Instruction *CxtI) { - return ::SimplifySelectInst(Cond, TrueVal, FalseVal, - Query(DL, TLI, DT, AC, CxtI), RecursionLimit); + const SimplifyQuery &Q) { + return ::SimplifySelectInst(Cond, TrueVal, FalseVal, Q, RecursionLimit); } /// Given operands for an GetElementPtrInst, see if we can fold the result. /// If not, this returns null. static Value *SimplifyGEPInst(Type *SrcTy, ArrayRef<Value *> Ops, - const Query &Q, unsigned) { + const SimplifyQuery &Q, unsigned) { // The type of the GEP pointer operand. unsigned AS = cast<PointerType>(Ops[0]->getType()->getScalarType())->getAddressSpace(); @@ -3757,6 +3784,8 @@ static Value *SimplifyGEPInst(Type *SrcTy, ArrayRef<Value *> Ops, Type *GEPTy = PointerType::get(LastType, AS); if (VectorType *VT = dyn_cast<VectorType>(Ops[0]->getType())) GEPTy = VectorType::get(GEPTy, VT->getNumElements()); + else if (VectorType *VT = dyn_cast<VectorType>(Ops[1]->getType())) + GEPTy = VectorType::get(GEPTy, VT->getNumElements()); if (isa<UndefValue>(Ops[0])) return UndefValue::get(GEPTy); @@ -3842,27 +3871,25 @@ static Value *SimplifyGEPInst(Type *SrcTy, ArrayRef<Value *> Ops, } // Check to see if this is constant foldable. - for (unsigned i = 0, e = Ops.size(); i != e; ++i) - if (!isa<Constant>(Ops[i])) - return nullptr; + if (!all_of(Ops, [](Value *V) { return isa<Constant>(V); })) + return nullptr; - return ConstantExpr::getGetElementPtr(SrcTy, cast<Constant>(Ops[0]), - Ops.slice(1)); + auto *CE = ConstantExpr::getGetElementPtr(SrcTy, cast<Constant>(Ops[0]), + Ops.slice(1)); + if (auto *CEFolded = ConstantFoldConstant(CE, Q.DL)) + return CEFolded; + return CE; } Value *llvm::SimplifyGEPInst(Type *SrcTy, ArrayRef<Value *> Ops, - const DataLayout &DL, - const TargetLibraryInfo *TLI, - const DominatorTree *DT, AssumptionCache *AC, - const Instruction *CxtI) { - return ::SimplifyGEPInst(SrcTy, Ops, - Query(DL, TLI, DT, AC, CxtI), RecursionLimit); + const SimplifyQuery &Q) { + return ::SimplifyGEPInst(SrcTy, Ops, Q, RecursionLimit); } /// Given operands for an InsertValueInst, see if we can fold the result. /// If not, this returns null. static Value *SimplifyInsertValueInst(Value *Agg, Value *Val, - ArrayRef<unsigned> Idxs, const Query &Q, + ArrayRef<unsigned> Idxs, const SimplifyQuery &Q, unsigned) { if (Constant *CAgg = dyn_cast<Constant>(Agg)) if (Constant *CVal = dyn_cast<Constant>(Val)) @@ -3888,18 +3915,16 @@ static Value *SimplifyInsertValueInst(Value *Agg, Value *Val, return nullptr; } -Value *llvm::SimplifyInsertValueInst( - Value *Agg, Value *Val, ArrayRef<unsigned> Idxs, const DataLayout &DL, - const TargetLibraryInfo *TLI, const DominatorTree *DT, AssumptionCache *AC, - const Instruction *CxtI) { - return ::SimplifyInsertValueInst(Agg, Val, Idxs, Query(DL, TLI, DT, AC, CxtI), - RecursionLimit); +Value *llvm::SimplifyInsertValueInst(Value *Agg, Value *Val, + ArrayRef<unsigned> Idxs, + const SimplifyQuery &Q) { + return ::SimplifyInsertValueInst(Agg, Val, Idxs, Q, RecursionLimit); } /// Given operands for an ExtractValueInst, see if we can fold the result. /// If not, this returns null. static Value *SimplifyExtractValueInst(Value *Agg, ArrayRef<unsigned> Idxs, - const Query &, unsigned) { + const SimplifyQuery &, unsigned) { if (auto *CAgg = dyn_cast<Constant>(Agg)) return ConstantFoldExtractValueInstruction(CAgg, Idxs); @@ -3922,18 +3947,13 @@ static Value *SimplifyExtractValueInst(Value *Agg, ArrayRef<unsigned> Idxs, } Value *llvm::SimplifyExtractValueInst(Value *Agg, ArrayRef<unsigned> Idxs, - const DataLayout &DL, - const TargetLibraryInfo *TLI, - const DominatorTree *DT, - AssumptionCache *AC, - const Instruction *CxtI) { - return ::SimplifyExtractValueInst(Agg, Idxs, Query(DL, TLI, DT, AC, CxtI), - RecursionLimit); + const SimplifyQuery &Q) { + return ::SimplifyExtractValueInst(Agg, Idxs, Q, RecursionLimit); } /// Given operands for an ExtractElementInst, see if we can fold the result. /// If not, this returns null. -static Value *SimplifyExtractElementInst(Value *Vec, Value *Idx, const Query &, +static Value *SimplifyExtractElementInst(Value *Vec, Value *Idx, const SimplifyQuery &, unsigned) { if (auto *CVec = dyn_cast<Constant>(Vec)) { if (auto *CIdx = dyn_cast<Constant>(Idx)) @@ -3956,15 +3976,13 @@ static Value *SimplifyExtractElementInst(Value *Vec, Value *Idx, const Query &, return nullptr; } -Value *llvm::SimplifyExtractElementInst( - Value *Vec, Value *Idx, const DataLayout &DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT, AssumptionCache *AC, const Instruction *CxtI) { - return ::SimplifyExtractElementInst(Vec, Idx, Query(DL, TLI, DT, AC, CxtI), - RecursionLimit); +Value *llvm::SimplifyExtractElementInst(Value *Vec, Value *Idx, + const SimplifyQuery &Q) { + return ::SimplifyExtractElementInst(Vec, Idx, Q, RecursionLimit); } /// See if we can fold the given phi. If not, returns null. -static Value *SimplifyPHINode(PHINode *PN, const Query &Q) { +static Value *SimplifyPHINode(PHINode *PN, const SimplifyQuery &Q) { // If all of the PHI's incoming values are the same then replace the PHI node // with the common value. Value *CommonValue = nullptr; @@ -3997,7 +4015,7 @@ static Value *SimplifyPHINode(PHINode *PN, const Query &Q) { } static Value *SimplifyCastInst(unsigned CastOpc, Value *Op, - Type *Ty, const Query &Q, unsigned MaxRecurse) { + Type *Ty, const SimplifyQuery &Q, unsigned MaxRecurse) { if (auto *C = dyn_cast<Constant>(Op)) return ConstantFoldCastOperand(CastOpc, C, Ty, Q.DL); @@ -4031,12 +4049,141 @@ static Value *SimplifyCastInst(unsigned CastOpc, Value *Op, } Value *llvm::SimplifyCastInst(unsigned CastOpc, Value *Op, Type *Ty, - const DataLayout &DL, - const TargetLibraryInfo *TLI, - const DominatorTree *DT, AssumptionCache *AC, - const Instruction *CxtI) { - return ::SimplifyCastInst(CastOpc, Op, Ty, Query(DL, TLI, DT, AC, CxtI), - RecursionLimit); + const SimplifyQuery &Q) { + return ::SimplifyCastInst(CastOpc, Op, Ty, Q, RecursionLimit); +} + +/// For the given destination element of a shuffle, peek through shuffles to +/// match a root vector source operand that contains that element in the same +/// vector lane (ie, the same mask index), so we can eliminate the shuffle(s). +static Value *foldIdentityShuffles(int DestElt, Value *Op0, Value *Op1, + int MaskVal, Value *RootVec, + unsigned MaxRecurse) { + if (!MaxRecurse--) + return nullptr; + + // Bail out if any mask value is undefined. That kind of shuffle may be + // simplified further based on demanded bits or other folds. + if (MaskVal == -1) + return nullptr; + + // The mask value chooses which source operand we need to look at next. + int InVecNumElts = Op0->getType()->getVectorNumElements(); + int RootElt = MaskVal; + Value *SourceOp = Op0; + if (MaskVal >= InVecNumElts) { + RootElt = MaskVal - InVecNumElts; + SourceOp = Op1; + } + + // If the source operand is a shuffle itself, look through it to find the + // matching root vector. + if (auto *SourceShuf = dyn_cast<ShuffleVectorInst>(SourceOp)) { + return foldIdentityShuffles( + DestElt, SourceShuf->getOperand(0), SourceShuf->getOperand(1), + SourceShuf->getMaskValue(RootElt), RootVec, MaxRecurse); + } + + // TODO: Look through bitcasts? What if the bitcast changes the vector element + // size? + + // The source operand is not a shuffle. Initialize the root vector value for + // this shuffle if that has not been done yet. + if (!RootVec) + RootVec = SourceOp; + + // Give up as soon as a source operand does not match the existing root value. + if (RootVec != SourceOp) + return nullptr; + + // The element must be coming from the same lane in the source vector + // (although it may have crossed lanes in intermediate shuffles). + if (RootElt != DestElt) + return nullptr; + + return RootVec; +} + +static Value *SimplifyShuffleVectorInst(Value *Op0, Value *Op1, Constant *Mask, + Type *RetTy, const SimplifyQuery &Q, + unsigned MaxRecurse) { + if (isa<UndefValue>(Mask)) + return UndefValue::get(RetTy); + + Type *InVecTy = Op0->getType(); + unsigned MaskNumElts = Mask->getType()->getVectorNumElements(); + unsigned InVecNumElts = InVecTy->getVectorNumElements(); + + SmallVector<int, 32> Indices; + ShuffleVectorInst::getShuffleMask(Mask, Indices); + assert(MaskNumElts == Indices.size() && + "Size of Indices not same as number of mask elements?"); + + // Canonicalization: If mask does not select elements from an input vector, + // replace that input vector with undef. + bool MaskSelects0 = false, MaskSelects1 = false; + for (unsigned i = 0; i != MaskNumElts; ++i) { + if (Indices[i] == -1) + continue; + if ((unsigned)Indices[i] < InVecNumElts) + MaskSelects0 = true; + else + MaskSelects1 = true; + } + if (!MaskSelects0) + Op0 = UndefValue::get(InVecTy); + if (!MaskSelects1) + Op1 = UndefValue::get(InVecTy); + + auto *Op0Const = dyn_cast<Constant>(Op0); + auto *Op1Const = dyn_cast<Constant>(Op1); + + // If all operands are constant, constant fold the shuffle. + if (Op0Const && Op1Const) + return ConstantFoldShuffleVectorInstruction(Op0Const, Op1Const, Mask); + + // Canonicalization: if only one input vector is constant, it shall be the + // second one. + if (Op0Const && !Op1Const) { + std::swap(Op0, Op1); + ShuffleVectorInst::commuteShuffleMask(Indices, InVecNumElts); + } + + // A shuffle of a splat is always the splat itself. Legal if the shuffle's + // value type is same as the input vectors' type. + if (auto *OpShuf = dyn_cast<ShuffleVectorInst>(Op0)) + if (isa<UndefValue>(Op1) && RetTy == InVecTy && + OpShuf->getMask()->getSplatValue()) + return Op0; + + // Don't fold a shuffle with undef mask elements. This may get folded in a + // better way using demanded bits or other analysis. + // TODO: Should we allow this? + if (find(Indices, -1) != Indices.end()) + return nullptr; + + // Check if every element of this shuffle can be mapped back to the + // corresponding element of a single root vector. If so, we don't need this + // shuffle. This handles simple identity shuffles as well as chains of + // shuffles that may widen/narrow and/or move elements across lanes and back. + Value *RootVec = nullptr; + for (unsigned i = 0; i != MaskNumElts; ++i) { + // Note that recursion is limited for each vector element, so if any element + // exceeds the limit, this will fail to simplify. + RootVec = + foldIdentityShuffles(i, Op0, Op1, Indices[i], RootVec, MaxRecurse); + + // We can't replace a widening/narrowing shuffle with one of its operands. + if (!RootVec || RootVec->getType() != RetTy) + return nullptr; + } + return RootVec; +} + +/// Given operands for a ShuffleVectorInst, fold the result or return null. +Value *llvm::SimplifyShuffleVectorInst(Value *Op0, Value *Op1, Constant *Mask, + Type *RetTy, const SimplifyQuery &Q) { + return ::SimplifyShuffleVectorInst(Op0, Op1, Mask, RetTy, Q, RecursionLimit); } //=== Helper functions for higher up the class hierarchy. @@ -4044,64 +4191,46 @@ Value *llvm::SimplifyCastInst(unsigned CastOpc, Value *Op, Type *Ty, /// Given operands for a BinaryOperator, see if we can fold the result. /// If not, this returns null. static Value *SimplifyBinOp(unsigned Opcode, Value *LHS, Value *RHS, - const Query &Q, unsigned MaxRecurse) { + const SimplifyQuery &Q, unsigned MaxRecurse) { switch (Opcode) { case Instruction::Add: - return SimplifyAddInst(LHS, RHS, /*isNSW*/false, /*isNUW*/false, - Q, MaxRecurse); + return SimplifyAddInst(LHS, RHS, false, false, Q, MaxRecurse); case Instruction::FAdd: return SimplifyFAddInst(LHS, RHS, FastMathFlags(), Q, MaxRecurse); - case Instruction::Sub: - return SimplifySubInst(LHS, RHS, /*isNSW*/false, /*isNUW*/false, - Q, MaxRecurse); + return SimplifySubInst(LHS, RHS, false, false, Q, MaxRecurse); case Instruction::FSub: return SimplifyFSubInst(LHS, RHS, FastMathFlags(), Q, MaxRecurse); - - case Instruction::Mul: return SimplifyMulInst (LHS, RHS, Q, MaxRecurse); + case Instruction::Mul: + return SimplifyMulInst(LHS, RHS, Q, MaxRecurse); case Instruction::FMul: - return SimplifyFMulInst (LHS, RHS, FastMathFlags(), Q, MaxRecurse); - case Instruction::SDiv: return SimplifySDivInst(LHS, RHS, Q, MaxRecurse); - case Instruction::UDiv: return SimplifyUDivInst(LHS, RHS, Q, MaxRecurse); + return SimplifyFMulInst(LHS, RHS, FastMathFlags(), Q, MaxRecurse); + case Instruction::SDiv: + return SimplifySDivInst(LHS, RHS, Q, MaxRecurse); + case Instruction::UDiv: + return SimplifyUDivInst(LHS, RHS, Q, MaxRecurse); case Instruction::FDiv: - return SimplifyFDivInst(LHS, RHS, FastMathFlags(), Q, MaxRecurse); - case Instruction::SRem: return SimplifySRemInst(LHS, RHS, Q, MaxRecurse); - case Instruction::URem: return SimplifyURemInst(LHS, RHS, Q, MaxRecurse); + return SimplifyFDivInst(LHS, RHS, FastMathFlags(), Q, MaxRecurse); + case Instruction::SRem: + return SimplifySRemInst(LHS, RHS, Q, MaxRecurse); + case Instruction::URem: + return SimplifyURemInst(LHS, RHS, Q, MaxRecurse); case Instruction::FRem: - return SimplifyFRemInst(LHS, RHS, FastMathFlags(), Q, MaxRecurse); + return SimplifyFRemInst(LHS, RHS, FastMathFlags(), Q, MaxRecurse); case Instruction::Shl: - return SimplifyShlInst(LHS, RHS, /*isNSW*/false, /*isNUW*/false, - Q, MaxRecurse); + return SimplifyShlInst(LHS, RHS, false, false, Q, MaxRecurse); case Instruction::LShr: - return SimplifyLShrInst(LHS, RHS, /*isExact*/false, Q, MaxRecurse); + return SimplifyLShrInst(LHS, RHS, false, Q, MaxRecurse); case Instruction::AShr: - return SimplifyAShrInst(LHS, RHS, /*isExact*/false, Q, MaxRecurse); - case Instruction::And: return SimplifyAndInst(LHS, RHS, Q, MaxRecurse); - case Instruction::Or: return SimplifyOrInst (LHS, RHS, Q, MaxRecurse); - case Instruction::Xor: return SimplifyXorInst(LHS, RHS, Q, MaxRecurse); + return SimplifyAShrInst(LHS, RHS, false, Q, MaxRecurse); + case Instruction::And: + return SimplifyAndInst(LHS, RHS, Q, MaxRecurse); + case Instruction::Or: + return SimplifyOrInst(LHS, RHS, Q, MaxRecurse); + case Instruction::Xor: + return SimplifyXorInst(LHS, RHS, Q, MaxRecurse); default: - if (Constant *CLHS = dyn_cast<Constant>(LHS)) - if (Constant *CRHS = dyn_cast<Constant>(RHS)) - return ConstantFoldBinaryOpOperands(Opcode, CLHS, CRHS, Q.DL); - - // If the operation is associative, try some generic simplifications. - if (Instruction::isAssociative(Opcode)) - if (Value *V = SimplifyAssociativeBinOp(Opcode, LHS, RHS, Q, MaxRecurse)) - return V; - - // If the operation is with the result of a select instruction check whether - // operating on either branch of the select always yields the same value. - if (isa<SelectInst>(LHS) || isa<SelectInst>(RHS)) - if (Value *V = ThreadBinOpOverSelect(Opcode, LHS, RHS, Q, MaxRecurse)) - return V; - - // If the operation is with the result of a phi instruction, check whether - // operating on all incoming values of the phi always yields the same value. - if (isa<PHINode>(LHS) || isa<PHINode>(RHS)) - if (Value *V = ThreadBinOpOverPHI(Opcode, LHS, RHS, Q, MaxRecurse)) - return V; - - return nullptr; + llvm_unreachable("Unexpected opcode"); } } @@ -4110,7 +4239,7 @@ static Value *SimplifyBinOp(unsigned Opcode, Value *LHS, Value *RHS, /// In contrast to SimplifyBinOp, try to use FastMathFlag when folding the /// result. In case we don't need FastMathFlags, simply fall to SimplifyBinOp. static Value *SimplifyFPBinOp(unsigned Opcode, Value *LHS, Value *RHS, - const FastMathFlags &FMF, const Query &Q, + const FastMathFlags &FMF, const SimplifyQuery &Q, unsigned MaxRecurse) { switch (Opcode) { case Instruction::FAdd: @@ -4127,36 +4256,26 @@ static Value *SimplifyFPBinOp(unsigned Opcode, Value *LHS, Value *RHS, } Value *llvm::SimplifyBinOp(unsigned Opcode, Value *LHS, Value *RHS, - const DataLayout &DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT, AssumptionCache *AC, - const Instruction *CxtI) { - return ::SimplifyBinOp(Opcode, LHS, RHS, Query(DL, TLI, DT, AC, CxtI), - RecursionLimit); + const SimplifyQuery &Q) { + return ::SimplifyBinOp(Opcode, LHS, RHS, Q, RecursionLimit); } Value *llvm::SimplifyFPBinOp(unsigned Opcode, Value *LHS, Value *RHS, - const FastMathFlags &FMF, const DataLayout &DL, - const TargetLibraryInfo *TLI, - const DominatorTree *DT, AssumptionCache *AC, - const Instruction *CxtI) { - return ::SimplifyFPBinOp(Opcode, LHS, RHS, FMF, Query(DL, TLI, DT, AC, CxtI), - RecursionLimit); + FastMathFlags FMF, const SimplifyQuery &Q) { + return ::SimplifyFPBinOp(Opcode, LHS, RHS, FMF, Q, RecursionLimit); } /// Given operands for a CmpInst, see if we can fold the result. static Value *SimplifyCmpInst(unsigned Predicate, Value *LHS, Value *RHS, - const Query &Q, unsigned MaxRecurse) { + const SimplifyQuery &Q, unsigned MaxRecurse) { if (CmpInst::isIntPredicate((CmpInst::Predicate)Predicate)) return SimplifyICmpInst(Predicate, LHS, RHS, Q, MaxRecurse); return SimplifyFCmpInst(Predicate, LHS, RHS, FastMathFlags(), Q, MaxRecurse); } Value *llvm::SimplifyCmpInst(unsigned Predicate, Value *LHS, Value *RHS, - const DataLayout &DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT, AssumptionCache *AC, - const Instruction *CxtI) { - return ::SimplifyCmpInst(Predicate, LHS, RHS, Query(DL, TLI, DT, AC, CxtI), - RecursionLimit); + const SimplifyQuery &Q) { + return ::SimplifyCmpInst(Predicate, LHS, RHS, Q, RecursionLimit); } static bool IsIdempotent(Intrinsic::ID ID) { @@ -4249,7 +4368,7 @@ static bool maskIsAllZeroOrUndef(Value *Mask) { template <typename IterTy> static Value *SimplifyIntrinsic(Function *F, IterTy ArgBegin, IterTy ArgEnd, - const Query &Q, unsigned MaxRecurse) { + const SimplifyQuery &Q, unsigned MaxRecurse) { Intrinsic::ID IID = F->getIntrinsicID(); unsigned NumOperands = std::distance(ArgBegin, ArgEnd); @@ -4267,6 +4386,7 @@ static Value *SimplifyIntrinsic(Function *F, IterTy ArgBegin, IterTy ArgEnd, case Intrinsic::fabs: { if (SignBitMustBeZero(*ArgBegin, Q.TLI)) return *ArgBegin; + return nullptr; } default: return nullptr; @@ -4296,19 +4416,21 @@ static Value *SimplifyIntrinsic(Function *F, IterTy ArgBegin, IterTy ArgEnd, case Intrinsic::uadd_with_overflow: case Intrinsic::sadd_with_overflow: { // X + undef -> undef - if (isa<UndefValue>(RHS)) + if (isa<UndefValue>(LHS) || isa<UndefValue>(RHS)) return UndefValue::get(ReturnType); return nullptr; } case Intrinsic::umul_with_overflow: case Intrinsic::smul_with_overflow: { + // 0 * X -> { 0, false } // X * 0 -> { 0, false } - if (match(RHS, m_Zero())) + if (match(LHS, m_Zero()) || match(RHS, m_Zero())) return Constant::getNullValue(ReturnType); + // undef * X -> { 0, false } // X * undef -> { 0, false } - if (match(RHS, m_Undef())) + if (match(LHS, m_Undef()) || match(RHS, m_Undef())) return Constant::getNullValue(ReturnType); return nullptr; @@ -4341,8 +4463,9 @@ static Value *SimplifyIntrinsic(Function *F, IterTy ArgBegin, IterTy ArgEnd, } template <typename IterTy> -static Value *SimplifyCall(Value *V, IterTy ArgBegin, IterTy ArgEnd, - const Query &Q, unsigned MaxRecurse) { +static Value *SimplifyCall(ImmutableCallSite CS, Value *V, IterTy ArgBegin, + IterTy ArgEnd, const SimplifyQuery &Q, + unsigned MaxRecurse) { Type *Ty = V->getType(); if (PointerType *PTy = dyn_cast<PointerType>(Ty)) Ty = PTy->getElementType(); @@ -4361,7 +4484,7 @@ static Value *SimplifyCall(Value *V, IterTy ArgBegin, IterTy ArgEnd, if (Value *Ret = SimplifyIntrinsic(F, ArgBegin, ArgEnd, Q, MaxRecurse)) return Ret; - if (!canConstantFoldCallTo(F)) + if (!canConstantFoldCallTo(CS, F)) return nullptr; SmallVector<Constant *, 4> ConstantArgs; @@ -4373,181 +4496,170 @@ static Value *SimplifyCall(Value *V, IterTy ArgBegin, IterTy ArgEnd, ConstantArgs.push_back(C); } - return ConstantFoldCall(F, ConstantArgs, Q.TLI); + return ConstantFoldCall(CS, F, ConstantArgs, Q.TLI); } -Value *llvm::SimplifyCall(Value *V, User::op_iterator ArgBegin, - User::op_iterator ArgEnd, const DataLayout &DL, - const TargetLibraryInfo *TLI, const DominatorTree *DT, - AssumptionCache *AC, const Instruction *CxtI) { - return ::SimplifyCall(V, ArgBegin, ArgEnd, Query(DL, TLI, DT, AC, CxtI), - RecursionLimit); +Value *llvm::SimplifyCall(ImmutableCallSite CS, Value *V, + User::op_iterator ArgBegin, User::op_iterator ArgEnd, + const SimplifyQuery &Q) { + return ::SimplifyCall(CS, V, ArgBegin, ArgEnd, Q, RecursionLimit); } -Value *llvm::SimplifyCall(Value *V, ArrayRef<Value *> Args, - const DataLayout &DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT, AssumptionCache *AC, - const Instruction *CxtI) { - return ::SimplifyCall(V, Args.begin(), Args.end(), - Query(DL, TLI, DT, AC, CxtI), RecursionLimit); +Value *llvm::SimplifyCall(ImmutableCallSite CS, Value *V, + ArrayRef<Value *> Args, const SimplifyQuery &Q) { + return ::SimplifyCall(CS, V, Args.begin(), Args.end(), Q, RecursionLimit); } /// See if we can compute a simplified version of this instruction. /// If not, this returns null. -Value *llvm::SimplifyInstruction(Instruction *I, const DataLayout &DL, - const TargetLibraryInfo *TLI, - const DominatorTree *DT, AssumptionCache *AC) { + +Value *llvm::SimplifyInstruction(Instruction *I, const SimplifyQuery &SQ, + OptimizationRemarkEmitter *ORE) { + const SimplifyQuery Q = SQ.CxtI ? SQ : SQ.getWithInstruction(I); Value *Result; switch (I->getOpcode()) { default: - Result = ConstantFoldInstruction(I, DL, TLI); + Result = ConstantFoldInstruction(I, Q.DL, Q.TLI); break; case Instruction::FAdd: Result = SimplifyFAddInst(I->getOperand(0), I->getOperand(1), - I->getFastMathFlags(), DL, TLI, DT, AC, I); + I->getFastMathFlags(), Q); break; case Instruction::Add: Result = SimplifyAddInst(I->getOperand(0), I->getOperand(1), cast<BinaryOperator>(I)->hasNoSignedWrap(), - cast<BinaryOperator>(I)->hasNoUnsignedWrap(), DL, - TLI, DT, AC, I); + cast<BinaryOperator>(I)->hasNoUnsignedWrap(), Q); break; case Instruction::FSub: Result = SimplifyFSubInst(I->getOperand(0), I->getOperand(1), - I->getFastMathFlags(), DL, TLI, DT, AC, I); + I->getFastMathFlags(), Q); break; case Instruction::Sub: Result = SimplifySubInst(I->getOperand(0), I->getOperand(1), cast<BinaryOperator>(I)->hasNoSignedWrap(), - cast<BinaryOperator>(I)->hasNoUnsignedWrap(), DL, - TLI, DT, AC, I); + cast<BinaryOperator>(I)->hasNoUnsignedWrap(), Q); break; case Instruction::FMul: Result = SimplifyFMulInst(I->getOperand(0), I->getOperand(1), - I->getFastMathFlags(), DL, TLI, DT, AC, I); + I->getFastMathFlags(), Q); break; case Instruction::Mul: - Result = - SimplifyMulInst(I->getOperand(0), I->getOperand(1), DL, TLI, DT, AC, I); + Result = SimplifyMulInst(I->getOperand(0), I->getOperand(1), Q); break; case Instruction::SDiv: - Result = SimplifySDivInst(I->getOperand(0), I->getOperand(1), DL, TLI, DT, - AC, I); + Result = SimplifySDivInst(I->getOperand(0), I->getOperand(1), Q); break; case Instruction::UDiv: - Result = SimplifyUDivInst(I->getOperand(0), I->getOperand(1), DL, TLI, DT, - AC, I); + Result = SimplifyUDivInst(I->getOperand(0), I->getOperand(1), Q); break; case Instruction::FDiv: Result = SimplifyFDivInst(I->getOperand(0), I->getOperand(1), - I->getFastMathFlags(), DL, TLI, DT, AC, I); + I->getFastMathFlags(), Q); break; case Instruction::SRem: - Result = SimplifySRemInst(I->getOperand(0), I->getOperand(1), DL, TLI, DT, - AC, I); + Result = SimplifySRemInst(I->getOperand(0), I->getOperand(1), Q); break; case Instruction::URem: - Result = SimplifyURemInst(I->getOperand(0), I->getOperand(1), DL, TLI, DT, - AC, I); + Result = SimplifyURemInst(I->getOperand(0), I->getOperand(1), Q); break; case Instruction::FRem: Result = SimplifyFRemInst(I->getOperand(0), I->getOperand(1), - I->getFastMathFlags(), DL, TLI, DT, AC, I); + I->getFastMathFlags(), Q); break; case Instruction::Shl: Result = SimplifyShlInst(I->getOperand(0), I->getOperand(1), cast<BinaryOperator>(I)->hasNoSignedWrap(), - cast<BinaryOperator>(I)->hasNoUnsignedWrap(), DL, - TLI, DT, AC, I); + cast<BinaryOperator>(I)->hasNoUnsignedWrap(), Q); break; case Instruction::LShr: Result = SimplifyLShrInst(I->getOperand(0), I->getOperand(1), - cast<BinaryOperator>(I)->isExact(), DL, TLI, DT, - AC, I); + cast<BinaryOperator>(I)->isExact(), Q); break; case Instruction::AShr: Result = SimplifyAShrInst(I->getOperand(0), I->getOperand(1), - cast<BinaryOperator>(I)->isExact(), DL, TLI, DT, - AC, I); + cast<BinaryOperator>(I)->isExact(), Q); break; case Instruction::And: - Result = - SimplifyAndInst(I->getOperand(0), I->getOperand(1), DL, TLI, DT, AC, I); + Result = SimplifyAndInst(I->getOperand(0), I->getOperand(1), Q); break; case Instruction::Or: - Result = - SimplifyOrInst(I->getOperand(0), I->getOperand(1), DL, TLI, DT, AC, I); + Result = SimplifyOrInst(I->getOperand(0), I->getOperand(1), Q); break; case Instruction::Xor: - Result = - SimplifyXorInst(I->getOperand(0), I->getOperand(1), DL, TLI, DT, AC, I); + Result = SimplifyXorInst(I->getOperand(0), I->getOperand(1), Q); break; case Instruction::ICmp: - Result = - SimplifyICmpInst(cast<ICmpInst>(I)->getPredicate(), I->getOperand(0), - I->getOperand(1), DL, TLI, DT, AC, I); + Result = SimplifyICmpInst(cast<ICmpInst>(I)->getPredicate(), + I->getOperand(0), I->getOperand(1), Q); break; case Instruction::FCmp: - Result = SimplifyFCmpInst(cast<FCmpInst>(I)->getPredicate(), - I->getOperand(0), I->getOperand(1), - I->getFastMathFlags(), DL, TLI, DT, AC, I); + Result = + SimplifyFCmpInst(cast<FCmpInst>(I)->getPredicate(), I->getOperand(0), + I->getOperand(1), I->getFastMathFlags(), Q); break; case Instruction::Select: Result = SimplifySelectInst(I->getOperand(0), I->getOperand(1), - I->getOperand(2), DL, TLI, DT, AC, I); + I->getOperand(2), Q); break; case Instruction::GetElementPtr: { - SmallVector<Value*, 8> Ops(I->op_begin(), I->op_end()); + SmallVector<Value *, 8> Ops(I->op_begin(), I->op_end()); Result = SimplifyGEPInst(cast<GetElementPtrInst>(I)->getSourceElementType(), - Ops, DL, TLI, DT, AC, I); + Ops, Q); break; } case Instruction::InsertValue: { InsertValueInst *IV = cast<InsertValueInst>(I); Result = SimplifyInsertValueInst(IV->getAggregateOperand(), IV->getInsertedValueOperand(), - IV->getIndices(), DL, TLI, DT, AC, I); + IV->getIndices(), Q); break; } case Instruction::ExtractValue: { auto *EVI = cast<ExtractValueInst>(I); Result = SimplifyExtractValueInst(EVI->getAggregateOperand(), - EVI->getIndices(), DL, TLI, DT, AC, I); + EVI->getIndices(), Q); break; } case Instruction::ExtractElement: { auto *EEI = cast<ExtractElementInst>(I); - Result = SimplifyExtractElementInst( - EEI->getVectorOperand(), EEI->getIndexOperand(), DL, TLI, DT, AC, I); + Result = SimplifyExtractElementInst(EEI->getVectorOperand(), + EEI->getIndexOperand(), Q); + break; + } + case Instruction::ShuffleVector: { + auto *SVI = cast<ShuffleVectorInst>(I); + Result = SimplifyShuffleVectorInst(SVI->getOperand(0), SVI->getOperand(1), + SVI->getMask(), SVI->getType(), Q); break; } case Instruction::PHI: - Result = SimplifyPHINode(cast<PHINode>(I), Query(DL, TLI, DT, AC, I)); + Result = SimplifyPHINode(cast<PHINode>(I), Q); break; case Instruction::Call: { CallSite CS(cast<CallInst>(I)); - Result = SimplifyCall(CS.getCalledValue(), CS.arg_begin(), CS.arg_end(), DL, - TLI, DT, AC, I); + Result = SimplifyCall(CS, CS.getCalledValue(), CS.arg_begin(), CS.arg_end(), + Q); break; } #define HANDLE_CAST_INST(num, opc, clas) case Instruction::opc: #include "llvm/IR/Instruction.def" #undef HANDLE_CAST_INST - Result = SimplifyCastInst(I->getOpcode(), I->getOperand(0), I->getType(), - DL, TLI, DT, AC, I); + Result = + SimplifyCastInst(I->getOpcode(), I->getOperand(0), I->getType(), Q); + break; + case Instruction::Alloca: + // No simplifications for Alloca and it can't be constant folded. + Result = nullptr; break; } // In general, it is possible for computeKnownBits to determine all bits in a // value even when the operands are not all constants. if (!Result && I->getType()->isIntOrIntVectorTy()) { - unsigned BitWidth = I->getType()->getScalarSizeInBits(); - APInt KnownZero(BitWidth, 0); - APInt KnownOne(BitWidth, 0); - computeKnownBits(I, KnownZero, KnownOne, DL, /*Depth*/0, AC, I, DT); - if ((KnownZero | KnownOne).isAllOnesValue()) - Result = ConstantInt::get(I->getType(), KnownOne); + KnownBits Known = computeKnownBits(I, Q.DL, /*Depth*/ 0, Q.AC, I, Q.DT, ORE); + if (Known.isConstant()) + Result = ConstantInt::get(I->getType(), Known.getConstant()); } /// If called on unreachable code, the above logic may report that the @@ -4599,7 +4711,7 @@ static bool replaceAndRecursivelySimplifyImpl(Instruction *I, Value *SimpleV, I = Worklist[Idx]; // See if this instruction simplifies. - SimpleV = SimplifyInstruction(I, DL, TLI, DT, AC); + SimpleV = SimplifyInstruction(I, {DL, TLI, DT, AC}); if (!SimpleV) continue; @@ -4638,3 +4750,31 @@ bool llvm::replaceAndRecursivelySimplify(Instruction *I, Value *SimpleV, assert(SimpleV && "Must provide a simplified value."); return replaceAndRecursivelySimplifyImpl(I, SimpleV, TLI, DT, AC); } + +namespace llvm { +const SimplifyQuery getBestSimplifyQuery(Pass &P, Function &F) { + auto *DTWP = P.getAnalysisIfAvailable<DominatorTreeWrapperPass>(); + auto *DT = DTWP ? &DTWP->getDomTree() : nullptr; + auto *TLIWP = P.getAnalysisIfAvailable<TargetLibraryInfoWrapperPass>(); + auto *TLI = TLIWP ? &TLIWP->getTLI() : nullptr; + auto *ACWP = P.getAnalysisIfAvailable<AssumptionCacheTracker>(); + auto *AC = ACWP ? &ACWP->getAssumptionCache(F) : nullptr; + return {F.getParent()->getDataLayout(), TLI, DT, AC}; +} + +const SimplifyQuery getBestSimplifyQuery(LoopStandardAnalysisResults &AR, + const DataLayout &DL) { + return {DL, &AR.TLI, &AR.DT, &AR.AC}; +} + +template <class T, class... TArgs> +const SimplifyQuery getBestSimplifyQuery(AnalysisManager<T, TArgs...> &AM, + Function &F) { + auto *DT = AM.template getCachedResult<DominatorTreeAnalysis>(F); + auto *TLI = AM.template getCachedResult<TargetLibraryAnalysis>(F); + auto *AC = AM.template getCachedResult<AssumptionAnalysis>(F); + return {F.getParent()->getDataLayout(), TLI, DT, AC}; +} +template const SimplifyQuery getBestSimplifyQuery(AnalysisManager<Function> &, + Function &); +} |