diff options
Diffstat (limited to 'contrib/llvm/lib/Analysis/ScalarEvolution.cpp')
-rw-r--r-- | contrib/llvm/lib/Analysis/ScalarEvolution.cpp | 2340 |
1 files changed, 1626 insertions, 714 deletions
diff --git a/contrib/llvm/lib/Analysis/ScalarEvolution.cpp b/contrib/llvm/lib/Analysis/ScalarEvolution.cpp index ed328f1..9539fd7 100644 --- a/contrib/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/contrib/llvm/lib/Analysis/ScalarEvolution.cpp @@ -89,9 +89,10 @@ #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/KnownBits.h" #include "llvm/Support/MathExtras.h" -#include "llvm/Support/raw_ostream.h" #include "llvm/Support/SaveAndRestore.h" +#include "llvm/Support/raw_ostream.h" #include <algorithm> using namespace llvm; @@ -125,18 +126,47 @@ static cl::opt<bool> static cl::opt<unsigned> MulOpsInlineThreshold( "scev-mulops-inline-threshold", cl::Hidden, cl::desc("Threshold for inlining multiplication operands into a SCEV"), - cl::init(1000)); + cl::init(32)); + +static cl::opt<unsigned> AddOpsInlineThreshold( + "scev-addops-inline-threshold", cl::Hidden, + cl::desc("Threshold for inlining addition operands into a SCEV"), + cl::init(500)); static cl::opt<unsigned> MaxSCEVCompareDepth( "scalar-evolution-max-scev-compare-depth", cl::Hidden, cl::desc("Maximum depth of recursive SCEV complexity comparisons"), cl::init(32)); +static cl::opt<unsigned> MaxSCEVOperationsImplicationDepth( + "scalar-evolution-max-scev-operations-implication-depth", cl::Hidden, + cl::desc("Maximum depth of recursive SCEV operations implication analysis"), + cl::init(2)); + static cl::opt<unsigned> MaxValueCompareDepth( "scalar-evolution-max-value-compare-depth", cl::Hidden, cl::desc("Maximum depth of recursive value complexity comparisons"), cl::init(2)); +static cl::opt<unsigned> + MaxArithDepth("scalar-evolution-max-arith-depth", cl::Hidden, + cl::desc("Maximum depth of recursive arithmetics"), + cl::init(32)); + +static cl::opt<unsigned> MaxConstantEvolvingDepth( + "scalar-evolution-max-constant-evolving-depth", cl::Hidden, + cl::desc("Maximum depth of recursive constant evolving"), cl::init(32)); + +static cl::opt<unsigned> + MaxExtDepth("scalar-evolution-max-ext-depth", cl::Hidden, + cl::desc("Maximum depth of recursive SExt/ZExt"), + cl::init(8)); + +static cl::opt<unsigned> + MaxAddRecSize("scalar-evolution-max-add-rec-size", cl::Hidden, + cl::desc("Max coefficients in AddRec during evolving"), + cl::init(16)); + //===----------------------------------------------------------------------===// // SCEV class definitions //===----------------------------------------------------------------------===// @@ -145,11 +175,12 @@ static cl::opt<unsigned> MaxValueCompareDepth( // Implementation of the SCEV class. // -LLVM_DUMP_METHOD -void SCEV::dump() const { +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +LLVM_DUMP_METHOD void SCEV::dump() const { print(dbgs()); dbgs() << '\n'; } +#endif void SCEV::print(raw_ostream &OS) const { switch (static_cast<SCEVTypes>(getSCEVType())) { @@ -300,7 +331,7 @@ bool SCEV::isOne() const { bool SCEV::isAllOnesValue() const { if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this)) - return SC->getValue()->isAllOnesValue(); + return SC->getValue()->isMinusOne(); return false; } @@ -563,7 +594,7 @@ CompareValueComplexity(SmallSet<std::pair<Value *, Value *>, 8> &EqCache, static int CompareSCEVComplexity( SmallSet<std::pair<const SCEV *, const SCEV *>, 8> &EqCacheSCEV, const LoopInfo *const LI, const SCEV *LHS, const SCEV *RHS, - unsigned Depth = 0) { + DominatorTree &DT, unsigned Depth = 0) { // Fast-path: SCEVs are uniqued so we can do a quick equality check. if (LHS == RHS) return 0; @@ -608,12 +639,19 @@ static int CompareSCEVComplexity( const SCEVAddRecExpr *LA = cast<SCEVAddRecExpr>(LHS); const SCEVAddRecExpr *RA = cast<SCEVAddRecExpr>(RHS); - // Compare addrec loop depths. + // There is always a dominance between two recs that are used by one SCEV, + // so we can safely sort recs by loop header dominance. We require such + // order in getAddExpr. const Loop *LLoop = LA->getLoop(), *RLoop = RA->getLoop(); if (LLoop != RLoop) { - unsigned LDepth = LLoop->getLoopDepth(), RDepth = RLoop->getLoopDepth(); - if (LDepth != RDepth) - return (int)LDepth - (int)RDepth; + const BasicBlock *LHead = LLoop->getHeader(), *RHead = RLoop->getHeader(); + assert(LHead != RHead && "Two loops share the same header?"); + if (DT.dominates(LHead, RHead)) + return 1; + else + assert(DT.dominates(RHead, LHead) && + "No dominance between recurrences used by one SCEV?"); + return -1; } // Addrec complexity grows with operand count. @@ -624,7 +662,7 @@ static int CompareSCEVComplexity( // Lexicographically compare. for (unsigned i = 0; i != LNumOps; ++i) { int X = CompareSCEVComplexity(EqCacheSCEV, LI, LA->getOperand(i), - RA->getOperand(i), Depth + 1); + RA->getOperand(i), DT, Depth + 1); if (X != 0) return X; } @@ -648,7 +686,7 @@ static int CompareSCEVComplexity( if (i >= RNumOps) return 1; int X = CompareSCEVComplexity(EqCacheSCEV, LI, LC->getOperand(i), - RC->getOperand(i), Depth + 1); + RC->getOperand(i), DT, Depth + 1); if (X != 0) return X; } @@ -662,10 +700,10 @@ static int CompareSCEVComplexity( // Lexicographically compare udiv expressions. int X = CompareSCEVComplexity(EqCacheSCEV, LI, LC->getLHS(), RC->getLHS(), - Depth + 1); + DT, Depth + 1); if (X != 0) return X; - X = CompareSCEVComplexity(EqCacheSCEV, LI, LC->getRHS(), RC->getRHS(), + X = CompareSCEVComplexity(EqCacheSCEV, LI, LC->getRHS(), RC->getRHS(), DT, Depth + 1); if (X == 0) EqCacheSCEV.insert({LHS, RHS}); @@ -680,7 +718,7 @@ static int CompareSCEVComplexity( // Compare cast expressions by operand. int X = CompareSCEVComplexity(EqCacheSCEV, LI, LC->getOperand(), - RC->getOperand(), Depth + 1); + RC->getOperand(), DT, Depth + 1); if (X == 0) EqCacheSCEV.insert({LHS, RHS}); return X; @@ -703,7 +741,7 @@ static int CompareSCEVComplexity( /// land in memory. /// static void GroupByComplexity(SmallVectorImpl<const SCEV *> &Ops, - LoopInfo *LI) { + LoopInfo *LI, DominatorTree &DT) { if (Ops.size() < 2) return; // Noop SmallSet<std::pair<const SCEV *, const SCEV *>, 8> EqCache; @@ -711,15 +749,16 @@ static void GroupByComplexity(SmallVectorImpl<const SCEV *> &Ops, // This is the common case, which also happens to be trivially simple. // Special case it. const SCEV *&LHS = Ops[0], *&RHS = Ops[1]; - if (CompareSCEVComplexity(EqCache, LI, RHS, LHS) < 0) + if (CompareSCEVComplexity(EqCache, LI, RHS, LHS, DT) < 0) std::swap(LHS, RHS); return; } // Do the rough sort by complexity. std::stable_sort(Ops.begin(), Ops.end(), - [&EqCache, LI](const SCEV *LHS, const SCEV *RHS) { - return CompareSCEVComplexity(EqCache, LI, LHS, RHS) < 0; + [&EqCache, LI, &DT](const SCEV *LHS, const SCEV *RHS) { + return + CompareSCEVComplexity(EqCache, LI, LHS, RHS, DT) < 0; }); // Now that we are sorted by complexity, group elements of the same @@ -1073,7 +1112,7 @@ static const SCEV *BinomialCoefficient(const SCEV *It, unsigned K, APInt Mult(W, i); unsigned TwoFactors = Mult.countTrailingZeros(); T += TwoFactors; - Mult = Mult.lshr(TwoFactors); + Mult.lshrInPlace(TwoFactors); OddFactorial *= Mult; } @@ -1230,12 +1269,12 @@ static const SCEV *getSignedOverflowLimitForStep(const SCEV *Step, if (SE->isKnownPositive(Step)) { *Pred = ICmpInst::ICMP_SLT; return SE->getConstant(APInt::getSignedMinValue(BitWidth) - - SE->getSignedRange(Step).getSignedMax()); + SE->getSignedRangeMax(Step)); } if (SE->isKnownNegative(Step)) { *Pred = ICmpInst::ICMP_SGT; return SE->getConstant(APInt::getSignedMaxValue(BitWidth) - - SE->getSignedRange(Step).getSignedMin()); + SE->getSignedRangeMin(Step)); } return nullptr; } @@ -1250,13 +1289,14 @@ static const SCEV *getUnsignedOverflowLimitForStep(const SCEV *Step, *Pred = ICmpInst::ICMP_ULT; return SE->getConstant(APInt::getMinValue(BitWidth) - - SE->getUnsignedRange(Step).getUnsignedMax()); + SE->getUnsignedRangeMax(Step)); } namespace { struct ExtendOpTraitsBase { - typedef const SCEV *(ScalarEvolution::*GetExtendExprTy)(const SCEV *, Type *); + typedef const SCEV *(ScalarEvolution::*GetExtendExprTy)(const SCEV *, Type *, + unsigned); }; // Used to make code generic over signed and unsigned overflow. @@ -1314,7 +1354,7 @@ const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits< // "sext/zext(PostIncAR)" template <typename ExtendOpTy> static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty, - ScalarEvolution *SE) { + ScalarEvolution *SE, unsigned Depth) { auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType; auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr; @@ -1361,9 +1401,9 @@ static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty, unsigned BitWidth = SE->getTypeSizeInBits(AR->getType()); Type *WideTy = IntegerType::get(SE->getContext(), BitWidth * 2); const SCEV *OperandExtendedStart = - SE->getAddExpr((SE->*GetExtendExpr)(PreStart, WideTy), - (SE->*GetExtendExpr)(Step, WideTy)); - if ((SE->*GetExtendExpr)(Start, WideTy) == OperandExtendedStart) { + SE->getAddExpr((SE->*GetExtendExpr)(PreStart, WideTy, Depth), + (SE->*GetExtendExpr)(Step, WideTy, Depth)); + if ((SE->*GetExtendExpr)(Start, WideTy, Depth) == OperandExtendedStart) { if (PreAR && AR->getNoWrapFlags(WrapType)) { // If we know `AR` == {`PreStart`+`Step`,+,`Step`} is `WrapType` (FlagNSW // or FlagNUW) and that `PreStart` + `Step` is `WrapType` too, then @@ -1388,15 +1428,17 @@ static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty, // Get the normalized zero or sign extended expression for this AddRec's Start. template <typename ExtendOpTy> static const SCEV *getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty, - ScalarEvolution *SE) { + ScalarEvolution *SE, + unsigned Depth) { auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr; - const SCEV *PreStart = getPreStartForExtend<ExtendOpTy>(AR, Ty, SE); + const SCEV *PreStart = getPreStartForExtend<ExtendOpTy>(AR, Ty, SE, Depth); if (!PreStart) - return (SE->*GetExtendExpr)(AR->getStart(), Ty); + return (SE->*GetExtendExpr)(AR->getStart(), Ty, Depth); - return SE->getAddExpr((SE->*GetExtendExpr)(AR->getStepRecurrence(*SE), Ty), - (SE->*GetExtendExpr)(PreStart, Ty)); + return SE->getAddExpr((SE->*GetExtendExpr)(AR->getStepRecurrence(*SE), Ty, + Depth), + (SE->*GetExtendExpr)(PreStart, Ty, Depth)); } // Try to prove away overflow by looking at "nearby" add recurrences. A @@ -1476,8 +1518,8 @@ bool ScalarEvolution::proveNoWrapByVaryingStart(const SCEV *Start, return false; } -const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op, - Type *Ty) { +const SCEV * +ScalarEvolution::getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) { assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) && "This is not an extending conversion!"); assert(isSCEVable(Ty) && @@ -1491,7 +1533,7 @@ const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op, // zext(zext(x)) --> zext(x) if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op)) - return getZeroExtendExpr(SZ->getOperand(), Ty); + return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1); // Before doing any expensive analysis, check to see if we've already // computed a SCEV for this Op and Ty. @@ -1501,6 +1543,12 @@ const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op, ID.AddPointer(Ty); void *IP = nullptr; if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; + if (Depth > MaxExtDepth) { + SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator), + Op, Ty); + UniqueSCEVs.InsertNode(S, IP); + return S; + } // zext(trunc(x)) --> zext(x) or x or trunc(x) if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op)) { @@ -1535,8 +1583,8 @@ const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op, // we don't need to do any further analysis. if (AR->hasNoUnsignedWrap()) return getAddRecExpr( - getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this), - getZeroExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); + getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Depth + 1), + getZeroExtendExpr(Step, Ty, Depth + 1), L, AR->getNoWrapFlags()); // Check whether the backedge-taken count is SCEVCouldNotCompute. // Note that this serves two purposes: It filters out loops that are @@ -1560,37 +1608,49 @@ const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op, if (MaxBECount == RecastedMaxBECount) { Type *WideTy = IntegerType::get(getContext(), BitWidth * 2); // Check whether Start+Step*MaxBECount has no unsigned overflow. - const SCEV *ZMul = getMulExpr(CastedMaxBECount, Step); - const SCEV *ZAdd = getZeroExtendExpr(getAddExpr(Start, ZMul), WideTy); - const SCEV *WideStart = getZeroExtendExpr(Start, WideTy); + const SCEV *ZMul = getMulExpr(CastedMaxBECount, Step, + SCEV::FlagAnyWrap, Depth + 1); + const SCEV *ZAdd = getZeroExtendExpr(getAddExpr(Start, ZMul, + SCEV::FlagAnyWrap, + Depth + 1), + WideTy, Depth + 1); + const SCEV *WideStart = getZeroExtendExpr(Start, WideTy, Depth + 1); const SCEV *WideMaxBECount = - getZeroExtendExpr(CastedMaxBECount, WideTy); + getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1); const SCEV *OperandExtendedAdd = getAddExpr(WideStart, getMulExpr(WideMaxBECount, - getZeroExtendExpr(Step, WideTy))); + getZeroExtendExpr(Step, WideTy, Depth + 1), + SCEV::FlagAnyWrap, Depth + 1), + SCEV::FlagAnyWrap, Depth + 1); if (ZAdd == OperandExtendedAdd) { // Cache knowledge of AR NUW, which is propagated to this AddRec. const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNUW); // Return the expression with the addrec on the outside. return getAddRecExpr( - getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this), - getZeroExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); + getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, + Depth + 1), + getZeroExtendExpr(Step, Ty, Depth + 1), L, + AR->getNoWrapFlags()); } // Similar to above, only this time treat the step value as signed. // This covers loops that count down. OperandExtendedAdd = getAddExpr(WideStart, getMulExpr(WideMaxBECount, - getSignExtendExpr(Step, WideTy))); + getSignExtendExpr(Step, WideTy, Depth + 1), + SCEV::FlagAnyWrap, Depth + 1), + SCEV::FlagAnyWrap, Depth + 1); if (ZAdd == OperandExtendedAdd) { // Cache knowledge of AR NW, which is propagated to this AddRec. // Negative step causes unsigned wrap, but it still can't self-wrap. const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNW); // Return the expression with the addrec on the outside. return getAddRecExpr( - getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this), - getSignExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); + getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, + Depth + 1), + getSignExtendExpr(Step, Ty, Depth + 1), L, + AR->getNoWrapFlags()); } } } @@ -1611,7 +1671,7 @@ const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op, // is safe. if (isKnownPositive(Step)) { const SCEV *N = getConstant(APInt::getMinValue(BitWidth) - - getUnsignedRange(Step).getUnsignedMax()); + getUnsignedRangeMax(Step)); if (isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_ULT, AR, N) || (isLoopEntryGuardedByCond(L, ICmpInst::ICMP_ULT, Start, N) && isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_ULT, @@ -1621,12 +1681,14 @@ const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op, const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNUW); // Return the expression with the addrec on the outside. return getAddRecExpr( - getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this), - getZeroExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); + getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, + Depth + 1), + getZeroExtendExpr(Step, Ty, Depth + 1), L, + AR->getNoWrapFlags()); } } else if (isKnownNegative(Step)) { const SCEV *N = getConstant(APInt::getMaxValue(BitWidth) - - getSignedRange(Step).getSignedMin()); + getSignedRangeMin(Step)); if (isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_UGT, AR, N) || (isLoopEntryGuardedByCond(L, ICmpInst::ICMP_UGT, Start, N) && isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_UGT, @@ -1637,8 +1699,10 @@ const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op, const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNW); // Return the expression with the addrec on the outside. return getAddRecExpr( - getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this), - getSignExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); + getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, + Depth + 1), + getSignExtendExpr(Step, Ty, Depth + 1), L, + AR->getNoWrapFlags()); } } } @@ -1646,8 +1710,8 @@ const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op, if (proveNoWrapByVaryingStart<SCEVZeroExtendExpr>(Start, Step, L)) { const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNUW); return getAddRecExpr( - getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this), - getZeroExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); + getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Depth + 1), + getZeroExtendExpr(Step, Ty, Depth + 1), L, AR->getNoWrapFlags()); } } @@ -1658,8 +1722,8 @@ const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op, // commute the zero extension with the addition operation. SmallVector<const SCEV *, 4> Ops; for (const auto *Op : SA->operands()) - Ops.push_back(getZeroExtendExpr(Op, Ty)); - return getAddExpr(Ops, SCEV::FlagNUW); + Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1)); + return getAddExpr(Ops, SCEV::FlagNUW, Depth + 1); } } @@ -1672,8 +1736,8 @@ const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op, return S; } -const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op, - Type *Ty) { +const SCEV * +ScalarEvolution::getSignExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) { assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) && "This is not an extending conversion!"); assert(isSCEVable(Ty) && @@ -1687,11 +1751,11 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op, // sext(sext(x)) --> sext(x) if (const SCEVSignExtendExpr *SS = dyn_cast<SCEVSignExtendExpr>(Op)) - return getSignExtendExpr(SS->getOperand(), Ty); + return getSignExtendExpr(SS->getOperand(), Ty, Depth + 1); // sext(zext(x)) --> zext(x) if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op)) - return getZeroExtendExpr(SZ->getOperand(), Ty); + return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1); // Before doing any expensive analysis, check to see if we've already // computed a SCEV for this Op and Ty. @@ -1701,6 +1765,13 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op, ID.AddPointer(Ty); void *IP = nullptr; if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; + // Limit recursion depth. + if (Depth > MaxExtDepth) { + SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator), + Op, Ty); + UniqueSCEVs.InsertNode(S, IP); + return S; + } // sext(trunc(x)) --> sext(x) or x or trunc(x) if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op)) { @@ -1726,8 +1797,9 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op, const APInt &C2 = SC2->getAPInt(); if (C1.isStrictlyPositive() && C2.isStrictlyPositive() && C2.ugt(C1) && C2.isPowerOf2()) - return getAddExpr(getSignExtendExpr(SC1, Ty), - getSignExtendExpr(SMul, Ty)); + return getAddExpr(getSignExtendExpr(SC1, Ty, Depth + 1), + getSignExtendExpr(SMul, Ty, Depth + 1), + SCEV::FlagAnyWrap, Depth + 1); } } } @@ -1738,8 +1810,8 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op, // commute the sign extension with the addition operation. SmallVector<const SCEV *, 4> Ops; for (const auto *Op : SA->operands()) - Ops.push_back(getSignExtendExpr(Op, Ty)); - return getAddExpr(Ops, SCEV::FlagNSW); + Ops.push_back(getSignExtendExpr(Op, Ty, Depth + 1)); + return getAddExpr(Ops, SCEV::FlagNSW, Depth + 1); } } // If the input value is a chrec scev, and we can prove that the value @@ -1762,8 +1834,8 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op, // we don't need to do any further analysis. if (AR->hasNoSignedWrap()) return getAddRecExpr( - getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this), - getSignExtendExpr(Step, Ty), L, SCEV::FlagNSW); + getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1), + getSignExtendExpr(Step, Ty, Depth + 1), L, SCEV::FlagNSW); // Check whether the backedge-taken count is SCEVCouldNotCompute. // Note that this serves two purposes: It filters out loops that are @@ -1787,29 +1859,39 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op, if (MaxBECount == RecastedMaxBECount) { Type *WideTy = IntegerType::get(getContext(), BitWidth * 2); // Check whether Start+Step*MaxBECount has no signed overflow. - const SCEV *SMul = getMulExpr(CastedMaxBECount, Step); - const SCEV *SAdd = getSignExtendExpr(getAddExpr(Start, SMul), WideTy); - const SCEV *WideStart = getSignExtendExpr(Start, WideTy); + const SCEV *SMul = getMulExpr(CastedMaxBECount, Step, + SCEV::FlagAnyWrap, Depth + 1); + const SCEV *SAdd = getSignExtendExpr(getAddExpr(Start, SMul, + SCEV::FlagAnyWrap, + Depth + 1), + WideTy, Depth + 1); + const SCEV *WideStart = getSignExtendExpr(Start, WideTy, Depth + 1); const SCEV *WideMaxBECount = - getZeroExtendExpr(CastedMaxBECount, WideTy); + getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1); const SCEV *OperandExtendedAdd = getAddExpr(WideStart, getMulExpr(WideMaxBECount, - getSignExtendExpr(Step, WideTy))); + getSignExtendExpr(Step, WideTy, Depth + 1), + SCEV::FlagAnyWrap, Depth + 1), + SCEV::FlagAnyWrap, Depth + 1); if (SAdd == OperandExtendedAdd) { // Cache knowledge of AR NSW, which is propagated to this AddRec. const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNSW); // Return the expression with the addrec on the outside. return getAddRecExpr( - getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this), - getSignExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); + getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, + Depth + 1), + getSignExtendExpr(Step, Ty, Depth + 1), L, + AR->getNoWrapFlags()); } // Similar to above, only this time treat the step value as unsigned. // This covers loops that count up with an unsigned step. OperandExtendedAdd = getAddExpr(WideStart, getMulExpr(WideMaxBECount, - getZeroExtendExpr(Step, WideTy))); + getZeroExtendExpr(Step, WideTy, Depth + 1), + SCEV::FlagAnyWrap, Depth + 1), + SCEV::FlagAnyWrap, Depth + 1); if (SAdd == OperandExtendedAdd) { // If AR wraps around then // @@ -1823,8 +1905,10 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op, // Return the expression with the addrec on the outside. return getAddRecExpr( - getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this), - getZeroExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); + getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, + Depth + 1), + getZeroExtendExpr(Step, Ty, Depth + 1), L, + AR->getNoWrapFlags()); } } } @@ -1855,8 +1939,8 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op, // Cache knowledge of AR NSW, then propagate NSW to the wide AddRec. const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNSW); return getAddRecExpr( - getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this), - getSignExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); + getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1), + getSignExtendExpr(Step, Ty, Depth + 1), L, AR->getNoWrapFlags()); } } @@ -1870,25 +1954,26 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op, const APInt &C2 = SC2->getAPInt(); if (C1.isStrictlyPositive() && C2.isStrictlyPositive() && C2.ugt(C1) && C2.isPowerOf2()) { - Start = getSignExtendExpr(Start, Ty); + Start = getSignExtendExpr(Start, Ty, Depth + 1); const SCEV *NewAR = getAddRecExpr(getZero(AR->getType()), Step, L, AR->getNoWrapFlags()); - return getAddExpr(Start, getSignExtendExpr(NewAR, Ty)); + return getAddExpr(Start, getSignExtendExpr(NewAR, Ty, Depth + 1), + SCEV::FlagAnyWrap, Depth + 1); } } if (proveNoWrapByVaryingStart<SCEVSignExtendExpr>(Start, Step, L)) { const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNSW); return getAddRecExpr( - getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this), - getSignExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); + getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1), + getSignExtendExpr(Step, Ty, Depth + 1), L, AR->getNoWrapFlags()); } } // If the input value is provably positive and we could not simplify // away the sext build a zext instead. if (isKnownNonNegative(Op)) - return getZeroExtendExpr(Op, Ty); + return getZeroExtendExpr(Op, Ty, Depth + 1); // The cast wasn't folded; create an explicit cast node. // Recompute the insert position, as it may have been invalidated. @@ -2093,9 +2178,66 @@ StrengthenNoWrapFlags(ScalarEvolution *SE, SCEVTypes Type, return Flags; } +bool ScalarEvolution::isAvailableAtLoopEntry(const SCEV *S, const Loop *L) { + if (!isLoopInvariant(S, L)) + return false; + // If a value depends on a SCEVUnknown which is defined after the loop, we + // conservatively assume that we cannot calculate it at the loop's entry. + struct FindDominatedSCEVUnknown { + bool Found = false; + const Loop *L; + DominatorTree &DT; + LoopInfo &LI; + + FindDominatedSCEVUnknown(const Loop *L, DominatorTree &DT, LoopInfo &LI) + : L(L), DT(DT), LI(LI) {} + + bool checkSCEVUnknown(const SCEVUnknown *SU) { + if (auto *I = dyn_cast<Instruction>(SU->getValue())) { + if (DT.dominates(L->getHeader(), I->getParent())) + Found = true; + else + assert(DT.dominates(I->getParent(), L->getHeader()) && + "No dominance relationship between SCEV and loop?"); + } + return false; + } + + bool follow(const SCEV *S) { + switch (static_cast<SCEVTypes>(S->getSCEVType())) { + case scConstant: + return false; + case scAddRecExpr: + case scTruncate: + case scZeroExtend: + case scSignExtend: + case scAddExpr: + case scMulExpr: + case scUMaxExpr: + case scSMaxExpr: + case scUDivExpr: + return true; + case scUnknown: + return checkSCEVUnknown(cast<SCEVUnknown>(S)); + case scCouldNotCompute: + llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!"); + } + return false; + } + + bool isDone() { return Found; } + }; + + FindDominatedSCEVUnknown FSU(L, DT, LI); + SCEVTraversal<FindDominatedSCEVUnknown> ST(FSU); + ST.visitAll(S); + return !FSU.Found; +} + /// Get a canonical add expression, or something simpler if possible. const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops, - SCEV::NoWrapFlags Flags) { + SCEV::NoWrapFlags Flags, + unsigned Depth) { assert(!(Flags & ~(SCEV::FlagNUW | SCEV::FlagNSW)) && "only nuw or nsw allowed"); assert(!Ops.empty() && "Cannot get empty add!"); @@ -2108,7 +2250,7 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops, #endif // Sort by complexity, this groups all similar expression types together. - GroupByComplexity(Ops, &LI); + GroupByComplexity(Ops, &LI, DT); Flags = StrengthenNoWrapFlags(this, scAddExpr, Ops, Flags); @@ -2134,6 +2276,10 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops, if (Ops.size() == 1) return Ops[0]; } + // Limit recursion calls depth. + if (Depth > MaxArithDepth) + return getOrCreateAddExpr(Ops, Flags); + // Okay, check to see if the same value occurs in the operand list more than // once. If so, merge them together into an multiply expression. Since we // sorted the list, these values are required to be adjacent. @@ -2147,7 +2293,7 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops, ++Count; // Merge the values into a multiply. const SCEV *Scale = getConstant(Ty, Count); - const SCEV *Mul = getMulExpr(Scale, Ops[i]); + const SCEV *Mul = getMulExpr(Scale, Ops[i], SCEV::FlagAnyWrap, Depth + 1); if (Ops.size() == Count) return Mul; Ops[i] = Mul; @@ -2197,7 +2343,7 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops, } } if (Ok) - LargeOps.push_back(getMulExpr(LargeMulOps)); + LargeOps.push_back(getMulExpr(LargeMulOps, SCEV::FlagAnyWrap, Depth + 1)); } else { Ok = false; break; @@ -2205,7 +2351,7 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops, } if (Ok) { // Evaluate the expression in the larger type. - const SCEV *Fold = getAddExpr(LargeOps, Flags); + const SCEV *Fold = getAddExpr(LargeOps, Flags, Depth + 1); // If it folds to something simple, use it. Otherwise, don't. if (isa<SCEVConstant>(Fold) || isa<SCEVUnknown>(Fold)) return getTruncateExpr(Fold, DstType); @@ -2220,6 +2366,9 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops, if (Idx < Ops.size()) { bool DeletedAdd = false; while (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[Idx])) { + if (Ops.size() > AddOpsInlineThreshold || + Add->getNumOperands() > AddOpsInlineThreshold) + break; // If we have an add, expand the add operands onto the end of the operands // list. Ops.erase(Ops.begin()+Idx); @@ -2231,7 +2380,7 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops, // and they are not necessarily sorted. Recurse to resort and resimplify // any operands we just acquired. if (DeletedAdd) - return getAddExpr(Ops); + return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1); } // Skip over the add expression until we get to a multiply. @@ -2266,13 +2415,15 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops, Ops.push_back(getConstant(AccumulatedConstant)); for (auto &MulOp : MulOpLists) if (MulOp.first != 0) - Ops.push_back(getMulExpr(getConstant(MulOp.first), - getAddExpr(MulOp.second))); + Ops.push_back(getMulExpr( + getConstant(MulOp.first), + getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1), + SCEV::FlagAnyWrap, Depth + 1)); if (Ops.empty()) return getZero(Ty); if (Ops.size() == 1) return Ops[0]; - return getAddExpr(Ops); + return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1); } } @@ -2295,11 +2446,12 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops, SmallVector<const SCEV *, 4> MulOps(Mul->op_begin(), Mul->op_begin()+MulOp); MulOps.append(Mul->op_begin()+MulOp+1, Mul->op_end()); - InnerMul = getMulExpr(MulOps); + InnerMul = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1); } - const SCEV *One = getOne(Ty); - const SCEV *AddOne = getAddExpr(One, InnerMul); - const SCEV *OuterMul = getMulExpr(AddOne, MulOpSCEV); + SmallVector<const SCEV *, 2> TwoOps = {getOne(Ty), InnerMul}; + const SCEV *AddOne = getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1); + const SCEV *OuterMul = getMulExpr(AddOne, MulOpSCEV, + SCEV::FlagAnyWrap, Depth + 1); if (Ops.size() == 2) return OuterMul; if (AddOp < Idx) { Ops.erase(Ops.begin()+AddOp); @@ -2309,7 +2461,7 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops, Ops.erase(Ops.begin()+AddOp-1); } Ops.push_back(OuterMul); - return getAddExpr(Ops); + return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1); } // Check this multiply against other multiplies being added together. @@ -2328,22 +2480,25 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops, SmallVector<const SCEV *, 4> MulOps(Mul->op_begin(), Mul->op_begin()+MulOp); MulOps.append(Mul->op_begin()+MulOp+1, Mul->op_end()); - InnerMul1 = getMulExpr(MulOps); + InnerMul1 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1); } const SCEV *InnerMul2 = OtherMul->getOperand(OMulOp == 0); if (OtherMul->getNumOperands() != 2) { SmallVector<const SCEV *, 4> MulOps(OtherMul->op_begin(), OtherMul->op_begin()+OMulOp); MulOps.append(OtherMul->op_begin()+OMulOp+1, OtherMul->op_end()); - InnerMul2 = getMulExpr(MulOps); + InnerMul2 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1); } - const SCEV *InnerMulSum = getAddExpr(InnerMul1,InnerMul2); - const SCEV *OuterMul = getMulExpr(MulOpSCEV, InnerMulSum); + SmallVector<const SCEV *, 2> TwoOps = {InnerMul1, InnerMul2}; + const SCEV *InnerMulSum = + getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1); + const SCEV *OuterMul = getMulExpr(MulOpSCEV, InnerMulSum, + SCEV::FlagAnyWrap, Depth + 1); if (Ops.size() == 2) return OuterMul; Ops.erase(Ops.begin()+Idx); Ops.erase(Ops.begin()+OtherMulIdx-1); Ops.push_back(OuterMul); - return getAddExpr(Ops); + return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1); } } } @@ -2363,7 +2518,7 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops, const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]); const Loop *AddRecLoop = AddRec->getLoop(); for (unsigned i = 0, e = Ops.size(); i != e; ++i) - if (isLoopInvariant(Ops[i], AddRecLoop)) { + if (isAvailableAtLoopEntry(Ops[i], AddRecLoop)) { LIOps.push_back(Ops[i]); Ops.erase(Ops.begin()+i); --i; --e; @@ -2379,7 +2534,7 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops, // This follows from the fact that the no-wrap flags on the outer add // expression are applicable on the 0th iteration, when the add recurrence // will be equal to its start value. - AddRecOps[0] = getAddExpr(LIOps, Flags); + AddRecOps[0] = getAddExpr(LIOps, Flags, Depth + 1); // Build the new addrec. Propagate the NUW and NSW flags if both the // outer add and the inner addrec are guaranteed to have no overflow. @@ -2396,7 +2551,7 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops, Ops[i] = NewRec; break; } - return getAddExpr(Ops); + return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1); } // Okay, if there weren't any loop invariants to be folded, check to see if @@ -2404,31 +2559,40 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops, // added together. If so, we can fold them. for (unsigned OtherIdx = Idx+1; OtherIdx < Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]); - ++OtherIdx) + ++OtherIdx) { + // We expect the AddRecExpr's to be sorted in reverse dominance order, + // so that the 1st found AddRecExpr is dominated by all others. + assert(DT.dominates( + cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()->getHeader(), + AddRec->getLoop()->getHeader()) && + "AddRecExprs are not sorted in reverse dominance order?"); if (AddRecLoop == cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()) { // Other + {A,+,B}<L> + {C,+,D}<L> --> Other + {A+C,+,B+D}<L> SmallVector<const SCEV *, 4> AddRecOps(AddRec->op_begin(), AddRec->op_end()); for (; OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]); - ++OtherIdx) - if (const auto *OtherAddRec = dyn_cast<SCEVAddRecExpr>(Ops[OtherIdx])) - if (OtherAddRec->getLoop() == AddRecLoop) { - for (unsigned i = 0, e = OtherAddRec->getNumOperands(); - i != e; ++i) { - if (i >= AddRecOps.size()) { - AddRecOps.append(OtherAddRec->op_begin()+i, - OtherAddRec->op_end()); - break; - } - AddRecOps[i] = getAddExpr(AddRecOps[i], - OtherAddRec->getOperand(i)); + ++OtherIdx) { + const auto *OtherAddRec = cast<SCEVAddRecExpr>(Ops[OtherIdx]); + if (OtherAddRec->getLoop() == AddRecLoop) { + for (unsigned i = 0, e = OtherAddRec->getNumOperands(); + i != e; ++i) { + if (i >= AddRecOps.size()) { + AddRecOps.append(OtherAddRec->op_begin()+i, + OtherAddRec->op_end()); + break; } - Ops.erase(Ops.begin() + OtherIdx); --OtherIdx; + SmallVector<const SCEV *, 2> TwoOps = { + AddRecOps[i], OtherAddRec->getOperand(i)}; + AddRecOps[i] = getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1); } + Ops.erase(Ops.begin() + OtherIdx); --OtherIdx; + } + } // Step size has changed, so we cannot guarantee no self-wraparound. Ops[Idx] = getAddRecExpr(AddRecOps, AddRecLoop, SCEV::FlagAnyWrap); - return getAddExpr(Ops); + return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1); } + } // Otherwise couldn't fold anything into this recurrence. Move onto the // next one. @@ -2436,17 +2600,44 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops, // Okay, it looks like we really DO need an add expr. Check to see if we // already have one, otherwise create a new one. + return getOrCreateAddExpr(Ops, Flags); +} + +const SCEV * +ScalarEvolution::getOrCreateAddExpr(SmallVectorImpl<const SCEV *> &Ops, + SCEV::NoWrapFlags Flags) { FoldingSetNodeID ID; ID.AddInteger(scAddExpr); for (unsigned i = 0, e = Ops.size(); i != e; ++i) ID.AddPointer(Ops[i]); void *IP = nullptr; SCEVAddExpr *S = - static_cast<SCEVAddExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP)); + static_cast<SCEVAddExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP)); if (!S) { const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size()); std::uninitialized_copy(Ops.begin(), Ops.end(), O); - S = new (SCEVAllocator) SCEVAddExpr(ID.Intern(SCEVAllocator), + S = new (SCEVAllocator) + SCEVAddExpr(ID.Intern(SCEVAllocator), O, Ops.size()); + UniqueSCEVs.InsertNode(S, IP); + } + S->setNoWrapFlags(Flags); + return S; +} + +const SCEV * +ScalarEvolution::getOrCreateMulExpr(SmallVectorImpl<const SCEV *> &Ops, + SCEV::NoWrapFlags Flags) { + FoldingSetNodeID ID; + ID.AddInteger(scMulExpr); + for (unsigned i = 0, e = Ops.size(); i != e; ++i) + ID.AddPointer(Ops[i]); + void *IP = nullptr; + SCEVMulExpr *S = + static_cast<SCEVMulExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP)); + if (!S) { + const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size()); + std::uninitialized_copy(Ops.begin(), Ops.end(), O); + S = new (SCEVAllocator) SCEVMulExpr(ID.Intern(SCEVAllocator), O, Ops.size()); UniqueSCEVs.InsertNode(S, IP); } @@ -2506,7 +2697,8 @@ static bool containsConstantSomewhere(const SCEV *StartExpr) { /// Get a canonical multiply expression, or something simpler if possible. const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops, - SCEV::NoWrapFlags Flags) { + SCEV::NoWrapFlags Flags, + unsigned Depth) { assert(Flags == maskFlags(Flags, SCEV::FlagNUW | SCEV::FlagNSW) && "only nuw or nsw allowed"); assert(!Ops.empty() && "Cannot get empty mul!"); @@ -2519,10 +2711,14 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops, #endif // Sort by complexity, this groups all similar expression types together. - GroupByComplexity(Ops, &LI); + GroupByComplexity(Ops, &LI, DT); Flags = StrengthenNoWrapFlags(this, scMulExpr, Ops, Flags); + // Limit recursion calls depth. + if (Depth > MaxArithDepth) + return getOrCreateMulExpr(Ops, Flags); + // If there are any constants, fold them together. unsigned Idx = 0; if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) { @@ -2534,8 +2730,11 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops, // apply this transformation as well. if (Add->getNumOperands() == 2) if (containsConstantSomewhere(Add)) - return getAddExpr(getMulExpr(LHSC, Add->getOperand(0)), - getMulExpr(LHSC, Add->getOperand(1))); + return getAddExpr(getMulExpr(LHSC, Add->getOperand(0), + SCEV::FlagAnyWrap, Depth + 1), + getMulExpr(LHSC, Add->getOperand(1), + SCEV::FlagAnyWrap, Depth + 1), + SCEV::FlagAnyWrap, Depth + 1); ++Idx; while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) { @@ -2549,7 +2748,7 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops, } // If we are left with a constant one being multiplied, strip it off. - if (cast<SCEVConstant>(Ops[0])->getValue()->equalsInt(1)) { + if (cast<SCEVConstant>(Ops[0])->getValue()->isOne()) { Ops.erase(Ops.begin()); --Idx; } else if (cast<SCEVConstant>(Ops[0])->getValue()->isZero()) { @@ -2563,17 +2762,19 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops, SmallVector<const SCEV *, 4> NewOps; bool AnyFolded = false; for (const SCEV *AddOp : Add->operands()) { - const SCEV *Mul = getMulExpr(Ops[0], AddOp); + const SCEV *Mul = getMulExpr(Ops[0], AddOp, SCEV::FlagAnyWrap, + Depth + 1); if (!isa<SCEVMulExpr>(Mul)) AnyFolded = true; NewOps.push_back(Mul); } if (AnyFolded) - return getAddExpr(NewOps); + return getAddExpr(NewOps, SCEV::FlagAnyWrap, Depth + 1); } else if (const auto *AddRec = dyn_cast<SCEVAddRecExpr>(Ops[1])) { // Negation preserves a recurrence's no self-wrap property. SmallVector<const SCEV *, 4> Operands; for (const SCEV *AddRecOp : AddRec->operands()) - Operands.push_back(getMulExpr(Ops[0], AddRecOp)); + Operands.push_back(getMulExpr(Ops[0], AddRecOp, SCEV::FlagAnyWrap, + Depth + 1)); return getAddRecExpr(Operands, AddRec->getLoop(), AddRec->getNoWrapFlags(SCEV::FlagNW)); @@ -2595,18 +2796,18 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops, while (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) { if (Ops.size() > MulOpsInlineThreshold) break; - // If we have an mul, expand the mul operands onto the end of the operands - // list. + // If we have an mul, expand the mul operands onto the end of the + // operands list. Ops.erase(Ops.begin()+Idx); Ops.append(Mul->op_begin(), Mul->op_end()); DeletedMul = true; } - // If we deleted at least one mul, we added operands to the end of the list, - // and they are not necessarily sorted. Recurse to resort and resimplify - // any operands we just acquired. + // If we deleted at least one mul, we added operands to the end of the + // list, and they are not necessarily sorted. Recurse to resort and + // resimplify any operands we just acquired. if (DeletedMul) - return getMulExpr(Ops); + return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1); } // If there are any add recurrences in the operands list, see if any other @@ -2617,13 +2818,13 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops, // Scan over all recurrences, trying to fold loop invariants into them. for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) { - // Scan all of the other operands to this mul and add them to the vector if - // they are loop invariant w.r.t. the recurrence. + // Scan all of the other operands to this mul and add them to the vector + // if they are loop invariant w.r.t. the recurrence. SmallVector<const SCEV *, 8> LIOps; const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]); const Loop *AddRecLoop = AddRec->getLoop(); for (unsigned i = 0, e = Ops.size(); i != e; ++i) - if (isLoopInvariant(Ops[i], AddRecLoop)) { + if (isAvailableAtLoopEntry(Ops[i], AddRecLoop)) { LIOps.push_back(Ops[i]); Ops.erase(Ops.begin()+i); --i; --e; @@ -2634,9 +2835,10 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops, // NLI * LI * {Start,+,Step} --> NLI * {LI*Start,+,LI*Step} SmallVector<const SCEV *, 4> NewOps; NewOps.reserve(AddRec->getNumOperands()); - const SCEV *Scale = getMulExpr(LIOps); + const SCEV *Scale = getMulExpr(LIOps, SCEV::FlagAnyWrap, Depth + 1); for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) - NewOps.push_back(getMulExpr(Scale, AddRec->getOperand(i))); + NewOps.push_back(getMulExpr(Scale, AddRec->getOperand(i), + SCEV::FlagAnyWrap, Depth + 1)); // Build the new addrec. Propagate the NUW and NSW flags if both the // outer mul and the inner addrec are guaranteed to have no overflow. @@ -2655,12 +2857,12 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops, Ops[i] = NewRec; break; } - return getMulExpr(Ops); + return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1); } - // Okay, if there weren't any loop invariants to be folded, check to see if - // there are multiple AddRec's with the same loop induction variable being - // multiplied together. If so, we can fold them. + // Okay, if there weren't any loop invariants to be folded, check to see + // if there are multiple AddRec's with the same loop induction variable + // being multiplied together. If so, we can fold them. // {A1,+,A2,+,...,+,An}<L> * {B1,+,B2,+,...,+,Bn}<L> // = {x=1 in [ sum y=x..2x [ sum z=max(y-x, y-n)..min(x,n) [ @@ -2681,6 +2883,12 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops, if (!OtherAddRec || OtherAddRec->getLoop() != AddRecLoop) continue; + // Limit max number of arguments to avoid creation of unreasonably big + // SCEVAddRecs with very complex operands. + if (AddRec->getNumOperands() + OtherAddRec->getNumOperands() - 1 > + MaxAddRecSize) + continue; + bool Overflow = false; Type *Ty = AddRec->getType(); bool LargerThan64Bits = getTypeSizeInBits(Ty) > 64; @@ -2702,7 +2910,9 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops, const SCEV *CoeffTerm = getConstant(Ty, Coeff); const SCEV *Term1 = AddRec->getOperand(y-z); const SCEV *Term2 = OtherAddRec->getOperand(z); - Term = getAddExpr(Term, getMulExpr(CoeffTerm, Term1,Term2)); + Term = getAddExpr(Term, getMulExpr(CoeffTerm, Term1, Term2, + SCEV::FlagAnyWrap, Depth + 1), + SCEV::FlagAnyWrap, Depth + 1); } } AddRecOps.push_back(Term); @@ -2720,7 +2930,7 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops, } } if (OpsModified) - return getMulExpr(Ops); + return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1); // Otherwise couldn't fold anything into this recurrence. Move onto the // next one. @@ -2728,22 +2938,7 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops, // Okay, it looks like we really DO need an mul expr. Check to see if we // already have one, otherwise create a new one. - FoldingSetNodeID ID; - ID.AddInteger(scMulExpr); - for (unsigned i = 0, e = Ops.size(); i != e; ++i) - ID.AddPointer(Ops[i]); - void *IP = nullptr; - SCEVMulExpr *S = - static_cast<SCEVMulExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP)); - if (!S) { - const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size()); - std::uninitialized_copy(Ops.begin(), Ops.end(), O); - S = new (SCEVAllocator) SCEVMulExpr(ID.Intern(SCEVAllocator), - O, Ops.size()); - UniqueSCEVs.InsertNode(S, IP); - } - S->setNoWrapFlags(Flags); - return S; + return getOrCreateMulExpr(Ops, Flags); } /// Get a canonical unsigned division expression, or something simpler if @@ -2755,7 +2950,7 @@ const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS, "SCEVUDivExpr operand types don't match!"); if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) { - if (RHSC->getValue()->equalsInt(1)) + if (RHSC->getValue()->isOne()) return LHS; // X udiv 1 --> x // If the denominator is zero, the result of the udiv is undefined. Don't // try to analyze it, because the resolution chosen here may differ from @@ -2875,7 +3070,7 @@ static const APInt gcd(const SCEVConstant *C1, const SCEVConstant *C2) { else if (ABW < BBW) A = A.zext(BBW); - return APIntOps::GreatestCommonDivisor(A, B); + return APIntOps::GreatestCommonDivisor(std::move(A), std::move(B)); } /// Get a canonical unsigned division expression, or something simpler if @@ -2889,7 +3084,7 @@ const SCEV *ScalarEvolution::getUDivExactExpr(const SCEV *LHS, // end of this file for inspiration. const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(LHS); - if (!Mul) + if (!Mul || !Mul->hasNoUnsignedWrap()) return getUDivExpr(LHS, RHS); if (const SCEVConstant *RHSCst = dyn_cast<SCEVConstant>(RHS)) { @@ -3116,7 +3311,7 @@ ScalarEvolution::getSMaxExpr(SmallVectorImpl<const SCEV *> &Ops) { #endif // Sort by complexity, this groups all similar expression types together. - GroupByComplexity(Ops, &LI); + GroupByComplexity(Ops, &LI, DT); // If there are any constants, fold them together. unsigned Idx = 0; @@ -3217,7 +3412,7 @@ ScalarEvolution::getUMaxExpr(SmallVectorImpl<const SCEV *> &Ops) { #endif // Sort by complexity, this groups all similar expression types together. - GroupByComplexity(Ops, &LI); + GroupByComplexity(Ops, &LI, DT); // If there are any constants, fold them together. unsigned Idx = 0; @@ -3385,6 +3580,10 @@ Type *ScalarEvolution::getEffectiveSCEVType(Type *Ty) const { return getDataLayout().getIntPtrType(Ty); } +Type *ScalarEvolution::getWiderType(Type *T1, Type *T2) const { + return getTypeSizeInBits(T1) >= getTypeSizeInBits(T2) ? T1 : T2; +} + const SCEV *ScalarEvolution::getCouldNotCompute() { return CouldNotCompute.get(); } @@ -3542,7 +3741,8 @@ const SCEV *ScalarEvolution::getNotSCEV(const SCEV *V) { } const SCEV *ScalarEvolution::getMinusSCEV(const SCEV *LHS, const SCEV *RHS, - SCEV::NoWrapFlags Flags) { + SCEV::NoWrapFlags Flags, + unsigned Depth) { // Fast path: X - X --> 0. if (LHS == RHS) return getZero(LHS->getType()); @@ -3551,7 +3751,7 @@ const SCEV *ScalarEvolution::getMinusSCEV(const SCEV *LHS, const SCEV *RHS, // makes it so that we cannot make much use of NUW. auto AddFlags = SCEV::FlagAnyWrap; const bool RHSIsNotMinSigned = - !getSignedRange(RHS).getSignedMin().isMinSignedValue(); + !getSignedRangeMin(RHS).isMinSignedValue(); if (maskFlags(Flags, SCEV::FlagNSW) == SCEV::FlagNSW) { // Let M be the minimum representable signed value. Then (-1)*RHS // signed-wraps if and only if RHS is M. That can happen even for @@ -3576,7 +3776,7 @@ const SCEV *ScalarEvolution::getMinusSCEV(const SCEV *LHS, const SCEV *RHS, // larger scope than intended. auto NegFlags = RHSIsNotMinSigned ? SCEV::FlagNSW : SCEV::FlagAnyWrap; - return getAddExpr(LHS, getNegativeSCEV(RHS, NegFlags), AddFlags); + return getAddExpr(LHS, getNegativeSCEV(RHS, NegFlags), AddFlags, Depth); } const SCEV * @@ -3770,7 +3970,7 @@ public: : SCEVRewriteVisitor(SE), L(L), Valid(true) {} const SCEV *visitUnknown(const SCEVUnknown *Expr) { - if (!(SE.getLoopDisposition(Expr, L) == ScalarEvolution::LoopInvariant)) + if (!SE.isLoopInvariant(Expr, L)) Valid = false; return Expr; } @@ -3804,7 +4004,7 @@ public: const SCEV *visitUnknown(const SCEVUnknown *Expr) { // Only allow AddRecExprs for this loop. - if (!(SE.getLoopDisposition(Expr, L) == ScalarEvolution::LoopInvariant)) + if (!SE.isLoopInvariant(Expr, L)) Valid = false; return Expr; } @@ -3909,9 +4109,9 @@ static Optional<BinaryOp> MatchBinaryOp(Value *V, DominatorTree &DT) { case Instruction::Xor: if (auto *RHSC = dyn_cast<ConstantInt>(Op->getOperand(1))) - // If the RHS of the xor is a signbit, then this is just an add. - // Instcombine turns add of signbit into xor as a strength reduction step. - if (RHSC->getValue().isSignBit()) + // If the RHS of the xor is a signmask, then this is just an add. + // Instcombine turns add of signmask into xor as a strength reduction step. + if (RHSC->getValue().isSignMask()) return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1)); return BinaryOp(Op); @@ -3984,6 +4184,369 @@ static Optional<BinaryOp> MatchBinaryOp(Value *V, DominatorTree &DT) { return None; } +/// Helper function to createAddRecFromPHIWithCasts. We have a phi +/// node whose symbolic (unknown) SCEV is \p SymbolicPHI, which is updated via +/// the loop backedge by a SCEVAddExpr, possibly also with a few casts on the +/// way. This function checks if \p Op, an operand of this SCEVAddExpr, +/// follows one of the following patterns: +/// Op == (SExt ix (Trunc iy (%SymbolicPHI) to ix) to iy) +/// Op == (ZExt ix (Trunc iy (%SymbolicPHI) to ix) to iy) +/// If the SCEV expression of \p Op conforms with one of the expected patterns +/// we return the type of the truncation operation, and indicate whether the +/// truncated type should be treated as signed/unsigned by setting +/// \p Signed to true/false, respectively. +static Type *isSimpleCastedPHI(const SCEV *Op, const SCEVUnknown *SymbolicPHI, + bool &Signed, ScalarEvolution &SE) { + + // The case where Op == SymbolicPHI (that is, with no type conversions on + // the way) is handled by the regular add recurrence creating logic and + // would have already been triggered in createAddRecForPHI. Reaching it here + // means that createAddRecFromPHI had failed for this PHI before (e.g., + // because one of the other operands of the SCEVAddExpr updating this PHI is + // not invariant). + // + // Here we look for the case where Op = (ext(trunc(SymbolicPHI))), and in + // this case predicates that allow us to prove that Op == SymbolicPHI will + // be added. + if (Op == SymbolicPHI) + return nullptr; + + unsigned SourceBits = SE.getTypeSizeInBits(SymbolicPHI->getType()); + unsigned NewBits = SE.getTypeSizeInBits(Op->getType()); + if (SourceBits != NewBits) + return nullptr; + + const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(Op); + const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(Op); + if (!SExt && !ZExt) + return nullptr; + const SCEVTruncateExpr *Trunc = + SExt ? dyn_cast<SCEVTruncateExpr>(SExt->getOperand()) + : dyn_cast<SCEVTruncateExpr>(ZExt->getOperand()); + if (!Trunc) + return nullptr; + const SCEV *X = Trunc->getOperand(); + if (X != SymbolicPHI) + return nullptr; + Signed = SExt ? true : false; + return Trunc->getType(); +} + +static const Loop *isIntegerLoopHeaderPHI(const PHINode *PN, LoopInfo &LI) { + if (!PN->getType()->isIntegerTy()) + return nullptr; + const Loop *L = LI.getLoopFor(PN->getParent()); + if (!L || L->getHeader() != PN->getParent()) + return nullptr; + return L; +} + +// Analyze \p SymbolicPHI, a SCEV expression of a phi node, and check if the +// computation that updates the phi follows the following pattern: +// (SExt/ZExt ix (Trunc iy (%SymbolicPHI) to ix) to iy) + InvariantAccum +// which correspond to a phi->trunc->sext/zext->add->phi update chain. +// If so, try to see if it can be rewritten as an AddRecExpr under some +// Predicates. If successful, return them as a pair. Also cache the results +// of the analysis. +// +// Example usage scenario: +// Say the Rewriter is called for the following SCEV: +// 8 * ((sext i32 (trunc i64 %X to i32) to i64) + %Step) +// where: +// %X = phi i64 (%Start, %BEValue) +// It will visitMul->visitAdd->visitSExt->visitTrunc->visitUnknown(%X), +// and call this function with %SymbolicPHI = %X. +// +// The analysis will find that the value coming around the backedge has +// the following SCEV: +// BEValue = ((sext i32 (trunc i64 %X to i32) to i64) + %Step) +// Upon concluding that this matches the desired pattern, the function +// will return the pair {NewAddRec, SmallPredsVec} where: +// NewAddRec = {%Start,+,%Step} +// SmallPredsVec = {P1, P2, P3} as follows: +// P1(WrapPred): AR: {trunc(%Start),+,(trunc %Step)}<nsw> Flags: <nssw> +// P2(EqualPred): %Start == (sext i32 (trunc i64 %Start to i32) to i64) +// P3(EqualPred): %Step == (sext i32 (trunc i64 %Step to i32) to i64) +// The returned pair means that SymbolicPHI can be rewritten into NewAddRec +// under the predicates {P1,P2,P3}. +// This predicated rewrite will be cached in PredicatedSCEVRewrites: +// PredicatedSCEVRewrites[{%X,L}] = {NewAddRec, {P1,P2,P3)} +// +// TODO's: +// +// 1) Extend the Induction descriptor to also support inductions that involve +// casts: When needed (namely, when we are called in the context of the +// vectorizer induction analysis), a Set of cast instructions will be +// populated by this method, and provided back to isInductionPHI. This is +// needed to allow the vectorizer to properly record them to be ignored by +// the cost model and to avoid vectorizing them (otherwise these casts, +// which are redundant under the runtime overflow checks, will be +// vectorized, which can be costly). +// +// 2) Support additional induction/PHISCEV patterns: We also want to support +// inductions where the sext-trunc / zext-trunc operations (partly) occur +// after the induction update operation (the induction increment): +// +// (Trunc iy (SExt/ZExt ix (%SymbolicPHI + InvariantAccum) to iy) to ix) +// which correspond to a phi->add->trunc->sext/zext->phi update chain. +// +// (Trunc iy ((SExt/ZExt ix (%SymbolicPhi) to iy) + InvariantAccum) to ix) +// which correspond to a phi->trunc->add->sext/zext->phi update chain. +// +// 3) Outline common code with createAddRecFromPHI to avoid duplication. +// +Optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>> +ScalarEvolution::createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI) { + SmallVector<const SCEVPredicate *, 3> Predicates; + + // *** Part1: Analyze if we have a phi-with-cast pattern for which we can + // return an AddRec expression under some predicate. + + auto *PN = cast<PHINode>(SymbolicPHI->getValue()); + const Loop *L = isIntegerLoopHeaderPHI(PN, LI); + assert (L && "Expecting an integer loop header phi"); + + // The loop may have multiple entrances or multiple exits; we can analyze + // this phi as an addrec if it has a unique entry value and a unique + // backedge value. + Value *BEValueV = nullptr, *StartValueV = nullptr; + for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) { + Value *V = PN->getIncomingValue(i); + if (L->contains(PN->getIncomingBlock(i))) { + if (!BEValueV) { + BEValueV = V; + } else if (BEValueV != V) { + BEValueV = nullptr; + break; + } + } else if (!StartValueV) { + StartValueV = V; + } else if (StartValueV != V) { + StartValueV = nullptr; + break; + } + } + if (!BEValueV || !StartValueV) + return None; + + const SCEV *BEValue = getSCEV(BEValueV); + + // If the value coming around the backedge is an add with the symbolic + // value we just inserted, possibly with casts that we can ignore under + // an appropriate runtime guard, then we found a simple induction variable! + const auto *Add = dyn_cast<SCEVAddExpr>(BEValue); + if (!Add) + return None; + + // If there is a single occurrence of the symbolic value, possibly + // casted, replace it with a recurrence. + unsigned FoundIndex = Add->getNumOperands(); + Type *TruncTy = nullptr; + bool Signed; + for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i) + if ((TruncTy = + isSimpleCastedPHI(Add->getOperand(i), SymbolicPHI, Signed, *this))) + if (FoundIndex == e) { + FoundIndex = i; + break; + } + + if (FoundIndex == Add->getNumOperands()) + return None; + + // Create an add with everything but the specified operand. + SmallVector<const SCEV *, 8> Ops; + for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i) + if (i != FoundIndex) + Ops.push_back(Add->getOperand(i)); + const SCEV *Accum = getAddExpr(Ops); + + // The runtime checks will not be valid if the step amount is + // varying inside the loop. + if (!isLoopInvariant(Accum, L)) + return None; + + + // *** Part2: Create the predicates + + // Analysis was successful: we have a phi-with-cast pattern for which we + // can return an AddRec expression under the following predicates: + // + // P1: A Wrap predicate that guarantees that Trunc(Start) + i*Trunc(Accum) + // fits within the truncated type (does not overflow) for i = 0 to n-1. + // P2: An Equal predicate that guarantees that + // Start = (Ext ix (Trunc iy (Start) to ix) to iy) + // P3: An Equal predicate that guarantees that + // Accum = (Ext ix (Trunc iy (Accum) to ix) to iy) + // + // As we next prove, the above predicates guarantee that: + // Start + i*Accum = (Ext ix (Trunc iy ( Start + i*Accum ) to ix) to iy) + // + // + // More formally, we want to prove that: + // Expr(i+1) = Start + (i+1) * Accum + // = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum + // + // Given that: + // 1) Expr(0) = Start + // 2) Expr(1) = Start + Accum + // = (Ext ix (Trunc iy (Start) to ix) to iy) + Accum :: from P2 + // 3) Induction hypothesis (step i): + // Expr(i) = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum + // + // Proof: + // Expr(i+1) = + // = Start + (i+1)*Accum + // = (Start + i*Accum) + Accum + // = Expr(i) + Accum + // = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum + Accum + // :: from step i + // + // = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy) + Accum + Accum + // + // = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy) + // + (Ext ix (Trunc iy (Accum) to ix) to iy) + // + Accum :: from P3 + // + // = (Ext ix (Trunc iy ((Start + (i-1)*Accum) + Accum) to ix) to iy) + // + Accum :: from P1: Ext(x)+Ext(y)=>Ext(x+y) + // + // = (Ext ix (Trunc iy (Start + i*Accum) to ix) to iy) + Accum + // = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum + // + // By induction, the same applies to all iterations 1<=i<n: + // + + // Create a truncated addrec for which we will add a no overflow check (P1). + const SCEV *StartVal = getSCEV(StartValueV); + const SCEV *PHISCEV = + getAddRecExpr(getTruncateExpr(StartVal, TruncTy), + getTruncateExpr(Accum, TruncTy), L, SCEV::FlagAnyWrap); + const auto *AR = cast<SCEVAddRecExpr>(PHISCEV); + + SCEVWrapPredicate::IncrementWrapFlags AddedFlags = + Signed ? SCEVWrapPredicate::IncrementNSSW + : SCEVWrapPredicate::IncrementNUSW; + const SCEVPredicate *AddRecPred = getWrapPredicate(AR, AddedFlags); + Predicates.push_back(AddRecPred); + + // Create the Equal Predicates P2,P3: + auto AppendPredicate = [&](const SCEV *Expr) -> void { + assert (isLoopInvariant(Expr, L) && "Expr is expected to be invariant"); + const SCEV *TruncatedExpr = getTruncateExpr(Expr, TruncTy); + const SCEV *ExtendedExpr = + Signed ? getSignExtendExpr(TruncatedExpr, Expr->getType()) + : getZeroExtendExpr(TruncatedExpr, Expr->getType()); + if (Expr != ExtendedExpr && + !isKnownPredicate(ICmpInst::ICMP_EQ, Expr, ExtendedExpr)) { + const SCEVPredicate *Pred = getEqualPredicate(Expr, ExtendedExpr); + DEBUG (dbgs() << "Added Predicate: " << *Pred); + Predicates.push_back(Pred); + } + }; + + AppendPredicate(StartVal); + AppendPredicate(Accum); + + // *** Part3: Predicates are ready. Now go ahead and create the new addrec in + // which the casts had been folded away. The caller can rewrite SymbolicPHI + // into NewAR if it will also add the runtime overflow checks specified in + // Predicates. + auto *NewAR = getAddRecExpr(StartVal, Accum, L, SCEV::FlagAnyWrap); + + std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> PredRewrite = + std::make_pair(NewAR, Predicates); + // Remember the result of the analysis for this SCEV at this locayyytion. + PredicatedSCEVRewrites[{SymbolicPHI, L}] = PredRewrite; + return PredRewrite; +} + +Optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>> +ScalarEvolution::createAddRecFromPHIWithCasts(const SCEVUnknown *SymbolicPHI) { + + auto *PN = cast<PHINode>(SymbolicPHI->getValue()); + const Loop *L = isIntegerLoopHeaderPHI(PN, LI); + if (!L) + return None; + + // Check to see if we already analyzed this PHI. + auto I = PredicatedSCEVRewrites.find({SymbolicPHI, L}); + if (I != PredicatedSCEVRewrites.end()) { + std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> Rewrite = + I->second; + // Analysis was done before and failed to create an AddRec: + if (Rewrite.first == SymbolicPHI) + return None; + // Analysis was done before and succeeded to create an AddRec under + // a predicate: + assert(isa<SCEVAddRecExpr>(Rewrite.first) && "Expected an AddRec"); + assert(!(Rewrite.second).empty() && "Expected to find Predicates"); + return Rewrite; + } + + Optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>> + Rewrite = createAddRecFromPHIWithCastsImpl(SymbolicPHI); + + // Record in the cache that the analysis failed + if (!Rewrite) { + SmallVector<const SCEVPredicate *, 3> Predicates; + PredicatedSCEVRewrites[{SymbolicPHI, L}] = {SymbolicPHI, Predicates}; + return None; + } + + return Rewrite; +} + +/// A helper function for createAddRecFromPHI to handle simple cases. +/// +/// This function tries to find an AddRec expression for the simplest (yet most +/// common) cases: PN = PHI(Start, OP(Self, LoopInvariant)). +/// If it fails, createAddRecFromPHI will use a more general, but slow, +/// technique for finding the AddRec expression. +const SCEV *ScalarEvolution::createSimpleAffineAddRec(PHINode *PN, + Value *BEValueV, + Value *StartValueV) { + const Loop *L = LI.getLoopFor(PN->getParent()); + assert(L && L->getHeader() == PN->getParent()); + assert(BEValueV && StartValueV); + + auto BO = MatchBinaryOp(BEValueV, DT); + if (!BO) + return nullptr; + + if (BO->Opcode != Instruction::Add) + return nullptr; + + const SCEV *Accum = nullptr; + if (BO->LHS == PN && L->isLoopInvariant(BO->RHS)) + Accum = getSCEV(BO->RHS); + else if (BO->RHS == PN && L->isLoopInvariant(BO->LHS)) + Accum = getSCEV(BO->LHS); + + if (!Accum) + return nullptr; + + SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap; + if (BO->IsNUW) + Flags = setFlags(Flags, SCEV::FlagNUW); + if (BO->IsNSW) + Flags = setFlags(Flags, SCEV::FlagNSW); + + const SCEV *StartVal = getSCEV(StartValueV); + const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags); + + ValueExprMap[SCEVCallbackVH(PN, this)] = PHISCEV; + + // We can add Flags to the post-inc expression only if we + // know that it is *undefined behavior* for BEValueV to + // overflow. + if (auto *BEInst = dyn_cast<Instruction>(BEValueV)) + if (isLoopInvariant(Accum, L) && isAddRecNeverPoison(BEInst, L)) + (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags); + + return PHISCEV; +} + const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) { const Loop *L = LI.getLoopFor(PN->getParent()); if (!L || L->getHeader() != PN->getParent()) @@ -4009,127 +4572,134 @@ const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) { break; } } - if (BEValueV && StartValueV) { - // While we are analyzing this PHI node, handle its value symbolically. - const SCEV *SymbolicName = getUnknown(PN); - assert(ValueExprMap.find_as(PN) == ValueExprMap.end() && - "PHI node already processed?"); - ValueExprMap.insert({SCEVCallbackVH(PN, this), SymbolicName}); + if (!BEValueV || !StartValueV) + return nullptr; - // Using this symbolic name for the PHI, analyze the value coming around - // the back-edge. - const SCEV *BEValue = getSCEV(BEValueV); + assert(ValueExprMap.find_as(PN) == ValueExprMap.end() && + "PHI node already processed?"); - // NOTE: If BEValue is loop invariant, we know that the PHI node just - // has a special value for the first iteration of the loop. + // First, try to find AddRec expression without creating a fictituos symbolic + // value for PN. + if (auto *S = createSimpleAffineAddRec(PN, BEValueV, StartValueV)) + return S; + + // Handle PHI node value symbolically. + const SCEV *SymbolicName = getUnknown(PN); + ValueExprMap.insert({SCEVCallbackVH(PN, this), SymbolicName}); + + // Using this symbolic name for the PHI, analyze the value coming around + // the back-edge. + const SCEV *BEValue = getSCEV(BEValueV); + + // NOTE: If BEValue is loop invariant, we know that the PHI node just + // has a special value for the first iteration of the loop. + + // If the value coming around the backedge is an add with the symbolic + // value we just inserted, then we found a simple induction variable! + if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(BEValue)) { + // If there is a single occurrence of the symbolic value, replace it + // with a recurrence. + unsigned FoundIndex = Add->getNumOperands(); + for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i) + if (Add->getOperand(i) == SymbolicName) + if (FoundIndex == e) { + FoundIndex = i; + break; + } - // If the value coming around the backedge is an add with the symbolic - // value we just inserted, then we found a simple induction variable! - if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(BEValue)) { - // If there is a single occurrence of the symbolic value, replace it - // with a recurrence. - unsigned FoundIndex = Add->getNumOperands(); + if (FoundIndex != Add->getNumOperands()) { + // Create an add with everything but the specified operand. + SmallVector<const SCEV *, 8> Ops; for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i) - if (Add->getOperand(i) == SymbolicName) - if (FoundIndex == e) { - FoundIndex = i; - break; + if (i != FoundIndex) + Ops.push_back(Add->getOperand(i)); + const SCEV *Accum = getAddExpr(Ops); + + // This is not a valid addrec if the step amount is varying each + // loop iteration, but is not itself an addrec in this loop. + if (isLoopInvariant(Accum, L) || + (isa<SCEVAddRecExpr>(Accum) && + cast<SCEVAddRecExpr>(Accum)->getLoop() == L)) { + SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap; + + if (auto BO = MatchBinaryOp(BEValueV, DT)) { + if (BO->Opcode == Instruction::Add && BO->LHS == PN) { + if (BO->IsNUW) + Flags = setFlags(Flags, SCEV::FlagNUW); + if (BO->IsNSW) + Flags = setFlags(Flags, SCEV::FlagNSW); } - - if (FoundIndex != Add->getNumOperands()) { - // Create an add with everything but the specified operand. - SmallVector<const SCEV *, 8> Ops; - for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i) - if (i != FoundIndex) - Ops.push_back(Add->getOperand(i)); - const SCEV *Accum = getAddExpr(Ops); - - // This is not a valid addrec if the step amount is varying each - // loop iteration, but is not itself an addrec in this loop. - if (isLoopInvariant(Accum, L) || - (isa<SCEVAddRecExpr>(Accum) && - cast<SCEVAddRecExpr>(Accum)->getLoop() == L)) { - SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap; - - if (auto BO = MatchBinaryOp(BEValueV, DT)) { - if (BO->Opcode == Instruction::Add && BO->LHS == PN) { - if (BO->IsNUW) - Flags = setFlags(Flags, SCEV::FlagNUW); - if (BO->IsNSW) - Flags = setFlags(Flags, SCEV::FlagNSW); - } - } else if (GEPOperator *GEP = dyn_cast<GEPOperator>(BEValueV)) { - // If the increment is an inbounds GEP, then we know the address - // space cannot be wrapped around. We cannot make any guarantee - // about signed or unsigned overflow because pointers are - // unsigned but we may have a negative index from the base - // pointer. We can guarantee that no unsigned wrap occurs if the - // indices form a positive value. - if (GEP->isInBounds() && GEP->getOperand(0) == PN) { - Flags = setFlags(Flags, SCEV::FlagNW); - - const SCEV *Ptr = getSCEV(GEP->getPointerOperand()); - if (isKnownPositive(getMinusSCEV(getSCEV(GEP), Ptr))) - Flags = setFlags(Flags, SCEV::FlagNUW); - } - - // We cannot transfer nuw and nsw flags from subtraction - // operations -- sub nuw X, Y is not the same as add nuw X, -Y - // for instance. + } else if (GEPOperator *GEP = dyn_cast<GEPOperator>(BEValueV)) { + // If the increment is an inbounds GEP, then we know the address + // space cannot be wrapped around. We cannot make any guarantee + // about signed or unsigned overflow because pointers are + // unsigned but we may have a negative index from the base + // pointer. We can guarantee that no unsigned wrap occurs if the + // indices form a positive value. + if (GEP->isInBounds() && GEP->getOperand(0) == PN) { + Flags = setFlags(Flags, SCEV::FlagNW); + + const SCEV *Ptr = getSCEV(GEP->getPointerOperand()); + if (isKnownPositive(getMinusSCEV(getSCEV(GEP), Ptr))) + Flags = setFlags(Flags, SCEV::FlagNUW); } - const SCEV *StartVal = getSCEV(StartValueV); - const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags); + // We cannot transfer nuw and nsw flags from subtraction + // operations -- sub nuw X, Y is not the same as add nuw X, -Y + // for instance. + } - // Okay, for the entire analysis of this edge we assumed the PHI - // to be symbolic. We now need to go back and purge all of the - // entries for the scalars that use the symbolic expression. - forgetSymbolicName(PN, SymbolicName); - ValueExprMap[SCEVCallbackVH(PN, this)] = PHISCEV; + const SCEV *StartVal = getSCEV(StartValueV); + const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags); - // We can add Flags to the post-inc expression only if we - // know that it us *undefined behavior* for BEValueV to - // overflow. - if (auto *BEInst = dyn_cast<Instruction>(BEValueV)) - if (isLoopInvariant(Accum, L) && isAddRecNeverPoison(BEInst, L)) - (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags); + // Okay, for the entire analysis of this edge we assumed the PHI + // to be symbolic. We now need to go back and purge all of the + // entries for the scalars that use the symbolic expression. + forgetSymbolicName(PN, SymbolicName); + ValueExprMap[SCEVCallbackVH(PN, this)] = PHISCEV; - return PHISCEV; - } + // We can add Flags to the post-inc expression only if we + // know that it is *undefined behavior* for BEValueV to + // overflow. + if (auto *BEInst = dyn_cast<Instruction>(BEValueV)) + if (isLoopInvariant(Accum, L) && isAddRecNeverPoison(BEInst, L)) + (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags); + + return PHISCEV; } - } else { - // Otherwise, this could be a loop like this: - // i = 0; for (j = 1; ..; ++j) { .... i = j; } - // In this case, j = {1,+,1} and BEValue is j. - // Because the other in-value of i (0) fits the evolution of BEValue - // i really is an addrec evolution. - // - // We can generalize this saying that i is the shifted value of BEValue - // by one iteration: - // PHI(f(0), f({1,+,1})) --> f({0,+,1}) - const SCEV *Shifted = SCEVShiftRewriter::rewrite(BEValue, L, *this); - const SCEV *Start = SCEVInitRewriter::rewrite(Shifted, L, *this); - if (Shifted != getCouldNotCompute() && - Start != getCouldNotCompute()) { - const SCEV *StartVal = getSCEV(StartValueV); - if (Start == StartVal) { - // Okay, for the entire analysis of this edge we assumed the PHI - // to be symbolic. We now need to go back and purge all of the - // entries for the scalars that use the symbolic expression. - forgetSymbolicName(PN, SymbolicName); - ValueExprMap[SCEVCallbackVH(PN, this)] = Shifted; - return Shifted; - } + } + } else { + // Otherwise, this could be a loop like this: + // i = 0; for (j = 1; ..; ++j) { .... i = j; } + // In this case, j = {1,+,1} and BEValue is j. + // Because the other in-value of i (0) fits the evolution of BEValue + // i really is an addrec evolution. + // + // We can generalize this saying that i is the shifted value of BEValue + // by one iteration: + // PHI(f(0), f({1,+,1})) --> f({0,+,1}) + const SCEV *Shifted = SCEVShiftRewriter::rewrite(BEValue, L, *this); + const SCEV *Start = SCEVInitRewriter::rewrite(Shifted, L, *this); + if (Shifted != getCouldNotCompute() && + Start != getCouldNotCompute()) { + const SCEV *StartVal = getSCEV(StartValueV); + if (Start == StartVal) { + // Okay, for the entire analysis of this edge we assumed the PHI + // to be symbolic. We now need to go back and purge all of the + // entries for the scalars that use the symbolic expression. + forgetSymbolicName(PN, SymbolicName); + ValueExprMap[SCEVCallbackVH(PN, this)] = Shifted; + return Shifted; } } - - // Remove the temporary PHI node SCEV that has been inserted while intending - // to create an AddRecExpr for this PHI node. We can not keep this temporary - // as it will prevent later (possibly simpler) SCEV expressions to be added - // to the ValueExprMap. - eraseValueFromMap(PN); } + // Remove the temporary PHI node SCEV that has been inserted while intending + // to create an AddRecExpr for this PHI node. We can not keep this temporary + // as it will prevent later (possibly simpler) SCEV expressions to be added + // to the ValueExprMap. + eraseValueFromMap(PN); + return nullptr; } @@ -4289,7 +4859,7 @@ const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) { // PHI's incoming blocks are in a different loop, in which case doing so // risks breaking LCSSA form. Instcombine would normally zap these, but // it doesn't have DominatorTree information, so it may miss cases. - if (Value *V = SimplifyInstruction(PN, getDataLayout(), &TLI, &DT, &AC)) + if (Value *V = SimplifyInstruction(PN, {getDataLayout(), &TLI, &DT, &AC})) if (LI.replacementPreservesLCSSAForm(PN, V)) return getSCEV(V); @@ -4409,8 +4979,7 @@ const SCEV *ScalarEvolution::createNodeForGEP(GEPOperator *GEP) { return getGEPExpr(GEP, IndexExprs); } -uint32_t -ScalarEvolution::GetMinTrailingZeros(const SCEV *S) { +uint32_t ScalarEvolution::GetMinTrailingZerosImpl(const SCEV *S) { if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S)) return C->getAPInt().countTrailingZeros(); @@ -4420,14 +4989,16 @@ ScalarEvolution::GetMinTrailingZeros(const SCEV *S) { if (const SCEVZeroExtendExpr *E = dyn_cast<SCEVZeroExtendExpr>(S)) { uint32_t OpRes = GetMinTrailingZeros(E->getOperand()); - return OpRes == getTypeSizeInBits(E->getOperand()->getType()) ? - getTypeSizeInBits(E->getType()) : OpRes; + return OpRes == getTypeSizeInBits(E->getOperand()->getType()) + ? getTypeSizeInBits(E->getType()) + : OpRes; } if (const SCEVSignExtendExpr *E = dyn_cast<SCEVSignExtendExpr>(S)) { uint32_t OpRes = GetMinTrailingZeros(E->getOperand()); - return OpRes == getTypeSizeInBits(E->getOperand()->getType()) ? - getTypeSizeInBits(E->getType()) : OpRes; + return OpRes == getTypeSizeInBits(E->getOperand()->getType()) + ? getTypeSizeInBits(E->getType()) + : OpRes; } if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(S)) { @@ -4444,8 +5015,8 @@ ScalarEvolution::GetMinTrailingZeros(const SCEV *S) { uint32_t BitWidth = getTypeSizeInBits(M->getType()); for (unsigned i = 1, e = M->getNumOperands(); SumOpRes != BitWidth && i != e; ++i) - SumOpRes = std::min(SumOpRes + GetMinTrailingZeros(M->getOperand(i)), - BitWidth); + SumOpRes = + std::min(SumOpRes + GetMinTrailingZeros(M->getOperand(i)), BitWidth); return SumOpRes; } @@ -4475,17 +5046,25 @@ ScalarEvolution::GetMinTrailingZeros(const SCEV *S) { if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S)) { // For a SCEVUnknown, ask ValueTracking. - unsigned BitWidth = getTypeSizeInBits(U->getType()); - APInt Zeros(BitWidth, 0), Ones(BitWidth, 0); - computeKnownBits(U->getValue(), Zeros, Ones, getDataLayout(), 0, &AC, - nullptr, &DT); - return Zeros.countTrailingOnes(); + KnownBits Known = computeKnownBits(U->getValue(), getDataLayout(), 0, &AC, nullptr, &DT); + return Known.countMinTrailingZeros(); } // SCEVUDivExpr return 0; } +uint32_t ScalarEvolution::GetMinTrailingZeros(const SCEV *S) { + auto I = MinTrailingZerosCache.find(S); + if (I != MinTrailingZerosCache.end()) + return I->second; + + uint32_t Result = GetMinTrailingZerosImpl(S); + auto InsertPair = MinTrailingZerosCache.insert({S, Result}); + assert(InsertPair.second && "Should insert a new key"); + return InsertPair.first->second; +} + /// Helper method to assign a range to V from metadata present in the IR. static Optional<ConstantRange> GetRangeFromMetadata(Value *V) { if (Instruction *I = dyn_cast<Instruction>(V)) @@ -4498,9 +5077,9 @@ static Optional<ConstantRange> GetRangeFromMetadata(Value *V) { /// Determine the range for a particular SCEV. If SignHint is /// HINT_RANGE_UNSIGNED (resp. HINT_RANGE_SIGNED) then getRange prefers ranges /// with a "cleaner" unsigned (resp. signed) representation. -ConstantRange -ScalarEvolution::getRange(const SCEV *S, - ScalarEvolution::RangeSignHint SignHint) { +const ConstantRange & +ScalarEvolution::getRangeRef(const SCEV *S, + ScalarEvolution::RangeSignHint SignHint) { DenseMap<const SCEV *, ConstantRange> &Cache = SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? UnsignedRanges : SignedRanges; @@ -4531,54 +5110,54 @@ ScalarEvolution::getRange(const SCEV *S, } if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(S)) { - ConstantRange X = getRange(Add->getOperand(0), SignHint); + ConstantRange X = getRangeRef(Add->getOperand(0), SignHint); for (unsigned i = 1, e = Add->getNumOperands(); i != e; ++i) - X = X.add(getRange(Add->getOperand(i), SignHint)); + X = X.add(getRangeRef(Add->getOperand(i), SignHint)); return setRange(Add, SignHint, ConservativeResult.intersectWith(X)); } if (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(S)) { - ConstantRange X = getRange(Mul->getOperand(0), SignHint); + ConstantRange X = getRangeRef(Mul->getOperand(0), SignHint); for (unsigned i = 1, e = Mul->getNumOperands(); i != e; ++i) - X = X.multiply(getRange(Mul->getOperand(i), SignHint)); + X = X.multiply(getRangeRef(Mul->getOperand(i), SignHint)); return setRange(Mul, SignHint, ConservativeResult.intersectWith(X)); } if (const SCEVSMaxExpr *SMax = dyn_cast<SCEVSMaxExpr>(S)) { - ConstantRange X = getRange(SMax->getOperand(0), SignHint); + ConstantRange X = getRangeRef(SMax->getOperand(0), SignHint); for (unsigned i = 1, e = SMax->getNumOperands(); i != e; ++i) - X = X.smax(getRange(SMax->getOperand(i), SignHint)); + X = X.smax(getRangeRef(SMax->getOperand(i), SignHint)); return setRange(SMax, SignHint, ConservativeResult.intersectWith(X)); } if (const SCEVUMaxExpr *UMax = dyn_cast<SCEVUMaxExpr>(S)) { - ConstantRange X = getRange(UMax->getOperand(0), SignHint); + ConstantRange X = getRangeRef(UMax->getOperand(0), SignHint); for (unsigned i = 1, e = UMax->getNumOperands(); i != e; ++i) - X = X.umax(getRange(UMax->getOperand(i), SignHint)); + X = X.umax(getRangeRef(UMax->getOperand(i), SignHint)); return setRange(UMax, SignHint, ConservativeResult.intersectWith(X)); } if (const SCEVUDivExpr *UDiv = dyn_cast<SCEVUDivExpr>(S)) { - ConstantRange X = getRange(UDiv->getLHS(), SignHint); - ConstantRange Y = getRange(UDiv->getRHS(), SignHint); + ConstantRange X = getRangeRef(UDiv->getLHS(), SignHint); + ConstantRange Y = getRangeRef(UDiv->getRHS(), SignHint); return setRange(UDiv, SignHint, ConservativeResult.intersectWith(X.udiv(Y))); } if (const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(S)) { - ConstantRange X = getRange(ZExt->getOperand(), SignHint); + ConstantRange X = getRangeRef(ZExt->getOperand(), SignHint); return setRange(ZExt, SignHint, ConservativeResult.intersectWith(X.zeroExtend(BitWidth))); } if (const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(S)) { - ConstantRange X = getRange(SExt->getOperand(), SignHint); + ConstantRange X = getRangeRef(SExt->getOperand(), SignHint); return setRange(SExt, SignHint, ConservativeResult.intersectWith(X.signExtend(BitWidth))); } if (const SCEVTruncateExpr *Trunc = dyn_cast<SCEVTruncateExpr>(S)) { - ConstantRange X = getRange(Trunc->getOperand(), SignHint); + ConstantRange X = getRangeRef(Trunc->getOperand(), SignHint); return setRange(Trunc, SignHint, ConservativeResult.intersectWith(X.truncate(BitWidth))); } @@ -4632,7 +5211,7 @@ ScalarEvolution::getRange(const SCEV *S, } } - return setRange(AddRec, SignHint, ConservativeResult); + return setRange(AddRec, SignHint, std::move(ConservativeResult)); } if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S)) { @@ -4647,11 +5226,11 @@ ScalarEvolution::getRange(const SCEV *S, const DataLayout &DL = getDataLayout(); if (SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED) { // For a SCEVUnknown, ask ValueTracking. - APInt Zeros(BitWidth, 0), Ones(BitWidth, 0); - computeKnownBits(U->getValue(), Zeros, Ones, DL, 0, &AC, nullptr, &DT); - if (Ones != ~Zeros + 1) + KnownBits Known = computeKnownBits(U->getValue(), DL, 0, &AC, nullptr, &DT); + if (Known.One != ~Known.Zero + 1) ConservativeResult = - ConservativeResult.intersectWith(ConstantRange(Ones, ~Zeros + 1)); + ConservativeResult.intersectWith(ConstantRange(Known.One, + ~Known.Zero + 1)); } else { assert(SignHint == ScalarEvolution::HINT_RANGE_SIGNED && "generalize as needed!"); @@ -4662,10 +5241,78 @@ ScalarEvolution::getRange(const SCEV *S, APInt::getSignedMaxValue(BitWidth).ashr(NS - 1) + 1)); } - return setRange(U, SignHint, ConservativeResult); + return setRange(U, SignHint, std::move(ConservativeResult)); } - return setRange(S, SignHint, ConservativeResult); + return setRange(S, SignHint, std::move(ConservativeResult)); +} + +// Given a StartRange, Step and MaxBECount for an expression compute a range of +// values that the expression can take. Initially, the expression has a value +// from StartRange and then is changed by Step up to MaxBECount times. Signed +// argument defines if we treat Step as signed or unsigned. +static ConstantRange getRangeForAffineARHelper(APInt Step, + const ConstantRange &StartRange, + const APInt &MaxBECount, + unsigned BitWidth, bool Signed) { + // If either Step or MaxBECount is 0, then the expression won't change, and we + // just need to return the initial range. + if (Step == 0 || MaxBECount == 0) + return StartRange; + + // If we don't know anything about the initial value (i.e. StartRange is + // FullRange), then we don't know anything about the final range either. + // Return FullRange. + if (StartRange.isFullSet()) + return ConstantRange(BitWidth, /* isFullSet = */ true); + + // If Step is signed and negative, then we use its absolute value, but we also + // note that we're moving in the opposite direction. + bool Descending = Signed && Step.isNegative(); + + if (Signed) + // This is correct even for INT_SMIN. Let's look at i8 to illustrate this: + // abs(INT_SMIN) = abs(-128) = abs(0x80) = -0x80 = 0x80 = 128. + // This equations hold true due to the well-defined wrap-around behavior of + // APInt. + Step = Step.abs(); + + // Check if Offset is more than full span of BitWidth. If it is, the + // expression is guaranteed to overflow. + if (APInt::getMaxValue(StartRange.getBitWidth()).udiv(Step).ult(MaxBECount)) + return ConstantRange(BitWidth, /* isFullSet = */ true); + + // Offset is by how much the expression can change. Checks above guarantee no + // overflow here. + APInt Offset = Step * MaxBECount; + + // Minimum value of the final range will match the minimal value of StartRange + // if the expression is increasing and will be decreased by Offset otherwise. + // Maximum value of the final range will match the maximal value of StartRange + // if the expression is decreasing and will be increased by Offset otherwise. + APInt StartLower = StartRange.getLower(); + APInt StartUpper = StartRange.getUpper() - 1; + APInt MovedBoundary = Descending ? (StartLower - std::move(Offset)) + : (StartUpper + std::move(Offset)); + + // It's possible that the new minimum/maximum value will fall into the initial + // range (due to wrap around). This means that the expression can take any + // value in this bitwidth, and we have to return full range. + if (StartRange.contains(MovedBoundary)) + return ConstantRange(BitWidth, /* isFullSet = */ true); + + APInt NewLower = + Descending ? std::move(MovedBoundary) : std::move(StartLower); + APInt NewUpper = + Descending ? std::move(StartUpper) : std::move(MovedBoundary); + NewUpper += 1; + + // If we end up with full range, return a proper full range. + if (NewLower == NewUpper) + return ConstantRange(BitWidth, /* isFullSet = */ true); + + // No overflow detected, return [StartLower, StartUpper + Offset + 1) range. + return ConstantRange(std::move(NewLower), std::move(NewUpper)); } ConstantRange ScalarEvolution::getRangeForAffineAR(const SCEV *Start, @@ -4676,60 +5323,29 @@ ConstantRange ScalarEvolution::getRangeForAffineAR(const SCEV *Start, getTypeSizeInBits(MaxBECount->getType()) <= BitWidth && "Precondition!"); - ConstantRange Result(BitWidth, /* isFullSet = */ true); - - // Check for overflow. This must be done with ConstantRange arithmetic - // because we could be called from within the ScalarEvolution overflow - // checking code. - MaxBECount = getNoopOrZeroExtend(MaxBECount, Start->getType()); - ConstantRange MaxBECountRange = getUnsignedRange(MaxBECount); - ConstantRange ZExtMaxBECountRange = MaxBECountRange.zextOrTrunc(BitWidth * 2); + APInt MaxBECountValue = getUnsignedRangeMax(MaxBECount); + // First, consider step signed. + ConstantRange StartSRange = getSignedRange(Start); ConstantRange StepSRange = getSignedRange(Step); - ConstantRange SExtStepSRange = StepSRange.sextOrTrunc(BitWidth * 2); - - ConstantRange StartURange = getUnsignedRange(Start); - ConstantRange EndURange = - StartURange.add(MaxBECountRange.multiply(StepSRange)); - - // Check for unsigned overflow. - ConstantRange ZExtStartURange = StartURange.zextOrTrunc(BitWidth * 2); - ConstantRange ZExtEndURange = EndURange.zextOrTrunc(BitWidth * 2); - if (ZExtStartURange.add(ZExtMaxBECountRange.multiply(SExtStepSRange)) == - ZExtEndURange) { - APInt Min = APIntOps::umin(StartURange.getUnsignedMin(), - EndURange.getUnsignedMin()); - APInt Max = APIntOps::umax(StartURange.getUnsignedMax(), - EndURange.getUnsignedMax()); - bool IsFullRange = Min.isMinValue() && Max.isMaxValue(); - if (!IsFullRange) - Result = - Result.intersectWith(ConstantRange(Min, Max + 1)); - } - ConstantRange StartSRange = getSignedRange(Start); - ConstantRange EndSRange = - StartSRange.add(MaxBECountRange.multiply(StepSRange)); - - // Check for signed overflow. This must be done with ConstantRange - // arithmetic because we could be called from within the ScalarEvolution - // overflow checking code. - ConstantRange SExtStartSRange = StartSRange.sextOrTrunc(BitWidth * 2); - ConstantRange SExtEndSRange = EndSRange.sextOrTrunc(BitWidth * 2); - if (SExtStartSRange.add(ZExtMaxBECountRange.multiply(SExtStepSRange)) == - SExtEndSRange) { - APInt Min = - APIntOps::smin(StartSRange.getSignedMin(), EndSRange.getSignedMin()); - APInt Max = - APIntOps::smax(StartSRange.getSignedMax(), EndSRange.getSignedMax()); - bool IsFullRange = Min.isMinSignedValue() && Max.isMaxSignedValue(); - if (!IsFullRange) - Result = - Result.intersectWith(ConstantRange(Min, Max + 1)); - } + // If Step can be both positive and negative, we need to find ranges for the + // maximum absolute step values in both directions and union them. + ConstantRange SR = + getRangeForAffineARHelper(StepSRange.getSignedMin(), StartSRange, + MaxBECountValue, BitWidth, /* Signed = */ true); + SR = SR.unionWith(getRangeForAffineARHelper(StepSRange.getSignedMax(), + StartSRange, MaxBECountValue, + BitWidth, /* Signed = */ true)); - return Result; + // Next, consider step unsigned. + ConstantRange UR = getRangeForAffineARHelper( + getUnsignedRangeMax(Step), getUnsignedRange(Start), + MaxBECountValue, BitWidth, /* Signed = */ false); + + // Finally, intersect signed and unsigned ranges. + return SR.intersectWith(UR); } ConstantRange ScalarEvolution::getRangeViaFactoring(const SCEV *Start, @@ -4875,7 +5491,8 @@ bool ScalarEvolution::isSCEVExprNeverPoison(const Instruction *I) { return false; // Only proceed if we can prove that I does not yield poison. - if (!isKnownNotFullPoison(I)) return false; + if (!programUndefinedIfFullPoison(I)) + return false; // At this point we know that if I is executed, then it does not wrap // according to at least one of NSW or NUW. If I is not executed, then we do @@ -5128,9 +5745,9 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { // For an expression like x&255 that merely masks off the high bits, // use zext(trunc(x)) as the SCEV expression. if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) { - if (CI->isNullValue()) + if (CI->isZero()) return getSCEV(BO->RHS); - if (CI->isAllOnesValue()) + if (CI->isMinusOne()) return getSCEV(BO->LHS); const APInt &A = CI->getValue(); @@ -5141,19 +5758,34 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { unsigned LZ = A.countLeadingZeros(); unsigned TZ = A.countTrailingZeros(); unsigned BitWidth = A.getBitWidth(); - APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0); - computeKnownBits(BO->LHS, KnownZero, KnownOne, getDataLayout(), + KnownBits Known(BitWidth); + computeKnownBits(BO->LHS, Known, getDataLayout(), 0, &AC, nullptr, &DT); APInt EffectiveMask = APInt::getLowBitsSet(BitWidth, BitWidth - LZ - TZ).shl(TZ); - if ((LZ != 0 || TZ != 0) && !((~A & ~KnownZero) & EffectiveMask)) { - const SCEV *MulCount = getConstant(ConstantInt::get( - getContext(), APInt::getOneBitSet(BitWidth, TZ))); + if ((LZ != 0 || TZ != 0) && !((~A & ~Known.Zero) & EffectiveMask)) { + const SCEV *MulCount = getConstant(APInt::getOneBitSet(BitWidth, TZ)); + const SCEV *LHS = getSCEV(BO->LHS); + const SCEV *ShiftedLHS = nullptr; + if (auto *LHSMul = dyn_cast<SCEVMulExpr>(LHS)) { + if (auto *OpC = dyn_cast<SCEVConstant>(LHSMul->getOperand(0))) { + // For an expression like (x * 8) & 8, simplify the multiply. + unsigned MulZeros = OpC->getAPInt().countTrailingZeros(); + unsigned GCD = std::min(MulZeros, TZ); + APInt DivAmt = APInt::getOneBitSet(BitWidth, TZ - GCD); + SmallVector<const SCEV*, 4> MulOps; + MulOps.push_back(getConstant(OpC->getAPInt().lshr(GCD))); + MulOps.append(LHSMul->op_begin() + 1, LHSMul->op_end()); + auto *NewMul = getMulExpr(MulOps, LHSMul->getNoWrapFlags()); + ShiftedLHS = getUDivExpr(NewMul, getConstant(DivAmt)); + } + } + if (!ShiftedLHS) + ShiftedLHS = getUDivExpr(LHS, MulCount); return getMulExpr( getZeroExtendExpr( - getTruncateExpr( - getUDivExactExpr(getSCEV(BO->LHS), MulCount), + getTruncateExpr(ShiftedLHS, IntegerType::get(getContext(), BitWidth - LZ - TZ)), BO->LHS->getType()), MulCount); @@ -5190,7 +5822,7 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { case Instruction::Xor: if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) { // If the RHS of xor is -1, then this is a not operation. - if (CI->isAllOnesValue()) + if (CI->isMinusOne()) return getNotSCEV(getSCEV(BO->LHS)); // Model xor(and(x, C), C) as and(~x, C), if C is a low-bits mask. @@ -5211,7 +5843,7 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { // If C is a low-bits mask, the zero extend is serving to // mask off the high bits. Complement the operand and // re-apply the zext. - if (APIntOps::isMask(Z0TySize, CI->getValue())) + if (CI->getValue().isMask(Z0TySize)) return getZeroExtendExpr(getNotSCEV(Z0), UTy); // If C is a single bit, it may be in the sign-bit position @@ -5219,7 +5851,7 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { // using an add, which is equivalent, and re-apply the zext. APInt Trunc = CI->getValue().trunc(Z0TySize); if (Trunc.zext(getTypeSizeInBits(UTy)) == CI->getValue() && - Trunc.isSignBit()) + Trunc.isSignMask()) return getZeroExtendExpr(getAddExpr(Z0, getConstant(Trunc)), UTy); } @@ -5255,28 +5887,55 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { break; case Instruction::AShr: - // For a two-shift sext-inreg, use sext(trunc(x)) as the SCEV expression. - if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) - if (Operator *L = dyn_cast<Operator>(BO->LHS)) - if (L->getOpcode() == Instruction::Shl && - L->getOperand(1) == BO->RHS) { - uint64_t BitWidth = getTypeSizeInBits(BO->LHS->getType()); - - // If the shift count is not less than the bitwidth, the result of - // the shift is undefined. Don't try to analyze it, because the - // resolution chosen here may differ from the resolution chosen in - // other parts of the compiler. - if (CI->getValue().uge(BitWidth)) - break; + // AShr X, C, where C is a constant. + ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS); + if (!CI) + break; + + Type *OuterTy = BO->LHS->getType(); + uint64_t BitWidth = getTypeSizeInBits(OuterTy); + // If the shift count is not less than the bitwidth, the result of + // the shift is undefined. Don't try to analyze it, because the + // resolution chosen here may differ from the resolution chosen in + // other parts of the compiler. + if (CI->getValue().uge(BitWidth)) + break; - uint64_t Amt = BitWidth - CI->getZExtValue(); - if (Amt == BitWidth) - return getSCEV(L->getOperand(0)); // shift by zero --> noop + if (CI->isZero()) + return getSCEV(BO->LHS); // shift by zero --> noop + + uint64_t AShrAmt = CI->getZExtValue(); + Type *TruncTy = IntegerType::get(getContext(), BitWidth - AShrAmt); + + Operator *L = dyn_cast<Operator>(BO->LHS); + if (L && L->getOpcode() == Instruction::Shl) { + // X = Shl A, n + // Y = AShr X, m + // Both n and m are constant. + + const SCEV *ShlOp0SCEV = getSCEV(L->getOperand(0)); + if (L->getOperand(1) == BO->RHS) + // For a two-shift sext-inreg, i.e. n = m, + // use sext(trunc(x)) as the SCEV expression. + return getSignExtendExpr( + getTruncateExpr(ShlOp0SCEV, TruncTy), OuterTy); + + ConstantInt *ShlAmtCI = dyn_cast<ConstantInt>(L->getOperand(1)); + if (ShlAmtCI && ShlAmtCI->getValue().ult(BitWidth)) { + uint64_t ShlAmt = ShlAmtCI->getZExtValue(); + if (ShlAmt > AShrAmt) { + // When n > m, use sext(mul(trunc(x), 2^(n-m)))) as the SCEV + // expression. We already checked that ShlAmt < BitWidth, so + // the multiplier, 1 << (ShlAmt - AShrAmt), fits into TruncTy as + // ShlAmt - AShrAmt < Amt. + APInt Mul = APInt::getOneBitSet(BitWidth - AShrAmt, + ShlAmt - AShrAmt); return getSignExtendExpr( - getTruncateExpr(getSCEV(L->getOperand(0)), - IntegerType::get(getContext(), Amt)), - BO->LHS->getType()); + getMulExpr(getTruncateExpr(ShlOp0SCEV, TruncTy), + getConstant(Mul)), OuterTy); } + } + } break; } } @@ -5348,7 +6007,7 @@ static unsigned getConstantTripCount(const SCEVConstant *ExitCount) { return ((unsigned)ExitConst->getZExtValue()) + 1; } -unsigned ScalarEvolution::getSmallConstantTripCount(Loop *L) { +unsigned ScalarEvolution::getSmallConstantTripCount(const Loop *L) { if (BasicBlock *ExitingBB = L->getExitingBlock()) return getSmallConstantTripCount(L, ExitingBB); @@ -5356,7 +6015,7 @@ unsigned ScalarEvolution::getSmallConstantTripCount(Loop *L) { return 0; } -unsigned ScalarEvolution::getSmallConstantTripCount(Loop *L, +unsigned ScalarEvolution::getSmallConstantTripCount(const Loop *L, BasicBlock *ExitingBlock) { assert(ExitingBlock && "Must pass a non-null exiting block!"); assert(L->isLoopExiting(ExitingBlock) && @@ -5366,13 +6025,13 @@ unsigned ScalarEvolution::getSmallConstantTripCount(Loop *L, return getConstantTripCount(ExitCount); } -unsigned ScalarEvolution::getSmallConstantMaxTripCount(Loop *L) { +unsigned ScalarEvolution::getSmallConstantMaxTripCount(const Loop *L) { const auto *MaxExitCount = dyn_cast<SCEVConstant>(getMaxBackedgeTakenCount(L)); return getConstantTripCount(MaxExitCount); } -unsigned ScalarEvolution::getSmallConstantTripMultiple(Loop *L) { +unsigned ScalarEvolution::getSmallConstantTripMultiple(const Loop *L) { if (BasicBlock *ExitingBB = L->getExitingBlock()) return getSmallConstantTripMultiple(L, ExitingBB); @@ -5393,7 +6052,7 @@ unsigned ScalarEvolution::getSmallConstantTripMultiple(Loop *L) { /// As explained in the comments for getSmallConstantTripCount, this assumes /// that control exits the loop via ExitingBlock. unsigned -ScalarEvolution::getSmallConstantTripMultiple(Loop *L, +ScalarEvolution::getSmallConstantTripMultiple(const Loop *L, BasicBlock *ExitingBlock) { assert(ExitingBlock && "Must pass a non-null exiting block!"); assert(L->isLoopExiting(ExitingBlock) && @@ -5403,17 +6062,16 @@ ScalarEvolution::getSmallConstantTripMultiple(Loop *L, return 1; // Get the trip count from the BE count by adding 1. - const SCEV *TCMul = getAddExpr(ExitCount, getOne(ExitCount->getType())); - // FIXME: SCEV distributes multiplication as V1*C1 + V2*C1. We could attempt - // to factor simple cases. - if (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(TCMul)) - TCMul = Mul->getOperand(0); - - const SCEVConstant *MulC = dyn_cast<SCEVConstant>(TCMul); - if (!MulC) - return 1; + const SCEV *TCExpr = getAddExpr(ExitCount, getOne(ExitCount->getType())); + + const SCEVConstant *TC = dyn_cast<SCEVConstant>(TCExpr); + if (!TC) + // Attempt to factor more general cases. Returns the greatest power of + // two divisor. If overflow happens, the trip count expression is still + // divisible by the greatest power of 2 divisor returned. + return 1U << std::min((uint32_t)31, GetMinTrailingZeros(TCExpr)); - ConstantInt *Result = MulC->getValue(); + ConstantInt *Result = TC->getValue(); // Guard against huge trip counts (this requires checking // for zero to handle the case where the trip count == -1 and the @@ -5428,7 +6086,8 @@ ScalarEvolution::getSmallConstantTripMultiple(Loop *L, /// Get the expression for the number of loop iterations for which this loop is /// guaranteed not to exit via ExitingBlock. Otherwise return /// SCEVCouldNotCompute. -const SCEV *ScalarEvolution::getExitCount(Loop *L, BasicBlock *ExitingBlock) { +const SCEV *ScalarEvolution::getExitCount(const Loop *L, + BasicBlock *ExitingBlock) { return getBackedgeTakenInfo(L).getExact(ExitingBlock, this); } @@ -5569,6 +6228,16 @@ void ScalarEvolution::forgetLoop(const Loop *L) { RemoveLoopFromBackedgeMap(BackedgeTakenCounts); RemoveLoopFromBackedgeMap(PredicatedBackedgeTakenCounts); + // Drop information about predicated SCEV rewrites for this loop. + for (auto I = PredicatedSCEVRewrites.begin(); + I != PredicatedSCEVRewrites.end();) { + std::pair<const SCEV *, const Loop *> Entry = I->first; + if (Entry.second == L) + PredicatedSCEVRewrites.erase(I++); + else + ++I; + } + // Drop information about expressions based on loop-header PHIs. SmallVector<Instruction *, 16> Worklist; PushLoopPHIs(L, Worklist); @@ -5681,6 +6350,8 @@ ScalarEvolution::BackedgeTakenInfo::getMax(ScalarEvolution *SE) const { if (any_of(ExitNotTaken, PredicateNotAlwaysTrue) || !getMax()) return SE->getCouldNotCompute(); + assert((isa<SCEVCouldNotCompute>(getMax()) || isa<SCEVConstant>(getMax())) && + "No point in having a non-constant max backedge taken count!"); return getMax(); } @@ -5705,6 +6376,45 @@ bool ScalarEvolution::BackedgeTakenInfo::hasOperand(const SCEV *S, return false; } +ScalarEvolution::ExitLimit::ExitLimit(const SCEV *E) + : ExactNotTaken(E), MaxNotTaken(E), MaxOrZero(false) { + assert((isa<SCEVCouldNotCompute>(MaxNotTaken) || + isa<SCEVConstant>(MaxNotTaken)) && + "No point in having a non-constant max backedge taken count!"); +} + +ScalarEvolution::ExitLimit::ExitLimit( + const SCEV *E, const SCEV *M, bool MaxOrZero, + ArrayRef<const SmallPtrSetImpl<const SCEVPredicate *> *> PredSetList) + : ExactNotTaken(E), MaxNotTaken(M), MaxOrZero(MaxOrZero) { + assert((isa<SCEVCouldNotCompute>(ExactNotTaken) || + !isa<SCEVCouldNotCompute>(MaxNotTaken)) && + "Exact is not allowed to be less precise than Max"); + assert((isa<SCEVCouldNotCompute>(MaxNotTaken) || + isa<SCEVConstant>(MaxNotTaken)) && + "No point in having a non-constant max backedge taken count!"); + for (auto *PredSet : PredSetList) + for (auto *P : *PredSet) + addPredicate(P); +} + +ScalarEvolution::ExitLimit::ExitLimit( + const SCEV *E, const SCEV *M, bool MaxOrZero, + const SmallPtrSetImpl<const SCEVPredicate *> &PredSet) + : ExitLimit(E, M, MaxOrZero, {&PredSet}) { + assert((isa<SCEVCouldNotCompute>(MaxNotTaken) || + isa<SCEVConstant>(MaxNotTaken)) && + "No point in having a non-constant max backedge taken count!"); +} + +ScalarEvolution::ExitLimit::ExitLimit(const SCEV *E, const SCEV *M, + bool MaxOrZero) + : ExitLimit(E, M, MaxOrZero, None) { + assert((isa<SCEVCouldNotCompute>(MaxNotTaken) || + isa<SCEVConstant>(MaxNotTaken)) && + "No point in having a non-constant max backedge taken count!"); +} + /// Allocate memory for BackedgeTakenInfo and copy the not-taken count of each /// computable exit into a persistent ExitNotTakenInfo array. ScalarEvolution::BackedgeTakenInfo::BackedgeTakenInfo( @@ -5728,6 +6438,8 @@ ScalarEvolution::BackedgeTakenInfo::BackedgeTakenInfo( return ExitNotTakenInfo(ExitBB, EL.ExactNotTaken, std::move(Predicate)); }); + assert((isa<SCEVCouldNotCompute>(MaxCount) || isa<SCEVConstant>(MaxCount)) && + "No point in having a non-constant max backedge taken count!"); } /// Invalidate this result and free the ExitNotTakenInfo array. @@ -5886,24 +6598,74 @@ ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock, return getCouldNotCompute(); } -ScalarEvolution::ExitLimit -ScalarEvolution::computeExitLimitFromCond(const Loop *L, - Value *ExitCond, - BasicBlock *TBB, - BasicBlock *FBB, - bool ControlsExit, - bool AllowPredicates) { +ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCond( + const Loop *L, Value *ExitCond, BasicBlock *TBB, BasicBlock *FBB, + bool ControlsExit, bool AllowPredicates) { + ScalarEvolution::ExitLimitCacheTy Cache(L, TBB, FBB, AllowPredicates); + return computeExitLimitFromCondCached(Cache, L, ExitCond, TBB, FBB, + ControlsExit, AllowPredicates); +} + +Optional<ScalarEvolution::ExitLimit> +ScalarEvolution::ExitLimitCache::find(const Loop *L, Value *ExitCond, + BasicBlock *TBB, BasicBlock *FBB, + bool ControlsExit, bool AllowPredicates) { + (void)this->L; + (void)this->TBB; + (void)this->FBB; + (void)this->AllowPredicates; + + assert(this->L == L && this->TBB == TBB && this->FBB == FBB && + this->AllowPredicates == AllowPredicates && + "Variance in assumed invariant key components!"); + auto Itr = TripCountMap.find({ExitCond, ControlsExit}); + if (Itr == TripCountMap.end()) + return None; + return Itr->second; +} + +void ScalarEvolution::ExitLimitCache::insert(const Loop *L, Value *ExitCond, + BasicBlock *TBB, BasicBlock *FBB, + bool ControlsExit, + bool AllowPredicates, + const ExitLimit &EL) { + assert(this->L == L && this->TBB == TBB && this->FBB == FBB && + this->AllowPredicates == AllowPredicates && + "Variance in assumed invariant key components!"); + + auto InsertResult = TripCountMap.insert({{ExitCond, ControlsExit}, EL}); + assert(InsertResult.second && "Expected successful insertion!"); + (void)InsertResult; +} + +ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondCached( + ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, BasicBlock *TBB, + BasicBlock *FBB, bool ControlsExit, bool AllowPredicates) { + + if (auto MaybeEL = + Cache.find(L, ExitCond, TBB, FBB, ControlsExit, AllowPredicates)) + return *MaybeEL; + + ExitLimit EL = computeExitLimitFromCondImpl(Cache, L, ExitCond, TBB, FBB, + ControlsExit, AllowPredicates); + Cache.insert(L, ExitCond, TBB, FBB, ControlsExit, AllowPredicates, EL); + return EL; +} + +ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl( + ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, BasicBlock *TBB, + BasicBlock *FBB, bool ControlsExit, bool AllowPredicates) { // Check if the controlling expression for this loop is an And or Or. if (BinaryOperator *BO = dyn_cast<BinaryOperator>(ExitCond)) { if (BO->getOpcode() == Instruction::And) { // Recurse on the operands of the and. bool EitherMayExit = L->contains(TBB); - ExitLimit EL0 = computeExitLimitFromCond(L, BO->getOperand(0), TBB, FBB, - ControlsExit && !EitherMayExit, - AllowPredicates); - ExitLimit EL1 = computeExitLimitFromCond(L, BO->getOperand(1), TBB, FBB, - ControlsExit && !EitherMayExit, - AllowPredicates); + ExitLimit EL0 = computeExitLimitFromCondCached( + Cache, L, BO->getOperand(0), TBB, FBB, ControlsExit && !EitherMayExit, + AllowPredicates); + ExitLimit EL1 = computeExitLimitFromCondCached( + Cache, L, BO->getOperand(1), TBB, FBB, ControlsExit && !EitherMayExit, + AllowPredicates); const SCEV *BECount = getCouldNotCompute(); const SCEV *MaxBECount = getCouldNotCompute(); if (EitherMayExit) { @@ -5939,7 +6701,7 @@ ScalarEvolution::computeExitLimitFromCond(const Loop *L, // to not. if (isa<SCEVCouldNotCompute>(MaxBECount) && !isa<SCEVCouldNotCompute>(BECount)) - MaxBECount = BECount; + MaxBECount = getConstant(getUnsignedRangeMax(BECount)); return ExitLimit(BECount, MaxBECount, false, {&EL0.Predicates, &EL1.Predicates}); @@ -5947,12 +6709,12 @@ ScalarEvolution::computeExitLimitFromCond(const Loop *L, if (BO->getOpcode() == Instruction::Or) { // Recurse on the operands of the or. bool EitherMayExit = L->contains(FBB); - ExitLimit EL0 = computeExitLimitFromCond(L, BO->getOperand(0), TBB, FBB, - ControlsExit && !EitherMayExit, - AllowPredicates); - ExitLimit EL1 = computeExitLimitFromCond(L, BO->getOperand(1), TBB, FBB, - ControlsExit && !EitherMayExit, - AllowPredicates); + ExitLimit EL0 = computeExitLimitFromCondCached( + Cache, L, BO->getOperand(0), TBB, FBB, ControlsExit && !EitherMayExit, + AllowPredicates); + ExitLimit EL1 = computeExitLimitFromCondCached( + Cache, L, BO->getOperand(1), TBB, FBB, ControlsExit && !EitherMayExit, + AllowPredicates); const SCEV *BECount = getCouldNotCompute(); const SCEV *MaxBECount = getCouldNotCompute(); if (EitherMayExit) { @@ -6337,13 +7099,12 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeShiftCompareExitLimit( // {K,ashr,<positive-constant>} stabilizes to signum(K) in at most // bitwidth(K) iterations. Value *FirstValue = PN->getIncomingValueForBlock(Predecessor); - bool KnownZero, KnownOne; - ComputeSignBit(FirstValue, KnownZero, KnownOne, DL, 0, nullptr, - Predecessor->getTerminator(), &DT); + KnownBits Known = computeKnownBits(FirstValue, DL, 0, nullptr, + Predecessor->getTerminator(), &DT); auto *Ty = cast<IntegerType>(RHS->getType()); - if (KnownZero) + if (Known.isNonNegative()) StableValue = ConstantInt::get(Ty, 0); - else if (KnownOne) + else if (Known.isNegative()) StableValue = ConstantInt::get(Ty, -1, true); else return getCouldNotCompute(); @@ -6383,7 +7144,7 @@ static bool CanConstantFold(const Instruction *I) { if (const CallInst *CI = dyn_cast<CallInst>(I)) if (const Function *F = CI->getCalledFunction()) - return canConstantFoldCallTo(F); + return canConstantFoldCallTo(CI, F); return false; } @@ -6408,7 +7169,10 @@ static bool canConstantEvolve(Instruction *I, const Loop *L) { /// recursing through each instruction operand until reaching a loop header phi. static PHINode * getConstantEvolvingPHIOperands(Instruction *UseInst, const Loop *L, - DenseMap<Instruction *, PHINode *> &PHIMap) { + DenseMap<Instruction *, PHINode *> &PHIMap, + unsigned Depth) { + if (Depth > MaxConstantEvolvingDepth) + return nullptr; // Otherwise, we can evaluate this instruction if all of its operands are // constant or derived from a PHI node themselves. @@ -6428,7 +7192,7 @@ getConstantEvolvingPHIOperands(Instruction *UseInst, const Loop *L, if (!P) { // Recurse and memoize the results, whether a phi is found or not. // This recursive call invalidates pointers into PHIMap. - P = getConstantEvolvingPHIOperands(OpInst, L, PHIMap); + P = getConstantEvolvingPHIOperands(OpInst, L, PHIMap, Depth + 1); PHIMap[OpInst] = P; } if (!P) @@ -6455,7 +7219,7 @@ static PHINode *getConstantEvolvingPHI(Value *V, const Loop *L) { // Record non-constant instructions contained by the loop. DenseMap<Instruction *, PHINode *> PHIMap; - return getConstantEvolvingPHIOperands(I, L, PHIMap); + return getConstantEvolvingPHIOperands(I, L, PHIMap, 0); } /// EvaluateExpression - Given an expression that passes the @@ -6829,6 +7593,25 @@ const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) { const SCEV *BackedgeTakenCount = getBackedgeTakenCount(LI); if (const SCEVConstant *BTCC = dyn_cast<SCEVConstant>(BackedgeTakenCount)) { + + // This trivial case can show up in some degenerate cases where + // the incoming IR has not yet been fully simplified. + if (BTCC->getValue()->isZero()) { + Value *InitValue = nullptr; + bool MultipleInitValues = false; + for (unsigned i = 0; i < PN->getNumIncomingValues(); i++) { + if (!LI->contains(PN->getIncomingBlock(i))) { + if (!InitValue) + InitValue = PN->getIncomingValue(i); + else if (InitValue != PN->getIncomingValue(i)) { + MultipleInitValues = true; + break; + } + } + if (!MultipleInitValues && InitValue) + return getSCEV(InitValue); + } + } // Okay, we know how many times the containing loop executes. If // this is a constant evolving PHI node, get the final value at // the specified iteration number. @@ -7014,10 +7797,10 @@ const SCEV *ScalarEvolution::getSCEVAtScope(Value *V, const Loop *L) { /// A and B isn't important. /// /// If the equation does not have a solution, SCEVCouldNotCompute is returned. -static const SCEV *SolveLinEquationWithOverflow(const APInt &A, const APInt &B, +static const SCEV *SolveLinEquationWithOverflow(const APInt &A, const SCEV *B, ScalarEvolution &SE) { uint32_t BW = A.getBitWidth(); - assert(BW == B.getBitWidth() && "Bit widths must be the same."); + assert(BW == SE.getTypeSizeInBits(B->getType())); assert(A != 0 && "A must be non-zero."); // 1. D = gcd(A, N) @@ -7031,7 +7814,7 @@ static const SCEV *SolveLinEquationWithOverflow(const APInt &A, const APInt &B, // // B is divisible by D if and only if the multiplicity of prime factor 2 for B // is not less than multiplicity of this prime factor for D. - if (B.countTrailingZeros() < Mult2) + if (SE.GetMinTrailingZeros(B) < Mult2) return SE.getCouldNotCompute(); // 3. Compute I: the multiplicative inverse of (A / D) in arithmetic @@ -7049,9 +7832,8 @@ static const SCEV *SolveLinEquationWithOverflow(const APInt &A, const APInt &B, // I * (B / D) mod (N / D) // To simplify the computation, we factor out the divide by D: // (I * B mod N) / D - APInt Result = (I * B).lshr(Mult2); - - return SE.getConstant(Result); + const SCEV *D = SE.getConstant(APInt::getOneBitSet(BW, Mult2)); + return SE.getUDivExactExpr(SE.getMulExpr(B, SE.getConstant(I)), D); } /// Find the roots of the quadratic equation for the given quadratic chrec @@ -7074,50 +7856,50 @@ SolveQuadraticEquation(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE) { const APInt &M = MC->getAPInt(); const APInt &N = NC->getAPInt(); APInt Two(BitWidth, 2); - APInt Four(BitWidth, 4); - - { - using namespace APIntOps; - const APInt& C = L; - // Convert from chrec coefficients to polynomial coefficients AX^2+BX+C - // The B coefficient is M-N/2 - APInt B(M); - B -= sdiv(N,Two); - - // The A coefficient is N/2 - APInt A(N.sdiv(Two)); - - // Compute the B^2-4ac term. - APInt SqrtTerm(B); - SqrtTerm *= B; - SqrtTerm -= Four * (A * C); - - if (SqrtTerm.isNegative()) { - // The loop is provably infinite. - return None; - } - // Compute sqrt(B^2-4ac). This is guaranteed to be the nearest - // integer value or else APInt::sqrt() will assert. - APInt SqrtVal(SqrtTerm.sqrt()); + // Convert from chrec coefficients to polynomial coefficients AX^2+BX+C - // Compute the two solutions for the quadratic formula. - // The divisions must be performed as signed divisions. - APInt NegB(-B); - APInt TwoA(A << 1); - if (TwoA.isMinValue()) - return None; + // The A coefficient is N/2 + APInt A = N.sdiv(Two); + + // The B coefficient is M-N/2 + APInt B = M; + B -= A; // A is the same as N/2. - LLVMContext &Context = SE.getContext(); + // The C coefficient is L. + const APInt& C = L; - ConstantInt *Solution1 = - ConstantInt::get(Context, (NegB + SqrtVal).sdiv(TwoA)); - ConstantInt *Solution2 = - ConstantInt::get(Context, (NegB - SqrtVal).sdiv(TwoA)); + // Compute the B^2-4ac term. + APInt SqrtTerm = B; + SqrtTerm *= B; + SqrtTerm -= 4 * (A * C); - return std::make_pair(cast<SCEVConstant>(SE.getConstant(Solution1)), - cast<SCEVConstant>(SE.getConstant(Solution2))); - } // end APIntOps namespace + if (SqrtTerm.isNegative()) { + // The loop is provably infinite. + return None; + } + + // Compute sqrt(B^2-4ac). This is guaranteed to be the nearest + // integer value or else APInt::sqrt() will assert. + APInt SqrtVal = SqrtTerm.sqrt(); + + // Compute the two solutions for the quadratic formula. + // The divisions must be performed as signed divisions. + APInt NegB = -std::move(B); + APInt TwoA = std::move(A); + TwoA <<= 1; + if (TwoA.isNullValue()) + return None; + + LLVMContext &Context = SE.getContext(); + + ConstantInt *Solution1 = + ConstantInt::get(Context, (NegB + SqrtVal).sdiv(TwoA)); + ConstantInt *Solution2 = + ConstantInt::get(Context, (NegB - SqrtVal).sdiv(TwoA)); + + return std::make_pair(cast<SCEVConstant>(SE.getConstant(Solution1)), + cast<SCEVConstant>(SE.getConstant(Solution2))); } ScalarEvolution::ExitLimit @@ -7197,7 +7979,7 @@ ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L, bool ControlsExit, // to 0, it must be counting down to equal 0. Consequently, N = Start / -Step. // We have not yet seen any such cases. const SCEVConstant *StepC = dyn_cast<SCEVConstant>(Step); - if (!StepC || StepC->getValue()->equalsInt(0)) + if (!StepC || StepC->getValue()->isZero()) return getCouldNotCompute(); // For positive steps (counting up until unsigned overflow): @@ -7211,8 +7993,8 @@ ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L, bool ControlsExit, // Handle unitary steps, which cannot wraparound. // 1*N = -Start; -1*N = Start (mod 2^BW), so: // N = Distance (as unsigned) - if (StepC->getValue()->equalsInt(1) || StepC->getValue()->isAllOnesValue()) { - APInt MaxBECount = getUnsignedRange(Distance).getUnsignedMax(); + if (StepC->getValue()->isOne() || StepC->getValue()->isMinusOne()) { + APInt MaxBECount = getUnsignedRangeMax(Distance); // When a loop like "for (int i = 0; i != n; ++i) { /* body */ }" is rotated, // we end up with a loop whose backedge-taken count is n - 1. Detect this @@ -7233,62 +8015,6 @@ ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L, bool ControlsExit, return ExitLimit(Distance, getConstant(MaxBECount), false, Predicates); } - // As a special case, handle the instance where Step is a positive power of - // two. In this case, determining whether Step divides Distance evenly can be - // done by counting and comparing the number of trailing zeros of Step and - // Distance. - if (!CountDown) { - const APInt &StepV = StepC->getAPInt(); - // StepV.isPowerOf2() returns true if StepV is an positive power of two. It - // also returns true if StepV is maximally negative (eg, INT_MIN), but that - // case is not handled as this code is guarded by !CountDown. - if (StepV.isPowerOf2() && - GetMinTrailingZeros(Distance) >= StepV.countTrailingZeros()) { - // Here we've constrained the equation to be of the form - // - // 2^(N + k) * Distance' = (StepV == 2^N) * X (mod 2^W) ... (0) - // - // where we're operating on a W bit wide integer domain and k is - // non-negative. The smallest unsigned solution for X is the trip count. - // - // (0) is equivalent to: - // - // 2^(N + k) * Distance' - 2^N * X = L * 2^W - // <=> 2^N(2^k * Distance' - X) = L * 2^(W - N) * 2^N - // <=> 2^k * Distance' - X = L * 2^(W - N) - // <=> 2^k * Distance' = L * 2^(W - N) + X ... (1) - // - // The smallest X satisfying (1) is unsigned remainder of dividing the LHS - // by 2^(W - N). - // - // <=> X = 2^k * Distance' URem 2^(W - N) ... (2) - // - // E.g. say we're solving - // - // 2 * Val = 2 * X (in i8) ... (3) - // - // then from (2), we get X = Val URem i8 128 (k = 0 in this case). - // - // Note: It is tempting to solve (3) by setting X = Val, but Val is not - // necessarily the smallest unsigned value of X that satisfies (3). - // E.g. if Val is i8 -127 then the smallest value of X that satisfies (3) - // is i8 1, not i8 -127 - - const auto *ModuloResult = getUDivExactExpr(Distance, Step); - - // Since SCEV does not have a URem node, we construct one using a truncate - // and a zero extend. - - unsigned NarrowWidth = StepV.getBitWidth() - StepV.countTrailingZeros(); - auto *NarrowTy = IntegerType::get(getContext(), NarrowWidth); - auto *WideTy = Distance->getType(); - - const SCEV *Limit = - getZeroExtendExpr(getTruncateExpr(ModuloResult, NarrowTy), WideTy); - return ExitLimit(Limit, Limit, false, Predicates); - } - } - // If the condition controls loop exit (the loop exits only if the expression // is true) and the addition is no-wrap we can use unsigned divide to // compute the backedge count. In this case, the step may not divide the @@ -7298,16 +8024,20 @@ ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L, bool ControlsExit, loopHasNoAbnormalExits(AddRec->getLoop())) { const SCEV *Exact = getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step); - return ExitLimit(Exact, Exact, false, Predicates); + const SCEV *Max = + Exact == getCouldNotCompute() + ? Exact + : getConstant(getUnsignedRangeMax(Exact)); + return ExitLimit(Exact, Max, false, Predicates); } - // Then, try to solve the above equation provided that Start is constant. - if (const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start)) { - const SCEV *E = SolveLinEquationWithOverflow( - StepC->getValue()->getValue(), -StartC->getValue()->getValue(), *this); - return ExitLimit(E, E, false, Predicates); - } - return getCouldNotCompute(); + // Solve the general equation. + const SCEV *E = SolveLinEquationWithOverflow(StepC->getAPInt(), + getNegativeSCEV(Start), *this); + const SCEV *M = E == getCouldNotCompute() + ? E + : getConstant(getUnsignedRangeMax(E)); + return ExitLimit(E, M, false, Predicates); } ScalarEvolution::ExitLimit @@ -7319,7 +8049,7 @@ ScalarEvolution::howFarToNonZero(const SCEV *V, const Loop *L) { // If the value is a constant, check to see if it is known to be non-zero // already. If so, the backedge will execute zero times. if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) { - if (!C->getValue()->isNullValue()) + if (!C->getValue()->isZero()) return getZero(C->getType()); return getCouldNotCompute(); // Otherwise it will loop infinitely. } @@ -7503,12 +8233,12 @@ bool ScalarEvolution::SimplifyICmpOperands(ICmpInst::Predicate &Pred, // adding or subtracting 1 from one of the operands. switch (Pred) { case ICmpInst::ICMP_SLE: - if (!getSignedRange(RHS).getSignedMax().isMaxSignedValue()) { + if (!getSignedRangeMax(RHS).isMaxSignedValue()) { RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS, SCEV::FlagNSW); Pred = ICmpInst::ICMP_SLT; Changed = true; - } else if (!getSignedRange(LHS).getSignedMin().isMinSignedValue()) { + } else if (!getSignedRangeMin(LHS).isMinSignedValue()) { LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS, SCEV::FlagNSW); Pred = ICmpInst::ICMP_SLT; @@ -7516,12 +8246,12 @@ bool ScalarEvolution::SimplifyICmpOperands(ICmpInst::Predicate &Pred, } break; case ICmpInst::ICMP_SGE: - if (!getSignedRange(RHS).getSignedMin().isMinSignedValue()) { + if (!getSignedRangeMin(RHS).isMinSignedValue()) { RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS, SCEV::FlagNSW); Pred = ICmpInst::ICMP_SGT; Changed = true; - } else if (!getSignedRange(LHS).getSignedMax().isMaxSignedValue()) { + } else if (!getSignedRangeMax(LHS).isMaxSignedValue()) { LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS, SCEV::FlagNSW); Pred = ICmpInst::ICMP_SGT; @@ -7529,23 +8259,23 @@ bool ScalarEvolution::SimplifyICmpOperands(ICmpInst::Predicate &Pred, } break; case ICmpInst::ICMP_ULE: - if (!getUnsignedRange(RHS).getUnsignedMax().isMaxValue()) { + if (!getUnsignedRangeMax(RHS).isMaxValue()) { RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS, SCEV::FlagNUW); Pred = ICmpInst::ICMP_ULT; Changed = true; - } else if (!getUnsignedRange(LHS).getUnsignedMin().isMinValue()) { + } else if (!getUnsignedRangeMin(LHS).isMinValue()) { LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS); Pred = ICmpInst::ICMP_ULT; Changed = true; } break; case ICmpInst::ICMP_UGE: - if (!getUnsignedRange(RHS).getUnsignedMin().isMinValue()) { + if (!getUnsignedRangeMin(RHS).isMinValue()) { RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS); Pred = ICmpInst::ICMP_UGT; Changed = true; - } else if (!getUnsignedRange(LHS).getUnsignedMax().isMaxValue()) { + } else if (!getUnsignedRangeMax(LHS).isMaxValue()) { LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS, SCEV::FlagNUW); Pred = ICmpInst::ICMP_UGT; @@ -7579,19 +8309,19 @@ trivially_false: } bool ScalarEvolution::isKnownNegative(const SCEV *S) { - return getSignedRange(S).getSignedMax().isNegative(); + return getSignedRangeMax(S).isNegative(); } bool ScalarEvolution::isKnownPositive(const SCEV *S) { - return getSignedRange(S).getSignedMin().isStrictlyPositive(); + return getSignedRangeMin(S).isStrictlyPositive(); } bool ScalarEvolution::isKnownNonNegative(const SCEV *S) { - return !getSignedRange(S).getSignedMin().isNegative(); + return !getSignedRangeMin(S).isNegative(); } bool ScalarEvolution::isKnownNonPositive(const SCEV *S) { - return !getSignedRange(S).getSignedMax().isStrictlyPositive(); + return !getSignedRangeMax(S).isStrictlyPositive(); } bool ScalarEvolution::isKnownNonZero(const SCEV *S) { @@ -7822,6 +8552,7 @@ bool ScalarEvolution::isKnownPredicateViaNoOverflow(ICmpInst::Predicate Pred, case ICmpInst::ICMP_SGE: std::swap(LHS, RHS); + LLVM_FALLTHROUGH; case ICmpInst::ICMP_SLE: // X s<= (X + C)<nsw> if C >= 0 if (MatchBinaryAddToConst(RHS, LHS, C, SCEV::FlagNSW) && C.isNonNegative()) @@ -7835,6 +8566,7 @@ bool ScalarEvolution::isKnownPredicateViaNoOverflow(ICmpInst::Predicate Pred, case ICmpInst::ICMP_SGT: std::swap(LHS, RHS); + LLVM_FALLTHROUGH; case ICmpInst::ICMP_SLT: // X s< (X + C)<nsw> if C > 0 if (MatchBinaryAddToConst(RHS, LHS, C, SCEV::FlagNSW) && @@ -8175,7 +8907,7 @@ bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS, // predicate we're interested in folding. APInt Min = ICmpInst::isSigned(Pred) ? - getSignedRange(V).getSignedMin() : getUnsignedRange(V).getUnsignedMin(); + getSignedRangeMin(V) : getUnsignedRangeMin(V); if (Min == C->getAPInt()) { // Given (V >= Min && V != Min) we conclude V >= (Min + 1). @@ -8192,6 +8924,7 @@ bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS, if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(SharperMin))) return true; + LLVM_FALLTHROUGH; case ICmpInst::ICMP_SGT: case ICmpInst::ICMP_UGT: @@ -8206,6 +8939,7 @@ bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS, if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(Min))) return true; + LLVM_FALLTHROUGH; default: // No change @@ -8488,19 +9222,161 @@ static bool IsKnownPredicateViaMinOrMax(ScalarEvolution &SE, llvm_unreachable("covered switch fell through?!"); } +bool ScalarEvolution::isImpliedViaOperations(ICmpInst::Predicate Pred, + const SCEV *LHS, const SCEV *RHS, + const SCEV *FoundLHS, + const SCEV *FoundRHS, + unsigned Depth) { + assert(getTypeSizeInBits(LHS->getType()) == + getTypeSizeInBits(RHS->getType()) && + "LHS and RHS have different sizes?"); + assert(getTypeSizeInBits(FoundLHS->getType()) == + getTypeSizeInBits(FoundRHS->getType()) && + "FoundLHS and FoundRHS have different sizes?"); + // We want to avoid hurting the compile time with analysis of too big trees. + if (Depth > MaxSCEVOperationsImplicationDepth) + return false; + // We only want to work with ICMP_SGT comparison so far. + // TODO: Extend to ICMP_UGT? + if (Pred == ICmpInst::ICMP_SLT) { + Pred = ICmpInst::ICMP_SGT; + std::swap(LHS, RHS); + std::swap(FoundLHS, FoundRHS); + } + if (Pred != ICmpInst::ICMP_SGT) + return false; + + auto GetOpFromSExt = [&](const SCEV *S) { + if (auto *Ext = dyn_cast<SCEVSignExtendExpr>(S)) + return Ext->getOperand(); + // TODO: If S is a SCEVConstant then you can cheaply "strip" the sext off + // the constant in some cases. + return S; + }; + + // Acquire values from extensions. + auto *OrigFoundLHS = FoundLHS; + LHS = GetOpFromSExt(LHS); + FoundLHS = GetOpFromSExt(FoundLHS); + + // Is the SGT predicate can be proved trivially or using the found context. + auto IsSGTViaContext = [&](const SCEV *S1, const SCEV *S2) { + return isKnownViaSimpleReasoning(ICmpInst::ICMP_SGT, S1, S2) || + isImpliedViaOperations(ICmpInst::ICMP_SGT, S1, S2, OrigFoundLHS, + FoundRHS, Depth + 1); + }; + + if (auto *LHSAddExpr = dyn_cast<SCEVAddExpr>(LHS)) { + // We want to avoid creation of any new non-constant SCEV. Since we are + // going to compare the operands to RHS, we should be certain that we don't + // need any size extensions for this. So let's decline all cases when the + // sizes of types of LHS and RHS do not match. + // TODO: Maybe try to get RHS from sext to catch more cases? + if (getTypeSizeInBits(LHS->getType()) != getTypeSizeInBits(RHS->getType())) + return false; + + // Should not overflow. + if (!LHSAddExpr->hasNoSignedWrap()) + return false; + + auto *LL = LHSAddExpr->getOperand(0); + auto *LR = LHSAddExpr->getOperand(1); + auto *MinusOne = getNegativeSCEV(getOne(RHS->getType())); + + // Checks that S1 >= 0 && S2 > RHS, trivially or using the found context. + auto IsSumGreaterThanRHS = [&](const SCEV *S1, const SCEV *S2) { + return IsSGTViaContext(S1, MinusOne) && IsSGTViaContext(S2, RHS); + }; + // Try to prove the following rule: + // (LHS = LL + LR) && (LL >= 0) && (LR > RHS) => (LHS > RHS). + // (LHS = LL + LR) && (LR >= 0) && (LL > RHS) => (LHS > RHS). + if (IsSumGreaterThanRHS(LL, LR) || IsSumGreaterThanRHS(LR, LL)) + return true; + } else if (auto *LHSUnknownExpr = dyn_cast<SCEVUnknown>(LHS)) { + Value *LL, *LR; + // FIXME: Once we have SDiv implemented, we can get rid of this matching. + using namespace llvm::PatternMatch; + if (match(LHSUnknownExpr->getValue(), m_SDiv(m_Value(LL), m_Value(LR)))) { + // Rules for division. + // We are going to perform some comparisons with Denominator and its + // derivative expressions. In general case, creating a SCEV for it may + // lead to a complex analysis of the entire graph, and in particular it + // can request trip count recalculation for the same loop. This would + // cache as SCEVCouldNotCompute to avoid the infinite recursion. To avoid + // this, we only want to create SCEVs that are constants in this section. + // So we bail if Denominator is not a constant. + if (!isa<ConstantInt>(LR)) + return false; + + auto *Denominator = cast<SCEVConstant>(getSCEV(LR)); + + // We want to make sure that LHS = FoundLHS / Denominator. If it is so, + // then a SCEV for the numerator already exists and matches with FoundLHS. + auto *Numerator = getExistingSCEV(LL); + if (!Numerator || Numerator->getType() != FoundLHS->getType()) + return false; + + // Make sure that the numerator matches with FoundLHS and the denominator + // is positive. + if (!HasSameValue(Numerator, FoundLHS) || !isKnownPositive(Denominator)) + return false; + + auto *DTy = Denominator->getType(); + auto *FRHSTy = FoundRHS->getType(); + if (DTy->isPointerTy() != FRHSTy->isPointerTy()) + // One of types is a pointer and another one is not. We cannot extend + // them properly to a wider type, so let us just reject this case. + // TODO: Usage of getEffectiveSCEVType for DTy, FRHSTy etc should help + // to avoid this check. + return false; + + // Given that: + // FoundLHS > FoundRHS, LHS = FoundLHS / Denominator, Denominator > 0. + auto *WTy = getWiderType(DTy, FRHSTy); + auto *DenominatorExt = getNoopOrSignExtend(Denominator, WTy); + auto *FoundRHSExt = getNoopOrSignExtend(FoundRHS, WTy); + + // Try to prove the following rule: + // (FoundRHS > Denominator - 2) && (RHS <= 0) => (LHS > RHS). + // For example, given that FoundLHS > 2. It means that FoundLHS is at + // least 3. If we divide it by Denominator < 4, we will have at least 1. + auto *DenomMinusTwo = getMinusSCEV(DenominatorExt, getConstant(WTy, 2)); + if (isKnownNonPositive(RHS) && + IsSGTViaContext(FoundRHSExt, DenomMinusTwo)) + return true; + + // Try to prove the following rule: + // (FoundRHS > -1 - Denominator) && (RHS < 0) => (LHS > RHS). + // For example, given that FoundLHS > -3. Then FoundLHS is at least -2. + // If we divide it by Denominator > 2, then: + // 1. If FoundLHS is negative, then the result is 0. + // 2. If FoundLHS is non-negative, then the result is non-negative. + // Anyways, the result is non-negative. + auto *MinusOne = getNegativeSCEV(getOne(WTy)); + auto *NegDenomMinusOne = getMinusSCEV(MinusOne, DenominatorExt); + if (isKnownNegative(RHS) && + IsSGTViaContext(FoundRHSExt, NegDenomMinusOne)) + return true; + } + } + + return false; +} + +bool +ScalarEvolution::isKnownViaSimpleReasoning(ICmpInst::Predicate Pred, + const SCEV *LHS, const SCEV *RHS) { + return isKnownPredicateViaConstantRanges(Pred, LHS, RHS) || + IsKnownPredicateViaMinOrMax(*this, Pred, LHS, RHS) || + IsKnownPredicateViaAddRecStart(*this, Pred, LHS, RHS) || + isKnownPredicateViaNoOverflow(Pred, LHS, RHS); +} + bool ScalarEvolution::isImpliedCondOperandsHelper(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const SCEV *FoundLHS, const SCEV *FoundRHS) { - auto IsKnownPredicateFull = - [this](ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS) { - return isKnownPredicateViaConstantRanges(Pred, LHS, RHS) || - IsKnownPredicateViaMinOrMax(*this, Pred, LHS, RHS) || - IsKnownPredicateViaAddRecStart(*this, Pred, LHS, RHS) || - isKnownPredicateViaNoOverflow(Pred, LHS, RHS); - }; - switch (Pred) { default: llvm_unreachable("Unexpected ICmpInst::Predicate value!"); case ICmpInst::ICMP_EQ: @@ -8510,30 +9386,34 @@ ScalarEvolution::isImpliedCondOperandsHelper(ICmpInst::Predicate Pred, break; case ICmpInst::ICMP_SLT: case ICmpInst::ICMP_SLE: - if (IsKnownPredicateFull(ICmpInst::ICMP_SLE, LHS, FoundLHS) && - IsKnownPredicateFull(ICmpInst::ICMP_SGE, RHS, FoundRHS)) + if (isKnownViaSimpleReasoning(ICmpInst::ICMP_SLE, LHS, FoundLHS) && + isKnownViaSimpleReasoning(ICmpInst::ICMP_SGE, RHS, FoundRHS)) return true; break; case ICmpInst::ICMP_SGT: case ICmpInst::ICMP_SGE: - if (IsKnownPredicateFull(ICmpInst::ICMP_SGE, LHS, FoundLHS) && - IsKnownPredicateFull(ICmpInst::ICMP_SLE, RHS, FoundRHS)) + if (isKnownViaSimpleReasoning(ICmpInst::ICMP_SGE, LHS, FoundLHS) && + isKnownViaSimpleReasoning(ICmpInst::ICMP_SLE, RHS, FoundRHS)) return true; break; case ICmpInst::ICMP_ULT: case ICmpInst::ICMP_ULE: - if (IsKnownPredicateFull(ICmpInst::ICMP_ULE, LHS, FoundLHS) && - IsKnownPredicateFull(ICmpInst::ICMP_UGE, RHS, FoundRHS)) + if (isKnownViaSimpleReasoning(ICmpInst::ICMP_ULE, LHS, FoundLHS) && + isKnownViaSimpleReasoning(ICmpInst::ICMP_UGE, RHS, FoundRHS)) return true; break; case ICmpInst::ICMP_UGT: case ICmpInst::ICMP_UGE: - if (IsKnownPredicateFull(ICmpInst::ICMP_UGE, LHS, FoundLHS) && - IsKnownPredicateFull(ICmpInst::ICMP_ULE, RHS, FoundRHS)) + if (isKnownViaSimpleReasoning(ICmpInst::ICMP_UGE, LHS, FoundLHS) && + isKnownViaSimpleReasoning(ICmpInst::ICMP_ULE, RHS, FoundRHS)) return true; break; } + // Maybe it can be proved via operations? + if (isImpliedViaOperations(Pred, LHS, RHS, FoundLHS, FoundRHS)) + return true; + return false; } @@ -8551,7 +9431,7 @@ bool ScalarEvolution::isImpliedCondOperandsViaRanges(ICmpInst::Predicate Pred, if (!Addend) return false; - APInt ConstFoundRHS = cast<SCEVConstant>(FoundRHS)->getAPInt(); + const APInt &ConstFoundRHS = cast<SCEVConstant>(FoundRHS)->getAPInt(); // `FoundLHSRange` is the range we know `FoundLHS` to be in by virtue of the // antecedent "`FoundLHS` `Pred` `FoundRHS`". @@ -8563,7 +9443,7 @@ bool ScalarEvolution::isImpliedCondOperandsViaRanges(ICmpInst::Predicate Pred, // We can also compute the range of values for `LHS` that satisfy the // consequent, "`LHS` `Pred` `RHS`": - APInt ConstRHS = cast<SCEVConstant>(RHS)->getAPInt(); + const APInt &ConstRHS = cast<SCEVConstant>(RHS)->getAPInt(); ConstantRange SatisfyingLHSRange = ConstantRange::makeSatisfyingICmpRegion(Pred, ConstRHS); @@ -8582,22 +9462,20 @@ bool ScalarEvolution::doesIVOverflowOnLT(const SCEV *RHS, const SCEV *Stride, const SCEV *One = getOne(Stride->getType()); if (IsSigned) { - APInt MaxRHS = getSignedRange(RHS).getSignedMax(); + APInt MaxRHS = getSignedRangeMax(RHS); APInt MaxValue = APInt::getSignedMaxValue(BitWidth); - APInt MaxStrideMinusOne = getSignedRange(getMinusSCEV(Stride, One)) - .getSignedMax(); + APInt MaxStrideMinusOne = getSignedRangeMax(getMinusSCEV(Stride, One)); // SMaxRHS + SMaxStrideMinusOne > SMaxValue => overflow! - return (MaxValue - MaxStrideMinusOne).slt(MaxRHS); + return (std::move(MaxValue) - MaxStrideMinusOne).slt(MaxRHS); } - APInt MaxRHS = getUnsignedRange(RHS).getUnsignedMax(); + APInt MaxRHS = getUnsignedRangeMax(RHS); APInt MaxValue = APInt::getMaxValue(BitWidth); - APInt MaxStrideMinusOne = getUnsignedRange(getMinusSCEV(Stride, One)) - .getUnsignedMax(); + APInt MaxStrideMinusOne = getUnsignedRangeMax(getMinusSCEV(Stride, One)); // UMaxRHS + UMaxStrideMinusOne > UMaxValue => overflow! - return (MaxValue - MaxStrideMinusOne).ult(MaxRHS); + return (std::move(MaxValue) - MaxStrideMinusOne).ult(MaxRHS); } bool ScalarEvolution::doesIVOverflowOnGT(const SCEV *RHS, const SCEV *Stride, @@ -8608,22 +9486,20 @@ bool ScalarEvolution::doesIVOverflowOnGT(const SCEV *RHS, const SCEV *Stride, const SCEV *One = getOne(Stride->getType()); if (IsSigned) { - APInt MinRHS = getSignedRange(RHS).getSignedMin(); + APInt MinRHS = getSignedRangeMin(RHS); APInt MinValue = APInt::getSignedMinValue(BitWidth); - APInt MaxStrideMinusOne = getSignedRange(getMinusSCEV(Stride, One)) - .getSignedMax(); + APInt MaxStrideMinusOne = getSignedRangeMax(getMinusSCEV(Stride, One)); // SMinRHS - SMaxStrideMinusOne < SMinValue => overflow! - return (MinValue + MaxStrideMinusOne).sgt(MinRHS); + return (std::move(MinValue) + MaxStrideMinusOne).sgt(MinRHS); } - APInt MinRHS = getUnsignedRange(RHS).getUnsignedMin(); + APInt MinRHS = getUnsignedRangeMin(RHS); APInt MinValue = APInt::getMinValue(BitWidth); - APInt MaxStrideMinusOne = getUnsignedRange(getMinusSCEV(Stride, One)) - .getUnsignedMax(); + APInt MaxStrideMinusOne = getUnsignedRangeMax(getMinusSCEV(Stride, One)); // UMinRHS - UMaxStrideMinusOne < UMinValue => overflow! - return (MinValue + MaxStrideMinusOne).ugt(MinRHS); + return (std::move(MinValue) + MaxStrideMinusOne).ugt(MinRHS); } const SCEV *ScalarEvolution::computeBECount(const SCEV *Delta, const SCEV *Step, @@ -8759,8 +9635,8 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS, } else { // Calculate the maximum backedge count based on the range of values // permitted by Start, End, and Stride. - APInt MinStart = IsSigned ? getSignedRange(Start).getSignedMin() - : getUnsignedRange(Start).getUnsignedMin(); + APInt MinStart = IsSigned ? getSignedRangeMin(Start) + : getUnsignedRangeMin(Start); unsigned BitWidth = getTypeSizeInBits(LHS->getType()); @@ -8768,8 +9644,8 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS, if (PositiveStride) StrideForMaxBECount = - IsSigned ? getSignedRange(Stride).getSignedMin() - : getUnsignedRange(Stride).getUnsignedMin(); + IsSigned ? getSignedRangeMin(Stride) + : getUnsignedRangeMin(Stride); else // Using a stride of 1 is safe when computing max backedge taken count for // a loop with unknown stride. @@ -8783,15 +9659,16 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS, // the case End = RHS. This is safe because in the other case (End - Start) // is zero, leading to a zero maximum backedge taken count. APInt MaxEnd = - IsSigned ? APIntOps::smin(getSignedRange(RHS).getSignedMax(), Limit) - : APIntOps::umin(getUnsignedRange(RHS).getUnsignedMax(), Limit); + IsSigned ? APIntOps::smin(getSignedRangeMax(RHS), Limit) + : APIntOps::umin(getUnsignedRangeMax(RHS), Limit); MaxBECount = computeBECount(getConstant(MaxEnd - MinStart), getConstant(StrideForMaxBECount), false); } - if (isa<SCEVCouldNotCompute>(MaxBECount)) - MaxBECount = BECount; + if (isa<SCEVCouldNotCompute>(MaxBECount) && + !isa<SCEVCouldNotCompute>(BECount)) + MaxBECount = getConstant(getUnsignedRangeMax(BECount)); return ExitLimit(BECount, MaxBECount, MaxOrZero, Predicates); } @@ -8842,11 +9719,11 @@ ScalarEvolution::howManyGreaterThans(const SCEV *LHS, const SCEV *RHS, const SCEV *BECount = computeBECount(getMinusSCEV(Start, End), Stride, false); - APInt MaxStart = IsSigned ? getSignedRange(Start).getSignedMax() - : getUnsignedRange(Start).getUnsignedMax(); + APInt MaxStart = IsSigned ? getSignedRangeMax(Start) + : getUnsignedRangeMax(Start); - APInt MinStride = IsSigned ? getSignedRange(Stride).getSignedMin() - : getUnsignedRange(Stride).getUnsignedMin(); + APInt MinStride = IsSigned ? getSignedRangeMin(Stride) + : getUnsignedRangeMin(Stride); unsigned BitWidth = getTypeSizeInBits(LHS->getType()); APInt Limit = IsSigned ? APInt::getSignedMinValue(BitWidth) + (MinStride - 1) @@ -8856,8 +9733,8 @@ ScalarEvolution::howManyGreaterThans(const SCEV *LHS, const SCEV *RHS, // the case End = RHS. This is safe because in the other case (Start - End) // is zero, leading to a zero maximum backedge taken count. APInt MinEnd = - IsSigned ? APIntOps::smax(getSignedRange(RHS).getSignedMin(), Limit) - : APIntOps::umax(getUnsignedRange(RHS).getUnsignedMin(), Limit); + IsSigned ? APIntOps::smax(getSignedRangeMin(RHS), Limit) + : APIntOps::umax(getUnsignedRangeMin(RHS), Limit); const SCEV *MaxBECount = getCouldNotCompute(); @@ -8914,9 +9791,8 @@ const SCEV *SCEVAddRecExpr::getNumIterationsInRange(const ConstantRange &Range, // the upper value of the range must be the first possible exit value. // If A is negative then the lower of the range is the last possible loop // value. Also note that we already checked for a full range. - APInt One(BitWidth,1); APInt A = cast<SCEVConstant>(getOperand(1))->getAPInt(); - APInt End = A.sge(One) ? (Range.getUpper() - One) : Range.getLower(); + APInt End = A.sge(1) ? (Range.getUpper() - 1) : Range.getLower(); // The exit value should be (End+A)/A. APInt ExitVal = (End + A).udiv(A); @@ -8932,7 +9808,7 @@ const SCEV *SCEVAddRecExpr::getNumIterationsInRange(const ConstantRange &Range, // Ensure that the previous value is in the range. This is a sanity check. assert(Range.contains( EvaluateConstantChrecAtConstant(this, - ConstantInt::get(SE.getContext(), ExitVal - One), SE)->getValue()) && + ConstantInt::get(SE.getContext(), ExitVal - 1), SE)->getValue()) && "Linear scev computation is off in a bad way!"); return SE.getConstant(ExitValue); } else if (isQuadratic()) { @@ -9083,8 +9959,11 @@ struct SCEVCollectAddRecMultiplies { bool HasAddRec = false; SmallVector<const SCEV *, 0> Operands; for (auto Op : Mul->operands()) { - if (isa<SCEVUnknown>(Op)) { + const SCEVUnknown *Unknown = dyn_cast<SCEVUnknown>(Op); + if (Unknown && !isa<CallInst>(Unknown->getValue())) { Operands.push_back(Op); + } else if (Unknown) { + HasAddRec = true; } else { bool ContainsAddRec; SCEVHasAddRec ContiansAddRec(ContainsAddRec); @@ -9238,7 +10117,7 @@ const SCEV *ScalarEvolution::getElementSize(Instruction *Inst) { void ScalarEvolution::findArrayDimensions(SmallVectorImpl<const SCEV *> &Terms, SmallVectorImpl<const SCEV *> &Sizes, - const SCEV *ElementSize) const { + const SCEV *ElementSize) { if (Terms.size() < 1 || !ElementSize) return; @@ -9254,7 +10133,7 @@ void ScalarEvolution::findArrayDimensions(SmallVectorImpl<const SCEV *> &Terms, }); // Remove duplicates. - std::sort(Terms.begin(), Terms.end()); + array_pod_sort(Terms.begin(), Terms.end()); Terms.erase(std::unique(Terms.begin(), Terms.end()), Terms.end()); // Put larger terms first. @@ -9262,13 +10141,11 @@ void ScalarEvolution::findArrayDimensions(SmallVectorImpl<const SCEV *> &Terms, return numberOfTerms(LHS) > numberOfTerms(RHS); }); - ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this); - // Try to divide all terms by the element size. If term is not divisible by // element size, proceed with the original term. for (const SCEV *&Term : Terms) { const SCEV *Q, *R; - SCEVDivision::divide(SE, Term, ElementSize, &Q, &R); + SCEVDivision::divide(*this, Term, ElementSize, &Q, &R); if (!Q->isZero()) Term = Q; } @@ -9277,7 +10154,7 @@ void ScalarEvolution::findArrayDimensions(SmallVectorImpl<const SCEV *> &Terms, // Remove constant factors. for (const SCEV *T : Terms) - if (const SCEV *NewT = removeConstantFactors(SE, T)) + if (const SCEV *NewT = removeConstantFactors(*this, T)) NewTerms.push_back(NewT); DEBUG({ @@ -9286,8 +10163,7 @@ void ScalarEvolution::findArrayDimensions(SmallVectorImpl<const SCEV *> &Terms, dbgs() << *T << "\n"; }); - if (NewTerms.empty() || - !findArrayDimensionsRec(SE, NewTerms, Sizes)) { + if (NewTerms.empty() || !findArrayDimensionsRec(*this, NewTerms, Sizes)) { Sizes.clear(); return; } @@ -9524,6 +10400,7 @@ ScalarEvolution::ScalarEvolution(ScalarEvolution &&Arg) ValueExprMap(std::move(Arg.ValueExprMap)), PendingLoopPredicates(std::move(Arg.PendingLoopPredicates)), WalkingBEDominatingConds(false), ProvingSplitPredicate(false), + MinTrailingZerosCache(std::move(Arg.MinTrailingZerosCache)), BackedgeTakenCounts(std::move(Arg.BackedgeTakenCounts)), PredicatedBackedgeTakenCounts( std::move(Arg.PredicatedBackedgeTakenCounts)), @@ -9538,6 +10415,7 @@ ScalarEvolution::ScalarEvolution(ScalarEvolution &&Arg) UniqueSCEVs(std::move(Arg.UniqueSCEVs)), UniquePreds(std::move(Arg.UniquePreds)), SCEVAllocator(std::move(Arg.SCEVAllocator)), + PredicatedSCEVRewrites(std::move(Arg.PredicatedSCEVRewrites)), FirstUnknown(Arg.FirstUnknown) { Arg.FirstUnknown = nullptr; } @@ -9621,6 +10499,13 @@ static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE, OS << "Unpredictable predicated backedge-taken count. "; } OS << "\n"; + + if (SE->hasLoopInvariantBackedgeTakenCount(L)) { + OS << "Loop "; + L->getHeader()->printAsOperand(OS, /*PrintType=*/false); + OS << ": "; + OS << "Trip multiple is " << SE->getSmallConstantTripMultiple(L) << "\n"; + } } static StringRef loopDispositionToStr(ScalarEvolution::LoopDisposition LD) { @@ -9929,6 +10814,16 @@ void ScalarEvolution::forgetMemoizedResults(const SCEV *S) { SignedRanges.erase(S); ExprValueMap.erase(S); HasRecMap.erase(S); + MinTrailingZerosCache.erase(S); + + for (auto I = PredicatedSCEVRewrites.begin(); + I != PredicatedSCEVRewrites.end();) { + std::pair<const SCEV *, const Loop *> Entry = I->first; + if (Entry.first == S) + PredicatedSCEVRewrites.erase(I++); + else + ++I; + } auto RemoveSCEVFromBackedgeMap = [S, this](DenseMap<const Loop *, BackedgeTakenInfo> &Map) { @@ -9946,84 +10841,75 @@ void ScalarEvolution::forgetMemoizedResults(const SCEV *S) { RemoveSCEVFromBackedgeMap(PredicatedBackedgeTakenCounts); } -typedef DenseMap<const Loop *, std::string> VerifyMap; +void ScalarEvolution::verify() const { + ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this); + ScalarEvolution SE2(F, TLI, AC, DT, LI); -/// replaceSubString - Replaces all occurrences of From in Str with To. -static void replaceSubString(std::string &Str, StringRef From, StringRef To) { - size_t Pos = 0; - while ((Pos = Str.find(From, Pos)) != std::string::npos) { - Str.replace(Pos, From.size(), To.data(), To.size()); - Pos += To.size(); - } -} + SmallVector<Loop *, 8> LoopStack(LI.begin(), LI.end()); -/// getLoopBackedgeTakenCounts - Helper method for verifyAnalysis. -static void -getLoopBackedgeTakenCounts(Loop *L, VerifyMap &Map, ScalarEvolution &SE) { - std::string &S = Map[L]; - if (S.empty()) { - raw_string_ostream OS(S); - SE.getBackedgeTakenCount(L)->print(OS); + // Map's SCEV expressions from one ScalarEvolution "universe" to another. + struct SCEVMapper : public SCEVRewriteVisitor<SCEVMapper> { + const SCEV *visitConstant(const SCEVConstant *Constant) { + return SE.getConstant(Constant->getAPInt()); + } + const SCEV *visitUnknown(const SCEVUnknown *Expr) { + return SE.getUnknown(Expr->getValue()); + } - // false and 0 are semantically equivalent. This can happen in dead loops. - replaceSubString(OS.str(), "false", "0"); - // Remove wrap flags, their use in SCEV is highly fragile. - // FIXME: Remove this when SCEV gets smarter about them. - replaceSubString(OS.str(), "<nw>", ""); - replaceSubString(OS.str(), "<nsw>", ""); - replaceSubString(OS.str(), "<nuw>", ""); - } + const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) { + return SE.getCouldNotCompute(); + } + SCEVMapper(ScalarEvolution &SE) : SCEVRewriteVisitor<SCEVMapper>(SE) {} + }; - for (auto *R : reverse(*L)) - getLoopBackedgeTakenCounts(R, Map, SE); // recurse. -} + SCEVMapper SCM(SE2); -void ScalarEvolution::verify() const { - ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this); + while (!LoopStack.empty()) { + auto *L = LoopStack.pop_back_val(); + LoopStack.insert(LoopStack.end(), L->begin(), L->end()); - // Gather stringified backedge taken counts for all loops using SCEV's caches. - // FIXME: It would be much better to store actual values instead of strings, - // but SCEV pointers will change if we drop the caches. - VerifyMap BackedgeDumpsOld, BackedgeDumpsNew; - for (LoopInfo::reverse_iterator I = LI.rbegin(), E = LI.rend(); I != E; ++I) - getLoopBackedgeTakenCounts(*I, BackedgeDumpsOld, SE); + auto *CurBECount = SCM.visit( + const_cast<ScalarEvolution *>(this)->getBackedgeTakenCount(L)); + auto *NewBECount = SE2.getBackedgeTakenCount(L); - // Gather stringified backedge taken counts for all loops using a fresh - // ScalarEvolution object. - ScalarEvolution SE2(F, TLI, AC, DT, LI); - for (LoopInfo::reverse_iterator I = LI.rbegin(), E = LI.rend(); I != E; ++I) - getLoopBackedgeTakenCounts(*I, BackedgeDumpsNew, SE2); - - // Now compare whether they're the same with and without caches. This allows - // verifying that no pass changed the cache. - assert(BackedgeDumpsOld.size() == BackedgeDumpsNew.size() && - "New loops suddenly appeared!"); - - for (VerifyMap::iterator OldI = BackedgeDumpsOld.begin(), - OldE = BackedgeDumpsOld.end(), - NewI = BackedgeDumpsNew.begin(); - OldI != OldE; ++OldI, ++NewI) { - assert(OldI->first == NewI->first && "Loop order changed!"); - - // Compare the stringified SCEVs. We don't care if undef backedgetaken count - // changes. - // FIXME: We currently ignore SCEV changes from/to CouldNotCompute. This - // means that a pass is buggy or SCEV has to learn a new pattern but is - // usually not harmful. - if (OldI->second != NewI->second && - OldI->second.find("undef") == std::string::npos && - NewI->second.find("undef") == std::string::npos && - OldI->second != "***COULDNOTCOMPUTE***" && - NewI->second != "***COULDNOTCOMPUTE***") { - dbgs() << "SCEVValidator: SCEV for loop '" - << OldI->first->getHeader()->getName() - << "' changed from '" << OldI->second - << "' to '" << NewI->second << "'!\n"; + if (CurBECount == SE2.getCouldNotCompute() || + NewBECount == SE2.getCouldNotCompute()) { + // NB! This situation is legal, but is very suspicious -- whatever pass + // change the loop to make a trip count go from could not compute to + // computable or vice-versa *should have* invalidated SCEV. However, we + // choose not to assert here (for now) since we don't want false + // positives. + continue; + } + + if (containsUndefs(CurBECount) || containsUndefs(NewBECount)) { + // SCEV treats "undef" as an unknown but consistent value (i.e. it does + // not propagate undef aggressively). This means we can (and do) fail + // verification in cases where a transform makes the trip count of a loop + // go from "undef" to "undef+1" (say). The transform is fine, since in + // both cases the loop iterates "undef" times, but SCEV thinks we + // increased the trip count of the loop by 1 incorrectly. + continue; + } + + if (SE.getTypeSizeInBits(CurBECount->getType()) > + SE.getTypeSizeInBits(NewBECount->getType())) + NewBECount = SE2.getZeroExtendExpr(NewBECount, CurBECount->getType()); + else if (SE.getTypeSizeInBits(CurBECount->getType()) < + SE.getTypeSizeInBits(NewBECount->getType())) + CurBECount = SE2.getZeroExtendExpr(CurBECount, NewBECount->getType()); + + auto *ConstantDelta = + dyn_cast<SCEVConstant>(SE2.getMinusSCEV(CurBECount, NewBECount)); + + if (ConstantDelta && ConstantDelta->getAPInt() != 0) { + dbgs() << "Trip Count Changed!\n"; + dbgs() << "Old: " << *CurBECount << "\n"; + dbgs() << "New: " << *NewBECount << "\n"; + dbgs() << "Delta: " << *ConstantDelta << "\n"; std::abort(); } } - - // TODO: Verify more things. } bool ScalarEvolution::invalidate( @@ -10098,10 +10984,11 @@ void ScalarEvolutionWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { AU.addRequiredTransitive<TargetLibraryInfoWrapperPass>(); } -const SCEVPredicate * -ScalarEvolution::getEqualPredicate(const SCEVUnknown *LHS, - const SCEVConstant *RHS) { +const SCEVPredicate *ScalarEvolution::getEqualPredicate(const SCEV *LHS, + const SCEV *RHS) { FoldingSetNodeID ID; + assert(LHS->getType() == RHS->getType() && + "Type mismatch between LHS and RHS"); // Unique this node based on the arguments ID.AddInteger(SCEVPredicate::P_Equal); ID.AddPointer(LHS); @@ -10164,8 +11051,7 @@ public: if (IPred->getLHS() == Expr) return IPred->getRHS(); } - - return Expr; + return convertToAddRecWithPreds(Expr); } const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { @@ -10201,17 +11087,41 @@ public: } private: - bool addOverflowAssumption(const SCEVAddRecExpr *AR, - SCEVWrapPredicate::IncrementWrapFlags AddedFlags) { - auto *A = SE.getWrapPredicate(AR, AddedFlags); + bool addOverflowAssumption(const SCEVPredicate *P) { if (!NewPreds) { // Check if we've already made this assumption. - return Pred && Pred->implies(A); + return Pred && Pred->implies(P); } - NewPreds->insert(A); + NewPreds->insert(P); return true; } + bool addOverflowAssumption(const SCEVAddRecExpr *AR, + SCEVWrapPredicate::IncrementWrapFlags AddedFlags) { + auto *A = SE.getWrapPredicate(AR, AddedFlags); + return addOverflowAssumption(A); + } + + // If \p Expr represents a PHINode, we try to see if it can be represented + // as an AddRec, possibly under a predicate (PHISCEVPred). If it is possible + // to add this predicate as a runtime overflow check, we return the AddRec. + // If \p Expr does not meet these conditions (is not a PHI node, or we + // couldn't create an AddRec for it, or couldn't add the predicate), we just + // return \p Expr. + const SCEV *convertToAddRecWithPreds(const SCEVUnknown *Expr) { + if (!isa<PHINode>(Expr->getValue())) + return Expr; + Optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>> + PredicatedRewrite = SE.createAddRecFromPHIWithCasts(Expr); + if (!PredicatedRewrite) + return Expr; + for (auto *P : PredicatedRewrite->second){ + if (!addOverflowAssumption(P)) + return Expr; + } + return PredicatedRewrite->first; + } + SmallPtrSetImpl<const SCEVPredicate *> *NewPreds; SCEVUnionPredicate *Pred; const Loop *L; @@ -10248,9 +11158,11 @@ SCEVPredicate::SCEVPredicate(const FoldingSetNodeIDRef ID, : FastID(ID), Kind(Kind) {} SCEVEqualPredicate::SCEVEqualPredicate(const FoldingSetNodeIDRef ID, - const SCEVUnknown *LHS, - const SCEVConstant *RHS) - : SCEVPredicate(ID, P_Equal), LHS(LHS), RHS(RHS) {} + const SCEV *LHS, const SCEV *RHS) + : SCEVPredicate(ID, P_Equal), LHS(LHS), RHS(RHS) { + assert(LHS->getType() == RHS->getType() && "LHS and RHS types don't match"); + assert(LHS != RHS && "LHS and RHS are the same SCEV"); +} bool SCEVEqualPredicate::implies(const SCEVPredicate *N) const { const auto *Op = dyn_cast<SCEVEqualPredicate>(N); |