diff options
Diffstat (limited to 'lib/Transforms/InstCombine/InstCombineCompares.cpp')
-rw-r--r-- | lib/Transforms/InstCombine/InstCombineCompares.cpp | 772 |
1 files changed, 542 insertions, 230 deletions
diff --git a/lib/Transforms/InstCombine/InstCombineCompares.cpp b/lib/Transforms/InstCombine/InstCombineCompares.cpp index d7e2b72..999de34 100644 --- a/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -22,13 +22,17 @@ using namespace llvm; using namespace PatternMatch; +static ConstantInt *getOne(Constant *C) { + return ConstantInt::get(cast<IntegerType>(C->getType()), 1); +} + /// AddOne - Add one to a ConstantInt static Constant *AddOne(Constant *C) { return ConstantExpr::getAdd(C, ConstantInt::get(C->getType(), 1)); } /// SubOne - Subtract one from a ConstantInt -static Constant *SubOne(ConstantInt *C) { - return ConstantExpr::getSub(C, ConstantInt::get(C->getType(), 1)); +static Constant *SubOne(Constant *C) { + return ConstantExpr::getSub(C, ConstantInt::get(C->getType(), 1)); } static ConstantInt *ExtractElement(Constant *V, Constant *Idx) { @@ -160,8 +164,8 @@ static void ComputeSignedMinMaxValuesFromKnownBits(const APInt& KnownZero, Max = KnownOne|UnknownBits; if (UnknownBits.isNegative()) { // Sign bit is unknown - Min.set(Min.getBitWidth()-1); - Max.clear(Max.getBitWidth()-1); + Min.setBit(Min.getBitWidth()-1); + Max.clearBit(Max.getBitWidth()-1); } } @@ -694,13 +698,6 @@ Instruction *InstCombiner::FoldICmpAddOpCst(ICmpInst &ICI, if (Pred == ICmpInst::ICMP_NE) return ReplaceInstUsesWith(ICI, ConstantInt::getTrue(X->getContext())); - // If this is an instruction (as opposed to constantexpr) get NUW/NSW info. - bool isNUW = false, isNSW = false; - if (BinaryOperator *Add = dyn_cast<BinaryOperator>(TheAdd)) { - isNUW = Add->hasNoUnsignedWrap(); - isNSW = Add->hasNoSignedWrap(); - } - // From this point on, we know that (X+C <= X) --> (X+C < X) because C != 0, // so the values can never be equal. Similiarly for all other "or equals" // operators. @@ -709,10 +706,6 @@ Instruction *InstCombiner::FoldICmpAddOpCst(ICmpInst &ICI, // (X+2) <u X --> X >u (MAXUINT-2) --> X > 253 // (X+MAXUINT) <u X --> X >u (MAXUINT-MAXUINT) --> X != 0 if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_ULE) { - // If this is an NUW add, then this is always false. - if (isNUW) - return ReplaceInstUsesWith(ICI, ConstantInt::getFalse(X->getContext())); - Value *R = ConstantExpr::getSub(ConstantInt::getAllOnesValue(CI->getType()), CI); return new ICmpInst(ICmpInst::ICMP_UGT, X, R); @@ -721,12 +714,8 @@ Instruction *InstCombiner::FoldICmpAddOpCst(ICmpInst &ICI, // (X+1) >u X --> X <u (0-1) --> X != 255 // (X+2) >u X --> X <u (0-2) --> X <u 254 // (X+MAXUINT) >u X --> X <u (0-MAXUINT) --> X <u 1 --> X == 0 - if (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_UGE) { - // If this is an NUW add, then this is always true. - if (isNUW) - return ReplaceInstUsesWith(ICI, ConstantInt::getTrue(X->getContext())); + if (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_UGE) return new ICmpInst(ICmpInst::ICMP_ULT, X, ConstantExpr::getNeg(CI)); - } unsigned BitWidth = CI->getType()->getPrimitiveSizeInBits(); ConstantInt *SMax = ConstantInt::get(X->getContext(), @@ -738,16 +727,8 @@ Instruction *InstCombiner::FoldICmpAddOpCst(ICmpInst &ICI, // (X+MINSINT) <s X --> X >s (MAXSINT-MINSINT) --> X >s -1 // (X+ -2) <s X --> X >s (MAXSINT- -2) --> X >s 126 // (X+ -1) <s X --> X >s (MAXSINT- -1) --> X != 127 - if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE) { - // If this is an NSW add, then we have two cases: if the constant is - // positive, then this is always false, if negative, this is always true. - if (isNSW) { - bool isTrue = CI->getValue().isNegative(); - return ReplaceInstUsesWith(ICI, ConstantInt::get(ICI.getType(), isTrue)); - } - + if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE) return new ICmpInst(ICmpInst::ICMP_SGT, X, ConstantExpr::getSub(SMax, CI)); - } // (X+ 1) >s X --> X <s (MAXSINT-(1-1)) --> X != 127 // (X+ 2) >s X --> X <s (MAXSINT-(2-1)) --> X <s 126 @@ -756,13 +737,6 @@ Instruction *InstCombiner::FoldICmpAddOpCst(ICmpInst &ICI, // (X+ -2) >s X --> X <s (MAXSINT-(-2-1)) --> X <s -126 // (X+ -1) >s X --> X <s (MAXSINT-(-1-1)) --> X == -128 - // If this is an NSW add, then we have two cases: if the constant is - // positive, then this is always true, if negative, this is always false. - if (isNSW) { - bool isTrue = !CI->getValue().isNegative(); - return ReplaceInstUsesWith(ICI, ConstantInt::get(ICI.getType(), isTrue)); - } - assert(Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SGE); Constant *C = ConstantInt::get(X->getContext(), CI->getValue()-1); return new ICmpInst(ICmpInst::ICMP_SLT, X, ConstantExpr::getSub(SMax, C)); @@ -782,7 +756,7 @@ Instruction *InstCombiner::FoldICmpDivCst(ICmpInst &ICI, BinaryOperator *DivI, // results than (x /s C1) <u C2 or (x /u C1) <s C2 or even // (x /u C1) <u C2. Simply casting the operands and result won't // work. :( The if statement below tests that condition and bails - // if it finds it. + // if it finds it. bool DivIsSigned = DivI->getOpcode() == Instruction::SDiv; if (!ICI.isEquality() && DivIsSigned != ICI.isSigned()) return 0; @@ -790,9 +764,11 @@ Instruction *InstCombiner::FoldICmpDivCst(ICmpInst &ICI, BinaryOperator *DivI, return 0; // The ProdOV computation fails on divide by zero. if (DivIsSigned && DivRHS->isAllOnesValue()) return 0; // The overflow computation also screws up here - if (DivRHS->isOne()) - return 0; // Not worth bothering, and eliminates some funny cases - // with INT_MIN. + if (DivRHS->isOne()) { + // This eliminates some funny cases with INT_MIN. + ICI.setOperand(0, DivI->getOperand(0)); // X/1 == X. + return &ICI; + } // Compute Prod = CI * DivRHS. We are essentially solving an equation // of form X/C1=C2. We solve for X by multiplying C1 (DivRHS) and @@ -809,6 +785,10 @@ Instruction *InstCombiner::FoldICmpDivCst(ICmpInst &ICI, BinaryOperator *DivI, // Get the ICmp opcode ICmpInst::Predicate Pred = ICI.getPredicate(); + /// If the division is known to be exact, then there is no remainder from the + /// divide, so the covered range size is unit, otherwise it is the divisor. + ConstantInt *RangeSize = DivI->isExact() ? getOne(Prod) : DivRHS; + // Figure out the interval that is being checked. For example, a comparison // like "X /u 5 == 0" is really checking that X is in the interval [0, 5). // Compute this interval based on the constants involved and the signedness of @@ -818,38 +798,43 @@ Instruction *InstCombiner::FoldICmpDivCst(ICmpInst &ICI, BinaryOperator *DivI, // -1 if overflowed off the bottom end, or +1 if overflowed off the top end. int LoOverflow = 0, HiOverflow = 0; Constant *LoBound = 0, *HiBound = 0; - + if (!DivIsSigned) { // udiv // e.g. X/5 op 3 --> [15, 20) LoBound = Prod; HiOverflow = LoOverflow = ProdOV; - if (!HiOverflow) - HiOverflow = AddWithOverflow(HiBound, LoBound, DivRHS, false); + if (!HiOverflow) { + // If this is not an exact divide, then many values in the range collapse + // to the same result value. + HiOverflow = AddWithOverflow(HiBound, LoBound, RangeSize, false); + } + } else if (DivRHS->getValue().isStrictlyPositive()) { // Divisor is > 0. if (CmpRHSV == 0) { // (X / pos) op 0 // Can't overflow. e.g. X/2 op 0 --> [-1, 2) - LoBound = cast<ConstantInt>(ConstantExpr::getNeg(SubOne(DivRHS))); - HiBound = DivRHS; + LoBound = ConstantExpr::getNeg(SubOne(RangeSize)); + HiBound = RangeSize; } else if (CmpRHSV.isStrictlyPositive()) { // (X / pos) op pos LoBound = Prod; // e.g. X/5 op 3 --> [15, 20) HiOverflow = LoOverflow = ProdOV; if (!HiOverflow) - HiOverflow = AddWithOverflow(HiBound, Prod, DivRHS, true); + HiOverflow = AddWithOverflow(HiBound, Prod, RangeSize, true); } else { // (X / pos) op neg // e.g. X/5 op -3 --> [-15-4, -15+1) --> [-19, -14) HiBound = AddOne(Prod); LoOverflow = HiOverflow = ProdOV ? -1 : 0; if (!LoOverflow) { - ConstantInt* DivNeg = - cast<ConstantInt>(ConstantExpr::getNeg(DivRHS)); + ConstantInt *DivNeg =cast<ConstantInt>(ConstantExpr::getNeg(RangeSize)); LoOverflow = AddWithOverflow(LoBound, HiBound, DivNeg, true) ? -1 : 0; - } + } } } else if (DivRHS->getValue().isNegative()) { // Divisor is < 0. + if (DivI->isExact()) + RangeSize = cast<ConstantInt>(ConstantExpr::getNeg(RangeSize)); if (CmpRHSV == 0) { // (X / neg) op 0 // e.g. X/-5 op 0 --> [-4, 5) - LoBound = AddOne(DivRHS); - HiBound = cast<ConstantInt>(ConstantExpr::getNeg(DivRHS)); + LoBound = AddOne(RangeSize); + HiBound = cast<ConstantInt>(ConstantExpr::getNeg(RangeSize)); if (HiBound == DivRHS) { // -INTMIN = INTMIN HiOverflow = 1; // [INTMIN+1, overflow) HiBound = 0; // e.g. X/INTMIN = 0 --> X > INTMIN @@ -859,12 +844,12 @@ Instruction *InstCombiner::FoldICmpDivCst(ICmpInst &ICI, BinaryOperator *DivI, HiBound = AddOne(Prod); HiOverflow = LoOverflow = ProdOV ? -1 : 0; if (!LoOverflow) - LoOverflow = AddWithOverflow(LoBound, HiBound, DivRHS, true) ? -1 : 0; + LoOverflow = AddWithOverflow(LoBound, HiBound, RangeSize, true) ? -1:0; } else { // (X / neg) op neg LoBound = Prod; // e.g. X/-5 op -3 --> [15, 20) LoOverflow = HiOverflow = ProdOV; if (!HiOverflow) - HiOverflow = SubWithOverflow(HiBound, Prod, DivRHS, true); + HiOverflow = SubWithOverflow(HiBound, Prod, RangeSize, true); } // Dividing by a negative swaps the condition. LT <-> GT @@ -883,9 +868,8 @@ Instruction *InstCombiner::FoldICmpDivCst(ICmpInst &ICI, BinaryOperator *DivI, if (LoOverflow) return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT, X, HiBound); - return ReplaceInstUsesWith(ICI, - InsertRangeTest(X, LoBound, HiBound, DivIsSigned, - true)); + return ReplaceInstUsesWith(ICI, InsertRangeTest(X, LoBound, HiBound, + DivIsSigned, true)); case ICmpInst::ICMP_NE: if (LoOverflow && HiOverflow) return ReplaceInstUsesWith(ICI, ConstantInt::getTrue(ICI.getContext())); @@ -908,13 +892,100 @@ Instruction *InstCombiner::FoldICmpDivCst(ICmpInst &ICI, BinaryOperator *DivI, case ICmpInst::ICMP_SGT: if (HiOverflow == +1) // High bound greater than input range. return ReplaceInstUsesWith(ICI, ConstantInt::getFalse(ICI.getContext())); - else if (HiOverflow == -1) // High bound less than input range. + if (HiOverflow == -1) // High bound less than input range. return ReplaceInstUsesWith(ICI, ConstantInt::getTrue(ICI.getContext())); if (Pred == ICmpInst::ICMP_UGT) return new ICmpInst(ICmpInst::ICMP_UGE, X, HiBound); - else - return new ICmpInst(ICmpInst::ICMP_SGE, X, HiBound); + return new ICmpInst(ICmpInst::ICMP_SGE, X, HiBound); + } +} + +/// FoldICmpShrCst - Handle "icmp(([al]shr X, cst1), cst2)". +Instruction *InstCombiner::FoldICmpShrCst(ICmpInst &ICI, BinaryOperator *Shr, + ConstantInt *ShAmt) { + const APInt &CmpRHSV = cast<ConstantInt>(ICI.getOperand(1))->getValue(); + + // Check that the shift amount is in range. If not, don't perform + // undefined shifts. When the shift is visited it will be + // simplified. + uint32_t TypeBits = CmpRHSV.getBitWidth(); + uint32_t ShAmtVal = (uint32_t)ShAmt->getLimitedValue(TypeBits); + if (ShAmtVal >= TypeBits || ShAmtVal == 0) + return 0; + + if (!ICI.isEquality()) { + // If we have an unsigned comparison and an ashr, we can't simplify this. + // Similarly for signed comparisons with lshr. + if (ICI.isSigned() != (Shr->getOpcode() == Instruction::AShr)) + return 0; + + // Otherwise, all lshr and all exact ashr's are equivalent to a udiv/sdiv by + // a power of 2. Since we already have logic to simplify these, transform + // to div and then simplify the resultant comparison. + if (Shr->getOpcode() == Instruction::AShr && + !Shr->isExact()) + return 0; + + // Revisit the shift (to delete it). + Worklist.Add(Shr); + + Constant *DivCst = + ConstantInt::get(Shr->getType(), APInt::getOneBitSet(TypeBits, ShAmtVal)); + + Value *Tmp = + Shr->getOpcode() == Instruction::AShr ? + Builder->CreateSDiv(Shr->getOperand(0), DivCst, "", Shr->isExact()) : + Builder->CreateUDiv(Shr->getOperand(0), DivCst, "", Shr->isExact()); + + ICI.setOperand(0, Tmp); + + // If the builder folded the binop, just return it. + BinaryOperator *TheDiv = dyn_cast<BinaryOperator>(Tmp); + if (TheDiv == 0) + return &ICI; + + // Otherwise, fold this div/compare. + assert(TheDiv->getOpcode() == Instruction::SDiv || + TheDiv->getOpcode() == Instruction::UDiv); + + Instruction *Res = FoldICmpDivCst(ICI, TheDiv, cast<ConstantInt>(DivCst)); + assert(Res && "This div/cst should have folded!"); + return Res; + } + + + // If we are comparing against bits always shifted out, the + // comparison cannot succeed. + APInt Comp = CmpRHSV << ShAmtVal; + ConstantInt *ShiftedCmpRHS = ConstantInt::get(ICI.getContext(), Comp); + if (Shr->getOpcode() == Instruction::LShr) + Comp = Comp.lshr(ShAmtVal); + else + Comp = Comp.ashr(ShAmtVal); + + if (Comp != CmpRHSV) { // Comparing against a bit that we know is zero. + bool IsICMP_NE = ICI.getPredicate() == ICmpInst::ICMP_NE; + Constant *Cst = ConstantInt::get(Type::getInt1Ty(ICI.getContext()), + IsICMP_NE); + return ReplaceInstUsesWith(ICI, Cst); + } + + // Otherwise, check to see if the bits shifted out are known to be zero. + // If so, we can compare against the unshifted value: + // (X & 4) >> 1 == 2 --> (X & 4) == 4. + if (Shr->hasOneUse() && Shr->isExact()) + return new ICmpInst(ICI.getPredicate(), Shr->getOperand(0), ShiftedCmpRHS); + + if (Shr->hasOneUse()) { + // Otherwise strength reduce the shift into an and. + APInt Val(APInt::getHighBitsSet(TypeBits, TypeBits - ShAmtVal)); + Constant *Mask = ConstantInt::get(ICI.getContext(), Val); + + Value *And = Builder->CreateAnd(Shr->getOperand(0), + Mask, Shr->getName()+".mask"); + return new ICmpInst(ICI.getPredicate(), And, ShiftedCmpRHS); } + return 0; } @@ -939,8 +1010,7 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI, // If all the high bits are known, we can do this xform. if ((KnownZero|KnownOne).countLeadingOnes() >= SrcBits-DstBits) { // Pull in the high bits from known-ones set. - APInt NewRHS(RHS->getValue()); - NewRHS.zext(SrcBits); + APInt NewRHS = RHS->getValue().zext(SrcBits); NewRHS |= KnownOne; return new ICmpInst(ICI.getPredicate(), LHSI->getOperand(0), ConstantInt::get(ICI.getContext(), NewRHS)); @@ -1022,10 +1092,8 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI, (AndCST->getValue().isNonNegative() && RHSV.isNonNegative()))) { uint32_t BitWidth = cast<IntegerType>(Cast->getOperand(0)->getType())->getBitWidth(); - APInt NewCST = AndCST->getValue(); - NewCST.zext(BitWidth); - APInt NewCI = RHSV; - NewCI.zext(BitWidth); + APInt NewCST = AndCST->getValue().zext(BitWidth); + APInt NewCI = RHSV.zext(BitWidth); Value *NewAnd = Builder->CreateAnd(Cast->getOperand(0), ConstantInt::get(ICI.getContext(), NewCST), @@ -1145,7 +1213,6 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI, if (match(LHSI, m_Or(m_PtrToInt(m_Value(P)), m_PtrToInt(m_Value(Q))))) { // Simplify icmp eq (or (ptrtoint P), (ptrtoint Q)), 0 // -> and (icmp eq P, null), (icmp eq Q, null). - Value *ICIP = Builder->CreateICmp(ICI.getPredicate(), P, Constant::getNullValue(P->getType())); Value *ICIQ = Builder->CreateICmp(ICI.getPredicate(), Q, @@ -1185,6 +1252,12 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI, return ReplaceInstUsesWith(ICI, Cst); } + // If the shift is NUW, then it is just shifting out zeros, no need for an + // AND. + if (cast<BinaryOperator>(LHSI)->hasNoUnsignedWrap()) + return new ICmpInst(ICI.getPredicate(), LHSI->getOperand(0), + ConstantExpr::getLShr(RHS, ShAmt)); + if (LHSI->hasOneUse()) { // Otherwise strength reduce the shift into an and. uint32_t ShAmtVal = (uint32_t)ShAmt->getLimitedValue(TypeBits); @@ -1195,8 +1268,7 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI, Value *And = Builder->CreateAnd(LHSI->getOperand(0),Mask, LHSI->getName()+".mask"); return new ICmpInst(ICI.getPredicate(), And, - ConstantInt::get(ICI.getContext(), - RHSV.lshr(ShAmtVal))); + ConstantExpr::getLShr(RHS, ShAmt)); } } @@ -1205,8 +1277,9 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI, if (LHSI->hasOneUse() && isSignBitCheck(ICI.getPredicate(), RHS, TrueIfSigned)) { // (X << 31) <s 0 --> (X&1) != 0 - Constant *Mask = ConstantInt::get(ICI.getContext(), APInt(TypeBits, 1) << - (TypeBits-ShAmt->getZExtValue()-1)); + Constant *Mask = ConstantInt::get(LHSI->getOperand(0)->getType(), + APInt::getOneBitSet(TypeBits, + TypeBits-ShAmt->getZExtValue()-1)); Value *And = Builder->CreateAnd(LHSI->getOperand(0), Mask, LHSI->getName()+".mask"); return new ICmpInst(TrueIfSigned ? ICmpInst::ICMP_NE : ICmpInst::ICMP_EQ, @@ -1216,57 +1289,13 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI, } case Instruction::LShr: // (icmp pred (shr X, ShAmt), CI) - case Instruction::AShr: { + case Instruction::AShr: // Only handle equality comparisons of shift-by-constant. - ConstantInt *ShAmt = dyn_cast<ConstantInt>(LHSI->getOperand(1)); - if (!ShAmt || !ICI.isEquality()) break; - - // Check that the shift amount is in range. If not, don't perform - // undefined shifts. When the shift is visited it will be - // simplified. - uint32_t TypeBits = RHSV.getBitWidth(); - if (ShAmt->uge(TypeBits)) - break; - - uint32_t ShAmtVal = (uint32_t)ShAmt->getLimitedValue(TypeBits); - - // If we are comparing against bits always shifted out, the - // comparison cannot succeed. - APInt Comp = RHSV << ShAmtVal; - if (LHSI->getOpcode() == Instruction::LShr) - Comp = Comp.lshr(ShAmtVal); - else - Comp = Comp.ashr(ShAmtVal); - - if (Comp != RHSV) { // Comparing against a bit that we know is zero. - bool IsICMP_NE = ICI.getPredicate() == ICmpInst::ICMP_NE; - Constant *Cst = ConstantInt::get(Type::getInt1Ty(ICI.getContext()), - IsICMP_NE); - return ReplaceInstUsesWith(ICI, Cst); - } - - // Otherwise, check to see if the bits shifted out are known to be zero. - // If so, we can compare against the unshifted value: - // (X & 4) >> 1 == 2 --> (X & 4) == 4. - if (LHSI->hasOneUse() && - MaskedValueIsZero(LHSI->getOperand(0), - APInt::getLowBitsSet(Comp.getBitWidth(), ShAmtVal))) { - return new ICmpInst(ICI.getPredicate(), LHSI->getOperand(0), - ConstantExpr::getShl(RHS, ShAmt)); - } - - if (LHSI->hasOneUse()) { - // Otherwise strength reduce the shift into an and. - APInt Val(APInt::getHighBitsSet(TypeBits, TypeBits - ShAmtVal)); - Constant *Mask = ConstantInt::get(ICI.getContext(), Val); - - Value *And = Builder->CreateAnd(LHSI->getOperand(0), - Mask, LHSI->getName()+".mask"); - return new ICmpInst(ICI.getPredicate(), And, - ConstantExpr::getShl(RHS, ShAmt)); - } + if (ConstantInt *ShAmt = dyn_cast<ConstantInt>(LHSI->getOperand(1))) + if (Instruction *Res = FoldICmpShrCst(ICI, cast<BinaryOperator>(LHSI), + ShAmt)) + return Res; break; - } case Instruction::SDiv: case Instruction::UDiv: @@ -1543,50 +1572,174 @@ Instruction *InstCombiner::visitICmpInstWithCastAndCast(ICmpInst &ICI) { // The re-extended constant changed so the constant cannot be represented // in the shorter type. Consequently, we cannot emit a simple comparison. + // All the cases that fold to true or false will have already been handled + // by SimplifyICmpInst, so only deal with the tricky case. - // First, handle some easy cases. We know the result cannot be equal at this - // point so handle the ICI.isEquality() cases - if (ICI.getPredicate() == ICmpInst::ICMP_EQ) - return ReplaceInstUsesWith(ICI, ConstantInt::getFalse(ICI.getContext())); - if (ICI.getPredicate() == ICmpInst::ICMP_NE) - return ReplaceInstUsesWith(ICI, ConstantInt::getTrue(ICI.getContext())); + if (isSignedCmp || !isSignedExt) + return 0; // Evaluate the comparison for LT (we invert for GT below). LE and GE cases // should have been folded away previously and not enter in here. - Value *Result; - if (isSignedCmp) { - // We're performing a signed comparison. - if (cast<ConstantInt>(CI)->getValue().isNegative()) - Result = ConstantInt::getFalse(ICI.getContext()); // X < (small) --> false - else - Result = ConstantInt::getTrue(ICI.getContext()); // X < (large) --> true - } else { - // We're performing an unsigned comparison. - if (isSignedExt) { - // We're performing an unsigned comp with a sign extended value. - // This is true if the input is >= 0. [aka >s -1] - Constant *NegOne = Constant::getAllOnesValue(SrcTy); - Result = Builder->CreateICmpSGT(LHSCIOp, NegOne, ICI.getName()); - } else { - // Unsigned extend & unsigned compare -> always true. - Result = ConstantInt::getTrue(ICI.getContext()); - } - } + + // We're performing an unsigned comp with a sign extended value. + // This is true if the input is >= 0. [aka >s -1] + Constant *NegOne = Constant::getAllOnesValue(SrcTy); + Value *Result = Builder->CreateICmpSGT(LHSCIOp, NegOne, ICI.getName()); // Finally, return the value computed. - if (ICI.getPredicate() == ICmpInst::ICMP_ULT || - ICI.getPredicate() == ICmpInst::ICMP_SLT) + if (ICI.getPredicate() == ICmpInst::ICMP_ULT) return ReplaceInstUsesWith(ICI, Result); - assert((ICI.getPredicate()==ICmpInst::ICMP_UGT || - ICI.getPredicate()==ICmpInst::ICMP_SGT) && - "ICmp should be folded!"); - if (Constant *CI = dyn_cast<Constant>(Result)) - return ReplaceInstUsesWith(ICI, ConstantExpr::getNot(CI)); + assert(ICI.getPredicate() == ICmpInst::ICMP_UGT && "ICmp should be folded!"); return BinaryOperator::CreateNot(Result); } +/// ProcessUGT_ADDCST_ADD - The caller has matched a pattern of the form: +/// I = icmp ugt (add (add A, B), CI2), CI1 +/// If this is of the form: +/// sum = a + b +/// if (sum+128 >u 255) +/// Then replace it with llvm.sadd.with.overflow.i8. +/// +static Instruction *ProcessUGT_ADDCST_ADD(ICmpInst &I, Value *A, Value *B, + ConstantInt *CI2, ConstantInt *CI1, + InstCombiner &IC) { + // The transformation we're trying to do here is to transform this into an + // llvm.sadd.with.overflow. To do this, we have to replace the original add + // with a narrower add, and discard the add-with-constant that is part of the + // range check (if we can't eliminate it, this isn't profitable). + + // In order to eliminate the add-with-constant, the compare can be its only + // use. + Instruction *AddWithCst = cast<Instruction>(I.getOperand(0)); + if (!AddWithCst->hasOneUse()) return 0; + + // If CI2 is 2^7, 2^15, 2^31, then it might be an sadd.with.overflow. + if (!CI2->getValue().isPowerOf2()) return 0; + unsigned NewWidth = CI2->getValue().countTrailingZeros(); + if (NewWidth != 7 && NewWidth != 15 && NewWidth != 31) return 0; + + // The width of the new add formed is 1 more than the bias. + ++NewWidth; + + // Check to see that CI1 is an all-ones value with NewWidth bits. + if (CI1->getBitWidth() == NewWidth || + CI1->getValue() != APInt::getLowBitsSet(CI1->getBitWidth(), NewWidth)) + return 0; + + // In order to replace the original add with a narrower + // llvm.sadd.with.overflow, the only uses allowed are the add-with-constant + // and truncates that discard the high bits of the add. Verify that this is + // the case. + Instruction *OrigAdd = cast<Instruction>(AddWithCst->getOperand(0)); + for (Value::use_iterator UI = OrigAdd->use_begin(), E = OrigAdd->use_end(); + UI != E; ++UI) { + if (*UI == AddWithCst) continue; + + // Only accept truncates for now. We would really like a nice recursive + // predicate like SimplifyDemandedBits, but which goes downwards the use-def + // chain to see which bits of a value are actually demanded. If the + // original add had another add which was then immediately truncated, we + // could still do the transformation. + TruncInst *TI = dyn_cast<TruncInst>(*UI); + if (TI == 0 || + TI->getType()->getPrimitiveSizeInBits() > NewWidth) return 0; + } + + // If the pattern matches, truncate the inputs to the narrower type and + // use the sadd_with_overflow intrinsic to efficiently compute both the + // result and the overflow bit. + Module *M = I.getParent()->getParent()->getParent(); + + const Type *NewType = IntegerType::get(OrigAdd->getContext(), NewWidth); + Value *F = Intrinsic::getDeclaration(M, Intrinsic::sadd_with_overflow, + &NewType, 1); + + InstCombiner::BuilderTy *Builder = IC.Builder; + + // Put the new code above the original add, in case there are any uses of the + // add between the add and the compare. + Builder->SetInsertPoint(OrigAdd); + + Value *TruncA = Builder->CreateTrunc(A, NewType, A->getName()+".trunc"); + Value *TruncB = Builder->CreateTrunc(B, NewType, B->getName()+".trunc"); + CallInst *Call = Builder->CreateCall2(F, TruncA, TruncB, "sadd"); + Value *Add = Builder->CreateExtractValue(Call, 0, "sadd.result"); + Value *ZExt = Builder->CreateZExt(Add, OrigAdd->getType()); + + // The inner add was the result of the narrow add, zero extended to the + // wider type. Replace it with the result computed by the intrinsic. + IC.ReplaceInstUsesWith(*OrigAdd, ZExt); + + // The original icmp gets replaced with the overflow value. + return ExtractValueInst::Create(Call, 1, "sadd.overflow"); +} + +static Instruction *ProcessUAddIdiom(Instruction &I, Value *OrigAddV, + InstCombiner &IC) { + // Don't bother doing this transformation for pointers, don't do it for + // vectors. + if (!isa<IntegerType>(OrigAddV->getType())) return 0; + + // If the add is a constant expr, then we don't bother transforming it. + Instruction *OrigAdd = dyn_cast<Instruction>(OrigAddV); + if (OrigAdd == 0) return 0; + + Value *LHS = OrigAdd->getOperand(0), *RHS = OrigAdd->getOperand(1); + + // Put the new code above the original add, in case there are any uses of the + // add between the add and the compare. + InstCombiner::BuilderTy *Builder = IC.Builder; + Builder->SetInsertPoint(OrigAdd); + + Module *M = I.getParent()->getParent()->getParent(); + const Type *Ty = LHS->getType(); + Value *F = Intrinsic::getDeclaration(M, Intrinsic::uadd_with_overflow, &Ty,1); + CallInst *Call = Builder->CreateCall2(F, LHS, RHS, "uadd"); + Value *Add = Builder->CreateExtractValue(Call, 0); + IC.ReplaceInstUsesWith(*OrigAdd, Add); + + // The original icmp gets replaced with the overflow value. + return ExtractValueInst::Create(Call, 1, "uadd.overflow"); +} + +// DemandedBitsLHSMask - When performing a comparison against a constant, +// it is possible that not all the bits in the LHS are demanded. This helper +// method computes the mask that IS demanded. +static APInt DemandedBitsLHSMask(ICmpInst &I, + unsigned BitWidth, bool isSignCheck) { + if (isSignCheck) + return APInt::getSignBit(BitWidth); + + ConstantInt *CI = dyn_cast<ConstantInt>(I.getOperand(1)); + if (!CI) return APInt::getAllOnesValue(BitWidth); + const APInt &RHS = CI->getValue(); + + switch (I.getPredicate()) { + // For a UGT comparison, we don't care about any bits that + // correspond to the trailing ones of the comparand. The value of these + // bits doesn't impact the outcome of the comparison, because any value + // greater than the RHS must differ in a bit higher than these due to carry. + case ICmpInst::ICMP_UGT: { + unsigned trailingOnes = RHS.countTrailingOnes(); + APInt lowBitsSet = APInt::getLowBitsSet(BitWidth, trailingOnes); + return ~lowBitsSet; + } + + // Similarly, for a ULT comparison, we don't care about the trailing zeros. + // Any value less than the RHS must differ in a higher bit because of carries. + case ICmpInst::ICMP_ULT: { + unsigned trailingZeros = RHS.countTrailingZeros(); + APInt lowBitsSet = APInt::getLowBitsSet(BitWidth, trailingZeros); + return ~lowBitsSet; + } + + default: + return APInt::getAllOnesValue(BitWidth); + } + +} Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { bool Changed = false; @@ -1649,17 +1802,37 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { } unsigned BitWidth = 0; - if (TD) - BitWidth = TD->getTypeSizeInBits(Ty->getScalarType()); - else if (Ty->isIntOrIntVectorTy()) + if (Ty->isIntOrIntVectorTy()) BitWidth = Ty->getScalarSizeInBits(); - + else if (TD) // Pointers require TD info to get their size. + BitWidth = TD->getTypeSizeInBits(Ty->getScalarType()); + bool isSignBit = false; // See if we are doing a comparison with a constant. if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) { Value *A = 0, *B = 0; + // Match the following pattern, which is a common idiom when writing + // overflow-safe integer arithmetic function. The source performs an + // addition in wider type, and explicitly checks for overflow using + // comparisons against INT_MIN and INT_MAX. Simplify this by using the + // sadd_with_overflow intrinsic. + // + // TODO: This could probably be generalized to handle other overflow-safe + // operations if we worked out the formulas to compute the appropriate + // magic constants. + // + // sum = a + b + // if (sum+128 >u 255) ... -> llvm.sadd.with.overflow.i8 + { + ConstantInt *CI2; // I = icmp ugt (add (add A, B), CI2), CI + if (I.getPredicate() == ICmpInst::ICMP_UGT && + match(Op0, m_Add(m_Add(m_Value(A), m_Value(B)), m_ConstantInt(CI2)))) + if (Instruction *Res = ProcessUGT_ADDCST_ADD(I, A, B, CI2, CI, *this)) + return Res; + } + // (icmp ne/eq (sub A B) 0) -> (icmp ne/eq A, B) if (I.isEquality() && CI->isZero() && match(Op0, m_Sub(m_Value(A), m_Value(B)))) { @@ -1704,8 +1877,7 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { APInt Op1KnownZero(BitWidth, 0), Op1KnownOne(BitWidth, 0); if (SimplifyDemandedBits(I.getOperandUse(0), - isSignBit ? APInt::getSignBit(BitWidth) - : APInt::getAllOnesValue(BitWidth), + DemandedBitsLHSMask(I, BitWidth, isSignBit), Op0KnownZero, Op0KnownOne, 0)) return &I; if (SimplifyDemandedBits(I.getOperandUse(1), @@ -1744,14 +1916,80 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { // simplify this comparison. For example, (x&4) < 8 is always true. switch (I.getPredicate()) { default: llvm_unreachable("Unknown icmp opcode!"); - case ICmpInst::ICMP_EQ: + case ICmpInst::ICMP_EQ: { if (Op0Max.ult(Op1Min) || Op0Min.ugt(Op1Max)) return ReplaceInstUsesWith(I, ConstantInt::getFalse(I.getContext())); + + // If all bits are known zero except for one, then we know at most one + // bit is set. If the comparison is against zero, then this is a check + // to see if *that* bit is set. + APInt Op0KnownZeroInverted = ~Op0KnownZero; + if (~Op1KnownZero == 0 && Op0KnownZeroInverted.isPowerOf2()) { + // If the LHS is an AND with the same constant, look through it. + Value *LHS = 0; + ConstantInt *LHSC = 0; + if (!match(Op0, m_And(m_Value(LHS), m_ConstantInt(LHSC))) || + LHSC->getValue() != Op0KnownZeroInverted) + LHS = Op0; + + // If the LHS is 1 << x, and we know the result is a power of 2 like 8, + // then turn "((1 << x)&8) == 0" into "x != 3". + Value *X = 0; + if (match(LHS, m_Shl(m_One(), m_Value(X)))) { + unsigned CmpVal = Op0KnownZeroInverted.countTrailingZeros(); + return new ICmpInst(ICmpInst::ICMP_NE, X, + ConstantInt::get(X->getType(), CmpVal)); + } + + // If the LHS is 8 >>u x, and we know the result is a power of 2 like 1, + // then turn "((8 >>u x)&1) == 0" into "x != 3". + const APInt *CI; + if (Op0KnownZeroInverted == 1 && + match(LHS, m_LShr(m_Power2(CI), m_Value(X)))) + return new ICmpInst(ICmpInst::ICMP_NE, X, + ConstantInt::get(X->getType(), + CI->countTrailingZeros())); + } + break; - case ICmpInst::ICMP_NE: + } + case ICmpInst::ICMP_NE: { if (Op0Max.ult(Op1Min) || Op0Min.ugt(Op1Max)) return ReplaceInstUsesWith(I, ConstantInt::getTrue(I.getContext())); + + // If all bits are known zero except for one, then we know at most one + // bit is set. If the comparison is against zero, then this is a check + // to see if *that* bit is set. + APInt Op0KnownZeroInverted = ~Op0KnownZero; + if (~Op1KnownZero == 0 && Op0KnownZeroInverted.isPowerOf2()) { + // If the LHS is an AND with the same constant, look through it. + Value *LHS = 0; + ConstantInt *LHSC = 0; + if (!match(Op0, m_And(m_Value(LHS), m_ConstantInt(LHSC))) || + LHSC->getValue() != Op0KnownZeroInverted) + LHS = Op0; + + // If the LHS is 1 << x, and we know the result is a power of 2 like 8, + // then turn "((1 << x)&8) != 0" into "x == 3". + Value *X = 0; + if (match(LHS, m_Shl(m_One(), m_Value(X)))) { + unsigned CmpVal = Op0KnownZeroInverted.countTrailingZeros(); + return new ICmpInst(ICmpInst::ICMP_EQ, X, + ConstantInt::get(X->getType(), CmpVal)); + } + + // If the LHS is 8 >>u x, and we know the result is a power of 2 like 1, + // then turn "((8 >>u x)&1) != 0" into "x == 3". + const APInt *CI; + if (Op0KnownZeroInverted == 1 && + match(LHS, m_LShr(m_Power2(CI), m_Value(X)))) + return new ICmpInst(ICmpInst::ICMP_EQ, X, + ConstantInt::get(X->getType(), + CI->countTrailingZeros())); + } + break; + } case ICmpInst::ICMP_ULT: if (Op0Max.ult(Op1Min)) // A <u B -> true if max(A) < min(B) return ReplaceInstUsesWith(I, ConstantInt::getTrue(I.getContext())); @@ -1894,7 +2132,7 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { // block. If in the same block, we're encouraging jump threading. If // not, we are just pessimizing the code by making an i1 phi. if (LHSI->getParent() == I.getParent()) - if (Instruction *NV = FoldOpIntoPhi(I, true)) + if (Instruction *NV = FoldOpIntoPhi(I)) return NV; break; case Instruction::Select: { @@ -1995,79 +2233,163 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { if (Instruction *R = visitICmpInstWithCastAndCast(I)) return R; } - - // See if it's the same type of instruction on the left and right. - if (BinaryOperator *Op0I = dyn_cast<BinaryOperator>(Op0)) { - if (BinaryOperator *Op1I = dyn_cast<BinaryOperator>(Op1)) { - if (Op0I->getOpcode() == Op1I->getOpcode() && Op0I->hasOneUse() && - Op1I->hasOneUse() && Op0I->getOperand(1) == Op1I->getOperand(1)) { - switch (Op0I->getOpcode()) { - default: break; - case Instruction::Add: - case Instruction::Sub: - case Instruction::Xor: - if (I.isEquality()) // a+x icmp eq/ne b+x --> a icmp b - return new ICmpInst(I.getPredicate(), Op0I->getOperand(0), - Op1I->getOperand(0)); - // icmp u/s (a ^ signbit), (b ^ signbit) --> icmp s/u a, b - if (ConstantInt *CI = dyn_cast<ConstantInt>(Op0I->getOperand(1))) { - if (CI->getValue().isSignBit()) { - ICmpInst::Predicate Pred = I.isSigned() - ? I.getUnsignedPredicate() - : I.getSignedPredicate(); - return new ICmpInst(Pred, Op0I->getOperand(0), - Op1I->getOperand(0)); - } - - if (CI->getValue().isMaxSignedValue()) { - ICmpInst::Predicate Pred = I.isSigned() - ? I.getUnsignedPredicate() - : I.getSignedPredicate(); - Pred = I.getSwappedPredicate(Pred); - return new ICmpInst(Pred, Op0I->getOperand(0), - Op1I->getOperand(0)); - } + + // Special logic for binary operators. + BinaryOperator *BO0 = dyn_cast<BinaryOperator>(Op0); + BinaryOperator *BO1 = dyn_cast<BinaryOperator>(Op1); + if (BO0 || BO1) { + CmpInst::Predicate Pred = I.getPredicate(); + bool NoOp0WrapProblem = false, NoOp1WrapProblem = false; + if (BO0 && isa<OverflowingBinaryOperator>(BO0)) + NoOp0WrapProblem = ICmpInst::isEquality(Pred) || + (CmpInst::isUnsigned(Pred) && BO0->hasNoUnsignedWrap()) || + (CmpInst::isSigned(Pred) && BO0->hasNoSignedWrap()); + if (BO1 && isa<OverflowingBinaryOperator>(BO1)) + NoOp1WrapProblem = ICmpInst::isEquality(Pred) || + (CmpInst::isUnsigned(Pred) && BO1->hasNoUnsignedWrap()) || + (CmpInst::isSigned(Pred) && BO1->hasNoSignedWrap()); + + // Analyze the case when either Op0 or Op1 is an add instruction. + // Op0 = A + B (or A and B are null); Op1 = C + D (or C and D are null). + Value *A = 0, *B = 0, *C = 0, *D = 0; + if (BO0 && BO0->getOpcode() == Instruction::Add) + A = BO0->getOperand(0), B = BO0->getOperand(1); + if (BO1 && BO1->getOpcode() == Instruction::Add) + C = BO1->getOperand(0), D = BO1->getOperand(1); + + // icmp (X+Y), X -> icmp Y, 0 for equalities or if there is no overflow. + if ((A == Op1 || B == Op1) && NoOp0WrapProblem) + return new ICmpInst(Pred, A == Op1 ? B : A, + Constant::getNullValue(Op1->getType())); + + // icmp X, (X+Y) -> icmp 0, Y for equalities or if there is no overflow. + if ((C == Op0 || D == Op0) && NoOp1WrapProblem) + return new ICmpInst(Pred, Constant::getNullValue(Op0->getType()), + C == Op0 ? D : C); + + // icmp (X+Y), (X+Z) -> icmp Y, Z for equalities or if there is no overflow. + if (A && C && (A == C || A == D || B == C || B == D) && + NoOp0WrapProblem && NoOp1WrapProblem && + // Try not to increase register pressure. + BO0->hasOneUse() && BO1->hasOneUse()) { + // Determine Y and Z in the form icmp (X+Y), (X+Z). + Value *Y = (A == C || A == D) ? B : A; + Value *Z = (C == A || C == B) ? D : C; + return new ICmpInst(Pred, Y, Z); + } + + // Analyze the case when either Op0 or Op1 is a sub instruction. + // Op0 = A - B (or A and B are null); Op1 = C - D (or C and D are null). + A = 0; B = 0; C = 0; D = 0; + if (BO0 && BO0->getOpcode() == Instruction::Sub) + A = BO0->getOperand(0), B = BO0->getOperand(1); + if (BO1 && BO1->getOpcode() == Instruction::Sub) + C = BO1->getOperand(0), D = BO1->getOperand(1); + + // icmp (X-Y), X -> icmp 0, Y for equalities or if there is no overflow. + if (A == Op1 && NoOp0WrapProblem) + return new ICmpInst(Pred, Constant::getNullValue(Op1->getType()), B); + + // icmp X, (X-Y) -> icmp Y, 0 for equalities or if there is no overflow. + if (C == Op0 && NoOp1WrapProblem) + return new ICmpInst(Pred, D, Constant::getNullValue(Op0->getType())); + + // icmp (Y-X), (Z-X) -> icmp Y, Z for equalities or if there is no overflow. + if (B && D && B == D && NoOp0WrapProblem && NoOp1WrapProblem && + // Try not to increase register pressure. + BO0->hasOneUse() && BO1->hasOneUse()) + return new ICmpInst(Pred, A, C); + + // icmp (X-Y), (X-Z) -> icmp Z, Y for equalities or if there is no overflow. + if (A && C && A == C && NoOp0WrapProblem && NoOp1WrapProblem && + // Try not to increase register pressure. + BO0->hasOneUse() && BO1->hasOneUse()) + return new ICmpInst(Pred, D, B); + + if (BO0 && BO1 && BO0->getOpcode() == BO1->getOpcode() && + BO0->hasOneUse() && BO1->hasOneUse() && + BO0->getOperand(1) == BO1->getOperand(1)) { + switch (BO0->getOpcode()) { + default: break; + case Instruction::Add: + case Instruction::Sub: + case Instruction::Xor: + if (I.isEquality()) // a+x icmp eq/ne b+x --> a icmp b + return new ICmpInst(I.getPredicate(), BO0->getOperand(0), + BO1->getOperand(0)); + // icmp u/s (a ^ signbit), (b ^ signbit) --> icmp s/u a, b + if (ConstantInt *CI = dyn_cast<ConstantInt>(BO0->getOperand(1))) { + if (CI->getValue().isSignBit()) { + ICmpInst::Predicate Pred = I.isSigned() + ? I.getUnsignedPredicate() + : I.getSignedPredicate(); + return new ICmpInst(Pred, BO0->getOperand(0), + BO1->getOperand(0)); + } + + if (CI->getValue().isMaxSignedValue()) { + ICmpInst::Predicate Pred = I.isSigned() + ? I.getUnsignedPredicate() + : I.getSignedPredicate(); + Pred = I.getSwappedPredicate(Pred); + return new ICmpInst(Pred, BO0->getOperand(0), + BO1->getOperand(0)); } + } + break; + case Instruction::Mul: + if (!I.isEquality()) break; - case Instruction::Mul: - if (!I.isEquality()) - break; - if (ConstantInt *CI = dyn_cast<ConstantInt>(Op0I->getOperand(1))) { - // a * Cst icmp eq/ne b * Cst --> a & Mask icmp b & Mask - // Mask = -1 >> count-trailing-zeros(Cst). - if (!CI->isZero() && !CI->isOne()) { - const APInt &AP = CI->getValue(); - ConstantInt *Mask = ConstantInt::get(I.getContext(), - APInt::getLowBitsSet(AP.getBitWidth(), - AP.getBitWidth() - - AP.countTrailingZeros())); - Value *And1 = Builder->CreateAnd(Op0I->getOperand(0), Mask); - Value *And2 = Builder->CreateAnd(Op1I->getOperand(0), Mask); - return new ICmpInst(I.getPredicate(), And1, And2); - } + if (ConstantInt *CI = dyn_cast<ConstantInt>(BO0->getOperand(1))) { + // a * Cst icmp eq/ne b * Cst --> a & Mask icmp b & Mask + // Mask = -1 >> count-trailing-zeros(Cst). + if (!CI->isZero() && !CI->isOne()) { + const APInt &AP = CI->getValue(); + ConstantInt *Mask = ConstantInt::get(I.getContext(), + APInt::getLowBitsSet(AP.getBitWidth(), + AP.getBitWidth() - + AP.countTrailingZeros())); + Value *And1 = Builder->CreateAnd(BO0->getOperand(0), Mask); + Value *And2 = Builder->CreateAnd(BO1->getOperand(0), Mask); + return new ICmpInst(I.getPredicate(), And1, And2); } - break; } + break; } } } - // ~x < ~y --> y < x { Value *A, *B; - if (match(Op0, m_Not(m_Value(A))) && - match(Op1, m_Not(m_Value(B)))) - return new ICmpInst(I.getPredicate(), B, A); + // ~x < ~y --> y < x + // ~x < cst --> ~cst < x + if (match(Op0, m_Not(m_Value(A)))) { + if (match(Op1, m_Not(m_Value(B)))) + return new ICmpInst(I.getPredicate(), B, A); + if (ConstantInt *RHSC = dyn_cast<ConstantInt>(Op1)) + return new ICmpInst(I.getPredicate(), ConstantExpr::getNot(RHSC), A); + } + + // (a+b) <u a --> llvm.uadd.with.overflow. + // (a+b) <u b --> llvm.uadd.with.overflow. + if (I.getPredicate() == ICmpInst::ICMP_ULT && + match(Op0, m_Add(m_Value(A), m_Value(B))) && + (Op1 == A || Op1 == B)) + if (Instruction *R = ProcessUAddIdiom(I, Op0, *this)) + return R; + + // a >u (a+b) --> llvm.uadd.with.overflow. + // b >u (a+b) --> llvm.uadd.with.overflow. + if (I.getPredicate() == ICmpInst::ICMP_UGT && + match(Op1, m_Add(m_Value(A), m_Value(B))) && + (Op0 == A || Op0 == B)) + if (Instruction *R = ProcessUAddIdiom(I, Op1, *this)) + return R; } if (I.isEquality()) { Value *A, *B, *C, *D; - - // -x == -y --> x == y - if (match(Op0, m_Neg(m_Value(A))) && - match(Op1, m_Neg(m_Value(B)))) - return new ICmpInst(I.getPredicate(), A, B); - + if (match(Op0, m_Xor(m_Value(A), m_Value(B)))) { if (A == Op1 || B == Op1) { // (A^B) == A -> B == 0 Value *OtherVal = A == Op1 ? B : A; @@ -2102,16 +2424,6 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { Constant::getNullValue(A->getType())); } - // (A-B) == A -> B == 0 - if (match(Op0, m_Sub(m_Specific(Op1), m_Value(B)))) - return new ICmpInst(I.getPredicate(), B, - Constant::getNullValue(B->getType())); - - // A == (A-B) -> B == 0 - if (match(Op1, m_Sub(m_Specific(Op0), m_Value(B)))) - return new ICmpInst(I.getPredicate(), B, - Constant::getNullValue(B->getType())); - // (X&Z) == (Y&Z) -> (X^Y) & Z == 0 if (Op0->hasOneUse() && Op1->hasOneUse() && match(Op0, m_And(m_Value(A), m_Value(B))) && @@ -2397,7 +2709,7 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) { // block. If in the same block, we're encouraging jump threading. If // not, we are just pessimizing the code by making an i1 phi. if (LHSI->getParent() == I.getParent()) - if (Instruction *NV = FoldOpIntoPhi(I, true)) + if (Instruction *NV = FoldOpIntoPhi(I)) return NV; break; case Instruction::SIToFP: |