diff options
Diffstat (limited to 'contrib/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp')
-rw-r--r-- | contrib/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp | 4391 |
1 files changed, 2278 insertions, 2113 deletions
diff --git a/contrib/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/contrib/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp index 961497f..428f94b 100644 --- a/contrib/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/contrib/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -35,17 +35,12 @@ using namespace PatternMatch; // How many times is a select replaced by one of its operands? STATISTIC(NumSel, "Number of select opts"); -// Initialization Routines -static ConstantInt *getOne(Constant *C) { - return ConstantInt::get(cast<IntegerType>(C->getType()), 1); -} - -static ConstantInt *ExtractElement(Constant *V, Constant *Idx) { +static ConstantInt *extractElement(Constant *V, Constant *Idx) { return cast<ConstantInt>(ConstantExpr::getExtractElement(V, Idx)); } -static bool HasAddOverflow(ConstantInt *Result, +static bool hasAddOverflow(ConstantInt *Result, ConstantInt *In1, ConstantInt *In2, bool IsSigned) { if (!IsSigned) @@ -58,28 +53,28 @@ static bool HasAddOverflow(ConstantInt *Result, /// Compute Result = In1+In2, returning true if the result overflowed for this /// type. -static bool AddWithOverflow(Constant *&Result, Constant *In1, +static bool addWithOverflow(Constant *&Result, Constant *In1, Constant *In2, bool IsSigned = false) { Result = ConstantExpr::getAdd(In1, In2); if (VectorType *VTy = dyn_cast<VectorType>(In1->getType())) { for (unsigned i = 0, e = VTy->getNumElements(); i != e; ++i) { Constant *Idx = ConstantInt::get(Type::getInt32Ty(In1->getContext()), i); - if (HasAddOverflow(ExtractElement(Result, Idx), - ExtractElement(In1, Idx), - ExtractElement(In2, Idx), + if (hasAddOverflow(extractElement(Result, Idx), + extractElement(In1, Idx), + extractElement(In2, Idx), IsSigned)) return true; } return false; } - return HasAddOverflow(cast<ConstantInt>(Result), + return hasAddOverflow(cast<ConstantInt>(Result), cast<ConstantInt>(In1), cast<ConstantInt>(In2), IsSigned); } -static bool HasSubOverflow(ConstantInt *Result, +static bool hasSubOverflow(ConstantInt *Result, ConstantInt *In1, ConstantInt *In2, bool IsSigned) { if (!IsSigned) @@ -93,23 +88,23 @@ static bool HasSubOverflow(ConstantInt *Result, /// Compute Result = In1-In2, returning true if the result overflowed for this /// type. -static bool SubWithOverflow(Constant *&Result, Constant *In1, +static bool subWithOverflow(Constant *&Result, Constant *In1, Constant *In2, bool IsSigned = false) { Result = ConstantExpr::getSub(In1, In2); if (VectorType *VTy = dyn_cast<VectorType>(In1->getType())) { for (unsigned i = 0, e = VTy->getNumElements(); i != e; ++i) { Constant *Idx = ConstantInt::get(Type::getInt32Ty(In1->getContext()), i); - if (HasSubOverflow(ExtractElement(Result, Idx), - ExtractElement(In1, Idx), - ExtractElement(In2, Idx), + if (hasSubOverflow(extractElement(Result, Idx), + extractElement(In1, Idx), + extractElement(In2, Idx), IsSigned)) return true; } return false; } - return HasSubOverflow(cast<ConstantInt>(Result), + return hasSubOverflow(cast<ConstantInt>(Result), cast<ConstantInt>(In1), cast<ConstantInt>(In2), IsSigned); } @@ -126,26 +121,26 @@ static bool isBranchOnSignBitCheck(ICmpInst &I, bool isSignBit) { /// Given an exploded icmp instruction, return true if the comparison only /// checks the sign bit. If it only checks the sign bit, set TrueIfSigned if the /// result of the comparison is true when the input value is signed. -static bool isSignBitCheck(ICmpInst::Predicate Pred, ConstantInt *RHS, +static bool isSignBitCheck(ICmpInst::Predicate Pred, const APInt &RHS, bool &TrueIfSigned) { switch (Pred) { case ICmpInst::ICMP_SLT: // True if LHS s< 0 TrueIfSigned = true; - return RHS->isZero(); + return RHS == 0; case ICmpInst::ICMP_SLE: // True if LHS s<= RHS and RHS == -1 TrueIfSigned = true; - return RHS->isAllOnesValue(); + return RHS.isAllOnesValue(); case ICmpInst::ICMP_SGT: // True if LHS s> -1 TrueIfSigned = false; - return RHS->isAllOnesValue(); + return RHS.isAllOnesValue(); case ICmpInst::ICMP_UGT: // True if LHS u> RHS and RHS == high-bit-mask - 1 TrueIfSigned = true; - return RHS->isMaxValue(true); + return RHS.isMaxSignedValue(); case ICmpInst::ICMP_UGE: // True if LHS u>= RHS and RHS == high-bit-mask (2^7, 2^15, 2^31, etc) TrueIfSigned = true; - return RHS->getValue().isSignBit(); + return RHS.isSignBit(); default: return false; } @@ -154,19 +149,20 @@ static bool isSignBitCheck(ICmpInst::Predicate Pred, ConstantInt *RHS, /// Returns true if the exploded icmp can be expressed as a signed comparison /// to zero and updates the predicate accordingly. /// The signedness of the comparison is preserved. -static bool isSignTest(ICmpInst::Predicate &Pred, const ConstantInt *RHS) { +/// TODO: Refactor with decomposeBitTestICmp()? +static bool isSignTest(ICmpInst::Predicate &Pred, const APInt &C) { if (!ICmpInst::isSigned(Pred)) return false; - if (RHS->isZero()) + if (C == 0) return ICmpInst::isRelational(Pred); - if (RHS->isOne()) { + if (C == 1) { if (Pred == ICmpInst::ICMP_SLT) { Pred = ICmpInst::ICMP_SLE; return true; } - } else if (RHS->isAllOnesValue()) { + } else if (C.isAllOnesValue()) { if (Pred == ICmpInst::ICMP_SGT) { Pred = ICmpInst::ICMP_SGE; return true; @@ -176,16 +172,10 @@ static bool isSignTest(ICmpInst::Predicate &Pred, const ConstantInt *RHS) { return false; } -/// Return true if the constant is of the form 1+0+. This is the same as -/// lowones(~X). -static bool isHighOnes(const ConstantInt *CI) { - return (~CI->getValue() + 1).isPowerOf2(); -} - /// Given a signed integer type and a set of known zero and one bits, compute /// the maximum and minimum values that could have the specified known zero and /// known one bits, returning them in Min/Max. -static void ComputeSignedMinMaxValuesFromKnownBits(const APInt &KnownZero, +static void computeSignedMinMaxValuesFromKnownBits(const APInt &KnownZero, const APInt &KnownOne, APInt &Min, APInt &Max) { assert(KnownZero.getBitWidth() == KnownOne.getBitWidth() && @@ -208,7 +198,7 @@ static void ComputeSignedMinMaxValuesFromKnownBits(const APInt &KnownZero, /// Given an unsigned integer type and a set of known zero and one bits, compute /// the maximum and minimum values that could have the specified known zero and /// known one bits, returning them in Min/Max. -static void ComputeUnsignedMinMaxValuesFromKnownBits(const APInt &KnownZero, +static void computeUnsignedMinMaxValuesFromKnownBits(const APInt &KnownZero, const APInt &KnownOne, APInt &Min, APInt &Max) { assert(KnownZero.getBitWidth() == KnownOne.getBitWidth() && @@ -231,9 +221,10 @@ static void ComputeUnsignedMinMaxValuesFromKnownBits(const APInt &KnownZero, /// /// If AndCst is non-null, then the loaded value is masked with that constant /// before doing the comparison. This handles cases like "A[i]&4 == 0". -Instruction *InstCombiner:: -FoldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, GlobalVariable *GV, - CmpInst &ICI, ConstantInt *AndCst) { +Instruction *InstCombiner::foldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, + GlobalVariable *GV, + CmpInst &ICI, + ConstantInt *AndCst) { Constant *Init = GV->getInitializer(); if (!isa<ConstantArray>(Init) && !isa<ConstantDataArray>(Init)) return nullptr; @@ -319,7 +310,7 @@ FoldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, GlobalVariable *GV, // Find out if the comparison would be true or false for the i'th element. Constant *C = ConstantFoldCompareInstOperands(ICI.getPredicate(), Elt, - CompareRHS, DL, TLI); + CompareRHS, DL, &TLI); // If the result is undef for this element, ignore it. if (isa<UndefValue>(C)) { // Extend range state machines to cover this element in case there is an @@ -509,7 +500,7 @@ FoldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, GlobalVariable *GV, /// /// If we can't emit an optimized form for this expression, this returns null. /// -static Value *EvaluateGEPOffsetExpression(User *GEP, InstCombiner &IC, +static Value *evaluateGEPOffsetExpression(User *GEP, InstCombiner &IC, const DataLayout &DL) { gep_type_iterator GTI = gep_type_begin(GEP); @@ -526,7 +517,7 @@ static Value *EvaluateGEPOffsetExpression(User *GEP, InstCombiner &IC, if (CI->isZero()) continue; // Handle a struct index, which adds its field offset to the pointer. - if (StructType *STy = dyn_cast<StructType>(*GTI)) { + if (StructType *STy = GTI.getStructTypeOrNull()) { Offset += DL.getStructLayout(STy)->getElementOffset(CI->getZExtValue()); } else { uint64_t Size = DL.getTypeAllocSize(GTI.getIndexedType()); @@ -556,7 +547,7 @@ static Value *EvaluateGEPOffsetExpression(User *GEP, InstCombiner &IC, if (CI->isZero()) continue; // Handle a struct index, which adds its field offset to the pointer. - if (StructType *STy = dyn_cast<StructType>(*GTI)) { + if (StructType *STy = GTI.getStructTypeOrNull()) { Offset += DL.getStructLayout(STy)->getElementOffset(CI->getZExtValue()); } else { uint64_t Size = DL.getTypeAllocSize(GTI.getIndexedType()); @@ -893,6 +884,10 @@ static Instruction *transformToIndexedCompare(GEPOperator *GEPLHS, Value *RHS, if (!GEPLHS->hasAllConstantIndices()) return nullptr; + // Make sure the pointers have the same type. + if (GEPLHS->getType() != RHS->getType()) + return nullptr; + Value *PtrBase, *Index; std::tie(PtrBase, Index) = getAsConstantIndexedAddress(GEPLHS, DL); @@ -919,7 +914,7 @@ static Instruction *transformToIndexedCompare(GEPOperator *GEPLHS, Value *RHS, /// Fold comparisons between a GEP instruction and something else. At this point /// we know that the GEP is on the LHS of the comparison. -Instruction *InstCombiner::FoldGEPICmp(GEPOperator *GEPLHS, Value *RHS, +Instruction *InstCombiner::foldGEPICmp(GEPOperator *GEPLHS, Value *RHS, ICmpInst::Predicate Cond, Instruction &I) { // Don't transform signed compares of GEPs into index compares. Even if the @@ -941,7 +936,7 @@ Instruction *InstCombiner::FoldGEPICmp(GEPOperator *GEPLHS, Value *RHS, // This transformation (ignoring the base and scales) is valid because we // know pointers can't overflow since the gep is inbounds. See if we can // output an optimized form. - Value *Offset = EvaluateGEPOffsetExpression(GEPLHS, *this, DL); + Value *Offset = evaluateGEPOffsetExpression(GEPLHS, *this, DL); // If not, synthesize the offset the hard way. if (!Offset) @@ -1003,12 +998,12 @@ Instruction *InstCombiner::FoldGEPICmp(GEPOperator *GEPLHS, Value *RHS, // If one of the GEPs has all zero indices, recurse. if (GEPLHS->hasAllZeroIndices()) - return FoldGEPICmp(GEPRHS, GEPLHS->getOperand(0), + return foldGEPICmp(GEPRHS, GEPLHS->getOperand(0), ICmpInst::getSwappedPredicate(Cond), I); // If the other GEP has all zero indices, recurse. if (GEPRHS->hasAllZeroIndices()) - return FoldGEPICmp(GEPLHS, GEPRHS->getOperand(0), Cond, I); + return foldGEPICmp(GEPLHS, GEPRHS->getOperand(0), Cond, I); bool GEPsInBounds = GEPLHS->isInBounds() && GEPRHS->isInBounds(); if (GEPLHS->getNumOperands() == GEPRHS->getNumOperands()) { @@ -1056,8 +1051,9 @@ Instruction *InstCombiner::FoldGEPICmp(GEPOperator *GEPLHS, Value *RHS, return transformToIndexedCompare(GEPLHS, RHS, Cond, DL); } -Instruction *InstCombiner::FoldAllocaCmp(ICmpInst &ICI, AllocaInst *Alloca, - Value *Other) { +Instruction *InstCombiner::foldAllocaCmp(ICmpInst &ICI, + const AllocaInst *Alloca, + const Value *Other) { assert(ICI.isEquality() && "Cannot fold non-equality comparison."); // It would be tempting to fold away comparisons between allocas and any @@ -1076,8 +1072,8 @@ Instruction *InstCombiner::FoldAllocaCmp(ICmpInst &ICI, AllocaInst *Alloca, unsigned MaxIter = 32; // Break cycles and bound to constant-time. - SmallVector<Use *, 32> Worklist; - for (Use &U : Alloca->uses()) { + SmallVector<const Use *, 32> Worklist; + for (const Use &U : Alloca->uses()) { if (Worklist.size() >= MaxIter) return nullptr; Worklist.push_back(&U); @@ -1086,8 +1082,8 @@ Instruction *InstCombiner::FoldAllocaCmp(ICmpInst &ICI, AllocaInst *Alloca, unsigned NumCmps = 0; while (!Worklist.empty()) { assert(Worklist.size() <= MaxIter); - Use *U = Worklist.pop_back_val(); - Value *V = U->getUser(); + const Use *U = Worklist.pop_back_val(); + const Value *V = U->getUser(); --MaxIter; if (isa<BitCastInst>(V) || isa<GetElementPtrInst>(V) || isa<PHINode>(V) || @@ -1096,7 +1092,7 @@ Instruction *InstCombiner::FoldAllocaCmp(ICmpInst &ICI, AllocaInst *Alloca, } else if (isa<LoadInst>(V)) { // Loading from the pointer doesn't escape it. continue; - } else if (auto *SI = dyn_cast<StoreInst>(V)) { + } else if (const auto *SI = dyn_cast<StoreInst>(V)) { // Storing *to* the pointer is fine, but storing the pointer escapes it. if (SI->getValueOperand() == U->get()) return nullptr; @@ -1105,7 +1101,7 @@ Instruction *InstCombiner::FoldAllocaCmp(ICmpInst &ICI, AllocaInst *Alloca, if (NumCmps++) return nullptr; // Found more than one cmp. continue; - } else if (auto *Intrin = dyn_cast<IntrinsicInst>(V)) { + } else if (const auto *Intrin = dyn_cast<IntrinsicInst>(V)) { switch (Intrin->getIntrinsicID()) { // These intrinsics don't escape or compare the pointer. Memset is safe // because we don't allow ptrtoint. Memcpy and memmove are safe because @@ -1120,7 +1116,7 @@ Instruction *InstCombiner::FoldAllocaCmp(ICmpInst &ICI, AllocaInst *Alloca, } else { return nullptr; } - for (Use &U : V->uses()) { + for (const Use &U : V->uses()) { if (Worklist.size() >= MaxIter) return nullptr; Worklist.push_back(&U); @@ -1134,9 +1130,9 @@ Instruction *InstCombiner::FoldAllocaCmp(ICmpInst &ICI, AllocaInst *Alloca, } /// Fold "icmp pred (X+CI), X". -Instruction *InstCombiner::FoldICmpAddOpCst(Instruction &ICI, - Value *X, ConstantInt *CI, - ICmpInst::Predicate Pred) { +Instruction *InstCombiner::foldICmpAddOpConst(Instruction &ICI, + Value *X, ConstantInt *CI, + ICmpInst::Predicate Pred) { // From this point on, we know that (X+C <= X) --> (X+C < X) because C != 0, // so the values can never be equal. Similarly for all other "or equals" // operators. @@ -1181,52 +1177,995 @@ Instruction *InstCombiner::FoldICmpAddOpCst(Instruction &ICI, return new ICmpInst(ICmpInst::ICMP_SLT, X, ConstantExpr::getSub(SMax, C)); } -/// Fold "icmp pred, ([su]div X, DivRHS), CmpRHS" where DivRHS and CmpRHS are -/// both known to be integer constants. -Instruction *InstCombiner::FoldICmpDivCst(ICmpInst &ICI, BinaryOperator *DivI, - ConstantInt *DivRHS) { - ConstantInt *CmpRHS = cast<ConstantInt>(ICI.getOperand(1)); - const APInt &CmpRHSV = CmpRHS->getValue(); +/// Handle "(icmp eq/ne (ashr/lshr AP2, A), AP1)" -> +/// (icmp eq/ne A, Log2(AP2/AP1)) -> +/// (icmp eq/ne A, Log2(AP2) - Log2(AP1)). +Instruction *InstCombiner::foldICmpShrConstConst(ICmpInst &I, Value *A, + const APInt &AP1, + const APInt &AP2) { + assert(I.isEquality() && "Cannot fold icmp gt/lt"); + + auto getICmp = [&I](CmpInst::Predicate Pred, Value *LHS, Value *RHS) { + if (I.getPredicate() == I.ICMP_NE) + Pred = CmpInst::getInversePredicate(Pred); + return new ICmpInst(Pred, LHS, RHS); + }; + + // Don't bother doing any work for cases which InstSimplify handles. + if (AP2 == 0) + return nullptr; + + bool IsAShr = isa<AShrOperator>(I.getOperand(0)); + if (IsAShr) { + if (AP2.isAllOnesValue()) + return nullptr; + if (AP2.isNegative() != AP1.isNegative()) + return nullptr; + if (AP2.sgt(AP1)) + return nullptr; + } + + if (!AP1) + // 'A' must be large enough to shift out the highest set bit. + return getICmp(I.ICMP_UGT, A, + ConstantInt::get(A->getType(), AP2.logBase2())); + + if (AP1 == AP2) + return getICmp(I.ICMP_EQ, A, ConstantInt::getNullValue(A->getType())); + + int Shift; + if (IsAShr && AP1.isNegative()) + Shift = AP1.countLeadingOnes() - AP2.countLeadingOnes(); + else + Shift = AP1.countLeadingZeros() - AP2.countLeadingZeros(); + + if (Shift > 0) { + if (IsAShr && AP1 == AP2.ashr(Shift)) { + // There are multiple solutions if we are comparing against -1 and the LHS + // of the ashr is not a power of two. + if (AP1.isAllOnesValue() && !AP2.isPowerOf2()) + return getICmp(I.ICMP_UGE, A, ConstantInt::get(A->getType(), Shift)); + return getICmp(I.ICMP_EQ, A, ConstantInt::get(A->getType(), Shift)); + } else if (AP1 == AP2.lshr(Shift)) { + return getICmp(I.ICMP_EQ, A, ConstantInt::get(A->getType(), Shift)); + } + } + + // Shifting const2 will never be equal to const1. + // FIXME: This should always be handled by InstSimplify? + auto *TorF = ConstantInt::get(I.getType(), I.getPredicate() == I.ICMP_NE); + return replaceInstUsesWith(I, TorF); +} + +/// Handle "(icmp eq/ne (shl AP2, A), AP1)" -> +/// (icmp eq/ne A, TrailingZeros(AP1) - TrailingZeros(AP2)). +Instruction *InstCombiner::foldICmpShlConstConst(ICmpInst &I, Value *A, + const APInt &AP1, + const APInt &AP2) { + assert(I.isEquality() && "Cannot fold icmp gt/lt"); + + auto getICmp = [&I](CmpInst::Predicate Pred, Value *LHS, Value *RHS) { + if (I.getPredicate() == I.ICMP_NE) + Pred = CmpInst::getInversePredicate(Pred); + return new ICmpInst(Pred, LHS, RHS); + }; + + // Don't bother doing any work for cases which InstSimplify handles. + if (AP2 == 0) + return nullptr; + + unsigned AP2TrailingZeros = AP2.countTrailingZeros(); + + if (!AP1 && AP2TrailingZeros != 0) + return getICmp( + I.ICMP_UGE, A, + ConstantInt::get(A->getType(), AP2.getBitWidth() - AP2TrailingZeros)); + + if (AP1 == AP2) + return getICmp(I.ICMP_EQ, A, ConstantInt::getNullValue(A->getType())); + + // Get the distance between the lowest bits that are set. + int Shift = AP1.countTrailingZeros() - AP2TrailingZeros; + + if (Shift > 0 && AP2.shl(Shift) == AP1) + return getICmp(I.ICMP_EQ, A, ConstantInt::get(A->getType(), Shift)); + + // Shifting const2 will never be equal to const1. + // FIXME: This should always be handled by InstSimplify? + auto *TorF = ConstantInt::get(I.getType(), I.getPredicate() == I.ICMP_NE); + return replaceInstUsesWith(I, TorF); +} + +/// The caller has matched a pattern of the form: +/// I = icmp ugt (add (add A, B), CI2), CI1 +/// If this is of the form: +/// sum = a + b +/// if (sum+128 >u 255) +/// Then replace it with llvm.sadd.with.overflow.i8. +/// +static Instruction *processUGT_ADDCST_ADD(ICmpInst &I, Value *A, Value *B, + ConstantInt *CI2, ConstantInt *CI1, + InstCombiner &IC) { + // The transformation we're trying to do here is to transform this into an + // llvm.sadd.with.overflow. To do this, we have to replace the original add + // with a narrower add, and discard the add-with-constant that is part of the + // range check (if we can't eliminate it, this isn't profitable). + + // In order to eliminate the add-with-constant, the compare can be its only + // use. + Instruction *AddWithCst = cast<Instruction>(I.getOperand(0)); + if (!AddWithCst->hasOneUse()) + return nullptr; + + // If CI2 is 2^7, 2^15, 2^31, then it might be an sadd.with.overflow. + if (!CI2->getValue().isPowerOf2()) + return nullptr; + unsigned NewWidth = CI2->getValue().countTrailingZeros(); + if (NewWidth != 7 && NewWidth != 15 && NewWidth != 31) + return nullptr; + + // The width of the new add formed is 1 more than the bias. + ++NewWidth; + + // Check to see that CI1 is an all-ones value with NewWidth bits. + if (CI1->getBitWidth() == NewWidth || + CI1->getValue() != APInt::getLowBitsSet(CI1->getBitWidth(), NewWidth)) + return nullptr; + + // This is only really a signed overflow check if the inputs have been + // sign-extended; check for that condition. For example, if CI2 is 2^31 and + // the operands of the add are 64 bits wide, we need at least 33 sign bits. + unsigned NeededSignBits = CI1->getBitWidth() - NewWidth + 1; + if (IC.ComputeNumSignBits(A, 0, &I) < NeededSignBits || + IC.ComputeNumSignBits(B, 0, &I) < NeededSignBits) + return nullptr; + + // In order to replace the original add with a narrower + // llvm.sadd.with.overflow, the only uses allowed are the add-with-constant + // and truncates that discard the high bits of the add. Verify that this is + // the case. + Instruction *OrigAdd = cast<Instruction>(AddWithCst->getOperand(0)); + for (User *U : OrigAdd->users()) { + if (U == AddWithCst) + continue; + + // Only accept truncates for now. We would really like a nice recursive + // predicate like SimplifyDemandedBits, but which goes downwards the use-def + // chain to see which bits of a value are actually demanded. If the + // original add had another add which was then immediately truncated, we + // could still do the transformation. + TruncInst *TI = dyn_cast<TruncInst>(U); + if (!TI || TI->getType()->getPrimitiveSizeInBits() > NewWidth) + return nullptr; + } + + // If the pattern matches, truncate the inputs to the narrower type and + // use the sadd_with_overflow intrinsic to efficiently compute both the + // result and the overflow bit. + Type *NewType = IntegerType::get(OrigAdd->getContext(), NewWidth); + Value *F = Intrinsic::getDeclaration(I.getModule(), + Intrinsic::sadd_with_overflow, NewType); + + InstCombiner::BuilderTy *Builder = IC.Builder; + + // Put the new code above the original add, in case there are any uses of the + // add between the add and the compare. + Builder->SetInsertPoint(OrigAdd); + + Value *TruncA = Builder->CreateTrunc(A, NewType, A->getName() + ".trunc"); + Value *TruncB = Builder->CreateTrunc(B, NewType, B->getName() + ".trunc"); + CallInst *Call = Builder->CreateCall(F, {TruncA, TruncB}, "sadd"); + Value *Add = Builder->CreateExtractValue(Call, 0, "sadd.result"); + Value *ZExt = Builder->CreateZExt(Add, OrigAdd->getType()); + + // The inner add was the result of the narrow add, zero extended to the + // wider type. Replace it with the result computed by the intrinsic. + IC.replaceInstUsesWith(*OrigAdd, ZExt); + + // The original icmp gets replaced with the overflow value. + return ExtractValueInst::Create(Call, 1, "sadd.overflow"); +} + +// Fold icmp Pred X, C. +Instruction *InstCombiner::foldICmpWithConstant(ICmpInst &Cmp) { + CmpInst::Predicate Pred = Cmp.getPredicate(); + Value *X = Cmp.getOperand(0); + + const APInt *C; + if (!match(Cmp.getOperand(1), m_APInt(C))) + return nullptr; + + Value *A = nullptr, *B = nullptr; + + // Match the following pattern, which is a common idiom when writing + // overflow-safe integer arithmetic functions. The source performs an addition + // in wider type and explicitly checks for overflow using comparisons against + // INT_MIN and INT_MAX. Simplify by using the sadd_with_overflow intrinsic. + // + // TODO: This could probably be generalized to handle other overflow-safe + // operations if we worked out the formulas to compute the appropriate magic + // constants. + // + // sum = a + b + // if (sum+128 >u 255) ... -> llvm.sadd.with.overflow.i8 + { + ConstantInt *CI2; // I = icmp ugt (add (add A, B), CI2), CI + if (Pred == ICmpInst::ICMP_UGT && + match(X, m_Add(m_Add(m_Value(A), m_Value(B)), m_ConstantInt(CI2)))) + if (Instruction *Res = processUGT_ADDCST_ADD( + Cmp, A, B, CI2, cast<ConstantInt>(Cmp.getOperand(1)), *this)) + return Res; + } + + // (icmp sgt smin(PosA, B) 0) -> (icmp sgt B 0) + if (*C == 0 && Pred == ICmpInst::ICMP_SGT) { + SelectPatternResult SPR = matchSelectPattern(X, A, B); + if (SPR.Flavor == SPF_SMIN) { + if (isKnownPositive(A, DL)) + return new ICmpInst(Pred, B, Cmp.getOperand(1)); + if (isKnownPositive(B, DL)) + return new ICmpInst(Pred, A, Cmp.getOperand(1)); + } + } + + // FIXME: Use m_APInt to allow folds for splat constants. + ConstantInt *CI = dyn_cast<ConstantInt>(Cmp.getOperand(1)); + if (!CI) + return nullptr; + + // Canonicalize icmp instructions based on dominating conditions. + BasicBlock *Parent = Cmp.getParent(); + BasicBlock *Dom = Parent->getSinglePredecessor(); + auto *BI = Dom ? dyn_cast<BranchInst>(Dom->getTerminator()) : nullptr; + ICmpInst::Predicate Pred2; + BasicBlock *TrueBB, *FalseBB; + ConstantInt *CI2; + if (BI && match(BI, m_Br(m_ICmp(Pred2, m_Specific(X), m_ConstantInt(CI2)), + TrueBB, FalseBB)) && + TrueBB != FalseBB) { + ConstantRange CR = + ConstantRange::makeAllowedICmpRegion(Pred, CI->getValue()); + ConstantRange DominatingCR = + (Parent == TrueBB) + ? ConstantRange::makeExactICmpRegion(Pred2, CI2->getValue()) + : ConstantRange::makeExactICmpRegion( + CmpInst::getInversePredicate(Pred2), CI2->getValue()); + ConstantRange Intersection = DominatingCR.intersectWith(CR); + ConstantRange Difference = DominatingCR.difference(CR); + if (Intersection.isEmptySet()) + return replaceInstUsesWith(Cmp, Builder->getFalse()); + if (Difference.isEmptySet()) + return replaceInstUsesWith(Cmp, Builder->getTrue()); + + // If this is a normal comparison, it demands all bits. If it is a sign + // bit comparison, it only demands the sign bit. + bool UnusedBit; + bool IsSignBit = isSignBitCheck(Pred, CI->getValue(), UnusedBit); + + // Canonicalizing a sign bit comparison that gets used in a branch, + // pessimizes codegen by generating branch on zero instruction instead + // of a test and branch. So we avoid canonicalizing in such situations + // because test and branch instruction has better branch displacement + // than compare and branch instruction. + if (!isBranchOnSignBitCheck(Cmp, IsSignBit) && !Cmp.isEquality()) { + if (auto *AI = Intersection.getSingleElement()) + return new ICmpInst(ICmpInst::ICMP_EQ, X, Builder->getInt(*AI)); + if (auto *AD = Difference.getSingleElement()) + return new ICmpInst(ICmpInst::ICMP_NE, X, Builder->getInt(*AD)); + } + } + + return nullptr; +} + +/// Fold icmp (trunc X, Y), C. +Instruction *InstCombiner::foldICmpTruncConstant(ICmpInst &Cmp, + Instruction *Trunc, + const APInt *C) { + ICmpInst::Predicate Pred = Cmp.getPredicate(); + Value *X = Trunc->getOperand(0); + if (*C == 1 && C->getBitWidth() > 1) { + // icmp slt trunc(signum(V)) 1 --> icmp slt V, 1 + Value *V = nullptr; + if (Pred == ICmpInst::ICMP_SLT && match(X, m_Signum(m_Value(V)))) + return new ICmpInst(ICmpInst::ICMP_SLT, V, + ConstantInt::get(V->getType(), 1)); + } + + if (Cmp.isEquality() && Trunc->hasOneUse()) { + // Simplify icmp eq (trunc x to i8), 42 -> icmp eq x, 42|highbits if all + // of the high bits truncated out of x are known. + unsigned DstBits = Trunc->getType()->getScalarSizeInBits(), + SrcBits = X->getType()->getScalarSizeInBits(); + APInt KnownZero(SrcBits, 0), KnownOne(SrcBits, 0); + computeKnownBits(X, KnownZero, KnownOne, 0, &Cmp); + + // If all the high bits are known, we can do this xform. + if ((KnownZero | KnownOne).countLeadingOnes() >= SrcBits - DstBits) { + // Pull in the high bits from known-ones set. + APInt NewRHS = C->zext(SrcBits); + NewRHS |= KnownOne & APInt::getHighBitsSet(SrcBits, SrcBits - DstBits); + return new ICmpInst(Pred, X, ConstantInt::get(X->getType(), NewRHS)); + } + } + + return nullptr; +} + +/// Fold icmp (xor X, Y), C. +Instruction *InstCombiner::foldICmpXorConstant(ICmpInst &Cmp, + BinaryOperator *Xor, + const APInt *C) { + Value *X = Xor->getOperand(0); + Value *Y = Xor->getOperand(1); + const APInt *XorC; + if (!match(Y, m_APInt(XorC))) + return nullptr; + + // If this is a comparison that tests the signbit (X < 0) or (x > -1), + // fold the xor. + ICmpInst::Predicate Pred = Cmp.getPredicate(); + if ((Pred == ICmpInst::ICMP_SLT && *C == 0) || + (Pred == ICmpInst::ICMP_SGT && C->isAllOnesValue())) { + + // If the sign bit of the XorCst is not set, there is no change to + // the operation, just stop using the Xor. + if (!XorC->isNegative()) { + Cmp.setOperand(0, X); + Worklist.Add(Xor); + return &Cmp; + } + + // Was the old condition true if the operand is positive? + bool isTrueIfPositive = Pred == ICmpInst::ICMP_SGT; + + // If so, the new one isn't. + isTrueIfPositive ^= true; + + Constant *CmpConstant = cast<Constant>(Cmp.getOperand(1)); + if (isTrueIfPositive) + return new ICmpInst(ICmpInst::ICMP_SGT, X, SubOne(CmpConstant)); + else + return new ICmpInst(ICmpInst::ICMP_SLT, X, AddOne(CmpConstant)); + } + + if (Xor->hasOneUse()) { + // (icmp u/s (xor X SignBit), C) -> (icmp s/u X, (xor C SignBit)) + if (!Cmp.isEquality() && XorC->isSignBit()) { + Pred = Cmp.isSigned() ? Cmp.getUnsignedPredicate() + : Cmp.getSignedPredicate(); + return new ICmpInst(Pred, X, ConstantInt::get(X->getType(), *C ^ *XorC)); + } + + // (icmp u/s (xor X ~SignBit), C) -> (icmp s/u X, (xor C ~SignBit)) + if (!Cmp.isEquality() && XorC->isMaxSignedValue()) { + Pred = Cmp.isSigned() ? Cmp.getUnsignedPredicate() + : Cmp.getSignedPredicate(); + Pred = Cmp.getSwappedPredicate(Pred); + return new ICmpInst(Pred, X, ConstantInt::get(X->getType(), *C ^ *XorC)); + } + } + + // (icmp ugt (xor X, C), ~C) -> (icmp ult X, C) + // iff -C is a power of 2 + if (Pred == ICmpInst::ICMP_UGT && *XorC == ~(*C) && (*C + 1).isPowerOf2()) + return new ICmpInst(ICmpInst::ICMP_ULT, X, Y); + + // (icmp ult (xor X, C), -C) -> (icmp uge X, C) + // iff -C is a power of 2 + if (Pred == ICmpInst::ICMP_ULT && *XorC == -(*C) && C->isPowerOf2()) + return new ICmpInst(ICmpInst::ICMP_UGE, X, Y); + + return nullptr; +} + +/// Fold icmp (and (sh X, Y), C2), C1. +Instruction *InstCombiner::foldICmpAndShift(ICmpInst &Cmp, BinaryOperator *And, + const APInt *C1, const APInt *C2) { + BinaryOperator *Shift = dyn_cast<BinaryOperator>(And->getOperand(0)); + if (!Shift || !Shift->isShift()) + return nullptr; + + // If this is: (X >> C3) & C2 != C1 (where any shift and any compare could + // exist), turn it into (X & (C2 << C3)) != (C1 << C3). This happens a LOT in + // code produced by the clang front-end, for bitfield access. + // This seemingly simple opportunity to fold away a shift turns out to be + // rather complicated. See PR17827 for details. + unsigned ShiftOpcode = Shift->getOpcode(); + bool IsShl = ShiftOpcode == Instruction::Shl; + const APInt *C3; + if (match(Shift->getOperand(1), m_APInt(C3))) { + bool CanFold = false; + if (ShiftOpcode == Instruction::AShr) { + // There may be some constraints that make this possible, but nothing + // simple has been discovered yet. + CanFold = false; + } else if (ShiftOpcode == Instruction::Shl) { + // For a left shift, we can fold if the comparison is not signed. We can + // also fold a signed comparison if the mask value and comparison value + // are not negative. These constraints may not be obvious, but we can + // prove that they are correct using an SMT solver. + if (!Cmp.isSigned() || (!C2->isNegative() && !C1->isNegative())) + CanFold = true; + } else if (ShiftOpcode == Instruction::LShr) { + // For a logical right shift, we can fold if the comparison is not signed. + // We can also fold a signed comparison if the shifted mask value and the + // shifted comparison value are not negative. These constraints may not be + // obvious, but we can prove that they are correct using an SMT solver. + if (!Cmp.isSigned() || + (!C2->shl(*C3).isNegative() && !C1->shl(*C3).isNegative())) + CanFold = true; + } + + if (CanFold) { + APInt NewCst = IsShl ? C1->lshr(*C3) : C1->shl(*C3); + APInt SameAsC1 = IsShl ? NewCst.shl(*C3) : NewCst.lshr(*C3); + // Check to see if we are shifting out any of the bits being compared. + if (SameAsC1 != *C1) { + // If we shifted bits out, the fold is not going to work out. As a + // special case, check to see if this means that the result is always + // true or false now. + if (Cmp.getPredicate() == ICmpInst::ICMP_EQ) + return replaceInstUsesWith(Cmp, ConstantInt::getFalse(Cmp.getType())); + if (Cmp.getPredicate() == ICmpInst::ICMP_NE) + return replaceInstUsesWith(Cmp, ConstantInt::getTrue(Cmp.getType())); + } else { + Cmp.setOperand(1, ConstantInt::get(And->getType(), NewCst)); + APInt NewAndCst = IsShl ? C2->lshr(*C3) : C2->shl(*C3); + And->setOperand(1, ConstantInt::get(And->getType(), NewAndCst)); + And->setOperand(0, Shift->getOperand(0)); + Worklist.Add(Shift); // Shift is dead. + return &Cmp; + } + } + } + + // Turn ((X >> Y) & C2) == 0 into (X & (C2 << Y)) == 0. The latter is + // preferable because it allows the C2 << Y expression to be hoisted out of a + // loop if Y is invariant and X is not. + if (Shift->hasOneUse() && *C1 == 0 && Cmp.isEquality() && + !Shift->isArithmeticShift() && !isa<Constant>(Shift->getOperand(0))) { + // Compute C2 << Y. + Value *NewShift = + IsShl ? Builder->CreateLShr(And->getOperand(1), Shift->getOperand(1)) + : Builder->CreateShl(And->getOperand(1), Shift->getOperand(1)); + + // Compute X & (C2 << Y). + Value *NewAnd = Builder->CreateAnd(Shift->getOperand(0), NewShift); + Cmp.setOperand(0, NewAnd); + return &Cmp; + } + + return nullptr; +} + +/// Fold icmp (and X, C2), C1. +Instruction *InstCombiner::foldICmpAndConstConst(ICmpInst &Cmp, + BinaryOperator *And, + const APInt *C1) { + const APInt *C2; + if (!match(And->getOperand(1), m_APInt(C2))) + return nullptr; + + if (!And->hasOneUse() || !And->getOperand(0)->hasOneUse()) + return nullptr; + + // If the LHS is an 'and' of a truncate and we can widen the and/compare to + // the input width without changing the value produced, eliminate the cast: + // + // icmp (and (trunc W), C2), C1 -> icmp (and W, C2'), C1' + // + // We can do this transformation if the constants do not have their sign bits + // set or if it is an equality comparison. Extending a relational comparison + // when we're checking the sign bit would not work. + Value *W; + if (match(And->getOperand(0), m_Trunc(m_Value(W))) && + (Cmp.isEquality() || (!C1->isNegative() && !C2->isNegative()))) { + // TODO: Is this a good transform for vectors? Wider types may reduce + // throughput. Should this transform be limited (even for scalars) by using + // ShouldChangeType()? + if (!Cmp.getType()->isVectorTy()) { + Type *WideType = W->getType(); + unsigned WideScalarBits = WideType->getScalarSizeInBits(); + Constant *ZextC1 = ConstantInt::get(WideType, C1->zext(WideScalarBits)); + Constant *ZextC2 = ConstantInt::get(WideType, C2->zext(WideScalarBits)); + Value *NewAnd = Builder->CreateAnd(W, ZextC2, And->getName()); + return new ICmpInst(Cmp.getPredicate(), NewAnd, ZextC1); + } + } + + if (Instruction *I = foldICmpAndShift(Cmp, And, C1, C2)) + return I; + + // (icmp pred (and (or (lshr A, B), A), 1), 0) --> + // (icmp pred (and A, (or (shl 1, B), 1), 0)) + // + // iff pred isn't signed + if (!Cmp.isSigned() && *C1 == 0 && match(And->getOperand(1), m_One())) { + Constant *One = cast<Constant>(And->getOperand(1)); + Value *Or = And->getOperand(0); + Value *A, *B, *LShr; + if (match(Or, m_Or(m_Value(LShr), m_Value(A))) && + match(LShr, m_LShr(m_Specific(A), m_Value(B)))) { + unsigned UsesRemoved = 0; + if (And->hasOneUse()) + ++UsesRemoved; + if (Or->hasOneUse()) + ++UsesRemoved; + if (LShr->hasOneUse()) + ++UsesRemoved; + + // Compute A & ((1 << B) | 1) + Value *NewOr = nullptr; + if (auto *C = dyn_cast<Constant>(B)) { + if (UsesRemoved >= 1) + NewOr = ConstantExpr::getOr(ConstantExpr::getNUWShl(One, C), One); + } else { + if (UsesRemoved >= 3) + NewOr = Builder->CreateOr(Builder->CreateShl(One, B, LShr->getName(), + /*HasNUW=*/true), + One, Or->getName()); + } + if (NewOr) { + Value *NewAnd = Builder->CreateAnd(A, NewOr, And->getName()); + Cmp.setOperand(0, NewAnd); + return &Cmp; + } + } + } + + // (X & C2) > C1 --> (X & C2) != 0, if any bit set in (X & C2) will produce a + // result greater than C1. + unsigned NumTZ = C2->countTrailingZeros(); + if (Cmp.getPredicate() == ICmpInst::ICMP_UGT && NumTZ < C2->getBitWidth() && + APInt::getOneBitSet(C2->getBitWidth(), NumTZ).ugt(*C1)) { + Constant *Zero = Constant::getNullValue(And->getType()); + return new ICmpInst(ICmpInst::ICMP_NE, And, Zero); + } + + return nullptr; +} + +/// Fold icmp (and X, Y), C. +Instruction *InstCombiner::foldICmpAndConstant(ICmpInst &Cmp, + BinaryOperator *And, + const APInt *C) { + if (Instruction *I = foldICmpAndConstConst(Cmp, And, C)) + return I; + + // TODO: These all require that Y is constant too, so refactor with the above. + + // Try to optimize things like "A[i] & 42 == 0" to index computations. + Value *X = And->getOperand(0); + Value *Y = And->getOperand(1); + if (auto *LI = dyn_cast<LoadInst>(X)) + if (auto *GEP = dyn_cast<GetElementPtrInst>(LI->getOperand(0))) + if (auto *GV = dyn_cast<GlobalVariable>(GEP->getOperand(0))) + if (GV->isConstant() && GV->hasDefinitiveInitializer() && + !LI->isVolatile() && isa<ConstantInt>(Y)) { + ConstantInt *C2 = cast<ConstantInt>(Y); + if (Instruction *Res = foldCmpLoadFromIndexedGlobal(GEP, GV, Cmp, C2)) + return Res; + } + + if (!Cmp.isEquality()) + return nullptr; + + // X & -C == -C -> X > u ~C + // X & -C != -C -> X <= u ~C + // iff C is a power of 2 + if (Cmp.getOperand(1) == Y && (-(*C)).isPowerOf2()) { + auto NewPred = Cmp.getPredicate() == CmpInst::ICMP_EQ ? CmpInst::ICMP_UGT + : CmpInst::ICMP_ULE; + return new ICmpInst(NewPred, X, SubOne(cast<Constant>(Cmp.getOperand(1)))); + } + + // (X & C2) == 0 -> (trunc X) >= 0 + // (X & C2) != 0 -> (trunc X) < 0 + // iff C2 is a power of 2 and it masks the sign bit of a legal integer type. + const APInt *C2; + if (And->hasOneUse() && *C == 0 && match(Y, m_APInt(C2))) { + int32_t ExactLogBase2 = C2->exactLogBase2(); + if (ExactLogBase2 != -1 && DL.isLegalInteger(ExactLogBase2 + 1)) { + Type *NTy = IntegerType::get(Cmp.getContext(), ExactLogBase2 + 1); + if (And->getType()->isVectorTy()) + NTy = VectorType::get(NTy, And->getType()->getVectorNumElements()); + Value *Trunc = Builder->CreateTrunc(X, NTy); + auto NewPred = Cmp.getPredicate() == CmpInst::ICMP_EQ ? CmpInst::ICMP_SGE + : CmpInst::ICMP_SLT; + return new ICmpInst(NewPred, Trunc, Constant::getNullValue(NTy)); + } + } + + return nullptr; +} + +/// Fold icmp (or X, Y), C. +Instruction *InstCombiner::foldICmpOrConstant(ICmpInst &Cmp, BinaryOperator *Or, + const APInt *C) { + ICmpInst::Predicate Pred = Cmp.getPredicate(); + if (*C == 1) { + // icmp slt signum(V) 1 --> icmp slt V, 1 + Value *V = nullptr; + if (Pred == ICmpInst::ICMP_SLT && match(Or, m_Signum(m_Value(V)))) + return new ICmpInst(ICmpInst::ICMP_SLT, V, + ConstantInt::get(V->getType(), 1)); + } + + if (!Cmp.isEquality() || *C != 0 || !Or->hasOneUse()) + return nullptr; + + Value *P, *Q; + if (match(Or, m_Or(m_PtrToInt(m_Value(P)), m_PtrToInt(m_Value(Q))))) { + // Simplify icmp eq (or (ptrtoint P), (ptrtoint Q)), 0 + // -> and (icmp eq P, null), (icmp eq Q, null). + Value *CmpP = + Builder->CreateICmp(Pred, P, ConstantInt::getNullValue(P->getType())); + Value *CmpQ = + Builder->CreateICmp(Pred, Q, ConstantInt::getNullValue(Q->getType())); + auto LogicOpc = Pred == ICmpInst::Predicate::ICMP_EQ ? Instruction::And + : Instruction::Or; + return BinaryOperator::Create(LogicOpc, CmpP, CmpQ); + } + + return nullptr; +} + +/// Fold icmp (mul X, Y), C. +Instruction *InstCombiner::foldICmpMulConstant(ICmpInst &Cmp, + BinaryOperator *Mul, + const APInt *C) { + const APInt *MulC; + if (!match(Mul->getOperand(1), m_APInt(MulC))) + return nullptr; + + // If this is a test of the sign bit and the multiply is sign-preserving with + // a constant operand, use the multiply LHS operand instead. + ICmpInst::Predicate Pred = Cmp.getPredicate(); + if (isSignTest(Pred, *C) && Mul->hasNoSignedWrap()) { + if (MulC->isNegative()) + Pred = ICmpInst::getSwappedPredicate(Pred); + return new ICmpInst(Pred, Mul->getOperand(0), + Constant::getNullValue(Mul->getType())); + } + + return nullptr; +} + +/// Fold icmp (shl 1, Y), C. +static Instruction *foldICmpShlOne(ICmpInst &Cmp, Instruction *Shl, + const APInt *C) { + Value *Y; + if (!match(Shl, m_Shl(m_One(), m_Value(Y)))) + return nullptr; + + Type *ShiftType = Shl->getType(); + uint32_t TypeBits = C->getBitWidth(); + bool CIsPowerOf2 = C->isPowerOf2(); + ICmpInst::Predicate Pred = Cmp.getPredicate(); + if (Cmp.isUnsigned()) { + // (1 << Y) pred C -> Y pred Log2(C) + if (!CIsPowerOf2) { + // (1 << Y) < 30 -> Y <= 4 + // (1 << Y) <= 30 -> Y <= 4 + // (1 << Y) >= 30 -> Y > 4 + // (1 << Y) > 30 -> Y > 4 + if (Pred == ICmpInst::ICMP_ULT) + Pred = ICmpInst::ICMP_ULE; + else if (Pred == ICmpInst::ICMP_UGE) + Pred = ICmpInst::ICMP_UGT; + } + + // (1 << Y) >= 2147483648 -> Y >= 31 -> Y == 31 + // (1 << Y) < 2147483648 -> Y < 31 -> Y != 31 + unsigned CLog2 = C->logBase2(); + if (CLog2 == TypeBits - 1) { + if (Pred == ICmpInst::ICMP_UGE) + Pred = ICmpInst::ICMP_EQ; + else if (Pred == ICmpInst::ICMP_ULT) + Pred = ICmpInst::ICMP_NE; + } + return new ICmpInst(Pred, Y, ConstantInt::get(ShiftType, CLog2)); + } else if (Cmp.isSigned()) { + Constant *BitWidthMinusOne = ConstantInt::get(ShiftType, TypeBits - 1); + if (C->isAllOnesValue()) { + // (1 << Y) <= -1 -> Y == 31 + if (Pred == ICmpInst::ICMP_SLE) + return new ICmpInst(ICmpInst::ICMP_EQ, Y, BitWidthMinusOne); + + // (1 << Y) > -1 -> Y != 31 + if (Pred == ICmpInst::ICMP_SGT) + return new ICmpInst(ICmpInst::ICMP_NE, Y, BitWidthMinusOne); + } else if (!(*C)) { + // (1 << Y) < 0 -> Y == 31 + // (1 << Y) <= 0 -> Y == 31 + if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE) + return new ICmpInst(ICmpInst::ICMP_EQ, Y, BitWidthMinusOne); + + // (1 << Y) >= 0 -> Y != 31 + // (1 << Y) > 0 -> Y != 31 + if (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SGE) + return new ICmpInst(ICmpInst::ICMP_NE, Y, BitWidthMinusOne); + } + } else if (Cmp.isEquality() && CIsPowerOf2) { + return new ICmpInst(Pred, Y, ConstantInt::get(ShiftType, C->logBase2())); + } + + return nullptr; +} + +/// Fold icmp (shl X, Y), C. +Instruction *InstCombiner::foldICmpShlConstant(ICmpInst &Cmp, + BinaryOperator *Shl, + const APInt *C) { + const APInt *ShiftVal; + if (Cmp.isEquality() && match(Shl->getOperand(0), m_APInt(ShiftVal))) + return foldICmpShlConstConst(Cmp, Shl->getOperand(1), *C, *ShiftVal); + + const APInt *ShiftAmt; + if (!match(Shl->getOperand(1), m_APInt(ShiftAmt))) + return foldICmpShlOne(Cmp, Shl, C); + + // Check that the shift amount is in range. If not, don't perform undefined + // shifts. When the shift is visited, it will be simplified. + unsigned TypeBits = C->getBitWidth(); + if (ShiftAmt->uge(TypeBits)) + return nullptr; + + ICmpInst::Predicate Pred = Cmp.getPredicate(); + Value *X = Shl->getOperand(0); + if (Cmp.isEquality()) { + // If the shift is NUW, then it is just shifting out zeros, no need for an + // AND. + Constant *LShrC = ConstantInt::get(Shl->getType(), C->lshr(*ShiftAmt)); + if (Shl->hasNoUnsignedWrap()) + return new ICmpInst(Pred, X, LShrC); + + // If the shift is NSW and we compare to 0, then it is just shifting out + // sign bits, no need for an AND either. + if (Shl->hasNoSignedWrap() && *C == 0) + return new ICmpInst(Pred, X, LShrC); + + if (Shl->hasOneUse()) { + // Otherwise, strength reduce the shift into an and. + Constant *Mask = ConstantInt::get(Shl->getType(), + APInt::getLowBitsSet(TypeBits, TypeBits - ShiftAmt->getZExtValue())); + + Value *And = Builder->CreateAnd(X, Mask, Shl->getName() + ".mask"); + return new ICmpInst(Pred, And, LShrC); + } + } + + // If this is a signed comparison to 0 and the shift is sign preserving, + // use the shift LHS operand instead; isSignTest may change 'Pred', so only + // do that if we're sure to not continue on in this function. + if (Shl->hasNoSignedWrap() && isSignTest(Pred, *C)) + return new ICmpInst(Pred, X, Constant::getNullValue(X->getType())); + + // Otherwise, if this is a comparison of the sign bit, simplify to and/test. + bool TrueIfSigned = false; + if (Shl->hasOneUse() && isSignBitCheck(Pred, *C, TrueIfSigned)) { + // (X << 31) <s 0 --> (X & 1) != 0 + Constant *Mask = ConstantInt::get( + X->getType(), + APInt::getOneBitSet(TypeBits, TypeBits - ShiftAmt->getZExtValue() - 1)); + Value *And = Builder->CreateAnd(X, Mask, Shl->getName() + ".mask"); + return new ICmpInst(TrueIfSigned ? ICmpInst::ICMP_NE : ICmpInst::ICMP_EQ, + And, Constant::getNullValue(And->getType())); + } + + // When the shift is nuw and pred is >u or <=u, comparison only really happens + // in the pre-shifted bits. Since InstSimplify canonicalizes <=u into <u, the + // <=u case can be further converted to match <u (see below). + if (Shl->hasNoUnsignedWrap() && + (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_ULT)) { + // Derivation for the ult case: + // (X << S) <=u C is equiv to X <=u (C >> S) for all C + // (X << S) <u (C + 1) is equiv to X <u (C >> S) + 1 if C <u ~0u + // (X << S) <u C is equiv to X <u ((C - 1) >> S) + 1 if C >u 0 + assert((Pred != ICmpInst::ICMP_ULT || C->ugt(0)) && + "Encountered `ult 0` that should have been eliminated by " + "InstSimplify."); + APInt ShiftedC = Pred == ICmpInst::ICMP_ULT ? (*C - 1).lshr(*ShiftAmt) + 1 + : C->lshr(*ShiftAmt); + return new ICmpInst(Pred, X, ConstantInt::get(X->getType(), ShiftedC)); + } + + // Transform (icmp pred iM (shl iM %v, N), C) + // -> (icmp pred i(M-N) (trunc %v iM to i(M-N)), (trunc (C>>N)) + // Transform the shl to a trunc if (trunc (C>>N)) has no loss and M-N. + // This enables us to get rid of the shift in favor of a trunc that may be + // free on the target. It has the additional benefit of comparing to a + // smaller constant that may be more target-friendly. + unsigned Amt = ShiftAmt->getLimitedValue(TypeBits - 1); + if (Shl->hasOneUse() && Amt != 0 && C->countTrailingZeros() >= Amt && + DL.isLegalInteger(TypeBits - Amt)) { + Type *TruncTy = IntegerType::get(Cmp.getContext(), TypeBits - Amt); + if (X->getType()->isVectorTy()) + TruncTy = VectorType::get(TruncTy, X->getType()->getVectorNumElements()); + Constant *NewC = + ConstantInt::get(TruncTy, C->ashr(*ShiftAmt).trunc(TypeBits - Amt)); + return new ICmpInst(Pred, Builder->CreateTrunc(X, TruncTy), NewC); + } + + return nullptr; +} + +/// Fold icmp ({al}shr X, Y), C. +Instruction *InstCombiner::foldICmpShrConstant(ICmpInst &Cmp, + BinaryOperator *Shr, + const APInt *C) { + // An exact shr only shifts out zero bits, so: + // icmp eq/ne (shr X, Y), 0 --> icmp eq/ne X, 0 + Value *X = Shr->getOperand(0); + CmpInst::Predicate Pred = Cmp.getPredicate(); + if (Cmp.isEquality() && Shr->isExact() && Shr->hasOneUse() && *C == 0) + return new ICmpInst(Pred, X, Cmp.getOperand(1)); + + const APInt *ShiftVal; + if (Cmp.isEquality() && match(Shr->getOperand(0), m_APInt(ShiftVal))) + return foldICmpShrConstConst(Cmp, Shr->getOperand(1), *C, *ShiftVal); + + const APInt *ShiftAmt; + if (!match(Shr->getOperand(1), m_APInt(ShiftAmt))) + return nullptr; + + // Check that the shift amount is in range. If not, don't perform undefined + // shifts. When the shift is visited it will be simplified. + unsigned TypeBits = C->getBitWidth(); + unsigned ShAmtVal = ShiftAmt->getLimitedValue(TypeBits); + if (ShAmtVal >= TypeBits || ShAmtVal == 0) + return nullptr; + + bool IsAShr = Shr->getOpcode() == Instruction::AShr; + if (!Cmp.isEquality()) { + // If we have an unsigned comparison and an ashr, we can't simplify this. + // Similarly for signed comparisons with lshr. + if (Cmp.isSigned() != IsAShr) + return nullptr; + + // Otherwise, all lshr and most exact ashr's are equivalent to a udiv/sdiv + // by a power of 2. Since we already have logic to simplify these, + // transform to div and then simplify the resultant comparison. + if (IsAShr && (!Shr->isExact() || ShAmtVal == TypeBits - 1)) + return nullptr; + + // Revisit the shift (to delete it). + Worklist.Add(Shr); + + Constant *DivCst = ConstantInt::get( + Shr->getType(), APInt::getOneBitSet(TypeBits, ShAmtVal)); + + Value *Tmp = IsAShr ? Builder->CreateSDiv(X, DivCst, "", Shr->isExact()) + : Builder->CreateUDiv(X, DivCst, "", Shr->isExact()); + + Cmp.setOperand(0, Tmp); + + // If the builder folded the binop, just return it. + BinaryOperator *TheDiv = dyn_cast<BinaryOperator>(Tmp); + if (!TheDiv) + return &Cmp; + + // Otherwise, fold this div/compare. + assert(TheDiv->getOpcode() == Instruction::SDiv || + TheDiv->getOpcode() == Instruction::UDiv); + + Instruction *Res = foldICmpDivConstant(Cmp, TheDiv, C); + assert(Res && "This div/cst should have folded!"); + return Res; + } + + // Handle equality comparisons of shift-by-constant. + + // If the comparison constant changes with the shift, the comparison cannot + // succeed (bits of the comparison constant cannot match the shifted value). + // This should be known by InstSimplify and already be folded to true/false. + assert(((IsAShr && C->shl(ShAmtVal).ashr(ShAmtVal) == *C) || + (!IsAShr && C->shl(ShAmtVal).lshr(ShAmtVal) == *C)) && + "Expected icmp+shr simplify did not occur."); + + // Check if the bits shifted out are known to be zero. If so, we can compare + // against the unshifted value: + // (X & 4) >> 1 == 2 --> (X & 4) == 4. + Constant *ShiftedCmpRHS = ConstantInt::get(Shr->getType(), *C << ShAmtVal); + if (Shr->hasOneUse()) { + if (Shr->isExact()) + return new ICmpInst(Pred, X, ShiftedCmpRHS); + + // Otherwise strength reduce the shift into an 'and'. + APInt Val(APInt::getHighBitsSet(TypeBits, TypeBits - ShAmtVal)); + Constant *Mask = ConstantInt::get(Shr->getType(), Val); + Value *And = Builder->CreateAnd(X, Mask, Shr->getName() + ".mask"); + return new ICmpInst(Pred, And, ShiftedCmpRHS); + } + + return nullptr; +} + +/// Fold icmp (udiv X, Y), C. +Instruction *InstCombiner::foldICmpUDivConstant(ICmpInst &Cmp, + BinaryOperator *UDiv, + const APInt *C) { + const APInt *C2; + if (!match(UDiv->getOperand(0), m_APInt(C2))) + return nullptr; + + assert(C2 != 0 && "udiv 0, X should have been simplified already."); + + // (icmp ugt (udiv C2, Y), C) -> (icmp ule Y, C2/(C+1)) + Value *Y = UDiv->getOperand(1); + if (Cmp.getPredicate() == ICmpInst::ICMP_UGT) { + assert(!C->isMaxValue() && + "icmp ugt X, UINT_MAX should have been simplified already."); + return new ICmpInst(ICmpInst::ICMP_ULE, Y, + ConstantInt::get(Y->getType(), C2->udiv(*C + 1))); + } + + // (icmp ult (udiv C2, Y), C) -> (icmp ugt Y, C2/C) + if (Cmp.getPredicate() == ICmpInst::ICMP_ULT) { + assert(C != 0 && "icmp ult X, 0 should have been simplified already."); + return new ICmpInst(ICmpInst::ICMP_UGT, Y, + ConstantInt::get(Y->getType(), C2->udiv(*C))); + } + + return nullptr; +} + +/// Fold icmp ({su}div X, Y), C. +Instruction *InstCombiner::foldICmpDivConstant(ICmpInst &Cmp, + BinaryOperator *Div, + const APInt *C) { + // Fold: icmp pred ([us]div X, C2), C -> range test + // Fold this div into the comparison, producing a range check. + // Determine, based on the divide type, what the range is being + // checked. If there is an overflow on the low or high side, remember + // it, otherwise compute the range [low, hi) bounding the new value. + // See: InsertRangeTest above for the kinds of replacements possible. + const APInt *C2; + if (!match(Div->getOperand(1), m_APInt(C2))) + return nullptr; // FIXME: If the operand types don't match the type of the divide // then don't attempt this transform. The code below doesn't have the // logic to deal with a signed divide and an unsigned compare (and - // vice versa). This is because (x /s C1) <s C2 produces different - // results than (x /s C1) <u C2 or (x /u C1) <s C2 or even - // (x /u C1) <u C2. Simply casting the operands and result won't + // vice versa). This is because (x /s C2) <s C produces different + // results than (x /s C2) <u C or (x /u C2) <s C or even + // (x /u C2) <u C. Simply casting the operands and result won't // work. :( The if statement below tests that condition and bails // if it finds it. - bool DivIsSigned = DivI->getOpcode() == Instruction::SDiv; - if (!ICI.isEquality() && DivIsSigned != ICI.isSigned()) + bool DivIsSigned = Div->getOpcode() == Instruction::SDiv; + if (!Cmp.isEquality() && DivIsSigned != Cmp.isSigned()) return nullptr; - if (DivRHS->isZero()) - return nullptr; // The ProdOV computation fails on divide by zero. - if (DivIsSigned && DivRHS->isAllOnesValue()) - return nullptr; // The overflow computation also screws up here - if (DivRHS->isOne()) { - // This eliminates some funny cases with INT_MIN. - ICI.setOperand(0, DivI->getOperand(0)); // X/1 == X. - return &ICI; - } - - // Compute Prod = CI * DivRHS. We are essentially solving an equation - // of form X/C1=C2. We solve for X by multiplying C1 (DivRHS) and - // C2 (CI). By solving for X we can turn this into a range check - // instead of computing a divide. + + // The ProdOV computation fails on divide by 0 and divide by -1. Cases with + // INT_MIN will also fail if the divisor is 1. Although folds of all these + // division-by-constant cases should be present, we can not assert that they + // have happened before we reach this icmp instruction. + if (*C2 == 0 || *C2 == 1 || (DivIsSigned && C2->isAllOnesValue())) + return nullptr; + + // TODO: We could do all of the computations below using APInt. + Constant *CmpRHS = cast<Constant>(Cmp.getOperand(1)); + Constant *DivRHS = cast<Constant>(Div->getOperand(1)); + + // Compute Prod = CmpRHS * DivRHS. We are essentially solving an equation of + // form X / C2 = C. We solve for X by multiplying C2 (DivRHS) and C (CmpRHS). + // By solving for X, we can turn this into a range check instead of computing + // a divide. Constant *Prod = ConstantExpr::getMul(CmpRHS, DivRHS); - // Determine if the product overflows by seeing if the product is - // not equal to the divide. Make sure we do the same kind of divide - // as in the LHS instruction that we're folding. - bool ProdOV = (DivIsSigned ? ConstantExpr::getSDiv(Prod, DivRHS) : - ConstantExpr::getUDiv(Prod, DivRHS)) != CmpRHS; + // Determine if the product overflows by seeing if the product is not equal to + // the divide. Make sure we do the same kind of divide as in the LHS + // instruction that we're folding. + bool ProdOV = (DivIsSigned ? ConstantExpr::getSDiv(Prod, DivRHS) + : ConstantExpr::getUDiv(Prod, DivRHS)) != CmpRHS; - // Get the ICmp opcode - ICmpInst::Predicate Pred = ICI.getPredicate(); + ICmpInst::Predicate Pred = Cmp.getPredicate(); // If the division is known to be exact, then there is no remainder from the // divide, so the covered range size is unit, otherwise it is the divisor. - ConstantInt *RangeSize = DivI->isExact() ? getOne(Prod) : DivRHS; + Constant *RangeSize = + Div->isExact() ? ConstantInt::get(Div->getType(), 1) : DivRHS; // Figure out the interval that is being checked. For example, a comparison // like "X /u 5 == 0" is really checking that X is in the interval [0, 5). @@ -1245,1134 +2184,1094 @@ Instruction *InstCombiner::FoldICmpDivCst(ICmpInst &ICI, BinaryOperator *DivI, if (!HiOverflow) { // If this is not an exact divide, then many values in the range collapse // to the same result value. - HiOverflow = AddWithOverflow(HiBound, LoBound, RangeSize, false); + HiOverflow = addWithOverflow(HiBound, LoBound, RangeSize, false); } - } else if (DivRHS->getValue().isStrictlyPositive()) { // Divisor is > 0. - if (CmpRHSV == 0) { // (X / pos) op 0 + } else if (C2->isStrictlyPositive()) { // Divisor is > 0. + if (*C == 0) { // (X / pos) op 0 // Can't overflow. e.g. X/2 op 0 --> [-1, 2) LoBound = ConstantExpr::getNeg(SubOne(RangeSize)); HiBound = RangeSize; - } else if (CmpRHSV.isStrictlyPositive()) { // (X / pos) op pos + } else if (C->isStrictlyPositive()) { // (X / pos) op pos LoBound = Prod; // e.g. X/5 op 3 --> [15, 20) HiOverflow = LoOverflow = ProdOV; if (!HiOverflow) - HiOverflow = AddWithOverflow(HiBound, Prod, RangeSize, true); + HiOverflow = addWithOverflow(HiBound, Prod, RangeSize, true); } else { // (X / pos) op neg // e.g. X/5 op -3 --> [-15-4, -15+1) --> [-19, -14) HiBound = AddOne(Prod); LoOverflow = HiOverflow = ProdOV ? -1 : 0; if (!LoOverflow) { - ConstantInt *DivNeg =cast<ConstantInt>(ConstantExpr::getNeg(RangeSize)); - LoOverflow = AddWithOverflow(LoBound, HiBound, DivNeg, true) ? -1 : 0; + Constant *DivNeg = ConstantExpr::getNeg(RangeSize); + LoOverflow = addWithOverflow(LoBound, HiBound, DivNeg, true) ? -1 : 0; } } - } else if (DivRHS->isNegative()) { // Divisor is < 0. - if (DivI->isExact()) - RangeSize = cast<ConstantInt>(ConstantExpr::getNeg(RangeSize)); - if (CmpRHSV == 0) { // (X / neg) op 0 + } else if (C2->isNegative()) { // Divisor is < 0. + if (Div->isExact()) + RangeSize = ConstantExpr::getNeg(RangeSize); + if (*C == 0) { // (X / neg) op 0 // e.g. X/-5 op 0 --> [-4, 5) LoBound = AddOne(RangeSize); - HiBound = cast<ConstantInt>(ConstantExpr::getNeg(RangeSize)); + HiBound = ConstantExpr::getNeg(RangeSize); if (HiBound == DivRHS) { // -INTMIN = INTMIN HiOverflow = 1; // [INTMIN+1, overflow) HiBound = nullptr; // e.g. X/INTMIN = 0 --> X > INTMIN } - } else if (CmpRHSV.isStrictlyPositive()) { // (X / neg) op pos + } else if (C->isStrictlyPositive()) { // (X / neg) op pos // e.g. X/-5 op 3 --> [-19, -14) HiBound = AddOne(Prod); HiOverflow = LoOverflow = ProdOV ? -1 : 0; if (!LoOverflow) - LoOverflow = AddWithOverflow(LoBound, HiBound, RangeSize, true) ? -1:0; + LoOverflow = addWithOverflow(LoBound, HiBound, RangeSize, true) ? -1:0; } else { // (X / neg) op neg LoBound = Prod; // e.g. X/-5 op -3 --> [15, 20) LoOverflow = HiOverflow = ProdOV; if (!HiOverflow) - HiOverflow = SubWithOverflow(HiBound, Prod, RangeSize, true); + HiOverflow = subWithOverflow(HiBound, Prod, RangeSize, true); } // Dividing by a negative swaps the condition. LT <-> GT Pred = ICmpInst::getSwappedPredicate(Pred); } - Value *X = DivI->getOperand(0); + Value *X = Div->getOperand(0); switch (Pred) { - default: llvm_unreachable("Unhandled icmp opcode!"); - case ICmpInst::ICMP_EQ: - if (LoOverflow && HiOverflow) - return replaceInstUsesWith(ICI, Builder->getFalse()); - if (HiOverflow) - return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SGE : - ICmpInst::ICMP_UGE, X, LoBound); - if (LoOverflow) - return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SLT : - ICmpInst::ICMP_ULT, X, HiBound); - return replaceInstUsesWith(ICI, InsertRangeTest(X, LoBound, HiBound, - DivIsSigned, true)); - case ICmpInst::ICMP_NE: - if (LoOverflow && HiOverflow) - return replaceInstUsesWith(ICI, Builder->getTrue()); - if (HiOverflow) - return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SLT : - ICmpInst::ICMP_ULT, X, LoBound); - if (LoOverflow) - return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SGE : - ICmpInst::ICMP_UGE, X, HiBound); - return replaceInstUsesWith(ICI, InsertRangeTest(X, LoBound, HiBound, - DivIsSigned, false)); - case ICmpInst::ICMP_ULT: - case ICmpInst::ICMP_SLT: - if (LoOverflow == +1) // Low bound is greater than input range. - return replaceInstUsesWith(ICI, Builder->getTrue()); - if (LoOverflow == -1) // Low bound is less than input range. - return replaceInstUsesWith(ICI, Builder->getFalse()); - return new ICmpInst(Pred, X, LoBound); - case ICmpInst::ICMP_UGT: - case ICmpInst::ICMP_SGT: - if (HiOverflow == +1) // High bound greater than input range. - return replaceInstUsesWith(ICI, Builder->getFalse()); - if (HiOverflow == -1) // High bound less than input range. - return replaceInstUsesWith(ICI, Builder->getTrue()); - if (Pred == ICmpInst::ICMP_UGT) - return new ICmpInst(ICmpInst::ICMP_UGE, X, HiBound); - return new ICmpInst(ICmpInst::ICMP_SGE, X, HiBound); + default: llvm_unreachable("Unhandled icmp opcode!"); + case ICmpInst::ICMP_EQ: + if (LoOverflow && HiOverflow) + return replaceInstUsesWith(Cmp, Builder->getFalse()); + if (HiOverflow) + return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SGE : + ICmpInst::ICMP_UGE, X, LoBound); + if (LoOverflow) + return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SLT : + ICmpInst::ICMP_ULT, X, HiBound); + return replaceInstUsesWith( + Cmp, insertRangeTest(X, LoBound->getUniqueInteger(), + HiBound->getUniqueInteger(), DivIsSigned, true)); + case ICmpInst::ICMP_NE: + if (LoOverflow && HiOverflow) + return replaceInstUsesWith(Cmp, Builder->getTrue()); + if (HiOverflow) + return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SLT : + ICmpInst::ICMP_ULT, X, LoBound); + if (LoOverflow) + return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SGE : + ICmpInst::ICMP_UGE, X, HiBound); + return replaceInstUsesWith(Cmp, + insertRangeTest(X, LoBound->getUniqueInteger(), + HiBound->getUniqueInteger(), + DivIsSigned, false)); + case ICmpInst::ICMP_ULT: + case ICmpInst::ICMP_SLT: + if (LoOverflow == +1) // Low bound is greater than input range. + return replaceInstUsesWith(Cmp, Builder->getTrue()); + if (LoOverflow == -1) // Low bound is less than input range. + return replaceInstUsesWith(Cmp, Builder->getFalse()); + return new ICmpInst(Pred, X, LoBound); + case ICmpInst::ICMP_UGT: + case ICmpInst::ICMP_SGT: + if (HiOverflow == +1) // High bound greater than input range. + return replaceInstUsesWith(Cmp, Builder->getFalse()); + if (HiOverflow == -1) // High bound less than input range. + return replaceInstUsesWith(Cmp, Builder->getTrue()); + if (Pred == ICmpInst::ICMP_UGT) + return new ICmpInst(ICmpInst::ICMP_UGE, X, HiBound); + return new ICmpInst(ICmpInst::ICMP_SGE, X, HiBound); } + + return nullptr; } -/// Handle "icmp(([al]shr X, cst1), cst2)". -Instruction *InstCombiner::FoldICmpShrCst(ICmpInst &ICI, BinaryOperator *Shr, - ConstantInt *ShAmt) { - const APInt &CmpRHSV = cast<ConstantInt>(ICI.getOperand(1))->getValue(); +/// Fold icmp (sub X, Y), C. +Instruction *InstCombiner::foldICmpSubConstant(ICmpInst &Cmp, + BinaryOperator *Sub, + const APInt *C) { + Value *X = Sub->getOperand(0), *Y = Sub->getOperand(1); + ICmpInst::Predicate Pred = Cmp.getPredicate(); - // Check that the shift amount is in range. If not, don't perform - // undefined shifts. When the shift is visited it will be - // simplified. - uint32_t TypeBits = CmpRHSV.getBitWidth(); - uint32_t ShAmtVal = (uint32_t)ShAmt->getLimitedValue(TypeBits); - if (ShAmtVal >= TypeBits || ShAmtVal == 0) + // The following transforms are only worth it if the only user of the subtract + // is the icmp. + if (!Sub->hasOneUse()) return nullptr; - if (!ICI.isEquality()) { - // If we have an unsigned comparison and an ashr, we can't simplify this. - // Similarly for signed comparisons with lshr. - if (ICI.isSigned() != (Shr->getOpcode() == Instruction::AShr)) - return nullptr; - - // Otherwise, all lshr and most exact ashr's are equivalent to a udiv/sdiv - // by a power of 2. Since we already have logic to simplify these, - // transform to div and then simplify the resultant comparison. - if (Shr->getOpcode() == Instruction::AShr && - (!Shr->isExact() || ShAmtVal == TypeBits - 1)) - return nullptr; - - // Revisit the shift (to delete it). - Worklist.Add(Shr); - - Constant *DivCst = - ConstantInt::get(Shr->getType(), APInt::getOneBitSet(TypeBits, ShAmtVal)); + if (Sub->hasNoSignedWrap()) { + // (icmp sgt (sub nsw X, Y), -1) -> (icmp sge X, Y) + if (Pred == ICmpInst::ICMP_SGT && C->isAllOnesValue()) + return new ICmpInst(ICmpInst::ICMP_SGE, X, Y); - Value *Tmp = - Shr->getOpcode() == Instruction::AShr ? - Builder->CreateSDiv(Shr->getOperand(0), DivCst, "", Shr->isExact()) : - Builder->CreateUDiv(Shr->getOperand(0), DivCst, "", Shr->isExact()); + // (icmp sgt (sub nsw X, Y), 0) -> (icmp sgt X, Y) + if (Pred == ICmpInst::ICMP_SGT && *C == 0) + return new ICmpInst(ICmpInst::ICMP_SGT, X, Y); - ICI.setOperand(0, Tmp); + // (icmp slt (sub nsw X, Y), 0) -> (icmp slt X, Y) + if (Pred == ICmpInst::ICMP_SLT && *C == 0) + return new ICmpInst(ICmpInst::ICMP_SLT, X, Y); - // If the builder folded the binop, just return it. - BinaryOperator *TheDiv = dyn_cast<BinaryOperator>(Tmp); - if (!TheDiv) - return &ICI; - - // Otherwise, fold this div/compare. - assert(TheDiv->getOpcode() == Instruction::SDiv || - TheDiv->getOpcode() == Instruction::UDiv); - - Instruction *Res = FoldICmpDivCst(ICI, TheDiv, cast<ConstantInt>(DivCst)); - assert(Res && "This div/cst should have folded!"); - return Res; + // (icmp slt (sub nsw X, Y), 1) -> (icmp sle X, Y) + if (Pred == ICmpInst::ICMP_SLT && *C == 1) + return new ICmpInst(ICmpInst::ICMP_SLE, X, Y); } - // If we are comparing against bits always shifted out, the - // comparison cannot succeed. - APInt Comp = CmpRHSV << ShAmtVal; - ConstantInt *ShiftedCmpRHS = Builder->getInt(Comp); - if (Shr->getOpcode() == Instruction::LShr) - Comp = Comp.lshr(ShAmtVal); - else - Comp = Comp.ashr(ShAmtVal); + const APInt *C2; + if (!match(X, m_APInt(C2))) + return nullptr; - if (Comp != CmpRHSV) { // Comparing against a bit that we know is zero. - bool IsICMP_NE = ICI.getPredicate() == ICmpInst::ICMP_NE; - Constant *Cst = Builder->getInt1(IsICMP_NE); - return replaceInstUsesWith(ICI, Cst); - } + // C2 - Y <u C -> (Y | (C - 1)) == C2 + // iff (C2 & (C - 1)) == C - 1 and C is a power of 2 + if (Pred == ICmpInst::ICMP_ULT && C->isPowerOf2() && + (*C2 & (*C - 1)) == (*C - 1)) + return new ICmpInst(ICmpInst::ICMP_EQ, Builder->CreateOr(Y, *C - 1), X); - // Otherwise, check to see if the bits shifted out are known to be zero. - // If so, we can compare against the unshifted value: - // (X & 4) >> 1 == 2 --> (X & 4) == 4. - if (Shr->hasOneUse() && Shr->isExact()) - return new ICmpInst(ICI.getPredicate(), Shr->getOperand(0), ShiftedCmpRHS); + // C2 - Y >u C -> (Y | C) != C2 + // iff C2 & C == C and C + 1 is a power of 2 + if (Pred == ICmpInst::ICMP_UGT && (*C + 1).isPowerOf2() && (*C2 & *C) == *C) + return new ICmpInst(ICmpInst::ICMP_NE, Builder->CreateOr(Y, *C), X); - if (Shr->hasOneUse()) { - // Otherwise strength reduce the shift into an and. - APInt Val(APInt::getHighBitsSet(TypeBits, TypeBits - ShAmtVal)); - Constant *Mask = Builder->getInt(Val); - - Value *And = Builder->CreateAnd(Shr->getOperand(0), - Mask, Shr->getName()+".mask"); - return new ICmpInst(ICI.getPredicate(), And, ShiftedCmpRHS); - } return nullptr; } -/// Handle "(icmp eq/ne (ashr/lshr const2, A), const1)" -> -/// (icmp eq/ne A, Log2(const2/const1)) -> -/// (icmp eq/ne A, Log2(const2) - Log2(const1)). -Instruction *InstCombiner::FoldICmpCstShrCst(ICmpInst &I, Value *Op, Value *A, - ConstantInt *CI1, - ConstantInt *CI2) { - assert(I.isEquality() && "Cannot fold icmp gt/lt"); - - auto getConstant = [&I, this](bool IsTrue) { - if (I.getPredicate() == I.ICMP_NE) - IsTrue = !IsTrue; - return replaceInstUsesWith(I, ConstantInt::get(I.getType(), IsTrue)); - }; - - auto getICmp = [&I](CmpInst::Predicate Pred, Value *LHS, Value *RHS) { - if (I.getPredicate() == I.ICMP_NE) - Pred = CmpInst::getInversePredicate(Pred); - return new ICmpInst(Pred, LHS, RHS); - }; +/// Fold icmp (add X, Y), C. +Instruction *InstCombiner::foldICmpAddConstant(ICmpInst &Cmp, + BinaryOperator *Add, + const APInt *C) { + Value *Y = Add->getOperand(1); + const APInt *C2; + if (Cmp.isEquality() || !match(Y, m_APInt(C2))) + return nullptr; - const APInt &AP1 = CI1->getValue(); - const APInt &AP2 = CI2->getValue(); + // Fold icmp pred (add X, C2), C. + Value *X = Add->getOperand(0); + Type *Ty = Add->getType(); + auto CR = + ConstantRange::makeExactICmpRegion(Cmp.getPredicate(), *C).subtract(*C2); + const APInt &Upper = CR.getUpper(); + const APInt &Lower = CR.getLower(); + if (Cmp.isSigned()) { + if (Lower.isSignBit()) + return new ICmpInst(ICmpInst::ICMP_SLT, X, ConstantInt::get(Ty, Upper)); + if (Upper.isSignBit()) + return new ICmpInst(ICmpInst::ICMP_SGE, X, ConstantInt::get(Ty, Lower)); + } else { + if (Lower.isMinValue()) + return new ICmpInst(ICmpInst::ICMP_ULT, X, ConstantInt::get(Ty, Upper)); + if (Upper.isMinValue()) + return new ICmpInst(ICmpInst::ICMP_UGE, X, ConstantInt::get(Ty, Lower)); + } - // Don't bother doing any work for cases which InstSimplify handles. - if (AP2 == 0) + if (!Add->hasOneUse()) return nullptr; - bool IsAShr = isa<AShrOperator>(Op); - if (IsAShr) { - if (AP2.isAllOnesValue()) - return nullptr; - if (AP2.isNegative() != AP1.isNegative()) - return nullptr; - if (AP2.sgt(AP1)) - return nullptr; - } - if (!AP1) - // 'A' must be large enough to shift out the highest set bit. - return getICmp(I.ICMP_UGT, A, - ConstantInt::get(A->getType(), AP2.logBase2())); + // X+C <u C2 -> (X & -C2) == C + // iff C & (C2-1) == 0 + // C2 is a power of 2 + if (Cmp.getPredicate() == ICmpInst::ICMP_ULT && C->isPowerOf2() && + (*C2 & (*C - 1)) == 0) + return new ICmpInst(ICmpInst::ICMP_EQ, Builder->CreateAnd(X, -(*C)), + ConstantExpr::getNeg(cast<Constant>(Y))); + + // X+C >u C2 -> (X & ~C2) != C + // iff C & C2 == 0 + // C2+1 is a power of 2 + if (Cmp.getPredicate() == ICmpInst::ICMP_UGT && (*C + 1).isPowerOf2() && + (*C2 & *C) == 0) + return new ICmpInst(ICmpInst::ICMP_NE, Builder->CreateAnd(X, ~(*C)), + ConstantExpr::getNeg(cast<Constant>(Y))); - if (AP1 == AP2) - return getICmp(I.ICMP_EQ, A, ConstantInt::getNullValue(A->getType())); + return nullptr; +} - int Shift; - if (IsAShr && AP1.isNegative()) - Shift = AP1.countLeadingOnes() - AP2.countLeadingOnes(); - else - Shift = AP1.countLeadingZeros() - AP2.countLeadingZeros(); +/// Try to fold integer comparisons with a constant operand: icmp Pred X, C +/// where X is some kind of instruction. +Instruction *InstCombiner::foldICmpInstWithConstant(ICmpInst &Cmp) { + const APInt *C; + if (!match(Cmp.getOperand(1), m_APInt(C))) + return nullptr; - if (Shift > 0) { - if (IsAShr && AP1 == AP2.ashr(Shift)) { - // There are multiple solutions if we are comparing against -1 and the LHS - // of the ashr is not a power of two. - if (AP1.isAllOnesValue() && !AP2.isPowerOf2()) - return getICmp(I.ICMP_UGE, A, ConstantInt::get(A->getType(), Shift)); - return getICmp(I.ICMP_EQ, A, ConstantInt::get(A->getType(), Shift)); - } else if (AP1 == AP2.lshr(Shift)) { - return getICmp(I.ICMP_EQ, A, ConstantInt::get(A->getType(), Shift)); + BinaryOperator *BO; + if (match(Cmp.getOperand(0), m_BinOp(BO))) { + switch (BO->getOpcode()) { + case Instruction::Xor: + if (Instruction *I = foldICmpXorConstant(Cmp, BO, C)) + return I; + break; + case Instruction::And: + if (Instruction *I = foldICmpAndConstant(Cmp, BO, C)) + return I; + break; + case Instruction::Or: + if (Instruction *I = foldICmpOrConstant(Cmp, BO, C)) + return I; + break; + case Instruction::Mul: + if (Instruction *I = foldICmpMulConstant(Cmp, BO, C)) + return I; + break; + case Instruction::Shl: + if (Instruction *I = foldICmpShlConstant(Cmp, BO, C)) + return I; + break; + case Instruction::LShr: + case Instruction::AShr: + if (Instruction *I = foldICmpShrConstant(Cmp, BO, C)) + return I; + break; + case Instruction::UDiv: + if (Instruction *I = foldICmpUDivConstant(Cmp, BO, C)) + return I; + LLVM_FALLTHROUGH; + case Instruction::SDiv: + if (Instruction *I = foldICmpDivConstant(Cmp, BO, C)) + return I; + break; + case Instruction::Sub: + if (Instruction *I = foldICmpSubConstant(Cmp, BO, C)) + return I; + break; + case Instruction::Add: + if (Instruction *I = foldICmpAddConstant(Cmp, BO, C)) + return I; + break; + default: + break; } + // TODO: These folds could be refactored to be part of the above calls. + if (Instruction *I = foldICmpBinOpEqualityWithConstant(Cmp, BO, C)) + return I; } - // Shifting const2 will never be equal to const1. - return getConstant(false); -} -/// Handle "(icmp eq/ne (shl const2, A), const1)" -> -/// (icmp eq/ne A, TrailingZeros(const1) - TrailingZeros(const2)). -Instruction *InstCombiner::FoldICmpCstShlCst(ICmpInst &I, Value *Op, Value *A, - ConstantInt *CI1, - ConstantInt *CI2) { - assert(I.isEquality() && "Cannot fold icmp gt/lt"); + Instruction *LHSI; + if (match(Cmp.getOperand(0), m_Instruction(LHSI)) && + LHSI->getOpcode() == Instruction::Trunc) + if (Instruction *I = foldICmpTruncConstant(Cmp, LHSI, C)) + return I; - auto getConstant = [&I, this](bool IsTrue) { - if (I.getPredicate() == I.ICMP_NE) - IsTrue = !IsTrue; - return replaceInstUsesWith(I, ConstantInt::get(I.getType(), IsTrue)); - }; + if (Instruction *I = foldICmpIntrinsicWithConstant(Cmp, C)) + return I; - auto getICmp = [&I](CmpInst::Predicate Pred, Value *LHS, Value *RHS) { - if (I.getPredicate() == I.ICMP_NE) - Pred = CmpInst::getInversePredicate(Pred); - return new ICmpInst(Pred, LHS, RHS); - }; - - const APInt &AP1 = CI1->getValue(); - const APInt &AP2 = CI2->getValue(); + return nullptr; +} - // Don't bother doing any work for cases which InstSimplify handles. - if (AP2 == 0) +/// Fold an icmp equality instruction with binary operator LHS and constant RHS: +/// icmp eq/ne BO, C. +Instruction *InstCombiner::foldICmpBinOpEqualityWithConstant(ICmpInst &Cmp, + BinaryOperator *BO, + const APInt *C) { + // TODO: Some of these folds could work with arbitrary constants, but this + // function is limited to scalar and vector splat constants. + if (!Cmp.isEquality()) return nullptr; - unsigned AP2TrailingZeros = AP2.countTrailingZeros(); - - if (!AP1 && AP2TrailingZeros != 0) - return getICmp(I.ICMP_UGE, A, - ConstantInt::get(A->getType(), AP2.getBitWidth() - AP2TrailingZeros)); + ICmpInst::Predicate Pred = Cmp.getPredicate(); + bool isICMP_NE = Pred == ICmpInst::ICMP_NE; + Constant *RHS = cast<Constant>(Cmp.getOperand(1)); + Value *BOp0 = BO->getOperand(0), *BOp1 = BO->getOperand(1); + + switch (BO->getOpcode()) { + case Instruction::SRem: + // If we have a signed (X % (2^c)) == 0, turn it into an unsigned one. + if (*C == 0 && BO->hasOneUse()) { + const APInt *BOC; + if (match(BOp1, m_APInt(BOC)) && BOC->sgt(1) && BOC->isPowerOf2()) { + Value *NewRem = Builder->CreateURem(BOp0, BOp1, BO->getName()); + return new ICmpInst(Pred, NewRem, + Constant::getNullValue(BO->getType())); + } + } + break; + case Instruction::Add: { + // Replace ((add A, B) != C) with (A != C-B) if B & C are constants. + const APInt *BOC; + if (match(BOp1, m_APInt(BOC))) { + if (BO->hasOneUse()) { + Constant *SubC = ConstantExpr::getSub(RHS, cast<Constant>(BOp1)); + return new ICmpInst(Pred, BOp0, SubC); + } + } else if (*C == 0) { + // Replace ((add A, B) != 0) with (A != -B) if A or B is + // efficiently invertible, or if the add has just this one use. + if (Value *NegVal = dyn_castNegVal(BOp1)) + return new ICmpInst(Pred, BOp0, NegVal); + if (Value *NegVal = dyn_castNegVal(BOp0)) + return new ICmpInst(Pred, NegVal, BOp1); + if (BO->hasOneUse()) { + Value *Neg = Builder->CreateNeg(BOp1); + Neg->takeName(BO); + return new ICmpInst(Pred, BOp0, Neg); + } + } + break; + } + case Instruction::Xor: + if (BO->hasOneUse()) { + if (Constant *BOC = dyn_cast<Constant>(BOp1)) { + // For the xor case, we can xor two constants together, eliminating + // the explicit xor. + return new ICmpInst(Pred, BOp0, ConstantExpr::getXor(RHS, BOC)); + } else if (*C == 0) { + // Replace ((xor A, B) != 0) with (A != B) + return new ICmpInst(Pred, BOp0, BOp1); + } + } + break; + case Instruction::Sub: + if (BO->hasOneUse()) { + const APInt *BOC; + if (match(BOp0, m_APInt(BOC))) { + // Replace ((sub BOC, B) != C) with (B != BOC-C). + Constant *SubC = ConstantExpr::getSub(cast<Constant>(BOp0), RHS); + return new ICmpInst(Pred, BOp1, SubC); + } else if (*C == 0) { + // Replace ((sub A, B) != 0) with (A != B). + return new ICmpInst(Pred, BOp0, BOp1); + } + } + break; + case Instruction::Or: { + const APInt *BOC; + if (match(BOp1, m_APInt(BOC)) && BO->hasOneUse() && RHS->isAllOnesValue()) { + // Comparing if all bits outside of a constant mask are set? + // Replace (X | C) == -1 with (X & ~C) == ~C. + // This removes the -1 constant. + Constant *NotBOC = ConstantExpr::getNot(cast<Constant>(BOp1)); + Value *And = Builder->CreateAnd(BOp0, NotBOC); + return new ICmpInst(Pred, And, NotBOC); + } + break; + } + case Instruction::And: { + const APInt *BOC; + if (match(BOp1, m_APInt(BOC))) { + // If we have ((X & C) == C), turn it into ((X & C) != 0). + if (C == BOC && C->isPowerOf2()) + return new ICmpInst(isICMP_NE ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE, + BO, Constant::getNullValue(RHS->getType())); + + // Don't perform the following transforms if the AND has multiple uses + if (!BO->hasOneUse()) + break; - if (AP1 == AP2) - return getICmp(I.ICMP_EQ, A, ConstantInt::getNullValue(A->getType())); + // Replace (and X, (1 << size(X)-1) != 0) with x s< 0 + if (BOC->isSignBit()) { + Constant *Zero = Constant::getNullValue(BOp0->getType()); + auto NewPred = isICMP_NE ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_SGE; + return new ICmpInst(NewPred, BOp0, Zero); + } - // Get the distance between the lowest bits that are set. - int Shift = AP1.countTrailingZeros() - AP2TrailingZeros; + // ((X & ~7) == 0) --> X < 8 + if (*C == 0 && (~(*BOC) + 1).isPowerOf2()) { + Constant *NegBOC = ConstantExpr::getNeg(cast<Constant>(BOp1)); + auto NewPred = isICMP_NE ? ICmpInst::ICMP_UGE : ICmpInst::ICMP_ULT; + return new ICmpInst(NewPred, BOp0, NegBOC); + } + } + break; + } + case Instruction::Mul: + if (*C == 0 && BO->hasNoSignedWrap()) { + const APInt *BOC; + if (match(BOp1, m_APInt(BOC)) && *BOC != 0) { + // The trivial case (mul X, 0) is handled by InstSimplify. + // General case : (mul X, C) != 0 iff X != 0 + // (mul X, C) == 0 iff X == 0 + return new ICmpInst(Pred, BOp0, Constant::getNullValue(RHS->getType())); + } + } + break; + case Instruction::UDiv: + if (*C == 0) { + // (icmp eq/ne (udiv A, B), 0) -> (icmp ugt/ule i32 B, A) + auto NewPred = isICMP_NE ? ICmpInst::ICMP_ULE : ICmpInst::ICMP_UGT; + return new ICmpInst(NewPred, BOp1, BOp0); + } + break; + default: + break; + } + return nullptr; +} - if (Shift > 0 && AP2.shl(Shift) == AP1) - return getICmp(I.ICMP_EQ, A, ConstantInt::get(A->getType(), Shift)); +/// Fold an icmp with LLVM intrinsic and constant operand: icmp Pred II, C. +Instruction *InstCombiner::foldICmpIntrinsicWithConstant(ICmpInst &Cmp, + const APInt *C) { + IntrinsicInst *II = dyn_cast<IntrinsicInst>(Cmp.getOperand(0)); + if (!II || !Cmp.isEquality()) + return nullptr; - // Shifting const2 will never be equal to const1. - return getConstant(false); + // Handle icmp {eq|ne} <intrinsic>, intcst. + switch (II->getIntrinsicID()) { + case Intrinsic::bswap: + Worklist.Add(II); + Cmp.setOperand(0, II->getArgOperand(0)); + Cmp.setOperand(1, Builder->getInt(C->byteSwap())); + return &Cmp; + case Intrinsic::ctlz: + case Intrinsic::cttz: + // ctz(A) == bitwidth(A) -> A == 0 and likewise for != + if (*C == C->getBitWidth()) { + Worklist.Add(II); + Cmp.setOperand(0, II->getArgOperand(0)); + Cmp.setOperand(1, ConstantInt::getNullValue(II->getType())); + return &Cmp; + } + break; + case Intrinsic::ctpop: { + // popcount(A) == 0 -> A == 0 and likewise for != + // popcount(A) == bitwidth(A) -> A == -1 and likewise for != + bool IsZero = *C == 0; + if (IsZero || *C == C->getBitWidth()) { + Worklist.Add(II); + Cmp.setOperand(0, II->getArgOperand(0)); + auto *NewOp = IsZero ? Constant::getNullValue(II->getType()) + : Constant::getAllOnesValue(II->getType()); + Cmp.setOperand(1, NewOp); + return &Cmp; + } + break; + } + default: + break; + } + return nullptr; } -/// Handle "icmp (instr, intcst)". -Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI, - Instruction *LHSI, - ConstantInt *RHS) { - const APInt &RHSV = RHS->getValue(); +/// Handle icmp with constant (but not simple integer constant) RHS. +Instruction *InstCombiner::foldICmpInstWithConstantNotInt(ICmpInst &I) { + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + Constant *RHSC = dyn_cast<Constant>(Op1); + Instruction *LHSI = dyn_cast<Instruction>(Op0); + if (!RHSC || !LHSI) + return nullptr; switch (LHSI->getOpcode()) { - case Instruction::Trunc: - if (RHS->isOne() && RHSV.getBitWidth() > 1) { - // icmp slt trunc(signum(V)) 1 --> icmp slt V, 1 - Value *V = nullptr; - if (ICI.getPredicate() == ICmpInst::ICMP_SLT && - match(LHSI->getOperand(0), m_Signum(m_Value(V)))) - return new ICmpInst(ICmpInst::ICMP_SLT, V, - ConstantInt::get(V->getType(), 1)); + case Instruction::GetElementPtr: + // icmp pred GEP (P, int 0, int 0, int 0), null -> icmp pred P, null + if (RHSC->isNullValue() && + cast<GetElementPtrInst>(LHSI)->hasAllZeroIndices()) + return new ICmpInst( + I.getPredicate(), LHSI->getOperand(0), + Constant::getNullValue(LHSI->getOperand(0)->getType())); + break; + case Instruction::PHI: + // Only fold icmp into the PHI if the phi and icmp are in the same + // block. If in the same block, we're encouraging jump threading. If + // not, we are just pessimizing the code by making an i1 phi. + if (LHSI->getParent() == I.getParent()) + if (Instruction *NV = FoldOpIntoPhi(I)) + return NV; + break; + case Instruction::Select: { + // If either operand of the select is a constant, we can fold the + // comparison into the select arms, which will cause one to be + // constant folded and the select turned into a bitwise or. + Value *Op1 = nullptr, *Op2 = nullptr; + ConstantInt *CI = nullptr; + if (Constant *C = dyn_cast<Constant>(LHSI->getOperand(1))) { + Op1 = ConstantExpr::getICmp(I.getPredicate(), C, RHSC); + CI = dyn_cast<ConstantInt>(Op1); } - if (ICI.isEquality() && LHSI->hasOneUse()) { - // Simplify icmp eq (trunc x to i8), 42 -> icmp eq x, 42|highbits if all - // of the high bits truncated out of x are known. - unsigned DstBits = LHSI->getType()->getPrimitiveSizeInBits(), - SrcBits = LHSI->getOperand(0)->getType()->getPrimitiveSizeInBits(); - APInt KnownZero(SrcBits, 0), KnownOne(SrcBits, 0); - computeKnownBits(LHSI->getOperand(0), KnownZero, KnownOne, 0, &ICI); - - // If all the high bits are known, we can do this xform. - if ((KnownZero|KnownOne).countLeadingOnes() >= SrcBits-DstBits) { - // Pull in the high bits from known-ones set. - APInt NewRHS = RHS->getValue().zext(SrcBits); - NewRHS |= KnownOne & APInt::getHighBitsSet(SrcBits, SrcBits-DstBits); - return new ICmpInst(ICI.getPredicate(), LHSI->getOperand(0), - Builder->getInt(NewRHS)); - } + if (Constant *C = dyn_cast<Constant>(LHSI->getOperand(2))) { + Op2 = ConstantExpr::getICmp(I.getPredicate(), C, RHSC); + CI = dyn_cast<ConstantInt>(Op2); + } + + // We only want to perform this transformation if it will not lead to + // additional code. This is true if either both sides of the select + // fold to a constant (in which case the icmp is replaced with a select + // which will usually simplify) or this is the only user of the + // select (in which case we are trading a select+icmp for a simpler + // select+icmp) or all uses of the select can be replaced based on + // dominance information ("Global cases"). + bool Transform = false; + if (Op1 && Op2) + Transform = true; + else if (Op1 || Op2) { + // Local case + if (LHSI->hasOneUse()) + Transform = true; + // Global cases + else if (CI && !CI->isZero()) + // When Op1 is constant try replacing select with second operand. + // Otherwise Op2 is constant and try replacing select with first + // operand. + Transform = + replacedSelectWithOperand(cast<SelectInst>(LHSI), &I, Op1 ? 2 : 1); + } + if (Transform) { + if (!Op1) + Op1 = Builder->CreateICmp(I.getPredicate(), LHSI->getOperand(1), RHSC, + I.getName()); + if (!Op2) + Op2 = Builder->CreateICmp(I.getPredicate(), LHSI->getOperand(2), RHSC, + I.getName()); + return SelectInst::Create(LHSI->getOperand(0), Op1, Op2); } break; + } + case Instruction::IntToPtr: + // icmp pred inttoptr(X), null -> icmp pred X, 0 + if (RHSC->isNullValue() && + DL.getIntPtrType(RHSC->getType()) == LHSI->getOperand(0)->getType()) + return new ICmpInst( + I.getPredicate(), LHSI->getOperand(0), + Constant::getNullValue(LHSI->getOperand(0)->getType())); + break; - case Instruction::Xor: // (icmp pred (xor X, XorCst), CI) - if (ConstantInt *XorCst = dyn_cast<ConstantInt>(LHSI->getOperand(1))) { - // If this is a comparison that tests the signbit (X < 0) or (x > -1), - // fold the xor. - if ((ICI.getPredicate() == ICmpInst::ICMP_SLT && RHSV == 0) || - (ICI.getPredicate() == ICmpInst::ICMP_SGT && RHSV.isAllOnesValue())) { - Value *CompareVal = LHSI->getOperand(0); - - // If the sign bit of the XorCst is not set, there is no change to - // the operation, just stop using the Xor. - if (!XorCst->isNegative()) { - ICI.setOperand(0, CompareVal); - Worklist.Add(LHSI); - return &ICI; - } + case Instruction::Load: + // Try to optimize things like "A[i] > 4" to index computations. + if (GetElementPtrInst *GEP = + dyn_cast<GetElementPtrInst>(LHSI->getOperand(0))) { + if (GlobalVariable *GV = dyn_cast<GlobalVariable>(GEP->getOperand(0))) + if (GV->isConstant() && GV->hasDefinitiveInitializer() && + !cast<LoadInst>(LHSI)->isVolatile()) + if (Instruction *Res = foldCmpLoadFromIndexedGlobal(GEP, GV, I)) + return Res; + } + break; + } - // Was the old condition true if the operand is positive? - bool isTrueIfPositive = ICI.getPredicate() == ICmpInst::ICMP_SGT; + return nullptr; +} - // If so, the new one isn't. - isTrueIfPositive ^= true; +/// Try to fold icmp (binop), X or icmp X, (binop). +Instruction *InstCombiner::foldICmpBinOp(ICmpInst &I) { + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); - if (isTrueIfPositive) - return new ICmpInst(ICmpInst::ICMP_SGT, CompareVal, - SubOne(RHS)); - else - return new ICmpInst(ICmpInst::ICMP_SLT, CompareVal, - AddOne(RHS)); - } + // Special logic for binary operators. + BinaryOperator *BO0 = dyn_cast<BinaryOperator>(Op0); + BinaryOperator *BO1 = dyn_cast<BinaryOperator>(Op1); + if (!BO0 && !BO1) + return nullptr; - if (LHSI->hasOneUse()) { - // (icmp u/s (xor A SignBit), C) -> (icmp s/u A, (xor C SignBit)) - if (!ICI.isEquality() && XorCst->getValue().isSignBit()) { - const APInt &SignBit = XorCst->getValue(); - ICmpInst::Predicate Pred = ICI.isSigned() - ? ICI.getUnsignedPredicate() - : ICI.getSignedPredicate(); - return new ICmpInst(Pred, LHSI->getOperand(0), - Builder->getInt(RHSV ^ SignBit)); - } + CmpInst::Predicate Pred = I.getPredicate(); + bool NoOp0WrapProblem = false, NoOp1WrapProblem = false; + if (BO0 && isa<OverflowingBinaryOperator>(BO0)) + NoOp0WrapProblem = + ICmpInst::isEquality(Pred) || + (CmpInst::isUnsigned(Pred) && BO0->hasNoUnsignedWrap()) || + (CmpInst::isSigned(Pred) && BO0->hasNoSignedWrap()); + if (BO1 && isa<OverflowingBinaryOperator>(BO1)) + NoOp1WrapProblem = + ICmpInst::isEquality(Pred) || + (CmpInst::isUnsigned(Pred) && BO1->hasNoUnsignedWrap()) || + (CmpInst::isSigned(Pred) && BO1->hasNoSignedWrap()); - // (icmp u/s (xor A ~SignBit), C) -> (icmp s/u (xor C ~SignBit), A) - if (!ICI.isEquality() && XorCst->isMaxValue(true)) { - const APInt &NotSignBit = XorCst->getValue(); - ICmpInst::Predicate Pred = ICI.isSigned() - ? ICI.getUnsignedPredicate() - : ICI.getSignedPredicate(); - Pred = ICI.getSwappedPredicate(Pred); - return new ICmpInst(Pred, LHSI->getOperand(0), - Builder->getInt(RHSV ^ NotSignBit)); - } - } + // Analyze the case when either Op0 or Op1 is an add instruction. + // Op0 = A + B (or A and B are null); Op1 = C + D (or C and D are null). + Value *A = nullptr, *B = nullptr, *C = nullptr, *D = nullptr; + if (BO0 && BO0->getOpcode() == Instruction::Add) { + A = BO0->getOperand(0); + B = BO0->getOperand(1); + } + if (BO1 && BO1->getOpcode() == Instruction::Add) { + C = BO1->getOperand(0); + D = BO1->getOperand(1); + } - // (icmp ugt (xor X, C), ~C) -> (icmp ult X, C) - // iff -C is a power of 2 - if (ICI.getPredicate() == ICmpInst::ICMP_UGT && - XorCst->getValue() == ~RHSV && (RHSV + 1).isPowerOf2()) - return new ICmpInst(ICmpInst::ICMP_ULT, LHSI->getOperand(0), XorCst); - - // (icmp ult (xor X, C), -C) -> (icmp uge X, C) - // iff -C is a power of 2 - if (ICI.getPredicate() == ICmpInst::ICMP_ULT && - XorCst->getValue() == -RHSV && RHSV.isPowerOf2()) - return new ICmpInst(ICmpInst::ICMP_UGE, LHSI->getOperand(0), XorCst); + // icmp (X+cst) < 0 --> X < -cst + if (NoOp0WrapProblem && ICmpInst::isSigned(Pred) && match(Op1, m_Zero())) + if (ConstantInt *RHSC = dyn_cast_or_null<ConstantInt>(B)) + if (!RHSC->isMinValue(/*isSigned=*/true)) + return new ICmpInst(Pred, A, ConstantExpr::getNeg(RHSC)); + + // icmp (X+Y), X -> icmp Y, 0 for equalities or if there is no overflow. + if ((A == Op1 || B == Op1) && NoOp0WrapProblem) + return new ICmpInst(Pred, A == Op1 ? B : A, + Constant::getNullValue(Op1->getType())); + + // icmp X, (X+Y) -> icmp 0, Y for equalities or if there is no overflow. + if ((C == Op0 || D == Op0) && NoOp1WrapProblem) + return new ICmpInst(Pred, Constant::getNullValue(Op0->getType()), + C == Op0 ? D : C); + + // icmp (X+Y), (X+Z) -> icmp Y, Z for equalities or if there is no overflow. + if (A && C && (A == C || A == D || B == C || B == D) && NoOp0WrapProblem && + NoOp1WrapProblem && + // Try not to increase register pressure. + BO0->hasOneUse() && BO1->hasOneUse()) { + // Determine Y and Z in the form icmp (X+Y), (X+Z). + Value *Y, *Z; + if (A == C) { + // C + B == C + D -> B == D + Y = B; + Z = D; + } else if (A == D) { + // D + B == C + D -> B == C + Y = B; + Z = C; + } else if (B == C) { + // A + C == C + D -> A == D + Y = A; + Z = D; + } else { + assert(B == D); + // A + D == C + D -> A == C + Y = A; + Z = C; } - break; - case Instruction::And: // (icmp pred (and X, AndCst), RHS) - if (LHSI->hasOneUse() && isa<ConstantInt>(LHSI->getOperand(1)) && - LHSI->getOperand(0)->hasOneUse()) { - ConstantInt *AndCst = cast<ConstantInt>(LHSI->getOperand(1)); - - // If the LHS is an AND of a truncating cast, we can widen the - // and/compare to be the input width without changing the value - // produced, eliminating a cast. - if (TruncInst *Cast = dyn_cast<TruncInst>(LHSI->getOperand(0))) { - // We can do this transformation if either the AND constant does not - // have its sign bit set or if it is an equality comparison. - // Extending a relational comparison when we're checking the sign - // bit would not work. - if (ICI.isEquality() || - (!AndCst->isNegative() && RHSV.isNonNegative())) { - Value *NewAnd = - Builder->CreateAnd(Cast->getOperand(0), - ConstantExpr::getZExt(AndCst, Cast->getSrcTy())); - NewAnd->takeName(LHSI); - return new ICmpInst(ICI.getPredicate(), NewAnd, - ConstantExpr::getZExt(RHS, Cast->getSrcTy())); - } - } - - // If the LHS is an AND of a zext, and we have an equality compare, we can - // shrink the and/compare to the smaller type, eliminating the cast. - if (ZExtInst *Cast = dyn_cast<ZExtInst>(LHSI->getOperand(0))) { - IntegerType *Ty = cast<IntegerType>(Cast->getSrcTy()); - // Make sure we don't compare the upper bits, SimplifyDemandedBits - // should fold the icmp to true/false in that case. - if (ICI.isEquality() && RHSV.getActiveBits() <= Ty->getBitWidth()) { - Value *NewAnd = - Builder->CreateAnd(Cast->getOperand(0), - ConstantExpr::getTrunc(AndCst, Ty)); - NewAnd->takeName(LHSI); - return new ICmpInst(ICI.getPredicate(), NewAnd, - ConstantExpr::getTrunc(RHS, Ty)); - } - } - - // If this is: (X >> C1) & C2 != C3 (where any shift and any compare - // could exist), turn it into (X & (C2 << C1)) != (C3 << C1). This - // happens a LOT in code produced by the C front-end, for bitfield - // access. - BinaryOperator *Shift = dyn_cast<BinaryOperator>(LHSI->getOperand(0)); - if (Shift && !Shift->isShift()) - Shift = nullptr; - - ConstantInt *ShAmt; - ShAmt = Shift ? dyn_cast<ConstantInt>(Shift->getOperand(1)) : nullptr; - - // This seemingly simple opportunity to fold away a shift turns out to - // be rather complicated. See PR17827 - // ( http://llvm.org/bugs/show_bug.cgi?id=17827 ) for details. - if (ShAmt) { - bool CanFold = false; - unsigned ShiftOpcode = Shift->getOpcode(); - if (ShiftOpcode == Instruction::AShr) { - // There may be some constraints that make this possible, - // but nothing simple has been discovered yet. - CanFold = false; - } else if (ShiftOpcode == Instruction::Shl) { - // For a left shift, we can fold if the comparison is not signed. - // We can also fold a signed comparison if the mask value and - // comparison value are not negative. These constraints may not be - // obvious, but we can prove that they are correct using an SMT - // solver. - if (!ICI.isSigned() || (!AndCst->isNegative() && !RHS->isNegative())) - CanFold = true; - } else if (ShiftOpcode == Instruction::LShr) { - // For a logical right shift, we can fold if the comparison is not - // signed. We can also fold a signed comparison if the shifted mask - // value and the shifted comparison value are not negative. - // These constraints may not be obvious, but we can prove that they - // are correct using an SMT solver. - if (!ICI.isSigned()) - CanFold = true; - else { - ConstantInt *ShiftedAndCst = - cast<ConstantInt>(ConstantExpr::getShl(AndCst, ShAmt)); - ConstantInt *ShiftedRHSCst = - cast<ConstantInt>(ConstantExpr::getShl(RHS, ShAmt)); - - if (!ShiftedAndCst->isNegative() && !ShiftedRHSCst->isNegative()) - CanFold = true; - } - } + return new ICmpInst(Pred, Y, Z); + } - if (CanFold) { - Constant *NewCst; - if (ShiftOpcode == Instruction::Shl) - NewCst = ConstantExpr::getLShr(RHS, ShAmt); - else - NewCst = ConstantExpr::getShl(RHS, ShAmt); - - // Check to see if we are shifting out any of the bits being - // compared. - if (ConstantExpr::get(ShiftOpcode, NewCst, ShAmt) != RHS) { - // If we shifted bits out, the fold is not going to work out. - // As a special case, check to see if this means that the - // result is always true or false now. - if (ICI.getPredicate() == ICmpInst::ICMP_EQ) - return replaceInstUsesWith(ICI, Builder->getFalse()); - if (ICI.getPredicate() == ICmpInst::ICMP_NE) - return replaceInstUsesWith(ICI, Builder->getTrue()); + // icmp slt (X + -1), Y -> icmp sle X, Y + if (A && NoOp0WrapProblem && Pred == CmpInst::ICMP_SLT && + match(B, m_AllOnes())) + return new ICmpInst(CmpInst::ICMP_SLE, A, Op1); + + // icmp sge (X + -1), Y -> icmp sgt X, Y + if (A && NoOp0WrapProblem && Pred == CmpInst::ICMP_SGE && + match(B, m_AllOnes())) + return new ICmpInst(CmpInst::ICMP_SGT, A, Op1); + + // icmp sle (X + 1), Y -> icmp slt X, Y + if (A && NoOp0WrapProblem && Pred == CmpInst::ICMP_SLE && match(B, m_One())) + return new ICmpInst(CmpInst::ICMP_SLT, A, Op1); + + // icmp sgt (X + 1), Y -> icmp sge X, Y + if (A && NoOp0WrapProblem && Pred == CmpInst::ICMP_SGT && match(B, m_One())) + return new ICmpInst(CmpInst::ICMP_SGE, A, Op1); + + // icmp sgt X, (Y + -1) -> icmp sge X, Y + if (C && NoOp1WrapProblem && Pred == CmpInst::ICMP_SGT && + match(D, m_AllOnes())) + return new ICmpInst(CmpInst::ICMP_SGE, Op0, C); + + // icmp sle X, (Y + -1) -> icmp slt X, Y + if (C && NoOp1WrapProblem && Pred == CmpInst::ICMP_SLE && + match(D, m_AllOnes())) + return new ICmpInst(CmpInst::ICMP_SLT, Op0, C); + + // icmp sge X, (Y + 1) -> icmp sgt X, Y + if (C && NoOp1WrapProblem && Pred == CmpInst::ICMP_SGE && match(D, m_One())) + return new ICmpInst(CmpInst::ICMP_SGT, Op0, C); + + // icmp slt X, (Y + 1) -> icmp sle X, Y + if (C && NoOp1WrapProblem && Pred == CmpInst::ICMP_SLT && match(D, m_One())) + return new ICmpInst(CmpInst::ICMP_SLE, Op0, C); + + // if C1 has greater magnitude than C2: + // icmp (X + C1), (Y + C2) -> icmp (X + C3), Y + // s.t. C3 = C1 - C2 + // + // if C2 has greater magnitude than C1: + // icmp (X + C1), (Y + C2) -> icmp X, (Y + C3) + // s.t. C3 = C2 - C1 + if (A && C && NoOp0WrapProblem && NoOp1WrapProblem && + (BO0->hasOneUse() || BO1->hasOneUse()) && !I.isUnsigned()) + if (ConstantInt *C1 = dyn_cast<ConstantInt>(B)) + if (ConstantInt *C2 = dyn_cast<ConstantInt>(D)) { + const APInt &AP1 = C1->getValue(); + const APInt &AP2 = C2->getValue(); + if (AP1.isNegative() == AP2.isNegative()) { + APInt AP1Abs = C1->getValue().abs(); + APInt AP2Abs = C2->getValue().abs(); + if (AP1Abs.uge(AP2Abs)) { + ConstantInt *C3 = Builder->getInt(AP1 - AP2); + Value *NewAdd = Builder->CreateNSWAdd(A, C3); + return new ICmpInst(Pred, NewAdd, C); } else { - ICI.setOperand(1, NewCst); - Constant *NewAndCst; - if (ShiftOpcode == Instruction::Shl) - NewAndCst = ConstantExpr::getLShr(AndCst, ShAmt); - else - NewAndCst = ConstantExpr::getShl(AndCst, ShAmt); - LHSI->setOperand(1, NewAndCst); - LHSI->setOperand(0, Shift->getOperand(0)); - Worklist.Add(Shift); // Shift is dead. - return &ICI; + ConstantInt *C3 = Builder->getInt(AP2 - AP1); + Value *NewAdd = Builder->CreateNSWAdd(C, C3); + return new ICmpInst(Pred, A, NewAdd); } } } - // Turn ((X >> Y) & C) == 0 into (X & (C << Y)) == 0. The later is - // preferable because it allows the C<<Y expression to be hoisted out - // of a loop if Y is invariant and X is not. - if (Shift && Shift->hasOneUse() && RHSV == 0 && - ICI.isEquality() && !Shift->isArithmeticShift() && - !isa<Constant>(Shift->getOperand(0))) { - // Compute C << Y. - Value *NS; - if (Shift->getOpcode() == Instruction::LShr) { - NS = Builder->CreateShl(AndCst, Shift->getOperand(1)); - } else { - // Insert a logical shift. - NS = Builder->CreateLShr(AndCst, Shift->getOperand(1)); - } + // Analyze the case when either Op0 or Op1 is a sub instruction. + // Op0 = A - B (or A and B are null); Op1 = C - D (or C and D are null). + A = nullptr; + B = nullptr; + C = nullptr; + D = nullptr; + if (BO0 && BO0->getOpcode() == Instruction::Sub) { + A = BO0->getOperand(0); + B = BO0->getOperand(1); + } + if (BO1 && BO1->getOpcode() == Instruction::Sub) { + C = BO1->getOperand(0); + D = BO1->getOperand(1); + } - // Compute X & (C << Y). - Value *NewAnd = - Builder->CreateAnd(Shift->getOperand(0), NS, LHSI->getName()); + // icmp (X-Y), X -> icmp 0, Y for equalities or if there is no overflow. + if (A == Op1 && NoOp0WrapProblem) + return new ICmpInst(Pred, Constant::getNullValue(Op1->getType()), B); + + // icmp X, (X-Y) -> icmp Y, 0 for equalities or if there is no overflow. + if (C == Op0 && NoOp1WrapProblem) + return new ICmpInst(Pred, D, Constant::getNullValue(Op0->getType())); + + // icmp (Y-X), (Z-X) -> icmp Y, Z for equalities or if there is no overflow. + if (B && D && B == D && NoOp0WrapProblem && NoOp1WrapProblem && + // Try not to increase register pressure. + BO0->hasOneUse() && BO1->hasOneUse()) + return new ICmpInst(Pred, A, C); + + // icmp (X-Y), (X-Z) -> icmp Z, Y for equalities or if there is no overflow. + if (A && C && A == C && NoOp0WrapProblem && NoOp1WrapProblem && + // Try not to increase register pressure. + BO0->hasOneUse() && BO1->hasOneUse()) + return new ICmpInst(Pred, D, B); + + // icmp (0-X) < cst --> x > -cst + if (NoOp0WrapProblem && ICmpInst::isSigned(Pred)) { + Value *X; + if (match(BO0, m_Neg(m_Value(X)))) + if (ConstantInt *RHSC = dyn_cast<ConstantInt>(Op1)) + if (!RHSC->isMinValue(/*isSigned=*/true)) + return new ICmpInst(I.getSwappedPredicate(), X, + ConstantExpr::getNeg(RHSC)); + } - ICI.setOperand(0, NewAnd); - return &ICI; - } + BinaryOperator *SRem = nullptr; + // icmp (srem X, Y), Y + if (BO0 && BO0->getOpcode() == Instruction::SRem && Op1 == BO0->getOperand(1)) + SRem = BO0; + // icmp Y, (srem X, Y) + else if (BO1 && BO1->getOpcode() == Instruction::SRem && + Op0 == BO1->getOperand(1)) + SRem = BO1; + if (SRem) { + // We don't check hasOneUse to avoid increasing register pressure because + // the value we use is the same value this instruction was already using. + switch (SRem == BO0 ? ICmpInst::getSwappedPredicate(Pred) : Pred) { + default: + break; + case ICmpInst::ICMP_EQ: + return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); + case ICmpInst::ICMP_NE: + return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); + case ICmpInst::ICMP_SGT: + case ICmpInst::ICMP_SGE: + return new ICmpInst(ICmpInst::ICMP_SGT, SRem->getOperand(1), + Constant::getAllOnesValue(SRem->getType())); + case ICmpInst::ICMP_SLT: + case ICmpInst::ICMP_SLE: + return new ICmpInst(ICmpInst::ICMP_SLT, SRem->getOperand(1), + Constant::getNullValue(SRem->getType())); + } + } - // (icmp pred (and (or (lshr X, Y), X), 1), 0) --> - // (icmp pred (and X, (or (shl 1, Y), 1), 0)) - // - // iff pred isn't signed - { - Value *X, *Y, *LShr; - if (!ICI.isSigned() && RHSV == 0) { - if (match(LHSI->getOperand(1), m_One())) { - Constant *One = cast<Constant>(LHSI->getOperand(1)); - Value *Or = LHSI->getOperand(0); - if (match(Or, m_Or(m_Value(LShr), m_Value(X))) && - match(LShr, m_LShr(m_Specific(X), m_Value(Y)))) { - unsigned UsesRemoved = 0; - if (LHSI->hasOneUse()) - ++UsesRemoved; - if (Or->hasOneUse()) - ++UsesRemoved; - if (LShr->hasOneUse()) - ++UsesRemoved; - Value *NewOr = nullptr; - // Compute X & ((1 << Y) | 1) - if (auto *C = dyn_cast<Constant>(Y)) { - if (UsesRemoved >= 1) - NewOr = - ConstantExpr::getOr(ConstantExpr::getNUWShl(One, C), One); - } else { - if (UsesRemoved >= 3) - NewOr = Builder->CreateOr(Builder->CreateShl(One, Y, - LShr->getName(), - /*HasNUW=*/true), - One, Or->getName()); - } - if (NewOr) { - Value *NewAnd = Builder->CreateAnd(X, NewOr, LHSI->getName()); - ICI.setOperand(0, NewAnd); - return &ICI; - } - } - } + if (BO0 && BO1 && BO0->getOpcode() == BO1->getOpcode() && BO0->hasOneUse() && + BO1->hasOneUse() && BO0->getOperand(1) == BO1->getOperand(1)) { + switch (BO0->getOpcode()) { + default: + break; + case Instruction::Add: + case Instruction::Sub: + case Instruction::Xor: + if (I.isEquality()) // a+x icmp eq/ne b+x --> a icmp b + return new ICmpInst(I.getPredicate(), BO0->getOperand(0), + BO1->getOperand(0)); + // icmp u/s (a ^ signbit), (b ^ signbit) --> icmp s/u a, b + if (ConstantInt *CI = dyn_cast<ConstantInt>(BO0->getOperand(1))) { + if (CI->getValue().isSignBit()) { + ICmpInst::Predicate Pred = + I.isSigned() ? I.getUnsignedPredicate() : I.getSignedPredicate(); + return new ICmpInst(Pred, BO0->getOperand(0), BO1->getOperand(0)); } - } - // Replace ((X & AndCst) > RHSV) with ((X & AndCst) != 0), if any - // bit set in (X & AndCst) will produce a result greater than RHSV. - if (ICI.getPredicate() == ICmpInst::ICMP_UGT) { - unsigned NTZ = AndCst->getValue().countTrailingZeros(); - if ((NTZ < AndCst->getBitWidth()) && - APInt::getOneBitSet(AndCst->getBitWidth(), NTZ).ugt(RHSV)) - return new ICmpInst(ICmpInst::ICMP_NE, LHSI, - Constant::getNullValue(RHS->getType())); + if (BO0->getOpcode() == Instruction::Xor && CI->isMaxValue(true)) { + ICmpInst::Predicate Pred = + I.isSigned() ? I.getUnsignedPredicate() : I.getSignedPredicate(); + Pred = I.getSwappedPredicate(Pred); + return new ICmpInst(Pred, BO0->getOperand(0), BO1->getOperand(0)); + } } - } - - // Try to optimize things like "A[i]&42 == 0" to index computations. - if (LoadInst *LI = dyn_cast<LoadInst>(LHSI->getOperand(0))) { - if (GetElementPtrInst *GEP = - dyn_cast<GetElementPtrInst>(LI->getOperand(0))) - if (GlobalVariable *GV = dyn_cast<GlobalVariable>(GEP->getOperand(0))) - if (GV->isConstant() && GV->hasDefinitiveInitializer() && - !LI->isVolatile() && isa<ConstantInt>(LHSI->getOperand(1))) { - ConstantInt *C = cast<ConstantInt>(LHSI->getOperand(1)); - if (Instruction *Res = FoldCmpLoadFromIndexedGlobal(GEP, GV,ICI, C)) - return Res; - } - } + break; + case Instruction::Mul: + if (!I.isEquality()) + break; - // X & -C == -C -> X > u ~C - // X & -C != -C -> X <= u ~C - // iff C is a power of 2 - if (ICI.isEquality() && RHS == LHSI->getOperand(1) && (-RHSV).isPowerOf2()) - return new ICmpInst( - ICI.getPredicate() == ICmpInst::ICMP_EQ ? ICmpInst::ICMP_UGT - : ICmpInst::ICMP_ULE, - LHSI->getOperand(0), SubOne(RHS)); - - // (icmp eq (and %A, C), 0) -> (icmp sgt (trunc %A), -1) - // iff C is a power of 2 - if (ICI.isEquality() && LHSI->hasOneUse() && match(RHS, m_Zero())) { - if (auto *CI = dyn_cast<ConstantInt>(LHSI->getOperand(1))) { - const APInt &AI = CI->getValue(); - int32_t ExactLogBase2 = AI.exactLogBase2(); - if (ExactLogBase2 != -1 && DL.isLegalInteger(ExactLogBase2 + 1)) { - Type *NTy = IntegerType::get(ICI.getContext(), ExactLogBase2 + 1); - Value *Trunc = Builder->CreateTrunc(LHSI->getOperand(0), NTy); - return new ICmpInst(ICI.getPredicate() == ICmpInst::ICMP_EQ - ? ICmpInst::ICMP_SGE - : ICmpInst::ICMP_SLT, - Trunc, Constant::getNullValue(NTy)); + if (ConstantInt *CI = dyn_cast<ConstantInt>(BO0->getOperand(1))) { + // a * Cst icmp eq/ne b * Cst --> a & Mask icmp b & Mask + // Mask = -1 >> count-trailing-zeros(Cst). + if (!CI->isZero() && !CI->isOne()) { + const APInt &AP = CI->getValue(); + ConstantInt *Mask = ConstantInt::get( + I.getContext(), + APInt::getLowBitsSet(AP.getBitWidth(), + AP.getBitWidth() - AP.countTrailingZeros())); + Value *And1 = Builder->CreateAnd(BO0->getOperand(0), Mask); + Value *And2 = Builder->CreateAnd(BO1->getOperand(0), Mask); + return new ICmpInst(I.getPredicate(), And1, And2); } } + break; + case Instruction::UDiv: + case Instruction::LShr: + if (I.isSigned()) + break; + LLVM_FALLTHROUGH; + case Instruction::SDiv: + case Instruction::AShr: + if (!BO0->isExact() || !BO1->isExact()) + break; + return new ICmpInst(I.getPredicate(), BO0->getOperand(0), + BO1->getOperand(0)); + case Instruction::Shl: { + bool NUW = BO0->hasNoUnsignedWrap() && BO1->hasNoUnsignedWrap(); + bool NSW = BO0->hasNoSignedWrap() && BO1->hasNoSignedWrap(); + if (!NUW && !NSW) + break; + if (!NSW && I.isSigned()) + break; + return new ICmpInst(I.getPredicate(), BO0->getOperand(0), + BO1->getOperand(0)); } - break; - - case Instruction::Or: { - if (RHS->isOne()) { - // icmp slt signum(V) 1 --> icmp slt V, 1 - Value *V = nullptr; - if (ICI.getPredicate() == ICmpInst::ICMP_SLT && - match(LHSI, m_Signum(m_Value(V)))) - return new ICmpInst(ICmpInst::ICMP_SLT, V, - ConstantInt::get(V->getType(), 1)); } + } - if (!ICI.isEquality() || !RHS->isNullValue() || !LHSI->hasOneUse()) - break; - Value *P, *Q; - if (match(LHSI, m_Or(m_PtrToInt(m_Value(P)), m_PtrToInt(m_Value(Q))))) { - // Simplify icmp eq (or (ptrtoint P), (ptrtoint Q)), 0 - // -> and (icmp eq P, null), (icmp eq Q, null). - Value *ICIP = Builder->CreateICmp(ICI.getPredicate(), P, - Constant::getNullValue(P->getType())); - Value *ICIQ = Builder->CreateICmp(ICI.getPredicate(), Q, - Constant::getNullValue(Q->getType())); - Instruction *Op; - if (ICI.getPredicate() == ICmpInst::ICMP_EQ) - Op = BinaryOperator::CreateAnd(ICIP, ICIQ); - else - Op = BinaryOperator::CreateOr(ICIP, ICIQ); - return Op; + if (BO0) { + // Transform A & (L - 1) `ult` L --> L != 0 + auto LSubOne = m_Add(m_Specific(Op1), m_AllOnes()); + auto BitwiseAnd = + m_CombineOr(m_And(m_Value(), LSubOne), m_And(LSubOne, m_Value())); + + if (match(BO0, BitwiseAnd) && I.getPredicate() == ICmpInst::ICMP_ULT) { + auto *Zero = Constant::getNullValue(BO0->getType()); + return new ICmpInst(ICmpInst::ICMP_NE, Op1, Zero); } - break; } - case Instruction::Mul: { // (icmp pred (mul X, Val), CI) - ConstantInt *Val = dyn_cast<ConstantInt>(LHSI->getOperand(1)); - if (!Val) break; + return nullptr; +} - // If this is a signed comparison to 0 and the mul is sign preserving, - // use the mul LHS operand instead. - ICmpInst::Predicate pred = ICI.getPredicate(); - if (isSignTest(pred, RHS) && !Val->isZero() && - cast<BinaryOperator>(LHSI)->hasNoSignedWrap()) - return new ICmpInst(Val->isNegative() ? - ICmpInst::getSwappedPredicate(pred) : pred, - LHSI->getOperand(0), - Constant::getNullValue(RHS->getType())); +/// Fold icmp Pred min|max(X, Y), X. +static Instruction *foldICmpWithMinMax(ICmpInst &Cmp) { + ICmpInst::Predicate Pred = Cmp.getPredicate(); + Value *Op0 = Cmp.getOperand(0); + Value *X = Cmp.getOperand(1); + + // Canonicalize minimum or maximum operand to LHS of the icmp. + if (match(X, m_c_SMin(m_Specific(Op0), m_Value())) || + match(X, m_c_SMax(m_Specific(Op0), m_Value())) || + match(X, m_c_UMin(m_Specific(Op0), m_Value())) || + match(X, m_c_UMax(m_Specific(Op0), m_Value()))) { + std::swap(Op0, X); + Pred = Cmp.getSwappedPredicate(); + } - break; + Value *Y; + if (match(Op0, m_c_SMin(m_Specific(X), m_Value(Y)))) { + // smin(X, Y) == X --> X s<= Y + // smin(X, Y) s>= X --> X s<= Y + if (Pred == CmpInst::ICMP_EQ || Pred == CmpInst::ICMP_SGE) + return new ICmpInst(ICmpInst::ICMP_SLE, X, Y); + + // smin(X, Y) != X --> X s> Y + // smin(X, Y) s< X --> X s> Y + if (Pred == CmpInst::ICMP_NE || Pred == CmpInst::ICMP_SLT) + return new ICmpInst(ICmpInst::ICMP_SGT, X, Y); + + // These cases should be handled in InstSimplify: + // smin(X, Y) s<= X --> true + // smin(X, Y) s> X --> false + return nullptr; } - case Instruction::Shl: { // (icmp pred (shl X, ShAmt), CI) - uint32_t TypeBits = RHSV.getBitWidth(); - ConstantInt *ShAmt = dyn_cast<ConstantInt>(LHSI->getOperand(1)); - if (!ShAmt) { - Value *X; - // (1 << X) pred P2 -> X pred Log2(P2) - if (match(LHSI, m_Shl(m_One(), m_Value(X)))) { - bool RHSVIsPowerOf2 = RHSV.isPowerOf2(); - ICmpInst::Predicate Pred = ICI.getPredicate(); - if (ICI.isUnsigned()) { - if (!RHSVIsPowerOf2) { - // (1 << X) < 30 -> X <= 4 - // (1 << X) <= 30 -> X <= 4 - // (1 << X) >= 30 -> X > 4 - // (1 << X) > 30 -> X > 4 - if (Pred == ICmpInst::ICMP_ULT) - Pred = ICmpInst::ICMP_ULE; - else if (Pred == ICmpInst::ICMP_UGE) - Pred = ICmpInst::ICMP_UGT; - } - unsigned RHSLog2 = RHSV.logBase2(); - - // (1 << X) >= 2147483648 -> X >= 31 -> X == 31 - // (1 << X) < 2147483648 -> X < 31 -> X != 31 - if (RHSLog2 == TypeBits-1) { - if (Pred == ICmpInst::ICMP_UGE) - Pred = ICmpInst::ICMP_EQ; - else if (Pred == ICmpInst::ICMP_ULT) - Pred = ICmpInst::ICMP_NE; - } + if (match(Op0, m_c_SMax(m_Specific(X), m_Value(Y)))) { + // smax(X, Y) == X --> X s>= Y + // smax(X, Y) s<= X --> X s>= Y + if (Pred == CmpInst::ICMP_EQ || Pred == CmpInst::ICMP_SLE) + return new ICmpInst(ICmpInst::ICMP_SGE, X, Y); - return new ICmpInst(Pred, X, - ConstantInt::get(RHS->getType(), RHSLog2)); - } else if (ICI.isSigned()) { - if (RHSV.isAllOnesValue()) { - // (1 << X) <= -1 -> X == 31 - if (Pred == ICmpInst::ICMP_SLE) - return new ICmpInst(ICmpInst::ICMP_EQ, X, - ConstantInt::get(RHS->getType(), TypeBits-1)); - - // (1 << X) > -1 -> X != 31 - if (Pred == ICmpInst::ICMP_SGT) - return new ICmpInst(ICmpInst::ICMP_NE, X, - ConstantInt::get(RHS->getType(), TypeBits-1)); - } else if (!RHSV) { - // (1 << X) < 0 -> X == 31 - // (1 << X) <= 0 -> X == 31 - if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE) - return new ICmpInst(ICmpInst::ICMP_EQ, X, - ConstantInt::get(RHS->getType(), TypeBits-1)); - - // (1 << X) >= 0 -> X != 31 - // (1 << X) > 0 -> X != 31 - if (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SGE) - return new ICmpInst(ICmpInst::ICMP_NE, X, - ConstantInt::get(RHS->getType(), TypeBits-1)); - } - } else if (ICI.isEquality()) { - if (RHSVIsPowerOf2) - return new ICmpInst( - Pred, X, ConstantInt::get(RHS->getType(), RHSV.logBase2())); - } - } - break; - } + // smax(X, Y) != X --> X s< Y + // smax(X, Y) s> X --> X s< Y + if (Pred == CmpInst::ICMP_NE || Pred == CmpInst::ICMP_SGT) + return new ICmpInst(ICmpInst::ICMP_SLT, X, Y); - // Check that the shift amount is in range. If not, don't perform - // undefined shifts. When the shift is visited it will be - // simplified. - if (ShAmt->uge(TypeBits)) - break; + // These cases should be handled in InstSimplify: + // smax(X, Y) s>= X --> true + // smax(X, Y) s< X --> false + return nullptr; + } - if (ICI.isEquality()) { - // If we are comparing against bits always shifted out, the - // comparison cannot succeed. - Constant *Comp = - ConstantExpr::getShl(ConstantExpr::getLShr(RHS, ShAmt), - ShAmt); - if (Comp != RHS) {// Comparing against a bit that we know is zero. - bool IsICMP_NE = ICI.getPredicate() == ICmpInst::ICMP_NE; - Constant *Cst = Builder->getInt1(IsICMP_NE); - return replaceInstUsesWith(ICI, Cst); - } + if (match(Op0, m_c_UMin(m_Specific(X), m_Value(Y)))) { + // umin(X, Y) == X --> X u<= Y + // umin(X, Y) u>= X --> X u<= Y + if (Pred == CmpInst::ICMP_EQ || Pred == CmpInst::ICMP_UGE) + return new ICmpInst(ICmpInst::ICMP_ULE, X, Y); - // If the shift is NUW, then it is just shifting out zeros, no need for an - // AND. - if (cast<BinaryOperator>(LHSI)->hasNoUnsignedWrap()) - return new ICmpInst(ICI.getPredicate(), LHSI->getOperand(0), - ConstantExpr::getLShr(RHS, ShAmt)); - - // If the shift is NSW and we compare to 0, then it is just shifting out - // sign bits, no need for an AND either. - if (cast<BinaryOperator>(LHSI)->hasNoSignedWrap() && RHSV == 0) - return new ICmpInst(ICI.getPredicate(), LHSI->getOperand(0), - ConstantExpr::getLShr(RHS, ShAmt)); - - if (LHSI->hasOneUse()) { - // Otherwise strength reduce the shift into an and. - uint32_t ShAmtVal = (uint32_t)ShAmt->getLimitedValue(TypeBits); - Constant *Mask = Builder->getInt(APInt::getLowBitsSet(TypeBits, - TypeBits - ShAmtVal)); - - Value *And = - Builder->CreateAnd(LHSI->getOperand(0),Mask, LHSI->getName()+".mask"); - return new ICmpInst(ICI.getPredicate(), And, - ConstantExpr::getLShr(RHS, ShAmt)); - } - } + // umin(X, Y) != X --> X u> Y + // umin(X, Y) u< X --> X u> Y + if (Pred == CmpInst::ICMP_NE || Pred == CmpInst::ICMP_ULT) + return new ICmpInst(ICmpInst::ICMP_UGT, X, Y); - // If this is a signed comparison to 0 and the shift is sign preserving, - // use the shift LHS operand instead. - ICmpInst::Predicate pred = ICI.getPredicate(); - if (isSignTest(pred, RHS) && - cast<BinaryOperator>(LHSI)->hasNoSignedWrap()) - return new ICmpInst(pred, - LHSI->getOperand(0), - Constant::getNullValue(RHS->getType())); - - // Otherwise, if this is a comparison of the sign bit, simplify to and/test. - bool TrueIfSigned = false; - if (LHSI->hasOneUse() && - isSignBitCheck(ICI.getPredicate(), RHS, TrueIfSigned)) { - // (X << 31) <s 0 --> (X&1) != 0 - Constant *Mask = ConstantInt::get(LHSI->getOperand(0)->getType(), - APInt::getOneBitSet(TypeBits, - TypeBits-ShAmt->getZExtValue()-1)); - Value *And = - Builder->CreateAnd(LHSI->getOperand(0), Mask, LHSI->getName()+".mask"); - return new ICmpInst(TrueIfSigned ? ICmpInst::ICMP_NE : ICmpInst::ICMP_EQ, - And, Constant::getNullValue(And->getType())); - } + // These cases should be handled in InstSimplify: + // umin(X, Y) u<= X --> true + // umin(X, Y) u> X --> false + return nullptr; + } - // Transform (icmp pred iM (shl iM %v, N), CI) - // -> (icmp pred i(M-N) (trunc %v iM to i(M-N)), (trunc (CI>>N)) - // Transform the shl to a trunc if (trunc (CI>>N)) has no loss and M-N. - // This enables to get rid of the shift in favor of a trunc which can be - // free on the target. It has the additional benefit of comparing to a - // smaller constant, which will be target friendly. - unsigned Amt = ShAmt->getLimitedValue(TypeBits-1); - if (LHSI->hasOneUse() && - Amt != 0 && RHSV.countTrailingZeros() >= Amt) { - Type *NTy = IntegerType::get(ICI.getContext(), TypeBits - Amt); - Constant *NCI = ConstantExpr::getTrunc( - ConstantExpr::getAShr(RHS, - ConstantInt::get(RHS->getType(), Amt)), - NTy); - return new ICmpInst(ICI.getPredicate(), - Builder->CreateTrunc(LHSI->getOperand(0), NTy), - NCI); - } + if (match(Op0, m_c_UMax(m_Specific(X), m_Value(Y)))) { + // umax(X, Y) == X --> X u>= Y + // umax(X, Y) u<= X --> X u>= Y + if (Pred == CmpInst::ICMP_EQ || Pred == CmpInst::ICMP_ULE) + return new ICmpInst(ICmpInst::ICMP_UGE, X, Y); - break; + // umax(X, Y) != X --> X u< Y + // umax(X, Y) u> X --> X u< Y + if (Pred == CmpInst::ICMP_NE || Pred == CmpInst::ICMP_UGT) + return new ICmpInst(ICmpInst::ICMP_ULT, X, Y); + + // These cases should be handled in InstSimplify: + // umax(X, Y) u>= X --> true + // umax(X, Y) u< X --> false + return nullptr; } - case Instruction::LShr: // (icmp pred (shr X, ShAmt), CI) - case Instruction::AShr: { - // Handle equality comparisons of shift-by-constant. - BinaryOperator *BO = cast<BinaryOperator>(LHSI); - if (ConstantInt *ShAmt = dyn_cast<ConstantInt>(LHSI->getOperand(1))) { - if (Instruction *Res = FoldICmpShrCst(ICI, BO, ShAmt)) - return Res; - } + return nullptr; +} + +Instruction *InstCombiner::foldICmpEquality(ICmpInst &I) { + if (!I.isEquality()) + return nullptr; - // Handle exact shr's. - if (ICI.isEquality() && BO->isExact() && BO->hasOneUse()) { - if (RHSV.isMinValue()) - return new ICmpInst(ICI.getPredicate(), BO->getOperand(0), RHS); + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + Value *A, *B, *C, *D; + if (match(Op0, m_Xor(m_Value(A), m_Value(B)))) { + if (A == Op1 || B == Op1) { // (A^B) == A -> B == 0 + Value *OtherVal = A == Op1 ? B : A; + return new ICmpInst(I.getPredicate(), OtherVal, + Constant::getNullValue(A->getType())); } - break; - } - case Instruction::UDiv: - if (ConstantInt *DivLHS = dyn_cast<ConstantInt>(LHSI->getOperand(0))) { - Value *X = LHSI->getOperand(1); - const APInt &C1 = RHS->getValue(); - const APInt &C2 = DivLHS->getValue(); - assert(C2 != 0 && "udiv 0, X should have been simplified already."); - // (icmp ugt (udiv C2, X), C1) -> (icmp ule X, C2/(C1+1)) - if (ICI.getPredicate() == ICmpInst::ICMP_UGT) { - assert(!C1.isMaxValue() && - "icmp ugt X, UINT_MAX should have been simplified already."); - return new ICmpInst(ICmpInst::ICMP_ULE, X, - ConstantInt::get(X->getType(), C2.udiv(C1 + 1))); - } - // (icmp ult (udiv C2, X), C1) -> (icmp ugt X, C2/C1) - if (ICI.getPredicate() == ICmpInst::ICMP_ULT) { - assert(C1 != 0 && "icmp ult X, 0 should have been simplified already."); - return new ICmpInst(ICmpInst::ICMP_UGT, X, - ConstantInt::get(X->getType(), C2.udiv(C1))); + if (match(Op1, m_Xor(m_Value(C), m_Value(D)))) { + // A^c1 == C^c2 --> A == C^(c1^c2) + ConstantInt *C1, *C2; + if (match(B, m_ConstantInt(C1)) && match(D, m_ConstantInt(C2)) && + Op1->hasOneUse()) { + Constant *NC = Builder->getInt(C1->getValue() ^ C2->getValue()); + Value *Xor = Builder->CreateXor(C, NC); + return new ICmpInst(I.getPredicate(), A, Xor); } + + // A^B == A^D -> B == D + if (A == C) + return new ICmpInst(I.getPredicate(), B, D); + if (A == D) + return new ICmpInst(I.getPredicate(), B, C); + if (B == C) + return new ICmpInst(I.getPredicate(), A, D); + if (B == D) + return new ICmpInst(I.getPredicate(), A, C); } - // fall-through - case Instruction::SDiv: - // Fold: icmp pred ([us]div X, C1), C2 -> range test - // Fold this div into the comparison, producing a range check. - // Determine, based on the divide type, what the range is being - // checked. If there is an overflow on the low or high side, remember - // it, otherwise compute the range [low, hi) bounding the new value. - // See: InsertRangeTest above for the kinds of replacements possible. - if (ConstantInt *DivRHS = dyn_cast<ConstantInt>(LHSI->getOperand(1))) - if (Instruction *R = FoldICmpDivCst(ICI, cast<BinaryOperator>(LHSI), - DivRHS)) - return R; - break; + } - case Instruction::Sub: { - ConstantInt *LHSC = dyn_cast<ConstantInt>(LHSI->getOperand(0)); - if (!LHSC) break; - const APInt &LHSV = LHSC->getValue(); - - // C1-X <u C2 -> (X|(C2-1)) == C1 - // iff C1 & (C2-1) == C2-1 - // C2 is a power of 2 - if (ICI.getPredicate() == ICmpInst::ICMP_ULT && LHSI->hasOneUse() && - RHSV.isPowerOf2() && (LHSV & (RHSV - 1)) == (RHSV - 1)) - return new ICmpInst(ICmpInst::ICMP_EQ, - Builder->CreateOr(LHSI->getOperand(1), RHSV - 1), - LHSC); - - // C1-X >u C2 -> (X|C2) != C1 - // iff C1 & C2 == C2 - // C2+1 is a power of 2 - if (ICI.getPredicate() == ICmpInst::ICMP_UGT && LHSI->hasOneUse() && - (RHSV + 1).isPowerOf2() && (LHSV & RHSV) == RHSV) - return new ICmpInst(ICmpInst::ICMP_NE, - Builder->CreateOr(LHSI->getOperand(1), RHSV), LHSC); - break; + if (match(Op1, m_Xor(m_Value(A), m_Value(B))) && (A == Op0 || B == Op0)) { + // A == (A^B) -> B == 0 + Value *OtherVal = A == Op0 ? B : A; + return new ICmpInst(I.getPredicate(), OtherVal, + Constant::getNullValue(A->getType())); } - case Instruction::Add: - // Fold: icmp pred (add X, C1), C2 - if (!ICI.isEquality()) { - ConstantInt *LHSC = dyn_cast<ConstantInt>(LHSI->getOperand(1)); - if (!LHSC) break; - const APInt &LHSV = LHSC->getValue(); + // (X&Z) == (Y&Z) -> (X^Y) & Z == 0 + if (match(Op0, m_OneUse(m_And(m_Value(A), m_Value(B)))) && + match(Op1, m_OneUse(m_And(m_Value(C), m_Value(D))))) { + Value *X = nullptr, *Y = nullptr, *Z = nullptr; + + if (A == C) { + X = B; + Y = D; + Z = A; + } else if (A == D) { + X = B; + Y = C; + Z = A; + } else if (B == C) { + X = A; + Y = D; + Z = B; + } else if (B == D) { + X = A; + Y = C; + Z = B; + } - ConstantRange CR = ICI.makeConstantRange(ICI.getPredicate(), RHSV) - .subtract(LHSV); + if (X) { // Build (X^Y) & Z + Op1 = Builder->CreateXor(X, Y); + Op1 = Builder->CreateAnd(Op1, Z); + I.setOperand(0, Op1); + I.setOperand(1, Constant::getNullValue(Op1->getType())); + return &I; + } + } - if (ICI.isSigned()) { - if (CR.getLower().isSignBit()) { - return new ICmpInst(ICmpInst::ICMP_SLT, LHSI->getOperand(0), - Builder->getInt(CR.getUpper())); - } else if (CR.getUpper().isSignBit()) { - return new ICmpInst(ICmpInst::ICMP_SGE, LHSI->getOperand(0), - Builder->getInt(CR.getLower())); - } - } else { - if (CR.getLower().isMinValue()) { - return new ICmpInst(ICmpInst::ICMP_ULT, LHSI->getOperand(0), - Builder->getInt(CR.getUpper())); - } else if (CR.getUpper().isMinValue()) { - return new ICmpInst(ICmpInst::ICMP_UGE, LHSI->getOperand(0), - Builder->getInt(CR.getLower())); - } - } + // Transform (zext A) == (B & (1<<X)-1) --> A == (trunc B) + // and (B & (1<<X)-1) == (zext A) --> A == (trunc B) + ConstantInt *Cst1; + if ((Op0->hasOneUse() && match(Op0, m_ZExt(m_Value(A))) && + match(Op1, m_And(m_Value(B), m_ConstantInt(Cst1)))) || + (Op1->hasOneUse() && match(Op0, m_And(m_Value(B), m_ConstantInt(Cst1))) && + match(Op1, m_ZExt(m_Value(A))))) { + APInt Pow2 = Cst1->getValue() + 1; + if (Pow2.isPowerOf2() && isa<IntegerType>(A->getType()) && + Pow2.logBase2() == cast<IntegerType>(A->getType())->getBitWidth()) + return new ICmpInst(I.getPredicate(), A, + Builder->CreateTrunc(B, A->getType())); + } - // X-C1 <u C2 -> (X & -C2) == C1 - // iff C1 & (C2-1) == 0 - // C2 is a power of 2 - if (ICI.getPredicate() == ICmpInst::ICMP_ULT && LHSI->hasOneUse() && - RHSV.isPowerOf2() && (LHSV & (RHSV - 1)) == 0) - return new ICmpInst(ICmpInst::ICMP_EQ, - Builder->CreateAnd(LHSI->getOperand(0), -RHSV), - ConstantExpr::getNeg(LHSC)); - - // X-C1 >u C2 -> (X & ~C2) != C1 - // iff C1 & C2 == 0 - // C2+1 is a power of 2 - if (ICI.getPredicate() == ICmpInst::ICMP_UGT && LHSI->hasOneUse() && - (RHSV + 1).isPowerOf2() && (LHSV & RHSV) == 0) - return new ICmpInst(ICmpInst::ICMP_NE, - Builder->CreateAnd(LHSI->getOperand(0), ~RHSV), - ConstantExpr::getNeg(LHSC)); + // (A >> C) == (B >> C) --> (A^B) u< (1 << C) + // For lshr and ashr pairs. + if ((match(Op0, m_OneUse(m_LShr(m_Value(A), m_ConstantInt(Cst1)))) && + match(Op1, m_OneUse(m_LShr(m_Value(B), m_Specific(Cst1))))) || + (match(Op0, m_OneUse(m_AShr(m_Value(A), m_ConstantInt(Cst1)))) && + match(Op1, m_OneUse(m_AShr(m_Value(B), m_Specific(Cst1)))))) { + unsigned TypeBits = Cst1->getBitWidth(); + unsigned ShAmt = (unsigned)Cst1->getLimitedValue(TypeBits); + if (ShAmt < TypeBits && ShAmt != 0) { + ICmpInst::Predicate Pred = I.getPredicate() == ICmpInst::ICMP_NE + ? ICmpInst::ICMP_UGE + : ICmpInst::ICMP_ULT; + Value *Xor = Builder->CreateXor(A, B, I.getName() + ".unshifted"); + APInt CmpVal = APInt::getOneBitSet(TypeBits, ShAmt); + return new ICmpInst(Pred, Xor, Builder->getInt(CmpVal)); } - break; } - // Simplify icmp_eq and icmp_ne instructions with integer constant RHS. - if (ICI.isEquality()) { - bool isICMP_NE = ICI.getPredicate() == ICmpInst::ICMP_NE; - - // If the first operand is (add|sub|and|or|xor|rem) with a constant, and - // the second operand is a constant, simplify a bit. - if (BinaryOperator *BO = dyn_cast<BinaryOperator>(LHSI)) { - switch (BO->getOpcode()) { - case Instruction::SRem: - // If we have a signed (X % (2^c)) == 0, turn it into an unsigned one. - if (RHSV == 0 && isa<ConstantInt>(BO->getOperand(1)) &&BO->hasOneUse()){ - const APInt &V = cast<ConstantInt>(BO->getOperand(1))->getValue(); - if (V.sgt(1) && V.isPowerOf2()) { - Value *NewRem = - Builder->CreateURem(BO->getOperand(0), BO->getOperand(1), - BO->getName()); - return new ICmpInst(ICI.getPredicate(), NewRem, - Constant::getNullValue(BO->getType())); - } - } - break; - case Instruction::Add: - // Replace ((add A, B) != C) with (A != C-B) if B & C are constants. - if (ConstantInt *BOp1C = dyn_cast<ConstantInt>(BO->getOperand(1))) { - if (BO->hasOneUse()) - return new ICmpInst(ICI.getPredicate(), BO->getOperand(0), - ConstantExpr::getSub(RHS, BOp1C)); - } else if (RHSV == 0) { - // Replace ((add A, B) != 0) with (A != -B) if A or B is - // efficiently invertible, or if the add has just this one use. - Value *BOp0 = BO->getOperand(0), *BOp1 = BO->getOperand(1); - - if (Value *NegVal = dyn_castNegVal(BOp1)) - return new ICmpInst(ICI.getPredicate(), BOp0, NegVal); - if (Value *NegVal = dyn_castNegVal(BOp0)) - return new ICmpInst(ICI.getPredicate(), NegVal, BOp1); - if (BO->hasOneUse()) { - Value *Neg = Builder->CreateNeg(BOp1); - Neg->takeName(BO); - return new ICmpInst(ICI.getPredicate(), BOp0, Neg); - } - } - break; - case Instruction::Xor: - if (BO->hasOneUse()) { - if (Constant *BOC = dyn_cast<Constant>(BO->getOperand(1))) { - // For the xor case, we can xor two constants together, eliminating - // the explicit xor. - return new ICmpInst(ICI.getPredicate(), BO->getOperand(0), - ConstantExpr::getXor(RHS, BOC)); - } else if (RHSV == 0) { - // Replace ((xor A, B) != 0) with (A != B) - return new ICmpInst(ICI.getPredicate(), BO->getOperand(0), - BO->getOperand(1)); - } - } - break; - case Instruction::Sub: - if (BO->hasOneUse()) { - if (ConstantInt *BOp0C = dyn_cast<ConstantInt>(BO->getOperand(0))) { - // Replace ((sub A, B) != C) with (B != A-C) if A & C are constants. - return new ICmpInst(ICI.getPredicate(), BO->getOperand(1), - ConstantExpr::getSub(BOp0C, RHS)); - } else if (RHSV == 0) { - // Replace ((sub A, B) != 0) with (A != B) - return new ICmpInst(ICI.getPredicate(), BO->getOperand(0), - BO->getOperand(1)); - } - } - break; - case Instruction::Or: - // If bits are being or'd in that are not present in the constant we - // are comparing against, then the comparison could never succeed! - if (ConstantInt *BOC = dyn_cast<ConstantInt>(BO->getOperand(1))) { - Constant *NotCI = ConstantExpr::getNot(RHS); - if (!ConstantExpr::getAnd(BOC, NotCI)->isNullValue()) - return replaceInstUsesWith(ICI, Builder->getInt1(isICMP_NE)); - - // Comparing if all bits outside of a constant mask are set? - // Replace (X | C) == -1 with (X & ~C) == ~C. - // This removes the -1 constant. - if (BO->hasOneUse() && RHS->isAllOnesValue()) { - Constant *NotBOC = ConstantExpr::getNot(BOC); - Value *And = Builder->CreateAnd(BO->getOperand(0), NotBOC); - return new ICmpInst(ICI.getPredicate(), And, NotBOC); - } - } - break; + // (A << C) == (B << C) --> ((A^B) & (~0U >> C)) == 0 + if (match(Op0, m_OneUse(m_Shl(m_Value(A), m_ConstantInt(Cst1)))) && + match(Op1, m_OneUse(m_Shl(m_Value(B), m_Specific(Cst1))))) { + unsigned TypeBits = Cst1->getBitWidth(); + unsigned ShAmt = (unsigned)Cst1->getLimitedValue(TypeBits); + if (ShAmt < TypeBits && ShAmt != 0) { + Value *Xor = Builder->CreateXor(A, B, I.getName() + ".unshifted"); + APInt AndVal = APInt::getLowBitsSet(TypeBits, TypeBits - ShAmt); + Value *And = Builder->CreateAnd(Xor, Builder->getInt(AndVal), + I.getName() + ".mask"); + return new ICmpInst(I.getPredicate(), And, + Constant::getNullValue(Cst1->getType())); + } + } - case Instruction::And: - if (ConstantInt *BOC = dyn_cast<ConstantInt>(BO->getOperand(1))) { - // If bits are being compared against that are and'd out, then the - // comparison can never succeed! - if ((RHSV & ~BOC->getValue()) != 0) - return replaceInstUsesWith(ICI, Builder->getInt1(isICMP_NE)); - - // If we have ((X & C) == C), turn it into ((X & C) != 0). - if (RHS == BOC && RHSV.isPowerOf2()) - return new ICmpInst(isICMP_NE ? ICmpInst::ICMP_EQ : - ICmpInst::ICMP_NE, LHSI, - Constant::getNullValue(RHS->getType())); - - // Don't perform the following transforms if the AND has multiple uses - if (!BO->hasOneUse()) - break; + // Transform "icmp eq (trunc (lshr(X, cst1)), cst" to + // "icmp (and X, mask), cst" + uint64_t ShAmt = 0; + if (Op0->hasOneUse() && + match(Op0, m_Trunc(m_OneUse(m_LShr(m_Value(A), m_ConstantInt(ShAmt))))) && + match(Op1, m_ConstantInt(Cst1)) && + // Only do this when A has multiple uses. This is most important to do + // when it exposes other optimizations. + !A->hasOneUse()) { + unsigned ASize = cast<IntegerType>(A->getType())->getPrimitiveSizeInBits(); + + if (ShAmt < ASize) { + APInt MaskV = + APInt::getLowBitsSet(ASize, Op0->getType()->getPrimitiveSizeInBits()); + MaskV <<= ShAmt; - // Replace (and X, (1 << size(X)-1) != 0) with x s< 0 - if (BOC->getValue().isSignBit()) { - Value *X = BO->getOperand(0); - Constant *Zero = Constant::getNullValue(X->getType()); - ICmpInst::Predicate pred = isICMP_NE ? - ICmpInst::ICMP_SLT : ICmpInst::ICMP_SGE; - return new ICmpInst(pred, X, Zero); - } + APInt CmpV = Cst1->getValue().zext(ASize); + CmpV <<= ShAmt; - // ((X & ~7) == 0) --> X < 8 - if (RHSV == 0 && isHighOnes(BOC)) { - Value *X = BO->getOperand(0); - Constant *NegX = ConstantExpr::getNeg(BOC); - ICmpInst::Predicate pred = isICMP_NE ? - ICmpInst::ICMP_UGE : ICmpInst::ICMP_ULT; - return new ICmpInst(pred, X, NegX); - } - } - break; - case Instruction::Mul: - if (RHSV == 0 && BO->hasNoSignedWrap()) { - if (ConstantInt *BOC = dyn_cast<ConstantInt>(BO->getOperand(1))) { - // The trivial case (mul X, 0) is handled by InstSimplify - // General case : (mul X, C) != 0 iff X != 0 - // (mul X, C) == 0 iff X == 0 - if (!BOC->isZero()) - return new ICmpInst(ICI.getPredicate(), BO->getOperand(0), - Constant::getNullValue(RHS->getType())); - } - } - break; - default: break; - } - } else if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(LHSI)) { - // Handle icmp {eq|ne} <intrinsic>, intcst. - switch (II->getIntrinsicID()) { - case Intrinsic::bswap: - Worklist.Add(II); - ICI.setOperand(0, II->getArgOperand(0)); - ICI.setOperand(1, Builder->getInt(RHSV.byteSwap())); - return &ICI; - case Intrinsic::ctlz: - case Intrinsic::cttz: - // ctz(A) == bitwidth(a) -> A == 0 and likewise for != - if (RHSV == RHS->getType()->getBitWidth()) { - Worklist.Add(II); - ICI.setOperand(0, II->getArgOperand(0)); - ICI.setOperand(1, ConstantInt::get(RHS->getType(), 0)); - return &ICI; - } - break; - case Intrinsic::ctpop: - // popcount(A) == 0 -> A == 0 and likewise for != - if (RHS->isZero()) { - Worklist.Add(II); - ICI.setOperand(0, II->getArgOperand(0)); - ICI.setOperand(1, RHS); - return &ICI; - } - break; - default: - break; - } + Value *Mask = Builder->CreateAnd(A, Builder->getInt(MaskV)); + return new ICmpInst(I.getPredicate(), Mask, Builder->getInt(CmpV)); } } + return nullptr; } /// Handle icmp (cast x to y), (cast/cst). We only handle extending casts so /// far. -Instruction *InstCombiner::visitICmpInstWithCastAndCast(ICmpInst &ICmp) { +Instruction *InstCombiner::foldICmpWithCastAndCast(ICmpInst &ICmp) { const CastInst *LHSCI = cast<CastInst>(ICmp.getOperand(0)); Value *LHSCIOp = LHSCI->getOperand(0); Type *SrcTy = LHSCIOp->getType(); @@ -2485,92 +3384,6 @@ Instruction *InstCombiner::visitICmpInstWithCastAndCast(ICmpInst &ICmp) { return BinaryOperator::CreateNot(Result); } -/// The caller has matched a pattern of the form: -/// I = icmp ugt (add (add A, B), CI2), CI1 -/// If this is of the form: -/// sum = a + b -/// if (sum+128 >u 255) -/// Then replace it with llvm.sadd.with.overflow.i8. -/// -static Instruction *ProcessUGT_ADDCST_ADD(ICmpInst &I, Value *A, Value *B, - ConstantInt *CI2, ConstantInt *CI1, - InstCombiner &IC) { - // The transformation we're trying to do here is to transform this into an - // llvm.sadd.with.overflow. To do this, we have to replace the original add - // with a narrower add, and discard the add-with-constant that is part of the - // range check (if we can't eliminate it, this isn't profitable). - - // In order to eliminate the add-with-constant, the compare can be its only - // use. - Instruction *AddWithCst = cast<Instruction>(I.getOperand(0)); - if (!AddWithCst->hasOneUse()) return nullptr; - - // If CI2 is 2^7, 2^15, 2^31, then it might be an sadd.with.overflow. - if (!CI2->getValue().isPowerOf2()) return nullptr; - unsigned NewWidth = CI2->getValue().countTrailingZeros(); - if (NewWidth != 7 && NewWidth != 15 && NewWidth != 31) return nullptr; - - // The width of the new add formed is 1 more than the bias. - ++NewWidth; - - // Check to see that CI1 is an all-ones value with NewWidth bits. - if (CI1->getBitWidth() == NewWidth || - CI1->getValue() != APInt::getLowBitsSet(CI1->getBitWidth(), NewWidth)) - return nullptr; - - // This is only really a signed overflow check if the inputs have been - // sign-extended; check for that condition. For example, if CI2 is 2^31 and - // the operands of the add are 64 bits wide, we need at least 33 sign bits. - unsigned NeededSignBits = CI1->getBitWidth() - NewWidth + 1; - if (IC.ComputeNumSignBits(A, 0, &I) < NeededSignBits || - IC.ComputeNumSignBits(B, 0, &I) < NeededSignBits) - return nullptr; - - // In order to replace the original add with a narrower - // llvm.sadd.with.overflow, the only uses allowed are the add-with-constant - // and truncates that discard the high bits of the add. Verify that this is - // the case. - Instruction *OrigAdd = cast<Instruction>(AddWithCst->getOperand(0)); - for (User *U : OrigAdd->users()) { - if (U == AddWithCst) continue; - - // Only accept truncates for now. We would really like a nice recursive - // predicate like SimplifyDemandedBits, but which goes downwards the use-def - // chain to see which bits of a value are actually demanded. If the - // original add had another add which was then immediately truncated, we - // could still do the transformation. - TruncInst *TI = dyn_cast<TruncInst>(U); - if (!TI || TI->getType()->getPrimitiveSizeInBits() > NewWidth) - return nullptr; - } - - // If the pattern matches, truncate the inputs to the narrower type and - // use the sadd_with_overflow intrinsic to efficiently compute both the - // result and the overflow bit. - Type *NewType = IntegerType::get(OrigAdd->getContext(), NewWidth); - Value *F = Intrinsic::getDeclaration(I.getModule(), - Intrinsic::sadd_with_overflow, NewType); - - InstCombiner::BuilderTy *Builder = IC.Builder; - - // Put the new code above the original add, in case there are any uses of the - // add between the add and the compare. - Builder->SetInsertPoint(OrigAdd); - - Value *TruncA = Builder->CreateTrunc(A, NewType, A->getName()+".trunc"); - Value *TruncB = Builder->CreateTrunc(B, NewType, B->getName()+".trunc"); - CallInst *Call = Builder->CreateCall(F, {TruncA, TruncB}, "sadd"); - Value *Add = Builder->CreateExtractValue(Call, 0, "sadd.result"); - Value *ZExt = Builder->CreateZExt(Add, OrigAdd->getType()); - - // The inner add was the result of the narrow add, zero extended to the - // wider type. Replace it with the result computed by the intrinsic. - IC.replaceInstUsesWith(*OrigAdd, ZExt); - - // The original icmp gets replaced with the overflow value. - return ExtractValueInst::Create(Call, 1, "sadd.overflow"); -} - bool InstCombiner::OptimizeOverflowCheck(OverflowCheckFlavor OCF, Value *LHS, Value *RHS, Instruction &OrigI, Value *&Result, Constant *&Overflow) { @@ -2603,8 +3416,10 @@ bool InstCombiner::OptimizeOverflowCheck(OverflowCheckFlavor OCF, Value *LHS, if (OR == OverflowResult::AlwaysOverflows) return SetResult(Builder->CreateAdd(LHS, RHS), Builder->getTrue(), true); + + // Fall through uadd into sadd + LLVM_FALLTHROUGH; } - // FALL THROUGH uadd into sadd case OCF_SIGNED_ADD: { // X + 0 -> {X, false} if (match(RHS, m_Zero())) @@ -2644,7 +3459,8 @@ bool InstCombiner::OptimizeOverflowCheck(OverflowCheckFlavor OCF, Value *LHS, true); if (OR == OverflowResult::AlwaysOverflows) return SetResult(Builder->CreateMul(LHS, RHS), Builder->getTrue(), true); - } // FALL THROUGH + LLVM_FALLTHROUGH; + } case OCF_SIGNED_MUL: // X * undef -> undef if (isa<UndefValue>(RHS)) @@ -2682,7 +3498,7 @@ bool InstCombiner::OptimizeOverflowCheck(OverflowCheckFlavor OCF, Value *LHS, /// \param OtherVal The other argument of compare instruction. /// \returns Instruction which must replace the compare instruction, NULL if no /// replacement required. -static Instruction *ProcessUMulZExtIdiom(ICmpInst &I, Value *MulVal, +static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal, Value *OtherVal, InstCombiner &IC) { // Don't bother doing this transformation for pointers, don't do it for // vectors. @@ -2906,8 +3722,8 @@ static Instruction *ProcessUMulZExtIdiom(ICmpInst &I, Value *MulVal, /// When performing a comparison against a constant, it is possible that not all /// the bits in the LHS are demanded. This helper method computes the mask that /// IS demanded. -static APInt DemandedBitsLHSMask(ICmpInst &I, - unsigned BitWidth, bool isSignCheck) { +static APInt getDemandedBitsLHSMask(ICmpInst &I, unsigned BitWidth, + bool isSignCheck) { if (isSignCheck) return APInt::getSignBit(BitWidth); @@ -2981,7 +3797,7 @@ static bool swapMayExposeCSEOpportunities(const Value * Op0, } /// \brief Check that one use is in the same block as the definition and all -/// other uses are in blocks dominated by a given block +/// other uses are in blocks dominated by a given block. /// /// \param DI Definition /// \param UI Use @@ -2994,21 +3810,18 @@ bool InstCombiner::dominatesAllUses(const Instruction *DI, const Instruction *UI, const BasicBlock *DB) const { assert(DI && UI && "Instruction not defined\n"); - // ignore incomplete definitions + // Ignore incomplete definitions. if (!DI->getParent()) return false; - // DI and UI must be in the same block + // DI and UI must be in the same block. if (DI->getParent() != UI->getParent()) return false; - // Protect from self-referencing blocks + // Protect from self-referencing blocks. if (DI->getParent() == DB) return false; - // DominatorTree available? - if (!DT) - return false; for (const User *U : DI->users()) { auto *Usr = cast<Instruction>(U); - if (Usr != UI && !DT->dominates(DB, Usr->getParent())) + if (Usr != UI && !DT.dominates(DB, Usr->getParent())) return false; } return true; @@ -3067,8 +3880,7 @@ static bool isChainSelectCmpBranch(const SelectInst *SI) { /// are equal, the optimization can work only for EQ predicates. This is not a /// major restriction since a NE compare should be 'normalized' to an equal /// compare, which usually happens in the combiner and test case -/// select-cmp-br.ll -/// checks for it. +/// select-cmp-br.ll checks for it. bool InstCombiner::replacedSelectWithOperand(SelectInst *SI, const ICmpInst *Icmp, const unsigned SIOpd) { @@ -3076,7 +3888,7 @@ bool InstCombiner::replacedSelectWithOperand(SelectInst *SI, if (isChainSelectCmpBranch(SI) && Icmp->getPredicate() == ICmpInst::ICMP_EQ) { BasicBlock *Succ = SI->getParent()->getTerminator()->getSuccessor(1); // The check for the unique predecessor is not the best that can be - // done. But it protects efficiently against cases like when SI's + // done. But it protects efficiently against cases like when SI's // home block has two successors, Succ and Succ1, and Succ1 predecessor // of Succ. Then SI can't be replaced by SIOpd because the use that gets // replaced can be reached on either path. So the uniqueness check @@ -3093,6 +3905,229 @@ bool InstCombiner::replacedSelectWithOperand(SelectInst *SI, return false; } +/// Try to fold the comparison based on range information we can get by checking +/// whether bits are known to be zero or one in the inputs. +Instruction *InstCombiner::foldICmpUsingKnownBits(ICmpInst &I) { + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + Type *Ty = Op0->getType(); + ICmpInst::Predicate Pred = I.getPredicate(); + + // Get scalar or pointer size. + unsigned BitWidth = Ty->isIntOrIntVectorTy() + ? Ty->getScalarSizeInBits() + : DL.getTypeSizeInBits(Ty->getScalarType()); + + if (!BitWidth) + return nullptr; + + // If this is a normal comparison, it demands all bits. If it is a sign bit + // comparison, it only demands the sign bit. + bool IsSignBit = false; + const APInt *CmpC; + if (match(Op1, m_APInt(CmpC))) { + bool UnusedBit; + IsSignBit = isSignBitCheck(Pred, *CmpC, UnusedBit); + } + + APInt Op0KnownZero(BitWidth, 0), Op0KnownOne(BitWidth, 0); + APInt Op1KnownZero(BitWidth, 0), Op1KnownOne(BitWidth, 0); + + if (SimplifyDemandedBits(I.getOperandUse(0), + getDemandedBitsLHSMask(I, BitWidth, IsSignBit), + Op0KnownZero, Op0KnownOne, 0)) + return &I; + + if (SimplifyDemandedBits(I.getOperandUse(1), APInt::getAllOnesValue(BitWidth), + Op1KnownZero, Op1KnownOne, 0)) + return &I; + + // Given the known and unknown bits, compute a range that the LHS could be + // in. Compute the Min, Max and RHS values based on the known bits. For the + // EQ and NE we use unsigned values. + APInt Op0Min(BitWidth, 0), Op0Max(BitWidth, 0); + APInt Op1Min(BitWidth, 0), Op1Max(BitWidth, 0); + if (I.isSigned()) { + computeSignedMinMaxValuesFromKnownBits(Op0KnownZero, Op0KnownOne, Op0Min, + Op0Max); + computeSignedMinMaxValuesFromKnownBits(Op1KnownZero, Op1KnownOne, Op1Min, + Op1Max); + } else { + computeUnsignedMinMaxValuesFromKnownBits(Op0KnownZero, Op0KnownOne, Op0Min, + Op0Max); + computeUnsignedMinMaxValuesFromKnownBits(Op1KnownZero, Op1KnownOne, Op1Min, + Op1Max); + } + + // If Min and Max are known to be the same, then SimplifyDemandedBits + // figured out that the LHS is a constant. Constant fold this now, so that + // code below can assume that Min != Max. + if (!isa<Constant>(Op0) && Op0Min == Op0Max) + return new ICmpInst(Pred, ConstantInt::get(Op0->getType(), Op0Min), Op1); + if (!isa<Constant>(Op1) && Op1Min == Op1Max) + return new ICmpInst(Pred, Op0, ConstantInt::get(Op1->getType(), Op1Min)); + + // Based on the range information we know about the LHS, see if we can + // simplify this comparison. For example, (x&4) < 8 is always true. + switch (Pred) { + default: + llvm_unreachable("Unknown icmp opcode!"); + case ICmpInst::ICMP_EQ: + case ICmpInst::ICMP_NE: { + if (Op0Max.ult(Op1Min) || Op0Min.ugt(Op1Max)) { + return Pred == CmpInst::ICMP_EQ + ? replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())) + : replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); + } + + // If all bits are known zero except for one, then we know at most one bit + // is set. If the comparison is against zero, then this is a check to see if + // *that* bit is set. + APInt Op0KnownZeroInverted = ~Op0KnownZero; + if (~Op1KnownZero == 0) { + // If the LHS is an AND with the same constant, look through it. + Value *LHS = nullptr; + const APInt *LHSC; + if (!match(Op0, m_And(m_Value(LHS), m_APInt(LHSC))) || + *LHSC != Op0KnownZeroInverted) + LHS = Op0; + + Value *X; + if (match(LHS, m_Shl(m_One(), m_Value(X)))) { + APInt ValToCheck = Op0KnownZeroInverted; + Type *XTy = X->getType(); + if (ValToCheck.isPowerOf2()) { + // ((1 << X) & 8) == 0 -> X != 3 + // ((1 << X) & 8) != 0 -> X == 3 + auto *CmpC = ConstantInt::get(XTy, ValToCheck.countTrailingZeros()); + auto NewPred = ICmpInst::getInversePredicate(Pred); + return new ICmpInst(NewPred, X, CmpC); + } else if ((++ValToCheck).isPowerOf2()) { + // ((1 << X) & 7) == 0 -> X >= 3 + // ((1 << X) & 7) != 0 -> X < 3 + auto *CmpC = ConstantInt::get(XTy, ValToCheck.countTrailingZeros()); + auto NewPred = + Pred == CmpInst::ICMP_EQ ? CmpInst::ICMP_UGE : CmpInst::ICMP_ULT; + return new ICmpInst(NewPred, X, CmpC); + } + } + + // Check if the LHS is 8 >>u x and the result is a power of 2 like 1. + const APInt *CI; + if (Op0KnownZeroInverted == 1 && + match(LHS, m_LShr(m_Power2(CI), m_Value(X)))) { + // ((8 >>u X) & 1) == 0 -> X != 3 + // ((8 >>u X) & 1) != 0 -> X == 3 + unsigned CmpVal = CI->countTrailingZeros(); + auto NewPred = ICmpInst::getInversePredicate(Pred); + return new ICmpInst(NewPred, X, ConstantInt::get(X->getType(), CmpVal)); + } + } + break; + } + case ICmpInst::ICMP_ULT: { + if (Op0Max.ult(Op1Min)) // A <u B -> true if max(A) < min(B) + return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); + if (Op0Min.uge(Op1Max)) // A <u B -> false if min(A) >= max(B) + return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); + if (Op1Min == Op0Max) // A <u B -> A != B if max(A) == min(B) + return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); + + const APInt *CmpC; + if (match(Op1, m_APInt(CmpC))) { + // A <u C -> A == C-1 if min(A)+1 == C + if (Op1Max == Op0Min + 1) { + Constant *CMinus1 = ConstantInt::get(Op0->getType(), *CmpC - 1); + return new ICmpInst(ICmpInst::ICMP_EQ, Op0, CMinus1); + } + } + break; + } + case ICmpInst::ICMP_UGT: { + if (Op0Min.ugt(Op1Max)) // A >u B -> true if min(A) > max(B) + return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); + + if (Op0Max.ule(Op1Min)) // A >u B -> false if max(A) <= max(B) + return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); + + if (Op1Max == Op0Min) // A >u B -> A != B if min(A) == max(B) + return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); + + const APInt *CmpC; + if (match(Op1, m_APInt(CmpC))) { + // A >u C -> A == C+1 if max(a)-1 == C + if (*CmpC == Op0Max - 1) + return new ICmpInst(ICmpInst::ICMP_EQ, Op0, + ConstantInt::get(Op1->getType(), *CmpC + 1)); + } + break; + } + case ICmpInst::ICMP_SLT: + if (Op0Max.slt(Op1Min)) // A <s B -> true if max(A) < min(C) + return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); + if (Op0Min.sge(Op1Max)) // A <s B -> false if min(A) >= max(C) + return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); + if (Op1Min == Op0Max) // A <s B -> A != B if max(A) == min(B) + return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); + if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) { + if (Op1Max == Op0Min + 1) // A <s C -> A == C-1 if min(A)+1 == C + return new ICmpInst(ICmpInst::ICMP_EQ, Op0, + Builder->getInt(CI->getValue() - 1)); + } + break; + case ICmpInst::ICMP_SGT: + if (Op0Min.sgt(Op1Max)) // A >s B -> true if min(A) > max(B) + return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); + if (Op0Max.sle(Op1Min)) // A >s B -> false if max(A) <= min(B) + return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); + + if (Op1Max == Op0Min) // A >s B -> A != B if min(A) == max(B) + return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); + if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) { + if (Op1Min == Op0Max - 1) // A >s C -> A == C+1 if max(A)-1 == C + return new ICmpInst(ICmpInst::ICMP_EQ, Op0, + Builder->getInt(CI->getValue() + 1)); + } + break; + case ICmpInst::ICMP_SGE: + assert(!isa<ConstantInt>(Op1) && "ICMP_SGE with ConstantInt not folded!"); + if (Op0Min.sge(Op1Max)) // A >=s B -> true if min(A) >= max(B) + return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); + if (Op0Max.slt(Op1Min)) // A >=s B -> false if max(A) < min(B) + return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); + break; + case ICmpInst::ICMP_SLE: + assert(!isa<ConstantInt>(Op1) && "ICMP_SLE with ConstantInt not folded!"); + if (Op0Max.sle(Op1Min)) // A <=s B -> true if max(A) <= min(B) + return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); + if (Op0Min.sgt(Op1Max)) // A <=s B -> false if min(A) > max(B) + return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); + break; + case ICmpInst::ICMP_UGE: + assert(!isa<ConstantInt>(Op1) && "ICMP_UGE with ConstantInt not folded!"); + if (Op0Min.uge(Op1Max)) // A >=u B -> true if min(A) >= max(B) + return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); + if (Op0Max.ult(Op1Min)) // A >=u B -> false if max(A) < min(B) + return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); + break; + case ICmpInst::ICMP_ULE: + assert(!isa<ConstantInt>(Op1) && "ICMP_ULE with ConstantInt not folded!"); + if (Op0Max.ule(Op1Min)) // A <=u B -> true if max(A) <= min(B) + return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); + if (Op0Min.ugt(Op1Max)) // A <=u B -> false if min(A) > max(B) + return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); + break; + } + + // Turn a signed comparison into an unsigned one if both operands are known to + // have the same sign. + if (I.isSigned() && + ((Op0KnownZero.isNegative() && Op1KnownZero.isNegative()) || + (Op0KnownOne.isNegative() && Op1KnownOne.isNegative()))) + return new ICmpInst(I.getUnsignedPredicate(), Op0, Op1); + + return nullptr; +} + /// If we have an icmp le or icmp ge instruction with a constant operand, turn /// it into the appropriate icmp lt or icmp gt instruction. This transform /// allows them to be folded in visitICmpInst. @@ -3131,6 +4166,7 @@ static ICmpInst *canonicalizeCmpWithConstant(ICmpInst &I) { if (isa<UndefValue>(Elt)) continue; + // Bail out if we can't determine if this constant is min/max or if we // know that this constant is min/max. auto *CI = dyn_cast<ConstantInt>(Elt); @@ -3167,7 +4203,7 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { } if (Value *V = - SimplifyICmpInst(I.getPredicate(), Op0, Op1, DL, TLI, DT, AC, &I)) + SimplifyICmpInst(I.getPredicate(), Op0, Op1, DL, &TLI, &DT, &AC, &I)) return replaceInstUsesWith(I, V); // comparing -val or val with non-zero is the same as just comparing val @@ -3202,28 +4238,28 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { case ICmpInst::ICMP_UGT: std::swap(Op0, Op1); // Change icmp ugt -> icmp ult - // FALL THROUGH + LLVM_FALLTHROUGH; case ICmpInst::ICMP_ULT:{ // icmp ult i1 A, B -> ~A & B Value *Not = Builder->CreateNot(Op0, I.getName() + "tmp"); return BinaryOperator::CreateAnd(Not, Op1); } case ICmpInst::ICMP_SGT: std::swap(Op0, Op1); // Change icmp sgt -> icmp slt - // FALL THROUGH + LLVM_FALLTHROUGH; case ICmpInst::ICMP_SLT: { // icmp slt i1 A, B -> A & ~B Value *Not = Builder->CreateNot(Op1, I.getName() + "tmp"); return BinaryOperator::CreateAnd(Not, Op0); } case ICmpInst::ICMP_UGE: std::swap(Op0, Op1); // Change icmp uge -> icmp ule - // FALL THROUGH + LLVM_FALLTHROUGH; case ICmpInst::ICMP_ULE: { // icmp ule i1 A, B -> ~A | B Value *Not = Builder->CreateNot(Op0, I.getName() + "tmp"); return BinaryOperator::CreateOr(Not, Op1); } case ICmpInst::ICMP_SGE: std::swap(Op0, Op1); // Change icmp sge -> icmp sle - // FALL THROUGH + LLVM_FALLTHROUGH; case ICmpInst::ICMP_SLE: { // icmp sle i1 A, B -> A | ~B Value *Not = Builder->CreateNot(Op1, I.getName() + "tmp"); return BinaryOperator::CreateOr(Not, Op0); @@ -3234,372 +4270,11 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { if (ICmpInst *NewICmp = canonicalizeCmpWithConstant(I)) return NewICmp; - unsigned BitWidth = 0; - if (Ty->isIntOrIntVectorTy()) - BitWidth = Ty->getScalarSizeInBits(); - else // Get pointer size. - BitWidth = DL.getTypeSizeInBits(Ty->getScalarType()); - - bool isSignBit = false; - - // See if we are doing a comparison with a constant. - if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) { - Value *A = nullptr, *B = nullptr; - - // Match the following pattern, which is a common idiom when writing - // overflow-safe integer arithmetic function. The source performs an - // addition in wider type, and explicitly checks for overflow using - // comparisons against INT_MIN and INT_MAX. Simplify this by using the - // sadd_with_overflow intrinsic. - // - // TODO: This could probably be generalized to handle other overflow-safe - // operations if we worked out the formulas to compute the appropriate - // magic constants. - // - // sum = a + b - // if (sum+128 >u 255) ... -> llvm.sadd.with.overflow.i8 - { - ConstantInt *CI2; // I = icmp ugt (add (add A, B), CI2), CI - if (I.getPredicate() == ICmpInst::ICMP_UGT && - match(Op0, m_Add(m_Add(m_Value(A), m_Value(B)), m_ConstantInt(CI2)))) - if (Instruction *Res = ProcessUGT_ADDCST_ADD(I, A, B, CI2, CI, *this)) - return Res; - } - - // (icmp sgt smin(PosA, B) 0) -> (icmp sgt B 0) - if (CI->isZero() && I.getPredicate() == ICmpInst::ICMP_SGT) - if (auto *SI = dyn_cast<SelectInst>(Op0)) { - SelectPatternResult SPR = matchSelectPattern(SI, A, B); - if (SPR.Flavor == SPF_SMIN) { - if (isKnownPositive(A, DL)) - return new ICmpInst(I.getPredicate(), B, CI); - if (isKnownPositive(B, DL)) - return new ICmpInst(I.getPredicate(), A, CI); - } - } - - - // The following transforms are only 'worth it' if the only user of the - // subtraction is the icmp. - if (Op0->hasOneUse()) { - // (icmp ne/eq (sub A B) 0) -> (icmp ne/eq A, B) - if (I.isEquality() && CI->isZero() && - match(Op0, m_Sub(m_Value(A), m_Value(B)))) - return new ICmpInst(I.getPredicate(), A, B); - - // (icmp sgt (sub nsw A B), -1) -> (icmp sge A, B) - if (I.getPredicate() == ICmpInst::ICMP_SGT && CI->isAllOnesValue() && - match(Op0, m_NSWSub(m_Value(A), m_Value(B)))) - return new ICmpInst(ICmpInst::ICMP_SGE, A, B); - - // (icmp sgt (sub nsw A B), 0) -> (icmp sgt A, B) - if (I.getPredicate() == ICmpInst::ICMP_SGT && CI->isZero() && - match(Op0, m_NSWSub(m_Value(A), m_Value(B)))) - return new ICmpInst(ICmpInst::ICMP_SGT, A, B); - - // (icmp slt (sub nsw A B), 0) -> (icmp slt A, B) - if (I.getPredicate() == ICmpInst::ICMP_SLT && CI->isZero() && - match(Op0, m_NSWSub(m_Value(A), m_Value(B)))) - return new ICmpInst(ICmpInst::ICMP_SLT, A, B); - - // (icmp slt (sub nsw A B), 1) -> (icmp sle A, B) - if (I.getPredicate() == ICmpInst::ICMP_SLT && CI->isOne() && - match(Op0, m_NSWSub(m_Value(A), m_Value(B)))) - return new ICmpInst(ICmpInst::ICMP_SLE, A, B); - } - - if (I.isEquality()) { - ConstantInt *CI2; - if (match(Op0, m_AShr(m_ConstantInt(CI2), m_Value(A))) || - match(Op0, m_LShr(m_ConstantInt(CI2), m_Value(A)))) { - // (icmp eq/ne (ashr/lshr const2, A), const1) - if (Instruction *Inst = FoldICmpCstShrCst(I, Op0, A, CI, CI2)) - return Inst; - } - if (match(Op0, m_Shl(m_ConstantInt(CI2), m_Value(A)))) { - // (icmp eq/ne (shl const2, A), const1) - if (Instruction *Inst = FoldICmpCstShlCst(I, Op0, A, CI, CI2)) - return Inst; - } - } - - // If this comparison is a normal comparison, it demands all - // bits, if it is a sign bit comparison, it only demands the sign bit. - bool UnusedBit; - isSignBit = isSignBitCheck(I.getPredicate(), CI, UnusedBit); - - // Canonicalize icmp instructions based on dominating conditions. - BasicBlock *Parent = I.getParent(); - BasicBlock *Dom = Parent->getSinglePredecessor(); - auto *BI = Dom ? dyn_cast<BranchInst>(Dom->getTerminator()) : nullptr; - ICmpInst::Predicate Pred; - BasicBlock *TrueBB, *FalseBB; - ConstantInt *CI2; - if (BI && match(BI, m_Br(m_ICmp(Pred, m_Specific(Op0), m_ConstantInt(CI2)), - TrueBB, FalseBB)) && - TrueBB != FalseBB) { - ConstantRange CR = ConstantRange::makeAllowedICmpRegion(I.getPredicate(), - CI->getValue()); - ConstantRange DominatingCR = - (Parent == TrueBB) - ? ConstantRange::makeExactICmpRegion(Pred, CI2->getValue()) - : ConstantRange::makeExactICmpRegion( - CmpInst::getInversePredicate(Pred), CI2->getValue()); - ConstantRange Intersection = DominatingCR.intersectWith(CR); - ConstantRange Difference = DominatingCR.difference(CR); - if (Intersection.isEmptySet()) - return replaceInstUsesWith(I, Builder->getFalse()); - if (Difference.isEmptySet()) - return replaceInstUsesWith(I, Builder->getTrue()); - // Canonicalizing a sign bit comparison that gets used in a branch, - // pessimizes codegen by generating branch on zero instruction instead - // of a test and branch. So we avoid canonicalizing in such situations - // because test and branch instruction has better branch displacement - // than compare and branch instruction. - if (!isBranchOnSignBitCheck(I, isSignBit) && !I.isEquality()) { - if (auto *AI = Intersection.getSingleElement()) - return new ICmpInst(ICmpInst::ICMP_EQ, Op0, Builder->getInt(*AI)); - if (auto *AD = Difference.getSingleElement()) - return new ICmpInst(ICmpInst::ICMP_NE, Op0, Builder->getInt(*AD)); - } - } - } - - // See if we can fold the comparison based on range information we can get - // by checking whether bits are known to be zero or one in the input. - if (BitWidth != 0) { - APInt Op0KnownZero(BitWidth, 0), Op0KnownOne(BitWidth, 0); - APInt Op1KnownZero(BitWidth, 0), Op1KnownOne(BitWidth, 0); - - if (SimplifyDemandedBits(I.getOperandUse(0), - DemandedBitsLHSMask(I, BitWidth, isSignBit), - Op0KnownZero, Op0KnownOne, 0)) - return &I; - if (SimplifyDemandedBits(I.getOperandUse(1), - APInt::getAllOnesValue(BitWidth), Op1KnownZero, - Op1KnownOne, 0)) - return &I; - - // Given the known and unknown bits, compute a range that the LHS could be - // in. Compute the Min, Max and RHS values based on the known bits. For the - // EQ and NE we use unsigned values. - APInt Op0Min(BitWidth, 0), Op0Max(BitWidth, 0); - APInt Op1Min(BitWidth, 0), Op1Max(BitWidth, 0); - if (I.isSigned()) { - ComputeSignedMinMaxValuesFromKnownBits(Op0KnownZero, Op0KnownOne, - Op0Min, Op0Max); - ComputeSignedMinMaxValuesFromKnownBits(Op1KnownZero, Op1KnownOne, - Op1Min, Op1Max); - } else { - ComputeUnsignedMinMaxValuesFromKnownBits(Op0KnownZero, Op0KnownOne, - Op0Min, Op0Max); - ComputeUnsignedMinMaxValuesFromKnownBits(Op1KnownZero, Op1KnownOne, - Op1Min, Op1Max); - } - - // If Min and Max are known to be the same, then SimplifyDemandedBits - // figured out that the LHS is a constant. Just constant fold this now so - // that code below can assume that Min != Max. - if (!isa<Constant>(Op0) && Op0Min == Op0Max) - return new ICmpInst(I.getPredicate(), - ConstantInt::get(Op0->getType(), Op0Min), Op1); - if (!isa<Constant>(Op1) && Op1Min == Op1Max) - return new ICmpInst(I.getPredicate(), Op0, - ConstantInt::get(Op1->getType(), Op1Min)); - - // Based on the range information we know about the LHS, see if we can - // simplify this comparison. For example, (x&4) < 8 is always true. - switch (I.getPredicate()) { - default: llvm_unreachable("Unknown icmp opcode!"); - case ICmpInst::ICMP_EQ: { - if (Op0Max.ult(Op1Min) || Op0Min.ugt(Op1Max)) - return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); - - // If all bits are known zero except for one, then we know at most one - // bit is set. If the comparison is against zero, then this is a check - // to see if *that* bit is set. - APInt Op0KnownZeroInverted = ~Op0KnownZero; - if (~Op1KnownZero == 0) { - // If the LHS is an AND with the same constant, look through it. - Value *LHS = nullptr; - ConstantInt *LHSC = nullptr; - if (!match(Op0, m_And(m_Value(LHS), m_ConstantInt(LHSC))) || - LHSC->getValue() != Op0KnownZeroInverted) - LHS = Op0; - - // If the LHS is 1 << x, and we know the result is a power of 2 like 8, - // then turn "((1 << x)&8) == 0" into "x != 3". - // or turn "((1 << x)&7) == 0" into "x > 2". - Value *X = nullptr; - if (match(LHS, m_Shl(m_One(), m_Value(X)))) { - APInt ValToCheck = Op0KnownZeroInverted; - if (ValToCheck.isPowerOf2()) { - unsigned CmpVal = ValToCheck.countTrailingZeros(); - return new ICmpInst(ICmpInst::ICMP_NE, X, - ConstantInt::get(X->getType(), CmpVal)); - } else if ((++ValToCheck).isPowerOf2()) { - unsigned CmpVal = ValToCheck.countTrailingZeros() - 1; - return new ICmpInst(ICmpInst::ICMP_UGT, X, - ConstantInt::get(X->getType(), CmpVal)); - } - } - - // If the LHS is 8 >>u x, and we know the result is a power of 2 like 1, - // then turn "((8 >>u x)&1) == 0" into "x != 3". - const APInt *CI; - if (Op0KnownZeroInverted == 1 && - match(LHS, m_LShr(m_Power2(CI), m_Value(X)))) - return new ICmpInst(ICmpInst::ICMP_NE, X, - ConstantInt::get(X->getType(), - CI->countTrailingZeros())); - } - break; - } - case ICmpInst::ICMP_NE: { - if (Op0Max.ult(Op1Min) || Op0Min.ugt(Op1Max)) - return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); - - // If all bits are known zero except for one, then we know at most one - // bit is set. If the comparison is against zero, then this is a check - // to see if *that* bit is set. - APInt Op0KnownZeroInverted = ~Op0KnownZero; - if (~Op1KnownZero == 0) { - // If the LHS is an AND with the same constant, look through it. - Value *LHS = nullptr; - ConstantInt *LHSC = nullptr; - if (!match(Op0, m_And(m_Value(LHS), m_ConstantInt(LHSC))) || - LHSC->getValue() != Op0KnownZeroInverted) - LHS = Op0; - - // If the LHS is 1 << x, and we know the result is a power of 2 like 8, - // then turn "((1 << x)&8) != 0" into "x == 3". - // or turn "((1 << x)&7) != 0" into "x < 3". - Value *X = nullptr; - if (match(LHS, m_Shl(m_One(), m_Value(X)))) { - APInt ValToCheck = Op0KnownZeroInverted; - if (ValToCheck.isPowerOf2()) { - unsigned CmpVal = ValToCheck.countTrailingZeros(); - return new ICmpInst(ICmpInst::ICMP_EQ, X, - ConstantInt::get(X->getType(), CmpVal)); - } else if ((++ValToCheck).isPowerOf2()) { - unsigned CmpVal = ValToCheck.countTrailingZeros(); - return new ICmpInst(ICmpInst::ICMP_ULT, X, - ConstantInt::get(X->getType(), CmpVal)); - } - } - - // If the LHS is 8 >>u x, and we know the result is a power of 2 like 1, - // then turn "((8 >>u x)&1) != 0" into "x == 3". - const APInt *CI; - if (Op0KnownZeroInverted == 1 && - match(LHS, m_LShr(m_Power2(CI), m_Value(X)))) - return new ICmpInst(ICmpInst::ICMP_EQ, X, - ConstantInt::get(X->getType(), - CI->countTrailingZeros())); - } - break; - } - case ICmpInst::ICMP_ULT: - if (Op0Max.ult(Op1Min)) // A <u B -> true if max(A) < min(B) - return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); - if (Op0Min.uge(Op1Max)) // A <u B -> false if min(A) >= max(B) - return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); - if (Op1Min == Op0Max) // A <u B -> A != B if max(A) == min(B) - return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); - if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) { - if (Op1Max == Op0Min+1) // A <u C -> A == C-1 if min(A)+1 == C - return new ICmpInst(ICmpInst::ICMP_EQ, Op0, - Builder->getInt(CI->getValue()-1)); - - // (x <u 2147483648) -> (x >s -1) -> true if sign bit clear - if (CI->isMinValue(true)) - return new ICmpInst(ICmpInst::ICMP_SGT, Op0, - Constant::getAllOnesValue(Op0->getType())); - } - break; - case ICmpInst::ICMP_UGT: - if (Op0Min.ugt(Op1Max)) // A >u B -> true if min(A) > max(B) - return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); - if (Op0Max.ule(Op1Min)) // A >u B -> false if max(A) <= max(B) - return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); - - if (Op1Max == Op0Min) // A >u B -> A != B if min(A) == max(B) - return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); - if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) { - if (Op1Min == Op0Max-1) // A >u C -> A == C+1 if max(a)-1 == C - return new ICmpInst(ICmpInst::ICMP_EQ, Op0, - Builder->getInt(CI->getValue()+1)); - - // (x >u 2147483647) -> (x <s 0) -> true if sign bit set - if (CI->isMaxValue(true)) - return new ICmpInst(ICmpInst::ICMP_SLT, Op0, - Constant::getNullValue(Op0->getType())); - } - break; - case ICmpInst::ICMP_SLT: - if (Op0Max.slt(Op1Min)) // A <s B -> true if max(A) < min(C) - return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); - if (Op0Min.sge(Op1Max)) // A <s B -> false if min(A) >= max(C) - return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); - if (Op1Min == Op0Max) // A <s B -> A != B if max(A) == min(B) - return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); - if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) { - if (Op1Max == Op0Min+1) // A <s C -> A == C-1 if min(A)+1 == C - return new ICmpInst(ICmpInst::ICMP_EQ, Op0, - Builder->getInt(CI->getValue()-1)); - } - break; - case ICmpInst::ICMP_SGT: - if (Op0Min.sgt(Op1Max)) // A >s B -> true if min(A) > max(B) - return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); - if (Op0Max.sle(Op1Min)) // A >s B -> false if max(A) <= min(B) - return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); - - if (Op1Max == Op0Min) // A >s B -> A != B if min(A) == max(B) - return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); - if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) { - if (Op1Min == Op0Max-1) // A >s C -> A == C+1 if max(A)-1 == C - return new ICmpInst(ICmpInst::ICMP_EQ, Op0, - Builder->getInt(CI->getValue()+1)); - } - break; - case ICmpInst::ICMP_SGE: - assert(!isa<ConstantInt>(Op1) && "ICMP_SGE with ConstantInt not folded!"); - if (Op0Min.sge(Op1Max)) // A >=s B -> true if min(A) >= max(B) - return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); - if (Op0Max.slt(Op1Min)) // A >=s B -> false if max(A) < min(B) - return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); - break; - case ICmpInst::ICMP_SLE: - assert(!isa<ConstantInt>(Op1) && "ICMP_SLE with ConstantInt not folded!"); - if (Op0Max.sle(Op1Min)) // A <=s B -> true if max(A) <= min(B) - return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); - if (Op0Min.sgt(Op1Max)) // A <=s B -> false if min(A) > max(B) - return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); - break; - case ICmpInst::ICMP_UGE: - assert(!isa<ConstantInt>(Op1) && "ICMP_UGE with ConstantInt not folded!"); - if (Op0Min.uge(Op1Max)) // A >=u B -> true if min(A) >= max(B) - return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); - if (Op0Max.ult(Op1Min)) // A >=u B -> false if max(A) < min(B) - return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); - break; - case ICmpInst::ICMP_ULE: - assert(!isa<ConstantInt>(Op1) && "ICMP_ULE with ConstantInt not folded!"); - if (Op0Max.ule(Op1Min)) // A <=u B -> true if max(A) <= min(B) - return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); - if (Op0Min.ugt(Op1Max)) // A <=u B -> false if min(A) > max(B) - return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); - break; - } + if (Instruction *Res = foldICmpWithConstant(I)) + return Res; - // Turn a signed comparison into an unsigned one if both operands - // are known to have the same sign. - if (I.isSigned() && - ((Op0KnownZero.isNegative() && Op1KnownZero.isNegative()) || - (Op0KnownOne.isNegative() && Op1KnownOne.isNegative()))) - return new ICmpInst(I.getUnsignedPredicate(), Op0, Op1); - } + if (Instruction *Res = foldICmpUsingKnownBits(I)) + return Res; // Test if the ICmpInst instruction is used exclusively by a select as // part of a minimum or maximum operation. If so, refrain from doing @@ -3614,122 +4289,39 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { (SI->getOperand(2) == Op0 && SI->getOperand(1) == Op1)) return nullptr; - // See if we are doing a comparison between a constant and an instruction that - // can be folded into the comparison. - if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) { - Value *A = nullptr, *B = nullptr; - // Since the RHS is a ConstantInt (CI), if the left hand side is an - // instruction, see if that instruction also has constants so that the - // instruction can be folded into the icmp - if (Instruction *LHSI = dyn_cast<Instruction>(Op0)) - if (Instruction *Res = visitICmpInstWithInstAndIntCst(I, LHSI, CI)) - return Res; + // FIXME: We only do this after checking for min/max to prevent infinite + // looping caused by a reverse canonicalization of these patterns for min/max. + // FIXME: The organization of folds is a mess. These would naturally go into + // canonicalizeCmpWithConstant(), but we can't move all of the above folds + // down here after the min/max restriction. + ICmpInst::Predicate Pred = I.getPredicate(); + const APInt *C; + if (match(Op1, m_APInt(C))) { + // For i32: x >u 2147483647 -> x <s 0 -> true if sign bit set + if (Pred == ICmpInst::ICMP_UGT && C->isMaxSignedValue()) { + Constant *Zero = Constant::getNullValue(Op0->getType()); + return new ICmpInst(ICmpInst::ICMP_SLT, Op0, Zero); + } - // (icmp eq/ne (udiv A, B), 0) -> (icmp ugt/ule i32 B, A) - if (I.isEquality() && CI->isZero() && - match(Op0, m_UDiv(m_Value(A), m_Value(B)))) { - ICmpInst::Predicate Pred = I.getPredicate() == ICmpInst::ICMP_EQ - ? ICmpInst::ICMP_UGT - : ICmpInst::ICMP_ULE; - return new ICmpInst(Pred, B, A); + // For i32: x <u 2147483648 -> x >s -1 -> true if sign bit clear + if (Pred == ICmpInst::ICMP_ULT && C->isMinSignedValue()) { + Constant *AllOnes = Constant::getAllOnesValue(Op0->getType()); + return new ICmpInst(ICmpInst::ICMP_SGT, Op0, AllOnes); } } - // Handle icmp with constant (but not simple integer constant) RHS - if (Constant *RHSC = dyn_cast<Constant>(Op1)) { - if (Instruction *LHSI = dyn_cast<Instruction>(Op0)) - switch (LHSI->getOpcode()) { - case Instruction::GetElementPtr: - // icmp pred GEP (P, int 0, int 0, int 0), null -> icmp pred P, null - if (RHSC->isNullValue() && - cast<GetElementPtrInst>(LHSI)->hasAllZeroIndices()) - return new ICmpInst(I.getPredicate(), LHSI->getOperand(0), - Constant::getNullValue(LHSI->getOperand(0)->getType())); - break; - case Instruction::PHI: - // Only fold icmp into the PHI if the phi and icmp are in the same - // block. If in the same block, we're encouraging jump threading. If - // not, we are just pessimizing the code by making an i1 phi. - if (LHSI->getParent() == I.getParent()) - if (Instruction *NV = FoldOpIntoPhi(I)) - return NV; - break; - case Instruction::Select: { - // If either operand of the select is a constant, we can fold the - // comparison into the select arms, which will cause one to be - // constant folded and the select turned into a bitwise or. - Value *Op1 = nullptr, *Op2 = nullptr; - ConstantInt *CI = nullptr; - if (Constant *C = dyn_cast<Constant>(LHSI->getOperand(1))) { - Op1 = ConstantExpr::getICmp(I.getPredicate(), C, RHSC); - CI = dyn_cast<ConstantInt>(Op1); - } - if (Constant *C = dyn_cast<Constant>(LHSI->getOperand(2))) { - Op2 = ConstantExpr::getICmp(I.getPredicate(), C, RHSC); - CI = dyn_cast<ConstantInt>(Op2); - } - - // We only want to perform this transformation if it will not lead to - // additional code. This is true if either both sides of the select - // fold to a constant (in which case the icmp is replaced with a select - // which will usually simplify) or this is the only user of the - // select (in which case we are trading a select+icmp for a simpler - // select+icmp) or all uses of the select can be replaced based on - // dominance information ("Global cases"). - bool Transform = false; - if (Op1 && Op2) - Transform = true; - else if (Op1 || Op2) { - // Local case - if (LHSI->hasOneUse()) - Transform = true; - // Global cases - else if (CI && !CI->isZero()) - // When Op1 is constant try replacing select with second operand. - // Otherwise Op2 is constant and try replacing select with first - // operand. - Transform = replacedSelectWithOperand(cast<SelectInst>(LHSI), &I, - Op1 ? 2 : 1); - } - if (Transform) { - if (!Op1) - Op1 = Builder->CreateICmp(I.getPredicate(), LHSI->getOperand(1), - RHSC, I.getName()); - if (!Op2) - Op2 = Builder->CreateICmp(I.getPredicate(), LHSI->getOperand(2), - RHSC, I.getName()); - return SelectInst::Create(LHSI->getOperand(0), Op1, Op2); - } - break; - } - case Instruction::IntToPtr: - // icmp pred inttoptr(X), null -> icmp pred X, 0 - if (RHSC->isNullValue() && - DL.getIntPtrType(RHSC->getType()) == LHSI->getOperand(0)->getType()) - return new ICmpInst(I.getPredicate(), LHSI->getOperand(0), - Constant::getNullValue(LHSI->getOperand(0)->getType())); - break; + if (Instruction *Res = foldICmpInstWithConstant(I)) + return Res; - case Instruction::Load: - // Try to optimize things like "A[i] > 4" to index computations. - if (GetElementPtrInst *GEP = - dyn_cast<GetElementPtrInst>(LHSI->getOperand(0))) { - if (GlobalVariable *GV = dyn_cast<GlobalVariable>(GEP->getOperand(0))) - if (GV->isConstant() && GV->hasDefinitiveInitializer() && - !cast<LoadInst>(LHSI)->isVolatile()) - if (Instruction *Res = FoldCmpLoadFromIndexedGlobal(GEP, GV, I)) - return Res; - } - break; - } - } + if (Instruction *Res = foldICmpInstWithConstantNotInt(I)) + return Res; // If we can optimize a 'icmp GEP, P' or 'icmp P, GEP', do so now. if (GEPOperator *GEP = dyn_cast<GEPOperator>(Op0)) - if (Instruction *NI = FoldGEPICmp(GEP, Op1, I.getPredicate(), I)) + if (Instruction *NI = foldGEPICmp(GEP, Op1, I.getPredicate(), I)) return NI; if (GEPOperator *GEP = dyn_cast<GEPOperator>(Op1)) - if (Instruction *NI = FoldGEPICmp(GEP, Op0, + if (Instruction *NI = foldGEPICmp(GEP, Op0, ICmpInst::getSwappedPredicate(I.getPredicate()), I)) return NI; @@ -3737,10 +4329,10 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { if (Op0->getType()->isPointerTy() && I.isEquality()) { assert(Op1->getType()->isPointerTy() && "Comparing pointer with non-pointer?"); if (auto *Alloca = dyn_cast<AllocaInst>(GetUnderlyingObject(Op0, DL))) - if (Instruction *New = FoldAllocaCmp(I, Alloca, Op1)) + if (Instruction *New = foldAllocaCmp(I, Alloca, Op1)) return New; if (auto *Alloca = dyn_cast<AllocaInst>(GetUnderlyingObject(Op1, DL))) - if (Instruction *New = FoldAllocaCmp(I, Alloca, Op0)) + if (Instruction *New = foldAllocaCmp(I, Alloca, Op0)) return New; } @@ -3780,318 +4372,24 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { // For generality, we handle any zero-extension of any operand comparison // with a constant or another cast from the same type. if (isa<Constant>(Op1) || isa<CastInst>(Op1)) - if (Instruction *R = visitICmpInstWithCastAndCast(I)) + if (Instruction *R = foldICmpWithCastAndCast(I)) return R; } - // Special logic for binary operators. - BinaryOperator *BO0 = dyn_cast<BinaryOperator>(Op0); - BinaryOperator *BO1 = dyn_cast<BinaryOperator>(Op1); - if (BO0 || BO1) { - CmpInst::Predicate Pred = I.getPredicate(); - bool NoOp0WrapProblem = false, NoOp1WrapProblem = false; - if (BO0 && isa<OverflowingBinaryOperator>(BO0)) - NoOp0WrapProblem = ICmpInst::isEquality(Pred) || - (CmpInst::isUnsigned(Pred) && BO0->hasNoUnsignedWrap()) || - (CmpInst::isSigned(Pred) && BO0->hasNoSignedWrap()); - if (BO1 && isa<OverflowingBinaryOperator>(BO1)) - NoOp1WrapProblem = ICmpInst::isEquality(Pred) || - (CmpInst::isUnsigned(Pred) && BO1->hasNoUnsignedWrap()) || - (CmpInst::isSigned(Pred) && BO1->hasNoSignedWrap()); - - // Analyze the case when either Op0 or Op1 is an add instruction. - // Op0 = A + B (or A and B are null); Op1 = C + D (or C and D are null). - Value *A = nullptr, *B = nullptr, *C = nullptr, *D = nullptr; - if (BO0 && BO0->getOpcode() == Instruction::Add) { - A = BO0->getOperand(0); - B = BO0->getOperand(1); - } - if (BO1 && BO1->getOpcode() == Instruction::Add) { - C = BO1->getOperand(0); - D = BO1->getOperand(1); - } - - // icmp (X+cst) < 0 --> X < -cst - if (NoOp0WrapProblem && ICmpInst::isSigned(Pred) && match(Op1, m_Zero())) - if (ConstantInt *RHSC = dyn_cast_or_null<ConstantInt>(B)) - if (!RHSC->isMinValue(/*isSigned=*/true)) - return new ICmpInst(Pred, A, ConstantExpr::getNeg(RHSC)); - - // icmp (X+Y), X -> icmp Y, 0 for equalities or if there is no overflow. - if ((A == Op1 || B == Op1) && NoOp0WrapProblem) - return new ICmpInst(Pred, A == Op1 ? B : A, - Constant::getNullValue(Op1->getType())); - - // icmp X, (X+Y) -> icmp 0, Y for equalities or if there is no overflow. - if ((C == Op0 || D == Op0) && NoOp1WrapProblem) - return new ICmpInst(Pred, Constant::getNullValue(Op0->getType()), - C == Op0 ? D : C); - - // icmp (X+Y), (X+Z) -> icmp Y, Z for equalities or if there is no overflow. - if (A && C && (A == C || A == D || B == C || B == D) && - NoOp0WrapProblem && NoOp1WrapProblem && - // Try not to increase register pressure. - BO0->hasOneUse() && BO1->hasOneUse()) { - // Determine Y and Z in the form icmp (X+Y), (X+Z). - Value *Y, *Z; - if (A == C) { - // C + B == C + D -> B == D - Y = B; - Z = D; - } else if (A == D) { - // D + B == C + D -> B == C - Y = B; - Z = C; - } else if (B == C) { - // A + C == C + D -> A == D - Y = A; - Z = D; - } else { - assert(B == D); - // A + D == C + D -> A == C - Y = A; - Z = C; - } - return new ICmpInst(Pred, Y, Z); - } - - // icmp slt (X + -1), Y -> icmp sle X, Y - if (A && NoOp0WrapProblem && Pred == CmpInst::ICMP_SLT && - match(B, m_AllOnes())) - return new ICmpInst(CmpInst::ICMP_SLE, A, Op1); - - // icmp sge (X + -1), Y -> icmp sgt X, Y - if (A && NoOp0WrapProblem && Pred == CmpInst::ICMP_SGE && - match(B, m_AllOnes())) - return new ICmpInst(CmpInst::ICMP_SGT, A, Op1); - - // icmp sle (X + 1), Y -> icmp slt X, Y - if (A && NoOp0WrapProblem && Pred == CmpInst::ICMP_SLE && - match(B, m_One())) - return new ICmpInst(CmpInst::ICMP_SLT, A, Op1); - - // icmp sgt (X + 1), Y -> icmp sge X, Y - if (A && NoOp0WrapProblem && Pred == CmpInst::ICMP_SGT && - match(B, m_One())) - return new ICmpInst(CmpInst::ICMP_SGE, A, Op1); - - // icmp sgt X, (Y + -1) -> icmp sge X, Y - if (C && NoOp1WrapProblem && Pred == CmpInst::ICMP_SGT && - match(D, m_AllOnes())) - return new ICmpInst(CmpInst::ICMP_SGE, Op0, C); - - // icmp sle X, (Y + -1) -> icmp slt X, Y - if (C && NoOp1WrapProblem && Pred == CmpInst::ICMP_SLE && - match(D, m_AllOnes())) - return new ICmpInst(CmpInst::ICMP_SLT, Op0, C); - - // icmp sge X, (Y + 1) -> icmp sgt X, Y - if (C && NoOp1WrapProblem && Pred == CmpInst::ICMP_SGE && - match(D, m_One())) - return new ICmpInst(CmpInst::ICMP_SGT, Op0, C); - - // icmp slt X, (Y + 1) -> icmp sle X, Y - if (C && NoOp1WrapProblem && Pred == CmpInst::ICMP_SLT && - match(D, m_One())) - return new ICmpInst(CmpInst::ICMP_SLE, Op0, C); - - // if C1 has greater magnitude than C2: - // icmp (X + C1), (Y + C2) -> icmp (X + C3), Y - // s.t. C3 = C1 - C2 - // - // if C2 has greater magnitude than C1: - // icmp (X + C1), (Y + C2) -> icmp X, (Y + C3) - // s.t. C3 = C2 - C1 - if (A && C && NoOp0WrapProblem && NoOp1WrapProblem && - (BO0->hasOneUse() || BO1->hasOneUse()) && !I.isUnsigned()) - if (ConstantInt *C1 = dyn_cast<ConstantInt>(B)) - if (ConstantInt *C2 = dyn_cast<ConstantInt>(D)) { - const APInt &AP1 = C1->getValue(); - const APInt &AP2 = C2->getValue(); - if (AP1.isNegative() == AP2.isNegative()) { - APInt AP1Abs = C1->getValue().abs(); - APInt AP2Abs = C2->getValue().abs(); - if (AP1Abs.uge(AP2Abs)) { - ConstantInt *C3 = Builder->getInt(AP1 - AP2); - Value *NewAdd = Builder->CreateNSWAdd(A, C3); - return new ICmpInst(Pred, NewAdd, C); - } else { - ConstantInt *C3 = Builder->getInt(AP2 - AP1); - Value *NewAdd = Builder->CreateNSWAdd(C, C3); - return new ICmpInst(Pred, A, NewAdd); - } - } - } - - - // Analyze the case when either Op0 or Op1 is a sub instruction. - // Op0 = A - B (or A and B are null); Op1 = C - D (or C and D are null). - A = nullptr; - B = nullptr; - C = nullptr; - D = nullptr; - if (BO0 && BO0->getOpcode() == Instruction::Sub) { - A = BO0->getOperand(0); - B = BO0->getOperand(1); - } - if (BO1 && BO1->getOpcode() == Instruction::Sub) { - C = BO1->getOperand(0); - D = BO1->getOperand(1); - } - - // icmp (X-Y), X -> icmp 0, Y for equalities or if there is no overflow. - if (A == Op1 && NoOp0WrapProblem) - return new ICmpInst(Pred, Constant::getNullValue(Op1->getType()), B); - - // icmp X, (X-Y) -> icmp Y, 0 for equalities or if there is no overflow. - if (C == Op0 && NoOp1WrapProblem) - return new ICmpInst(Pred, D, Constant::getNullValue(Op0->getType())); - - // icmp (Y-X), (Z-X) -> icmp Y, Z for equalities or if there is no overflow. - if (B && D && B == D && NoOp0WrapProblem && NoOp1WrapProblem && - // Try not to increase register pressure. - BO0->hasOneUse() && BO1->hasOneUse()) - return new ICmpInst(Pred, A, C); - - // icmp (X-Y), (X-Z) -> icmp Z, Y for equalities or if there is no overflow. - if (A && C && A == C && NoOp0WrapProblem && NoOp1WrapProblem && - // Try not to increase register pressure. - BO0->hasOneUse() && BO1->hasOneUse()) - return new ICmpInst(Pred, D, B); - - // icmp (0-X) < cst --> x > -cst - if (NoOp0WrapProblem && ICmpInst::isSigned(Pred)) { - Value *X; - if (match(BO0, m_Neg(m_Value(X)))) - if (ConstantInt *RHSC = dyn_cast<ConstantInt>(Op1)) - if (!RHSC->isMinValue(/*isSigned=*/true)) - return new ICmpInst(I.getSwappedPredicate(), X, - ConstantExpr::getNeg(RHSC)); - } - - BinaryOperator *SRem = nullptr; - // icmp (srem X, Y), Y - if (BO0 && BO0->getOpcode() == Instruction::SRem && - Op1 == BO0->getOperand(1)) - SRem = BO0; - // icmp Y, (srem X, Y) - else if (BO1 && BO1->getOpcode() == Instruction::SRem && - Op0 == BO1->getOperand(1)) - SRem = BO1; - if (SRem) { - // We don't check hasOneUse to avoid increasing register pressure because - // the value we use is the same value this instruction was already using. - switch (SRem == BO0 ? ICmpInst::getSwappedPredicate(Pred) : Pred) { - default: break; - case ICmpInst::ICMP_EQ: - return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); - case ICmpInst::ICMP_NE: - return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); - case ICmpInst::ICMP_SGT: - case ICmpInst::ICMP_SGE: - return new ICmpInst(ICmpInst::ICMP_SGT, SRem->getOperand(1), - Constant::getAllOnesValue(SRem->getType())); - case ICmpInst::ICMP_SLT: - case ICmpInst::ICMP_SLE: - return new ICmpInst(ICmpInst::ICMP_SLT, SRem->getOperand(1), - Constant::getNullValue(SRem->getType())); - } - } - - if (BO0 && BO1 && BO0->getOpcode() == BO1->getOpcode() && - BO0->hasOneUse() && BO1->hasOneUse() && - BO0->getOperand(1) == BO1->getOperand(1)) { - switch (BO0->getOpcode()) { - default: break; - case Instruction::Add: - case Instruction::Sub: - case Instruction::Xor: - if (I.isEquality()) // a+x icmp eq/ne b+x --> a icmp b - return new ICmpInst(I.getPredicate(), BO0->getOperand(0), - BO1->getOperand(0)); - // icmp u/s (a ^ signbit), (b ^ signbit) --> icmp s/u a, b - if (ConstantInt *CI = dyn_cast<ConstantInt>(BO0->getOperand(1))) { - if (CI->getValue().isSignBit()) { - ICmpInst::Predicate Pred = I.isSigned() - ? I.getUnsignedPredicate() - : I.getSignedPredicate(); - return new ICmpInst(Pred, BO0->getOperand(0), - BO1->getOperand(0)); - } - - if (BO0->getOpcode() == Instruction::Xor && CI->isMaxValue(true)) { - ICmpInst::Predicate Pred = I.isSigned() - ? I.getUnsignedPredicate() - : I.getSignedPredicate(); - Pred = I.getSwappedPredicate(Pred); - return new ICmpInst(Pred, BO0->getOperand(0), - BO1->getOperand(0)); - } - } - break; - case Instruction::Mul: - if (!I.isEquality()) - break; - - if (ConstantInt *CI = dyn_cast<ConstantInt>(BO0->getOperand(1))) { - // a * Cst icmp eq/ne b * Cst --> a & Mask icmp b & Mask - // Mask = -1 >> count-trailing-zeros(Cst). - if (!CI->isZero() && !CI->isOne()) { - const APInt &AP = CI->getValue(); - ConstantInt *Mask = ConstantInt::get(I.getContext(), - APInt::getLowBitsSet(AP.getBitWidth(), - AP.getBitWidth() - - AP.countTrailingZeros())); - Value *And1 = Builder->CreateAnd(BO0->getOperand(0), Mask); - Value *And2 = Builder->CreateAnd(BO1->getOperand(0), Mask); - return new ICmpInst(I.getPredicate(), And1, And2); - } - } - break; - case Instruction::UDiv: - case Instruction::LShr: - if (I.isSigned()) - break; - // fall-through - case Instruction::SDiv: - case Instruction::AShr: - if (!BO0->isExact() || !BO1->isExact()) - break; - return new ICmpInst(I.getPredicate(), BO0->getOperand(0), - BO1->getOperand(0)); - case Instruction::Shl: { - bool NUW = BO0->hasNoUnsignedWrap() && BO1->hasNoUnsignedWrap(); - bool NSW = BO0->hasNoSignedWrap() && BO1->hasNoSignedWrap(); - if (!NUW && !NSW) - break; - if (!NSW && I.isSigned()) - break; - return new ICmpInst(I.getPredicate(), BO0->getOperand(0), - BO1->getOperand(0)); - } - } - } - - if (BO0) { - // Transform A & (L - 1) `ult` L --> L != 0 - auto LSubOne = m_Add(m_Specific(Op1), m_AllOnes()); - auto BitwiseAnd = - m_CombineOr(m_And(m_Value(), LSubOne), m_And(LSubOne, m_Value())); + if (Instruction *Res = foldICmpBinOp(I)) + return Res; - if (match(BO0, BitwiseAnd) && I.getPredicate() == ICmpInst::ICMP_ULT) { - auto *Zero = Constant::getNullValue(BO0->getType()); - return new ICmpInst(ICmpInst::ICMP_NE, Op1, Zero); - } - } - } + if (Instruction *Res = foldICmpWithMinMax(I)) + return Res; - { Value *A, *B; + { + Value *A, *B; // Transform (A & ~B) == 0 --> (A & B) != 0 // and (A & ~B) != 0 --> (A & B) == 0 // if A is a power of 2. if (match(Op0, m_And(m_Value(A), m_Not(m_Value(B)))) && match(Op1, m_Zero()) && - isKnownToBeAPowerOfTwo(A, DL, false, 0, AC, &I, DT) && I.isEquality()) + isKnownToBeAPowerOfTwo(A, DL, false, 0, &AC, &I, &DT) && I.isEquality()) return new ICmpInst(I.getInversePredicate(), Builder->CreateAnd(A, B), Op1); @@ -4120,149 +4418,17 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { // (zext a) * (zext b) --> llvm.umul.with.overflow. if (match(Op0, m_Mul(m_ZExt(m_Value(A)), m_ZExt(m_Value(B))))) { - if (Instruction *R = ProcessUMulZExtIdiom(I, Op0, Op1, *this)) + if (Instruction *R = processUMulZExtIdiom(I, Op0, Op1, *this)) return R; } if (match(Op1, m_Mul(m_ZExt(m_Value(A)), m_ZExt(m_Value(B))))) { - if (Instruction *R = ProcessUMulZExtIdiom(I, Op1, Op0, *this)) + if (Instruction *R = processUMulZExtIdiom(I, Op1, Op0, *this)) return R; } } - if (I.isEquality()) { - Value *A, *B, *C, *D; - - if (match(Op0, m_Xor(m_Value(A), m_Value(B)))) { - if (A == Op1 || B == Op1) { // (A^B) == A -> B == 0 - Value *OtherVal = A == Op1 ? B : A; - return new ICmpInst(I.getPredicate(), OtherVal, - Constant::getNullValue(A->getType())); - } - - if (match(Op1, m_Xor(m_Value(C), m_Value(D)))) { - // A^c1 == C^c2 --> A == C^(c1^c2) - ConstantInt *C1, *C2; - if (match(B, m_ConstantInt(C1)) && - match(D, m_ConstantInt(C2)) && Op1->hasOneUse()) { - Constant *NC = Builder->getInt(C1->getValue() ^ C2->getValue()); - Value *Xor = Builder->CreateXor(C, NC); - return new ICmpInst(I.getPredicate(), A, Xor); - } - - // A^B == A^D -> B == D - if (A == C) return new ICmpInst(I.getPredicate(), B, D); - if (A == D) return new ICmpInst(I.getPredicate(), B, C); - if (B == C) return new ICmpInst(I.getPredicate(), A, D); - if (B == D) return new ICmpInst(I.getPredicate(), A, C); - } - } - - if (match(Op1, m_Xor(m_Value(A), m_Value(B))) && - (A == Op0 || B == Op0)) { - // A == (A^B) -> B == 0 - Value *OtherVal = A == Op0 ? B : A; - return new ICmpInst(I.getPredicate(), OtherVal, - Constant::getNullValue(A->getType())); - } - - // (X&Z) == (Y&Z) -> (X^Y) & Z == 0 - if (match(Op0, m_OneUse(m_And(m_Value(A), m_Value(B)))) && - match(Op1, m_OneUse(m_And(m_Value(C), m_Value(D))))) { - Value *X = nullptr, *Y = nullptr, *Z = nullptr; - - if (A == C) { - X = B; Y = D; Z = A; - } else if (A == D) { - X = B; Y = C; Z = A; - } else if (B == C) { - X = A; Y = D; Z = B; - } else if (B == D) { - X = A; Y = C; Z = B; - } - - if (X) { // Build (X^Y) & Z - Op1 = Builder->CreateXor(X, Y); - Op1 = Builder->CreateAnd(Op1, Z); - I.setOperand(0, Op1); - I.setOperand(1, Constant::getNullValue(Op1->getType())); - return &I; - } - } - - // Transform (zext A) == (B & (1<<X)-1) --> A == (trunc B) - // and (B & (1<<X)-1) == (zext A) --> A == (trunc B) - ConstantInt *Cst1; - if ((Op0->hasOneUse() && - match(Op0, m_ZExt(m_Value(A))) && - match(Op1, m_And(m_Value(B), m_ConstantInt(Cst1)))) || - (Op1->hasOneUse() && - match(Op0, m_And(m_Value(B), m_ConstantInt(Cst1))) && - match(Op1, m_ZExt(m_Value(A))))) { - APInt Pow2 = Cst1->getValue() + 1; - if (Pow2.isPowerOf2() && isa<IntegerType>(A->getType()) && - Pow2.logBase2() == cast<IntegerType>(A->getType())->getBitWidth()) - return new ICmpInst(I.getPredicate(), A, - Builder->CreateTrunc(B, A->getType())); - } - - // (A >> C) == (B >> C) --> (A^B) u< (1 << C) - // For lshr and ashr pairs. - if ((match(Op0, m_OneUse(m_LShr(m_Value(A), m_ConstantInt(Cst1)))) && - match(Op1, m_OneUse(m_LShr(m_Value(B), m_Specific(Cst1))))) || - (match(Op0, m_OneUse(m_AShr(m_Value(A), m_ConstantInt(Cst1)))) && - match(Op1, m_OneUse(m_AShr(m_Value(B), m_Specific(Cst1)))))) { - unsigned TypeBits = Cst1->getBitWidth(); - unsigned ShAmt = (unsigned)Cst1->getLimitedValue(TypeBits); - if (ShAmt < TypeBits && ShAmt != 0) { - ICmpInst::Predicate Pred = I.getPredicate() == ICmpInst::ICMP_NE - ? ICmpInst::ICMP_UGE - : ICmpInst::ICMP_ULT; - Value *Xor = Builder->CreateXor(A, B, I.getName() + ".unshifted"); - APInt CmpVal = APInt::getOneBitSet(TypeBits, ShAmt); - return new ICmpInst(Pred, Xor, Builder->getInt(CmpVal)); - } - } - - // (A << C) == (B << C) --> ((A^B) & (~0U >> C)) == 0 - if (match(Op0, m_OneUse(m_Shl(m_Value(A), m_ConstantInt(Cst1)))) && - match(Op1, m_OneUse(m_Shl(m_Value(B), m_Specific(Cst1))))) { - unsigned TypeBits = Cst1->getBitWidth(); - unsigned ShAmt = (unsigned)Cst1->getLimitedValue(TypeBits); - if (ShAmt < TypeBits && ShAmt != 0) { - Value *Xor = Builder->CreateXor(A, B, I.getName() + ".unshifted"); - APInt AndVal = APInt::getLowBitsSet(TypeBits, TypeBits - ShAmt); - Value *And = Builder->CreateAnd(Xor, Builder->getInt(AndVal), - I.getName() + ".mask"); - return new ICmpInst(I.getPredicate(), And, - Constant::getNullValue(Cst1->getType())); - } - } - - // Transform "icmp eq (trunc (lshr(X, cst1)), cst" to - // "icmp (and X, mask), cst" - uint64_t ShAmt = 0; - if (Op0->hasOneUse() && - match(Op0, m_Trunc(m_OneUse(m_LShr(m_Value(A), - m_ConstantInt(ShAmt))))) && - match(Op1, m_ConstantInt(Cst1)) && - // Only do this when A has multiple uses. This is most important to do - // when it exposes other optimizations. - !A->hasOneUse()) { - unsigned ASize =cast<IntegerType>(A->getType())->getPrimitiveSizeInBits(); - - if (ShAmt < ASize) { - APInt MaskV = - APInt::getLowBitsSet(ASize, Op0->getType()->getPrimitiveSizeInBits()); - MaskV <<= ShAmt; - - APInt CmpV = Cst1->getValue().zext(ASize); - CmpV <<= ShAmt; - - Value *Mask = Builder->CreateAnd(A, Builder->getInt(MaskV)); - return new ICmpInst(I.getPredicate(), Mask, Builder->getInt(CmpV)); - } - } - } + if (Instruction *Res = foldICmpEquality(I)) + return Res; // The 'cmpxchg' instruction returns an aggregate containing the old value and // an i1 which indicates whether or not we successfully did the swap. @@ -4284,18 +4450,17 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { Value *X; ConstantInt *Cst; // icmp X+Cst, X if (match(Op0, m_Add(m_Value(X), m_ConstantInt(Cst))) && Op1 == X) - return FoldICmpAddOpCst(I, X, Cst, I.getPredicate()); + return foldICmpAddOpConst(I, X, Cst, I.getPredicate()); // icmp X, X+Cst if (match(Op1, m_Add(m_Value(X), m_ConstantInt(Cst))) && Op0 == X) - return FoldICmpAddOpCst(I, X, Cst, I.getSwappedPredicate()); + return foldICmpAddOpConst(I, X, Cst, I.getSwappedPredicate()); } return Changed ? &I : nullptr; } /// Fold fcmp ([us]itofp x, cst) if possible. -Instruction *InstCombiner::FoldFCmp_IntToFP_Cst(FCmpInst &I, - Instruction *LHSI, +Instruction *InstCombiner::foldFCmpIntToFPConst(FCmpInst &I, Instruction *LHSI, Constant *RHSC) { if (!isa<ConstantFP>(RHSC)) return nullptr; const APFloat &RHS = cast<ConstantFP>(RHSC)->getValueAPF(); @@ -4339,21 +4504,21 @@ Instruction *InstCombiner::FoldFCmp_IntToFP_Cst(FCmpInst &I, // This would allow us to handle (fptosi (x >>s 62) to float) if x is i64 f.e. unsigned InputSize = IntTy->getScalarSizeInBits(); - // Following test does NOT adjust InputSize downwards for signed inputs, - // because the most negative value still requires all the mantissa bits + // Following test does NOT adjust InputSize downwards for signed inputs, + // because the most negative value still requires all the mantissa bits // to distinguish it from one less than that value. if ((int)InputSize > MantissaWidth) { // Conversion would lose accuracy. Check if loss can impact comparison. int Exp = ilogb(RHS); if (Exp == APFloat::IEK_Inf) { int MaxExponent = ilogb(APFloat::getLargest(RHS.getSemantics())); - if (MaxExponent < (int)InputSize - !LHSUnsigned) + if (MaxExponent < (int)InputSize - !LHSUnsigned) // Conversion could create infinity. return nullptr; } else { - // Note that if RHS is zero or NaN, then Exp is negative + // Note that if RHS is zero or NaN, then Exp is negative // and first condition is trivially false. - if (MantissaWidth <= Exp && Exp <= (int)InputSize - !LHSUnsigned) + if (MantissaWidth <= Exp && Exp <= (int)InputSize - !LHSUnsigned) // Conversion could affect comparison. return nullptr; } @@ -4547,7 +4712,7 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); if (Value *V = SimplifyFCmpInst(I.getPredicate(), Op0, Op1, - I.getFastMathFlags(), DL, TLI, DT, AC, &I)) + I.getFastMathFlags(), DL, &TLI, &DT, &AC, &I)) return replaceInstUsesWith(I, V); // Simplify 'fcmp pred X, X' @@ -4601,17 +4766,17 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) { const fltSemantics *Sem; // FIXME: This shouldn't be here. if (LHSExt->getSrcTy()->isHalfTy()) - Sem = &APFloat::IEEEhalf; + Sem = &APFloat::IEEEhalf(); else if (LHSExt->getSrcTy()->isFloatTy()) - Sem = &APFloat::IEEEsingle; + Sem = &APFloat::IEEEsingle(); else if (LHSExt->getSrcTy()->isDoubleTy()) - Sem = &APFloat::IEEEdouble; + Sem = &APFloat::IEEEdouble(); else if (LHSExt->getSrcTy()->isFP128Ty()) - Sem = &APFloat::IEEEquad; + Sem = &APFloat::IEEEquad(); else if (LHSExt->getSrcTy()->isX86_FP80Ty()) - Sem = &APFloat::x87DoubleExtended; + Sem = &APFloat::x87DoubleExtended(); else if (LHSExt->getSrcTy()->isPPC_FP128Ty()) - Sem = &APFloat::PPCDoubleDouble; + Sem = &APFloat::PPCDoubleDouble(); else break; @@ -4641,7 +4806,7 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) { break; case Instruction::SIToFP: case Instruction::UIToFP: - if (Instruction *NV = FoldFCmp_IntToFP_Cst(I, LHSI, RHSC)) + if (Instruction *NV = foldFCmpIntToFPConst(I, LHSI, RHSC)) return NV; break; case Instruction::FSub: { @@ -4658,7 +4823,7 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) { if (GlobalVariable *GV = dyn_cast<GlobalVariable>(GEP->getOperand(0))) if (GV->isConstant() && GV->hasDefinitiveInitializer() && !cast<LoadInst>(LHSI)->isVolatile()) - if (Instruction *Res = FoldCmpLoadFromIndexedGlobal(GEP, GV, I)) + if (Instruction *Res = foldCmpLoadFromIndexedGlobal(GEP, GV, I)) return Res; } break; @@ -4667,7 +4832,7 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) { break; CallInst *CI = cast<CallInst>(LHSI); - Intrinsic::ID IID = getIntrinsicForCallSite(CI, TLI); + Intrinsic::ID IID = getIntrinsicForCallSite(CI, &TLI); if (IID != Intrinsic::fabs) break; |