summaryrefslogtreecommitdiffstats
path: root/contrib/llvm/lib/Analysis/ScalarEvolution.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'contrib/llvm/lib/Analysis/ScalarEvolution.cpp')
-rw-r--r--contrib/llvm/lib/Analysis/ScalarEvolution.cpp2506
1 files changed, 1612 insertions, 894 deletions
diff --git a/contrib/llvm/lib/Analysis/ScalarEvolution.cpp b/contrib/llvm/lib/Analysis/ScalarEvolution.cpp
index ef1bb3a..e42a4b5 100644
--- a/contrib/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/contrib/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -111,10 +111,14 @@ MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden,
"derived loop"),
cl::init(100));
-// FIXME: Enable this with XDEBUG when the test suite is clean.
+// FIXME: Enable this with EXPENSIVE_CHECKS when the test suite is clean.
static cl::opt<bool>
VerifySCEV("verify-scev",
cl::desc("Verify ScalarEvolution's backedge taken counts (slow)"));
+static cl::opt<bool>
+ VerifySCEVMap("verify-scev-maps",
+ cl::desc("Verify no dangling value in ScalarEvolution's "
+ "ExprValueMap (slow)"));
//===----------------------------------------------------------------------===//
// SCEV class definitions
@@ -162,11 +166,11 @@ void SCEV::print(raw_ostream &OS) const {
for (unsigned i = 1, e = AR->getNumOperands(); i != e; ++i)
OS << ",+," << *AR->getOperand(i);
OS << "}<";
- if (AR->getNoWrapFlags(FlagNUW))
+ if (AR->hasNoUnsignedWrap())
OS << "nuw><";
- if (AR->getNoWrapFlags(FlagNSW))
+ if (AR->hasNoSignedWrap())
OS << "nsw><";
- if (AR->getNoWrapFlags(FlagNW) &&
+ if (AR->hasNoSelfWrap() &&
!AR->getNoWrapFlags((NoWrapFlags)(FlagNUW | FlagNSW)))
OS << "nw><";
AR->getLoop()->getHeader()->printAsOperand(OS, /*PrintType=*/false);
@@ -196,9 +200,9 @@ void SCEV::print(raw_ostream &OS) const {
switch (NAry->getSCEVType()) {
case scAddExpr:
case scMulExpr:
- if (NAry->getNoWrapFlags(FlagNUW))
+ if (NAry->hasNoUnsignedWrap())
OS << "<nuw>";
- if (NAry->getNoWrapFlags(FlagNSW))
+ if (NAry->hasNoSignedWrap())
OS << "<nsw>";
}
return;
@@ -283,8 +287,6 @@ bool SCEV::isAllOnesValue() const {
return false;
}
-/// isNonConstantNegative - Return true if the specified scev is negated, but
-/// not a constant.
bool SCEV::isNonConstantNegative() const {
const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(this);
if (!Mul) return false;
@@ -620,10 +622,10 @@ public:
};
} // end anonymous namespace
-/// GroupByComplexity - Given a list of SCEV objects, order them by their
-/// complexity, and group objects of the same complexity together by value.
-/// When this routine is finished, we know that any duplicates in the vector are
-/// consecutive and that complexity is monotonically increasing.
+/// Given a list of SCEV objects, order them by their complexity, and group
+/// objects of the same complexity together by value. When this routine is
+/// finished, we know that any duplicates in the vector are consecutive and that
+/// complexity is monotonically increasing.
///
/// Note that we go take special precautions to ensure that we get deterministic
/// results from this routine. In other words, we don't want the results of
@@ -723,7 +725,7 @@ public:
}
// Split the Denominator when it is a product.
- if (const SCEVMulExpr *T = dyn_cast<const SCEVMulExpr>(Denominator)) {
+ if (const SCEVMulExpr *T = dyn_cast<SCEVMulExpr>(Denominator)) {
const SCEV *Q, *R;
*Quotient = Numerator;
for (const SCEV *Op : T->operands()) {
@@ -922,8 +924,7 @@ private:
// Simple SCEV method implementations
//===----------------------------------------------------------------------===//
-/// BinomialCoefficient - Compute BC(It, K). The result has width W.
-/// Assume, K > 0.
+/// Compute BC(It, K). The result has width W. Assume, K > 0.
static const SCEV *BinomialCoefficient(const SCEV *It, unsigned K,
ScalarEvolution &SE,
Type *ResultTy) {
@@ -1034,10 +1035,10 @@ static const SCEV *BinomialCoefficient(const SCEV *It, unsigned K,
SE.getTruncateOrZeroExtend(DivResult, ResultTy));
}
-/// evaluateAtIteration - Return the value of this chain of recurrences at
-/// the specified iteration number. We can evaluate this recurrence by
-/// multiplying each element in the chain by the binomial coefficient
-/// corresponding to it. In other words, we can evaluate {A,+,B,+,C,+,D} as:
+/// Return the value of this chain of recurrences at the specified iteration
+/// number. We can evaluate this recurrence by multiplying each element in the
+/// chain by the binomial coefficient corresponding to it. In other words, we
+/// can evaluate {A,+,B,+,C,+,D} as:
///
/// A*BC(It, 0) + B*BC(It, 1) + C*BC(It, 2) + D*BC(It, 3)
///
@@ -1450,9 +1451,14 @@ const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op,
unsigned BitWidth = getTypeSizeInBits(AR->getType());
const Loop *L = AR->getLoop();
+ if (!AR->hasNoUnsignedWrap()) {
+ auto NewFlags = proveNoWrapViaConstantRanges(AR);
+ const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(NewFlags);
+ }
+
// If we have special knowledge that this addrec won't overflow,
// we don't need to do any further analysis.
- if (AR->getNoWrapFlags(SCEV::FlagNUW))
+ if (AR->hasNoUnsignedWrap())
return getAddRecExpr(
getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this),
getZeroExtendExpr(Step, Ty), L, AR->getNoWrapFlags());
@@ -1512,11 +1518,22 @@ const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op,
getSignExtendExpr(Step, Ty), L, AR->getNoWrapFlags());
}
}
+ }
- // If the backedge is guarded by a comparison with the pre-inc value
- // the addrec is safe. Also, if the entry is guarded by a comparison
- // with the start value and the backedge is guarded by a comparison
- // with the post-inc value, the addrec is safe.
+ // Normally, in the cases we can prove no-overflow via a
+ // backedge guarding condition, we can also compute a backedge
+ // taken count for the loop. The exceptions are assumptions and
+ // guards present in the loop -- SCEV is not great at exploiting
+ // these to compute max backedge taken counts, but can still use
+ // these to prove lack of overflow. Use this fact to avoid
+ // doing extra work that may not pay off.
+ if (!isa<SCEVCouldNotCompute>(MaxBECount) || HasGuards ||
+ !AC.assumptions().empty()) {
+ // If the backedge is guarded by a comparison with the pre-inc
+ // value the addrec is safe. Also, if the entry is guarded by
+ // a comparison with the start value and the backedge is
+ // guarded by a comparison with the post-inc value, the addrec
+ // is safe.
if (isKnownPositive(Step)) {
const SCEV *N = getConstant(APInt::getMinValue(BitWidth) -
getUnsignedRange(Step).getUnsignedMax());
@@ -1524,7 +1541,8 @@ const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op,
(isLoopEntryGuardedByCond(L, ICmpInst::ICMP_ULT, Start, N) &&
isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_ULT,
AR->getPostIncExpr(*this), N))) {
- // Cache knowledge of AR NUW, which is propagated to this AddRec.
+ // Cache knowledge of AR NUW, which is propagated to this
+ // AddRec.
const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNUW);
// Return the expression with the addrec on the outside.
return getAddRecExpr(
@@ -1538,8 +1556,9 @@ const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op,
(isLoopEntryGuardedByCond(L, ICmpInst::ICMP_UGT, Start, N) &&
isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_UGT,
AR->getPostIncExpr(*this), N))) {
- // Cache knowledge of AR NW, which is propagated to this AddRec.
- // Negative step causes unsigned wrap, but it still can't self-wrap.
+ // Cache knowledge of AR NW, which is propagated to this
+ // AddRec. Negative step causes unsigned wrap, but it
+ // still can't self-wrap.
const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNW);
// Return the expression with the addrec on the outside.
return getAddRecExpr(
@@ -1559,7 +1578,7 @@ const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op,
if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
// zext((A + B + ...)<nuw>) --> (zext(A) + zext(B) + ...)<nuw>
- if (SA->getNoWrapFlags(SCEV::FlagNUW)) {
+ if (SA->hasNoUnsignedWrap()) {
// If the addition does not unsign overflow then we can, by definition,
// commute the zero extension with the addition operation.
SmallVector<const SCEV *, 4> Ops;
@@ -1608,10 +1627,6 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op,
void *IP = nullptr;
if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
- // If the input value is provably positive, build a zext instead.
- if (isKnownNonNegative(Op))
- return getZeroExtendExpr(Op, Ty);
-
// sext(trunc(x)) --> sext(x) or x or trunc(x)
if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op)) {
// It's possible the bits taken off by the truncate were all sign bits. If
@@ -1643,7 +1658,7 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op,
}
// sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw>
- if (SA->getNoWrapFlags(SCEV::FlagNSW)) {
+ if (SA->hasNoSignedWrap()) {
// If the addition does not sign overflow then we can, by definition,
// commute the sign extension with the addition operation.
SmallVector<const SCEV *, 4> Ops;
@@ -1663,9 +1678,14 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op,
unsigned BitWidth = getTypeSizeInBits(AR->getType());
const Loop *L = AR->getLoop();
+ if (!AR->hasNoSignedWrap()) {
+ auto NewFlags = proveNoWrapViaConstantRanges(AR);
+ const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(NewFlags);
+ }
+
// If we have special knowledge that this addrec won't overflow,
// we don't need to do any further analysis.
- if (AR->getNoWrapFlags(SCEV::FlagNSW))
+ if (AR->hasNoSignedWrap())
return getAddRecExpr(
getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this),
getSignExtendExpr(Step, Ty), L, SCEV::FlagNSW);
@@ -1732,11 +1752,23 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op,
getZeroExtendExpr(Step, Ty), L, AR->getNoWrapFlags());
}
}
+ }
- // If the backedge is guarded by a comparison with the pre-inc value
- // the addrec is safe. Also, if the entry is guarded by a comparison
- // with the start value and the backedge is guarded by a comparison
- // with the post-inc value, the addrec is safe.
+ // Normally, in the cases we can prove no-overflow via a
+ // backedge guarding condition, we can also compute a backedge
+ // taken count for the loop. The exceptions are assumptions and
+ // guards present in the loop -- SCEV is not great at exploiting
+ // these to compute max backedge taken counts, but can still use
+ // these to prove lack of overflow. Use this fact to avoid
+ // doing extra work that may not pay off.
+
+ if (!isa<SCEVCouldNotCompute>(MaxBECount) || HasGuards ||
+ !AC.assumptions().empty()) {
+ // If the backedge is guarded by a comparison with the pre-inc
+ // value the addrec is safe. Also, if the entry is guarded by
+ // a comparison with the start value and the backedge is
+ // guarded by a comparison with the post-inc value, the addrec
+ // is safe.
ICmpInst::Predicate Pred;
const SCEV *OverflowLimit =
getSignedOverflowLimitForStep(Step, &Pred, this);
@@ -1752,6 +1784,7 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op,
getSignExtendExpr(Step, Ty), L, AR->getNoWrapFlags());
}
}
+
// If Start and Step are constants, check if we can apply this
// transformation:
// sext{C1,+,C2} --> C1 + sext{0,+,C2} if C1 < C2
@@ -1777,6 +1810,11 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op,
}
}
+ // If the input value is provably positive and we could not simplify
+ // away the sext build a zext instead.
+ if (isKnownNonNegative(Op))
+ return getZeroExtendExpr(Op, Ty);
+
// The cast wasn't folded; create an explicit cast node.
// Recompute the insert position, as it may have been invalidated.
if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
@@ -1836,11 +1874,10 @@ const SCEV *ScalarEvolution::getAnyExtendExpr(const SCEV *Op,
return ZExt;
}
-/// CollectAddOperandsWithScales - Process the given Ops list, which is
-/// a list of operands to be added under the given scale, update the given
-/// map. This is a helper function for getAddRecExpr. As an example of
-/// what it does, given a sequence of operands that would form an add
-/// expression like this:
+/// Process the given Ops list, which is a list of operands to be added under
+/// the given scale, update the given map. This is a helper function for
+/// getAddRecExpr. As an example of what it does, given a sequence of operands
+/// that would form an add expression like this:
///
/// m + n + 13 + (A * (o + p + (B * (q + m + 29)))) + r + (-1 * r)
///
@@ -1899,7 +1936,7 @@ CollectAddOperandsWithScales(DenseMap<const SCEV *, APInt> &M,
// the map.
SmallVector<const SCEV *, 4> MulOps(Mul->op_begin()+1, Mul->op_end());
const SCEV *Key = SE.getMulExpr(MulOps);
- auto Pair = M.insert(std::make_pair(Key, NewScale));
+ auto Pair = M.insert({Key, NewScale});
if (Pair.second) {
NewOps.push_back(Pair.first->first);
} else {
@@ -1912,7 +1949,7 @@ CollectAddOperandsWithScales(DenseMap<const SCEV *, APInt> &M,
} else {
// An ordinary operand. Update the map.
std::pair<DenseMap<const SCEV *, APInt>::iterator, bool> Pair =
- M.insert(std::make_pair(Ops[i], Scale));
+ M.insert({Ops[i], Scale});
if (Pair.second) {
NewOps.push_back(Pair.first->first);
} else {
@@ -1965,15 +2002,14 @@ StrengthenNoWrapFlags(ScalarEvolution *SE, SCEVTypes Type,
const APInt &C = cast<SCEVConstant>(Ops[0])->getAPInt();
if (!(SignOrUnsignWrap & SCEV::FlagNSW)) {
- auto NSWRegion =
- ConstantRange::makeNoWrapRegion(Instruction::Add, C, OBO::NoSignedWrap);
+ auto NSWRegion = ConstantRange::makeGuaranteedNoWrapRegion(
+ Instruction::Add, C, OBO::NoSignedWrap);
if (NSWRegion.contains(SE->getSignedRange(Ops[1])))
Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNSW);
}
if (!(SignOrUnsignWrap & SCEV::FlagNUW)) {
- auto NUWRegion =
- ConstantRange::makeNoWrapRegion(Instruction::Add, C,
- OBO::NoUnsignedWrap);
+ auto NUWRegion = ConstantRange::makeGuaranteedNoWrapRegion(
+ Instruction::Add, C, OBO::NoUnsignedWrap);
if (NUWRegion.contains(SE->getUnsignedRange(Ops[1])))
Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW);
}
@@ -1982,8 +2018,7 @@ StrengthenNoWrapFlags(ScalarEvolution *SE, SCEVTypes Type,
return Flags;
}
-/// getAddExpr - Get a canonical add expression, or something simpler if
-/// possible.
+/// Get a canonical add expression, or something simpler if possible.
const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops,
SCEV::NoWrapFlags Flags) {
assert(!(Flags & ~(SCEV::FlagNUW | SCEV::FlagNSW)) &&
@@ -2266,7 +2301,10 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops,
SmallVector<const SCEV *, 4> AddRecOps(AddRec->op_begin(),
AddRec->op_end());
- AddRecOps[0] = getAddExpr(LIOps);
+ // This follows from the fact that the no-wrap flags on the outer add
+ // expression are applicable on the 0th iteration, when the add recurrence
+ // will be equal to its start value.
+ AddRecOps[0] = getAddExpr(LIOps, Flags);
// Build the new addrec. Propagate the NUW and NSW flags if both the
// outer add and the inner addrec are guaranteed to have no overflow.
@@ -2391,8 +2429,7 @@ static bool containsConstantSomewhere(const SCEV *StartExpr) {
return false;
}
-/// getMulExpr - Get a canonical multiply expression, or something simpler if
-/// possible.
+/// Get a canonical multiply expression, or something simpler if possible.
const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops,
SCEV::NoWrapFlags Flags) {
assert(Flags == maskFlags(Flags, SCEV::FlagNUW | SCEV::FlagNSW) &&
@@ -2632,8 +2669,8 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops,
return S;
}
-/// getUDivExpr - Get a canonical unsigned division expression, or something
-/// simpler if possible.
+/// Get a canonical unsigned division expression, or something simpler if
+/// possible.
const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS,
const SCEV *RHS) {
assert(getEffectiveSCEVType(LHS->getType()) ==
@@ -2764,10 +2801,10 @@ static const APInt gcd(const SCEVConstant *C1, const SCEVConstant *C2) {
return APIntOps::GreatestCommonDivisor(A, B);
}
-/// getUDivExactExpr - Get a canonical unsigned division expression, or
-/// something simpler if possible. There is no representation for an exact udiv
-/// in SCEV IR, but we can attempt to remove factors from the LHS and RHS.
-/// We can't do this when it's not exact because the udiv may be clearing bits.
+/// Get a canonical unsigned division expression, or something simpler if
+/// possible. There is no representation for an exact udiv in SCEV IR, but we
+/// can attempt to remove factors from the LHS and RHS. We can't do this when
+/// it's not exact because the udiv may be clearing bits.
const SCEV *ScalarEvolution::getUDivExactExpr(const SCEV *LHS,
const SCEV *RHS) {
// TODO: we could try to find factors in all sorts of things, but for now we
@@ -2821,8 +2858,8 @@ const SCEV *ScalarEvolution::getUDivExactExpr(const SCEV *LHS,
return getUDivExpr(LHS, RHS);
}
-/// getAddRecExpr - Get an add recurrence expression for the specified loop.
-/// Simplify the expression as much as possible.
+/// Get an add recurrence expression for the specified loop. Simplify the
+/// expression as much as possible.
const SCEV *ScalarEvolution::getAddRecExpr(const SCEV *Start, const SCEV *Step,
const Loop *L,
SCEV::NoWrapFlags Flags) {
@@ -2838,8 +2875,8 @@ const SCEV *ScalarEvolution::getAddRecExpr(const SCEV *Start, const SCEV *Step,
return getAddRecExpr(Operands, L, Flags);
}
-/// getAddRecExpr - Get an add recurrence expression for the specified loop.
-/// Simplify the expression as much as possible.
+/// Get an add recurrence expression for the specified loop. Simplify the
+/// expression as much as possible.
const SCEV *
ScalarEvolution::getAddRecExpr(SmallVectorImpl<const SCEV *> &Operands,
const Loop *L, SCEV::NoWrapFlags Flags) {
@@ -2985,9 +3022,7 @@ ScalarEvolution::getGEPExpr(Type *PointeeType, const SCEV *BaseExpr,
const SCEV *ScalarEvolution::getSMaxExpr(const SCEV *LHS,
const SCEV *RHS) {
- SmallVector<const SCEV *, 2> Ops;
- Ops.push_back(LHS);
- Ops.push_back(RHS);
+ SmallVector<const SCEV *, 2> Ops = {LHS, RHS};
return getSMaxExpr(Ops);
}
@@ -3088,9 +3123,7 @@ ScalarEvolution::getSMaxExpr(SmallVectorImpl<const SCEV *> &Ops) {
const SCEV *ScalarEvolution::getUMaxExpr(const SCEV *LHS,
const SCEV *RHS) {
- SmallVector<const SCEV *, 2> Ops;
- Ops.push_back(LHS);
- Ops.push_back(RHS);
+ SmallVector<const SCEV *, 2> Ops = {LHS, RHS};
return getUMaxExpr(Ops);
}
@@ -3244,26 +3277,25 @@ const SCEV *ScalarEvolution::getUnknown(Value *V) {
// Basic SCEV Analysis and PHI Idiom Recognition Code
//
-/// isSCEVable - Test if values of the given type are analyzable within
-/// the SCEV framework. This primarily includes integer types, and it
-/// can optionally include pointer types if the ScalarEvolution class
-/// has access to target-specific information.
+/// Test if values of the given type are analyzable within the SCEV
+/// framework. This primarily includes integer types, and it can optionally
+/// include pointer types if the ScalarEvolution class has access to
+/// target-specific information.
bool ScalarEvolution::isSCEVable(Type *Ty) const {
// Integers and pointers are always SCEVable.
return Ty->isIntegerTy() || Ty->isPointerTy();
}
-/// getTypeSizeInBits - Return the size in bits of the specified type,
-/// for which isSCEVable must return true.
+/// Return the size in bits of the specified type, for which isSCEVable must
+/// return true.
uint64_t ScalarEvolution::getTypeSizeInBits(Type *Ty) const {
assert(isSCEVable(Ty) && "Type is not SCEVable!");
return getDataLayout().getTypeSizeInBits(Ty);
}
-/// getEffectiveSCEVType - Return a type with the same bitwidth as
-/// the given type and which represents how SCEV will treat the given
-/// type, for which isSCEVable must return true. For pointer types,
-/// this is the pointer-sized integer type.
+/// Return a type with the same bitwidth as the given type and which represents
+/// how SCEV will treat the given type, for which isSCEVable must return
+/// true. For pointer types, this is the pointer-sized integer type.
Type *ScalarEvolution::getEffectiveSCEVType(Type *Ty) const {
assert(isSCEVable(Ty) && "Type is not SCEVable!");
@@ -3310,15 +3342,88 @@ bool ScalarEvolution::checkValidity(const SCEV *S) const {
return !F.FindOne;
}
-/// getSCEV - Return an existing SCEV if it exists, otherwise analyze the
-/// expression and create a new one.
+namespace {
+// Helper class working with SCEVTraversal to figure out if a SCEV contains
+// a sub SCEV of scAddRecExpr type. FindInvalidSCEVUnknown::FoundOne is set
+// iff if such sub scAddRecExpr type SCEV is found.
+struct FindAddRecurrence {
+ bool FoundOne;
+ FindAddRecurrence() : FoundOne(false) {}
+
+ bool follow(const SCEV *S) {
+ switch (static_cast<SCEVTypes>(S->getSCEVType())) {
+ case scAddRecExpr:
+ FoundOne = true;
+ case scConstant:
+ case scUnknown:
+ case scCouldNotCompute:
+ return false;
+ default:
+ return true;
+ }
+ }
+ bool isDone() const { return FoundOne; }
+};
+}
+
+bool ScalarEvolution::containsAddRecurrence(const SCEV *S) {
+ HasRecMapType::iterator I = HasRecMap.find_as(S);
+ if (I != HasRecMap.end())
+ return I->second;
+
+ FindAddRecurrence F;
+ SCEVTraversal<FindAddRecurrence> ST(F);
+ ST.visitAll(S);
+ HasRecMap.insert({S, F.FoundOne});
+ return F.FoundOne;
+}
+
+/// Return the Value set from S.
+SetVector<Value *> *ScalarEvolution::getSCEVValues(const SCEV *S) {
+ ExprValueMapType::iterator SI = ExprValueMap.find_as(S);
+ if (SI == ExprValueMap.end())
+ return nullptr;
+#ifndef NDEBUG
+ if (VerifySCEVMap) {
+ // Check there is no dangling Value in the set returned.
+ for (const auto &VE : SI->second)
+ assert(ValueExprMap.count(VE));
+ }
+#endif
+ return &SI->second;
+}
+
+/// Erase Value from ValueExprMap and ExprValueMap. If ValueExprMap.erase(V) is
+/// not used together with forgetMemoizedResults(S), eraseValueFromMap should be
+/// used instead to ensure whenever V->S is removed from ValueExprMap, V is also
+/// removed from the set of ExprValueMap[S].
+void ScalarEvolution::eraseValueFromMap(Value *V) {
+ ValueExprMapType::iterator I = ValueExprMap.find_as(V);
+ if (I != ValueExprMap.end()) {
+ const SCEV *S = I->second;
+ SetVector<Value *> *SV = getSCEVValues(S);
+ // Remove V from the set of ExprValueMap[S]
+ if (SV)
+ SV->remove(V);
+ ValueExprMap.erase(V);
+ }
+}
+
+/// Return an existing SCEV if it exists, otherwise analyze the expression and
+/// create a new one.
const SCEV *ScalarEvolution::getSCEV(Value *V) {
assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
const SCEV *S = getExistingSCEV(V);
if (S == nullptr) {
S = createSCEV(V);
- ValueExprMap.insert(std::make_pair(SCEVCallbackVH(V, this), S));
+ // During PHI resolution, it is possible to create two SCEVs for the same
+ // V, so it is needed to double check whether V->S is inserted into
+ // ValueExprMap before insert S->V into ExprValueMap.
+ std::pair<ValueExprMapType::iterator, bool> Pair =
+ ValueExprMap.insert({SCEVCallbackVH(V, this), S});
+ if (Pair.second)
+ ExprValueMap[S].insert(V);
}
return S;
}
@@ -3331,12 +3436,13 @@ const SCEV *ScalarEvolution::getExistingSCEV(Value *V) {
const SCEV *S = I->second;
if (checkValidity(S))
return S;
+ forgetMemoizedResults(S);
ValueExprMap.erase(I);
}
return nullptr;
}
-/// getNegativeSCEV - Return a SCEV corresponding to -V = -1*V
+/// Return a SCEV corresponding to -V = -1*V
///
const SCEV *ScalarEvolution::getNegativeSCEV(const SCEV *V,
SCEV::NoWrapFlags Flags) {
@@ -3350,7 +3456,7 @@ const SCEV *ScalarEvolution::getNegativeSCEV(const SCEV *V,
V, getConstant(cast<ConstantInt>(Constant::getAllOnesValue(Ty))), Flags);
}
-/// getNotSCEV - Return a SCEV corresponding to ~V = -1-V
+/// Return a SCEV corresponding to ~V = -1-V
const SCEV *ScalarEvolution::getNotSCEV(const SCEV *V) {
if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
return getConstant(
@@ -3363,7 +3469,6 @@ const SCEV *ScalarEvolution::getNotSCEV(const SCEV *V) {
return getMinusSCEV(AllOnes, V);
}
-/// getMinusSCEV - Return LHS-RHS. Minus is represented in SCEV as A+B*-1.
const SCEV *ScalarEvolution::getMinusSCEV(const SCEV *LHS, const SCEV *RHS,
SCEV::NoWrapFlags Flags) {
// Fast path: X - X --> 0.
@@ -3402,9 +3507,6 @@ const SCEV *ScalarEvolution::getMinusSCEV(const SCEV *LHS, const SCEV *RHS,
return getAddExpr(LHS, getNegativeSCEV(RHS, NegFlags), AddFlags);
}
-/// getTruncateOrZeroExtend - Return a SCEV corresponding to a conversion of the
-/// input value to the specified type. If the type must be extended, it is zero
-/// extended.
const SCEV *
ScalarEvolution::getTruncateOrZeroExtend(const SCEV *V, Type *Ty) {
Type *SrcTy = V->getType();
@@ -3418,9 +3520,6 @@ ScalarEvolution::getTruncateOrZeroExtend(const SCEV *V, Type *Ty) {
return getZeroExtendExpr(V, Ty);
}
-/// getTruncateOrSignExtend - Return a SCEV corresponding to a conversion of the
-/// input value to the specified type. If the type must be extended, it is sign
-/// extended.
const SCEV *
ScalarEvolution::getTruncateOrSignExtend(const SCEV *V,
Type *Ty) {
@@ -3435,9 +3534,6 @@ ScalarEvolution::getTruncateOrSignExtend(const SCEV *V,
return getSignExtendExpr(V, Ty);
}
-/// getNoopOrZeroExtend - Return a SCEV corresponding to a conversion of the
-/// input value to the specified type. If the type must be extended, it is zero
-/// extended. The conversion must not be narrowing.
const SCEV *
ScalarEvolution::getNoopOrZeroExtend(const SCEV *V, Type *Ty) {
Type *SrcTy = V->getType();
@@ -3451,9 +3547,6 @@ ScalarEvolution::getNoopOrZeroExtend(const SCEV *V, Type *Ty) {
return getZeroExtendExpr(V, Ty);
}
-/// getNoopOrSignExtend - Return a SCEV corresponding to a conversion of the
-/// input value to the specified type. If the type must be extended, it is sign
-/// extended. The conversion must not be narrowing.
const SCEV *
ScalarEvolution::getNoopOrSignExtend(const SCEV *V, Type *Ty) {
Type *SrcTy = V->getType();
@@ -3467,10 +3560,6 @@ ScalarEvolution::getNoopOrSignExtend(const SCEV *V, Type *Ty) {
return getSignExtendExpr(V, Ty);
}
-/// getNoopOrAnyExtend - Return a SCEV corresponding to a conversion of
-/// the input value to the specified type. If the type must be extended,
-/// it is extended with unspecified bits. The conversion must not be
-/// narrowing.
const SCEV *
ScalarEvolution::getNoopOrAnyExtend(const SCEV *V, Type *Ty) {
Type *SrcTy = V->getType();
@@ -3484,8 +3573,6 @@ ScalarEvolution::getNoopOrAnyExtend(const SCEV *V, Type *Ty) {
return getAnyExtendExpr(V, Ty);
}
-/// getTruncateOrNoop - Return a SCEV corresponding to a conversion of the
-/// input value to the specified type. The conversion must not be widening.
const SCEV *
ScalarEvolution::getTruncateOrNoop(const SCEV *V, Type *Ty) {
Type *SrcTy = V->getType();
@@ -3499,9 +3586,6 @@ ScalarEvolution::getTruncateOrNoop(const SCEV *V, Type *Ty) {
return getTruncateExpr(V, Ty);
}
-/// getUMaxFromMismatchedTypes - Promote the operands to the wider of
-/// the types using zero-extension, and then perform a umax operation
-/// with them.
const SCEV *ScalarEvolution::getUMaxFromMismatchedTypes(const SCEV *LHS,
const SCEV *RHS) {
const SCEV *PromotedLHS = LHS;
@@ -3515,9 +3599,6 @@ const SCEV *ScalarEvolution::getUMaxFromMismatchedTypes(const SCEV *LHS,
return getUMaxExpr(PromotedLHS, PromotedRHS);
}
-/// getUMinFromMismatchedTypes - Promote the operands to the wider of
-/// the types using zero-extension, and then perform a umin operation
-/// with them.
const SCEV *ScalarEvolution::getUMinFromMismatchedTypes(const SCEV *LHS,
const SCEV *RHS) {
const SCEV *PromotedLHS = LHS;
@@ -3531,10 +3612,6 @@ const SCEV *ScalarEvolution::getUMinFromMismatchedTypes(const SCEV *LHS,
return getUMinExpr(PromotedLHS, PromotedRHS);
}
-/// getPointerBase - Transitively follow the chain of pointer-type operands
-/// until reaching a SCEV that does not have a single pointer operand. This
-/// returns a SCEVUnknown pointer for well-formed pointer-type expressions,
-/// but corner cases do exist.
const SCEV *ScalarEvolution::getPointerBase(const SCEV *V) {
// A pointer operand may evaluate to a nonpointer expression, such as null.
if (!V->getType()->isPointerTy())
@@ -3559,8 +3636,7 @@ const SCEV *ScalarEvolution::getPointerBase(const SCEV *V) {
return V;
}
-/// PushDefUseChildren - Push users of the given Instruction
-/// onto the given Worklist.
+/// Push users of the given Instruction onto the given Worklist.
static void
PushDefUseChildren(Instruction *I,
SmallVectorImpl<Instruction *> &Worklist) {
@@ -3569,12 +3645,7 @@ PushDefUseChildren(Instruction *I,
Worklist.push_back(cast<Instruction>(U));
}
-/// ForgetSymbolicValue - This looks up computed SCEV values for all
-/// instructions that depend on the given instruction and removes them from
-/// the ValueExprMapType map if they reference SymName. This is used during PHI
-/// resolution.
-void
-ScalarEvolution::ForgetSymbolicName(Instruction *PN, const SCEV *SymName) {
+void ScalarEvolution::forgetSymbolicName(Instruction *PN, const SCEV *SymName) {
SmallVector<Instruction *, 16> Worklist;
PushDefUseChildren(PN, Worklist);
@@ -3616,10 +3687,10 @@ ScalarEvolution::ForgetSymbolicName(Instruction *PN, const SCEV *SymName) {
namespace {
class SCEVInitRewriter : public SCEVRewriteVisitor<SCEVInitRewriter> {
public:
- static const SCEV *rewrite(const SCEV *Scev, const Loop *L,
+ static const SCEV *rewrite(const SCEV *S, const Loop *L,
ScalarEvolution &SE) {
SCEVInitRewriter Rewriter(L, SE);
- const SCEV *Result = Rewriter.visit(Scev);
+ const SCEV *Result = Rewriter.visit(S);
return Rewriter.isValid() ? Result : SE.getCouldNotCompute();
}
@@ -3649,10 +3720,10 @@ private:
class SCEVShiftRewriter : public SCEVRewriteVisitor<SCEVShiftRewriter> {
public:
- static const SCEV *rewrite(const SCEV *Scev, const Loop *L,
+ static const SCEV *rewrite(const SCEV *S, const Loop *L,
ScalarEvolution &SE) {
SCEVShiftRewriter Rewriter(L, SE);
- const SCEV *Result = Rewriter.visit(Scev);
+ const SCEV *Result = Rewriter.visit(S);
return Rewriter.isValid() ? Result : SE.getCouldNotCompute();
}
@@ -3680,6 +3751,167 @@ private:
};
} // end anonymous namespace
+SCEV::NoWrapFlags
+ScalarEvolution::proveNoWrapViaConstantRanges(const SCEVAddRecExpr *AR) {
+ if (!AR->isAffine())
+ return SCEV::FlagAnyWrap;
+
+ typedef OverflowingBinaryOperator OBO;
+ SCEV::NoWrapFlags Result = SCEV::FlagAnyWrap;
+
+ if (!AR->hasNoSignedWrap()) {
+ ConstantRange AddRecRange = getSignedRange(AR);
+ ConstantRange IncRange = getSignedRange(AR->getStepRecurrence(*this));
+
+ auto NSWRegion = ConstantRange::makeGuaranteedNoWrapRegion(
+ Instruction::Add, IncRange, OBO::NoSignedWrap);
+ if (NSWRegion.contains(AddRecRange))
+ Result = ScalarEvolution::setFlags(Result, SCEV::FlagNSW);
+ }
+
+ if (!AR->hasNoUnsignedWrap()) {
+ ConstantRange AddRecRange = getUnsignedRange(AR);
+ ConstantRange IncRange = getUnsignedRange(AR->getStepRecurrence(*this));
+
+ auto NUWRegion = ConstantRange::makeGuaranteedNoWrapRegion(
+ Instruction::Add, IncRange, OBO::NoUnsignedWrap);
+ if (NUWRegion.contains(AddRecRange))
+ Result = ScalarEvolution::setFlags(Result, SCEV::FlagNUW);
+ }
+
+ return Result;
+}
+
+namespace {
+/// Represents an abstract binary operation. This may exist as a
+/// normal instruction or constant expression, or may have been
+/// derived from an expression tree.
+struct BinaryOp {
+ unsigned Opcode;
+ Value *LHS;
+ Value *RHS;
+ bool IsNSW;
+ bool IsNUW;
+
+ /// Op is set if this BinaryOp corresponds to a concrete LLVM instruction or
+ /// constant expression.
+ Operator *Op;
+
+ explicit BinaryOp(Operator *Op)
+ : Opcode(Op->getOpcode()), LHS(Op->getOperand(0)), RHS(Op->getOperand(1)),
+ IsNSW(false), IsNUW(false), Op(Op) {
+ if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(Op)) {
+ IsNSW = OBO->hasNoSignedWrap();
+ IsNUW = OBO->hasNoUnsignedWrap();
+ }
+ }
+
+ explicit BinaryOp(unsigned Opcode, Value *LHS, Value *RHS, bool IsNSW = false,
+ bool IsNUW = false)
+ : Opcode(Opcode), LHS(LHS), RHS(RHS), IsNSW(IsNSW), IsNUW(IsNUW),
+ Op(nullptr) {}
+};
+}
+
+
+/// Try to map \p V into a BinaryOp, and return \c None on failure.
+static Optional<BinaryOp> MatchBinaryOp(Value *V, DominatorTree &DT) {
+ auto *Op = dyn_cast<Operator>(V);
+ if (!Op)
+ return None;
+
+ // Implementation detail: all the cleverness here should happen without
+ // creating new SCEV expressions -- our caller knowns tricks to avoid creating
+ // SCEV expressions when possible, and we should not break that.
+
+ switch (Op->getOpcode()) {
+ case Instruction::Add:
+ case Instruction::Sub:
+ case Instruction::Mul:
+ case Instruction::UDiv:
+ case Instruction::And:
+ case Instruction::Or:
+ case Instruction::AShr:
+ case Instruction::Shl:
+ return BinaryOp(Op);
+
+ case Instruction::Xor:
+ if (auto *RHSC = dyn_cast<ConstantInt>(Op->getOperand(1)))
+ // If the RHS of the xor is a signbit, then this is just an add.
+ // Instcombine turns add of signbit into xor as a strength reduction step.
+ if (RHSC->getValue().isSignBit())
+ return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1));
+ return BinaryOp(Op);
+
+ case Instruction::LShr:
+ // Turn logical shift right of a constant into a unsigned divide.
+ if (ConstantInt *SA = dyn_cast<ConstantInt>(Op->getOperand(1))) {
+ uint32_t BitWidth = cast<IntegerType>(Op->getType())->getBitWidth();
+
+ // If the shift count is not less than the bitwidth, the result of
+ // the shift is undefined. Don't try to analyze it, because the
+ // resolution chosen here may differ from the resolution chosen in
+ // other parts of the compiler.
+ if (SA->getValue().ult(BitWidth)) {
+ Constant *X =
+ ConstantInt::get(SA->getContext(),
+ APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
+ return BinaryOp(Instruction::UDiv, Op->getOperand(0), X);
+ }
+ }
+ return BinaryOp(Op);
+
+ case Instruction::ExtractValue: {
+ auto *EVI = cast<ExtractValueInst>(Op);
+ if (EVI->getNumIndices() != 1 || EVI->getIndices()[0] != 0)
+ break;
+
+ auto *CI = dyn_cast<CallInst>(EVI->getAggregateOperand());
+ if (!CI)
+ break;
+
+ if (auto *F = CI->getCalledFunction())
+ switch (F->getIntrinsicID()) {
+ case Intrinsic::sadd_with_overflow:
+ case Intrinsic::uadd_with_overflow: {
+ if (!isOverflowIntrinsicNoWrap(cast<IntrinsicInst>(CI), DT))
+ return BinaryOp(Instruction::Add, CI->getArgOperand(0),
+ CI->getArgOperand(1));
+
+ // Now that we know that all uses of the arithmetic-result component of
+ // CI are guarded by the overflow check, we can go ahead and pretend
+ // that the arithmetic is non-overflowing.
+ if (F->getIntrinsicID() == Intrinsic::sadd_with_overflow)
+ return BinaryOp(Instruction::Add, CI->getArgOperand(0),
+ CI->getArgOperand(1), /* IsNSW = */ true,
+ /* IsNUW = */ false);
+ else
+ return BinaryOp(Instruction::Add, CI->getArgOperand(0),
+ CI->getArgOperand(1), /* IsNSW = */ false,
+ /* IsNUW*/ true);
+ }
+
+ case Intrinsic::ssub_with_overflow:
+ case Intrinsic::usub_with_overflow:
+ return BinaryOp(Instruction::Sub, CI->getArgOperand(0),
+ CI->getArgOperand(1));
+
+ case Intrinsic::smul_with_overflow:
+ case Intrinsic::umul_with_overflow:
+ return BinaryOp(Instruction::Mul, CI->getArgOperand(0),
+ CI->getArgOperand(1));
+ default:
+ break;
+ }
+ }
+
+ default:
+ break;
+ }
+
+ return None;
+}
+
const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) {
const Loop *L = LI.getLoopFor(PN->getParent());
if (!L || L->getHeader() != PN->getParent())
@@ -3710,7 +3942,7 @@ const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) {
const SCEV *SymbolicName = getUnknown(PN);
assert(ValueExprMap.find_as(PN) == ValueExprMap.end() &&
"PHI node already processed?");
- ValueExprMap.insert(std::make_pair(SCEVCallbackVH(PN, this), SymbolicName));
+ ValueExprMap.insert({SCEVCallbackVH(PN, this), SymbolicName});
// Using this symbolic name for the PHI, analyze the value coming around
// the back-edge.
@@ -3747,13 +3979,11 @@ const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) {
cast<SCEVAddRecExpr>(Accum)->getLoop() == L)) {
SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap;
- // If the increment doesn't overflow, then neither the addrec nor
- // the post-increment will overflow.
- if (const AddOperator *OBO = dyn_cast<AddOperator>(BEValueV)) {
- if (OBO->getOperand(0) == PN) {
- if (OBO->hasNoUnsignedWrap())
+ if (auto BO = MatchBinaryOp(BEValueV, DT)) {
+ if (BO->Opcode == Instruction::Add && BO->LHS == PN) {
+ if (BO->IsNUW)
Flags = setFlags(Flags, SCEV::FlagNUW);
- if (OBO->hasNoSignedWrap())
+ if (BO->IsNSW)
Flags = setFlags(Flags, SCEV::FlagNSW);
}
} else if (GEPOperator *GEP = dyn_cast<GEPOperator>(BEValueV)) {
@@ -3779,16 +4009,19 @@ const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) {
const SCEV *StartVal = getSCEV(StartValueV);
const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags);
- // Since the no-wrap flags are on the increment, they apply to the
- // post-incremented value as well.
- if (isLoopInvariant(Accum, L))
- (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags);
-
// Okay, for the entire analysis of this edge we assumed the PHI
// to be symbolic. We now need to go back and purge all of the
// entries for the scalars that use the symbolic expression.
- ForgetSymbolicName(PN, SymbolicName);
+ forgetSymbolicName(PN, SymbolicName);
ValueExprMap[SCEVCallbackVH(PN, this)] = PHISCEV;
+
+ // We can add Flags to the post-inc expression only if we
+ // know that it us *undefined behavior* for BEValueV to
+ // overflow.
+ if (auto *BEInst = dyn_cast<Instruction>(BEValueV))
+ if (isLoopInvariant(Accum, L) && isAddRecNeverPoison(BEInst, L))
+ (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags);
+
return PHISCEV;
}
}
@@ -3811,12 +4044,18 @@ const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) {
// Okay, for the entire analysis of this edge we assumed the PHI
// to be symbolic. We now need to go back and purge all of the
// entries for the scalars that use the symbolic expression.
- ForgetSymbolicName(PN, SymbolicName);
+ forgetSymbolicName(PN, SymbolicName);
ValueExprMap[SCEVCallbackVH(PN, this)] = Shifted;
return Shifted;
}
}
}
+
+ // Remove the temporary PHI node SCEV that has been inserted while intending
+ // to create an AddRecExpr for this PHI node. We can not keep this temporary
+ // as it will prevent later (possibly simpler) SCEV expressions to be added
+ // to the ValueExprMap.
+ ValueExprMap.erase(PN);
}
return nullptr;
@@ -4083,26 +4322,21 @@ const SCEV *ScalarEvolution::createNodeForSelectOrPHI(Instruction *I,
return getUnknown(I);
}
-/// createNodeForGEP - Expand GEP instructions into add and multiply
-/// operations. This allows them to be analyzed by regular SCEV code.
-///
+/// Expand GEP instructions into add and multiply operations. This allows them
+/// to be analyzed by regular SCEV code.
const SCEV *ScalarEvolution::createNodeForGEP(GEPOperator *GEP) {
- Value *Base = GEP->getOperand(0);
// Don't attempt to analyze GEPs over unsized objects.
- if (!Base->getType()->getPointerElementType()->isSized())
+ if (!GEP->getSourceElementType()->isSized())
return getUnknown(GEP);
SmallVector<const SCEV *, 4> IndexExprs;
for (auto Index = GEP->idx_begin(); Index != GEP->idx_end(); ++Index)
IndexExprs.push_back(getSCEV(*Index));
- return getGEPExpr(GEP->getSourceElementType(), getSCEV(Base), IndexExprs,
- GEP->isInBounds());
+ return getGEPExpr(GEP->getSourceElementType(),
+ getSCEV(GEP->getPointerOperand()),
+ IndexExprs, GEP->isInBounds());
}
-/// GetMinTrailingZeros - Determine the minimum number of zero bits that S is
-/// guaranteed to end in (at every loop iteration). It is, at the same time,
-/// the minimum number of times S is divisible by 2. For example, given {4,+,8}
-/// it returns 2. If S is guaranteed to be 0, it returns the bitwidth of S.
uint32_t
ScalarEvolution::GetMinTrailingZeros(const SCEV *S) {
if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S))
@@ -4180,8 +4414,7 @@ ScalarEvolution::GetMinTrailingZeros(const SCEV *S) {
return 0;
}
-/// GetRangeFromMetadata - Helper method to assign a range to V from
-/// metadata present in the IR.
+/// Helper method to assign a range to V from metadata present in the IR.
static Optional<ConstantRange> GetRangeFromMetadata(Value *V) {
if (Instruction *I = dyn_cast<Instruction>(V))
if (MDNode *MD = I->getMetadata(LLVMContext::MD_range))
@@ -4190,10 +4423,9 @@ static Optional<ConstantRange> GetRangeFromMetadata(Value *V) {
return None;
}
-/// getRange - Determine the range for a particular SCEV. If SignHint is
+/// Determine the range for a particular SCEV. If SignHint is
/// HINT_RANGE_UNSIGNED (resp. HINT_RANGE_SIGNED) then getRange prefers ranges
/// with a "cleaner" unsigned (resp. signed) representation.
-///
ConstantRange
ScalarEvolution::getRange(const SCEV *S,
ScalarEvolution::RangeSignHint SignHint) {
@@ -4282,7 +4514,7 @@ ScalarEvolution::getRange(const SCEV *S,
if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(S)) {
// If there's no unsigned wrap, the value will never be less than its
// initial value.
- if (AddRec->getNoWrapFlags(SCEV::FlagNUW))
+ if (AddRec->hasNoUnsignedWrap())
if (const SCEVConstant *C = dyn_cast<SCEVConstant>(AddRec->getStart()))
if (!C->getValue()->isZero())
ConservativeResult = ConservativeResult.intersectWith(
@@ -4290,7 +4522,7 @@ ScalarEvolution::getRange(const SCEV *S,
// If there's no signed wrap, and all the operands have the same sign or
// zero, the value won't ever change sign.
- if (AddRec->getNoWrapFlags(SCEV::FlagNSW)) {
+ if (AddRec->hasNoSignedWrap()) {
bool AllNonNeg = true;
bool AllNonPos = true;
for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
@@ -4309,66 +4541,22 @@ ScalarEvolution::getRange(const SCEV *S,
// TODO: non-affine addrec
if (AddRec->isAffine()) {
- Type *Ty = AddRec->getType();
const SCEV *MaxBECount = getMaxBackedgeTakenCount(AddRec->getLoop());
if (!isa<SCEVCouldNotCompute>(MaxBECount) &&
getTypeSizeInBits(MaxBECount->getType()) <= BitWidth) {
-
- // Check for overflow. This must be done with ConstantRange arithmetic
- // because we could be called from within the ScalarEvolution overflow
- // checking code.
-
- MaxBECount = getNoopOrZeroExtend(MaxBECount, Ty);
- ConstantRange MaxBECountRange = getUnsignedRange(MaxBECount);
- ConstantRange ZExtMaxBECountRange =
- MaxBECountRange.zextOrTrunc(BitWidth * 2 + 1);
-
- const SCEV *Start = AddRec->getStart();
- const SCEV *Step = AddRec->getStepRecurrence(*this);
- ConstantRange StepSRange = getSignedRange(Step);
- ConstantRange SExtStepSRange = StepSRange.sextOrTrunc(BitWidth * 2 + 1);
-
- ConstantRange StartURange = getUnsignedRange(Start);
- ConstantRange EndURange =
- StartURange.add(MaxBECountRange.multiply(StepSRange));
-
- // Check for unsigned overflow.
- ConstantRange ZExtStartURange =
- StartURange.zextOrTrunc(BitWidth * 2 + 1);
- ConstantRange ZExtEndURange = EndURange.zextOrTrunc(BitWidth * 2 + 1);
- if (ZExtStartURange.add(ZExtMaxBECountRange.multiply(SExtStepSRange)) ==
- ZExtEndURange) {
- APInt Min = APIntOps::umin(StartURange.getUnsignedMin(),
- EndURange.getUnsignedMin());
- APInt Max = APIntOps::umax(StartURange.getUnsignedMax(),
- EndURange.getUnsignedMax());
- bool IsFullRange = Min.isMinValue() && Max.isMaxValue();
- if (!IsFullRange)
- ConservativeResult =
- ConservativeResult.intersectWith(ConstantRange(Min, Max + 1));
- }
-
- ConstantRange StartSRange = getSignedRange(Start);
- ConstantRange EndSRange =
- StartSRange.add(MaxBECountRange.multiply(StepSRange));
-
- // Check for signed overflow. This must be done with ConstantRange
- // arithmetic because we could be called from within the ScalarEvolution
- // overflow checking code.
- ConstantRange SExtStartSRange =
- StartSRange.sextOrTrunc(BitWidth * 2 + 1);
- ConstantRange SExtEndSRange = EndSRange.sextOrTrunc(BitWidth * 2 + 1);
- if (SExtStartSRange.add(ZExtMaxBECountRange.multiply(SExtStepSRange)) ==
- SExtEndSRange) {
- APInt Min = APIntOps::smin(StartSRange.getSignedMin(),
- EndSRange.getSignedMin());
- APInt Max = APIntOps::smax(StartSRange.getSignedMax(),
- EndSRange.getSignedMax());
- bool IsFullRange = Min.isMinSignedValue() && Max.isMaxSignedValue();
- if (!IsFullRange)
- ConservativeResult =
- ConservativeResult.intersectWith(ConstantRange(Min, Max + 1));
- }
+ auto RangeFromAffine = getRangeForAffineAR(
+ AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount,
+ BitWidth);
+ if (!RangeFromAffine.isFullSet())
+ ConservativeResult =
+ ConservativeResult.intersectWith(RangeFromAffine);
+
+ auto RangeFromFactoring = getRangeViaFactoring(
+ AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount,
+ BitWidth);
+ if (!RangeFromFactoring.isFullSet())
+ ConservativeResult =
+ ConservativeResult.intersectWith(RangeFromFactoring);
}
}
@@ -4408,6 +4596,186 @@ ScalarEvolution::getRange(const SCEV *S,
return setRange(S, SignHint, ConservativeResult);
}
+ConstantRange ScalarEvolution::getRangeForAffineAR(const SCEV *Start,
+ const SCEV *Step,
+ const SCEV *MaxBECount,
+ unsigned BitWidth) {
+ assert(!isa<SCEVCouldNotCompute>(MaxBECount) &&
+ getTypeSizeInBits(MaxBECount->getType()) <= BitWidth &&
+ "Precondition!");
+
+ ConstantRange Result(BitWidth, /* isFullSet = */ true);
+
+ // Check for overflow. This must be done with ConstantRange arithmetic
+ // because we could be called from within the ScalarEvolution overflow
+ // checking code.
+
+ MaxBECount = getNoopOrZeroExtend(MaxBECount, Start->getType());
+ ConstantRange MaxBECountRange = getUnsignedRange(MaxBECount);
+ ConstantRange ZExtMaxBECountRange =
+ MaxBECountRange.zextOrTrunc(BitWidth * 2 + 1);
+
+ ConstantRange StepSRange = getSignedRange(Step);
+ ConstantRange SExtStepSRange = StepSRange.sextOrTrunc(BitWidth * 2 + 1);
+
+ ConstantRange StartURange = getUnsignedRange(Start);
+ ConstantRange EndURange =
+ StartURange.add(MaxBECountRange.multiply(StepSRange));
+
+ // Check for unsigned overflow.
+ ConstantRange ZExtStartURange = StartURange.zextOrTrunc(BitWidth * 2 + 1);
+ ConstantRange ZExtEndURange = EndURange.zextOrTrunc(BitWidth * 2 + 1);
+ if (ZExtStartURange.add(ZExtMaxBECountRange.multiply(SExtStepSRange)) ==
+ ZExtEndURange) {
+ APInt Min = APIntOps::umin(StartURange.getUnsignedMin(),
+ EndURange.getUnsignedMin());
+ APInt Max = APIntOps::umax(StartURange.getUnsignedMax(),
+ EndURange.getUnsignedMax());
+ bool IsFullRange = Min.isMinValue() && Max.isMaxValue();
+ if (!IsFullRange)
+ Result =
+ Result.intersectWith(ConstantRange(Min, Max + 1));
+ }
+
+ ConstantRange StartSRange = getSignedRange(Start);
+ ConstantRange EndSRange =
+ StartSRange.add(MaxBECountRange.multiply(StepSRange));
+
+ // Check for signed overflow. This must be done with ConstantRange
+ // arithmetic because we could be called from within the ScalarEvolution
+ // overflow checking code.
+ ConstantRange SExtStartSRange = StartSRange.sextOrTrunc(BitWidth * 2 + 1);
+ ConstantRange SExtEndSRange = EndSRange.sextOrTrunc(BitWidth * 2 + 1);
+ if (SExtStartSRange.add(ZExtMaxBECountRange.multiply(SExtStepSRange)) ==
+ SExtEndSRange) {
+ APInt Min =
+ APIntOps::smin(StartSRange.getSignedMin(), EndSRange.getSignedMin());
+ APInt Max =
+ APIntOps::smax(StartSRange.getSignedMax(), EndSRange.getSignedMax());
+ bool IsFullRange = Min.isMinSignedValue() && Max.isMaxSignedValue();
+ if (!IsFullRange)
+ Result =
+ Result.intersectWith(ConstantRange(Min, Max + 1));
+ }
+
+ return Result;
+}
+
+ConstantRange ScalarEvolution::getRangeViaFactoring(const SCEV *Start,
+ const SCEV *Step,
+ const SCEV *MaxBECount,
+ unsigned BitWidth) {
+ // RangeOf({C?A:B,+,C?P:Q}) == RangeOf(C?{A,+,P}:{B,+,Q})
+ // == RangeOf({A,+,P}) union RangeOf({B,+,Q})
+
+ struct SelectPattern {
+ Value *Condition = nullptr;
+ APInt TrueValue;
+ APInt FalseValue;
+
+ explicit SelectPattern(ScalarEvolution &SE, unsigned BitWidth,
+ const SCEV *S) {
+ Optional<unsigned> CastOp;
+ APInt Offset(BitWidth, 0);
+
+ assert(SE.getTypeSizeInBits(S->getType()) == BitWidth &&
+ "Should be!");
+
+ // Peel off a constant offset:
+ if (auto *SA = dyn_cast<SCEVAddExpr>(S)) {
+ // In the future we could consider being smarter here and handle
+ // {Start+Step,+,Step} too.
+ if (SA->getNumOperands() != 2 || !isa<SCEVConstant>(SA->getOperand(0)))
+ return;
+
+ Offset = cast<SCEVConstant>(SA->getOperand(0))->getAPInt();
+ S = SA->getOperand(1);
+ }
+
+ // Peel off a cast operation
+ if (auto *SCast = dyn_cast<SCEVCastExpr>(S)) {
+ CastOp = SCast->getSCEVType();
+ S = SCast->getOperand();
+ }
+
+ using namespace llvm::PatternMatch;
+
+ auto *SU = dyn_cast<SCEVUnknown>(S);
+ const APInt *TrueVal, *FalseVal;
+ if (!SU ||
+ !match(SU->getValue(), m_Select(m_Value(Condition), m_APInt(TrueVal),
+ m_APInt(FalseVal)))) {
+ Condition = nullptr;
+ return;
+ }
+
+ TrueValue = *TrueVal;
+ FalseValue = *FalseVal;
+
+ // Re-apply the cast we peeled off earlier
+ if (CastOp.hasValue())
+ switch (*CastOp) {
+ default:
+ llvm_unreachable("Unknown SCEV cast type!");
+
+ case scTruncate:
+ TrueValue = TrueValue.trunc(BitWidth);
+ FalseValue = FalseValue.trunc(BitWidth);
+ break;
+ case scZeroExtend:
+ TrueValue = TrueValue.zext(BitWidth);
+ FalseValue = FalseValue.zext(BitWidth);
+ break;
+ case scSignExtend:
+ TrueValue = TrueValue.sext(BitWidth);
+ FalseValue = FalseValue.sext(BitWidth);
+ break;
+ }
+
+ // Re-apply the constant offset we peeled off earlier
+ TrueValue += Offset;
+ FalseValue += Offset;
+ }
+
+ bool isRecognized() { return Condition != nullptr; }
+ };
+
+ SelectPattern StartPattern(*this, BitWidth, Start);
+ if (!StartPattern.isRecognized())
+ return ConstantRange(BitWidth, /* isFullSet = */ true);
+
+ SelectPattern StepPattern(*this, BitWidth, Step);
+ if (!StepPattern.isRecognized())
+ return ConstantRange(BitWidth, /* isFullSet = */ true);
+
+ if (StartPattern.Condition != StepPattern.Condition) {
+ // We don't handle this case today; but we could, by considering four
+ // possibilities below instead of two. I'm not sure if there are cases where
+ // that will help over what getRange already does, though.
+ return ConstantRange(BitWidth, /* isFullSet = */ true);
+ }
+
+ // NB! Calling ScalarEvolution::getConstant is fine, but we should not try to
+ // construct arbitrary general SCEV expressions here. This function is called
+ // from deep in the call stack, and calling getSCEV (on a sext instruction,
+ // say) can end up caching a suboptimal value.
+
+ // FIXME: without the explicit `this` receiver below, MSVC errors out with
+ // C2352 and C2512 (otherwise it isn't needed).
+
+ const SCEV *TrueStart = this->getConstant(StartPattern.TrueValue);
+ const SCEV *TrueStep = this->getConstant(StepPattern.TrueValue);
+ const SCEV *FalseStart = this->getConstant(StartPattern.FalseValue);
+ const SCEV *FalseStep = this->getConstant(StepPattern.FalseValue);
+
+ ConstantRange TrueRange =
+ this->getRangeForAffineAR(TrueStart, TrueStep, MaxBECount, BitWidth);
+ ConstantRange FalseRange =
+ this->getRangeForAffineAR(FalseStart, FalseStep, MaxBECount, BitWidth);
+
+ return TrueRange.unionWith(FalseRange);
+}
+
SCEV::NoWrapFlags ScalarEvolution::getNoWrapFlagsFromUB(const Value *V) {
if (isa<ConstantExpr>(V)) return SCEV::FlagAnyWrap;
const BinaryOperator *BinOp = cast<BinaryOperator>(V);
@@ -4418,273 +4786,363 @@ SCEV::NoWrapFlags ScalarEvolution::getNoWrapFlagsFromUB(const Value *V) {
Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW);
if (BinOp->hasNoSignedWrap())
Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNSW);
- if (Flags == SCEV::FlagAnyWrap) {
+ if (Flags == SCEV::FlagAnyWrap)
return SCEV::FlagAnyWrap;
- }
- // Here we check that BinOp is in the header of the innermost loop
- // containing BinOp, since we only deal with instructions in the loop
- // header. The actual loop we need to check later will come from an add
- // recurrence, but getting that requires computing the SCEV of the operands,
- // which can be expensive. This check we can do cheaply to rule out some
- // cases early.
- Loop *innermostContainingLoop = LI.getLoopFor(BinOp->getParent());
- if (innermostContainingLoop == nullptr ||
- innermostContainingLoop->getHeader() != BinOp->getParent())
- return SCEV::FlagAnyWrap;
+ return isSCEVExprNeverPoison(BinOp) ? Flags : SCEV::FlagAnyWrap;
+}
- // Only proceed if we can prove that BinOp does not yield poison.
- if (!isKnownNotFullPoison(BinOp)) return SCEV::FlagAnyWrap;
+bool ScalarEvolution::isSCEVExprNeverPoison(const Instruction *I) {
+ // Here we check that I is in the header of the innermost loop containing I,
+ // since we only deal with instructions in the loop header. The actual loop we
+ // need to check later will come from an add recurrence, but getting that
+ // requires computing the SCEV of the operands, which can be expensive. This
+ // check we can do cheaply to rule out some cases early.
+ Loop *InnermostContainingLoop = LI.getLoopFor(I->getParent());
+ if (InnermostContainingLoop == nullptr ||
+ InnermostContainingLoop->getHeader() != I->getParent())
+ return false;
+
+ // Only proceed if we can prove that I does not yield poison.
+ if (!isKnownNotFullPoison(I)) return false;
- // At this point we know that if V is executed, then it does not wrap
- // according to at least one of NSW or NUW. If V is not executed, then we do
- // not know if the calculation that V represents would wrap. Multiple
- // instructions can map to the same SCEV. If we apply NSW or NUW from V to
+ // At this point we know that if I is executed, then it does not wrap
+ // according to at least one of NSW or NUW. If I is not executed, then we do
+ // not know if the calculation that I represents would wrap. Multiple
+ // instructions can map to the same SCEV. If we apply NSW or NUW from I to
// the SCEV, we must guarantee no wrapping for that SCEV also when it is
// derived from other instructions that map to the same SCEV. We cannot make
- // that guarantee for cases where V is not executed. So we need to find the
- // loop that V is considered in relation to and prove that V is executed for
- // every iteration of that loop. That implies that the value that V
+ // that guarantee for cases where I is not executed. So we need to find the
+ // loop that I is considered in relation to and prove that I is executed for
+ // every iteration of that loop. That implies that the value that I
// calculates does not wrap anywhere in the loop, so then we can apply the
// flags to the SCEV.
//
- // We check isLoopInvariant to disambiguate in case we are adding two
- // recurrences from different loops, so that we know which loop to prove
- // that V is executed in.
- for (int OpIndex = 0; OpIndex < 2; ++OpIndex) {
- const SCEV *Op = getSCEV(BinOp->getOperand(OpIndex));
+ // We check isLoopInvariant to disambiguate in case we are adding recurrences
+ // from different loops, so that we know which loop to prove that I is
+ // executed in.
+ for (unsigned OpIndex = 0; OpIndex < I->getNumOperands(); ++OpIndex) {
+ // I could be an extractvalue from a call to an overflow intrinsic.
+ // TODO: We can do better here in some cases.
+ if (!isSCEVable(I->getOperand(OpIndex)->getType()))
+ return false;
+ const SCEV *Op = getSCEV(I->getOperand(OpIndex));
if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(Op)) {
- const int OtherOpIndex = 1 - OpIndex;
- const SCEV *OtherOp = getSCEV(BinOp->getOperand(OtherOpIndex));
- if (isLoopInvariant(OtherOp, AddRec->getLoop()) &&
- isGuaranteedToExecuteForEveryIteration(BinOp, AddRec->getLoop()))
- return Flags;
+ bool AllOtherOpsLoopInvariant = true;
+ for (unsigned OtherOpIndex = 0; OtherOpIndex < I->getNumOperands();
+ ++OtherOpIndex) {
+ if (OtherOpIndex != OpIndex) {
+ const SCEV *OtherOp = getSCEV(I->getOperand(OtherOpIndex));
+ if (!isLoopInvariant(OtherOp, AddRec->getLoop())) {
+ AllOtherOpsLoopInvariant = false;
+ break;
+ }
+ }
+ }
+ if (AllOtherOpsLoopInvariant &&
+ isGuaranteedToExecuteForEveryIteration(I, AddRec->getLoop()))
+ return true;
}
}
- return SCEV::FlagAnyWrap;
+ return false;
+}
+
+bool ScalarEvolution::isAddRecNeverPoison(const Instruction *I, const Loop *L) {
+ // If we know that \c I can never be poison period, then that's enough.
+ if (isSCEVExprNeverPoison(I))
+ return true;
+
+ // For an add recurrence specifically, we assume that infinite loops without
+ // side effects are undefined behavior, and then reason as follows:
+ //
+ // If the add recurrence is poison in any iteration, it is poison on all
+ // future iterations (since incrementing poison yields poison). If the result
+ // of the add recurrence is fed into the loop latch condition and the loop
+ // does not contain any throws or exiting blocks other than the latch, we now
+ // have the ability to "choose" whether the backedge is taken or not (by
+ // choosing a sufficiently evil value for the poison feeding into the branch)
+ // for every iteration including and after the one in which \p I first became
+ // poison. There are two possibilities (let's call the iteration in which \p
+ // I first became poison as K):
+ //
+ // 1. In the set of iterations including and after K, the loop body executes
+ // no side effects. In this case executing the backege an infinte number
+ // of times will yield undefined behavior.
+ //
+ // 2. In the set of iterations including and after K, the loop body executes
+ // at least one side effect. In this case, that specific instance of side
+ // effect is control dependent on poison, which also yields undefined
+ // behavior.
+
+ auto *ExitingBB = L->getExitingBlock();
+ auto *LatchBB = L->getLoopLatch();
+ if (!ExitingBB || !LatchBB || ExitingBB != LatchBB)
+ return false;
+
+ SmallPtrSet<const Instruction *, 16> Pushed;
+ SmallVector<const Instruction *, 8> PoisonStack;
+
+ // We start by assuming \c I, the post-inc add recurrence, is poison. Only
+ // things that are known to be fully poison under that assumption go on the
+ // PoisonStack.
+ Pushed.insert(I);
+ PoisonStack.push_back(I);
+
+ bool LatchControlDependentOnPoison = false;
+ while (!PoisonStack.empty() && !LatchControlDependentOnPoison) {
+ const Instruction *Poison = PoisonStack.pop_back_val();
+
+ for (auto *PoisonUser : Poison->users()) {
+ if (propagatesFullPoison(cast<Instruction>(PoisonUser))) {
+ if (Pushed.insert(cast<Instruction>(PoisonUser)).second)
+ PoisonStack.push_back(cast<Instruction>(PoisonUser));
+ } else if (auto *BI = dyn_cast<BranchInst>(PoisonUser)) {
+ assert(BI->isConditional() && "Only possibility!");
+ if (BI->getParent() == LatchBB) {
+ LatchControlDependentOnPoison = true;
+ break;
+ }
+ }
+ }
+ }
+
+ return LatchControlDependentOnPoison && loopHasNoAbnormalExits(L);
+}
+
+bool ScalarEvolution::loopHasNoAbnormalExits(const Loop *L) {
+ auto Itr = LoopHasNoAbnormalExits.find(L);
+ if (Itr == LoopHasNoAbnormalExits.end()) {
+ auto NoAbnormalExitInBB = [&](BasicBlock *BB) {
+ return all_of(*BB, [](Instruction &I) {
+ return isGuaranteedToTransferExecutionToSuccessor(&I);
+ });
+ };
+
+ auto InsertPair = LoopHasNoAbnormalExits.insert(
+ {L, all_of(L->getBlocks(), NoAbnormalExitInBB)});
+ assert(InsertPair.second && "We just checked!");
+ Itr = InsertPair.first;
+ }
+
+ return Itr->second;
}
-/// createSCEV - We know that there is no SCEV for the specified value. Analyze
-/// the expression.
-///
const SCEV *ScalarEvolution::createSCEV(Value *V) {
if (!isSCEVable(V->getType()))
return getUnknown(V);
- unsigned Opcode = Instruction::UserOp1;
if (Instruction *I = dyn_cast<Instruction>(V)) {
- Opcode = I->getOpcode();
-
// Don't attempt to analyze instructions in blocks that aren't
// reachable. Such instructions don't matter, and they aren't required
// to obey basic rules for definitions dominating uses which this
// analysis depends on.
if (!DT.isReachableFromEntry(I->getParent()))
return getUnknown(V);
- } else if (ConstantExpr *CE = dyn_cast<ConstantExpr>(V))
- Opcode = CE->getOpcode();
- else if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
+ } else if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
return getConstant(CI);
else if (isa<ConstantPointerNull>(V))
return getZero(V->getType());
else if (GlobalAlias *GA = dyn_cast<GlobalAlias>(V))
- return GA->mayBeOverridden() ? getUnknown(V) : getSCEV(GA->getAliasee());
- else
+ return GA->isInterposable() ? getUnknown(V) : getSCEV(GA->getAliasee());
+ else if (!isa<ConstantExpr>(V))
return getUnknown(V);
Operator *U = cast<Operator>(V);
- switch (Opcode) {
- case Instruction::Add: {
- // The simple thing to do would be to just call getSCEV on both operands
- // and call getAddExpr with the result. However if we're looking at a
- // bunch of things all added together, this can be quite inefficient,
- // because it leads to N-1 getAddExpr calls for N ultimate operands.
- // Instead, gather up all the operands and make a single getAddExpr call.
- // LLVM IR canonical form means we need only traverse the left operands.
- SmallVector<const SCEV *, 4> AddOps;
- for (Value *Op = U;; Op = U->getOperand(0)) {
- U = dyn_cast<Operator>(Op);
- unsigned Opcode = U ? U->getOpcode() : 0;
- if (!U || (Opcode != Instruction::Add && Opcode != Instruction::Sub)) {
- assert(Op != V && "V should be an add");
- AddOps.push_back(getSCEV(Op));
- break;
- }
+ if (auto BO = MatchBinaryOp(U, DT)) {
+ switch (BO->Opcode) {
+ case Instruction::Add: {
+ // The simple thing to do would be to just call getSCEV on both operands
+ // and call getAddExpr with the result. However if we're looking at a
+ // bunch of things all added together, this can be quite inefficient,
+ // because it leads to N-1 getAddExpr calls for N ultimate operands.
+ // Instead, gather up all the operands and make a single getAddExpr call.
+ // LLVM IR canonical form means we need only traverse the left operands.
+ SmallVector<const SCEV *, 4> AddOps;
+ do {
+ if (BO->Op) {
+ if (auto *OpSCEV = getExistingSCEV(BO->Op)) {
+ AddOps.push_back(OpSCEV);
+ break;
+ }
- if (auto *OpSCEV = getExistingSCEV(U)) {
- AddOps.push_back(OpSCEV);
- break;
- }
+ // If a NUW or NSW flag can be applied to the SCEV for this
+ // addition, then compute the SCEV for this addition by itself
+ // with a separate call to getAddExpr. We need to do that
+ // instead of pushing the operands of the addition onto AddOps,
+ // since the flags are only known to apply to this particular
+ // addition - they may not apply to other additions that can be
+ // formed with operands from AddOps.
+ const SCEV *RHS = getSCEV(BO->RHS);
+ SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op);
+ if (Flags != SCEV::FlagAnyWrap) {
+ const SCEV *LHS = getSCEV(BO->LHS);
+ if (BO->Opcode == Instruction::Sub)
+ AddOps.push_back(getMinusSCEV(LHS, RHS, Flags));
+ else
+ AddOps.push_back(getAddExpr(LHS, RHS, Flags));
+ break;
+ }
+ }
- // If a NUW or NSW flag can be applied to the SCEV for this
- // addition, then compute the SCEV for this addition by itself
- // with a separate call to getAddExpr. We need to do that
- // instead of pushing the operands of the addition onto AddOps,
- // since the flags are only known to apply to this particular
- // addition - they may not apply to other additions that can be
- // formed with operands from AddOps.
- const SCEV *RHS = getSCEV(U->getOperand(1));
- SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(U);
- if (Flags != SCEV::FlagAnyWrap) {
- const SCEV *LHS = getSCEV(U->getOperand(0));
- if (Opcode == Instruction::Sub)
- AddOps.push_back(getMinusSCEV(LHS, RHS, Flags));
+ if (BO->Opcode == Instruction::Sub)
+ AddOps.push_back(getNegativeSCEV(getSCEV(BO->RHS)));
else
- AddOps.push_back(getAddExpr(LHS, RHS, Flags));
- break;
- }
+ AddOps.push_back(getSCEV(BO->RHS));
- if (Opcode == Instruction::Sub)
- AddOps.push_back(getNegativeSCEV(RHS));
- else
- AddOps.push_back(RHS);
+ auto NewBO = MatchBinaryOp(BO->LHS, DT);
+ if (!NewBO || (NewBO->Opcode != Instruction::Add &&
+ NewBO->Opcode != Instruction::Sub)) {
+ AddOps.push_back(getSCEV(BO->LHS));
+ break;
+ }
+ BO = NewBO;
+ } while (true);
+
+ return getAddExpr(AddOps);
}
- return getAddExpr(AddOps);
- }
- case Instruction::Mul: {
- SmallVector<const SCEV *, 4> MulOps;
- for (Value *Op = U;; Op = U->getOperand(0)) {
- U = dyn_cast<Operator>(Op);
- if (!U || U->getOpcode() != Instruction::Mul) {
- assert(Op != V && "V should be a mul");
- MulOps.push_back(getSCEV(Op));
- break;
- }
+ case Instruction::Mul: {
+ SmallVector<const SCEV *, 4> MulOps;
+ do {
+ if (BO->Op) {
+ if (auto *OpSCEV = getExistingSCEV(BO->Op)) {
+ MulOps.push_back(OpSCEV);
+ break;
+ }
- if (auto *OpSCEV = getExistingSCEV(U)) {
- MulOps.push_back(OpSCEV);
- break;
- }
+ SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op);
+ if (Flags != SCEV::FlagAnyWrap) {
+ MulOps.push_back(
+ getMulExpr(getSCEV(BO->LHS), getSCEV(BO->RHS), Flags));
+ break;
+ }
+ }
- SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(U);
- if (Flags != SCEV::FlagAnyWrap) {
- MulOps.push_back(getMulExpr(getSCEV(U->getOperand(0)),
- getSCEV(U->getOperand(1)), Flags));
- break;
+ MulOps.push_back(getSCEV(BO->RHS));
+ auto NewBO = MatchBinaryOp(BO->LHS, DT);
+ if (!NewBO || NewBO->Opcode != Instruction::Mul) {
+ MulOps.push_back(getSCEV(BO->LHS));
+ break;
+ }
+ BO = NewBO;
+ } while (true);
+
+ return getMulExpr(MulOps);
+ }
+ case Instruction::UDiv:
+ return getUDivExpr(getSCEV(BO->LHS), getSCEV(BO->RHS));
+ case Instruction::Sub: {
+ SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap;
+ if (BO->Op)
+ Flags = getNoWrapFlagsFromUB(BO->Op);
+ return getMinusSCEV(getSCEV(BO->LHS), getSCEV(BO->RHS), Flags);
+ }
+ case Instruction::And:
+ // For an expression like x&255 that merely masks off the high bits,
+ // use zext(trunc(x)) as the SCEV expression.
+ if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) {
+ if (CI->isNullValue())
+ return getSCEV(BO->RHS);
+ if (CI->isAllOnesValue())
+ return getSCEV(BO->LHS);
+ const APInt &A = CI->getValue();
+
+ // Instcombine's ShrinkDemandedConstant may strip bits out of
+ // constants, obscuring what would otherwise be a low-bits mask.
+ // Use computeKnownBits to compute what ShrinkDemandedConstant
+ // knew about to reconstruct a low-bits mask value.
+ unsigned LZ = A.countLeadingZeros();
+ unsigned TZ = A.countTrailingZeros();
+ unsigned BitWidth = A.getBitWidth();
+ APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0);
+ computeKnownBits(BO->LHS, KnownZero, KnownOne, getDataLayout(),
+ 0, &AC, nullptr, &DT);
+
+ APInt EffectiveMask =
+ APInt::getLowBitsSet(BitWidth, BitWidth - LZ - TZ).shl(TZ);
+ if ((LZ != 0 || TZ != 0) && !((~A & ~KnownZero) & EffectiveMask)) {
+ const SCEV *MulCount = getConstant(ConstantInt::get(
+ getContext(), APInt::getOneBitSet(BitWidth, TZ)));
+ return getMulExpr(
+ getZeroExtendExpr(
+ getTruncateExpr(
+ getUDivExactExpr(getSCEV(BO->LHS), MulCount),
+ IntegerType::get(getContext(), BitWidth - LZ - TZ)),
+ BO->LHS->getType()),
+ MulCount);
+ }
}
+ break;
- MulOps.push_back(getSCEV(U->getOperand(1)));
- }
- return getMulExpr(MulOps);
- }
- case Instruction::UDiv:
- return getUDivExpr(getSCEV(U->getOperand(0)),
- getSCEV(U->getOperand(1)));
- case Instruction::Sub:
- return getMinusSCEV(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1)),
- getNoWrapFlagsFromUB(U));
- case Instruction::And:
- // For an expression like x&255 that merely masks off the high bits,
- // use zext(trunc(x)) as the SCEV expression.
- if (ConstantInt *CI = dyn_cast<ConstantInt>(U->getOperand(1))) {
- if (CI->isNullValue())
- return getSCEV(U->getOperand(1));
- if (CI->isAllOnesValue())
- return getSCEV(U->getOperand(0));
- const APInt &A = CI->getValue();
-
- // Instcombine's ShrinkDemandedConstant may strip bits out of
- // constants, obscuring what would otherwise be a low-bits mask.
- // Use computeKnownBits to compute what ShrinkDemandedConstant
- // knew about to reconstruct a low-bits mask value.
- unsigned LZ = A.countLeadingZeros();
- unsigned TZ = A.countTrailingZeros();
- unsigned BitWidth = A.getBitWidth();
- APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0);
- computeKnownBits(U->getOperand(0), KnownZero, KnownOne, getDataLayout(),
- 0, &AC, nullptr, &DT);
-
- APInt EffectiveMask =
- APInt::getLowBitsSet(BitWidth, BitWidth - LZ - TZ).shl(TZ);
- if ((LZ != 0 || TZ != 0) && !((~A & ~KnownZero) & EffectiveMask)) {
- const SCEV *MulCount = getConstant(
- ConstantInt::get(getContext(), APInt::getOneBitSet(BitWidth, TZ)));
- return getMulExpr(
- getZeroExtendExpr(
- getTruncateExpr(
- getUDivExactExpr(getSCEV(U->getOperand(0)), MulCount),
- IntegerType::get(getContext(), BitWidth - LZ - TZ)),
- U->getType()),
- MulCount);
+ case Instruction::Or:
+ // If the RHS of the Or is a constant, we may have something like:
+ // X*4+1 which got turned into X*4|1. Handle this as an Add so loop
+ // optimizations will transparently handle this case.
+ //
+ // In order for this transformation to be safe, the LHS must be of the
+ // form X*(2^n) and the Or constant must be less than 2^n.
+ if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) {
+ const SCEV *LHS = getSCEV(BO->LHS);
+ const APInt &CIVal = CI->getValue();
+ if (GetMinTrailingZeros(LHS) >=
+ (CIVal.getBitWidth() - CIVal.countLeadingZeros())) {
+ // Build a plain add SCEV.
+ const SCEV *S = getAddExpr(LHS, getSCEV(CI));
+ // If the LHS of the add was an addrec and it has no-wrap flags,
+ // transfer the no-wrap flags, since an or won't introduce a wrap.
+ if (const SCEVAddRecExpr *NewAR = dyn_cast<SCEVAddRecExpr>(S)) {
+ const SCEVAddRecExpr *OldAR = cast<SCEVAddRecExpr>(LHS);
+ const_cast<SCEVAddRecExpr *>(NewAR)->setNoWrapFlags(
+ OldAR->getNoWrapFlags());
+ }
+ return S;
+ }
}
- }
- break;
+ break;
- case Instruction::Or:
- // If the RHS of the Or is a constant, we may have something like:
- // X*4+1 which got turned into X*4|1. Handle this as an Add so loop
- // optimizations will transparently handle this case.
- //
- // In order for this transformation to be safe, the LHS must be of the
- // form X*(2^n) and the Or constant must be less than 2^n.
- if (ConstantInt *CI = dyn_cast<ConstantInt>(U->getOperand(1))) {
- const SCEV *LHS = getSCEV(U->getOperand(0));
- const APInt &CIVal = CI->getValue();
- if (GetMinTrailingZeros(LHS) >=
- (CIVal.getBitWidth() - CIVal.countLeadingZeros())) {
- // Build a plain add SCEV.
- const SCEV *S = getAddExpr(LHS, getSCEV(CI));
- // If the LHS of the add was an addrec and it has no-wrap flags,
- // transfer the no-wrap flags, since an or won't introduce a wrap.
- if (const SCEVAddRecExpr *NewAR = dyn_cast<SCEVAddRecExpr>(S)) {
- const SCEVAddRecExpr *OldAR = cast<SCEVAddRecExpr>(LHS);
- const_cast<SCEVAddRecExpr *>(NewAR)->setNoWrapFlags(
- OldAR->getNoWrapFlags());
- }
- return S;
+ case Instruction::Xor:
+ if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) {
+ // If the RHS of xor is -1, then this is a not operation.
+ if (CI->isAllOnesValue())
+ return getNotSCEV(getSCEV(BO->LHS));
+
+ // Model xor(and(x, C), C) as and(~x, C), if C is a low-bits mask.
+ // This is a variant of the check for xor with -1, and it handles
+ // the case where instcombine has trimmed non-demanded bits out
+ // of an xor with -1.
+ if (auto *LBO = dyn_cast<BinaryOperator>(BO->LHS))
+ if (ConstantInt *LCI = dyn_cast<ConstantInt>(LBO->getOperand(1)))
+ if (LBO->getOpcode() == Instruction::And &&
+ LCI->getValue() == CI->getValue())
+ if (const SCEVZeroExtendExpr *Z =
+ dyn_cast<SCEVZeroExtendExpr>(getSCEV(BO->LHS))) {
+ Type *UTy = BO->LHS->getType();
+ const SCEV *Z0 = Z->getOperand();
+ Type *Z0Ty = Z0->getType();
+ unsigned Z0TySize = getTypeSizeInBits(Z0Ty);
+
+ // If C is a low-bits mask, the zero extend is serving to
+ // mask off the high bits. Complement the operand and
+ // re-apply the zext.
+ if (APIntOps::isMask(Z0TySize, CI->getValue()))
+ return getZeroExtendExpr(getNotSCEV(Z0), UTy);
+
+ // If C is a single bit, it may be in the sign-bit position
+ // before the zero-extend. In this case, represent the xor
+ // using an add, which is equivalent, and re-apply the zext.
+ APInt Trunc = CI->getValue().trunc(Z0TySize);
+ if (Trunc.zext(getTypeSizeInBits(UTy)) == CI->getValue() &&
+ Trunc.isSignBit())
+ return getZeroExtendExpr(getAddExpr(Z0, getConstant(Trunc)),
+ UTy);
+ }
}
- }
- break;
- case Instruction::Xor:
- if (ConstantInt *CI = dyn_cast<ConstantInt>(U->getOperand(1))) {
- // If the RHS of the xor is a signbit, then this is just an add.
- // Instcombine turns add of signbit into xor as a strength reduction step.
- if (CI->getValue().isSignBit())
- return getAddExpr(getSCEV(U->getOperand(0)),
- getSCEV(U->getOperand(1)));
-
- // If the RHS of xor is -1, then this is a not operation.
- if (CI->isAllOnesValue())
- return getNotSCEV(getSCEV(U->getOperand(0)));
-
- // Model xor(and(x, C), C) as and(~x, C), if C is a low-bits mask.
- // This is a variant of the check for xor with -1, and it handles
- // the case where instcombine has trimmed non-demanded bits out
- // of an xor with -1.
- if (BinaryOperator *BO = dyn_cast<BinaryOperator>(U->getOperand(0)))
- if (ConstantInt *LCI = dyn_cast<ConstantInt>(BO->getOperand(1)))
- if (BO->getOpcode() == Instruction::And &&
- LCI->getValue() == CI->getValue())
- if (const SCEVZeroExtendExpr *Z =
- dyn_cast<SCEVZeroExtendExpr>(getSCEV(U->getOperand(0)))) {
- Type *UTy = U->getType();
- const SCEV *Z0 = Z->getOperand();
- Type *Z0Ty = Z0->getType();
- unsigned Z0TySize = getTypeSizeInBits(Z0Ty);
-
- // If C is a low-bits mask, the zero extend is serving to
- // mask off the high bits. Complement the operand and
- // re-apply the zext.
- if (APIntOps::isMask(Z0TySize, CI->getValue()))
- return getZeroExtendExpr(getNotSCEV(Z0), UTy);
-
- // If C is a single bit, it may be in the sign-bit position
- // before the zero-extend. In this case, represent the xor
- // using an add, which is equivalent, and re-apply the zext.
- APInt Trunc = CI->getValue().trunc(Z0TySize);
- if (Trunc.zext(getTypeSizeInBits(UTy)) == CI->getValue() &&
- Trunc.isSignBit())
- return getZeroExtendExpr(getAddExpr(Z0, getConstant(Trunc)),
- UTy);
- }
- }
- break;
+ break;
case Instruction::Shl:
// Turn shift left of a constant amount into a multiply.
- if (ConstantInt *SA = dyn_cast<ConstantInt>(U->getOperand(1))) {
- uint32_t BitWidth = cast<IntegerType>(U->getType())->getBitWidth();
+ if (ConstantInt *SA = dyn_cast<ConstantInt>(BO->RHS)) {
+ uint32_t BitWidth = cast<IntegerType>(SA->getType())->getBitWidth();
// If the shift count is not less than the bitwidth, the result of
// the shift is undefined. Don't try to analyze it, because the
@@ -4700,58 +5158,43 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) {
// http://lists.llvm.org/pipermail/llvm-dev/2015-April/084195.html
// and http://reviews.llvm.org/D8890 .
auto Flags = SCEV::FlagAnyWrap;
- if (SA->getValue().ult(BitWidth - 1)) Flags = getNoWrapFlagsFromUB(U);
+ if (BO->Op && SA->getValue().ult(BitWidth - 1))
+ Flags = getNoWrapFlagsFromUB(BO->Op);
Constant *X = ConstantInt::get(getContext(),
APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
- return getMulExpr(getSCEV(U->getOperand(0)), getSCEV(X), Flags);
+ return getMulExpr(getSCEV(BO->LHS), getSCEV(X), Flags);
}
break;
- case Instruction::LShr:
- // Turn logical shift right of a constant into a unsigned divide.
- if (ConstantInt *SA = dyn_cast<ConstantInt>(U->getOperand(1))) {
- uint32_t BitWidth = cast<IntegerType>(U->getType())->getBitWidth();
-
- // If the shift count is not less than the bitwidth, the result of
- // the shift is undefined. Don't try to analyze it, because the
- // resolution chosen here may differ from the resolution chosen in
- // other parts of the compiler.
- if (SA->getValue().uge(BitWidth))
- break;
+ case Instruction::AShr:
+ // For a two-shift sext-inreg, use sext(trunc(x)) as the SCEV expression.
+ if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS))
+ if (Operator *L = dyn_cast<Operator>(BO->LHS))
+ if (L->getOpcode() == Instruction::Shl &&
+ L->getOperand(1) == BO->RHS) {
+ uint64_t BitWidth = getTypeSizeInBits(BO->LHS->getType());
+
+ // If the shift count is not less than the bitwidth, the result of
+ // the shift is undefined. Don't try to analyze it, because the
+ // resolution chosen here may differ from the resolution chosen in
+ // other parts of the compiler.
+ if (CI->getValue().uge(BitWidth))
+ break;
- Constant *X = ConstantInt::get(getContext(),
- APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
- return getUDivExpr(getSCEV(U->getOperand(0)), getSCEV(X));
+ uint64_t Amt = BitWidth - CI->getZExtValue();
+ if (Amt == BitWidth)
+ return getSCEV(L->getOperand(0)); // shift by zero --> noop
+ return getSignExtendExpr(
+ getTruncateExpr(getSCEV(L->getOperand(0)),
+ IntegerType::get(getContext(), Amt)),
+ BO->LHS->getType());
+ }
+ break;
}
- break;
-
- case Instruction::AShr:
- // For a two-shift sext-inreg, use sext(trunc(x)) as the SCEV expression.
- if (ConstantInt *CI = dyn_cast<ConstantInt>(U->getOperand(1)))
- if (Operator *L = dyn_cast<Operator>(U->getOperand(0)))
- if (L->getOpcode() == Instruction::Shl &&
- L->getOperand(1) == U->getOperand(1)) {
- uint64_t BitWidth = getTypeSizeInBits(U->getType());
-
- // If the shift count is not less than the bitwidth, the result of
- // the shift is undefined. Don't try to analyze it, because the
- // resolution chosen here may differ from the resolution chosen in
- // other parts of the compiler.
- if (CI->getValue().uge(BitWidth))
- break;
-
- uint64_t Amt = BitWidth - CI->getZExtValue();
- if (Amt == BitWidth)
- return getSCEV(L->getOperand(0)); // shift by zero --> noop
- return
- getSignExtendExpr(getTruncateExpr(getSCEV(L->getOperand(0)),
- IntegerType::get(getContext(),
- Amt)),
- U->getType());
- }
- break;
+ }
+ switch (U->getOpcode()) {
case Instruction::Trunc:
return getTruncateExpr(getSCEV(U->getOperand(0)), U->getType());
@@ -4786,8 +5229,12 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) {
if (isa<Instruction>(U))
return createNodeForSelectOrPHI(cast<Instruction>(U), U->getOperand(0),
U->getOperand(1), U->getOperand(2));
+ break;
- default: // We cannot analyze this expression.
+ case Instruction::Call:
+ case Instruction::Invoke:
+ if (Value *RV = CallSite(U).getReturnedArgOperand())
+ return getSCEV(RV);
break;
}
@@ -4808,16 +5255,6 @@ unsigned ScalarEvolution::getSmallConstantTripCount(Loop *L) {
return 0;
}
-/// getSmallConstantTripCount - Returns the maximum trip count of this loop as a
-/// normal unsigned value. Returns 0 if the trip count is unknown or not
-/// constant. Will also return 0 if the maximum trip count is very large (>=
-/// 2^32).
-///
-/// This "trip count" assumes that control exits via ExitingBlock. More
-/// precisely, it is the number of times that control may reach ExitingBlock
-/// before taking the branch. For loops with multiple exits, it may not be the
-/// number times that the loop header executes because the loop may exit
-/// prematurely via another branch.
unsigned ScalarEvolution::getSmallConstantTripCount(Loop *L,
BasicBlock *ExitingBlock) {
assert(ExitingBlock && "Must pass a non-null exiting block!");
@@ -4846,10 +5283,10 @@ unsigned ScalarEvolution::getSmallConstantTripMultiple(Loop *L) {
return 0;
}
-/// getSmallConstantTripMultiple - Returns the largest constant divisor of the
-/// trip count of this loop as a normal unsigned value, if possible. This
-/// means that the actual trip count is always a multiple of the returned
-/// value (don't forget the trip count could very well be zero as well!).
+/// Returns the largest constant divisor of the trip count of this loop as a
+/// normal unsigned value, if possible. This means that the actual trip count is
+/// always a multiple of the returned value (don't forget the trip count could
+/// very well be zero as well!).
///
/// Returns 1 if the trip count is unknown or not guaranteed to be the
/// multiple of a constant (which is also the case if the trip count is simply
@@ -4891,37 +5328,30 @@ ScalarEvolution::getSmallConstantTripMultiple(Loop *L,
return (unsigned)Result->getZExtValue();
}
-// getExitCount - Get the expression for the number of loop iterations for which
-// this loop is guaranteed not to exit via ExitingBlock. Otherwise return
-// SCEVCouldNotCompute.
+/// Get the expression for the number of loop iterations for which this loop is
+/// guaranteed not to exit via ExitingBlock. Otherwise return
+/// SCEVCouldNotCompute.
const SCEV *ScalarEvolution::getExitCount(Loop *L, BasicBlock *ExitingBlock) {
return getBackedgeTakenInfo(L).getExact(ExitingBlock, this);
}
-/// getBackedgeTakenCount - If the specified loop has a predictable
-/// backedge-taken count, return it, otherwise return a SCEVCouldNotCompute
-/// object. The backedge-taken count is the number of times the loop header
-/// will be branched to from within the loop. This is one less than the
-/// trip count of the loop, since it doesn't count the first iteration,
-/// when the header is branched to from outside the loop.
-///
-/// Note that it is not valid to call this method on a loop without a
-/// loop-invariant backedge-taken count (see
-/// hasLoopInvariantBackedgeTakenCount).
-///
+const SCEV *
+ScalarEvolution::getPredicatedBackedgeTakenCount(const Loop *L,
+ SCEVUnionPredicate &Preds) {
+ return getPredicatedBackedgeTakenInfo(L).getExact(this, &Preds);
+}
+
const SCEV *ScalarEvolution::getBackedgeTakenCount(const Loop *L) {
return getBackedgeTakenInfo(L).getExact(this);
}
-/// getMaxBackedgeTakenCount - Similar to getBackedgeTakenCount, except
-/// return the least SCEV value that is known never to be less than the
-/// actual backedge taken count.
+/// Similar to getBackedgeTakenCount, except return the least SCEV value that is
+/// known never to be less than the actual backedge taken count.
const SCEV *ScalarEvolution::getMaxBackedgeTakenCount(const Loop *L) {
return getBackedgeTakenInfo(L).getMax(this);
}
-/// PushLoopPHIs - Push PHI nodes in the header of the given loop
-/// onto the given Worklist.
+/// Push PHI nodes in the header of the given loop onto the given Worklist.
static void
PushLoopPHIs(const Loop *L, SmallVectorImpl<Instruction *> &Worklist) {
BasicBlock *Header = L->getHeader();
@@ -4933,6 +5363,23 @@ PushLoopPHIs(const Loop *L, SmallVectorImpl<Instruction *> &Worklist) {
}
const ScalarEvolution::BackedgeTakenInfo &
+ScalarEvolution::getPredicatedBackedgeTakenInfo(const Loop *L) {
+ auto &BTI = getBackedgeTakenInfo(L);
+ if (BTI.hasFullInfo())
+ return BTI;
+
+ auto Pair = PredicatedBackedgeTakenCounts.insert({L, BackedgeTakenInfo()});
+
+ if (!Pair.second)
+ return Pair.first->second;
+
+ BackedgeTakenInfo Result =
+ computeBackedgeTakenCount(L, /*AllowPredicates=*/true);
+
+ return PredicatedBackedgeTakenCounts.find(L)->second = Result;
+}
+
+const ScalarEvolution::BackedgeTakenInfo &
ScalarEvolution::getBackedgeTakenInfo(const Loop *L) {
// Initially insert an invalid entry for this loop. If the insertion
// succeeds, proceed to actually compute a backedge-taken count and
@@ -4940,7 +5387,7 @@ ScalarEvolution::getBackedgeTakenInfo(const Loop *L) {
// code elsewhere that it shouldn't attempt to request a new
// backedge-taken count, which could result in infinite recursion.
std::pair<DenseMap<const Loop *, BackedgeTakenInfo>::iterator, bool> Pair =
- BackedgeTakenCounts.insert(std::make_pair(L, BackedgeTakenInfo()));
+ BackedgeTakenCounts.insert({L, BackedgeTakenInfo()});
if (!Pair.second)
return Pair.first->second;
@@ -5007,17 +5454,19 @@ ScalarEvolution::getBackedgeTakenInfo(const Loop *L) {
return BackedgeTakenCounts.find(L)->second = Result;
}
-/// forgetLoop - This method should be called by the client when it has
-/// changed a loop in a way that may effect ScalarEvolution's ability to
-/// compute a trip count, or if the loop is deleted.
void ScalarEvolution::forgetLoop(const Loop *L) {
// Drop any stored trip count value.
- DenseMap<const Loop*, BackedgeTakenInfo>::iterator BTCPos =
- BackedgeTakenCounts.find(L);
- if (BTCPos != BackedgeTakenCounts.end()) {
- BTCPos->second.clear();
- BackedgeTakenCounts.erase(BTCPos);
- }
+ auto RemoveLoopFromBackedgeMap =
+ [L](DenseMap<const Loop *, BackedgeTakenInfo> &Map) {
+ auto BTCPos = Map.find(L);
+ if (BTCPos != Map.end()) {
+ BTCPos->second.clear();
+ Map.erase(BTCPos);
+ }
+ };
+
+ RemoveLoopFromBackedgeMap(BackedgeTakenCounts);
+ RemoveLoopFromBackedgeMap(PredicatedBackedgeTakenCounts);
// Drop information about expressions based on loop-header PHIs.
SmallVector<Instruction *, 16> Worklist;
@@ -5043,13 +5492,12 @@ void ScalarEvolution::forgetLoop(const Loop *L) {
// Forget all contained loops too, to avoid dangling entries in the
// ValuesAtScopes map.
- for (Loop::iterator I = L->begin(), E = L->end(); I != E; ++I)
- forgetLoop(*I);
+ for (Loop *I : *L)
+ forgetLoop(I);
+
+ LoopHasNoAbnormalExits.erase(L);
}
-/// forgetValue - This method should be called by the client when it has
-/// changed a value in a way that may effect its value, or which may
-/// disconnect it from a def-use chain linking it to a loop.
void ScalarEvolution::forgetValue(Value *V) {
Instruction *I = dyn_cast<Instruction>(V);
if (!I) return;
@@ -5077,16 +5525,17 @@ void ScalarEvolution::forgetValue(Value *V) {
}
}
-/// getExact - Get the exact loop backedge taken count considering all loop
-/// exits. A computable result can only be returned for loops with a single
-/// exit. Returning the minimum taken count among all exits is incorrect
-/// because one of the loop's exit limit's may have been skipped. HowFarToZero
-/// assumes that the limit of each loop test is never skipped. This is a valid
-/// assumption as long as the loop exits via that test. For precise results, it
-/// is the caller's responsibility to specify the relevant loop exit using
+/// Get the exact loop backedge taken count considering all loop exits. A
+/// computable result can only be returned for loops with a single exit.
+/// Returning the minimum taken count among all exits is incorrect because one
+/// of the loop's exit limit's may have been skipped. howFarToZero assumes that
+/// the limit of each loop test is never skipped. This is a valid assumption as
+/// long as the loop exits via that test. For precise results, it is the
+/// caller's responsibility to specify the relevant loop exit using
/// getExact(ExitingBlock, SE).
const SCEV *
-ScalarEvolution::BackedgeTakenInfo::getExact(ScalarEvolution *SE) const {
+ScalarEvolution::BackedgeTakenInfo::getExact(
+ ScalarEvolution *SE, SCEVUnionPredicate *Preds) const {
// If any exits were not computable, the loop is not computable.
if (!ExitNotTaken.isCompleteList()) return SE->getCouldNotCompute();
@@ -5095,36 +5544,42 @@ ScalarEvolution::BackedgeTakenInfo::getExact(ScalarEvolution *SE) const {
assert(ExitNotTaken.ExactNotTaken && "uninitialized not-taken info");
const SCEV *BECount = nullptr;
- for (const ExitNotTakenInfo *ENT = &ExitNotTaken;
- ENT != nullptr; ENT = ENT->getNextExit()) {
-
- assert(ENT->ExactNotTaken != SE->getCouldNotCompute() && "bad exit SCEV");
+ for (auto &ENT : ExitNotTaken) {
+ assert(ENT.ExactNotTaken != SE->getCouldNotCompute() && "bad exit SCEV");
if (!BECount)
- BECount = ENT->ExactNotTaken;
- else if (BECount != ENT->ExactNotTaken)
+ BECount = ENT.ExactNotTaken;
+ else if (BECount != ENT.ExactNotTaken)
return SE->getCouldNotCompute();
+ if (Preds && ENT.getPred())
+ Preds->add(ENT.getPred());
+
+ assert((Preds || ENT.hasAlwaysTruePred()) &&
+ "Predicate should be always true!");
}
+
assert(BECount && "Invalid not taken count for loop exit");
return BECount;
}
-/// getExact - Get the exact not taken count for this loop exit.
+/// Get the exact not taken count for this loop exit.
const SCEV *
ScalarEvolution::BackedgeTakenInfo::getExact(BasicBlock *ExitingBlock,
ScalarEvolution *SE) const {
- for (const ExitNotTakenInfo *ENT = &ExitNotTaken;
- ENT != nullptr; ENT = ENT->getNextExit()) {
+ for (auto &ENT : ExitNotTaken)
+ if (ENT.ExitingBlock == ExitingBlock && ENT.hasAlwaysTruePred())
+ return ENT.ExactNotTaken;
- if (ENT->ExitingBlock == ExitingBlock)
- return ENT->ExactNotTaken;
- }
return SE->getCouldNotCompute();
}
/// getMax - Get the max backedge taken count for the loop.
const SCEV *
ScalarEvolution::BackedgeTakenInfo::getMax(ScalarEvolution *SE) const {
+ for (auto &ENT : ExitNotTaken)
+ if (!ENT.hasAlwaysTruePred())
+ return SE->getCouldNotCompute();
+
return Max ? Max : SE->getCouldNotCompute();
}
@@ -5136,22 +5591,19 @@ bool ScalarEvolution::BackedgeTakenInfo::hasOperand(const SCEV *S,
if (!ExitNotTaken.ExitingBlock)
return false;
- for (const ExitNotTakenInfo *ENT = &ExitNotTaken;
- ENT != nullptr; ENT = ENT->getNextExit()) {
-
- if (ENT->ExactNotTaken != SE->getCouldNotCompute()
- && SE->hasOperand(ENT->ExactNotTaken, S)) {
+ for (auto &ENT : ExitNotTaken)
+ if (ENT.ExactNotTaken != SE->getCouldNotCompute() &&
+ SE->hasOperand(ENT.ExactNotTaken, S))
return true;
- }
- }
+
return false;
}
/// Allocate memory for BackedgeTakenInfo and copy the not-taken count of each
/// computable exit into a persistent ExitNotTakenInfo array.
ScalarEvolution::BackedgeTakenInfo::BackedgeTakenInfo(
- SmallVectorImpl< std::pair<BasicBlock *, const SCEV *> > &ExitCounts,
- bool Complete, const SCEV *MaxCount) : Max(MaxCount) {
+ SmallVectorImpl<EdgeInfo> &ExitCounts, bool Complete, const SCEV *MaxCount)
+ : Max(MaxCount) {
if (!Complete)
ExitNotTaken.setIncomplete();
@@ -5159,36 +5611,63 @@ ScalarEvolution::BackedgeTakenInfo::BackedgeTakenInfo(
unsigned NumExits = ExitCounts.size();
if (NumExits == 0) return;
- ExitNotTaken.ExitingBlock = ExitCounts[0].first;
- ExitNotTaken.ExactNotTaken = ExitCounts[0].second;
- if (NumExits == 1) return;
+ ExitNotTaken.ExitingBlock = ExitCounts[0].ExitBlock;
+ ExitNotTaken.ExactNotTaken = ExitCounts[0].Taken;
+
+ // Determine the number of ExitNotTakenExtras structures that we need.
+ unsigned ExtraInfoSize = 0;
+ if (NumExits > 1)
+ ExtraInfoSize = 1 + std::count_if(std::next(ExitCounts.begin()),
+ ExitCounts.end(), [](EdgeInfo &Entry) {
+ return !Entry.Pred.isAlwaysTrue();
+ });
+ else if (!ExitCounts[0].Pred.isAlwaysTrue())
+ ExtraInfoSize = 1;
+
+ ExitNotTakenExtras *ENT = nullptr;
+
+ // Allocate the ExitNotTakenExtras structures and initialize the first
+ // element (ExitNotTaken).
+ if (ExtraInfoSize > 0) {
+ ENT = new ExitNotTakenExtras[ExtraInfoSize];
+ ExitNotTaken.ExtraInfo = &ENT[0];
+ *ExitNotTaken.getPred() = std::move(ExitCounts[0].Pred);
+ }
+
+ if (NumExits == 1)
+ return;
+
+ assert(ENT && "ExitNotTakenExtras is NULL while having more than one exit");
+
+ auto &Exits = ExitNotTaken.ExtraInfo->Exits;
// Handle the rare case of multiple computable exits.
- ExitNotTakenInfo *ENT = new ExitNotTakenInfo[NumExits-1];
+ for (unsigned i = 1, PredPos = 1; i < NumExits; ++i) {
+ ExitNotTakenExtras *Ptr = nullptr;
+ if (!ExitCounts[i].Pred.isAlwaysTrue()) {
+ Ptr = &ENT[PredPos++];
+ Ptr->Pred = std::move(ExitCounts[i].Pred);
+ }
- ExitNotTakenInfo *PrevENT = &ExitNotTaken;
- for (unsigned i = 1; i < NumExits; ++i, PrevENT = ENT, ++ENT) {
- PrevENT->setNextExit(ENT);
- ENT->ExitingBlock = ExitCounts[i].first;
- ENT->ExactNotTaken = ExitCounts[i].second;
+ Exits.emplace_back(ExitCounts[i].ExitBlock, ExitCounts[i].Taken, Ptr);
}
}
-/// clear - Invalidate this result and free the ExitNotTakenInfo array.
+/// Invalidate this result and free the ExitNotTakenInfo array.
void ScalarEvolution::BackedgeTakenInfo::clear() {
ExitNotTaken.ExitingBlock = nullptr;
ExitNotTaken.ExactNotTaken = nullptr;
- delete[] ExitNotTaken.getNextExit();
+ delete[] ExitNotTaken.ExtraInfo;
}
-/// computeBackedgeTakenCount - Compute the number of times the backedge
-/// of the specified loop will execute.
+/// Compute the number of times the backedge of the specified loop will execute.
ScalarEvolution::BackedgeTakenInfo
-ScalarEvolution::computeBackedgeTakenCount(const Loop *L) {
+ScalarEvolution::computeBackedgeTakenCount(const Loop *L,
+ bool AllowPredicates) {
SmallVector<BasicBlock *, 8> ExitingBlocks;
L->getExitingBlocks(ExitingBlocks);
- SmallVector<std::pair<BasicBlock *, const SCEV *>, 4> ExitCounts;
+ SmallVector<EdgeInfo, 4> ExitCounts;
bool CouldComputeBECount = true;
BasicBlock *Latch = L->getLoopLatch(); // may be NULL.
const SCEV *MustExitMaxBECount = nullptr;
@@ -5196,9 +5675,13 @@ ScalarEvolution::computeBackedgeTakenCount(const Loop *L) {
// Compute the ExitLimit for each loop exit. Use this to populate ExitCounts
// and compute maxBECount.
+ // Do a union of all the predicates here.
for (unsigned i = 0, e = ExitingBlocks.size(); i != e; ++i) {
BasicBlock *ExitBB = ExitingBlocks[i];
- ExitLimit EL = computeExitLimit(L, ExitBB);
+ ExitLimit EL = computeExitLimit(L, ExitBB, AllowPredicates);
+
+ assert((AllowPredicates || EL.Pred.isAlwaysTrue()) &&
+ "Predicated exit limit when predicates are not allowed!");
// 1. For each exit that can be computed, add an entry to ExitCounts.
// CouldComputeBECount is true only if all exits can be computed.
@@ -5207,7 +5690,7 @@ ScalarEvolution::computeBackedgeTakenCount(const Loop *L) {
// we won't be able to compute an exact value for the loop.
CouldComputeBECount = false;
else
- ExitCounts.push_back(std::make_pair(ExitBB, EL.Exact));
+ ExitCounts.emplace_back(EdgeInfo(ExitBB, EL.Exact, EL.Pred));
// 2. Derive the loop's MaxBECount from each exit's max number of
// non-exiting iterations. Partition the loop exits into two kinds:
@@ -5241,20 +5724,20 @@ ScalarEvolution::computeBackedgeTakenCount(const Loop *L) {
}
ScalarEvolution::ExitLimit
-ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock) {
+ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock,
+ bool AllowPredicates) {
// Okay, we've chosen an exiting block. See what condition causes us to exit
// at this block and remember the exit block and whether all other targets
// lead to the loop header.
bool MustExecuteLoopHeader = true;
BasicBlock *Exit = nullptr;
- for (succ_iterator SI = succ_begin(ExitingBlock), SE = succ_end(ExitingBlock);
- SI != SE; ++SI)
- if (!L->contains(*SI)) {
+ for (auto *SBB : successors(ExitingBlock))
+ if (!L->contains(SBB)) {
if (Exit) // Multiple exit successors.
return getCouldNotCompute();
- Exit = *SI;
- } else if (*SI != L->getHeader()) {
+ Exit = SBB;
+ } else if (SBB != L->getHeader()) {
MustExecuteLoopHeader = false;
}
@@ -5307,9 +5790,9 @@ ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock) {
if (BranchInst *BI = dyn_cast<BranchInst>(Term)) {
assert(BI->isConditional() && "If unconditional, it can't be in loop!");
// Proceed to the next level to examine the exit condition expression.
- return computeExitLimitFromCond(L, BI->getCondition(), BI->getSuccessor(0),
- BI->getSuccessor(1),
- /*ControlsExit=*/IsOnlyExit);
+ return computeExitLimitFromCond(
+ L, BI->getCondition(), BI->getSuccessor(0), BI->getSuccessor(1),
+ /*ControlsExit=*/IsOnlyExit, AllowPredicates);
}
if (SwitchInst *SI = dyn_cast<SwitchInst>(Term))
@@ -5319,29 +5802,24 @@ ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock) {
return getCouldNotCompute();
}
-/// computeExitLimitFromCond - Compute the number of times the
-/// backedge of the specified loop will execute if its exit condition
-/// were a conditional branch of ExitCond, TBB, and FBB.
-///
-/// @param ControlsExit is true if ExitCond directly controls the exit
-/// branch. In this case, we can assume that the loop exits only if the
-/// condition is true and can infer that failing to meet the condition prior to
-/// integer wraparound results in undefined behavior.
ScalarEvolution::ExitLimit
ScalarEvolution::computeExitLimitFromCond(const Loop *L,
Value *ExitCond,
BasicBlock *TBB,
BasicBlock *FBB,
- bool ControlsExit) {
+ bool ControlsExit,
+ bool AllowPredicates) {
// Check if the controlling expression for this loop is an And or Or.
if (BinaryOperator *BO = dyn_cast<BinaryOperator>(ExitCond)) {
if (BO->getOpcode() == Instruction::And) {
// Recurse on the operands of the and.
bool EitherMayExit = L->contains(TBB);
ExitLimit EL0 = computeExitLimitFromCond(L, BO->getOperand(0), TBB, FBB,
- ControlsExit && !EitherMayExit);
+ ControlsExit && !EitherMayExit,
+ AllowPredicates);
ExitLimit EL1 = computeExitLimitFromCond(L, BO->getOperand(1), TBB, FBB,
- ControlsExit && !EitherMayExit);
+ ControlsExit && !EitherMayExit,
+ AllowPredicates);
const SCEV *BECount = getCouldNotCompute();
const SCEV *MaxBECount = getCouldNotCompute();
if (EitherMayExit) {
@@ -5368,6 +5846,9 @@ ScalarEvolution::computeExitLimitFromCond(const Loop *L,
BECount = EL0.Exact;
}
+ SCEVUnionPredicate NP;
+ NP.add(&EL0.Pred);
+ NP.add(&EL1.Pred);
// There are cases (e.g. PR26207) where computeExitLimitFromCond is able
// to be more aggressive when computing BECount than when computing
// MaxBECount. In these cases it is possible for EL0.Exact and EL1.Exact
@@ -5376,15 +5857,17 @@ ScalarEvolution::computeExitLimitFromCond(const Loop *L,
!isa<SCEVCouldNotCompute>(BECount))
MaxBECount = BECount;
- return ExitLimit(BECount, MaxBECount);
+ return ExitLimit(BECount, MaxBECount, NP);
}
if (BO->getOpcode() == Instruction::Or) {
// Recurse on the operands of the or.
bool EitherMayExit = L->contains(FBB);
ExitLimit EL0 = computeExitLimitFromCond(L, BO->getOperand(0), TBB, FBB,
- ControlsExit && !EitherMayExit);
+ ControlsExit && !EitherMayExit,
+ AllowPredicates);
ExitLimit EL1 = computeExitLimitFromCond(L, BO->getOperand(1), TBB, FBB,
- ControlsExit && !EitherMayExit);
+ ControlsExit && !EitherMayExit,
+ AllowPredicates);
const SCEV *BECount = getCouldNotCompute();
const SCEV *MaxBECount = getCouldNotCompute();
if (EitherMayExit) {
@@ -5411,14 +5894,25 @@ ScalarEvolution::computeExitLimitFromCond(const Loop *L,
BECount = EL0.Exact;
}
- return ExitLimit(BECount, MaxBECount);
+ SCEVUnionPredicate NP;
+ NP.add(&EL0.Pred);
+ NP.add(&EL1.Pred);
+ return ExitLimit(BECount, MaxBECount, NP);
}
}
// With an icmp, it may be feasible to compute an exact backedge-taken count.
// Proceed to the next level to examine the icmp.
- if (ICmpInst *ExitCondICmp = dyn_cast<ICmpInst>(ExitCond))
- return computeExitLimitFromICmp(L, ExitCondICmp, TBB, FBB, ControlsExit);
+ if (ICmpInst *ExitCondICmp = dyn_cast<ICmpInst>(ExitCond)) {
+ ExitLimit EL =
+ computeExitLimitFromICmp(L, ExitCondICmp, TBB, FBB, ControlsExit);
+ if (EL.hasFullInfo() || !AllowPredicates)
+ return EL;
+
+ // Try again, but use SCEV predicates this time.
+ return computeExitLimitFromICmp(L, ExitCondICmp, TBB, FBB, ControlsExit,
+ /*AllowPredicates=*/true);
+ }
// Check for a constant condition. These are normally stripped out by
// SimplifyCFG, but ScalarEvolution may be used by a pass which wishes to
@@ -5442,7 +5936,8 @@ ScalarEvolution::computeExitLimitFromICmp(const Loop *L,
ICmpInst *ExitCond,
BasicBlock *TBB,
BasicBlock *FBB,
- bool ControlsExit) {
+ bool ControlsExit,
+ bool AllowPredicates) {
// If the condition was exit on true, convert the condition to exit on false
ICmpInst::Predicate Cond;
@@ -5460,11 +5955,6 @@ ScalarEvolution::computeExitLimitFromICmp(const Loop *L,
return ItCnt;
}
- ExitLimit ShiftEL = computeShiftCompareExitLimit(
- ExitCond->getOperand(0), ExitCond->getOperand(1), L, Cond);
- if (ShiftEL.hasAnyInfo())
- return ShiftEL;
-
const SCEV *LHS = getSCEV(ExitCond->getOperand(0));
const SCEV *RHS = getSCEV(ExitCond->getOperand(1));
@@ -5499,34 +5989,46 @@ ScalarEvolution::computeExitLimitFromICmp(const Loop *L,
switch (Cond) {
case ICmpInst::ICMP_NE: { // while (X != Y)
// Convert to: while (X-Y != 0)
- ExitLimit EL = HowFarToZero(getMinusSCEV(LHS, RHS), L, ControlsExit);
+ ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsExit,
+ AllowPredicates);
if (EL.hasAnyInfo()) return EL;
break;
}
case ICmpInst::ICMP_EQ: { // while (X == Y)
// Convert to: while (X-Y == 0)
- ExitLimit EL = HowFarToNonZero(getMinusSCEV(LHS, RHS), L);
+ ExitLimit EL = howFarToNonZero(getMinusSCEV(LHS, RHS), L);
if (EL.hasAnyInfo()) return EL;
break;
}
case ICmpInst::ICMP_SLT:
case ICmpInst::ICMP_ULT: { // while (X < Y)
bool IsSigned = Cond == ICmpInst::ICMP_SLT;
- ExitLimit EL = HowManyLessThans(LHS, RHS, L, IsSigned, ControlsExit);
+ ExitLimit EL = howManyLessThans(LHS, RHS, L, IsSigned, ControlsExit,
+ AllowPredicates);
if (EL.hasAnyInfo()) return EL;
break;
}
case ICmpInst::ICMP_SGT:
case ICmpInst::ICMP_UGT: { // while (X > Y)
bool IsSigned = Cond == ICmpInst::ICMP_SGT;
- ExitLimit EL = HowManyGreaterThans(LHS, RHS, L, IsSigned, ControlsExit);
+ ExitLimit EL =
+ howManyGreaterThans(LHS, RHS, L, IsSigned, ControlsExit,
+ AllowPredicates);
if (EL.hasAnyInfo()) return EL;
break;
}
default:
break;
}
- return computeExitCountExhaustively(L, ExitCond, !L->contains(TBB));
+
+ auto *ExhaustiveCount =
+ computeExitCountExhaustively(L, ExitCond, !L->contains(TBB));
+
+ if (!isa<SCEVCouldNotCompute>(ExhaustiveCount))
+ return ExhaustiveCount;
+
+ return computeShiftCompareExitLimit(ExitCond->getOperand(0),
+ ExitCond->getOperand(1), L, Cond);
}
ScalarEvolution::ExitLimit
@@ -5546,7 +6048,7 @@ ScalarEvolution::computeExitLimitFromSingleExitSwitch(const Loop *L,
const SCEV *RHS = getConstant(Switch->findCaseDest(ExitingBlock));
// while (X != Y) --> while (X-Y != 0)
- ExitLimit EL = HowFarToZero(getMinusSCEV(LHS, RHS), L, ControlsExit);
+ ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsExit);
if (EL.hasAnyInfo())
return EL;
@@ -5563,9 +6065,8 @@ EvaluateConstantChrecAtConstant(const SCEVAddRecExpr *AddRec, ConstantInt *C,
return cast<SCEVConstant>(Val)->getValue();
}
-/// computeLoadConstantCompareExitLimit - Given an exit condition of
-/// 'icmp op load X, cst', try to see if we can compute the backedge
-/// execution count.
+/// Given an exit condition of 'icmp op load X, cst', try to see if we can
+/// compute the backedge execution count.
ScalarEvolution::ExitLimit
ScalarEvolution::computeLoadConstantCompareExitLimit(
LoadInst *LI,
@@ -5781,14 +6282,15 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeShiftCompareExitLimit(
unsigned BitWidth = getTypeSizeInBits(RHS->getType());
const SCEV *UpperBound =
getConstant(getEffectiveSCEVType(RHS->getType()), BitWidth);
- return ExitLimit(getCouldNotCompute(), UpperBound);
+ SCEVUnionPredicate P;
+ return ExitLimit(getCouldNotCompute(), UpperBound, P);
}
return getCouldNotCompute();
}
-/// CanConstantFold - Return true if we can constant fold an instruction of the
-/// specified type, assuming that all operands were constants.
+/// Return true if we can constant fold an instruction of the specified type,
+/// assuming that all operands were constants.
static bool CanConstantFold(const Instruction *I) {
if (isa<BinaryOperator>(I) || isa<CmpInst>(I) ||
isa<SelectInst>(I) || isa<CastInst>(I) || isa<GetElementPtrInst>(I) ||
@@ -5916,10 +6418,9 @@ static Constant *EvaluateExpression(Value *V, const Loop *L,
Operands[1], DL, TLI);
if (LoadInst *LI = dyn_cast<LoadInst>(I)) {
if (!LI->isVolatile())
- return ConstantFoldLoadFromConstPtr(Operands[0], DL);
+ return ConstantFoldLoadFromConstPtr(Operands[0], LI->getType(), DL);
}
- return ConstantFoldInstOperands(I->getOpcode(), I->getType(), Operands, DL,
- TLI);
+ return ConstantFoldInstOperands(I, Operands, DL, TLI);
}
@@ -6107,16 +6608,6 @@ const SCEV *ScalarEvolution::computeExitCountExhaustively(const Loop *L,
return getCouldNotCompute();
}
-/// getSCEVAtScope - Return a SCEV expression for the specified value
-/// at the specified scope in the program. The L value specifies a loop
-/// nest to evaluate the expression at, where null is the top-level or a
-/// specified loop is immediately inside of the loop.
-///
-/// This method can be used to compute the exit value for a variable defined
-/// in a loop by querying what the value will hold in the parent loop.
-///
-/// In the case that a relevant loop exit value cannot be computed, the
-/// original value V is returned.
const SCEV *ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) {
SmallVector<std::pair<const Loop *, const SCEV *>, 2> &Values =
ValuesAtScopes[V];
@@ -6305,10 +6796,9 @@ const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) {
Operands[1], DL, &TLI);
else if (const LoadInst *LI = dyn_cast<LoadInst>(I)) {
if (!LI->isVolatile())
- C = ConstantFoldLoadFromConstPtr(Operands[0], DL);
+ C = ConstantFoldLoadFromConstPtr(Operands[0], LI->getType(), DL);
} else
- C = ConstantFoldInstOperands(I->getOpcode(), I->getType(), Operands,
- DL, &TLI);
+ C = ConstantFoldInstOperands(I, Operands, DL, &TLI);
if (!C) return V;
return getSCEV(C);
}
@@ -6428,14 +6918,11 @@ const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) {
llvm_unreachable("Unknown SCEV type!");
}
-/// getSCEVAtScope - This is a convenience function which does
-/// getSCEVAtScope(getSCEV(V), L).
const SCEV *ScalarEvolution::getSCEVAtScope(Value *V, const Loop *L) {
return getSCEVAtScope(getSCEV(V), L);
}
-/// SolveLinEquationWithOverflow - Finds the minimum unsigned root of the
-/// following equation:
+/// Finds the minimum unsigned root of the following equation:
///
/// A * X = B (mod N)
///
@@ -6482,11 +6969,11 @@ static const SCEV *SolveLinEquationWithOverflow(const APInt &A, const APInt &B,
return SE.getConstant(Result.trunc(BW));
}
-/// SolveQuadraticEquation - Find the roots of the quadratic equation for the
-/// given quadratic chrec {L,+,M,+,N}. This returns either the two roots (which
-/// might be the same) or two SCEVCouldNotCompute objects.
+/// Find the roots of the quadratic equation for the given quadratic chrec
+/// {L,+,M,+,N}. This returns either the two roots (which might be the same) or
+/// two SCEVCouldNotCompute objects.
///
-static std::pair<const SCEV *,const SCEV *>
+static Optional<std::pair<const SCEVConstant *,const SCEVConstant *>>
SolveQuadraticEquation(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE) {
assert(AddRec->getNumOperands() == 3 && "This is not a quadratic chrec!");
const SCEVConstant *LC = dyn_cast<SCEVConstant>(AddRec->getOperand(0));
@@ -6494,10 +6981,8 @@ SolveQuadraticEquation(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE) {
const SCEVConstant *NC = dyn_cast<SCEVConstant>(AddRec->getOperand(2));
// We currently can only solve this if the coefficients are constants.
- if (!LC || !MC || !NC) {
- const SCEV *CNC = SE.getCouldNotCompute();
- return std::make_pair(CNC, CNC);
- }
+ if (!LC || !MC || !NC)
+ return None;
uint32_t BitWidth = LC->getAPInt().getBitWidth();
const APInt &L = LC->getAPInt();
@@ -6524,8 +7009,7 @@ SolveQuadraticEquation(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE) {
if (SqrtTerm.isNegative()) {
// The loop is provably infinite.
- const SCEV *CNC = SE.getCouldNotCompute();
- return std::make_pair(CNC, CNC);
+ return None;
}
// Compute sqrt(B^2-4ac). This is guaranteed to be the nearest
@@ -6536,10 +7020,8 @@ SolveQuadraticEquation(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE) {
// The divisions must be performed as signed divisions.
APInt NegB(-B);
APInt TwoA(A << 1);
- if (TwoA.isMinValue()) {
- const SCEV *CNC = SE.getCouldNotCompute();
- return std::make_pair(CNC, CNC);
- }
+ if (TwoA.isMinValue())
+ return None;
LLVMContext &Context = SE.getContext();
@@ -6548,20 +7030,21 @@ SolveQuadraticEquation(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE) {
ConstantInt *Solution2 =
ConstantInt::get(Context, (NegB - SqrtVal).sdiv(TwoA));
- return std::make_pair(SE.getConstant(Solution1),
- SE.getConstant(Solution2));
+ return std::make_pair(cast<SCEVConstant>(SE.getConstant(Solution1)),
+ cast<SCEVConstant>(SE.getConstant(Solution2)));
} // end APIntOps namespace
}
-/// HowFarToZero - Return the number of times a backedge comparing the specified
-/// value to zero will execute. If not computable, return CouldNotCompute.
-///
-/// This is only used for loops with a "x != y" exit test. The exit condition is
-/// now expressed as a single expression, V = x-y. So the exit test is
-/// effectively V != 0. We know and take advantage of the fact that this
-/// expression only being used in a comparison by zero context.
ScalarEvolution::ExitLimit
-ScalarEvolution::HowFarToZero(const SCEV *V, const Loop *L, bool ControlsExit) {
+ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L, bool ControlsExit,
+ bool AllowPredicates) {
+
+ // This is only used for loops with a "x != y" exit test. The exit condition
+ // is now expressed as a single expression, V = x-y. So the exit test is
+ // effectively V != 0. We know and take advantage of the fact that this
+ // expression only being used in a comparison by zero context.
+
+ SCEVUnionPredicate P;
// If the value is a constant
if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
// If the value is already zero, the branch will execute zero times.
@@ -6570,31 +7053,33 @@ ScalarEvolution::HowFarToZero(const SCEV *V, const Loop *L, bool ControlsExit) {
}
const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(V);
+ if (!AddRec && AllowPredicates)
+ // Try to make this an AddRec using runtime tests, in the first X
+ // iterations of this loop, where X is the SCEV expression found by the
+ // algorithm below.
+ AddRec = convertSCEVToAddRecWithPredicates(V, L, P);
+
if (!AddRec || AddRec->getLoop() != L)
return getCouldNotCompute();
// If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of
// the quadratic equation to solve it.
if (AddRec->isQuadratic() && AddRec->getType()->isIntegerTy()) {
- std::pair<const SCEV *,const SCEV *> Roots =
- SolveQuadraticEquation(AddRec, *this);
- const SCEVConstant *R1 = dyn_cast<SCEVConstant>(Roots.first);
- const SCEVConstant *R2 = dyn_cast<SCEVConstant>(Roots.second);
- if (R1 && R2) {
+ if (auto Roots = SolveQuadraticEquation(AddRec, *this)) {
+ const SCEVConstant *R1 = Roots->first;
+ const SCEVConstant *R2 = Roots->second;
// Pick the smallest positive root value.
- if (ConstantInt *CB =
- dyn_cast<ConstantInt>(ConstantExpr::getICmp(CmpInst::ICMP_ULT,
- R1->getValue(),
- R2->getValue()))) {
+ if (ConstantInt *CB = dyn_cast<ConstantInt>(ConstantExpr::getICmp(
+ CmpInst::ICMP_ULT, R1->getValue(), R2->getValue()))) {
if (!CB->getZExtValue())
- std::swap(R1, R2); // R1 is the minimum root now.
+ std::swap(R1, R2); // R1 is the minimum root now.
// We can only use this value if the chrec ends up with an exact zero
// value at this index. When solving for "X*X != 5", for example, we
// should not accept a root of 2.
const SCEV *Val = AddRec->evaluateAtIteration(R1, *this);
if (Val->isZero())
- return R1; // We found a quadratic root!
+ return ExitLimit(R1, R1, P); // We found a quadratic root!
}
}
return getCouldNotCompute();
@@ -6651,7 +7136,7 @@ ScalarEvolution::HowFarToZero(const SCEV *V, const Loop *L, bool ControlsExit) {
else
MaxBECount = getConstant(CountDown ? CR.getUnsignedMax()
: -CR.getUnsignedMin());
- return ExitLimit(Distance, MaxBECount);
+ return ExitLimit(Distance, MaxBECount, P);
}
// As a special case, handle the instance where Step is a positive power of
@@ -6704,7 +7189,9 @@ ScalarEvolution::HowFarToZero(const SCEV *V, const Loop *L, bool ControlsExit) {
auto *NarrowTy = IntegerType::get(getContext(), NarrowWidth);
auto *WideTy = Distance->getType();
- return getZeroExtendExpr(getTruncateExpr(ModuloResult, NarrowTy), WideTy);
+ const SCEV *Limit =
+ getZeroExtendExpr(getTruncateExpr(ModuloResult, NarrowTy), WideTy);
+ return ExitLimit(Limit, Limit, P);
}
}
@@ -6713,24 +7200,24 @@ ScalarEvolution::HowFarToZero(const SCEV *V, const Loop *L, bool ControlsExit) {
// compute the backedge count. In this case, the step may not divide the
// distance, but we don't care because if the condition is "missed" the loop
// will have undefined behavior due to wrapping.
- if (ControlsExit && AddRec->getNoWrapFlags(SCEV::FlagNW)) {
+ if (ControlsExit && AddRec->hasNoSelfWrap() &&
+ loopHasNoAbnormalExits(AddRec->getLoop())) {
const SCEV *Exact =
getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step);
- return ExitLimit(Exact, Exact);
+ return ExitLimit(Exact, Exact, P);
}
// Then, try to solve the above equation provided that Start is constant.
- if (const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start))
- return SolveLinEquationWithOverflow(StepC->getAPInt(), -StartC->getAPInt(),
- *this);
+ if (const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start)) {
+ const SCEV *E = SolveLinEquationWithOverflow(
+ StepC->getValue()->getValue(), -StartC->getValue()->getValue(), *this);
+ return ExitLimit(E, E, P);
+ }
return getCouldNotCompute();
}
-/// HowFarToNonZero - Return the number of times a backedge checking the
-/// specified value for nonzero will execute. If not computable, return
-/// CouldNotCompute
ScalarEvolution::ExitLimit
-ScalarEvolution::HowFarToNonZero(const SCEV *V, const Loop *L) {
+ScalarEvolution::howFarToNonZero(const SCEV *V, const Loop *L) {
// Loops that look like: while (X == 0) are very strange indeed. We don't
// handle them yet except for the trivial case. This could be expanded in the
// future as needed.
@@ -6748,33 +7235,27 @@ ScalarEvolution::HowFarToNonZero(const SCEV *V, const Loop *L) {
return getCouldNotCompute();
}
-/// getPredecessorWithUniqueSuccessorForBB - Return a predecessor of BB
-/// (which may not be an immediate predecessor) which has exactly one
-/// successor from which BB is reachable, or null if no such block is
-/// found.
-///
std::pair<BasicBlock *, BasicBlock *>
ScalarEvolution::getPredecessorWithUniqueSuccessorForBB(BasicBlock *BB) {
// If the block has a unique predecessor, then there is no path from the
// predecessor to the block that does not go through the direct edge
// from the predecessor to the block.
if (BasicBlock *Pred = BB->getSinglePredecessor())
- return std::make_pair(Pred, BB);
+ return {Pred, BB};
// A loop's header is defined to be a block that dominates the loop.
// If the header has a unique predecessor outside the loop, it must be
// a block that has exactly one successor that can reach the loop.
if (Loop *L = LI.getLoopFor(BB))
- return std::make_pair(L->getLoopPredecessor(), L->getHeader());
+ return {L->getLoopPredecessor(), L->getHeader()};
- return std::pair<BasicBlock *, BasicBlock *>();
+ return {nullptr, nullptr};
}
-/// HasSameValue - SCEV structural equivalence is usually sufficient for
-/// testing whether two expressions are equal, however for the purposes of
-/// looking for a condition guarding a loop, it can be useful to be a little
-/// more general, since a front-end may have replicated the controlling
-/// expression.
+/// SCEV structural equivalence is usually sufficient for testing whether two
+/// expressions are equal, however for the purposes of looking for a condition
+/// guarding a loop, it can be useful to be a little more general, since a
+/// front-end may have replicated the controlling expression.
///
static bool HasSameValue(const SCEV *A, const SCEV *B) {
// Quick check to see if they are the same SCEV.
@@ -6800,9 +7281,6 @@ static bool HasSameValue(const SCEV *A, const SCEV *B) {
return false;
}
-/// SimplifyICmpOperands - Simplify LHS and RHS in a comparison with
-/// predicate Pred. Return true iff any changes were made.
-///
bool ScalarEvolution::SimplifyICmpOperands(ICmpInst::Predicate &Pred,
const SCEV *&LHS, const SCEV *&RHS,
unsigned Depth) {
@@ -7134,7 +7612,7 @@ bool ScalarEvolution::isKnownPredicate(ICmpInst::Predicate Pred,
return true;
// Otherwise see what can be done with known constant ranges.
- return isKnownPredicateWithRanges(Pred, LHS, RHS);
+ return isKnownPredicateViaConstantRanges(Pred, LHS, RHS);
}
bool ScalarEvolution::isMonotonicPredicate(const SCEVAddRecExpr *LHS,
@@ -7180,7 +7658,7 @@ bool ScalarEvolution::isMonotonicPredicateImpl(const SCEVAddRecExpr *LHS,
case ICmpInst::ICMP_UGE:
case ICmpInst::ICMP_ULT:
case ICmpInst::ICMP_ULE:
- if (!LHS->getNoWrapFlags(SCEV::FlagNUW))
+ if (!LHS->hasNoUnsignedWrap())
return false;
Increasing = Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_UGE;
@@ -7190,7 +7668,7 @@ bool ScalarEvolution::isMonotonicPredicateImpl(const SCEVAddRecExpr *LHS,
case ICmpInst::ICMP_SGE:
case ICmpInst::ICMP_SLT:
case ICmpInst::ICMP_SLE: {
- if (!LHS->getNoWrapFlags(SCEV::FlagNSW))
+ if (!LHS->hasNoSignedWrap())
return false;
const SCEV *Step = LHS->getStepRecurrence(*this);
@@ -7264,78 +7742,34 @@ bool ScalarEvolution::isLoopInvariantPredicate(
return true;
}
-bool
-ScalarEvolution::isKnownPredicateWithRanges(ICmpInst::Predicate Pred,
- const SCEV *LHS, const SCEV *RHS) {
+bool ScalarEvolution::isKnownPredicateViaConstantRanges(
+ ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS) {
if (HasSameValue(LHS, RHS))
return ICmpInst::isTrueWhenEqual(Pred);
// This code is split out from isKnownPredicate because it is called from
// within isLoopEntryGuardedByCond.
- switch (Pred) {
- default:
- llvm_unreachable("Unexpected ICmpInst::Predicate value!");
- case ICmpInst::ICMP_SGT:
- std::swap(LHS, RHS);
- case ICmpInst::ICMP_SLT: {
- ConstantRange LHSRange = getSignedRange(LHS);
- ConstantRange RHSRange = getSignedRange(RHS);
- if (LHSRange.getSignedMax().slt(RHSRange.getSignedMin()))
- return true;
- if (LHSRange.getSignedMin().sge(RHSRange.getSignedMax()))
- return false;
- break;
- }
- case ICmpInst::ICMP_SGE:
- std::swap(LHS, RHS);
- case ICmpInst::ICMP_SLE: {
- ConstantRange LHSRange = getSignedRange(LHS);
- ConstantRange RHSRange = getSignedRange(RHS);
- if (LHSRange.getSignedMax().sle(RHSRange.getSignedMin()))
- return true;
- if (LHSRange.getSignedMin().sgt(RHSRange.getSignedMax()))
- return false;
- break;
- }
- case ICmpInst::ICMP_UGT:
- std::swap(LHS, RHS);
- case ICmpInst::ICMP_ULT: {
- ConstantRange LHSRange = getUnsignedRange(LHS);
- ConstantRange RHSRange = getUnsignedRange(RHS);
- if (LHSRange.getUnsignedMax().ult(RHSRange.getUnsignedMin()))
- return true;
- if (LHSRange.getUnsignedMin().uge(RHSRange.getUnsignedMax()))
- return false;
- break;
- }
- case ICmpInst::ICMP_UGE:
- std::swap(LHS, RHS);
- case ICmpInst::ICMP_ULE: {
- ConstantRange LHSRange = getUnsignedRange(LHS);
- ConstantRange RHSRange = getUnsignedRange(RHS);
- if (LHSRange.getUnsignedMax().ule(RHSRange.getUnsignedMin()))
- return true;
- if (LHSRange.getUnsignedMin().ugt(RHSRange.getUnsignedMax()))
- return false;
- break;
- }
- case ICmpInst::ICMP_NE: {
- if (getUnsignedRange(LHS).intersectWith(getUnsignedRange(RHS)).isEmptySet())
- return true;
- if (getSignedRange(LHS).intersectWith(getSignedRange(RHS)).isEmptySet())
- return true;
- const SCEV *Diff = getMinusSCEV(LHS, RHS);
- if (isKnownNonZero(Diff))
- return true;
- break;
- }
- case ICmpInst::ICMP_EQ:
- // The check at the top of the function catches the case where
- // the values are known to be equal.
- break;
- }
- return false;
+ auto CheckRanges =
+ [&](const ConstantRange &RangeLHS, const ConstantRange &RangeRHS) {
+ return ConstantRange::makeSatisfyingICmpRegion(Pred, RangeRHS)
+ .contains(RangeLHS);
+ };
+
+ // The check at the top of the function catches the case where the values are
+ // known to be equal.
+ if (Pred == CmpInst::ICMP_EQ)
+ return false;
+
+ if (Pred == CmpInst::ICMP_NE)
+ return CheckRanges(getSignedRange(LHS), getSignedRange(RHS)) ||
+ CheckRanges(getUnsignedRange(LHS), getUnsignedRange(RHS)) ||
+ isKnownNonZero(getMinusSCEV(LHS, RHS));
+
+ if (CmpInst::isSigned(Pred))
+ return CheckRanges(getSignedRange(LHS), getSignedRange(RHS));
+
+ return CheckRanges(getUnsignedRange(LHS), getUnsignedRange(RHS));
}
bool ScalarEvolution::isKnownPredicateViaNoOverflow(ICmpInst::Predicate Pred,
@@ -7416,6 +7850,23 @@ bool ScalarEvolution::isKnownPredicateViaSplitting(ICmpInst::Predicate Pred,
isKnownPredicate(CmpInst::ICMP_SLT, LHS, RHS);
}
+bool ScalarEvolution::isImpliedViaGuard(BasicBlock *BB,
+ ICmpInst::Predicate Pred,
+ const SCEV *LHS, const SCEV *RHS) {
+ // No need to even try if we know the module has no guards.
+ if (!HasGuards)
+ return false;
+
+ return any_of(*BB, [&](Instruction &I) {
+ using namespace llvm::PatternMatch;
+
+ Value *Condition;
+ return match(&I, m_Intrinsic<Intrinsic::experimental_guard>(
+ m_Value(Condition))) &&
+ isImpliedCond(Pred, LHS, RHS, Condition, false);
+ });
+}
+
/// isLoopBackedgeGuardedByCond - Test whether the backedge of the loop is
/// protected by a conditional between LHS and RHS. This is used to
/// to eliminate casts.
@@ -7427,7 +7878,8 @@ ScalarEvolution::isLoopBackedgeGuardedByCond(const Loop *L,
// (interprocedural conditions notwithstanding).
if (!L) return true;
- if (isKnownPredicateWithRanges(Pred, LHS, RHS)) return true;
+ if (isKnownPredicateViaConstantRanges(Pred, LHS, RHS))
+ return true;
BasicBlock *Latch = L->getLoopLatch();
if (!Latch)
@@ -7482,12 +7934,18 @@ ScalarEvolution::isLoopBackedgeGuardedByCond(const Loop *L,
if (!DT.isReachableFromEntry(L->getHeader()))
return false;
+ if (isImpliedViaGuard(Latch, Pred, LHS, RHS))
+ return true;
+
for (DomTreeNode *DTN = DT[Latch], *HeaderDTN = DT[L->getHeader()];
DTN != HeaderDTN; DTN = DTN->getIDom()) {
assert(DTN && "should reach the loop header before reaching the root!");
BasicBlock *BB = DTN->getBlock();
+ if (isImpliedViaGuard(BB, Pred, LHS, RHS))
+ return true;
+
BasicBlock *PBB = BB->getSinglePredecessor();
if (!PBB)
continue;
@@ -7518,9 +7976,6 @@ ScalarEvolution::isLoopBackedgeGuardedByCond(const Loop *L,
return false;
}
-/// isLoopEntryGuardedByCond - Test whether entry to the loop is protected
-/// by a conditional between LHS and RHS. This is used to help avoid max
-/// expressions in loop trip counts, and to eliminate casts.
bool
ScalarEvolution::isLoopEntryGuardedByCond(const Loop *L,
ICmpInst::Predicate Pred,
@@ -7529,7 +7984,8 @@ ScalarEvolution::isLoopEntryGuardedByCond(const Loop *L,
// (interprocedural conditions notwithstanding).
if (!L) return false;
- if (isKnownPredicateWithRanges(Pred, LHS, RHS)) return true;
+ if (isKnownPredicateViaConstantRanges(Pred, LHS, RHS))
+ return true;
// Starting at the loop predecessor, climb up the predecessor chain, as long
// as there are predecessors that can be found that have unique successors
@@ -7539,6 +7995,9 @@ ScalarEvolution::isLoopEntryGuardedByCond(const Loop *L,
Pair.first;
Pair = getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
+ if (isImpliedViaGuard(Pair.first, Pred, LHS, RHS))
+ return true;
+
BranchInst *LoopEntryPredicate =
dyn_cast<BranchInst>(Pair.first->getTerminator());
if (!LoopEntryPredicate ||
@@ -7586,8 +8045,6 @@ struct MarkPendingLoopPredicate {
};
} // end anonymous namespace
-/// isImpliedCond - Test whether the condition described by Pred, LHS,
-/// and RHS is true whenever the given Cond value evaluates to true.
bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred,
const SCEV *LHS, const SCEV *RHS,
Value *FoundCondValue,
@@ -7910,9 +8367,6 @@ bool ScalarEvolution::isImpliedCondOperandsViaNoOverflow(
getConstant(FoundRHSLimit));
}
-/// isImpliedCondOperands - Test whether the condition described by Pred,
-/// LHS, and RHS is true whenever the condition described by Pred, FoundLHS,
-/// and FoundRHS is true.
bool ScalarEvolution::isImpliedCondOperands(ICmpInst::Predicate Pred,
const SCEV *LHS, const SCEV *RHS,
const SCEV *FoundLHS,
@@ -8037,9 +8491,6 @@ static bool IsKnownPredicateViaMinOrMax(ScalarEvolution &SE,
llvm_unreachable("covered switch fell through?!");
}
-/// isImpliedCondOperandsHelper - Test whether the condition described by
-/// Pred, LHS, and RHS is true whenever the condition described by Pred,
-/// FoundLHS, and FoundRHS is true.
bool
ScalarEvolution::isImpliedCondOperandsHelper(ICmpInst::Predicate Pred,
const SCEV *LHS, const SCEV *RHS,
@@ -8047,7 +8498,7 @@ ScalarEvolution::isImpliedCondOperandsHelper(ICmpInst::Predicate Pred,
const SCEV *FoundRHS) {
auto IsKnownPredicateFull =
[this](ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS) {
- return isKnownPredicateWithRanges(Pred, LHS, RHS) ||
+ return isKnownPredicateViaConstantRanges(Pred, LHS, RHS) ||
IsKnownPredicateViaMinOrMax(*this, Pred, LHS, RHS) ||
IsKnownPredicateViaAddRecStart(*this, Pred, LHS, RHS) ||
isKnownPredicateViaNoOverflow(Pred, LHS, RHS);
@@ -8089,8 +8540,6 @@ ScalarEvolution::isImpliedCondOperandsHelper(ICmpInst::Predicate Pred,
return false;
}
-/// isImpliedCondOperandsViaRanges - helper function for isImpliedCondOperands.
-/// Tries to get cases like "X `sgt` 0 => X - 1 `sgt` -1".
bool ScalarEvolution::isImpliedCondOperandsViaRanges(ICmpInst::Predicate Pred,
const SCEV *LHS,
const SCEV *RHS,
@@ -8129,9 +8578,6 @@ bool ScalarEvolution::isImpliedCondOperandsViaRanges(ICmpInst::Predicate Pred,
return SatisfyingLHSRange.contains(LHSRange);
}
-// Verify if an linear IV with positive stride can overflow when in a
-// less-than comparison, knowing the invariant term of the comparison, the
-// stride and the knowledge of NSW/NUW flags on the recurrence.
bool ScalarEvolution::doesIVOverflowOnLT(const SCEV *RHS, const SCEV *Stride,
bool IsSigned, bool NoWrap) {
if (NoWrap) return false;
@@ -8158,9 +8604,6 @@ bool ScalarEvolution::doesIVOverflowOnLT(const SCEV *RHS, const SCEV *Stride,
return (MaxValue - MaxStrideMinusOne).ult(MaxRHS);
}
-// Verify if an linear IV with negative stride can overflow when in a
-// greater-than comparison, knowing the invariant term of the comparison,
-// the stride and the knowledge of NSW/NUW flags on the recurrence.
bool ScalarEvolution::doesIVOverflowOnGT(const SCEV *RHS, const SCEV *Stride,
bool IsSigned, bool NoWrap) {
if (NoWrap) return false;
@@ -8187,8 +8630,6 @@ bool ScalarEvolution::doesIVOverflowOnGT(const SCEV *RHS, const SCEV *Stride,
return (MinValue + MaxStrideMinusOne).ugt(MinRHS);
}
-// Compute the backedge taken count knowing the interval difference, the
-// stride and presence of the equality in the comparison.
const SCEV *ScalarEvolution::computeBECount(const SCEV *Delta, const SCEV *Step,
bool Equality) {
const SCEV *One = getOne(Step->getType());
@@ -8197,22 +8638,21 @@ const SCEV *ScalarEvolution::computeBECount(const SCEV *Delta, const SCEV *Step,
return getUDivExpr(Delta, Step);
}
-/// HowManyLessThans - Return the number of times a backedge containing the
-/// specified less-than comparison will execute. If not computable, return
-/// CouldNotCompute.
-///
-/// @param ControlsExit is true when the LHS < RHS condition directly controls
-/// the branch (loops exits only if condition is true). In this case, we can use
-/// NoWrapFlags to skip overflow checks.
ScalarEvolution::ExitLimit
-ScalarEvolution::HowManyLessThans(const SCEV *LHS, const SCEV *RHS,
+ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
const Loop *L, bool IsSigned,
- bool ControlsExit) {
+ bool ControlsExit, bool AllowPredicates) {
+ SCEVUnionPredicate P;
// We handle only IV < Invariant
if (!isLoopInvariant(RHS, L))
return getCouldNotCompute();
const SCEVAddRecExpr *IV = dyn_cast<SCEVAddRecExpr>(LHS);
+ if (!IV && AllowPredicates)
+ // Try to make this an AddRec using runtime tests, in the first X
+ // iterations of this loop, where X is the SCEV expression found by the
+ // algorithm below.
+ IV = convertSCEVToAddRecWithPredicates(LHS, L, P);
// Avoid weird loops
if (!IV || IV->getLoop() != L || !IV->isAffine())
@@ -8238,19 +8678,8 @@ ScalarEvolution::HowManyLessThans(const SCEV *LHS, const SCEV *RHS,
: ICmpInst::ICMP_ULT;
const SCEV *Start = IV->getStart();
const SCEV *End = RHS;
- if (!isLoopEntryGuardedByCond(L, Cond, getMinusSCEV(Start, Stride), RHS)) {
- const SCEV *Diff = getMinusSCEV(RHS, Start);
- // If we have NoWrap set, then we can assume that the increment won't
- // overflow, in which case if RHS - Start is a constant, we don't need to
- // do a max operation since we can just figure it out statically
- if (NoWrap && isa<SCEVConstant>(Diff)) {
- APInt D = dyn_cast<const SCEVConstant>(Diff)->getAPInt();
- if (D.isNegative())
- End = Start;
- } else
- End = IsSigned ? getSMaxExpr(RHS, Start)
- : getUMaxExpr(RHS, Start);
- }
+ if (!isLoopEntryGuardedByCond(L, Cond, getMinusSCEV(Start, Stride), RHS))
+ End = IsSigned ? getSMaxExpr(RHS, Start) : getUMaxExpr(RHS, Start);
const SCEV *BECount = computeBECount(getMinusSCEV(End, Start), Stride, false);
@@ -8281,18 +8710,24 @@ ScalarEvolution::HowManyLessThans(const SCEV *LHS, const SCEV *RHS,
if (isa<SCEVCouldNotCompute>(MaxBECount))
MaxBECount = BECount;
- return ExitLimit(BECount, MaxBECount);
+ return ExitLimit(BECount, MaxBECount, P);
}
ScalarEvolution::ExitLimit
-ScalarEvolution::HowManyGreaterThans(const SCEV *LHS, const SCEV *RHS,
+ScalarEvolution::howManyGreaterThans(const SCEV *LHS, const SCEV *RHS,
const Loop *L, bool IsSigned,
- bool ControlsExit) {
+ bool ControlsExit, bool AllowPredicates) {
+ SCEVUnionPredicate P;
// We handle only IV > Invariant
if (!isLoopInvariant(RHS, L))
return getCouldNotCompute();
const SCEVAddRecExpr *IV = dyn_cast<SCEVAddRecExpr>(LHS);
+ if (!IV && AllowPredicates)
+ // Try to make this an AddRec using runtime tests, in the first X
+ // iterations of this loop, where X is the SCEV expression found by the
+ // algorithm below.
+ IV = convertSCEVToAddRecWithPredicates(LHS, L, P);
// Avoid weird loops
if (!IV || IV->getLoop() != L || !IV->isAffine())
@@ -8319,19 +8754,8 @@ ScalarEvolution::HowManyGreaterThans(const SCEV *LHS, const SCEV *RHS,
const SCEV *Start = IV->getStart();
const SCEV *End = RHS;
- if (!isLoopEntryGuardedByCond(L, Cond, getAddExpr(Start, Stride), RHS)) {
- const SCEV *Diff = getMinusSCEV(RHS, Start);
- // If we have NoWrap set, then we can assume that the increment won't
- // overflow, in which case if RHS - Start is a constant, we don't need to
- // do a max operation since we can just figure it out statically
- if (NoWrap && isa<SCEVConstant>(Diff)) {
- APInt D = dyn_cast<const SCEVConstant>(Diff)->getAPInt();
- if (!D.isNegative())
- End = Start;
- } else
- End = IsSigned ? getSMinExpr(RHS, Start)
- : getUMinExpr(RHS, Start);
- }
+ if (!isLoopEntryGuardedByCond(L, Cond, getAddExpr(Start, Stride), RHS))
+ End = IsSigned ? getSMinExpr(RHS, Start) : getUMinExpr(RHS, Start);
const SCEV *BECount = computeBECount(getMinusSCEV(Start, End), Stride, false);
@@ -8363,15 +8787,10 @@ ScalarEvolution::HowManyGreaterThans(const SCEV *LHS, const SCEV *RHS,
if (isa<SCEVCouldNotCompute>(MaxBECount))
MaxBECount = BECount;
- return ExitLimit(BECount, MaxBECount);
+ return ExitLimit(BECount, MaxBECount, P);
}
-/// getNumIterationsInRange - Return the number of iterations of this loop that
-/// produce values in the specified constant range. Another way of looking at
-/// this is that it returns the first iteration number where the value is not in
-/// the condition, thus computing the exit count. If the iteration count can't
-/// be computed, an instance of SCEVCouldNotCompute is returned.
-const SCEV *SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range,
+const SCEV *SCEVAddRecExpr::getNumIterationsInRange(const ConstantRange &Range,
ScalarEvolution &SE) const {
if (Range.isFullSet()) // Infinite loop.
return SE.getCouldNotCompute();
@@ -8445,22 +8864,21 @@ const SCEV *SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range,
FlagAnyWrap);
// Next, solve the constructed addrec
- auto Roots = SolveQuadraticEquation(cast<SCEVAddRecExpr>(NewAddRec), SE);
- const SCEVConstant *R1 = dyn_cast<SCEVConstant>(Roots.first);
- const SCEVConstant *R2 = dyn_cast<SCEVConstant>(Roots.second);
- if (R1) {
+ if (auto Roots =
+ SolveQuadraticEquation(cast<SCEVAddRecExpr>(NewAddRec), SE)) {
+ const SCEVConstant *R1 = Roots->first;
+ const SCEVConstant *R2 = Roots->second;
// Pick the smallest positive root value.
if (ConstantInt *CB = dyn_cast<ConstantInt>(ConstantExpr::getICmp(
ICmpInst::ICMP_ULT, R1->getValue(), R2->getValue()))) {
if (!CB->getZExtValue())
- std::swap(R1, R2); // R1 is the minimum root now.
+ std::swap(R1, R2); // R1 is the minimum root now.
// Make sure the root is not off by one. The returned iteration should
// not be in the range, but the previous one should be. When solving
// for "X*X < 5", for example, we should not return a root of 2.
- ConstantInt *R1Val = EvaluateConstantChrecAtConstant(this,
- R1->getValue(),
- SE);
+ ConstantInt *R1Val =
+ EvaluateConstantChrecAtConstant(this, R1->getValue(), SE);
if (Range.contains(R1Val->getValue())) {
// The next iteration must be out of the range...
ConstantInt *NextVal =
@@ -8469,7 +8887,7 @@ const SCEV *SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range,
R1Val = EvaluateConstantChrecAtConstant(this, NextVal, SE);
if (!Range.contains(R1Val->getValue()))
return SE.getConstant(NextVal);
- return SE.getCouldNotCompute(); // Something strange happened
+ return SE.getCouldNotCompute(); // Something strange happened
}
// If R1 was not in the range, then it is a good return value. Make
@@ -8479,7 +8897,7 @@ const SCEV *SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range,
R1Val = EvaluateConstantChrecAtConstant(this, NextVal, SE);
if (Range.contains(R1Val->getValue()))
return R1;
- return SE.getCouldNotCompute(); // Something strange happened
+ return SE.getCouldNotCompute(); // Something strange happened
}
}
}
@@ -8789,12 +9207,9 @@ const SCEV *ScalarEvolution::getElementSize(Instruction *Inst) {
return getSizeOfExpr(ETy, Ty);
}
-/// Second step of delinearization: compute the array dimensions Sizes from the
-/// set of Terms extracted from the memory access function of this SCEVAddRec.
void ScalarEvolution::findArrayDimensions(SmallVectorImpl<const SCEV *> &Terms,
SmallVectorImpl<const SCEV *> &Sizes,
const SCEV *ElementSize) const {
-
if (Terms.size() < 1 || !ElementSize)
return;
@@ -8858,8 +9273,6 @@ void ScalarEvolution::findArrayDimensions(SmallVectorImpl<const SCEV *> &Terms,
});
}
-/// Third step of delinearization: compute the access functions for the
-/// Subscripts based on the dimensions in Sizes.
void ScalarEvolution::computeAccessFunctions(
const SCEV *Expr, SmallVectorImpl<const SCEV *> &Subscripts,
SmallVectorImpl<const SCEV *> &Sizes) {
@@ -9012,7 +9425,7 @@ void ScalarEvolution::SCEVCallbackVH::deleted() {
assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!");
if (PHINode *PN = dyn_cast<PHINode>(getValPtr()))
SE->ConstantEvolutionLoopExitValue.erase(PN);
- SE->ValueExprMap.erase(getValPtr());
+ SE->eraseValueFromMap(getValPtr());
// this now dangles!
}
@@ -9035,13 +9448,13 @@ void ScalarEvolution::SCEVCallbackVH::allUsesReplacedWith(Value *V) {
continue;
if (PHINode *PN = dyn_cast<PHINode>(U))
SE->ConstantEvolutionLoopExitValue.erase(PN);
- SE->ValueExprMap.erase(U);
+ SE->eraseValueFromMap(U);
Worklist.insert(Worklist.end(), U->user_begin(), U->user_end());
}
// Delete the Old value.
if (PHINode *PN = dyn_cast<PHINode>(Old))
SE->ConstantEvolutionLoopExitValue.erase(PN);
- SE->ValueExprMap.erase(Old);
+ SE->eraseValueFromMap(Old);
// this now dangles!
}
@@ -9059,14 +9472,31 @@ ScalarEvolution::ScalarEvolution(Function &F, TargetLibraryInfo &TLI,
CouldNotCompute(new SCEVCouldNotCompute()),
WalkingBEDominatingConds(false), ProvingSplitPredicate(false),
ValuesAtScopes(64), LoopDispositions(64), BlockDispositions(64),
- FirstUnknown(nullptr) {}
+ FirstUnknown(nullptr) {
+
+ // To use guards for proving predicates, we need to scan every instruction in
+ // relevant basic blocks, and not just terminators. Doing this is a waste of
+ // time if the IR does not actually contain any calls to
+ // @llvm.experimental.guard, so do a quick check and remember this beforehand.
+ //
+ // This pessimizes the case where a pass that preserves ScalarEvolution wants
+ // to _add_ guards to the module when there weren't any before, and wants
+ // ScalarEvolution to optimize based on those guards. For now we prefer to be
+ // efficient in lieu of being smart in that rather obscure case.
+
+ auto *GuardDecl = F.getParent()->getFunction(
+ Intrinsic::getName(Intrinsic::experimental_guard));
+ HasGuards = GuardDecl && !GuardDecl->use_empty();
+}
ScalarEvolution::ScalarEvolution(ScalarEvolution &&Arg)
- : F(Arg.F), TLI(Arg.TLI), AC(Arg.AC), DT(Arg.DT), LI(Arg.LI),
- CouldNotCompute(std::move(Arg.CouldNotCompute)),
+ : F(Arg.F), HasGuards(Arg.HasGuards), TLI(Arg.TLI), AC(Arg.AC), DT(Arg.DT),
+ LI(Arg.LI), CouldNotCompute(std::move(Arg.CouldNotCompute)),
ValueExprMap(std::move(Arg.ValueExprMap)),
WalkingBEDominatingConds(false), ProvingSplitPredicate(false),
BackedgeTakenCounts(std::move(Arg.BackedgeTakenCounts)),
+ PredicatedBackedgeTakenCounts(
+ std::move(Arg.PredicatedBackedgeTakenCounts)),
ConstantEvolutionLoopExitValue(
std::move(Arg.ConstantEvolutionLoopExitValue)),
ValuesAtScopes(std::move(Arg.ValuesAtScopes)),
@@ -9091,12 +9521,16 @@ ScalarEvolution::~ScalarEvolution() {
}
FirstUnknown = nullptr;
+ ExprValueMap.clear();
ValueExprMap.clear();
+ HasRecMap.clear();
// Free any extra memory created for ExitNotTakenInfo in the unlikely event
// that a loop had multiple computable exits.
for (auto &BTCI : BackedgeTakenCounts)
BTCI.second.clear();
+ for (auto &BTCI : PredicatedBackedgeTakenCounts)
+ BTCI.second.clear();
assert(PendingLoopPredicates.empty() && "isImpliedCond garbage");
assert(!WalkingBEDominatingConds && "isLoopBackedgeGuardedByCond garbage!");
@@ -9110,8 +9544,8 @@ bool ScalarEvolution::hasLoopInvariantBackedgeTakenCount(const Loop *L) {
static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE,
const Loop *L) {
// Print all inner loops first
- for (Loop::iterator I = L->begin(), E = L->end(); I != E; ++I)
- PrintLoopInfo(OS, SE, *I);
+ for (Loop *I : *L)
+ PrintLoopInfo(OS, SE, I);
OS << "Loop ";
L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
@@ -9139,9 +9573,35 @@ static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE,
OS << "Unpredictable max backedge-taken count. ";
}
+ OS << "\n"
+ "Loop ";
+ L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
+ OS << ": ";
+
+ SCEVUnionPredicate Pred;
+ auto PBT = SE->getPredicatedBackedgeTakenCount(L, Pred);
+ if (!isa<SCEVCouldNotCompute>(PBT)) {
+ OS << "Predicated backedge-taken count is " << *PBT << "\n";
+ OS << " Predicates:\n";
+ Pred.print(OS, 4);
+ } else {
+ OS << "Unpredictable predicated backedge-taken count. ";
+ }
OS << "\n";
}
+static StringRef loopDispositionToStr(ScalarEvolution::LoopDisposition LD) {
+ switch (LD) {
+ case ScalarEvolution::LoopVariant:
+ return "Variant";
+ case ScalarEvolution::LoopInvariant:
+ return "Invariant";
+ case ScalarEvolution::LoopComputable:
+ return "Computable";
+ }
+ llvm_unreachable("Unknown ScalarEvolution::LoopDisposition kind!");
+}
+
void ScalarEvolution::print(raw_ostream &OS) const {
// ScalarEvolution's implementation of the print method is to print
// out SCEV values of all instructions that are interesting. Doing
@@ -9189,6 +9649,35 @@ void ScalarEvolution::print(raw_ostream &OS) const {
} else {
OS << *ExitValue;
}
+
+ bool First = true;
+ for (auto *Iter = L; Iter; Iter = Iter->getParentLoop()) {
+ if (First) {
+ OS << "\t\t" "LoopDispositions: { ";
+ First = false;
+ } else {
+ OS << ", ";
+ }
+
+ Iter->getHeader()->printAsOperand(OS, /*PrintType=*/false);
+ OS << ": " << loopDispositionToStr(SE.getLoopDisposition(SV, Iter));
+ }
+
+ for (auto *InnerL : depth_first(L)) {
+ if (InnerL == L)
+ continue;
+ if (First) {
+ OS << "\t\t" "LoopDispositions: { ";
+ First = false;
+ } else {
+ OS << ", ";
+ }
+
+ InnerL->getHeader()->printAsOperand(OS, /*PrintType=*/false);
+ OS << ": " << loopDispositionToStr(SE.getLoopDisposition(SV, InnerL));
+ }
+
+ OS << " }";
}
OS << "\n";
@@ -9197,8 +9686,8 @@ void ScalarEvolution::print(raw_ostream &OS) const {
OS << "Determining loop execution counts for: ";
F.printAsOperand(OS, /*PrintType=*/false);
OS << "\n";
- for (LoopInfo::iterator I = LI.begin(), E = LI.end(); I != E; ++I)
- PrintLoopInfo(OS, &SE, *I);
+ for (Loop *I : LI)
+ PrintLoopInfo(OS, &SE, I);
}
ScalarEvolution::LoopDisposition
@@ -9420,17 +9909,23 @@ void ScalarEvolution::forgetMemoizedResults(const SCEV *S) {
BlockDispositions.erase(S);
UnsignedRanges.erase(S);
SignedRanges.erase(S);
+ ExprValueMap.erase(S);
+ HasRecMap.erase(S);
+
+ auto RemoveSCEVFromBackedgeMap =
+ [S, this](DenseMap<const Loop *, BackedgeTakenInfo> &Map) {
+ for (auto I = Map.begin(), E = Map.end(); I != E;) {
+ BackedgeTakenInfo &BEInfo = I->second;
+ if (BEInfo.hasOperand(S, this)) {
+ BEInfo.clear();
+ Map.erase(I++);
+ } else
+ ++I;
+ }
+ };
- for (DenseMap<const Loop*, BackedgeTakenInfo>::iterator I =
- BackedgeTakenCounts.begin(), E = BackedgeTakenCounts.end(); I != E; ) {
- BackedgeTakenInfo &BEInfo = I->second;
- if (BEInfo.hasOperand(S, this)) {
- BEInfo.clear();
- BackedgeTakenCounts.erase(I++);
- }
- else
- ++I;
- }
+ RemoveSCEVFromBackedgeMap(BackedgeTakenCounts);
+ RemoveSCEVFromBackedgeMap(PredicatedBackedgeTakenCounts);
}
typedef DenseMap<const Loop *, std::string> VerifyMap;
@@ -9516,16 +10011,16 @@ void ScalarEvolution::verify() const {
char ScalarEvolutionAnalysis::PassID;
ScalarEvolution ScalarEvolutionAnalysis::run(Function &F,
- AnalysisManager<Function> *AM) {
- return ScalarEvolution(F, AM->getResult<TargetLibraryAnalysis>(F),
- AM->getResult<AssumptionAnalysis>(F),
- AM->getResult<DominatorTreeAnalysis>(F),
- AM->getResult<LoopAnalysis>(F));
+ AnalysisManager<Function> &AM) {
+ return ScalarEvolution(F, AM.getResult<TargetLibraryAnalysis>(F),
+ AM.getResult<AssumptionAnalysis>(F),
+ AM.getResult<DominatorTreeAnalysis>(F),
+ AM.getResult<LoopAnalysis>(F));
}
PreservedAnalyses
-ScalarEvolutionPrinterPass::run(Function &F, AnalysisManager<Function> *AM) {
- AM->getResult<ScalarEvolutionAnalysis>(F).print(OS);
+ScalarEvolutionPrinterPass::run(Function &F, AnalysisManager<Function> &AM) {
+ AM.getResult<ScalarEvolutionAnalysis>(F).print(OS);
return PreservedAnalyses::all();
}
@@ -9590,36 +10085,121 @@ ScalarEvolution::getEqualPredicate(const SCEVUnknown *LHS,
return Eq;
}
+const SCEVPredicate *ScalarEvolution::getWrapPredicate(
+ const SCEVAddRecExpr *AR,
+ SCEVWrapPredicate::IncrementWrapFlags AddedFlags) {
+ FoldingSetNodeID ID;
+ // Unique this node based on the arguments
+ ID.AddInteger(SCEVPredicate::P_Wrap);
+ ID.AddPointer(AR);
+ ID.AddInteger(AddedFlags);
+ void *IP = nullptr;
+ if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP))
+ return S;
+ auto *OF = new (SCEVAllocator)
+ SCEVWrapPredicate(ID.Intern(SCEVAllocator), AR, AddedFlags);
+ UniquePreds.InsertNode(OF, IP);
+ return OF;
+}
+
namespace {
+
class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> {
public:
- static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE,
- SCEVUnionPredicate &A) {
- SCEVPredicateRewriter Rewriter(SE, A);
- return Rewriter.visit(Scev);
+ // Rewrites \p S in the context of a loop L and the predicate A.
+ // If Assume is true, rewrite is free to add further predicates to A
+ // such that the result will be an AddRecExpr.
+ static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE,
+ SCEVUnionPredicate &A, bool Assume) {
+ SCEVPredicateRewriter Rewriter(L, SE, A, Assume);
+ return Rewriter.visit(S);
}
- SCEVPredicateRewriter(ScalarEvolution &SE, SCEVUnionPredicate &P)
- : SCEVRewriteVisitor(SE), P(P) {}
+ SCEVPredicateRewriter(const Loop *L, ScalarEvolution &SE,
+ SCEVUnionPredicate &P, bool Assume)
+ : SCEVRewriteVisitor(SE), P(P), L(L), Assume(Assume) {}
const SCEV *visitUnknown(const SCEVUnknown *Expr) {
auto ExprPreds = P.getPredicatesForExpr(Expr);
for (auto *Pred : ExprPreds)
- if (const auto *IPred = dyn_cast<const SCEVEqualPredicate>(Pred))
+ if (const auto *IPred = dyn_cast<SCEVEqualPredicate>(Pred))
if (IPred->getLHS() == Expr)
return IPred->getRHS();
return Expr;
}
+ const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
+ const SCEV *Operand = visit(Expr->getOperand());
+ const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Operand);
+ if (AR && AR->getLoop() == L && AR->isAffine()) {
+ // This couldn't be folded because the operand didn't have the nuw
+ // flag. Add the nusw flag as an assumption that we could make.
+ const SCEV *Step = AR->getStepRecurrence(SE);
+ Type *Ty = Expr->getType();
+ if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNUSW))
+ return SE.getAddRecExpr(SE.getZeroExtendExpr(AR->getStart(), Ty),
+ SE.getSignExtendExpr(Step, Ty), L,
+ AR->getNoWrapFlags());
+ }
+ return SE.getZeroExtendExpr(Operand, Expr->getType());
+ }
+
+ const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
+ const SCEV *Operand = visit(Expr->getOperand());
+ const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Operand);
+ if (AR && AR->getLoop() == L && AR->isAffine()) {
+ // This couldn't be folded because the operand didn't have the nsw
+ // flag. Add the nssw flag as an assumption that we could make.
+ const SCEV *Step = AR->getStepRecurrence(SE);
+ Type *Ty = Expr->getType();
+ if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNSSW))
+ return SE.getAddRecExpr(SE.getSignExtendExpr(AR->getStart(), Ty),
+ SE.getSignExtendExpr(Step, Ty), L,
+ AR->getNoWrapFlags());
+ }
+ return SE.getSignExtendExpr(Operand, Expr->getType());
+ }
+
private:
+ bool addOverflowAssumption(const SCEVAddRecExpr *AR,
+ SCEVWrapPredicate::IncrementWrapFlags AddedFlags) {
+ auto *A = SE.getWrapPredicate(AR, AddedFlags);
+ if (!Assume) {
+ // Check if we've already made this assumption.
+ if (P.implies(A))
+ return true;
+ return false;
+ }
+ P.add(A);
+ return true;
+ }
+
SCEVUnionPredicate &P;
+ const Loop *L;
+ bool Assume;
};
} // end anonymous namespace
-const SCEV *ScalarEvolution::rewriteUsingPredicate(const SCEV *Scev,
+const SCEV *ScalarEvolution::rewriteUsingPredicate(const SCEV *S, const Loop *L,
SCEVUnionPredicate &Preds) {
- return SCEVPredicateRewriter::rewrite(Scev, *this, Preds);
+ return SCEVPredicateRewriter::rewrite(S, L, *this, Preds, false);
+}
+
+const SCEVAddRecExpr *
+ScalarEvolution::convertSCEVToAddRecWithPredicates(const SCEV *S, const Loop *L,
+ SCEVUnionPredicate &Preds) {
+ SCEVUnionPredicate TransformPreds;
+ S = SCEVPredicateRewriter::rewrite(S, L, *this, TransformPreds, true);
+ auto *AddRec = dyn_cast<SCEVAddRecExpr>(S);
+
+ if (!AddRec)
+ return nullptr;
+
+ // Since the transformation was successful, we can now transfer the SCEV
+ // predicates.
+ Preds.add(&TransformPreds);
+ return AddRec;
}
/// SCEV predicates
@@ -9633,7 +10213,7 @@ SCEVEqualPredicate::SCEVEqualPredicate(const FoldingSetNodeIDRef ID,
: SCEVPredicate(ID, P_Equal), LHS(LHS), RHS(RHS) {}
bool SCEVEqualPredicate::implies(const SCEVPredicate *N) const {
- const auto *Op = dyn_cast<const SCEVEqualPredicate>(N);
+ const auto *Op = dyn_cast<SCEVEqualPredicate>(N);
if (!Op)
return false;
@@ -9649,6 +10229,59 @@ void SCEVEqualPredicate::print(raw_ostream &OS, unsigned Depth) const {
OS.indent(Depth) << "Equal predicate: " << *LHS << " == " << *RHS << "\n";
}
+SCEVWrapPredicate::SCEVWrapPredicate(const FoldingSetNodeIDRef ID,
+ const SCEVAddRecExpr *AR,
+ IncrementWrapFlags Flags)
+ : SCEVPredicate(ID, P_Wrap), AR(AR), Flags(Flags) {}
+
+const SCEV *SCEVWrapPredicate::getExpr() const { return AR; }
+
+bool SCEVWrapPredicate::implies(const SCEVPredicate *N) const {
+ const auto *Op = dyn_cast<SCEVWrapPredicate>(N);
+
+ return Op && Op->AR == AR && setFlags(Flags, Op->Flags) == Flags;
+}
+
+bool SCEVWrapPredicate::isAlwaysTrue() const {
+ SCEV::NoWrapFlags ScevFlags = AR->getNoWrapFlags();
+ IncrementWrapFlags IFlags = Flags;
+
+ if (ScalarEvolution::setFlags(ScevFlags, SCEV::FlagNSW) == ScevFlags)
+ IFlags = clearFlags(IFlags, IncrementNSSW);
+
+ return IFlags == IncrementAnyWrap;
+}
+
+void SCEVWrapPredicate::print(raw_ostream &OS, unsigned Depth) const {
+ OS.indent(Depth) << *getExpr() << " Added Flags: ";
+ if (SCEVWrapPredicate::IncrementNUSW & getFlags())
+ OS << "<nusw>";
+ if (SCEVWrapPredicate::IncrementNSSW & getFlags())
+ OS << "<nssw>";
+ OS << "\n";
+}
+
+SCEVWrapPredicate::IncrementWrapFlags
+SCEVWrapPredicate::getImpliedFlags(const SCEVAddRecExpr *AR,
+ ScalarEvolution &SE) {
+ IncrementWrapFlags ImpliedFlags = IncrementAnyWrap;
+ SCEV::NoWrapFlags StaticFlags = AR->getNoWrapFlags();
+
+ // We can safely transfer the NSW flag as NSSW.
+ if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNSW) == StaticFlags)
+ ImpliedFlags = IncrementNSSW;
+
+ if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNUW) == StaticFlags) {
+ // If the increment is positive, the SCEV NUW flag will also imply the
+ // WrapPredicate NUSW flag.
+ if (const auto *Step = dyn_cast<SCEVConstant>(AR->getStepRecurrence(SE)))
+ if (Step->getValue()->getValue().isNonNegative())
+ ImpliedFlags = setFlags(ImpliedFlags, IncrementNUSW);
+ }
+
+ return ImpliedFlags;
+}
+
/// Union predicates don't get cached so create a dummy set ID for it.
SCEVUnionPredicate::SCEVUnionPredicate()
: SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {}
@@ -9667,7 +10300,7 @@ SCEVUnionPredicate::getPredicatesForExpr(const SCEV *Expr) {
}
bool SCEVUnionPredicate::implies(const SCEVPredicate *N) const {
- if (const auto *Set = dyn_cast<const SCEVUnionPredicate>(N))
+ if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N))
return all_of(Set->Preds,
[this](const SCEVPredicate *I) { return this->implies(I); });
@@ -9688,7 +10321,7 @@ void SCEVUnionPredicate::print(raw_ostream &OS, unsigned Depth) const {
}
void SCEVUnionPredicate::add(const SCEVPredicate *N) {
- if (const auto *Set = dyn_cast<const SCEVUnionPredicate>(N)) {
+ if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N)) {
for (auto Pred : Set->Preds)
add(Pred);
return;
@@ -9705,8 +10338,9 @@ void SCEVUnionPredicate::add(const SCEVPredicate *N) {
Preds.push_back(N);
}
-PredicatedScalarEvolution::PredicatedScalarEvolution(ScalarEvolution &SE)
- : SE(SE), Generation(0) {}
+PredicatedScalarEvolution::PredicatedScalarEvolution(ScalarEvolution &SE,
+ Loop &L)
+ : SE(SE), L(L), Generation(0), BackedgeCount(nullptr) {}
const SCEV *PredicatedScalarEvolution::getSCEV(Value *V) {
const SCEV *Expr = SE.getSCEV(V);
@@ -9721,12 +10355,21 @@ const SCEV *PredicatedScalarEvolution::getSCEV(Value *V) {
if (Entry.second)
Expr = Entry.second;
- const SCEV *NewSCEV = SE.rewriteUsingPredicate(Expr, Preds);
+ const SCEV *NewSCEV = SE.rewriteUsingPredicate(Expr, &L, Preds);
Entry = {Generation, NewSCEV};
return NewSCEV;
}
+const SCEV *PredicatedScalarEvolution::getBackedgeTakenCount() {
+ if (!BackedgeCount) {
+ SCEVUnionPredicate BackedgePred;
+ BackedgeCount = SE.getPredicatedBackedgeTakenCount(&L, BackedgePred);
+ addPredicate(BackedgePred);
+ }
+ return BackedgeCount;
+}
+
void PredicatedScalarEvolution::addPredicate(const SCEVPredicate &Pred) {
if (Preds.implies(&Pred))
return;
@@ -9743,7 +10386,82 @@ void PredicatedScalarEvolution::updateGeneration() {
if (++Generation == 0) {
for (auto &II : RewriteMap) {
const SCEV *Rewritten = II.second.second;
- II.second = {Generation, SE.rewriteUsingPredicate(Rewritten, Preds)};
+ II.second = {Generation, SE.rewriteUsingPredicate(Rewritten, &L, Preds)};
}
}
}
+
+void PredicatedScalarEvolution::setNoOverflow(
+ Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags) {
+ const SCEV *Expr = getSCEV(V);
+ const auto *AR = cast<SCEVAddRecExpr>(Expr);
+
+ auto ImpliedFlags = SCEVWrapPredicate::getImpliedFlags(AR, SE);
+
+ // Clear the statically implied flags.
+ Flags = SCEVWrapPredicate::clearFlags(Flags, ImpliedFlags);
+ addPredicate(*SE.getWrapPredicate(AR, Flags));
+
+ auto II = FlagsMap.insert({V, Flags});
+ if (!II.second)
+ II.first->second = SCEVWrapPredicate::setFlags(Flags, II.first->second);
+}
+
+bool PredicatedScalarEvolution::hasNoOverflow(
+ Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags) {
+ const SCEV *Expr = getSCEV(V);
+ const auto *AR = cast<SCEVAddRecExpr>(Expr);
+
+ Flags = SCEVWrapPredicate::clearFlags(
+ Flags, SCEVWrapPredicate::getImpliedFlags(AR, SE));
+
+ auto II = FlagsMap.find(V);
+
+ if (II != FlagsMap.end())
+ Flags = SCEVWrapPredicate::clearFlags(Flags, II->second);
+
+ return Flags == SCEVWrapPredicate::IncrementAnyWrap;
+}
+
+const SCEVAddRecExpr *PredicatedScalarEvolution::getAsAddRec(Value *V) {
+ const SCEV *Expr = this->getSCEV(V);
+ auto *New = SE.convertSCEVToAddRecWithPredicates(Expr, &L, Preds);
+
+ if (!New)
+ return nullptr;
+
+ updateGeneration();
+ RewriteMap[SE.getSCEV(V)] = {Generation, New};
+ return New;
+}
+
+PredicatedScalarEvolution::PredicatedScalarEvolution(
+ const PredicatedScalarEvolution &Init)
+ : RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L), Preds(Init.Preds),
+ Generation(Init.Generation), BackedgeCount(Init.BackedgeCount) {
+ for (const auto &I : Init.FlagsMap)
+ FlagsMap.insert(I);
+}
+
+void PredicatedScalarEvolution::print(raw_ostream &OS, unsigned Depth) const {
+ // For each block.
+ for (auto *BB : L.getBlocks())
+ for (auto &I : *BB) {
+ if (!SE.isSCEVable(I.getType()))
+ continue;
+
+ auto *Expr = SE.getSCEV(&I);
+ auto II = RewriteMap.find(Expr);
+
+ if (II == RewriteMap.end())
+ continue;
+
+ // Don't print things that are not interesting.
+ if (II->second.second == Expr)
+ continue;
+
+ OS.indent(Depth) << "[PSE]" << I << ":\n";
+ OS.indent(Depth + 2) << *Expr << "\n";
+ OS.indent(Depth + 2) << "--> " << *II->second.second << "\n";
+ }
+}
OpenPOWER on IntegriCloud