diff options
Diffstat (limited to 'contrib/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp')
-rw-r--r-- | contrib/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp | 471 |
1 files changed, 307 insertions, 164 deletions
diff --git a/contrib/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/contrib/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp index 6c6e7d8..b2ff96f 100644 --- a/contrib/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp +++ b/contrib/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp @@ -25,7 +25,8 @@ using namespace PatternMatch; /// simplifyValueKnownNonZero - The specific integer value is used in a context /// where it is known to be non-zero. If this allows us to simplify the /// computation, do so and return the new operand, otherwise return null. -static Value *simplifyValueKnownNonZero(Value *V, InstCombiner &IC) { +static Value *simplifyValueKnownNonZero(Value *V, InstCombiner &IC, + Instruction *CxtI) { // If V has multiple uses, then we would have to do more analysis to determine // if this is safe. For example, the use could be in dynamically unreached // code. @@ -35,22 +36,23 @@ static Value *simplifyValueKnownNonZero(Value *V, InstCombiner &IC) { // ((1 << A) >>u B) --> (1 << (A-B)) // Because V cannot be zero, we know that B is less than A. - Value *A = nullptr, *B = nullptr, *PowerOf2 = nullptr; - if (match(V, m_LShr(m_OneUse(m_Shl(m_Value(PowerOf2), m_Value(A))), - m_Value(B))) && - // The "1" can be any value known to be a power of 2. - isKnownToBeAPowerOfTwo(PowerOf2)) { + Value *A = nullptr, *B = nullptr, *One = nullptr; + if (match(V, m_LShr(m_OneUse(m_Shl(m_Value(One), m_Value(A))), m_Value(B))) && + match(One, m_One())) { A = IC.Builder->CreateSub(A, B); - return IC.Builder->CreateShl(PowerOf2, A); + return IC.Builder->CreateShl(One, A); } // (PowerOfTwo >>u B) --> isExact since shifting out the result would make it // inexact. Similarly for <<. if (BinaryOperator *I = dyn_cast<BinaryOperator>(V)) - if (I->isLogicalShift() && isKnownToBeAPowerOfTwo(I->getOperand(0))) { + if (I->isLogicalShift() && + isKnownToBeAPowerOfTwo(I->getOperand(0), false, 0, + IC.getAssumptionCache(), CxtI, + IC.getDominatorTree())) { // We know that this is an exact/nuw shift and that the input is a // non-zero context as well. - if (Value *V2 = simplifyValueKnownNonZero(I->getOperand(0), IC)) { + if (Value *V2 = simplifyValueKnownNonZero(I->getOperand(0), IC, CxtI)) { I->setOperand(0, V2); MadeChange = true; } @@ -76,25 +78,30 @@ static Value *simplifyValueKnownNonZero(Value *V, InstCombiner &IC) { /// MultiplyOverflows - True if the multiply can not be expressed in an int /// this size. -static bool MultiplyOverflows(ConstantInt *C1, ConstantInt *C2, bool sign) { - uint32_t W = C1->getBitWidth(); - APInt LHSExt = C1->getValue(), RHSExt = C2->getValue(); - if (sign) { - LHSExt = LHSExt.sext(W * 2); - RHSExt = RHSExt.sext(W * 2); - } else { - LHSExt = LHSExt.zext(W * 2); - RHSExt = RHSExt.zext(W * 2); - } +static bool MultiplyOverflows(const APInt &C1, const APInt &C2, APInt &Product, + bool IsSigned) { + bool Overflow; + if (IsSigned) + Product = C1.smul_ov(C2, Overflow); + else + Product = C1.umul_ov(C2, Overflow); + + return Overflow; +} - APInt MulExt = LHSExt * RHSExt; +/// \brief True if C2 is a multiple of C1. Quotient contains C2/C1. +static bool IsMultiple(const APInt &C1, const APInt &C2, APInt &Quotient, + bool IsSigned) { + assert(C1.getBitWidth() == C2.getBitWidth() && + "Inconsistent width of constants!"); - if (!sign) - return MulExt.ugt(APInt::getLowBitsSet(W * 2, W)); + APInt Remainder(C1.getBitWidth(), /*Val=*/0ULL, IsSigned); + if (IsSigned) + APInt::sdivrem(C1, C2, Quotient, Remainder); + else + APInt::udivrem(C1, C2, Quotient, Remainder); - APInt Min = APInt::getSignedMinValue(W).sext(W * 2); - APInt Max = APInt::getSignedMaxValue(W).sext(W * 2); - return MulExt.slt(Min) || MulExt.sgt(Max); + return Remainder.isMinValue(); } /// \brief A helper routine of InstCombiner::visitMul(). @@ -116,6 +123,48 @@ static Constant *getLogBase2Vector(ConstantDataVector *CV) { return ConstantVector::get(Elts); } +/// \brief Return true if we can prove that: +/// (mul LHS, RHS) === (mul nsw LHS, RHS) +bool InstCombiner::WillNotOverflowSignedMul(Value *LHS, Value *RHS, + Instruction *CxtI) { + // Multiplying n * m significant bits yields a result of n + m significant + // bits. If the total number of significant bits does not exceed the + // result bit width (minus 1), there is no overflow. + // This means if we have enough leading sign bits in the operands + // we can guarantee that the result does not overflow. + // Ref: "Hacker's Delight" by Henry Warren + unsigned BitWidth = LHS->getType()->getScalarSizeInBits(); + + // Note that underestimating the number of sign bits gives a more + // conservative answer. + unsigned SignBits = ComputeNumSignBits(LHS, 0, CxtI) + + ComputeNumSignBits(RHS, 0, CxtI); + + // First handle the easy case: if we have enough sign bits there's + // definitely no overflow. + if (SignBits > BitWidth + 1) + return true; + + // There are two ambiguous cases where there can be no overflow: + // SignBits == BitWidth + 1 and + // SignBits == BitWidth + // The second case is difficult to check, therefore we only handle the + // first case. + if (SignBits == BitWidth + 1) { + // It overflows only when both arguments are negative and the true + // product is exactly the minimum negative number. + // E.g. mul i16 with 17 sign bits: 0xff00 * 0xff80 = 0x8000 + // For simplicity we just check if at least one side is not negative. + bool LHSNonNegative, LHSNegative; + bool RHSNonNegative, RHSNegative; + ComputeSignBit(LHS, LHSNonNegative, LHSNegative, /*Depth=*/0, CxtI); + ComputeSignBit(RHS, RHSNonNegative, RHSNegative, /*Depth=*/0, CxtI); + if (LHSNonNegative || RHSNonNegative) + return true; + } + return false; +} + Instruction *InstCombiner::visitMul(BinaryOperator &I) { bool Changed = SimplifyAssociativeOrCommutative(I); Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); @@ -123,14 +172,19 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return ReplaceInstUsesWith(I, V); - if (Value *V = SimplifyMulInst(Op0, Op1, DL)) + if (Value *V = SimplifyMulInst(Op0, Op1, DL, TLI, DT, AC)) return ReplaceInstUsesWith(I, V); if (Value *V = SimplifyUsingDistributiveLaws(I)) return ReplaceInstUsesWith(I, V); - if (match(Op1, m_AllOnes())) // X * -1 == 0 - X - return BinaryOperator::CreateNeg(Op0, I.getName()); + // X * -1 == 0 - X + if (match(Op1, m_AllOnes())) { + BinaryOperator *BO = BinaryOperator::CreateNeg(Op0, I.getName()); + if (I.hasNoSignedWrap()) + BO->setHasNoSignedWrap(); + return BO; + } // Also allow combining multiply instructions on vectors. { @@ -139,9 +193,18 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) { const APInt *IVal; if (match(&I, m_Mul(m_Shl(m_Value(NewOp), m_Constant(C2)), m_Constant(C1))) && - match(C1, m_APInt(IVal))) - // ((X << C1)*C2) == (X * (C2 << C1)) - return BinaryOperator::CreateMul(NewOp, ConstantExpr::getShl(C1, C2)); + match(C1, m_APInt(IVal))) { + // ((X << C2)*C1) == (X * (C1 << C2)) + Constant *Shl = ConstantExpr::getShl(C1, C2); + BinaryOperator *Mul = cast<BinaryOperator>(I.getOperand(0)); + BinaryOperator *BO = BinaryOperator::CreateMul(NewOp, Shl); + if (I.hasNoUnsignedWrap() && Mul->hasNoUnsignedWrap()) + BO->setHasNoUnsignedWrap(); + if (I.hasNoSignedWrap() && Mul->hasNoSignedWrap() && + Shl->isNotMinSignedValue()) + BO->setHasNoSignedWrap(); + return BO; + } if (match(&I, m_Mul(m_Value(NewOp), m_Constant(C1)))) { Constant *NewCst = nullptr; @@ -155,8 +218,12 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) { if (NewCst) { BinaryOperator *Shl = BinaryOperator::CreateShl(NewOp, NewCst); - if (I.hasNoSignedWrap()) Shl->setHasNoSignedWrap(); - if (I.hasNoUnsignedWrap()) Shl->setHasNoUnsignedWrap(); + + if (I.hasNoUnsignedWrap()) + Shl->setHasNoUnsignedWrap(); + if (I.hasNoSignedWrap() && NewCst->isNotMinSignedValue()) + Shl->setHasNoSignedWrap(); + return Shl; } } @@ -212,9 +279,16 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) { } } - if (Value *Op0v = dyn_castNegVal(Op0)) // -X * -Y = X*Y - if (Value *Op1v = dyn_castNegVal(Op1)) - return BinaryOperator::CreateMul(Op0v, Op1v); + if (Value *Op0v = dyn_castNegVal(Op0)) { // -X * -Y = X*Y + if (Value *Op1v = dyn_castNegVal(Op1)) { + BinaryOperator *BO = BinaryOperator::CreateMul(Op0v, Op1v); + if (I.hasNoSignedWrap() && + match(Op0, m_NSWSub(m_Value(), m_Value())) && + match(Op1, m_NSWSub(m_Value(), m_Value()))) + BO->setHasNoSignedWrap(); + return BO; + } + } // (X / Y) * Y = X - (X % Y) // (X / Y) * -Y = (X % Y) - X @@ -263,10 +337,22 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) { // (1 << Y)*X --> X << Y { Value *Y; - if (match(Op0, m_Shl(m_One(), m_Value(Y)))) - return BinaryOperator::CreateShl(Op1, Y); - if (match(Op1, m_Shl(m_One(), m_Value(Y)))) - return BinaryOperator::CreateShl(Op0, Y); + BinaryOperator *BO = nullptr; + bool ShlNSW = false; + if (match(Op0, m_Shl(m_One(), m_Value(Y)))) { + BO = BinaryOperator::CreateShl(Op1, Y); + ShlNSW = cast<ShlOperator>(Op0)->hasNoSignedWrap(); + } else if (match(Op1, m_Shl(m_One(), m_Value(Y)))) { + BO = BinaryOperator::CreateShl(Op0, Y); + ShlNSW = cast<ShlOperator>(Op1)->hasNoSignedWrap(); + } + if (BO) { + if (I.hasNoUnsignedWrap()) + BO->setHasNoUnsignedWrap(); + if (I.hasNoSignedWrap() && ShlNSW) + BO->setHasNoSignedWrap(); + return BO; + } } // If one of the operands of the multiply is a cast from a boolean value, then @@ -277,9 +363,9 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) { APInt Negative2(I.getType()->getPrimitiveSizeInBits(), (uint64_t)-2, true); Value *BoolCast = nullptr, *OtherOp = nullptr; - if (MaskedValueIsZero(Op0, Negative2)) + if (MaskedValueIsZero(Op0, Negative2, 0, &I)) BoolCast = Op0, OtherOp = Op1; - else if (MaskedValueIsZero(Op1, Negative2)) + else if (MaskedValueIsZero(Op1, Negative2, 0, &I)) BoolCast = Op1, OtherOp = Op0; if (BoolCast) { @@ -289,43 +375,47 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) { } } + if (!I.hasNoSignedWrap() && WillNotOverflowSignedMul(Op0, Op1, &I)) { + Changed = true; + I.setHasNoSignedWrap(true); + } + + if (!I.hasNoUnsignedWrap() && + computeOverflowForUnsignedMul(Op0, Op1, &I) == + OverflowResult::NeverOverflows) { + Changed = true; + I.setHasNoUnsignedWrap(true); + } + return Changed ? &I : nullptr; } -// -// Detect pattern: -// -// log2(Y*0.5) -// -// And check for corresponding fast math flags -// - +/// Detect pattern log2(Y * 0.5) with corresponding fast math flags. static void detectLog2OfHalf(Value *&Op, Value *&Y, IntrinsicInst *&Log2) { - - if (!Op->hasOneUse()) - return; - - IntrinsicInst *II = dyn_cast<IntrinsicInst>(Op); - if (!II) - return; - if (II->getIntrinsicID() != Intrinsic::log2 || !II->hasUnsafeAlgebra()) - return; - Log2 = II; - - Value *OpLog2Of = II->getArgOperand(0); - if (!OpLog2Of->hasOneUse()) - return; - - Instruction *I = dyn_cast<Instruction>(OpLog2Of); - if (!I) - return; - if (I->getOpcode() != Instruction::FMul || !I->hasUnsafeAlgebra()) - return; - - if (match(I->getOperand(0), m_SpecificFP(0.5))) - Y = I->getOperand(1); - else if (match(I->getOperand(1), m_SpecificFP(0.5))) - Y = I->getOperand(0); + if (!Op->hasOneUse()) + return; + + IntrinsicInst *II = dyn_cast<IntrinsicInst>(Op); + if (!II) + return; + if (II->getIntrinsicID() != Intrinsic::log2 || !II->hasUnsafeAlgebra()) + return; + Log2 = II; + + Value *OpLog2Of = II->getArgOperand(0); + if (!OpLog2Of->hasOneUse()) + return; + + Instruction *I = dyn_cast<Instruction>(OpLog2Of); + if (!I) + return; + if (I->getOpcode() != Instruction::FMul || !I->hasUnsafeAlgebra()) + return; + + if (match(I->getOperand(0), m_SpecificFP(0.5))) + Y = I->getOperand(1); + else if (match(I->getOperand(1), m_SpecificFP(0.5))) + Y = I->getOperand(0); } static bool isFiniteNonZeroFp(Constant *C) { @@ -440,7 +530,8 @@ Instruction *InstCombiner::visitFMul(BinaryOperator &I) { if (isa<Constant>(Op0)) std::swap(Op0, Op1); - if (Value *V = SimplifyFMulInst(Op0, Op1, I.getFastMathFlags(), DL)) + if (Value *V = + SimplifyFMulInst(Op0, Op1, I.getFastMathFlags(), DL, TLI, DT, AC)) return ReplaceInstUsesWith(I, V); bool AllowReassociate = I.hasUnsafeAlgebra(); @@ -510,10 +601,15 @@ Instruction *InstCombiner::visitFMul(BinaryOperator &I) { } } + // sqrt(X) * sqrt(X) -> X + if (AllowReassociate && (Op0 == Op1)) + if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(Op0)) + if (II->getIntrinsicID() == Intrinsic::sqrt) + return ReplaceInstUsesWith(I, II->getOperand(0)); // Under unsafe algebra do: // X * log2(0.5*Y) = X*log2(Y) - X - if (I.hasUnsafeAlgebra()) { + if (AllowReassociate) { Value *OpX = nullptr; Value *OpY = nullptr; IntrinsicInst *Log2; @@ -596,36 +692,6 @@ Instruction *InstCombiner::visitFMul(BinaryOperator &I) { } } - // B * (uitofp i1 C) -> select C, B, 0 - if (I.hasNoNaNs() && I.hasNoInfs() && I.hasNoSignedZeros()) { - Value *LHS = Op0, *RHS = Op1; - Value *B, *C; - if (!match(RHS, m_UIToFP(m_Value(C)))) - std::swap(LHS, RHS); - - if (match(RHS, m_UIToFP(m_Value(C))) && - C->getType()->getScalarType()->isIntegerTy(1)) { - B = LHS; - Value *Zero = ConstantFP::getNegativeZero(B->getType()); - return SelectInst::Create(C, B, Zero); - } - } - - // A * (1 - uitofp i1 C) -> select C, 0, A - if (I.hasNoNaNs() && I.hasNoInfs() && I.hasNoSignedZeros()) { - Value *LHS = Op0, *RHS = Op1; - Value *A, *C; - if (!match(RHS, m_FSub(m_FPOne(), m_UIToFP(m_Value(C))))) - std::swap(LHS, RHS); - - if (match(RHS, m_FSub(m_FPOne(), m_UIToFP(m_Value(C)))) && - C->getType()->getScalarType()->isIntegerTy(1)) { - A = LHS; - Value *Zero = ConstantFP::getNegativeZero(A->getType()); - return SelectInst::Create(C, Zero, A); - } - } - if (!isa<Constant>(Op1)) std::swap(Opnd0, Opnd1); else @@ -714,7 +780,7 @@ Instruction *InstCombiner::commonIDivTransforms(BinaryOperator &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); // The RHS is known non-zero. - if (Value *V = simplifyValueKnownNonZero(I.getOperand(1), *this)) { + if (Value *V = simplifyValueKnownNonZero(I.getOperand(1), *this, &I)) { I.setOperand(1, V); return &I; } @@ -724,25 +790,83 @@ Instruction *InstCombiner::commonIDivTransforms(BinaryOperator &I) { if (isa<SelectInst>(Op1) && SimplifyDivRemOfSelect(I)) return &I; - if (ConstantInt *RHS = dyn_cast<ConstantInt>(Op1)) { - // (X / C1) / C2 -> X / (C1*C2) - if (Instruction *LHS = dyn_cast<Instruction>(Op0)) - if (Instruction::BinaryOps(LHS->getOpcode()) == I.getOpcode()) - if (ConstantInt *LHSRHS = dyn_cast<ConstantInt>(LHS->getOperand(1))) { - if (MultiplyOverflows(RHS, LHSRHS, - I.getOpcode() == Instruction::SDiv)) - return ReplaceInstUsesWith(I, Constant::getNullValue(I.getType())); - return BinaryOperator::Create(I.getOpcode(), LHS->getOperand(0), - ConstantExpr::getMul(RHS, LHSRHS)); + if (Instruction *LHS = dyn_cast<Instruction>(Op0)) { + const APInt *C2; + if (match(Op1, m_APInt(C2))) { + Value *X; + const APInt *C1; + bool IsSigned = I.getOpcode() == Instruction::SDiv; + + // (X / C1) / C2 -> X / (C1*C2) + if ((IsSigned && match(LHS, m_SDiv(m_Value(X), m_APInt(C1)))) || + (!IsSigned && match(LHS, m_UDiv(m_Value(X), m_APInt(C1))))) { + APInt Product(C1->getBitWidth(), /*Val=*/0ULL, IsSigned); + if (!MultiplyOverflows(*C1, *C2, Product, IsSigned)) + return BinaryOperator::Create(I.getOpcode(), X, + ConstantInt::get(I.getType(), Product)); + } + + if ((IsSigned && match(LHS, m_NSWMul(m_Value(X), m_APInt(C1)))) || + (!IsSigned && match(LHS, m_NUWMul(m_Value(X), m_APInt(C1))))) { + APInt Quotient(C1->getBitWidth(), /*Val=*/0ULL, IsSigned); + + // (X * C1) / C2 -> X / (C2 / C1) if C2 is a multiple of C1. + if (IsMultiple(*C2, *C1, Quotient, IsSigned)) { + BinaryOperator *BO = BinaryOperator::Create( + I.getOpcode(), X, ConstantInt::get(X->getType(), Quotient)); + BO->setIsExact(I.isExact()); + return BO; } - if (!RHS->isZero()) { // avoid X udiv 0 - if (SelectInst *SI = dyn_cast<SelectInst>(Op0)) - if (Instruction *R = FoldOpIntoSelect(I, SI)) - return R; - if (isa<PHINode>(Op0)) - if (Instruction *NV = FoldOpIntoPhi(I)) - return NV; + // (X * C1) / C2 -> X * (C1 / C2) if C1 is a multiple of C2. + if (IsMultiple(*C1, *C2, Quotient, IsSigned)) { + BinaryOperator *BO = BinaryOperator::Create( + Instruction::Mul, X, ConstantInt::get(X->getType(), Quotient)); + BO->setHasNoUnsignedWrap( + !IsSigned && + cast<OverflowingBinaryOperator>(LHS)->hasNoUnsignedWrap()); + BO->setHasNoSignedWrap( + cast<OverflowingBinaryOperator>(LHS)->hasNoSignedWrap()); + return BO; + } + } + + if ((IsSigned && match(LHS, m_NSWShl(m_Value(X), m_APInt(C1))) && + *C1 != C1->getBitWidth() - 1) || + (!IsSigned && match(LHS, m_NUWShl(m_Value(X), m_APInt(C1))))) { + APInt Quotient(C1->getBitWidth(), /*Val=*/0ULL, IsSigned); + APInt C1Shifted = APInt::getOneBitSet( + C1->getBitWidth(), static_cast<unsigned>(C1->getLimitedValue())); + + // (X << C1) / C2 -> X / (C2 >> C1) if C2 is a multiple of C1. + if (IsMultiple(*C2, C1Shifted, Quotient, IsSigned)) { + BinaryOperator *BO = BinaryOperator::Create( + I.getOpcode(), X, ConstantInt::get(X->getType(), Quotient)); + BO->setIsExact(I.isExact()); + return BO; + } + + // (X << C1) / C2 -> X * (C2 >> C1) if C1 is a multiple of C2. + if (IsMultiple(C1Shifted, *C2, Quotient, IsSigned)) { + BinaryOperator *BO = BinaryOperator::Create( + Instruction::Mul, X, ConstantInt::get(X->getType(), Quotient)); + BO->setHasNoUnsignedWrap( + !IsSigned && + cast<OverflowingBinaryOperator>(LHS)->hasNoUnsignedWrap()); + BO->setHasNoSignedWrap( + cast<OverflowingBinaryOperator>(LHS)->hasNoSignedWrap()); + return BO; + } + } + + if (*C2 != 0) { // avoid X udiv 0 + if (SelectInst *SI = dyn_cast<SelectInst>(Op0)) + if (Instruction *R = FoldOpIntoSelect(I, SI)) + return R; + if (isa<PHINode>(Op0)) + if (Instruction *NV = FoldOpIntoPhi(I)) + return NV; + } } } @@ -828,7 +952,8 @@ static Instruction *foldUDivPow2Cst(Value *Op0, Value *Op1, const APInt &C = cast<Constant>(Op1)->getUniqueInteger(); BinaryOperator *LShr = BinaryOperator::CreateLShr( Op0, ConstantInt::get(Op0->getType(), C.logBase2())); - if (I.isExact()) LShr->setIsExact(); + if (I.isExact()) + LShr->setIsExact(); return LShr; } @@ -856,7 +981,8 @@ static Instruction *foldUDivShl(Value *Op0, Value *Op1, const BinaryOperator &I, if (ZExtInst *Z = dyn_cast<ZExtInst>(Op1)) N = IC.Builder->CreateZExt(N, Z->getDestTy()); BinaryOperator *LShr = BinaryOperator::CreateLShr(Op0, N); - if (I.isExact()) LShr->setIsExact(); + if (I.isExact()) + LShr->setIsExact(); return LShr; } @@ -893,10 +1019,10 @@ static size_t visitUDivOperand(Value *Op0, Value *Op1, const BinaryOperator &I, return 0; if (SelectInst *SI = dyn_cast<SelectInst>(Op1)) - if (size_t LHSIdx = visitUDivOperand(Op0, SI->getOperand(1), I, Actions)) - if (visitUDivOperand(Op0, SI->getOperand(2), I, Actions)) { - Actions.push_back(UDivFoldAction((FoldUDivOperandCb)nullptr, Op1, - LHSIdx-1)); + if (size_t LHSIdx = + visitUDivOperand(Op0, SI->getOperand(1), I, Actions, Depth)) + if (visitUDivOperand(Op0, SI->getOperand(2), I, Actions, Depth)) { + Actions.push_back(UDivFoldAction(nullptr, Op1, LHSIdx - 1)); return Actions.size(); } @@ -909,7 +1035,7 @@ Instruction *InstCombiner::visitUDiv(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return ReplaceInstUsesWith(I, V); - if (Value *V = SimplifyUDivInst(Op0, Op1, DL)) + if (Value *V = SimplifyUDivInst(Op0, Op1, DL, TLI, DT, AC)) return ReplaceInstUsesWith(I, V); // Handle the integer div common cases @@ -917,19 +1043,30 @@ Instruction *InstCombiner::visitUDiv(BinaryOperator &I) { return Common; // (x lshr C1) udiv C2 --> x udiv (C2 << C1) - if (Constant *C2 = dyn_cast<Constant>(Op1)) { + { Value *X; - Constant *C1; - if (match(Op0, m_LShr(m_Value(X), m_Constant(C1)))) - return BinaryOperator::CreateUDiv(X, ConstantExpr::getShl(C2, C1)); + const APInt *C1, *C2; + if (match(Op0, m_LShr(m_Value(X), m_APInt(C1))) && + match(Op1, m_APInt(C2))) { + bool Overflow; + APInt C2ShlC1 = C2->ushl_ov(*C1, Overflow); + if (!Overflow) { + bool IsExact = I.isExact() && match(Op0, m_Exact(m_Value())); + BinaryOperator *BO = BinaryOperator::CreateUDiv( + X, ConstantInt::get(X->getType(), C2ShlC1)); + if (IsExact) + BO->setIsExact(); + return BO; + } + } } // (zext A) udiv (zext B) --> zext (A udiv B) if (ZExtInst *ZOp0 = dyn_cast<ZExtInst>(Op0)) if (Value *ZOp1 = dyn_castZExtVal(Op1, ZOp0->getSrcTy())) - return new ZExtInst(Builder->CreateUDiv(ZOp0->getOperand(0), ZOp1, "div", - I.isExact()), - I.getType()); + return new ZExtInst( + Builder->CreateUDiv(ZOp0->getOperand(0), ZOp1, "div", I.isExact()), + I.getType()); // (LHS udiv (select (select (...)))) -> (LHS >> (select (select (...)))) SmallVector<UDivFoldAction, 6> UDivActions; @@ -971,7 +1108,7 @@ Instruction *InstCombiner::visitSDiv(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return ReplaceInstUsesWith(I, V); - if (Value *V = SimplifySDivInst(Op0, Op1, DL)) + if (Value *V = SimplifySDivInst(Op0, Op1, DL, TLI, DT, AC)) return ReplaceInstUsesWith(I, V); // Handle the integer div common cases @@ -998,28 +1135,34 @@ Instruction *InstCombiner::visitSDiv(BinaryOperator &I) { return new ZExtInst(Builder->CreateICmpEQ(Op0, Op1), I.getType()); // -X/C --> X/-C provided the negation doesn't overflow. - if (SubOperator *Sub = dyn_cast<SubOperator>(Op0)) - if (match(Sub->getOperand(0), m_Zero()) && Sub->hasNoSignedWrap()) - return BinaryOperator::CreateSDiv(Sub->getOperand(1), - ConstantExpr::getNeg(RHS)); + Value *X; + if (match(Op0, m_NSWSub(m_Zero(), m_Value(X)))) { + auto *BO = BinaryOperator::CreateSDiv(X, ConstantExpr::getNeg(RHS)); + BO->setIsExact(I.isExact()); + return BO; + } } // If the sign bits of both operands are zero (i.e. we can prove they are // unsigned inputs), turn this into a udiv. if (I.getType()->isIntegerTy()) { APInt Mask(APInt::getSignBit(I.getType()->getPrimitiveSizeInBits())); - if (MaskedValueIsZero(Op0, Mask)) { - if (MaskedValueIsZero(Op1, Mask)) { + if (MaskedValueIsZero(Op0, Mask, 0, &I)) { + if (MaskedValueIsZero(Op1, Mask, 0, &I)) { // X sdiv Y -> X udiv Y, iff X and Y don't have sign bit set - return BinaryOperator::CreateUDiv(Op0, Op1, I.getName()); + auto *BO = BinaryOperator::CreateUDiv(Op0, Op1, I.getName()); + BO->setIsExact(I.isExact()); + return BO; } - if (match(Op1, m_Shl(m_Power2(), m_Value()))) { + if (isKnownToBeAPowerOfTwo(Op1, /*OrZero*/ true, 0, AC, &I, DT)) { // X sdiv (1 << Y) -> X udiv (1 << Y) ( -> X u>> Y) // Safe because the only negative value (1 << Y) can take on is // INT_MIN, and X sdiv INT_MIN == X udiv INT_MIN == 0 if X doesn't have // the sign bit set. - return BinaryOperator::CreateUDiv(Op0, Op1, I.getName()); + auto *BO = BinaryOperator::CreateUDiv(Op0, Op1, I.getName()); + BO->setIsExact(I.isExact()); + return BO; } } } @@ -1034,8 +1177,7 @@ Instruction *InstCombiner::visitSDiv(BinaryOperator &I) { /// If the conversion was successful, the simplified expression "X * 1/C" is /// returned; otherwise, NULL is returned. /// -static Instruction *CvtFDivConstToReciprocal(Value *Dividend, - Constant *Divisor, +static Instruction *CvtFDivConstToReciprocal(Value *Dividend, Constant *Divisor, bool AllowReciprocal) { if (!isa<ConstantFP>(Divisor)) // TODO: handle vectors. return nullptr; @@ -1064,7 +1206,7 @@ Instruction *InstCombiner::visitFDiv(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return ReplaceInstUsesWith(I, V); - if (Value *V = SimplifyFDivInst(Op0, Op1, DL)) + if (Value *V = SimplifyFDivInst(Op0, Op1, DL, TLI, DT, AC)) return ReplaceInstUsesWith(I, V); if (isa<Constant>(Op0)) @@ -1195,7 +1337,7 @@ Instruction *InstCombiner::commonIRemTransforms(BinaryOperator &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); // The RHS is known non-zero. - if (Value *V = simplifyValueKnownNonZero(I.getOperand(1), *this)) { + if (Value *V = simplifyValueKnownNonZero(I.getOperand(1), *this, &I)) { I.setOperand(1, V); return &I; } @@ -1229,7 +1371,7 @@ Instruction *InstCombiner::visitURem(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return ReplaceInstUsesWith(I, V); - if (Value *V = SimplifyURemInst(Op0, Op1, DL)) + if (Value *V = SimplifyURemInst(Op0, Op1, DL, TLI, DT, AC)) return ReplaceInstUsesWith(I, V); if (Instruction *common = commonIRemTransforms(I)) @@ -1242,7 +1384,7 @@ Instruction *InstCombiner::visitURem(BinaryOperator &I) { I.getType()); // X urem Y -> X and Y-1, where Y is a power of 2, - if (isKnownToBeAPowerOfTwo(Op1, /*OrZero*/true)) { + if (isKnownToBeAPowerOfTwo(Op1, /*OrZero*/ true, 0, AC, &I, DT)) { Constant *N1 = Constant::getAllOnesValue(I.getType()); Value *Add = Builder->CreateAdd(Op1, N1); return BinaryOperator::CreateAnd(Op0, Add); @@ -1264,28 +1406,29 @@ Instruction *InstCombiner::visitSRem(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return ReplaceInstUsesWith(I, V); - if (Value *V = SimplifySRemInst(Op0, Op1, DL)) + if (Value *V = SimplifySRemInst(Op0, Op1, DL, TLI, DT, AC)) return ReplaceInstUsesWith(I, V); // Handle the integer rem common cases if (Instruction *Common = commonIRemTransforms(I)) return Common; - if (Value *RHSNeg = dyn_castNegVal(Op1)) - if (!isa<Constant>(RHSNeg) || - (isa<ConstantInt>(RHSNeg) && - cast<ConstantInt>(RHSNeg)->getValue().isStrictlyPositive())) { - // X % -Y -> X % Y + { + const APInt *Y; + // X % -Y -> X % Y + if (match(Op1, m_APInt(Y)) && Y->isNegative() && !Y->isMinSignedValue()) { Worklist.AddValue(I.getOperand(1)); - I.setOperand(1, RHSNeg); + I.setOperand(1, ConstantInt::get(I.getType(), -*Y)); return &I; } + } // If the sign bits of both operands are zero (i.e. we can prove they are // unsigned inputs), turn this into a urem. if (I.getType()->isIntegerTy()) { APInt Mask(APInt::getSignBit(I.getType()->getPrimitiveSizeInBits())); - if (MaskedValueIsZero(Op1, Mask) && MaskedValueIsZero(Op0, Mask)) { + if (MaskedValueIsZero(Op1, Mask, 0, &I) && + MaskedValueIsZero(Op0, Mask, 0, &I)) { // X srem Y -> X urem Y, iff X and Y don't have sign bit set return BinaryOperator::CreateURem(Op0, Op1, I.getName()); } @@ -1338,7 +1481,7 @@ Instruction *InstCombiner::visitFRem(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return ReplaceInstUsesWith(I, V); - if (Value *V = SimplifyFRemInst(Op0, Op1, DL)) + if (Value *V = SimplifyFRemInst(Op0, Op1, DL, TLI, DT, AC)) return ReplaceInstUsesWith(I, V); // Handle cases involving: rem X, (select Cond, Y, Z) |