diff options
Diffstat (limited to 'contrib/llvm/lib/Analysis/ScalarEvolution.cpp')
-rw-r--r-- | contrib/llvm/lib/Analysis/ScalarEvolution.cpp | 1251 |
1 files changed, 730 insertions, 521 deletions
diff --git a/contrib/llvm/lib/Analysis/ScalarEvolution.cpp b/contrib/llvm/lib/Analysis/ScalarEvolution.cpp index ee42737..0e9f812 100644 --- a/contrib/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/contrib/llvm/lib/Analysis/ScalarEvolution.cpp @@ -68,6 +68,7 @@ #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" +#include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/ConstantRange.h" #include "llvm/IR/Constants.h" @@ -87,7 +88,6 @@ #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" -#include "llvm/Target/TargetLibraryInfo.h" #include <algorithm> using namespace llvm; @@ -117,9 +117,9 @@ VerifySCEV("verify-scev", INITIALIZE_PASS_BEGIN(ScalarEvolution, "scalar-evolution", "Scalar Evolution Analysis", false, true) INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) -INITIALIZE_PASS_DEPENDENCY(LoopInfo) +INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) -INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfo) +INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) INITIALIZE_PASS_END(ScalarEvolution, "scalar-evolution", "Scalar Evolution Analysis", false, true) char ScalarEvolution::ID = 0; @@ -726,6 +726,13 @@ public: return; } + // A simple case when N/1. The quotient is N. + if (Denominator->isOne()) { + *Quotient = Numerator; + *Remainder = D.Zero; + return; + } + // Split the Denominator when it is a product. if (const SCEVMulExpr *T = dyn_cast<const SCEVMulExpr>(Denominator)) { const SCEV *Q, *R; @@ -788,6 +795,14 @@ public: assert(Numerator->isAffine() && "Numerator should be affine"); divide(SE, Numerator->getStart(), Denominator, &StartQ, &StartR); divide(SE, Numerator->getStepRecurrence(SE), Denominator, &StepQ, &StepR); + // Bail out if the types do not match. + Type *Ty = Denominator->getType(); + if (Ty != StartQ->getType() || Ty != StartR->getType() || + Ty != StepQ->getType() || Ty != StepR->getType()) { + Quotient = Zero; + Remainder = Numerator; + return; + } Quotient = SE.getAddRecExpr(StartQ, StepQ, Numerator->getLoop(), Numerator->getNoWrapFlags()); Remainder = SE.getAddRecExpr(StartR, StepR, Numerator->getLoop(), @@ -1102,13 +1117,14 @@ const SCEV *ScalarEvolution::getTruncateExpr(const SCEV *Op, return getTruncateOrZeroExtend(SZ->getOperand(), Ty); // trunc(x1+x2+...+xN) --> trunc(x1)+trunc(x2)+...+trunc(xN) if we can - // eliminate all the truncates. + // eliminate all the truncates, or we replace other casts with truncates. if (const SCEVAddExpr *SA = dyn_cast<SCEVAddExpr>(Op)) { SmallVector<const SCEV *, 4> Operands; bool hasTrunc = false; for (unsigned i = 0, e = SA->getNumOperands(); i != e && !hasTrunc; ++i) { const SCEV *S = getTruncateExpr(SA->getOperand(i), Ty); - hasTrunc = isa<SCEVTruncateExpr>(S); + if (!isa<SCEVCastExpr>(SA->getOperand(i))) + hasTrunc = isa<SCEVTruncateExpr>(S); Operands.push_back(S); } if (!hasTrunc) @@ -1117,13 +1133,14 @@ const SCEV *ScalarEvolution::getTruncateExpr(const SCEV *Op, } // trunc(x1*x2*...*xN) --> trunc(x1)*trunc(x2)*...*trunc(xN) if we can - // eliminate all the truncates. + // eliminate all the truncates, or we replace other casts with truncates. if (const SCEVMulExpr *SM = dyn_cast<SCEVMulExpr>(Op)) { SmallVector<const SCEV *, 4> Operands; bool hasTrunc = false; for (unsigned i = 0, e = SM->getNumOperands(); i != e && !hasTrunc; ++i) { const SCEV *S = getTruncateExpr(SM->getOperand(i), Ty); - hasTrunc = isa<SCEVTruncateExpr>(S); + if (!isa<SCEVCastExpr>(SM->getOperand(i))) + hasTrunc = isa<SCEVTruncateExpr>(S); Operands.push_back(S); } if (!hasTrunc) @@ -1148,6 +1165,262 @@ const SCEV *ScalarEvolution::getTruncateExpr(const SCEV *Op, return S; } +// Get the limit of a recurrence such that incrementing by Step cannot cause +// signed overflow as long as the value of the recurrence within the +// loop does not exceed this limit before incrementing. +static const SCEV *getSignedOverflowLimitForStep(const SCEV *Step, + ICmpInst::Predicate *Pred, + ScalarEvolution *SE) { + unsigned BitWidth = SE->getTypeSizeInBits(Step->getType()); + if (SE->isKnownPositive(Step)) { + *Pred = ICmpInst::ICMP_SLT; + return SE->getConstant(APInt::getSignedMinValue(BitWidth) - + SE->getSignedRange(Step).getSignedMax()); + } + if (SE->isKnownNegative(Step)) { + *Pred = ICmpInst::ICMP_SGT; + return SE->getConstant(APInt::getSignedMaxValue(BitWidth) - + SE->getSignedRange(Step).getSignedMin()); + } + return nullptr; +} + +// Get the limit of a recurrence such that incrementing by Step cannot cause +// unsigned overflow as long as the value of the recurrence within the loop does +// not exceed this limit before incrementing. +static const SCEV *getUnsignedOverflowLimitForStep(const SCEV *Step, + ICmpInst::Predicate *Pred, + ScalarEvolution *SE) { + unsigned BitWidth = SE->getTypeSizeInBits(Step->getType()); + *Pred = ICmpInst::ICMP_ULT; + + return SE->getConstant(APInt::getMinValue(BitWidth) - + SE->getUnsignedRange(Step).getUnsignedMax()); +} + +namespace { + +struct ExtendOpTraitsBase { + typedef const SCEV *(ScalarEvolution::*GetExtendExprTy)(const SCEV *, Type *); +}; + +// Used to make code generic over signed and unsigned overflow. +template <typename ExtendOp> struct ExtendOpTraits { + // Members present: + // + // static const SCEV::NoWrapFlags WrapType; + // + // static const ExtendOpTraitsBase::GetExtendExprTy GetExtendExpr; + // + // static const SCEV *getOverflowLimitForStep(const SCEV *Step, + // ICmpInst::Predicate *Pred, + // ScalarEvolution *SE); +}; + +template <> +struct ExtendOpTraits<SCEVSignExtendExpr> : public ExtendOpTraitsBase { + static const SCEV::NoWrapFlags WrapType = SCEV::FlagNSW; + + static const GetExtendExprTy GetExtendExpr; + + static const SCEV *getOverflowLimitForStep(const SCEV *Step, + ICmpInst::Predicate *Pred, + ScalarEvolution *SE) { + return getSignedOverflowLimitForStep(Step, Pred, SE); + } +}; + +const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits< + SCEVSignExtendExpr>::GetExtendExpr = &ScalarEvolution::getSignExtendExpr; + +template <> +struct ExtendOpTraits<SCEVZeroExtendExpr> : public ExtendOpTraitsBase { + static const SCEV::NoWrapFlags WrapType = SCEV::FlagNUW; + + static const GetExtendExprTy GetExtendExpr; + + static const SCEV *getOverflowLimitForStep(const SCEV *Step, + ICmpInst::Predicate *Pred, + ScalarEvolution *SE) { + return getUnsignedOverflowLimitForStep(Step, Pred, SE); + } +}; + +const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits< + SCEVZeroExtendExpr>::GetExtendExpr = &ScalarEvolution::getZeroExtendExpr; +} + +// The recurrence AR has been shown to have no signed/unsigned wrap or something +// close to it. Typically, if we can prove NSW/NUW for AR, then we can just as +// easily prove NSW/NUW for its preincrement or postincrement sibling. This +// allows normalizing a sign/zero extended AddRec as such: {sext/zext(Step + +// Start),+,Step} => {(Step + sext/zext(Start),+,Step} As a result, the +// expression "Step + sext/zext(PreIncAR)" is congruent with +// "sext/zext(PostIncAR)" +template <typename ExtendOpTy> +static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty, + ScalarEvolution *SE) { + auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType; + auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr; + + const Loop *L = AR->getLoop(); + const SCEV *Start = AR->getStart(); + const SCEV *Step = AR->getStepRecurrence(*SE); + + // Check for a simple looking step prior to loop entry. + const SCEVAddExpr *SA = dyn_cast<SCEVAddExpr>(Start); + if (!SA) + return nullptr; + + // Create an AddExpr for "PreStart" after subtracting Step. Full SCEV + // subtraction is expensive. For this purpose, perform a quick and dirty + // difference, by checking for Step in the operand list. + SmallVector<const SCEV *, 4> DiffOps; + for (const SCEV *Op : SA->operands()) + if (Op != Step) + DiffOps.push_back(Op); + + if (DiffOps.size() == SA->getNumOperands()) + return nullptr; + + // Try to prove `WrapType` (SCEV::FlagNSW or SCEV::FlagNUW) on `PreStart` + + // `Step`: + + // 1. NSW/NUW flags on the step increment. + const SCEV *PreStart = SE->getAddExpr(DiffOps, SA->getNoWrapFlags()); + const SCEVAddRecExpr *PreAR = dyn_cast<SCEVAddRecExpr>( + SE->getAddRecExpr(PreStart, Step, L, SCEV::FlagAnyWrap)); + + // "{S,+,X} is <nsw>/<nuw>" and "the backedge is taken at least once" implies + // "S+X does not sign/unsign-overflow". + // + + const SCEV *BECount = SE->getBackedgeTakenCount(L); + if (PreAR && PreAR->getNoWrapFlags(WrapType) && + !isa<SCEVCouldNotCompute>(BECount) && SE->isKnownPositive(BECount)) + return PreStart; + + // 2. Direct overflow check on the step operation's expression. + unsigned BitWidth = SE->getTypeSizeInBits(AR->getType()); + Type *WideTy = IntegerType::get(SE->getContext(), BitWidth * 2); + const SCEV *OperandExtendedStart = + SE->getAddExpr((SE->*GetExtendExpr)(PreStart, WideTy), + (SE->*GetExtendExpr)(Step, WideTy)); + if ((SE->*GetExtendExpr)(Start, WideTy) == OperandExtendedStart) { + if (PreAR && AR->getNoWrapFlags(WrapType)) { + // If we know `AR` == {`PreStart`+`Step`,+,`Step`} is `WrapType` (FlagNSW + // or FlagNUW) and that `PreStart` + `Step` is `WrapType` too, then + // `PreAR` == {`PreStart`,+,`Step`} is also `WrapType`. Cache this fact. + const_cast<SCEVAddRecExpr *>(PreAR)->setNoWrapFlags(WrapType); + } + return PreStart; + } + + // 3. Loop precondition. + ICmpInst::Predicate Pred; + const SCEV *OverflowLimit = + ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(Step, &Pred, SE); + + if (OverflowLimit && + SE->isLoopEntryGuardedByCond(L, Pred, PreStart, OverflowLimit)) { + return PreStart; + } + return nullptr; +} + +// Get the normalized zero or sign extended expression for this AddRec's Start. +template <typename ExtendOpTy> +static const SCEV *getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty, + ScalarEvolution *SE) { + auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr; + + const SCEV *PreStart = getPreStartForExtend<ExtendOpTy>(AR, Ty, SE); + if (!PreStart) + return (SE->*GetExtendExpr)(AR->getStart(), Ty); + + return SE->getAddExpr((SE->*GetExtendExpr)(AR->getStepRecurrence(*SE), Ty), + (SE->*GetExtendExpr)(PreStart, Ty)); +} + +// Try to prove away overflow by looking at "nearby" add recurrences. A +// motivating example for this rule: if we know `{0,+,4}` is `ult` `-1` and it +// does not itself wrap then we can conclude that `{1,+,4}` is `nuw`. +// +// Formally: +// +// {S,+,X} == {S-T,+,X} + T +// => Ext({S,+,X}) == Ext({S-T,+,X} + T) +// +// If ({S-T,+,X} + T) does not overflow ... (1) +// +// RHS == Ext({S-T,+,X} + T) == Ext({S-T,+,X}) + Ext(T) +// +// If {S-T,+,X} does not overflow ... (2) +// +// RHS == Ext({S-T,+,X}) + Ext(T) == {Ext(S-T),+,Ext(X)} + Ext(T) +// == {Ext(S-T)+Ext(T),+,Ext(X)} +// +// If (S-T)+T does not overflow ... (3) +// +// RHS == {Ext(S-T)+Ext(T),+,Ext(X)} == {Ext(S-T+T),+,Ext(X)} +// == {Ext(S),+,Ext(X)} == LHS +// +// Thus, if (1), (2) and (3) are true for some T, then +// Ext({S,+,X}) == {Ext(S),+,Ext(X)} +// +// (3) is implied by (1) -- "(S-T)+T does not overflow" is simply "({S-T,+,X}+T) +// does not overflow" restricted to the 0th iteration. Therefore we only need +// to check for (1) and (2). +// +// In the current context, S is `Start`, X is `Step`, Ext is `ExtendOpTy` and T +// is `Delta` (defined below). +// +template <typename ExtendOpTy> +bool ScalarEvolution::proveNoWrapByVaryingStart(const SCEV *Start, + const SCEV *Step, + const Loop *L) { + auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType; + + // We restrict `Start` to a constant to prevent SCEV from spending too much + // time here. It is correct (but more expensive) to continue with a + // non-constant `Start` and do a general SCEV subtraction to compute + // `PreStart` below. + // + const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start); + if (!StartC) + return false; + + APInt StartAI = StartC->getValue()->getValue(); + + for (unsigned Delta : {-2, -1, 1, 2}) { + const SCEV *PreStart = getConstant(StartAI - Delta); + + // Give up if we don't already have the add recurrence we need because + // actually constructing an add recurrence is relatively expensive. + const SCEVAddRecExpr *PreAR = [&]() { + FoldingSetNodeID ID; + ID.AddInteger(scAddRecExpr); + ID.AddPointer(PreStart); + ID.AddPointer(Step); + ID.AddPointer(L); + void *IP = nullptr; + return static_cast<SCEVAddRecExpr *>( + this->UniqueSCEVs.FindNodeOrInsertPos(ID, IP)); + }(); + + if (PreAR && PreAR->getNoWrapFlags(WrapType)) { // proves (2) + const SCEV *DeltaS = getConstant(StartC->getType(), Delta); + ICmpInst::Predicate Pred = ICmpInst::BAD_ICMP_PREDICATE; + const SCEV *Limit = ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep( + DeltaS, &Pred, this); + if (Limit && isKnownPredicate(Pred, PreAR, Limit)) // proves (1) + return true; + } + } + + return false; +} + const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op, Type *Ty) { assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) && @@ -1201,9 +1474,9 @@ const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op, // If we have special knowledge that this addrec won't overflow, // we don't need to do any further analysis. if (AR->getNoWrapFlags(SCEV::FlagNUW)) - return getAddRecExpr(getZeroExtendExpr(Start, Ty), - getZeroExtendExpr(Step, Ty), - L, AR->getNoWrapFlags()); + return getAddRecExpr( + getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this), + getZeroExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); // Check whether the backedge-taken count is SCEVCouldNotCompute. // Note that this serves two purposes: It filters out loops that are @@ -1240,9 +1513,9 @@ const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op, // Cache knowledge of AR NUW, which is propagated to this AddRec. const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNUW); // Return the expression with the addrec on the outside. - return getAddRecExpr(getZeroExtendExpr(Start, Ty), - getZeroExtendExpr(Step, Ty), - L, AR->getNoWrapFlags()); + return getAddRecExpr( + getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this), + getZeroExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); } // Similar to above, only this time treat the step value as signed. // This covers loops that count down. @@ -1255,9 +1528,9 @@ const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op, // Negative step causes unsigned wrap, but it still can't self-wrap. const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNW); // Return the expression with the addrec on the outside. - return getAddRecExpr(getZeroExtendExpr(Start, Ty), - getSignExtendExpr(Step, Ty), - L, AR->getNoWrapFlags()); + return getAddRecExpr( + getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this), + getSignExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); } } @@ -1275,9 +1548,9 @@ const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op, // Cache knowledge of AR NUW, which is propagated to this AddRec. const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNUW); // Return the expression with the addrec on the outside. - return getAddRecExpr(getZeroExtendExpr(Start, Ty), - getZeroExtendExpr(Step, Ty), - L, AR->getNoWrapFlags()); + return getAddRecExpr( + getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this), + getZeroExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); } } else if (isKnownNegative(Step)) { const SCEV *N = getConstant(APInt::getMaxValue(BitWidth) - @@ -1290,12 +1563,19 @@ const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op, // Negative step causes unsigned wrap, but it still can't self-wrap. const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNW); // Return the expression with the addrec on the outside. - return getAddRecExpr(getZeroExtendExpr(Start, Ty), - getSignExtendExpr(Step, Ty), - L, AR->getNoWrapFlags()); + return getAddRecExpr( + getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this), + getSignExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); } } } + + if (proveNoWrapByVaryingStart<SCEVZeroExtendExpr>(Start, Step, L)) { + const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNUW); + return getAddRecExpr( + getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this), + getZeroExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); + } } // The cast wasn't folded; create an explicit cast node. @@ -1307,104 +1587,6 @@ const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op, return S; } -// Get the limit of a recurrence such that incrementing by Step cannot cause -// signed overflow as long as the value of the recurrence within the loop does -// not exceed this limit before incrementing. -static const SCEV *getOverflowLimitForStep(const SCEV *Step, - ICmpInst::Predicate *Pred, - ScalarEvolution *SE) { - unsigned BitWidth = SE->getTypeSizeInBits(Step->getType()); - if (SE->isKnownPositive(Step)) { - *Pred = ICmpInst::ICMP_SLT; - return SE->getConstant(APInt::getSignedMinValue(BitWidth) - - SE->getSignedRange(Step).getSignedMax()); - } - if (SE->isKnownNegative(Step)) { - *Pred = ICmpInst::ICMP_SGT; - return SE->getConstant(APInt::getSignedMaxValue(BitWidth) - - SE->getSignedRange(Step).getSignedMin()); - } - return nullptr; -} - -// The recurrence AR has been shown to have no signed wrap. Typically, if we can -// prove NSW for AR, then we can just as easily prove NSW for its preincrement -// or postincrement sibling. This allows normalizing a sign extended AddRec as -// such: {sext(Step + Start),+,Step} => {(Step + sext(Start),+,Step} As a -// result, the expression "Step + sext(PreIncAR)" is congruent with -// "sext(PostIncAR)" -static const SCEV *getPreStartForSignExtend(const SCEVAddRecExpr *AR, - Type *Ty, - ScalarEvolution *SE) { - const Loop *L = AR->getLoop(); - const SCEV *Start = AR->getStart(); - const SCEV *Step = AR->getStepRecurrence(*SE); - - // Check for a simple looking step prior to loop entry. - const SCEVAddExpr *SA = dyn_cast<SCEVAddExpr>(Start); - if (!SA) - return nullptr; - - // Create an AddExpr for "PreStart" after subtracting Step. Full SCEV - // subtraction is expensive. For this purpose, perform a quick and dirty - // difference, by checking for Step in the operand list. - SmallVector<const SCEV *, 4> DiffOps; - for (const SCEV *Op : SA->operands()) - if (Op != Step) - DiffOps.push_back(Op); - - if (DiffOps.size() == SA->getNumOperands()) - return nullptr; - - // This is a postinc AR. Check for overflow on the preinc recurrence using the - // same three conditions that getSignExtendedExpr checks. - - // 1. NSW flags on the step increment. - const SCEV *PreStart = SE->getAddExpr(DiffOps, SA->getNoWrapFlags()); - const SCEVAddRecExpr *PreAR = dyn_cast<SCEVAddRecExpr>( - SE->getAddRecExpr(PreStart, Step, L, SCEV::FlagAnyWrap)); - - if (PreAR && PreAR->getNoWrapFlags(SCEV::FlagNSW)) - return PreStart; - - // 2. Direct overflow check on the step operation's expression. - unsigned BitWidth = SE->getTypeSizeInBits(AR->getType()); - Type *WideTy = IntegerType::get(SE->getContext(), BitWidth * 2); - const SCEV *OperandExtendedStart = - SE->getAddExpr(SE->getSignExtendExpr(PreStart, WideTy), - SE->getSignExtendExpr(Step, WideTy)); - if (SE->getSignExtendExpr(Start, WideTy) == OperandExtendedStart) { - // Cache knowledge of PreAR NSW. - if (PreAR) - const_cast<SCEVAddRecExpr *>(PreAR)->setNoWrapFlags(SCEV::FlagNSW); - // FIXME: this optimization needs a unit test - DEBUG(dbgs() << "SCEV: untested prestart overflow check\n"); - return PreStart; - } - - // 3. Loop precondition. - ICmpInst::Predicate Pred; - const SCEV *OverflowLimit = getOverflowLimitForStep(Step, &Pred, SE); - - if (OverflowLimit && - SE->isLoopEntryGuardedByCond(L, Pred, PreStart, OverflowLimit)) { - return PreStart; - } - return nullptr; -} - -// Get the normalized sign-extended expression for this AddRec's Start. -static const SCEV *getSignExtendAddRecStart(const SCEVAddRecExpr *AR, - Type *Ty, - ScalarEvolution *SE) { - const SCEV *PreStart = getPreStartForSignExtend(AR, Ty, SE); - if (!PreStart) - return SE->getSignExtendExpr(AR->getStart(), Ty); - - return SE->getAddExpr(SE->getSignExtendExpr(AR->getStepRecurrence(*SE), Ty), - SE->getSignExtendExpr(PreStart, Ty)); -} - const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op, Type *Ty) { assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) && @@ -1483,9 +1665,9 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op, // If we have special knowledge that this addrec won't overflow, // we don't need to do any further analysis. if (AR->getNoWrapFlags(SCEV::FlagNSW)) - return getAddRecExpr(getSignExtendAddRecStart(AR, Ty, this), - getSignExtendExpr(Step, Ty), - L, SCEV::FlagNSW); + return getAddRecExpr( + getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this), + getSignExtendExpr(Step, Ty), L, SCEV::FlagNSW); // Check whether the backedge-taken count is SCEVCouldNotCompute. // Note that this serves two purposes: It filters out loops that are @@ -1522,9 +1704,9 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op, // Cache knowledge of AR NSW, which is propagated to this AddRec. const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNSW); // Return the expression with the addrec on the outside. - return getAddRecExpr(getSignExtendAddRecStart(AR, Ty, this), - getSignExtendExpr(Step, Ty), - L, AR->getNoWrapFlags()); + return getAddRecExpr( + getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this), + getSignExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); } // Similar to above, only this time treat the step value as unsigned. // This covers loops that count up with an unsigned step. @@ -1533,12 +1715,20 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op, getMulExpr(WideMaxBECount, getZeroExtendExpr(Step, WideTy))); if (SAdd == OperandExtendedAdd) { - // Cache knowledge of AR NSW, which is propagated to this AddRec. - const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNSW); + // If AR wraps around then + // + // abs(Step) * MaxBECount > unsigned-max(AR->getType()) + // => SAdd != OperandExtendedAdd + // + // Thus (AR is not NW => SAdd != OperandExtendedAdd) <=> + // (SAdd == OperandExtendedAdd => AR is NW) + + const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNW); + // Return the expression with the addrec on the outside. - return getAddRecExpr(getSignExtendAddRecStart(AR, Ty, this), - getZeroExtendExpr(Step, Ty), - L, AR->getNoWrapFlags()); + return getAddRecExpr( + getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this), + getZeroExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); } } @@ -1547,7 +1737,8 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op, // with the start value and the backedge is guarded by a comparison // with the post-inc value, the addrec is safe. ICmpInst::Predicate Pred; - const SCEV *OverflowLimit = getOverflowLimitForStep(Step, &Pred, this); + const SCEV *OverflowLimit = + getSignedOverflowLimitForStep(Step, &Pred, this); if (OverflowLimit && (isLoopBackedgeGuardedByCond(L, Pred, AR, OverflowLimit) || (isLoopEntryGuardedByCond(L, Pred, Start, OverflowLimit) && @@ -1555,9 +1746,9 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op, OverflowLimit)))) { // Cache knowledge of AR NSW, then propagate NSW to the wide AddRec. const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNSW); - return getAddRecExpr(getSignExtendAddRecStart(AR, Ty, this), - getSignExtendExpr(Step, Ty), - L, AR->getNoWrapFlags()); + return getAddRecExpr( + getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this), + getSignExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); } } // If Start and Step are constants, check if we can apply this @@ -1576,6 +1767,13 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op, return getAddExpr(Start, getSignExtendExpr(NewAR, Ty)); } } + + if (proveNoWrapByVaryingStart<SCEVSignExtendExpr>(Start, Step, L)) { + const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNSW); + return getAddRecExpr( + getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this), + getSignExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); + } } // The cast wasn't folded; create an explicit cast node. @@ -2169,8 +2367,7 @@ static bool containsConstantSomewhere(const SCEV *StartExpr) { if (isa<SCEVAddExpr>(*CurrentExpr) || isa<SCEVMulExpr>(*CurrentExpr)) { const auto *CurrentNAry = cast<SCEVNAryExpr>(CurrentExpr); - for (const SCEV *Operand : CurrentNAry->operands()) - Ops.push_back(Operand); + Ops.append(CurrentNAry->op_begin(), CurrentNAry->op_end()); } } return false; @@ -2729,6 +2926,56 @@ ScalarEvolution::getAddRecExpr(SmallVectorImpl<const SCEV *> &Operands, return S; } +const SCEV * +ScalarEvolution::getGEPExpr(Type *PointeeType, const SCEV *BaseExpr, + const SmallVectorImpl<const SCEV *> &IndexExprs, + bool InBounds) { + // getSCEV(Base)->getType() has the same address space as Base->getType() + // because SCEV::getType() preserves the address space. + Type *IntPtrTy = getEffectiveSCEVType(BaseExpr->getType()); + // FIXME(PR23527): Don't blindly transfer the inbounds flag from the GEP + // instruction to its SCEV, because the Instruction may be guarded by control + // flow and the no-overflow bits may not be valid for the expression in any + // context. + SCEV::NoWrapFlags Wrap = InBounds ? SCEV::FlagNSW : SCEV::FlagAnyWrap; + + const SCEV *TotalOffset = getConstant(IntPtrTy, 0); + // The address space is unimportant. The first thing we do on CurTy is getting + // its element type. + Type *CurTy = PointerType::getUnqual(PointeeType); + for (const SCEV *IndexExpr : IndexExprs) { + // Compute the (potentially symbolic) offset in bytes for this index. + if (StructType *STy = dyn_cast<StructType>(CurTy)) { + // For a struct, add the member offset. + ConstantInt *Index = cast<SCEVConstant>(IndexExpr)->getValue(); + unsigned FieldNo = Index->getZExtValue(); + const SCEV *FieldOffset = getOffsetOfExpr(IntPtrTy, STy, FieldNo); + + // Add the field offset to the running total offset. + TotalOffset = getAddExpr(TotalOffset, FieldOffset); + + // Update CurTy to the type of the field at Index. + CurTy = STy->getTypeAtIndex(Index); + } else { + // Update CurTy to its element type. + CurTy = cast<SequentialType>(CurTy)->getElementType(); + // For an array, add the element offset, explicitly scaled. + const SCEV *ElementSize = getSizeOfExpr(IntPtrTy, CurTy); + // Getelementptr indices are signed. + IndexExpr = getTruncateOrSignExtend(IndexExpr, IntPtrTy); + + // Multiply the index by the element size to compute the element offset. + const SCEV *LocalOffset = getMulExpr(IndexExpr, ElementSize, Wrap); + + // Add the element offset to the running total offset. + TotalOffset = getAddExpr(TotalOffset, LocalOffset); + } + } + + // Add the total offset from all the GEP indices to the base. + return getAddExpr(BaseExpr, TotalOffset, Wrap); +} + const SCEV *ScalarEvolution::getSMaxExpr(const SCEV *LHS, const SCEV *RHS) { SmallVector<const SCEV *, 2> Ops; @@ -2950,39 +3197,23 @@ const SCEV *ScalarEvolution::getUMinExpr(const SCEV *LHS, } const SCEV *ScalarEvolution::getSizeOfExpr(Type *IntTy, Type *AllocTy) { - // If we have DataLayout, we can bypass creating a target-independent + // We can bypass creating a target-independent // constant expression and then folding it back into a ConstantInt. // This is just a compile-time optimization. - if (DL) - return getConstant(IntTy, DL->getTypeAllocSize(AllocTy)); - - Constant *C = ConstantExpr::getSizeOf(AllocTy); - if (ConstantExpr *CE = dyn_cast<ConstantExpr>(C)) - if (Constant *Folded = ConstantFoldConstantExpression(CE, DL, TLI)) - C = Folded; - Type *Ty = getEffectiveSCEVType(PointerType::getUnqual(AllocTy)); - assert(Ty == IntTy && "Effective SCEV type doesn't match"); - return getTruncateOrZeroExtend(getSCEV(C), Ty); + return getConstant(IntTy, + F->getParent()->getDataLayout().getTypeAllocSize(AllocTy)); } const SCEV *ScalarEvolution::getOffsetOfExpr(Type *IntTy, StructType *STy, unsigned FieldNo) { - // If we have DataLayout, we can bypass creating a target-independent + // We can bypass creating a target-independent // constant expression and then folding it back into a ConstantInt. // This is just a compile-time optimization. - if (DL) { - return getConstant(IntTy, - DL->getStructLayout(STy)->getElementOffset(FieldNo)); - } - - Constant *C = ConstantExpr::getOffsetOf(STy, FieldNo); - if (ConstantExpr *CE = dyn_cast<ConstantExpr>(C)) - if (Constant *Folded = ConstantFoldConstantExpression(CE, DL, TLI)) - C = Folded; - - Type *Ty = getEffectiveSCEVType(PointerType::getUnqual(STy)); - return getTruncateOrZeroExtend(getSCEV(C), Ty); + return getConstant( + IntTy, + F->getParent()->getDataLayout().getStructLayout(STy)->getElementOffset( + FieldNo)); } const SCEV *ScalarEvolution::getUnknown(Value *V) { @@ -3024,19 +3255,7 @@ bool ScalarEvolution::isSCEVable(Type *Ty) const { /// for which isSCEVable must return true. uint64_t ScalarEvolution::getTypeSizeInBits(Type *Ty) const { assert(isSCEVable(Ty) && "Type is not SCEVable!"); - - // If we have a DataLayout, use it! - if (DL) - return DL->getTypeSizeInBits(Ty); - - // Integer types have fixed sizes. - if (Ty->isIntegerTy()) - return Ty->getPrimitiveSizeInBits(); - - // The only other support type is pointer. Without DataLayout, conservatively - // assume pointers are 64-bit. - assert(Ty->isPointerTy() && "isSCEVable permitted a non-SCEVable type!"); - return 64; + return F->getParent()->getDataLayout().getTypeSizeInBits(Ty); } /// getEffectiveSCEVType - Return a type with the same bitwidth as @@ -3052,12 +3271,7 @@ Type *ScalarEvolution::getEffectiveSCEVType(Type *Ty) const { // The only other support type is pointer. assert(Ty->isPointerTy() && "Unexpected non-pointer non-integer type!"); - - if (DL) - return DL->getIntPtrType(Ty); - - // Without DataLayout, conservatively assume pointers are 64-bit. - return Type::getInt64Ty(getContext()); + return F->getParent()->getDataLayout().getIntPtrType(Ty); } const SCEV *ScalarEvolution::getCouldNotCompute() { @@ -3444,10 +3658,12 @@ const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) { // If the increment doesn't overflow, then neither the addrec nor // the post-increment will overflow. if (const AddOperator *OBO = dyn_cast<AddOperator>(BEValueV)) { - if (OBO->hasNoUnsignedWrap()) - Flags = setFlags(Flags, SCEV::FlagNUW); - if (OBO->hasNoSignedWrap()) - Flags = setFlags(Flags, SCEV::FlagNSW); + if (OBO->getOperand(0) == PN) { + if (OBO->hasNoUnsignedWrap()) + Flags = setFlags(Flags, SCEV::FlagNUW); + if (OBO->hasNoSignedWrap()) + Flags = setFlags(Flags, SCEV::FlagNSW); + } } else if (GEPOperator *GEP = dyn_cast<GEPOperator>(BEValueV)) { // If the increment is an inbounds GEP, then we know the address // space cannot be wrapped around. We cannot make any guarantee @@ -3455,7 +3671,7 @@ const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) { // unsigned but we may have a negative index from the base // pointer. We can guarantee that no unsigned wrap occurs if the // indices form a positive value. - if (GEP->isInBounds()) { + if (GEP->isInBounds() && GEP->getOperand(0) == PN) { Flags = setFlags(Flags, SCEV::FlagNW); const SCEV *Ptr = getSCEV(GEP->getPointerOperand()); @@ -3521,7 +3737,8 @@ const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) { // PHI's incoming blocks are in a different loop, in which case doing so // risks breaking LCSSA form. Instcombine would normally zap these, but // it doesn't have DominatorTree information, so it may miss cases. - if (Value *V = SimplifyInstruction(PN, DL, TLI, DT, AC)) + if (Value *V = + SimplifyInstruction(PN, F->getParent()->getDataLayout(), TLI, DT, AC)) if (LI->replacementPreservesLCSSAForm(PN, V)) return getSCEV(V); @@ -3533,52 +3750,16 @@ const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) { /// operations. This allows them to be analyzed by regular SCEV code. /// const SCEV *ScalarEvolution::createNodeForGEP(GEPOperator *GEP) { - Type *IntPtrTy = getEffectiveSCEVType(GEP->getType()); Value *Base = GEP->getOperand(0); // Don't attempt to analyze GEPs over unsized objects. if (!Base->getType()->getPointerElementType()->isSized()) return getUnknown(GEP); - // Don't blindly transfer the inbounds flag from the GEP instruction to the - // Add expression, because the Instruction may be guarded by control flow - // and the no-overflow bits may not be valid for the expression in any - // context. - SCEV::NoWrapFlags Wrap = GEP->isInBounds() ? SCEV::FlagNSW : SCEV::FlagAnyWrap; - - const SCEV *TotalOffset = getConstant(IntPtrTy, 0); - gep_type_iterator GTI = gep_type_begin(GEP); - for (GetElementPtrInst::op_iterator I = std::next(GEP->op_begin()), - E = GEP->op_end(); - I != E; ++I) { - Value *Index = *I; - // Compute the (potentially symbolic) offset in bytes for this index. - if (StructType *STy = dyn_cast<StructType>(*GTI++)) { - // For a struct, add the member offset. - unsigned FieldNo = cast<ConstantInt>(Index)->getZExtValue(); - const SCEV *FieldOffset = getOffsetOfExpr(IntPtrTy, STy, FieldNo); - - // Add the field offset to the running total offset. - TotalOffset = getAddExpr(TotalOffset, FieldOffset); - } else { - // For an array, add the element offset, explicitly scaled. - const SCEV *ElementSize = getSizeOfExpr(IntPtrTy, *GTI); - const SCEV *IndexS = getSCEV(Index); - // Getelementptr indices are signed. - IndexS = getTruncateOrSignExtend(IndexS, IntPtrTy); - - // Multiply the index by the element size to compute the element offset. - const SCEV *LocalOffset = getMulExpr(IndexS, ElementSize, Wrap); - - // Add the element offset to the running total offset. - TotalOffset = getAddExpr(TotalOffset, LocalOffset); - } - } - - // Get the SCEV for the GEP base. - const SCEV *BaseS = getSCEV(Base); - - // Add the total offset from all the GEP indices to the base. - return getAddExpr(BaseS, TotalOffset, Wrap); + SmallVector<const SCEV *, 4> IndexExprs; + for (auto Index = GEP->idx_begin(); Index != GEP->idx_end(); ++Index) + IndexExprs.push_back(getSCEV(*Index)); + return getGEPExpr(GEP->getSourceElementType(), getSCEV(Base), IndexExprs, + GEP->isInBounds()); } /// GetMinTrailingZeros - Determine the minimum number of zero bits that S is @@ -3653,7 +3834,8 @@ ScalarEvolution::GetMinTrailingZeros(const SCEV *S) { // For a SCEVUnknown, ask ValueTracking. unsigned BitWidth = getTypeSizeInBits(U->getType()); APInt Zeros(BitWidth, 0), Ones(BitWidth, 0); - computeKnownBits(U->getValue(), Zeros, Ones, DL, 0, AC, nullptr, DT); + computeKnownBits(U->getValue(), Zeros, Ones, + F->getParent()->getDataLayout(), 0, AC, nullptr, DT); return Zeros.countTrailingOnes(); } @@ -3688,79 +3870,93 @@ static Optional<ConstantRange> GetRangeFromMetadata(Value *V) { return None; } -/// getUnsignedRange - Determine the unsigned range for a particular SCEV. +/// getRange - Determine the range for a particular SCEV. If SignHint is +/// HINT_RANGE_UNSIGNED (resp. HINT_RANGE_SIGNED) then getRange prefers ranges +/// with a "cleaner" unsigned (resp. signed) representation. /// ConstantRange -ScalarEvolution::getUnsignedRange(const SCEV *S) { +ScalarEvolution::getRange(const SCEV *S, + ScalarEvolution::RangeSignHint SignHint) { + DenseMap<const SCEV *, ConstantRange> &Cache = + SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? UnsignedRanges + : SignedRanges; + // See if we've computed this range already. - DenseMap<const SCEV *, ConstantRange>::iterator I = UnsignedRanges.find(S); - if (I != UnsignedRanges.end()) + DenseMap<const SCEV *, ConstantRange>::iterator I = Cache.find(S); + if (I != Cache.end()) return I->second; if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S)) - return setUnsignedRange(C, ConstantRange(C->getValue()->getValue())); + return setRange(C, SignHint, ConstantRange(C->getValue()->getValue())); unsigned BitWidth = getTypeSizeInBits(S->getType()); ConstantRange ConservativeResult(BitWidth, /*isFullSet=*/true); - // If the value has known zeros, the maximum unsigned value will have those - // known zeros as well. + // If the value has known zeros, the maximum value will have those known zeros + // as well. uint32_t TZ = GetMinTrailingZeros(S); - if (TZ != 0) - ConservativeResult = - ConstantRange(APInt::getMinValue(BitWidth), - APInt::getMaxValue(BitWidth).lshr(TZ).shl(TZ) + 1); + if (TZ != 0) { + if (SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED) + ConservativeResult = + ConstantRange(APInt::getMinValue(BitWidth), + APInt::getMaxValue(BitWidth).lshr(TZ).shl(TZ) + 1); + else + ConservativeResult = ConstantRange( + APInt::getSignedMinValue(BitWidth), + APInt::getSignedMaxValue(BitWidth).ashr(TZ).shl(TZ) + 1); + } if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(S)) { - ConstantRange X = getUnsignedRange(Add->getOperand(0)); + ConstantRange X = getRange(Add->getOperand(0), SignHint); for (unsigned i = 1, e = Add->getNumOperands(); i != e; ++i) - X = X.add(getUnsignedRange(Add->getOperand(i))); - return setUnsignedRange(Add, ConservativeResult.intersectWith(X)); + X = X.add(getRange(Add->getOperand(i), SignHint)); + return setRange(Add, SignHint, ConservativeResult.intersectWith(X)); } if (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(S)) { - ConstantRange X = getUnsignedRange(Mul->getOperand(0)); + ConstantRange X = getRange(Mul->getOperand(0), SignHint); for (unsigned i = 1, e = Mul->getNumOperands(); i != e; ++i) - X = X.multiply(getUnsignedRange(Mul->getOperand(i))); - return setUnsignedRange(Mul, ConservativeResult.intersectWith(X)); + X = X.multiply(getRange(Mul->getOperand(i), SignHint)); + return setRange(Mul, SignHint, ConservativeResult.intersectWith(X)); } if (const SCEVSMaxExpr *SMax = dyn_cast<SCEVSMaxExpr>(S)) { - ConstantRange X = getUnsignedRange(SMax->getOperand(0)); + ConstantRange X = getRange(SMax->getOperand(0), SignHint); for (unsigned i = 1, e = SMax->getNumOperands(); i != e; ++i) - X = X.smax(getUnsignedRange(SMax->getOperand(i))); - return setUnsignedRange(SMax, ConservativeResult.intersectWith(X)); + X = X.smax(getRange(SMax->getOperand(i), SignHint)); + return setRange(SMax, SignHint, ConservativeResult.intersectWith(X)); } if (const SCEVUMaxExpr *UMax = dyn_cast<SCEVUMaxExpr>(S)) { - ConstantRange X = getUnsignedRange(UMax->getOperand(0)); + ConstantRange X = getRange(UMax->getOperand(0), SignHint); for (unsigned i = 1, e = UMax->getNumOperands(); i != e; ++i) - X = X.umax(getUnsignedRange(UMax->getOperand(i))); - return setUnsignedRange(UMax, ConservativeResult.intersectWith(X)); + X = X.umax(getRange(UMax->getOperand(i), SignHint)); + return setRange(UMax, SignHint, ConservativeResult.intersectWith(X)); } if (const SCEVUDivExpr *UDiv = dyn_cast<SCEVUDivExpr>(S)) { - ConstantRange X = getUnsignedRange(UDiv->getLHS()); - ConstantRange Y = getUnsignedRange(UDiv->getRHS()); - return setUnsignedRange(UDiv, ConservativeResult.intersectWith(X.udiv(Y))); + ConstantRange X = getRange(UDiv->getLHS(), SignHint); + ConstantRange Y = getRange(UDiv->getRHS(), SignHint); + return setRange(UDiv, SignHint, + ConservativeResult.intersectWith(X.udiv(Y))); } if (const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(S)) { - ConstantRange X = getUnsignedRange(ZExt->getOperand()); - return setUnsignedRange(ZExt, - ConservativeResult.intersectWith(X.zeroExtend(BitWidth))); + ConstantRange X = getRange(ZExt->getOperand(), SignHint); + return setRange(ZExt, SignHint, + ConservativeResult.intersectWith(X.zeroExtend(BitWidth))); } if (const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(S)) { - ConstantRange X = getUnsignedRange(SExt->getOperand()); - return setUnsignedRange(SExt, - ConservativeResult.intersectWith(X.signExtend(BitWidth))); + ConstantRange X = getRange(SExt->getOperand(), SignHint); + return setRange(SExt, SignHint, + ConservativeResult.intersectWith(X.signExtend(BitWidth))); } if (const SCEVTruncateExpr *Trunc = dyn_cast<SCEVTruncateExpr>(S)) { - ConstantRange X = getUnsignedRange(Trunc->getOperand()); - return setUnsignedRange(Trunc, - ConservativeResult.intersectWith(X.truncate(BitWidth))); + ConstantRange X = getRange(Trunc->getOperand(), SignHint); + return setRange(Trunc, SignHint, + ConservativeResult.intersectWith(X.truncate(BitWidth))); } if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(S)) { @@ -3773,143 +3969,6 @@ ScalarEvolution::getUnsignedRange(const SCEV *S) { ConservativeResult.intersectWith( ConstantRange(C->getValue()->getValue(), APInt(BitWidth, 0))); - // TODO: non-affine addrec - if (AddRec->isAffine()) { - Type *Ty = AddRec->getType(); - const SCEV *MaxBECount = getMaxBackedgeTakenCount(AddRec->getLoop()); - if (!isa<SCEVCouldNotCompute>(MaxBECount) && - getTypeSizeInBits(MaxBECount->getType()) <= BitWidth) { - MaxBECount = getNoopOrZeroExtend(MaxBECount, Ty); - - const SCEV *Start = AddRec->getStart(); - const SCEV *Step = AddRec->getStepRecurrence(*this); - - ConstantRange StartRange = getUnsignedRange(Start); - ConstantRange StepRange = getSignedRange(Step); - ConstantRange MaxBECountRange = getUnsignedRange(MaxBECount); - ConstantRange EndRange = - StartRange.add(MaxBECountRange.multiply(StepRange)); - - // Check for overflow. This must be done with ConstantRange arithmetic - // because we could be called from within the ScalarEvolution overflow - // checking code. - ConstantRange ExtStartRange = StartRange.zextOrTrunc(BitWidth*2+1); - ConstantRange ExtStepRange = StepRange.sextOrTrunc(BitWidth*2+1); - ConstantRange ExtMaxBECountRange = - MaxBECountRange.zextOrTrunc(BitWidth*2+1); - ConstantRange ExtEndRange = EndRange.zextOrTrunc(BitWidth*2+1); - if (ExtStartRange.add(ExtMaxBECountRange.multiply(ExtStepRange)) != - ExtEndRange) - return setUnsignedRange(AddRec, ConservativeResult); - - APInt Min = APIntOps::umin(StartRange.getUnsignedMin(), - EndRange.getUnsignedMin()); - APInt Max = APIntOps::umax(StartRange.getUnsignedMax(), - EndRange.getUnsignedMax()); - if (Min.isMinValue() && Max.isMaxValue()) - return setUnsignedRange(AddRec, ConservativeResult); - return setUnsignedRange(AddRec, - ConservativeResult.intersectWith(ConstantRange(Min, Max+1))); - } - } - - return setUnsignedRange(AddRec, ConservativeResult); - } - - if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S)) { - // Check if the IR explicitly contains !range metadata. - Optional<ConstantRange> MDRange = GetRangeFromMetadata(U->getValue()); - if (MDRange.hasValue()) - ConservativeResult = ConservativeResult.intersectWith(MDRange.getValue()); - - // For a SCEVUnknown, ask ValueTracking. - APInt Zeros(BitWidth, 0), Ones(BitWidth, 0); - computeKnownBits(U->getValue(), Zeros, Ones, DL, 0, AC, nullptr, DT); - if (Ones == ~Zeros + 1) - return setUnsignedRange(U, ConservativeResult); - return setUnsignedRange(U, - ConservativeResult.intersectWith(ConstantRange(Ones, ~Zeros + 1))); - } - - return setUnsignedRange(S, ConservativeResult); -} - -/// getSignedRange - Determine the signed range for a particular SCEV. -/// -ConstantRange -ScalarEvolution::getSignedRange(const SCEV *S) { - // See if we've computed this range already. - DenseMap<const SCEV *, ConstantRange>::iterator I = SignedRanges.find(S); - if (I != SignedRanges.end()) - return I->second; - - if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S)) - return setSignedRange(C, ConstantRange(C->getValue()->getValue())); - - unsigned BitWidth = getTypeSizeInBits(S->getType()); - ConstantRange ConservativeResult(BitWidth, /*isFullSet=*/true); - - // If the value has known zeros, the maximum signed value will have those - // known zeros as well. - uint32_t TZ = GetMinTrailingZeros(S); - if (TZ != 0) - ConservativeResult = - ConstantRange(APInt::getSignedMinValue(BitWidth), - APInt::getSignedMaxValue(BitWidth).ashr(TZ).shl(TZ) + 1); - - if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(S)) { - ConstantRange X = getSignedRange(Add->getOperand(0)); - for (unsigned i = 1, e = Add->getNumOperands(); i != e; ++i) - X = X.add(getSignedRange(Add->getOperand(i))); - return setSignedRange(Add, ConservativeResult.intersectWith(X)); - } - - if (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(S)) { - ConstantRange X = getSignedRange(Mul->getOperand(0)); - for (unsigned i = 1, e = Mul->getNumOperands(); i != e; ++i) - X = X.multiply(getSignedRange(Mul->getOperand(i))); - return setSignedRange(Mul, ConservativeResult.intersectWith(X)); - } - - if (const SCEVSMaxExpr *SMax = dyn_cast<SCEVSMaxExpr>(S)) { - ConstantRange X = getSignedRange(SMax->getOperand(0)); - for (unsigned i = 1, e = SMax->getNumOperands(); i != e; ++i) - X = X.smax(getSignedRange(SMax->getOperand(i))); - return setSignedRange(SMax, ConservativeResult.intersectWith(X)); - } - - if (const SCEVUMaxExpr *UMax = dyn_cast<SCEVUMaxExpr>(S)) { - ConstantRange X = getSignedRange(UMax->getOperand(0)); - for (unsigned i = 1, e = UMax->getNumOperands(); i != e; ++i) - X = X.umax(getSignedRange(UMax->getOperand(i))); - return setSignedRange(UMax, ConservativeResult.intersectWith(X)); - } - - if (const SCEVUDivExpr *UDiv = dyn_cast<SCEVUDivExpr>(S)) { - ConstantRange X = getSignedRange(UDiv->getLHS()); - ConstantRange Y = getSignedRange(UDiv->getRHS()); - return setSignedRange(UDiv, ConservativeResult.intersectWith(X.udiv(Y))); - } - - if (const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(S)) { - ConstantRange X = getSignedRange(ZExt->getOperand()); - return setSignedRange(ZExt, - ConservativeResult.intersectWith(X.zeroExtend(BitWidth))); - } - - if (const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(S)) { - ConstantRange X = getSignedRange(SExt->getOperand()); - return setSignedRange(SExt, - ConservativeResult.intersectWith(X.signExtend(BitWidth))); - } - - if (const SCEVTruncateExpr *Trunc = dyn_cast<SCEVTruncateExpr>(S)) { - ConstantRange X = getSignedRange(Trunc->getOperand()); - return setSignedRange(Trunc, - ConservativeResult.intersectWith(X.truncate(BitWidth))); - } - - if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(S)) { // If there's no signed wrap, and all the operands have the same sign or // zero, the value won't ever change sign. if (AddRec->getNoWrapFlags(SCEV::FlagNSW)) { @@ -3935,41 +3994,66 @@ ScalarEvolution::getSignedRange(const SCEV *S) { const SCEV *MaxBECount = getMaxBackedgeTakenCount(AddRec->getLoop()); if (!isa<SCEVCouldNotCompute>(MaxBECount) && getTypeSizeInBits(MaxBECount->getType()) <= BitWidth) { + + // Check for overflow. This must be done with ConstantRange arithmetic + // because we could be called from within the ScalarEvolution overflow + // checking code. + MaxBECount = getNoopOrZeroExtend(MaxBECount, Ty); + ConstantRange MaxBECountRange = getUnsignedRange(MaxBECount); + ConstantRange ZExtMaxBECountRange = + MaxBECountRange.zextOrTrunc(BitWidth * 2 + 1); const SCEV *Start = AddRec->getStart(); const SCEV *Step = AddRec->getStepRecurrence(*this); + ConstantRange StepSRange = getSignedRange(Step); + ConstantRange SExtStepSRange = StepSRange.sextOrTrunc(BitWidth * 2 + 1); + + ConstantRange StartURange = getUnsignedRange(Start); + ConstantRange EndURange = + StartURange.add(MaxBECountRange.multiply(StepSRange)); + + // Check for unsigned overflow. + ConstantRange ZExtStartURange = + StartURange.zextOrTrunc(BitWidth * 2 + 1); + ConstantRange ZExtEndURange = EndURange.zextOrTrunc(BitWidth * 2 + 1); + if (ZExtStartURange.add(ZExtMaxBECountRange.multiply(SExtStepSRange)) == + ZExtEndURange) { + APInt Min = APIntOps::umin(StartURange.getUnsignedMin(), + EndURange.getUnsignedMin()); + APInt Max = APIntOps::umax(StartURange.getUnsignedMax(), + EndURange.getUnsignedMax()); + bool IsFullRange = Min.isMinValue() && Max.isMaxValue(); + if (!IsFullRange) + ConservativeResult = + ConservativeResult.intersectWith(ConstantRange(Min, Max + 1)); + } - ConstantRange StartRange = getSignedRange(Start); - ConstantRange StepRange = getSignedRange(Step); - ConstantRange MaxBECountRange = getUnsignedRange(MaxBECount); - ConstantRange EndRange = - StartRange.add(MaxBECountRange.multiply(StepRange)); - - // Check for overflow. This must be done with ConstantRange arithmetic - // because we could be called from within the ScalarEvolution overflow - // checking code. - ConstantRange ExtStartRange = StartRange.sextOrTrunc(BitWidth*2+1); - ConstantRange ExtStepRange = StepRange.sextOrTrunc(BitWidth*2+1); - ConstantRange ExtMaxBECountRange = - MaxBECountRange.zextOrTrunc(BitWidth*2+1); - ConstantRange ExtEndRange = EndRange.sextOrTrunc(BitWidth*2+1); - if (ExtStartRange.add(ExtMaxBECountRange.multiply(ExtStepRange)) != - ExtEndRange) - return setSignedRange(AddRec, ConservativeResult); - - APInt Min = APIntOps::smin(StartRange.getSignedMin(), - EndRange.getSignedMin()); - APInt Max = APIntOps::smax(StartRange.getSignedMax(), - EndRange.getSignedMax()); - if (Min.isMinSignedValue() && Max.isMaxSignedValue()) - return setSignedRange(AddRec, ConservativeResult); - return setSignedRange(AddRec, - ConservativeResult.intersectWith(ConstantRange(Min, Max+1))); + ConstantRange StartSRange = getSignedRange(Start); + ConstantRange EndSRange = + StartSRange.add(MaxBECountRange.multiply(StepSRange)); + + // Check for signed overflow. This must be done with ConstantRange + // arithmetic because we could be called from within the ScalarEvolution + // overflow checking code. + ConstantRange SExtStartSRange = + StartSRange.sextOrTrunc(BitWidth * 2 + 1); + ConstantRange SExtEndSRange = EndSRange.sextOrTrunc(BitWidth * 2 + 1); + if (SExtStartSRange.add(ZExtMaxBECountRange.multiply(SExtStepSRange)) == + SExtEndSRange) { + APInt Min = APIntOps::smin(StartSRange.getSignedMin(), + EndSRange.getSignedMin()); + APInt Max = APIntOps::smax(StartSRange.getSignedMax(), + EndSRange.getSignedMax()); + bool IsFullRange = Min.isMinSignedValue() && Max.isMaxSignedValue(); + if (!IsFullRange) + ConservativeResult = + ConservativeResult.intersectWith(ConstantRange(Min, Max + 1)); + } } } - return setSignedRange(AddRec, ConservativeResult); + return setRange(AddRec, SignHint, ConservativeResult); } if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S)) { @@ -3978,18 +4062,31 @@ ScalarEvolution::getSignedRange(const SCEV *S) { if (MDRange.hasValue()) ConservativeResult = ConservativeResult.intersectWith(MDRange.getValue()); - // For a SCEVUnknown, ask ValueTracking. - if (!U->getValue()->getType()->isIntegerTy() && !DL) - return setSignedRange(U, ConservativeResult); - unsigned NS = ComputeNumSignBits(U->getValue(), DL, 0, AC, nullptr, DT); - if (NS <= 1) - return setSignedRange(U, ConservativeResult); - return setSignedRange(U, ConservativeResult.intersectWith( - ConstantRange(APInt::getSignedMinValue(BitWidth).ashr(NS - 1), - APInt::getSignedMaxValue(BitWidth).ashr(NS - 1)+1))); + // Split here to avoid paying the compile-time cost of calling both + // computeKnownBits and ComputeNumSignBits. This restriction can be lifted + // if needed. + const DataLayout &DL = F->getParent()->getDataLayout(); + if (SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED) { + // For a SCEVUnknown, ask ValueTracking. + APInt Zeros(BitWidth, 0), Ones(BitWidth, 0); + computeKnownBits(U->getValue(), Zeros, Ones, DL, 0, AC, nullptr, DT); + if (Ones != ~Zeros + 1) + ConservativeResult = + ConservativeResult.intersectWith(ConstantRange(Ones, ~Zeros + 1)); + } else { + assert(SignHint == ScalarEvolution::HINT_RANGE_SIGNED && + "generalize as needed!"); + unsigned NS = ComputeNumSignBits(U->getValue(), DL, 0, AC, nullptr, DT); + if (NS > 1) + ConservativeResult = ConservativeResult.intersectWith( + ConstantRange(APInt::getSignedMinValue(BitWidth).ashr(NS - 1), + APInt::getSignedMaxValue(BitWidth).ashr(NS - 1) + 1)); + } + + return setRange(U, SignHint, ConservativeResult); } - return setSignedRange(S, ConservativeResult); + return setRange(S, SignHint, ConservativeResult); } /// createSCEV - We know that there is no SCEV for the specified value. @@ -4088,8 +4185,8 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { unsigned TZ = A.countTrailingZeros(); unsigned BitWidth = A.getBitWidth(); APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0); - computeKnownBits(U->getOperand(0), KnownZero, KnownOne, DL, 0, AC, - nullptr, DT); + computeKnownBits(U->getOperand(0), KnownZero, KnownOne, + F->getParent()->getDataLayout(), 0, AC, nullptr, DT); APInt EffectiveMask = APInt::getLowBitsSet(BitWidth, BitWidth - LZ - TZ).shl(TZ); @@ -4280,9 +4377,10 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { case ICmpInst::ICMP_SGE: // a >s b ? a+x : b+x -> smax(a, b)+x // a >s b ? b+x : a+x -> smin(a, b)+x - if (LHS->getType() == U->getType()) { - const SCEV *LS = getSCEV(LHS); - const SCEV *RS = getSCEV(RHS); + if (getTypeSizeInBits(LHS->getType()) <= + getTypeSizeInBits(U->getType())) { + const SCEV *LS = getNoopOrSignExtend(getSCEV(LHS), U->getType()); + const SCEV *RS = getNoopOrSignExtend(getSCEV(RHS), U->getType()); const SCEV *LA = getSCEV(U->getOperand(1)); const SCEV *RA = getSCEV(U->getOperand(2)); const SCEV *LDiff = getMinusSCEV(LA, LS); @@ -4303,9 +4401,10 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { case ICmpInst::ICMP_UGE: // a >u b ? a+x : b+x -> umax(a, b)+x // a >u b ? b+x : a+x -> umin(a, b)+x - if (LHS->getType() == U->getType()) { - const SCEV *LS = getSCEV(LHS); - const SCEV *RS = getSCEV(RHS); + if (getTypeSizeInBits(LHS->getType()) <= + getTypeSizeInBits(U->getType())) { + const SCEV *LS = getNoopOrZeroExtend(getSCEV(LHS), U->getType()); + const SCEV *RS = getNoopOrZeroExtend(getSCEV(RHS), U->getType()); const SCEV *LA = getSCEV(U->getOperand(1)); const SCEV *RA = getSCEV(U->getOperand(2)); const SCEV *LDiff = getMinusSCEV(LA, LS); @@ -4320,11 +4419,11 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { break; case ICmpInst::ICMP_NE: // n != 0 ? n+x : 1+x -> umax(n, 1)+x - if (LHS->getType() == U->getType() && - isa<ConstantInt>(RHS) && - cast<ConstantInt>(RHS)->isZero()) { - const SCEV *One = getConstant(LHS->getType(), 1); - const SCEV *LS = getSCEV(LHS); + if (getTypeSizeInBits(LHS->getType()) <= + getTypeSizeInBits(U->getType()) && + isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero()) { + const SCEV *One = getConstant(U->getType(), 1); + const SCEV *LS = getNoopOrZeroExtend(getSCEV(LHS), U->getType()); const SCEV *LA = getSCEV(U->getOperand(1)); const SCEV *RA = getSCEV(U->getOperand(2)); const SCEV *LDiff = getMinusSCEV(LA, LS); @@ -4335,11 +4434,11 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { break; case ICmpInst::ICMP_EQ: // n == 0 ? 1+x : n+x -> umax(n, 1)+x - if (LHS->getType() == U->getType() && - isa<ConstantInt>(RHS) && - cast<ConstantInt>(RHS)->isZero()) { - const SCEV *One = getConstant(LHS->getType(), 1); - const SCEV *LS = getSCEV(LHS); + if (getTypeSizeInBits(LHS->getType()) <= + getTypeSizeInBits(U->getType()) && + isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero()) { + const SCEV *One = getConstant(U->getType(), 1); + const SCEV *LS = getNoopOrZeroExtend(getSCEV(LHS), U->getType()); const SCEV *LA = getSCEV(U->getOperand(1)); const SCEV *RA = getSCEV(U->getOperand(2)); const SCEV *LDiff = getMinusSCEV(LA, One); @@ -5238,12 +5337,9 @@ static bool canConstantEvolve(Instruction *I, const Loop *L) { if (!L->contains(I)) return false; if (isa<PHINode>(I)) { - if (L->getHeader() == I->getParent()) - return true; - else - // We don't currently keep track of the control flow needed to evaluate - // PHIs, so we cannot handle PHIs inside of loops. - return false; + // We don't currently keep track of the control flow needed to evaluate + // PHIs, so we cannot handle PHIs inside of loops. + return L->getHeader() == I->getParent(); } // If we won't be able to constant fold this expression even if the operands @@ -5314,7 +5410,7 @@ static PHINode *getConstantEvolvingPHI(Value *V, const Loop *L) { /// reason, return null. static Constant *EvaluateExpression(Value *V, const Loop *L, DenseMap<Instruction *, Constant *> &Vals, - const DataLayout *DL, + const DataLayout &DL, const TargetLibraryInfo *TLI) { // Convenient constant check, but redundant for recursive calls. if (Constant *C = dyn_cast<Constant>(V)) return C; @@ -5403,6 +5499,7 @@ ScalarEvolution::getConstantEvolutionLoopExitValue(PHINode *PN, unsigned NumIterations = BEs.getZExtValue(); // must be in range unsigned IterationNum = 0; + const DataLayout &DL = F->getParent()->getDataLayout(); for (; ; ++IterationNum) { if (IterationNum == NumIterations) return RetVal = CurrentIterVals[PN]; // Got exit value! @@ -5410,8 +5507,8 @@ ScalarEvolution::getConstantEvolutionLoopExitValue(PHINode *PN, // Compute the value of the PHIs for the next iteration. // EvaluateExpression adds non-phi values to the CurrentIterVals map. DenseMap<Instruction *, Constant *> NextIterVals; - Constant *NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, - TLI); + Constant *NextPHI = + EvaluateExpression(BEValue, L, CurrentIterVals, DL, TLI); if (!NextPHI) return nullptr; // Couldn't evaluate! NextIterVals[PN] = NextPHI; @@ -5487,12 +5584,11 @@ const SCEV *ScalarEvolution::ComputeExitCountExhaustively(const Loop *L, // Okay, we find a PHI node that defines the trip count of this loop. Execute // the loop symbolically to determine when the condition gets a value of // "ExitWhen". - unsigned MaxIterations = MaxBruteForceIterations; // Limit analysis. + const DataLayout &DL = F->getParent()->getDataLayout(); for (unsigned IterationNum = 0; IterationNum != MaxIterations;++IterationNum){ - ConstantInt *CondVal = - dyn_cast_or_null<ConstantInt>(EvaluateExpression(Cond, L, CurrentIterVals, - DL, TLI)); + ConstantInt *CondVal = dyn_cast_or_null<ConstantInt>( + EvaluateExpression(Cond, L, CurrentIterVals, DL, TLI)); // Couldn't symbolically evaluate. if (!CondVal) return getCouldNotCompute(); @@ -5623,7 +5719,7 @@ static Constant *BuildConstantFromSCEV(const SCEV *V) { if (PTy->getElementType()->isStructTy()) C2 = ConstantExpr::getIntegerCast( C2, Type::getInt32Ty(C->getContext()), true); - C = ConstantExpr::getGetElementPtr(C, C2); + C = ConstantExpr::getGetElementPtr(PTy->getElementType(), C, C2); } else C = ConstantExpr::getAdd(C, C2); } @@ -5725,16 +5821,16 @@ const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) { // Check to see if getSCEVAtScope actually made an improvement. if (MadeImprovement) { Constant *C = nullptr; + const DataLayout &DL = F->getParent()->getDataLayout(); if (const CmpInst *CI = dyn_cast<CmpInst>(I)) - C = ConstantFoldCompareInstOperands(CI->getPredicate(), - Operands[0], Operands[1], DL, - TLI); + C = ConstantFoldCompareInstOperands(CI->getPredicate(), Operands[0], + Operands[1], DL, TLI); else if (const LoadInst *LI = dyn_cast<LoadInst>(I)) { if (!LI->isVolatile()) C = ConstantFoldLoadFromConstPtr(Operands[0], DL); } else - C = ConstantFoldInstOperands(I->getOpcode(), I->getType(), - Operands, DL, TLI); + C = ConstantFoldInstOperands(I->getOpcode(), I->getType(), Operands, + DL, TLI); if (!C) return V; return getSCEV(C); } @@ -6016,7 +6112,7 @@ ScalarEvolution::HowFarToZero(const SCEV *V, const Loop *L, bool ControlsExit) { dyn_cast<ConstantInt>(ConstantExpr::getICmp(CmpInst::ICMP_ULT, R1->getValue(), R2->getValue()))) { - if (CB->getZExtValue() == false) + if (!CB->getZExtValue()) std::swap(R1, R2); // R1 is the minimum root now. // We can only use this value if the chrec ends up with an exact zero @@ -6631,6 +6727,65 @@ ScalarEvolution::isLoopBackedgeGuardedByCond(const Loop *L, return true; } + struct ClearWalkingBEDominatingCondsOnExit { + ScalarEvolution &SE; + + explicit ClearWalkingBEDominatingCondsOnExit(ScalarEvolution &SE) + : SE(SE){}; + + ~ClearWalkingBEDominatingCondsOnExit() { + SE.WalkingBEDominatingConds = false; + } + }; + + // We don't want more than one activation of the following loop on the stack + // -- that can lead to O(n!) time complexity. + if (WalkingBEDominatingConds) + return false; + + WalkingBEDominatingConds = true; + ClearWalkingBEDominatingCondsOnExit ClearOnExit(*this); + + // If the loop is not reachable from the entry block, we risk running into an + // infinite loop as we walk up into the dom tree. These loops do not matter + // anyway, so we just return a conservative answer when we see them. + if (!DT->isReachableFromEntry(L->getHeader())) + return false; + + for (DomTreeNode *DTN = (*DT)[Latch], *HeaderDTN = (*DT)[L->getHeader()]; + DTN != HeaderDTN; + DTN = DTN->getIDom()) { + + assert(DTN && "should reach the loop header before reaching the root!"); + + BasicBlock *BB = DTN->getBlock(); + BasicBlock *PBB = BB->getSinglePredecessor(); + if (!PBB) + continue; + + BranchInst *ContinuePredicate = dyn_cast<BranchInst>(PBB->getTerminator()); + if (!ContinuePredicate || !ContinuePredicate->isConditional()) + continue; + + Value *Condition = ContinuePredicate->getCondition(); + + // If we have an edge `E` within the loop body that dominates the only + // latch, the condition guarding `E` also guards the backedge. This + // reasoning works only for loops with a single latch. + + BasicBlockEdge DominatingEdge(PBB, BB); + if (DominatingEdge.isSingleEdge()) { + // We're constructively (and conservatively) enumerating edges within the + // loop body that dominate the latch. The dominator tree better agree + // with us on this: + assert(DT->dominates(DominatingEdge, Latch) && "should be!"); + + if (isImpliedCond(Pred, LHS, RHS, Condition, + BB != ContinuePredicate->getSuccessor(0))) + return true; + } + } + return false; } @@ -6726,15 +6881,6 @@ bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, ICmpInst *ICI = dyn_cast<ICmpInst>(FoundCondValue); if (!ICI) return false; - // Bail if the ICmp's operands' types are wider than the needed type - // before attempting to call getSCEV on them. This avoids infinite - // recursion, since the analysis of widening casts can require loop - // exit condition information for overflow checking, which would - // lead back here. - if (getTypeSizeInBits(LHS->getType()) < - getTypeSizeInBits(ICI->getOperand(0)->getType())) - return false; - // Now that we found a conditional branch that dominates the loop or controls // the loop latch. Check to see if it is the comparison we are looking for. ICmpInst::Predicate FoundPred; @@ -6746,9 +6892,17 @@ bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *FoundLHS = getSCEV(ICI->getOperand(0)); const SCEV *FoundRHS = getSCEV(ICI->getOperand(1)); - // Balance the types. The case where FoundLHS' type is wider than - // LHS' type is checked for above. - if (getTypeSizeInBits(LHS->getType()) > + // Balance the types. + if (getTypeSizeInBits(LHS->getType()) < + getTypeSizeInBits(FoundLHS->getType())) { + if (CmpInst::isSigned(Pred)) { + LHS = getSignExtendExpr(LHS, FoundLHS->getType()); + RHS = getSignExtendExpr(RHS, FoundLHS->getType()); + } else { + LHS = getZeroExtendExpr(LHS, FoundLHS->getType()); + RHS = getZeroExtendExpr(RHS, FoundLHS->getType()); + } + } else if (getTypeSizeInBits(LHS->getType()) > getTypeSizeInBits(FoundLHS->getType())) { if (CmpInst::isSigned(FoundPred)) { FoundLHS = getSignExtendExpr(FoundLHS, LHS->getType()); @@ -6874,6 +7028,9 @@ bool ScalarEvolution::isImpliedCondOperands(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const SCEV *FoundLHS, const SCEV *FoundRHS) { + if (isImpliedCondOperandsViaRanges(Pred, LHS, RHS, FoundLHS, FoundRHS)) + return true; + return isImpliedCondOperandsHelper(Pred, LHS, RHS, FoundLHS, FoundRHS) || // ~x < ~y --> x > y @@ -7011,8 +7168,49 @@ ScalarEvolution::isImpliedCondOperandsHelper(ICmpInst::Predicate Pred, return false; } -// Verify if an linear IV with positive stride can overflow when in a -// less-than comparison, knowing the invariant term of the comparison, the +/// isImpliedCondOperandsViaRanges - helper function for isImpliedCondOperands. +/// Tries to get cases like "X `sgt` 0 => X - 1 `sgt` -1". +bool ScalarEvolution::isImpliedCondOperandsViaRanges(ICmpInst::Predicate Pred, + const SCEV *LHS, + const SCEV *RHS, + const SCEV *FoundLHS, + const SCEV *FoundRHS) { + if (!isa<SCEVConstant>(RHS) || !isa<SCEVConstant>(FoundRHS)) + // The restriction on `FoundRHS` be lifted easily -- it exists only to + // reduce the compile time impact of this optimization. + return false; + + const SCEVAddExpr *AddLHS = dyn_cast<SCEVAddExpr>(LHS); + if (!AddLHS || AddLHS->getOperand(1) != FoundLHS || + !isa<SCEVConstant>(AddLHS->getOperand(0))) + return false; + + APInt ConstFoundRHS = cast<SCEVConstant>(FoundRHS)->getValue()->getValue(); + + // `FoundLHSRange` is the range we know `FoundLHS` to be in by virtue of the + // antecedent "`FoundLHS` `Pred` `FoundRHS`". + ConstantRange FoundLHSRange = + ConstantRange::makeAllowedICmpRegion(Pred, ConstFoundRHS); + + // Since `LHS` is `FoundLHS` + `AddLHS->getOperand(0)`, we can compute a range + // for `LHS`: + APInt Addend = + cast<SCEVConstant>(AddLHS->getOperand(0))->getValue()->getValue(); + ConstantRange LHSRange = FoundLHSRange.add(ConstantRange(Addend)); + + // We can also compute the range of values for `LHS` that satisfy the + // consequent, "`LHS` `Pred` `RHS`": + APInt ConstRHS = cast<SCEVConstant>(RHS)->getValue()->getValue(); + ConstantRange SatisfyingLHSRange = + ConstantRange::makeSatisfyingICmpRegion(Pred, ConstRHS); + + // The antecedent implies the consequent if every value of `LHS` that + // satisfies the antecedent also satisfies the consequent. + return SatisfyingLHSRange.contains(LHSRange); +} + +// Verify if an linear IV with positive stride can overflow when in a +// less-than comparison, knowing the invariant term of the comparison, the // stride and the knowledge of NSW/NUW flags on the recurrence. bool ScalarEvolution::doesIVOverflowOnLT(const SCEV *RHS, const SCEV *Stride, bool IsSigned, bool NoWrap) { @@ -7040,7 +7238,7 @@ bool ScalarEvolution::doesIVOverflowOnLT(const SCEV *RHS, const SCEV *Stride, return (MaxValue - MaxStrideMinusOne).ult(MaxRHS); } -// Verify if an linear IV with negative stride can overflow when in a +// Verify if an linear IV with negative stride can overflow when in a // greater-than comparison, knowing the invariant term of the comparison, // the stride and the knowledge of NSW/NUW flags on the recurrence. bool ScalarEvolution::doesIVOverflowOnGT(const SCEV *RHS, const SCEV *Stride, @@ -7071,7 +7269,7 @@ bool ScalarEvolution::doesIVOverflowOnGT(const SCEV *RHS, const SCEV *Stride, // Compute the backedge taken count knowing the interval difference, the // stride and presence of the equality in the comparison. -const SCEV *ScalarEvolution::computeBECount(const SCEV *Delta, const SCEV *Step, +const SCEV *ScalarEvolution::computeBECount(const SCEV *Delta, const SCEV *Step, bool Equality) { const SCEV *One = getConstant(Step->getType(), 1); Delta = Equality ? getAddExpr(Delta, Step) @@ -7111,7 +7309,7 @@ ScalarEvolution::HowManyLessThans(const SCEV *LHS, const SCEV *RHS, // Avoid proven overflow cases: this will ensure that the backedge taken count // will not generate any unsigned overflow. Relaxed no-overflow conditions - // exploit NoWrapFlags, allowing to optimize in presence of undefined + // exploit NoWrapFlags, allowing to optimize in presence of undefined // behaviors like the case of C language. if (!Stride->isOne() && doesIVOverflowOnLT(RHS, Stride, IsSigned, NoWrap)) return getCouldNotCompute(); @@ -7191,7 +7389,7 @@ ScalarEvolution::HowManyGreaterThans(const SCEV *LHS, const SCEV *RHS, // Avoid proven overflow cases: this will ensure that the backedge taken count // will not generate any unsigned overflow. Relaxed no-overflow conditions - // exploit NoWrapFlags, allowing to optimize in presence of undefined + // exploit NoWrapFlags, allowing to optimize in presence of undefined // behaviors like the case of C language. if (!Stride->isOne() && doesIVOverflowOnGT(RHS, Stride, IsSigned, NoWrap)) return getCouldNotCompute(); @@ -7239,7 +7437,7 @@ ScalarEvolution::HowManyGreaterThans(const SCEV *LHS, const SCEV *RHS, if (isa<SCEVConstant>(BECount)) MaxBECount = BECount; else - MaxBECount = computeBECount(getConstant(MaxStart - MinEnd), + MaxBECount = computeBECount(getConstant(MaxStart - MinEnd), getConstant(MinStride), false); if (isa<SCEVCouldNotCompute>(MaxBECount)) @@ -7339,7 +7537,7 @@ const SCEV *SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range, if (ConstantInt *CB = dyn_cast<ConstantInt>(ConstantExpr::getICmp(ICmpInst::ICMP_ULT, R1->getValue(), R2->getValue()))) { - if (CB->getZExtValue() == false) + if (!CB->getZExtValue()) std::swap(R1, R2); // R1 is the minimum root now. // Make sure the root is not off by one. The returned iteration should @@ -7858,18 +8056,16 @@ ScalarEvolution::SCEVCallbackVH::SCEVCallbackVH(Value *V, ScalarEvolution *se) //===----------------------------------------------------------------------===// ScalarEvolution::ScalarEvolution() - : FunctionPass(ID), ValuesAtScopes(64), LoopDispositions(64), - BlockDispositions(64), FirstUnknown(nullptr) { + : FunctionPass(ID), WalkingBEDominatingConds(false), ValuesAtScopes(64), + LoopDispositions(64), BlockDispositions(64), FirstUnknown(nullptr) { initializeScalarEvolutionPass(*PassRegistry::getPassRegistry()); } bool ScalarEvolution::runOnFunction(Function &F) { this->F = &F; AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); - LI = &getAnalysis<LoopInfo>(); - DataLayoutPass *DLP = getAnalysisIfAvailable<DataLayoutPass>(); - DL = DLP ? &DLP->getDataLayout() : nullptr; - TLI = &getAnalysis<TargetLibraryInfo>(); + LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); + TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); return false; } @@ -7892,6 +8088,7 @@ void ScalarEvolution::releaseMemory() { } assert(PendingLoopPredicates.empty() && "isImpliedCond garbage"); + assert(!WalkingBEDominatingConds && "isLoopBackedgeGuardedByCond garbage!"); BackedgeTakenCounts.clear(); ConstantEvolutionLoopExitValue.clear(); @@ -7907,9 +8104,9 @@ void ScalarEvolution::releaseMemory() { void ScalarEvolution::getAnalysisUsage(AnalysisUsage &AU) const { AU.setPreservesAll(); AU.addRequired<AssumptionCacheTracker>(); - AU.addRequiredTransitive<LoopInfo>(); + AU.addRequiredTransitive<LoopInfoWrapperPass>(); AU.addRequiredTransitive<DominatorTreeWrapperPass>(); - AU.addRequired<TargetLibraryInfo>(); + AU.addRequired<TargetLibraryInfoWrapperPass>(); } bool ScalarEvolution::hasLoopInvariantBackedgeTakenCount(const Loop *L) { @@ -7969,6 +8166,12 @@ void ScalarEvolution::print(raw_ostream &OS, const Module *) const { OS << " --> "; const SCEV *SV = SE.getSCEV(&*I); SV->print(OS); + if (!isa<SCEVCouldNotCompute>(SV)) { + OS << " U: "; + SE.getUnsignedRange(SV).print(OS); + OS << " S: "; + SE.getSignedRange(SV).print(OS); + } const Loop *L = LI->getLoopFor((*I).getParent()); @@ -7976,6 +8179,12 @@ void ScalarEvolution::print(raw_ostream &OS, const Module *) const { if (AtUse != SV) { OS << " --> "; AtUse->print(OS); + if (!isa<SCEVCouldNotCompute>(AtUse)) { + OS << " U: "; + SE.getUnsignedRange(AtUse).print(OS); + OS << " S: "; + SE.getSignedRange(AtUse).print(OS); + } } if (L) { @@ -8000,17 +8209,17 @@ void ScalarEvolution::print(raw_ostream &OS, const Module *) const { ScalarEvolution::LoopDisposition ScalarEvolution::getLoopDisposition(const SCEV *S, const Loop *L) { - SmallVector<std::pair<const Loop *, LoopDisposition>, 2> &Values = LoopDispositions[S]; - for (unsigned u = 0; u < Values.size(); u++) { - if (Values[u].first == L) - return Values[u].second; + auto &Values = LoopDispositions[S]; + for (auto &V : Values) { + if (V.getPointer() == L) + return V.getInt(); } - Values.push_back(std::make_pair(L, LoopVariant)); + Values.emplace_back(L, LoopVariant); LoopDisposition D = computeLoopDisposition(S, L); - SmallVector<std::pair<const Loop *, LoopDisposition>, 2> &Values2 = LoopDispositions[S]; - for (unsigned u = Values2.size(); u > 0; u--) { - if (Values2[u - 1].first == L) { - Values2[u - 1].second = D; + auto &Values2 = LoopDispositions[S]; + for (auto &V : make_range(Values2.rbegin(), Values2.rend())) { + if (V.getPointer() == L) { + V.setInt(D); break; } } @@ -8106,17 +8315,17 @@ bool ScalarEvolution::hasComputableLoopEvolution(const SCEV *S, const Loop *L) { ScalarEvolution::BlockDisposition ScalarEvolution::getBlockDisposition(const SCEV *S, const BasicBlock *BB) { - SmallVector<std::pair<const BasicBlock *, BlockDisposition>, 2> &Values = BlockDispositions[S]; - for (unsigned u = 0; u < Values.size(); u++) { - if (Values[u].first == BB) - return Values[u].second; + auto &Values = BlockDispositions[S]; + for (auto &V : Values) { + if (V.getPointer() == BB) + return V.getInt(); } - Values.push_back(std::make_pair(BB, DoesNotDominateBlock)); + Values.emplace_back(BB, DoesNotDominateBlock); BlockDisposition D = computeBlockDisposition(S, BB); - SmallVector<std::pair<const BasicBlock *, BlockDisposition>, 2> &Values2 = BlockDispositions[S]; - for (unsigned u = Values2.size(); u > 0; u--) { - if (Values2[u - 1].first == BB) { - Values2[u - 1].second = D; + auto &Values2 = BlockDispositions[S]; + for (auto &V : make_range(Values2.rbegin(), Values2.rend())) { + if (V.getPointer() == BB) { + V.setInt(D); break; } } |