diff options
author | ed <ed@FreeBSD.org> | 2009-06-22 08:08:12 +0000 |
---|---|---|
committer | ed <ed@FreeBSD.org> | 2009-06-22 08:08:12 +0000 |
commit | a4c19d68f13cf0a83bc0da53bd6d547fcaf635fe (patch) | |
tree | 86c1bc482baa6c81fc70b8d715153bfa93377186 /lib/Analysis/ScalarEvolution.cpp | |
parent | db89e312d968c258aba3c79c1c398f5fb19267a3 (diff) | |
download | FreeBSD-src-a4c19d68f13cf0a83bc0da53bd6d547fcaf635fe.zip FreeBSD-src-a4c19d68f13cf0a83bc0da53bd6d547fcaf635fe.tar.gz |
Update LLVM sources to r73879.
Diffstat (limited to 'lib/Analysis/ScalarEvolution.cpp')
-rw-r--r-- | lib/Analysis/ScalarEvolution.cpp | 811 |
1 files changed, 664 insertions, 147 deletions
diff --git a/lib/Analysis/ScalarEvolution.cpp b/lib/Analysis/ScalarEvolution.cpp index 98ab6f4..68aa595 100644 --- a/lib/Analysis/ScalarEvolution.cpp +++ b/lib/Analysis/ScalarEvolution.cpp @@ -68,6 +68,7 @@ #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/Dominators.h" #include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/ValueTracking.h" #include "llvm/Assembly/Writer.h" #include "llvm/Target/TargetData.h" #include "llvm/Support/CommandLine.h" @@ -132,7 +133,8 @@ bool SCEV::isOne() const { return false; } -SCEVCouldNotCompute::SCEVCouldNotCompute() : SCEV(scCouldNotCompute) {} +SCEVCouldNotCompute::SCEVCouldNotCompute(const ScalarEvolution* p) : + SCEV(scCouldNotCompute, p) {} SCEVCouldNotCompute::~SCEVCouldNotCompute() {} bool SCEVCouldNotCompute::isLoopInvariant(const Loop *L) const { @@ -178,7 +180,7 @@ SCEVConstant::~SCEVConstant() { SCEVHandle ScalarEvolution::getConstant(ConstantInt *V) { SCEVConstant *&R = (*SCEVConstants)[V]; - if (R == 0) R = new SCEVConstant(V); + if (R == 0) R = new SCEVConstant(V, this); return R; } @@ -186,6 +188,11 @@ SCEVHandle ScalarEvolution::getConstant(const APInt& Val) { return getConstant(ConstantInt::get(Val)); } +SCEVHandle +ScalarEvolution::getConstant(const Type *Ty, uint64_t V, bool isSigned) { + return getConstant(ConstantInt::get(cast<IntegerType>(Ty), V, isSigned)); +} + const Type *SCEVConstant::getType() const { return V->getType(); } void SCEVConstant::print(raw_ostream &OS) const { @@ -193,8 +200,9 @@ void SCEVConstant::print(raw_ostream &OS) const { } SCEVCastExpr::SCEVCastExpr(unsigned SCEVTy, - const SCEVHandle &op, const Type *ty) - : SCEV(SCEVTy), Op(op), Ty(ty) {} + const SCEVHandle &op, const Type *ty, + const ScalarEvolution* p) + : SCEV(SCEVTy, p), Op(op), Ty(ty) {} SCEVCastExpr::~SCEVCastExpr() {} @@ -208,8 +216,9 @@ bool SCEVCastExpr::dominates(BasicBlock *BB, DominatorTree *DT) const { static ManagedStatic<std::map<std::pair<const SCEV*, const Type*>, SCEVTruncateExpr*> > SCEVTruncates; -SCEVTruncateExpr::SCEVTruncateExpr(const SCEVHandle &op, const Type *ty) - : SCEVCastExpr(scTruncate, op, ty) { +SCEVTruncateExpr::SCEVTruncateExpr(const SCEVHandle &op, const Type *ty, + const ScalarEvolution* p) + : SCEVCastExpr(scTruncate, op, ty, p) { assert((Op->getType()->isInteger() || isa<PointerType>(Op->getType())) && (Ty->isInteger() || isa<PointerType>(Ty)) && "Cannot truncate non-integer value!"); @@ -229,8 +238,9 @@ void SCEVTruncateExpr::print(raw_ostream &OS) const { static ManagedStatic<std::map<std::pair<const SCEV*, const Type*>, SCEVZeroExtendExpr*> > SCEVZeroExtends; -SCEVZeroExtendExpr::SCEVZeroExtendExpr(const SCEVHandle &op, const Type *ty) - : SCEVCastExpr(scZeroExtend, op, ty) { +SCEVZeroExtendExpr::SCEVZeroExtendExpr(const SCEVHandle &op, const Type *ty, + const ScalarEvolution* p) + : SCEVCastExpr(scZeroExtend, op, ty, p) { assert((Op->getType()->isInteger() || isa<PointerType>(Op->getType())) && (Ty->isInteger() || isa<PointerType>(Ty)) && "Cannot zero extend non-integer value!"); @@ -250,8 +260,9 @@ void SCEVZeroExtendExpr::print(raw_ostream &OS) const { static ManagedStatic<std::map<std::pair<const SCEV*, const Type*>, SCEVSignExtendExpr*> > SCEVSignExtends; -SCEVSignExtendExpr::SCEVSignExtendExpr(const SCEVHandle &op, const Type *ty) - : SCEVCastExpr(scSignExtend, op, ty) { +SCEVSignExtendExpr::SCEVSignExtendExpr(const SCEVHandle &op, const Type *ty, + const ScalarEvolution* p) + : SCEVCastExpr(scSignExtend, op, ty, p) { assert((Op->getType()->isInteger() || isa<PointerType>(Op->getType())) && (Ty->isInteger() || isa<PointerType>(Ty)) && "Cannot sign extend non-integer value!"); @@ -293,7 +304,7 @@ replaceSymbolicValuesWithConcrete(const SCEVHandle &Sym, SCEVHandle H = getOperand(i)->replaceSymbolicValuesWithConcrete(Sym, Conc, SE); if (H != getOperand(i)) { - std::vector<SCEVHandle> NewOps; + SmallVector<SCEVHandle, 8> NewOps; NewOps.reserve(getNumOperands()); for (unsigned j = 0; j != i; ++j) NewOps.push_back(getOperand(j)); @@ -373,7 +384,7 @@ replaceSymbolicValuesWithConcrete(const SCEVHandle &Sym, SCEVHandle H = getOperand(i)->replaceSymbolicValuesWithConcrete(Sym, Conc, SE); if (H != getOperand(i)) { - std::vector<SCEVHandle> NewOps; + SmallVector<SCEVHandle, 8> NewOps; NewOps.reserve(getNumOperands()); for (unsigned j = 0; j != i; ++j) NewOps.push_back(getOperand(j)); @@ -504,9 +515,18 @@ namespace { return false; } - // Constant sorting doesn't matter since they'll be folded. - if (isa<SCEVConstant>(LHS)) - return false; + // Compare constant values. + if (const SCEVConstant *LC = dyn_cast<SCEVConstant>(LHS)) { + const SCEVConstant *RC = cast<SCEVConstant>(RHS); + return LC->getValue()->getValue().ult(RC->getValue()->getValue()); + } + + // Compare addrec loop depths. + if (const SCEVAddRecExpr *LA = dyn_cast<SCEVAddRecExpr>(LHS)) { + const SCEVAddRecExpr *RA = cast<SCEVAddRecExpr>(RHS); + if (LA->getLoop()->getLoopDepth() != RA->getLoop()->getLoopDepth()) + return LA->getLoop()->getLoopDepth() < RA->getLoop()->getLoopDepth(); + } // Lexicographically compare n-ary expressions. if (const SCEVNAryExpr *LC = dyn_cast<SCEVNAryExpr>(LHS)) { @@ -558,7 +578,7 @@ namespace { /// this to depend on where the addresses of various SCEV objects happened to /// land in memory. /// -static void GroupByComplexity(std::vector<SCEVHandle> &Ops, +static void GroupByComplexity(SmallVectorImpl<SCEVHandle> &Ops, LoopInfo *LI) { if (Ops.size() < 2) return; // Noop if (Ops.size() == 2) { @@ -763,17 +783,16 @@ SCEVHandle ScalarEvolution::getTruncateExpr(const SCEVHandle &Op, if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op)) return getTruncateOrZeroExtend(SZ->getOperand(), Ty); - // If the input value is a chrec scev made out of constants, truncate - // all of the constants. + // If the input value is a chrec scev, truncate the chrec's operands. if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Op)) { - std::vector<SCEVHandle> Operands; + SmallVector<SCEVHandle, 4> Operands; for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) Operands.push_back(getTruncateExpr(AddRec->getOperand(i), Ty)); return getAddRecExpr(Operands, AddRec->getLoop()); } SCEVTruncateExpr *&Result = (*SCEVTruncates)[std::make_pair(Op, Ty)]; - if (Result == 0) Result = new SCEVTruncateExpr(Op, Ty); + if (Result == 0) Result = new SCEVTruncateExpr(Op, Ty, this); return Result; } @@ -861,7 +880,7 @@ SCEVHandle ScalarEvolution::getZeroExtendExpr(const SCEVHandle &Op, } SCEVZeroExtendExpr *&Result = (*SCEVZeroExtends)[std::make_pair(Op, Ty)]; - if (Result == 0) Result = new SCEVZeroExtendExpr(Op, Ty); + if (Result == 0) Result = new SCEVZeroExtendExpr(Op, Ty, this); return Result; } @@ -933,7 +952,7 @@ SCEVHandle ScalarEvolution::getSignExtendExpr(const SCEVHandle &Op, } SCEVSignExtendExpr *&Result = (*SCEVSignExtends)[std::make_pair(Op, Ty)]; - if (Result == 0) Result = new SCEVSignExtendExpr(Op, Ty); + if (Result == 0) Result = new SCEVSignExtendExpr(Op, Ty, this); return Result; } @@ -979,9 +998,105 @@ SCEVHandle ScalarEvolution::getAnyExtendExpr(const SCEVHandle &Op, return ZExt; } +/// CollectAddOperandsWithScales - Process the given Ops list, which is +/// a list of operands to be added under the given scale, update the given +/// map. This is a helper function for getAddRecExpr. As an example of +/// what it does, given a sequence of operands that would form an add +/// expression like this: +/// +/// m + n + 13 + (A * (o + p + (B * q + m + 29))) + r + (-1 * r) +/// +/// where A and B are constants, update the map with these values: +/// +/// (m, 1+A*B), (n, 1), (o, A), (p, A), (q, A*B), (r, 0) +/// +/// and add 13 + A*B*29 to AccumulatedConstant. +/// This will allow getAddRecExpr to produce this: +/// +/// 13+A*B*29 + n + (m * (1+A*B)) + ((o + p) * A) + (q * A*B) +/// +/// This form often exposes folding opportunities that are hidden in +/// the original operand list. +/// +/// Return true iff it appears that any interesting folding opportunities +/// may be exposed. This helps getAddRecExpr short-circuit extra work in +/// the common case where no interesting opportunities are present, and +/// is also used as a check to avoid infinite recursion. +/// +static bool +CollectAddOperandsWithScales(DenseMap<SCEVHandle, APInt> &M, + SmallVector<SCEVHandle, 8> &NewOps, + APInt &AccumulatedConstant, + const SmallVectorImpl<SCEVHandle> &Ops, + const APInt &Scale, + ScalarEvolution &SE) { + bool Interesting = false; + + // Iterate over the add operands. + for (unsigned i = 0, e = Ops.size(); i != e; ++i) { + const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[i]); + if (Mul && isa<SCEVConstant>(Mul->getOperand(0))) { + APInt NewScale = + Scale * cast<SCEVConstant>(Mul->getOperand(0))->getValue()->getValue(); + if (Mul->getNumOperands() == 2 && isa<SCEVAddExpr>(Mul->getOperand(1))) { + // A multiplication of a constant with another add; recurse. + Interesting |= + CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant, + cast<SCEVAddExpr>(Mul->getOperand(1)) + ->getOperands(), + NewScale, SE); + } else { + // A multiplication of a constant with some other value. Update + // the map. + SmallVector<SCEVHandle, 4> MulOps(Mul->op_begin()+1, Mul->op_end()); + SCEVHandle Key = SE.getMulExpr(MulOps); + std::pair<DenseMap<SCEVHandle, APInt>::iterator, bool> Pair = + M.insert(std::make_pair(Key, APInt())); + if (Pair.second) { + Pair.first->second = NewScale; + NewOps.push_back(Pair.first->first); + } else { + Pair.first->second += NewScale; + // The map already had an entry for this value, which may indicate + // a folding opportunity. + Interesting = true; + } + } + } else if (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) { + // Pull a buried constant out to the outside. + if (Scale != 1 || AccumulatedConstant != 0 || C->isZero()) + Interesting = true; + AccumulatedConstant += Scale * C->getValue()->getValue(); + } else { + // An ordinary operand. Update the map. + std::pair<DenseMap<SCEVHandle, APInt>::iterator, bool> Pair = + M.insert(std::make_pair(Ops[i], APInt())); + if (Pair.second) { + Pair.first->second = Scale; + NewOps.push_back(Pair.first->first); + } else { + Pair.first->second += Scale; + // The map already had an entry for this value, which may indicate + // a folding opportunity. + Interesting = true; + } + } + } + + return Interesting; +} + +namespace { + struct APIntCompare { + bool operator()(const APInt &LHS, const APInt &RHS) const { + return LHS.ult(RHS); + } + }; +} + /// getAddExpr - Get a canonical add expression, or something simpler if /// possible. -SCEVHandle ScalarEvolution::getAddExpr(std::vector<SCEVHandle> &Ops) { +SCEVHandle ScalarEvolution::getAddExpr(SmallVectorImpl<SCEVHandle> &Ops) { assert(!Ops.empty() && "Cannot get empty add!"); if (Ops.size() == 1) return Ops[0]; #ifndef NDEBUG @@ -1001,11 +1116,10 @@ SCEVHandle ScalarEvolution::getAddExpr(std::vector<SCEVHandle> &Ops) { assert(Idx < Ops.size()); while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) { // We found two constants, fold them together! - ConstantInt *Fold = ConstantInt::get(LHSC->getValue()->getValue() + - RHSC->getValue()->getValue()); - Ops[0] = getConstant(Fold); + Ops[0] = getConstant(LHSC->getValue()->getValue() + + RHSC->getValue()->getValue()); + if (Ops.size() == 2) return Ops[0]; Ops.erase(Ops.begin()+1); // Erase the folded element - if (Ops.size() == 1) return Ops[0]; LHSC = cast<SCEVConstant>(Ops[0]); } @@ -1043,7 +1157,7 @@ SCEVHandle ScalarEvolution::getAddExpr(std::vector<SCEVHandle> &Ops) { const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(Ops[Idx]); const Type *DstType = Trunc->getType(); const Type *SrcType = Trunc->getOperand()->getType(); - std::vector<SCEVHandle> LargeOps; + SmallVector<SCEVHandle, 8> LargeOps; bool Ok = true; // Check all the operands to see if they can be represented in the // source type of the truncate. @@ -1059,7 +1173,7 @@ SCEVHandle ScalarEvolution::getAddExpr(std::vector<SCEVHandle> &Ops) { // is much more likely to be foldable here. LargeOps.push_back(getSignExtendExpr(C, SrcType)); } else if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(Ops[i])) { - std::vector<SCEVHandle> LargeMulOps; + SmallVector<SCEVHandle, 8> LargeMulOps; for (unsigned j = 0, f = M->getNumOperands(); j != f && Ok; ++j) { if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(M->getOperand(j))) { @@ -1120,6 +1234,38 @@ SCEVHandle ScalarEvolution::getAddExpr(std::vector<SCEVHandle> &Ops) { while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr) ++Idx; + // Check to see if there are any folding opportunities present with + // operands multiplied by constant values. + if (Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx])) { + uint64_t BitWidth = getTypeSizeInBits(Ty); + DenseMap<SCEVHandle, APInt> M; + SmallVector<SCEVHandle, 8> NewOps; + APInt AccumulatedConstant(BitWidth, 0); + if (CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant, + Ops, APInt(BitWidth, 1), *this)) { + // Some interesting folding opportunity is present, so its worthwhile to + // re-generate the operands list. Group the operands by constant scale, + // to avoid multiplying by the same constant scale multiple times. + std::map<APInt, SmallVector<SCEVHandle, 4>, APIntCompare> MulOpLists; + for (SmallVector<SCEVHandle, 8>::iterator I = NewOps.begin(), + E = NewOps.end(); I != E; ++I) + MulOpLists[M.find(*I)->second].push_back(*I); + // Re-generate the operands list. + Ops.clear(); + if (AccumulatedConstant != 0) + Ops.push_back(getConstant(AccumulatedConstant)); + for (std::map<APInt, SmallVector<SCEVHandle, 4>, APIntCompare>::iterator I = + MulOpLists.begin(), E = MulOpLists.end(); I != E; ++I) + if (I->first != 0) + Ops.push_back(getMulExpr(getConstant(I->first), getAddExpr(I->second))); + if (Ops.empty()) + return getIntegerSCEV(0, Ty); + if (Ops.size() == 1) + return Ops[0]; + return getAddExpr(Ops); + } + } + // If we are adding something to a multiply expression, make sure the // something is not already an operand of the multiply. If so, merge it into // the multiply. @@ -1128,13 +1274,13 @@ SCEVHandle ScalarEvolution::getAddExpr(std::vector<SCEVHandle> &Ops) { for (unsigned MulOp = 0, e = Mul->getNumOperands(); MulOp != e; ++MulOp) { const SCEV *MulOpSCEV = Mul->getOperand(MulOp); for (unsigned AddOp = 0, e = Ops.size(); AddOp != e; ++AddOp) - if (MulOpSCEV == Ops[AddOp] && !isa<SCEVConstant>(MulOpSCEV)) { + if (MulOpSCEV == Ops[AddOp] && !isa<SCEVConstant>(Ops[AddOp])) { // Fold W + X + (X * Y * Z) --> W + (X * ((Y*Z)+1)) SCEVHandle InnerMul = Mul->getOperand(MulOp == 0); if (Mul->getNumOperands() != 2) { // If the multiply has more than two operands, we must get the // Y*Z term. - std::vector<SCEVHandle> MulOps(Mul->op_begin(), Mul->op_end()); + SmallVector<SCEVHandle, 4> MulOps(Mul->op_begin(), Mul->op_end()); MulOps.erase(MulOps.begin()+MulOp); InnerMul = getMulExpr(MulOps); } @@ -1166,13 +1312,13 @@ SCEVHandle ScalarEvolution::getAddExpr(std::vector<SCEVHandle> &Ops) { // Fold X + (A*B*C) + (A*D*E) --> X + (A*(B*C+D*E)) SCEVHandle InnerMul1 = Mul->getOperand(MulOp == 0); if (Mul->getNumOperands() != 2) { - std::vector<SCEVHandle> MulOps(Mul->op_begin(), Mul->op_end()); + SmallVector<SCEVHandle, 4> MulOps(Mul->op_begin(), Mul->op_end()); MulOps.erase(MulOps.begin()+MulOp); InnerMul1 = getMulExpr(MulOps); } SCEVHandle InnerMul2 = OtherMul->getOperand(OMulOp == 0); if (OtherMul->getNumOperands() != 2) { - std::vector<SCEVHandle> MulOps(OtherMul->op_begin(), + SmallVector<SCEVHandle, 4> MulOps(OtherMul->op_begin(), OtherMul->op_end()); MulOps.erase(MulOps.begin()+OMulOp); InnerMul2 = getMulExpr(MulOps); @@ -1199,7 +1345,7 @@ SCEVHandle ScalarEvolution::getAddExpr(std::vector<SCEVHandle> &Ops) { for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) { // Scan all of the other operands to this add and add them to the vector if // they are loop invariant w.r.t. the recurrence. - std::vector<SCEVHandle> LIOps; + SmallVector<SCEVHandle, 8> LIOps; const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]); for (unsigned i = 0, e = Ops.size(); i != e; ++i) if (Ops[i]->isLoopInvariant(AddRec->getLoop())) { @@ -1213,7 +1359,8 @@ SCEVHandle ScalarEvolution::getAddExpr(std::vector<SCEVHandle> &Ops) { // NLI + LI + {Start,+,Step} --> NLI + {LI+Start,+,Step} LIOps.push_back(AddRec->getStart()); - std::vector<SCEVHandle> AddRecOps(AddRec->op_begin(), AddRec->op_end()); + SmallVector<SCEVHandle, 4> AddRecOps(AddRec->op_begin(), + AddRec->op_end()); AddRecOps[0] = getAddExpr(LIOps); SCEVHandle NewRec = getAddRecExpr(AddRecOps, AddRec->getLoop()); @@ -1238,7 +1385,7 @@ SCEVHandle ScalarEvolution::getAddExpr(std::vector<SCEVHandle> &Ops) { const SCEVAddRecExpr *OtherAddRec = cast<SCEVAddRecExpr>(Ops[OtherIdx]); if (AddRec->getLoop() == OtherAddRec->getLoop()) { // Other + {A,+,B} + {C,+,D} --> Other + {A+C,+,B+D} - std::vector<SCEVHandle> NewOps(AddRec->op_begin(), AddRec->op_end()); + SmallVector<SCEVHandle, 4> NewOps(AddRec->op_begin(), AddRec->op_end()); for (unsigned i = 0, e = OtherAddRec->getNumOperands(); i != e; ++i) { if (i >= NewOps.size()) { NewOps.insert(NewOps.end(), OtherAddRec->op_begin()+i, @@ -1267,14 +1414,14 @@ SCEVHandle ScalarEvolution::getAddExpr(std::vector<SCEVHandle> &Ops) { std::vector<const SCEV*> SCEVOps(Ops.begin(), Ops.end()); SCEVCommutativeExpr *&Result = (*SCEVCommExprs)[std::make_pair(scAddExpr, SCEVOps)]; - if (Result == 0) Result = new SCEVAddExpr(Ops); + if (Result == 0) Result = new SCEVAddExpr(Ops, this); return Result; } /// getMulExpr - Get a canonical multiply expression, or something simpler if /// possible. -SCEVHandle ScalarEvolution::getMulExpr(std::vector<SCEVHandle> &Ops) { +SCEVHandle ScalarEvolution::getMulExpr(SmallVectorImpl<SCEVHandle> &Ops) { assert(!Ops.empty() && "Cannot get empty mul!"); #ifndef NDEBUG for (unsigned i = 1, e = Ops.size(); i != e; ++i) @@ -1355,7 +1502,7 @@ SCEVHandle ScalarEvolution::getMulExpr(std::vector<SCEVHandle> &Ops) { for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) { // Scan all of the other operands to this mul and add them to the vector if // they are loop invariant w.r.t. the recurrence. - std::vector<SCEVHandle> LIOps; + SmallVector<SCEVHandle, 8> LIOps; const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]); for (unsigned i = 0, e = Ops.size(); i != e; ++i) if (Ops[i]->isLoopInvariant(AddRec->getLoop())) { @@ -1367,7 +1514,7 @@ SCEVHandle ScalarEvolution::getMulExpr(std::vector<SCEVHandle> &Ops) { // If we found some loop invariants, fold them into the recurrence. if (!LIOps.empty()) { // NLI * LI * {Start,+,Step} --> NLI * {LI*Start,+,LI*Step} - std::vector<SCEVHandle> NewOps; + SmallVector<SCEVHandle, 4> NewOps; NewOps.reserve(AddRec->getNumOperands()); if (LIOps.size() == 1) { const SCEV *Scale = LIOps[0]; @@ -1375,7 +1522,7 @@ SCEVHandle ScalarEvolution::getMulExpr(std::vector<SCEVHandle> &Ops) { NewOps.push_back(getMulExpr(Scale, AddRec->getOperand(i))); } else { for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) { - std::vector<SCEVHandle> MulOps(LIOps); + SmallVector<SCEVHandle, 4> MulOps(LIOps.begin(), LIOps.end()); MulOps.push_back(AddRec->getOperand(i)); NewOps.push_back(getMulExpr(MulOps)); } @@ -1433,7 +1580,7 @@ SCEVHandle ScalarEvolution::getMulExpr(std::vector<SCEVHandle> &Ops) { SCEVCommutativeExpr *&Result = (*SCEVCommExprs)[std::make_pair(scMulExpr, SCEVOps)]; if (Result == 0) - Result = new SCEVMulExpr(Ops); + Result = new SCEVMulExpr(Ops, this); return Result; } @@ -1473,14 +1620,14 @@ SCEVHandle ScalarEvolution::getUDivExpr(const SCEVHandle &LHS, getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy), getZeroExtendExpr(Step, ExtTy), AR->getLoop())) { - std::vector<SCEVHandle> Operands; + SmallVector<SCEVHandle, 4> Operands; for (unsigned i = 0, e = AR->getNumOperands(); i != e; ++i) Operands.push_back(getUDivExpr(AR->getOperand(i), RHS)); return getAddRecExpr(Operands, AR->getLoop()); } // (A*B)/C --> A*(B/C) if safe and B/C can be folded. if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(LHS)) { - std::vector<SCEVHandle> Operands; + SmallVector<SCEVHandle, 4> Operands; for (unsigned i = 0, e = M->getNumOperands(); i != e; ++i) Operands.push_back(getZeroExtendExpr(M->getOperand(i), ExtTy)); if (getZeroExtendExpr(M, ExtTy) == getMulExpr(Operands)) @@ -1489,7 +1636,9 @@ SCEVHandle ScalarEvolution::getUDivExpr(const SCEVHandle &LHS, SCEVHandle Op = M->getOperand(i); SCEVHandle Div = getUDivExpr(Op, RHSC); if (!isa<SCEVUDivExpr>(Div) && getMulExpr(Div, RHSC) == Op) { - Operands = M->getOperands(); + const SmallVectorImpl<SCEVHandle> &MOperands = M->getOperands(); + Operands = SmallVector<SCEVHandle, 4>(MOperands.begin(), + MOperands.end()); Operands[i] = Div; return getMulExpr(Operands); } @@ -1497,7 +1646,7 @@ SCEVHandle ScalarEvolution::getUDivExpr(const SCEVHandle &LHS, } // (A+B)/C --> (A/C + B/C) if safe and A/C and B/C can be folded. if (const SCEVAddRecExpr *A = dyn_cast<SCEVAddRecExpr>(LHS)) { - std::vector<SCEVHandle> Operands; + SmallVector<SCEVHandle, 4> Operands; for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i) Operands.push_back(getZeroExtendExpr(A->getOperand(i), ExtTy)); if (getZeroExtendExpr(A, ExtTy) == getAddExpr(Operands)) { @@ -1522,7 +1671,7 @@ SCEVHandle ScalarEvolution::getUDivExpr(const SCEVHandle &LHS, } SCEVUDivExpr *&Result = (*SCEVUDivs)[std::make_pair(LHS, RHS)]; - if (Result == 0) Result = new SCEVUDivExpr(LHS, RHS); + if (Result == 0) Result = new SCEVUDivExpr(LHS, RHS, this); return Result; } @@ -1531,7 +1680,7 @@ SCEVHandle ScalarEvolution::getUDivExpr(const SCEVHandle &LHS, /// Simplify the expression as much as possible. SCEVHandle ScalarEvolution::getAddRecExpr(const SCEVHandle &Start, const SCEVHandle &Step, const Loop *L) { - std::vector<SCEVHandle> Operands; + SmallVector<SCEVHandle, 4> Operands; Operands.push_back(Start); if (const SCEVAddRecExpr *StepChrec = dyn_cast<SCEVAddRecExpr>(Step)) if (StepChrec->getLoop() == L) { @@ -1546,7 +1695,7 @@ SCEVHandle ScalarEvolution::getAddRecExpr(const SCEVHandle &Start, /// getAddRecExpr - Get an add recurrence expression for the specified loop. /// Simplify the expression as much as possible. -SCEVHandle ScalarEvolution::getAddRecExpr(std::vector<SCEVHandle> &Operands, +SCEVHandle ScalarEvolution::getAddRecExpr(SmallVectorImpl<SCEVHandle> &Operands, const Loop *L) { if (Operands.size() == 1) return Operands[0]; #ifndef NDEBUG @@ -1565,8 +1714,8 @@ SCEVHandle ScalarEvolution::getAddRecExpr(std::vector<SCEVHandle> &Operands, if (const SCEVAddRecExpr *NestedAR = dyn_cast<SCEVAddRecExpr>(Operands[0])) { const Loop* NestedLoop = NestedAR->getLoop(); if (L->getLoopDepth() < NestedLoop->getLoopDepth()) { - std::vector<SCEVHandle> NestedOperands(NestedAR->op_begin(), - NestedAR->op_end()); + SmallVector<SCEVHandle, 4> NestedOperands(NestedAR->op_begin(), + NestedAR->op_end()); SCEVHandle NestedARHandle(NestedAR); Operands[0] = NestedAR->getStart(); NestedOperands[0] = getAddRecExpr(Operands, L); @@ -1576,19 +1725,20 @@ SCEVHandle ScalarEvolution::getAddRecExpr(std::vector<SCEVHandle> &Operands, std::vector<const SCEV*> SCEVOps(Operands.begin(), Operands.end()); SCEVAddRecExpr *&Result = (*SCEVAddRecExprs)[std::make_pair(L, SCEVOps)]; - if (Result == 0) Result = new SCEVAddRecExpr(Operands, L); + if (Result == 0) Result = new SCEVAddRecExpr(Operands, L, this); return Result; } SCEVHandle ScalarEvolution::getSMaxExpr(const SCEVHandle &LHS, const SCEVHandle &RHS) { - std::vector<SCEVHandle> Ops; + SmallVector<SCEVHandle, 2> Ops; Ops.push_back(LHS); Ops.push_back(RHS); return getSMaxExpr(Ops); } -SCEVHandle ScalarEvolution::getSMaxExpr(std::vector<SCEVHandle> Ops) { +SCEVHandle +ScalarEvolution::getSMaxExpr(SmallVectorImpl<SCEVHandle> &Ops) { assert(!Ops.empty() && "Cannot get empty smax!"); if (Ops.size() == 1) return Ops[0]; #ifndef NDEBUG @@ -1662,19 +1812,20 @@ SCEVHandle ScalarEvolution::getSMaxExpr(std::vector<SCEVHandle> Ops) { std::vector<const SCEV*> SCEVOps(Ops.begin(), Ops.end()); SCEVCommutativeExpr *&Result = (*SCEVCommExprs)[std::make_pair(scSMaxExpr, SCEVOps)]; - if (Result == 0) Result = new SCEVSMaxExpr(Ops); + if (Result == 0) Result = new SCEVSMaxExpr(Ops, this); return Result; } SCEVHandle ScalarEvolution::getUMaxExpr(const SCEVHandle &LHS, const SCEVHandle &RHS) { - std::vector<SCEVHandle> Ops; + SmallVector<SCEVHandle, 2> Ops; Ops.push_back(LHS); Ops.push_back(RHS); return getUMaxExpr(Ops); } -SCEVHandle ScalarEvolution::getUMaxExpr(std::vector<SCEVHandle> Ops) { +SCEVHandle +ScalarEvolution::getUMaxExpr(SmallVectorImpl<SCEVHandle> &Ops) { assert(!Ops.empty() && "Cannot get empty umax!"); if (Ops.size() == 1) return Ops[0]; #ifndef NDEBUG @@ -1748,17 +1899,29 @@ SCEVHandle ScalarEvolution::getUMaxExpr(std::vector<SCEVHandle> Ops) { std::vector<const SCEV*> SCEVOps(Ops.begin(), Ops.end()); SCEVCommutativeExpr *&Result = (*SCEVCommExprs)[std::make_pair(scUMaxExpr, SCEVOps)]; - if (Result == 0) Result = new SCEVUMaxExpr(Ops); + if (Result == 0) Result = new SCEVUMaxExpr(Ops, this); return Result; } +SCEVHandle ScalarEvolution::getSMinExpr(const SCEVHandle &LHS, + const SCEVHandle &RHS) { + // ~smax(~x, ~y) == smin(x, y). + return getNotSCEV(getSMaxExpr(getNotSCEV(LHS), getNotSCEV(RHS))); +} + +SCEVHandle ScalarEvolution::getUMinExpr(const SCEVHandle &LHS, + const SCEVHandle &RHS) { + // ~umax(~x, ~y) == umin(x, y) + return getNotSCEV(getUMaxExpr(getNotSCEV(LHS), getNotSCEV(RHS))); +} + SCEVHandle ScalarEvolution::getUnknown(Value *V) { if (ConstantInt *CI = dyn_cast<ConstantInt>(V)) return getConstant(CI); if (isa<ConstantPointerNull>(V)) return getIntegerSCEV(0, V->getType()); SCEVUnknown *&Result = (*SCEVUnknowns)[V]; - if (Result == 0) Result = new SCEVUnknown(V); + if (Result == 0) Result = new SCEVUnknown(V, this); return Result; } @@ -1977,6 +2140,22 @@ ScalarEvolution::getTruncateOrNoop(const SCEVHandle &V, const Type *Ty) { return getTruncateExpr(V, Ty); } +/// getUMaxFromMismatchedTypes - Promote the operands to the wider of +/// the types using zero-extension, and then perform a umax operation +/// with them. +SCEVHandle ScalarEvolution::getUMaxFromMismatchedTypes(const SCEVHandle &LHS, + const SCEVHandle &RHS) { + SCEVHandle PromotedLHS = LHS; + SCEVHandle PromotedRHS = RHS; + + if (getTypeSizeInBits(LHS->getType()) > getTypeSizeInBits(RHS->getType())) + PromotedRHS = getZeroExtendExpr(RHS, LHS->getType()); + else + PromotedLHS = getNoopOrZeroExtend(LHS, RHS->getType()); + + return getUMaxExpr(PromotedLHS, PromotedRHS); +} + /// ReplaceSymbolicValueWithConcrete - This looks up the computed SCEV value for /// the specified instruction and replaces any references to the symbolic value /// SymName with the specified value. This is used during PHI resolution. @@ -2040,7 +2219,7 @@ SCEVHandle ScalarEvolution::createNodeForPHI(PHINode *PN) { if (FoundIndex != Add->getNumOperands()) { // Create an add with everything but the specified operand. - std::vector<SCEVHandle> Ops; + SmallVector<SCEVHandle, 8> Ops; for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i) if (i != FoundIndex) Ops.push_back(Add->getOperand(i)); @@ -2143,73 +2322,134 @@ SCEVHandle ScalarEvolution::createNodeForGEP(User *GEP) { /// guaranteed to end in (at every loop iteration). It is, at the same time, /// the minimum number of times S is divisible by 2. For example, given {4,+,8} /// it returns 2. If S is guaranteed to be 0, it returns the bitwidth of S. -static uint32_t GetMinTrailingZeros(SCEVHandle S, const ScalarEvolution &SE) { +uint32_t +ScalarEvolution::GetMinTrailingZeros(const SCEVHandle &S) { if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S)) return C->getValue()->getValue().countTrailingZeros(); if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(S)) - return std::min(GetMinTrailingZeros(T->getOperand(), SE), - (uint32_t)SE.getTypeSizeInBits(T->getType())); + return std::min(GetMinTrailingZeros(T->getOperand()), + (uint32_t)getTypeSizeInBits(T->getType())); if (const SCEVZeroExtendExpr *E = dyn_cast<SCEVZeroExtendExpr>(S)) { - uint32_t OpRes = GetMinTrailingZeros(E->getOperand(), SE); - return OpRes == SE.getTypeSizeInBits(E->getOperand()->getType()) ? - SE.getTypeSizeInBits(E->getType()) : OpRes; + uint32_t OpRes = GetMinTrailingZeros(E->getOperand()); + return OpRes == getTypeSizeInBits(E->getOperand()->getType()) ? + getTypeSizeInBits(E->getType()) : OpRes; } if (const SCEVSignExtendExpr *E = dyn_cast<SCEVSignExtendExpr>(S)) { - uint32_t OpRes = GetMinTrailingZeros(E->getOperand(), SE); - return OpRes == SE.getTypeSizeInBits(E->getOperand()->getType()) ? - SE.getTypeSizeInBits(E->getType()) : OpRes; + uint32_t OpRes = GetMinTrailingZeros(E->getOperand()); + return OpRes == getTypeSizeInBits(E->getOperand()->getType()) ? + getTypeSizeInBits(E->getType()) : OpRes; } if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(S)) { // The result is the min of all operands results. - uint32_t MinOpRes = GetMinTrailingZeros(A->getOperand(0), SE); + uint32_t MinOpRes = GetMinTrailingZeros(A->getOperand(0)); for (unsigned i = 1, e = A->getNumOperands(); MinOpRes && i != e; ++i) - MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(A->getOperand(i), SE)); + MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(A->getOperand(i))); return MinOpRes; } if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(S)) { // The result is the sum of all operands results. - uint32_t SumOpRes = GetMinTrailingZeros(M->getOperand(0), SE); - uint32_t BitWidth = SE.getTypeSizeInBits(M->getType()); + uint32_t SumOpRes = GetMinTrailingZeros(M->getOperand(0)); + uint32_t BitWidth = getTypeSizeInBits(M->getType()); for (unsigned i = 1, e = M->getNumOperands(); SumOpRes != BitWidth && i != e; ++i) - SumOpRes = std::min(SumOpRes + GetMinTrailingZeros(M->getOperand(i), SE), + SumOpRes = std::min(SumOpRes + GetMinTrailingZeros(M->getOperand(i)), BitWidth); return SumOpRes; } if (const SCEVAddRecExpr *A = dyn_cast<SCEVAddRecExpr>(S)) { // The result is the min of all operands results. - uint32_t MinOpRes = GetMinTrailingZeros(A->getOperand(0), SE); + uint32_t MinOpRes = GetMinTrailingZeros(A->getOperand(0)); for (unsigned i = 1, e = A->getNumOperands(); MinOpRes && i != e; ++i) - MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(A->getOperand(i), SE)); + MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(A->getOperand(i))); return MinOpRes; } if (const SCEVSMaxExpr *M = dyn_cast<SCEVSMaxExpr>(S)) { // The result is the min of all operands results. - uint32_t MinOpRes = GetMinTrailingZeros(M->getOperand(0), SE); + uint32_t MinOpRes = GetMinTrailingZeros(M->getOperand(0)); for (unsigned i = 1, e = M->getNumOperands(); MinOpRes && i != e; ++i) - MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(M->getOperand(i), SE)); + MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(M->getOperand(i))); return MinOpRes; } if (const SCEVUMaxExpr *M = dyn_cast<SCEVUMaxExpr>(S)) { // The result is the min of all operands results. - uint32_t MinOpRes = GetMinTrailingZeros(M->getOperand(0), SE); + uint32_t MinOpRes = GetMinTrailingZeros(M->getOperand(0)); for (unsigned i = 1, e = M->getNumOperands(); MinOpRes && i != e; ++i) - MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(M->getOperand(i), SE)); + MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(M->getOperand(i))); return MinOpRes; } - // SCEVUDivExpr, SCEVUnknown + if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S)) { + // For a SCEVUnknown, ask ValueTracking. + unsigned BitWidth = getTypeSizeInBits(U->getType()); + APInt Mask = APInt::getAllOnesValue(BitWidth); + APInt Zeros(BitWidth, 0), Ones(BitWidth, 0); + ComputeMaskedBits(U->getValue(), Mask, Zeros, Ones); + return Zeros.countTrailingOnes(); + } + + // SCEVUDivExpr return 0; } +uint32_t +ScalarEvolution::GetMinLeadingZeros(const SCEVHandle &S) { + // TODO: Handle other SCEV expression types here. + + if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S)) + return C->getValue()->getValue().countLeadingZeros(); + + if (const SCEVZeroExtendExpr *C = dyn_cast<SCEVZeroExtendExpr>(S)) { + // A zero-extension cast adds zero bits. + return GetMinLeadingZeros(C->getOperand()) + + (getTypeSizeInBits(C->getType()) - + getTypeSizeInBits(C->getOperand()->getType())); + } + + if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S)) { + // For a SCEVUnknown, ask ValueTracking. + unsigned BitWidth = getTypeSizeInBits(U->getType()); + APInt Mask = APInt::getAllOnesValue(BitWidth); + APInt Zeros(BitWidth, 0), Ones(BitWidth, 0); + ComputeMaskedBits(U->getValue(), Mask, Zeros, Ones, TD); + return Zeros.countLeadingOnes(); + } + + return 1; +} + +uint32_t +ScalarEvolution::GetMinSignBits(const SCEVHandle &S) { + // TODO: Handle other SCEV expression types here. + + if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S)) { + const APInt &A = C->getValue()->getValue(); + return A.isNegative() ? A.countLeadingOnes() : + A.countLeadingZeros(); + } + + if (const SCEVSignExtendExpr *C = dyn_cast<SCEVSignExtendExpr>(S)) { + // A sign-extension cast adds sign bits. + return GetMinSignBits(C->getOperand()) + + (getTypeSizeInBits(C->getType()) - + getTypeSizeInBits(C->getOperand()->getType())); + } + + if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S)) { + // For a SCEVUnknown, ask ValueTracking. + return ComputeNumSignBits(U->getValue(), TD); + } + + return 1; +} + /// createSCEV - We know that there is no SCEV for the specified value. /// Analyze the expression. /// @@ -2248,14 +2488,27 @@ SCEVHandle ScalarEvolution::createSCEV(Value *V) { if (CI->isAllOnesValue()) return getSCEV(U->getOperand(0)); const APInt &A = CI->getValue(); - unsigned Ones = A.countTrailingOnes(); - if (APIntOps::isMask(Ones, A)) + + // Instcombine's ShrinkDemandedConstant may strip bits out of + // constants, obscuring what would otherwise be a low-bits mask. + // Use ComputeMaskedBits to compute what ShrinkDemandedConstant + // knew about to reconstruct a low-bits mask value. + unsigned LZ = A.countLeadingZeros(); + unsigned BitWidth = A.getBitWidth(); + APInt AllOnes = APInt::getAllOnesValue(BitWidth); + APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0); + ComputeMaskedBits(U->getOperand(0), AllOnes, KnownZero, KnownOne, TD); + + APInt EffectiveMask = APInt::getLowBitsSet(BitWidth, BitWidth - LZ); + + if (LZ != 0 && !((~A & ~KnownZero) & EffectiveMask)) return getZeroExtendExpr(getTruncateExpr(getSCEV(U->getOperand(0)), - IntegerType::get(Ones)), + IntegerType::get(BitWidth - LZ)), U->getType()); } break; + case Instruction::Or: // If the RHS of the Or is a constant, we may have something like: // X*4+1 which got turned into X*4|1. Handle this as an Add so loop @@ -2266,7 +2519,7 @@ SCEVHandle ScalarEvolution::createSCEV(Value *V) { if (ConstantInt *CI = dyn_cast<ConstantInt>(U->getOperand(1))) { SCEVHandle LHS = getSCEV(U->getOperand(0)); const APInt &CIVal = CI->getValue(); - if (GetMinTrailingZeros(LHS, *this) >= + if (GetMinTrailingZeros(LHS) >= (CIVal.getBitWidth() - CIVal.countLeadingZeros())) return getAddExpr(LHS, getSCEV(U->getOperand(1))); } @@ -2292,9 +2545,27 @@ SCEVHandle ScalarEvolution::createSCEV(Value *V) { if (BO->getOpcode() == Instruction::And && LCI->getValue() == CI->getValue()) if (const SCEVZeroExtendExpr *Z = - dyn_cast<SCEVZeroExtendExpr>(getSCEV(U->getOperand(0)))) - return getZeroExtendExpr(getNotSCEV(Z->getOperand()), - U->getType()); + dyn_cast<SCEVZeroExtendExpr>(getSCEV(U->getOperand(0)))) { + const Type *UTy = U->getType(); + SCEVHandle Z0 = Z->getOperand(); + const Type *Z0Ty = Z0->getType(); + unsigned Z0TySize = getTypeSizeInBits(Z0Ty); + + // If C is a low-bits mask, the zero extend is zerving to + // mask off the high bits. Complement the operand and + // re-apply the zext. + if (APIntOps::isMask(Z0TySize, CI->getValue())) + return getZeroExtendExpr(getNotSCEV(Z0), UTy); + + // If C is a single bit, it may be in the sign-bit position + // before the zero-extend. In this case, represent the xor + // using an add, which is equivalent, and re-apply the zext. + APInt Trunc = APInt(CI->getValue()).trunc(Z0TySize); + if (APInt(Trunc).zext(getTypeSizeInBits(UTy)) == CI->getValue() && + Trunc.isSignBit()) + return getZeroExtendExpr(getAddExpr(Z0, getConstant(Trunc)), + UTy); + } } break; @@ -2385,10 +2656,7 @@ SCEVHandle ScalarEvolution::createSCEV(Value *V) { if (LHS == U->getOperand(1) && RHS == U->getOperand(2)) return getSMaxExpr(getSCEV(LHS), getSCEV(RHS)); else if (LHS == U->getOperand(2) && RHS == U->getOperand(1)) - // ~smax(~x, ~y) == smin(x, y). - return getNotSCEV(getSMaxExpr( - getNotSCEV(getSCEV(LHS)), - getNotSCEV(getSCEV(RHS)))); + return getSMinExpr(getSCEV(LHS), getSCEV(RHS)); break; case ICmpInst::ICMP_ULT: case ICmpInst::ICMP_ULE: @@ -2399,9 +2667,25 @@ SCEVHandle ScalarEvolution::createSCEV(Value *V) { if (LHS == U->getOperand(1) && RHS == U->getOperand(2)) return getUMaxExpr(getSCEV(LHS), getSCEV(RHS)); else if (LHS == U->getOperand(2) && RHS == U->getOperand(1)) - // ~umax(~x, ~y) == umin(x, y) - return getNotSCEV(getUMaxExpr(getNotSCEV(getSCEV(LHS)), - getNotSCEV(getSCEV(RHS)))); + return getUMinExpr(getSCEV(LHS), getSCEV(RHS)); + break; + case ICmpInst::ICMP_NE: + // n != 0 ? n : 1 -> umax(n, 1) + if (LHS == U->getOperand(1) && + isa<ConstantInt>(U->getOperand(2)) && + cast<ConstantInt>(U->getOperand(2))->isOne() && + isa<ConstantInt>(RHS) && + cast<ConstantInt>(RHS)->isZero()) + return getUMaxExpr(getSCEV(LHS), getSCEV(U->getOperand(2))); + break; + case ICmpInst::ICMP_EQ: + // n == 0 ? 1 : n -> umax(n, 1) + if (LHS == U->getOperand(2) && + isa<ConstantInt>(U->getOperand(1)) && + cast<ConstantInt>(U->getOperand(1))->isOne() && + isa<ConstantInt>(RHS) && + cast<ConstantInt>(RHS)->isZero()) + return getUMaxExpr(getSCEV(LHS), getSCEV(U->getOperand(1))); break; default: break; @@ -2462,9 +2746,13 @@ ScalarEvolution::getBackedgeTakenInfo(const Loop *L) { // Update the value in the map. Pair.first->second = ItCount; - } else if (isa<PHINode>(L->getHeader()->begin())) { - // Only count loops that have phi nodes as not being computable. - ++NumTripCountsNotComputed; + } else { + if (ItCount.Max != CouldNotCompute) + // Update the value in the map. + Pair.first->second = ItCount; + if (isa<PHINode>(L->getHeader()->begin())) + // Only count loops that have phi nodes as not being computable. + ++NumTripCountsNotComputed; } // Now that we know more about the trip count for this loop, forget any @@ -2520,19 +2808,58 @@ void ScalarEvolution::forgetLoopPHIs(const Loop *L) { /// of the specified loop will execute. ScalarEvolution::BackedgeTakenInfo ScalarEvolution::ComputeBackedgeTakenCount(const Loop *L) { - // If the loop has a non-one exit block count, we can't analyze it. - BasicBlock *ExitBlock = L->getExitBlock(); - if (!ExitBlock) - return CouldNotCompute; + SmallVector<BasicBlock*, 8> ExitingBlocks; + L->getExitingBlocks(ExitingBlocks); + + // Examine all exits and pick the most conservative values. + SCEVHandle BECount = CouldNotCompute; + SCEVHandle MaxBECount = CouldNotCompute; + bool CouldNotComputeBECount = false; + bool CouldNotComputeMaxBECount = false; + for (unsigned i = 0, e = ExitingBlocks.size(); i != e; ++i) { + BackedgeTakenInfo NewBTI = + ComputeBackedgeTakenCountFromExit(L, ExitingBlocks[i]); + + if (NewBTI.Exact == CouldNotCompute) { + // We couldn't compute an exact value for this exit, so + // we don't be able to compute an exact value for the loop. + CouldNotComputeBECount = true; + BECount = CouldNotCompute; + } else if (!CouldNotComputeBECount) { + if (BECount == CouldNotCompute) + BECount = NewBTI.Exact; + else { + // TODO: More analysis could be done here. For example, a + // loop with a short-circuiting && operator has an exact count + // of the min of both sides. + CouldNotComputeBECount = true; + BECount = CouldNotCompute; + } + } + if (NewBTI.Max == CouldNotCompute) { + // We couldn't compute an maximum value for this exit, so + // we don't be able to compute an maximum value for the loop. + CouldNotComputeMaxBECount = true; + MaxBECount = CouldNotCompute; + } else if (!CouldNotComputeMaxBECount) { + if (MaxBECount == CouldNotCompute) + MaxBECount = NewBTI.Max; + else + MaxBECount = getUMaxFromMismatchedTypes(MaxBECount, NewBTI.Max); + } + } + + return BackedgeTakenInfo(BECount, MaxBECount); +} - // Okay, there is one exit block. Try to find the condition that causes the - // loop to be exited. - BasicBlock *ExitingBlock = L->getExitingBlock(); - if (!ExitingBlock) - return CouldNotCompute; // More than one block exiting! +/// ComputeBackedgeTakenCountFromExit - Compute the number of times the backedge +/// of the specified loop will execute if it exits via the specified block. +ScalarEvolution::BackedgeTakenInfo +ScalarEvolution::ComputeBackedgeTakenCountFromExit(const Loop *L, + BasicBlock *ExitingBlock) { - // Okay, we've computed the exiting block. See what condition causes us to - // exit. + // Okay, we've chosen an exiting block. See what condition causes us to + // exit at this block. // // FIXME: we should be able to handle switch instructions (with a single exit) BranchInst *ExitBr = dyn_cast<BranchInst>(ExitingBlock->getTerminator()); @@ -2547,23 +2874,154 @@ ScalarEvolution::ComputeBackedgeTakenCount(const Loop *L) { // Currently we check for this by checking to see if the Exit branch goes to // the loop header. If so, we know it will always execute the same number of // times as the loop. We also handle the case where the exit block *is* the - // loop header. This is common for un-rotated loops. More extensive analysis - // could be done to handle more cases here. + // loop header. This is common for un-rotated loops. + // + // If both of those tests fail, walk up the unique predecessor chain to the + // header, stopping if there is an edge that doesn't exit the loop. If the + // header is reached, the execution count of the branch will be equal to the + // trip count of the loop. + // + // More extensive analysis could be done to handle more cases here. + // if (ExitBr->getSuccessor(0) != L->getHeader() && ExitBr->getSuccessor(1) != L->getHeader() && - ExitBr->getParent() != L->getHeader()) - return CouldNotCompute; - - ICmpInst *ExitCond = dyn_cast<ICmpInst>(ExitBr->getCondition()); + ExitBr->getParent() != L->getHeader()) { + // The simple checks failed, try climbing the unique predecessor chain + // up to the header. + bool Ok = false; + for (BasicBlock *BB = ExitBr->getParent(); BB; ) { + BasicBlock *Pred = BB->getUniquePredecessor(); + if (!Pred) + return CouldNotCompute; + TerminatorInst *PredTerm = Pred->getTerminator(); + for (unsigned i = 0, e = PredTerm->getNumSuccessors(); i != e; ++i) { + BasicBlock *PredSucc = PredTerm->getSuccessor(i); + if (PredSucc == BB) + continue; + // If the predecessor has a successor that isn't BB and isn't + // outside the loop, assume the worst. + if (L->contains(PredSucc)) + return CouldNotCompute; + } + if (Pred == L->getHeader()) { + Ok = true; + break; + } + BB = Pred; + } + if (!Ok) + return CouldNotCompute; + } + + // Procede to the next level to examine the exit condition expression. + return ComputeBackedgeTakenCountFromExitCond(L, ExitBr->getCondition(), + ExitBr->getSuccessor(0), + ExitBr->getSuccessor(1)); +} + +/// ComputeBackedgeTakenCountFromExitCond - Compute the number of times the +/// backedge of the specified loop will execute if its exit condition +/// were a conditional branch of ExitCond, TBB, and FBB. +ScalarEvolution::BackedgeTakenInfo +ScalarEvolution::ComputeBackedgeTakenCountFromExitCond(const Loop *L, + Value *ExitCond, + BasicBlock *TBB, + BasicBlock *FBB) { + // Check if the controlling expression for this loop is an and or or. In + // such cases, an exact backedge-taken count may be infeasible, but a + // maximum count may still be feasible. + if (BinaryOperator *BO = dyn_cast<BinaryOperator>(ExitCond)) { + if (BO->getOpcode() == Instruction::And) { + // Recurse on the operands of the and. + BackedgeTakenInfo BTI0 = + ComputeBackedgeTakenCountFromExitCond(L, BO->getOperand(0), TBB, FBB); + BackedgeTakenInfo BTI1 = + ComputeBackedgeTakenCountFromExitCond(L, BO->getOperand(1), TBB, FBB); + SCEVHandle BECount = CouldNotCompute; + SCEVHandle MaxBECount = CouldNotCompute; + if (L->contains(TBB)) { + // Both conditions must be true for the loop to continue executing. + // Choose the less conservative count. + // TODO: Take the minimum of the exact counts. + if (BTI0.Exact == BTI1.Exact) + BECount = BTI0.Exact; + // TODO: Take the minimum of the maximum counts. + if (BTI0.Max == CouldNotCompute) + MaxBECount = BTI1.Max; + else if (BTI1.Max == CouldNotCompute) + MaxBECount = BTI0.Max; + else if (const SCEVConstant *C0 = dyn_cast<SCEVConstant>(BTI0.Max)) + if (const SCEVConstant *C1 = dyn_cast<SCEVConstant>(BTI1.Max)) + MaxBECount = getConstant(APIntOps::umin(C0->getValue()->getValue(), + C1->getValue()->getValue())); + } else { + // Both conditions must be true for the loop to exit. + assert(L->contains(FBB) && "Loop block has no successor in loop!"); + if (BTI0.Exact != CouldNotCompute && BTI1.Exact != CouldNotCompute) + BECount = getUMaxFromMismatchedTypes(BTI0.Exact, BTI1.Exact); + if (BTI0.Max != CouldNotCompute && BTI1.Max != CouldNotCompute) + MaxBECount = getUMaxFromMismatchedTypes(BTI0.Max, BTI1.Max); + } + + return BackedgeTakenInfo(BECount, MaxBECount); + } + if (BO->getOpcode() == Instruction::Or) { + // Recurse on the operands of the or. + BackedgeTakenInfo BTI0 = + ComputeBackedgeTakenCountFromExitCond(L, BO->getOperand(0), TBB, FBB); + BackedgeTakenInfo BTI1 = + ComputeBackedgeTakenCountFromExitCond(L, BO->getOperand(1), TBB, FBB); + SCEVHandle BECount = CouldNotCompute; + SCEVHandle MaxBECount = CouldNotCompute; + if (L->contains(FBB)) { + // Both conditions must be false for the loop to continue executing. + // Choose the less conservative count. + // TODO: Take the minimum of the exact counts. + if (BTI0.Exact == BTI1.Exact) + BECount = BTI0.Exact; + // TODO: Take the minimum of the maximum counts. + if (BTI0.Max == CouldNotCompute) + MaxBECount = BTI1.Max; + else if (BTI1.Max == CouldNotCompute) + MaxBECount = BTI0.Max; + else if (const SCEVConstant *C0 = dyn_cast<SCEVConstant>(BTI0.Max)) + if (const SCEVConstant *C1 = dyn_cast<SCEVConstant>(BTI1.Max)) + MaxBECount = getConstant(APIntOps::umin(C0->getValue()->getValue(), + C1->getValue()->getValue())); + } else { + // Both conditions must be false for the loop to exit. + assert(L->contains(TBB) && "Loop block has no successor in loop!"); + if (BTI0.Exact != CouldNotCompute && BTI1.Exact != CouldNotCompute) + BECount = getUMaxFromMismatchedTypes(BTI0.Exact, BTI1.Exact); + if (BTI0.Max != CouldNotCompute && BTI1.Max != CouldNotCompute) + MaxBECount = getUMaxFromMismatchedTypes(BTI0.Max, BTI1.Max); + } + + return BackedgeTakenInfo(BECount, MaxBECount); + } + } + + // With an icmp, it may be feasible to compute an exact backedge-taken count. + // Procede to the next level to examine the icmp. + if (ICmpInst *ExitCondICmp = dyn_cast<ICmpInst>(ExitCond)) + return ComputeBackedgeTakenCountFromExitCondICmp(L, ExitCondICmp, TBB, FBB); // If it's not an integer or pointer comparison then compute it the hard way. - if (ExitCond == 0) - return ComputeBackedgeTakenCountExhaustively(L, ExitBr->getCondition(), - ExitBr->getSuccessor(0) == ExitBlock); + return ComputeBackedgeTakenCountExhaustively(L, ExitCond, !L->contains(TBB)); +} + +/// ComputeBackedgeTakenCountFromExitCondICmp - Compute the number of times the +/// backedge of the specified loop will execute if its exit condition +/// were a conditional branch of the ICmpInst ExitCond, TBB, and FBB. +ScalarEvolution::BackedgeTakenInfo +ScalarEvolution::ComputeBackedgeTakenCountFromExitCondICmp(const Loop *L, + ICmpInst *ExitCond, + BasicBlock *TBB, + BasicBlock *FBB) { // If the condition was exit on true, convert the condition to exit on false ICmpInst::Predicate Cond; - if (ExitBr->getSuccessor(1) == ExitBlock) + if (!L->contains(FBB)) Cond = ExitCond->getPredicate(); else Cond = ExitCond->getInversePredicate(); @@ -2573,7 +3031,12 @@ ScalarEvolution::ComputeBackedgeTakenCount(const Loop *L) { if (Constant *RHS = dyn_cast<Constant>(ExitCond->getOperand(1))) { SCEVHandle ItCnt = ComputeLoadConstantCompareBackedgeTakenCount(LI, RHS, L, Cond); - if (!isa<SCEVCouldNotCompute>(ItCnt)) return ItCnt; + if (!isa<SCEVCouldNotCompute>(ItCnt)) { + unsigned BitWidth = getTypeSizeInBits(ItCnt->getType()); + return BackedgeTakenInfo(ItCnt, + isa<SCEVConstant>(ItCnt) ? ItCnt : + getConstant(APInt::getMaxValue(BitWidth)-1)); + } } SCEVHandle LHS = getSCEV(ExitCond->getOperand(0)); @@ -2651,8 +3114,7 @@ ScalarEvolution::ComputeBackedgeTakenCount(const Loop *L) { break; } return - ComputeBackedgeTakenCountExhaustively(L, ExitCond, - ExitBr->getSuccessor(0) == ExitBlock); + ComputeBackedgeTakenCountExhaustively(L, ExitCond, !L->contains(TBB)); } static ConstantInt * @@ -2750,7 +3212,7 @@ ComputeLoadConstantCompareBackedgeTakenCount(LoadInst *LI, Constant *RHS, unsigned MaxSteps = MaxBruteForceIterations; for (unsigned IterationNum = 0; IterationNum != MaxSteps; ++IterationNum) { ConstantInt *ItCst = - ConstantInt::get(IdxExpr->getType(), IterationNum); + ConstantInt::get(cast<IntegerType>(IdxExpr->getType()), IterationNum); ConstantInt *Val = EvaluateConstantChrecAtConstant(IdxExpr, ItCst, *this); // Form the GEP offset. @@ -2945,7 +3407,7 @@ ComputeBackedgeTakenCountExhaustively(const Loop *L, Value *Cond, bool ExitWhen) if (CondVal->getValue() == uint64_t(ExitWhen)) { ConstantEvolutionLoopExitValue[PN] = PHIVal; ++NumBruteForceTripCountsComputed; - return getConstant(ConstantInt::get(Type::Int32Ty, IterationNum)); + return getConstant(Type::Int32Ty, IterationNum); } // Compute the value of the PHI node for the next iteration. @@ -3074,7 +3536,7 @@ SCEVHandle ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) { if (OpAtScope != Comm->getOperand(i)) { // Okay, at least one of these operands is loop variant but might be // foldable. Build a new instance of the folded commutative expression. - std::vector<SCEVHandle> NewOps(Comm->op_begin(), Comm->op_begin()+i); + SmallVector<SCEVHandle, 8> NewOps(Comm->op_begin(), Comm->op_begin()+i); NewOps.push_back(OpAtScope); for (++i; i != e; ++i) { @@ -3394,6 +3856,29 @@ ScalarEvolution::getPredecessorWithUniqueSuccessorForBB(BasicBlock *BB) { return 0; } +/// HasSameValue - SCEV structural equivalence is usually sufficient for +/// testing whether two expressions are equal, however for the purposes of +/// looking for a condition guarding a loop, it can be useful to be a little +/// more general, since a front-end may have replicated the controlling +/// expression. +/// +static bool HasSameValue(const SCEVHandle &A, const SCEVHandle &B) { + // Quick check to see if they are the same SCEV. + if (A == B) return true; + + // Otherwise, if they're both SCEVUnknown, it's possible that they hold + // two different instructions with the same value. Check for this case. + if (const SCEVUnknown *AU = dyn_cast<SCEVUnknown>(A)) + if (const SCEVUnknown *BU = dyn_cast<SCEVUnknown>(B)) + if (const Instruction *AI = dyn_cast<Instruction>(AU->getValue())) + if (const Instruction *BI = dyn_cast<Instruction>(BU->getValue())) + if (AI->isIdenticalTo(BI)) + return true; + + // Otherwise assume they may have a different value. + return false; +} + /// isLoopGuardedByCond - Test whether entry to the loop is protected by /// a conditional between LHS and RHS. This is used to help avoid max /// expressions in loop trip counts. @@ -3494,15 +3979,43 @@ bool ScalarEvolution::isLoopGuardedByCond(const Loop *L, SCEVHandle PreCondLHSSCEV = getSCEV(PreCondLHS); SCEVHandle PreCondRHSSCEV = getSCEV(PreCondRHS); - if ((LHS == PreCondLHSSCEV && RHS == PreCondRHSSCEV) || - (LHS == getNotSCEV(PreCondRHSSCEV) && - RHS == getNotSCEV(PreCondLHSSCEV))) + if ((HasSameValue(LHS, PreCondLHSSCEV) && + HasSameValue(RHS, PreCondRHSSCEV)) || + (HasSameValue(LHS, getNotSCEV(PreCondRHSSCEV)) && + HasSameValue(RHS, getNotSCEV(PreCondLHSSCEV)))) return true; } return false; } +/// getBECount - Subtract the end and start values and divide by the step, +/// rounding up, to get the number of times the backedge is executed. Return +/// CouldNotCompute if an intermediate computation overflows. +SCEVHandle ScalarEvolution::getBECount(const SCEVHandle &Start, + const SCEVHandle &End, + const SCEVHandle &Step) { + const Type *Ty = Start->getType(); + SCEVHandle NegOne = getIntegerSCEV(-1, Ty); + SCEVHandle Diff = getMinusSCEV(End, Start); + SCEVHandle RoundUp = getAddExpr(Step, NegOne); + + // Add an adjustment to the difference between End and Start so that + // the division will effectively round up. + SCEVHandle Add = getAddExpr(Diff, RoundUp); + + // Check Add for unsigned overflow. + // TODO: More sophisticated things could be done here. + const Type *WideTy = IntegerType::get(getTypeSizeInBits(Ty) + 1); + SCEVHandle OperandExtendedAdd = + getAddExpr(getZeroExtendExpr(Diff, WideTy), + getZeroExtendExpr(RoundUp, WideTy)); + if (getZeroExtendExpr(Add, WideTy) != OperandExtendedAdd) + return CouldNotCompute; + + return getUDivExpr(Add, Step); +} + /// HowManyLessThans - Return the number of times a backedge containing the /// specified less-than comparison will execute. If not computable, return /// CouldNotCompute. @@ -3520,7 +4033,6 @@ HowManyLessThans(const SCEV *LHS, const SCEV *RHS, // FORNOW: We only support unit strides. unsigned BitWidth = getTypeSizeInBits(AddRec->getType()); SCEVHandle Step = AddRec->getStepRecurrence(*this); - SCEVHandle NegOne = getIntegerSCEV(-1, AddRec->getType()); // TODO: handle non-constant strides. const SCEVConstant *CStep = dyn_cast<SCEVConstant>(Step); @@ -3575,22 +4087,20 @@ HowManyLessThans(const SCEV *LHS, const SCEV *RHS, : getUMaxExpr(RHS, Start); // Determine the maximum constant end value. - SCEVHandle MaxEnd = isa<SCEVConstant>(End) ? End : - getConstant(isSigned ? APInt::getSignedMaxValue(BitWidth) : - APInt::getMaxValue(BitWidth)); + SCEVHandle MaxEnd = + isa<SCEVConstant>(End) ? End : + getConstant(isSigned ? APInt::getSignedMaxValue(BitWidth) + .ashr(GetMinSignBits(End) - 1) : + APInt::getMaxValue(BitWidth) + .lshr(GetMinLeadingZeros(End))); // Finally, we subtract these two values and divide, rounding up, to get // the number of times the backedge is executed. - SCEVHandle BECount = getUDivExpr(getAddExpr(getMinusSCEV(End, Start), - getAddExpr(Step, NegOne)), - Step); + SCEVHandle BECount = getBECount(Start, End, Step); // The maximum backedge count is similar, except using the minimum start // value and the maximum end value. - SCEVHandle MaxBECount = getUDivExpr(getAddExpr(getMinusSCEV(MaxEnd, - MinStart), - getAddExpr(Step, NegOne)), - Step); + SCEVHandle MaxBECount = getBECount(MinStart, MaxEnd, Step);; return BackedgeTakenInfo(BECount, MaxBECount); } @@ -3611,7 +4121,7 @@ SCEVHandle SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range, // If the start is a non-zero constant, shift the range to simplify things. if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(getStart())) if (!SC->getValue()->isZero()) { - std::vector<SCEVHandle> Operands(op_begin(), op_end()); + SmallVector<SCEVHandle, 4> Operands(op_begin(), op_end()); Operands[0] = SE.getIntegerSCEV(0, SC->getType()); SCEVHandle Shifted = SE.getAddRecExpr(Operands, getLoop()); if (const SCEVAddRecExpr *ShiftedAddRec = @@ -3636,7 +4146,7 @@ SCEVHandle SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range, // iteration exits. unsigned BitWidth = SE.getTypeSizeInBits(getType()); if (!Range.contains(APInt(BitWidth, 0))) - return SE.getConstant(ConstantInt::get(getType(),0)); + return SE.getIntegerSCEV(0, getType()); if (isAffine()) { // If this is an affine expression then we have this situation: @@ -3672,7 +4182,7 @@ SCEVHandle SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range, // quadratic equation to solve it. To do this, we must frame our problem in // terms of figuring out when zero is crossed, instead of when // Range.getUpper() is crossed. - std::vector<SCEVHandle> NewOps(op_begin(), op_end()); + SmallVector<SCEVHandle, 4> NewOps(op_begin(), op_end()); NewOps[0] = SE.getNegativeSCEV(SE.getConstant(Range.getUpper())); SCEVHandle NewAddRec = SE.getAddRecExpr(NewOps, getLoop()); @@ -3783,7 +4293,7 @@ ScalarEvolution::SCEVCallbackVH::SCEVCallbackVH(Value *V, ScalarEvolution *se) //===----------------------------------------------------------------------===// ScalarEvolution::ScalarEvolution() - : FunctionPass(&ID), CouldNotCompute(new SCEVCouldNotCompute()) { + : FunctionPass(&ID), CouldNotCompute(new SCEVCouldNotCompute(0)) { } bool ScalarEvolution::runOnFunction(Function &F) { @@ -3847,11 +4357,18 @@ void ScalarEvolution::print(raw_ostream &OS, const Module* ) const { OS << " --> "; SCEVHandle SV = SE.getSCEV(&*I); SV->print(OS); - OS << "\t\t"; - if (const Loop *L = LI->getLoopFor((*I).getParent())) { - OS << "Exits: "; - SCEVHandle ExitValue = SE.getSCEVAtScope(&*I, L->getParentLoop()); + const Loop *L = LI->getLoopFor((*I).getParent()); + + SCEVHandle AtUse = SE.getSCEVAtScope(SV, L); + if (AtUse != SV) { + OS << " --> "; + AtUse->print(OS); + } + + if (L) { + OS << "\t\t" "Exits: "; + SCEVHandle ExitValue = SE.getSCEVAtScope(SV, L->getParentLoop()); if (!ExitValue->isLoopInvariant(L)) { OS << "<<Unknown>>"; } else { |