diff options
author | dim <dim@FreeBSD.org> | 2016-12-26 20:36:37 +0000 |
---|---|---|
committer | dim <dim@FreeBSD.org> | 2016-12-26 20:36:37 +0000 |
commit | 06210ae42d418d50d8d9365d5c9419308ae9e7ee (patch) | |
tree | ab60b4cdd6e430dda1f292a46a77ddb744723f31 /contrib/llvm/lib/Analysis/ScalarEvolution.cpp | |
parent | 2dd166267f53df1c3748b4325d294b9b839de74b (diff) | |
download | FreeBSD-src-06210ae42d418d50d8d9365d5c9419308ae9e7ee.zip FreeBSD-src-06210ae42d418d50d8d9365d5c9419308ae9e7ee.tar.gz |
MFC r309124:
Upgrade our copies of clang, llvm, lldb, compiler-rt and libc++ to 3.9.0
release, and add lld 3.9.0. Also completely revamp the build system for
clang, llvm, lldb and their related tools.
Please note that from 3.5.0 onwards, clang, llvm and lldb require C++11
support to build; see UPDATING for more information.
Release notes for llvm, clang and lld are available here:
<http://llvm.org/releases/3.9.0/docs/ReleaseNotes.html>
<http://llvm.org/releases/3.9.0/tools/clang/docs/ReleaseNotes.html>
<http://llvm.org/releases/3.9.0/tools/lld/docs/ReleaseNotes.html>
Thanks to Ed Maste, Bryan Drewery, Andrew Turner, Antoine Brodin and Jan
Beich for their help.
Relnotes: yes
MFC r309147:
Pull in r282174 from upstream llvm trunk (by Krzysztof Parzyszek):
[PPC] Set SP after loading data from stack frame, if no red zone is
present
Follow-up to r280705: Make sure that the SP is only restored after
all data is loaded from the stack frame, if there is no red zone.
This completes the fix for
https://llvm.org/bugs/show_bug.cgi?id=26519.
Differential Revision: https://reviews.llvm.org/D24466
Reported by: Mark Millard
PR: 214433
MFC r309149:
Pull in r283060 from upstream llvm trunk (by Hal Finkel):
[PowerPC] Refactor soft-float support, and enable PPC64 soft float
This change enables soft-float for PowerPC64, and also makes
soft-float disable all vector instruction sets for both 32-bit and
64-bit modes. This latter part is necessary because the PPC backend
canonicalizes many Altivec vector types to floating-point types, and
so soft-float breaks scalarization support for many operations. Both
for embedded targets and for operating-system kernels desiring
soft-float support, it seems reasonable that disabling hardware
floating-point also disables vector instructions (embedded targets
without hardware floating point support are unlikely to have Altivec,
etc. and operating system kernels desiring not to use floating-point
registers to lower syscall cost are unlikely to want to use vector
registers either). If someone needs this to work, we'll need to
change the fact that we promote many Altivec operations to act on
v4f32. To make it possible to disable Altivec when soft-float is
enabled, hardware floating-point support needs to be expressed as a
positive feature, like the others, and not a negative feature,
because target features cannot have dependencies on the disabling of
some other feature. So +soft-float has now become -hard-float.
Fixes PR26970.
Pull in r283061 from upstream clang trunk (by Hal Finkel):
[PowerPC] Enable soft-float for PPC64, and +soft-float -> -hard-float
Enable soft-float support on PPC64, as the backend now supports it.
Also, the backend now uses -hard-float instead of +soft-float, so set
the target features accordingly.
Fixes PR26970.
Reported by: Mark Millard
PR: 214433
MFC r309212:
Add a few missed clang 3.9.0 files to OptionalObsoleteFiles.
MFC r309262:
Fix packaging for clang, lldb and lld 3.9.0
During the upgrade of clang/llvm etc to 3.9.0 in r309124, the PACKAGE
directive in the usr.bin/clang/*.mk files got dropped accidentally.
Restore it, with a few minor changes and additions:
* Correct license in clang.ucl to NCSA
* Add PACKAGE=clang for clang and most of the "ll" tools
* Put lldb in its own package
* Put lld in its own package
Reviewed by: gjb, jmallett
Differential Revision: https://reviews.freebsd.org/D8666
MFC r309656:
During the bootstrap phase, when building the minimal llvm library on
PowerPC, add lib/Support/Atomic.cpp. This is needed because upstream
llvm revision r271821 disabled the use of std::call_once, which causes
some fallback functions from Atomic.cpp to be used instead.
Reported by: Mark Millard
PR: 214902
MFC r309835:
Tentatively apply https://reviews.llvm.org/D18730 to work around gcc PR
70528 (bogus error: constructor required before non-static data member).
This should fix buildworld with the external gcc package.
Reported by: https://jenkins.freebsd.org/job/FreeBSD_HEAD_amd64_gcc/
MFC r310194:
Upgrade our copies of clang, llvm, lld, lldb, compiler-rt and libc++ to
3.9.1 release.
Please note that from 3.5.0 onwards, clang, llvm and lldb require C++11
support to build; see UPDATING for more information.
Release notes for llvm, clang and lld will be available here:
<http://releases.llvm.org/3.9.1/docs/ReleaseNotes.html>
<http://releases.llvm.org/3.9.1/tools/clang/docs/ReleaseNotes.html>
<http://releases.llvm.org/3.9.1/tools/lld/docs/ReleaseNotes.html>
Relnotes: yes
Diffstat (limited to 'contrib/llvm/lib/Analysis/ScalarEvolution.cpp')
-rw-r--r-- | contrib/llvm/lib/Analysis/ScalarEvolution.cpp | 2506 |
1 files changed, 1612 insertions, 894 deletions
diff --git a/contrib/llvm/lib/Analysis/ScalarEvolution.cpp b/contrib/llvm/lib/Analysis/ScalarEvolution.cpp index ef1bb3a..e42a4b5 100644 --- a/contrib/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/contrib/llvm/lib/Analysis/ScalarEvolution.cpp @@ -111,10 +111,14 @@ MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden, "derived loop"), cl::init(100)); -// FIXME: Enable this with XDEBUG when the test suite is clean. +// FIXME: Enable this with EXPENSIVE_CHECKS when the test suite is clean. static cl::opt<bool> VerifySCEV("verify-scev", cl::desc("Verify ScalarEvolution's backedge taken counts (slow)")); +static cl::opt<bool> + VerifySCEVMap("verify-scev-maps", + cl::desc("Verify no dangling value in ScalarEvolution's " + "ExprValueMap (slow)")); //===----------------------------------------------------------------------===// // SCEV class definitions @@ -162,11 +166,11 @@ void SCEV::print(raw_ostream &OS) const { for (unsigned i = 1, e = AR->getNumOperands(); i != e; ++i) OS << ",+," << *AR->getOperand(i); OS << "}<"; - if (AR->getNoWrapFlags(FlagNUW)) + if (AR->hasNoUnsignedWrap()) OS << "nuw><"; - if (AR->getNoWrapFlags(FlagNSW)) + if (AR->hasNoSignedWrap()) OS << "nsw><"; - if (AR->getNoWrapFlags(FlagNW) && + if (AR->hasNoSelfWrap() && !AR->getNoWrapFlags((NoWrapFlags)(FlagNUW | FlagNSW))) OS << "nw><"; AR->getLoop()->getHeader()->printAsOperand(OS, /*PrintType=*/false); @@ -196,9 +200,9 @@ void SCEV::print(raw_ostream &OS) const { switch (NAry->getSCEVType()) { case scAddExpr: case scMulExpr: - if (NAry->getNoWrapFlags(FlagNUW)) + if (NAry->hasNoUnsignedWrap()) OS << "<nuw>"; - if (NAry->getNoWrapFlags(FlagNSW)) + if (NAry->hasNoSignedWrap()) OS << "<nsw>"; } return; @@ -283,8 +287,6 @@ bool SCEV::isAllOnesValue() const { return false; } -/// isNonConstantNegative - Return true if the specified scev is negated, but -/// not a constant. bool SCEV::isNonConstantNegative() const { const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(this); if (!Mul) return false; @@ -620,10 +622,10 @@ public: }; } // end anonymous namespace -/// GroupByComplexity - Given a list of SCEV objects, order them by their -/// complexity, and group objects of the same complexity together by value. -/// When this routine is finished, we know that any duplicates in the vector are -/// consecutive and that complexity is monotonically increasing. +/// Given a list of SCEV objects, order them by their complexity, and group +/// objects of the same complexity together by value. When this routine is +/// finished, we know that any duplicates in the vector are consecutive and that +/// complexity is monotonically increasing. /// /// Note that we go take special precautions to ensure that we get deterministic /// results from this routine. In other words, we don't want the results of @@ -723,7 +725,7 @@ public: } // Split the Denominator when it is a product. - if (const SCEVMulExpr *T = dyn_cast<const SCEVMulExpr>(Denominator)) { + if (const SCEVMulExpr *T = dyn_cast<SCEVMulExpr>(Denominator)) { const SCEV *Q, *R; *Quotient = Numerator; for (const SCEV *Op : T->operands()) { @@ -922,8 +924,7 @@ private: // Simple SCEV method implementations //===----------------------------------------------------------------------===// -/// BinomialCoefficient - Compute BC(It, K). The result has width W. -/// Assume, K > 0. +/// Compute BC(It, K). The result has width W. Assume, K > 0. static const SCEV *BinomialCoefficient(const SCEV *It, unsigned K, ScalarEvolution &SE, Type *ResultTy) { @@ -1034,10 +1035,10 @@ static const SCEV *BinomialCoefficient(const SCEV *It, unsigned K, SE.getTruncateOrZeroExtend(DivResult, ResultTy)); } -/// evaluateAtIteration - Return the value of this chain of recurrences at -/// the specified iteration number. We can evaluate this recurrence by -/// multiplying each element in the chain by the binomial coefficient -/// corresponding to it. In other words, we can evaluate {A,+,B,+,C,+,D} as: +/// Return the value of this chain of recurrences at the specified iteration +/// number. We can evaluate this recurrence by multiplying each element in the +/// chain by the binomial coefficient corresponding to it. In other words, we +/// can evaluate {A,+,B,+,C,+,D} as: /// /// A*BC(It, 0) + B*BC(It, 1) + C*BC(It, 2) + D*BC(It, 3) /// @@ -1450,9 +1451,14 @@ const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op, unsigned BitWidth = getTypeSizeInBits(AR->getType()); const Loop *L = AR->getLoop(); + if (!AR->hasNoUnsignedWrap()) { + auto NewFlags = proveNoWrapViaConstantRanges(AR); + const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(NewFlags); + } + // If we have special knowledge that this addrec won't overflow, // we don't need to do any further analysis. - if (AR->getNoWrapFlags(SCEV::FlagNUW)) + if (AR->hasNoUnsignedWrap()) return getAddRecExpr( getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this), getZeroExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); @@ -1512,11 +1518,22 @@ const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op, getSignExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); } } + } - // If the backedge is guarded by a comparison with the pre-inc value - // the addrec is safe. Also, if the entry is guarded by a comparison - // with the start value and the backedge is guarded by a comparison - // with the post-inc value, the addrec is safe. + // Normally, in the cases we can prove no-overflow via a + // backedge guarding condition, we can also compute a backedge + // taken count for the loop. The exceptions are assumptions and + // guards present in the loop -- SCEV is not great at exploiting + // these to compute max backedge taken counts, but can still use + // these to prove lack of overflow. Use this fact to avoid + // doing extra work that may not pay off. + if (!isa<SCEVCouldNotCompute>(MaxBECount) || HasGuards || + !AC.assumptions().empty()) { + // If the backedge is guarded by a comparison with the pre-inc + // value the addrec is safe. Also, if the entry is guarded by + // a comparison with the start value and the backedge is + // guarded by a comparison with the post-inc value, the addrec + // is safe. if (isKnownPositive(Step)) { const SCEV *N = getConstant(APInt::getMinValue(BitWidth) - getUnsignedRange(Step).getUnsignedMax()); @@ -1524,7 +1541,8 @@ const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op, (isLoopEntryGuardedByCond(L, ICmpInst::ICMP_ULT, Start, N) && isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_ULT, AR->getPostIncExpr(*this), N))) { - // Cache knowledge of AR NUW, which is propagated to this AddRec. + // Cache knowledge of AR NUW, which is propagated to this + // AddRec. const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNUW); // Return the expression with the addrec on the outside. return getAddRecExpr( @@ -1538,8 +1556,9 @@ const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op, (isLoopEntryGuardedByCond(L, ICmpInst::ICMP_UGT, Start, N) && isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_UGT, AR->getPostIncExpr(*this), N))) { - // Cache knowledge of AR NW, which is propagated to this AddRec. - // Negative step causes unsigned wrap, but it still can't self-wrap. + // Cache knowledge of AR NW, which is propagated to this + // AddRec. Negative step causes unsigned wrap, but it + // still can't self-wrap. const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNW); // Return the expression with the addrec on the outside. return getAddRecExpr( @@ -1559,7 +1578,7 @@ const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op, if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) { // zext((A + B + ...)<nuw>) --> (zext(A) + zext(B) + ...)<nuw> - if (SA->getNoWrapFlags(SCEV::FlagNUW)) { + if (SA->hasNoUnsignedWrap()) { // If the addition does not unsign overflow then we can, by definition, // commute the zero extension with the addition operation. SmallVector<const SCEV *, 4> Ops; @@ -1608,10 +1627,6 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op, void *IP = nullptr; if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; - // If the input value is provably positive, build a zext instead. - if (isKnownNonNegative(Op)) - return getZeroExtendExpr(Op, Ty); - // sext(trunc(x)) --> sext(x) or x or trunc(x) if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op)) { // It's possible the bits taken off by the truncate were all sign bits. If @@ -1643,7 +1658,7 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op, } // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw> - if (SA->getNoWrapFlags(SCEV::FlagNSW)) { + if (SA->hasNoSignedWrap()) { // If the addition does not sign overflow then we can, by definition, // commute the sign extension with the addition operation. SmallVector<const SCEV *, 4> Ops; @@ -1663,9 +1678,14 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op, unsigned BitWidth = getTypeSizeInBits(AR->getType()); const Loop *L = AR->getLoop(); + if (!AR->hasNoSignedWrap()) { + auto NewFlags = proveNoWrapViaConstantRanges(AR); + const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(NewFlags); + } + // If we have special knowledge that this addrec won't overflow, // we don't need to do any further analysis. - if (AR->getNoWrapFlags(SCEV::FlagNSW)) + if (AR->hasNoSignedWrap()) return getAddRecExpr( getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this), getSignExtendExpr(Step, Ty), L, SCEV::FlagNSW); @@ -1732,11 +1752,23 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op, getZeroExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); } } + } - // If the backedge is guarded by a comparison with the pre-inc value - // the addrec is safe. Also, if the entry is guarded by a comparison - // with the start value and the backedge is guarded by a comparison - // with the post-inc value, the addrec is safe. + // Normally, in the cases we can prove no-overflow via a + // backedge guarding condition, we can also compute a backedge + // taken count for the loop. The exceptions are assumptions and + // guards present in the loop -- SCEV is not great at exploiting + // these to compute max backedge taken counts, but can still use + // these to prove lack of overflow. Use this fact to avoid + // doing extra work that may not pay off. + + if (!isa<SCEVCouldNotCompute>(MaxBECount) || HasGuards || + !AC.assumptions().empty()) { + // If the backedge is guarded by a comparison with the pre-inc + // value the addrec is safe. Also, if the entry is guarded by + // a comparison with the start value and the backedge is + // guarded by a comparison with the post-inc value, the addrec + // is safe. ICmpInst::Predicate Pred; const SCEV *OverflowLimit = getSignedOverflowLimitForStep(Step, &Pred, this); @@ -1752,6 +1784,7 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op, getSignExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); } } + // If Start and Step are constants, check if we can apply this // transformation: // sext{C1,+,C2} --> C1 + sext{0,+,C2} if C1 < C2 @@ -1777,6 +1810,11 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op, } } + // If the input value is provably positive and we could not simplify + // away the sext build a zext instead. + if (isKnownNonNegative(Op)) + return getZeroExtendExpr(Op, Ty); + // The cast wasn't folded; create an explicit cast node. // Recompute the insert position, as it may have been invalidated. if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; @@ -1836,11 +1874,10 @@ const SCEV *ScalarEvolution::getAnyExtendExpr(const SCEV *Op, return ZExt; } -/// CollectAddOperandsWithScales - Process the given Ops list, which is -/// a list of operands to be added under the given scale, update the given -/// map. This is a helper function for getAddRecExpr. As an example of -/// what it does, given a sequence of operands that would form an add -/// expression like this: +/// Process the given Ops list, which is a list of operands to be added under +/// the given scale, update the given map. This is a helper function for +/// getAddRecExpr. As an example of what it does, given a sequence of operands +/// that would form an add expression like this: /// /// m + n + 13 + (A * (o + p + (B * (q + m + 29)))) + r + (-1 * r) /// @@ -1899,7 +1936,7 @@ CollectAddOperandsWithScales(DenseMap<const SCEV *, APInt> &M, // the map. SmallVector<const SCEV *, 4> MulOps(Mul->op_begin()+1, Mul->op_end()); const SCEV *Key = SE.getMulExpr(MulOps); - auto Pair = M.insert(std::make_pair(Key, NewScale)); + auto Pair = M.insert({Key, NewScale}); if (Pair.second) { NewOps.push_back(Pair.first->first); } else { @@ -1912,7 +1949,7 @@ CollectAddOperandsWithScales(DenseMap<const SCEV *, APInt> &M, } else { // An ordinary operand. Update the map. std::pair<DenseMap<const SCEV *, APInt>::iterator, bool> Pair = - M.insert(std::make_pair(Ops[i], Scale)); + M.insert({Ops[i], Scale}); if (Pair.second) { NewOps.push_back(Pair.first->first); } else { @@ -1965,15 +2002,14 @@ StrengthenNoWrapFlags(ScalarEvolution *SE, SCEVTypes Type, const APInt &C = cast<SCEVConstant>(Ops[0])->getAPInt(); if (!(SignOrUnsignWrap & SCEV::FlagNSW)) { - auto NSWRegion = - ConstantRange::makeNoWrapRegion(Instruction::Add, C, OBO::NoSignedWrap); + auto NSWRegion = ConstantRange::makeGuaranteedNoWrapRegion( + Instruction::Add, C, OBO::NoSignedWrap); if (NSWRegion.contains(SE->getSignedRange(Ops[1]))) Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNSW); } if (!(SignOrUnsignWrap & SCEV::FlagNUW)) { - auto NUWRegion = - ConstantRange::makeNoWrapRegion(Instruction::Add, C, - OBO::NoUnsignedWrap); + auto NUWRegion = ConstantRange::makeGuaranteedNoWrapRegion( + Instruction::Add, C, OBO::NoUnsignedWrap); if (NUWRegion.contains(SE->getUnsignedRange(Ops[1]))) Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW); } @@ -1982,8 +2018,7 @@ StrengthenNoWrapFlags(ScalarEvolution *SE, SCEVTypes Type, return Flags; } -/// getAddExpr - Get a canonical add expression, or something simpler if -/// possible. +/// Get a canonical add expression, or something simpler if possible. const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops, SCEV::NoWrapFlags Flags) { assert(!(Flags & ~(SCEV::FlagNUW | SCEV::FlagNSW)) && @@ -2266,7 +2301,10 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops, SmallVector<const SCEV *, 4> AddRecOps(AddRec->op_begin(), AddRec->op_end()); - AddRecOps[0] = getAddExpr(LIOps); + // This follows from the fact that the no-wrap flags on the outer add + // expression are applicable on the 0th iteration, when the add recurrence + // will be equal to its start value. + AddRecOps[0] = getAddExpr(LIOps, Flags); // Build the new addrec. Propagate the NUW and NSW flags if both the // outer add and the inner addrec are guaranteed to have no overflow. @@ -2391,8 +2429,7 @@ static bool containsConstantSomewhere(const SCEV *StartExpr) { return false; } -/// getMulExpr - Get a canonical multiply expression, or something simpler if -/// possible. +/// Get a canonical multiply expression, or something simpler if possible. const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops, SCEV::NoWrapFlags Flags) { assert(Flags == maskFlags(Flags, SCEV::FlagNUW | SCEV::FlagNSW) && @@ -2632,8 +2669,8 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops, return S; } -/// getUDivExpr - Get a canonical unsigned division expression, or something -/// simpler if possible. +/// Get a canonical unsigned division expression, or something simpler if +/// possible. const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS, const SCEV *RHS) { assert(getEffectiveSCEVType(LHS->getType()) == @@ -2764,10 +2801,10 @@ static const APInt gcd(const SCEVConstant *C1, const SCEVConstant *C2) { return APIntOps::GreatestCommonDivisor(A, B); } -/// getUDivExactExpr - Get a canonical unsigned division expression, or -/// something simpler if possible. There is no representation for an exact udiv -/// in SCEV IR, but we can attempt to remove factors from the LHS and RHS. -/// We can't do this when it's not exact because the udiv may be clearing bits. +/// Get a canonical unsigned division expression, or something simpler if +/// possible. There is no representation for an exact udiv in SCEV IR, but we +/// can attempt to remove factors from the LHS and RHS. We can't do this when +/// it's not exact because the udiv may be clearing bits. const SCEV *ScalarEvolution::getUDivExactExpr(const SCEV *LHS, const SCEV *RHS) { // TODO: we could try to find factors in all sorts of things, but for now we @@ -2821,8 +2858,8 @@ const SCEV *ScalarEvolution::getUDivExactExpr(const SCEV *LHS, return getUDivExpr(LHS, RHS); } -/// getAddRecExpr - Get an add recurrence expression for the specified loop. -/// Simplify the expression as much as possible. +/// Get an add recurrence expression for the specified loop. Simplify the +/// expression as much as possible. const SCEV *ScalarEvolution::getAddRecExpr(const SCEV *Start, const SCEV *Step, const Loop *L, SCEV::NoWrapFlags Flags) { @@ -2838,8 +2875,8 @@ const SCEV *ScalarEvolution::getAddRecExpr(const SCEV *Start, const SCEV *Step, return getAddRecExpr(Operands, L, Flags); } -/// getAddRecExpr - Get an add recurrence expression for the specified loop. -/// Simplify the expression as much as possible. +/// Get an add recurrence expression for the specified loop. Simplify the +/// expression as much as possible. const SCEV * ScalarEvolution::getAddRecExpr(SmallVectorImpl<const SCEV *> &Operands, const Loop *L, SCEV::NoWrapFlags Flags) { @@ -2985,9 +3022,7 @@ ScalarEvolution::getGEPExpr(Type *PointeeType, const SCEV *BaseExpr, const SCEV *ScalarEvolution::getSMaxExpr(const SCEV *LHS, const SCEV *RHS) { - SmallVector<const SCEV *, 2> Ops; - Ops.push_back(LHS); - Ops.push_back(RHS); + SmallVector<const SCEV *, 2> Ops = {LHS, RHS}; return getSMaxExpr(Ops); } @@ -3088,9 +3123,7 @@ ScalarEvolution::getSMaxExpr(SmallVectorImpl<const SCEV *> &Ops) { const SCEV *ScalarEvolution::getUMaxExpr(const SCEV *LHS, const SCEV *RHS) { - SmallVector<const SCEV *, 2> Ops; - Ops.push_back(LHS); - Ops.push_back(RHS); + SmallVector<const SCEV *, 2> Ops = {LHS, RHS}; return getUMaxExpr(Ops); } @@ -3244,26 +3277,25 @@ const SCEV *ScalarEvolution::getUnknown(Value *V) { // Basic SCEV Analysis and PHI Idiom Recognition Code // -/// isSCEVable - Test if values of the given type are analyzable within -/// the SCEV framework. This primarily includes integer types, and it -/// can optionally include pointer types if the ScalarEvolution class -/// has access to target-specific information. +/// Test if values of the given type are analyzable within the SCEV +/// framework. This primarily includes integer types, and it can optionally +/// include pointer types if the ScalarEvolution class has access to +/// target-specific information. bool ScalarEvolution::isSCEVable(Type *Ty) const { // Integers and pointers are always SCEVable. return Ty->isIntegerTy() || Ty->isPointerTy(); } -/// getTypeSizeInBits - Return the size in bits of the specified type, -/// for which isSCEVable must return true. +/// Return the size in bits of the specified type, for which isSCEVable must +/// return true. uint64_t ScalarEvolution::getTypeSizeInBits(Type *Ty) const { assert(isSCEVable(Ty) && "Type is not SCEVable!"); return getDataLayout().getTypeSizeInBits(Ty); } -/// getEffectiveSCEVType - Return a type with the same bitwidth as -/// the given type and which represents how SCEV will treat the given -/// type, for which isSCEVable must return true. For pointer types, -/// this is the pointer-sized integer type. +/// Return a type with the same bitwidth as the given type and which represents +/// how SCEV will treat the given type, for which isSCEVable must return +/// true. For pointer types, this is the pointer-sized integer type. Type *ScalarEvolution::getEffectiveSCEVType(Type *Ty) const { assert(isSCEVable(Ty) && "Type is not SCEVable!"); @@ -3310,15 +3342,88 @@ bool ScalarEvolution::checkValidity(const SCEV *S) const { return !F.FindOne; } -/// getSCEV - Return an existing SCEV if it exists, otherwise analyze the -/// expression and create a new one. +namespace { +// Helper class working with SCEVTraversal to figure out if a SCEV contains +// a sub SCEV of scAddRecExpr type. FindInvalidSCEVUnknown::FoundOne is set +// iff if such sub scAddRecExpr type SCEV is found. +struct FindAddRecurrence { + bool FoundOne; + FindAddRecurrence() : FoundOne(false) {} + + bool follow(const SCEV *S) { + switch (static_cast<SCEVTypes>(S->getSCEVType())) { + case scAddRecExpr: + FoundOne = true; + case scConstant: + case scUnknown: + case scCouldNotCompute: + return false; + default: + return true; + } + } + bool isDone() const { return FoundOne; } +}; +} + +bool ScalarEvolution::containsAddRecurrence(const SCEV *S) { + HasRecMapType::iterator I = HasRecMap.find_as(S); + if (I != HasRecMap.end()) + return I->second; + + FindAddRecurrence F; + SCEVTraversal<FindAddRecurrence> ST(F); + ST.visitAll(S); + HasRecMap.insert({S, F.FoundOne}); + return F.FoundOne; +} + +/// Return the Value set from S. +SetVector<Value *> *ScalarEvolution::getSCEVValues(const SCEV *S) { + ExprValueMapType::iterator SI = ExprValueMap.find_as(S); + if (SI == ExprValueMap.end()) + return nullptr; +#ifndef NDEBUG + if (VerifySCEVMap) { + // Check there is no dangling Value in the set returned. + for (const auto &VE : SI->second) + assert(ValueExprMap.count(VE)); + } +#endif + return &SI->second; +} + +/// Erase Value from ValueExprMap and ExprValueMap. If ValueExprMap.erase(V) is +/// not used together with forgetMemoizedResults(S), eraseValueFromMap should be +/// used instead to ensure whenever V->S is removed from ValueExprMap, V is also +/// removed from the set of ExprValueMap[S]. +void ScalarEvolution::eraseValueFromMap(Value *V) { + ValueExprMapType::iterator I = ValueExprMap.find_as(V); + if (I != ValueExprMap.end()) { + const SCEV *S = I->second; + SetVector<Value *> *SV = getSCEVValues(S); + // Remove V from the set of ExprValueMap[S] + if (SV) + SV->remove(V); + ValueExprMap.erase(V); + } +} + +/// Return an existing SCEV if it exists, otherwise analyze the expression and +/// create a new one. const SCEV *ScalarEvolution::getSCEV(Value *V) { assert(isSCEVable(V->getType()) && "Value is not SCEVable!"); const SCEV *S = getExistingSCEV(V); if (S == nullptr) { S = createSCEV(V); - ValueExprMap.insert(std::make_pair(SCEVCallbackVH(V, this), S)); + // During PHI resolution, it is possible to create two SCEVs for the same + // V, so it is needed to double check whether V->S is inserted into + // ValueExprMap before insert S->V into ExprValueMap. + std::pair<ValueExprMapType::iterator, bool> Pair = + ValueExprMap.insert({SCEVCallbackVH(V, this), S}); + if (Pair.second) + ExprValueMap[S].insert(V); } return S; } @@ -3331,12 +3436,13 @@ const SCEV *ScalarEvolution::getExistingSCEV(Value *V) { const SCEV *S = I->second; if (checkValidity(S)) return S; + forgetMemoizedResults(S); ValueExprMap.erase(I); } return nullptr; } -/// getNegativeSCEV - Return a SCEV corresponding to -V = -1*V +/// Return a SCEV corresponding to -V = -1*V /// const SCEV *ScalarEvolution::getNegativeSCEV(const SCEV *V, SCEV::NoWrapFlags Flags) { @@ -3350,7 +3456,7 @@ const SCEV *ScalarEvolution::getNegativeSCEV(const SCEV *V, V, getConstant(cast<ConstantInt>(Constant::getAllOnesValue(Ty))), Flags); } -/// getNotSCEV - Return a SCEV corresponding to ~V = -1-V +/// Return a SCEV corresponding to ~V = -1-V const SCEV *ScalarEvolution::getNotSCEV(const SCEV *V) { if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V)) return getConstant( @@ -3363,7 +3469,6 @@ const SCEV *ScalarEvolution::getNotSCEV(const SCEV *V) { return getMinusSCEV(AllOnes, V); } -/// getMinusSCEV - Return LHS-RHS. Minus is represented in SCEV as A+B*-1. const SCEV *ScalarEvolution::getMinusSCEV(const SCEV *LHS, const SCEV *RHS, SCEV::NoWrapFlags Flags) { // Fast path: X - X --> 0. @@ -3402,9 +3507,6 @@ const SCEV *ScalarEvolution::getMinusSCEV(const SCEV *LHS, const SCEV *RHS, return getAddExpr(LHS, getNegativeSCEV(RHS, NegFlags), AddFlags); } -/// getTruncateOrZeroExtend - Return a SCEV corresponding to a conversion of the -/// input value to the specified type. If the type must be extended, it is zero -/// extended. const SCEV * ScalarEvolution::getTruncateOrZeroExtend(const SCEV *V, Type *Ty) { Type *SrcTy = V->getType(); @@ -3418,9 +3520,6 @@ ScalarEvolution::getTruncateOrZeroExtend(const SCEV *V, Type *Ty) { return getZeroExtendExpr(V, Ty); } -/// getTruncateOrSignExtend - Return a SCEV corresponding to a conversion of the -/// input value to the specified type. If the type must be extended, it is sign -/// extended. const SCEV * ScalarEvolution::getTruncateOrSignExtend(const SCEV *V, Type *Ty) { @@ -3435,9 +3534,6 @@ ScalarEvolution::getTruncateOrSignExtend(const SCEV *V, return getSignExtendExpr(V, Ty); } -/// getNoopOrZeroExtend - Return a SCEV corresponding to a conversion of the -/// input value to the specified type. If the type must be extended, it is zero -/// extended. The conversion must not be narrowing. const SCEV * ScalarEvolution::getNoopOrZeroExtend(const SCEV *V, Type *Ty) { Type *SrcTy = V->getType(); @@ -3451,9 +3547,6 @@ ScalarEvolution::getNoopOrZeroExtend(const SCEV *V, Type *Ty) { return getZeroExtendExpr(V, Ty); } -/// getNoopOrSignExtend - Return a SCEV corresponding to a conversion of the -/// input value to the specified type. If the type must be extended, it is sign -/// extended. The conversion must not be narrowing. const SCEV * ScalarEvolution::getNoopOrSignExtend(const SCEV *V, Type *Ty) { Type *SrcTy = V->getType(); @@ -3467,10 +3560,6 @@ ScalarEvolution::getNoopOrSignExtend(const SCEV *V, Type *Ty) { return getSignExtendExpr(V, Ty); } -/// getNoopOrAnyExtend - Return a SCEV corresponding to a conversion of -/// the input value to the specified type. If the type must be extended, -/// it is extended with unspecified bits. The conversion must not be -/// narrowing. const SCEV * ScalarEvolution::getNoopOrAnyExtend(const SCEV *V, Type *Ty) { Type *SrcTy = V->getType(); @@ -3484,8 +3573,6 @@ ScalarEvolution::getNoopOrAnyExtend(const SCEV *V, Type *Ty) { return getAnyExtendExpr(V, Ty); } -/// getTruncateOrNoop - Return a SCEV corresponding to a conversion of the -/// input value to the specified type. The conversion must not be widening. const SCEV * ScalarEvolution::getTruncateOrNoop(const SCEV *V, Type *Ty) { Type *SrcTy = V->getType(); @@ -3499,9 +3586,6 @@ ScalarEvolution::getTruncateOrNoop(const SCEV *V, Type *Ty) { return getTruncateExpr(V, Ty); } -/// getUMaxFromMismatchedTypes - Promote the operands to the wider of -/// the types using zero-extension, and then perform a umax operation -/// with them. const SCEV *ScalarEvolution::getUMaxFromMismatchedTypes(const SCEV *LHS, const SCEV *RHS) { const SCEV *PromotedLHS = LHS; @@ -3515,9 +3599,6 @@ const SCEV *ScalarEvolution::getUMaxFromMismatchedTypes(const SCEV *LHS, return getUMaxExpr(PromotedLHS, PromotedRHS); } -/// getUMinFromMismatchedTypes - Promote the operands to the wider of -/// the types using zero-extension, and then perform a umin operation -/// with them. const SCEV *ScalarEvolution::getUMinFromMismatchedTypes(const SCEV *LHS, const SCEV *RHS) { const SCEV *PromotedLHS = LHS; @@ -3531,10 +3612,6 @@ const SCEV *ScalarEvolution::getUMinFromMismatchedTypes(const SCEV *LHS, return getUMinExpr(PromotedLHS, PromotedRHS); } -/// getPointerBase - Transitively follow the chain of pointer-type operands -/// until reaching a SCEV that does not have a single pointer operand. This -/// returns a SCEVUnknown pointer for well-formed pointer-type expressions, -/// but corner cases do exist. const SCEV *ScalarEvolution::getPointerBase(const SCEV *V) { // A pointer operand may evaluate to a nonpointer expression, such as null. if (!V->getType()->isPointerTy()) @@ -3559,8 +3636,7 @@ const SCEV *ScalarEvolution::getPointerBase(const SCEV *V) { return V; } -/// PushDefUseChildren - Push users of the given Instruction -/// onto the given Worklist. +/// Push users of the given Instruction onto the given Worklist. static void PushDefUseChildren(Instruction *I, SmallVectorImpl<Instruction *> &Worklist) { @@ -3569,12 +3645,7 @@ PushDefUseChildren(Instruction *I, Worklist.push_back(cast<Instruction>(U)); } -/// ForgetSymbolicValue - This looks up computed SCEV values for all -/// instructions that depend on the given instruction and removes them from -/// the ValueExprMapType map if they reference SymName. This is used during PHI -/// resolution. -void -ScalarEvolution::ForgetSymbolicName(Instruction *PN, const SCEV *SymName) { +void ScalarEvolution::forgetSymbolicName(Instruction *PN, const SCEV *SymName) { SmallVector<Instruction *, 16> Worklist; PushDefUseChildren(PN, Worklist); @@ -3616,10 +3687,10 @@ ScalarEvolution::ForgetSymbolicName(Instruction *PN, const SCEV *SymName) { namespace { class SCEVInitRewriter : public SCEVRewriteVisitor<SCEVInitRewriter> { public: - static const SCEV *rewrite(const SCEV *Scev, const Loop *L, + static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE) { SCEVInitRewriter Rewriter(L, SE); - const SCEV *Result = Rewriter.visit(Scev); + const SCEV *Result = Rewriter.visit(S); return Rewriter.isValid() ? Result : SE.getCouldNotCompute(); } @@ -3649,10 +3720,10 @@ private: class SCEVShiftRewriter : public SCEVRewriteVisitor<SCEVShiftRewriter> { public: - static const SCEV *rewrite(const SCEV *Scev, const Loop *L, + static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE) { SCEVShiftRewriter Rewriter(L, SE); - const SCEV *Result = Rewriter.visit(Scev); + const SCEV *Result = Rewriter.visit(S); return Rewriter.isValid() ? Result : SE.getCouldNotCompute(); } @@ -3680,6 +3751,167 @@ private: }; } // end anonymous namespace +SCEV::NoWrapFlags +ScalarEvolution::proveNoWrapViaConstantRanges(const SCEVAddRecExpr *AR) { + if (!AR->isAffine()) + return SCEV::FlagAnyWrap; + + typedef OverflowingBinaryOperator OBO; + SCEV::NoWrapFlags Result = SCEV::FlagAnyWrap; + + if (!AR->hasNoSignedWrap()) { + ConstantRange AddRecRange = getSignedRange(AR); + ConstantRange IncRange = getSignedRange(AR->getStepRecurrence(*this)); + + auto NSWRegion = ConstantRange::makeGuaranteedNoWrapRegion( + Instruction::Add, IncRange, OBO::NoSignedWrap); + if (NSWRegion.contains(AddRecRange)) + Result = ScalarEvolution::setFlags(Result, SCEV::FlagNSW); + } + + if (!AR->hasNoUnsignedWrap()) { + ConstantRange AddRecRange = getUnsignedRange(AR); + ConstantRange IncRange = getUnsignedRange(AR->getStepRecurrence(*this)); + + auto NUWRegion = ConstantRange::makeGuaranteedNoWrapRegion( + Instruction::Add, IncRange, OBO::NoUnsignedWrap); + if (NUWRegion.contains(AddRecRange)) + Result = ScalarEvolution::setFlags(Result, SCEV::FlagNUW); + } + + return Result; +} + +namespace { +/// Represents an abstract binary operation. This may exist as a +/// normal instruction or constant expression, or may have been +/// derived from an expression tree. +struct BinaryOp { + unsigned Opcode; + Value *LHS; + Value *RHS; + bool IsNSW; + bool IsNUW; + + /// Op is set if this BinaryOp corresponds to a concrete LLVM instruction or + /// constant expression. + Operator *Op; + + explicit BinaryOp(Operator *Op) + : Opcode(Op->getOpcode()), LHS(Op->getOperand(0)), RHS(Op->getOperand(1)), + IsNSW(false), IsNUW(false), Op(Op) { + if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(Op)) { + IsNSW = OBO->hasNoSignedWrap(); + IsNUW = OBO->hasNoUnsignedWrap(); + } + } + + explicit BinaryOp(unsigned Opcode, Value *LHS, Value *RHS, bool IsNSW = false, + bool IsNUW = false) + : Opcode(Opcode), LHS(LHS), RHS(RHS), IsNSW(IsNSW), IsNUW(IsNUW), + Op(nullptr) {} +}; +} + + +/// Try to map \p V into a BinaryOp, and return \c None on failure. +static Optional<BinaryOp> MatchBinaryOp(Value *V, DominatorTree &DT) { + auto *Op = dyn_cast<Operator>(V); + if (!Op) + return None; + + // Implementation detail: all the cleverness here should happen without + // creating new SCEV expressions -- our caller knowns tricks to avoid creating + // SCEV expressions when possible, and we should not break that. + + switch (Op->getOpcode()) { + case Instruction::Add: + case Instruction::Sub: + case Instruction::Mul: + case Instruction::UDiv: + case Instruction::And: + case Instruction::Or: + case Instruction::AShr: + case Instruction::Shl: + return BinaryOp(Op); + + case Instruction::Xor: + if (auto *RHSC = dyn_cast<ConstantInt>(Op->getOperand(1))) + // If the RHS of the xor is a signbit, then this is just an add. + // Instcombine turns add of signbit into xor as a strength reduction step. + if (RHSC->getValue().isSignBit()) + return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1)); + return BinaryOp(Op); + + case Instruction::LShr: + // Turn logical shift right of a constant into a unsigned divide. + if (ConstantInt *SA = dyn_cast<ConstantInt>(Op->getOperand(1))) { + uint32_t BitWidth = cast<IntegerType>(Op->getType())->getBitWidth(); + + // If the shift count is not less than the bitwidth, the result of + // the shift is undefined. Don't try to analyze it, because the + // resolution chosen here may differ from the resolution chosen in + // other parts of the compiler. + if (SA->getValue().ult(BitWidth)) { + Constant *X = + ConstantInt::get(SA->getContext(), + APInt::getOneBitSet(BitWidth, SA->getZExtValue())); + return BinaryOp(Instruction::UDiv, Op->getOperand(0), X); + } + } + return BinaryOp(Op); + + case Instruction::ExtractValue: { + auto *EVI = cast<ExtractValueInst>(Op); + if (EVI->getNumIndices() != 1 || EVI->getIndices()[0] != 0) + break; + + auto *CI = dyn_cast<CallInst>(EVI->getAggregateOperand()); + if (!CI) + break; + + if (auto *F = CI->getCalledFunction()) + switch (F->getIntrinsicID()) { + case Intrinsic::sadd_with_overflow: + case Intrinsic::uadd_with_overflow: { + if (!isOverflowIntrinsicNoWrap(cast<IntrinsicInst>(CI), DT)) + return BinaryOp(Instruction::Add, CI->getArgOperand(0), + CI->getArgOperand(1)); + + // Now that we know that all uses of the arithmetic-result component of + // CI are guarded by the overflow check, we can go ahead and pretend + // that the arithmetic is non-overflowing. + if (F->getIntrinsicID() == Intrinsic::sadd_with_overflow) + return BinaryOp(Instruction::Add, CI->getArgOperand(0), + CI->getArgOperand(1), /* IsNSW = */ true, + /* IsNUW = */ false); + else + return BinaryOp(Instruction::Add, CI->getArgOperand(0), + CI->getArgOperand(1), /* IsNSW = */ false, + /* IsNUW*/ true); + } + + case Intrinsic::ssub_with_overflow: + case Intrinsic::usub_with_overflow: + return BinaryOp(Instruction::Sub, CI->getArgOperand(0), + CI->getArgOperand(1)); + + case Intrinsic::smul_with_overflow: + case Intrinsic::umul_with_overflow: + return BinaryOp(Instruction::Mul, CI->getArgOperand(0), + CI->getArgOperand(1)); + default: + break; + } + } + + default: + break; + } + + return None; +} + const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) { const Loop *L = LI.getLoopFor(PN->getParent()); if (!L || L->getHeader() != PN->getParent()) @@ -3710,7 +3942,7 @@ const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) { const SCEV *SymbolicName = getUnknown(PN); assert(ValueExprMap.find_as(PN) == ValueExprMap.end() && "PHI node already processed?"); - ValueExprMap.insert(std::make_pair(SCEVCallbackVH(PN, this), SymbolicName)); + ValueExprMap.insert({SCEVCallbackVH(PN, this), SymbolicName}); // Using this symbolic name for the PHI, analyze the value coming around // the back-edge. @@ -3747,13 +3979,11 @@ const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) { cast<SCEVAddRecExpr>(Accum)->getLoop() == L)) { SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap; - // If the increment doesn't overflow, then neither the addrec nor - // the post-increment will overflow. - if (const AddOperator *OBO = dyn_cast<AddOperator>(BEValueV)) { - if (OBO->getOperand(0) == PN) { - if (OBO->hasNoUnsignedWrap()) + if (auto BO = MatchBinaryOp(BEValueV, DT)) { + if (BO->Opcode == Instruction::Add && BO->LHS == PN) { + if (BO->IsNUW) Flags = setFlags(Flags, SCEV::FlagNUW); - if (OBO->hasNoSignedWrap()) + if (BO->IsNSW) Flags = setFlags(Flags, SCEV::FlagNSW); } } else if (GEPOperator *GEP = dyn_cast<GEPOperator>(BEValueV)) { @@ -3779,16 +4009,19 @@ const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) { const SCEV *StartVal = getSCEV(StartValueV); const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags); - // Since the no-wrap flags are on the increment, they apply to the - // post-incremented value as well. - if (isLoopInvariant(Accum, L)) - (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags); - // Okay, for the entire analysis of this edge we assumed the PHI // to be symbolic. We now need to go back and purge all of the // entries for the scalars that use the symbolic expression. - ForgetSymbolicName(PN, SymbolicName); + forgetSymbolicName(PN, SymbolicName); ValueExprMap[SCEVCallbackVH(PN, this)] = PHISCEV; + + // We can add Flags to the post-inc expression only if we + // know that it us *undefined behavior* for BEValueV to + // overflow. + if (auto *BEInst = dyn_cast<Instruction>(BEValueV)) + if (isLoopInvariant(Accum, L) && isAddRecNeverPoison(BEInst, L)) + (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags); + return PHISCEV; } } @@ -3811,12 +4044,18 @@ const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) { // Okay, for the entire analysis of this edge we assumed the PHI // to be symbolic. We now need to go back and purge all of the // entries for the scalars that use the symbolic expression. - ForgetSymbolicName(PN, SymbolicName); + forgetSymbolicName(PN, SymbolicName); ValueExprMap[SCEVCallbackVH(PN, this)] = Shifted; return Shifted; } } } + + // Remove the temporary PHI node SCEV that has been inserted while intending + // to create an AddRecExpr for this PHI node. We can not keep this temporary + // as it will prevent later (possibly simpler) SCEV expressions to be added + // to the ValueExprMap. + ValueExprMap.erase(PN); } return nullptr; @@ -4083,26 +4322,21 @@ const SCEV *ScalarEvolution::createNodeForSelectOrPHI(Instruction *I, return getUnknown(I); } -/// createNodeForGEP - Expand GEP instructions into add and multiply -/// operations. This allows them to be analyzed by regular SCEV code. -/// +/// Expand GEP instructions into add and multiply operations. This allows them +/// to be analyzed by regular SCEV code. const SCEV *ScalarEvolution::createNodeForGEP(GEPOperator *GEP) { - Value *Base = GEP->getOperand(0); // Don't attempt to analyze GEPs over unsized objects. - if (!Base->getType()->getPointerElementType()->isSized()) + if (!GEP->getSourceElementType()->isSized()) return getUnknown(GEP); SmallVector<const SCEV *, 4> IndexExprs; for (auto Index = GEP->idx_begin(); Index != GEP->idx_end(); ++Index) IndexExprs.push_back(getSCEV(*Index)); - return getGEPExpr(GEP->getSourceElementType(), getSCEV(Base), IndexExprs, - GEP->isInBounds()); + return getGEPExpr(GEP->getSourceElementType(), + getSCEV(GEP->getPointerOperand()), + IndexExprs, GEP->isInBounds()); } -/// GetMinTrailingZeros - Determine the minimum number of zero bits that S is -/// guaranteed to end in (at every loop iteration). It is, at the same time, -/// the minimum number of times S is divisible by 2. For example, given {4,+,8} -/// it returns 2. If S is guaranteed to be 0, it returns the bitwidth of S. uint32_t ScalarEvolution::GetMinTrailingZeros(const SCEV *S) { if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S)) @@ -4180,8 +4414,7 @@ ScalarEvolution::GetMinTrailingZeros(const SCEV *S) { return 0; } -/// GetRangeFromMetadata - Helper method to assign a range to V from -/// metadata present in the IR. +/// Helper method to assign a range to V from metadata present in the IR. static Optional<ConstantRange> GetRangeFromMetadata(Value *V) { if (Instruction *I = dyn_cast<Instruction>(V)) if (MDNode *MD = I->getMetadata(LLVMContext::MD_range)) @@ -4190,10 +4423,9 @@ static Optional<ConstantRange> GetRangeFromMetadata(Value *V) { return None; } -/// getRange - Determine the range for a particular SCEV. If SignHint is +/// Determine the range for a particular SCEV. If SignHint is /// HINT_RANGE_UNSIGNED (resp. HINT_RANGE_SIGNED) then getRange prefers ranges /// with a "cleaner" unsigned (resp. signed) representation. -/// ConstantRange ScalarEvolution::getRange(const SCEV *S, ScalarEvolution::RangeSignHint SignHint) { @@ -4282,7 +4514,7 @@ ScalarEvolution::getRange(const SCEV *S, if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(S)) { // If there's no unsigned wrap, the value will never be less than its // initial value. - if (AddRec->getNoWrapFlags(SCEV::FlagNUW)) + if (AddRec->hasNoUnsignedWrap()) if (const SCEVConstant *C = dyn_cast<SCEVConstant>(AddRec->getStart())) if (!C->getValue()->isZero()) ConservativeResult = ConservativeResult.intersectWith( @@ -4290,7 +4522,7 @@ ScalarEvolution::getRange(const SCEV *S, // If there's no signed wrap, and all the operands have the same sign or // zero, the value won't ever change sign. - if (AddRec->getNoWrapFlags(SCEV::FlagNSW)) { + if (AddRec->hasNoSignedWrap()) { bool AllNonNeg = true; bool AllNonPos = true; for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) { @@ -4309,66 +4541,22 @@ ScalarEvolution::getRange(const SCEV *S, // TODO: non-affine addrec if (AddRec->isAffine()) { - Type *Ty = AddRec->getType(); const SCEV *MaxBECount = getMaxBackedgeTakenCount(AddRec->getLoop()); if (!isa<SCEVCouldNotCompute>(MaxBECount) && getTypeSizeInBits(MaxBECount->getType()) <= BitWidth) { - - // Check for overflow. This must be done with ConstantRange arithmetic - // because we could be called from within the ScalarEvolution overflow - // checking code. - - MaxBECount = getNoopOrZeroExtend(MaxBECount, Ty); - ConstantRange MaxBECountRange = getUnsignedRange(MaxBECount); - ConstantRange ZExtMaxBECountRange = - MaxBECountRange.zextOrTrunc(BitWidth * 2 + 1); - - const SCEV *Start = AddRec->getStart(); - const SCEV *Step = AddRec->getStepRecurrence(*this); - ConstantRange StepSRange = getSignedRange(Step); - ConstantRange SExtStepSRange = StepSRange.sextOrTrunc(BitWidth * 2 + 1); - - ConstantRange StartURange = getUnsignedRange(Start); - ConstantRange EndURange = - StartURange.add(MaxBECountRange.multiply(StepSRange)); - - // Check for unsigned overflow. - ConstantRange ZExtStartURange = - StartURange.zextOrTrunc(BitWidth * 2 + 1); - ConstantRange ZExtEndURange = EndURange.zextOrTrunc(BitWidth * 2 + 1); - if (ZExtStartURange.add(ZExtMaxBECountRange.multiply(SExtStepSRange)) == - ZExtEndURange) { - APInt Min = APIntOps::umin(StartURange.getUnsignedMin(), - EndURange.getUnsignedMin()); - APInt Max = APIntOps::umax(StartURange.getUnsignedMax(), - EndURange.getUnsignedMax()); - bool IsFullRange = Min.isMinValue() && Max.isMaxValue(); - if (!IsFullRange) - ConservativeResult = - ConservativeResult.intersectWith(ConstantRange(Min, Max + 1)); - } - - ConstantRange StartSRange = getSignedRange(Start); - ConstantRange EndSRange = - StartSRange.add(MaxBECountRange.multiply(StepSRange)); - - // Check for signed overflow. This must be done with ConstantRange - // arithmetic because we could be called from within the ScalarEvolution - // overflow checking code. - ConstantRange SExtStartSRange = - StartSRange.sextOrTrunc(BitWidth * 2 + 1); - ConstantRange SExtEndSRange = EndSRange.sextOrTrunc(BitWidth * 2 + 1); - if (SExtStartSRange.add(ZExtMaxBECountRange.multiply(SExtStepSRange)) == - SExtEndSRange) { - APInt Min = APIntOps::smin(StartSRange.getSignedMin(), - EndSRange.getSignedMin()); - APInt Max = APIntOps::smax(StartSRange.getSignedMax(), - EndSRange.getSignedMax()); - bool IsFullRange = Min.isMinSignedValue() && Max.isMaxSignedValue(); - if (!IsFullRange) - ConservativeResult = - ConservativeResult.intersectWith(ConstantRange(Min, Max + 1)); - } + auto RangeFromAffine = getRangeForAffineAR( + AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount, + BitWidth); + if (!RangeFromAffine.isFullSet()) + ConservativeResult = + ConservativeResult.intersectWith(RangeFromAffine); + + auto RangeFromFactoring = getRangeViaFactoring( + AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount, + BitWidth); + if (!RangeFromFactoring.isFullSet()) + ConservativeResult = + ConservativeResult.intersectWith(RangeFromFactoring); } } @@ -4408,6 +4596,186 @@ ScalarEvolution::getRange(const SCEV *S, return setRange(S, SignHint, ConservativeResult); } +ConstantRange ScalarEvolution::getRangeForAffineAR(const SCEV *Start, + const SCEV *Step, + const SCEV *MaxBECount, + unsigned BitWidth) { + assert(!isa<SCEVCouldNotCompute>(MaxBECount) && + getTypeSizeInBits(MaxBECount->getType()) <= BitWidth && + "Precondition!"); + + ConstantRange Result(BitWidth, /* isFullSet = */ true); + + // Check for overflow. This must be done with ConstantRange arithmetic + // because we could be called from within the ScalarEvolution overflow + // checking code. + + MaxBECount = getNoopOrZeroExtend(MaxBECount, Start->getType()); + ConstantRange MaxBECountRange = getUnsignedRange(MaxBECount); + ConstantRange ZExtMaxBECountRange = + MaxBECountRange.zextOrTrunc(BitWidth * 2 + 1); + + ConstantRange StepSRange = getSignedRange(Step); + ConstantRange SExtStepSRange = StepSRange.sextOrTrunc(BitWidth * 2 + 1); + + ConstantRange StartURange = getUnsignedRange(Start); + ConstantRange EndURange = + StartURange.add(MaxBECountRange.multiply(StepSRange)); + + // Check for unsigned overflow. + ConstantRange ZExtStartURange = StartURange.zextOrTrunc(BitWidth * 2 + 1); + ConstantRange ZExtEndURange = EndURange.zextOrTrunc(BitWidth * 2 + 1); + if (ZExtStartURange.add(ZExtMaxBECountRange.multiply(SExtStepSRange)) == + ZExtEndURange) { + APInt Min = APIntOps::umin(StartURange.getUnsignedMin(), + EndURange.getUnsignedMin()); + APInt Max = APIntOps::umax(StartURange.getUnsignedMax(), + EndURange.getUnsignedMax()); + bool IsFullRange = Min.isMinValue() && Max.isMaxValue(); + if (!IsFullRange) + Result = + Result.intersectWith(ConstantRange(Min, Max + 1)); + } + + ConstantRange StartSRange = getSignedRange(Start); + ConstantRange EndSRange = + StartSRange.add(MaxBECountRange.multiply(StepSRange)); + + // Check for signed overflow. This must be done with ConstantRange + // arithmetic because we could be called from within the ScalarEvolution + // overflow checking code. + ConstantRange SExtStartSRange = StartSRange.sextOrTrunc(BitWidth * 2 + 1); + ConstantRange SExtEndSRange = EndSRange.sextOrTrunc(BitWidth * 2 + 1); + if (SExtStartSRange.add(ZExtMaxBECountRange.multiply(SExtStepSRange)) == + SExtEndSRange) { + APInt Min = + APIntOps::smin(StartSRange.getSignedMin(), EndSRange.getSignedMin()); + APInt Max = + APIntOps::smax(StartSRange.getSignedMax(), EndSRange.getSignedMax()); + bool IsFullRange = Min.isMinSignedValue() && Max.isMaxSignedValue(); + if (!IsFullRange) + Result = + Result.intersectWith(ConstantRange(Min, Max + 1)); + } + + return Result; +} + +ConstantRange ScalarEvolution::getRangeViaFactoring(const SCEV *Start, + const SCEV *Step, + const SCEV *MaxBECount, + unsigned BitWidth) { + // RangeOf({C?A:B,+,C?P:Q}) == RangeOf(C?{A,+,P}:{B,+,Q}) + // == RangeOf({A,+,P}) union RangeOf({B,+,Q}) + + struct SelectPattern { + Value *Condition = nullptr; + APInt TrueValue; + APInt FalseValue; + + explicit SelectPattern(ScalarEvolution &SE, unsigned BitWidth, + const SCEV *S) { + Optional<unsigned> CastOp; + APInt Offset(BitWidth, 0); + + assert(SE.getTypeSizeInBits(S->getType()) == BitWidth && + "Should be!"); + + // Peel off a constant offset: + if (auto *SA = dyn_cast<SCEVAddExpr>(S)) { + // In the future we could consider being smarter here and handle + // {Start+Step,+,Step} too. + if (SA->getNumOperands() != 2 || !isa<SCEVConstant>(SA->getOperand(0))) + return; + + Offset = cast<SCEVConstant>(SA->getOperand(0))->getAPInt(); + S = SA->getOperand(1); + } + + // Peel off a cast operation + if (auto *SCast = dyn_cast<SCEVCastExpr>(S)) { + CastOp = SCast->getSCEVType(); + S = SCast->getOperand(); + } + + using namespace llvm::PatternMatch; + + auto *SU = dyn_cast<SCEVUnknown>(S); + const APInt *TrueVal, *FalseVal; + if (!SU || + !match(SU->getValue(), m_Select(m_Value(Condition), m_APInt(TrueVal), + m_APInt(FalseVal)))) { + Condition = nullptr; + return; + } + + TrueValue = *TrueVal; + FalseValue = *FalseVal; + + // Re-apply the cast we peeled off earlier + if (CastOp.hasValue()) + switch (*CastOp) { + default: + llvm_unreachable("Unknown SCEV cast type!"); + + case scTruncate: + TrueValue = TrueValue.trunc(BitWidth); + FalseValue = FalseValue.trunc(BitWidth); + break; + case scZeroExtend: + TrueValue = TrueValue.zext(BitWidth); + FalseValue = FalseValue.zext(BitWidth); + break; + case scSignExtend: + TrueValue = TrueValue.sext(BitWidth); + FalseValue = FalseValue.sext(BitWidth); + break; + } + + // Re-apply the constant offset we peeled off earlier + TrueValue += Offset; + FalseValue += Offset; + } + + bool isRecognized() { return Condition != nullptr; } + }; + + SelectPattern StartPattern(*this, BitWidth, Start); + if (!StartPattern.isRecognized()) + return ConstantRange(BitWidth, /* isFullSet = */ true); + + SelectPattern StepPattern(*this, BitWidth, Step); + if (!StepPattern.isRecognized()) + return ConstantRange(BitWidth, /* isFullSet = */ true); + + if (StartPattern.Condition != StepPattern.Condition) { + // We don't handle this case today; but we could, by considering four + // possibilities below instead of two. I'm not sure if there are cases where + // that will help over what getRange already does, though. + return ConstantRange(BitWidth, /* isFullSet = */ true); + } + + // NB! Calling ScalarEvolution::getConstant is fine, but we should not try to + // construct arbitrary general SCEV expressions here. This function is called + // from deep in the call stack, and calling getSCEV (on a sext instruction, + // say) can end up caching a suboptimal value. + + // FIXME: without the explicit `this` receiver below, MSVC errors out with + // C2352 and C2512 (otherwise it isn't needed). + + const SCEV *TrueStart = this->getConstant(StartPattern.TrueValue); + const SCEV *TrueStep = this->getConstant(StepPattern.TrueValue); + const SCEV *FalseStart = this->getConstant(StartPattern.FalseValue); + const SCEV *FalseStep = this->getConstant(StepPattern.FalseValue); + + ConstantRange TrueRange = + this->getRangeForAffineAR(TrueStart, TrueStep, MaxBECount, BitWidth); + ConstantRange FalseRange = + this->getRangeForAffineAR(FalseStart, FalseStep, MaxBECount, BitWidth); + + return TrueRange.unionWith(FalseRange); +} + SCEV::NoWrapFlags ScalarEvolution::getNoWrapFlagsFromUB(const Value *V) { if (isa<ConstantExpr>(V)) return SCEV::FlagAnyWrap; const BinaryOperator *BinOp = cast<BinaryOperator>(V); @@ -4418,273 +4786,363 @@ SCEV::NoWrapFlags ScalarEvolution::getNoWrapFlagsFromUB(const Value *V) { Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW); if (BinOp->hasNoSignedWrap()) Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNSW); - if (Flags == SCEV::FlagAnyWrap) { + if (Flags == SCEV::FlagAnyWrap) return SCEV::FlagAnyWrap; - } - // Here we check that BinOp is in the header of the innermost loop - // containing BinOp, since we only deal with instructions in the loop - // header. The actual loop we need to check later will come from an add - // recurrence, but getting that requires computing the SCEV of the operands, - // which can be expensive. This check we can do cheaply to rule out some - // cases early. - Loop *innermostContainingLoop = LI.getLoopFor(BinOp->getParent()); - if (innermostContainingLoop == nullptr || - innermostContainingLoop->getHeader() != BinOp->getParent()) - return SCEV::FlagAnyWrap; + return isSCEVExprNeverPoison(BinOp) ? Flags : SCEV::FlagAnyWrap; +} - // Only proceed if we can prove that BinOp does not yield poison. - if (!isKnownNotFullPoison(BinOp)) return SCEV::FlagAnyWrap; +bool ScalarEvolution::isSCEVExprNeverPoison(const Instruction *I) { + // Here we check that I is in the header of the innermost loop containing I, + // since we only deal with instructions in the loop header. The actual loop we + // need to check later will come from an add recurrence, but getting that + // requires computing the SCEV of the operands, which can be expensive. This + // check we can do cheaply to rule out some cases early. + Loop *InnermostContainingLoop = LI.getLoopFor(I->getParent()); + if (InnermostContainingLoop == nullptr || + InnermostContainingLoop->getHeader() != I->getParent()) + return false; + + // Only proceed if we can prove that I does not yield poison. + if (!isKnownNotFullPoison(I)) return false; - // At this point we know that if V is executed, then it does not wrap - // according to at least one of NSW or NUW. If V is not executed, then we do - // not know if the calculation that V represents would wrap. Multiple - // instructions can map to the same SCEV. If we apply NSW or NUW from V to + // At this point we know that if I is executed, then it does not wrap + // according to at least one of NSW or NUW. If I is not executed, then we do + // not know if the calculation that I represents would wrap. Multiple + // instructions can map to the same SCEV. If we apply NSW or NUW from I to // the SCEV, we must guarantee no wrapping for that SCEV also when it is // derived from other instructions that map to the same SCEV. We cannot make - // that guarantee for cases where V is not executed. So we need to find the - // loop that V is considered in relation to and prove that V is executed for - // every iteration of that loop. That implies that the value that V + // that guarantee for cases where I is not executed. So we need to find the + // loop that I is considered in relation to and prove that I is executed for + // every iteration of that loop. That implies that the value that I // calculates does not wrap anywhere in the loop, so then we can apply the // flags to the SCEV. // - // We check isLoopInvariant to disambiguate in case we are adding two - // recurrences from different loops, so that we know which loop to prove - // that V is executed in. - for (int OpIndex = 0; OpIndex < 2; ++OpIndex) { - const SCEV *Op = getSCEV(BinOp->getOperand(OpIndex)); + // We check isLoopInvariant to disambiguate in case we are adding recurrences + // from different loops, so that we know which loop to prove that I is + // executed in. + for (unsigned OpIndex = 0; OpIndex < I->getNumOperands(); ++OpIndex) { + // I could be an extractvalue from a call to an overflow intrinsic. + // TODO: We can do better here in some cases. + if (!isSCEVable(I->getOperand(OpIndex)->getType())) + return false; + const SCEV *Op = getSCEV(I->getOperand(OpIndex)); if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(Op)) { - const int OtherOpIndex = 1 - OpIndex; - const SCEV *OtherOp = getSCEV(BinOp->getOperand(OtherOpIndex)); - if (isLoopInvariant(OtherOp, AddRec->getLoop()) && - isGuaranteedToExecuteForEveryIteration(BinOp, AddRec->getLoop())) - return Flags; + bool AllOtherOpsLoopInvariant = true; + for (unsigned OtherOpIndex = 0; OtherOpIndex < I->getNumOperands(); + ++OtherOpIndex) { + if (OtherOpIndex != OpIndex) { + const SCEV *OtherOp = getSCEV(I->getOperand(OtherOpIndex)); + if (!isLoopInvariant(OtherOp, AddRec->getLoop())) { + AllOtherOpsLoopInvariant = false; + break; + } + } + } + if (AllOtherOpsLoopInvariant && + isGuaranteedToExecuteForEveryIteration(I, AddRec->getLoop())) + return true; } } - return SCEV::FlagAnyWrap; + return false; +} + +bool ScalarEvolution::isAddRecNeverPoison(const Instruction *I, const Loop *L) { + // If we know that \c I can never be poison period, then that's enough. + if (isSCEVExprNeverPoison(I)) + return true; + + // For an add recurrence specifically, we assume that infinite loops without + // side effects are undefined behavior, and then reason as follows: + // + // If the add recurrence is poison in any iteration, it is poison on all + // future iterations (since incrementing poison yields poison). If the result + // of the add recurrence is fed into the loop latch condition and the loop + // does not contain any throws or exiting blocks other than the latch, we now + // have the ability to "choose" whether the backedge is taken or not (by + // choosing a sufficiently evil value for the poison feeding into the branch) + // for every iteration including and after the one in which \p I first became + // poison. There are two possibilities (let's call the iteration in which \p + // I first became poison as K): + // + // 1. In the set of iterations including and after K, the loop body executes + // no side effects. In this case executing the backege an infinte number + // of times will yield undefined behavior. + // + // 2. In the set of iterations including and after K, the loop body executes + // at least one side effect. In this case, that specific instance of side + // effect is control dependent on poison, which also yields undefined + // behavior. + + auto *ExitingBB = L->getExitingBlock(); + auto *LatchBB = L->getLoopLatch(); + if (!ExitingBB || !LatchBB || ExitingBB != LatchBB) + return false; + + SmallPtrSet<const Instruction *, 16> Pushed; + SmallVector<const Instruction *, 8> PoisonStack; + + // We start by assuming \c I, the post-inc add recurrence, is poison. Only + // things that are known to be fully poison under that assumption go on the + // PoisonStack. + Pushed.insert(I); + PoisonStack.push_back(I); + + bool LatchControlDependentOnPoison = false; + while (!PoisonStack.empty() && !LatchControlDependentOnPoison) { + const Instruction *Poison = PoisonStack.pop_back_val(); + + for (auto *PoisonUser : Poison->users()) { + if (propagatesFullPoison(cast<Instruction>(PoisonUser))) { + if (Pushed.insert(cast<Instruction>(PoisonUser)).second) + PoisonStack.push_back(cast<Instruction>(PoisonUser)); + } else if (auto *BI = dyn_cast<BranchInst>(PoisonUser)) { + assert(BI->isConditional() && "Only possibility!"); + if (BI->getParent() == LatchBB) { + LatchControlDependentOnPoison = true; + break; + } + } + } + } + + return LatchControlDependentOnPoison && loopHasNoAbnormalExits(L); +} + +bool ScalarEvolution::loopHasNoAbnormalExits(const Loop *L) { + auto Itr = LoopHasNoAbnormalExits.find(L); + if (Itr == LoopHasNoAbnormalExits.end()) { + auto NoAbnormalExitInBB = [&](BasicBlock *BB) { + return all_of(*BB, [](Instruction &I) { + return isGuaranteedToTransferExecutionToSuccessor(&I); + }); + }; + + auto InsertPair = LoopHasNoAbnormalExits.insert( + {L, all_of(L->getBlocks(), NoAbnormalExitInBB)}); + assert(InsertPair.second && "We just checked!"); + Itr = InsertPair.first; + } + + return Itr->second; } -/// createSCEV - We know that there is no SCEV for the specified value. Analyze -/// the expression. -/// const SCEV *ScalarEvolution::createSCEV(Value *V) { if (!isSCEVable(V->getType())) return getUnknown(V); - unsigned Opcode = Instruction::UserOp1; if (Instruction *I = dyn_cast<Instruction>(V)) { - Opcode = I->getOpcode(); - // Don't attempt to analyze instructions in blocks that aren't // reachable. Such instructions don't matter, and they aren't required // to obey basic rules for definitions dominating uses which this // analysis depends on. if (!DT.isReachableFromEntry(I->getParent())) return getUnknown(V); - } else if (ConstantExpr *CE = dyn_cast<ConstantExpr>(V)) - Opcode = CE->getOpcode(); - else if (ConstantInt *CI = dyn_cast<ConstantInt>(V)) + } else if (ConstantInt *CI = dyn_cast<ConstantInt>(V)) return getConstant(CI); else if (isa<ConstantPointerNull>(V)) return getZero(V->getType()); else if (GlobalAlias *GA = dyn_cast<GlobalAlias>(V)) - return GA->mayBeOverridden() ? getUnknown(V) : getSCEV(GA->getAliasee()); - else + return GA->isInterposable() ? getUnknown(V) : getSCEV(GA->getAliasee()); + else if (!isa<ConstantExpr>(V)) return getUnknown(V); Operator *U = cast<Operator>(V); - switch (Opcode) { - case Instruction::Add: { - // The simple thing to do would be to just call getSCEV on both operands - // and call getAddExpr with the result. However if we're looking at a - // bunch of things all added together, this can be quite inefficient, - // because it leads to N-1 getAddExpr calls for N ultimate operands. - // Instead, gather up all the operands and make a single getAddExpr call. - // LLVM IR canonical form means we need only traverse the left operands. - SmallVector<const SCEV *, 4> AddOps; - for (Value *Op = U;; Op = U->getOperand(0)) { - U = dyn_cast<Operator>(Op); - unsigned Opcode = U ? U->getOpcode() : 0; - if (!U || (Opcode != Instruction::Add && Opcode != Instruction::Sub)) { - assert(Op != V && "V should be an add"); - AddOps.push_back(getSCEV(Op)); - break; - } + if (auto BO = MatchBinaryOp(U, DT)) { + switch (BO->Opcode) { + case Instruction::Add: { + // The simple thing to do would be to just call getSCEV on both operands + // and call getAddExpr with the result. However if we're looking at a + // bunch of things all added together, this can be quite inefficient, + // because it leads to N-1 getAddExpr calls for N ultimate operands. + // Instead, gather up all the operands and make a single getAddExpr call. + // LLVM IR canonical form means we need only traverse the left operands. + SmallVector<const SCEV *, 4> AddOps; + do { + if (BO->Op) { + if (auto *OpSCEV = getExistingSCEV(BO->Op)) { + AddOps.push_back(OpSCEV); + break; + } - if (auto *OpSCEV = getExistingSCEV(U)) { - AddOps.push_back(OpSCEV); - break; - } + // If a NUW or NSW flag can be applied to the SCEV for this + // addition, then compute the SCEV for this addition by itself + // with a separate call to getAddExpr. We need to do that + // instead of pushing the operands of the addition onto AddOps, + // since the flags are only known to apply to this particular + // addition - they may not apply to other additions that can be + // formed with operands from AddOps. + const SCEV *RHS = getSCEV(BO->RHS); + SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op); + if (Flags != SCEV::FlagAnyWrap) { + const SCEV *LHS = getSCEV(BO->LHS); + if (BO->Opcode == Instruction::Sub) + AddOps.push_back(getMinusSCEV(LHS, RHS, Flags)); + else + AddOps.push_back(getAddExpr(LHS, RHS, Flags)); + break; + } + } - // If a NUW or NSW flag can be applied to the SCEV for this - // addition, then compute the SCEV for this addition by itself - // with a separate call to getAddExpr. We need to do that - // instead of pushing the operands of the addition onto AddOps, - // since the flags are only known to apply to this particular - // addition - they may not apply to other additions that can be - // formed with operands from AddOps. - const SCEV *RHS = getSCEV(U->getOperand(1)); - SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(U); - if (Flags != SCEV::FlagAnyWrap) { - const SCEV *LHS = getSCEV(U->getOperand(0)); - if (Opcode == Instruction::Sub) - AddOps.push_back(getMinusSCEV(LHS, RHS, Flags)); + if (BO->Opcode == Instruction::Sub) + AddOps.push_back(getNegativeSCEV(getSCEV(BO->RHS))); else - AddOps.push_back(getAddExpr(LHS, RHS, Flags)); - break; - } + AddOps.push_back(getSCEV(BO->RHS)); - if (Opcode == Instruction::Sub) - AddOps.push_back(getNegativeSCEV(RHS)); - else - AddOps.push_back(RHS); + auto NewBO = MatchBinaryOp(BO->LHS, DT); + if (!NewBO || (NewBO->Opcode != Instruction::Add && + NewBO->Opcode != Instruction::Sub)) { + AddOps.push_back(getSCEV(BO->LHS)); + break; + } + BO = NewBO; + } while (true); + + return getAddExpr(AddOps); } - return getAddExpr(AddOps); - } - case Instruction::Mul: { - SmallVector<const SCEV *, 4> MulOps; - for (Value *Op = U;; Op = U->getOperand(0)) { - U = dyn_cast<Operator>(Op); - if (!U || U->getOpcode() != Instruction::Mul) { - assert(Op != V && "V should be a mul"); - MulOps.push_back(getSCEV(Op)); - break; - } + case Instruction::Mul: { + SmallVector<const SCEV *, 4> MulOps; + do { + if (BO->Op) { + if (auto *OpSCEV = getExistingSCEV(BO->Op)) { + MulOps.push_back(OpSCEV); + break; + } - if (auto *OpSCEV = getExistingSCEV(U)) { - MulOps.push_back(OpSCEV); - break; - } + SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op); + if (Flags != SCEV::FlagAnyWrap) { + MulOps.push_back( + getMulExpr(getSCEV(BO->LHS), getSCEV(BO->RHS), Flags)); + break; + } + } - SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(U); - if (Flags != SCEV::FlagAnyWrap) { - MulOps.push_back(getMulExpr(getSCEV(U->getOperand(0)), - getSCEV(U->getOperand(1)), Flags)); - break; + MulOps.push_back(getSCEV(BO->RHS)); + auto NewBO = MatchBinaryOp(BO->LHS, DT); + if (!NewBO || NewBO->Opcode != Instruction::Mul) { + MulOps.push_back(getSCEV(BO->LHS)); + break; + } + BO = NewBO; + } while (true); + + return getMulExpr(MulOps); + } + case Instruction::UDiv: + return getUDivExpr(getSCEV(BO->LHS), getSCEV(BO->RHS)); + case Instruction::Sub: { + SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap; + if (BO->Op) + Flags = getNoWrapFlagsFromUB(BO->Op); + return getMinusSCEV(getSCEV(BO->LHS), getSCEV(BO->RHS), Flags); + } + case Instruction::And: + // For an expression like x&255 that merely masks off the high bits, + // use zext(trunc(x)) as the SCEV expression. + if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) { + if (CI->isNullValue()) + return getSCEV(BO->RHS); + if (CI->isAllOnesValue()) + return getSCEV(BO->LHS); + const APInt &A = CI->getValue(); + + // Instcombine's ShrinkDemandedConstant may strip bits out of + // constants, obscuring what would otherwise be a low-bits mask. + // Use computeKnownBits to compute what ShrinkDemandedConstant + // knew about to reconstruct a low-bits mask value. + unsigned LZ = A.countLeadingZeros(); + unsigned TZ = A.countTrailingZeros(); + unsigned BitWidth = A.getBitWidth(); + APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0); + computeKnownBits(BO->LHS, KnownZero, KnownOne, getDataLayout(), + 0, &AC, nullptr, &DT); + + APInt EffectiveMask = + APInt::getLowBitsSet(BitWidth, BitWidth - LZ - TZ).shl(TZ); + if ((LZ != 0 || TZ != 0) && !((~A & ~KnownZero) & EffectiveMask)) { + const SCEV *MulCount = getConstant(ConstantInt::get( + getContext(), APInt::getOneBitSet(BitWidth, TZ))); + return getMulExpr( + getZeroExtendExpr( + getTruncateExpr( + getUDivExactExpr(getSCEV(BO->LHS), MulCount), + IntegerType::get(getContext(), BitWidth - LZ - TZ)), + BO->LHS->getType()), + MulCount); + } } + break; - MulOps.push_back(getSCEV(U->getOperand(1))); - } - return getMulExpr(MulOps); - } - case Instruction::UDiv: - return getUDivExpr(getSCEV(U->getOperand(0)), - getSCEV(U->getOperand(1))); - case Instruction::Sub: - return getMinusSCEV(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1)), - getNoWrapFlagsFromUB(U)); - case Instruction::And: - // For an expression like x&255 that merely masks off the high bits, - // use zext(trunc(x)) as the SCEV expression. - if (ConstantInt *CI = dyn_cast<ConstantInt>(U->getOperand(1))) { - if (CI->isNullValue()) - return getSCEV(U->getOperand(1)); - if (CI->isAllOnesValue()) - return getSCEV(U->getOperand(0)); - const APInt &A = CI->getValue(); - - // Instcombine's ShrinkDemandedConstant may strip bits out of - // constants, obscuring what would otherwise be a low-bits mask. - // Use computeKnownBits to compute what ShrinkDemandedConstant - // knew about to reconstruct a low-bits mask value. - unsigned LZ = A.countLeadingZeros(); - unsigned TZ = A.countTrailingZeros(); - unsigned BitWidth = A.getBitWidth(); - APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0); - computeKnownBits(U->getOperand(0), KnownZero, KnownOne, getDataLayout(), - 0, &AC, nullptr, &DT); - - APInt EffectiveMask = - APInt::getLowBitsSet(BitWidth, BitWidth - LZ - TZ).shl(TZ); - if ((LZ != 0 || TZ != 0) && !((~A & ~KnownZero) & EffectiveMask)) { - const SCEV *MulCount = getConstant( - ConstantInt::get(getContext(), APInt::getOneBitSet(BitWidth, TZ))); - return getMulExpr( - getZeroExtendExpr( - getTruncateExpr( - getUDivExactExpr(getSCEV(U->getOperand(0)), MulCount), - IntegerType::get(getContext(), BitWidth - LZ - TZ)), - U->getType()), - MulCount); + case Instruction::Or: + // If the RHS of the Or is a constant, we may have something like: + // X*4+1 which got turned into X*4|1. Handle this as an Add so loop + // optimizations will transparently handle this case. + // + // In order for this transformation to be safe, the LHS must be of the + // form X*(2^n) and the Or constant must be less than 2^n. + if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) { + const SCEV *LHS = getSCEV(BO->LHS); + const APInt &CIVal = CI->getValue(); + if (GetMinTrailingZeros(LHS) >= + (CIVal.getBitWidth() - CIVal.countLeadingZeros())) { + // Build a plain add SCEV. + const SCEV *S = getAddExpr(LHS, getSCEV(CI)); + // If the LHS of the add was an addrec and it has no-wrap flags, + // transfer the no-wrap flags, since an or won't introduce a wrap. + if (const SCEVAddRecExpr *NewAR = dyn_cast<SCEVAddRecExpr>(S)) { + const SCEVAddRecExpr *OldAR = cast<SCEVAddRecExpr>(LHS); + const_cast<SCEVAddRecExpr *>(NewAR)->setNoWrapFlags( + OldAR->getNoWrapFlags()); + } + return S; + } } - } - break; + break; - case Instruction::Or: - // If the RHS of the Or is a constant, we may have something like: - // X*4+1 which got turned into X*4|1. Handle this as an Add so loop - // optimizations will transparently handle this case. - // - // In order for this transformation to be safe, the LHS must be of the - // form X*(2^n) and the Or constant must be less than 2^n. - if (ConstantInt *CI = dyn_cast<ConstantInt>(U->getOperand(1))) { - const SCEV *LHS = getSCEV(U->getOperand(0)); - const APInt &CIVal = CI->getValue(); - if (GetMinTrailingZeros(LHS) >= - (CIVal.getBitWidth() - CIVal.countLeadingZeros())) { - // Build a plain add SCEV. - const SCEV *S = getAddExpr(LHS, getSCEV(CI)); - // If the LHS of the add was an addrec and it has no-wrap flags, - // transfer the no-wrap flags, since an or won't introduce a wrap. - if (const SCEVAddRecExpr *NewAR = dyn_cast<SCEVAddRecExpr>(S)) { - const SCEVAddRecExpr *OldAR = cast<SCEVAddRecExpr>(LHS); - const_cast<SCEVAddRecExpr *>(NewAR)->setNoWrapFlags( - OldAR->getNoWrapFlags()); - } - return S; + case Instruction::Xor: + if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) { + // If the RHS of xor is -1, then this is a not operation. + if (CI->isAllOnesValue()) + return getNotSCEV(getSCEV(BO->LHS)); + + // Model xor(and(x, C), C) as and(~x, C), if C is a low-bits mask. + // This is a variant of the check for xor with -1, and it handles + // the case where instcombine has trimmed non-demanded bits out + // of an xor with -1. + if (auto *LBO = dyn_cast<BinaryOperator>(BO->LHS)) + if (ConstantInt *LCI = dyn_cast<ConstantInt>(LBO->getOperand(1))) + if (LBO->getOpcode() == Instruction::And && + LCI->getValue() == CI->getValue()) + if (const SCEVZeroExtendExpr *Z = + dyn_cast<SCEVZeroExtendExpr>(getSCEV(BO->LHS))) { + Type *UTy = BO->LHS->getType(); + const SCEV *Z0 = Z->getOperand(); + Type *Z0Ty = Z0->getType(); + unsigned Z0TySize = getTypeSizeInBits(Z0Ty); + + // If C is a low-bits mask, the zero extend is serving to + // mask off the high bits. Complement the operand and + // re-apply the zext. + if (APIntOps::isMask(Z0TySize, CI->getValue())) + return getZeroExtendExpr(getNotSCEV(Z0), UTy); + + // If C is a single bit, it may be in the sign-bit position + // before the zero-extend. In this case, represent the xor + // using an add, which is equivalent, and re-apply the zext. + APInt Trunc = CI->getValue().trunc(Z0TySize); + if (Trunc.zext(getTypeSizeInBits(UTy)) == CI->getValue() && + Trunc.isSignBit()) + return getZeroExtendExpr(getAddExpr(Z0, getConstant(Trunc)), + UTy); + } } - } - break; - case Instruction::Xor: - if (ConstantInt *CI = dyn_cast<ConstantInt>(U->getOperand(1))) { - // If the RHS of the xor is a signbit, then this is just an add. - // Instcombine turns add of signbit into xor as a strength reduction step. - if (CI->getValue().isSignBit()) - return getAddExpr(getSCEV(U->getOperand(0)), - getSCEV(U->getOperand(1))); - - // If the RHS of xor is -1, then this is a not operation. - if (CI->isAllOnesValue()) - return getNotSCEV(getSCEV(U->getOperand(0))); - - // Model xor(and(x, C), C) as and(~x, C), if C is a low-bits mask. - // This is a variant of the check for xor with -1, and it handles - // the case where instcombine has trimmed non-demanded bits out - // of an xor with -1. - if (BinaryOperator *BO = dyn_cast<BinaryOperator>(U->getOperand(0))) - if (ConstantInt *LCI = dyn_cast<ConstantInt>(BO->getOperand(1))) - if (BO->getOpcode() == Instruction::And && - LCI->getValue() == CI->getValue()) - if (const SCEVZeroExtendExpr *Z = - dyn_cast<SCEVZeroExtendExpr>(getSCEV(U->getOperand(0)))) { - Type *UTy = U->getType(); - const SCEV *Z0 = Z->getOperand(); - Type *Z0Ty = Z0->getType(); - unsigned Z0TySize = getTypeSizeInBits(Z0Ty); - - // If C is a low-bits mask, the zero extend is serving to - // mask off the high bits. Complement the operand and - // re-apply the zext. - if (APIntOps::isMask(Z0TySize, CI->getValue())) - return getZeroExtendExpr(getNotSCEV(Z0), UTy); - - // If C is a single bit, it may be in the sign-bit position - // before the zero-extend. In this case, represent the xor - // using an add, which is equivalent, and re-apply the zext. - APInt Trunc = CI->getValue().trunc(Z0TySize); - if (Trunc.zext(getTypeSizeInBits(UTy)) == CI->getValue() && - Trunc.isSignBit()) - return getZeroExtendExpr(getAddExpr(Z0, getConstant(Trunc)), - UTy); - } - } - break; + break; case Instruction::Shl: // Turn shift left of a constant amount into a multiply. - if (ConstantInt *SA = dyn_cast<ConstantInt>(U->getOperand(1))) { - uint32_t BitWidth = cast<IntegerType>(U->getType())->getBitWidth(); + if (ConstantInt *SA = dyn_cast<ConstantInt>(BO->RHS)) { + uint32_t BitWidth = cast<IntegerType>(SA->getType())->getBitWidth(); // If the shift count is not less than the bitwidth, the result of // the shift is undefined. Don't try to analyze it, because the @@ -4700,58 +5158,43 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { // http://lists.llvm.org/pipermail/llvm-dev/2015-April/084195.html // and http://reviews.llvm.org/D8890 . auto Flags = SCEV::FlagAnyWrap; - if (SA->getValue().ult(BitWidth - 1)) Flags = getNoWrapFlagsFromUB(U); + if (BO->Op && SA->getValue().ult(BitWidth - 1)) + Flags = getNoWrapFlagsFromUB(BO->Op); Constant *X = ConstantInt::get(getContext(), APInt::getOneBitSet(BitWidth, SA->getZExtValue())); - return getMulExpr(getSCEV(U->getOperand(0)), getSCEV(X), Flags); + return getMulExpr(getSCEV(BO->LHS), getSCEV(X), Flags); } break; - case Instruction::LShr: - // Turn logical shift right of a constant into a unsigned divide. - if (ConstantInt *SA = dyn_cast<ConstantInt>(U->getOperand(1))) { - uint32_t BitWidth = cast<IntegerType>(U->getType())->getBitWidth(); - - // If the shift count is not less than the bitwidth, the result of - // the shift is undefined. Don't try to analyze it, because the - // resolution chosen here may differ from the resolution chosen in - // other parts of the compiler. - if (SA->getValue().uge(BitWidth)) - break; + case Instruction::AShr: + // For a two-shift sext-inreg, use sext(trunc(x)) as the SCEV expression. + if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) + if (Operator *L = dyn_cast<Operator>(BO->LHS)) + if (L->getOpcode() == Instruction::Shl && + L->getOperand(1) == BO->RHS) { + uint64_t BitWidth = getTypeSizeInBits(BO->LHS->getType()); + + // If the shift count is not less than the bitwidth, the result of + // the shift is undefined. Don't try to analyze it, because the + // resolution chosen here may differ from the resolution chosen in + // other parts of the compiler. + if (CI->getValue().uge(BitWidth)) + break; - Constant *X = ConstantInt::get(getContext(), - APInt::getOneBitSet(BitWidth, SA->getZExtValue())); - return getUDivExpr(getSCEV(U->getOperand(0)), getSCEV(X)); + uint64_t Amt = BitWidth - CI->getZExtValue(); + if (Amt == BitWidth) + return getSCEV(L->getOperand(0)); // shift by zero --> noop + return getSignExtendExpr( + getTruncateExpr(getSCEV(L->getOperand(0)), + IntegerType::get(getContext(), Amt)), + BO->LHS->getType()); + } + break; } - break; - - case Instruction::AShr: - // For a two-shift sext-inreg, use sext(trunc(x)) as the SCEV expression. - if (ConstantInt *CI = dyn_cast<ConstantInt>(U->getOperand(1))) - if (Operator *L = dyn_cast<Operator>(U->getOperand(0))) - if (L->getOpcode() == Instruction::Shl && - L->getOperand(1) == U->getOperand(1)) { - uint64_t BitWidth = getTypeSizeInBits(U->getType()); - - // If the shift count is not less than the bitwidth, the result of - // the shift is undefined. Don't try to analyze it, because the - // resolution chosen here may differ from the resolution chosen in - // other parts of the compiler. - if (CI->getValue().uge(BitWidth)) - break; - - uint64_t Amt = BitWidth - CI->getZExtValue(); - if (Amt == BitWidth) - return getSCEV(L->getOperand(0)); // shift by zero --> noop - return - getSignExtendExpr(getTruncateExpr(getSCEV(L->getOperand(0)), - IntegerType::get(getContext(), - Amt)), - U->getType()); - } - break; + } + switch (U->getOpcode()) { case Instruction::Trunc: return getTruncateExpr(getSCEV(U->getOperand(0)), U->getType()); @@ -4786,8 +5229,12 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { if (isa<Instruction>(U)) return createNodeForSelectOrPHI(cast<Instruction>(U), U->getOperand(0), U->getOperand(1), U->getOperand(2)); + break; - default: // We cannot analyze this expression. + case Instruction::Call: + case Instruction::Invoke: + if (Value *RV = CallSite(U).getReturnedArgOperand()) + return getSCEV(RV); break; } @@ -4808,16 +5255,6 @@ unsigned ScalarEvolution::getSmallConstantTripCount(Loop *L) { return 0; } -/// getSmallConstantTripCount - Returns the maximum trip count of this loop as a -/// normal unsigned value. Returns 0 if the trip count is unknown or not -/// constant. Will also return 0 if the maximum trip count is very large (>= -/// 2^32). -/// -/// This "trip count" assumes that control exits via ExitingBlock. More -/// precisely, it is the number of times that control may reach ExitingBlock -/// before taking the branch. For loops with multiple exits, it may not be the -/// number times that the loop header executes because the loop may exit -/// prematurely via another branch. unsigned ScalarEvolution::getSmallConstantTripCount(Loop *L, BasicBlock *ExitingBlock) { assert(ExitingBlock && "Must pass a non-null exiting block!"); @@ -4846,10 +5283,10 @@ unsigned ScalarEvolution::getSmallConstantTripMultiple(Loop *L) { return 0; } -/// getSmallConstantTripMultiple - Returns the largest constant divisor of the -/// trip count of this loop as a normal unsigned value, if possible. This -/// means that the actual trip count is always a multiple of the returned -/// value (don't forget the trip count could very well be zero as well!). +/// Returns the largest constant divisor of the trip count of this loop as a +/// normal unsigned value, if possible. This means that the actual trip count is +/// always a multiple of the returned value (don't forget the trip count could +/// very well be zero as well!). /// /// Returns 1 if the trip count is unknown or not guaranteed to be the /// multiple of a constant (which is also the case if the trip count is simply @@ -4891,37 +5328,30 @@ ScalarEvolution::getSmallConstantTripMultiple(Loop *L, return (unsigned)Result->getZExtValue(); } -// getExitCount - Get the expression for the number of loop iterations for which -// this loop is guaranteed not to exit via ExitingBlock. Otherwise return -// SCEVCouldNotCompute. +/// Get the expression for the number of loop iterations for which this loop is +/// guaranteed not to exit via ExitingBlock. Otherwise return +/// SCEVCouldNotCompute. const SCEV *ScalarEvolution::getExitCount(Loop *L, BasicBlock *ExitingBlock) { return getBackedgeTakenInfo(L).getExact(ExitingBlock, this); } -/// getBackedgeTakenCount - If the specified loop has a predictable -/// backedge-taken count, return it, otherwise return a SCEVCouldNotCompute -/// object. The backedge-taken count is the number of times the loop header -/// will be branched to from within the loop. This is one less than the -/// trip count of the loop, since it doesn't count the first iteration, -/// when the header is branched to from outside the loop. -/// -/// Note that it is not valid to call this method on a loop without a -/// loop-invariant backedge-taken count (see -/// hasLoopInvariantBackedgeTakenCount). -/// +const SCEV * +ScalarEvolution::getPredicatedBackedgeTakenCount(const Loop *L, + SCEVUnionPredicate &Preds) { + return getPredicatedBackedgeTakenInfo(L).getExact(this, &Preds); +} + const SCEV *ScalarEvolution::getBackedgeTakenCount(const Loop *L) { return getBackedgeTakenInfo(L).getExact(this); } -/// getMaxBackedgeTakenCount - Similar to getBackedgeTakenCount, except -/// return the least SCEV value that is known never to be less than the -/// actual backedge taken count. +/// Similar to getBackedgeTakenCount, except return the least SCEV value that is +/// known never to be less than the actual backedge taken count. const SCEV *ScalarEvolution::getMaxBackedgeTakenCount(const Loop *L) { return getBackedgeTakenInfo(L).getMax(this); } -/// PushLoopPHIs - Push PHI nodes in the header of the given loop -/// onto the given Worklist. +/// Push PHI nodes in the header of the given loop onto the given Worklist. static void PushLoopPHIs(const Loop *L, SmallVectorImpl<Instruction *> &Worklist) { BasicBlock *Header = L->getHeader(); @@ -4933,6 +5363,23 @@ PushLoopPHIs(const Loop *L, SmallVectorImpl<Instruction *> &Worklist) { } const ScalarEvolution::BackedgeTakenInfo & +ScalarEvolution::getPredicatedBackedgeTakenInfo(const Loop *L) { + auto &BTI = getBackedgeTakenInfo(L); + if (BTI.hasFullInfo()) + return BTI; + + auto Pair = PredicatedBackedgeTakenCounts.insert({L, BackedgeTakenInfo()}); + + if (!Pair.second) + return Pair.first->second; + + BackedgeTakenInfo Result = + computeBackedgeTakenCount(L, /*AllowPredicates=*/true); + + return PredicatedBackedgeTakenCounts.find(L)->second = Result; +} + +const ScalarEvolution::BackedgeTakenInfo & ScalarEvolution::getBackedgeTakenInfo(const Loop *L) { // Initially insert an invalid entry for this loop. If the insertion // succeeds, proceed to actually compute a backedge-taken count and @@ -4940,7 +5387,7 @@ ScalarEvolution::getBackedgeTakenInfo(const Loop *L) { // code elsewhere that it shouldn't attempt to request a new // backedge-taken count, which could result in infinite recursion. std::pair<DenseMap<const Loop *, BackedgeTakenInfo>::iterator, bool> Pair = - BackedgeTakenCounts.insert(std::make_pair(L, BackedgeTakenInfo())); + BackedgeTakenCounts.insert({L, BackedgeTakenInfo()}); if (!Pair.second) return Pair.first->second; @@ -5007,17 +5454,19 @@ ScalarEvolution::getBackedgeTakenInfo(const Loop *L) { return BackedgeTakenCounts.find(L)->second = Result; } -/// forgetLoop - This method should be called by the client when it has -/// changed a loop in a way that may effect ScalarEvolution's ability to -/// compute a trip count, or if the loop is deleted. void ScalarEvolution::forgetLoop(const Loop *L) { // Drop any stored trip count value. - DenseMap<const Loop*, BackedgeTakenInfo>::iterator BTCPos = - BackedgeTakenCounts.find(L); - if (BTCPos != BackedgeTakenCounts.end()) { - BTCPos->second.clear(); - BackedgeTakenCounts.erase(BTCPos); - } + auto RemoveLoopFromBackedgeMap = + [L](DenseMap<const Loop *, BackedgeTakenInfo> &Map) { + auto BTCPos = Map.find(L); + if (BTCPos != Map.end()) { + BTCPos->second.clear(); + Map.erase(BTCPos); + } + }; + + RemoveLoopFromBackedgeMap(BackedgeTakenCounts); + RemoveLoopFromBackedgeMap(PredicatedBackedgeTakenCounts); // Drop information about expressions based on loop-header PHIs. SmallVector<Instruction *, 16> Worklist; @@ -5043,13 +5492,12 @@ void ScalarEvolution::forgetLoop(const Loop *L) { // Forget all contained loops too, to avoid dangling entries in the // ValuesAtScopes map. - for (Loop::iterator I = L->begin(), E = L->end(); I != E; ++I) - forgetLoop(*I); + for (Loop *I : *L) + forgetLoop(I); + + LoopHasNoAbnormalExits.erase(L); } -/// forgetValue - This method should be called by the client when it has -/// changed a value in a way that may effect its value, or which may -/// disconnect it from a def-use chain linking it to a loop. void ScalarEvolution::forgetValue(Value *V) { Instruction *I = dyn_cast<Instruction>(V); if (!I) return; @@ -5077,16 +5525,17 @@ void ScalarEvolution::forgetValue(Value *V) { } } -/// getExact - Get the exact loop backedge taken count considering all loop -/// exits. A computable result can only be returned for loops with a single -/// exit. Returning the minimum taken count among all exits is incorrect -/// because one of the loop's exit limit's may have been skipped. HowFarToZero -/// assumes that the limit of each loop test is never skipped. This is a valid -/// assumption as long as the loop exits via that test. For precise results, it -/// is the caller's responsibility to specify the relevant loop exit using +/// Get the exact loop backedge taken count considering all loop exits. A +/// computable result can only be returned for loops with a single exit. +/// Returning the minimum taken count among all exits is incorrect because one +/// of the loop's exit limit's may have been skipped. howFarToZero assumes that +/// the limit of each loop test is never skipped. This is a valid assumption as +/// long as the loop exits via that test. For precise results, it is the +/// caller's responsibility to specify the relevant loop exit using /// getExact(ExitingBlock, SE). const SCEV * -ScalarEvolution::BackedgeTakenInfo::getExact(ScalarEvolution *SE) const { +ScalarEvolution::BackedgeTakenInfo::getExact( + ScalarEvolution *SE, SCEVUnionPredicate *Preds) const { // If any exits were not computable, the loop is not computable. if (!ExitNotTaken.isCompleteList()) return SE->getCouldNotCompute(); @@ -5095,36 +5544,42 @@ ScalarEvolution::BackedgeTakenInfo::getExact(ScalarEvolution *SE) const { assert(ExitNotTaken.ExactNotTaken && "uninitialized not-taken info"); const SCEV *BECount = nullptr; - for (const ExitNotTakenInfo *ENT = &ExitNotTaken; - ENT != nullptr; ENT = ENT->getNextExit()) { - - assert(ENT->ExactNotTaken != SE->getCouldNotCompute() && "bad exit SCEV"); + for (auto &ENT : ExitNotTaken) { + assert(ENT.ExactNotTaken != SE->getCouldNotCompute() && "bad exit SCEV"); if (!BECount) - BECount = ENT->ExactNotTaken; - else if (BECount != ENT->ExactNotTaken) + BECount = ENT.ExactNotTaken; + else if (BECount != ENT.ExactNotTaken) return SE->getCouldNotCompute(); + if (Preds && ENT.getPred()) + Preds->add(ENT.getPred()); + + assert((Preds || ENT.hasAlwaysTruePred()) && + "Predicate should be always true!"); } + assert(BECount && "Invalid not taken count for loop exit"); return BECount; } -/// getExact - Get the exact not taken count for this loop exit. +/// Get the exact not taken count for this loop exit. const SCEV * ScalarEvolution::BackedgeTakenInfo::getExact(BasicBlock *ExitingBlock, ScalarEvolution *SE) const { - for (const ExitNotTakenInfo *ENT = &ExitNotTaken; - ENT != nullptr; ENT = ENT->getNextExit()) { + for (auto &ENT : ExitNotTaken) + if (ENT.ExitingBlock == ExitingBlock && ENT.hasAlwaysTruePred()) + return ENT.ExactNotTaken; - if (ENT->ExitingBlock == ExitingBlock) - return ENT->ExactNotTaken; - } return SE->getCouldNotCompute(); } /// getMax - Get the max backedge taken count for the loop. const SCEV * ScalarEvolution::BackedgeTakenInfo::getMax(ScalarEvolution *SE) const { + for (auto &ENT : ExitNotTaken) + if (!ENT.hasAlwaysTruePred()) + return SE->getCouldNotCompute(); + return Max ? Max : SE->getCouldNotCompute(); } @@ -5136,22 +5591,19 @@ bool ScalarEvolution::BackedgeTakenInfo::hasOperand(const SCEV *S, if (!ExitNotTaken.ExitingBlock) return false; - for (const ExitNotTakenInfo *ENT = &ExitNotTaken; - ENT != nullptr; ENT = ENT->getNextExit()) { - - if (ENT->ExactNotTaken != SE->getCouldNotCompute() - && SE->hasOperand(ENT->ExactNotTaken, S)) { + for (auto &ENT : ExitNotTaken) + if (ENT.ExactNotTaken != SE->getCouldNotCompute() && + SE->hasOperand(ENT.ExactNotTaken, S)) return true; - } - } + return false; } /// Allocate memory for BackedgeTakenInfo and copy the not-taken count of each /// computable exit into a persistent ExitNotTakenInfo array. ScalarEvolution::BackedgeTakenInfo::BackedgeTakenInfo( - SmallVectorImpl< std::pair<BasicBlock *, const SCEV *> > &ExitCounts, - bool Complete, const SCEV *MaxCount) : Max(MaxCount) { + SmallVectorImpl<EdgeInfo> &ExitCounts, bool Complete, const SCEV *MaxCount) + : Max(MaxCount) { if (!Complete) ExitNotTaken.setIncomplete(); @@ -5159,36 +5611,63 @@ ScalarEvolution::BackedgeTakenInfo::BackedgeTakenInfo( unsigned NumExits = ExitCounts.size(); if (NumExits == 0) return; - ExitNotTaken.ExitingBlock = ExitCounts[0].first; - ExitNotTaken.ExactNotTaken = ExitCounts[0].second; - if (NumExits == 1) return; + ExitNotTaken.ExitingBlock = ExitCounts[0].ExitBlock; + ExitNotTaken.ExactNotTaken = ExitCounts[0].Taken; + + // Determine the number of ExitNotTakenExtras structures that we need. + unsigned ExtraInfoSize = 0; + if (NumExits > 1) + ExtraInfoSize = 1 + std::count_if(std::next(ExitCounts.begin()), + ExitCounts.end(), [](EdgeInfo &Entry) { + return !Entry.Pred.isAlwaysTrue(); + }); + else if (!ExitCounts[0].Pred.isAlwaysTrue()) + ExtraInfoSize = 1; + + ExitNotTakenExtras *ENT = nullptr; + + // Allocate the ExitNotTakenExtras structures and initialize the first + // element (ExitNotTaken). + if (ExtraInfoSize > 0) { + ENT = new ExitNotTakenExtras[ExtraInfoSize]; + ExitNotTaken.ExtraInfo = &ENT[0]; + *ExitNotTaken.getPred() = std::move(ExitCounts[0].Pred); + } + + if (NumExits == 1) + return; + + assert(ENT && "ExitNotTakenExtras is NULL while having more than one exit"); + + auto &Exits = ExitNotTaken.ExtraInfo->Exits; // Handle the rare case of multiple computable exits. - ExitNotTakenInfo *ENT = new ExitNotTakenInfo[NumExits-1]; + for (unsigned i = 1, PredPos = 1; i < NumExits; ++i) { + ExitNotTakenExtras *Ptr = nullptr; + if (!ExitCounts[i].Pred.isAlwaysTrue()) { + Ptr = &ENT[PredPos++]; + Ptr->Pred = std::move(ExitCounts[i].Pred); + } - ExitNotTakenInfo *PrevENT = &ExitNotTaken; - for (unsigned i = 1; i < NumExits; ++i, PrevENT = ENT, ++ENT) { - PrevENT->setNextExit(ENT); - ENT->ExitingBlock = ExitCounts[i].first; - ENT->ExactNotTaken = ExitCounts[i].second; + Exits.emplace_back(ExitCounts[i].ExitBlock, ExitCounts[i].Taken, Ptr); } } -/// clear - Invalidate this result and free the ExitNotTakenInfo array. +/// Invalidate this result and free the ExitNotTakenInfo array. void ScalarEvolution::BackedgeTakenInfo::clear() { ExitNotTaken.ExitingBlock = nullptr; ExitNotTaken.ExactNotTaken = nullptr; - delete[] ExitNotTaken.getNextExit(); + delete[] ExitNotTaken.ExtraInfo; } -/// computeBackedgeTakenCount - Compute the number of times the backedge -/// of the specified loop will execute. +/// Compute the number of times the backedge of the specified loop will execute. ScalarEvolution::BackedgeTakenInfo -ScalarEvolution::computeBackedgeTakenCount(const Loop *L) { +ScalarEvolution::computeBackedgeTakenCount(const Loop *L, + bool AllowPredicates) { SmallVector<BasicBlock *, 8> ExitingBlocks; L->getExitingBlocks(ExitingBlocks); - SmallVector<std::pair<BasicBlock *, const SCEV *>, 4> ExitCounts; + SmallVector<EdgeInfo, 4> ExitCounts; bool CouldComputeBECount = true; BasicBlock *Latch = L->getLoopLatch(); // may be NULL. const SCEV *MustExitMaxBECount = nullptr; @@ -5196,9 +5675,13 @@ ScalarEvolution::computeBackedgeTakenCount(const Loop *L) { // Compute the ExitLimit for each loop exit. Use this to populate ExitCounts // and compute maxBECount. + // Do a union of all the predicates here. for (unsigned i = 0, e = ExitingBlocks.size(); i != e; ++i) { BasicBlock *ExitBB = ExitingBlocks[i]; - ExitLimit EL = computeExitLimit(L, ExitBB); + ExitLimit EL = computeExitLimit(L, ExitBB, AllowPredicates); + + assert((AllowPredicates || EL.Pred.isAlwaysTrue()) && + "Predicated exit limit when predicates are not allowed!"); // 1. For each exit that can be computed, add an entry to ExitCounts. // CouldComputeBECount is true only if all exits can be computed. @@ -5207,7 +5690,7 @@ ScalarEvolution::computeBackedgeTakenCount(const Loop *L) { // we won't be able to compute an exact value for the loop. CouldComputeBECount = false; else - ExitCounts.push_back(std::make_pair(ExitBB, EL.Exact)); + ExitCounts.emplace_back(EdgeInfo(ExitBB, EL.Exact, EL.Pred)); // 2. Derive the loop's MaxBECount from each exit's max number of // non-exiting iterations. Partition the loop exits into two kinds: @@ -5241,20 +5724,20 @@ ScalarEvolution::computeBackedgeTakenCount(const Loop *L) { } ScalarEvolution::ExitLimit -ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock) { +ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock, + bool AllowPredicates) { // Okay, we've chosen an exiting block. See what condition causes us to exit // at this block and remember the exit block and whether all other targets // lead to the loop header. bool MustExecuteLoopHeader = true; BasicBlock *Exit = nullptr; - for (succ_iterator SI = succ_begin(ExitingBlock), SE = succ_end(ExitingBlock); - SI != SE; ++SI) - if (!L->contains(*SI)) { + for (auto *SBB : successors(ExitingBlock)) + if (!L->contains(SBB)) { if (Exit) // Multiple exit successors. return getCouldNotCompute(); - Exit = *SI; - } else if (*SI != L->getHeader()) { + Exit = SBB; + } else if (SBB != L->getHeader()) { MustExecuteLoopHeader = false; } @@ -5307,9 +5790,9 @@ ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock) { if (BranchInst *BI = dyn_cast<BranchInst>(Term)) { assert(BI->isConditional() && "If unconditional, it can't be in loop!"); // Proceed to the next level to examine the exit condition expression. - return computeExitLimitFromCond(L, BI->getCondition(), BI->getSuccessor(0), - BI->getSuccessor(1), - /*ControlsExit=*/IsOnlyExit); + return computeExitLimitFromCond( + L, BI->getCondition(), BI->getSuccessor(0), BI->getSuccessor(1), + /*ControlsExit=*/IsOnlyExit, AllowPredicates); } if (SwitchInst *SI = dyn_cast<SwitchInst>(Term)) @@ -5319,29 +5802,24 @@ ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock) { return getCouldNotCompute(); } -/// computeExitLimitFromCond - Compute the number of times the -/// backedge of the specified loop will execute if its exit condition -/// were a conditional branch of ExitCond, TBB, and FBB. -/// -/// @param ControlsExit is true if ExitCond directly controls the exit -/// branch. In this case, we can assume that the loop exits only if the -/// condition is true and can infer that failing to meet the condition prior to -/// integer wraparound results in undefined behavior. ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCond(const Loop *L, Value *ExitCond, BasicBlock *TBB, BasicBlock *FBB, - bool ControlsExit) { + bool ControlsExit, + bool AllowPredicates) { // Check if the controlling expression for this loop is an And or Or. if (BinaryOperator *BO = dyn_cast<BinaryOperator>(ExitCond)) { if (BO->getOpcode() == Instruction::And) { // Recurse on the operands of the and. bool EitherMayExit = L->contains(TBB); ExitLimit EL0 = computeExitLimitFromCond(L, BO->getOperand(0), TBB, FBB, - ControlsExit && !EitherMayExit); + ControlsExit && !EitherMayExit, + AllowPredicates); ExitLimit EL1 = computeExitLimitFromCond(L, BO->getOperand(1), TBB, FBB, - ControlsExit && !EitherMayExit); + ControlsExit && !EitherMayExit, + AllowPredicates); const SCEV *BECount = getCouldNotCompute(); const SCEV *MaxBECount = getCouldNotCompute(); if (EitherMayExit) { @@ -5368,6 +5846,9 @@ ScalarEvolution::computeExitLimitFromCond(const Loop *L, BECount = EL0.Exact; } + SCEVUnionPredicate NP; + NP.add(&EL0.Pred); + NP.add(&EL1.Pred); // There are cases (e.g. PR26207) where computeExitLimitFromCond is able // to be more aggressive when computing BECount than when computing // MaxBECount. In these cases it is possible for EL0.Exact and EL1.Exact @@ -5376,15 +5857,17 @@ ScalarEvolution::computeExitLimitFromCond(const Loop *L, !isa<SCEVCouldNotCompute>(BECount)) MaxBECount = BECount; - return ExitLimit(BECount, MaxBECount); + return ExitLimit(BECount, MaxBECount, NP); } if (BO->getOpcode() == Instruction::Or) { // Recurse on the operands of the or. bool EitherMayExit = L->contains(FBB); ExitLimit EL0 = computeExitLimitFromCond(L, BO->getOperand(0), TBB, FBB, - ControlsExit && !EitherMayExit); + ControlsExit && !EitherMayExit, + AllowPredicates); ExitLimit EL1 = computeExitLimitFromCond(L, BO->getOperand(1), TBB, FBB, - ControlsExit && !EitherMayExit); + ControlsExit && !EitherMayExit, + AllowPredicates); const SCEV *BECount = getCouldNotCompute(); const SCEV *MaxBECount = getCouldNotCompute(); if (EitherMayExit) { @@ -5411,14 +5894,25 @@ ScalarEvolution::computeExitLimitFromCond(const Loop *L, BECount = EL0.Exact; } - return ExitLimit(BECount, MaxBECount); + SCEVUnionPredicate NP; + NP.add(&EL0.Pred); + NP.add(&EL1.Pred); + return ExitLimit(BECount, MaxBECount, NP); } } // With an icmp, it may be feasible to compute an exact backedge-taken count. // Proceed to the next level to examine the icmp. - if (ICmpInst *ExitCondICmp = dyn_cast<ICmpInst>(ExitCond)) - return computeExitLimitFromICmp(L, ExitCondICmp, TBB, FBB, ControlsExit); + if (ICmpInst *ExitCondICmp = dyn_cast<ICmpInst>(ExitCond)) { + ExitLimit EL = + computeExitLimitFromICmp(L, ExitCondICmp, TBB, FBB, ControlsExit); + if (EL.hasFullInfo() || !AllowPredicates) + return EL; + + // Try again, but use SCEV predicates this time. + return computeExitLimitFromICmp(L, ExitCondICmp, TBB, FBB, ControlsExit, + /*AllowPredicates=*/true); + } // Check for a constant condition. These are normally stripped out by // SimplifyCFG, but ScalarEvolution may be used by a pass which wishes to @@ -5442,7 +5936,8 @@ ScalarEvolution::computeExitLimitFromICmp(const Loop *L, ICmpInst *ExitCond, BasicBlock *TBB, BasicBlock *FBB, - bool ControlsExit) { + bool ControlsExit, + bool AllowPredicates) { // If the condition was exit on true, convert the condition to exit on false ICmpInst::Predicate Cond; @@ -5460,11 +5955,6 @@ ScalarEvolution::computeExitLimitFromICmp(const Loop *L, return ItCnt; } - ExitLimit ShiftEL = computeShiftCompareExitLimit( - ExitCond->getOperand(0), ExitCond->getOperand(1), L, Cond); - if (ShiftEL.hasAnyInfo()) - return ShiftEL; - const SCEV *LHS = getSCEV(ExitCond->getOperand(0)); const SCEV *RHS = getSCEV(ExitCond->getOperand(1)); @@ -5499,34 +5989,46 @@ ScalarEvolution::computeExitLimitFromICmp(const Loop *L, switch (Cond) { case ICmpInst::ICMP_NE: { // while (X != Y) // Convert to: while (X-Y != 0) - ExitLimit EL = HowFarToZero(getMinusSCEV(LHS, RHS), L, ControlsExit); + ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsExit, + AllowPredicates); if (EL.hasAnyInfo()) return EL; break; } case ICmpInst::ICMP_EQ: { // while (X == Y) // Convert to: while (X-Y == 0) - ExitLimit EL = HowFarToNonZero(getMinusSCEV(LHS, RHS), L); + ExitLimit EL = howFarToNonZero(getMinusSCEV(LHS, RHS), L); if (EL.hasAnyInfo()) return EL; break; } case ICmpInst::ICMP_SLT: case ICmpInst::ICMP_ULT: { // while (X < Y) bool IsSigned = Cond == ICmpInst::ICMP_SLT; - ExitLimit EL = HowManyLessThans(LHS, RHS, L, IsSigned, ControlsExit); + ExitLimit EL = howManyLessThans(LHS, RHS, L, IsSigned, ControlsExit, + AllowPredicates); if (EL.hasAnyInfo()) return EL; break; } case ICmpInst::ICMP_SGT: case ICmpInst::ICMP_UGT: { // while (X > Y) bool IsSigned = Cond == ICmpInst::ICMP_SGT; - ExitLimit EL = HowManyGreaterThans(LHS, RHS, L, IsSigned, ControlsExit); + ExitLimit EL = + howManyGreaterThans(LHS, RHS, L, IsSigned, ControlsExit, + AllowPredicates); if (EL.hasAnyInfo()) return EL; break; } default: break; } - return computeExitCountExhaustively(L, ExitCond, !L->contains(TBB)); + + auto *ExhaustiveCount = + computeExitCountExhaustively(L, ExitCond, !L->contains(TBB)); + + if (!isa<SCEVCouldNotCompute>(ExhaustiveCount)) + return ExhaustiveCount; + + return computeShiftCompareExitLimit(ExitCond->getOperand(0), + ExitCond->getOperand(1), L, Cond); } ScalarEvolution::ExitLimit @@ -5546,7 +6048,7 @@ ScalarEvolution::computeExitLimitFromSingleExitSwitch(const Loop *L, const SCEV *RHS = getConstant(Switch->findCaseDest(ExitingBlock)); // while (X != Y) --> while (X-Y != 0) - ExitLimit EL = HowFarToZero(getMinusSCEV(LHS, RHS), L, ControlsExit); + ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsExit); if (EL.hasAnyInfo()) return EL; @@ -5563,9 +6065,8 @@ EvaluateConstantChrecAtConstant(const SCEVAddRecExpr *AddRec, ConstantInt *C, return cast<SCEVConstant>(Val)->getValue(); } -/// computeLoadConstantCompareExitLimit - Given an exit condition of -/// 'icmp op load X, cst', try to see if we can compute the backedge -/// execution count. +/// Given an exit condition of 'icmp op load X, cst', try to see if we can +/// compute the backedge execution count. ScalarEvolution::ExitLimit ScalarEvolution::computeLoadConstantCompareExitLimit( LoadInst *LI, @@ -5781,14 +6282,15 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeShiftCompareExitLimit( unsigned BitWidth = getTypeSizeInBits(RHS->getType()); const SCEV *UpperBound = getConstant(getEffectiveSCEVType(RHS->getType()), BitWidth); - return ExitLimit(getCouldNotCompute(), UpperBound); + SCEVUnionPredicate P; + return ExitLimit(getCouldNotCompute(), UpperBound, P); } return getCouldNotCompute(); } -/// CanConstantFold - Return true if we can constant fold an instruction of the -/// specified type, assuming that all operands were constants. +/// Return true if we can constant fold an instruction of the specified type, +/// assuming that all operands were constants. static bool CanConstantFold(const Instruction *I) { if (isa<BinaryOperator>(I) || isa<CmpInst>(I) || isa<SelectInst>(I) || isa<CastInst>(I) || isa<GetElementPtrInst>(I) || @@ -5916,10 +6418,9 @@ static Constant *EvaluateExpression(Value *V, const Loop *L, Operands[1], DL, TLI); if (LoadInst *LI = dyn_cast<LoadInst>(I)) { if (!LI->isVolatile()) - return ConstantFoldLoadFromConstPtr(Operands[0], DL); + return ConstantFoldLoadFromConstPtr(Operands[0], LI->getType(), DL); } - return ConstantFoldInstOperands(I->getOpcode(), I->getType(), Operands, DL, - TLI); + return ConstantFoldInstOperands(I, Operands, DL, TLI); } @@ -6107,16 +6608,6 @@ const SCEV *ScalarEvolution::computeExitCountExhaustively(const Loop *L, return getCouldNotCompute(); } -/// getSCEVAtScope - Return a SCEV expression for the specified value -/// at the specified scope in the program. The L value specifies a loop -/// nest to evaluate the expression at, where null is the top-level or a -/// specified loop is immediately inside of the loop. -/// -/// This method can be used to compute the exit value for a variable defined -/// in a loop by querying what the value will hold in the parent loop. -/// -/// In the case that a relevant loop exit value cannot be computed, the -/// original value V is returned. const SCEV *ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) { SmallVector<std::pair<const Loop *, const SCEV *>, 2> &Values = ValuesAtScopes[V]; @@ -6305,10 +6796,9 @@ const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) { Operands[1], DL, &TLI); else if (const LoadInst *LI = dyn_cast<LoadInst>(I)) { if (!LI->isVolatile()) - C = ConstantFoldLoadFromConstPtr(Operands[0], DL); + C = ConstantFoldLoadFromConstPtr(Operands[0], LI->getType(), DL); } else - C = ConstantFoldInstOperands(I->getOpcode(), I->getType(), Operands, - DL, &TLI); + C = ConstantFoldInstOperands(I, Operands, DL, &TLI); if (!C) return V; return getSCEV(C); } @@ -6428,14 +6918,11 @@ const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) { llvm_unreachable("Unknown SCEV type!"); } -/// getSCEVAtScope - This is a convenience function which does -/// getSCEVAtScope(getSCEV(V), L). const SCEV *ScalarEvolution::getSCEVAtScope(Value *V, const Loop *L) { return getSCEVAtScope(getSCEV(V), L); } -/// SolveLinEquationWithOverflow - Finds the minimum unsigned root of the -/// following equation: +/// Finds the minimum unsigned root of the following equation: /// /// A * X = B (mod N) /// @@ -6482,11 +6969,11 @@ static const SCEV *SolveLinEquationWithOverflow(const APInt &A, const APInt &B, return SE.getConstant(Result.trunc(BW)); } -/// SolveQuadraticEquation - Find the roots of the quadratic equation for the -/// given quadratic chrec {L,+,M,+,N}. This returns either the two roots (which -/// might be the same) or two SCEVCouldNotCompute objects. +/// Find the roots of the quadratic equation for the given quadratic chrec +/// {L,+,M,+,N}. This returns either the two roots (which might be the same) or +/// two SCEVCouldNotCompute objects. /// -static std::pair<const SCEV *,const SCEV *> +static Optional<std::pair<const SCEVConstant *,const SCEVConstant *>> SolveQuadraticEquation(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE) { assert(AddRec->getNumOperands() == 3 && "This is not a quadratic chrec!"); const SCEVConstant *LC = dyn_cast<SCEVConstant>(AddRec->getOperand(0)); @@ -6494,10 +6981,8 @@ SolveQuadraticEquation(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE) { const SCEVConstant *NC = dyn_cast<SCEVConstant>(AddRec->getOperand(2)); // We currently can only solve this if the coefficients are constants. - if (!LC || !MC || !NC) { - const SCEV *CNC = SE.getCouldNotCompute(); - return std::make_pair(CNC, CNC); - } + if (!LC || !MC || !NC) + return None; uint32_t BitWidth = LC->getAPInt().getBitWidth(); const APInt &L = LC->getAPInt(); @@ -6524,8 +7009,7 @@ SolveQuadraticEquation(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE) { if (SqrtTerm.isNegative()) { // The loop is provably infinite. - const SCEV *CNC = SE.getCouldNotCompute(); - return std::make_pair(CNC, CNC); + return None; } // Compute sqrt(B^2-4ac). This is guaranteed to be the nearest @@ -6536,10 +7020,8 @@ SolveQuadraticEquation(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE) { // The divisions must be performed as signed divisions. APInt NegB(-B); APInt TwoA(A << 1); - if (TwoA.isMinValue()) { - const SCEV *CNC = SE.getCouldNotCompute(); - return std::make_pair(CNC, CNC); - } + if (TwoA.isMinValue()) + return None; LLVMContext &Context = SE.getContext(); @@ -6548,20 +7030,21 @@ SolveQuadraticEquation(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE) { ConstantInt *Solution2 = ConstantInt::get(Context, (NegB - SqrtVal).sdiv(TwoA)); - return std::make_pair(SE.getConstant(Solution1), - SE.getConstant(Solution2)); + return std::make_pair(cast<SCEVConstant>(SE.getConstant(Solution1)), + cast<SCEVConstant>(SE.getConstant(Solution2))); } // end APIntOps namespace } -/// HowFarToZero - Return the number of times a backedge comparing the specified -/// value to zero will execute. If not computable, return CouldNotCompute. -/// -/// This is only used for loops with a "x != y" exit test. The exit condition is -/// now expressed as a single expression, V = x-y. So the exit test is -/// effectively V != 0. We know and take advantage of the fact that this -/// expression only being used in a comparison by zero context. ScalarEvolution::ExitLimit -ScalarEvolution::HowFarToZero(const SCEV *V, const Loop *L, bool ControlsExit) { +ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L, bool ControlsExit, + bool AllowPredicates) { + + // This is only used for loops with a "x != y" exit test. The exit condition + // is now expressed as a single expression, V = x-y. So the exit test is + // effectively V != 0. We know and take advantage of the fact that this + // expression only being used in a comparison by zero context. + + SCEVUnionPredicate P; // If the value is a constant if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) { // If the value is already zero, the branch will execute zero times. @@ -6570,31 +7053,33 @@ ScalarEvolution::HowFarToZero(const SCEV *V, const Loop *L, bool ControlsExit) { } const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(V); + if (!AddRec && AllowPredicates) + // Try to make this an AddRec using runtime tests, in the first X + // iterations of this loop, where X is the SCEV expression found by the + // algorithm below. + AddRec = convertSCEVToAddRecWithPredicates(V, L, P); + if (!AddRec || AddRec->getLoop() != L) return getCouldNotCompute(); // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of // the quadratic equation to solve it. if (AddRec->isQuadratic() && AddRec->getType()->isIntegerTy()) { - std::pair<const SCEV *,const SCEV *> Roots = - SolveQuadraticEquation(AddRec, *this); - const SCEVConstant *R1 = dyn_cast<SCEVConstant>(Roots.first); - const SCEVConstant *R2 = dyn_cast<SCEVConstant>(Roots.second); - if (R1 && R2) { + if (auto Roots = SolveQuadraticEquation(AddRec, *this)) { + const SCEVConstant *R1 = Roots->first; + const SCEVConstant *R2 = Roots->second; // Pick the smallest positive root value. - if (ConstantInt *CB = - dyn_cast<ConstantInt>(ConstantExpr::getICmp(CmpInst::ICMP_ULT, - R1->getValue(), - R2->getValue()))) { + if (ConstantInt *CB = dyn_cast<ConstantInt>(ConstantExpr::getICmp( + CmpInst::ICMP_ULT, R1->getValue(), R2->getValue()))) { if (!CB->getZExtValue()) - std::swap(R1, R2); // R1 is the minimum root now. + std::swap(R1, R2); // R1 is the minimum root now. // We can only use this value if the chrec ends up with an exact zero // value at this index. When solving for "X*X != 5", for example, we // should not accept a root of 2. const SCEV *Val = AddRec->evaluateAtIteration(R1, *this); if (Val->isZero()) - return R1; // We found a quadratic root! + return ExitLimit(R1, R1, P); // We found a quadratic root! } } return getCouldNotCompute(); @@ -6651,7 +7136,7 @@ ScalarEvolution::HowFarToZero(const SCEV *V, const Loop *L, bool ControlsExit) { else MaxBECount = getConstant(CountDown ? CR.getUnsignedMax() : -CR.getUnsignedMin()); - return ExitLimit(Distance, MaxBECount); + return ExitLimit(Distance, MaxBECount, P); } // As a special case, handle the instance where Step is a positive power of @@ -6704,7 +7189,9 @@ ScalarEvolution::HowFarToZero(const SCEV *V, const Loop *L, bool ControlsExit) { auto *NarrowTy = IntegerType::get(getContext(), NarrowWidth); auto *WideTy = Distance->getType(); - return getZeroExtendExpr(getTruncateExpr(ModuloResult, NarrowTy), WideTy); + const SCEV *Limit = + getZeroExtendExpr(getTruncateExpr(ModuloResult, NarrowTy), WideTy); + return ExitLimit(Limit, Limit, P); } } @@ -6713,24 +7200,24 @@ ScalarEvolution::HowFarToZero(const SCEV *V, const Loop *L, bool ControlsExit) { // compute the backedge count. In this case, the step may not divide the // distance, but we don't care because if the condition is "missed" the loop // will have undefined behavior due to wrapping. - if (ControlsExit && AddRec->getNoWrapFlags(SCEV::FlagNW)) { + if (ControlsExit && AddRec->hasNoSelfWrap() && + loopHasNoAbnormalExits(AddRec->getLoop())) { const SCEV *Exact = getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step); - return ExitLimit(Exact, Exact); + return ExitLimit(Exact, Exact, P); } // Then, try to solve the above equation provided that Start is constant. - if (const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start)) - return SolveLinEquationWithOverflow(StepC->getAPInt(), -StartC->getAPInt(), - *this); + if (const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start)) { + const SCEV *E = SolveLinEquationWithOverflow( + StepC->getValue()->getValue(), -StartC->getValue()->getValue(), *this); + return ExitLimit(E, E, P); + } return getCouldNotCompute(); } -/// HowFarToNonZero - Return the number of times a backedge checking the -/// specified value for nonzero will execute. If not computable, return -/// CouldNotCompute ScalarEvolution::ExitLimit -ScalarEvolution::HowFarToNonZero(const SCEV *V, const Loop *L) { +ScalarEvolution::howFarToNonZero(const SCEV *V, const Loop *L) { // Loops that look like: while (X == 0) are very strange indeed. We don't // handle them yet except for the trivial case. This could be expanded in the // future as needed. @@ -6748,33 +7235,27 @@ ScalarEvolution::HowFarToNonZero(const SCEV *V, const Loop *L) { return getCouldNotCompute(); } -/// getPredecessorWithUniqueSuccessorForBB - Return a predecessor of BB -/// (which may not be an immediate predecessor) which has exactly one -/// successor from which BB is reachable, or null if no such block is -/// found. -/// std::pair<BasicBlock *, BasicBlock *> ScalarEvolution::getPredecessorWithUniqueSuccessorForBB(BasicBlock *BB) { // If the block has a unique predecessor, then there is no path from the // predecessor to the block that does not go through the direct edge // from the predecessor to the block. if (BasicBlock *Pred = BB->getSinglePredecessor()) - return std::make_pair(Pred, BB); + return {Pred, BB}; // A loop's header is defined to be a block that dominates the loop. // If the header has a unique predecessor outside the loop, it must be // a block that has exactly one successor that can reach the loop. if (Loop *L = LI.getLoopFor(BB)) - return std::make_pair(L->getLoopPredecessor(), L->getHeader()); + return {L->getLoopPredecessor(), L->getHeader()}; - return std::pair<BasicBlock *, BasicBlock *>(); + return {nullptr, nullptr}; } -/// HasSameValue - SCEV structural equivalence is usually sufficient for -/// testing whether two expressions are equal, however for the purposes of -/// looking for a condition guarding a loop, it can be useful to be a little -/// more general, since a front-end may have replicated the controlling -/// expression. +/// SCEV structural equivalence is usually sufficient for testing whether two +/// expressions are equal, however for the purposes of looking for a condition +/// guarding a loop, it can be useful to be a little more general, since a +/// front-end may have replicated the controlling expression. /// static bool HasSameValue(const SCEV *A, const SCEV *B) { // Quick check to see if they are the same SCEV. @@ -6800,9 +7281,6 @@ static bool HasSameValue(const SCEV *A, const SCEV *B) { return false; } -/// SimplifyICmpOperands - Simplify LHS and RHS in a comparison with -/// predicate Pred. Return true iff any changes were made. -/// bool ScalarEvolution::SimplifyICmpOperands(ICmpInst::Predicate &Pred, const SCEV *&LHS, const SCEV *&RHS, unsigned Depth) { @@ -7134,7 +7612,7 @@ bool ScalarEvolution::isKnownPredicate(ICmpInst::Predicate Pred, return true; // Otherwise see what can be done with known constant ranges. - return isKnownPredicateWithRanges(Pred, LHS, RHS); + return isKnownPredicateViaConstantRanges(Pred, LHS, RHS); } bool ScalarEvolution::isMonotonicPredicate(const SCEVAddRecExpr *LHS, @@ -7180,7 +7658,7 @@ bool ScalarEvolution::isMonotonicPredicateImpl(const SCEVAddRecExpr *LHS, case ICmpInst::ICMP_UGE: case ICmpInst::ICMP_ULT: case ICmpInst::ICMP_ULE: - if (!LHS->getNoWrapFlags(SCEV::FlagNUW)) + if (!LHS->hasNoUnsignedWrap()) return false; Increasing = Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_UGE; @@ -7190,7 +7668,7 @@ bool ScalarEvolution::isMonotonicPredicateImpl(const SCEVAddRecExpr *LHS, case ICmpInst::ICMP_SGE: case ICmpInst::ICMP_SLT: case ICmpInst::ICMP_SLE: { - if (!LHS->getNoWrapFlags(SCEV::FlagNSW)) + if (!LHS->hasNoSignedWrap()) return false; const SCEV *Step = LHS->getStepRecurrence(*this); @@ -7264,78 +7742,34 @@ bool ScalarEvolution::isLoopInvariantPredicate( return true; } -bool -ScalarEvolution::isKnownPredicateWithRanges(ICmpInst::Predicate Pred, - const SCEV *LHS, const SCEV *RHS) { +bool ScalarEvolution::isKnownPredicateViaConstantRanges( + ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS) { if (HasSameValue(LHS, RHS)) return ICmpInst::isTrueWhenEqual(Pred); // This code is split out from isKnownPredicate because it is called from // within isLoopEntryGuardedByCond. - switch (Pred) { - default: - llvm_unreachable("Unexpected ICmpInst::Predicate value!"); - case ICmpInst::ICMP_SGT: - std::swap(LHS, RHS); - case ICmpInst::ICMP_SLT: { - ConstantRange LHSRange = getSignedRange(LHS); - ConstantRange RHSRange = getSignedRange(RHS); - if (LHSRange.getSignedMax().slt(RHSRange.getSignedMin())) - return true; - if (LHSRange.getSignedMin().sge(RHSRange.getSignedMax())) - return false; - break; - } - case ICmpInst::ICMP_SGE: - std::swap(LHS, RHS); - case ICmpInst::ICMP_SLE: { - ConstantRange LHSRange = getSignedRange(LHS); - ConstantRange RHSRange = getSignedRange(RHS); - if (LHSRange.getSignedMax().sle(RHSRange.getSignedMin())) - return true; - if (LHSRange.getSignedMin().sgt(RHSRange.getSignedMax())) - return false; - break; - } - case ICmpInst::ICMP_UGT: - std::swap(LHS, RHS); - case ICmpInst::ICMP_ULT: { - ConstantRange LHSRange = getUnsignedRange(LHS); - ConstantRange RHSRange = getUnsignedRange(RHS); - if (LHSRange.getUnsignedMax().ult(RHSRange.getUnsignedMin())) - return true; - if (LHSRange.getUnsignedMin().uge(RHSRange.getUnsignedMax())) - return false; - break; - } - case ICmpInst::ICMP_UGE: - std::swap(LHS, RHS); - case ICmpInst::ICMP_ULE: { - ConstantRange LHSRange = getUnsignedRange(LHS); - ConstantRange RHSRange = getUnsignedRange(RHS); - if (LHSRange.getUnsignedMax().ule(RHSRange.getUnsignedMin())) - return true; - if (LHSRange.getUnsignedMin().ugt(RHSRange.getUnsignedMax())) - return false; - break; - } - case ICmpInst::ICMP_NE: { - if (getUnsignedRange(LHS).intersectWith(getUnsignedRange(RHS)).isEmptySet()) - return true; - if (getSignedRange(LHS).intersectWith(getSignedRange(RHS)).isEmptySet()) - return true; - const SCEV *Diff = getMinusSCEV(LHS, RHS); - if (isKnownNonZero(Diff)) - return true; - break; - } - case ICmpInst::ICMP_EQ: - // The check at the top of the function catches the case where - // the values are known to be equal. - break; - } - return false; + auto CheckRanges = + [&](const ConstantRange &RangeLHS, const ConstantRange &RangeRHS) { + return ConstantRange::makeSatisfyingICmpRegion(Pred, RangeRHS) + .contains(RangeLHS); + }; + + // The check at the top of the function catches the case where the values are + // known to be equal. + if (Pred == CmpInst::ICMP_EQ) + return false; + + if (Pred == CmpInst::ICMP_NE) + return CheckRanges(getSignedRange(LHS), getSignedRange(RHS)) || + CheckRanges(getUnsignedRange(LHS), getUnsignedRange(RHS)) || + isKnownNonZero(getMinusSCEV(LHS, RHS)); + + if (CmpInst::isSigned(Pred)) + return CheckRanges(getSignedRange(LHS), getSignedRange(RHS)); + + return CheckRanges(getUnsignedRange(LHS), getUnsignedRange(RHS)); } bool ScalarEvolution::isKnownPredicateViaNoOverflow(ICmpInst::Predicate Pred, @@ -7416,6 +7850,23 @@ bool ScalarEvolution::isKnownPredicateViaSplitting(ICmpInst::Predicate Pred, isKnownPredicate(CmpInst::ICMP_SLT, LHS, RHS); } +bool ScalarEvolution::isImpliedViaGuard(BasicBlock *BB, + ICmpInst::Predicate Pred, + const SCEV *LHS, const SCEV *RHS) { + // No need to even try if we know the module has no guards. + if (!HasGuards) + return false; + + return any_of(*BB, [&](Instruction &I) { + using namespace llvm::PatternMatch; + + Value *Condition; + return match(&I, m_Intrinsic<Intrinsic::experimental_guard>( + m_Value(Condition))) && + isImpliedCond(Pred, LHS, RHS, Condition, false); + }); +} + /// isLoopBackedgeGuardedByCond - Test whether the backedge of the loop is /// protected by a conditional between LHS and RHS. This is used to /// to eliminate casts. @@ -7427,7 +7878,8 @@ ScalarEvolution::isLoopBackedgeGuardedByCond(const Loop *L, // (interprocedural conditions notwithstanding). if (!L) return true; - if (isKnownPredicateWithRanges(Pred, LHS, RHS)) return true; + if (isKnownPredicateViaConstantRanges(Pred, LHS, RHS)) + return true; BasicBlock *Latch = L->getLoopLatch(); if (!Latch) @@ -7482,12 +7934,18 @@ ScalarEvolution::isLoopBackedgeGuardedByCond(const Loop *L, if (!DT.isReachableFromEntry(L->getHeader())) return false; + if (isImpliedViaGuard(Latch, Pred, LHS, RHS)) + return true; + for (DomTreeNode *DTN = DT[Latch], *HeaderDTN = DT[L->getHeader()]; DTN != HeaderDTN; DTN = DTN->getIDom()) { assert(DTN && "should reach the loop header before reaching the root!"); BasicBlock *BB = DTN->getBlock(); + if (isImpliedViaGuard(BB, Pred, LHS, RHS)) + return true; + BasicBlock *PBB = BB->getSinglePredecessor(); if (!PBB) continue; @@ -7518,9 +7976,6 @@ ScalarEvolution::isLoopBackedgeGuardedByCond(const Loop *L, return false; } -/// isLoopEntryGuardedByCond - Test whether entry to the loop is protected -/// by a conditional between LHS and RHS. This is used to help avoid max -/// expressions in loop trip counts, and to eliminate casts. bool ScalarEvolution::isLoopEntryGuardedByCond(const Loop *L, ICmpInst::Predicate Pred, @@ -7529,7 +7984,8 @@ ScalarEvolution::isLoopEntryGuardedByCond(const Loop *L, // (interprocedural conditions notwithstanding). if (!L) return false; - if (isKnownPredicateWithRanges(Pred, LHS, RHS)) return true; + if (isKnownPredicateViaConstantRanges(Pred, LHS, RHS)) + return true; // Starting at the loop predecessor, climb up the predecessor chain, as long // as there are predecessors that can be found that have unique successors @@ -7539,6 +7995,9 @@ ScalarEvolution::isLoopEntryGuardedByCond(const Loop *L, Pair.first; Pair = getPredecessorWithUniqueSuccessorForBB(Pair.first)) { + if (isImpliedViaGuard(Pair.first, Pred, LHS, RHS)) + return true; + BranchInst *LoopEntryPredicate = dyn_cast<BranchInst>(Pair.first->getTerminator()); if (!LoopEntryPredicate || @@ -7586,8 +8045,6 @@ struct MarkPendingLoopPredicate { }; } // end anonymous namespace -/// isImpliedCond - Test whether the condition described by Pred, LHS, -/// and RHS is true whenever the given Cond value evaluates to true. bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, Value *FoundCondValue, @@ -7910,9 +8367,6 @@ bool ScalarEvolution::isImpliedCondOperandsViaNoOverflow( getConstant(FoundRHSLimit)); } -/// isImpliedCondOperands - Test whether the condition described by Pred, -/// LHS, and RHS is true whenever the condition described by Pred, FoundLHS, -/// and FoundRHS is true. bool ScalarEvolution::isImpliedCondOperands(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const SCEV *FoundLHS, @@ -8037,9 +8491,6 @@ static bool IsKnownPredicateViaMinOrMax(ScalarEvolution &SE, llvm_unreachable("covered switch fell through?!"); } -/// isImpliedCondOperandsHelper - Test whether the condition described by -/// Pred, LHS, and RHS is true whenever the condition described by Pred, -/// FoundLHS, and FoundRHS is true. bool ScalarEvolution::isImpliedCondOperandsHelper(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, @@ -8047,7 +8498,7 @@ ScalarEvolution::isImpliedCondOperandsHelper(ICmpInst::Predicate Pred, const SCEV *FoundRHS) { auto IsKnownPredicateFull = [this](ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS) { - return isKnownPredicateWithRanges(Pred, LHS, RHS) || + return isKnownPredicateViaConstantRanges(Pred, LHS, RHS) || IsKnownPredicateViaMinOrMax(*this, Pred, LHS, RHS) || IsKnownPredicateViaAddRecStart(*this, Pred, LHS, RHS) || isKnownPredicateViaNoOverflow(Pred, LHS, RHS); @@ -8089,8 +8540,6 @@ ScalarEvolution::isImpliedCondOperandsHelper(ICmpInst::Predicate Pred, return false; } -/// isImpliedCondOperandsViaRanges - helper function for isImpliedCondOperands. -/// Tries to get cases like "X `sgt` 0 => X - 1 `sgt` -1". bool ScalarEvolution::isImpliedCondOperandsViaRanges(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, @@ -8129,9 +8578,6 @@ bool ScalarEvolution::isImpliedCondOperandsViaRanges(ICmpInst::Predicate Pred, return SatisfyingLHSRange.contains(LHSRange); } -// Verify if an linear IV with positive stride can overflow when in a -// less-than comparison, knowing the invariant term of the comparison, the -// stride and the knowledge of NSW/NUW flags on the recurrence. bool ScalarEvolution::doesIVOverflowOnLT(const SCEV *RHS, const SCEV *Stride, bool IsSigned, bool NoWrap) { if (NoWrap) return false; @@ -8158,9 +8604,6 @@ bool ScalarEvolution::doesIVOverflowOnLT(const SCEV *RHS, const SCEV *Stride, return (MaxValue - MaxStrideMinusOne).ult(MaxRHS); } -// Verify if an linear IV with negative stride can overflow when in a -// greater-than comparison, knowing the invariant term of the comparison, -// the stride and the knowledge of NSW/NUW flags on the recurrence. bool ScalarEvolution::doesIVOverflowOnGT(const SCEV *RHS, const SCEV *Stride, bool IsSigned, bool NoWrap) { if (NoWrap) return false; @@ -8187,8 +8630,6 @@ bool ScalarEvolution::doesIVOverflowOnGT(const SCEV *RHS, const SCEV *Stride, return (MinValue + MaxStrideMinusOne).ugt(MinRHS); } -// Compute the backedge taken count knowing the interval difference, the -// stride and presence of the equality in the comparison. const SCEV *ScalarEvolution::computeBECount(const SCEV *Delta, const SCEV *Step, bool Equality) { const SCEV *One = getOne(Step->getType()); @@ -8197,22 +8638,21 @@ const SCEV *ScalarEvolution::computeBECount(const SCEV *Delta, const SCEV *Step, return getUDivExpr(Delta, Step); } -/// HowManyLessThans - Return the number of times a backedge containing the -/// specified less-than comparison will execute. If not computable, return -/// CouldNotCompute. -/// -/// @param ControlsExit is true when the LHS < RHS condition directly controls -/// the branch (loops exits only if condition is true). In this case, we can use -/// NoWrapFlags to skip overflow checks. ScalarEvolution::ExitLimit -ScalarEvolution::HowManyLessThans(const SCEV *LHS, const SCEV *RHS, +ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS, const Loop *L, bool IsSigned, - bool ControlsExit) { + bool ControlsExit, bool AllowPredicates) { + SCEVUnionPredicate P; // We handle only IV < Invariant if (!isLoopInvariant(RHS, L)) return getCouldNotCompute(); const SCEVAddRecExpr *IV = dyn_cast<SCEVAddRecExpr>(LHS); + if (!IV && AllowPredicates) + // Try to make this an AddRec using runtime tests, in the first X + // iterations of this loop, where X is the SCEV expression found by the + // algorithm below. + IV = convertSCEVToAddRecWithPredicates(LHS, L, P); // Avoid weird loops if (!IV || IV->getLoop() != L || !IV->isAffine()) @@ -8238,19 +8678,8 @@ ScalarEvolution::HowManyLessThans(const SCEV *LHS, const SCEV *RHS, : ICmpInst::ICMP_ULT; const SCEV *Start = IV->getStart(); const SCEV *End = RHS; - if (!isLoopEntryGuardedByCond(L, Cond, getMinusSCEV(Start, Stride), RHS)) { - const SCEV *Diff = getMinusSCEV(RHS, Start); - // If we have NoWrap set, then we can assume that the increment won't - // overflow, in which case if RHS - Start is a constant, we don't need to - // do a max operation since we can just figure it out statically - if (NoWrap && isa<SCEVConstant>(Diff)) { - APInt D = dyn_cast<const SCEVConstant>(Diff)->getAPInt(); - if (D.isNegative()) - End = Start; - } else - End = IsSigned ? getSMaxExpr(RHS, Start) - : getUMaxExpr(RHS, Start); - } + if (!isLoopEntryGuardedByCond(L, Cond, getMinusSCEV(Start, Stride), RHS)) + End = IsSigned ? getSMaxExpr(RHS, Start) : getUMaxExpr(RHS, Start); const SCEV *BECount = computeBECount(getMinusSCEV(End, Start), Stride, false); @@ -8281,18 +8710,24 @@ ScalarEvolution::HowManyLessThans(const SCEV *LHS, const SCEV *RHS, if (isa<SCEVCouldNotCompute>(MaxBECount)) MaxBECount = BECount; - return ExitLimit(BECount, MaxBECount); + return ExitLimit(BECount, MaxBECount, P); } ScalarEvolution::ExitLimit -ScalarEvolution::HowManyGreaterThans(const SCEV *LHS, const SCEV *RHS, +ScalarEvolution::howManyGreaterThans(const SCEV *LHS, const SCEV *RHS, const Loop *L, bool IsSigned, - bool ControlsExit) { + bool ControlsExit, bool AllowPredicates) { + SCEVUnionPredicate P; // We handle only IV > Invariant if (!isLoopInvariant(RHS, L)) return getCouldNotCompute(); const SCEVAddRecExpr *IV = dyn_cast<SCEVAddRecExpr>(LHS); + if (!IV && AllowPredicates) + // Try to make this an AddRec using runtime tests, in the first X + // iterations of this loop, where X is the SCEV expression found by the + // algorithm below. + IV = convertSCEVToAddRecWithPredicates(LHS, L, P); // Avoid weird loops if (!IV || IV->getLoop() != L || !IV->isAffine()) @@ -8319,19 +8754,8 @@ ScalarEvolution::HowManyGreaterThans(const SCEV *LHS, const SCEV *RHS, const SCEV *Start = IV->getStart(); const SCEV *End = RHS; - if (!isLoopEntryGuardedByCond(L, Cond, getAddExpr(Start, Stride), RHS)) { - const SCEV *Diff = getMinusSCEV(RHS, Start); - // If we have NoWrap set, then we can assume that the increment won't - // overflow, in which case if RHS - Start is a constant, we don't need to - // do a max operation since we can just figure it out statically - if (NoWrap && isa<SCEVConstant>(Diff)) { - APInt D = dyn_cast<const SCEVConstant>(Diff)->getAPInt(); - if (!D.isNegative()) - End = Start; - } else - End = IsSigned ? getSMinExpr(RHS, Start) - : getUMinExpr(RHS, Start); - } + if (!isLoopEntryGuardedByCond(L, Cond, getAddExpr(Start, Stride), RHS)) + End = IsSigned ? getSMinExpr(RHS, Start) : getUMinExpr(RHS, Start); const SCEV *BECount = computeBECount(getMinusSCEV(Start, End), Stride, false); @@ -8363,15 +8787,10 @@ ScalarEvolution::HowManyGreaterThans(const SCEV *LHS, const SCEV *RHS, if (isa<SCEVCouldNotCompute>(MaxBECount)) MaxBECount = BECount; - return ExitLimit(BECount, MaxBECount); + return ExitLimit(BECount, MaxBECount, P); } -/// getNumIterationsInRange - Return the number of iterations of this loop that -/// produce values in the specified constant range. Another way of looking at -/// this is that it returns the first iteration number where the value is not in -/// the condition, thus computing the exit count. If the iteration count can't -/// be computed, an instance of SCEVCouldNotCompute is returned. -const SCEV *SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range, +const SCEV *SCEVAddRecExpr::getNumIterationsInRange(const ConstantRange &Range, ScalarEvolution &SE) const { if (Range.isFullSet()) // Infinite loop. return SE.getCouldNotCompute(); @@ -8445,22 +8864,21 @@ const SCEV *SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range, FlagAnyWrap); // Next, solve the constructed addrec - auto Roots = SolveQuadraticEquation(cast<SCEVAddRecExpr>(NewAddRec), SE); - const SCEVConstant *R1 = dyn_cast<SCEVConstant>(Roots.first); - const SCEVConstant *R2 = dyn_cast<SCEVConstant>(Roots.second); - if (R1) { + if (auto Roots = + SolveQuadraticEquation(cast<SCEVAddRecExpr>(NewAddRec), SE)) { + const SCEVConstant *R1 = Roots->first; + const SCEVConstant *R2 = Roots->second; // Pick the smallest positive root value. if (ConstantInt *CB = dyn_cast<ConstantInt>(ConstantExpr::getICmp( ICmpInst::ICMP_ULT, R1->getValue(), R2->getValue()))) { if (!CB->getZExtValue()) - std::swap(R1, R2); // R1 is the minimum root now. + std::swap(R1, R2); // R1 is the minimum root now. // Make sure the root is not off by one. The returned iteration should // not be in the range, but the previous one should be. When solving // for "X*X < 5", for example, we should not return a root of 2. - ConstantInt *R1Val = EvaluateConstantChrecAtConstant(this, - R1->getValue(), - SE); + ConstantInt *R1Val = + EvaluateConstantChrecAtConstant(this, R1->getValue(), SE); if (Range.contains(R1Val->getValue())) { // The next iteration must be out of the range... ConstantInt *NextVal = @@ -8469,7 +8887,7 @@ const SCEV *SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range, R1Val = EvaluateConstantChrecAtConstant(this, NextVal, SE); if (!Range.contains(R1Val->getValue())) return SE.getConstant(NextVal); - return SE.getCouldNotCompute(); // Something strange happened + return SE.getCouldNotCompute(); // Something strange happened } // If R1 was not in the range, then it is a good return value. Make @@ -8479,7 +8897,7 @@ const SCEV *SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range, R1Val = EvaluateConstantChrecAtConstant(this, NextVal, SE); if (Range.contains(R1Val->getValue())) return R1; - return SE.getCouldNotCompute(); // Something strange happened + return SE.getCouldNotCompute(); // Something strange happened } } } @@ -8789,12 +9207,9 @@ const SCEV *ScalarEvolution::getElementSize(Instruction *Inst) { return getSizeOfExpr(ETy, Ty); } -/// Second step of delinearization: compute the array dimensions Sizes from the -/// set of Terms extracted from the memory access function of this SCEVAddRec. void ScalarEvolution::findArrayDimensions(SmallVectorImpl<const SCEV *> &Terms, SmallVectorImpl<const SCEV *> &Sizes, const SCEV *ElementSize) const { - if (Terms.size() < 1 || !ElementSize) return; @@ -8858,8 +9273,6 @@ void ScalarEvolution::findArrayDimensions(SmallVectorImpl<const SCEV *> &Terms, }); } -/// Third step of delinearization: compute the access functions for the -/// Subscripts based on the dimensions in Sizes. void ScalarEvolution::computeAccessFunctions( const SCEV *Expr, SmallVectorImpl<const SCEV *> &Subscripts, SmallVectorImpl<const SCEV *> &Sizes) { @@ -9012,7 +9425,7 @@ void ScalarEvolution::SCEVCallbackVH::deleted() { assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!"); if (PHINode *PN = dyn_cast<PHINode>(getValPtr())) SE->ConstantEvolutionLoopExitValue.erase(PN); - SE->ValueExprMap.erase(getValPtr()); + SE->eraseValueFromMap(getValPtr()); // this now dangles! } @@ -9035,13 +9448,13 @@ void ScalarEvolution::SCEVCallbackVH::allUsesReplacedWith(Value *V) { continue; if (PHINode *PN = dyn_cast<PHINode>(U)) SE->ConstantEvolutionLoopExitValue.erase(PN); - SE->ValueExprMap.erase(U); + SE->eraseValueFromMap(U); Worklist.insert(Worklist.end(), U->user_begin(), U->user_end()); } // Delete the Old value. if (PHINode *PN = dyn_cast<PHINode>(Old)) SE->ConstantEvolutionLoopExitValue.erase(PN); - SE->ValueExprMap.erase(Old); + SE->eraseValueFromMap(Old); // this now dangles! } @@ -9059,14 +9472,31 @@ ScalarEvolution::ScalarEvolution(Function &F, TargetLibraryInfo &TLI, CouldNotCompute(new SCEVCouldNotCompute()), WalkingBEDominatingConds(false), ProvingSplitPredicate(false), ValuesAtScopes(64), LoopDispositions(64), BlockDispositions(64), - FirstUnknown(nullptr) {} + FirstUnknown(nullptr) { + + // To use guards for proving predicates, we need to scan every instruction in + // relevant basic blocks, and not just terminators. Doing this is a waste of + // time if the IR does not actually contain any calls to + // @llvm.experimental.guard, so do a quick check and remember this beforehand. + // + // This pessimizes the case where a pass that preserves ScalarEvolution wants + // to _add_ guards to the module when there weren't any before, and wants + // ScalarEvolution to optimize based on those guards. For now we prefer to be + // efficient in lieu of being smart in that rather obscure case. + + auto *GuardDecl = F.getParent()->getFunction( + Intrinsic::getName(Intrinsic::experimental_guard)); + HasGuards = GuardDecl && !GuardDecl->use_empty(); +} ScalarEvolution::ScalarEvolution(ScalarEvolution &&Arg) - : F(Arg.F), TLI(Arg.TLI), AC(Arg.AC), DT(Arg.DT), LI(Arg.LI), - CouldNotCompute(std::move(Arg.CouldNotCompute)), + : F(Arg.F), HasGuards(Arg.HasGuards), TLI(Arg.TLI), AC(Arg.AC), DT(Arg.DT), + LI(Arg.LI), CouldNotCompute(std::move(Arg.CouldNotCompute)), ValueExprMap(std::move(Arg.ValueExprMap)), WalkingBEDominatingConds(false), ProvingSplitPredicate(false), BackedgeTakenCounts(std::move(Arg.BackedgeTakenCounts)), + PredicatedBackedgeTakenCounts( + std::move(Arg.PredicatedBackedgeTakenCounts)), ConstantEvolutionLoopExitValue( std::move(Arg.ConstantEvolutionLoopExitValue)), ValuesAtScopes(std::move(Arg.ValuesAtScopes)), @@ -9091,12 +9521,16 @@ ScalarEvolution::~ScalarEvolution() { } FirstUnknown = nullptr; + ExprValueMap.clear(); ValueExprMap.clear(); + HasRecMap.clear(); // Free any extra memory created for ExitNotTakenInfo in the unlikely event // that a loop had multiple computable exits. for (auto &BTCI : BackedgeTakenCounts) BTCI.second.clear(); + for (auto &BTCI : PredicatedBackedgeTakenCounts) + BTCI.second.clear(); assert(PendingLoopPredicates.empty() && "isImpliedCond garbage"); assert(!WalkingBEDominatingConds && "isLoopBackedgeGuardedByCond garbage!"); @@ -9110,8 +9544,8 @@ bool ScalarEvolution::hasLoopInvariantBackedgeTakenCount(const Loop *L) { static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE, const Loop *L) { // Print all inner loops first - for (Loop::iterator I = L->begin(), E = L->end(); I != E; ++I) - PrintLoopInfo(OS, SE, *I); + for (Loop *I : *L) + PrintLoopInfo(OS, SE, I); OS << "Loop "; L->getHeader()->printAsOperand(OS, /*PrintType=*/false); @@ -9139,9 +9573,35 @@ static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE, OS << "Unpredictable max backedge-taken count. "; } + OS << "\n" + "Loop "; + L->getHeader()->printAsOperand(OS, /*PrintType=*/false); + OS << ": "; + + SCEVUnionPredicate Pred; + auto PBT = SE->getPredicatedBackedgeTakenCount(L, Pred); + if (!isa<SCEVCouldNotCompute>(PBT)) { + OS << "Predicated backedge-taken count is " << *PBT << "\n"; + OS << " Predicates:\n"; + Pred.print(OS, 4); + } else { + OS << "Unpredictable predicated backedge-taken count. "; + } OS << "\n"; } +static StringRef loopDispositionToStr(ScalarEvolution::LoopDisposition LD) { + switch (LD) { + case ScalarEvolution::LoopVariant: + return "Variant"; + case ScalarEvolution::LoopInvariant: + return "Invariant"; + case ScalarEvolution::LoopComputable: + return "Computable"; + } + llvm_unreachable("Unknown ScalarEvolution::LoopDisposition kind!"); +} + void ScalarEvolution::print(raw_ostream &OS) const { // ScalarEvolution's implementation of the print method is to print // out SCEV values of all instructions that are interesting. Doing @@ -9189,6 +9649,35 @@ void ScalarEvolution::print(raw_ostream &OS) const { } else { OS << *ExitValue; } + + bool First = true; + for (auto *Iter = L; Iter; Iter = Iter->getParentLoop()) { + if (First) { + OS << "\t\t" "LoopDispositions: { "; + First = false; + } else { + OS << ", "; + } + + Iter->getHeader()->printAsOperand(OS, /*PrintType=*/false); + OS << ": " << loopDispositionToStr(SE.getLoopDisposition(SV, Iter)); + } + + for (auto *InnerL : depth_first(L)) { + if (InnerL == L) + continue; + if (First) { + OS << "\t\t" "LoopDispositions: { "; + First = false; + } else { + OS << ", "; + } + + InnerL->getHeader()->printAsOperand(OS, /*PrintType=*/false); + OS << ": " << loopDispositionToStr(SE.getLoopDisposition(SV, InnerL)); + } + + OS << " }"; } OS << "\n"; @@ -9197,8 +9686,8 @@ void ScalarEvolution::print(raw_ostream &OS) const { OS << "Determining loop execution counts for: "; F.printAsOperand(OS, /*PrintType=*/false); OS << "\n"; - for (LoopInfo::iterator I = LI.begin(), E = LI.end(); I != E; ++I) - PrintLoopInfo(OS, &SE, *I); + for (Loop *I : LI) + PrintLoopInfo(OS, &SE, I); } ScalarEvolution::LoopDisposition @@ -9420,17 +9909,23 @@ void ScalarEvolution::forgetMemoizedResults(const SCEV *S) { BlockDispositions.erase(S); UnsignedRanges.erase(S); SignedRanges.erase(S); + ExprValueMap.erase(S); + HasRecMap.erase(S); + + auto RemoveSCEVFromBackedgeMap = + [S, this](DenseMap<const Loop *, BackedgeTakenInfo> &Map) { + for (auto I = Map.begin(), E = Map.end(); I != E;) { + BackedgeTakenInfo &BEInfo = I->second; + if (BEInfo.hasOperand(S, this)) { + BEInfo.clear(); + Map.erase(I++); + } else + ++I; + } + }; - for (DenseMap<const Loop*, BackedgeTakenInfo>::iterator I = - BackedgeTakenCounts.begin(), E = BackedgeTakenCounts.end(); I != E; ) { - BackedgeTakenInfo &BEInfo = I->second; - if (BEInfo.hasOperand(S, this)) { - BEInfo.clear(); - BackedgeTakenCounts.erase(I++); - } - else - ++I; - } + RemoveSCEVFromBackedgeMap(BackedgeTakenCounts); + RemoveSCEVFromBackedgeMap(PredicatedBackedgeTakenCounts); } typedef DenseMap<const Loop *, std::string> VerifyMap; @@ -9516,16 +10011,16 @@ void ScalarEvolution::verify() const { char ScalarEvolutionAnalysis::PassID; ScalarEvolution ScalarEvolutionAnalysis::run(Function &F, - AnalysisManager<Function> *AM) { - return ScalarEvolution(F, AM->getResult<TargetLibraryAnalysis>(F), - AM->getResult<AssumptionAnalysis>(F), - AM->getResult<DominatorTreeAnalysis>(F), - AM->getResult<LoopAnalysis>(F)); + AnalysisManager<Function> &AM) { + return ScalarEvolution(F, AM.getResult<TargetLibraryAnalysis>(F), + AM.getResult<AssumptionAnalysis>(F), + AM.getResult<DominatorTreeAnalysis>(F), + AM.getResult<LoopAnalysis>(F)); } PreservedAnalyses -ScalarEvolutionPrinterPass::run(Function &F, AnalysisManager<Function> *AM) { - AM->getResult<ScalarEvolutionAnalysis>(F).print(OS); +ScalarEvolutionPrinterPass::run(Function &F, AnalysisManager<Function> &AM) { + AM.getResult<ScalarEvolutionAnalysis>(F).print(OS); return PreservedAnalyses::all(); } @@ -9590,36 +10085,121 @@ ScalarEvolution::getEqualPredicate(const SCEVUnknown *LHS, return Eq; } +const SCEVPredicate *ScalarEvolution::getWrapPredicate( + const SCEVAddRecExpr *AR, + SCEVWrapPredicate::IncrementWrapFlags AddedFlags) { + FoldingSetNodeID ID; + // Unique this node based on the arguments + ID.AddInteger(SCEVPredicate::P_Wrap); + ID.AddPointer(AR); + ID.AddInteger(AddedFlags); + void *IP = nullptr; + if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP)) + return S; + auto *OF = new (SCEVAllocator) + SCEVWrapPredicate(ID.Intern(SCEVAllocator), AR, AddedFlags); + UniquePreds.InsertNode(OF, IP); + return OF; +} + namespace { + class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> { public: - static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE, - SCEVUnionPredicate &A) { - SCEVPredicateRewriter Rewriter(SE, A); - return Rewriter.visit(Scev); + // Rewrites \p S in the context of a loop L and the predicate A. + // If Assume is true, rewrite is free to add further predicates to A + // such that the result will be an AddRecExpr. + static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE, + SCEVUnionPredicate &A, bool Assume) { + SCEVPredicateRewriter Rewriter(L, SE, A, Assume); + return Rewriter.visit(S); } - SCEVPredicateRewriter(ScalarEvolution &SE, SCEVUnionPredicate &P) - : SCEVRewriteVisitor(SE), P(P) {} + SCEVPredicateRewriter(const Loop *L, ScalarEvolution &SE, + SCEVUnionPredicate &P, bool Assume) + : SCEVRewriteVisitor(SE), P(P), L(L), Assume(Assume) {} const SCEV *visitUnknown(const SCEVUnknown *Expr) { auto ExprPreds = P.getPredicatesForExpr(Expr); for (auto *Pred : ExprPreds) - if (const auto *IPred = dyn_cast<const SCEVEqualPredicate>(Pred)) + if (const auto *IPred = dyn_cast<SCEVEqualPredicate>(Pred)) if (IPred->getLHS() == Expr) return IPred->getRHS(); return Expr; } + const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { + const SCEV *Operand = visit(Expr->getOperand()); + const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Operand); + if (AR && AR->getLoop() == L && AR->isAffine()) { + // This couldn't be folded because the operand didn't have the nuw + // flag. Add the nusw flag as an assumption that we could make. + const SCEV *Step = AR->getStepRecurrence(SE); + Type *Ty = Expr->getType(); + if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNUSW)) + return SE.getAddRecExpr(SE.getZeroExtendExpr(AR->getStart(), Ty), + SE.getSignExtendExpr(Step, Ty), L, + AR->getNoWrapFlags()); + } + return SE.getZeroExtendExpr(Operand, Expr->getType()); + } + + const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { + const SCEV *Operand = visit(Expr->getOperand()); + const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Operand); + if (AR && AR->getLoop() == L && AR->isAffine()) { + // This couldn't be folded because the operand didn't have the nsw + // flag. Add the nssw flag as an assumption that we could make. + const SCEV *Step = AR->getStepRecurrence(SE); + Type *Ty = Expr->getType(); + if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNSSW)) + return SE.getAddRecExpr(SE.getSignExtendExpr(AR->getStart(), Ty), + SE.getSignExtendExpr(Step, Ty), L, + AR->getNoWrapFlags()); + } + return SE.getSignExtendExpr(Operand, Expr->getType()); + } + private: + bool addOverflowAssumption(const SCEVAddRecExpr *AR, + SCEVWrapPredicate::IncrementWrapFlags AddedFlags) { + auto *A = SE.getWrapPredicate(AR, AddedFlags); + if (!Assume) { + // Check if we've already made this assumption. + if (P.implies(A)) + return true; + return false; + } + P.add(A); + return true; + } + SCEVUnionPredicate &P; + const Loop *L; + bool Assume; }; } // end anonymous namespace -const SCEV *ScalarEvolution::rewriteUsingPredicate(const SCEV *Scev, +const SCEV *ScalarEvolution::rewriteUsingPredicate(const SCEV *S, const Loop *L, SCEVUnionPredicate &Preds) { - return SCEVPredicateRewriter::rewrite(Scev, *this, Preds); + return SCEVPredicateRewriter::rewrite(S, L, *this, Preds, false); +} + +const SCEVAddRecExpr * +ScalarEvolution::convertSCEVToAddRecWithPredicates(const SCEV *S, const Loop *L, + SCEVUnionPredicate &Preds) { + SCEVUnionPredicate TransformPreds; + S = SCEVPredicateRewriter::rewrite(S, L, *this, TransformPreds, true); + auto *AddRec = dyn_cast<SCEVAddRecExpr>(S); + + if (!AddRec) + return nullptr; + + // Since the transformation was successful, we can now transfer the SCEV + // predicates. + Preds.add(&TransformPreds); + return AddRec; } /// SCEV predicates @@ -9633,7 +10213,7 @@ SCEVEqualPredicate::SCEVEqualPredicate(const FoldingSetNodeIDRef ID, : SCEVPredicate(ID, P_Equal), LHS(LHS), RHS(RHS) {} bool SCEVEqualPredicate::implies(const SCEVPredicate *N) const { - const auto *Op = dyn_cast<const SCEVEqualPredicate>(N); + const auto *Op = dyn_cast<SCEVEqualPredicate>(N); if (!Op) return false; @@ -9649,6 +10229,59 @@ void SCEVEqualPredicate::print(raw_ostream &OS, unsigned Depth) const { OS.indent(Depth) << "Equal predicate: " << *LHS << " == " << *RHS << "\n"; } +SCEVWrapPredicate::SCEVWrapPredicate(const FoldingSetNodeIDRef ID, + const SCEVAddRecExpr *AR, + IncrementWrapFlags Flags) + : SCEVPredicate(ID, P_Wrap), AR(AR), Flags(Flags) {} + +const SCEV *SCEVWrapPredicate::getExpr() const { return AR; } + +bool SCEVWrapPredicate::implies(const SCEVPredicate *N) const { + const auto *Op = dyn_cast<SCEVWrapPredicate>(N); + + return Op && Op->AR == AR && setFlags(Flags, Op->Flags) == Flags; +} + +bool SCEVWrapPredicate::isAlwaysTrue() const { + SCEV::NoWrapFlags ScevFlags = AR->getNoWrapFlags(); + IncrementWrapFlags IFlags = Flags; + + if (ScalarEvolution::setFlags(ScevFlags, SCEV::FlagNSW) == ScevFlags) + IFlags = clearFlags(IFlags, IncrementNSSW); + + return IFlags == IncrementAnyWrap; +} + +void SCEVWrapPredicate::print(raw_ostream &OS, unsigned Depth) const { + OS.indent(Depth) << *getExpr() << " Added Flags: "; + if (SCEVWrapPredicate::IncrementNUSW & getFlags()) + OS << "<nusw>"; + if (SCEVWrapPredicate::IncrementNSSW & getFlags()) + OS << "<nssw>"; + OS << "\n"; +} + +SCEVWrapPredicate::IncrementWrapFlags +SCEVWrapPredicate::getImpliedFlags(const SCEVAddRecExpr *AR, + ScalarEvolution &SE) { + IncrementWrapFlags ImpliedFlags = IncrementAnyWrap; + SCEV::NoWrapFlags StaticFlags = AR->getNoWrapFlags(); + + // We can safely transfer the NSW flag as NSSW. + if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNSW) == StaticFlags) + ImpliedFlags = IncrementNSSW; + + if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNUW) == StaticFlags) { + // If the increment is positive, the SCEV NUW flag will also imply the + // WrapPredicate NUSW flag. + if (const auto *Step = dyn_cast<SCEVConstant>(AR->getStepRecurrence(SE))) + if (Step->getValue()->getValue().isNonNegative()) + ImpliedFlags = setFlags(ImpliedFlags, IncrementNUSW); + } + + return ImpliedFlags; +} + /// Union predicates don't get cached so create a dummy set ID for it. SCEVUnionPredicate::SCEVUnionPredicate() : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {} @@ -9667,7 +10300,7 @@ SCEVUnionPredicate::getPredicatesForExpr(const SCEV *Expr) { } bool SCEVUnionPredicate::implies(const SCEVPredicate *N) const { - if (const auto *Set = dyn_cast<const SCEVUnionPredicate>(N)) + if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N)) return all_of(Set->Preds, [this](const SCEVPredicate *I) { return this->implies(I); }); @@ -9688,7 +10321,7 @@ void SCEVUnionPredicate::print(raw_ostream &OS, unsigned Depth) const { } void SCEVUnionPredicate::add(const SCEVPredicate *N) { - if (const auto *Set = dyn_cast<const SCEVUnionPredicate>(N)) { + if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N)) { for (auto Pred : Set->Preds) add(Pred); return; @@ -9705,8 +10338,9 @@ void SCEVUnionPredicate::add(const SCEVPredicate *N) { Preds.push_back(N); } -PredicatedScalarEvolution::PredicatedScalarEvolution(ScalarEvolution &SE) - : SE(SE), Generation(0) {} +PredicatedScalarEvolution::PredicatedScalarEvolution(ScalarEvolution &SE, + Loop &L) + : SE(SE), L(L), Generation(0), BackedgeCount(nullptr) {} const SCEV *PredicatedScalarEvolution::getSCEV(Value *V) { const SCEV *Expr = SE.getSCEV(V); @@ -9721,12 +10355,21 @@ const SCEV *PredicatedScalarEvolution::getSCEV(Value *V) { if (Entry.second) Expr = Entry.second; - const SCEV *NewSCEV = SE.rewriteUsingPredicate(Expr, Preds); + const SCEV *NewSCEV = SE.rewriteUsingPredicate(Expr, &L, Preds); Entry = {Generation, NewSCEV}; return NewSCEV; } +const SCEV *PredicatedScalarEvolution::getBackedgeTakenCount() { + if (!BackedgeCount) { + SCEVUnionPredicate BackedgePred; + BackedgeCount = SE.getPredicatedBackedgeTakenCount(&L, BackedgePred); + addPredicate(BackedgePred); + } + return BackedgeCount; +} + void PredicatedScalarEvolution::addPredicate(const SCEVPredicate &Pred) { if (Preds.implies(&Pred)) return; @@ -9743,7 +10386,82 @@ void PredicatedScalarEvolution::updateGeneration() { if (++Generation == 0) { for (auto &II : RewriteMap) { const SCEV *Rewritten = II.second.second; - II.second = {Generation, SE.rewriteUsingPredicate(Rewritten, Preds)}; + II.second = {Generation, SE.rewriteUsingPredicate(Rewritten, &L, Preds)}; } } } + +void PredicatedScalarEvolution::setNoOverflow( + Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags) { + const SCEV *Expr = getSCEV(V); + const auto *AR = cast<SCEVAddRecExpr>(Expr); + + auto ImpliedFlags = SCEVWrapPredicate::getImpliedFlags(AR, SE); + + // Clear the statically implied flags. + Flags = SCEVWrapPredicate::clearFlags(Flags, ImpliedFlags); + addPredicate(*SE.getWrapPredicate(AR, Flags)); + + auto II = FlagsMap.insert({V, Flags}); + if (!II.second) + II.first->second = SCEVWrapPredicate::setFlags(Flags, II.first->second); +} + +bool PredicatedScalarEvolution::hasNoOverflow( + Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags) { + const SCEV *Expr = getSCEV(V); + const auto *AR = cast<SCEVAddRecExpr>(Expr); + + Flags = SCEVWrapPredicate::clearFlags( + Flags, SCEVWrapPredicate::getImpliedFlags(AR, SE)); + + auto II = FlagsMap.find(V); + + if (II != FlagsMap.end()) + Flags = SCEVWrapPredicate::clearFlags(Flags, II->second); + + return Flags == SCEVWrapPredicate::IncrementAnyWrap; +} + +const SCEVAddRecExpr *PredicatedScalarEvolution::getAsAddRec(Value *V) { + const SCEV *Expr = this->getSCEV(V); + auto *New = SE.convertSCEVToAddRecWithPredicates(Expr, &L, Preds); + + if (!New) + return nullptr; + + updateGeneration(); + RewriteMap[SE.getSCEV(V)] = {Generation, New}; + return New; +} + +PredicatedScalarEvolution::PredicatedScalarEvolution( + const PredicatedScalarEvolution &Init) + : RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L), Preds(Init.Preds), + Generation(Init.Generation), BackedgeCount(Init.BackedgeCount) { + for (const auto &I : Init.FlagsMap) + FlagsMap.insert(I); +} + +void PredicatedScalarEvolution::print(raw_ostream &OS, unsigned Depth) const { + // For each block. + for (auto *BB : L.getBlocks()) + for (auto &I : *BB) { + if (!SE.isSCEVable(I.getType())) + continue; + + auto *Expr = SE.getSCEV(&I); + auto II = RewriteMap.find(Expr); + + if (II == RewriteMap.end()) + continue; + + // Don't print things that are not interesting. + if (II->second.second == Expr) + continue; + + OS.indent(Depth) << "[PSE]" << I << ":\n"; + OS.indent(Depth + 2) << *Expr << "\n"; + OS.indent(Depth + 2) << "--> " << *II->second.second << "\n"; + } +} |