diff options
Diffstat (limited to 'contrib/llvm/lib/Analysis/ScalarEvolution.cpp')
-rw-r--r-- | contrib/llvm/lib/Analysis/ScalarEvolution.cpp | 1462 |
1 files changed, 732 insertions, 730 deletions
diff --git a/contrib/llvm/lib/Analysis/ScalarEvolution.cpp b/contrib/llvm/lib/Analysis/ScalarEvolution.cpp index 8fefada..ed328f1 100644 --- a/contrib/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/contrib/llvm/lib/Analysis/ScalarEvolution.cpp @@ -61,6 +61,8 @@ #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/ADT/Optional.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/ScopeExit.h" +#include "llvm/ADT/Sequence.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AssumptionCache.h" @@ -120,6 +122,21 @@ static cl::opt<bool> cl::desc("Verify no dangling value in ScalarEvolution's " "ExprValueMap (slow)")); +static cl::opt<unsigned> MulOpsInlineThreshold( + "scev-mulops-inline-threshold", cl::Hidden, + cl::desc("Threshold for inlining multiplication operands into a SCEV"), + cl::init(1000)); + +static cl::opt<unsigned> MaxSCEVCompareDepth( + "scalar-evolution-max-scev-compare-depth", cl::Hidden, + cl::desc("Maximum depth of recursive SCEV complexity comparisons"), + cl::init(32)); + +static cl::opt<unsigned> MaxValueCompareDepth( + "scalar-evolution-max-value-compare-depth", cl::Hidden, + cl::desc("Maximum depth of recursive value complexity comparisons"), + cl::init(2)); + //===----------------------------------------------------------------------===// // SCEV class definitions //===----------------------------------------------------------------------===// @@ -447,180 +464,233 @@ bool SCEVUnknown::isOffsetOf(Type *&CTy, Constant *&FieldNo) const { // SCEV Utilities //===----------------------------------------------------------------------===// -namespace { -/// SCEVComplexityCompare - Return true if the complexity of the LHS is less -/// than the complexity of the RHS. This comparator is used to canonicalize -/// expressions. -class SCEVComplexityCompare { - const LoopInfo *const LI; -public: - explicit SCEVComplexityCompare(const LoopInfo *li) : LI(li) {} +/// Compare the two values \p LV and \p RV in terms of their "complexity" where +/// "complexity" is a partial (and somewhat ad-hoc) relation used to order +/// operands in SCEV expressions. \p EqCache is a set of pairs of values that +/// have been previously deemed to be "equally complex" by this routine. It is +/// intended to avoid exponential time complexity in cases like: +/// +/// %a = f(%x, %y) +/// %b = f(%a, %a) +/// %c = f(%b, %b) +/// +/// %d = f(%x, %y) +/// %e = f(%d, %d) +/// %f = f(%e, %e) +/// +/// CompareValueComplexity(%f, %c) +/// +/// Since we do not continue running this routine on expression trees once we +/// have seen unequal values, there is no need to track them in the cache. +static int +CompareValueComplexity(SmallSet<std::pair<Value *, Value *>, 8> &EqCache, + const LoopInfo *const LI, Value *LV, Value *RV, + unsigned Depth) { + if (Depth > MaxValueCompareDepth || EqCache.count({LV, RV})) + return 0; + + // Order pointer values after integer values. This helps SCEVExpander form + // GEPs. + bool LIsPointer = LV->getType()->isPointerTy(), + RIsPointer = RV->getType()->isPointerTy(); + if (LIsPointer != RIsPointer) + return (int)LIsPointer - (int)RIsPointer; - // Return true or false if LHS is less than, or at least RHS, respectively. - bool operator()(const SCEV *LHS, const SCEV *RHS) const { - return compare(LHS, RHS) < 0; + // Compare getValueID values. + unsigned LID = LV->getValueID(), RID = RV->getValueID(); + if (LID != RID) + return (int)LID - (int)RID; + + // Sort arguments by their position. + if (const auto *LA = dyn_cast<Argument>(LV)) { + const auto *RA = cast<Argument>(RV); + unsigned LArgNo = LA->getArgNo(), RArgNo = RA->getArgNo(); + return (int)LArgNo - (int)RArgNo; } - // Return negative, zero, or positive, if LHS is less than, equal to, or - // greater than RHS, respectively. A three-way result allows recursive - // comparisons to be more efficient. - int compare(const SCEV *LHS, const SCEV *RHS) const { - // Fast-path: SCEVs are uniqued so we can do a quick equality check. - if (LHS == RHS) - return 0; - - // Primarily, sort the SCEVs by their getSCEVType(). - unsigned LType = LHS->getSCEVType(), RType = RHS->getSCEVType(); - if (LType != RType) - return (int)LType - (int)RType; - - // Aside from the getSCEVType() ordering, the particular ordering - // isn't very important except that it's beneficial to be consistent, - // so that (a + b) and (b + a) don't end up as different expressions. - switch (static_cast<SCEVTypes>(LType)) { - case scUnknown: { - const SCEVUnknown *LU = cast<SCEVUnknown>(LHS); - const SCEVUnknown *RU = cast<SCEVUnknown>(RHS); - - // Sort SCEVUnknown values with some loose heuristics. TODO: This is - // not as complete as it could be. - const Value *LV = LU->getValue(), *RV = RU->getValue(); - - // Order pointer values after integer values. This helps SCEVExpander - // form GEPs. - bool LIsPointer = LV->getType()->isPointerTy(), - RIsPointer = RV->getType()->isPointerTy(); - if (LIsPointer != RIsPointer) - return (int)LIsPointer - (int)RIsPointer; - - // Compare getValueID values. - unsigned LID = LV->getValueID(), - RID = RV->getValueID(); - if (LID != RID) - return (int)LID - (int)RID; - - // Sort arguments by their position. - if (const Argument *LA = dyn_cast<Argument>(LV)) { - const Argument *RA = cast<Argument>(RV); - unsigned LArgNo = LA->getArgNo(), RArgNo = RA->getArgNo(); - return (int)LArgNo - (int)RArgNo; - } + if (const auto *LGV = dyn_cast<GlobalValue>(LV)) { + const auto *RGV = cast<GlobalValue>(RV); - // For instructions, compare their loop depth, and their operand - // count. This is pretty loose. - if (const Instruction *LInst = dyn_cast<Instruction>(LV)) { - const Instruction *RInst = cast<Instruction>(RV); - - // Compare loop depths. - const BasicBlock *LParent = LInst->getParent(), - *RParent = RInst->getParent(); - if (LParent != RParent) { - unsigned LDepth = LI->getLoopDepth(LParent), - RDepth = LI->getLoopDepth(RParent); - if (LDepth != RDepth) - return (int)LDepth - (int)RDepth; - } + const auto IsGVNameSemantic = [&](const GlobalValue *GV) { + auto LT = GV->getLinkage(); + return !(GlobalValue::isPrivateLinkage(LT) || + GlobalValue::isInternalLinkage(LT)); + }; - // Compare the number of operands. - unsigned LNumOps = LInst->getNumOperands(), - RNumOps = RInst->getNumOperands(); - return (int)LNumOps - (int)RNumOps; - } + // Use the names to distinguish the two values, but only if the + // names are semantically important. + if (IsGVNameSemantic(LGV) && IsGVNameSemantic(RGV)) + return LGV->getName().compare(RGV->getName()); + } + + // For instructions, compare their loop depth, and their operand count. This + // is pretty loose. + if (const auto *LInst = dyn_cast<Instruction>(LV)) { + const auto *RInst = cast<Instruction>(RV); - return 0; + // Compare loop depths. + const BasicBlock *LParent = LInst->getParent(), + *RParent = RInst->getParent(); + if (LParent != RParent) { + unsigned LDepth = LI->getLoopDepth(LParent), + RDepth = LI->getLoopDepth(RParent); + if (LDepth != RDepth) + return (int)LDepth - (int)RDepth; } - case scConstant: { - const SCEVConstant *LC = cast<SCEVConstant>(LHS); - const SCEVConstant *RC = cast<SCEVConstant>(RHS); + // Compare the number of operands. + unsigned LNumOps = LInst->getNumOperands(), + RNumOps = RInst->getNumOperands(); + if (LNumOps != RNumOps) + return (int)LNumOps - (int)RNumOps; - // Compare constant values. - const APInt &LA = LC->getAPInt(); - const APInt &RA = RC->getAPInt(); - unsigned LBitWidth = LA.getBitWidth(), RBitWidth = RA.getBitWidth(); - if (LBitWidth != RBitWidth) - return (int)LBitWidth - (int)RBitWidth; - return LA.ult(RA) ? -1 : 1; + for (unsigned Idx : seq(0u, LNumOps)) { + int Result = + CompareValueComplexity(EqCache, LI, LInst->getOperand(Idx), + RInst->getOperand(Idx), Depth + 1); + if (Result != 0) + return Result; } + } - case scAddRecExpr: { - const SCEVAddRecExpr *LA = cast<SCEVAddRecExpr>(LHS); - const SCEVAddRecExpr *RA = cast<SCEVAddRecExpr>(RHS); + EqCache.insert({LV, RV}); + return 0; +} - // Compare addrec loop depths. - const Loop *LLoop = LA->getLoop(), *RLoop = RA->getLoop(); - if (LLoop != RLoop) { - unsigned LDepth = LLoop->getLoopDepth(), - RDepth = RLoop->getLoopDepth(); - if (LDepth != RDepth) - return (int)LDepth - (int)RDepth; - } +// Return negative, zero, or positive, if LHS is less than, equal to, or greater +// than RHS, respectively. A three-way result allows recursive comparisons to be +// more efficient. +static int CompareSCEVComplexity( + SmallSet<std::pair<const SCEV *, const SCEV *>, 8> &EqCacheSCEV, + const LoopInfo *const LI, const SCEV *LHS, const SCEV *RHS, + unsigned Depth = 0) { + // Fast-path: SCEVs are uniqued so we can do a quick equality check. + if (LHS == RHS) + return 0; - // Addrec complexity grows with operand count. - unsigned LNumOps = LA->getNumOperands(), RNumOps = RA->getNumOperands(); - if (LNumOps != RNumOps) - return (int)LNumOps - (int)RNumOps; + // Primarily, sort the SCEVs by their getSCEVType(). + unsigned LType = LHS->getSCEVType(), RType = RHS->getSCEVType(); + if (LType != RType) + return (int)LType - (int)RType; - // Lexicographically compare. - for (unsigned i = 0; i != LNumOps; ++i) { - long X = compare(LA->getOperand(i), RA->getOperand(i)); - if (X != 0) - return X; - } + if (Depth > MaxSCEVCompareDepth || EqCacheSCEV.count({LHS, RHS})) + return 0; + // Aside from the getSCEVType() ordering, the particular ordering + // isn't very important except that it's beneficial to be consistent, + // so that (a + b) and (b + a) don't end up as different expressions. + switch (static_cast<SCEVTypes>(LType)) { + case scUnknown: { + const SCEVUnknown *LU = cast<SCEVUnknown>(LHS); + const SCEVUnknown *RU = cast<SCEVUnknown>(RHS); + + SmallSet<std::pair<Value *, Value *>, 8> EqCache; + int X = CompareValueComplexity(EqCache, LI, LU->getValue(), RU->getValue(), + Depth + 1); + if (X == 0) + EqCacheSCEV.insert({LHS, RHS}); + return X; + } - return 0; + case scConstant: { + const SCEVConstant *LC = cast<SCEVConstant>(LHS); + const SCEVConstant *RC = cast<SCEVConstant>(RHS); + + // Compare constant values. + const APInt &LA = LC->getAPInt(); + const APInt &RA = RC->getAPInt(); + unsigned LBitWidth = LA.getBitWidth(), RBitWidth = RA.getBitWidth(); + if (LBitWidth != RBitWidth) + return (int)LBitWidth - (int)RBitWidth; + return LA.ult(RA) ? -1 : 1; + } + + case scAddRecExpr: { + const SCEVAddRecExpr *LA = cast<SCEVAddRecExpr>(LHS); + const SCEVAddRecExpr *RA = cast<SCEVAddRecExpr>(RHS); + + // Compare addrec loop depths. + const Loop *LLoop = LA->getLoop(), *RLoop = RA->getLoop(); + if (LLoop != RLoop) { + unsigned LDepth = LLoop->getLoopDepth(), RDepth = RLoop->getLoopDepth(); + if (LDepth != RDepth) + return (int)LDepth - (int)RDepth; } - case scAddExpr: - case scMulExpr: - case scSMaxExpr: - case scUMaxExpr: { - const SCEVNAryExpr *LC = cast<SCEVNAryExpr>(LHS); - const SCEVNAryExpr *RC = cast<SCEVNAryExpr>(RHS); - - // Lexicographically compare n-ary expressions. - unsigned LNumOps = LC->getNumOperands(), RNumOps = RC->getNumOperands(); - if (LNumOps != RNumOps) - return (int)LNumOps - (int)RNumOps; - - for (unsigned i = 0; i != LNumOps; ++i) { - if (i >= RNumOps) - return 1; - long X = compare(LC->getOperand(i), RC->getOperand(i)); - if (X != 0) - return X; - } + // Addrec complexity grows with operand count. + unsigned LNumOps = LA->getNumOperands(), RNumOps = RA->getNumOperands(); + if (LNumOps != RNumOps) return (int)LNumOps - (int)RNumOps; + + // Lexicographically compare. + for (unsigned i = 0; i != LNumOps; ++i) { + int X = CompareSCEVComplexity(EqCacheSCEV, LI, LA->getOperand(i), + RA->getOperand(i), Depth + 1); + if (X != 0) + return X; } + EqCacheSCEV.insert({LHS, RHS}); + return 0; + } - case scUDivExpr: { - const SCEVUDivExpr *LC = cast<SCEVUDivExpr>(LHS); - const SCEVUDivExpr *RC = cast<SCEVUDivExpr>(RHS); + case scAddExpr: + case scMulExpr: + case scSMaxExpr: + case scUMaxExpr: { + const SCEVNAryExpr *LC = cast<SCEVNAryExpr>(LHS); + const SCEVNAryExpr *RC = cast<SCEVNAryExpr>(RHS); + + // Lexicographically compare n-ary expressions. + unsigned LNumOps = LC->getNumOperands(), RNumOps = RC->getNumOperands(); + if (LNumOps != RNumOps) + return (int)LNumOps - (int)RNumOps; - // Lexicographically compare udiv expressions. - long X = compare(LC->getLHS(), RC->getLHS()); + for (unsigned i = 0; i != LNumOps; ++i) { + if (i >= RNumOps) + return 1; + int X = CompareSCEVComplexity(EqCacheSCEV, LI, LC->getOperand(i), + RC->getOperand(i), Depth + 1); if (X != 0) return X; - return compare(LC->getRHS(), RC->getRHS()); } + EqCacheSCEV.insert({LHS, RHS}); + return 0; + } - case scTruncate: - case scZeroExtend: - case scSignExtend: { - const SCEVCastExpr *LC = cast<SCEVCastExpr>(LHS); - const SCEVCastExpr *RC = cast<SCEVCastExpr>(RHS); + case scUDivExpr: { + const SCEVUDivExpr *LC = cast<SCEVUDivExpr>(LHS); + const SCEVUDivExpr *RC = cast<SCEVUDivExpr>(RHS); - // Compare cast expressions by operand. - return compare(LC->getOperand(), RC->getOperand()); - } + // Lexicographically compare udiv expressions. + int X = CompareSCEVComplexity(EqCacheSCEV, LI, LC->getLHS(), RC->getLHS(), + Depth + 1); + if (X != 0) + return X; + X = CompareSCEVComplexity(EqCacheSCEV, LI, LC->getRHS(), RC->getRHS(), + Depth + 1); + if (X == 0) + EqCacheSCEV.insert({LHS, RHS}); + return X; + } - case scCouldNotCompute: - llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!"); - } - llvm_unreachable("Unknown SCEV kind!"); + case scTruncate: + case scZeroExtend: + case scSignExtend: { + const SCEVCastExpr *LC = cast<SCEVCastExpr>(LHS); + const SCEVCastExpr *RC = cast<SCEVCastExpr>(RHS); + + // Compare cast expressions by operand. + int X = CompareSCEVComplexity(EqCacheSCEV, LI, LC->getOperand(), + RC->getOperand(), Depth + 1); + if (X == 0) + EqCacheSCEV.insert({LHS, RHS}); + return X; } -}; -} // end anonymous namespace + + case scCouldNotCompute: + llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!"); + } + llvm_unreachable("Unknown SCEV kind!"); +} /// Given a list of SCEV objects, order them by their complexity, and group /// objects of the same complexity together by value. When this routine is @@ -635,17 +705,22 @@ public: static void GroupByComplexity(SmallVectorImpl<const SCEV *> &Ops, LoopInfo *LI) { if (Ops.size() < 2) return; // Noop + + SmallSet<std::pair<const SCEV *, const SCEV *>, 8> EqCache; if (Ops.size() == 2) { // This is the common case, which also happens to be trivially simple. // Special case it. const SCEV *&LHS = Ops[0], *&RHS = Ops[1]; - if (SCEVComplexityCompare(LI)(RHS, LHS)) + if (CompareSCEVComplexity(EqCache, LI, RHS, LHS) < 0) std::swap(LHS, RHS); return; } // Do the rough sort by complexity. - std::stable_sort(Ops.begin(), Ops.end(), SCEVComplexityCompare(LI)); + std::stable_sort(Ops.begin(), Ops.end(), + [&EqCache, LI](const SCEV *LHS, const SCEV *RHS) { + return CompareSCEVComplexity(EqCache, LI, LHS, RHS) < 0; + }); // Now that we are sorted by complexity, group elements of the same // complexity. Note that this is, at worst, N^2, but the vector is likely to @@ -2518,6 +2593,8 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops, if (Idx < Ops.size()) { bool DeletedMul = false; while (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) { + if (Ops.size() > MulOpsInlineThreshold) + break; // If we have an mul, expand the mul operands onto the end of the operands // list. Ops.erase(Ops.begin()+Idx); @@ -2970,9 +3047,9 @@ ScalarEvolution::getAddRecExpr(SmallVectorImpl<const SCEV *> &Operands, } const SCEV * -ScalarEvolution::getGEPExpr(Type *PointeeType, const SCEV *BaseExpr, - const SmallVectorImpl<const SCEV *> &IndexExprs, - bool InBounds) { +ScalarEvolution::getGEPExpr(GEPOperator *GEP, + const SmallVectorImpl<const SCEV *> &IndexExprs) { + const SCEV *BaseExpr = getSCEV(GEP->getPointerOperand()); // getSCEV(Base)->getType() has the same address space as Base->getType() // because SCEV::getType() preserves the address space. Type *IntPtrTy = getEffectiveSCEVType(BaseExpr->getType()); @@ -2981,12 +3058,13 @@ ScalarEvolution::getGEPExpr(Type *PointeeType, const SCEV *BaseExpr, // flow and the no-overflow bits may not be valid for the expression in any // context. This can be fixed similarly to how these flags are handled for // adds. - SCEV::NoWrapFlags Wrap = InBounds ? SCEV::FlagNSW : SCEV::FlagAnyWrap; + SCEV::NoWrapFlags Wrap = GEP->isInBounds() ? SCEV::FlagNSW + : SCEV::FlagAnyWrap; const SCEV *TotalOffset = getZero(IntPtrTy); - // The address space is unimportant. The first thing we do on CurTy is getting + // The array size is unimportant. The first thing we do on CurTy is getting // its element type. - Type *CurTy = PointerType::getUnqual(PointeeType); + Type *CurTy = ArrayType::get(GEP->getSourceElementType(), 0); for (const SCEV *IndexExpr : IndexExprs) { // Compute the (potentially symbolic) offset in bytes for this index. if (StructType *STy = dyn_cast<StructType>(CurTy)) { @@ -3311,71 +3389,23 @@ const SCEV *ScalarEvolution::getCouldNotCompute() { return CouldNotCompute.get(); } - bool ScalarEvolution::checkValidity(const SCEV *S) const { - // Helper class working with SCEVTraversal to figure out if a SCEV contains - // a SCEVUnknown with null value-pointer. FindInvalidSCEVUnknown::FindOne - // is set iff if find such SCEVUnknown. - // - struct FindInvalidSCEVUnknown { - bool FindOne; - FindInvalidSCEVUnknown() { FindOne = false; } - bool follow(const SCEV *S) { - switch (static_cast<SCEVTypes>(S->getSCEVType())) { - case scConstant: - return false; - case scUnknown: - if (!cast<SCEVUnknown>(S)->getValue()) - FindOne = true; - return false; - default: - return true; - } - } - bool isDone() const { return FindOne; } - }; - - FindInvalidSCEVUnknown F; - SCEVTraversal<FindInvalidSCEVUnknown> ST(F); - ST.visitAll(S); + bool ContainsNulls = SCEVExprContains(S, [](const SCEV *S) { + auto *SU = dyn_cast<SCEVUnknown>(S); + return SU && SU->getValue() == nullptr; + }); - return !F.FindOne; -} - -namespace { -// Helper class working with SCEVTraversal to figure out if a SCEV contains -// a sub SCEV of scAddRecExpr type. FindInvalidSCEVUnknown::FoundOne is set -// iff if such sub scAddRecExpr type SCEV is found. -struct FindAddRecurrence { - bool FoundOne; - FindAddRecurrence() : FoundOne(false) {} - - bool follow(const SCEV *S) { - switch (static_cast<SCEVTypes>(S->getSCEVType())) { - case scAddRecExpr: - FoundOne = true; - case scConstant: - case scUnknown: - case scCouldNotCompute: - return false; - default: - return true; - } - } - bool isDone() const { return FoundOne; } -}; + return !ContainsNulls; } bool ScalarEvolution::containsAddRecurrence(const SCEV *S) { - HasRecMapType::iterator I = HasRecMap.find_as(S); + HasRecMapType::iterator I = HasRecMap.find(S); if (I != HasRecMap.end()) return I->second; - FindAddRecurrence F; - SCEVTraversal<FindAddRecurrence> ST(F); - ST.visitAll(S); - HasRecMap.insert({S, F.FoundOne}); - return F.FoundOne; + bool FoundAddRec = SCEVExprContains(S, isa<SCEVAddRecExpr, const SCEV *>); + HasRecMap.insert({S, FoundAddRec}); + return FoundAddRec; } /// Try to split a SCEVAddExpr into a pair of {SCEV, ConstantInt}. @@ -4210,7 +4240,9 @@ static bool BrPHIToSelect(DominatorTree &DT, BranchInst *BI, PHINode *Merge, } const SCEV *ScalarEvolution::createNodeFromSelectLikePHI(PHINode *PN) { - if (PN->getNumIncomingValues() == 2) { + auto IsReachable = + [&](BasicBlock *BB) { return DT.isReachableFromEntry(BB); }; + if (PN->getNumIncomingValues() == 2 && all_of(PN->blocks(), IsReachable)) { const Loop *L = LI.getLoopFor(PN->getParent()); // We don't want to break LCSSA, even in a SCEV expression tree. @@ -4286,7 +4318,7 @@ const SCEV *ScalarEvolution::createNodeForSelectOrPHI(Instruction *I, case ICmpInst::ICMP_SLT: case ICmpInst::ICMP_SLE: std::swap(LHS, RHS); - // fall through + LLVM_FALLTHROUGH; case ICmpInst::ICMP_SGT: case ICmpInst::ICMP_SGE: // a >s b ? a+x : b+x -> smax(a, b)+x @@ -4309,7 +4341,7 @@ const SCEV *ScalarEvolution::createNodeForSelectOrPHI(Instruction *I, case ICmpInst::ICMP_ULT: case ICmpInst::ICMP_ULE: std::swap(LHS, RHS); - // fall through + LLVM_FALLTHROUGH; case ICmpInst::ICMP_UGT: case ICmpInst::ICMP_UGE: // a >u b ? a+x : b+x -> umax(a, b)+x @@ -4374,9 +4406,7 @@ const SCEV *ScalarEvolution::createNodeForGEP(GEPOperator *GEP) { SmallVector<const SCEV *, 4> IndexExprs; for (auto Index = GEP->idx_begin(); Index != GEP->idx_end(); ++Index) IndexExprs.push_back(getSCEV(*Index)); - return getGEPExpr(GEP->getSourceElementType(), - getSCEV(GEP->getPointerOperand()), - IndexExprs, GEP->isInBounds()); + return getGEPExpr(GEP, IndexExprs); } uint32_t @@ -4654,19 +4684,18 @@ ConstantRange ScalarEvolution::getRangeForAffineAR(const SCEV *Start, MaxBECount = getNoopOrZeroExtend(MaxBECount, Start->getType()); ConstantRange MaxBECountRange = getUnsignedRange(MaxBECount); - ConstantRange ZExtMaxBECountRange = - MaxBECountRange.zextOrTrunc(BitWidth * 2 + 1); + ConstantRange ZExtMaxBECountRange = MaxBECountRange.zextOrTrunc(BitWidth * 2); ConstantRange StepSRange = getSignedRange(Step); - ConstantRange SExtStepSRange = StepSRange.sextOrTrunc(BitWidth * 2 + 1); + ConstantRange SExtStepSRange = StepSRange.sextOrTrunc(BitWidth * 2); ConstantRange StartURange = getUnsignedRange(Start); ConstantRange EndURange = StartURange.add(MaxBECountRange.multiply(StepSRange)); // Check for unsigned overflow. - ConstantRange ZExtStartURange = StartURange.zextOrTrunc(BitWidth * 2 + 1); - ConstantRange ZExtEndURange = EndURange.zextOrTrunc(BitWidth * 2 + 1); + ConstantRange ZExtStartURange = StartURange.zextOrTrunc(BitWidth * 2); + ConstantRange ZExtEndURange = EndURange.zextOrTrunc(BitWidth * 2); if (ZExtStartURange.add(ZExtMaxBECountRange.multiply(SExtStepSRange)) == ZExtEndURange) { APInt Min = APIntOps::umin(StartURange.getUnsignedMin(), @@ -4686,8 +4715,8 @@ ConstantRange ScalarEvolution::getRangeForAffineAR(const SCEV *Start, // Check for signed overflow. This must be done with ConstantRange // arithmetic because we could be called from within the ScalarEvolution // overflow checking code. - ConstantRange SExtStartSRange = StartSRange.sextOrTrunc(BitWidth * 2 + 1); - ConstantRange SExtEndSRange = EndSRange.sextOrTrunc(BitWidth * 2 + 1); + ConstantRange SExtStartSRange = StartSRange.sextOrTrunc(BitWidth * 2); + ConstantRange SExtEndSRange = EndSRange.sextOrTrunc(BitWidth * 2); if (SExtStartSRange.add(ZExtMaxBECountRange.multiply(SExtStepSRange)) == SExtEndSRange) { APInt Min = @@ -4951,17 +4980,33 @@ bool ScalarEvolution::isAddRecNeverPoison(const Instruction *I, const Loop *L) { return LatchControlDependentOnPoison && loopHasNoAbnormalExits(L); } -bool ScalarEvolution::loopHasNoAbnormalExits(const Loop *L) { - auto Itr = LoopHasNoAbnormalExits.find(L); - if (Itr == LoopHasNoAbnormalExits.end()) { - auto NoAbnormalExitInBB = [&](BasicBlock *BB) { - return all_of(*BB, [](Instruction &I) { - return isGuaranteedToTransferExecutionToSuccessor(&I); - }); +ScalarEvolution::LoopProperties +ScalarEvolution::getLoopProperties(const Loop *L) { + typedef ScalarEvolution::LoopProperties LoopProperties; + + auto Itr = LoopPropertiesCache.find(L); + if (Itr == LoopPropertiesCache.end()) { + auto HasSideEffects = [](Instruction *I) { + if (auto *SI = dyn_cast<StoreInst>(I)) + return !SI->isSimple(); + + return I->mayHaveSideEffects(); }; - auto InsertPair = LoopHasNoAbnormalExits.insert( - {L, all_of(L->getBlocks(), NoAbnormalExitInBB)}); + LoopProperties LP = {/* HasNoAbnormalExits */ true, + /*HasNoSideEffects*/ true}; + + for (auto *BB : L->getBlocks()) + for (auto &I : *BB) { + if (!isGuaranteedToTransferExecutionToSuccessor(&I)) + LP.HasNoAbnormalExits = false; + if (HasSideEffects(&I)) + LP.HasNoSideEffects = false; + if (!LP.HasNoAbnormalExits && !LP.HasNoSideEffects) + break; // We're already as pessimistic as we can get. + } + + auto InsertPair = LoopPropertiesCache.insert({L, LP}); assert(InsertPair.second && "We just checked!"); Itr = InsertPair.first; } @@ -5289,6 +5334,20 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { // Iteration Count Computation Code // +static unsigned getConstantTripCount(const SCEVConstant *ExitCount) { + if (!ExitCount) + return 0; + + ConstantInt *ExitConst = ExitCount->getValue(); + + // Guard against huge trip counts. + if (ExitConst->getValue().getActiveBits() > 32) + return 0; + + // In case of integer overflow, this returns 0, which is correct. + return ((unsigned)ExitConst->getZExtValue()) + 1; +} + unsigned ScalarEvolution::getSmallConstantTripCount(Loop *L) { if (BasicBlock *ExitingBB = L->getExitingBlock()) return getSmallConstantTripCount(L, ExitingBB); @@ -5304,17 +5363,13 @@ unsigned ScalarEvolution::getSmallConstantTripCount(Loop *L, "Exiting block must actually branch out of the loop!"); const SCEVConstant *ExitCount = dyn_cast<SCEVConstant>(getExitCount(L, ExitingBlock)); - if (!ExitCount) - return 0; - - ConstantInt *ExitConst = ExitCount->getValue(); - - // Guard against huge trip counts. - if (ExitConst->getValue().getActiveBits() > 32) - return 0; + return getConstantTripCount(ExitCount); +} - // In case of integer overflow, this returns 0, which is correct. - return ((unsigned)ExitConst->getZExtValue()) + 1; +unsigned ScalarEvolution::getSmallConstantMaxTripCount(Loop *L) { + const auto *MaxExitCount = + dyn_cast<SCEVConstant>(getMaxBackedgeTakenCount(L)); + return getConstantTripCount(MaxExitCount); } unsigned ScalarEvolution::getSmallConstantTripMultiple(Loop *L) { @@ -5393,6 +5448,10 @@ const SCEV *ScalarEvolution::getMaxBackedgeTakenCount(const Loop *L) { return getBackedgeTakenInfo(L).getMax(this); } +bool ScalarEvolution::isBackedgeTakenCountMaxOrZero(const Loop *L) { + return getBackedgeTakenInfo(L).isMaxOrZero(this); +} + /// Push PHI nodes in the header of the given loop onto the given Worklist. static void PushLoopPHIs(const Loop *L, SmallVectorImpl<Instruction *> &Worklist) { @@ -5418,7 +5477,7 @@ ScalarEvolution::getPredicatedBackedgeTakenInfo(const Loop *L) { BackedgeTakenInfo Result = computeBackedgeTakenCount(L, /*AllowPredicates=*/true); - return PredicatedBackedgeTakenCounts.find(L)->second = Result; + return PredicatedBackedgeTakenCounts.find(L)->second = std::move(Result); } const ScalarEvolution::BackedgeTakenInfo & @@ -5493,7 +5552,7 @@ ScalarEvolution::getBackedgeTakenInfo(const Loop *L) { // recusive call to getBackedgeTakenInfo (on a different // loop), which would invalidate the iterator computed // earlier. - return BackedgeTakenCounts.find(L)->second = Result; + return BackedgeTakenCounts.find(L)->second = std::move(Result); } void ScalarEvolution::forgetLoop(const Loop *L) { @@ -5537,7 +5596,7 @@ void ScalarEvolution::forgetLoop(const Loop *L) { for (Loop *I : *L) forgetLoop(I); - LoopHasNoAbnormalExits.erase(L); + LoopPropertiesCache.erase(L); } void ScalarEvolution::forgetValue(Value *V) { @@ -5576,14 +5635,11 @@ void ScalarEvolution::forgetValue(Value *V) { /// caller's responsibility to specify the relevant loop exit using /// getExact(ExitingBlock, SE). const SCEV * -ScalarEvolution::BackedgeTakenInfo::getExact( - ScalarEvolution *SE, SCEVUnionPredicate *Preds) const { +ScalarEvolution::BackedgeTakenInfo::getExact(ScalarEvolution *SE, + SCEVUnionPredicate *Preds) const { // If any exits were not computable, the loop is not computable. - if (!ExitNotTaken.isCompleteList()) return SE->getCouldNotCompute(); - - // We need exactly one computable exit. - if (!ExitNotTaken.ExitingBlock) return SE->getCouldNotCompute(); - assert(ExitNotTaken.ExactNotTaken && "uninitialized not-taken info"); + if (!isComplete() || ExitNotTaken.empty()) + return SE->getCouldNotCompute(); const SCEV *BECount = nullptr; for (auto &ENT : ExitNotTaken) { @@ -5593,10 +5649,10 @@ ScalarEvolution::BackedgeTakenInfo::getExact( BECount = ENT.ExactNotTaken; else if (BECount != ENT.ExactNotTaken) return SE->getCouldNotCompute(); - if (Preds && ENT.getPred()) - Preds->add(ENT.getPred()); + if (Preds && !ENT.hasAlwaysTruePredicate()) + Preds->add(ENT.Predicate.get()); - assert((Preds || ENT.hasAlwaysTruePred()) && + assert((Preds || ENT.hasAlwaysTruePredicate()) && "Predicate should be always true!"); } @@ -5609,7 +5665,7 @@ const SCEV * ScalarEvolution::BackedgeTakenInfo::getExact(BasicBlock *ExitingBlock, ScalarEvolution *SE) const { for (auto &ENT : ExitNotTaken) - if (ENT.ExitingBlock == ExitingBlock && ENT.hasAlwaysTruePred()) + if (ENT.ExitingBlock == ExitingBlock && ENT.hasAlwaysTruePredicate()) return ENT.ExactNotTaken; return SE->getCouldNotCompute(); @@ -5618,21 +5674,29 @@ ScalarEvolution::BackedgeTakenInfo::getExact(BasicBlock *ExitingBlock, /// getMax - Get the max backedge taken count for the loop. const SCEV * ScalarEvolution::BackedgeTakenInfo::getMax(ScalarEvolution *SE) const { - for (auto &ENT : ExitNotTaken) - if (!ENT.hasAlwaysTruePred()) - return SE->getCouldNotCompute(); + auto PredicateNotAlwaysTrue = [](const ExitNotTakenInfo &ENT) { + return !ENT.hasAlwaysTruePredicate(); + }; - return Max ? Max : SE->getCouldNotCompute(); + if (any_of(ExitNotTaken, PredicateNotAlwaysTrue) || !getMax()) + return SE->getCouldNotCompute(); + + return getMax(); +} + +bool ScalarEvolution::BackedgeTakenInfo::isMaxOrZero(ScalarEvolution *SE) const { + auto PredicateNotAlwaysTrue = [](const ExitNotTakenInfo &ENT) { + return !ENT.hasAlwaysTruePredicate(); + }; + return MaxOrZero && !any_of(ExitNotTaken, PredicateNotAlwaysTrue); } bool ScalarEvolution::BackedgeTakenInfo::hasOperand(const SCEV *S, ScalarEvolution *SE) const { - if (Max && Max != SE->getCouldNotCompute() && SE->hasOperand(Max, S)) + if (getMax() && getMax() != SE->getCouldNotCompute() && + SE->hasOperand(getMax(), S)) return true; - if (!ExitNotTaken.ExitingBlock) - return false; - for (auto &ENT : ExitNotTaken) if (ENT.ExactNotTaken != SE->getCouldNotCompute() && SE->hasOperand(ENT.ExactNotTaken, S)) @@ -5644,62 +5708,31 @@ bool ScalarEvolution::BackedgeTakenInfo::hasOperand(const SCEV *S, /// Allocate memory for BackedgeTakenInfo and copy the not-taken count of each /// computable exit into a persistent ExitNotTakenInfo array. ScalarEvolution::BackedgeTakenInfo::BackedgeTakenInfo( - SmallVectorImpl<EdgeInfo> &ExitCounts, bool Complete, const SCEV *MaxCount) - : Max(MaxCount) { - - if (!Complete) - ExitNotTaken.setIncomplete(); - - unsigned NumExits = ExitCounts.size(); - if (NumExits == 0) return; - - ExitNotTaken.ExitingBlock = ExitCounts[0].ExitBlock; - ExitNotTaken.ExactNotTaken = ExitCounts[0].Taken; - - // Determine the number of ExitNotTakenExtras structures that we need. - unsigned ExtraInfoSize = 0; - if (NumExits > 1) - ExtraInfoSize = 1 + std::count_if(std::next(ExitCounts.begin()), - ExitCounts.end(), [](EdgeInfo &Entry) { - return !Entry.Pred.isAlwaysTrue(); - }); - else if (!ExitCounts[0].Pred.isAlwaysTrue()) - ExtraInfoSize = 1; - - ExitNotTakenExtras *ENT = nullptr; - - // Allocate the ExitNotTakenExtras structures and initialize the first - // element (ExitNotTaken). - if (ExtraInfoSize > 0) { - ENT = new ExitNotTakenExtras[ExtraInfoSize]; - ExitNotTaken.ExtraInfo = &ENT[0]; - *ExitNotTaken.getPred() = std::move(ExitCounts[0].Pred); - } - - if (NumExits == 1) - return; - - assert(ENT && "ExitNotTakenExtras is NULL while having more than one exit"); - - auto &Exits = ExitNotTaken.ExtraInfo->Exits; - - // Handle the rare case of multiple computable exits. - for (unsigned i = 1, PredPos = 1; i < NumExits; ++i) { - ExitNotTakenExtras *Ptr = nullptr; - if (!ExitCounts[i].Pred.isAlwaysTrue()) { - Ptr = &ENT[PredPos++]; - Ptr->Pred = std::move(ExitCounts[i].Pred); - } - - Exits.emplace_back(ExitCounts[i].ExitBlock, ExitCounts[i].Taken, Ptr); - } + SmallVectorImpl<ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo> + &&ExitCounts, + bool Complete, const SCEV *MaxCount, bool MaxOrZero) + : MaxAndComplete(MaxCount, Complete), MaxOrZero(MaxOrZero) { + typedef ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo EdgeExitInfo; + ExitNotTaken.reserve(ExitCounts.size()); + std::transform( + ExitCounts.begin(), ExitCounts.end(), std::back_inserter(ExitNotTaken), + [&](const EdgeExitInfo &EEI) { + BasicBlock *ExitBB = EEI.first; + const ExitLimit &EL = EEI.second; + if (EL.Predicates.empty()) + return ExitNotTakenInfo(ExitBB, EL.ExactNotTaken, nullptr); + + std::unique_ptr<SCEVUnionPredicate> Predicate(new SCEVUnionPredicate); + for (auto *Pred : EL.Predicates) + Predicate->add(Pred); + + return ExitNotTakenInfo(ExitBB, EL.ExactNotTaken, std::move(Predicate)); + }); } /// Invalidate this result and free the ExitNotTakenInfo array. void ScalarEvolution::BackedgeTakenInfo::clear() { - ExitNotTaken.ExitingBlock = nullptr; - ExitNotTaken.ExactNotTaken = nullptr; - delete[] ExitNotTaken.ExtraInfo; + ExitNotTaken.clear(); } /// Compute the number of times the backedge of the specified loop will execute. @@ -5709,11 +5742,14 @@ ScalarEvolution::computeBackedgeTakenCount(const Loop *L, SmallVector<BasicBlock *, 8> ExitingBlocks; L->getExitingBlocks(ExitingBlocks); - SmallVector<EdgeInfo, 4> ExitCounts; + typedef ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo EdgeExitInfo; + + SmallVector<EdgeExitInfo, 4> ExitCounts; bool CouldComputeBECount = true; BasicBlock *Latch = L->getLoopLatch(); // may be NULL. const SCEV *MustExitMaxBECount = nullptr; const SCEV *MayExitMaxBECount = nullptr; + bool MustExitMaxOrZero = false; // Compute the ExitLimit for each loop exit. Use this to populate ExitCounts // and compute maxBECount. @@ -5722,17 +5758,17 @@ ScalarEvolution::computeBackedgeTakenCount(const Loop *L, BasicBlock *ExitBB = ExitingBlocks[i]; ExitLimit EL = computeExitLimit(L, ExitBB, AllowPredicates); - assert((AllowPredicates || EL.Pred.isAlwaysTrue()) && + assert((AllowPredicates || EL.Predicates.empty()) && "Predicated exit limit when predicates are not allowed!"); // 1. For each exit that can be computed, add an entry to ExitCounts. // CouldComputeBECount is true only if all exits can be computed. - if (EL.Exact == getCouldNotCompute()) + if (EL.ExactNotTaken == getCouldNotCompute()) // We couldn't compute an exact value for this exit, so // we won't be able to compute an exact value for the loop. CouldComputeBECount = false; else - ExitCounts.emplace_back(EdgeInfo(ExitBB, EL.Exact, EL.Pred)); + ExitCounts.emplace_back(ExitBB, EL); // 2. Derive the loop's MaxBECount from each exit's max number of // non-exiting iterations. Partition the loop exits into two kinds: @@ -5740,29 +5776,35 @@ ScalarEvolution::computeBackedgeTakenCount(const Loop *L, // // If the exit dominates the loop latch, it is a LoopMustExit otherwise it // is a LoopMayExit. If any computable LoopMustExit is found, then - // MaxBECount is the minimum EL.Max of computable LoopMustExits. Otherwise, - // MaxBECount is conservatively the maximum EL.Max, where CouldNotCompute is - // considered greater than any computable EL.Max. - if (EL.Max != getCouldNotCompute() && Latch && + // MaxBECount is the minimum EL.MaxNotTaken of computable + // LoopMustExits. Otherwise, MaxBECount is conservatively the maximum + // EL.MaxNotTaken, where CouldNotCompute is considered greater than any + // computable EL.MaxNotTaken. + if (EL.MaxNotTaken != getCouldNotCompute() && Latch && DT.dominates(ExitBB, Latch)) { - if (!MustExitMaxBECount) - MustExitMaxBECount = EL.Max; - else { + if (!MustExitMaxBECount) { + MustExitMaxBECount = EL.MaxNotTaken; + MustExitMaxOrZero = EL.MaxOrZero; + } else { MustExitMaxBECount = - getUMinFromMismatchedTypes(MustExitMaxBECount, EL.Max); + getUMinFromMismatchedTypes(MustExitMaxBECount, EL.MaxNotTaken); } } else if (MayExitMaxBECount != getCouldNotCompute()) { - if (!MayExitMaxBECount || EL.Max == getCouldNotCompute()) - MayExitMaxBECount = EL.Max; + if (!MayExitMaxBECount || EL.MaxNotTaken == getCouldNotCompute()) + MayExitMaxBECount = EL.MaxNotTaken; else { MayExitMaxBECount = - getUMaxFromMismatchedTypes(MayExitMaxBECount, EL.Max); + getUMaxFromMismatchedTypes(MayExitMaxBECount, EL.MaxNotTaken); } } } const SCEV *MaxBECount = MustExitMaxBECount ? MustExitMaxBECount : (MayExitMaxBECount ? MayExitMaxBECount : getCouldNotCompute()); - return BackedgeTakenInfo(ExitCounts, CouldComputeBECount, MaxBECount); + // The loop backedge will be taken the maximum or zero times if there's + // a single exit that must be taken the maximum or zero times. + bool MaxOrZero = (MustExitMaxOrZero && ExitingBlocks.size() == 1); + return BackedgeTakenInfo(std::move(ExitCounts), CouldComputeBECount, + MaxBECount, MaxOrZero); } ScalarEvolution::ExitLimit @@ -5867,39 +5909,40 @@ ScalarEvolution::computeExitLimitFromCond(const Loop *L, if (EitherMayExit) { // Both conditions must be true for the loop to continue executing. // Choose the less conservative count. - if (EL0.Exact == getCouldNotCompute() || - EL1.Exact == getCouldNotCompute()) + if (EL0.ExactNotTaken == getCouldNotCompute() || + EL1.ExactNotTaken == getCouldNotCompute()) BECount = getCouldNotCompute(); else - BECount = getUMinFromMismatchedTypes(EL0.Exact, EL1.Exact); - if (EL0.Max == getCouldNotCompute()) - MaxBECount = EL1.Max; - else if (EL1.Max == getCouldNotCompute()) - MaxBECount = EL0.Max; + BECount = + getUMinFromMismatchedTypes(EL0.ExactNotTaken, EL1.ExactNotTaken); + if (EL0.MaxNotTaken == getCouldNotCompute()) + MaxBECount = EL1.MaxNotTaken; + else if (EL1.MaxNotTaken == getCouldNotCompute()) + MaxBECount = EL0.MaxNotTaken; else - MaxBECount = getUMinFromMismatchedTypes(EL0.Max, EL1.Max); + MaxBECount = + getUMinFromMismatchedTypes(EL0.MaxNotTaken, EL1.MaxNotTaken); } else { // Both conditions must be true at the same time for the loop to exit. // For now, be conservative. assert(L->contains(FBB) && "Loop block has no successor in loop!"); - if (EL0.Max == EL1.Max) - MaxBECount = EL0.Max; - if (EL0.Exact == EL1.Exact) - BECount = EL0.Exact; + if (EL0.MaxNotTaken == EL1.MaxNotTaken) + MaxBECount = EL0.MaxNotTaken; + if (EL0.ExactNotTaken == EL1.ExactNotTaken) + BECount = EL0.ExactNotTaken; } - SCEVUnionPredicate NP; - NP.add(&EL0.Pred); - NP.add(&EL1.Pred); // There are cases (e.g. PR26207) where computeExitLimitFromCond is able // to be more aggressive when computing BECount than when computing - // MaxBECount. In these cases it is possible for EL0.Exact and EL1.Exact - // to match, but for EL0.Max and EL1.Max to not. + // MaxBECount. In these cases it is possible for EL0.ExactNotTaken and + // EL1.ExactNotTaken to match, but for EL0.MaxNotTaken and EL1.MaxNotTaken + // to not. if (isa<SCEVCouldNotCompute>(MaxBECount) && !isa<SCEVCouldNotCompute>(BECount)) MaxBECount = BECount; - return ExitLimit(BECount, MaxBECount, NP); + return ExitLimit(BECount, MaxBECount, false, + {&EL0.Predicates, &EL1.Predicates}); } if (BO->getOpcode() == Instruction::Or) { // Recurse on the operands of the or. @@ -5915,31 +5958,31 @@ ScalarEvolution::computeExitLimitFromCond(const Loop *L, if (EitherMayExit) { // Both conditions must be false for the loop to continue executing. // Choose the less conservative count. - if (EL0.Exact == getCouldNotCompute() || - EL1.Exact == getCouldNotCompute()) + if (EL0.ExactNotTaken == getCouldNotCompute() || + EL1.ExactNotTaken == getCouldNotCompute()) BECount = getCouldNotCompute(); else - BECount = getUMinFromMismatchedTypes(EL0.Exact, EL1.Exact); - if (EL0.Max == getCouldNotCompute()) - MaxBECount = EL1.Max; - else if (EL1.Max == getCouldNotCompute()) - MaxBECount = EL0.Max; + BECount = + getUMinFromMismatchedTypes(EL0.ExactNotTaken, EL1.ExactNotTaken); + if (EL0.MaxNotTaken == getCouldNotCompute()) + MaxBECount = EL1.MaxNotTaken; + else if (EL1.MaxNotTaken == getCouldNotCompute()) + MaxBECount = EL0.MaxNotTaken; else - MaxBECount = getUMinFromMismatchedTypes(EL0.Max, EL1.Max); + MaxBECount = + getUMinFromMismatchedTypes(EL0.MaxNotTaken, EL1.MaxNotTaken); } else { // Both conditions must be false at the same time for the loop to exit. // For now, be conservative. assert(L->contains(TBB) && "Loop block has no successor in loop!"); - if (EL0.Max == EL1.Max) - MaxBECount = EL0.Max; - if (EL0.Exact == EL1.Exact) - BECount = EL0.Exact; + if (EL0.MaxNotTaken == EL1.MaxNotTaken) + MaxBECount = EL0.MaxNotTaken; + if (EL0.ExactNotTaken == EL1.ExactNotTaken) + BECount = EL0.ExactNotTaken; } - SCEVUnionPredicate NP; - NP.add(&EL0.Pred); - NP.add(&EL1.Pred); - return ExitLimit(BECount, MaxBECount, NP); + return ExitLimit(BECount, MaxBECount, false, + {&EL0.Predicates, &EL1.Predicates}); } } @@ -6021,8 +6064,8 @@ ScalarEvolution::computeExitLimitFromICmp(const Loop *L, if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(LHS)) if (AddRec->getLoop() == L) { // Form the constant range. - ConstantRange CompRange( - ICmpInst::makeConstantRange(Cond, RHSC->getAPInt())); + ConstantRange CompRange = + ConstantRange::makeExactICmpRegion(Cond, RHSC->getAPInt()); const SCEV *Ret = AddRec->getNumIterationsInRange(CompRange, *this); if (!isa<SCEVCouldNotCompute>(Ret)) return Ret; @@ -6226,7 +6269,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeShiftCompareExitLimit( // %iv = phi i32 [ %iv.shifted, %loop ], [ %val, %preheader ] // %iv.shifted = lshr i32 %iv, <positive constant> // - // Return true on a succesful match. Return the corresponding PHI node (%iv + // Return true on a successful match. Return the corresponding PHI node (%iv // above) in PNOut and the opcode of the shift operation in OpCodeOut. auto MatchShiftRecurrence = [&](Value *V, PHINode *&PNOut, Instruction::BinaryOps &OpCodeOut) { @@ -6324,8 +6367,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeShiftCompareExitLimit( unsigned BitWidth = getTypeSizeInBits(RHS->getType()); const SCEV *UpperBound = getConstant(getEffectiveSCEVType(RHS->getType()), BitWidth); - SCEVUnionPredicate P; - return ExitLimit(getCouldNotCompute(), UpperBound, P); + return ExitLimit(getCouldNotCompute(), UpperBound, false); } return getCouldNotCompute(); @@ -6995,20 +7037,21 @@ static const SCEV *SolveLinEquationWithOverflow(const APInt &A, const APInt &B, // 3. Compute I: the multiplicative inverse of (A / D) in arithmetic // modulo (N / D). // - // (N / D) may need BW+1 bits in its representation. Hence, we'll use this - // bit width during computations. + // If D == 1, (N / D) == N == 2^BW, so we need one extra bit to represent + // (N / D) in general. The inverse itself always fits into BW bits, though, + // so we immediately truncate it. APInt AD = A.lshr(Mult2).zext(BW + 1); // AD = A / D APInt Mod(BW + 1, 0); Mod.setBit(BW - Mult2); // Mod = N / D - APInt I = AD.multiplicativeInverse(Mod); + APInt I = AD.multiplicativeInverse(Mod).trunc(BW); // 4. Compute the minimum unsigned root of the equation: // I * (B / D) mod (N / D) - APInt Result = (I * B.lshr(Mult2).zext(BW + 1)).urem(Mod); + // To simplify the computation, we factor out the divide by D: + // (I * B mod N) / D + APInt Result = (I * B).lshr(Mult2); - // The result is guaranteed to be less than 2^BW so we may truncate it to BW - // bits. - return SE.getConstant(Result.trunc(BW)); + return SE.getConstant(Result); } /// Find the roots of the quadratic equation for the given quadratic chrec @@ -7086,7 +7129,7 @@ ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L, bool ControlsExit, // effectively V != 0. We know and take advantage of the fact that this // expression only being used in a comparison by zero context. - SCEVUnionPredicate P; + SmallPtrSet<const SCEVPredicate *, 4> Predicates; // If the value is a constant if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) { // If the value is already zero, the branch will execute zero times. @@ -7099,7 +7142,7 @@ ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L, bool ControlsExit, // Try to make this an AddRec using runtime tests, in the first X // iterations of this loop, where X is the SCEV expression found by the // algorithm below. - AddRec = convertSCEVToAddRecWithPredicates(V, L, P); + AddRec = convertSCEVToAddRecWithPredicates(V, L, Predicates); if (!AddRec || AddRec->getLoop() != L) return getCouldNotCompute(); @@ -7121,7 +7164,8 @@ ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L, bool ControlsExit, // should not accept a root of 2. const SCEV *Val = AddRec->evaluateAtIteration(R1, *this); if (Val->isZero()) - return ExitLimit(R1, R1, P); // We found a quadratic root! + // We found a quadratic root! + return ExitLimit(R1, R1, false, Predicates); } } return getCouldNotCompute(); @@ -7168,17 +7212,25 @@ ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L, bool ControlsExit, // 1*N = -Start; -1*N = Start (mod 2^BW), so: // N = Distance (as unsigned) if (StepC->getValue()->equalsInt(1) || StepC->getValue()->isAllOnesValue()) { - ConstantRange CR = getUnsignedRange(Start); - const SCEV *MaxBECount; - if (!CountDown && CR.getUnsignedMin().isMinValue()) - // When counting up, the worst starting value is 1, not 0. - MaxBECount = CR.getUnsignedMax().isMinValue() - ? getConstant(APInt::getMinValue(CR.getBitWidth())) - : getConstant(APInt::getMaxValue(CR.getBitWidth())); - else - MaxBECount = getConstant(CountDown ? CR.getUnsignedMax() - : -CR.getUnsignedMin()); - return ExitLimit(Distance, MaxBECount, P); + APInt MaxBECount = getUnsignedRange(Distance).getUnsignedMax(); + + // When a loop like "for (int i = 0; i != n; ++i) { /* body */ }" is rotated, + // we end up with a loop whose backedge-taken count is n - 1. Detect this + // case, and see if we can improve the bound. + // + // Explicitly handling this here is necessary because getUnsignedRange + // isn't context-sensitive; it doesn't know that we only care about the + // range inside the loop. + const SCEV *Zero = getZero(Distance->getType()); + const SCEV *One = getOne(Distance->getType()); + const SCEV *DistancePlusOne = getAddExpr(Distance, One); + if (isLoopEntryGuardedByCond(L, ICmpInst::ICMP_NE, DistancePlusOne, Zero)) { + // If Distance + 1 doesn't overflow, we can compute the maximum distance + // as "unsigned_max(Distance + 1) - 1". + ConstantRange CR = getUnsignedRange(DistancePlusOne); + MaxBECount = APIntOps::umin(MaxBECount, CR.getUnsignedMax() - 1); + } + return ExitLimit(Distance, getConstant(MaxBECount), false, Predicates); } // As a special case, handle the instance where Step is a positive power of @@ -7233,7 +7285,7 @@ ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L, bool ControlsExit, const SCEV *Limit = getZeroExtendExpr(getTruncateExpr(ModuloResult, NarrowTy), WideTy); - return ExitLimit(Limit, Limit, P); + return ExitLimit(Limit, Limit, false, Predicates); } } @@ -7246,14 +7298,14 @@ ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L, bool ControlsExit, loopHasNoAbnormalExits(AddRec->getLoop())) { const SCEV *Exact = getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step); - return ExitLimit(Exact, Exact, P); + return ExitLimit(Exact, Exact, false, Predicates); } // Then, try to solve the above equation provided that Start is constant. if (const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start)) { const SCEV *E = SolveLinEquationWithOverflow( StepC->getValue()->getValue(), -StartC->getValue()->getValue(), *this); - return ExitLimit(E, E, P); + return ExitLimit(E, E, false, Predicates); } return getCouldNotCompute(); } @@ -7365,149 +7417,77 @@ bool ScalarEvolution::SimplifyICmpOperands(ICmpInst::Predicate &Pred, // cases, and canonicalize *-or-equal comparisons to regular comparisons. if (const SCEVConstant *RC = dyn_cast<SCEVConstant>(RHS)) { const APInt &RA = RC->getAPInt(); - switch (Pred) { - default: llvm_unreachable("Unexpected ICmpInst::Predicate value!"); - case ICmpInst::ICMP_EQ: - case ICmpInst::ICMP_NE: - // Fold ((-1) * %a) + %b == 0 (equivalent to %b-%a == 0) into %a == %b. - if (!RA) - if (const SCEVAddExpr *AE = dyn_cast<SCEVAddExpr>(LHS)) - if (const SCEVMulExpr *ME = dyn_cast<SCEVMulExpr>(AE->getOperand(0))) - if (AE->getNumOperands() == 2 && ME->getNumOperands() == 2 && - ME->getOperand(0)->isAllOnesValue()) { - RHS = AE->getOperand(1); - LHS = ME->getOperand(1); - Changed = true; - } - break; - case ICmpInst::ICMP_UGE: - if ((RA - 1).isMinValue()) { - Pred = ICmpInst::ICMP_NE; - RHS = getConstant(RA - 1); - Changed = true; - break; - } - if (RA.isMaxValue()) { - Pred = ICmpInst::ICMP_EQ; - Changed = true; - break; - } - if (RA.isMinValue()) goto trivially_true; - Pred = ICmpInst::ICMP_UGT; - RHS = getConstant(RA - 1); - Changed = true; - break; - case ICmpInst::ICMP_ULE: - if ((RA + 1).isMaxValue()) { - Pred = ICmpInst::ICMP_NE; - RHS = getConstant(RA + 1); - Changed = true; - break; - } - if (RA.isMinValue()) { - Pred = ICmpInst::ICMP_EQ; - Changed = true; - break; - } - if (RA.isMaxValue()) goto trivially_true; + bool SimplifiedByConstantRange = false; - Pred = ICmpInst::ICMP_ULT; - RHS = getConstant(RA + 1); - Changed = true; - break; - case ICmpInst::ICMP_SGE: - if ((RA - 1).isMinSignedValue()) { - Pred = ICmpInst::ICMP_NE; - RHS = getConstant(RA - 1); - Changed = true; - break; - } - if (RA.isMaxSignedValue()) { - Pred = ICmpInst::ICMP_EQ; - Changed = true; - break; + if (!ICmpInst::isEquality(Pred)) { + ConstantRange ExactCR = ConstantRange::makeExactICmpRegion(Pred, RA); + if (ExactCR.isFullSet()) + goto trivially_true; + else if (ExactCR.isEmptySet()) + goto trivially_false; + + APInt NewRHS; + CmpInst::Predicate NewPred; + if (ExactCR.getEquivalentICmp(NewPred, NewRHS) && + ICmpInst::isEquality(NewPred)) { + // We were able to convert an inequality to an equality. + Pred = NewPred; + RHS = getConstant(NewRHS); + Changed = SimplifiedByConstantRange = true; } - if (RA.isMinSignedValue()) goto trivially_true; + } - Pred = ICmpInst::ICMP_SGT; - RHS = getConstant(RA - 1); - Changed = true; - break; - case ICmpInst::ICMP_SLE: - if ((RA + 1).isMaxSignedValue()) { - Pred = ICmpInst::ICMP_NE; - RHS = getConstant(RA + 1); - Changed = true; + if (!SimplifiedByConstantRange) { + switch (Pred) { + default: break; - } - if (RA.isMinSignedValue()) { - Pred = ICmpInst::ICMP_EQ; - Changed = true; + case ICmpInst::ICMP_EQ: + case ICmpInst::ICMP_NE: + // Fold ((-1) * %a) + %b == 0 (equivalent to %b-%a == 0) into %a == %b. + if (!RA) + if (const SCEVAddExpr *AE = dyn_cast<SCEVAddExpr>(LHS)) + if (const SCEVMulExpr *ME = + dyn_cast<SCEVMulExpr>(AE->getOperand(0))) + if (AE->getNumOperands() == 2 && ME->getNumOperands() == 2 && + ME->getOperand(0)->isAllOnesValue()) { + RHS = AE->getOperand(1); + LHS = ME->getOperand(1); + Changed = true; + } break; - } - if (RA.isMaxSignedValue()) goto trivially_true; - Pred = ICmpInst::ICMP_SLT; - RHS = getConstant(RA + 1); - Changed = true; - break; - case ICmpInst::ICMP_UGT: - if (RA.isMinValue()) { - Pred = ICmpInst::ICMP_NE; + + // The "Should have been caught earlier!" messages refer to the fact + // that the ExactCR.isFullSet() or ExactCR.isEmptySet() check above + // should have fired on the corresponding cases, and canonicalized the + // check to trivially_true or trivially_false. + + case ICmpInst::ICMP_UGE: + assert(!RA.isMinValue() && "Should have been caught earlier!"); + Pred = ICmpInst::ICMP_UGT; + RHS = getConstant(RA - 1); Changed = true; break; - } - if ((RA + 1).isMaxValue()) { - Pred = ICmpInst::ICMP_EQ; + case ICmpInst::ICMP_ULE: + assert(!RA.isMaxValue() && "Should have been caught earlier!"); + Pred = ICmpInst::ICMP_ULT; RHS = getConstant(RA + 1); Changed = true; break; - } - if (RA.isMaxValue()) goto trivially_false; - break; - case ICmpInst::ICMP_ULT: - if (RA.isMaxValue()) { - Pred = ICmpInst::ICMP_NE; - Changed = true; - break; - } - if ((RA - 1).isMinValue()) { - Pred = ICmpInst::ICMP_EQ; + case ICmpInst::ICMP_SGE: + assert(!RA.isMinSignedValue() && "Should have been caught earlier!"); + Pred = ICmpInst::ICMP_SGT; RHS = getConstant(RA - 1); Changed = true; break; - } - if (RA.isMinValue()) goto trivially_false; - break; - case ICmpInst::ICMP_SGT: - if (RA.isMinSignedValue()) { - Pred = ICmpInst::ICMP_NE; - Changed = true; - break; - } - if ((RA + 1).isMaxSignedValue()) { - Pred = ICmpInst::ICMP_EQ; + case ICmpInst::ICMP_SLE: + assert(!RA.isMaxSignedValue() && "Should have been caught earlier!"); + Pred = ICmpInst::ICMP_SLT; RHS = getConstant(RA + 1); Changed = true; break; } - if (RA.isMaxSignedValue()) goto trivially_false; - break; - case ICmpInst::ICMP_SLT: - if (RA.isMaxSignedValue()) { - Pred = ICmpInst::ICMP_NE; - Changed = true; - break; - } - if ((RA - 1).isMinSignedValue()) { - Pred = ICmpInst::ICMP_EQ; - RHS = getConstant(RA - 1); - Changed = true; - break; - } - if (RA.isMinSignedValue()) goto trivially_false; - break; } } @@ -8067,34 +8047,16 @@ ScalarEvolution::isLoopEntryGuardedByCond(const Loop *L, return false; } -namespace { -/// RAII wrapper to prevent recursive application of isImpliedCond. -/// ScalarEvolution's PendingLoopPredicates set must be empty unless we are -/// currently evaluating isImpliedCond. -struct MarkPendingLoopPredicate { - Value *Cond; - DenseSet<Value*> &LoopPreds; - bool Pending; - - MarkPendingLoopPredicate(Value *C, DenseSet<Value*> &LP) - : Cond(C), LoopPreds(LP) { - Pending = !LoopPreds.insert(Cond).second; - } - ~MarkPendingLoopPredicate() { - if (!Pending) - LoopPreds.erase(Cond); - } -}; -} // end anonymous namespace - bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, Value *FoundCondValue, bool Inverse) { - MarkPendingLoopPredicate Mark(FoundCondValue, PendingLoopPredicates); - if (Mark.Pending) + if (!PendingLoopPredicates.insert(FoundCondValue).second) return false; + auto ClearOnExit = + make_scope_exit([&]() { PendingLoopPredicates.erase(FoundCondValue); }); + // Recursively handle And and Or conditions. if (BinaryOperator *BO = dyn_cast<BinaryOperator>(FoundCondValue)) { if (BO->getOpcode() == Instruction::And) { @@ -8279,9 +8241,8 @@ bool ScalarEvolution::splitBinaryAdd(const SCEV *Expr, return true; } -bool ScalarEvolution::computeConstantDifference(const SCEV *Less, - const SCEV *More, - APInt &C) { +Optional<APInt> ScalarEvolution::computeConstantDifference(const SCEV *More, + const SCEV *Less) { // We avoid subtracting expressions here because this function is usually // fairly deep in the call stack (i.e. is called many times). @@ -8290,15 +8251,15 @@ bool ScalarEvolution::computeConstantDifference(const SCEV *Less, const auto *MAR = cast<SCEVAddRecExpr>(More); if (LAR->getLoop() != MAR->getLoop()) - return false; + return None; // We look at affine expressions only; not for correctness but to keep // getStepRecurrence cheap. if (!LAR->isAffine() || !MAR->isAffine()) - return false; + return None; if (LAR->getStepRecurrence(*this) != MAR->getStepRecurrence(*this)) - return false; + return None; Less = LAR->getStart(); More = MAR->getStart(); @@ -8309,27 +8270,22 @@ bool ScalarEvolution::computeConstantDifference(const SCEV *Less, if (isa<SCEVConstant>(Less) && isa<SCEVConstant>(More)) { const auto &M = cast<SCEVConstant>(More)->getAPInt(); const auto &L = cast<SCEVConstant>(Less)->getAPInt(); - C = M - L; - return true; + return M - L; } const SCEV *L, *R; SCEV::NoWrapFlags Flags; if (splitBinaryAdd(Less, L, R, Flags)) if (const auto *LC = dyn_cast<SCEVConstant>(L)) - if (R == More) { - C = -(LC->getAPInt()); - return true; - } + if (R == More) + return -(LC->getAPInt()); if (splitBinaryAdd(More, L, R, Flags)) if (const auto *LC = dyn_cast<SCEVConstant>(L)) - if (R == Less) { - C = LC->getAPInt(); - return true; - } + if (R == Less) + return LC->getAPInt(); - return false; + return None; } bool ScalarEvolution::isImpliedCondOperandsViaNoOverflow( @@ -8386,22 +8342,21 @@ bool ScalarEvolution::isImpliedCondOperandsViaNoOverflow( // neither necessary nor sufficient to prove "(FoundLHS + C) s< (FoundRHS + // C)". - APInt LDiff, RDiff; - if (!computeConstantDifference(FoundLHS, LHS, LDiff) || - !computeConstantDifference(FoundRHS, RHS, RDiff) || - LDiff != RDiff) + Optional<APInt> LDiff = computeConstantDifference(LHS, FoundLHS); + Optional<APInt> RDiff = computeConstantDifference(RHS, FoundRHS); + if (!LDiff || !RDiff || *LDiff != *RDiff) return false; - if (LDiff == 0) + if (LDiff->isMinValue()) return true; APInt FoundRHSLimit; if (Pred == CmpInst::ICMP_ULT) { - FoundRHSLimit = -RDiff; + FoundRHSLimit = -(*RDiff); } else { assert(Pred == CmpInst::ICMP_SLT && "Checked above!"); - FoundRHSLimit = APInt::getSignedMinValue(getTypeSizeInBits(RHS->getType())) - RDiff; + FoundRHSLimit = APInt::getSignedMinValue(getTypeSizeInBits(RHS->getType())) - *RDiff; } // Try to prove (1) or (2), as needed. @@ -8511,7 +8466,7 @@ static bool IsKnownPredicateViaMinOrMax(ScalarEvolution &SE, case ICmpInst::ICMP_SGE: std::swap(LHS, RHS); - // fall through + LLVM_FALLTHROUGH; case ICmpInst::ICMP_SLE: return // min(A, ...) <= A @@ -8521,7 +8476,7 @@ static bool IsKnownPredicateViaMinOrMax(ScalarEvolution &SE, case ICmpInst::ICMP_UGE: std::swap(LHS, RHS); - // fall through + LLVM_FALLTHROUGH; case ICmpInst::ICMP_ULE: return // min(A, ...) <= A @@ -8592,9 +8547,8 @@ bool ScalarEvolution::isImpliedCondOperandsViaRanges(ICmpInst::Predicate Pred, // reduce the compile time impact of this optimization. return false; - const SCEVAddExpr *AddLHS = dyn_cast<SCEVAddExpr>(LHS); - if (!AddLHS || AddLHS->getOperand(1) != FoundLHS || - !isa<SCEVConstant>(AddLHS->getOperand(0))) + Optional<APInt> Addend = computeConstantDifference(LHS, FoundLHS); + if (!Addend) return false; APInt ConstFoundRHS = cast<SCEVConstant>(FoundRHS)->getAPInt(); @@ -8604,10 +8558,8 @@ bool ScalarEvolution::isImpliedCondOperandsViaRanges(ICmpInst::Predicate Pred, ConstantRange FoundLHSRange = ConstantRange::makeAllowedICmpRegion(Pred, ConstFoundRHS); - // Since `LHS` is `FoundLHS` + `AddLHS->getOperand(0)`, we can compute a range - // for `LHS`: - APInt Addend = cast<SCEVConstant>(AddLHS->getOperand(0))->getAPInt(); - ConstantRange LHSRange = FoundLHSRange.add(ConstantRange(Addend)); + // Since `LHS` is `FoundLHS` + `Addend`, we can compute a range for `LHS`: + ConstantRange LHSRange = FoundLHSRange.add(ConstantRange(*Addend)); // We can also compute the range of values for `LHS` that satisfy the // consequent, "`LHS` `Pred` `RHS`": @@ -8622,6 +8574,8 @@ bool ScalarEvolution::isImpliedCondOperandsViaRanges(ICmpInst::Predicate Pred, bool ScalarEvolution::doesIVOverflowOnLT(const SCEV *RHS, const SCEV *Stride, bool IsSigned, bool NoWrap) { + assert(isKnownPositive(Stride) && "Positive stride expected!"); + if (NoWrap) return false; unsigned BitWidth = getTypeSizeInBits(RHS->getType()); @@ -8684,17 +8638,21 @@ ScalarEvolution::ExitLimit ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS, const Loop *L, bool IsSigned, bool ControlsExit, bool AllowPredicates) { - SCEVUnionPredicate P; + SmallPtrSet<const SCEVPredicate *, 4> Predicates; // We handle only IV < Invariant if (!isLoopInvariant(RHS, L)) return getCouldNotCompute(); const SCEVAddRecExpr *IV = dyn_cast<SCEVAddRecExpr>(LHS); - if (!IV && AllowPredicates) + bool PredicatedIV = false; + + if (!IV && AllowPredicates) { // Try to make this an AddRec using runtime tests, in the first X // iterations of this loop, where X is the SCEV expression found by the // algorithm below. - IV = convertSCEVToAddRecWithPredicates(LHS, L, P); + IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates); + PredicatedIV = true; + } // Avoid weird loops if (!IV || IV->getLoop() != L || !IV->isAffine()) @@ -8705,61 +8663,144 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS, const SCEV *Stride = IV->getStepRecurrence(*this); - // Avoid negative or zero stride values - if (!isKnownPositive(Stride)) - return getCouldNotCompute(); + bool PositiveStride = isKnownPositive(Stride); - // Avoid proven overflow cases: this will ensure that the backedge taken count - // will not generate any unsigned overflow. Relaxed no-overflow conditions - // exploit NoWrapFlags, allowing to optimize in presence of undefined - // behaviors like the case of C language. - if (!Stride->isOne() && doesIVOverflowOnLT(RHS, Stride, IsSigned, NoWrap)) + // Avoid negative or zero stride values. + if (!PositiveStride) { + // We can compute the correct backedge taken count for loops with unknown + // strides if we can prove that the loop is not an infinite loop with side + // effects. Here's the loop structure we are trying to handle - + // + // i = start + // do { + // A[i] = i; + // i += s; + // } while (i < end); + // + // The backedge taken count for such loops is evaluated as - + // (max(end, start + stride) - start - 1) /u stride + // + // The additional preconditions that we need to check to prove correctness + // of the above formula is as follows - + // + // a) IV is either nuw or nsw depending upon signedness (indicated by the + // NoWrap flag). + // b) loop is single exit with no side effects. + // + // + // Precondition a) implies that if the stride is negative, this is a single + // trip loop. The backedge taken count formula reduces to zero in this case. + // + // Precondition b) implies that the unknown stride cannot be zero otherwise + // we have UB. + // + // The positive stride case is the same as isKnownPositive(Stride) returning + // true (original behavior of the function). + // + // We want to make sure that the stride is truly unknown as there are edge + // cases where ScalarEvolution propagates no wrap flags to the + // post-increment/decrement IV even though the increment/decrement operation + // itself is wrapping. The computed backedge taken count may be wrong in + // such cases. This is prevented by checking that the stride is not known to + // be either positive or non-positive. For example, no wrap flags are + // propagated to the post-increment IV of this loop with a trip count of 2 - + // + // unsigned char i; + // for(i=127; i<128; i+=129) + // A[i] = i; + // + if (PredicatedIV || !NoWrap || isKnownNonPositive(Stride) || + !loopHasNoSideEffects(L)) + return getCouldNotCompute(); + + } else if (!Stride->isOne() && + doesIVOverflowOnLT(RHS, Stride, IsSigned, NoWrap)) + // Avoid proven overflow cases: this will ensure that the backedge taken + // count will not generate any unsigned overflow. Relaxed no-overflow + // conditions exploit NoWrapFlags, allowing to optimize in presence of + // undefined behaviors like the case of C language. return getCouldNotCompute(); ICmpInst::Predicate Cond = IsSigned ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT; const SCEV *Start = IV->getStart(); const SCEV *End = RHS; - if (!isLoopEntryGuardedByCond(L, Cond, getMinusSCEV(Start, Stride), RHS)) + // If the backedge is taken at least once, then it will be taken + // (End-Start)/Stride times (rounded up to a multiple of Stride), where Start + // is the LHS value of the less-than comparison the first time it is evaluated + // and End is the RHS. + const SCEV *BECountIfBackedgeTaken = + computeBECount(getMinusSCEV(End, Start), Stride, false); + // If the loop entry is guarded by the result of the backedge test of the + // first loop iteration, then we know the backedge will be taken at least + // once and so the backedge taken count is as above. If not then we use the + // expression (max(End,Start)-Start)/Stride to describe the backedge count, + // as if the backedge is taken at least once max(End,Start) is End and so the + // result is as above, and if not max(End,Start) is Start so we get a backedge + // count of zero. + const SCEV *BECount; + if (isLoopEntryGuardedByCond(L, Cond, getMinusSCEV(Start, Stride), RHS)) + BECount = BECountIfBackedgeTaken; + else { End = IsSigned ? getSMaxExpr(RHS, Start) : getUMaxExpr(RHS, Start); + BECount = computeBECount(getMinusSCEV(End, Start), Stride, false); + } - const SCEV *BECount = computeBECount(getMinusSCEV(End, Start), Stride, false); + const SCEV *MaxBECount; + bool MaxOrZero = false; + if (isa<SCEVConstant>(BECount)) + MaxBECount = BECount; + else if (isa<SCEVConstant>(BECountIfBackedgeTaken)) { + // If we know exactly how many times the backedge will be taken if it's + // taken at least once, then the backedge count will either be that or + // zero. + MaxBECount = BECountIfBackedgeTaken; + MaxOrZero = true; + } else { + // Calculate the maximum backedge count based on the range of values + // permitted by Start, End, and Stride. + APInt MinStart = IsSigned ? getSignedRange(Start).getSignedMin() + : getUnsignedRange(Start).getUnsignedMin(); - APInt MinStart = IsSigned ? getSignedRange(Start).getSignedMin() - : getUnsignedRange(Start).getUnsignedMin(); + unsigned BitWidth = getTypeSizeInBits(LHS->getType()); - APInt MinStride = IsSigned ? getSignedRange(Stride).getSignedMin() - : getUnsignedRange(Stride).getUnsignedMin(); + APInt StrideForMaxBECount; - unsigned BitWidth = getTypeSizeInBits(LHS->getType()); - APInt Limit = IsSigned ? APInt::getSignedMaxValue(BitWidth) - (MinStride - 1) - : APInt::getMaxValue(BitWidth) - (MinStride - 1); + if (PositiveStride) + StrideForMaxBECount = + IsSigned ? getSignedRange(Stride).getSignedMin() + : getUnsignedRange(Stride).getUnsignedMin(); + else + // Using a stride of 1 is safe when computing max backedge taken count for + // a loop with unknown stride. + StrideForMaxBECount = APInt(BitWidth, 1, IsSigned); - // Although End can be a MAX expression we estimate MaxEnd considering only - // the case End = RHS. This is safe because in the other case (End - Start) - // is zero, leading to a zero maximum backedge taken count. - APInt MaxEnd = - IsSigned ? APIntOps::smin(getSignedRange(RHS).getSignedMax(), Limit) - : APIntOps::umin(getUnsignedRange(RHS).getUnsignedMax(), Limit); + APInt Limit = + IsSigned ? APInt::getSignedMaxValue(BitWidth) - (StrideForMaxBECount - 1) + : APInt::getMaxValue(BitWidth) - (StrideForMaxBECount - 1); + + // Although End can be a MAX expression we estimate MaxEnd considering only + // the case End = RHS. This is safe because in the other case (End - Start) + // is zero, leading to a zero maximum backedge taken count. + APInt MaxEnd = + IsSigned ? APIntOps::smin(getSignedRange(RHS).getSignedMax(), Limit) + : APIntOps::umin(getUnsignedRange(RHS).getUnsignedMax(), Limit); - const SCEV *MaxBECount; - if (isa<SCEVConstant>(BECount)) - MaxBECount = BECount; - else MaxBECount = computeBECount(getConstant(MaxEnd - MinStart), - getConstant(MinStride), false); + getConstant(StrideForMaxBECount), false); + } if (isa<SCEVCouldNotCompute>(MaxBECount)) MaxBECount = BECount; - return ExitLimit(BECount, MaxBECount, P); + return ExitLimit(BECount, MaxBECount, MaxOrZero, Predicates); } ScalarEvolution::ExitLimit ScalarEvolution::howManyGreaterThans(const SCEV *LHS, const SCEV *RHS, const Loop *L, bool IsSigned, bool ControlsExit, bool AllowPredicates) { - SCEVUnionPredicate P; + SmallPtrSet<const SCEVPredicate *, 4> Predicates; // We handle only IV > Invariant if (!isLoopInvariant(RHS, L)) return getCouldNotCompute(); @@ -8769,7 +8810,7 @@ ScalarEvolution::howManyGreaterThans(const SCEV *LHS, const SCEV *RHS, // Try to make this an AddRec using runtime tests, in the first X // iterations of this loop, where X is the SCEV expression found by the // algorithm below. - IV = convertSCEVToAddRecWithPredicates(LHS, L, P); + IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates); // Avoid weird loops if (!IV || IV->getLoop() != L || !IV->isAffine()) @@ -8829,7 +8870,7 @@ ScalarEvolution::howManyGreaterThans(const SCEV *LHS, const SCEV *RHS, if (isa<SCEVCouldNotCompute>(MaxBECount)) MaxBECount = BECount; - return ExitLimit(BECount, MaxBECount, P); + return ExitLimit(BECount, MaxBECount, false, Predicates); } const SCEV *SCEVAddRecExpr::getNumIterationsInRange(const ConstantRange &Range, @@ -8901,9 +8942,7 @@ const SCEV *SCEVAddRecExpr::getNumIterationsInRange(const ConstantRange &Range, // Range.getUpper() is crossed. SmallVector<const SCEV *, 4> NewOps(op_begin(), op_end()); NewOps[0] = SE.getNegativeSCEV(SE.getConstant(Range.getUpper())); - const SCEV *NewAddRec = SE.getAddRecExpr(NewOps, getLoop(), - // getNoWrapFlags(FlagNW) - FlagAnyWrap); + const SCEV *NewAddRec = SE.getAddRecExpr(NewOps, getLoop(), FlagAnyWrap); // Next, solve the constructed addrec if (auto Roots = @@ -8947,38 +8986,15 @@ const SCEV *SCEVAddRecExpr::getNumIterationsInRange(const ConstantRange &Range, return SE.getCouldNotCompute(); } -namespace { -struct FindUndefs { - bool Found; - FindUndefs() : Found(false) {} - - bool follow(const SCEV *S) { - if (const SCEVUnknown *C = dyn_cast<SCEVUnknown>(S)) { - if (isa<UndefValue>(C->getValue())) - Found = true; - } else if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S)) { - if (isa<UndefValue>(C->getValue())) - Found = true; - } - - // Keep looking if we haven't found it yet. - return !Found; - } - bool isDone() const { - // Stop recursion if we have found an undef. - return Found; - } -}; -} - // Return true when S contains at least an undef value. -static inline bool -containsUndefs(const SCEV *S) { - FindUndefs F; - SCEVTraversal<FindUndefs> ST(F); - ST.visitAll(S); - - return F.Found; +static inline bool containsUndefs(const SCEV *S) { + return SCEVExprContains(S, [](const SCEV *S) { + if (const auto *SU = dyn_cast<SCEVUnknown>(S)) + return isa<UndefValue>(SU->getValue()); + else if (const auto *SC = dyn_cast<SCEVConstant>(S)) + return isa<UndefValue>(SC->getValue()); + return false; + }); } namespace { @@ -9006,7 +9022,8 @@ struct SCEVCollectTerms { : Terms(T) {} bool follow(const SCEV *S) { - if (isa<SCEVUnknown>(S) || isa<SCEVMulExpr>(S)) { + if (isa<SCEVUnknown>(S) || isa<SCEVMulExpr>(S) || + isa<SCEVSignExtendExpr>(S)) { if (!containsUndefs(S)) Terms.push_back(S); @@ -9158,10 +9175,9 @@ static bool findArrayDimensionsRec(ScalarEvolution &SE, } // Remove all SCEVConstants. - Terms.erase(std::remove_if(Terms.begin(), Terms.end(), [](const SCEV *E) { - return isa<SCEVConstant>(E); - }), - Terms.end()); + Terms.erase( + remove_if(Terms, [](const SCEV *E) { return isa<SCEVConstant>(E); }), + Terms.end()); if (Terms.size() > 0) if (!findArrayDimensionsRec(SE, Terms, Sizes)) @@ -9171,40 +9187,11 @@ static bool findArrayDimensionsRec(ScalarEvolution &SE, return true; } -// Returns true when S contains at least a SCEVUnknown parameter. -static inline bool -containsParameters(const SCEV *S) { - struct FindParameter { - bool FoundParameter; - FindParameter() : FoundParameter(false) {} - - bool follow(const SCEV *S) { - if (isa<SCEVUnknown>(S)) { - FoundParameter = true; - // Stop recursion: we found a parameter. - return false; - } - // Keep looking. - return true; - } - bool isDone() const { - // Stop recursion if we have found a parameter. - return FoundParameter; - } - }; - - FindParameter F; - SCEVTraversal<FindParameter> ST(F); - ST.visitAll(S); - - return F.FoundParameter; -} // Returns true when one of the SCEVs of Terms contains a SCEVUnknown parameter. -static inline bool -containsParameters(SmallVectorImpl<const SCEV *> &Terms) { +static inline bool containsParameters(SmallVectorImpl<const SCEV *> &Terms) { for (const SCEV *T : Terms) - if (containsParameters(T)) + if (SCEVExprContains(T, isa<SCEVUnknown, const SCEV *>)) return true; return false; } @@ -9535,6 +9522,7 @@ ScalarEvolution::ScalarEvolution(ScalarEvolution &&Arg) : F(Arg.F), HasGuards(Arg.HasGuards), TLI(Arg.TLI), AC(Arg.AC), DT(Arg.DT), LI(Arg.LI), CouldNotCompute(std::move(Arg.CouldNotCompute)), ValueExprMap(std::move(Arg.ValueExprMap)), + PendingLoopPredicates(std::move(Arg.PendingLoopPredicates)), WalkingBEDominatingConds(false), ProvingSplitPredicate(false), BackedgeTakenCounts(std::move(Arg.BackedgeTakenCounts)), PredicatedBackedgeTakenCounts( @@ -9543,6 +9531,7 @@ ScalarEvolution::ScalarEvolution(ScalarEvolution &&Arg) std::move(Arg.ConstantEvolutionLoopExitValue)), ValuesAtScopes(std::move(Arg.ValuesAtScopes)), LoopDispositions(std::move(Arg.LoopDispositions)), + LoopPropertiesCache(std::move(Arg.LoopPropertiesCache)), BlockDispositions(std::move(Arg.BlockDispositions)), UnsignedRanges(std::move(Arg.UnsignedRanges)), SignedRanges(std::move(Arg.SignedRanges)), @@ -9611,6 +9600,8 @@ static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE, if (!isa<SCEVCouldNotCompute>(SE->getMaxBackedgeTakenCount(L))) { OS << "max backedge-taken count is " << *SE->getMaxBackedgeTakenCount(L); + if (SE->isBackedgeTakenCountMaxOrZero(L)) + OS << ", actual taken count either this or zero."; } else { OS << "Unpredictable max backedge-taken count. "; } @@ -9871,8 +9862,10 @@ ScalarEvolution::computeBlockDisposition(const SCEV *S, const BasicBlock *BB) { const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S); if (!DT.dominates(AR->getLoop()->getHeader(), BB)) return DoesNotDominateBlock; + + // Fall through into SCEVNAryExpr handling. + LLVM_FALLTHROUGH; } - // FALL THROUGH into SCEVNAryExpr handling. case scAddExpr: case scMulExpr: case scUMaxExpr: @@ -9925,24 +9918,7 @@ bool ScalarEvolution::properlyDominates(const SCEV *S, const BasicBlock *BB) { } bool ScalarEvolution::hasOperand(const SCEV *S, const SCEV *Op) const { - // Search for a SCEV expression node within an expression tree. - // Implements SCEVTraversal::Visitor. - struct SCEVSearch { - const SCEV *Node; - bool IsFound; - - SCEVSearch(const SCEV *N): Node(N), IsFound(false) {} - - bool follow(const SCEV *S) { - IsFound |= (S == Node); - return !IsFound; - } - bool isDone() const { return IsFound; } - }; - - SCEVSearch Search(Op); - visitAll(S, Search); - return Search.IsFound; + return SCEVExprContains(S, [&](const SCEV *Expr) { return Expr == Op; }); } void ScalarEvolution::forgetMemoizedResults(const SCEV *S) { @@ -10050,10 +10026,22 @@ void ScalarEvolution::verify() const { // TODO: Verify more things. } -char ScalarEvolutionAnalysis::PassID; +bool ScalarEvolution::invalidate( + Function &F, const PreservedAnalyses &PA, + FunctionAnalysisManager::Invalidator &Inv) { + // Invalidate the ScalarEvolution object whenever it isn't preserved or one + // of its dependencies is invalidated. + auto PAC = PA.getChecker<ScalarEvolutionAnalysis>(); + return !(PAC.preserved() || PAC.preservedSet<AllAnalysesOn<Function>>()) || + Inv.invalidate<AssumptionAnalysis>(F, PA) || + Inv.invalidate<DominatorTreeAnalysis>(F, PA) || + Inv.invalidate<LoopAnalysis>(F, PA); +} + +AnalysisKey ScalarEvolutionAnalysis::Key; ScalarEvolution ScalarEvolutionAnalysis::run(Function &F, - AnalysisManager<Function> &AM) { + FunctionAnalysisManager &AM) { return ScalarEvolution(F, AM.getResult<TargetLibraryAnalysis>(F), AM.getResult<AssumptionAnalysis>(F), AM.getResult<DominatorTreeAnalysis>(F), @@ -10061,7 +10049,7 @@ ScalarEvolution ScalarEvolutionAnalysis::run(Function &F, } PreservedAnalyses -ScalarEvolutionPrinterPass::run(Function &F, AnalysisManager<Function> &AM) { +ScalarEvolutionPrinterPass::run(Function &F, FunctionAnalysisManager &AM) { AM.getResult<ScalarEvolutionAnalysis>(F).print(OS); return PreservedAnalyses::all(); } @@ -10148,25 +10136,34 @@ namespace { class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> { public: - // Rewrites \p S in the context of a loop L and the predicate A. - // If Assume is true, rewrite is free to add further predicates to A - // such that the result will be an AddRecExpr. + /// Rewrites \p S in the context of a loop L and the SCEV predication + /// infrastructure. + /// + /// If \p Pred is non-null, the SCEV expression is rewritten to respect the + /// equivalences present in \p Pred. + /// + /// If \p NewPreds is non-null, rewrite is free to add further predicates to + /// \p NewPreds such that the result will be an AddRecExpr. static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE, - SCEVUnionPredicate &A, bool Assume) { - SCEVPredicateRewriter Rewriter(L, SE, A, Assume); + SmallPtrSetImpl<const SCEVPredicate *> *NewPreds, + SCEVUnionPredicate *Pred) { + SCEVPredicateRewriter Rewriter(L, SE, NewPreds, Pred); return Rewriter.visit(S); } SCEVPredicateRewriter(const Loop *L, ScalarEvolution &SE, - SCEVUnionPredicate &P, bool Assume) - : SCEVRewriteVisitor(SE), P(P), L(L), Assume(Assume) {} + SmallPtrSetImpl<const SCEVPredicate *> *NewPreds, + SCEVUnionPredicate *Pred) + : SCEVRewriteVisitor(SE), NewPreds(NewPreds), Pred(Pred), L(L) {} const SCEV *visitUnknown(const SCEVUnknown *Expr) { - auto ExprPreds = P.getPredicatesForExpr(Expr); - for (auto *Pred : ExprPreds) - if (const auto *IPred = dyn_cast<SCEVEqualPredicate>(Pred)) - if (IPred->getLHS() == Expr) - return IPred->getRHS(); + if (Pred) { + auto ExprPreds = Pred->getPredicatesForExpr(Expr); + for (auto *Pred : ExprPreds) + if (const auto *IPred = dyn_cast<SCEVEqualPredicate>(Pred)) + if (IPred->getLHS() == Expr) + return IPred->getRHS(); + } return Expr; } @@ -10207,32 +10204,31 @@ private: bool addOverflowAssumption(const SCEVAddRecExpr *AR, SCEVWrapPredicate::IncrementWrapFlags AddedFlags) { auto *A = SE.getWrapPredicate(AR, AddedFlags); - if (!Assume) { + if (!NewPreds) { // Check if we've already made this assumption. - if (P.implies(A)) - return true; - return false; + return Pred && Pred->implies(A); } - P.add(A); + NewPreds->insert(A); return true; } - SCEVUnionPredicate &P; + SmallPtrSetImpl<const SCEVPredicate *> *NewPreds; + SCEVUnionPredicate *Pred; const Loop *L; - bool Assume; }; } // end anonymous namespace const SCEV *ScalarEvolution::rewriteUsingPredicate(const SCEV *S, const Loop *L, SCEVUnionPredicate &Preds) { - return SCEVPredicateRewriter::rewrite(S, L, *this, Preds, false); + return SCEVPredicateRewriter::rewrite(S, L, *this, nullptr, &Preds); } -const SCEVAddRecExpr * -ScalarEvolution::convertSCEVToAddRecWithPredicates(const SCEV *S, const Loop *L, - SCEVUnionPredicate &Preds) { - SCEVUnionPredicate TransformPreds; - S = SCEVPredicateRewriter::rewrite(S, L, *this, TransformPreds, true); +const SCEVAddRecExpr *ScalarEvolution::convertSCEVToAddRecWithPredicates( + const SCEV *S, const Loop *L, + SmallPtrSetImpl<const SCEVPredicate *> &Preds) { + + SmallPtrSet<const SCEVPredicate *, 4> TransformPreds; + S = SCEVPredicateRewriter::rewrite(S, L, *this, &TransformPreds, nullptr); auto *AddRec = dyn_cast<SCEVAddRecExpr>(S); if (!AddRec) @@ -10240,7 +10236,9 @@ ScalarEvolution::convertSCEVToAddRecWithPredicates(const SCEV *S, const Loop *L, // Since the transformation was successful, we can now transfer the SCEV // predicates. - Preds.add(&TransformPreds); + for (auto *P : TransformPreds) + Preds.insert(P); + return AddRec; } @@ -10393,7 +10391,7 @@ const SCEV *PredicatedScalarEvolution::getSCEV(Value *V) { return Entry.second; // We found an entry but it's stale. Rewrite the stale entry - // acording to the current predicate. + // according to the current predicate. if (Entry.second) Expr = Entry.second; @@ -10467,11 +10465,15 @@ bool PredicatedScalarEvolution::hasNoOverflow( const SCEVAddRecExpr *PredicatedScalarEvolution::getAsAddRec(Value *V) { const SCEV *Expr = this->getSCEV(V); - auto *New = SE.convertSCEVToAddRecWithPredicates(Expr, &L, Preds); + SmallPtrSet<const SCEVPredicate *, 4> NewPreds; + auto *New = SE.convertSCEVToAddRecWithPredicates(Expr, &L, NewPreds); if (!New) return nullptr; + for (auto *P : NewPreds) + Preds.add(P); + updateGeneration(); RewriteMap[SE.getSCEV(V)] = {Generation, New}; return New; |