summaryrefslogtreecommitdiffstats
path: root/contrib/llvm/lib/Analysis/ScalarEvolution.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'contrib/llvm/lib/Analysis/ScalarEvolution.cpp')
-rw-r--r--contrib/llvm/lib/Analysis/ScalarEvolution.cpp1462
1 files changed, 732 insertions, 730 deletions
diff --git a/contrib/llvm/lib/Analysis/ScalarEvolution.cpp b/contrib/llvm/lib/Analysis/ScalarEvolution.cpp
index 8fefada..ed328f1 100644
--- a/contrib/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/contrib/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -61,6 +61,8 @@
#include "llvm/Analysis/ScalarEvolution.h"
#include "llvm/ADT/Optional.h"
#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/ScopeExit.h"
+#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/Analysis/AssumptionCache.h"
@@ -120,6 +122,21 @@ static cl::opt<bool>
cl::desc("Verify no dangling value in ScalarEvolution's "
"ExprValueMap (slow)"));
+static cl::opt<unsigned> MulOpsInlineThreshold(
+ "scev-mulops-inline-threshold", cl::Hidden,
+ cl::desc("Threshold for inlining multiplication operands into a SCEV"),
+ cl::init(1000));
+
+static cl::opt<unsigned> MaxSCEVCompareDepth(
+ "scalar-evolution-max-scev-compare-depth", cl::Hidden,
+ cl::desc("Maximum depth of recursive SCEV complexity comparisons"),
+ cl::init(32));
+
+static cl::opt<unsigned> MaxValueCompareDepth(
+ "scalar-evolution-max-value-compare-depth", cl::Hidden,
+ cl::desc("Maximum depth of recursive value complexity comparisons"),
+ cl::init(2));
+
//===----------------------------------------------------------------------===//
// SCEV class definitions
//===----------------------------------------------------------------------===//
@@ -447,180 +464,233 @@ bool SCEVUnknown::isOffsetOf(Type *&CTy, Constant *&FieldNo) const {
// SCEV Utilities
//===----------------------------------------------------------------------===//
-namespace {
-/// SCEVComplexityCompare - Return true if the complexity of the LHS is less
-/// than the complexity of the RHS. This comparator is used to canonicalize
-/// expressions.
-class SCEVComplexityCompare {
- const LoopInfo *const LI;
-public:
- explicit SCEVComplexityCompare(const LoopInfo *li) : LI(li) {}
+/// Compare the two values \p LV and \p RV in terms of their "complexity" where
+/// "complexity" is a partial (and somewhat ad-hoc) relation used to order
+/// operands in SCEV expressions. \p EqCache is a set of pairs of values that
+/// have been previously deemed to be "equally complex" by this routine. It is
+/// intended to avoid exponential time complexity in cases like:
+///
+/// %a = f(%x, %y)
+/// %b = f(%a, %a)
+/// %c = f(%b, %b)
+///
+/// %d = f(%x, %y)
+/// %e = f(%d, %d)
+/// %f = f(%e, %e)
+///
+/// CompareValueComplexity(%f, %c)
+///
+/// Since we do not continue running this routine on expression trees once we
+/// have seen unequal values, there is no need to track them in the cache.
+static int
+CompareValueComplexity(SmallSet<std::pair<Value *, Value *>, 8> &EqCache,
+ const LoopInfo *const LI, Value *LV, Value *RV,
+ unsigned Depth) {
+ if (Depth > MaxValueCompareDepth || EqCache.count({LV, RV}))
+ return 0;
+
+ // Order pointer values after integer values. This helps SCEVExpander form
+ // GEPs.
+ bool LIsPointer = LV->getType()->isPointerTy(),
+ RIsPointer = RV->getType()->isPointerTy();
+ if (LIsPointer != RIsPointer)
+ return (int)LIsPointer - (int)RIsPointer;
- // Return true or false if LHS is less than, or at least RHS, respectively.
- bool operator()(const SCEV *LHS, const SCEV *RHS) const {
- return compare(LHS, RHS) < 0;
+ // Compare getValueID values.
+ unsigned LID = LV->getValueID(), RID = RV->getValueID();
+ if (LID != RID)
+ return (int)LID - (int)RID;
+
+ // Sort arguments by their position.
+ if (const auto *LA = dyn_cast<Argument>(LV)) {
+ const auto *RA = cast<Argument>(RV);
+ unsigned LArgNo = LA->getArgNo(), RArgNo = RA->getArgNo();
+ return (int)LArgNo - (int)RArgNo;
}
- // Return negative, zero, or positive, if LHS is less than, equal to, or
- // greater than RHS, respectively. A three-way result allows recursive
- // comparisons to be more efficient.
- int compare(const SCEV *LHS, const SCEV *RHS) const {
- // Fast-path: SCEVs are uniqued so we can do a quick equality check.
- if (LHS == RHS)
- return 0;
-
- // Primarily, sort the SCEVs by their getSCEVType().
- unsigned LType = LHS->getSCEVType(), RType = RHS->getSCEVType();
- if (LType != RType)
- return (int)LType - (int)RType;
-
- // Aside from the getSCEVType() ordering, the particular ordering
- // isn't very important except that it's beneficial to be consistent,
- // so that (a + b) and (b + a) don't end up as different expressions.
- switch (static_cast<SCEVTypes>(LType)) {
- case scUnknown: {
- const SCEVUnknown *LU = cast<SCEVUnknown>(LHS);
- const SCEVUnknown *RU = cast<SCEVUnknown>(RHS);
-
- // Sort SCEVUnknown values with some loose heuristics. TODO: This is
- // not as complete as it could be.
- const Value *LV = LU->getValue(), *RV = RU->getValue();
-
- // Order pointer values after integer values. This helps SCEVExpander
- // form GEPs.
- bool LIsPointer = LV->getType()->isPointerTy(),
- RIsPointer = RV->getType()->isPointerTy();
- if (LIsPointer != RIsPointer)
- return (int)LIsPointer - (int)RIsPointer;
-
- // Compare getValueID values.
- unsigned LID = LV->getValueID(),
- RID = RV->getValueID();
- if (LID != RID)
- return (int)LID - (int)RID;
-
- // Sort arguments by their position.
- if (const Argument *LA = dyn_cast<Argument>(LV)) {
- const Argument *RA = cast<Argument>(RV);
- unsigned LArgNo = LA->getArgNo(), RArgNo = RA->getArgNo();
- return (int)LArgNo - (int)RArgNo;
- }
+ if (const auto *LGV = dyn_cast<GlobalValue>(LV)) {
+ const auto *RGV = cast<GlobalValue>(RV);
- // For instructions, compare their loop depth, and their operand
- // count. This is pretty loose.
- if (const Instruction *LInst = dyn_cast<Instruction>(LV)) {
- const Instruction *RInst = cast<Instruction>(RV);
-
- // Compare loop depths.
- const BasicBlock *LParent = LInst->getParent(),
- *RParent = RInst->getParent();
- if (LParent != RParent) {
- unsigned LDepth = LI->getLoopDepth(LParent),
- RDepth = LI->getLoopDepth(RParent);
- if (LDepth != RDepth)
- return (int)LDepth - (int)RDepth;
- }
+ const auto IsGVNameSemantic = [&](const GlobalValue *GV) {
+ auto LT = GV->getLinkage();
+ return !(GlobalValue::isPrivateLinkage(LT) ||
+ GlobalValue::isInternalLinkage(LT));
+ };
- // Compare the number of operands.
- unsigned LNumOps = LInst->getNumOperands(),
- RNumOps = RInst->getNumOperands();
- return (int)LNumOps - (int)RNumOps;
- }
+ // Use the names to distinguish the two values, but only if the
+ // names are semantically important.
+ if (IsGVNameSemantic(LGV) && IsGVNameSemantic(RGV))
+ return LGV->getName().compare(RGV->getName());
+ }
+
+ // For instructions, compare their loop depth, and their operand count. This
+ // is pretty loose.
+ if (const auto *LInst = dyn_cast<Instruction>(LV)) {
+ const auto *RInst = cast<Instruction>(RV);
- return 0;
+ // Compare loop depths.
+ const BasicBlock *LParent = LInst->getParent(),
+ *RParent = RInst->getParent();
+ if (LParent != RParent) {
+ unsigned LDepth = LI->getLoopDepth(LParent),
+ RDepth = LI->getLoopDepth(RParent);
+ if (LDepth != RDepth)
+ return (int)LDepth - (int)RDepth;
}
- case scConstant: {
- const SCEVConstant *LC = cast<SCEVConstant>(LHS);
- const SCEVConstant *RC = cast<SCEVConstant>(RHS);
+ // Compare the number of operands.
+ unsigned LNumOps = LInst->getNumOperands(),
+ RNumOps = RInst->getNumOperands();
+ if (LNumOps != RNumOps)
+ return (int)LNumOps - (int)RNumOps;
- // Compare constant values.
- const APInt &LA = LC->getAPInt();
- const APInt &RA = RC->getAPInt();
- unsigned LBitWidth = LA.getBitWidth(), RBitWidth = RA.getBitWidth();
- if (LBitWidth != RBitWidth)
- return (int)LBitWidth - (int)RBitWidth;
- return LA.ult(RA) ? -1 : 1;
+ for (unsigned Idx : seq(0u, LNumOps)) {
+ int Result =
+ CompareValueComplexity(EqCache, LI, LInst->getOperand(Idx),
+ RInst->getOperand(Idx), Depth + 1);
+ if (Result != 0)
+ return Result;
}
+ }
- case scAddRecExpr: {
- const SCEVAddRecExpr *LA = cast<SCEVAddRecExpr>(LHS);
- const SCEVAddRecExpr *RA = cast<SCEVAddRecExpr>(RHS);
+ EqCache.insert({LV, RV});
+ return 0;
+}
- // Compare addrec loop depths.
- const Loop *LLoop = LA->getLoop(), *RLoop = RA->getLoop();
- if (LLoop != RLoop) {
- unsigned LDepth = LLoop->getLoopDepth(),
- RDepth = RLoop->getLoopDepth();
- if (LDepth != RDepth)
- return (int)LDepth - (int)RDepth;
- }
+// Return negative, zero, or positive, if LHS is less than, equal to, or greater
+// than RHS, respectively. A three-way result allows recursive comparisons to be
+// more efficient.
+static int CompareSCEVComplexity(
+ SmallSet<std::pair<const SCEV *, const SCEV *>, 8> &EqCacheSCEV,
+ const LoopInfo *const LI, const SCEV *LHS, const SCEV *RHS,
+ unsigned Depth = 0) {
+ // Fast-path: SCEVs are uniqued so we can do a quick equality check.
+ if (LHS == RHS)
+ return 0;
- // Addrec complexity grows with operand count.
- unsigned LNumOps = LA->getNumOperands(), RNumOps = RA->getNumOperands();
- if (LNumOps != RNumOps)
- return (int)LNumOps - (int)RNumOps;
+ // Primarily, sort the SCEVs by their getSCEVType().
+ unsigned LType = LHS->getSCEVType(), RType = RHS->getSCEVType();
+ if (LType != RType)
+ return (int)LType - (int)RType;
- // Lexicographically compare.
- for (unsigned i = 0; i != LNumOps; ++i) {
- long X = compare(LA->getOperand(i), RA->getOperand(i));
- if (X != 0)
- return X;
- }
+ if (Depth > MaxSCEVCompareDepth || EqCacheSCEV.count({LHS, RHS}))
+ return 0;
+ // Aside from the getSCEVType() ordering, the particular ordering
+ // isn't very important except that it's beneficial to be consistent,
+ // so that (a + b) and (b + a) don't end up as different expressions.
+ switch (static_cast<SCEVTypes>(LType)) {
+ case scUnknown: {
+ const SCEVUnknown *LU = cast<SCEVUnknown>(LHS);
+ const SCEVUnknown *RU = cast<SCEVUnknown>(RHS);
+
+ SmallSet<std::pair<Value *, Value *>, 8> EqCache;
+ int X = CompareValueComplexity(EqCache, LI, LU->getValue(), RU->getValue(),
+ Depth + 1);
+ if (X == 0)
+ EqCacheSCEV.insert({LHS, RHS});
+ return X;
+ }
- return 0;
+ case scConstant: {
+ const SCEVConstant *LC = cast<SCEVConstant>(LHS);
+ const SCEVConstant *RC = cast<SCEVConstant>(RHS);
+
+ // Compare constant values.
+ const APInt &LA = LC->getAPInt();
+ const APInt &RA = RC->getAPInt();
+ unsigned LBitWidth = LA.getBitWidth(), RBitWidth = RA.getBitWidth();
+ if (LBitWidth != RBitWidth)
+ return (int)LBitWidth - (int)RBitWidth;
+ return LA.ult(RA) ? -1 : 1;
+ }
+
+ case scAddRecExpr: {
+ const SCEVAddRecExpr *LA = cast<SCEVAddRecExpr>(LHS);
+ const SCEVAddRecExpr *RA = cast<SCEVAddRecExpr>(RHS);
+
+ // Compare addrec loop depths.
+ const Loop *LLoop = LA->getLoop(), *RLoop = RA->getLoop();
+ if (LLoop != RLoop) {
+ unsigned LDepth = LLoop->getLoopDepth(), RDepth = RLoop->getLoopDepth();
+ if (LDepth != RDepth)
+ return (int)LDepth - (int)RDepth;
}
- case scAddExpr:
- case scMulExpr:
- case scSMaxExpr:
- case scUMaxExpr: {
- const SCEVNAryExpr *LC = cast<SCEVNAryExpr>(LHS);
- const SCEVNAryExpr *RC = cast<SCEVNAryExpr>(RHS);
-
- // Lexicographically compare n-ary expressions.
- unsigned LNumOps = LC->getNumOperands(), RNumOps = RC->getNumOperands();
- if (LNumOps != RNumOps)
- return (int)LNumOps - (int)RNumOps;
-
- for (unsigned i = 0; i != LNumOps; ++i) {
- if (i >= RNumOps)
- return 1;
- long X = compare(LC->getOperand(i), RC->getOperand(i));
- if (X != 0)
- return X;
- }
+ // Addrec complexity grows with operand count.
+ unsigned LNumOps = LA->getNumOperands(), RNumOps = RA->getNumOperands();
+ if (LNumOps != RNumOps)
return (int)LNumOps - (int)RNumOps;
+
+ // Lexicographically compare.
+ for (unsigned i = 0; i != LNumOps; ++i) {
+ int X = CompareSCEVComplexity(EqCacheSCEV, LI, LA->getOperand(i),
+ RA->getOperand(i), Depth + 1);
+ if (X != 0)
+ return X;
}
+ EqCacheSCEV.insert({LHS, RHS});
+ return 0;
+ }
- case scUDivExpr: {
- const SCEVUDivExpr *LC = cast<SCEVUDivExpr>(LHS);
- const SCEVUDivExpr *RC = cast<SCEVUDivExpr>(RHS);
+ case scAddExpr:
+ case scMulExpr:
+ case scSMaxExpr:
+ case scUMaxExpr: {
+ const SCEVNAryExpr *LC = cast<SCEVNAryExpr>(LHS);
+ const SCEVNAryExpr *RC = cast<SCEVNAryExpr>(RHS);
+
+ // Lexicographically compare n-ary expressions.
+ unsigned LNumOps = LC->getNumOperands(), RNumOps = RC->getNumOperands();
+ if (LNumOps != RNumOps)
+ return (int)LNumOps - (int)RNumOps;
- // Lexicographically compare udiv expressions.
- long X = compare(LC->getLHS(), RC->getLHS());
+ for (unsigned i = 0; i != LNumOps; ++i) {
+ if (i >= RNumOps)
+ return 1;
+ int X = CompareSCEVComplexity(EqCacheSCEV, LI, LC->getOperand(i),
+ RC->getOperand(i), Depth + 1);
if (X != 0)
return X;
- return compare(LC->getRHS(), RC->getRHS());
}
+ EqCacheSCEV.insert({LHS, RHS});
+ return 0;
+ }
- case scTruncate:
- case scZeroExtend:
- case scSignExtend: {
- const SCEVCastExpr *LC = cast<SCEVCastExpr>(LHS);
- const SCEVCastExpr *RC = cast<SCEVCastExpr>(RHS);
+ case scUDivExpr: {
+ const SCEVUDivExpr *LC = cast<SCEVUDivExpr>(LHS);
+ const SCEVUDivExpr *RC = cast<SCEVUDivExpr>(RHS);
- // Compare cast expressions by operand.
- return compare(LC->getOperand(), RC->getOperand());
- }
+ // Lexicographically compare udiv expressions.
+ int X = CompareSCEVComplexity(EqCacheSCEV, LI, LC->getLHS(), RC->getLHS(),
+ Depth + 1);
+ if (X != 0)
+ return X;
+ X = CompareSCEVComplexity(EqCacheSCEV, LI, LC->getRHS(), RC->getRHS(),
+ Depth + 1);
+ if (X == 0)
+ EqCacheSCEV.insert({LHS, RHS});
+ return X;
+ }
- case scCouldNotCompute:
- llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
- }
- llvm_unreachable("Unknown SCEV kind!");
+ case scTruncate:
+ case scZeroExtend:
+ case scSignExtend: {
+ const SCEVCastExpr *LC = cast<SCEVCastExpr>(LHS);
+ const SCEVCastExpr *RC = cast<SCEVCastExpr>(RHS);
+
+ // Compare cast expressions by operand.
+ int X = CompareSCEVComplexity(EqCacheSCEV, LI, LC->getOperand(),
+ RC->getOperand(), Depth + 1);
+ if (X == 0)
+ EqCacheSCEV.insert({LHS, RHS});
+ return X;
}
-};
-} // end anonymous namespace
+
+ case scCouldNotCompute:
+ llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
+ }
+ llvm_unreachable("Unknown SCEV kind!");
+}
/// Given a list of SCEV objects, order them by their complexity, and group
/// objects of the same complexity together by value. When this routine is
@@ -635,17 +705,22 @@ public:
static void GroupByComplexity(SmallVectorImpl<const SCEV *> &Ops,
LoopInfo *LI) {
if (Ops.size() < 2) return; // Noop
+
+ SmallSet<std::pair<const SCEV *, const SCEV *>, 8> EqCache;
if (Ops.size() == 2) {
// This is the common case, which also happens to be trivially simple.
// Special case it.
const SCEV *&LHS = Ops[0], *&RHS = Ops[1];
- if (SCEVComplexityCompare(LI)(RHS, LHS))
+ if (CompareSCEVComplexity(EqCache, LI, RHS, LHS) < 0)
std::swap(LHS, RHS);
return;
}
// Do the rough sort by complexity.
- std::stable_sort(Ops.begin(), Ops.end(), SCEVComplexityCompare(LI));
+ std::stable_sort(Ops.begin(), Ops.end(),
+ [&EqCache, LI](const SCEV *LHS, const SCEV *RHS) {
+ return CompareSCEVComplexity(EqCache, LI, LHS, RHS) < 0;
+ });
// Now that we are sorted by complexity, group elements of the same
// complexity. Note that this is, at worst, N^2, but the vector is likely to
@@ -2518,6 +2593,8 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops,
if (Idx < Ops.size()) {
bool DeletedMul = false;
while (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
+ if (Ops.size() > MulOpsInlineThreshold)
+ break;
// If we have an mul, expand the mul operands onto the end of the operands
// list.
Ops.erase(Ops.begin()+Idx);
@@ -2970,9 +3047,9 @@ ScalarEvolution::getAddRecExpr(SmallVectorImpl<const SCEV *> &Operands,
}
const SCEV *
-ScalarEvolution::getGEPExpr(Type *PointeeType, const SCEV *BaseExpr,
- const SmallVectorImpl<const SCEV *> &IndexExprs,
- bool InBounds) {
+ScalarEvolution::getGEPExpr(GEPOperator *GEP,
+ const SmallVectorImpl<const SCEV *> &IndexExprs) {
+ const SCEV *BaseExpr = getSCEV(GEP->getPointerOperand());
// getSCEV(Base)->getType() has the same address space as Base->getType()
// because SCEV::getType() preserves the address space.
Type *IntPtrTy = getEffectiveSCEVType(BaseExpr->getType());
@@ -2981,12 +3058,13 @@ ScalarEvolution::getGEPExpr(Type *PointeeType, const SCEV *BaseExpr,
// flow and the no-overflow bits may not be valid for the expression in any
// context. This can be fixed similarly to how these flags are handled for
// adds.
- SCEV::NoWrapFlags Wrap = InBounds ? SCEV::FlagNSW : SCEV::FlagAnyWrap;
+ SCEV::NoWrapFlags Wrap = GEP->isInBounds() ? SCEV::FlagNSW
+ : SCEV::FlagAnyWrap;
const SCEV *TotalOffset = getZero(IntPtrTy);
- // The address space is unimportant. The first thing we do on CurTy is getting
+ // The array size is unimportant. The first thing we do on CurTy is getting
// its element type.
- Type *CurTy = PointerType::getUnqual(PointeeType);
+ Type *CurTy = ArrayType::get(GEP->getSourceElementType(), 0);
for (const SCEV *IndexExpr : IndexExprs) {
// Compute the (potentially symbolic) offset in bytes for this index.
if (StructType *STy = dyn_cast<StructType>(CurTy)) {
@@ -3311,71 +3389,23 @@ const SCEV *ScalarEvolution::getCouldNotCompute() {
return CouldNotCompute.get();
}
-
bool ScalarEvolution::checkValidity(const SCEV *S) const {
- // Helper class working with SCEVTraversal to figure out if a SCEV contains
- // a SCEVUnknown with null value-pointer. FindInvalidSCEVUnknown::FindOne
- // is set iff if find such SCEVUnknown.
- //
- struct FindInvalidSCEVUnknown {
- bool FindOne;
- FindInvalidSCEVUnknown() { FindOne = false; }
- bool follow(const SCEV *S) {
- switch (static_cast<SCEVTypes>(S->getSCEVType())) {
- case scConstant:
- return false;
- case scUnknown:
- if (!cast<SCEVUnknown>(S)->getValue())
- FindOne = true;
- return false;
- default:
- return true;
- }
- }
- bool isDone() const { return FindOne; }
- };
-
- FindInvalidSCEVUnknown F;
- SCEVTraversal<FindInvalidSCEVUnknown> ST(F);
- ST.visitAll(S);
+ bool ContainsNulls = SCEVExprContains(S, [](const SCEV *S) {
+ auto *SU = dyn_cast<SCEVUnknown>(S);
+ return SU && SU->getValue() == nullptr;
+ });
- return !F.FindOne;
-}
-
-namespace {
-// Helper class working with SCEVTraversal to figure out if a SCEV contains
-// a sub SCEV of scAddRecExpr type. FindInvalidSCEVUnknown::FoundOne is set
-// iff if such sub scAddRecExpr type SCEV is found.
-struct FindAddRecurrence {
- bool FoundOne;
- FindAddRecurrence() : FoundOne(false) {}
-
- bool follow(const SCEV *S) {
- switch (static_cast<SCEVTypes>(S->getSCEVType())) {
- case scAddRecExpr:
- FoundOne = true;
- case scConstant:
- case scUnknown:
- case scCouldNotCompute:
- return false;
- default:
- return true;
- }
- }
- bool isDone() const { return FoundOne; }
-};
+ return !ContainsNulls;
}
bool ScalarEvolution::containsAddRecurrence(const SCEV *S) {
- HasRecMapType::iterator I = HasRecMap.find_as(S);
+ HasRecMapType::iterator I = HasRecMap.find(S);
if (I != HasRecMap.end())
return I->second;
- FindAddRecurrence F;
- SCEVTraversal<FindAddRecurrence> ST(F);
- ST.visitAll(S);
- HasRecMap.insert({S, F.FoundOne});
- return F.FoundOne;
+ bool FoundAddRec = SCEVExprContains(S, isa<SCEVAddRecExpr, const SCEV *>);
+ HasRecMap.insert({S, FoundAddRec});
+ return FoundAddRec;
}
/// Try to split a SCEVAddExpr into a pair of {SCEV, ConstantInt}.
@@ -4210,7 +4240,9 @@ static bool BrPHIToSelect(DominatorTree &DT, BranchInst *BI, PHINode *Merge,
}
const SCEV *ScalarEvolution::createNodeFromSelectLikePHI(PHINode *PN) {
- if (PN->getNumIncomingValues() == 2) {
+ auto IsReachable =
+ [&](BasicBlock *BB) { return DT.isReachableFromEntry(BB); };
+ if (PN->getNumIncomingValues() == 2 && all_of(PN->blocks(), IsReachable)) {
const Loop *L = LI.getLoopFor(PN->getParent());
// We don't want to break LCSSA, even in a SCEV expression tree.
@@ -4286,7 +4318,7 @@ const SCEV *ScalarEvolution::createNodeForSelectOrPHI(Instruction *I,
case ICmpInst::ICMP_SLT:
case ICmpInst::ICMP_SLE:
std::swap(LHS, RHS);
- // fall through
+ LLVM_FALLTHROUGH;
case ICmpInst::ICMP_SGT:
case ICmpInst::ICMP_SGE:
// a >s b ? a+x : b+x -> smax(a, b)+x
@@ -4309,7 +4341,7 @@ const SCEV *ScalarEvolution::createNodeForSelectOrPHI(Instruction *I,
case ICmpInst::ICMP_ULT:
case ICmpInst::ICMP_ULE:
std::swap(LHS, RHS);
- // fall through
+ LLVM_FALLTHROUGH;
case ICmpInst::ICMP_UGT:
case ICmpInst::ICMP_UGE:
// a >u b ? a+x : b+x -> umax(a, b)+x
@@ -4374,9 +4406,7 @@ const SCEV *ScalarEvolution::createNodeForGEP(GEPOperator *GEP) {
SmallVector<const SCEV *, 4> IndexExprs;
for (auto Index = GEP->idx_begin(); Index != GEP->idx_end(); ++Index)
IndexExprs.push_back(getSCEV(*Index));
- return getGEPExpr(GEP->getSourceElementType(),
- getSCEV(GEP->getPointerOperand()),
- IndexExprs, GEP->isInBounds());
+ return getGEPExpr(GEP, IndexExprs);
}
uint32_t
@@ -4654,19 +4684,18 @@ ConstantRange ScalarEvolution::getRangeForAffineAR(const SCEV *Start,
MaxBECount = getNoopOrZeroExtend(MaxBECount, Start->getType());
ConstantRange MaxBECountRange = getUnsignedRange(MaxBECount);
- ConstantRange ZExtMaxBECountRange =
- MaxBECountRange.zextOrTrunc(BitWidth * 2 + 1);
+ ConstantRange ZExtMaxBECountRange = MaxBECountRange.zextOrTrunc(BitWidth * 2);
ConstantRange StepSRange = getSignedRange(Step);
- ConstantRange SExtStepSRange = StepSRange.sextOrTrunc(BitWidth * 2 + 1);
+ ConstantRange SExtStepSRange = StepSRange.sextOrTrunc(BitWidth * 2);
ConstantRange StartURange = getUnsignedRange(Start);
ConstantRange EndURange =
StartURange.add(MaxBECountRange.multiply(StepSRange));
// Check for unsigned overflow.
- ConstantRange ZExtStartURange = StartURange.zextOrTrunc(BitWidth * 2 + 1);
- ConstantRange ZExtEndURange = EndURange.zextOrTrunc(BitWidth * 2 + 1);
+ ConstantRange ZExtStartURange = StartURange.zextOrTrunc(BitWidth * 2);
+ ConstantRange ZExtEndURange = EndURange.zextOrTrunc(BitWidth * 2);
if (ZExtStartURange.add(ZExtMaxBECountRange.multiply(SExtStepSRange)) ==
ZExtEndURange) {
APInt Min = APIntOps::umin(StartURange.getUnsignedMin(),
@@ -4686,8 +4715,8 @@ ConstantRange ScalarEvolution::getRangeForAffineAR(const SCEV *Start,
// Check for signed overflow. This must be done with ConstantRange
// arithmetic because we could be called from within the ScalarEvolution
// overflow checking code.
- ConstantRange SExtStartSRange = StartSRange.sextOrTrunc(BitWidth * 2 + 1);
- ConstantRange SExtEndSRange = EndSRange.sextOrTrunc(BitWidth * 2 + 1);
+ ConstantRange SExtStartSRange = StartSRange.sextOrTrunc(BitWidth * 2);
+ ConstantRange SExtEndSRange = EndSRange.sextOrTrunc(BitWidth * 2);
if (SExtStartSRange.add(ZExtMaxBECountRange.multiply(SExtStepSRange)) ==
SExtEndSRange) {
APInt Min =
@@ -4951,17 +4980,33 @@ bool ScalarEvolution::isAddRecNeverPoison(const Instruction *I, const Loop *L) {
return LatchControlDependentOnPoison && loopHasNoAbnormalExits(L);
}
-bool ScalarEvolution::loopHasNoAbnormalExits(const Loop *L) {
- auto Itr = LoopHasNoAbnormalExits.find(L);
- if (Itr == LoopHasNoAbnormalExits.end()) {
- auto NoAbnormalExitInBB = [&](BasicBlock *BB) {
- return all_of(*BB, [](Instruction &I) {
- return isGuaranteedToTransferExecutionToSuccessor(&I);
- });
+ScalarEvolution::LoopProperties
+ScalarEvolution::getLoopProperties(const Loop *L) {
+ typedef ScalarEvolution::LoopProperties LoopProperties;
+
+ auto Itr = LoopPropertiesCache.find(L);
+ if (Itr == LoopPropertiesCache.end()) {
+ auto HasSideEffects = [](Instruction *I) {
+ if (auto *SI = dyn_cast<StoreInst>(I))
+ return !SI->isSimple();
+
+ return I->mayHaveSideEffects();
};
- auto InsertPair = LoopHasNoAbnormalExits.insert(
- {L, all_of(L->getBlocks(), NoAbnormalExitInBB)});
+ LoopProperties LP = {/* HasNoAbnormalExits */ true,
+ /*HasNoSideEffects*/ true};
+
+ for (auto *BB : L->getBlocks())
+ for (auto &I : *BB) {
+ if (!isGuaranteedToTransferExecutionToSuccessor(&I))
+ LP.HasNoAbnormalExits = false;
+ if (HasSideEffects(&I))
+ LP.HasNoSideEffects = false;
+ if (!LP.HasNoAbnormalExits && !LP.HasNoSideEffects)
+ break; // We're already as pessimistic as we can get.
+ }
+
+ auto InsertPair = LoopPropertiesCache.insert({L, LP});
assert(InsertPair.second && "We just checked!");
Itr = InsertPair.first;
}
@@ -5289,6 +5334,20 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) {
// Iteration Count Computation Code
//
+static unsigned getConstantTripCount(const SCEVConstant *ExitCount) {
+ if (!ExitCount)
+ return 0;
+
+ ConstantInt *ExitConst = ExitCount->getValue();
+
+ // Guard against huge trip counts.
+ if (ExitConst->getValue().getActiveBits() > 32)
+ return 0;
+
+ // In case of integer overflow, this returns 0, which is correct.
+ return ((unsigned)ExitConst->getZExtValue()) + 1;
+}
+
unsigned ScalarEvolution::getSmallConstantTripCount(Loop *L) {
if (BasicBlock *ExitingBB = L->getExitingBlock())
return getSmallConstantTripCount(L, ExitingBB);
@@ -5304,17 +5363,13 @@ unsigned ScalarEvolution::getSmallConstantTripCount(Loop *L,
"Exiting block must actually branch out of the loop!");
const SCEVConstant *ExitCount =
dyn_cast<SCEVConstant>(getExitCount(L, ExitingBlock));
- if (!ExitCount)
- return 0;
-
- ConstantInt *ExitConst = ExitCount->getValue();
-
- // Guard against huge trip counts.
- if (ExitConst->getValue().getActiveBits() > 32)
- return 0;
+ return getConstantTripCount(ExitCount);
+}
- // In case of integer overflow, this returns 0, which is correct.
- return ((unsigned)ExitConst->getZExtValue()) + 1;
+unsigned ScalarEvolution::getSmallConstantMaxTripCount(Loop *L) {
+ const auto *MaxExitCount =
+ dyn_cast<SCEVConstant>(getMaxBackedgeTakenCount(L));
+ return getConstantTripCount(MaxExitCount);
}
unsigned ScalarEvolution::getSmallConstantTripMultiple(Loop *L) {
@@ -5393,6 +5448,10 @@ const SCEV *ScalarEvolution::getMaxBackedgeTakenCount(const Loop *L) {
return getBackedgeTakenInfo(L).getMax(this);
}
+bool ScalarEvolution::isBackedgeTakenCountMaxOrZero(const Loop *L) {
+ return getBackedgeTakenInfo(L).isMaxOrZero(this);
+}
+
/// Push PHI nodes in the header of the given loop onto the given Worklist.
static void
PushLoopPHIs(const Loop *L, SmallVectorImpl<Instruction *> &Worklist) {
@@ -5418,7 +5477,7 @@ ScalarEvolution::getPredicatedBackedgeTakenInfo(const Loop *L) {
BackedgeTakenInfo Result =
computeBackedgeTakenCount(L, /*AllowPredicates=*/true);
- return PredicatedBackedgeTakenCounts.find(L)->second = Result;
+ return PredicatedBackedgeTakenCounts.find(L)->second = std::move(Result);
}
const ScalarEvolution::BackedgeTakenInfo &
@@ -5493,7 +5552,7 @@ ScalarEvolution::getBackedgeTakenInfo(const Loop *L) {
// recusive call to getBackedgeTakenInfo (on a different
// loop), which would invalidate the iterator computed
// earlier.
- return BackedgeTakenCounts.find(L)->second = Result;
+ return BackedgeTakenCounts.find(L)->second = std::move(Result);
}
void ScalarEvolution::forgetLoop(const Loop *L) {
@@ -5537,7 +5596,7 @@ void ScalarEvolution::forgetLoop(const Loop *L) {
for (Loop *I : *L)
forgetLoop(I);
- LoopHasNoAbnormalExits.erase(L);
+ LoopPropertiesCache.erase(L);
}
void ScalarEvolution::forgetValue(Value *V) {
@@ -5576,14 +5635,11 @@ void ScalarEvolution::forgetValue(Value *V) {
/// caller's responsibility to specify the relevant loop exit using
/// getExact(ExitingBlock, SE).
const SCEV *
-ScalarEvolution::BackedgeTakenInfo::getExact(
- ScalarEvolution *SE, SCEVUnionPredicate *Preds) const {
+ScalarEvolution::BackedgeTakenInfo::getExact(ScalarEvolution *SE,
+ SCEVUnionPredicate *Preds) const {
// If any exits were not computable, the loop is not computable.
- if (!ExitNotTaken.isCompleteList()) return SE->getCouldNotCompute();
-
- // We need exactly one computable exit.
- if (!ExitNotTaken.ExitingBlock) return SE->getCouldNotCompute();
- assert(ExitNotTaken.ExactNotTaken && "uninitialized not-taken info");
+ if (!isComplete() || ExitNotTaken.empty())
+ return SE->getCouldNotCompute();
const SCEV *BECount = nullptr;
for (auto &ENT : ExitNotTaken) {
@@ -5593,10 +5649,10 @@ ScalarEvolution::BackedgeTakenInfo::getExact(
BECount = ENT.ExactNotTaken;
else if (BECount != ENT.ExactNotTaken)
return SE->getCouldNotCompute();
- if (Preds && ENT.getPred())
- Preds->add(ENT.getPred());
+ if (Preds && !ENT.hasAlwaysTruePredicate())
+ Preds->add(ENT.Predicate.get());
- assert((Preds || ENT.hasAlwaysTruePred()) &&
+ assert((Preds || ENT.hasAlwaysTruePredicate()) &&
"Predicate should be always true!");
}
@@ -5609,7 +5665,7 @@ const SCEV *
ScalarEvolution::BackedgeTakenInfo::getExact(BasicBlock *ExitingBlock,
ScalarEvolution *SE) const {
for (auto &ENT : ExitNotTaken)
- if (ENT.ExitingBlock == ExitingBlock && ENT.hasAlwaysTruePred())
+ if (ENT.ExitingBlock == ExitingBlock && ENT.hasAlwaysTruePredicate())
return ENT.ExactNotTaken;
return SE->getCouldNotCompute();
@@ -5618,21 +5674,29 @@ ScalarEvolution::BackedgeTakenInfo::getExact(BasicBlock *ExitingBlock,
/// getMax - Get the max backedge taken count for the loop.
const SCEV *
ScalarEvolution::BackedgeTakenInfo::getMax(ScalarEvolution *SE) const {
- for (auto &ENT : ExitNotTaken)
- if (!ENT.hasAlwaysTruePred())
- return SE->getCouldNotCompute();
+ auto PredicateNotAlwaysTrue = [](const ExitNotTakenInfo &ENT) {
+ return !ENT.hasAlwaysTruePredicate();
+ };
- return Max ? Max : SE->getCouldNotCompute();
+ if (any_of(ExitNotTaken, PredicateNotAlwaysTrue) || !getMax())
+ return SE->getCouldNotCompute();
+
+ return getMax();
+}
+
+bool ScalarEvolution::BackedgeTakenInfo::isMaxOrZero(ScalarEvolution *SE) const {
+ auto PredicateNotAlwaysTrue = [](const ExitNotTakenInfo &ENT) {
+ return !ENT.hasAlwaysTruePredicate();
+ };
+ return MaxOrZero && !any_of(ExitNotTaken, PredicateNotAlwaysTrue);
}
bool ScalarEvolution::BackedgeTakenInfo::hasOperand(const SCEV *S,
ScalarEvolution *SE) const {
- if (Max && Max != SE->getCouldNotCompute() && SE->hasOperand(Max, S))
+ if (getMax() && getMax() != SE->getCouldNotCompute() &&
+ SE->hasOperand(getMax(), S))
return true;
- if (!ExitNotTaken.ExitingBlock)
- return false;
-
for (auto &ENT : ExitNotTaken)
if (ENT.ExactNotTaken != SE->getCouldNotCompute() &&
SE->hasOperand(ENT.ExactNotTaken, S))
@@ -5644,62 +5708,31 @@ bool ScalarEvolution::BackedgeTakenInfo::hasOperand(const SCEV *S,
/// Allocate memory for BackedgeTakenInfo and copy the not-taken count of each
/// computable exit into a persistent ExitNotTakenInfo array.
ScalarEvolution::BackedgeTakenInfo::BackedgeTakenInfo(
- SmallVectorImpl<EdgeInfo> &ExitCounts, bool Complete, const SCEV *MaxCount)
- : Max(MaxCount) {
-
- if (!Complete)
- ExitNotTaken.setIncomplete();
-
- unsigned NumExits = ExitCounts.size();
- if (NumExits == 0) return;
-
- ExitNotTaken.ExitingBlock = ExitCounts[0].ExitBlock;
- ExitNotTaken.ExactNotTaken = ExitCounts[0].Taken;
-
- // Determine the number of ExitNotTakenExtras structures that we need.
- unsigned ExtraInfoSize = 0;
- if (NumExits > 1)
- ExtraInfoSize = 1 + std::count_if(std::next(ExitCounts.begin()),
- ExitCounts.end(), [](EdgeInfo &Entry) {
- return !Entry.Pred.isAlwaysTrue();
- });
- else if (!ExitCounts[0].Pred.isAlwaysTrue())
- ExtraInfoSize = 1;
-
- ExitNotTakenExtras *ENT = nullptr;
-
- // Allocate the ExitNotTakenExtras structures and initialize the first
- // element (ExitNotTaken).
- if (ExtraInfoSize > 0) {
- ENT = new ExitNotTakenExtras[ExtraInfoSize];
- ExitNotTaken.ExtraInfo = &ENT[0];
- *ExitNotTaken.getPred() = std::move(ExitCounts[0].Pred);
- }
-
- if (NumExits == 1)
- return;
-
- assert(ENT && "ExitNotTakenExtras is NULL while having more than one exit");
-
- auto &Exits = ExitNotTaken.ExtraInfo->Exits;
-
- // Handle the rare case of multiple computable exits.
- for (unsigned i = 1, PredPos = 1; i < NumExits; ++i) {
- ExitNotTakenExtras *Ptr = nullptr;
- if (!ExitCounts[i].Pred.isAlwaysTrue()) {
- Ptr = &ENT[PredPos++];
- Ptr->Pred = std::move(ExitCounts[i].Pred);
- }
-
- Exits.emplace_back(ExitCounts[i].ExitBlock, ExitCounts[i].Taken, Ptr);
- }
+ SmallVectorImpl<ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo>
+ &&ExitCounts,
+ bool Complete, const SCEV *MaxCount, bool MaxOrZero)
+ : MaxAndComplete(MaxCount, Complete), MaxOrZero(MaxOrZero) {
+ typedef ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo EdgeExitInfo;
+ ExitNotTaken.reserve(ExitCounts.size());
+ std::transform(
+ ExitCounts.begin(), ExitCounts.end(), std::back_inserter(ExitNotTaken),
+ [&](const EdgeExitInfo &EEI) {
+ BasicBlock *ExitBB = EEI.first;
+ const ExitLimit &EL = EEI.second;
+ if (EL.Predicates.empty())
+ return ExitNotTakenInfo(ExitBB, EL.ExactNotTaken, nullptr);
+
+ std::unique_ptr<SCEVUnionPredicate> Predicate(new SCEVUnionPredicate);
+ for (auto *Pred : EL.Predicates)
+ Predicate->add(Pred);
+
+ return ExitNotTakenInfo(ExitBB, EL.ExactNotTaken, std::move(Predicate));
+ });
}
/// Invalidate this result and free the ExitNotTakenInfo array.
void ScalarEvolution::BackedgeTakenInfo::clear() {
- ExitNotTaken.ExitingBlock = nullptr;
- ExitNotTaken.ExactNotTaken = nullptr;
- delete[] ExitNotTaken.ExtraInfo;
+ ExitNotTaken.clear();
}
/// Compute the number of times the backedge of the specified loop will execute.
@@ -5709,11 +5742,14 @@ ScalarEvolution::computeBackedgeTakenCount(const Loop *L,
SmallVector<BasicBlock *, 8> ExitingBlocks;
L->getExitingBlocks(ExitingBlocks);
- SmallVector<EdgeInfo, 4> ExitCounts;
+ typedef ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo EdgeExitInfo;
+
+ SmallVector<EdgeExitInfo, 4> ExitCounts;
bool CouldComputeBECount = true;
BasicBlock *Latch = L->getLoopLatch(); // may be NULL.
const SCEV *MustExitMaxBECount = nullptr;
const SCEV *MayExitMaxBECount = nullptr;
+ bool MustExitMaxOrZero = false;
// Compute the ExitLimit for each loop exit. Use this to populate ExitCounts
// and compute maxBECount.
@@ -5722,17 +5758,17 @@ ScalarEvolution::computeBackedgeTakenCount(const Loop *L,
BasicBlock *ExitBB = ExitingBlocks[i];
ExitLimit EL = computeExitLimit(L, ExitBB, AllowPredicates);
- assert((AllowPredicates || EL.Pred.isAlwaysTrue()) &&
+ assert((AllowPredicates || EL.Predicates.empty()) &&
"Predicated exit limit when predicates are not allowed!");
// 1. For each exit that can be computed, add an entry to ExitCounts.
// CouldComputeBECount is true only if all exits can be computed.
- if (EL.Exact == getCouldNotCompute())
+ if (EL.ExactNotTaken == getCouldNotCompute())
// We couldn't compute an exact value for this exit, so
// we won't be able to compute an exact value for the loop.
CouldComputeBECount = false;
else
- ExitCounts.emplace_back(EdgeInfo(ExitBB, EL.Exact, EL.Pred));
+ ExitCounts.emplace_back(ExitBB, EL);
// 2. Derive the loop's MaxBECount from each exit's max number of
// non-exiting iterations. Partition the loop exits into two kinds:
@@ -5740,29 +5776,35 @@ ScalarEvolution::computeBackedgeTakenCount(const Loop *L,
//
// If the exit dominates the loop latch, it is a LoopMustExit otherwise it
// is a LoopMayExit. If any computable LoopMustExit is found, then
- // MaxBECount is the minimum EL.Max of computable LoopMustExits. Otherwise,
- // MaxBECount is conservatively the maximum EL.Max, where CouldNotCompute is
- // considered greater than any computable EL.Max.
- if (EL.Max != getCouldNotCompute() && Latch &&
+ // MaxBECount is the minimum EL.MaxNotTaken of computable
+ // LoopMustExits. Otherwise, MaxBECount is conservatively the maximum
+ // EL.MaxNotTaken, where CouldNotCompute is considered greater than any
+ // computable EL.MaxNotTaken.
+ if (EL.MaxNotTaken != getCouldNotCompute() && Latch &&
DT.dominates(ExitBB, Latch)) {
- if (!MustExitMaxBECount)
- MustExitMaxBECount = EL.Max;
- else {
+ if (!MustExitMaxBECount) {
+ MustExitMaxBECount = EL.MaxNotTaken;
+ MustExitMaxOrZero = EL.MaxOrZero;
+ } else {
MustExitMaxBECount =
- getUMinFromMismatchedTypes(MustExitMaxBECount, EL.Max);
+ getUMinFromMismatchedTypes(MustExitMaxBECount, EL.MaxNotTaken);
}
} else if (MayExitMaxBECount != getCouldNotCompute()) {
- if (!MayExitMaxBECount || EL.Max == getCouldNotCompute())
- MayExitMaxBECount = EL.Max;
+ if (!MayExitMaxBECount || EL.MaxNotTaken == getCouldNotCompute())
+ MayExitMaxBECount = EL.MaxNotTaken;
else {
MayExitMaxBECount =
- getUMaxFromMismatchedTypes(MayExitMaxBECount, EL.Max);
+ getUMaxFromMismatchedTypes(MayExitMaxBECount, EL.MaxNotTaken);
}
}
}
const SCEV *MaxBECount = MustExitMaxBECount ? MustExitMaxBECount :
(MayExitMaxBECount ? MayExitMaxBECount : getCouldNotCompute());
- return BackedgeTakenInfo(ExitCounts, CouldComputeBECount, MaxBECount);
+ // The loop backedge will be taken the maximum or zero times if there's
+ // a single exit that must be taken the maximum or zero times.
+ bool MaxOrZero = (MustExitMaxOrZero && ExitingBlocks.size() == 1);
+ return BackedgeTakenInfo(std::move(ExitCounts), CouldComputeBECount,
+ MaxBECount, MaxOrZero);
}
ScalarEvolution::ExitLimit
@@ -5867,39 +5909,40 @@ ScalarEvolution::computeExitLimitFromCond(const Loop *L,
if (EitherMayExit) {
// Both conditions must be true for the loop to continue executing.
// Choose the less conservative count.
- if (EL0.Exact == getCouldNotCompute() ||
- EL1.Exact == getCouldNotCompute())
+ if (EL0.ExactNotTaken == getCouldNotCompute() ||
+ EL1.ExactNotTaken == getCouldNotCompute())
BECount = getCouldNotCompute();
else
- BECount = getUMinFromMismatchedTypes(EL0.Exact, EL1.Exact);
- if (EL0.Max == getCouldNotCompute())
- MaxBECount = EL1.Max;
- else if (EL1.Max == getCouldNotCompute())
- MaxBECount = EL0.Max;
+ BECount =
+ getUMinFromMismatchedTypes(EL0.ExactNotTaken, EL1.ExactNotTaken);
+ if (EL0.MaxNotTaken == getCouldNotCompute())
+ MaxBECount = EL1.MaxNotTaken;
+ else if (EL1.MaxNotTaken == getCouldNotCompute())
+ MaxBECount = EL0.MaxNotTaken;
else
- MaxBECount = getUMinFromMismatchedTypes(EL0.Max, EL1.Max);
+ MaxBECount =
+ getUMinFromMismatchedTypes(EL0.MaxNotTaken, EL1.MaxNotTaken);
} else {
// Both conditions must be true at the same time for the loop to exit.
// For now, be conservative.
assert(L->contains(FBB) && "Loop block has no successor in loop!");
- if (EL0.Max == EL1.Max)
- MaxBECount = EL0.Max;
- if (EL0.Exact == EL1.Exact)
- BECount = EL0.Exact;
+ if (EL0.MaxNotTaken == EL1.MaxNotTaken)
+ MaxBECount = EL0.MaxNotTaken;
+ if (EL0.ExactNotTaken == EL1.ExactNotTaken)
+ BECount = EL0.ExactNotTaken;
}
- SCEVUnionPredicate NP;
- NP.add(&EL0.Pred);
- NP.add(&EL1.Pred);
// There are cases (e.g. PR26207) where computeExitLimitFromCond is able
// to be more aggressive when computing BECount than when computing
- // MaxBECount. In these cases it is possible for EL0.Exact and EL1.Exact
- // to match, but for EL0.Max and EL1.Max to not.
+ // MaxBECount. In these cases it is possible for EL0.ExactNotTaken and
+ // EL1.ExactNotTaken to match, but for EL0.MaxNotTaken and EL1.MaxNotTaken
+ // to not.
if (isa<SCEVCouldNotCompute>(MaxBECount) &&
!isa<SCEVCouldNotCompute>(BECount))
MaxBECount = BECount;
- return ExitLimit(BECount, MaxBECount, NP);
+ return ExitLimit(BECount, MaxBECount, false,
+ {&EL0.Predicates, &EL1.Predicates});
}
if (BO->getOpcode() == Instruction::Or) {
// Recurse on the operands of the or.
@@ -5915,31 +5958,31 @@ ScalarEvolution::computeExitLimitFromCond(const Loop *L,
if (EitherMayExit) {
// Both conditions must be false for the loop to continue executing.
// Choose the less conservative count.
- if (EL0.Exact == getCouldNotCompute() ||
- EL1.Exact == getCouldNotCompute())
+ if (EL0.ExactNotTaken == getCouldNotCompute() ||
+ EL1.ExactNotTaken == getCouldNotCompute())
BECount = getCouldNotCompute();
else
- BECount = getUMinFromMismatchedTypes(EL0.Exact, EL1.Exact);
- if (EL0.Max == getCouldNotCompute())
- MaxBECount = EL1.Max;
- else if (EL1.Max == getCouldNotCompute())
- MaxBECount = EL0.Max;
+ BECount =
+ getUMinFromMismatchedTypes(EL0.ExactNotTaken, EL1.ExactNotTaken);
+ if (EL0.MaxNotTaken == getCouldNotCompute())
+ MaxBECount = EL1.MaxNotTaken;
+ else if (EL1.MaxNotTaken == getCouldNotCompute())
+ MaxBECount = EL0.MaxNotTaken;
else
- MaxBECount = getUMinFromMismatchedTypes(EL0.Max, EL1.Max);
+ MaxBECount =
+ getUMinFromMismatchedTypes(EL0.MaxNotTaken, EL1.MaxNotTaken);
} else {
// Both conditions must be false at the same time for the loop to exit.
// For now, be conservative.
assert(L->contains(TBB) && "Loop block has no successor in loop!");
- if (EL0.Max == EL1.Max)
- MaxBECount = EL0.Max;
- if (EL0.Exact == EL1.Exact)
- BECount = EL0.Exact;
+ if (EL0.MaxNotTaken == EL1.MaxNotTaken)
+ MaxBECount = EL0.MaxNotTaken;
+ if (EL0.ExactNotTaken == EL1.ExactNotTaken)
+ BECount = EL0.ExactNotTaken;
}
- SCEVUnionPredicate NP;
- NP.add(&EL0.Pred);
- NP.add(&EL1.Pred);
- return ExitLimit(BECount, MaxBECount, NP);
+ return ExitLimit(BECount, MaxBECount, false,
+ {&EL0.Predicates, &EL1.Predicates});
}
}
@@ -6021,8 +6064,8 @@ ScalarEvolution::computeExitLimitFromICmp(const Loop *L,
if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(LHS))
if (AddRec->getLoop() == L) {
// Form the constant range.
- ConstantRange CompRange(
- ICmpInst::makeConstantRange(Cond, RHSC->getAPInt()));
+ ConstantRange CompRange =
+ ConstantRange::makeExactICmpRegion(Cond, RHSC->getAPInt());
const SCEV *Ret = AddRec->getNumIterationsInRange(CompRange, *this);
if (!isa<SCEVCouldNotCompute>(Ret)) return Ret;
@@ -6226,7 +6269,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeShiftCompareExitLimit(
// %iv = phi i32 [ %iv.shifted, %loop ], [ %val, %preheader ]
// %iv.shifted = lshr i32 %iv, <positive constant>
//
- // Return true on a succesful match. Return the corresponding PHI node (%iv
+ // Return true on a successful match. Return the corresponding PHI node (%iv
// above) in PNOut and the opcode of the shift operation in OpCodeOut.
auto MatchShiftRecurrence =
[&](Value *V, PHINode *&PNOut, Instruction::BinaryOps &OpCodeOut) {
@@ -6324,8 +6367,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeShiftCompareExitLimit(
unsigned BitWidth = getTypeSizeInBits(RHS->getType());
const SCEV *UpperBound =
getConstant(getEffectiveSCEVType(RHS->getType()), BitWidth);
- SCEVUnionPredicate P;
- return ExitLimit(getCouldNotCompute(), UpperBound, P);
+ return ExitLimit(getCouldNotCompute(), UpperBound, false);
}
return getCouldNotCompute();
@@ -6995,20 +7037,21 @@ static const SCEV *SolveLinEquationWithOverflow(const APInt &A, const APInt &B,
// 3. Compute I: the multiplicative inverse of (A / D) in arithmetic
// modulo (N / D).
//
- // (N / D) may need BW+1 bits in its representation. Hence, we'll use this
- // bit width during computations.
+ // If D == 1, (N / D) == N == 2^BW, so we need one extra bit to represent
+ // (N / D) in general. The inverse itself always fits into BW bits, though,
+ // so we immediately truncate it.
APInt AD = A.lshr(Mult2).zext(BW + 1); // AD = A / D
APInt Mod(BW + 1, 0);
Mod.setBit(BW - Mult2); // Mod = N / D
- APInt I = AD.multiplicativeInverse(Mod);
+ APInt I = AD.multiplicativeInverse(Mod).trunc(BW);
// 4. Compute the minimum unsigned root of the equation:
// I * (B / D) mod (N / D)
- APInt Result = (I * B.lshr(Mult2).zext(BW + 1)).urem(Mod);
+ // To simplify the computation, we factor out the divide by D:
+ // (I * B mod N) / D
+ APInt Result = (I * B).lshr(Mult2);
- // The result is guaranteed to be less than 2^BW so we may truncate it to BW
- // bits.
- return SE.getConstant(Result.trunc(BW));
+ return SE.getConstant(Result);
}
/// Find the roots of the quadratic equation for the given quadratic chrec
@@ -7086,7 +7129,7 @@ ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L, bool ControlsExit,
// effectively V != 0. We know and take advantage of the fact that this
// expression only being used in a comparison by zero context.
- SCEVUnionPredicate P;
+ SmallPtrSet<const SCEVPredicate *, 4> Predicates;
// If the value is a constant
if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
// If the value is already zero, the branch will execute zero times.
@@ -7099,7 +7142,7 @@ ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L, bool ControlsExit,
// Try to make this an AddRec using runtime tests, in the first X
// iterations of this loop, where X is the SCEV expression found by the
// algorithm below.
- AddRec = convertSCEVToAddRecWithPredicates(V, L, P);
+ AddRec = convertSCEVToAddRecWithPredicates(V, L, Predicates);
if (!AddRec || AddRec->getLoop() != L)
return getCouldNotCompute();
@@ -7121,7 +7164,8 @@ ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L, bool ControlsExit,
// should not accept a root of 2.
const SCEV *Val = AddRec->evaluateAtIteration(R1, *this);
if (Val->isZero())
- return ExitLimit(R1, R1, P); // We found a quadratic root!
+ // We found a quadratic root!
+ return ExitLimit(R1, R1, false, Predicates);
}
}
return getCouldNotCompute();
@@ -7168,17 +7212,25 @@ ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L, bool ControlsExit,
// 1*N = -Start; -1*N = Start (mod 2^BW), so:
// N = Distance (as unsigned)
if (StepC->getValue()->equalsInt(1) || StepC->getValue()->isAllOnesValue()) {
- ConstantRange CR = getUnsignedRange(Start);
- const SCEV *MaxBECount;
- if (!CountDown && CR.getUnsignedMin().isMinValue())
- // When counting up, the worst starting value is 1, not 0.
- MaxBECount = CR.getUnsignedMax().isMinValue()
- ? getConstant(APInt::getMinValue(CR.getBitWidth()))
- : getConstant(APInt::getMaxValue(CR.getBitWidth()));
- else
- MaxBECount = getConstant(CountDown ? CR.getUnsignedMax()
- : -CR.getUnsignedMin());
- return ExitLimit(Distance, MaxBECount, P);
+ APInt MaxBECount = getUnsignedRange(Distance).getUnsignedMax();
+
+ // When a loop like "for (int i = 0; i != n; ++i) { /* body */ }" is rotated,
+ // we end up with a loop whose backedge-taken count is n - 1. Detect this
+ // case, and see if we can improve the bound.
+ //
+ // Explicitly handling this here is necessary because getUnsignedRange
+ // isn't context-sensitive; it doesn't know that we only care about the
+ // range inside the loop.
+ const SCEV *Zero = getZero(Distance->getType());
+ const SCEV *One = getOne(Distance->getType());
+ const SCEV *DistancePlusOne = getAddExpr(Distance, One);
+ if (isLoopEntryGuardedByCond(L, ICmpInst::ICMP_NE, DistancePlusOne, Zero)) {
+ // If Distance + 1 doesn't overflow, we can compute the maximum distance
+ // as "unsigned_max(Distance + 1) - 1".
+ ConstantRange CR = getUnsignedRange(DistancePlusOne);
+ MaxBECount = APIntOps::umin(MaxBECount, CR.getUnsignedMax() - 1);
+ }
+ return ExitLimit(Distance, getConstant(MaxBECount), false, Predicates);
}
// As a special case, handle the instance where Step is a positive power of
@@ -7233,7 +7285,7 @@ ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L, bool ControlsExit,
const SCEV *Limit =
getZeroExtendExpr(getTruncateExpr(ModuloResult, NarrowTy), WideTy);
- return ExitLimit(Limit, Limit, P);
+ return ExitLimit(Limit, Limit, false, Predicates);
}
}
@@ -7246,14 +7298,14 @@ ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L, bool ControlsExit,
loopHasNoAbnormalExits(AddRec->getLoop())) {
const SCEV *Exact =
getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step);
- return ExitLimit(Exact, Exact, P);
+ return ExitLimit(Exact, Exact, false, Predicates);
}
// Then, try to solve the above equation provided that Start is constant.
if (const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start)) {
const SCEV *E = SolveLinEquationWithOverflow(
StepC->getValue()->getValue(), -StartC->getValue()->getValue(), *this);
- return ExitLimit(E, E, P);
+ return ExitLimit(E, E, false, Predicates);
}
return getCouldNotCompute();
}
@@ -7365,149 +7417,77 @@ bool ScalarEvolution::SimplifyICmpOperands(ICmpInst::Predicate &Pred,
// cases, and canonicalize *-or-equal comparisons to regular comparisons.
if (const SCEVConstant *RC = dyn_cast<SCEVConstant>(RHS)) {
const APInt &RA = RC->getAPInt();
- switch (Pred) {
- default: llvm_unreachable("Unexpected ICmpInst::Predicate value!");
- case ICmpInst::ICMP_EQ:
- case ICmpInst::ICMP_NE:
- // Fold ((-1) * %a) + %b == 0 (equivalent to %b-%a == 0) into %a == %b.
- if (!RA)
- if (const SCEVAddExpr *AE = dyn_cast<SCEVAddExpr>(LHS))
- if (const SCEVMulExpr *ME = dyn_cast<SCEVMulExpr>(AE->getOperand(0)))
- if (AE->getNumOperands() == 2 && ME->getNumOperands() == 2 &&
- ME->getOperand(0)->isAllOnesValue()) {
- RHS = AE->getOperand(1);
- LHS = ME->getOperand(1);
- Changed = true;
- }
- break;
- case ICmpInst::ICMP_UGE:
- if ((RA - 1).isMinValue()) {
- Pred = ICmpInst::ICMP_NE;
- RHS = getConstant(RA - 1);
- Changed = true;
- break;
- }
- if (RA.isMaxValue()) {
- Pred = ICmpInst::ICMP_EQ;
- Changed = true;
- break;
- }
- if (RA.isMinValue()) goto trivially_true;
- Pred = ICmpInst::ICMP_UGT;
- RHS = getConstant(RA - 1);
- Changed = true;
- break;
- case ICmpInst::ICMP_ULE:
- if ((RA + 1).isMaxValue()) {
- Pred = ICmpInst::ICMP_NE;
- RHS = getConstant(RA + 1);
- Changed = true;
- break;
- }
- if (RA.isMinValue()) {
- Pred = ICmpInst::ICMP_EQ;
- Changed = true;
- break;
- }
- if (RA.isMaxValue()) goto trivially_true;
+ bool SimplifiedByConstantRange = false;
- Pred = ICmpInst::ICMP_ULT;
- RHS = getConstant(RA + 1);
- Changed = true;
- break;
- case ICmpInst::ICMP_SGE:
- if ((RA - 1).isMinSignedValue()) {
- Pred = ICmpInst::ICMP_NE;
- RHS = getConstant(RA - 1);
- Changed = true;
- break;
- }
- if (RA.isMaxSignedValue()) {
- Pred = ICmpInst::ICMP_EQ;
- Changed = true;
- break;
+ if (!ICmpInst::isEquality(Pred)) {
+ ConstantRange ExactCR = ConstantRange::makeExactICmpRegion(Pred, RA);
+ if (ExactCR.isFullSet())
+ goto trivially_true;
+ else if (ExactCR.isEmptySet())
+ goto trivially_false;
+
+ APInt NewRHS;
+ CmpInst::Predicate NewPred;
+ if (ExactCR.getEquivalentICmp(NewPred, NewRHS) &&
+ ICmpInst::isEquality(NewPred)) {
+ // We were able to convert an inequality to an equality.
+ Pred = NewPred;
+ RHS = getConstant(NewRHS);
+ Changed = SimplifiedByConstantRange = true;
}
- if (RA.isMinSignedValue()) goto trivially_true;
+ }
- Pred = ICmpInst::ICMP_SGT;
- RHS = getConstant(RA - 1);
- Changed = true;
- break;
- case ICmpInst::ICMP_SLE:
- if ((RA + 1).isMaxSignedValue()) {
- Pred = ICmpInst::ICMP_NE;
- RHS = getConstant(RA + 1);
- Changed = true;
+ if (!SimplifiedByConstantRange) {
+ switch (Pred) {
+ default:
break;
- }
- if (RA.isMinSignedValue()) {
- Pred = ICmpInst::ICMP_EQ;
- Changed = true;
+ case ICmpInst::ICMP_EQ:
+ case ICmpInst::ICMP_NE:
+ // Fold ((-1) * %a) + %b == 0 (equivalent to %b-%a == 0) into %a == %b.
+ if (!RA)
+ if (const SCEVAddExpr *AE = dyn_cast<SCEVAddExpr>(LHS))
+ if (const SCEVMulExpr *ME =
+ dyn_cast<SCEVMulExpr>(AE->getOperand(0)))
+ if (AE->getNumOperands() == 2 && ME->getNumOperands() == 2 &&
+ ME->getOperand(0)->isAllOnesValue()) {
+ RHS = AE->getOperand(1);
+ LHS = ME->getOperand(1);
+ Changed = true;
+ }
break;
- }
- if (RA.isMaxSignedValue()) goto trivially_true;
- Pred = ICmpInst::ICMP_SLT;
- RHS = getConstant(RA + 1);
- Changed = true;
- break;
- case ICmpInst::ICMP_UGT:
- if (RA.isMinValue()) {
- Pred = ICmpInst::ICMP_NE;
+
+ // The "Should have been caught earlier!" messages refer to the fact
+ // that the ExactCR.isFullSet() or ExactCR.isEmptySet() check above
+ // should have fired on the corresponding cases, and canonicalized the
+ // check to trivially_true or trivially_false.
+
+ case ICmpInst::ICMP_UGE:
+ assert(!RA.isMinValue() && "Should have been caught earlier!");
+ Pred = ICmpInst::ICMP_UGT;
+ RHS = getConstant(RA - 1);
Changed = true;
break;
- }
- if ((RA + 1).isMaxValue()) {
- Pred = ICmpInst::ICMP_EQ;
+ case ICmpInst::ICMP_ULE:
+ assert(!RA.isMaxValue() && "Should have been caught earlier!");
+ Pred = ICmpInst::ICMP_ULT;
RHS = getConstant(RA + 1);
Changed = true;
break;
- }
- if (RA.isMaxValue()) goto trivially_false;
- break;
- case ICmpInst::ICMP_ULT:
- if (RA.isMaxValue()) {
- Pred = ICmpInst::ICMP_NE;
- Changed = true;
- break;
- }
- if ((RA - 1).isMinValue()) {
- Pred = ICmpInst::ICMP_EQ;
+ case ICmpInst::ICMP_SGE:
+ assert(!RA.isMinSignedValue() && "Should have been caught earlier!");
+ Pred = ICmpInst::ICMP_SGT;
RHS = getConstant(RA - 1);
Changed = true;
break;
- }
- if (RA.isMinValue()) goto trivially_false;
- break;
- case ICmpInst::ICMP_SGT:
- if (RA.isMinSignedValue()) {
- Pred = ICmpInst::ICMP_NE;
- Changed = true;
- break;
- }
- if ((RA + 1).isMaxSignedValue()) {
- Pred = ICmpInst::ICMP_EQ;
+ case ICmpInst::ICMP_SLE:
+ assert(!RA.isMaxSignedValue() && "Should have been caught earlier!");
+ Pred = ICmpInst::ICMP_SLT;
RHS = getConstant(RA + 1);
Changed = true;
break;
}
- if (RA.isMaxSignedValue()) goto trivially_false;
- break;
- case ICmpInst::ICMP_SLT:
- if (RA.isMaxSignedValue()) {
- Pred = ICmpInst::ICMP_NE;
- Changed = true;
- break;
- }
- if ((RA - 1).isMinSignedValue()) {
- Pred = ICmpInst::ICMP_EQ;
- RHS = getConstant(RA - 1);
- Changed = true;
- break;
- }
- if (RA.isMinSignedValue()) goto trivially_false;
- break;
}
}
@@ -8067,34 +8047,16 @@ ScalarEvolution::isLoopEntryGuardedByCond(const Loop *L,
return false;
}
-namespace {
-/// RAII wrapper to prevent recursive application of isImpliedCond.
-/// ScalarEvolution's PendingLoopPredicates set must be empty unless we are
-/// currently evaluating isImpliedCond.
-struct MarkPendingLoopPredicate {
- Value *Cond;
- DenseSet<Value*> &LoopPreds;
- bool Pending;
-
- MarkPendingLoopPredicate(Value *C, DenseSet<Value*> &LP)
- : Cond(C), LoopPreds(LP) {
- Pending = !LoopPreds.insert(Cond).second;
- }
- ~MarkPendingLoopPredicate() {
- if (!Pending)
- LoopPreds.erase(Cond);
- }
-};
-} // end anonymous namespace
-
bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred,
const SCEV *LHS, const SCEV *RHS,
Value *FoundCondValue,
bool Inverse) {
- MarkPendingLoopPredicate Mark(FoundCondValue, PendingLoopPredicates);
- if (Mark.Pending)
+ if (!PendingLoopPredicates.insert(FoundCondValue).second)
return false;
+ auto ClearOnExit =
+ make_scope_exit([&]() { PendingLoopPredicates.erase(FoundCondValue); });
+
// Recursively handle And and Or conditions.
if (BinaryOperator *BO = dyn_cast<BinaryOperator>(FoundCondValue)) {
if (BO->getOpcode() == Instruction::And) {
@@ -8279,9 +8241,8 @@ bool ScalarEvolution::splitBinaryAdd(const SCEV *Expr,
return true;
}
-bool ScalarEvolution::computeConstantDifference(const SCEV *Less,
- const SCEV *More,
- APInt &C) {
+Optional<APInt> ScalarEvolution::computeConstantDifference(const SCEV *More,
+ const SCEV *Less) {
// We avoid subtracting expressions here because this function is usually
// fairly deep in the call stack (i.e. is called many times).
@@ -8290,15 +8251,15 @@ bool ScalarEvolution::computeConstantDifference(const SCEV *Less,
const auto *MAR = cast<SCEVAddRecExpr>(More);
if (LAR->getLoop() != MAR->getLoop())
- return false;
+ return None;
// We look at affine expressions only; not for correctness but to keep
// getStepRecurrence cheap.
if (!LAR->isAffine() || !MAR->isAffine())
- return false;
+ return None;
if (LAR->getStepRecurrence(*this) != MAR->getStepRecurrence(*this))
- return false;
+ return None;
Less = LAR->getStart();
More = MAR->getStart();
@@ -8309,27 +8270,22 @@ bool ScalarEvolution::computeConstantDifference(const SCEV *Less,
if (isa<SCEVConstant>(Less) && isa<SCEVConstant>(More)) {
const auto &M = cast<SCEVConstant>(More)->getAPInt();
const auto &L = cast<SCEVConstant>(Less)->getAPInt();
- C = M - L;
- return true;
+ return M - L;
}
const SCEV *L, *R;
SCEV::NoWrapFlags Flags;
if (splitBinaryAdd(Less, L, R, Flags))
if (const auto *LC = dyn_cast<SCEVConstant>(L))
- if (R == More) {
- C = -(LC->getAPInt());
- return true;
- }
+ if (R == More)
+ return -(LC->getAPInt());
if (splitBinaryAdd(More, L, R, Flags))
if (const auto *LC = dyn_cast<SCEVConstant>(L))
- if (R == Less) {
- C = LC->getAPInt();
- return true;
- }
+ if (R == Less)
+ return LC->getAPInt();
- return false;
+ return None;
}
bool ScalarEvolution::isImpliedCondOperandsViaNoOverflow(
@@ -8386,22 +8342,21 @@ bool ScalarEvolution::isImpliedCondOperandsViaNoOverflow(
// neither necessary nor sufficient to prove "(FoundLHS + C) s< (FoundRHS +
// C)".
- APInt LDiff, RDiff;
- if (!computeConstantDifference(FoundLHS, LHS, LDiff) ||
- !computeConstantDifference(FoundRHS, RHS, RDiff) ||
- LDiff != RDiff)
+ Optional<APInt> LDiff = computeConstantDifference(LHS, FoundLHS);
+ Optional<APInt> RDiff = computeConstantDifference(RHS, FoundRHS);
+ if (!LDiff || !RDiff || *LDiff != *RDiff)
return false;
- if (LDiff == 0)
+ if (LDiff->isMinValue())
return true;
APInt FoundRHSLimit;
if (Pred == CmpInst::ICMP_ULT) {
- FoundRHSLimit = -RDiff;
+ FoundRHSLimit = -(*RDiff);
} else {
assert(Pred == CmpInst::ICMP_SLT && "Checked above!");
- FoundRHSLimit = APInt::getSignedMinValue(getTypeSizeInBits(RHS->getType())) - RDiff;
+ FoundRHSLimit = APInt::getSignedMinValue(getTypeSizeInBits(RHS->getType())) - *RDiff;
}
// Try to prove (1) or (2), as needed.
@@ -8511,7 +8466,7 @@ static bool IsKnownPredicateViaMinOrMax(ScalarEvolution &SE,
case ICmpInst::ICMP_SGE:
std::swap(LHS, RHS);
- // fall through
+ LLVM_FALLTHROUGH;
case ICmpInst::ICMP_SLE:
return
// min(A, ...) <= A
@@ -8521,7 +8476,7 @@ static bool IsKnownPredicateViaMinOrMax(ScalarEvolution &SE,
case ICmpInst::ICMP_UGE:
std::swap(LHS, RHS);
- // fall through
+ LLVM_FALLTHROUGH;
case ICmpInst::ICMP_ULE:
return
// min(A, ...) <= A
@@ -8592,9 +8547,8 @@ bool ScalarEvolution::isImpliedCondOperandsViaRanges(ICmpInst::Predicate Pred,
// reduce the compile time impact of this optimization.
return false;
- const SCEVAddExpr *AddLHS = dyn_cast<SCEVAddExpr>(LHS);
- if (!AddLHS || AddLHS->getOperand(1) != FoundLHS ||
- !isa<SCEVConstant>(AddLHS->getOperand(0)))
+ Optional<APInt> Addend = computeConstantDifference(LHS, FoundLHS);
+ if (!Addend)
return false;
APInt ConstFoundRHS = cast<SCEVConstant>(FoundRHS)->getAPInt();
@@ -8604,10 +8558,8 @@ bool ScalarEvolution::isImpliedCondOperandsViaRanges(ICmpInst::Predicate Pred,
ConstantRange FoundLHSRange =
ConstantRange::makeAllowedICmpRegion(Pred, ConstFoundRHS);
- // Since `LHS` is `FoundLHS` + `AddLHS->getOperand(0)`, we can compute a range
- // for `LHS`:
- APInt Addend = cast<SCEVConstant>(AddLHS->getOperand(0))->getAPInt();
- ConstantRange LHSRange = FoundLHSRange.add(ConstantRange(Addend));
+ // Since `LHS` is `FoundLHS` + `Addend`, we can compute a range for `LHS`:
+ ConstantRange LHSRange = FoundLHSRange.add(ConstantRange(*Addend));
// We can also compute the range of values for `LHS` that satisfy the
// consequent, "`LHS` `Pred` `RHS`":
@@ -8622,6 +8574,8 @@ bool ScalarEvolution::isImpliedCondOperandsViaRanges(ICmpInst::Predicate Pred,
bool ScalarEvolution::doesIVOverflowOnLT(const SCEV *RHS, const SCEV *Stride,
bool IsSigned, bool NoWrap) {
+ assert(isKnownPositive(Stride) && "Positive stride expected!");
+
if (NoWrap) return false;
unsigned BitWidth = getTypeSizeInBits(RHS->getType());
@@ -8684,17 +8638,21 @@ ScalarEvolution::ExitLimit
ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
const Loop *L, bool IsSigned,
bool ControlsExit, bool AllowPredicates) {
- SCEVUnionPredicate P;
+ SmallPtrSet<const SCEVPredicate *, 4> Predicates;
// We handle only IV < Invariant
if (!isLoopInvariant(RHS, L))
return getCouldNotCompute();
const SCEVAddRecExpr *IV = dyn_cast<SCEVAddRecExpr>(LHS);
- if (!IV && AllowPredicates)
+ bool PredicatedIV = false;
+
+ if (!IV && AllowPredicates) {
// Try to make this an AddRec using runtime tests, in the first X
// iterations of this loop, where X is the SCEV expression found by the
// algorithm below.
- IV = convertSCEVToAddRecWithPredicates(LHS, L, P);
+ IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates);
+ PredicatedIV = true;
+ }
// Avoid weird loops
if (!IV || IV->getLoop() != L || !IV->isAffine())
@@ -8705,61 +8663,144 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
const SCEV *Stride = IV->getStepRecurrence(*this);
- // Avoid negative or zero stride values
- if (!isKnownPositive(Stride))
- return getCouldNotCompute();
+ bool PositiveStride = isKnownPositive(Stride);
- // Avoid proven overflow cases: this will ensure that the backedge taken count
- // will not generate any unsigned overflow. Relaxed no-overflow conditions
- // exploit NoWrapFlags, allowing to optimize in presence of undefined
- // behaviors like the case of C language.
- if (!Stride->isOne() && doesIVOverflowOnLT(RHS, Stride, IsSigned, NoWrap))
+ // Avoid negative or zero stride values.
+ if (!PositiveStride) {
+ // We can compute the correct backedge taken count for loops with unknown
+ // strides if we can prove that the loop is not an infinite loop with side
+ // effects. Here's the loop structure we are trying to handle -
+ //
+ // i = start
+ // do {
+ // A[i] = i;
+ // i += s;
+ // } while (i < end);
+ //
+ // The backedge taken count for such loops is evaluated as -
+ // (max(end, start + stride) - start - 1) /u stride
+ //
+ // The additional preconditions that we need to check to prove correctness
+ // of the above formula is as follows -
+ //
+ // a) IV is either nuw or nsw depending upon signedness (indicated by the
+ // NoWrap flag).
+ // b) loop is single exit with no side effects.
+ //
+ //
+ // Precondition a) implies that if the stride is negative, this is a single
+ // trip loop. The backedge taken count formula reduces to zero in this case.
+ //
+ // Precondition b) implies that the unknown stride cannot be zero otherwise
+ // we have UB.
+ //
+ // The positive stride case is the same as isKnownPositive(Stride) returning
+ // true (original behavior of the function).
+ //
+ // We want to make sure that the stride is truly unknown as there are edge
+ // cases where ScalarEvolution propagates no wrap flags to the
+ // post-increment/decrement IV even though the increment/decrement operation
+ // itself is wrapping. The computed backedge taken count may be wrong in
+ // such cases. This is prevented by checking that the stride is not known to
+ // be either positive or non-positive. For example, no wrap flags are
+ // propagated to the post-increment IV of this loop with a trip count of 2 -
+ //
+ // unsigned char i;
+ // for(i=127; i<128; i+=129)
+ // A[i] = i;
+ //
+ if (PredicatedIV || !NoWrap || isKnownNonPositive(Stride) ||
+ !loopHasNoSideEffects(L))
+ return getCouldNotCompute();
+
+ } else if (!Stride->isOne() &&
+ doesIVOverflowOnLT(RHS, Stride, IsSigned, NoWrap))
+ // Avoid proven overflow cases: this will ensure that the backedge taken
+ // count will not generate any unsigned overflow. Relaxed no-overflow
+ // conditions exploit NoWrapFlags, allowing to optimize in presence of
+ // undefined behaviors like the case of C language.
return getCouldNotCompute();
ICmpInst::Predicate Cond = IsSigned ? ICmpInst::ICMP_SLT
: ICmpInst::ICMP_ULT;
const SCEV *Start = IV->getStart();
const SCEV *End = RHS;
- if (!isLoopEntryGuardedByCond(L, Cond, getMinusSCEV(Start, Stride), RHS))
+ // If the backedge is taken at least once, then it will be taken
+ // (End-Start)/Stride times (rounded up to a multiple of Stride), where Start
+ // is the LHS value of the less-than comparison the first time it is evaluated
+ // and End is the RHS.
+ const SCEV *BECountIfBackedgeTaken =
+ computeBECount(getMinusSCEV(End, Start), Stride, false);
+ // If the loop entry is guarded by the result of the backedge test of the
+ // first loop iteration, then we know the backedge will be taken at least
+ // once and so the backedge taken count is as above. If not then we use the
+ // expression (max(End,Start)-Start)/Stride to describe the backedge count,
+ // as if the backedge is taken at least once max(End,Start) is End and so the
+ // result is as above, and if not max(End,Start) is Start so we get a backedge
+ // count of zero.
+ const SCEV *BECount;
+ if (isLoopEntryGuardedByCond(L, Cond, getMinusSCEV(Start, Stride), RHS))
+ BECount = BECountIfBackedgeTaken;
+ else {
End = IsSigned ? getSMaxExpr(RHS, Start) : getUMaxExpr(RHS, Start);
+ BECount = computeBECount(getMinusSCEV(End, Start), Stride, false);
+ }
- const SCEV *BECount = computeBECount(getMinusSCEV(End, Start), Stride, false);
+ const SCEV *MaxBECount;
+ bool MaxOrZero = false;
+ if (isa<SCEVConstant>(BECount))
+ MaxBECount = BECount;
+ else if (isa<SCEVConstant>(BECountIfBackedgeTaken)) {
+ // If we know exactly how many times the backedge will be taken if it's
+ // taken at least once, then the backedge count will either be that or
+ // zero.
+ MaxBECount = BECountIfBackedgeTaken;
+ MaxOrZero = true;
+ } else {
+ // Calculate the maximum backedge count based on the range of values
+ // permitted by Start, End, and Stride.
+ APInt MinStart = IsSigned ? getSignedRange(Start).getSignedMin()
+ : getUnsignedRange(Start).getUnsignedMin();
- APInt MinStart = IsSigned ? getSignedRange(Start).getSignedMin()
- : getUnsignedRange(Start).getUnsignedMin();
+ unsigned BitWidth = getTypeSizeInBits(LHS->getType());
- APInt MinStride = IsSigned ? getSignedRange(Stride).getSignedMin()
- : getUnsignedRange(Stride).getUnsignedMin();
+ APInt StrideForMaxBECount;
- unsigned BitWidth = getTypeSizeInBits(LHS->getType());
- APInt Limit = IsSigned ? APInt::getSignedMaxValue(BitWidth) - (MinStride - 1)
- : APInt::getMaxValue(BitWidth) - (MinStride - 1);
+ if (PositiveStride)
+ StrideForMaxBECount =
+ IsSigned ? getSignedRange(Stride).getSignedMin()
+ : getUnsignedRange(Stride).getUnsignedMin();
+ else
+ // Using a stride of 1 is safe when computing max backedge taken count for
+ // a loop with unknown stride.
+ StrideForMaxBECount = APInt(BitWidth, 1, IsSigned);
- // Although End can be a MAX expression we estimate MaxEnd considering only
- // the case End = RHS. This is safe because in the other case (End - Start)
- // is zero, leading to a zero maximum backedge taken count.
- APInt MaxEnd =
- IsSigned ? APIntOps::smin(getSignedRange(RHS).getSignedMax(), Limit)
- : APIntOps::umin(getUnsignedRange(RHS).getUnsignedMax(), Limit);
+ APInt Limit =
+ IsSigned ? APInt::getSignedMaxValue(BitWidth) - (StrideForMaxBECount - 1)
+ : APInt::getMaxValue(BitWidth) - (StrideForMaxBECount - 1);
+
+ // Although End can be a MAX expression we estimate MaxEnd considering only
+ // the case End = RHS. This is safe because in the other case (End - Start)
+ // is zero, leading to a zero maximum backedge taken count.
+ APInt MaxEnd =
+ IsSigned ? APIntOps::smin(getSignedRange(RHS).getSignedMax(), Limit)
+ : APIntOps::umin(getUnsignedRange(RHS).getUnsignedMax(), Limit);
- const SCEV *MaxBECount;
- if (isa<SCEVConstant>(BECount))
- MaxBECount = BECount;
- else
MaxBECount = computeBECount(getConstant(MaxEnd - MinStart),
- getConstant(MinStride), false);
+ getConstant(StrideForMaxBECount), false);
+ }
if (isa<SCEVCouldNotCompute>(MaxBECount))
MaxBECount = BECount;
- return ExitLimit(BECount, MaxBECount, P);
+ return ExitLimit(BECount, MaxBECount, MaxOrZero, Predicates);
}
ScalarEvolution::ExitLimit
ScalarEvolution::howManyGreaterThans(const SCEV *LHS, const SCEV *RHS,
const Loop *L, bool IsSigned,
bool ControlsExit, bool AllowPredicates) {
- SCEVUnionPredicate P;
+ SmallPtrSet<const SCEVPredicate *, 4> Predicates;
// We handle only IV > Invariant
if (!isLoopInvariant(RHS, L))
return getCouldNotCompute();
@@ -8769,7 +8810,7 @@ ScalarEvolution::howManyGreaterThans(const SCEV *LHS, const SCEV *RHS,
// Try to make this an AddRec using runtime tests, in the first X
// iterations of this loop, where X is the SCEV expression found by the
// algorithm below.
- IV = convertSCEVToAddRecWithPredicates(LHS, L, P);
+ IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates);
// Avoid weird loops
if (!IV || IV->getLoop() != L || !IV->isAffine())
@@ -8829,7 +8870,7 @@ ScalarEvolution::howManyGreaterThans(const SCEV *LHS, const SCEV *RHS,
if (isa<SCEVCouldNotCompute>(MaxBECount))
MaxBECount = BECount;
- return ExitLimit(BECount, MaxBECount, P);
+ return ExitLimit(BECount, MaxBECount, false, Predicates);
}
const SCEV *SCEVAddRecExpr::getNumIterationsInRange(const ConstantRange &Range,
@@ -8901,9 +8942,7 @@ const SCEV *SCEVAddRecExpr::getNumIterationsInRange(const ConstantRange &Range,
// Range.getUpper() is crossed.
SmallVector<const SCEV *, 4> NewOps(op_begin(), op_end());
NewOps[0] = SE.getNegativeSCEV(SE.getConstant(Range.getUpper()));
- const SCEV *NewAddRec = SE.getAddRecExpr(NewOps, getLoop(),
- // getNoWrapFlags(FlagNW)
- FlagAnyWrap);
+ const SCEV *NewAddRec = SE.getAddRecExpr(NewOps, getLoop(), FlagAnyWrap);
// Next, solve the constructed addrec
if (auto Roots =
@@ -8947,38 +8986,15 @@ const SCEV *SCEVAddRecExpr::getNumIterationsInRange(const ConstantRange &Range,
return SE.getCouldNotCompute();
}
-namespace {
-struct FindUndefs {
- bool Found;
- FindUndefs() : Found(false) {}
-
- bool follow(const SCEV *S) {
- if (const SCEVUnknown *C = dyn_cast<SCEVUnknown>(S)) {
- if (isa<UndefValue>(C->getValue()))
- Found = true;
- } else if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S)) {
- if (isa<UndefValue>(C->getValue()))
- Found = true;
- }
-
- // Keep looking if we haven't found it yet.
- return !Found;
- }
- bool isDone() const {
- // Stop recursion if we have found an undef.
- return Found;
- }
-};
-}
-
// Return true when S contains at least an undef value.
-static inline bool
-containsUndefs(const SCEV *S) {
- FindUndefs F;
- SCEVTraversal<FindUndefs> ST(F);
- ST.visitAll(S);
-
- return F.Found;
+static inline bool containsUndefs(const SCEV *S) {
+ return SCEVExprContains(S, [](const SCEV *S) {
+ if (const auto *SU = dyn_cast<SCEVUnknown>(S))
+ return isa<UndefValue>(SU->getValue());
+ else if (const auto *SC = dyn_cast<SCEVConstant>(S))
+ return isa<UndefValue>(SC->getValue());
+ return false;
+ });
}
namespace {
@@ -9006,7 +9022,8 @@ struct SCEVCollectTerms {
: Terms(T) {}
bool follow(const SCEV *S) {
- if (isa<SCEVUnknown>(S) || isa<SCEVMulExpr>(S)) {
+ if (isa<SCEVUnknown>(S) || isa<SCEVMulExpr>(S) ||
+ isa<SCEVSignExtendExpr>(S)) {
if (!containsUndefs(S))
Terms.push_back(S);
@@ -9158,10 +9175,9 @@ static bool findArrayDimensionsRec(ScalarEvolution &SE,
}
// Remove all SCEVConstants.
- Terms.erase(std::remove_if(Terms.begin(), Terms.end(), [](const SCEV *E) {
- return isa<SCEVConstant>(E);
- }),
- Terms.end());
+ Terms.erase(
+ remove_if(Terms, [](const SCEV *E) { return isa<SCEVConstant>(E); }),
+ Terms.end());
if (Terms.size() > 0)
if (!findArrayDimensionsRec(SE, Terms, Sizes))
@@ -9171,40 +9187,11 @@ static bool findArrayDimensionsRec(ScalarEvolution &SE,
return true;
}
-// Returns true when S contains at least a SCEVUnknown parameter.
-static inline bool
-containsParameters(const SCEV *S) {
- struct FindParameter {
- bool FoundParameter;
- FindParameter() : FoundParameter(false) {}
-
- bool follow(const SCEV *S) {
- if (isa<SCEVUnknown>(S)) {
- FoundParameter = true;
- // Stop recursion: we found a parameter.
- return false;
- }
- // Keep looking.
- return true;
- }
- bool isDone() const {
- // Stop recursion if we have found a parameter.
- return FoundParameter;
- }
- };
-
- FindParameter F;
- SCEVTraversal<FindParameter> ST(F);
- ST.visitAll(S);
-
- return F.FoundParameter;
-}
// Returns true when one of the SCEVs of Terms contains a SCEVUnknown parameter.
-static inline bool
-containsParameters(SmallVectorImpl<const SCEV *> &Terms) {
+static inline bool containsParameters(SmallVectorImpl<const SCEV *> &Terms) {
for (const SCEV *T : Terms)
- if (containsParameters(T))
+ if (SCEVExprContains(T, isa<SCEVUnknown, const SCEV *>))
return true;
return false;
}
@@ -9535,6 +9522,7 @@ ScalarEvolution::ScalarEvolution(ScalarEvolution &&Arg)
: F(Arg.F), HasGuards(Arg.HasGuards), TLI(Arg.TLI), AC(Arg.AC), DT(Arg.DT),
LI(Arg.LI), CouldNotCompute(std::move(Arg.CouldNotCompute)),
ValueExprMap(std::move(Arg.ValueExprMap)),
+ PendingLoopPredicates(std::move(Arg.PendingLoopPredicates)),
WalkingBEDominatingConds(false), ProvingSplitPredicate(false),
BackedgeTakenCounts(std::move(Arg.BackedgeTakenCounts)),
PredicatedBackedgeTakenCounts(
@@ -9543,6 +9531,7 @@ ScalarEvolution::ScalarEvolution(ScalarEvolution &&Arg)
std::move(Arg.ConstantEvolutionLoopExitValue)),
ValuesAtScopes(std::move(Arg.ValuesAtScopes)),
LoopDispositions(std::move(Arg.LoopDispositions)),
+ LoopPropertiesCache(std::move(Arg.LoopPropertiesCache)),
BlockDispositions(std::move(Arg.BlockDispositions)),
UnsignedRanges(std::move(Arg.UnsignedRanges)),
SignedRanges(std::move(Arg.SignedRanges)),
@@ -9611,6 +9600,8 @@ static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE,
if (!isa<SCEVCouldNotCompute>(SE->getMaxBackedgeTakenCount(L))) {
OS << "max backedge-taken count is " << *SE->getMaxBackedgeTakenCount(L);
+ if (SE->isBackedgeTakenCountMaxOrZero(L))
+ OS << ", actual taken count either this or zero.";
} else {
OS << "Unpredictable max backedge-taken count. ";
}
@@ -9871,8 +9862,10 @@ ScalarEvolution::computeBlockDisposition(const SCEV *S, const BasicBlock *BB) {
const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
if (!DT.dominates(AR->getLoop()->getHeader(), BB))
return DoesNotDominateBlock;
+
+ // Fall through into SCEVNAryExpr handling.
+ LLVM_FALLTHROUGH;
}
- // FALL THROUGH into SCEVNAryExpr handling.
case scAddExpr:
case scMulExpr:
case scUMaxExpr:
@@ -9925,24 +9918,7 @@ bool ScalarEvolution::properlyDominates(const SCEV *S, const BasicBlock *BB) {
}
bool ScalarEvolution::hasOperand(const SCEV *S, const SCEV *Op) const {
- // Search for a SCEV expression node within an expression tree.
- // Implements SCEVTraversal::Visitor.
- struct SCEVSearch {
- const SCEV *Node;
- bool IsFound;
-
- SCEVSearch(const SCEV *N): Node(N), IsFound(false) {}
-
- bool follow(const SCEV *S) {
- IsFound |= (S == Node);
- return !IsFound;
- }
- bool isDone() const { return IsFound; }
- };
-
- SCEVSearch Search(Op);
- visitAll(S, Search);
- return Search.IsFound;
+ return SCEVExprContains(S, [&](const SCEV *Expr) { return Expr == Op; });
}
void ScalarEvolution::forgetMemoizedResults(const SCEV *S) {
@@ -10050,10 +10026,22 @@ void ScalarEvolution::verify() const {
// TODO: Verify more things.
}
-char ScalarEvolutionAnalysis::PassID;
+bool ScalarEvolution::invalidate(
+ Function &F, const PreservedAnalyses &PA,
+ FunctionAnalysisManager::Invalidator &Inv) {
+ // Invalidate the ScalarEvolution object whenever it isn't preserved or one
+ // of its dependencies is invalidated.
+ auto PAC = PA.getChecker<ScalarEvolutionAnalysis>();
+ return !(PAC.preserved() || PAC.preservedSet<AllAnalysesOn<Function>>()) ||
+ Inv.invalidate<AssumptionAnalysis>(F, PA) ||
+ Inv.invalidate<DominatorTreeAnalysis>(F, PA) ||
+ Inv.invalidate<LoopAnalysis>(F, PA);
+}
+
+AnalysisKey ScalarEvolutionAnalysis::Key;
ScalarEvolution ScalarEvolutionAnalysis::run(Function &F,
- AnalysisManager<Function> &AM) {
+ FunctionAnalysisManager &AM) {
return ScalarEvolution(F, AM.getResult<TargetLibraryAnalysis>(F),
AM.getResult<AssumptionAnalysis>(F),
AM.getResult<DominatorTreeAnalysis>(F),
@@ -10061,7 +10049,7 @@ ScalarEvolution ScalarEvolutionAnalysis::run(Function &F,
}
PreservedAnalyses
-ScalarEvolutionPrinterPass::run(Function &F, AnalysisManager<Function> &AM) {
+ScalarEvolutionPrinterPass::run(Function &F, FunctionAnalysisManager &AM) {
AM.getResult<ScalarEvolutionAnalysis>(F).print(OS);
return PreservedAnalyses::all();
}
@@ -10148,25 +10136,34 @@ namespace {
class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> {
public:
- // Rewrites \p S in the context of a loop L and the predicate A.
- // If Assume is true, rewrite is free to add further predicates to A
- // such that the result will be an AddRecExpr.
+ /// Rewrites \p S in the context of a loop L and the SCEV predication
+ /// infrastructure.
+ ///
+ /// If \p Pred is non-null, the SCEV expression is rewritten to respect the
+ /// equivalences present in \p Pred.
+ ///
+ /// If \p NewPreds is non-null, rewrite is free to add further predicates to
+ /// \p NewPreds such that the result will be an AddRecExpr.
static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE,
- SCEVUnionPredicate &A, bool Assume) {
- SCEVPredicateRewriter Rewriter(L, SE, A, Assume);
+ SmallPtrSetImpl<const SCEVPredicate *> *NewPreds,
+ SCEVUnionPredicate *Pred) {
+ SCEVPredicateRewriter Rewriter(L, SE, NewPreds, Pred);
return Rewriter.visit(S);
}
SCEVPredicateRewriter(const Loop *L, ScalarEvolution &SE,
- SCEVUnionPredicate &P, bool Assume)
- : SCEVRewriteVisitor(SE), P(P), L(L), Assume(Assume) {}
+ SmallPtrSetImpl<const SCEVPredicate *> *NewPreds,
+ SCEVUnionPredicate *Pred)
+ : SCEVRewriteVisitor(SE), NewPreds(NewPreds), Pred(Pred), L(L) {}
const SCEV *visitUnknown(const SCEVUnknown *Expr) {
- auto ExprPreds = P.getPredicatesForExpr(Expr);
- for (auto *Pred : ExprPreds)
- if (const auto *IPred = dyn_cast<SCEVEqualPredicate>(Pred))
- if (IPred->getLHS() == Expr)
- return IPred->getRHS();
+ if (Pred) {
+ auto ExprPreds = Pred->getPredicatesForExpr(Expr);
+ for (auto *Pred : ExprPreds)
+ if (const auto *IPred = dyn_cast<SCEVEqualPredicate>(Pred))
+ if (IPred->getLHS() == Expr)
+ return IPred->getRHS();
+ }
return Expr;
}
@@ -10207,32 +10204,31 @@ private:
bool addOverflowAssumption(const SCEVAddRecExpr *AR,
SCEVWrapPredicate::IncrementWrapFlags AddedFlags) {
auto *A = SE.getWrapPredicate(AR, AddedFlags);
- if (!Assume) {
+ if (!NewPreds) {
// Check if we've already made this assumption.
- if (P.implies(A))
- return true;
- return false;
+ return Pred && Pred->implies(A);
}
- P.add(A);
+ NewPreds->insert(A);
return true;
}
- SCEVUnionPredicate &P;
+ SmallPtrSetImpl<const SCEVPredicate *> *NewPreds;
+ SCEVUnionPredicate *Pred;
const Loop *L;
- bool Assume;
};
} // end anonymous namespace
const SCEV *ScalarEvolution::rewriteUsingPredicate(const SCEV *S, const Loop *L,
SCEVUnionPredicate &Preds) {
- return SCEVPredicateRewriter::rewrite(S, L, *this, Preds, false);
+ return SCEVPredicateRewriter::rewrite(S, L, *this, nullptr, &Preds);
}
-const SCEVAddRecExpr *
-ScalarEvolution::convertSCEVToAddRecWithPredicates(const SCEV *S, const Loop *L,
- SCEVUnionPredicate &Preds) {
- SCEVUnionPredicate TransformPreds;
- S = SCEVPredicateRewriter::rewrite(S, L, *this, TransformPreds, true);
+const SCEVAddRecExpr *ScalarEvolution::convertSCEVToAddRecWithPredicates(
+ const SCEV *S, const Loop *L,
+ SmallPtrSetImpl<const SCEVPredicate *> &Preds) {
+
+ SmallPtrSet<const SCEVPredicate *, 4> TransformPreds;
+ S = SCEVPredicateRewriter::rewrite(S, L, *this, &TransformPreds, nullptr);
auto *AddRec = dyn_cast<SCEVAddRecExpr>(S);
if (!AddRec)
@@ -10240,7 +10236,9 @@ ScalarEvolution::convertSCEVToAddRecWithPredicates(const SCEV *S, const Loop *L,
// Since the transformation was successful, we can now transfer the SCEV
// predicates.
- Preds.add(&TransformPreds);
+ for (auto *P : TransformPreds)
+ Preds.insert(P);
+
return AddRec;
}
@@ -10393,7 +10391,7 @@ const SCEV *PredicatedScalarEvolution::getSCEV(Value *V) {
return Entry.second;
// We found an entry but it's stale. Rewrite the stale entry
- // acording to the current predicate.
+ // according to the current predicate.
if (Entry.second)
Expr = Entry.second;
@@ -10467,11 +10465,15 @@ bool PredicatedScalarEvolution::hasNoOverflow(
const SCEVAddRecExpr *PredicatedScalarEvolution::getAsAddRec(Value *V) {
const SCEV *Expr = this->getSCEV(V);
- auto *New = SE.convertSCEVToAddRecWithPredicates(Expr, &L, Preds);
+ SmallPtrSet<const SCEVPredicate *, 4> NewPreds;
+ auto *New = SE.convertSCEVToAddRecWithPredicates(Expr, &L, NewPreds);
if (!New)
return nullptr;
+ for (auto *P : NewPreds)
+ Preds.add(P);
+
updateGeneration();
RewriteMap[SE.getSCEV(V)] = {Generation, New};
return New;
OpenPOWER on IntegriCloud