diff options
Diffstat (limited to 'contrib/llvm/lib/Analysis/ScalarEvolution.cpp')
-rw-r--r-- | contrib/llvm/lib/Analysis/ScalarEvolution.cpp | 2506 |
1 files changed, 1612 insertions, 894 deletions
diff --git a/contrib/llvm/lib/Analysis/ScalarEvolution.cpp b/contrib/llvm/lib/Analysis/ScalarEvolution.cpp index ef1bb3a..e42a4b5 100644 --- a/contrib/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/contrib/llvm/lib/Analysis/ScalarEvolution.cpp @@ -111,10 +111,14 @@ MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden, "derived loop"), cl::init(100)); -// FIXME: Enable this with XDEBUG when the test suite is clean. +// FIXME: Enable this with EXPENSIVE_CHECKS when the test suite is clean. static cl::opt<bool> VerifySCEV("verify-scev", cl::desc("Verify ScalarEvolution's backedge taken counts (slow)")); +static cl::opt<bool> + VerifySCEVMap("verify-scev-maps", + cl::desc("Verify no dangling value in ScalarEvolution's " + "ExprValueMap (slow)")); //===----------------------------------------------------------------------===// // SCEV class definitions @@ -162,11 +166,11 @@ void SCEV::print(raw_ostream &OS) const { for (unsigned i = 1, e = AR->getNumOperands(); i != e; ++i) OS << ",+," << *AR->getOperand(i); OS << "}<"; - if (AR->getNoWrapFlags(FlagNUW)) + if (AR->hasNoUnsignedWrap()) OS << "nuw><"; - if (AR->getNoWrapFlags(FlagNSW)) + if (AR->hasNoSignedWrap()) OS << "nsw><"; - if (AR->getNoWrapFlags(FlagNW) && + if (AR->hasNoSelfWrap() && !AR->getNoWrapFlags((NoWrapFlags)(FlagNUW | FlagNSW))) OS << "nw><"; AR->getLoop()->getHeader()->printAsOperand(OS, /*PrintType=*/false); @@ -196,9 +200,9 @@ void SCEV::print(raw_ostream &OS) const { switch (NAry->getSCEVType()) { case scAddExpr: case scMulExpr: - if (NAry->getNoWrapFlags(FlagNUW)) + if (NAry->hasNoUnsignedWrap()) OS << "<nuw>"; - if (NAry->getNoWrapFlags(FlagNSW)) + if (NAry->hasNoSignedWrap()) OS << "<nsw>"; } return; @@ -283,8 +287,6 @@ bool SCEV::isAllOnesValue() const { return false; } -/// isNonConstantNegative - Return true if the specified scev is negated, but -/// not a constant. bool SCEV::isNonConstantNegative() const { const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(this); if (!Mul) return false; @@ -620,10 +622,10 @@ public: }; } // end anonymous namespace -/// GroupByComplexity - Given a list of SCEV objects, order them by their -/// complexity, and group objects of the same complexity together by value. -/// When this routine is finished, we know that any duplicates in the vector are -/// consecutive and that complexity is monotonically increasing. +/// Given a list of SCEV objects, order them by their complexity, and group +/// objects of the same complexity together by value. When this routine is +/// finished, we know that any duplicates in the vector are consecutive and that +/// complexity is monotonically increasing. /// /// Note that we go take special precautions to ensure that we get deterministic /// results from this routine. In other words, we don't want the results of @@ -723,7 +725,7 @@ public: } // Split the Denominator when it is a product. - if (const SCEVMulExpr *T = dyn_cast<const SCEVMulExpr>(Denominator)) { + if (const SCEVMulExpr *T = dyn_cast<SCEVMulExpr>(Denominator)) { const SCEV *Q, *R; *Quotient = Numerator; for (const SCEV *Op : T->operands()) { @@ -922,8 +924,7 @@ private: // Simple SCEV method implementations //===----------------------------------------------------------------------===// -/// BinomialCoefficient - Compute BC(It, K). The result has width W. -/// Assume, K > 0. +/// Compute BC(It, K). The result has width W. Assume, K > 0. static const SCEV *BinomialCoefficient(const SCEV *It, unsigned K, ScalarEvolution &SE, Type *ResultTy) { @@ -1034,10 +1035,10 @@ static const SCEV *BinomialCoefficient(const SCEV *It, unsigned K, SE.getTruncateOrZeroExtend(DivResult, ResultTy)); } -/// evaluateAtIteration - Return the value of this chain of recurrences at -/// the specified iteration number. We can evaluate this recurrence by -/// multiplying each element in the chain by the binomial coefficient -/// corresponding to it. In other words, we can evaluate {A,+,B,+,C,+,D} as: +/// Return the value of this chain of recurrences at the specified iteration +/// number. We can evaluate this recurrence by multiplying each element in the +/// chain by the binomial coefficient corresponding to it. In other words, we +/// can evaluate {A,+,B,+,C,+,D} as: /// /// A*BC(It, 0) + B*BC(It, 1) + C*BC(It, 2) + D*BC(It, 3) /// @@ -1450,9 +1451,14 @@ const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op, unsigned BitWidth = getTypeSizeInBits(AR->getType()); const Loop *L = AR->getLoop(); + if (!AR->hasNoUnsignedWrap()) { + auto NewFlags = proveNoWrapViaConstantRanges(AR); + const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(NewFlags); + } + // If we have special knowledge that this addrec won't overflow, // we don't need to do any further analysis. - if (AR->getNoWrapFlags(SCEV::FlagNUW)) + if (AR->hasNoUnsignedWrap()) return getAddRecExpr( getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this), getZeroExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); @@ -1512,11 +1518,22 @@ const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op, getSignExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); } } + } - // If the backedge is guarded by a comparison with the pre-inc value - // the addrec is safe. Also, if the entry is guarded by a comparison - // with the start value and the backedge is guarded by a comparison - // with the post-inc value, the addrec is safe. + // Normally, in the cases we can prove no-overflow via a + // backedge guarding condition, we can also compute a backedge + // taken count for the loop. The exceptions are assumptions and + // guards present in the loop -- SCEV is not great at exploiting + // these to compute max backedge taken counts, but can still use + // these to prove lack of overflow. Use this fact to avoid + // doing extra work that may not pay off. + if (!isa<SCEVCouldNotCompute>(MaxBECount) || HasGuards || + !AC.assumptions().empty()) { + // If the backedge is guarded by a comparison with the pre-inc + // value the addrec is safe. Also, if the entry is guarded by + // a comparison with the start value and the backedge is + // guarded by a comparison with the post-inc value, the addrec + // is safe. if (isKnownPositive(Step)) { const SCEV *N = getConstant(APInt::getMinValue(BitWidth) - getUnsignedRange(Step).getUnsignedMax()); @@ -1524,7 +1541,8 @@ const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op, (isLoopEntryGuardedByCond(L, ICmpInst::ICMP_ULT, Start, N) && isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_ULT, AR->getPostIncExpr(*this), N))) { - // Cache knowledge of AR NUW, which is propagated to this AddRec. + // Cache knowledge of AR NUW, which is propagated to this + // AddRec. const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNUW); // Return the expression with the addrec on the outside. return getAddRecExpr( @@ -1538,8 +1556,9 @@ const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op, (isLoopEntryGuardedByCond(L, ICmpInst::ICMP_UGT, Start, N) && isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_UGT, AR->getPostIncExpr(*this), N))) { - // Cache knowledge of AR NW, which is propagated to this AddRec. - // Negative step causes unsigned wrap, but it still can't self-wrap. + // Cache knowledge of AR NW, which is propagated to this + // AddRec. Negative step causes unsigned wrap, but it + // still can't self-wrap. const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNW); // Return the expression with the addrec on the outside. return getAddRecExpr( @@ -1559,7 +1578,7 @@ const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op, if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) { // zext((A + B + ...)<nuw>) --> (zext(A) + zext(B) + ...)<nuw> - if (SA->getNoWrapFlags(SCEV::FlagNUW)) { + if (SA->hasNoUnsignedWrap()) { // If the addition does not unsign overflow then we can, by definition, // commute the zero extension with the addition operation. SmallVector<const SCEV *, 4> Ops; @@ -1608,10 +1627,6 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op, void *IP = nullptr; if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; - // If the input value is provably positive, build a zext instead. - if (isKnownNonNegative(Op)) - return getZeroExtendExpr(Op, Ty); - // sext(trunc(x)) --> sext(x) or x or trunc(x) if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op)) { // It's possible the bits taken off by the truncate were all sign bits. If @@ -1643,7 +1658,7 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op, } // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw> - if (SA->getNoWrapFlags(SCEV::FlagNSW)) { + if (SA->hasNoSignedWrap()) { // If the addition does not sign overflow then we can, by definition, // commute the sign extension with the addition operation. SmallVector<const SCEV *, 4> Ops; @@ -1663,9 +1678,14 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op, unsigned BitWidth = getTypeSizeInBits(AR->getType()); const Loop *L = AR->getLoop(); + if (!AR->hasNoSignedWrap()) { + auto NewFlags = proveNoWrapViaConstantRanges(AR); + const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(NewFlags); + } + // If we have special knowledge that this addrec won't overflow, // we don't need to do any further analysis. - if (AR->getNoWrapFlags(SCEV::FlagNSW)) + if (AR->hasNoSignedWrap()) return getAddRecExpr( getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this), getSignExtendExpr(Step, Ty), L, SCEV::FlagNSW); @@ -1732,11 +1752,23 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op, getZeroExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); } } + } - // If the backedge is guarded by a comparison with the pre-inc value - // the addrec is safe. Also, if the entry is guarded by a comparison - // with the start value and the backedge is guarded by a comparison - // with the post-inc value, the addrec is safe. + // Normally, in the cases we can prove no-overflow via a + // backedge guarding condition, we can also compute a backedge + // taken count for the loop. The exceptions are assumptions and + // guards present in the loop -- SCEV is not great at exploiting + // these to compute max backedge taken counts, but can still use + // these to prove lack of overflow. Use this fact to avoid + // doing extra work that may not pay off. + + if (!isa<SCEVCouldNotCompute>(MaxBECount) || HasGuards || + !AC.assumptions().empty()) { + // If the backedge is guarded by a comparison with the pre-inc + // value the addrec is safe. Also, if the entry is guarded by + // a comparison with the start value and the backedge is + // guarded by a comparison with the post-inc value, the addrec + // is safe. ICmpInst::Predicate Pred; const SCEV *OverflowLimit = getSignedOverflowLimitForStep(Step, &Pred, this); @@ -1752,6 +1784,7 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op, getSignExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); } } + // If Start and Step are constants, check if we can apply this // transformation: // sext{C1,+,C2} --> C1 + sext{0,+,C2} if C1 < C2 @@ -1777,6 +1810,11 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op, } } + // If the input value is provably positive and we could not simplify + // away the sext build a zext instead. + if (isKnownNonNegative(Op)) + return getZeroExtendExpr(Op, Ty); + // The cast wasn't folded; create an explicit cast node. // Recompute the insert position, as it may have been invalidated. if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; @@ -1836,11 +1874,10 @@ const SCEV *ScalarEvolution::getAnyExtendExpr(const SCEV *Op, return ZExt; } -/// CollectAddOperandsWithScales - Process the given Ops list, which is -/// a list of operands to be added under the given scale, update the given -/// map. This is a helper function for getAddRecExpr. As an example of -/// what it does, given a sequence of operands that would form an add -/// expression like this: +/// Process the given Ops list, which is a list of operands to be added under +/// the given scale, update the given map. This is a helper function for +/// getAddRecExpr. As an example of what it does, given a sequence of operands +/// that would form an add expression like this: /// /// m + n + 13 + (A * (o + p + (B * (q + m + 29)))) + r + (-1 * r) /// @@ -1899,7 +1936,7 @@ CollectAddOperandsWithScales(DenseMap<const SCEV *, APInt> &M, // the map. SmallVector<const SCEV *, 4> MulOps(Mul->op_begin()+1, Mul->op_end()); const SCEV *Key = SE.getMulExpr(MulOps); - auto Pair = M.insert(std::make_pair(Key, NewScale)); + auto Pair = M.insert({Key, NewScale}); if (Pair.second) { NewOps.push_back(Pair.first->first); } else { @@ -1912,7 +1949,7 @@ CollectAddOperandsWithScales(DenseMap<const SCEV *, APInt> &M, } else { // An ordinary operand. Update the map. std::pair<DenseMap<const SCEV *, APInt>::iterator, bool> Pair = - M.insert(std::make_pair(Ops[i], Scale)); + M.insert({Ops[i], Scale}); if (Pair.second) { NewOps.push_back(Pair.first->first); } else { @@ -1965,15 +2002,14 @@ StrengthenNoWrapFlags(ScalarEvolution *SE, SCEVTypes Type, const APInt &C = cast<SCEVConstant>(Ops[0])->getAPInt(); if (!(SignOrUnsignWrap & SCEV::FlagNSW)) { - auto NSWRegion = - ConstantRange::makeNoWrapRegion(Instruction::Add, C, OBO::NoSignedWrap); + auto NSWRegion = ConstantRange::makeGuaranteedNoWrapRegion( + Instruction::Add, C, OBO::NoSignedWrap); if (NSWRegion.contains(SE->getSignedRange(Ops[1]))) Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNSW); } if (!(SignOrUnsignWrap & SCEV::FlagNUW)) { - auto NUWRegion = - ConstantRange::makeNoWrapRegion(Instruction::Add, C, - OBO::NoUnsignedWrap); + auto NUWRegion = ConstantRange::makeGuaranteedNoWrapRegion( + Instruction::Add, C, OBO::NoUnsignedWrap); if (NUWRegion.contains(SE->getUnsignedRange(Ops[1]))) Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW); } @@ -1982,8 +2018,7 @@ StrengthenNoWrapFlags(ScalarEvolution *SE, SCEVTypes Type, return Flags; } -/// getAddExpr - Get a canonical add expression, or something simpler if -/// possible. +/// Get a canonical add expression, or something simpler if possible. const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops, SCEV::NoWrapFlags Flags) { assert(!(Flags & ~(SCEV::FlagNUW | SCEV::FlagNSW)) && @@ -2266,7 +2301,10 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops, SmallVector<const SCEV *, 4> AddRecOps(AddRec->op_begin(), AddRec->op_end()); - AddRecOps[0] = getAddExpr(LIOps); + // This follows from the fact that the no-wrap flags on the outer add + // expression are applicable on the 0th iteration, when the add recurrence + // will be equal to its start value. + AddRecOps[0] = getAddExpr(LIOps, Flags); // Build the new addrec. Propagate the NUW and NSW flags if both the // outer add and the inner addrec are guaranteed to have no overflow. @@ -2391,8 +2429,7 @@ static bool containsConstantSomewhere(const SCEV *StartExpr) { return false; } -/// getMulExpr - Get a canonical multiply expression, or something simpler if -/// possible. +/// Get a canonical multiply expression, or something simpler if possible. const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops, SCEV::NoWrapFlags Flags) { assert(Flags == maskFlags(Flags, SCEV::FlagNUW | SCEV::FlagNSW) && @@ -2632,8 +2669,8 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops, return S; } -/// getUDivExpr - Get a canonical unsigned division expression, or something -/// simpler if possible. +/// Get a canonical unsigned division expression, or something simpler if +/// possible. const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS, const SCEV *RHS) { assert(getEffectiveSCEVType(LHS->getType()) == @@ -2764,10 +2801,10 @@ static const APInt gcd(const SCEVConstant *C1, const SCEVConstant *C2) { return APIntOps::GreatestCommonDivisor(A, B); } -/// getUDivExactExpr - Get a canonical unsigned division expression, or -/// something simpler if possible. There is no representation for an exact udiv -/// in SCEV IR, but we can attempt to remove factors from the LHS and RHS. -/// We can't do this when it's not exact because the udiv may be clearing bits. +/// Get a canonical unsigned division expression, or something simpler if +/// possible. There is no representation for an exact udiv in SCEV IR, but we +/// can attempt to remove factors from the LHS and RHS. We can't do this when +/// it's not exact because the udiv may be clearing bits. const SCEV *ScalarEvolution::getUDivExactExpr(const SCEV *LHS, const SCEV *RHS) { // TODO: we could try to find factors in all sorts of things, but for now we @@ -2821,8 +2858,8 @@ const SCEV *ScalarEvolution::getUDivExactExpr(const SCEV *LHS, return getUDivExpr(LHS, RHS); } -/// getAddRecExpr - Get an add recurrence expression for the specified loop. -/// Simplify the expression as much as possible. +/// Get an add recurrence expression for the specified loop. Simplify the +/// expression as much as possible. const SCEV *ScalarEvolution::getAddRecExpr(const SCEV *Start, const SCEV *Step, const Loop *L, SCEV::NoWrapFlags Flags) { @@ -2838,8 +2875,8 @@ const SCEV *ScalarEvolution::getAddRecExpr(const SCEV *Start, const SCEV *Step, return getAddRecExpr(Operands, L, Flags); } -/// getAddRecExpr - Get an add recurrence expression for the specified loop. -/// Simplify the expression as much as possible. +/// Get an add recurrence expression for the specified loop. Simplify the +/// expression as much as possible. const SCEV * ScalarEvolution::getAddRecExpr(SmallVectorImpl<const SCEV *> &Operands, const Loop *L, SCEV::NoWrapFlags Flags) { @@ -2985,9 +3022,7 @@ ScalarEvolution::getGEPExpr(Type *PointeeType, const SCEV *BaseExpr, const SCEV *ScalarEvolution::getSMaxExpr(const SCEV *LHS, const SCEV *RHS) { - SmallVector<const SCEV *, 2> Ops; - Ops.push_back(LHS); - Ops.push_back(RHS); + SmallVector<const SCEV *, 2> Ops = {LHS, RHS}; return getSMaxExpr(Ops); } @@ -3088,9 +3123,7 @@ ScalarEvolution::getSMaxExpr(SmallVectorImpl<const SCEV *> &Ops) { const SCEV *ScalarEvolution::getUMaxExpr(const SCEV *LHS, const SCEV *RHS) { - SmallVector<const SCEV *, 2> Ops; - Ops.push_back(LHS); - Ops.push_back(RHS); + SmallVector<const SCEV *, 2> Ops = {LHS, RHS}; return getUMaxExpr(Ops); } @@ -3244,26 +3277,25 @@ const SCEV *ScalarEvolution::getUnknown(Value *V) { // Basic SCEV Analysis and PHI Idiom Recognition Code // -/// isSCEVable - Test if values of the given type are analyzable within -/// the SCEV framework. This primarily includes integer types, and it -/// can optionally include pointer types if the ScalarEvolution class -/// has access to target-specific information. +/// Test if values of the given type are analyzable within the SCEV +/// framework. This primarily includes integer types, and it can optionally +/// include pointer types if the ScalarEvolution class has access to +/// target-specific information. bool ScalarEvolution::isSCEVable(Type *Ty) const { // Integers and pointers are always SCEVable. return Ty->isIntegerTy() || Ty->isPointerTy(); } -/// getTypeSizeInBits - Return the size in bits of the specified type, -/// for which isSCEVable must return true. +/// Return the size in bits of the specified type, for which isSCEVable must +/// return true. uint64_t ScalarEvolution::getTypeSizeInBits(Type *Ty) const { assert(isSCEVable(Ty) && "Type is not SCEVable!"); return getDataLayout().getTypeSizeInBits(Ty); } -/// getEffectiveSCEVType - Return a type with the same bitwidth as -/// the given type and which represents how SCEV will treat the given -/// type, for which isSCEVable must return true. For pointer types, -/// this is the pointer-sized integer type. +/// Return a type with the same bitwidth as the given type and which represents +/// how SCEV will treat the given type, for which isSCEVable must return +/// true. For pointer types, this is the pointer-sized integer type. Type *ScalarEvolution::getEffectiveSCEVType(Type *Ty) const { assert(isSCEVable(Ty) && "Type is not SCEVable!"); @@ -3310,15 +3342,88 @@ bool ScalarEvolution::checkValidity(const SCEV *S) const { return !F.FindOne; } -/// getSCEV - Return an existing SCEV if it exists, otherwise analyze the -/// expression and create a new one. +namespace { +// Helper class working with SCEVTraversal to figure out if a SCEV contains +// a sub SCEV of scAddRecExpr type. FindInvalidSCEVUnknown::FoundOne is set +// iff if such sub scAddRecExpr type SCEV is found. +struct FindAddRecurrence { + bool FoundOne; + FindAddRecurrence() : FoundOne(false) {} + + bool follow(const SCEV *S) { + switch (static_cast<SCEVTypes>(S->getSCEVType())) { + case scAddRecExpr: + FoundOne = true; + case scConstant: + case scUnknown: + case scCouldNotCompute: + return false; + default: + return true; + } + } + bool isDone() const { return FoundOne; } +}; +} + +bool ScalarEvolution::containsAddRecurrence(const SCEV *S) { + HasRecMapType::iterator I = HasRecMap.find_as(S); + if (I != HasRecMap.end()) + return I->second; + + FindAddRecurrence F; + SCEVTraversal<FindAddRecurrence> ST(F); + ST.visitAll(S); + HasRecMap.insert({S, F.FoundOne}); + return F.FoundOne; +} + +/// Return the Value set from S. +SetVector<Value *> *ScalarEvolution::getSCEVValues(const SCEV *S) { + ExprValueMapType::iterator SI = ExprValueMap.find_as(S); + if (SI == ExprValueMap.end()) + return nullptr; +#ifndef NDEBUG + if (VerifySCEVMap) { + // Check there is no dangling Value in the set returned. + for (const auto &VE : SI->second) + assert(ValueExprMap.count(VE)); + } +#endif + return &SI->second; +} + +/// Erase Value from ValueExprMap and ExprValueMap. If ValueExprMap.erase(V) is +/// not used together with forgetMemoizedResults(S), eraseValueFromMap should be +/// used instead to ensure whenever V->S is removed from ValueExprMap, V is also +/// removed from the set of ExprValueMap[S]. +void ScalarEvolution::eraseValueFromMap(Value *V) { + ValueExprMapType::iterator I = ValueExprMap.find_as(V); + if (I != ValueExprMap.end()) { + const SCEV *S = I->second; + SetVector<Value *> *SV = getSCEVValues(S); + // Remove V from the set of ExprValueMap[S] + if (SV) + SV->remove(V); + ValueExprMap.erase(V); + } +} + +/// Return an existing SCEV if it exists, otherwise analyze the expression and +/// create a new one. const SCEV *ScalarEvolution::getSCEV(Value *V) { assert(isSCEVable(V->getType()) && "Value is not SCEVable!"); const SCEV *S = getExistingSCEV(V); if (S == nullptr) { S = createSCEV(V); - ValueExprMap.insert(std::make_pair(SCEVCallbackVH(V, this), S)); + // During PHI resolution, it is possible to create two SCEVs for the same + // V, so it is needed to double check whether V->S is inserted into + // ValueExprMap before insert S->V into ExprValueMap. + std::pair<ValueExprMapType::iterator, bool> Pair = + ValueExprMap.insert({SCEVCallbackVH(V, this), S}); + if (Pair.second) + ExprValueMap[S].insert(V); } return S; } @@ -3331,12 +3436,13 @@ const SCEV *ScalarEvolution::getExistingSCEV(Value *V) { const SCEV *S = I->second; if (checkValidity(S)) return S; + forgetMemoizedResults(S); ValueExprMap.erase(I); } return nullptr; } -/// getNegativeSCEV - Return a SCEV corresponding to -V = -1*V +/// Return a SCEV corresponding to -V = -1*V /// const SCEV *ScalarEvolution::getNegativeSCEV(const SCEV *V, SCEV::NoWrapFlags Flags) { @@ -3350,7 +3456,7 @@ const SCEV *ScalarEvolution::getNegativeSCEV(const SCEV *V, V, getConstant(cast<ConstantInt>(Constant::getAllOnesValue(Ty))), Flags); } -/// getNotSCEV - Return a SCEV corresponding to ~V = -1-V +/// Return a SCEV corresponding to ~V = -1-V const SCEV *ScalarEvolution::getNotSCEV(const SCEV *V) { if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V)) return getConstant( @@ -3363,7 +3469,6 @@ const SCEV *ScalarEvolution::getNotSCEV(const SCEV *V) { return getMinusSCEV(AllOnes, V); } -/// getMinusSCEV - Return LHS-RHS. Minus is represented in SCEV as A+B*-1. const SCEV *ScalarEvolution::getMinusSCEV(const SCEV *LHS, const SCEV *RHS, SCEV::NoWrapFlags Flags) { // Fast path: X - X --> 0. @@ -3402,9 +3507,6 @@ const SCEV *ScalarEvolution::getMinusSCEV(const SCEV *LHS, const SCEV *RHS, return getAddExpr(LHS, getNegativeSCEV(RHS, NegFlags), AddFlags); } -/// getTruncateOrZeroExtend - Return a SCEV corresponding to a conversion of the -/// input value to the specified type. If the type must be extended, it is zero -/// extended. const SCEV * ScalarEvolution::getTruncateOrZeroExtend(const SCEV *V, Type *Ty) { Type *SrcTy = V->getType(); @@ -3418,9 +3520,6 @@ ScalarEvolution::getTruncateOrZeroExtend(const SCEV *V, Type *Ty) { return getZeroExtendExpr(V, Ty); } -/// getTruncateOrSignExtend - Return a SCEV corresponding to a conversion of the -/// input value to the specified type. If the type must be extended, it is sign -/// extended. const SCEV * ScalarEvolution::getTruncateOrSignExtend(const SCEV *V, Type *Ty) { @@ -3435,9 +3534,6 @@ ScalarEvolution::getTruncateOrSignExtend(const SCEV *V, return getSignExtendExpr(V, Ty); } -/// getNoopOrZeroExtend - Return a SCEV corresponding to a conversion of the -/// input value to the specified type. If the type must be extended, it is zero -/// extended. The conversion must not be narrowing. const SCEV * ScalarEvolution::getNoopOrZeroExtend(const SCEV *V, Type *Ty) { Type *SrcTy = V->getType(); @@ -3451,9 +3547,6 @@ ScalarEvolution::getNoopOrZeroExtend(const SCEV *V, Type *Ty) { return getZeroExtendExpr(V, Ty); } -/// getNoopOrSignExtend - Return a SCEV corresponding to a conversion of the -/// input value to the specified type. If the type must be extended, it is sign -/// extended. The conversion must not be narrowing. const SCEV * ScalarEvolution::getNoopOrSignExtend(const SCEV *V, Type *Ty) { Type *SrcTy = V->getType(); @@ -3467,10 +3560,6 @@ ScalarEvolution::getNoopOrSignExtend(const SCEV *V, Type *Ty) { return getSignExtendExpr(V, Ty); } -/// getNoopOrAnyExtend - Return a SCEV corresponding to a conversion of -/// the input value to the specified type. If the type must be extended, -/// it is extended with unspecified bits. The conversion must not be -/// narrowing. const SCEV * ScalarEvolution::getNoopOrAnyExtend(const SCEV *V, Type *Ty) { Type *SrcTy = V->getType(); @@ -3484,8 +3573,6 @@ ScalarEvolution::getNoopOrAnyExtend(const SCEV *V, Type *Ty) { return getAnyExtendExpr(V, Ty); } -/// getTruncateOrNoop - Return a SCEV corresponding to a conversion of the -/// input value to the specified type. The conversion must not be widening. const SCEV * ScalarEvolution::getTruncateOrNoop(const SCEV *V, Type *Ty) { Type *SrcTy = V->getType(); @@ -3499,9 +3586,6 @@ ScalarEvolution::getTruncateOrNoop(const SCEV *V, Type *Ty) { return getTruncateExpr(V, Ty); } -/// getUMaxFromMismatchedTypes - Promote the operands to the wider of -/// the types using zero-extension, and then perform a umax operation -/// with them. const SCEV *ScalarEvolution::getUMaxFromMismatchedTypes(const SCEV *LHS, const SCEV *RHS) { const SCEV *PromotedLHS = LHS; @@ -3515,9 +3599,6 @@ const SCEV *ScalarEvolution::getUMaxFromMismatchedTypes(const SCEV *LHS, return getUMaxExpr(PromotedLHS, PromotedRHS); } -/// getUMinFromMismatchedTypes - Promote the operands to the wider of -/// the types using zero-extension, and then perform a umin operation -/// with them. const SCEV *ScalarEvolution::getUMinFromMismatchedTypes(const SCEV *LHS, const SCEV *RHS) { const SCEV *PromotedLHS = LHS; @@ -3531,10 +3612,6 @@ const SCEV *ScalarEvolution::getUMinFromMismatchedTypes(const SCEV *LHS, return getUMinExpr(PromotedLHS, PromotedRHS); } -/// getPointerBase - Transitively follow the chain of pointer-type operands -/// until reaching a SCEV that does not have a single pointer operand. This -/// returns a SCEVUnknown pointer for well-formed pointer-type expressions, -/// but corner cases do exist. const SCEV *ScalarEvolution::getPointerBase(const SCEV *V) { // A pointer operand may evaluate to a nonpointer expression, such as null. if (!V->getType()->isPointerTy()) @@ -3559,8 +3636,7 @@ const SCEV *ScalarEvolution::getPointerBase(const SCEV *V) { return V; } -/// PushDefUseChildren - Push users of the given Instruction -/// onto the given Worklist. +/// Push users of the given Instruction onto the given Worklist. static void PushDefUseChildren(Instruction *I, SmallVectorImpl<Instruction *> &Worklist) { @@ -3569,12 +3645,7 @@ PushDefUseChildren(Instruction *I, Worklist.push_back(cast<Instruction>(U)); } -/// ForgetSymbolicValue - This looks up computed SCEV values for all -/// instructions that depend on the given instruction and removes them from -/// the ValueExprMapType map if they reference SymName. This is used during PHI -/// resolution. -void -ScalarEvolution::ForgetSymbolicName(Instruction *PN, const SCEV *SymName) { +void ScalarEvolution::forgetSymbolicName(Instruction *PN, const SCEV *SymName) { SmallVector<Instruction *, 16> Worklist; PushDefUseChildren(PN, Worklist); @@ -3616,10 +3687,10 @@ ScalarEvolution::ForgetSymbolicName(Instruction *PN, const SCEV *SymName) { namespace { class SCEVInitRewriter : public SCEVRewriteVisitor<SCEVInitRewriter> { public: - static const SCEV *rewrite(const SCEV *Scev, const Loop *L, + static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE) { SCEVInitRewriter Rewriter(L, SE); - const SCEV *Result = Rewriter.visit(Scev); + const SCEV *Result = Rewriter.visit(S); return Rewriter.isValid() ? Result : SE.getCouldNotCompute(); } @@ -3649,10 +3720,10 @@ private: class SCEVShiftRewriter : public SCEVRewriteVisitor<SCEVShiftRewriter> { public: - static const SCEV *rewrite(const SCEV *Scev, const Loop *L, + static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE) { SCEVShiftRewriter Rewriter(L, SE); - const SCEV *Result = Rewriter.visit(Scev); + const SCEV *Result = Rewriter.visit(S); return Rewriter.isValid() ? Result : SE.getCouldNotCompute(); } @@ -3680,6 +3751,167 @@ private: }; } // end anonymous namespace +SCEV::NoWrapFlags +ScalarEvolution::proveNoWrapViaConstantRanges(const SCEVAddRecExpr *AR) { + if (!AR->isAffine()) + return SCEV::FlagAnyWrap; + + typedef OverflowingBinaryOperator OBO; + SCEV::NoWrapFlags Result = SCEV::FlagAnyWrap; + + if (!AR->hasNoSignedWrap()) { + ConstantRange AddRecRange = getSignedRange(AR); + ConstantRange IncRange = getSignedRange(AR->getStepRecurrence(*this)); + + auto NSWRegion = ConstantRange::makeGuaranteedNoWrapRegion( + Instruction::Add, IncRange, OBO::NoSignedWrap); + if (NSWRegion.contains(AddRecRange)) + Result = ScalarEvolution::setFlags(Result, SCEV::FlagNSW); + } + + if (!AR->hasNoUnsignedWrap()) { + ConstantRange AddRecRange = getUnsignedRange(AR); + ConstantRange IncRange = getUnsignedRange(AR->getStepRecurrence(*this)); + + auto NUWRegion = ConstantRange::makeGuaranteedNoWrapRegion( + Instruction::Add, IncRange, OBO::NoUnsignedWrap); + if (NUWRegion.contains(AddRecRange)) + Result = ScalarEvolution::setFlags(Result, SCEV::FlagNUW); + } + + return Result; +} + +namespace { +/// Represents an abstract binary operation. This may exist as a +/// normal instruction or constant expression, or may have been +/// derived from an expression tree. +struct BinaryOp { + unsigned Opcode; + Value *LHS; + Value *RHS; + bool IsNSW; + bool IsNUW; + + /// Op is set if this BinaryOp corresponds to a concrete LLVM instruction or + /// constant expression. + Operator *Op; + + explicit BinaryOp(Operator *Op) + : Opcode(Op->getOpcode()), LHS(Op->getOperand(0)), RHS(Op->getOperand(1)), + IsNSW(false), IsNUW(false), Op(Op) { + if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(Op)) { + IsNSW = OBO->hasNoSignedWrap(); + IsNUW = OBO->hasNoUnsignedWrap(); + } + } + + explicit BinaryOp(unsigned Opcode, Value *LHS, Value *RHS, bool IsNSW = false, + bool IsNUW = false) + : Opcode(Opcode), LHS(LHS), RHS(RHS), IsNSW(IsNSW), IsNUW(IsNUW), + Op(nullptr) {} +}; +} + + +/// Try to map \p V into a BinaryOp, and return \c None on failure. +static Optional<BinaryOp> MatchBinaryOp(Value *V, DominatorTree &DT) { + auto *Op = dyn_cast<Operator>(V); + if (!Op) + return None; + + // Implementation detail: all the cleverness here should happen without + // creating new SCEV expressions -- our caller knowns tricks to avoid creating + // SCEV expressions when possible, and we should not break that. + + switch (Op->getOpcode()) { + case Instruction::Add: + case Instruction::Sub: + case Instruction::Mul: + case Instruction::UDiv: + case Instruction::And: + case Instruction::Or: + case Instruction::AShr: + case Instruction::Shl: + return BinaryOp(Op); + + case Instruction::Xor: + if (auto *RHSC = dyn_cast<ConstantInt>(Op->getOperand(1))) + // If the RHS of the xor is a signbit, then this is just an add. + // Instcombine turns add of signbit into xor as a strength reduction step. + if (RHSC->getValue().isSignBit()) + return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1)); + return BinaryOp(Op); + + case Instruction::LShr: + // Turn logical shift right of a constant into a unsigned divide. + if (ConstantInt *SA = dyn_cast<ConstantInt>(Op->getOperand(1))) { + uint32_t BitWidth = cast<IntegerType>(Op->getType())->getBitWidth(); + + // If the shift count is not less than the bitwidth, the result of + // the shift is undefined. Don't try to analyze it, because the + // resolution chosen here may differ from the resolution chosen in + // other parts of the compiler. + if (SA->getValue().ult(BitWidth)) { + Constant *X = + ConstantInt::get(SA->getContext(), + APInt::getOneBitSet(BitWidth, SA->getZExtValue())); + return BinaryOp(Instruction::UDiv, Op->getOperand(0), X); + } + } + return BinaryOp(Op); + + case Instruction::ExtractValue: { + auto *EVI = cast<ExtractValueInst>(Op); + if (EVI->getNumIndices() != 1 || EVI->getIndices()[0] != 0) + break; + + auto *CI = dyn_cast<CallInst>(EVI->getAggregateOperand()); + if (!CI) + break; + + if (auto *F = CI->getCalledFunction()) + switch (F->getIntrinsicID()) { + case Intrinsic::sadd_with_overflow: + case Intrinsic::uadd_with_overflow: { + if (!isOverflowIntrinsicNoWrap(cast<IntrinsicInst>(CI), DT)) + return BinaryOp(Instruction::Add, CI->getArgOperand(0), + CI->getArgOperand(1)); + + // Now that we know that all uses of the arithmetic-result component of + // CI are guarded by the overflow check, we can go ahead and pretend + // that the arithmetic is non-overflowing. + if (F->getIntrinsicID() == Intrinsic::sadd_with_overflow) + return BinaryOp(Instruction::Add, CI->getArgOperand(0), + CI->getArgOperand(1), /* IsNSW = */ true, + /* IsNUW = */ false); + else + return BinaryOp(Instruction::Add, CI->getArgOperand(0), + CI->getArgOperand(1), /* IsNSW = */ false, + /* IsNUW*/ true); + } + + case Intrinsic::ssub_with_overflow: + case Intrinsic::usub_with_overflow: + return BinaryOp(Instruction::Sub, CI->getArgOperand(0), + CI->getArgOperand(1)); + + case Intrinsic::smul_with_overflow: + case Intrinsic::umul_with_overflow: + return BinaryOp(Instruction::Mul, CI->getArgOperand(0), + CI->getArgOperand(1)); + default: + break; + } + } + + default: + break; + } + + return None; +} + const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) { const Loop *L = LI.getLoopFor(PN->getParent()); if (!L || L->getHeader() != PN->getParent()) @@ -3710,7 +3942,7 @@ const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) { const SCEV *SymbolicName = getUnknown(PN); assert(ValueExprMap.find_as(PN) == ValueExprMap.end() && "PHI node already processed?"); - ValueExprMap.insert(std::make_pair(SCEVCallbackVH(PN, this), SymbolicName)); + ValueExprMap.insert({SCEVCallbackVH(PN, this), SymbolicName}); // Using this symbolic name for the PHI, analyze the value coming around // the back-edge. @@ -3747,13 +3979,11 @@ const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) { cast<SCEVAddRecExpr>(Accum)->getLoop() == L)) { SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap; - // If the increment doesn't overflow, then neither the addrec nor - // the post-increment will overflow. - if (const AddOperator *OBO = dyn_cast<AddOperator>(BEValueV)) { - if (OBO->getOperand(0) == PN) { - if (OBO->hasNoUnsignedWrap()) + if (auto BO = MatchBinaryOp(BEValueV, DT)) { + if (BO->Opcode == Instruction::Add && BO->LHS == PN) { + if (BO->IsNUW) Flags = setFlags(Flags, SCEV::FlagNUW); - if (OBO->hasNoSignedWrap()) + if (BO->IsNSW) Flags = setFlags(Flags, SCEV::FlagNSW); } } else if (GEPOperator *GEP = dyn_cast<GEPOperator>(BEValueV)) { @@ -3779,16 +4009,19 @@ const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) { const SCEV *StartVal = getSCEV(StartValueV); const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags); - // Since the no-wrap flags are on the increment, they apply to the - // post-incremented value as well. - if (isLoopInvariant(Accum, L)) - (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags); - // Okay, for the entire analysis of this edge we assumed the PHI // to be symbolic. We now need to go back and purge all of the // entries for the scalars that use the symbolic expression. - ForgetSymbolicName(PN, SymbolicName); + forgetSymbolicName(PN, SymbolicName); ValueExprMap[SCEVCallbackVH(PN, this)] = PHISCEV; + + // We can add Flags to the post-inc expression only if we + // know that it us *undefined behavior* for BEValueV to + // overflow. + if (auto *BEInst = dyn_cast<Instruction>(BEValueV)) + if (isLoopInvariant(Accum, L) && isAddRecNeverPoison(BEInst, L)) + (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags); + return PHISCEV; } } @@ -3811,12 +4044,18 @@ const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) { // Okay, for the entire analysis of this edge we assumed the PHI // to be symbolic. We now need to go back and purge all of the // entries for the scalars that use the symbolic expression. - ForgetSymbolicName(PN, SymbolicName); + forgetSymbolicName(PN, SymbolicName); ValueExprMap[SCEVCallbackVH(PN, this)] = Shifted; return Shifted; } } } + + // Remove the temporary PHI node SCEV that has been inserted while intending + // to create an AddRecExpr for this PHI node. We can not keep this temporary + // as it will prevent later (possibly simpler) SCEV expressions to be added + // to the ValueExprMap. + ValueExprMap.erase(PN); } return nullptr; @@ -4083,26 +4322,21 @@ const SCEV *ScalarEvolution::createNodeForSelectOrPHI(Instruction *I, return getUnknown(I); } -/// createNodeForGEP - Expand GEP instructions into add and multiply -/// operations. This allows them to be analyzed by regular SCEV code. -/// +/// Expand GEP instructions into add and multiply operations. This allows them +/// to be analyzed by regular SCEV code. const SCEV *ScalarEvolution::createNodeForGEP(GEPOperator *GEP) { - Value *Base = GEP->getOperand(0); // Don't attempt to analyze GEPs over unsized objects. - if (!Base->getType()->getPointerElementType()->isSized()) + if (!GEP->getSourceElementType()->isSized()) return getUnknown(GEP); SmallVector<const SCEV *, 4> IndexExprs; for (auto Index = GEP->idx_begin(); Index != GEP->idx_end(); ++Index) IndexExprs.push_back(getSCEV(*Index)); - return getGEPExpr(GEP->getSourceElementType(), getSCEV(Base), IndexExprs, - GEP->isInBounds()); + return getGEPExpr(GEP->getSourceElementType(), + getSCEV(GEP->getPointerOperand()), + IndexExprs, GEP->isInBounds()); } -/// GetMinTrailingZeros - Determine the minimum number of zero bits that S is -/// guaranteed to end in (at every loop iteration). It is, at the same time, -/// the minimum number of times S is divisible by 2. For example, given {4,+,8} -/// it returns 2. If S is guaranteed to be 0, it returns the bitwidth of S. uint32_t ScalarEvolution::GetMinTrailingZeros(const SCEV *S) { if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S)) @@ -4180,8 +4414,7 @@ ScalarEvolution::GetMinTrailingZeros(const SCEV *S) { return 0; } -/// GetRangeFromMetadata - Helper method to assign a range to V from -/// metadata present in the IR. +/// Helper method to assign a range to V from metadata present in the IR. static Optional<ConstantRange> GetRangeFromMetadata(Value *V) { if (Instruction *I = dyn_cast<Instruction>(V)) if (MDNode *MD = I->getMetadata(LLVMContext::MD_range)) @@ -4190,10 +4423,9 @@ static Optional<ConstantRange> GetRangeFromMetadata(Value *V) { return None; } -/// getRange - Determine the range for a particular SCEV. If SignHint is +/// Determine the range for a particular SCEV. If SignHint is /// HINT_RANGE_UNSIGNED (resp. HINT_RANGE_SIGNED) then getRange prefers ranges /// with a "cleaner" unsigned (resp. signed) representation. -/// ConstantRange ScalarEvolution::getRange(const SCEV *S, ScalarEvolution::RangeSignHint SignHint) { @@ -4282,7 +4514,7 @@ ScalarEvolution::getRange(const SCEV *S, if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(S)) { // If there's no unsigned wrap, the value will never be less than its // initial value. - if (AddRec->getNoWrapFlags(SCEV::FlagNUW)) + if (AddRec->hasNoUnsignedWrap()) if (const SCEVConstant *C = dyn_cast<SCEVConstant>(AddRec->getStart())) if (!C->getValue()->isZero()) ConservativeResult = ConservativeResult.intersectWith( @@ -4290,7 +4522,7 @@ ScalarEvolution::getRange(const SCEV *S, // If there's no signed wrap, and all the operands have the same sign or // zero, the value won't ever change sign. - if (AddRec->getNoWrapFlags(SCEV::FlagNSW)) { + if (AddRec->hasNoSignedWrap()) { bool AllNonNeg = true; bool AllNonPos = true; for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) { @@ -4309,66 +4541,22 @@ ScalarEvolution::getRange(const SCEV *S, // TODO: non-affine addrec if (AddRec->isAffine()) { - Type *Ty = AddRec->getType(); const SCEV *MaxBECount = getMaxBackedgeTakenCount(AddRec->getLoop()); if (!isa<SCEVCouldNotCompute>(MaxBECount) && getTypeSizeInBits(MaxBECount->getType()) <= BitWidth) { - - // Check for overflow. This must be done with ConstantRange arithmetic - // because we could be called from within the ScalarEvolution overflow - // checking code. - - MaxBECount = getNoopOrZeroExtend(MaxBECount, Ty); - ConstantRange MaxBECountRange = getUnsignedRange(MaxBECount); - ConstantRange ZExtMaxBECountRange = - MaxBECountRange.zextOrTrunc(BitWidth * 2 + 1); - - const SCEV *Start = AddRec->getStart(); - const SCEV *Step = AddRec->getStepRecurrence(*this); - ConstantRange StepSRange = getSignedRange(Step); - ConstantRange SExtStepSRange = StepSRange.sextOrTrunc(BitWidth * 2 + 1); - - ConstantRange StartURange = getUnsignedRange(Start); - ConstantRange EndURange = - StartURange.add(MaxBECountRange.multiply(StepSRange)); - - // Check for unsigned overflow. - ConstantRange ZExtStartURange = - StartURange.zextOrTrunc(BitWidth * 2 + 1); - ConstantRange ZExtEndURange = EndURange.zextOrTrunc(BitWidth * 2 + 1); - if (ZExtStartURange.add(ZExtMaxBECountRange.multiply(SExtStepSRange)) == - ZExtEndURange) { - APInt Min = APIntOps::umin(StartURange.getUnsignedMin(), - EndURange.getUnsignedMin()); - APInt Max = APIntOps::umax(StartURange.getUnsignedMax(), - EndURange.getUnsignedMax()); - bool IsFullRange = Min.isMinValue() && Max.isMaxValue(); - if (!IsFullRange) - ConservativeResult = - ConservativeResult.intersectWith(ConstantRange(Min, Max + 1)); - } - - ConstantRange StartSRange = getSignedRange(Start); - ConstantRange EndSRange = - StartSRange.add(MaxBECountRange.multiply(StepSRange)); - - // Check for signed overflow. This must be done with ConstantRange - // arithmetic because we could be called from within the ScalarEvolution - // overflow checking code. - ConstantRange SExtStartSRange = - StartSRange.sextOrTrunc(BitWidth * 2 + 1); - ConstantRange SExtEndSRange = EndSRange.sextOrTrunc(BitWidth * 2 + 1); - if (SExtStartSRange.add(ZExtMaxBECountRange.multiply(SExtStepSRange)) == - SExtEndSRange) { - APInt Min = APIntOps::smin(StartSRange.getSignedMin(), - EndSRange.getSignedMin()); - APInt Max = APIntOps::smax(StartSRange.getSignedMax(), - EndSRange.getSignedMax()); - bool IsFullRange = Min.isMinSignedValue() && Max.isMaxSignedValue(); - if (!IsFullRange) - ConservativeResult = - ConservativeResult.intersectWith(ConstantRange(Min, Max + 1)); - } + auto RangeFromAffine = getRangeForAffineAR( + AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount, + BitWidth); + if (!RangeFromAffine.isFullSet()) + ConservativeResult = + ConservativeResult.intersectWith(RangeFromAffine); + + auto RangeFromFactoring = getRangeViaFactoring( + AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount, + BitWidth); + if (!RangeFromFactoring.isFullSet()) + ConservativeResult = + ConservativeResult.intersectWith(RangeFromFactoring); } } @@ -4408,6 +4596,186 @@ ScalarEvolution::getRange(const SCEV *S, return setRange(S, SignHint, ConservativeResult); } +ConstantRange ScalarEvolution::getRangeForAffineAR(const SCEV *Start, + const SCEV *Step, + const SCEV *MaxBECount, + unsigned BitWidth) { + assert(!isa<SCEVCouldNotCompute>(MaxBECount) && + getTypeSizeInBits(MaxBECount->getType()) <= BitWidth && + "Precondition!"); + + ConstantRange Result(BitWidth, /* isFullSet = */ true); + + // Check for overflow. This must be done with ConstantRange arithmetic + // because we could be called from within the ScalarEvolution overflow + // checking code. + + MaxBECount = getNoopOrZeroExtend(MaxBECount, Start->getType()); + ConstantRange MaxBECountRange = getUnsignedRange(MaxBECount); + ConstantRange ZExtMaxBECountRange = + MaxBECountRange.zextOrTrunc(BitWidth * 2 + 1); + + ConstantRange StepSRange = getSignedRange(Step); + ConstantRange SExtStepSRange = StepSRange.sextOrTrunc(BitWidth * 2 + 1); + + ConstantRange StartURange = getUnsignedRange(Start); + ConstantRange EndURange = + StartURange.add(MaxBECountRange.multiply(StepSRange)); + + // Check for unsigned overflow. + ConstantRange ZExtStartURange = StartURange.zextOrTrunc(BitWidth * 2 + 1); + ConstantRange ZExtEndURange = EndURange.zextOrTrunc(BitWidth * 2 + 1); + if (ZExtStartURange.add(ZExtMaxBECountRange.multiply(SExtStepSRange)) == + ZExtEndURange) { + APInt Min = APIntOps::umin(StartURange.getUnsignedMin(), + EndURange.getUnsignedMin()); + APInt Max = APIntOps::umax(StartURange.getUnsignedMax(), + EndURange.getUnsignedMax()); + bool IsFullRange = Min.isMinValue() && Max.isMaxValue(); + if (!IsFullRange) + Result = + Result.intersectWith(ConstantRange(Min, Max + 1)); + } + + ConstantRange StartSRange = getSignedRange(Start); + ConstantRange EndSRange = + StartSRange.add(MaxBECountRange.multiply(StepSRange)); + + // Check for signed overflow. This must be done with ConstantRange + // arithmetic because we could be called from within the ScalarEvolution + // overflow checking code. + ConstantRange SExtStartSRange = StartSRange.sextOrTrunc(BitWidth * 2 + 1); + ConstantRange SExtEndSRange = EndSRange.sextOrTrunc(BitWidth * 2 + 1); + if (SExtStartSRange.add(ZExtMaxBECountRange.multiply(SExtStepSRange)) == + SExtEndSRange) { + APInt Min = + APIntOps::smin(StartSRange.getSignedMin(), EndSRange.getSignedMin()); + APInt Max = + APIntOps::smax(StartSRange.getSignedMax(), EndSRange.getSignedMax()); + bool IsFullRange = Min.isMinSignedValue() && Max.isMaxSignedValue(); + if (!IsFullRange) + Result = + Result.intersectWith(ConstantRange(Min, Max + 1)); + } + + return Result; +} + +ConstantRange ScalarEvolution::getRangeViaFactoring(const SCEV *Start, + const SCEV *Step, + const SCEV *MaxBECount, + unsigned BitWidth) { + // RangeOf({C?A:B,+,C?P:Q}) == RangeOf(C?{A,+,P}:{B,+,Q}) + // == RangeOf({A,+,P}) union RangeOf({B,+,Q}) + + struct SelectPattern { + Value *Condition = nullptr; + APInt TrueValue; + APInt FalseValue; + + explicit SelectPattern(ScalarEvolution &SE, unsigned BitWidth, + const SCEV *S) { + Optional<unsigned> CastOp; + APInt Offset(BitWidth, 0); + + assert(SE.getTypeSizeInBits(S->getType()) == BitWidth && + "Should be!"); + + // Peel off a constant offset: + if (auto *SA = dyn_cast<SCEVAddExpr>(S)) { + // In the future we could consider being smarter here and handle + // {Start+Step,+,Step} too. + if (SA->getNumOperands() != 2 || !isa<SCEVConstant>(SA->getOperand(0))) + return; + + Offset = cast<SCEVConstant>(SA->getOperand(0))->getAPInt(); + S = SA->getOperand(1); + } + + // Peel off a cast operation + if (auto *SCast = dyn_cast<SCEVCastExpr>(S)) { + CastOp = SCast->getSCEVType(); + S = SCast->getOperand(); + } + + using namespace llvm::PatternMatch; + + auto *SU = dyn_cast<SCEVUnknown>(S); + const APInt *TrueVal, *FalseVal; + if (!SU || + !match(SU->getValue(), m_Select(m_Value(Condition), m_APInt(TrueVal), + m_APInt(FalseVal)))) { + Condition = nullptr; + return; + } + + TrueValue = *TrueVal; + FalseValue = *FalseVal; + + // Re-apply the cast we peeled off earlier + if (CastOp.hasValue()) + switch (*CastOp) { + default: + llvm_unreachable("Unknown SCEV cast type!"); + + case scTruncate: + TrueValue = TrueValue.trunc(BitWidth); + FalseValue = FalseValue.trunc(BitWidth); + break; + case scZeroExtend: + TrueValue = TrueValue.zext(BitWidth); + FalseValue = FalseValue.zext(BitWidth); + break; + case scSignExtend: + TrueValue = TrueValue.sext(BitWidth); + FalseValue = FalseValue.sext(BitWidth); + break; + } + + // Re-apply the constant offset we peeled off earlier + TrueValue += Offset; + FalseValue += Offset; + } + + bool isRecognized() { return Condition != nullptr; } + }; + + SelectPattern StartPattern(*this, BitWidth, Start); + if (!StartPattern.isRecognized()) + return ConstantRange(BitWidth, /* isFullSet = */ true); + + SelectPattern StepPattern(*this, BitWidth, Step); + if (!StepPattern.isRecognized()) + return ConstantRange(BitWidth, /* isFullSet = */ true); + + if (StartPattern.Condition != StepPattern.Condition) { + // We don't handle this case today; but we could, by considering four + // possibilities below instead of two. I'm not sure if there are cases where + // that will help over what getRange already does, though. + return ConstantRange(BitWidth, /* isFullSet = */ true); + } + + // NB! Calling ScalarEvolution::getConstant is fine, but we should not try to + // construct arbitrary general SCEV expressions here. This function is called + // from deep in the call stack, and calling getSCEV (on a sext instruction, + // say) can end up caching a suboptimal value. + + // FIXME: without the explicit `this` receiver below, MSVC errors out with + // C2352 and C2512 (otherwise it isn't needed). + + const SCEV *TrueStart = this->getConstant(StartPattern.TrueValue); + const SCEV *TrueStep = this->getConstant(StepPattern.TrueValue); + const SCEV *FalseStart = this->getConstant(StartPattern.FalseValue); + const SCEV *FalseStep = this->getConstant(StepPattern.FalseValue); + + ConstantRange TrueRange = + this->getRangeForAffineAR(TrueStart, TrueStep, MaxBECount, BitWidth); + ConstantRange FalseRange = + this->getRangeForAffineAR(FalseStart, FalseStep, MaxBECount, BitWidth); + + return TrueRange.unionWith(FalseRange); +} + SCEV::NoWrapFlags ScalarEvolution::getNoWrapFlagsFromUB(const Value *V) { if (isa<ConstantExpr>(V)) return SCEV::FlagAnyWrap; const BinaryOperator *BinOp = cast<BinaryOperator>(V); @@ -4418,273 +4786,363 @@ SCEV::NoWrapFlags ScalarEvolution::getNoWrapFlagsFromUB(const Value *V) { Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW); if (BinOp->hasNoSignedWrap()) Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNSW); - if (Flags == SCEV::FlagAnyWrap) { + if (Flags == SCEV::FlagAnyWrap) return SCEV::FlagAnyWrap; - } - // Here we check that BinOp is in the header of the innermost loop - // containing BinOp, since we only deal with instructions in the loop - // header. The actual loop we need to check later will come from an add - // recurrence, but getting that requires computing the SCEV of the operands, - // which can be expensive. This check we can do cheaply to rule out some - // cases early. - Loop *innermostContainingLoop = LI.getLoopFor(BinOp->getParent()); - if (innermostContainingLoop == nullptr || - innermostContainingLoop->getHeader() != BinOp->getParent()) - return SCEV::FlagAnyWrap; + return isSCEVExprNeverPoison(BinOp) ? Flags : SCEV::FlagAnyWrap; +} - // Only proceed if we can prove that BinOp does not yield poison. - if (!isKnownNotFullPoison(BinOp)) return SCEV::FlagAnyWrap; +bool ScalarEvolution::isSCEVExprNeverPoison(const Instruction *I) { + // Here we check that I is in the header of the innermost loop containing I, + // since we only deal with instructions in the loop header. The actual loop we + // need to check later will come from an add recurrence, but getting that + // requires computing the SCEV of the operands, which can be expensive. This + // check we can do cheaply to rule out some cases early. + Loop *InnermostContainingLoop = LI.getLoopFor(I->getParent()); + if (InnermostContainingLoop == nullptr || + InnermostContainingLoop->getHeader() != I->getParent()) + return false; + + // Only proceed if we can prove that I does not yield poison. + if (!isKnownNotFullPoison(I)) return false; - // At this point we know that if V is executed, then it does not wrap - // according to at least one of NSW or NUW. If V is not executed, then we do - // not know if the calculation that V represents would wrap. Multiple - // instructions can map to the same SCEV. If we apply NSW or NUW from V to + // At this point we know that if I is executed, then it does not wrap + // according to at least one of NSW or NUW. If I is not executed, then we do + // not know if the calculation that I represents would wrap. Multiple + // instructions can map to the same SCEV. If we apply NSW or NUW from I to // the SCEV, we must guarantee no wrapping for that SCEV also when it is // derived from other instructions that map to the same SCEV. We cannot make - // that guarantee for cases where V is not executed. So we need to find the - // loop that V is considered in relation to and prove that V is executed for - // every iteration of that loop. That implies that the value that V + // that guarantee for cases where I is not executed. So we need to find the + // loop that I is considered in relation to and prove that I is executed for + // every iteration of that loop. That implies that the value that I // calculates does not wrap anywhere in the loop, so then we can apply the // flags to the SCEV. // - // We check isLoopInvariant to disambiguate in case we are adding two - // recurrences from different loops, so that we know which loop to prove - // that V is executed in. - for (int OpIndex = 0; OpIndex < 2; ++OpIndex) { - const SCEV *Op = getSCEV(BinOp->getOperand(OpIndex)); + // We check isLoopInvariant to disambiguate in case we are adding recurrences + // from different loops, so that we know which loop to prove that I is + // executed in. + for (unsigned OpIndex = 0; OpIndex < I->getNumOperands(); ++OpIndex) { + // I could be an extractvalue from a call to an overflow intrinsic. + // TODO: We can do better here in some cases. + if (!isSCEVable(I->getOperand(OpIndex)->getType())) + return false; + const SCEV *Op = getSCEV(I->getOperand(OpIndex)); if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(Op)) { - const int OtherOpIndex = 1 - OpIndex; - const SCEV *OtherOp = getSCEV(BinOp->getOperand(OtherOpIndex)); - if (isLoopInvariant(OtherOp, AddRec->getLoop()) && - isGuaranteedToExecuteForEveryIteration(BinOp, AddRec->getLoop())) - return Flags; + bool AllOtherOpsLoopInvariant = true; + for (unsigned OtherOpIndex = 0; OtherOpIndex < I->getNumOperands(); + ++OtherOpIndex) { + if (OtherOpIndex != OpIndex) { + const SCEV *OtherOp = getSCEV(I->getOperand(OtherOpIndex)); + if (!isLoopInvariant(OtherOp, AddRec->getLoop())) { + AllOtherOpsLoopInvariant = false; + break; + } + } + } + if (AllOtherOpsLoopInvariant && + isGuaranteedToExecuteForEveryIteration(I, AddRec->getLoop())) + return true; } } - return SCEV::FlagAnyWrap; + return false; +} + +bool ScalarEvolution::isAddRecNeverPoison(const Instruction *I, const Loop *L) { + // If we know that \c I can never be poison period, then that's enough. + if (isSCEVExprNeverPoison(I)) + return true; + + // For an add recurrence specifically, we assume that infinite loops without + // side effects are undefined behavior, and then reason as follows: + // + // If the add recurrence is poison in any iteration, it is poison on all + // future iterations (since incrementing poison yields poison). If the result + // of the add recurrence is fed into the loop latch condition and the loop + // does not contain any throws or exiting blocks other than the latch, we now + // have the ability to "choose" whether the backedge is taken or not (by + // choosing a sufficiently evil value for the poison feeding into the branch) + // for every iteration including and after the one in which \p I first became + // poison. There are two possibilities (let's call the iteration in which \p + // I first became poison as K): + // + // 1. In the set of iterations including and after K, the loop body executes + // no side effects. In this case executing the backege an infinte number + // of times will yield undefined behavior. + // + // 2. In the set of iterations including and after K, the loop body executes + // at least one side effect. In this case, that specific instance of side + // effect is control dependent on poison, which also yields undefined + // behavior. + + auto *ExitingBB = L->getExitingBlock(); + auto *LatchBB = L->getLoopLatch(); + if (!ExitingBB || !LatchBB || ExitingBB != LatchBB) + return false; + + SmallPtrSet<const Instruction *, 16> Pushed; + SmallVector<const Instruction *, 8> PoisonStack; + + // We start by assuming \c I, the post-inc add recurrence, is poison. Only + // things that are known to be fully poison under that assumption go on the + // PoisonStack. + Pushed.insert(I); + PoisonStack.push_back(I); + + bool LatchControlDependentOnPoison = false; + while (!PoisonStack.empty() && !LatchControlDependentOnPoison) { + const Instruction *Poison = PoisonStack.pop_back_val(); + + for (auto *PoisonUser : Poison->users()) { + if (propagatesFullPoison(cast<Instruction>(PoisonUser))) { + if (Pushed.insert(cast<Instruction>(PoisonUser)).second) + PoisonStack.push_back(cast<Instruction>(PoisonUser)); + } else if (auto *BI = dyn_cast<BranchInst>(PoisonUser)) { + assert(BI->isConditional() && "Only possibility!"); + if (BI->getParent() == LatchBB) { + LatchControlDependentOnPoison = true; + break; + } + } + } + } + + return LatchControlDependentOnPoison && loopHasNoAbnormalExits(L); +} + +bool ScalarEvolution::loopHasNoAbnormalExits(const Loop *L) { + auto Itr = LoopHasNoAbnormalExits.find(L); + if (Itr == LoopHasNoAbnormalExits.end()) { + auto NoAbnormalExitInBB = [&](BasicBlock *BB) { + return all_of(*BB, [](Instruction &I) { + return isGuaranteedToTransferExecutionToSuccessor(&I); + }); + }; + + auto InsertPair = LoopHasNoAbnormalExits.insert( + {L, all_of(L->getBlocks(), NoAbnormalExitInBB)}); + assert(InsertPair.second && "We just checked!"); + Itr = InsertPair.first; + } + + return Itr->second; } -/// createSCEV - We know that there is no SCEV for the specified value. Analyze -/// the expression. -/// const SCEV *ScalarEvolution::createSCEV(Value *V) { if (!isSCEVable(V->getType())) return getUnknown(V); - unsigned Opcode = Instruction::UserOp1; if (Instruction *I = dyn_cast<Instruction>(V)) { - Opcode = I->getOpcode(); - // Don't attempt to analyze instructions in blocks that aren't // reachable. Such instructions don't matter, and they aren't required // to obey basic rules for definitions dominating uses which this // analysis depends on. if (!DT.isReachableFromEntry(I->getParent())) return getUnknown(V); - } else if (ConstantExpr *CE = dyn_cast<ConstantExpr>(V)) - Opcode = CE->getOpcode(); - else if (ConstantInt *CI = dyn_cast<ConstantInt>(V)) + } else if (ConstantInt *CI = dyn_cast<ConstantInt>(V)) return getConstant(CI); else if (isa<ConstantPointerNull>(V)) return getZero(V->getType()); else if (GlobalAlias *GA = dyn_cast<GlobalAlias>(V)) - return GA->mayBeOverridden() ? getUnknown(V) : getSCEV(GA->getAliasee()); - else + return GA->isInterposable() ? getUnknown(V) : getSCEV(GA->getAliasee()); + else if (!isa<ConstantExpr>(V)) return getUnknown(V); Operator *U = cast<Operator>(V); - switch (Opcode) { - case Instruction::Add: { - // The simple thing to do would be to just call getSCEV on both operands - // and call getAddExpr with the result. However if we're looking at a - // bunch of things all added together, this can be quite inefficient, - // because it leads to N-1 getAddExpr calls for N ultimate operands. - // Instead, gather up all the operands and make a single getAddExpr call. - // LLVM IR canonical form means we need only traverse the left operands. - SmallVector<const SCEV *, 4> AddOps; - for (Value *Op = U;; Op = U->getOperand(0)) { - U = dyn_cast<Operator>(Op); - unsigned Opcode = U ? U->getOpcode() : 0; - if (!U || (Opcode != Instruction::Add && Opcode != Instruction::Sub)) { - assert(Op != V && "V should be an add"); - AddOps.push_back(getSCEV(Op)); - break; - } + if (auto BO = MatchBinaryOp(U, DT)) { + switch (BO->Opcode) { + case Instruction::Add: { + // The simple thing to do would be to just call getSCEV on both operands + // and call getAddExpr with the result. However if we're looking at a + // bunch of things all added together, this can be quite inefficient, + // because it leads to N-1 getAddExpr calls for N ultimate operands. + // Instead, gather up all the operands and make a single getAddExpr call. + // LLVM IR canonical form means we need only traverse the left operands. + SmallVector<const SCEV *, 4> AddOps; + do { + if (BO->Op) { + if (auto *OpSCEV = getExistingSCEV(BO->Op)) { + AddOps.push_back(OpSCEV); + break; + } - if (auto *OpSCEV = getExistingSCEV(U)) { - AddOps.push_back(OpSCEV); - break; - } + // If a NUW or NSW flag can be applied to the SCEV for this + // addition, then compute the SCEV for this addition by itself + // with a separate call to getAddExpr. We need to do that + // instead of pushing the operands of the addition onto AddOps, + // since the flags are only known to apply to this particular + // addition - they may not apply to other additions that can be + // formed with operands from AddOps. + const SCEV *RHS = getSCEV(BO->RHS); + SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op); + if (Flags != SCEV::FlagAnyWrap) { + const SCEV *LHS = getSCEV(BO->LHS); + if (BO->Opcode == Instruction::Sub) + AddOps.push_back(getMinusSCEV(LHS, RHS, Flags)); + else + AddOps.push_back(getAddExpr(LHS, RHS, Flags)); + break; + } + } - // If a NUW or NSW flag can be applied to the SCEV for this - // addition, then compute the SCEV for this addition by itself - // with a separate call to getAddExpr. We need to do that - // instead of pushing the operands of the addition onto AddOps, - // since the flags are only known to apply to this particular - // addition - they may not apply to other additions that can be - // formed with operands from AddOps. - const SCEV *RHS = getSCEV(U->getOperand(1)); - SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(U); - if (Flags != SCEV::FlagAnyWrap) { - const SCEV *LHS = getSCEV(U->getOperand(0)); - if (Opcode == Instruction::Sub) - AddOps.push_back(getMinusSCEV(LHS, RHS, Flags)); + if (BO->Opcode == Instruction::Sub) + AddOps.push_back(getNegativeSCEV(getSCEV(BO->RHS))); else - AddOps.push_back(getAddExpr(LHS, RHS, Flags)); - break; - } + AddOps.push_back(getSCEV(BO->RHS)); - if (Opcode == Instruction::Sub) - AddOps.push_back(getNegativeSCEV(RHS)); - else - AddOps.push_back(RHS); + auto NewBO = MatchBinaryOp(BO->LHS, DT); + if (!NewBO || (NewBO->Opcode != Instruction::Add && + NewBO->Opcode != Instruction::Sub)) { + AddOps.push_back(getSCEV(BO->LHS)); + break; + } + BO = NewBO; + } while (true); + + return getAddExpr(AddOps); } - return getAddExpr(AddOps); - } - case Instruction::Mul: { - SmallVector<const SCEV *, 4> MulOps; - for (Value *Op = U;; Op = U->getOperand(0)) { - U = dyn_cast<Operator>(Op); - if (!U || U->getOpcode() != Instruction::Mul) { - assert(Op != V && "V should be a mul"); - MulOps.push_back(getSCEV(Op)); - break; - } + case Instruction::Mul: { + SmallVector<const SCEV *, 4> MulOps; + do { + if (BO->Op) { + if (auto *OpSCEV = getExistingSCEV(BO->Op)) { + MulOps.push_back(OpSCEV); + break; + } - if (auto *OpSCEV = getExistingSCEV(U)) { - MulOps.push_back(OpSCEV); - break; - } + SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op); + if (Flags != SCEV::FlagAnyWrap) { + MulOps.push_back( + getMulExpr(getSCEV(BO->LHS), getSCEV(BO->RHS), Flags)); + break; + } + } - SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(U); - if (Flags != SCEV::FlagAnyWrap) { - MulOps.push_back(getMulExpr(getSCEV(U->getOperand(0)), - getSCEV(U->getOperand(1)), Flags)); - break; + MulOps.push_back(getSCEV(BO->RHS)); + auto NewBO = MatchBinaryOp(BO->LHS, DT); + if (!NewBO || NewBO->Opcode != Instruction::Mul) { + MulOps.push_back(getSCEV(BO->LHS)); + break; + } + BO = NewBO; + } while (true); + + return getMulExpr(MulOps); + } + case Instruction::UDiv: + return getUDivExpr(getSCEV(BO->LHS), getSCEV(BO->RHS)); + case Instruction::Sub: { + SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap; + if (BO->Op) + Flags = getNoWrapFlagsFromUB(BO->Op); + return getMinusSCEV(getSCEV(BO->LHS), getSCEV(BO->RHS), Flags); + } + case Instruction::And: + // For an expression like x&255 that merely masks off the high bits, + // use zext(trunc(x)) as the SCEV expression. + if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) { + if (CI->isNullValue()) + return getSCEV(BO->RHS); + if (CI->isAllOnesValue()) + return getSCEV(BO->LHS); + const APInt &A = CI->getValue(); + + // Instcombine's ShrinkDemandedConstant may strip bits out of + // constants, obscuring what would otherwise be a low-bits mask. + // Use computeKnownBits to compute what ShrinkDemandedConstant + // knew about to reconstruct a low-bits mask value. + unsigned LZ = A.countLeadingZeros(); + unsigned TZ = A.countTrailingZeros(); + unsigned BitWidth = A.getBitWidth(); + APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0); + computeKnownBits(BO->LHS, KnownZero, KnownOne, getDataLayout(), + 0, &AC, nullptr, &DT); + + APInt EffectiveMask = + APInt::getLowBitsSet(BitWidth, BitWidth - LZ - TZ).shl(TZ); + if ((LZ != 0 || TZ != 0) && !((~A & ~KnownZero) & EffectiveMask)) { + const SCEV *MulCount = getConstant(ConstantInt::get( + getContext(), APInt::getOneBitSet(BitWidth, TZ))); + return getMulExpr( + getZeroExtendExpr( + getTruncateExpr( + getUDivExactExpr(getSCEV(BO->LHS), MulCount), + IntegerType::get(getContext(), BitWidth - LZ - TZ)), + BO->LHS->getType()), + MulCount); + } } + break; - MulOps.push_back(getSCEV(U->getOperand(1))); - } - return getMulExpr(MulOps); - } - case Instruction::UDiv: - return getUDivExpr(getSCEV(U->getOperand(0)), - getSCEV(U->getOperand(1))); - case Instruction::Sub: - return getMinusSCEV(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1)), - getNoWrapFlagsFromUB(U)); - case Instruction::And: - // For an expression like x&255 that merely masks off the high bits, - // use zext(trunc(x)) as the SCEV expression. - if (ConstantInt *CI = dyn_cast<ConstantInt>(U->getOperand(1))) { - if (CI->isNullValue()) - return getSCEV(U->getOperand(1)); - if (CI->isAllOnesValue()) - return getSCEV(U->getOperand(0)); - const APInt &A = CI->getValue(); - - // Instcombine's ShrinkDemandedConstant may strip bits out of - // constants, obscuring what would otherwise be a low-bits mask. - // Use computeKnownBits to compute what ShrinkDemandedConstant - // knew about to reconstruct a low-bits mask value. - unsigned LZ = A.countLeadingZeros(); - unsigned TZ = A.countTrailingZeros(); - unsigned BitWidth = A.getBitWidth(); - APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0); - computeKnownBits(U->getOperand(0), KnownZero, KnownOne, getDataLayout(), - 0, &AC, nullptr, &DT); - - APInt EffectiveMask = - APInt::getLowBitsSet(BitWidth, BitWidth - LZ - TZ).shl(TZ); - if ((LZ != 0 || TZ != 0) && !((~A & ~KnownZero) & EffectiveMask)) { - const SCEV *MulCount = getConstant( - ConstantInt::get(getContext(), APInt::getOneBitSet(BitWidth, TZ))); - return getMulExpr( - getZeroExtendExpr( - getTruncateExpr( - getUDivExactExpr(getSCEV(U->getOperand(0)), MulCount), - IntegerType::get(getContext(), BitWidth - LZ - TZ)), - U->getType()), - MulCount); + case Instruction::Or: + // If the RHS of the Or is a constant, we may have something like: + // X*4+1 which got turned into X*4|1. Handle this as an Add so loop + // optimizations will transparently handle this case. + // + // In order for this transformation to be safe, the LHS must be of the + // form X*(2^n) and the Or constant must be less than 2^n. + if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) { + const SCEV *LHS = getSCEV(BO->LHS); + const APInt &CIVal = CI->getValue(); + if (GetMinTrailingZeros(LHS) >= + (CIVal.getBitWidth() - CIVal.countLeadingZeros())) { + // Build a plain add SCEV. + const SCEV *S = getAddExpr(LHS, getSCEV(CI)); + // If the LHS of the add was an addrec and it has no-wrap flags, + // transfer the no-wrap flags, since an or won't introduce a wrap. + if (const SCEVAddRecExpr *NewAR = dyn_cast<SCEVAddRecExpr>(S)) { + const SCEVAddRecExpr *OldAR = cast<SCEVAddRecExpr>(LHS); + const_cast<SCEVAddRecExpr *>(NewAR)->setNoWrapFlags( + OldAR->getNoWrapFlags()); + } + return S; + } } - } - break; + break; - case Instruction::Or: - // If the RHS of the Or is a constant, we may have something like: - // X*4+1 which got turned into X*4|1. Handle this as an Add so loop - // optimizations will transparently handle this case. - // - // In order for this transformation to be safe, the LHS must be of the - // form X*(2^n) and the Or constant must be less than 2^n. - if (ConstantInt *CI = dyn_cast<ConstantInt>(U->getOperand(1))) { - const SCEV *LHS = getSCEV(U->getOperand(0)); - const APInt &CIVal = CI->getValue(); - if (GetMinTrailingZeros(LHS) >= - (CIVal.getBitWidth() - CIVal.countLeadingZeros())) { - // Build a plain add SCEV. - const SCEV *S = getAddExpr(LHS, getSCEV(CI)); - // If the LHS of the add was an addrec and it has no-wrap flags, - // transfer the no-wrap flags, since an or won't introduce a wrap. - if (const SCEVAddRecExpr *NewAR = dyn_cast<SCEVAddRecExpr>(S)) { - const SCEVAddRecExpr *OldAR = cast<SCEVAddRecExpr>(LHS); - const_cast<SCEVAddRecExpr *>(NewAR)->setNoWrapFlags( - OldAR->getNoWrapFlags()); - } - return S; + case Instruction::Xor: + if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) { + // If the RHS of xor is -1, then this is a not operation. + if (CI->isAllOnesValue()) + return getNotSCEV(getSCEV(BO->LHS)); + + // Model xor(and(x, C), C) as and(~x, C), if C is a low-bits mask. + // This is a variant of the check for xor with -1, and it handles + // the case where instcombine has trimmed non-demanded bits out + // of an xor with -1. + if (auto *LBO = dyn_cast<BinaryOperator>(BO->LHS)) + if (ConstantInt *LCI = dyn_cast<ConstantInt>(LBO->getOperand(1))) + if (LBO->getOpcode() == Instruction::And && + LCI->getValue() == CI->getValue()) + if (const SCEVZeroExtendExpr *Z = + dyn_cast<SCEVZeroExtendExpr>(getSCEV(BO->LHS))) { + Type *UTy = BO->LHS->getType(); + const SCEV *Z0 = Z->getOperand(); + Type *Z0Ty = Z0->getType(); + unsigned Z0TySize = getTypeSizeInBits(Z0Ty); + + // If C is a low-bits mask, the zero extend is serving to + // mask off the high bits. Complement the operand and + // re-apply the zext. + if (APIntOps::isMask(Z0TySize, CI->getValue())) + return getZeroExtendExpr(getNotSCEV(Z0), UTy); + + // If C is a single bit, it may be in the sign-bit position + // before the zero-extend. In this case, represent the xor + // using an add, which is equivalent, and re-apply the zext. + APInt Trunc = CI->getValue().trunc(Z0TySize); + if (Trunc.zext(getTypeSizeInBits(UTy)) == CI->getValue() && + Trunc.isSignBit()) + return getZeroExtendExpr(getAddExpr(Z0, getConstant(Trunc)), + UTy); + } } - } - break; - case Instruction::Xor: - if (ConstantInt *CI = dyn_cast<ConstantInt>(U->getOperand(1))) { - // If the RHS of the xor is a signbit, then this is just an add. - // Instcombine turns add of signbit into xor as a strength reduction step. - if (CI->getValue().isSignBit()) - return getAddExpr(getSCEV(U->getOperand(0)), - getSCEV(U->getOperand(1))); - - // If the RHS of xor is -1, then this is a not operation. - if (CI->isAllOnesValue()) - return getNotSCEV(getSCEV(U->getOperand(0))); - - // Model xor(and(x, C), C) as and(~x, C), if C is a low-bits mask. - // This is a variant of the check for xor with -1, and it handles - // the case where instcombine has trimmed non-demanded bits out - // of an xor with -1. - if (BinaryOperator *BO = dyn_cast<BinaryOperator>(U->getOperand(0))) - if (ConstantInt *LCI = dyn_cast<ConstantInt>(BO->getOperand(1))) - if (BO->getOpcode() == Instruction::And && - LCI->getValue() == CI->getValue()) - if (const SCEVZeroExtendExpr *Z = - dyn_cast<SCEVZeroExtendExpr>(getSCEV(U->getOperand(0)))) { - Type *UTy = U->getType(); - const SCEV *Z0 = Z->getOperand(); - Type *Z0Ty = Z0->getType(); - unsigned Z0TySize = getTypeSizeInBits(Z0Ty); - - // If C is a low-bits mask, the zero extend is serving to - // mask off the high bits. Complement the operand and - // re-apply the zext. - if (APIntOps::isMask(Z0TySize, CI->getValue())) - return getZeroExtendExpr(getNotSCEV(Z0), UTy); - - // If C is a single bit, it may be in the sign-bit position - // before the zero-extend. In this case, represent the xor - // using an add, which is equivalent, and re-apply the zext. - APInt Trunc = CI->getValue().trunc(Z0TySize); - if (Trunc.zext(getTypeSizeInBits(UTy)) == CI->getValue() && - Trunc.isSignBit()) - return getZeroExtendExpr(getAddExpr(Z0, getConstant(Trunc)), - UTy); - } - } - break; + break; case Instruction::Shl: // Turn shift left of a constant amount into a multiply. - if (ConstantInt *SA = dyn_cast<ConstantInt>(U->getOperand(1))) { - uint32_t BitWidth = cast<IntegerType>(U->getType())->getBitWidth(); + if (ConstantInt *SA = dyn_cast<ConstantInt>(BO->RHS)) { + uint32_t BitWidth = cast<IntegerType>(SA->getType())->getBitWidth(); // If the shift count is not less than the bitwidth, the result of // the shift is undefined. Don't try to analyze it, because the @@ -4700,58 +5158,43 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { // http://lists.llvm.org/pipermail/llvm-dev/2015-April/084195.html // and http://reviews.llvm.org/D8890 . auto Flags = SCEV::FlagAnyWrap; - if (SA->getValue().ult(BitWidth - 1)) Flags = getNoWrapFlagsFromUB(U); + if (BO->Op && SA->getValue().ult(BitWidth - 1)) + Flags = getNoWrapFlagsFromUB(BO->Op); Constant *X = ConstantInt::get(getContext(), APInt::getOneBitSet(BitWidth, SA->getZExtValue())); - return getMulExpr(getSCEV(U->getOperand(0)), getSCEV(X), Flags); + return getMulExpr(getSCEV(BO->LHS), getSCEV(X), Flags); } break; - case Instruction::LShr: - // Turn logical shift right of a constant into a unsigned divide. - if (ConstantInt *SA = dyn_cast<ConstantInt>(U->getOperand(1))) { - uint32_t BitWidth = cast<IntegerType>(U->getType())->getBitWidth(); - - // If the shift count is not less than the bitwidth, the result of - // the shift is undefined. Don't try to analyze it, because the - // resolution chosen here may differ from the resolution chosen in - // other parts of the compiler. - if (SA->getValue().uge(BitWidth)) - break; + case Instruction::AShr: + // For a two-shift sext-inreg, use sext(trunc(x)) as the SCEV expression. + if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) + if (Operator *L = dyn_cast<Operator>(BO->LHS)) + if (L->getOpcode() == Instruction::Shl && + L->getOperand(1) == BO->RHS) { + uint64_t BitWidth = getTypeSizeInBits(BO->LHS->getType()); + + // If the shift count is not less than the bitwidth, the result of + // the shift is undefined. Don't try to analyze it, because the + // resolution chosen here may differ from the resolution chosen in + // other parts of the compiler. + if (CI->getValue().uge(BitWidth)) + break; - Constant *X = ConstantInt::get(getContext(), - APInt::getOneBitSet(BitWidth, SA->getZExtValue())); - return getUDivExpr(getSCEV(U->getOperand(0)), getSCEV(X)); + uint64_t Amt = BitWidth - CI->getZExtValue(); + if (Amt == BitWidth) + return getSCEV(L->getOperand(0)); // shift by zero --> noop + return getSignExtendExpr( + getTruncateExpr(getSCEV(L->getOperand(0)), + IntegerType::get(getContext(), Amt)), + BO->LHS->getType()); + } + break; } - break; - - case Instruction::AShr: - // For a two-shift sext-inreg, use sext(trunc(x)) as the SCEV expression. - if (ConstantInt *CI = dyn_cast<ConstantInt>(U->getOperand(1))) - if (Operator *L = dyn_cast<Operator>(U->getOperand(0))) - if (L->getOpcode() == Instruction::Shl && - L->getOperand(1) == U->getOperand(1)) { - uint64_t BitWidth = getTypeSizeInBits(U->getType()); - - // If the shift count is not less than the bitwidth, the result of - // the shift is undefined. Don't try to analyze it, because the - // resolution chosen here may differ from the resolution chosen in - // other parts of the compiler. - if (CI->getValue().uge(BitWidth)) - break; - - uint64_t Amt = BitWidth - CI->getZExtValue(); - if (Amt == BitWidth) - return getSCEV(L->getOperand(0)); // shift by zero --> noop - return - getSignExtendExpr(getTruncateExpr(getSCEV(L->getOperand(0)), - IntegerType::get(getContext(), - Amt)), - U->getType()); - } - break; + } + switch (U->getOpcode()) { case Instruction::Trunc: return getTruncateExpr(getSCEV(U->getOperand(0)), U->getType()); @@ -4786,8 +5229,12 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { if (isa<Instruction>(U)) return createNodeForSelectOrPHI(cast<Instruction>(U), U->getOperand(0), U->getOperand(1), U->getOperand(2)); + break; - default: // We cannot analyze this expression. + case Instruction::Call: + case Instruction::Invoke: + if (Value *RV = CallSite(U).getReturnedArgOperand()) + return getSCEV(RV); break; } @@ -4808,16 +5255,6 @@ unsigned ScalarEvolution::getSmallConstantTripCount(Loop *L) { return 0; } -/// getSmallConstantTripCount - Returns the maximum trip count of this loop as a -/// normal unsigned value. Returns 0 if the trip count is unknown or not -/// constant. Will also return 0 if the maximum trip count is very large (>= -/// 2^32). -/// -/// This "trip count" assumes that control exits via ExitingBlock. More -/// precisely, it is the number of times that control may reach ExitingBlock -/// before taking the branch. For loops with multiple exits, it may not be the -/// number times that the loop header executes because the loop may exit -/// prematurely via another branch. unsigned ScalarEvolution::getSmallConstantTripCount(Loop *L, BasicBlock *ExitingBlock) { assert(ExitingBlock && "Must pass a non-null exiting block!"); @@ -4846,10 +5283,10 @@ unsigned ScalarEvolution::getSmallConstantTripMultiple(Loop *L) { return 0; } -/// getSmallConstantTripMultiple - Returns the largest constant divisor of the -/// trip count of this loop as a normal unsigned value, if possible. This -/// means that the actual trip count is always a multiple of the returned -/// value (don't forget the trip count could very well be zero as well!). +/// Returns the largest constant divisor of the trip count of this loop as a +/// normal unsigned value, if possible. This means that the actual trip count is +/// always a multiple of the returned value (don't forget the trip count could +/// very well be zero as well!). /// /// Returns 1 if the trip count is unknown or not guaranteed to be the /// multiple of a constant (which is also the case if the trip count is simply @@ -4891,37 +5328,30 @@ ScalarEvolution::getSmallConstantTripMultiple(Loop *L, return (unsigned)Result->getZExtValue(); } -// getExitCount - Get the expression for the number of loop iterations for which -// this loop is guaranteed not to exit via ExitingBlock. Otherwise return -// SCEVCouldNotCompute. +/// Get the expression for the number of loop iterations for which this loop is +/// guaranteed not to exit via ExitingBlock. Otherwise return +/// SCEVCouldNotCompute. const SCEV *ScalarEvolution::getExitCount(Loop *L, BasicBlock *ExitingBlock) { return getBackedgeTakenInfo(L).getExact(ExitingBlock, this); } -/// getBackedgeTakenCount - If the specified loop has a predictable -/// backedge-taken count, return it, otherwise return a SCEVCouldNotCompute -/// object. The backedge-taken count is the number of times the loop header -/// will be branched to from within the loop. This is one less than the -/// trip count of the loop, since it doesn't count the first iteration, -/// when the header is branched to from outside the loop. -/// -/// Note that it is not valid to call this method on a loop without a -/// loop-invariant backedge-taken count (see -/// hasLoopInvariantBackedgeTakenCount). -/// +const SCEV * +ScalarEvolution::getPredicatedBackedgeTakenCount(const Loop *L, + SCEVUnionPredicate &Preds) { + return getPredicatedBackedgeTakenInfo(L).getExact(this, &Preds); +} + const SCEV *ScalarEvolution::getBackedgeTakenCount(const Loop *L) { return getBackedgeTakenInfo(L).getExact(this); } -/// getMaxBackedgeTakenCount - Similar to getBackedgeTakenCount, except -/// return the least SCEV value that is known never to be less than the -/// actual backedge taken count. +/// Similar to getBackedgeTakenCount, except return the least SCEV value that is +/// known never to be less than the actual backedge taken count. const SCEV *ScalarEvolution::getMaxBackedgeTakenCount(const Loop *L) { return getBackedgeTakenInfo(L).getMax(this); } -/// PushLoopPHIs - Push PHI nodes in the header of the given loop -/// onto the given Worklist. +/// Push PHI nodes in the header of the given loop onto the given Worklist. static void PushLoopPHIs(const Loop *L, SmallVectorImpl<Instruction *> &Worklist) { BasicBlock *Header = L->getHeader(); @@ -4933,6 +5363,23 @@ PushLoopPHIs(const Loop *L, SmallVectorImpl<Instruction *> &Worklist) { } const ScalarEvolution::BackedgeTakenInfo & +ScalarEvolution::getPredicatedBackedgeTakenInfo(const Loop *L) { + auto &BTI = getBackedgeTakenInfo(L); + if (BTI.hasFullInfo()) + return BTI; + + auto Pair = PredicatedBackedgeTakenCounts.insert({L, BackedgeTakenInfo()}); + + if (!Pair.second) + return Pair.first->second; + + BackedgeTakenInfo Result = + computeBackedgeTakenCount(L, /*AllowPredicates=*/true); + + return PredicatedBackedgeTakenCounts.find(L)->second = Result; +} + +const ScalarEvolution::BackedgeTakenInfo & ScalarEvolution::getBackedgeTakenInfo(const Loop *L) { // Initially insert an invalid entry for this loop. If the insertion // succeeds, proceed to actually compute a backedge-taken count and @@ -4940,7 +5387,7 @@ ScalarEvolution::getBackedgeTakenInfo(const Loop *L) { // code elsewhere that it shouldn't attempt to request a new // backedge-taken count, which could result in infinite recursion. std::pair<DenseMap<const Loop *, BackedgeTakenInfo>::iterator, bool> Pair = - BackedgeTakenCounts.insert(std::make_pair(L, BackedgeTakenInfo())); + BackedgeTakenCounts.insert({L, BackedgeTakenInfo()}); if (!Pair.second) return Pair.first->second; @@ -5007,17 +5454,19 @@ ScalarEvolution::getBackedgeTakenInfo(const Loop *L) { return BackedgeTakenCounts.find(L)->second = Result; } -/// forgetLoop - This method should be called by the client when it has -/// changed a loop in a way that may effect ScalarEvolution's ability to -/// compute a trip count, or if the loop is deleted. void ScalarEvolution::forgetLoop(const Loop *L) { // Drop any stored trip count value. - DenseMap<const Loop*, BackedgeTakenInfo>::iterator BTCPos = - BackedgeTakenCounts.find(L); - if (BTCPos != BackedgeTakenCounts.end()) { - BTCPos->second.clear(); - BackedgeTakenCounts.erase(BTCPos); - } + auto RemoveLoopFromBackedgeMap = + [L](DenseMap<const Loop *, BackedgeTakenInfo> &Map) { + auto BTCPos = Map.find(L); + if (BTCPos != Map.end()) { + BTCPos->second.clear(); + Map.erase(BTCPos); + } + }; + + RemoveLoopFromBackedgeMap(BackedgeTakenCounts); + RemoveLoopFromBackedgeMap(PredicatedBackedgeTakenCounts); // Drop information about expressions based on loop-header PHIs. SmallVector<Instruction *, 16> Worklist; @@ -5043,13 +5492,12 @@ void ScalarEvolution::forgetLoop(const Loop *L) { // Forget all contained loops too, to avoid dangling entries in the // ValuesAtScopes map. - for (Loop::iterator I = L->begin(), E = L->end(); I != E; ++I) - forgetLoop(*I); + for (Loop *I : *L) + forgetLoop(I); + + LoopHasNoAbnormalExits.erase(L); } -/// forgetValue - This method should be called by the client when it has -/// changed a value in a way that may effect its value, or which may -/// disconnect it from a def-use chain linking it to a loop. void ScalarEvolution::forgetValue(Value *V) { Instruction *I = dyn_cast<Instruction>(V); if (!I) return; @@ -5077,16 +5525,17 @@ void ScalarEvolution::forgetValue(Value *V) { } } -/// getExact - Get the exact loop backedge taken count considering all loop -/// exits. A computable result can only be returned for loops with a single -/// exit. Returning the minimum taken count among all exits is incorrect -/// because one of the loop's exit limit's may have been skipped. HowFarToZero -/// assumes that the limit of each loop test is never skipped. This is a valid -/// assumption as long as the loop exits via that test. For precise results, it -/// is the caller's responsibility to specify the relevant loop exit using +/// Get the exact loop backedge taken count considering all loop exits. A +/// computable result can only be returned for loops with a single exit. +/// Returning the minimum taken count among all exits is incorrect because one +/// of the loop's exit limit's may have been skipped. howFarToZero assumes that +/// the limit of each loop test is never skipped. This is a valid assumption as +/// long as the loop exits via that test. For precise results, it is the +/// caller's responsibility to specify the relevant loop exit using /// getExact(ExitingBlock, SE). const SCEV * -ScalarEvolution::BackedgeTakenInfo::getExact(ScalarEvolution *SE) const { +ScalarEvolution::BackedgeTakenInfo::getExact( + ScalarEvolution *SE, SCEVUnionPredicate *Preds) const { // If any exits were not computable, the loop is not computable. if (!ExitNotTaken.isCompleteList()) return SE->getCouldNotCompute(); @@ -5095,36 +5544,42 @@ ScalarEvolution::BackedgeTakenInfo::getExact(ScalarEvolution *SE) const { assert(ExitNotTaken.ExactNotTaken && "uninitialized not-taken info"); const SCEV *BECount = nullptr; - for (const ExitNotTakenInfo *ENT = &ExitNotTaken; - ENT != nullptr; ENT = ENT->getNextExit()) { - - assert(ENT->ExactNotTaken != SE->getCouldNotCompute() && "bad exit SCEV"); + for (auto &ENT : ExitNotTaken) { + assert(ENT.ExactNotTaken != SE->getCouldNotCompute() && "bad exit SCEV"); if (!BECount) - BECount = ENT->ExactNotTaken; - else if (BECount != ENT->ExactNotTaken) + BECount = ENT.ExactNotTaken; + else if (BECount != ENT.ExactNotTaken) return SE->getCouldNotCompute(); + if (Preds && ENT.getPred()) + Preds->add(ENT.getPred()); + + assert((Preds || ENT.hasAlwaysTruePred()) && + "Predicate should be always true!"); } + assert(BECount && "Invalid not taken count for loop exit"); return BECount; } -/// getExact - Get the exact not taken count for this loop exit. +/// Get the exact not taken count for this loop exit. const SCEV * ScalarEvolution::BackedgeTakenInfo::getExact(BasicBlock *ExitingBlock, ScalarEvolution *SE) const { - for (const ExitNotTakenInfo *ENT = &ExitNotTaken; - ENT != nullptr; ENT = ENT->getNextExit()) { + for (auto &ENT : ExitNotTaken) + if (ENT.ExitingBlock == ExitingBlock && ENT.hasAlwaysTruePred()) + return ENT.ExactNotTaken; - if (ENT->ExitingBlock == ExitingBlock) - return ENT->ExactNotTaken; - } return SE->getCouldNotCompute(); } /// getMax - Get the max backedge taken count for the loop. const SCEV * ScalarEvolution::BackedgeTakenInfo::getMax(ScalarEvolution *SE) const { + for (auto &ENT : ExitNotTaken) + if (!ENT.hasAlwaysTruePred()) + return SE->getCouldNotCompute(); + return Max ? Max : SE->getCouldNotCompute(); } @@ -5136,22 +5591,19 @@ bool ScalarEvolution::BackedgeTakenInfo::hasOperand(const SCEV *S, if (!ExitNotTaken.ExitingBlock) return false; - for (const ExitNotTakenInfo *ENT = &ExitNotTaken; - ENT != nullptr; ENT = ENT->getNextExit()) { - - if (ENT->ExactNotTaken != SE->getCouldNotCompute() - && SE->hasOperand(ENT->ExactNotTaken, S)) { + for (auto &ENT : ExitNotTaken) + if (ENT.ExactNotTaken != SE->getCouldNotCompute() && + SE->hasOperand(ENT.ExactNotTaken, S)) return true; - } - } + return false; } /// Allocate memory for BackedgeTakenInfo and copy the not-taken count of each /// computable exit into a persistent ExitNotTakenInfo array. ScalarEvolution::BackedgeTakenInfo::BackedgeTakenInfo( - SmallVectorImpl< std::pair<BasicBlock *, const SCEV *> > &ExitCounts, - bool Complete, const SCEV *MaxCount) : Max(MaxCount) { + SmallVectorImpl<EdgeInfo> &ExitCounts, bool Complete, const SCEV *MaxCount) + : Max(MaxCount) { if (!Complete) ExitNotTaken.setIncomplete(); @@ -5159,36 +5611,63 @@ ScalarEvolution::BackedgeTakenInfo::BackedgeTakenInfo( unsigned NumExits = ExitCounts.size(); if (NumExits == 0) return; - ExitNotTaken.ExitingBlock = ExitCounts[0].first; - ExitNotTaken.ExactNotTaken = ExitCounts[0].second; - if (NumExits == 1) return; + ExitNotTaken.ExitingBlock = ExitCounts[0].ExitBlock; + ExitNotTaken.ExactNotTaken = ExitCounts[0].Taken; + + // Determine the number of ExitNotTakenExtras structures that we need. + unsigned ExtraInfoSize = 0; + if (NumExits > 1) + ExtraInfoSize = 1 + std::count_if(std::next(ExitCounts.begin()), + ExitCounts.end(), [](EdgeInfo &Entry) { + return !Entry.Pred.isAlwaysTrue(); + }); + else if (!ExitCounts[0].Pred.isAlwaysTrue()) + ExtraInfoSize = 1; + + ExitNotTakenExtras *ENT = nullptr; + + // Allocate the ExitNotTakenExtras structures and initialize the first + // element (ExitNotTaken). + if (ExtraInfoSize > 0) { + ENT = new ExitNotTakenExtras[ExtraInfoSize]; + ExitNotTaken.ExtraInfo = &ENT[0]; + *ExitNotTaken.getPred() = std::move(ExitCounts[0].Pred); + } + + if (NumExits == 1) + return; + + assert(ENT && "ExitNotTakenExtras is NULL while having more than one exit"); + + auto &Exits = ExitNotTaken.ExtraInfo->Exits; // Handle the rare case of multiple computable exits. - ExitNotTakenInfo *ENT = new ExitNotTakenInfo[NumExits-1]; + for (unsigned i = 1, PredPos = 1; i < NumExits; ++i) { + ExitNotTakenExtras *Ptr = nullptr; + if (!ExitCounts[i].Pred.isAlwaysTrue()) { + Ptr = &ENT[PredPos++]; + Ptr->Pred = std::move(ExitCounts[i].Pred); + } - ExitNotTakenInfo *PrevENT = &ExitNotTaken; - for (unsigned i = 1; i < NumExits; ++i, PrevENT = ENT, ++ENT) { - PrevENT->setNextExit(ENT); - ENT->ExitingBlock = ExitCounts[i].first; - ENT->ExactNotTaken = ExitCounts[i].second; + Exits.emplace_back(ExitCounts[i].ExitBlock, ExitCounts[i].Taken, Ptr); } } -/// clear - Invalidate this result and free the ExitNotTakenInfo array. +/// Invalidate this result and free the ExitNotTakenInfo array. void ScalarEvolution::BackedgeTakenInfo::clear() { ExitNotTaken.ExitingBlock = nullptr; ExitNotTaken.ExactNotTaken = nullptr; - delete[] ExitNotTaken.getNextExit(); + delete[] ExitNotTaken.ExtraInfo; } -/// computeBackedgeTakenCount - Compute the number of times the backedge -/// of the specified loop will execute. +/// Compute the number of times the backedge of the specified loop will execute. ScalarEvolution::BackedgeTakenInfo -ScalarEvolution::computeBackedgeTakenCount(const Loop *L) { +ScalarEvolution::computeBackedgeTakenCount(const Loop *L, + bool AllowPredicates) { SmallVector<BasicBlock *, 8> ExitingBlocks; L->getExitingBlocks(ExitingBlocks); - SmallVector<std::pair<BasicBlock *, const SCEV *>, 4> ExitCounts; + SmallVector<EdgeInfo, 4> ExitCounts; bool CouldComputeBECount = true; BasicBlock *Latch = L->getLoopLatch(); // may be NULL. const SCEV *MustExitMaxBECount = nullptr; @@ -5196,9 +5675,13 @@ ScalarEvolution::computeBackedgeTakenCount(const Loop *L) { // Compute the ExitLimit for each loop exit. Use this to populate ExitCounts // and compute maxBECount. + // Do a union of all the predicates here. for (unsigned i = 0, e = ExitingBlocks.size(); i != e; ++i) { BasicBlock *ExitBB = ExitingBlocks[i]; - ExitLimit EL = computeExitLimit(L, ExitBB); + ExitLimit EL = computeExitLimit(L, ExitBB, AllowPredicates); + + assert((AllowPredicates || EL.Pred.isAlwaysTrue()) && + "Predicated exit limit when predicates are not allowed!"); // 1. For each exit that can be computed, add an entry to ExitCounts. // CouldComputeBECount is true only if all exits can be computed. @@ -5207,7 +5690,7 @@ ScalarEvolution::computeBackedgeTakenCount(const Loop *L) { // we won't be able to compute an exact value for the loop. CouldComputeBECount = false; else - ExitCounts.push_back(std::make_pair(ExitBB, EL.Exact)); + ExitCounts.emplace_back(EdgeInfo(ExitBB, EL.Exact, EL.Pred)); // 2. Derive the loop's MaxBECount from each exit's max number of // non-exiting iterations. Partition the loop exits into two kinds: @@ -5241,20 +5724,20 @@ ScalarEvolution::computeBackedgeTakenCount(const Loop *L) { } ScalarEvolution::ExitLimit -ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock) { +ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock, + bool AllowPredicates) { // Okay, we've chosen an exiting block. See what condition causes us to exit // at this block and remember the exit block and whether all other targets // lead to the loop header. bool MustExecuteLoopHeader = true; BasicBlock *Exit = nullptr; - for (succ_iterator SI = succ_begin(ExitingBlock), SE = succ_end(ExitingBlock); - SI != SE; ++SI) - if (!L->contains(*SI)) { + for (auto *SBB : successors(ExitingBlock)) + if (!L->contains(SBB)) { if (Exit) // Multiple exit successors. return getCouldNotCompute(); - Exit = *SI; - } else if (*SI != L->getHeader()) { + Exit = SBB; + } else if (SBB != L->getHeader()) { MustExecuteLoopHeader = false; } @@ -5307,9 +5790,9 @@ ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock) { if (BranchInst *BI = dyn_cast<BranchInst>(Term)) { assert(BI->isConditional() && "If unconditional, it can't be in loop!"); // Proceed to the next level to examine the exit condition expression. - return computeExitLimitFromCond(L, BI->getCondition(), BI->getSuccessor(0), - BI->getSuccessor(1), - /*ControlsExit=*/IsOnlyExit); + return computeExitLimitFromCond( + L, BI->getCondition(), BI->getSuccessor(0), BI->getSuccessor(1), + /*ControlsExit=*/IsOnlyExit, AllowPredicates); } if (SwitchInst *SI = dyn_cast<SwitchInst>(Term)) @@ -5319,29 +5802,24 @@ ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock) { return getCouldNotCompute(); } -/// computeExitLimitFromCond - Compute the number of times the -/// backedge of the specified loop will execute if its exit condition -/// were a conditional branch of ExitCond, TBB, and FBB. -/// -/// @param ControlsExit is true if ExitCond directly controls the exit -/// branch. In this case, we can assume that the loop exits only if the -/// condition is true and can infer that failing to meet the condition prior to -/// integer wraparound results in undefined behavior. ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCond(const Loop *L, Value *ExitCond, BasicBlock *TBB, BasicBlock *FBB, - bool ControlsExit) { + bool ControlsExit, + bool AllowPredicates) { // Check if the controlling expression for this loop is an And or Or. if (BinaryOperator *BO = dyn_cast<BinaryOperator>(ExitCond)) { if (BO->getOpcode() == Instruction::And) { // Recurse on the operands of the and. bool EitherMayExit = L->contains(TBB); ExitLimit EL0 = computeExitLimitFromCond(L, BO->getOperand(0), TBB, FBB, - ControlsExit && !EitherMayExit); + ControlsExit && !EitherMayExit, + AllowPredicates); ExitLimit EL1 = computeExitLimitFromCond(L, BO->getOperand(1), TBB, FBB, - ControlsExit && !EitherMayExit); + ControlsExit && !EitherMayExit, + AllowPredicates); const SCEV *BECount = getCouldNotCompute(); const SCEV *MaxBECount = getCouldNotCompute(); if (EitherMayExit) { @@ -5368,6 +5846,9 @@ ScalarEvolution::computeExitLimitFromCond(const Loop *L, BECount = EL0.Exact; } + SCEVUnionPredicate NP; + NP.add(&EL0.Pred); + NP.add(&EL1.Pred); // There are cases (e.g. PR26207) where computeExitLimitFromCond is able // to be more aggressive when computing BECount than when computing // MaxBECount. In these cases it is possible for EL0.Exact and EL1.Exact @@ -5376,15 +5857,17 @@ ScalarEvolution::computeExitLimitFromCond(const Loop *L, !isa<SCEVCouldNotCompute>(BECount)) MaxBECount = BECount; - return ExitLimit(BECount, MaxBECount); + return ExitLimit(BECount, MaxBECount, NP); } if (BO->getOpcode() == Instruction::Or) { // Recurse on the operands of the or. bool EitherMayExit = L->contains(FBB); ExitLimit EL0 = computeExitLimitFromCond(L, BO->getOperand(0), TBB, FBB, - ControlsExit && !EitherMayExit); + ControlsExit && !EitherMayExit, + AllowPredicates); ExitLimit EL1 = computeExitLimitFromCond(L, BO->getOperand(1), TBB, FBB, - ControlsExit && !EitherMayExit); + ControlsExit && !EitherMayExit, + AllowPredicates); const SCEV *BECount = getCouldNotCompute(); const SCEV *MaxBECount = getCouldNotCompute(); if (EitherMayExit) { @@ -5411,14 +5894,25 @@ ScalarEvolution::computeExitLimitFromCond(const Loop *L, BECount = EL0.Exact; } - return ExitLimit(BECount, MaxBECount); + SCEVUnionPredicate NP; + NP.add(&EL0.Pred); + NP.add(&EL1.Pred); + return ExitLimit(BECount, MaxBECount, NP); } } // With an icmp, it may be feasible to compute an exact backedge-taken count. // Proceed to the next level to examine the icmp. - if (ICmpInst *ExitCondICmp = dyn_cast<ICmpInst>(ExitCond)) - return computeExitLimitFromICmp(L, ExitCondICmp, TBB, FBB, ControlsExit); + if (ICmpInst *ExitCondICmp = dyn_cast<ICmpInst>(ExitCond)) { + ExitLimit EL = + computeExitLimitFromICmp(L, ExitCondICmp, TBB, FBB, ControlsExit); + if (EL.hasFullInfo() || !AllowPredicates) + return EL; + + // Try again, but use SCEV predicates this time. + return computeExitLimitFromICmp(L, ExitCondICmp, TBB, FBB, ControlsExit, + /*AllowPredicates=*/true); + } // Check for a constant condition. These are normally stripped out by // SimplifyCFG, but ScalarEvolution may be used by a pass which wishes to @@ -5442,7 +5936,8 @@ ScalarEvolution::computeExitLimitFromICmp(const Loop *L, ICmpInst *ExitCond, BasicBlock *TBB, BasicBlock *FBB, - bool ControlsExit) { + bool ControlsExit, + bool AllowPredicates) { // If the condition was exit on true, convert the condition to exit on false ICmpInst::Predicate Cond; @@ -5460,11 +5955,6 @@ ScalarEvolution::computeExitLimitFromICmp(const Loop *L, return ItCnt; } - ExitLimit ShiftEL = computeShiftCompareExitLimit( - ExitCond->getOperand(0), ExitCond->getOperand(1), L, Cond); - if (ShiftEL.hasAnyInfo()) - return ShiftEL; - const SCEV *LHS = getSCEV(ExitCond->getOperand(0)); const SCEV *RHS = getSCEV(ExitCond->getOperand(1)); @@ -5499,34 +5989,46 @@ ScalarEvolution::computeExitLimitFromICmp(const Loop *L, switch (Cond) { case ICmpInst::ICMP_NE: { // while (X != Y) // Convert to: while (X-Y != 0) - ExitLimit EL = HowFarToZero(getMinusSCEV(LHS, RHS), L, ControlsExit); + ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsExit, + AllowPredicates); if (EL.hasAnyInfo()) return EL; break; } case ICmpInst::ICMP_EQ: { // while (X == Y) // Convert to: while (X-Y == 0) - ExitLimit EL = HowFarToNonZero(getMinusSCEV(LHS, RHS), L); + ExitLimit EL = howFarToNonZero(getMinusSCEV(LHS, RHS), L); if (EL.hasAnyInfo()) return EL; break; } case ICmpInst::ICMP_SLT: case ICmpInst::ICMP_ULT: { // while (X < Y) bool IsSigned = Cond == ICmpInst::ICMP_SLT; - ExitLimit EL = HowManyLessThans(LHS, RHS, L, IsSigned, ControlsExit); + ExitLimit EL = howManyLessThans(LHS, RHS, L, IsSigned, ControlsExit, + AllowPredicates); if (EL.hasAnyInfo()) return EL; break; } case ICmpInst::ICMP_SGT: case ICmpInst::ICMP_UGT: { // while (X > Y) bool IsSigned = Cond == ICmpInst::ICMP_SGT; - ExitLimit EL = HowManyGreaterThans(LHS, RHS, L, IsSigned, ControlsExit); + ExitLimit EL = + howManyGreaterThans(LHS, RHS, L, IsSigned, ControlsExit, + AllowPredicates); if (EL.hasAnyInfo()) return EL; break; } default: break; } - return computeExitCountExhaustively(L, ExitCond, !L->contains(TBB)); + + auto *ExhaustiveCount = + computeExitCountExhaustively(L, ExitCond, !L->contains(TBB)); + + if (!isa<SCEVCouldNotCompute>(ExhaustiveCount)) + return ExhaustiveCount; + + return computeShiftCompareExitLimit(ExitCond->getOperand(0), + ExitCond->getOperand(1), L, Cond); } ScalarEvolution::ExitLimit @@ -5546,7 +6048,7 @@ ScalarEvolution::computeExitLimitFromSingleExitSwitch(const Loop *L, const SCEV *RHS = getConstant(Switch->findCaseDest(ExitingBlock)); // while (X != Y) --> while (X-Y != 0) - ExitLimit EL = HowFarToZero(getMinusSCEV(LHS, RHS), L, ControlsExit); + ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsExit); if (EL.hasAnyInfo()) return EL; @@ -5563,9 +6065,8 @@ EvaluateConstantChrecAtConstant(const SCEVAddRecExpr *AddRec, ConstantInt *C, return cast<SCEVConstant>(Val)->getValue(); } -/// computeLoadConstantCompareExitLimit - Given an exit condition of -/// 'icmp op load X, cst', try to see if we can compute the backedge -/// execution count. +/// Given an exit condition of 'icmp op load X, cst', try to see if we can +/// compute the backedge execution count. ScalarEvolution::ExitLimit ScalarEvolution::computeLoadConstantCompareExitLimit( LoadInst *LI, @@ -5781,14 +6282,15 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeShiftCompareExitLimit( unsigned BitWidth = getTypeSizeInBits(RHS->getType()); const SCEV *UpperBound = getConstant(getEffectiveSCEVType(RHS->getType()), BitWidth); - return ExitLimit(getCouldNotCompute(), UpperBound); + SCEVUnionPredicate P; + return ExitLimit(getCouldNotCompute(), UpperBound, P); } return getCouldNotCompute(); } -/// CanConstantFold - Return true if we can constant fold an instruction of the -/// specified type, assuming that all operands were constants. +/// Return true if we can constant fold an instruction of the specified type, +/// assuming that all operands were constants. static bool CanConstantFold(const Instruction *I) { if (isa<BinaryOperator>(I) || isa<CmpInst>(I) || isa<SelectInst>(I) || isa<CastInst>(I) || isa<GetElementPtrInst>(I) || @@ -5916,10 +6418,9 @@ static Constant *EvaluateExpression(Value *V, const Loop *L, Operands[1], DL, TLI); if (LoadInst *LI = dyn_cast<LoadInst>(I)) { if (!LI->isVolatile()) - return ConstantFoldLoadFromConstPtr(Operands[0], DL); + return ConstantFoldLoadFromConstPtr(Operands[0], LI->getType(), DL); } - return ConstantFoldInstOperands(I->getOpcode(), I->getType(), Operands, DL, - TLI); + return ConstantFoldInstOperands(I, Operands, DL, TLI); } @@ -6107,16 +6608,6 @@ const SCEV *ScalarEvolution::computeExitCountExhaustively(const Loop *L, return getCouldNotCompute(); } -/// getSCEVAtScope - Return a SCEV expression for the specified value -/// at the specified scope in the program. The L value specifies a loop -/// nest to evaluate the expression at, where null is the top-level or a -/// specified loop is immediately inside of the loop. -/// -/// This method can be used to compute the exit value for a variable defined -/// in a loop by querying what the value will hold in the parent loop. -/// -/// In the case that a relevant loop exit value cannot be computed, the -/// original value V is returned. const SCEV *ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) { SmallVector<std::pair<const Loop *, const SCEV *>, 2> &Values = ValuesAtScopes[V]; @@ -6305,10 +6796,9 @@ const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) { Operands[1], DL, &TLI); else if (const LoadInst *LI = dyn_cast<LoadInst>(I)) { if (!LI->isVolatile()) - C = ConstantFoldLoadFromConstPtr(Operands[0], DL); + C = ConstantFoldLoadFromConstPtr(Operands[0], LI->getType(), DL); } else - C = ConstantFoldInstOperands(I->getOpcode(), I->getType(), Operands, - DL, &TLI); + C = ConstantFoldInstOperands(I, Operands, DL, &TLI); if (!C) return V; return getSCEV(C); } @@ -6428,14 +6918,11 @@ const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) { llvm_unreachable("Unknown SCEV type!"); } -/// getSCEVAtScope - This is a convenience function which does -/// getSCEVAtScope(getSCEV(V), L). const SCEV *ScalarEvolution::getSCEVAtScope(Value *V, const Loop *L) { return getSCEVAtScope(getSCEV(V), L); } -/// SolveLinEquationWithOverflow - Finds the minimum unsigned root of the -/// following equation: +/// Finds the minimum unsigned root of the following equation: /// /// A * X = B (mod N) /// @@ -6482,11 +6969,11 @@ static const SCEV *SolveLinEquationWithOverflow(const APInt &A, const APInt &B, return SE.getConstant(Result.trunc(BW)); } -/// SolveQuadraticEquation - Find the roots of the quadratic equation for the -/// given quadratic chrec {L,+,M,+,N}. This returns either the two roots (which -/// might be the same) or two SCEVCouldNotCompute objects. +/// Find the roots of the quadratic equation for the given quadratic chrec +/// {L,+,M,+,N}. This returns either the two roots (which might be the same) or +/// two SCEVCouldNotCompute objects. /// -static std::pair<const SCEV *,const SCEV *> +static Optional<std::pair<const SCEVConstant *,const SCEVConstant *>> SolveQuadraticEquation(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE) { assert(AddRec->getNumOperands() == 3 && "This is not a quadratic chrec!"); const SCEVConstant *LC = dyn_cast<SCEVConstant>(AddRec->getOperand(0)); @@ -6494,10 +6981,8 @@ SolveQuadraticEquation(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE) { const SCEVConstant *NC = dyn_cast<SCEVConstant>(AddRec->getOperand(2)); // We currently can only solve this if the coefficients are constants. - if (!LC || !MC || !NC) { - const SCEV *CNC = SE.getCouldNotCompute(); - return std::make_pair(CNC, CNC); - } + if (!LC || !MC || !NC) + return None; uint32_t BitWidth = LC->getAPInt().getBitWidth(); const APInt &L = LC->getAPInt(); @@ -6524,8 +7009,7 @@ SolveQuadraticEquation(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE) { if (SqrtTerm.isNegative()) { // The loop is provably infinite. - const SCEV *CNC = SE.getCouldNotCompute(); - return std::make_pair(CNC, CNC); + return None; } // Compute sqrt(B^2-4ac). This is guaranteed to be the nearest @@ -6536,10 +7020,8 @@ SolveQuadraticEquation(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE) { // The divisions must be performed as signed divisions. APInt NegB(-B); APInt TwoA(A << 1); - if (TwoA.isMinValue()) { - const SCEV *CNC = SE.getCouldNotCompute(); - return std::make_pair(CNC, CNC); - } + if (TwoA.isMinValue()) + return None; LLVMContext &Context = SE.getContext(); @@ -6548,20 +7030,21 @@ SolveQuadraticEquation(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE) { ConstantInt *Solution2 = ConstantInt::get(Context, (NegB - SqrtVal).sdiv(TwoA)); - return std::make_pair(SE.getConstant(Solution1), - SE.getConstant(Solution2)); + return std::make_pair(cast<SCEVConstant>(SE.getConstant(Solution1)), + cast<SCEVConstant>(SE.getConstant(Solution2))); } // end APIntOps namespace } -/// HowFarToZero - Return the number of times a backedge comparing the specified -/// value to zero will execute. If not computable, return CouldNotCompute. -/// -/// This is only used for loops with a "x != y" exit test. The exit condition is -/// now expressed as a single expression, V = x-y. So the exit test is -/// effectively V != 0. We know and take advantage of the fact that this -/// expression only being used in a comparison by zero context. ScalarEvolution::ExitLimit -ScalarEvolution::HowFarToZero(const SCEV *V, const Loop *L, bool ControlsExit) { +ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L, bool ControlsExit, + bool AllowPredicates) { + + // This is only used for loops with a "x != y" exit test. The exit condition + // is now expressed as a single expression, V = x-y. So the exit test is + // effectively V != 0. We know and take advantage of the fact that this + // expression only being used in a comparison by zero context. + + SCEVUnionPredicate P; // If the value is a constant if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) { // If the value is already zero, the branch will execute zero times. @@ -6570,31 +7053,33 @@ ScalarEvolution::HowFarToZero(const SCEV *V, const Loop *L, bool ControlsExit) { } const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(V); + if (!AddRec && AllowPredicates) + // Try to make this an AddRec using runtime tests, in the first X + // iterations of this loop, where X is the SCEV expression found by the + // algorithm below. + AddRec = convertSCEVToAddRecWithPredicates(V, L, P); + if (!AddRec || AddRec->getLoop() != L) return getCouldNotCompute(); // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of // the quadratic equation to solve it. if (AddRec->isQuadratic() && AddRec->getType()->isIntegerTy()) { - std::pair<const SCEV *,const SCEV *> Roots = - SolveQuadraticEquation(AddRec, *this); - const SCEVConstant *R1 = dyn_cast<SCEVConstant>(Roots.first); - const SCEVConstant *R2 = dyn_cast<SCEVConstant>(Roots.second); - if (R1 && R2) { + if (auto Roots = SolveQuadraticEquation(AddRec, *this)) { + const SCEVConstant *R1 = Roots->first; + const SCEVConstant *R2 = Roots->second; // Pick the smallest positive root value. - if (ConstantInt *CB = - dyn_cast<ConstantInt>(ConstantExpr::getICmp(CmpInst::ICMP_ULT, - R1->getValue(), - R2->getValue()))) { + if (ConstantInt *CB = dyn_cast<ConstantInt>(ConstantExpr::getICmp( + CmpInst::ICMP_ULT, R1->getValue(), R2->getValue()))) { if (!CB->getZExtValue()) - std::swap(R1, R2); // R1 is the minimum root now. + std::swap(R1, R2); // R1 is the minimum root now. // We can only use this value if the chrec ends up with an exact zero // value at this index. When solving for "X*X != 5", for example, we // should not accept a root of 2. const SCEV *Val = AddRec->evaluateAtIteration(R1, *this); if (Val->isZero()) - return R1; // We found a quadratic root! + return ExitLimit(R1, R1, P); // We found a quadratic root! } } return getCouldNotCompute(); @@ -6651,7 +7136,7 @@ ScalarEvolution::HowFarToZero(const SCEV *V, const Loop *L, bool ControlsExit) { else MaxBECount = getConstant(CountDown ? CR.getUnsignedMax() : -CR.getUnsignedMin()); - return ExitLimit(Distance, MaxBECount); + return ExitLimit(Distance, MaxBECount, P); } // As a special case, handle the instance where Step is a positive power of @@ -6704,7 +7189,9 @@ ScalarEvolution::HowFarToZero(const SCEV *V, const Loop *L, bool ControlsExit) { auto *NarrowTy = IntegerType::get(getContext(), NarrowWidth); auto *WideTy = Distance->getType(); - return getZeroExtendExpr(getTruncateExpr(ModuloResult, NarrowTy), WideTy); + const SCEV *Limit = + getZeroExtendExpr(getTruncateExpr(ModuloResult, NarrowTy), WideTy); + return ExitLimit(Limit, Limit, P); } } @@ -6713,24 +7200,24 @@ ScalarEvolution::HowFarToZero(const SCEV *V, const Loop *L, bool ControlsExit) { // compute the backedge count. In this case, the step may not divide the // distance, but we don't care because if the condition is "missed" the loop // will have undefined behavior due to wrapping. - if (ControlsExit && AddRec->getNoWrapFlags(SCEV::FlagNW)) { + if (ControlsExit && AddRec->hasNoSelfWrap() && + loopHasNoAbnormalExits(AddRec->getLoop())) { const SCEV *Exact = getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step); - return ExitLimit(Exact, Exact); + return ExitLimit(Exact, Exact, P); } // Then, try to solve the above equation provided that Start is constant. - if (const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start)) - return SolveLinEquationWithOverflow(StepC->getAPInt(), -StartC->getAPInt(), - *this); + if (const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start)) { + const SCEV *E = SolveLinEquationWithOverflow( + StepC->getValue()->getValue(), -StartC->getValue()->getValue(), *this); + return ExitLimit(E, E, P); + } return getCouldNotCompute(); } -/// HowFarToNonZero - Return the number of times a backedge checking the -/// specified value for nonzero will execute. If not computable, return -/// CouldNotCompute ScalarEvolution::ExitLimit -ScalarEvolution::HowFarToNonZero(const SCEV *V, const Loop *L) { +ScalarEvolution::howFarToNonZero(const SCEV *V, const Loop *L) { // Loops that look like: while (X == 0) are very strange indeed. We don't // handle them yet except for the trivial case. This could be expanded in the // future as needed. @@ -6748,33 +7235,27 @@ ScalarEvolution::HowFarToNonZero(const SCEV *V, const Loop *L) { return getCouldNotCompute(); } -/// getPredecessorWithUniqueSuccessorForBB - Return a predecessor of BB -/// (which may not be an immediate predecessor) which has exactly one -/// successor from which BB is reachable, or null if no such block is -/// found. -/// std::pair<BasicBlock *, BasicBlock *> ScalarEvolution::getPredecessorWithUniqueSuccessorForBB(BasicBlock *BB) { // If the block has a unique predecessor, then there is no path from the // predecessor to the block that does not go through the direct edge // from the predecessor to the block. if (BasicBlock *Pred = BB->getSinglePredecessor()) - return std::make_pair(Pred, BB); + return {Pred, BB}; // A loop's header is defined to be a block that dominates the loop. // If the header has a unique predecessor outside the loop, it must be // a block that has exactly one successor that can reach the loop. if (Loop *L = LI.getLoopFor(BB)) - return std::make_pair(L->getLoopPredecessor(), L->getHeader()); + return {L->getLoopPredecessor(), L->getHeader()}; - return std::pair<BasicBlock *, BasicBlock *>(); + return {nullptr, nullptr}; } -/// HasSameValue - SCEV structural equivalence is usually sufficient for -/// testing whether two expressions are equal, however for the purposes of -/// looking for a condition guarding a loop, it can be useful to be a little -/// more general, since a front-end may have replicated the controlling -/// expression. +/// SCEV structural equivalence is usually sufficient for testing whether two +/// expressions are equal, however for the purposes of looking for a condition +/// guarding a loop, it can be useful to be a little more general, since a +/// front-end may have replicated the controlling expression. /// static bool HasSameValue(const SCEV *A, const SCEV *B) { // Quick check to see if they are the same SCEV. @@ -6800,9 +7281,6 @@ static bool HasSameValue(const SCEV *A, const SCEV *B) { return false; } -/// SimplifyICmpOperands - Simplify LHS and RHS in a comparison with -/// predicate Pred. Return true iff any changes were made. -/// bool ScalarEvolution::SimplifyICmpOperands(ICmpInst::Predicate &Pred, const SCEV *&LHS, const SCEV *&RHS, unsigned Depth) { @@ -7134,7 +7612,7 @@ bool ScalarEvolution::isKnownPredicate(ICmpInst::Predicate Pred, return true; // Otherwise see what can be done with known constant ranges. - return isKnownPredicateWithRanges(Pred, LHS, RHS); + return isKnownPredicateViaConstantRanges(Pred, LHS, RHS); } bool ScalarEvolution::isMonotonicPredicate(const SCEVAddRecExpr *LHS, @@ -7180,7 +7658,7 @@ bool ScalarEvolution::isMonotonicPredicateImpl(const SCEVAddRecExpr *LHS, case ICmpInst::ICMP_UGE: case ICmpInst::ICMP_ULT: case ICmpInst::ICMP_ULE: - if (!LHS->getNoWrapFlags(SCEV::FlagNUW)) + if (!LHS->hasNoUnsignedWrap()) return false; Increasing = Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_UGE; @@ -7190,7 +7668,7 @@ bool ScalarEvolution::isMonotonicPredicateImpl(const SCEVAddRecExpr *LHS, case ICmpInst::ICMP_SGE: case ICmpInst::ICMP_SLT: case ICmpInst::ICMP_SLE: { - if (!LHS->getNoWrapFlags(SCEV::FlagNSW)) + if (!LHS->hasNoSignedWrap()) return false; const SCEV *Step = LHS->getStepRecurrence(*this); @@ -7264,78 +7742,34 @@ bool ScalarEvolution::isLoopInvariantPredicate( return true; } -bool -ScalarEvolution::isKnownPredicateWithRanges(ICmpInst::Predicate Pred, - const SCEV *LHS, const SCEV *RHS) { +bool ScalarEvolution::isKnownPredicateViaConstantRanges( + ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS) { if (HasSameValue(LHS, RHS)) return ICmpInst::isTrueWhenEqual(Pred); // This code is split out from isKnownPredicate because it is called from // within isLoopEntryGuardedByCond. - switch (Pred) { - default: - llvm_unreachable("Unexpected ICmpInst::Predicate value!"); - case ICmpInst::ICMP_SGT: - std::swap(LHS, RHS); - case ICmpInst::ICMP_SLT: { - ConstantRange LHSRange = getSignedRange(LHS); - ConstantRange RHSRange = getSignedRange(RHS); - if (LHSRange.getSignedMax().slt(RHSRange.getSignedMin())) - return true; - if (LHSRange.getSignedMin().sge(RHSRange.getSignedMax())) - return false; - break; - } - case ICmpInst::ICMP_SGE: - std::swap(LHS, RHS); - case ICmpInst::ICMP_SLE: { - ConstantRange LHSRange = getSignedRange(LHS); - ConstantRange RHSRange = getSignedRange(RHS); - if (LHSRange.getSignedMax().sle(RHSRange.getSignedMin())) - return true; - if (LHSRange.getSignedMin().sgt(RHSRange.getSignedMax())) - return false; - break; - } - case ICmpInst::ICMP_UGT: - std::swap(LHS, RHS); - case ICmpInst::ICMP_ULT: { - ConstantRange LHSRange = getUnsignedRange(LHS); - ConstantRange RHSRange = getUnsignedRange(RHS); - if (LHSRange.getUnsignedMax().ult(RHSRange.getUnsignedMin())) - return true; - if (LHSRange.getUnsignedMin().uge(RHSRange.getUnsignedMax())) - return false; - break; - } - case ICmpInst::ICMP_UGE: - std::swap(LHS, RHS); - case ICmpInst::ICMP_ULE: { - ConstantRange LHSRange = getUnsignedRange(LHS); - ConstantRange RHSRange = getUnsignedRange(RHS); - if (LHSRange.getUnsignedMax().ule(RHSRange.getUnsignedMin())) - return true; - if (LHSRange.getUnsignedMin().ugt(RHSRange.getUnsignedMax())) - return false; - break; - } - case ICmpInst::ICMP_NE: { - if (getUnsignedRange(LHS).intersectWith(getUnsignedRange(RHS)).isEmptySet()) - return true; - if (getSignedRange(LHS).intersectWith(getSignedRange(RHS)).isEmptySet()) - return true; - const SCEV *Diff = getMinusSCEV(LHS, RHS); - if (isKnownNonZero(Diff)) - return true; - break; - } - case ICmpInst::ICMP_EQ: - // The check at the top of the function catches the case where - // the values are known to be equal. - break; - } - return false; + auto CheckRanges = + [&](const ConstantRange &RangeLHS, const ConstantRange &RangeRHS) { + return ConstantRange::makeSatisfyingICmpRegion(Pred, RangeRHS) + .contains(RangeLHS); + }; + + // The check at the top of the function catches the case where the values are + // known to be equal. + if (Pred == CmpInst::ICMP_EQ) + return false; + + if (Pred == CmpInst::ICMP_NE) + return CheckRanges(getSignedRange(LHS), getSignedRange(RHS)) || + CheckRanges(getUnsignedRange(LHS), getUnsignedRange(RHS)) || + isKnownNonZero(getMinusSCEV(LHS, RHS)); + + if (CmpInst::isSigned(Pred)) + return CheckRanges(getSignedRange(LHS), getSignedRange(RHS)); + + return CheckRanges(getUnsignedRange(LHS), getUnsignedRange(RHS)); } bool ScalarEvolution::isKnownPredicateViaNoOverflow(ICmpInst::Predicate Pred, @@ -7416,6 +7850,23 @@ bool ScalarEvolution::isKnownPredicateViaSplitting(ICmpInst::Predicate Pred, isKnownPredicate(CmpInst::ICMP_SLT, LHS, RHS); } +bool ScalarEvolution::isImpliedViaGuard(BasicBlock *BB, + ICmpInst::Predicate Pred, + const SCEV *LHS, const SCEV *RHS) { + // No need to even try if we know the module has no guards. + if (!HasGuards) + return false; + + return any_of(*BB, [&](Instruction &I) { + using namespace llvm::PatternMatch; + + Value *Condition; + return match(&I, m_Intrinsic<Intrinsic::experimental_guard>( + m_Value(Condition))) && + isImpliedCond(Pred, LHS, RHS, Condition, false); + }); +} + /// isLoopBackedgeGuardedByCond - Test whether the backedge of the loop is /// protected by a conditional between LHS and RHS. This is used to /// to eliminate casts. @@ -7427,7 +7878,8 @@ ScalarEvolution::isLoopBackedgeGuardedByCond(const Loop *L, // (interprocedural conditions notwithstanding). if (!L) return true; - if (isKnownPredicateWithRanges(Pred, LHS, RHS)) return true; + if (isKnownPredicateViaConstantRanges(Pred, LHS, RHS)) + return true; BasicBlock *Latch = L->getLoopLatch(); if (!Latch) @@ -7482,12 +7934,18 @@ ScalarEvolution::isLoopBackedgeGuardedByCond(const Loop *L, if (!DT.isReachableFromEntry(L->getHeader())) return false; + if (isImpliedViaGuard(Latch, Pred, LHS, RHS)) + return true; + for (DomTreeNode *DTN = DT[Latch], *HeaderDTN = DT[L->getHeader()]; DTN != HeaderDTN; DTN = DTN->getIDom()) { assert(DTN && "should reach the loop header before reaching the root!"); BasicBlock *BB = DTN->getBlock(); + if (isImpliedViaGuard(BB, Pred, LHS, RHS)) + return true; + BasicBlock *PBB = BB->getSinglePredecessor(); if (!PBB) continue; @@ -7518,9 +7976,6 @@ ScalarEvolution::isLoopBackedgeGuardedByCond(const Loop *L, return false; } -/// isLoopEntryGuardedByCond - Test whether entry to the loop is protected -/// by a conditional between LHS and RHS. This is used to help avoid max -/// expressions in loop trip counts, and to eliminate casts. bool ScalarEvolution::isLoopEntryGuardedByCond(const Loop *L, ICmpInst::Predicate Pred, @@ -7529,7 +7984,8 @@ ScalarEvolution::isLoopEntryGuardedByCond(const Loop *L, // (interprocedural conditions notwithstanding). if (!L) return false; - if (isKnownPredicateWithRanges(Pred, LHS, RHS)) return true; + if (isKnownPredicateViaConstantRanges(Pred, LHS, RHS)) + return true; // Starting at the loop predecessor, climb up the predecessor chain, as long // as there are predecessors that can be found that have unique successors @@ -7539,6 +7995,9 @@ ScalarEvolution::isLoopEntryGuardedByCond(const Loop *L, Pair.first; Pair = getPredecessorWithUniqueSuccessorForBB(Pair.first)) { + if (isImpliedViaGuard(Pair.first, Pred, LHS, RHS)) + return true; + BranchInst *LoopEntryPredicate = dyn_cast<BranchInst>(Pair.first->getTerminator()); if (!LoopEntryPredicate || @@ -7586,8 +8045,6 @@ struct MarkPendingLoopPredicate { }; } // end anonymous namespace -/// isImpliedCond - Test whether the condition described by Pred, LHS, -/// and RHS is true whenever the given Cond value evaluates to true. bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, Value *FoundCondValue, @@ -7910,9 +8367,6 @@ bool ScalarEvolution::isImpliedCondOperandsViaNoOverflow( getConstant(FoundRHSLimit)); } -/// isImpliedCondOperands - Test whether the condition described by Pred, -/// LHS, and RHS is true whenever the condition described by Pred, FoundLHS, -/// and FoundRHS is true. bool ScalarEvolution::isImpliedCondOperands(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const SCEV *FoundLHS, @@ -8037,9 +8491,6 @@ static bool IsKnownPredicateViaMinOrMax(ScalarEvolution &SE, llvm_unreachable("covered switch fell through?!"); } -/// isImpliedCondOperandsHelper - Test whether the condition described by -/// Pred, LHS, and RHS is true whenever the condition described by Pred, -/// FoundLHS, and FoundRHS is true. bool ScalarEvolution::isImpliedCondOperandsHelper(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, @@ -8047,7 +8498,7 @@ ScalarEvolution::isImpliedCondOperandsHelper(ICmpInst::Predicate Pred, const SCEV *FoundRHS) { auto IsKnownPredicateFull = [this](ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS) { - return isKnownPredicateWithRanges(Pred, LHS, RHS) || + return isKnownPredicateViaConstantRanges(Pred, LHS, RHS) || IsKnownPredicateViaMinOrMax(*this, Pred, LHS, RHS) || IsKnownPredicateViaAddRecStart(*this, Pred, LHS, RHS) || isKnownPredicateViaNoOverflow(Pred, LHS, RHS); @@ -8089,8 +8540,6 @@ ScalarEvolution::isImpliedCondOperandsHelper(ICmpInst::Predicate Pred, return false; } -/// isImpliedCondOperandsViaRanges - helper function for isImpliedCondOperands. -/// Tries to get cases like "X `sgt` 0 => X - 1 `sgt` -1". bool ScalarEvolution::isImpliedCondOperandsViaRanges(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, @@ -8129,9 +8578,6 @@ bool ScalarEvolution::isImpliedCondOperandsViaRanges(ICmpInst::Predicate Pred, return SatisfyingLHSRange.contains(LHSRange); } -// Verify if an linear IV with positive stride can overflow when in a -// less-than comparison, knowing the invariant term of the comparison, the -// stride and the knowledge of NSW/NUW flags on the recurrence. bool ScalarEvolution::doesIVOverflowOnLT(const SCEV *RHS, const SCEV *Stride, bool IsSigned, bool NoWrap) { if (NoWrap) return false; @@ -8158,9 +8604,6 @@ bool ScalarEvolution::doesIVOverflowOnLT(const SCEV *RHS, const SCEV *Stride, return (MaxValue - MaxStrideMinusOne).ult(MaxRHS); } -// Verify if an linear IV with negative stride can overflow when in a -// greater-than comparison, knowing the invariant term of the comparison, -// the stride and the knowledge of NSW/NUW flags on the recurrence. bool ScalarEvolution::doesIVOverflowOnGT(const SCEV *RHS, const SCEV *Stride, bool IsSigned, bool NoWrap) { if (NoWrap) return false; @@ -8187,8 +8630,6 @@ bool ScalarEvolution::doesIVOverflowOnGT(const SCEV *RHS, const SCEV *Stride, return (MinValue + MaxStrideMinusOne).ugt(MinRHS); } -// Compute the backedge taken count knowing the interval difference, the -// stride and presence of the equality in the comparison. const SCEV *ScalarEvolution::computeBECount(const SCEV *Delta, const SCEV *Step, bool Equality) { const SCEV *One = getOne(Step->getType()); @@ -8197,22 +8638,21 @@ const SCEV *ScalarEvolution::computeBECount(const SCEV *Delta, const SCEV *Step, return getUDivExpr(Delta, Step); } -/// HowManyLessThans - Return the number of times a backedge containing the -/// specified less-than comparison will execute. If not computable, return -/// CouldNotCompute. -/// -/// @param ControlsExit is true when the LHS < RHS condition directly controls -/// the branch (loops exits only if condition is true). In this case, we can use -/// NoWrapFlags to skip overflow checks. ScalarEvolution::ExitLimit -ScalarEvolution::HowManyLessThans(const SCEV *LHS, const SCEV *RHS, +ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS, const Loop *L, bool IsSigned, - bool ControlsExit) { + bool ControlsExit, bool AllowPredicates) { + SCEVUnionPredicate P; // We handle only IV < Invariant if (!isLoopInvariant(RHS, L)) return getCouldNotCompute(); const SCEVAddRecExpr *IV = dyn_cast<SCEVAddRecExpr>(LHS); + if (!IV && AllowPredicates) + // Try to make this an AddRec using runtime tests, in the first X + // iterations of this loop, where X is the SCEV expression found by the + // algorithm below. + IV = convertSCEVToAddRecWithPredicates(LHS, L, P); // Avoid weird loops if (!IV || IV->getLoop() != L || !IV->isAffine()) @@ -8238,19 +8678,8 @@ ScalarEvolution::HowManyLessThans(const SCEV *LHS, const SCEV *RHS, : ICmpInst::ICMP_ULT; const SCEV *Start = IV->getStart(); const SCEV *End = RHS; - if (!isLoopEntryGuardedByCond(L, Cond, getMinusSCEV(Start, Stride), RHS)) { - const SCEV *Diff = getMinusSCEV(RHS, Start); - // If we have NoWrap set, then we can assume that the increment won't - // overflow, in which case if RHS - Start is a constant, we don't need to - // do a max operation since we can just figure it out statically - if (NoWrap && isa<SCEVConstant>(Diff)) { - APInt D = dyn_cast<const SCEVConstant>(Diff)->getAPInt(); - if (D.isNegative()) - End = Start; - } else - End = IsSigned ? getSMaxExpr(RHS, Start) - : getUMaxExpr(RHS, Start); - } + if (!isLoopEntryGuardedByCond(L, Cond, getMinusSCEV(Start, Stride), RHS)) + End = IsSigned ? getSMaxExpr(RHS, Start) : getUMaxExpr(RHS, Start); const SCEV *BECount = computeBECount(getMinusSCEV(End, Start), Stride, false); @@ -8281,18 +8710,24 @@ ScalarEvolution::HowManyLessThans(const SCEV *LHS, const SCEV *RHS, if (isa<SCEVCouldNotCompute>(MaxBECount)) MaxBECount = BECount; - return ExitLimit(BECount, MaxBECount); + return ExitLimit(BECount, MaxBECount, P); } ScalarEvolution::ExitLimit -ScalarEvolution::HowManyGreaterThans(const SCEV *LHS, const SCEV *RHS, +ScalarEvolution::howManyGreaterThans(const SCEV *LHS, const SCEV *RHS, const Loop *L, bool IsSigned, - bool ControlsExit) { + bool ControlsExit, bool AllowPredicates) { + SCEVUnionPredicate P; // We handle only IV > Invariant if (!isLoopInvariant(RHS, L)) return getCouldNotCompute(); const SCEVAddRecExpr *IV = dyn_cast<SCEVAddRecExpr>(LHS); + if (!IV && AllowPredicates) + // Try to make this an AddRec using runtime tests, in the first X + // iterations of this loop, where X is the SCEV expression found by the + // algorithm below. + IV = convertSCEVToAddRecWithPredicates(LHS, L, P); // Avoid weird loops if (!IV || IV->getLoop() != L || !IV->isAffine()) @@ -8319,19 +8754,8 @@ ScalarEvolution::HowManyGreaterThans(const SCEV *LHS, const SCEV *RHS, const SCEV *Start = IV->getStart(); const SCEV *End = RHS; - if (!isLoopEntryGuardedByCond(L, Cond, getAddExpr(Start, Stride), RHS)) { - const SCEV *Diff = getMinusSCEV(RHS, Start); - // If we have NoWrap set, then we can assume that the increment won't - // overflow, in which case if RHS - Start is a constant, we don't need to - // do a max operation since we can just figure it out statically - if (NoWrap && isa<SCEVConstant>(Diff)) { - APInt D = dyn_cast<const SCEVConstant>(Diff)->getAPInt(); - if (!D.isNegative()) - End = Start; - } else - End = IsSigned ? getSMinExpr(RHS, Start) - : getUMinExpr(RHS, Start); - } + if (!isLoopEntryGuardedByCond(L, Cond, getAddExpr(Start, Stride), RHS)) + End = IsSigned ? getSMinExpr(RHS, Start) : getUMinExpr(RHS, Start); const SCEV *BECount = computeBECount(getMinusSCEV(Start, End), Stride, false); @@ -8363,15 +8787,10 @@ ScalarEvolution::HowManyGreaterThans(const SCEV *LHS, const SCEV *RHS, if (isa<SCEVCouldNotCompute>(MaxBECount)) MaxBECount = BECount; - return ExitLimit(BECount, MaxBECount); + return ExitLimit(BECount, MaxBECount, P); } -/// getNumIterationsInRange - Return the number of iterations of this loop that -/// produce values in the specified constant range. Another way of looking at -/// this is that it returns the first iteration number where the value is not in -/// the condition, thus computing the exit count. If the iteration count can't -/// be computed, an instance of SCEVCouldNotCompute is returned. -const SCEV *SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range, +const SCEV *SCEVAddRecExpr::getNumIterationsInRange(const ConstantRange &Range, ScalarEvolution &SE) const { if (Range.isFullSet()) // Infinite loop. return SE.getCouldNotCompute(); @@ -8445,22 +8864,21 @@ const SCEV *SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range, FlagAnyWrap); // Next, solve the constructed addrec - auto Roots = SolveQuadraticEquation(cast<SCEVAddRecExpr>(NewAddRec), SE); - const SCEVConstant *R1 = dyn_cast<SCEVConstant>(Roots.first); - const SCEVConstant *R2 = dyn_cast<SCEVConstant>(Roots.second); - if (R1) { + if (auto Roots = + SolveQuadraticEquation(cast<SCEVAddRecExpr>(NewAddRec), SE)) { + const SCEVConstant *R1 = Roots->first; + const SCEVConstant *R2 = Roots->second; // Pick the smallest positive root value. if (ConstantInt *CB = dyn_cast<ConstantInt>(ConstantExpr::getICmp( ICmpInst::ICMP_ULT, R1->getValue(), R2->getValue()))) { if (!CB->getZExtValue()) - std::swap(R1, R2); // R1 is the minimum root now. + std::swap(R1, R2); // R1 is the minimum root now. // Make sure the root is not off by one. The returned iteration should // not be in the range, but the previous one should be. When solving // for "X*X < 5", for example, we should not return a root of 2. - ConstantInt *R1Val = EvaluateConstantChrecAtConstant(this, - R1->getValue(), - SE); + ConstantInt *R1Val = + EvaluateConstantChrecAtConstant(this, R1->getValue(), SE); if (Range.contains(R1Val->getValue())) { // The next iteration must be out of the range... ConstantInt *NextVal = @@ -8469,7 +8887,7 @@ const SCEV *SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range, R1Val = EvaluateConstantChrecAtConstant(this, NextVal, SE); if (!Range.contains(R1Val->getValue())) return SE.getConstant(NextVal); - return SE.getCouldNotCompute(); // Something strange happened + return SE.getCouldNotCompute(); // Something strange happened } // If R1 was not in the range, then it is a good return value. Make @@ -8479,7 +8897,7 @@ const SCEV *SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range, R1Val = EvaluateConstantChrecAtConstant(this, NextVal, SE); if (Range.contains(R1Val->getValue())) return R1; - return SE.getCouldNotCompute(); // Something strange happened + return SE.getCouldNotCompute(); // Something strange happened } } } @@ -8789,12 +9207,9 @@ const SCEV *ScalarEvolution::getElementSize(Instruction *Inst) { return getSizeOfExpr(ETy, Ty); } -/// Second step of delinearization: compute the array dimensions Sizes from the -/// set of Terms extracted from the memory access function of this SCEVAddRec. void ScalarEvolution::findArrayDimensions(SmallVectorImpl<const SCEV *> &Terms, SmallVectorImpl<const SCEV *> &Sizes, const SCEV *ElementSize) const { - if (Terms.size() < 1 || !ElementSize) return; @@ -8858,8 +9273,6 @@ void ScalarEvolution::findArrayDimensions(SmallVectorImpl<const SCEV *> &Terms, }); } -/// Third step of delinearization: compute the access functions for the -/// Subscripts based on the dimensions in Sizes. void ScalarEvolution::computeAccessFunctions( const SCEV *Expr, SmallVectorImpl<const SCEV *> &Subscripts, SmallVectorImpl<const SCEV *> &Sizes) { @@ -9012,7 +9425,7 @@ void ScalarEvolution::SCEVCallbackVH::deleted() { assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!"); if (PHINode *PN = dyn_cast<PHINode>(getValPtr())) SE->ConstantEvolutionLoopExitValue.erase(PN); - SE->ValueExprMap.erase(getValPtr()); + SE->eraseValueFromMap(getValPtr()); // this now dangles! } @@ -9035,13 +9448,13 @@ void ScalarEvolution::SCEVCallbackVH::allUsesReplacedWith(Value *V) { continue; if (PHINode *PN = dyn_cast<PHINode>(U)) SE->ConstantEvolutionLoopExitValue.erase(PN); - SE->ValueExprMap.erase(U); + SE->eraseValueFromMap(U); Worklist.insert(Worklist.end(), U->user_begin(), U->user_end()); } // Delete the Old value. if (PHINode *PN = dyn_cast<PHINode>(Old)) SE->ConstantEvolutionLoopExitValue.erase(PN); - SE->ValueExprMap.erase(Old); + SE->eraseValueFromMap(Old); // this now dangles! } @@ -9059,14 +9472,31 @@ ScalarEvolution::ScalarEvolution(Function &F, TargetLibraryInfo &TLI, CouldNotCompute(new SCEVCouldNotCompute()), WalkingBEDominatingConds(false), ProvingSplitPredicate(false), ValuesAtScopes(64), LoopDispositions(64), BlockDispositions(64), - FirstUnknown(nullptr) {} + FirstUnknown(nullptr) { + + // To use guards for proving predicates, we need to scan every instruction in + // relevant basic blocks, and not just terminators. Doing this is a waste of + // time if the IR does not actually contain any calls to + // @llvm.experimental.guard, so do a quick check and remember this beforehand. + // + // This pessimizes the case where a pass that preserves ScalarEvolution wants + // to _add_ guards to the module when there weren't any before, and wants + // ScalarEvolution to optimize based on those guards. For now we prefer to be + // efficient in lieu of being smart in that rather obscure case. + + auto *GuardDecl = F.getParent()->getFunction( + Intrinsic::getName(Intrinsic::experimental_guard)); + HasGuards = GuardDecl && !GuardDecl->use_empty(); +} ScalarEvolution::ScalarEvolution(ScalarEvolution &&Arg) - : F(Arg.F), TLI(Arg.TLI), AC(Arg.AC), DT(Arg.DT), LI(Arg.LI), - CouldNotCompute(std::move(Arg.CouldNotCompute)), + : F(Arg.F), HasGuards(Arg.HasGuards), TLI(Arg.TLI), AC(Arg.AC), DT(Arg.DT), + LI(Arg.LI), CouldNotCompute(std::move(Arg.CouldNotCompute)), ValueExprMap(std::move(Arg.ValueExprMap)), WalkingBEDominatingConds(false), ProvingSplitPredicate(false), BackedgeTakenCounts(std::move(Arg.BackedgeTakenCounts)), + PredicatedBackedgeTakenCounts( + std::move(Arg.PredicatedBackedgeTakenCounts)), ConstantEvolutionLoopExitValue( std::move(Arg.ConstantEvolutionLoopExitValue)), ValuesAtScopes(std::move(Arg.ValuesAtScopes)), @@ -9091,12 +9521,16 @@ ScalarEvolution::~ScalarEvolution() { } FirstUnknown = nullptr; + ExprValueMap.clear(); ValueExprMap.clear(); + HasRecMap.clear(); // Free any extra memory created for ExitNotTakenInfo in the unlikely event // that a loop had multiple computable exits. for (auto &BTCI : BackedgeTakenCounts) BTCI.second.clear(); + for (auto &BTCI : PredicatedBackedgeTakenCounts) + BTCI.second.clear(); assert(PendingLoopPredicates.empty() && "isImpliedCond garbage"); assert(!WalkingBEDominatingConds && "isLoopBackedgeGuardedByCond garbage!"); @@ -9110,8 +9544,8 @@ bool ScalarEvolution::hasLoopInvariantBackedgeTakenCount(const Loop *L) { static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE, const Loop *L) { // Print all inner loops first - for (Loop::iterator I = L->begin(), E = L->end(); I != E; ++I) - PrintLoopInfo(OS, SE, *I); + for (Loop *I : *L) + PrintLoopInfo(OS, SE, I); OS << "Loop "; L->getHeader()->printAsOperand(OS, /*PrintType=*/false); @@ -9139,9 +9573,35 @@ static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE, OS << "Unpredictable max backedge-taken count. "; } + OS << "\n" + "Loop "; + L->getHeader()->printAsOperand(OS, /*PrintType=*/false); + OS << ": "; + + SCEVUnionPredicate Pred; + auto PBT = SE->getPredicatedBackedgeTakenCount(L, Pred); + if (!isa<SCEVCouldNotCompute>(PBT)) { + OS << "Predicated backedge-taken count is " << *PBT << "\n"; + OS << " Predicates:\n"; + Pred.print(OS, 4); + } else { + OS << "Unpredictable predicated backedge-taken count. "; + } OS << "\n"; } +static StringRef loopDispositionToStr(ScalarEvolution::LoopDisposition LD) { + switch (LD) { + case ScalarEvolution::LoopVariant: + return "Variant"; + case ScalarEvolution::LoopInvariant: + return "Invariant"; + case ScalarEvolution::LoopComputable: + return "Computable"; + } + llvm_unreachable("Unknown ScalarEvolution::LoopDisposition kind!"); +} + void ScalarEvolution::print(raw_ostream &OS) const { // ScalarEvolution's implementation of the print method is to print // out SCEV values of all instructions that are interesting. Doing @@ -9189,6 +9649,35 @@ void ScalarEvolution::print(raw_ostream &OS) const { } else { OS << *ExitValue; } + + bool First = true; + for (auto *Iter = L; Iter; Iter = Iter->getParentLoop()) { + if (First) { + OS << "\t\t" "LoopDispositions: { "; + First = false; + } else { + OS << ", "; + } + + Iter->getHeader()->printAsOperand(OS, /*PrintType=*/false); + OS << ": " << loopDispositionToStr(SE.getLoopDisposition(SV, Iter)); + } + + for (auto *InnerL : depth_first(L)) { + if (InnerL == L) + continue; + if (First) { + OS << "\t\t" "LoopDispositions: { "; + First = false; + } else { + OS << ", "; + } + + InnerL->getHeader()->printAsOperand(OS, /*PrintType=*/false); + OS << ": " << loopDispositionToStr(SE.getLoopDisposition(SV, InnerL)); + } + + OS << " }"; } OS << "\n"; @@ -9197,8 +9686,8 @@ void ScalarEvolution::print(raw_ostream &OS) const { OS << "Determining loop execution counts for: "; F.printAsOperand(OS, /*PrintType=*/false); OS << "\n"; - for (LoopInfo::iterator I = LI.begin(), E = LI.end(); I != E; ++I) - PrintLoopInfo(OS, &SE, *I); + for (Loop *I : LI) + PrintLoopInfo(OS, &SE, I); } ScalarEvolution::LoopDisposition @@ -9420,17 +9909,23 @@ void ScalarEvolution::forgetMemoizedResults(const SCEV *S) { BlockDispositions.erase(S); UnsignedRanges.erase(S); SignedRanges.erase(S); + ExprValueMap.erase(S); + HasRecMap.erase(S); + + auto RemoveSCEVFromBackedgeMap = + [S, this](DenseMap<const Loop *, BackedgeTakenInfo> &Map) { + for (auto I = Map.begin(), E = Map.end(); I != E;) { + BackedgeTakenInfo &BEInfo = I->second; + if (BEInfo.hasOperand(S, this)) { + BEInfo.clear(); + Map.erase(I++); + } else + ++I; + } + }; - for (DenseMap<const Loop*, BackedgeTakenInfo>::iterator I = - BackedgeTakenCounts.begin(), E = BackedgeTakenCounts.end(); I != E; ) { - BackedgeTakenInfo &BEInfo = I->second; - if (BEInfo.hasOperand(S, this)) { - BEInfo.clear(); - BackedgeTakenCounts.erase(I++); - } - else - ++I; - } + RemoveSCEVFromBackedgeMap(BackedgeTakenCounts); + RemoveSCEVFromBackedgeMap(PredicatedBackedgeTakenCounts); } typedef DenseMap<const Loop *, std::string> VerifyMap; @@ -9516,16 +10011,16 @@ void ScalarEvolution::verify() const { char ScalarEvolutionAnalysis::PassID; ScalarEvolution ScalarEvolutionAnalysis::run(Function &F, - AnalysisManager<Function> *AM) { - return ScalarEvolution(F, AM->getResult<TargetLibraryAnalysis>(F), - AM->getResult<AssumptionAnalysis>(F), - AM->getResult<DominatorTreeAnalysis>(F), - AM->getResult<LoopAnalysis>(F)); + AnalysisManager<Function> &AM) { + return ScalarEvolution(F, AM.getResult<TargetLibraryAnalysis>(F), + AM.getResult<AssumptionAnalysis>(F), + AM.getResult<DominatorTreeAnalysis>(F), + AM.getResult<LoopAnalysis>(F)); } PreservedAnalyses -ScalarEvolutionPrinterPass::run(Function &F, AnalysisManager<Function> *AM) { - AM->getResult<ScalarEvolutionAnalysis>(F).print(OS); +ScalarEvolutionPrinterPass::run(Function &F, AnalysisManager<Function> &AM) { + AM.getResult<ScalarEvolutionAnalysis>(F).print(OS); return PreservedAnalyses::all(); } @@ -9590,36 +10085,121 @@ ScalarEvolution::getEqualPredicate(const SCEVUnknown *LHS, return Eq; } +const SCEVPredicate *ScalarEvolution::getWrapPredicate( + const SCEVAddRecExpr *AR, + SCEVWrapPredicate::IncrementWrapFlags AddedFlags) { + FoldingSetNodeID ID; + // Unique this node based on the arguments + ID.AddInteger(SCEVPredicate::P_Wrap); + ID.AddPointer(AR); + ID.AddInteger(AddedFlags); + void *IP = nullptr; + if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP)) + return S; + auto *OF = new (SCEVAllocator) + SCEVWrapPredicate(ID.Intern(SCEVAllocator), AR, AddedFlags); + UniquePreds.InsertNode(OF, IP); + return OF; +} + namespace { + class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> { public: - static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE, - SCEVUnionPredicate &A) { - SCEVPredicateRewriter Rewriter(SE, A); - return Rewriter.visit(Scev); + // Rewrites \p S in the context of a loop L and the predicate A. + // If Assume is true, rewrite is free to add further predicates to A + // such that the result will be an AddRecExpr. + static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE, + SCEVUnionPredicate &A, bool Assume) { + SCEVPredicateRewriter Rewriter(L, SE, A, Assume); + return Rewriter.visit(S); } - SCEVPredicateRewriter(ScalarEvolution &SE, SCEVUnionPredicate &P) - : SCEVRewriteVisitor(SE), P(P) {} + SCEVPredicateRewriter(const Loop *L, ScalarEvolution &SE, + SCEVUnionPredicate &P, bool Assume) + : SCEVRewriteVisitor(SE), P(P), L(L), Assume(Assume) {} const SCEV *visitUnknown(const SCEVUnknown *Expr) { auto ExprPreds = P.getPredicatesForExpr(Expr); for (auto *Pred : ExprPreds) - if (const auto *IPred = dyn_cast<const SCEVEqualPredicate>(Pred)) + if (const auto *IPred = dyn_cast<SCEVEqualPredicate>(Pred)) if (IPred->getLHS() == Expr) return IPred->getRHS(); return Expr; } + const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { + const SCEV *Operand = visit(Expr->getOperand()); + const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Operand); + if (AR && AR->getLoop() == L && AR->isAffine()) { + // This couldn't be folded because the operand didn't have the nuw + // flag. Add the nusw flag as an assumption that we could make. + const SCEV *Step = AR->getStepRecurrence(SE); + Type *Ty = Expr->getType(); + if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNUSW)) + return SE.getAddRecExpr(SE.getZeroExtendExpr(AR->getStart(), Ty), + SE.getSignExtendExpr(Step, Ty), L, + AR->getNoWrapFlags()); + } + return SE.getZeroExtendExpr(Operand, Expr->getType()); + } + + const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { + const SCEV *Operand = visit(Expr->getOperand()); + const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Operand); + if (AR && AR->getLoop() == L && AR->isAffine()) { + // This couldn't be folded because the operand didn't have the nsw + // flag. Add the nssw flag as an assumption that we could make. + const SCEV *Step = AR->getStepRecurrence(SE); + Type *Ty = Expr->getType(); + if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNSSW)) + return SE.getAddRecExpr(SE.getSignExtendExpr(AR->getStart(), Ty), + SE.getSignExtendExpr(Step, Ty), L, + AR->getNoWrapFlags()); + } + return SE.getSignExtendExpr(Operand, Expr->getType()); + } + private: + bool addOverflowAssumption(const SCEVAddRecExpr *AR, + SCEVWrapPredicate::IncrementWrapFlags AddedFlags) { + auto *A = SE.getWrapPredicate(AR, AddedFlags); + if (!Assume) { + // Check if we've already made this assumption. + if (P.implies(A)) + return true; + return false; + } + P.add(A); + return true; + } + SCEVUnionPredicate &P; + const Loop *L; + bool Assume; }; } // end anonymous namespace -const SCEV *ScalarEvolution::rewriteUsingPredicate(const SCEV *Scev, +const SCEV *ScalarEvolution::rewriteUsingPredicate(const SCEV *S, const Loop *L, SCEVUnionPredicate &Preds) { - return SCEVPredicateRewriter::rewrite(Scev, *this, Preds); + return SCEVPredicateRewriter::rewrite(S, L, *this, Preds, false); +} + +const SCEVAddRecExpr * +ScalarEvolution::convertSCEVToAddRecWithPredicates(const SCEV *S, const Loop *L, + SCEVUnionPredicate &Preds) { + SCEVUnionPredicate TransformPreds; + S = SCEVPredicateRewriter::rewrite(S, L, *this, TransformPreds, true); + auto *AddRec = dyn_cast<SCEVAddRecExpr>(S); + + if (!AddRec) + return nullptr; + + // Since the transformation was successful, we can now transfer the SCEV + // predicates. + Preds.add(&TransformPreds); + return AddRec; } /// SCEV predicates @@ -9633,7 +10213,7 @@ SCEVEqualPredicate::SCEVEqualPredicate(const FoldingSetNodeIDRef ID, : SCEVPredicate(ID, P_Equal), LHS(LHS), RHS(RHS) {} bool SCEVEqualPredicate::implies(const SCEVPredicate *N) const { - const auto *Op = dyn_cast<const SCEVEqualPredicate>(N); + const auto *Op = dyn_cast<SCEVEqualPredicate>(N); if (!Op) return false; @@ -9649,6 +10229,59 @@ void SCEVEqualPredicate::print(raw_ostream &OS, unsigned Depth) const { OS.indent(Depth) << "Equal predicate: " << *LHS << " == " << *RHS << "\n"; } +SCEVWrapPredicate::SCEVWrapPredicate(const FoldingSetNodeIDRef ID, + const SCEVAddRecExpr *AR, + IncrementWrapFlags Flags) + : SCEVPredicate(ID, P_Wrap), AR(AR), Flags(Flags) {} + +const SCEV *SCEVWrapPredicate::getExpr() const { return AR; } + +bool SCEVWrapPredicate::implies(const SCEVPredicate *N) const { + const auto *Op = dyn_cast<SCEVWrapPredicate>(N); + + return Op && Op->AR == AR && setFlags(Flags, Op->Flags) == Flags; +} + +bool SCEVWrapPredicate::isAlwaysTrue() const { + SCEV::NoWrapFlags ScevFlags = AR->getNoWrapFlags(); + IncrementWrapFlags IFlags = Flags; + + if (ScalarEvolution::setFlags(ScevFlags, SCEV::FlagNSW) == ScevFlags) + IFlags = clearFlags(IFlags, IncrementNSSW); + + return IFlags == IncrementAnyWrap; +} + +void SCEVWrapPredicate::print(raw_ostream &OS, unsigned Depth) const { + OS.indent(Depth) << *getExpr() << " Added Flags: "; + if (SCEVWrapPredicate::IncrementNUSW & getFlags()) + OS << "<nusw>"; + if (SCEVWrapPredicate::IncrementNSSW & getFlags()) + OS << "<nssw>"; + OS << "\n"; +} + +SCEVWrapPredicate::IncrementWrapFlags +SCEVWrapPredicate::getImpliedFlags(const SCEVAddRecExpr *AR, + ScalarEvolution &SE) { + IncrementWrapFlags ImpliedFlags = IncrementAnyWrap; + SCEV::NoWrapFlags StaticFlags = AR->getNoWrapFlags(); + + // We can safely transfer the NSW flag as NSSW. + if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNSW) == StaticFlags) + ImpliedFlags = IncrementNSSW; + + if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNUW) == StaticFlags) { + // If the increment is positive, the SCEV NUW flag will also imply the + // WrapPredicate NUSW flag. + if (const auto *Step = dyn_cast<SCEVConstant>(AR->getStepRecurrence(SE))) + if (Step->getValue()->getValue().isNonNegative()) + ImpliedFlags = setFlags(ImpliedFlags, IncrementNUSW); + } + + return ImpliedFlags; +} + /// Union predicates don't get cached so create a dummy set ID for it. SCEVUnionPredicate::SCEVUnionPredicate() : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {} @@ -9667,7 +10300,7 @@ SCEVUnionPredicate::getPredicatesForExpr(const SCEV *Expr) { } bool SCEVUnionPredicate::implies(const SCEVPredicate *N) const { - if (const auto *Set = dyn_cast<const SCEVUnionPredicate>(N)) + if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N)) return all_of(Set->Preds, [this](const SCEVPredicate *I) { return this->implies(I); }); @@ -9688,7 +10321,7 @@ void SCEVUnionPredicate::print(raw_ostream &OS, unsigned Depth) const { } void SCEVUnionPredicate::add(const SCEVPredicate *N) { - if (const auto *Set = dyn_cast<const SCEVUnionPredicate>(N)) { + if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N)) { for (auto Pred : Set->Preds) add(Pred); return; @@ -9705,8 +10338,9 @@ void SCEVUnionPredicate::add(const SCEVPredicate *N) { Preds.push_back(N); } -PredicatedScalarEvolution::PredicatedScalarEvolution(ScalarEvolution &SE) - : SE(SE), Generation(0) {} +PredicatedScalarEvolution::PredicatedScalarEvolution(ScalarEvolution &SE, + Loop &L) + : SE(SE), L(L), Generation(0), BackedgeCount(nullptr) {} const SCEV *PredicatedScalarEvolution::getSCEV(Value *V) { const SCEV *Expr = SE.getSCEV(V); @@ -9721,12 +10355,21 @@ const SCEV *PredicatedScalarEvolution::getSCEV(Value *V) { if (Entry.second) Expr = Entry.second; - const SCEV *NewSCEV = SE.rewriteUsingPredicate(Expr, Preds); + const SCEV *NewSCEV = SE.rewriteUsingPredicate(Expr, &L, Preds); Entry = {Generation, NewSCEV}; return NewSCEV; } +const SCEV *PredicatedScalarEvolution::getBackedgeTakenCount() { + if (!BackedgeCount) { + SCEVUnionPredicate BackedgePred; + BackedgeCount = SE.getPredicatedBackedgeTakenCount(&L, BackedgePred); + addPredicate(BackedgePred); + } + return BackedgeCount; +} + void PredicatedScalarEvolution::addPredicate(const SCEVPredicate &Pred) { if (Preds.implies(&Pred)) return; @@ -9743,7 +10386,82 @@ void PredicatedScalarEvolution::updateGeneration() { if (++Generation == 0) { for (auto &II : RewriteMap) { const SCEV *Rewritten = II.second.second; - II.second = {Generation, SE.rewriteUsingPredicate(Rewritten, Preds)}; + II.second = {Generation, SE.rewriteUsingPredicate(Rewritten, &L, Preds)}; } } } + +void PredicatedScalarEvolution::setNoOverflow( + Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags) { + const SCEV *Expr = getSCEV(V); + const auto *AR = cast<SCEVAddRecExpr>(Expr); + + auto ImpliedFlags = SCEVWrapPredicate::getImpliedFlags(AR, SE); + + // Clear the statically implied flags. + Flags = SCEVWrapPredicate::clearFlags(Flags, ImpliedFlags); + addPredicate(*SE.getWrapPredicate(AR, Flags)); + + auto II = FlagsMap.insert({V, Flags}); + if (!II.second) + II.first->second = SCEVWrapPredicate::setFlags(Flags, II.first->second); +} + +bool PredicatedScalarEvolution::hasNoOverflow( + Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags) { + const SCEV *Expr = getSCEV(V); + const auto *AR = cast<SCEVAddRecExpr>(Expr); + + Flags = SCEVWrapPredicate::clearFlags( + Flags, SCEVWrapPredicate::getImpliedFlags(AR, SE)); + + auto II = FlagsMap.find(V); + + if (II != FlagsMap.end()) + Flags = SCEVWrapPredicate::clearFlags(Flags, II->second); + + return Flags == SCEVWrapPredicate::IncrementAnyWrap; +} + +const SCEVAddRecExpr *PredicatedScalarEvolution::getAsAddRec(Value *V) { + const SCEV *Expr = this->getSCEV(V); + auto *New = SE.convertSCEVToAddRecWithPredicates(Expr, &L, Preds); + + if (!New) + return nullptr; + + updateGeneration(); + RewriteMap[SE.getSCEV(V)] = {Generation, New}; + return New; +} + +PredicatedScalarEvolution::PredicatedScalarEvolution( + const PredicatedScalarEvolution &Init) + : RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L), Preds(Init.Preds), + Generation(Init.Generation), BackedgeCount(Init.BackedgeCount) { + for (const auto &I : Init.FlagsMap) + FlagsMap.insert(I); +} + +void PredicatedScalarEvolution::print(raw_ostream &OS, unsigned Depth) const { + // For each block. + for (auto *BB : L.getBlocks()) + for (auto &I : *BB) { + if (!SE.isSCEVable(I.getType())) + continue; + + auto *Expr = SE.getSCEV(&I); + auto II = RewriteMap.find(Expr); + + if (II == RewriteMap.end()) + continue; + + // Don't print things that are not interesting. + if (II->second.second == Expr) + continue; + + OS.indent(Depth) << "[PSE]" << I << ":\n"; + OS.indent(Depth + 2) << *Expr << "\n"; + OS.indent(Depth + 2) << "--> " << *II->second.second << "\n"; + } +} |